from dataclasses import dataclass
from enum import Enum
from functools import wraps
import json
from typing import Any, Dict, List, MutableMapping, Optional
from hashlib import sha256
from pydantic import (
BaseModel,
ConfigDict,
Field,
SerializeAsAny,
model_validator,
field_serializer,
)
from sycamore import DocSet
def exclude_from_comparison(func):
@wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
wrapper._exclude_from_comparison = True
return wrapper
# This is a mapping from class name to subclasses of Node, which is used for deserialization.
_NODE_SUBCLASSES: Dict[str, Any] = {}
[docs]
@dataclass
class NodeSchemaField:
field_name: str
description: Optional[str]
type_hint: str
[docs]
class Node(BaseModel):
"""Represents a node in a logical query plan.
Args:
node_id: The ID of the node.
_inputs: The nodes that this node depends on.
"""
# This allows pydantic to pick up field descriptions from
# docstrings.
model_config = ConfigDict(use_attribute_docstrings=True)
def __init_subclass__(cls, **kwargs: Any):
"""Called when subclasses of Node are initialized. Used to register all subclasses."""
super().__init_subclass__(**kwargs)
if cls.__name__ in _NODE_SUBCLASSES:
raise ValueError(f"Duplicate node type: {cls.__name__}")
_NODE_SUBCLASSES[cls.__name__] = cls
node_type: str = Field(default=None)
"""The type of this node."""
[docs]
@field_serializer("node_type")
def serialize_node_type(self, value: str) -> str:
"""Field serializer for node_type that returns the class name as a default."""
# We can't do this using the "default" argument to Field, because we don't have
# the class instance yet at the time the field is created.
return value or type(self).__name__
node_id: int
"""A unique integer ID representing this node."""
description: Optional[str] = Field(None, json_schema_extra={"exclude_from_comparison": True})
"""A detailed description of why this operator was chosen for this query plan."""
inputs: List[int] = []
"""A list of node IDs that this operation depends on."""
# The nodes that this node depends on. This should be populated externally
# when a LogicalPlan is created.
_input_nodes: Optional[List["Node"]] = None
@property
def input_types(self) -> set[type]:
"""The types of the input to this operator. Default operations accept DocSets"""
return {DocSet}
@property
def output_type(self) -> type:
"""The type of the output to this operator. Default operations return a DocSet"""
return DocSet
# The cache key for this node. Hidden so it is not included in serialization.
_cache_key: Optional[str] = None
def __str__(self) -> str:
return f"Id: {self.node_id} Op: {type(self).__name__}"
[docs]
def logical_compare(self, other: "Node") -> bool:
"""Logically compare two instances of a Node."""
if not isinstance(other, Node):
return False
def exclude_field(field: str):
"""Determine whether the given field should be excluded from comparison."""
if field not in self.model_fields:
return False
json_schema_extra = self.model_fields[field].json_schema_extra
if not json_schema_extra or not hasattr(json_schema_extra, "get"):
return False
return json_schema_extra.get("exclude_from_comparison", False)
# explicitly use dict to compare and exclude keys if needed
self_dict = {k: v for k, v in self.model_dump().items() if not exclude_field(k)}
other_dict = {k: v for k, v in other.model_dump().items() if not exclude_field(k)}
return self_dict == other_dict
def __hash__(self) -> int:
# Note that this hash value will change from run to run as Python's built-in hash()
# is not deterministic.
return hash(self.model_dump_json())
[docs]
def cache_dict(self) -> dict:
"""Returns a dict representation of this node that can be used for comparison."""
# We want to exclude fields that may change from plan to plan, but which do not
# affect the semantic equivalence of the plan.
retval = self.model_dump(exclude={"node_id", "inputs", "description"})
# Recursively include inputs.
retval["inputs"] = [dep.cache_dict() for dep in self.input_nodes()]
return retval
[docs]
def cache_key(self) -> str:
"""Returns the cache key of this node, used for caching intermediate query results during
execution."""
if self._cache_key:
return self._cache_key
cache_key = self.cache_dict()
self._cache_key = sha256(json.dumps(cache_key).encode()).hexdigest()
return self._cache_key
[docs]
@classmethod
def deserialize(cls, data: Dict[str, Any]) -> "Node":
"""Used to deserialize a Node from a dictionary, by returning the appropriate Node subclass."""
if "node_type" not in data:
raise ValueError("No node_type field found in serialized Node")
if data["node_type"] in _NODE_SUBCLASSES:
return _NODE_SUBCLASSES[data["node_type"]].model_validate(data)
else:
raise ValueError(f"Unknown node type: {data['node_type']}")
[docs]
@classmethod
def usage(cls) -> str:
"""Return a detailed description of the this query operator. Used by the planner."""
return f"""**{cls.__name__}**: {cls.__doc__}"""
[docs]
class LogicalNodeDiffType(Enum):
OPERATOR_TYPE = "operator_type"
OPERATOR_DATA = "operator_data"
PLAN_STRUCTURE = "plan_structure"
[docs]
class LogicalPlanDiffEntry(BaseModel):
node_a: SerializeAsAny[Node]
node_b: SerializeAsAny[Node]
diff_type: LogicalNodeDiffType
message: Optional[str] = None
[docs]
class LogicalPlan(BaseModel):
"""Represents a logical query plan.
Args:
query: The query that the plan is for.
nodes: A mapping of node IDs to nodes.
result_node: The node that is the result of the query.
llm_prompt: The LLM prompt that was used to generate this query plan.
llm_plan: The LLM plan that was used to generate this query plan.
"""
query: str
"""The query that the plan is for."""
nodes: MutableMapping[int, SerializeAsAny[Node]]
"""A mapping of node IDs to nodes in the query plan."""
result_node: int
"""The ID of the node that is the result of the query."""
llm_prompt: Optional[Any] = None
"""The LLM prompt that was used to generate this query plan."""
llm_plan: Optional[str] = None
"""The result generated by the LLM."""
def __init__(self, **kwargs):
# Ensure that the correct subclass of Node is used.
if "nodes" not in kwargs:
raise ValueError("No 'nodes' field provided for LogicalPlan")
if isinstance(kwargs["nodes"], dict):
for node_id, node in kwargs["nodes"].items():
if isinstance(node, dict):
kwargs["nodes"][node_id] = Node.deserialize(node)
super().__init__(**kwargs)
[docs]
def downstream_nodes(self, node_id: int) -> List[int]:
"""Return the IDs of all nodes that are downstream of the given node."""
return [n for n in self.nodes.keys() if node_id in self.nodes[n].inputs]
[docs]
def compare(self, other: "LogicalPlan") -> list[LogicalPlanDiffEntry]:
"""
A simple method to compare 2 logical plans. This comparator traverses a plan 'forward', i.e. it attempts to
start from node_id == 0 which is typically a data source query. This helps us detect differences in the plan
in the natural flow of data. If the plans diverge structurally, i.e. 2 nodes have different number of downstream
nodes we stop traversing.
@param other: plan to compare against
@return: List of comparison metrics.
"""
assert 0 in self.nodes, "Plan a requires at least 1 node with ID [0]"
assert 0 in other.nodes, "Plan b requires at least 1 node with ID [0]"
return compare_graphs(self, other, self.nodes[0].node_id, other.nodes[0].node_id, set(), set())
[docs]
def replace_node(self, node_id: int, new_node: Node) -> None:
"""
Replace the existing node at node_id with "new_node".
"""
# sanity check -- we can't create "new_node" without node_id, so we should have already set that properly
old_node = self.nodes[node_id]
assert new_node.node_id == old_node.node_id
# borrow _input_nodes and inputs from the old node
new_node.inputs = old_node.inputs
new_node._input_nodes = old_node._input_nodes
# for any other node that has old_node in its _input_nodes, replace with node
for node in self.nodes.values():
if node._input_nodes and old_node in node._input_nodes:
node._input_nodes = [new_node if x == old_node else x for x in node._input_nodes]
# update the nodes array
self.nodes[node_id] = new_node
[docs]
def insert_node(self, node_id: int, new_node: Node) -> None:
"""
Insert a node into the plan at the specified node_id.
Any nodes that depend on the current node_id are shifted to the right, and their node_ids are incremented.
Also, the input arrays of the affected nodes are updated.
Precondition: node_id must be greater than 0, and the current node at node_id must have exactly one input.
If there is no current node at node_id (i.e., the new node is being "appended"), we use the current result node
as the input to it.
"""
assert node_id > 0, f"Node ID must be greater than 0, got {node_id}"
assert (
len(self.nodes) == node_id or len(self.nodes[node_id].inputs) == 1
), f"""Current node at {node_id}
must have exactly one input, or there should be only one operator"""
if len(self.nodes) == node_id:
new_node._input_nodes = [self.nodes[self.result_node]]
else:
# Add the input node for the new node
new_node._input_nodes = self.nodes[node_id]._input_nodes
# Shift all nodes that are after the current node_id to the right
for nid in sorted(self.nodes.keys(), reverse=True):
if nid >= node_id:
self.nodes[nid].node_id += 1
self.nodes[nid].inputs = [x + 1 for x in self.nodes[nid].inputs]
self.nodes[nid + 1] = self.nodes[nid]
# Modify the _input_nodes for self.nodes[node_id+1]
self.nodes[node_id + 1]._input_nodes = [new_node]
# Insert the new node at the specified node_id
self.nodes[node_id] = new_node
# Modify the terminal node in the plan
self.result_node += 1
return
[docs]
def compare_graphs(
plan_a: LogicalPlan, plan_b: LogicalPlan, node_id_a: int, node_id_b: int, visited_a: set[int], visited_b: set[int]
) -> list[LogicalPlanDiffEntry]:
"""
Traverse and compare 2 graphs given a node pointer in each. Computes different comparison metrics per node.
The function will continue to traverse as long as the graph structure is identical, i.e. same number of outgoing
nodes per node. It also assumes that the "downstream nodes"/edges are ordered - this is the current logical
plan implementation to support operations like math.
@param plan_a: LogicalPlan a
@param plan_b: LogicalPlan b
@param node_id_a: graph node a
@param node_id_b: graph node b
@param visited_a: helper to track traversal in graph a
@param visited_b: helper to track traversal in graph b
@return: list of LogicalPlanDiffEntry
"""
diff_results: list[LogicalPlanDiffEntry] = []
if node_id_a in visited_a and node_id_b in visited_b:
return diff_results
visited_a.add(node_id_a)
visited_b.add(node_id_b)
node_a = plan_a.nodes[node_id_a]
node_b = plan_b.nodes[node_id_b]
# Compare node types
# pylint: disable=unidiomatic-typecheck
if type(node_a) != type(node_b):
diff_results.append(
LogicalPlanDiffEntry(node_a=node_a, node_b=node_b, diff_type=LogicalNodeDiffType.OPERATOR_TYPE)
)
# Compare node data
if not node_a.logical_compare(node_b):
diff_results.append(
LogicalPlanDiffEntry(node_a=node_a, node_b=node_b, diff_type=LogicalNodeDiffType.OPERATOR_DATA)
)
# Compare the structure (inputs)
a_downstream = plan_a.downstream_nodes(node_id_a)
b_downstream = plan_b.downstream_nodes(node_id_b)
if len(a_downstream) != len(b_downstream):
diff_results.append(
LogicalPlanDiffEntry(node_a=node_a, node_b=node_b, diff_type=LogicalNodeDiffType.PLAN_STRUCTURE)
)
else:
for ds1, ds2 in zip(a_downstream, b_downstream):
diff_results.extend(compare_graphs(plan_a, plan_b, ds1, ds2, visited_a, visited_b))
return diff_results