Source code for sycamore.llms.llms

import inspect
from abc import ABC, abstractmethod
import copy
import datetime
import io
import json
import logging
import base64
from PIL import Image
from typing import Any, Optional

from sycamore.llms.config import LLMMode, LLMModel
from sycamore.utils.cache import Cache
from sycamore.utils.thread_local import ThreadLocalAccess, ADD_METADATA_TO_OUTPUT
from sycamore.data.metadata import add_metadata
from sycamore.llms.prompts import RenderedPrompt, RenderedMessage

from sycamore.utils.deprecate import deprecated

logger = logging.getLogger(__name__)


[docs] class LLM(ABC): """Abstract representation of an LLM instance. and should be subclassed to implement specific LLM providers.""" model: LLMModel def __init__( self, model_name: str, default_mode: LLMMode, cache: Optional[Cache] = None, default_llm_kwargs: Optional[dict[str, Any]] = None, ): self._model_name: str = model_name self._cache = cache self._default_mode = default_mode self._default_llm_kwargs = default_llm_kwargs or {}
[docs] def default_mode(self) -> LLMMode: """Returns the default execution mode for the llm""" return self._default_mode
def _merge_llm_kwargs(self, llm_kwargs: Optional[dict[str, Any]] = None) -> dict[str, Any]: """Merges the default LLM kwargs with any provided LLM kwargs. Prefers the passed in values if there is a conflict. """ new_kwargs = copy.copy(self._default_llm_kwargs) new_kwargs.update(llm_kwargs or {}) logging.debug(f"Merging LLM kwargs: {new_kwargs}") return new_kwargs
[docs] @abstractmethod def generate( self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None, model: Optional[LLMModel] = None ) -> str: """Generates a response from the LLM for the given prompt and LLM parameters.""" pass
[docs] def generate_metadata( self, *, prompt: RenderedPrompt, model: Optional[LLMModel] = None, llm_kwargs: Optional[dict] = None ) -> dict: """Generates a response from the LLM for the given prompt and LLM parameters and returns metadata. TODO: Implement generic_generate(model: LLMModel, ...) and generic_generate_args(model_class, kwargs) to specify default arguments during model construction. The former should cache the client if possible. Then we can call generate on any model rather than only ones in the same family. Potentially get rid of the model argument to generate* at the same time to simplify the implementations.""" raise NotImplementedError("This LLM does not support metadata generation")
@deprecated(version="0.1.31", reason="Use generate, with a RenderedPrompt, instead") def generate_old(self, *, prompt_kwargs: dict[str, Any], llm_kwargs: Optional[dict] = None) -> str: """Generates a response from the LLM""" from sycamore.llms.prompts.default_prompts import SimplePrompt if "prompt" in prompt_kwargs: prompt = prompt_kwargs.get("prompt") if isinstance(prompt, SimplePrompt): prompt = prompt.as_messages() for idx, prompt_message in enumerate(prompt): prompt[idx]["content"] = prompt_message["content"].format(**prompt_kwargs) rendered = RenderedPrompt( messages=[RenderedMessage(role=m["role"], content=m["content"]) for m in prompt] ) else: rendered = RenderedPrompt(messages=[RenderedMessage(role="user", content=f"{prompt}")]) elif "messages" in prompt_kwargs: ms = prompt_kwargs.get("messages", []) messages = [RenderedMessage(role=m["role"], content=m["content"]) for m in ms] rendered = RenderedPrompt(messages=messages) else: raise ValueError("Either 'prompt' or 'messages' must be specified in prompt_kwargs") return self.generate(prompt=rendered, llm_kwargs=llm_kwargs)
[docs] @abstractmethod def is_chat_mode(self) -> bool: """Returns True if the LLM is in chat mode, False otherwise.""" pass
[docs] def format_image(self, image: Image.Image) -> dict[str, Any]: """Returns a dictionary containing the specified image suitable for use in an LLM message.""" raise NotImplementedError("This LLM does not support images.")
[docs] async def generate_async( self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None, model: Optional[LLMModel] = None ) -> str: """Generates a response from the LLM for the given prompt and LLM parameters asynchronously.""" raise NotImplementedError("This LLM does not support asynchronous generation.")
@deprecated(version="0.1.31", reason="Use generate_async, with a RenderedPrompt, instead") async def generate_async_old(self, *, prompt_kwargs: dict[str, Any], llm_kwargs: Optional[dict] = None) -> str: from sycamore.llms.prompts.default_prompts import SimplePrompt if "prompt" in prompt_kwargs: prompt = prompt_kwargs.get("prompt") if isinstance(prompt, SimplePrompt): prompt = prompt.as_messages() for idx, prompt_message in enumerate(prompt): prompt[idx]["content"] = prompt_message["content"].format(**prompt_kwargs) rendered = RenderedPrompt( messages=[RenderedMessage(role=m["role"], content=m["content"]) for m in prompt] ) else: rendered = RenderedPrompt(messages=[RenderedMessage(role="user", content=f"{prompt}")]) elif "messages" in prompt_kwargs: ms = prompt_kwargs.get("messages", []) messages = [RenderedMessage(role=m["role"], content=m["content"]) for m in ms] rendered = RenderedPrompt(messages=messages) else: raise ValueError("Either 'prompt' or 'messages' must be specified in prompt_kwargs") return await self.generate_async(prompt=rendered, llm_kwargs=llm_kwargs)
[docs] def generate_batch( self, *, prompts: list[RenderedPrompt], llm_kwargs: Optional[dict] = None, model: Optional[LLMModel] = None ) -> list[str]: """Generates a series of responses from the LLM for the given series of prompts. Order is preserved.""" raise NotImplementedError("This LLM does not support batched generation")
def __str__(self): return f"{self.__class__.__name__}({self._model_name})" @staticmethod def _jsonable_response_format(prompt: RenderedPrompt) -> Any: import pydantic if inspect.isclass(prompt.response_format) and issubclass(prompt.response_format, pydantic.BaseModel): return prompt.response_format.model_json_schema() else: return prompt.response_format @staticmethod def _llm_cache_json_default(obj: Any) -> Any: import pydantic if isinstance(obj, RenderedPrompt): return { "__type__": "RenderedPrompt", "messages": obj.messages, "response_format": LLM._jsonable_response_format(obj), } if isinstance(obj, RenderedMessage): return {"__type__": "RenderedMessage", "role": obj.role, "content": obj.content, "images": obj.images} if isinstance(obj, Image.Image): with io.BytesIO() as buffer: obj.save(buffer, format="PNG") data = base64.b64encode(buffer.getvalue()).decode("ascii") return {"__type__": "PIL.Image", "format": "PNG", "data": data} if isinstance(obj, datetime.timedelta): return { "__type__": "datetime.timedelta", "days": obj.days, "seconds": obj.seconds, "microseconds": obj.microseconds, } if isinstance(obj, datetime.datetime): return {"__type__": "datetime.datetime", "value": obj.isoformat()} if inspect.isclass(obj) and issubclass(obj, pydantic.BaseModel): return obj.model_json_schema() raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable") @staticmethod def _llm_cache_json_object_hook(obj: dict[str, Any]) -> Any: obj_type = obj.get("__type__") if obj_type == "RenderedPrompt": return RenderedPrompt(messages=obj["messages"], response_format=obj.get("response_format")) if obj_type == "RenderedMessage": return RenderedMessage(role=obj["role"], content=obj["content"], images=obj.get("images")) if obj_type == "PIL.Image": image_data = base64.b64decode(obj["data"]) with Image.open(io.BytesIO(image_data)) as image: return image.copy() if obj_type == "datetime.timedelta": return datetime.timedelta( days=obj["days"], seconds=obj["seconds"], microseconds=obj["microseconds"], ) if obj_type == "datetime.datetime": return datetime.datetime.fromisoformat(obj["value"]) return obj @staticmethod def _llm_cache_json_dumps(data: Any) -> str: return json.dumps(data, default=LLM._llm_cache_json_default, sort_keys=True, separators=(",", ":")) @staticmethod def _llm_cache_jsonable(data: Any) -> Any: return json.loads(LLM._llm_cache_json_dumps(data)) @staticmethod def _llm_cache_from_jsonable(data: Any) -> Any: return json.loads(json.dumps(data), object_hook=LLM._llm_cache_json_object_hook) def _llm_cache_key( self, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None, model: Optional[str] = None ) -> str: """Return a cache key for the given prompt and LLM parameters.""" assert self._cache # We now default this to an empty dict if None is passed in. llm_kwargs = llm_kwargs or {} model = model or self._model_name rf = self._jsonable_response_format(prompt) ms = prompt.messages combined = { "prompt": RenderedPrompt(messages=ms), "prompt.response_format": rf, "llm_kwargs": llm_kwargs, "model_name": model, } data = self._llm_cache_json_dumps(combined).encode("utf-8") return self._cache.get_hash_context(data).hexdigest() def _use_caching(self, llm_kwargs: Optional[dict]): if not self._cache: return False if not llm_kwargs: return True # Only cache when temperature setting is zero. return llm_kwargs.get("temperature", 0) == 0 def _llm_cache_get(self, prompt: RenderedPrompt, llm_kwargs: Optional[dict], model: Optional[str] = None) -> Any: """Get a cached result for the given prompt and LLM parameters. Returns the cached result if found, or otherwise None.""" if not self._use_caching(llm_kwargs): return None assert self._cache is not None, "make mypy happy" model = model or self._model_name llm_kwargs = llm_kwargs or {} key = self._llm_cache_key(prompt, llm_kwargs, model=model) hit = self._cache.get(key) if hit: hit = self._llm_cache_from_jsonable(hit) assert ( len(hit) == 5 and hit.get("prompt") == RenderedPrompt(messages=prompt.messages) and hit.get("prompt.response_format") == self._jsonable_response_format(prompt) and hit.get("llm_kwargs") == llm_kwargs and hit.get("model_name") == model and "result" in hit ), f""" Found LLM cache content mismatch: key={key} prompt={prompt}, cached={hit.get("prompt")} cached_response_format={hit.get("prompt.response_format")} llm_kwargs={llm_kwargs}, cached={hit.get("llm_kwargs")} model_name={model}, cached={hit.get("model_name")} Complete hit: {hit}""" return hit.get("result") return None def _llm_cache_set( self, prompt: RenderedPrompt, llm_kwargs: Optional[dict], result: Any, model: Optional[str] = None ) -> None: """Set a cached result for the given key.""" if not self._use_caching(llm_kwargs): return assert self._cache is not None, "make mypy happy" model = model or self._model_name llm_kwargs = llm_kwargs or {} key = self._llm_cache_key(prompt, llm_kwargs, model=model) data = self._llm_cache_jsonable( { "prompt": RenderedPrompt(messages=prompt.messages), "prompt.response_format": self._jsonable_response_format(prompt), "llm_kwargs": llm_kwargs, "model_name": model or self._model_name, "result": result, } ) print(f"Cache set using {model=}") self._cache.set( key, data, )
[docs] def get_metadata(self, model, kwargs, response_text, wall_latency, in_tokens, out_tokens) -> dict: """Generate metadata for the LLM response.""" return { "model": model, "temperature": kwargs.get("temperature", None), "usage": { "completion_tokens": out_tokens, "prompt_tokens": in_tokens, "total_tokens": in_tokens + out_tokens, }, "wall_latency": wall_latency, "prompt": kwargs.get("prompt") or kwargs.get("messages"), "output": response_text, }
def add_llm_metadata( self, kwargs, output, wall_latency, in_tokens, out_tokens, model: Optional[str] = None ) -> None: tls = ThreadLocalAccess(ADD_METADATA_TO_OUTPUT) if tls.present(): model = model or self._model_name metadata = self.get_metadata(model, kwargs, output, wall_latency, in_tokens, out_tokens) add_metadata(**metadata)
class FakeLLM(LLM): """Useful for tests where the fake LLM needs to run in a ray function because mocks are not serializable""" def __init__( self, *, return_value="trivial", cache: Optional[Cache] = None, default_mode: LLMMode = LLMMode.SYNC, default_llm_kwargs: Optional[dict[str, Any]] = None, ): super().__init__("trivial", cache=cache, default_mode=default_mode, default_llm_kwargs=default_llm_kwargs) self._return_value = return_value def generate( self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None, model: Optional[LLMModel] = None ) -> str: return self._return_value def generate_metadata( self, *, prompt: RenderedPrompt, model: Optional[LLMModel] = None, llm_kwargs: Optional[dict] = None ) -> dict: model_name = model.name if model else self._model_name return { "model": model_name, "output": self._return_value, "wall_latency": datetime.timedelta(seconds=0), "in_tokens": 0, "out_tokens": 0, } async def generate_async( self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None, model: Optional[LLMModel] = None ) -> str: return self.generate(prompt=prompt, llm_kwargs=llm_kwargs) def is_chat_mode(self) -> bool: return False