Source code for sycamore.transforms.extract_schema
from abc import ABC, abstractmethod
from typing import Any, Callable, Optional, Union, List
import json
import sycamore
import logging
from sycamore import ExecMode
from sycamore.data import Element, Document
from sycamore.schema import SchemaV2 as Schema, NamedProperty
from sycamore.llms import LLM
from sycamore.llms.prompts.default_prompts import (
PropertiesZeroShotJinjaPrompt,
PropertiesFromSchemaJinjaPrompt,
SchemaZeroShotJinjaPrompt,
)
from sycamore.llms.prompts import SycamorePrompt
from sycamore.plan_nodes import Node
from sycamore.transforms.base import CompositeTransform
from sycamore.transforms.map import Map
from sycamore.transforms.base_llm import LLMMap
from sycamore.transforms.property_extraction.prompts import format_schema_v2
from sycamore.utils.extract_json import extract_json
from sycamore.utils.time_trace import timetrace
from sycamore.transforms.embed import Embedder
from sycamore.llms.prompts.default_prompts import MetadataExtractorJinjaPrompt
import math
def _named_prop_to_dict(named_prop: NamedProperty) -> dict[str, Any]:
return {
"name": named_prop.name,
"type": named_prop.type.type,
"description": named_prop.type.description,
"default": named_prop.type.default,
"examples": named_prop.type.examples,
}
def cluster_schema_json(schema: Schema, cluster_size: int, embedder: Optional[Embedder] = None) -> List[Document]:
field_docs: List[Document] = []
for named_prop in schema.properties:
txt = f"Field: {named_prop.name}\nDescription: {named_prop.type.description or ''}"
field_docs.append(Document(text_representation=txt, **_named_prop_to_dict(named_prop)))
ctx = sycamore.init(exec_mode=ExecMode.LOCAL)
embeddings = ctx.read.document(field_docs).embed(embedder)
centroids = embeddings.kmeans(K=cluster_size or round(math.sqrt(len(schema.properties))), iterations=40)
clds = embeddings.clustering(centroids, cluster_field_name="cluster")
clusters_docs = clds.take_all()
groups = {}
for d in clusters_docs:
cluster = d["cluster"].item() if hasattr(d["cluster"], "item") else d["cluster"]
if cluster not in groups:
groups[cluster] = Document()
groups[cluster].elements.append(Element(**d))
return list(groups.values())
def batch_schema_json(schema: Schema, batch_size: int) -> List[Document]:
groups = {}
for batch_num in range(batch_size):
groups[batch_num] = Document()
field_count = len(schema.fields)
for field_num in range(field_count):
batch = field_num % batch_size
groups[batch].elements.append(Element(**_named_prop_to_dict(schema.properties[field_num])))
return list(groups.values())
def element_list_formatter(elements: list[Element]) -> str:
query = ""
for i in range(len(elements)):
query += f"ELEMENT {i + 1}: {elements[i].text_representation}\n"
return query
class SchemaExtractor(ABC):
def __init__(self, entity_name: str):
self._entity_name = entity_name
@abstractmethod
def as_llm_map(self, child: Optional[Node], **kwargs) -> Node:
pass
@abstractmethod
def extract_schema(self, document: Document) -> Document:
pass
class PropertyExtractor(ABC):
def __init__(
self,
): # properties: list[str]):
# self._properties = properties
pass
@abstractmethod
def as_llm_map(self, child: Optional[Node], **kwargs) -> Node:
pass
class LLMSchemaExtractor(SchemaExtractor):
"""
The LLMSchemaExtractor uses the specified LLM object to extract a schema.
Args:
entity_name: A natural-language name of the class to be extracted (e.g. `Corporation`)
llm: An instance of an LLM for text processing.
num_of_elements: The number of elements to consider for schema extraction. Default is 10.
prompt_formatter: A callable function to format prompts based on document elements.
Example:
.. code-block:: python
openai = OpenAI(OpenAIModels.GPT_3_5_TURBO.value)
schema_extractor=LLMSchemaExtractor("Corporation", llm=openai, num_of_elements=35)
context = sycamore.init()
pdf_docset = context.read.binary(paths, binary_format="pdf")
.partition(partitioner=ArynPartitioner())
.extract_schema(schema_extractor=schema_extractor)
"""
def __init__(
self,
entity_name: str,
llm: LLM,
num_of_elements: int = 35,
max_num_properties: int = 7,
prompt_formatter: Callable[[list[Element]], str] = element_list_formatter,
):
super().__init__(entity_name)
self._llm = llm
self._num_of_elements = num_of_elements
self._prompt_formatter = prompt_formatter
self._max_num_properties = max_num_properties
def as_llm_map(self, child: Optional[Node], **kwargs) -> Node:
prompt = SchemaZeroShotJinjaPrompt.fork(
entity=self._entity_name,
max_num_properties=self._max_num_properties,
num_elements=self._num_of_elements,
field="text_representation",
)
if self._prompt_formatter is not element_list_formatter:
prompt = prompt.fork(prompt_formatter=self._prompt_formatter)
def parse_json(doc: Document) -> Document:
schemastr = doc.properties.get("_schema", "{}")
try:
schema = extract_json(schemastr)
except (json.JSONDecodeError, AttributeError, ValueError):
schema = schemastr
doc.properties["_schema"] = schema
doc.properties["_schema_class"] = self._entity_name
return doc
llm_map = LLMMap(child, prompt=prompt, output_field="_schema", llm=self._llm)
json_map = Map(llm_map, f=parse_json)
comptransform = CompositeTransform(child, []) # type: ignore
comptransform.nodes = [llm_map, json_map]
return comptransform
@timetrace("ExtrSchema")
def extract_schema(self, document: Document) -> Document:
comptransform = self.as_llm_map(None)
assert isinstance(comptransform, CompositeTransform)
return comptransform._local_process([document])[0]
[docs]
class OpenAISchemaExtractor(LLMSchemaExtractor):
"""Alias for LLMSchemaExtractor for OpenAI models.
Retained for backward compatibility.
.. deprecated:: 0.1.25
Use LLMSchemaExtractor instead.
"""
pass
[docs]
class LLMPropertyExtractor(PropertyExtractor):
"""
The LLMPropertyExtractor uses an LLM to extract actual property values once
a schema has been detected or provided.
Args:
llm: An instance of an LLM for text processing.
schema_name: An optional natural-language name of the class to be extracted (e.g. `Corporation`)
If not provided, will use the _schema_class property added by extract_schema.
schema: An optional JSON-encoded schema, or Schema object to be used for property extraction.
If not provided, will use the _schema property added by extract_schema.
num_of_elements: The number of elements to consider for property extraction. Default is 10.
prompt_formatter: A callable function to format prompts based on document elements.
Example:
.. code-block:: python
schema_name = "AircraftIncident"
schema = {"location": "string", "aircraft": "string", "date_and_time": "string"}
openai_llm = OpenAI(OpenAIModels.GPT_3_5_TURBO.value)
property_extractor = LLMPropertyExtractor(
llm=openai, schema_name=schema_name, schema=schema, num_of_elements=35
)
docs_with_schema = ...
docs_with_schema = docs_with_schema.extract_properties(property_extractor=property_extractor)
"""
def __init__(
self,
llm: LLM,
schema_name: Optional[str] = None,
schema: Optional[Union[dict, Schema]] = None,
num_of_elements: Optional[int] = None,
prompt_formatter: Callable[[list[Element]], str] = element_list_formatter,
metadata_extraction: bool = False,
embedder: Optional[Embedder] = None,
group_size: Optional[int] = None,
clustering: bool = True,
):
super().__init__()
self._llm = llm
self._schema_name = schema_name
self._schema = schema
self._num_of_elements = num_of_elements
self._metadata_extraction = metadata_extraction
self._prompt_formatter = prompt_formatter
self._group_size = group_size
self._embedder = embedder
self._clustering = clustering
def extract_docs(self, docs: list[Document]) -> list[Document]:
jsonextract_node = self.as_llm_map(None)
assert len(jsonextract_node.children) == 1
llm_map_node = jsonextract_node.children[0]
assert isinstance(jsonextract_node, Map)
assert isinstance(llm_map_node, LLMMap)
return [jsonextract_node.run(d) for d in llm_map_node.run(docs)]
def cast_types(self, fields: dict) -> dict:
import dateparser # type: ignore # No type stubs available for 'dateparser'; ignoring for mypy
assert self._schema is not None, "Schema must be provided for property standardization."
assert isinstance(self._schema, Schema), "Schema object must be provided for property standardization."
result: dict = {}
type_cast_functions: dict[str, Callable] = {
"int": int,
"float": float,
"string": str,
"bool": bool,
"date": lambda x: dateparser.parse(x),
"datetime": lambda x: dateparser.parse(x),
"array": list, # TODO: Handle array types properly
}
for field in self._schema.properties:
value = fields.get(field.name)
if value is None and field.type.default is None:
result[field.name] = None
elif value is None:
result[field.name] = field.type.default
else:
result[field.name] = type_cast_functions.get(field.type.type, lambda x: x)(value)
# Include additional fields not defined in the schema
for key, value in fields.items():
if key not in result:
result[key] = value
return result
def as_llm_map(self, child: Optional[Node], **kwargs) -> Node:
prompt: SycamorePrompt # mypy grr
if self._metadata_extraction:
assert isinstance(self._schema, Schema), "check format of schema passed"
self._group_size = self._group_size or round(math.sqrt(len(self._schema.fields)))
if self._clustering:
clusters_docs = cluster_schema_json(
schema=self._schema, embedder=self._embedder, cluster_size=self._group_size
)
else:
clusters_docs = batch_schema_json(schema=self._schema, batch_size=self._group_size)
tmp_props: list[str] = []
for idx, field_doc in enumerate(clusters_docs):
schema = {}
schema_name = f"_tmp_cluster_{idx}"
tmp_props.append(schema_name)
assert isinstance(field_doc, Document), "Expected field_doc to be a Document instance"
for field in field_doc.elements:
schema[field["name"]] = {
"description": field["description"],
"type": field["type"],
"default": field.get("default"),
"examples": field.get("examples"),
}
prompt = MetadataExtractorJinjaPrompt.fork(
entity_name=schema_name,
response_format=schema,
schema=schema,
)
child = LLMMap(child, prompt=prompt, output_field=schema_name, llm=self._llm, **kwargs)
def _merge(d: Document) -> Document:
merged_metadata: dict = {}
merged_provenance: dict = {}
for k in tmp_props:
temp_metadata = {}
temp_provenance = {}
part = d.properties.pop(k, "{}")
try:
if isinstance(part, str):
part_json = extract_json(part)
if isinstance(part_json, dict):
for k, v in part_json.items():
if v:
temp_metadata[k] = v[0]
temp_provenance[k] = v[1]
else:
temp_metadata[k] = None
merged_metadata.update(temp_metadata)
merged_provenance.update(temp_provenance)
except json.JSONDecodeError:
logging.error(f"Failed to decode JSON for property '{k}': {part}")
d.properties[self._schema_name or "_entity"] = merged_metadata
d.properties[(self._schema_name or "_entity") + "_metadata"] = merged_provenance
return d
return Map(child, f=_merge)
if isinstance(self._schema, Schema):
prompt = PropertiesFromSchemaJinjaPrompt
prompt = prompt.fork(
schema_string=format_schema_v2(self._schema), response_format=self._schema.model_dump()
)
else:
prompt = PropertiesZeroShotJinjaPrompt
if self._schema is not None:
prompt = prompt.fork(schema=self._schema)
if self._schema_name is not None:
prompt = prompt.fork(entity=self._schema_name)
if self._num_of_elements is not None:
prompt = prompt.fork(num_elements=self._num_of_elements)
if self._prompt_formatter is not element_list_formatter:
prompt = prompt.fork(prompt_formatter=self._prompt_formatter)
def parse_json_and_cast(d: Document) -> Document:
entity_name = self._schema_name or "_entity"
entitystr = d.properties.get(entity_name, "{}")
endkey = self._schema_name or d.properties.get("_schema_class", "entity")
try:
entity = extract_json(entitystr)
except (json.JSONDecodeError, AttributeError, ValueError):
entity = entitystr
# If LLM couldn't do extract we instructed it to say "None"
# So handle that
if entity == "None":
entity = {}
if isinstance(self._schema, Schema):
entity = self.cast_types(entity)
# If schema name wasn't provided we wrote stuff to a
# temp "_entity" property
if entity_name == "_entity":
if endkey in d.properties:
d.properties[endkey].update(entity)
else:
d.properties[endkey] = entity
if "_entity" in d.properties:
d.properties.pop("_entity")
return d
d.properties[endkey] = entity
return d
llm_map = LLMMap(child, prompt, output_field=self._schema_name or "_entity", llm=self._llm, **kwargs)
parse_map = Map(llm_map, f=parse_json_and_cast)
return parse_map
[docs]
class ExtractSchema(Map):
"""
ExtractSchema is a transformation class for extracting schemas from documents using an SchemaExtractor.
This method will extract a unique schema for each document in the DocSet independently. If the documents in the
DocSet represent instances with a common schema, consider `ExtractBatchSchema` which will extract a common
schema for all documents.
The dataset is returned with an additional `_schema` property that contains JSON-encoded schema, if any
is detected.
Args:
child: The source node or component that provides the dataset text for schema suggestion
schema_extractor: An instance of an SchemaExtractor class that provides the schema extraction method
resource_args: Additional resource-related arguments that can be passed to the extraction operation
Example:
.. code-block:: python
custom_schema_extractor = ExampleSchemaExtractor(entity_extraction_params)
documents = ... # Define a source node or component that provides a dataset with text data.
documents_with_schema = ExtractSchema(child=documents, schema_extractor=custom_schema_extractor)
documents_with_schema = documents_with_schema.execute()
"""
def __init__(self, child: Node, schema_extractor: SchemaExtractor, **resource_args):
super().__init__(child, f=schema_extractor.extract_schema, **resource_args)
[docs]
class OpenAIPropertyExtractor(LLMPropertyExtractor):
"""Alias for LLMPropertyExtractor for OpenAI models.
Retained for backward compatibility.
.. deprecated:: 0.1.25
Use LLMPropertyExtractor instead.
"""
pass
[docs]
class ExtractBatchSchema(Map):
"""
ExtractBatchSchema is a transformation class for extracting a schema from a dataset using an SchemaExtractor.
This assumes all documents in the dataset share a common schema.
If it is more appropriate to provide a unique schema for each document (such as in a hetreogenous PDF collection)
consider using `ExtractSchema` instead.
The dataset is returned with an additional `_schema` property that contains JSON-encoded schema, if any
is detected. This schema will be the same for all elements of the dataest.
Args:
child: The source node or component that provides the dataset text for schema suggestion
schema_extractor: An instance of an SchemaExtractor class that provides the schema extraction method
resource_args: Additional resource-related arguments that can be passed to the extraction operation
Example:
.. code-block:: python
custom_schema_extractor = ExampleSchemaExtractor(entity_extraction_params)
documents = ... # Define a source node or component that provides a dataset with text data.
documents_with_schema = ExtractBatchSchema(child=documents, schema_extractor=custom_schema_extractor)
documents_with_schema = documents_with_schema.execute()
"""
def __init__(self, child: Node, schema_extractor: SchemaExtractor, **resource_args):
# Must run on a single instance so that the cached calculation of the schema works
resource_args["parallelism"] = 1
# super().__init__(child, f=lambda d: d, **resource_args)
super().__init__(child, f=ExtractBatchSchema.Extract, constructor_args=[schema_extractor], **resource_args)
class Extract:
def __init__(self, schema_extractor: SchemaExtractor):
self._schema_extractor = schema_extractor
self._schema: Optional[dict] = None
def __call__(self, d: Document) -> Document:
if self._schema is None:
s = self._schema_extractor.extract_schema(d)
self._schema = {"_schema": s.properties["_schema"], "_schema_class": s.properties["_schema_class"]}
d.properties.update(self._schema)
return d