from abc import ABC, abstractmethod
from typing import Callable, Any
from sycamore.data import Document
from sycamore.plan_nodes import Node, NonCPUUser, NonGPUUser
from sycamore.transforms.map import Map
class TextAugmentor(ABC):
@abstractmethod
def augment_text(self, doc: Document) -> str:
pass
def __call__(self, doc: Document) -> str:
return self.augment_text(doc)
def augment_text_for_map(self, doc: Document) -> Document:
doc.text_representation = self.augment_text(doc)
return doc
[docs]
class UDFTextAugmentor(TextAugmentor):
"""
UDFTextAugmentor augments text by calling a user-defined function (UDF)
that maps documents to strings.
Args:
fn (Callable[[Document], str]): A function that maps a document to the
string to use as the new `text_representation`
Example:
.. code-block:: python
def aug_text_fn(doc: Document) -> str:
return " ".join([
f"This pertains to the part {doc.properties['part_name']}.",
f"{doc.text_representation}"
])
augmentor = UDFTextAugmentor(aug_text_fn)
context = sycamore.init()
pdf_docset = context.read.binary(paths, binary_format="pdf")
.augment_text(augmentor)
"""
def __init__(self, fn: Callable[[Document], str]):
super().__init__()
self._fn = fn
def augment_text(self, doc: Document) -> str:
return self._fn(doc)
[docs]
class JinjaTextAugmentor(TextAugmentor):
"""
JinjaTextAugmentor uses a jinja template in a SandboxedEnvironment to
transform the text representation with metadata from the thingy
Args:
template (str): A jinja2 template for the new text represenation. Can contain
references to `doc` and to any modules passed in the `modules` param
modules (dict[str, Any]): A mapping of module names to module objects
Example:
.. code-block:: python
from sycamore.transforms.augment_text import JinjaTextAugmentor
from sycamore.transforms.regex_replace import COALESCE_WHITESPACE
import pathlib
template = '''This document is from {{ pathlib.Path(doc.properties['path']).name }}.
The title is {{ doc.properties['title'] }}.
The authors are {{ doc.properties['authors'] }}.
{% if doc.text_representation %}
{{ doc.text_representation }}
{% else %}
There is no text representation for this
{% endif %}
'''
aug = JinjaTextAugmentor(template=template, modules={"pathlib": pathlib})
aug_docset = exp_docset.augment_text(aug).regex_replace(COALESCE_WHITESPACE)
aug_docset.show(show_binary=False, truncate_content=False)
"""
def __init__(self, template: str, modules: dict[str, Any] = {}):
from jinja2.sandbox import SandboxedEnvironment
super().__init__()
self._env = SandboxedEnvironment()
self._modules = modules
self._template = template
def augment_text(self, doc: Document) -> str:
return self._env.from_string(source=self._template, globals=self._modules).render(doc=doc)
[docs]
class AugmentText(NonCPUUser, NonGPUUser, Map):
"""
The AugmentText transform puts metadata into the text representation of
documents for better embedding and search quality
"""
def __init__(self, child: Node, text_augmentor: TextAugmentor, **kwargs):
super().__init__(child, f=text_augmentor.augment_text_for_map, **kwargs)