from abc import abstractmethod, ABC
from typing import Any
from sycamore.utils.import_utils import requires_modules
from sycamore.data import OpenSearchQueryResult, Element, OpenSearchQuery
from sycamore.plan_nodes import Node, NonCPUUser, NonGPUUser
from sycamore.transforms.map import Map
import logging
logger = logging.getLogger("ray")
class QueryExecutor(ABC):
@abstractmethod
def query(self, query: Any) -> Any:
pass
def __call__(self, query: Any) -> Any:
return self.query(query)
[docs]
class OpenSearchQueryExecutor(QueryExecutor):
def __init__(self, os_client_args: dict) -> None:
super().__init__()
self._os_client_args = os_client_args
@requires_modules("opensearchpy", extra="opensearch")
def query(self, query: OpenSearchQuery) -> OpenSearchQueryResult:
from sycamore.connectors.opensearch.utils import OpenSearchClientWithLogging
logger.debug("Executing OS query: " + str(query))
client = OpenSearchClientWithLogging(**self._os_client_args)
os_result = client.transport.perform_request(
"POST",
url=f"/{query['index']}/_search",
params=query.get("params", None),
headers=query.get("headers", None),
body=query["query"],
)
result = OpenSearchQueryResult(query)
result.result = os_result
result.hits = [Element(hit["_source"]) for hit in os_result["hits"]["hits"]]
if "ext" in os_result and "retrieval_augmented_generation" in os_result["ext"]:
result.generated_answer = os_result["ext"]["retrieval_augmented_generation"]["answer"]
return result
[docs]
class Query(NonCPUUser, NonGPUUser, Map):
"""
Given a DocSet of queries, executes them and generates a DocSet of query results.
"""
def __init__(self, child: Node, query_executor: QueryExecutor, **kwargs):
super().__init__(child, f=query_executor.query, **kwargs)