Source code for sycamore.transforms.extract_entity

from abc import ABC, abstractmethod
from typing import Callable, Any, Optional, Union

from sycamore.context import Context, context_params, OperationTypes
from sycamore.data import Element, Document
from sycamore.llms import LLM

from sycamore.llms.prompts.prompts import (
    RenderedMessage,
    SycamorePrompt,
    RenderedPrompt,
    JinjaPrompt,
)
from sycamore.llms.prompts.jinja_fragments import (
    J_ELEMENT_BATCHED_LIST,
    J_ELEMENT_BATCHED_LIST_WITH_METADATA,
)
from sycamore.plan_nodes import Node
from sycamore.transforms.base import CompositeTransform, BaseMapTransform
from sycamore.transforms.base_llm import LLMMap
from sycamore.transforms.map import Map
from sycamore.utils.time_trace import timetrace
from sycamore.functions.tokenizer import Tokenizer
from sycamore.transforms.similarity import SimilarityScorer
from sycamore.utils.similarity import make_element_sorter_fn
from sycamore.utils.llm_utils import merge_elements


def element_list_formatter(elements: list[Element], field: str = "text_representation") -> str:
    query = ""
    for i in range(len(elements)):
        value = str(elements[i].field_to_value(field))
        query += f"ELEMENT {i + 1}: {value}\n"
    return query


class FieldToValuePrompt(SycamorePrompt):
    def __init__(self, messages: list[RenderedMessage], field: str):
        self.messages = messages
        self.field = field

    def render_document(self, doc: Document) -> RenderedPrompt:
        value = doc.field_to_value(self.field)
        rendered = []
        for m in self.messages:
            rendered.append(RenderedMessage(role=m.role, content=m.content.format(value=value)))
        return RenderedPrompt(messages=rendered)


class EntityExtractor(ABC):
    def __init__(self, entity_name: str):
        self._entity_name = entity_name

    @abstractmethod
    def as_llm_map(
        self,
        child: Optional[Node],
        context: Optional[Context] = None,
        llm: Optional[LLM] = None,
        **kwargs,
    ) -> Node:
        pass

    @abstractmethod
    def extract_entity(
        self,
        document: Document,
        context: Optional[Context] = None,
        llm: Optional[LLM] = None,
    ) -> Document:
        pass

    def property(self):
        """The name of the property added by calling extract_entity"""
        return self._entity_name


