from typing import Optional
import logging
from sycamore.data import Document, Element, TableElement
from sycamore.functions.tokenizer import Tokenizer
from sycamore.plan_nodes import Node, SingleThreadUser, NonGPUUser
from sycamore.transforms.map import Map
from sycamore.utils.time_trace import timetrace
logger = logging.getLogger(__name__)
[docs]
class SplitElements(SingleThreadUser, NonGPUUser, Map):
"""
The SplitElements transform recursively divides elements such that no
Element exceeds a maximum number of tokens.
Args:
child: The source node or component that provides the elements to be split
tokenizer: The tokenizer to use in counting tokens, should match embedder
maximum: Maximum tokens allowed in any Element
Example:
.. code-block:: python
node = ... # Define a source node or component that provides hierarchical documents.
xform = SplitElements(child=node, tokenizer=tokenizer, 512)
dataset = xform.execute()
"""
def __init__(self, child: Node, tokenizer: Tokenizer, maximum: int, **kwargs):
super().__init__(child, f=SplitElements.split_doc, args=[tokenizer, maximum], **kwargs)
[docs]
@staticmethod
@timetrace("splitElem")
def split_doc(
parent: Document,
tokenizer: Tokenizer,
max: int,
max_depth: int = 20,
add_binary: bool = True,
) -> Document:
"""
Args:
parent: the document that holds all the elements.
tokenizer: tokenizer for computing the number of tokens in a chunk.
max: maximum number of tokens allowed in a chunk as computed by the above tokenizer.
max_depth: maximum depth of the binary tree that forms as we split each element into two recursively.
add_binary: legacy feature to add text_representation as binary_representation as well.
Returns: the same parent document with split elements.
"""
result = []
for elem in parent.elements:
# Ensure the _header does not take up more than a third of the tokens
# Also avoid max resursive depth error
if elem.get("_header") and len(tokenizer.tokenize(elem["_header"])) / max > 0.33:
logger.warning(f"Token limit exceeded, dropping _header: {elem['_header']}")
del elem["_header"]
logger.debug(f"Splitting element using max_depth of {max_depth}")
try:
split_elements = SplitElements.split_one(
elem,
tokenizer,
max,
0,
max_depth=max_depth,
add_binary=add_binary,
)
if elem.type == "table" and isinstance(elem, TableElement) and elem.table is not None:
for ment in split_elements[1:]:
cheaders = "" if elem.table.column_headers is None else ", ".join(elem.table.column_headers)
pieces = [
ment.text_representation,
cheaders,
elem.data["properties"].get("title"),
elem.get("_header"),
]
counts = []
for x in pieces:
if x is None:
counts.append(0)
else:
counts.append(len(tokenizer.tokenize(x)))
two = ""
tokens = 0
for piece, count in zip(pieces, counts):
if tokens == 0:
two = piece
tokens = count
elif count == 0:
continue
elif (tokens + count) < max:
two = f"{piece}\n{two}"
tokens += count
ment.text_representation = two
result.extend(split_elements)
except RecursionError:
result.extend([elem])
parent.elements = result
return parent
@staticmethod
def split_one(
elem: Element,
tokenizer: Tokenizer,
max: int,
depth: int,
max_depth: int,
add_binary: bool = True,
) -> list[Element]:
if depth > max_depth:
logger.warning("Max split depth exceeded, truncating the splitting")
raise RecursionError()
txt = elem.text_representation
if not txt:
return [elem]
num = len(tokenizer.tokenize(txt))
if num <= max:
return [elem]
half = len(txt) // 2
left = half
right = half + 1
# FIXME: The table object in the split elements would have the whole table structure rather than split
newlineFound = False
if elem.type == "table":
for jj in range(half // 2):
if txt[left] == "\n":
idx = left + 1
newlineFound = True
break
elif txt[right] == "\n":
idx = right + 1
newlineFound = True
break
left -= 1
right += 1
# FIXME: make this work with asian languages
if not newlineFound:
left = half
right = half + 1
predicates = [ # in precedence order
lambda c: c in ".!?",
lambda c: c == ";",
lambda c: c in "()",
lambda c: c == ":",
lambda c: c == ",",
str.isspace,
]
results: list[Optional[int]] = [None] * len(predicates)
for jj in range(half // 2): # stay near middle; avoid the ends
lchar = txt[left]
rchar = txt[right]
go = True
for ii, predicate in enumerate(predicates):
if predicate(lchar):
if results[ii] is None:
results[ii] = left
go = ii != 0
break
elif predicate(rchar):
if results[ii] is None:
results[ii] = right
go = ii != 0
break
if not go:
break
left -= 1
right += 1
idx = half + 1
for res in results:
if res is not None:
idx = res + 1
break
one = txt[:idx]
two = txt[idx:]
ment = elem.copy()
elem.text_representation = one
ment.text_representation = two
if add_binary:
elem.binary_representation = bytes(one, "utf-8")
ment.binary_representation = bytes(two, "utf-8")
aa = SplitElements.split_one(
elem,
tokenizer,
max,
depth + 1,
max_depth=max_depth,
add_binary=add_binary,
)
bb = SplitElements.split_one(
ment,
tokenizer,
max,
depth + 1,
max_depth=max_depth,
add_binary=add_binary,
)
aa.extend(bb)
return aa