from abc import ABC, abstractmethod
from typing import Callable, Optional, Union, Type
import copy
import textwrap
import itertools
import logging
from sycamore.data import Element, Document
from sycamore.functions.tokenizer import Tokenizer, CharacterTokenizer
from sycamore.llms.prompts.prompts import (
SycamorePrompt,
JinjaPrompt,
)
from sycamore.plan_nodes import NonCPUUser, NonGPUUser, Node
from sycamore.llms import LLM
from sycamore.llms.llms import LLMMode
from sycamore.transforms.map import Map
from sycamore.transforms.aggregation import Aggregation
from sycamore.transforms.base import CompositeTransform, BaseMapTransform
from sycamore.transforms.base_llm import LLMMapElements, LLMMap, _infer_prompts
# TODO: Rename this to DocumentListDocument or something less stupid-looking
# and move it somewhere more generally available
class SummaryDocument(Document):
def __init__(self, document=None, **kwargs):
if "elements" in kwargs:
raise ValueError("Cannot set elements directly in a SummarizeDocument")
super().__init__(document, **kwargs)
if self.data.get("sub_docs") is None:
self.data["sub_docs"] = []
elif not isinstance(sd := self.data["sub_docs"], list):
raise ValueError(f"sub_docs must be a list of Document, found {sd}")
else:
subdocs = self.data["sub_docs"]
for sd in subdocs:
if not isinstance(sd, Document):
raise ValueError(f"sub_docs must be a list of Documents. Found nonmatching {sd}")
self.data["sub_docs"] = [Document(sd) for sd in subdocs]
@property
def sub_docs(self) -> list[Document]:
return self.data["sub_docs"]
@sub_docs.setter
def sub_docs(self, sub_docs: list[Document]):
self.data["sub_docs"] = sub_docs
@sub_docs.deleter
def sub_docs(self) -> None:
self.data["sub_docs"] = []
@property
def elements(self) -> list[Element]:
"""A list of elements belonging to this document. A document does not necessarily always have
elements, for instance, before a document is chunked."""
return self.data.get(
"_elements",
list(itertools.chain(*(d.elements for d in self.data["sub_docs"]))),
)
@elements.setter
def elements(self, elements: list[Element]):
"""Set the elements for this document."""
self.data["_elements"] = elements
@elements.deleter
def elements(self) -> None:
"""Delete the elements of this document."""
self.data.pop("_elements", None)
[docs]
class Summarizer(ABC):
def summarize(self, document: Document) -> Document:
map = self.as_llm_map(None)
assert isinstance(map, (BaseMapTransform, CompositeTransform))
ds = map.local_execute([document], drop_metadata=True)
assert len(ds) == 1, f"Found more than one Document after summmarizing just one: {ds}"
return ds[0]
@abstractmethod
def as_llm_map(self, child: Optional[Node], **kwargs) -> Node:
pass
[docs]
class LLMElementTextSummarizer(Summarizer):
"""
LLMElementTextSummarizer uses a specified LLM to summarize text data within elements of a document.
Args:
llm: An instance of an LLM class to use for text summarization.
element_operator: A callable function that operates on the document and returns a list of elements to be
summarized. Default is None.
Example:
.. code-block:: python
llm_model = OpenAILanguageModel("gpt-3.5-turbo")
element_operator = my_element_selector # A custom element selection function
summarizer = LLMElementTextSummarizer(llm_model, element_operator)
context = sycamore.init()
pdf_docset = context.read.binary(paths, binary_format="pdf")
.partition(partitioner=ArynPartitioner())
.summarize(summarizer=summarizer)
"""
def __init__(self, llm: LLM, element_filter: Optional[Callable[[Element], bool]] = None):
self._llm = llm
self._element_filter = element_filter
def as_llm_map(self, child: Optional[Node], **kwargs) -> Node:
from sycamore.llms.prompts.default_prompts import TextSummarizerJinjaPrompt
filter = self._element_filter or (lambda e: True)
return LLMMapElements(
child,
TextSummarizerJinjaPrompt,
output_field="summary",
llm=self._llm,
filter=filter,
)
[docs]
class EtCetera:
"""Sentinel value to sit at the end of a list of fields, signifying 'add as
many additional properties as you can within the token limit'"""
def _partition_fields(document: Document, fields: list[Union[str, Type[EtCetera]]]) -> tuple[list[str], list[str]]:
"""
Split a list of fields into document and element fields - any fields in the list that
are in the document properties are document fields, and everything else is an element field.
EtCetera turns the list into 'every field' (with a prefix if early field order matters)
"""
# TODO: If property values are varied between document and elements we might
# not want to drop them from the elements.
doc_fields: list[str] = []
elt_fields: list[str] = []
if len(fields) == 0:
return doc_fields, elt_fields
for f in fields:
if f is EtCetera:
continue
assert not isinstance(f, type)
property_name = f[len("properties.") :]
if property_name in document.properties:
assert isinstance(f, str), "mypy thinks f could be EtCetera"
doc_fields.append(f)
else:
assert isinstance(f, str), "mypy thinks f could be EtCetera"
elt_fields.append(f)
fieldset = set(fields)
if fields[-1] is EtCetera:
for f in document.properties:
if f"properties.{f}" not in fieldset:
doc_fields.append(f"properties.{f}")
docfieldset = set(doc_fields) | fieldset
eltfieldset = {f"properties.{k}" for e in document.elements for k in e.properties if k not in docfieldset}
elt_fields.extend(list(eltfieldset))
return doc_fields, elt_fields
def make_max_tokens_heirarchy_prompt() -> JinjaPrompt:
return JinjaPrompt(
system=textwrap.dedent("""
{%- if element_testing is not defined -%}{# element_testing means only render an element, to get token count #}
{% if question is defined %}You are a helpful research assistant. You answer questions based on
text you are presented with.
{% else %}You are a helpful data summarizer. You concisely summarize text you are presented with,
including as much detail as possible.
{% endif %}{% endif %}
"""),
user=textwrap.dedent("""
{%- macro get_text_fields(element, fields) %}
{% for f in fields %}
{{ f }}: {{ element.field_to_value(f) }}
{%- endfor %}
{% endmacro -%}
{%- macro get_text_base(element) %}
{{ get_text_fields(element, elt_fields) }}
Text: {{ element.text_representation }}
{% endmacro -%}
{%- macro get_text(element) %}
{%- if round == 0 -%}
{{ get_text_base(element) }}
{%- else -%}
{{ element.properties["summary"] }}
{% endif -%}
{% endmacro -%}
{%- macro get_data_description() -%}
{%- if data_description is defined -%}
{{ data_description }}
{%- else -%}
a set of documents with properties for each document
{%- endif -%}
{%- endmacro -%}
{%- if element_testing is not defined -%}{# element_testing means only render an element, to get token count #}
{% if round == 0 -%}
You are given {{ get_data_description() }}. Please use only the information found in these elements
to determine an answer to the question "{{ question }}". If you cannot answer the question based on
the data provided, instead respond with any data that might be relevant to the question.
{% else %}
You are given a list of partial answers to the question "{{ question }}" based on {{ get_data_description() }}.
Please combine these partial answers into a coherent single answer to the question "{{ question }}".
Include the parts of the partial answers that are relevant, ignore irrelevant parts.
{%- endif %}
{% if doc_fields|count > 0 -%}
Shared Properties:
{{ get_text_fields(doc, doc_fields) }}
{%- endif %}
{% if round == 0 -%}
Elements:
{% elif round > 0 -%}
Answers:
{% endif %}
{%- endif -%}{# end of element_testing check. Stuff inside this block was constant across elements #}
{%- for e in doc.elements %}
{{ loop.index }}: {{ get_text(e) }}
{% endfor %}
"""),
question="What is the summary of this data?",
)
[docs]
class MultiStepDocumentSummarizer(Summarizer):
"""
Summarizes a document by constructing a tree of summaries. Each leaf contains as many consecutive
elements as possible within the token limit, and each vertex of the tree contains as many sub-
summaries as possible within the token limit. e.g with max_tokens=10
.. code-block::
Elements: (3 tokens) - (3 tokens) - (5 tokens) - (8 tokens)
| | | |
(4 token summary) - (3 token summary) - (2 token summary)
\\ | /
(5 token summary)
Args:
llm: LLM to use for summarization
llm_mode: How to call the LLM - SYNC, ASYNC, BATCH. Async is faster but not all llms support it.
question: Optional question to use as context for the summarization. If set, the llm will
attempt to answer the question with the data provided
data_description: Optional string describing the input documents.
prompt: Prompt to use for each summarization. Caution: The default (MaxTokensHeirarchicalSummarizerPrompt)
has some fairly complicated logic encoded in it to make the tree construction work correctly.
fields: List of fields to include in each element's representation in the prompt. Specify
with dotted notation (e.g. properties.title). End the list with `EtCetera` to add all fields
(previously specified fields go first). Default is [] which includes no fields.
tokenizer: tokenizer to use when computing how many tokens a prompt will take. Default is
CharacterTokenizer
"""
def __init__(
self,
llm: LLM,
llm_mode: Optional[LLMMode] = None,
question: Optional[str] = None,
data_description: Optional[str] = None,
prompt: SycamorePrompt = make_max_tokens_heirarchy_prompt(),
fields: list[Union[str, Type[EtCetera]]] = [],
tokenizer: Tokenizer = CharacterTokenizer(),
):
self.llm = llm
self.llm_mode = llm_mode if llm_mode is not None else llm.default_mode()
self.prompt = prompt
assert EtCetera not in fields[:-1], "EtCetera must be at the end of the list of fields if provided"
self.fields = fields
self.question = question
self.data_description = data_description
self.max_tokens = tokenizer.max_tokens or 10_000
self.tokenizer = tokenizer
def as_llm_map(self, child: Optional[Node], **kwargs) -> Node:
# MultiStepDocumentSummarizer doesn't use LLMMap - it doesn't work very cleanly
return Map(child, f=self.summarize)
[docs]
def summarize(self, document: Document) -> Document:
"""Summarize a document by summarizing groups of elements iteratively
in rounds until only one element remains; that's our new summary"""
doc_fields, elt_fields = _partition_fields(document, self.fields)
base_prompt = self.prompt.fork(
ignore_none=True,
doc_fields=doc_fields,
elt_fields=elt_fields,
question=self.question,
data_description=self.data_description,
)
etk_prompt = base_prompt.fork(element_testing=True)
dummy_doc = document.copy()
remaining_elements = dummy_doc.elements
round = 0
last_elt_len = len(remaining_elements)
while len(remaining_elements) > 1 or round == 0:
round_prompt = base_prompt.fork(round=round)
round_etk_prompt = etk_prompt.fork(round=round)
remaining_elements = self.summarize_one_round(dummy_doc, remaining_elements, round_prompt, round_etk_prompt)
if len(remaining_elements) == last_elt_len and round > 0:
logging.warning("Detected likely infinite summary loop. Exiting with incomplete summary")
break
last_elt_len = len(remaining_elements)
round += 1
if remaining_elements:
document.properties["summary"] = remaining_elements[0].properties["summary"]
else:
document.properties["summary"] = "Empty Summary Document, nothing to summarize"
for e in document.elements:
e.properties.pop("summary", None)
return document
[docs]
def summarize_one_round(
self,
document: Document,
elements: list[Element],
base_prompt: SycamorePrompt,
etk_prompt: SycamorePrompt,
) -> list[Element]:
"""Perform a 'round' of element summarization: Assemble batches of maximal amounts
of elements and summarize them, attaching the resulting summaries to the first
element of each batch and returning only those elements."""
# Compute token costs for the base stuff and each element individually
document.elements = []
baseline_tks = base_prompt.render_document(document).token_count(self.tokenizer)
# Batch elements and make prompts out of them
elt_batches = self.batch_elements(baseline_tks, elements, etk_prompt, document)
final_elements = []
to_infer = []
for eb in elt_batches:
if eb:
document.elements = eb
final_elements.append(eb[0])
to_infer.append(base_prompt.render_document(document))
# Invoke the llm and attach summaries
# TODO: Use run_coros_threadsafe here instead
summaries = _infer_prompts(prompts=to_infer, llm=self.llm, llm_mode=self.llm_mode)
for e, s in zip(final_elements, summaries):
e.properties["summary"] = s
return final_elements
[docs]
def batch_elements(
self,
baseline_tokens: int,
elements: list[Element],
etk_prompt: SycamorePrompt,
document: Document,
) -> list[list[Element]]:
"""Return a list of lengths of consecutive batches of elements keeping total
token counts below my token limit"""
limit = self.max_tokens
result = []
curr_tks = baseline_tokens
curr_batch: list[Element] = []
for e in elements:
document.elements = [e]
etks = etk_prompt.render_document(document).token_count(self.tokenizer)
if etks + curr_tks > limit:
if etks + baseline_tokens > limit:
raise ValueError(
"An element was too big to fit within the specified max tokens. "
"Please run `docset.split_elements` to break it up or limit the"
f" properties used in the prompt.\n\nElement: {e}"
)
result.append(curr_batch)
curr_batch = [e]
curr_tks = baseline_tokens + etks
else:
curr_batch.append(e)
curr_tks += etks
result.append(curr_batch)
return result
def make_onestep_summarizer_prompt() -> JinjaPrompt:
return JinjaPrompt(
system="You are a helpful text summarizer",
user=textwrap.dedent("""
You are given a series of database entries that answer the question "{{ question }}".
Generate a concise, conversational summary of the data to answer the question.
{%- for subdoc in doc.data.get("sub_docs", [doc]) %}
Entry {{ loop.index }}:
{% for f in doc.properties[doc_fields_key] %}{% if f.startswith("_") %}{% continue %}{% endif %}
{{ f }}: {{ subdoc.field_to_value(f) }}
{% endfor -%}
{%- if doc.properties[numel_key] is not none and doc.properties[numel_key] > 0 %}
Elements:
{%- set start = doc.properties[startel_key] -%}
{%- set end = doc.properties[startel_key] + doc.properties[numel_key] -%}
{%- for subel in subdoc.elements[start:end] -%}
{#- Removed {loop.index} from here because it blows up the token count. For an element token count, the index is 0 but when we count the tokens for all the elements included, it becomes like (0,1,2...) which results in a different tokenization from how we tokenize 1 element at a time. -#}
{%- for f in doc.properties[elt_fields_key] %}
{{ f }}: {{ subel.field_to_value(f) }}
{%- endfor %}
Text: {{ subel.text_representation }}
{% endfor %}
{% endif -%}
{% endfor %}
"""),
)
[docs]
class OneStepDocumentSummarizer(Summarizer):
"""
Summarizes a document in a single LLM call by taking as much data as possible
from every element, spread across them evenly. Intended for use with summarize_data,
where a summarizer is used to summarize an entire docset.
Args:
llm: LLM to use for summarization
question: Question to use as context for the summary. The llm will attempt to
use the data provided to answer the question.
tokenizer: Tokenizer to use to count tokens (to not exceed the token limit).
Default is CharacterTokenizer
fields: List of fields to include from every element. To include any additional
fields (after the ones specified), end the list with `EtCetera`. Default is
empty list, which stands for 'no properties'
"""
def __init__(
self,
llm: LLM,
question: str,
tokenizer: Tokenizer = CharacterTokenizer(),
fields: list[Union[str, Type[EtCetera]]] = [],
):
self.llm = llm
self.question = question
self.token_limit = tokenizer.max_tokens or 10_000
self.tokenizer = tokenizer
assert EtCetera not in fields[:-1], "EtCetera must be at the end of the list of fields if provided"
self.fields = fields
self.prompt = make_onestep_summarizer_prompt().fork(**self.get_const_vars())
@staticmethod
def get_const_vars() -> dict[str, str]:
return {
"doc_fields_key": "_doc_fields",
"elt_fields_key": "_elt_fields",
"numel_key": "_num_elements",
"startel_key": "_start_element",
}
def _maximize_fields(
self,
doc: Document,
data_independent_ntk: int,
curr_ntk: int,
partitioned_fields: list[str],
initial_fieldset: set[Union[str, Type[EtCetera]]],
field_key: str,
prompt: SycamorePrompt,
) -> tuple[bool, int, list[str]]:
"""Stuff as many fields into the plan as can fit in the token limit.
Args:
doc: The document to operate on
data_independent_ntk: How many tokens are in the prompt regardless of the data
curr_ntk: Current token count before adding stuff
partitioned_fields: list of fields from _partition_fields - either the element or document fields
initial_fieldset: the set of fields specified by the user
field_key: either "doc_fields_key" or "elt_fields_key", depending on whether we're adding doc or elt fields
prompt: the sycamore prompt to use to render and count tokens
Returns:
(bool, int, list[str]): Whether we filled up the token limit, the total tokens after adding fields,
the finalized list of fields to add
"""
vars = self.get_const_vars()
final_fields = [f for f in partitioned_fields if f in initial_fieldset]
for f in partitioned_fields:
if f in initial_fieldset:
continue
doc.properties[vars[field_key]] = [f]
ntk = prompt.render_document(doc).token_count(self.tokenizer) - data_independent_ntk
if curr_ntk + ntk < self.token_limit:
final_fields.append(f)
curr_ntk += ntk
else:
doc.properties[vars[field_key]] = final_fields
return True, curr_ntk, final_fields
return False, curr_ntk, final_fields
[docs]
def maximize_elements(
self,
doc: Document,
data_independent_ntk: int,
curr_ntk: int,
prompt: SycamorePrompt,
) -> tuple[bool, int, int]:
"""Stuff as many elements as possible into the prompt.
Args:
doc: The document to operate on
data_independent_ntk: How many tokens are in the prompt regardless of data
curr_ntk: Current token count before adding elements
prompt: the sycamore prompt to use to render and count tokens
Returns:
(bool, int, int): Whether we filled up the token limit, the total tokens after adding fields,
the number of elements to use
"""
vars = self.get_const_vars()
# This is complicated bc we might get a SummarizeDocument or a Document
max_numel = (
max(len(d.data.get("elements", [])) for d in doc.data.get("sub_docs", doc.elements))
if doc.data.get("sub_docs")
else 0
)
# If elements can fit there's a little additional fluff added, so recompute baseline tokens
# with no elements (but the element introduction fluff)
doc.properties[vars["numel_key"]] = 1
doc.properties[vars["startel_key"]] = max_numel + 1
data_independent_ntk_with_fluff = prompt.render_document(doc).token_count(self.tokenizer)
curr_ntk += data_independent_ntk_with_fluff - data_independent_ntk
final_numel = 0
for i in range(max_numel):
doc.properties[vars["startel_key"]] = i
ntk = prompt.render_document(doc).token_count(self.tokenizer) - data_independent_ntk_with_fluff
if curr_ntk + ntk < self.token_limit:
final_numel += 1
curr_ntk += ntk
else:
return True, curr_ntk, final_numel
return False, curr_ntk, final_numel
[docs]
def preprocess(self, doc: Document) -> Document:
"""Compute which fields and how many elements to include in the prompt.
First: If specified fields has an EtCetera, add as many fields as possible.
Second: Add as many elements as possible, taking evenly from each document.
Third: If we can add all the elements and specified fields has an EtCetera,
add as many element fielse as possible
"""
vars = self.get_const_vars()
prompt = self.prompt.fork(ignore_none=True, question=self.question)
fields = copy.deepcopy(self.fields)
if isinstance(doc, SummaryDocument):
doc.properties = {k: True for d in doc.sub_docs for k in d.properties.keys()}
doc_fields, elt_fields = _partition_fields(doc, fields)
fieldset = {f for f in fields if f is not EtCetera}
etc = len(fields) > 0 and fields[-1] is EtCetera
# Compute baseline 'fluff' tokens by setting fields and elements to 'no fields'
# and 'no elements'. Use this later to figure out how many tokens adding a field adds
doc.properties[vars["doc_fields_key"]] = []
doc.properties[vars["elt_fields_key"]] = []
doc.properties[vars["numel_key"]] = 0
doc.properties[vars["startel_key"]] = 0
data_independent_ntk = prompt.render_document(doc).token_count(self.tokenizer)
# If fields is specified these are always included, so this will be our starting token total
final_docfields = [f for f in doc_fields if f in fieldset]
doc.properties[vars["doc_fields_key"]] = final_docfields
curr_ntks = prompt.render_document(doc).token_count(self.tokenizer)
finished = False
if etc:
finished, curr_ntks, final_docfields = self._maximize_fields(
doc,
data_independent_ntk,
curr_ntks,
doc_fields,
fieldset,
"doc_fields_key",
prompt,
)
# We added all the fields, now add as many elements as possible
final_eltfields = [f for f in elt_fields if f in fieldset]
if not finished:
doc.properties[vars["doc_fields_key"]] = []
doc.properties[vars["elt_fields_key"]] = final_eltfields
finished, curr_ntks, final_numel = self.maximize_elements(doc, data_independent_ntk, curr_ntks, prompt)
doc.properties[vars["numel_key"]] = final_numel
doc.properties[vars["startel_key"]] = 0
# If we're supposed to add as many fields as possible and we still have room,
# try adding element fields until we run out of space. This feels computationally
# expensive but I think it's just a 'for each element and for each field' which
# seems like the optimum for the intended behavior.
if etc and not finished:
doc.properties[vars["elt_fields_key"]] = []
total_ntk_with_no_fields = prompt.render_document(doc).token_count(self.tokenizer)
finished, curr_ntks, final_eltfields = self._maximize_fields(
doc,
total_ntk_with_no_fields,
curr_ntks,
elt_fields,
fieldset,
"elt_fields_key",
prompt,
)
doc.properties[vars["doc_fields_key"]] = final_docfields
doc.properties[vars["elt_fields_key"]] = final_eltfields
return doc
def cleanup(self, doc: Document) -> Document:
vars = self.get_const_vars()
for v in vars:
doc.properties.pop(vars[v], None)
return doc
def as_llm_map(self, child: Optional[Node], **kwargs):
prompt = self.prompt
if self.question is not None:
prompt = prompt.fork(question=self.question)
preprocess = Map(child, f=self.preprocess)
llm_map = LLMMap(preprocess, prompt=prompt, output_field="summary", llm=self.llm, **kwargs)
postprocess = Map(llm_map, f=self.cleanup)
comptransform = CompositeTransform(child, nodes=[preprocess, llm_map, postprocess]) # type: ignore
return comptransform
[docs]
class Summarize(NonCPUUser, NonGPUUser, Map):
"""
The summarize transform generates summaries of documents or elements.
"""
def __init__(self, child: Node, summarizer: Summarizer, **kwargs):
super().__init__(child, f=summarizer.summarize, **kwargs)
class CollectToSummaryDoc(Aggregation):
def __init__(self):
super().__init__(name="collect_to_summary_doc")
def accumulate(self, docs: list[Document]) -> Document:
return SummaryDocument(sub_docs=docs)
def combine(self, doc1: Document, doc2: Document) -> Document:
assert isinstance(doc1, SummaryDocument)
assert isinstance(doc2, SummaryDocument)
doc1.sub_docs.extend(doc2.sub_docs)
return doc1
def finalize(self, doc: Document) -> Document:
assert isinstance(doc, SummaryDocument)
doc.sub_docs.sort(key=lambda d: d.doc_id or "")
return doc
def zero_factory(self) -> Document:
return SummaryDocument()