Source code for sycamore.transforms.llm_query

from typing import Optional, Any, Union

from sycamore.data import Element, Document
from sycamore.plan_nodes import NonCPUUser, NonGPUUser, Node
from sycamore.llms import LLM
from sycamore.transforms.map import Map
from sycamore.utils.time_trace import timetrace
from jinja2.sandbox import SandboxedEnvironment


[docs] class LLMTextQueryAgent: """ LLMTextQueryAgent uses a specified LLM to execute LLM queries about a document or its child elements. Args: prompt: A prompt to be passed into the underlying LLM execution engine llm: The LLM Client to be used here. It is defined as an instance of the LLM class in Sycamore. output_property: (Optional, default="llm_response") The output property of the document or element to add results in. format_kwargs: (Optional, default="None") If passed in, details the formatting details that must be passed into the underlying Jinja Sandbox. number_of_elements: (Optional, default="None") When "per_element" is true, limits the number of elements to add an "output_property". Otherwise, the response is added to the entire document using a limited prefix subset of the elements. llm_kwargs: (Optional) LLM keyword argument for the underlying execution engine per_element: (Optional, default="{}") Keyword arguments to be passed into the underlying LLM execution engine. element_type: (Optional) Parameter to only execute the LLM query on a particular element type. If not specified, the query will be executed on all elements. Example: .. code-block:: python prompt="Tell me the important numbers from this element" llm_query_agent = LLMTextQueryAgent(prompt=prompt) context = sycamore.init() pdf_docset = context.read.binary(paths, binary_format="pdf") .partition(partitioner=ArynPartitioner()) .llm_query(query_agent=llm_query_agent) """ def __init__( self, prompt: str, llm: LLM, output_property: str = "llm_response", format_kwargs: Optional[dict[str, Any]] = None, number_of_elements: Optional[int] = None, llm_kwargs: dict = {}, per_element: bool = True, element_type: Optional[str] = None, table_cont: Optional[bool] = False, ): self._llm = llm self._prompt = prompt self._output_property = output_property self._llm_kwargs = llm_kwargs self._per_element = per_element self._format_kwargs = format_kwargs self._number_of_elements = number_of_elements self._element_type = element_type self._table_cont = table_cont def execute_query(self, document: Document) -> Document: final_prompt = self._prompt element_count = 0 prev_table = -1 if self._per_element or self._number_of_elements: for idx, element in enumerate(document.elements): if self._element_type and element.type != self._element_type: continue if self._per_element: if not self._table_cont: document.elements[idx] = self._query_text_object(element) else: if prev_table >= 0: document.elements[idx] = self._query_text_object(element, document.elements[prev_table]) else: document.elements[idx] = self._query_text_object(element) prev_table = idx else: final_prompt += "\n" + element["text_representation"] if self._number_of_elements: element_count += 1 if element_count >= self._number_of_elements: break if not self._per_element: prompt_kwargs = {"prompt": final_prompt} llm_resp = self._llm.generate_old(prompt_kwargs=prompt_kwargs, llm_kwargs=self._llm_kwargs) document["properties"][self._output_property] = llm_resp else: if document.text_representation: document = self._query_text_object(document) return document @timetrace("LLMQueryText") def _query_text_object( self, object: Union[Document, Element], objectPrev: Optional[Element] = None ) -> Union[Document, Element]: if object.text_representation: if self._format_kwargs: prompt = ( SandboxedEnvironment() .from_string(source=self._prompt, globals=self._format_kwargs) .render(doc=object) ) else: object_name = "ELEMENT" if isinstance(object, Element) else "DOCUMENT" if objectPrev and objectPrev.text_representation: prompt = ( self._prompt + "\n" + f"{object_name} 1: \n\n" + objectPrev.text_representation + "\n\n" + f"{object_name} 2: \n" + object.text_representation ) else: prompt = self._prompt + "\n" + object.text_representation prompt_kwargs = {"prompt": prompt} llm_resp = self._llm.generate_old(prompt_kwargs=prompt_kwargs, llm_kwargs=self._llm_kwargs) if self._table_cont: object["properties"]["table_continuation"] = llm_resp else: object["properties"][self._output_property] = llm_resp return object
class LLMQuery(NonCPUUser, NonGPUUser, Map): """ The LLM Query Transform executes user defined queries on a document or the elements within it. """ def __init__(self, child: Node, query_agent: LLMTextQueryAgent, **kwargs): super().__init__(child, f=query_agent.execute_query, **kwargs)