Source code for sycamore.transforms.basics

from typing import Callable, TYPE_CHECKING

from sycamore.data import Document, MetadataDocument
from sycamore.plan_nodes import Node, NonGPUUser, NonCPUUser, Transform
from sycamore.transforms.map import MapBatch

if TYPE_CHECKING:
    from ray.data import Dataset


[docs] class Limit(NonCPUUser, NonGPUUser, Transform): """ Limit is a transformation that restricts the size of a dataset to a specified number of records. Args: child: The source node or component that provides the dataset to be limited. limit: The maximum number of records to include in the resulting dataset. Example: .. code-block:: python source_node = ... # Define a source node or component that provides a dataset. limit_transform = Limit(child=source_node, limit=100) limited_dataset = limit_transform.execute() """ def __init__(self, child: Node, limit: int): super().__init__(child) self._limit = limit def execute(self, **kwargs) -> "Dataset": import ray dataset = self.child().execute() rayDocs = [] count = 0 for doc in dataset.iter_rows(): deser_doc = Document.deserialize(doc["doc"]) if not isinstance(deser_doc, MetadataDocument): count += 1 if count > self._limit: break rayDocs.append(doc) return ray.data.from_items(rayDocs) def local_execute(self, all_docs: list[Document]) -> list[Document]: filtered_docs: list[Document] = [] count = 0 for doc in all_docs: if not isinstance(doc, MetadataDocument): count += 1 if count > self._limit: break filtered_docs.append(doc) return filtered_docs
[docs] class Filter(MapBatch): """ Filter is a transformation that applies a user-defined filter function to a dataset. Args: child: The source node or component that provides the dataset to be filtered. f: A callable function that takes a Document object and returns a boolean indicating whether the document should be included in the filtered dataset. resource_args: Additional resource-related arguments that can be passed to the filtering operation. Example: .. code-block:: python source_node = ... # Define a source node or component that provides a dataset. def custom_filter(doc: Document) -> bool: # Define your custom filtering logic here. return doc.some_property == some_value filter_transform = Filter(child=source_node, f=custom_filter) filtered_dataset = filter_transform.execute() """ def __init__(self, child: Node, *, f: Callable[[Document], bool], **resource_args): super().__init__(child, f=lambda docs: [d for d in docs if f(d)], **resource_args)