[docs] class OpenAIEntityExtractor(EntityExtractor): """ OpenAIEntityExtractor uses one of OpenAI's language model (LLM) for entity extraction. This class inherits from EntityExtractor and is designed for extracting a specific entity from a document using OpenAI's language model. It can use either zero-shot prompting or few-shot prompting to extract the entity. The extracted entities from the input document are put into the document properties. Args: entity_name: The name of the entity to be extracted. llm: An instance of an OpenAI language model for text processing. prompt_template: A template for constructing prompts for few-shot prompting. Default is None. num_of_elements: The number of elements to consider for entity extraction. Default is 10. prompt_formatter: A callable function to format prompts based on document elements. Example: .. code-block:: python title_context_template = "template" openai_llm = OpenAI(OpenAIModels.GPT_3_5_TURBO.value) entity_extractor = OpenAIEntityExtractor("title", llm=openai_llm, prompt_template=title_context_template) context = sycamore.init() pdf_docset = context.read.binary(paths, binary_format="pdf") .partition(partitioner=ArynPartitioner()) .extract_entity(entity_extractor=entity_extractor) """ def __init__( self, entity_name: str, entity_type: Optional[str] = None, llm: Optional[LLM] = None, prompt_template: Optional[str] = None, num_of_elements: int = 10, prompt_formatter: Callable[[list[Element], str], str] = element_list_formatter, use_elements: Optional[bool] = True, prompt: Optional[Union[list[dict], str, SycamorePrompt]] = None, field: str = "text_representation", max_tokens: int = 512, tokenizer: Optional[Tokenizer] = None, similarity_query: Optional[str] = None, similarity_scorer: Optional[SimilarityScorer] = None, ): super().__init__(entity_name) self._entity_name = entity_name self._entity_type = entity_type self._llm = llm self._num_of_elements = num_of_elements self._prompt_template = prompt_template self._prompt_formatter = prompt_formatter self._use_elements = use_elements self._prompt = prompt self._field = field self._max_tokens = max_tokens self._tokenizer = tokenizer self._similarity_query = similarity_query self._similarity_scorer = similarity_scorer def _get_const_variables(self) -> dict[str, str]: # These kept popping up in various places across the transforms return { "similarity_field_name": f"{self._field}_similarity_score", "source_idx_key": f"{self._entity_name}_source_indices", "batch_key": f"{self._entity_name}_batches", "iteration_var_name": f"{self._entity_name}_i", } def _get_prompt(self) -> SycamorePrompt: # there's like a million paths to cover but I think I have # them all vars = self._get_const_variables() if self._prompt_formatter is not element_list_formatter: j_elements = "{{ formatter(doc.elements) }}" elif self._tokenizer is not None: j_elements = J_ELEMENT_BATCHED_LIST_WITH_METADATA else: j_elements = J_ELEMENT_BATCHED_LIST if not self._use_elements: if self._prompt is None: raise ValueError("prompt must be specified if use_elements is False") j_elements = "{{ doc.field_to_value(field) }}" common_params = { "field": self._field, "num_elements": self._num_of_elements, "batch_key": vars["batch_key"], "iteration_var": vars["iteration_var_name"], "entity": self._entity_name, "use_elements": self._use_elements, } if self._prompt is not None: if isinstance(self._prompt, SycamorePrompt): return self._prompt.fork(**common_params) if isinstance(self._prompt, str): return JinjaPrompt( system=None, user=self._prompt + "\n" + j_elements, response_format=None, **common_params, ) else: system = None if len(self._prompt) > 0 and self._prompt[0]["role"] == "system": system = self._prompt[0]["content"] user = [p["content"] for p in self._prompt[1:]] + [j_elements] else: user = [p["content"] for p in self._prompt] + [j_elements] return JinjaPrompt(system=system, user=user, response_format=None, **common_params) elif self._prompt_template is not None: from sycamore.llms.prompts.default_prompts import EntityExtractorFewShotJinjaPrompt return EntityExtractorFewShotJinjaPrompt.fork(examples=self._prompt_template, **common_params) else: from sycamore.llms.prompts.default_prompts import EntityExtractorZeroShotJinjaPrompt return EntityExtractorZeroShotJinjaPrompt.fork(**common_params) def _make_preprocess_fn(self, prompt: SycamorePrompt) -> Callable[[Document], Document]: vars = self._get_const_variables() def sort_and_batch_elements(doc: Document) -> Document: if self._similarity_query is not None and self._similarity_scorer is not None: # If we did similarity scoring sort the elements (keep track of their original # locations though) elements = sorted( [(e, i) for i, e in enumerate(doc.elements)], key=(lambda e_i: e_i[0].properties.get(vars["similarity_field_name"], float("-inf"))), reverse=True, ) else: elements = [(e, i) for i, e in enumerate(doc.elements)] batches = [] if self._tokenizer is not None: curr_club = [] # We'll create a dummy document and consecutively # add more elements to it, rendering out to a prompt # at each step and counting tokens to find breakpoints. dummy = doc.copy() dummy.properties = doc.properties.copy() dummy.properties[vars["iteration_var_name"]] = 0 dummy.elements = [] for e, i in elements: dummy.elements.append(e) curr_club.append(i) dummy.properties[vars["batch_key"]] = [curr_club] rendered = prompt.render_document(dummy) tks = rendered.token_count(self._tokenizer) if tks > self._max_tokens: curr_club.pop() if len(curr_club) > 0: batches.append(curr_club) curr_club = [i] e.properties[vars["source_idx_key"]] = curr_club # dummy.elements = [e] else: e.properties[vars["source_idx_key"]] = curr_club if len(curr_club) > 0: batches.append(curr_club) else: # If no tokenizer, we run a single batch with the first num_of_elements. batches = [[i for e, i in elements[: self._num_of_elements]]] for i in batches[0]: doc.elements[i].properties[vars["source_idx_key"]] = batches[0] doc.properties[vars["batch_key"]] = batches return doc return sort_and_batch_elements @context_params(OperationTypes.INFORMATION_EXTRACTOR) def as_llm_map( self, child: Optional[Node], context: Optional[Context] = None, llm: Optional[LLM] = None, **kwargs, ) -> Node: # represent this EntityExtractor as a CompositeTransform consisting of some # preprocessing (set up batches, sort elements, etc), the central LLMMap, # and some postprocessing (derive the source_indices property) if llm is None: llm = self._llm assert llm is not None, "Could not find an LLM to use" prompt = self._get_prompt() preprocess = self._make_preprocess_fn(prompt) vars = self._get_const_variables() def validate(d: Document) -> bool: return self._tokenizer is None or d.properties.get(self._entity_name, "None") != "None" def postprocess(d: Document) -> Document: target_club_idx = d.properties[vars["iteration_var_name"]] if target_club_idx >= len(d.properties[vars["batch_key"]]): return d batch = d.properties[vars["batch_key"]][target_club_idx] d.properties[vars["source_idx_key"]] = batch if d.properties[self._entity_name] == "None": d.properties[self._entity_name] = None elif self._entity_type is not None and self._entity_type in [ "int", "float", ]: try: conversion_func = {"int": int, "float": float}[self._entity_type] d.properties[self._entity_name] = conversion_func(d.properties[self._entity_name]) except ValueError: d.properties[self._entity_name] = None return d nodes: list[BaseMapTransform] = [] head_node: Node if self._similarity_query is not None and self._similarity_scorer is not None: # If similarity we add a ScoreSimilarity node to the sub-pipeline from sycamore.transforms.similarity import ScoreSimilarity head_node = ScoreSimilarity( child, # type: ignore similarity_scorer=self._similarity_scorer, query=self._similarity_query, score_property_name=vars["similarity_field_name"], ) nodes.append(head_node) else: head_node = child # type: ignore head_node = Map(head_node, f=preprocess) nodes.append(head_node) head_node = LLMMap( head_node, prompt, self._entity_name, llm, validate=validate, iteration_var=vars["iteration_var_name"], max_tries=100, **kwargs, ) nodes.append(head_node) head_node = Map(head_node, f=postprocess) nodes.append(head_node) comptransform = CompositeTransform(child, []) # type: ignore comptransform.nodes = nodes return comptransform @context_params(OperationTypes.INFORMATION_EXTRACTOR) @timetrace("OaExtract") def extract_entity( self, document: Document, context: Optional[Context] = None, llm: Optional[LLM] = None, ) -> Document: self._llm = llm or self._llm if self._use_elements: element_sorter = make_element_sorter_fn(self._field, self._similarity_query, self._similarity_scorer) element_sorter(document) if self._tokenizer is not None: entities, window_indices = self._handle_element_chunking(document) document.properties[f"{self._entity_name}_source_element_index"] = window_indices else: entities = self._handle_element_prompting(document) else: if self._prompt is None: raise Exception("prompt must be specified if use_elements is False") entities = self._handle_document_field_prompting(document) document.properties.update({f"{self._entity_name}": entities}) return document def _handle_element_prompting(self, document: Document) -> Any: assert self._llm is not None sub_elements = [document.elements[i] for i in range((min(self._num_of_elements, len(document.elements))))] content = self._prompt_formatter(sub_elements, self._field) if self._prompt is None: prompt: Any = None if self._prompt_template: from sycamore.llms.prompts.default_prompts import _EntityExtractorFewShotGuidancePrompt prompt = _EntityExtractorFewShotGuidancePrompt() else: from sycamore.llms.prompts.default_prompts import _EntityExtractorZeroShotGuidancePrompt prompt = _EntityExtractorZeroShotGuidancePrompt() entities = self._llm.generate_old( prompt_kwargs={ "prompt": prompt, "entity": self._entity_name, "query": content, "examples": self._prompt_template, } ) return entities else: return self._get_entities(content) def _handle_element_chunking(self, document: Document) -> Any: assert self._tokenizer is not None ind = 0 while ind < len(document.elements): ind, combined_text, window_indices = merge_elements( ind, document.elements, self._field, self._tokenizer, self._max_tokens ) entity = self._get_entities(combined_text) for i in range(0, len(window_indices)): document.elements[ind - i - 1]["properties"][f"{self._entity_name}"] = entity document.elements[ind - i - 1]["properties"][ f"{self._entity_name}_source_element_index" ] = window_indices if entity != "None": return entity, window_indices return "None", None def _handle_document_field_prompting(self, document: Document) -> Any: assert self._llm is not None if self._field is None: self._field = "text_representation" value = str(document.field_to_value(self._field)) return self._get_entities(value) def _get_entities(self, content: str, prompt: Optional[Union[list[dict], str]] = None): assert self._llm is not None assert not isinstance( self._prompt, SycamorePrompt ), f"cannot use old extract_entity interface with a SycamorePrompt: {self._prompt}" prompt = prompt or self._prompt assert prompt is not None, "No prompt found for entity extraction" if isinstance(self._prompt, str): prompt = self._prompt + content response = self._llm.generate_old(prompt_kwargs={"prompt": prompt}, llm_kwargs={}) else: messages = (self._prompt or []) + [{"role": "user", "content": content}] response = self._llm.generate_old(prompt_kwargs={"messages": messages}, llm_kwargs={}) return response
[docs] class ExtractEntity(Map): """ ExtractEntity is a transformation class for extracting entities from a dataset using an EntityExtractor. The Extract Entity Transform extracts semantically meaningful information from your documents.These extracted entities are then incorporated as properties into the document structure. Args: child: The source node or component that provides the dataset containing text data. entity_extractor: An instance of an EntityExtractor class that defines the entity extraction method to be applied. resource_args: Additional resource-related arguments that can be passed to the extraction operation. Example: .. code-block:: python source_node = ... # Define a source node or component that provides a dataset with text data. custom_entity_extractor = MyEntityExtractor(entity_extraction_params) extraction_transform = ExtractEntity(child=source_node, entity_extractor=custom_entity_extractor) extracted_entities_dataset = extraction_transform.execute() """ def __init__( self, child: Node, entity_extractor: EntityExtractor, context: Optional[Context] = None, **resource_args, ): super().__init__( child, f=entity_extractor.extract_entity, kwargs={"context": context}, **resource_args, )