Source code for sycamore.context

import functools
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, Optional, Union, List
import inspect

from sycamore.plan_nodes import Node, NodeTraverse


[docs] class ExecMode(Enum): UNKNOWN = 0 RAY = 1 LOCAL = 2
[docs] class OperationTypes(Enum): DEFAULT = "default" BINARY_CLASSIFIER = "binary_classifier" INFORMATION_EXTRACTOR = "information_extractor" TEXT_SIMILARITY = "text_similarity"
def _default_rewrite_rules(): import sycamore.rules.optimize_resource_args as o return [o.EnforceResourceUsage(), o.OptimizeResourceArgs()]
[docs] @dataclass class Context: """ A class to implement a Sycamore Context, which initializes a Ray Worker and provides the ability to read data into a DocSet """ exec_mode: ExecMode = ExecMode.RAY ray_args: Optional[dict[str, Any]] = None """ Allows for the registration of Rules in the Sycamore Context that allow for transforming the nodes before execution. These rules can optimize ray execution or perform other manipulations. """ rewrite_rules: list[Union[Callable[[Node], Node], NodeTraverse]] = field(default_factory=_default_rewrite_rules) """ Define parameters for global usage """ params: dict[str, Any] = field(default_factory=dict) @property def read(self): from sycamore.reader import DocSetReader return DocSetReader(self)
[docs] def get_val_from_context( context: "Context", val_key: str, param_names: Optional[List[str]] = None, ignore_default: bool = False ) -> Optional[Any]: """ Helper function: Given a Context object, return the possible value for a given val. This assumes context.params is not a nested dict. @param context: Context to use @param val_key: Key for the value to be returned @param param_names: List of parameter namespaces to look for. Always uses OperationTypes.DEFAULT unless configured otherwise. @param ignore_default: disable usage for OperationTypes.DEFAULT parameter namespace @return: Optional value given configs. """ if not context.params: return None if param_names: for param_name in param_names: cand = context.params.get(param_name, {}).get(val_key) if cand is not None: return cand if not ignore_default: return context.params.get(OperationTypes.DEFAULT.value, {}).get(val_key) return None
[docs] def context_params(*names): """ Applies kwargs from the context to a function call. Requires 'context': Context, to be an argument to the method. There is a fair bit of complexity regarding arg management but the comments should be clear. """ def decorator(func: Callable) -> Callable: @functools.wraps(func) def wrapper(*args, **kwargs): self = args[0] if len(args) > 0 else {} ctx = kwargs.get("context", getattr(self, "context", getattr(self, "_context", None))) if ctx and ctx.params: """ Create argument candidates 'candidate_kwargs' from the Context """ candidate_kwargs = {} candidate_kwargs.update(ctx.params.get("default", {})) qual_name = func.__qualname__ # e.g. 'DocSetWriter.opensearch' function_name_wo_class = qual_name.split(".")[-1] # e.g. 'opensearch' candidate_kwargs.update(ctx.params.get(function_name_wo_class, {})) candidate_kwargs.update(ctx.params.get(qual_name, {})) for i in names: candidate_kwargs.update(ctx.params.get(i, {})) """ If positional args are provided, we want to pop those keys from candidate_kwargs """ sig = inspect.signature(func) signature = list(sig.parameters.keys()) for param in signature[: len(args)]: candidate_kwargs.pop(param, None) """ If the function doesn't accept arbitrary kwargs, we don't want to use candidate_kwargs that aren't in the function signature. """ new_kwargs = {} accepts_kwargs = any(param.kind == param.VAR_KEYWORD for param in sig.parameters.values()) if accepts_kwargs: new_kwargs = candidate_kwargs else: for param in signature[len(args) :]: # arguments that haven't been passed as positional args candidate_val = candidate_kwargs.get(param) if candidate_val: new_kwargs[param] = candidate_val """ Put in user provided kwargs (either through decorator param or function call) """ new_kwargs.update(kwargs) return func(*args, **new_kwargs) else: return func(*args, **kwargs) return wrapper """ this let's you handle decorator usage like: @context_params OR @context_params() OR @context_params("template") OR @context_params("template1", "template2") """ if len(names) == 1 and callable(names[0]): return decorator(names[0]) else: return decorator
[docs] def init(exec_mode=ExecMode.RAY, ray_args: Optional[dict[str, Any]] = None, **kwargs) -> Context: """ Initialize a new Context. """ if ray_args is None: ray_args = {} # Set Logger for driver only, we consider worker_process_setup_hook # or runtime_env/config file for worker application log from sycamore.utils import sycamore_logger sycamore_logger.setup_logger() return Context(exec_mode=exec_mode, ray_args=ray_args, **kwargs)
def shutdown() -> None: import ray ray.shutdown()