import inspect
from abc import ABC, abstractmethod
import copy
import datetime
import io
import json
import logging
import base64
from PIL import Image
from typing import Any, Optional
from sycamore.llms.config import LLMMode, LLMModel
from sycamore.utils.cache import Cache
from sycamore.utils.thread_local import ThreadLocalAccess, ADD_METADATA_TO_OUTPUT
from sycamore.data.metadata import add_metadata
from sycamore.llms.prompts import RenderedPrompt, RenderedMessage
from sycamore.utils.deprecate import deprecated
logger = logging.getLogger(__name__)
[docs]
class LLM(ABC):
"""Abstract representation of an LLM instance. and should be subclassed to implement specific LLM providers."""
model: LLMModel
def __init__(
self,
model_name: str,
default_mode: LLMMode,
cache: Optional[Cache] = None,
default_llm_kwargs: Optional[dict[str, Any]] = None,
):
self._model_name: str = model_name
self._cache = cache
self._default_mode = default_mode
self._default_llm_kwargs = default_llm_kwargs or {}
[docs]
def default_mode(self) -> LLMMode:
"""Returns the default execution mode for the llm"""
return self._default_mode
def _merge_llm_kwargs(self, llm_kwargs: Optional[dict[str, Any]] = None) -> dict[str, Any]:
"""Merges the default LLM kwargs with any provided LLM kwargs.
Prefers the passed in values if there is a conflict.
"""
new_kwargs = copy.copy(self._default_llm_kwargs)
new_kwargs.update(llm_kwargs or {})
logging.debug(f"Merging LLM kwargs: {new_kwargs}")
return new_kwargs
[docs]
@abstractmethod
def generate(
self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None, model: Optional[LLMModel] = None
) -> str:
"""Generates a response from the LLM for the given prompt and LLM parameters."""
pass
@deprecated(version="0.1.31", reason="Use generate, with a RenderedPrompt, instead")
def generate_old(self, *, prompt_kwargs: dict[str, Any], llm_kwargs: Optional[dict] = None) -> str:
"""Generates a response from the LLM"""
from sycamore.llms.prompts.default_prompts import SimplePrompt
if "prompt" in prompt_kwargs:
prompt = prompt_kwargs.get("prompt")
if isinstance(prompt, SimplePrompt):
prompt = prompt.as_messages()
for idx, prompt_message in enumerate(prompt):
prompt[idx]["content"] = prompt_message["content"].format(**prompt_kwargs)
rendered = RenderedPrompt(
messages=[RenderedMessage(role=m["role"], content=m["content"]) for m in prompt]
)
else:
rendered = RenderedPrompt(messages=[RenderedMessage(role="user", content=f"{prompt}")])
elif "messages" in prompt_kwargs:
ms = prompt_kwargs.get("messages", [])
messages = [RenderedMessage(role=m["role"], content=m["content"]) for m in ms]
rendered = RenderedPrompt(messages=messages)
else:
raise ValueError("Either 'prompt' or 'messages' must be specified in prompt_kwargs")
return self.generate(prompt=rendered, llm_kwargs=llm_kwargs)
[docs]
@abstractmethod
def is_chat_mode(self) -> bool:
"""Returns True if the LLM is in chat mode, False otherwise."""
pass
[docs]
async def generate_async(
self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None, model: Optional[LLMModel] = None
) -> str:
"""Generates a response from the LLM for the given prompt and LLM parameters asynchronously."""
raise NotImplementedError("This LLM does not support asynchronous generation.")
@deprecated(version="0.1.31", reason="Use generate_async, with a RenderedPrompt, instead")
async def generate_async_old(self, *, prompt_kwargs: dict[str, Any], llm_kwargs: Optional[dict] = None) -> str:
from sycamore.llms.prompts.default_prompts import SimplePrompt
if "prompt" in prompt_kwargs:
prompt = prompt_kwargs.get("prompt")
if isinstance(prompt, SimplePrompt):
prompt = prompt.as_messages()
for idx, prompt_message in enumerate(prompt):
prompt[idx]["content"] = prompt_message["content"].format(**prompt_kwargs)
rendered = RenderedPrompt(
messages=[RenderedMessage(role=m["role"], content=m["content"]) for m in prompt]
)
else:
rendered = RenderedPrompt(messages=[RenderedMessage(role="user", content=f"{prompt}")])
elif "messages" in prompt_kwargs:
ms = prompt_kwargs.get("messages", [])
messages = [RenderedMessage(role=m["role"], content=m["content"]) for m in ms]
rendered = RenderedPrompt(messages=messages)
else:
raise ValueError("Either 'prompt' or 'messages' must be specified in prompt_kwargs")
return await self.generate_async(prompt=rendered, llm_kwargs=llm_kwargs)
[docs]
def generate_batch(
self, *, prompts: list[RenderedPrompt], llm_kwargs: Optional[dict] = None, model: Optional[LLMModel] = None
) -> list[str]:
"""Generates a series of responses from the LLM for the given series of prompts. Order is preserved."""
raise NotImplementedError("This LLM does not support batched generation")
def __str__(self):
return f"{self.__class__.__name__}({self._model_name})"
@staticmethod
def _jsonable_response_format(prompt: RenderedPrompt) -> Any:
import pydantic
if inspect.isclass(prompt.response_format) and issubclass(prompt.response_format, pydantic.BaseModel):
return prompt.response_format.model_json_schema()
else:
return prompt.response_format
@staticmethod
def _llm_cache_json_default(obj: Any) -> Any:
import pydantic
if isinstance(obj, RenderedPrompt):
return {
"__type__": "RenderedPrompt",
"messages": obj.messages,
"response_format": LLM._jsonable_response_format(obj),
}
if isinstance(obj, RenderedMessage):
return {"__type__": "RenderedMessage", "role": obj.role, "content": obj.content, "images": obj.images}
if isinstance(obj, Image.Image):
with io.BytesIO() as buffer:
obj.save(buffer, format="PNG")
data = base64.b64encode(buffer.getvalue()).decode("ascii")
return {"__type__": "PIL.Image", "format": "PNG", "data": data}
if isinstance(obj, datetime.timedelta):
return {
"__type__": "datetime.timedelta",
"days": obj.days,
"seconds": obj.seconds,
"microseconds": obj.microseconds,
}
if isinstance(obj, datetime.datetime):
return {"__type__": "datetime.datetime", "value": obj.isoformat()}
if inspect.isclass(obj) and issubclass(obj, pydantic.BaseModel):
return obj.model_json_schema()
raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable")
@staticmethod
def _llm_cache_json_object_hook(obj: dict[str, Any]) -> Any:
obj_type = obj.get("__type__")
if obj_type == "RenderedPrompt":
return RenderedPrompt(messages=obj["messages"], response_format=obj.get("response_format"))
if obj_type == "RenderedMessage":
return RenderedMessage(role=obj["role"], content=obj["content"], images=obj.get("images"))
if obj_type == "PIL.Image":
image_data = base64.b64decode(obj["data"])
with Image.open(io.BytesIO(image_data)) as image:
return image.copy()
if obj_type == "datetime.timedelta":
return datetime.timedelta(
days=obj["days"],
seconds=obj["seconds"],
microseconds=obj["microseconds"],
)
if obj_type == "datetime.datetime":
return datetime.datetime.fromisoformat(obj["value"])
return obj
@staticmethod
def _llm_cache_json_dumps(data: Any) -> str:
return json.dumps(data, default=LLM._llm_cache_json_default, sort_keys=True, separators=(",", ":"))
@staticmethod
def _llm_cache_jsonable(data: Any) -> Any:
return json.loads(LLM._llm_cache_json_dumps(data))
@staticmethod
def _llm_cache_from_jsonable(data: Any) -> Any:
return json.loads(json.dumps(data), object_hook=LLM._llm_cache_json_object_hook)
def _llm_cache_key(
self, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None, model: Optional[str] = None
) -> str:
"""Return a cache key for the given prompt and LLM parameters."""
assert self._cache
# We now default this to an empty dict if None is passed in.
llm_kwargs = llm_kwargs or {}
model = model or self._model_name
rf = self._jsonable_response_format(prompt)
ms = prompt.messages
combined = {
"prompt": RenderedPrompt(messages=ms),
"prompt.response_format": rf,
"llm_kwargs": llm_kwargs,
"model_name": model,
}
data = self._llm_cache_json_dumps(combined).encode("utf-8")
return self._cache.get_hash_context(data).hexdigest()
def _use_caching(self, llm_kwargs: Optional[dict]):
if not self._cache:
return False
if not llm_kwargs:
return True
# Only cache when temperature setting is zero.
return llm_kwargs.get("temperature", 0) == 0
def _llm_cache_get(self, prompt: RenderedPrompt, llm_kwargs: Optional[dict], model: Optional[str] = None) -> Any:
"""Get a cached result for the given prompt and LLM parameters. Returns the cached
result if found, or otherwise None."""
if not self._use_caching(llm_kwargs):
return None
assert self._cache is not None, "make mypy happy"
model = model or self._model_name
llm_kwargs = llm_kwargs or {}
key = self._llm_cache_key(prompt, llm_kwargs, model=model)
hit = self._cache.get(key)
if hit:
hit = self._llm_cache_from_jsonable(hit)
assert (
len(hit) == 5
and hit.get("prompt") == RenderedPrompt(messages=prompt.messages)
and hit.get("prompt.response_format") == self._jsonable_response_format(prompt)
and hit.get("llm_kwargs") == llm_kwargs
and hit.get("model_name") == model
and "result" in hit
), f"""
Found LLM cache content mismatch:
key={key}
prompt={prompt}, cached={hit.get("prompt")}
cached_response_format={hit.get("prompt.response_format")}
llm_kwargs={llm_kwargs}, cached={hit.get("llm_kwargs")}
model_name={model}, cached={hit.get("model_name")}
Complete hit: {hit}"""
return hit.get("result")
return None
def _llm_cache_set(
self, prompt: RenderedPrompt, llm_kwargs: Optional[dict], result: Any, model: Optional[str] = None
) -> None:
"""Set a cached result for the given key."""
if not self._use_caching(llm_kwargs):
return
assert self._cache is not None, "make mypy happy"
model = model or self._model_name
llm_kwargs = llm_kwargs or {}
key = self._llm_cache_key(prompt, llm_kwargs, model=model)
data = self._llm_cache_jsonable(
{
"prompt": RenderedPrompt(messages=prompt.messages),
"prompt.response_format": self._jsonable_response_format(prompt),
"llm_kwargs": llm_kwargs,
"model_name": model or self._model_name,
"result": result,
}
)
print(f"Cache set using {model=}")
self._cache.set(
key,
data,
)
def add_llm_metadata(
self, kwargs, output, wall_latency, in_tokens, out_tokens, model: Optional[str] = None
) -> None:
tls = ThreadLocalAccess(ADD_METADATA_TO_OUTPUT)
if tls.present():
model = model or self._model_name
metadata = self.get_metadata(model, kwargs, output, wall_latency, in_tokens, out_tokens)
add_metadata(**metadata)
class FakeLLM(LLM):
"""Useful for tests where the fake LLM needs to run in a ray function because mocks are not serializable"""
def __init__(
self,
*,
return_value="trivial",
cache: Optional[Cache] = None,
default_mode: LLMMode = LLMMode.SYNC,
default_llm_kwargs: Optional[dict[str, Any]] = None,
):
super().__init__("trivial", cache=cache, default_mode=default_mode, default_llm_kwargs=default_llm_kwargs)
self._return_value = return_value
def generate(
self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None, model: Optional[LLMModel] = None
) -> str:
return self._return_value
def generate_metadata(
self, *, prompt: RenderedPrompt, model: Optional[LLMModel] = None, llm_kwargs: Optional[dict] = None
) -> dict:
model_name = model.name if model else self._model_name
return {
"model": model_name,
"output": self._return_value,
"wall_latency": datetime.timedelta(seconds=0),
"in_tokens": 0,
"out_tokens": 0,
}
async def generate_async(
self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None, model: Optional[LLMModel] = None
) -> str:
return self.generate(prompt=prompt, llm_kwargs=llm_kwargs)
def is_chat_mode(self) -> bool:
return False