Source code for sycamore.plan_nodes

import sys
from abc import ABC, abstractmethod
from typing import Callable, Optional, TYPE_CHECKING

if TYPE_CHECKING:
    from ray.data import Dataset
    from sycamore.context import Context


[docs] class NodeTraverse: """NodeTraverse allows for complicated traversals For simple use cases, call node.traverse({before,visit,after}=fn) - before is called before traversing children. - after is called after traversing children. - visit is called over each node in an unspecified order, and is easier to use since the function returns nothing. - once is called one time at the very start, and enables multi-pass transforms. """ def __init__( self, before: Optional[Callable[["Node"], "Node"]] = None, visit: Optional[Callable[["Node"], None]] = None, after: Optional[Callable[["Node"], "Node"]] = None, ): self.before_fn = before self.visit_fn = visit self.after_fn = after def once(self, context: "Context", node: "Node") -> "Node": # Called one time at the start of rewriting on the root of the tree. # Enables multi-pass traversals return node # Called before traversing children def before(self, node: "Node") -> "Node": if self.before_fn is None: return node return self.before_fn(node) # Called before traversing children, convenience function for single node mutating operations def visit(self, node: "Node") -> None: if self.visit_fn is not None: self.visit_fn(node) # Called after traversing children def after(self, node: "Node") -> "Node": if self.after_fn is None: return node return self.after_fn(node)
[docs] class Node(ABC): """ A Node is the abstract base unit of a Sycamore Transform, which allows DocSets to transform themselves into end results. Sycamore processes this as a directed tree graph, which allows transforms to be linked to each other and then implemented """ def __init__( self, children: list[Optional["Node"]], materialize: dict = {}, parallelism: Optional[int] = None, **resource_args, ): self.children = children assert parallelism is None or parallelism > 0 self.parallelism = parallelism self.resource_args = resource_args self.properties = {} # copy because of https://stackoverflow.com/questions/1132941/least-astonishment-and-the-mutable-default-argument self.properties["materialize"] = materialize.copy() def __str__(self): return "node" @abstractmethod def execute(self, **kwargs) -> "Dataset": pass
[docs] def prepare(self) -> Optional[Callable]: """Override this method to run something at the beginning of execution after rules have been applied. The entire tree will be traversed in before mode and then any returned callables will be called in the order they were returned. Each callable can return another callable.""" pass
[docs] def finalize(self) -> None: """Override this method to run something at the end of execution after all documents have been returned.""" pass
[docs] def traverse_down(self, f: Callable[["Node"], "Node"]) -> "Node": """ Allows a function to be applied to a node first and then all of its children """ f(self) self.children = [c.traverse_down(f) for c in self.children if c is not None] return self
[docs] def traverse_up(self, f: Callable[["Node"], "Node"]) -> "Node": """ Allows a function to be applied to all of a node's children first and then itelf """ self.children = [c.traverse_up(f) for c in self.children if c is not None] f(self) return self
[docs] def traverse( self, obj: Optional[NodeTraverse] = None, before: Optional[Callable[["Node"], "Node"]] = None, visit: Optional[Callable[["Node"], None]] = None, after: Optional[Callable[["Node"], "Node"]] = None, ) -> "Node": """ Traverse the node tree, functions will be converted to an object. See NodeTraverse for the semantics. """ if obj is None: assert before is not None or visit is not None or after is not None obj = NodeTraverse(before=before, visit=visit, after=after) else: assert before is None and visit is None and after is None return self._traverse(obj)
def _traverse(self, obj: NodeTraverse) -> "Node": n = obj.before(self) obj.visit(self) n.children = [c._traverse(obj) for c in n.children if c is not None] return obj.after(n) def clone(self) -> "Node": raise Exception("Unimplemented")
[docs] class LeafNode(Node): def __init__(self, **resource_args): super().__init__([], **resource_args) def __str__(self, **resource_args): return "leaf"
[docs] class UnaryNode(Node): def __init__(self, child: Optional[Node], **resource_args): super().__init__([child], **resource_args) def __str__(self): return "unary" def child(self) -> Node: assert self.children[0] is not None return self.children[0]
class NonCPUUser: def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) class SingleThreadUser: def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) class NonGPUUser: def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs)
[docs] class Scan(SingleThreadUser, NonGPUUser, LeafNode): def __init__(self, **resource_args): super().__init__(**resource_args) def __str__(self): return "scan" @abstractmethod def format(self): pass
[docs] class Transform(UnaryNode): def __init__(self, child: Node, **resource_args): super().__init__(child, **resource_args) def __str__(self): return "transform"
[docs] class Write(SingleThreadUser, NonGPUUser, UnaryNode): def __init__(self, child: Node, **resource_args): super().__init__(child, **resource_args) def __str__(self): return "write"