Source code for sycamore.llms.anthropic

from datetime import datetime
import logging
from typing import Any, Optional, Union
import asyncio
import random
import time

from PIL import Image

from sycamore.llms.config import AnthropicModels, AnthropicModel, LLMModel
from sycamore.llms.llms import LLM, LLMMode
from sycamore.llms.prompts import RenderedPrompt
from sycamore.utils.cache import Cache
from sycamore.utils.image_utils import base64_data
from sycamore.utils.import_utils import requires_modules

DEFAULT_MAX_TOKENS = 1000
INITIAL_BACKOFF = 1
BATCH_POLL_INTERVAL = 10

logger = logging.getLogger(__name__)


def rewrite_system_messages(messages: Optional[list[dict]]) -> Optional[list[dict]]:
    # Anthropic models don't accept messages with "role" set to "system", and
    # requires alternation between "user" and "assistant" roles. So, we rewrite
    # the messages to fold all "system" messages into the "user" role.
    if not messages:
        return messages
    orig_messages = messages.copy()
    cur_system_message = ""
    for i, message in enumerate(orig_messages):
        if message.get("role") == "system":
            cur_system_message += message.get("content", "")
        else:
            if cur_system_message:
                messages[i]["content"] = cur_system_message + "\n" + message.get("content", "")
                cur_system_message = ""
    return [m for m in messages if m.get("role") != "system"]


def get_generate_kwargs(prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> dict:
    kwargs = {
        "temperature": 0,
        **(llm_kwargs or {}),
    }
    kwargs["max_tokens"] = kwargs.get("max_tokens", DEFAULT_MAX_TOKENS)

    # Anthropic models require _exactly_ alternation between "user" and "assistant"
    # roles, so we break the messages into groups of consecutive user/assistant
    # messages, treating "system" as "user". Then crunch each group down to a single
    # message to ensure alternation.
    message_groups = []  # type: ignore
    last_role = None

    for m in prompt.messages:
        r = m.role
        if r == "system":
            r = "user"
        if r != last_role:
            message_groups.append([])
        message_groups[-1].append(m)
        last_role = r

    messages = []
    for group in message_groups:
        role = group[0].role
        if role == "system":
            role = "user"
        content = "\n".join(m.content for m in group)
        if any(m.images is not None for m in group):
            images = [im for m in group if m.images is not None for im in m.images]
            contents = [{"type": "text", "text": content}]
            for im in images:
                contents.append(
                    {  # type: ignore
                        "type": "image",
                        "source": {  # type: ignore
                            "type": "base64",
                            "media_type": "image/png",
                            "data": base64_data(im),
                        },
                    }
                )
            messages.append({"role": role, "content": contents})
        else:
            messages.append({"role": role, "content": content})

    kwargs["messages"] = messages
    return kwargs


def format_image(image: Image.Image) -> dict[str, Any]:
    return {
        "type": "image",
        "source": {"type": "base64", "media_type": "image/png", "data": base64_data(image)},
    }


def anthropic_deserializer(kwargs):
    return Anthropic(**kwargs)


[docs] class Anthropic(LLM): """This is an LLM implementation that uses the AWS Claude API to generate text. Args: model_name: The name of the Claude model to use. cache: A cache object to use for caching results. """ @requires_modules("anthropic", extra="anthropic") def __init__( self, model_name: Union[AnthropicModels, AnthropicModel, str], default_mode: LLMMode = LLMMode.ASYNC, cache: Optional[Cache] = None, default_llm_kwargs: Optional[dict[str, Any]] = None, client_args: dict[str, Any] = {}, ): # We import this here so we can share utility code with the Bedrock # LLM implementation without requiring an Anthropic dependency. from anthropic import Anthropic as AnthropicClient from anthropic import AsyncAnthropic as AsyncAnthropicClient self.model_name = model_name if isinstance(model_name, AnthropicModels): self.model = model_name.value elif isinstance(model_name, AnthropicModel): self.model = model_name elif isinstance(model_name, str): model = AnthropicModels.from_name(name=model_name) if model is None: raise ValueError(f"Invalid model name: {model_name}") self.model = model.value self._client = AnthropicClient(**client_args) self._async_client = AsyncAnthropicClient(**client_args) super().__init__(self.model.name, default_mode, cache, default_llm_kwargs=default_llm_kwargs) def __reduce__(self): kwargs = { "model_name": self.model_name, "cache": self._cache, "default_mode": self._default_mode, "default_llm_kwargs": self._default_llm_kwargs, } return anthropic_deserializer, (kwargs,)
[docs] def default_mode(self) -> LLMMode: if self._default_mode is not None: return self._default_mode return LLMMode.ASYNC
[docs] def is_chat_mode(self) -> bool: """Returns True if the LLM is in chat mode, False otherwise.""" return True
[docs] def format_image(self, image: Image.Image) -> dict[str, Any]: return format_image(image)
def _metadata_from_response(self, model: str, kwargs, response, starttime) -> dict: wall_latency = datetime.now() - starttime in_tokens = response.usage.input_tokens out_tokens = response.usage.output_tokens output = response.content[0].text ret = { "model": model, "output": output, "wall_latency": wall_latency, "in_tokens": in_tokens, "out_tokens": out_tokens, } self.add_llm_metadata(kwargs, output, wall_latency, in_tokens, out_tokens, model=model) return ret
[docs] def generate_metadata( self, *, prompt: RenderedPrompt, model: Optional[LLMModel] = None, llm_kwargs: Optional[dict] = None ) -> dict: assert model is None or isinstance( model, AnthropicModel ), f"model must be a AnthropicModel, got {type(model)} from {model=}" if model is not None and model != self.model: logger.info(f"Generating response using {model=} instead of {self.model=}") model_name = model.name if model else self.model.name llm_kwargs = self._merge_llm_kwargs(llm_kwargs) ret = self._llm_cache_get(prompt, llm_kwargs, model=model_name) if isinstance(ret, dict): return ret kwargs = get_generate_kwargs(prompt, llm_kwargs) start = datetime.now() response = self._client.messages.create(model=model_name, **kwargs) ret = self._metadata_from_response(model_name, kwargs, response, start) logging.debug(f"Generated response from Anthropic model: {ret}") self._llm_cache_set(prompt, llm_kwargs, ret, model=model_name) return ret
[docs] def generate( self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None, model: Optional[LLMModel] = None ) -> str: d = self.generate_metadata(prompt=prompt, model=model, llm_kwargs=llm_kwargs) return d["output"]
[docs] async def generate_async( self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None, model: Optional[LLMModel] = None ) -> str: from anthropic import RateLimitError, APIConnectionError llm_kwargs = self._merge_llm_kwargs(llm_kwargs) model_name: str = model.name if model else self.model.name if self.model.name != model_name: logging.info(f"Overriding Anthropic model from {self.model.name} to {model_name}") ret = self._llm_cache_get(prompt, llm_kwargs, model=model_name) if isinstance(ret, dict): return ret["output"] kwargs = get_generate_kwargs(prompt, llm_kwargs) start = datetime.now() done = False retries = 0 response = None while not done: try: response = await self._async_client.messages.create(model=model_name, **kwargs) done = True except (RateLimitError, APIConnectionError): backoff = INITIAL_BACKOFF * (2**retries) jitter = random.uniform(0, 0.1 * backoff) await asyncio.sleep(backoff + jitter) retries += 1 ret = self._metadata_from_response(model_name, kwargs, response, start) logging.debug(f"Generated response from Anthropic model: {ret}") self._llm_cache_set(prompt, llm_kwargs, ret, model=model_name) return ret["output"]
[docs] def generate_batch( self, *, prompts: list[RenderedPrompt], llm_kwargs: Optional[dict] = None, model: Optional[LLMModel] = None ) -> list[str]: from anthropic.types.message_create_params import MessageCreateParamsNonStreaming from anthropic.types.messages.batch_create_params import Request llm_kwargs = self._merge_llm_kwargs(llm_kwargs) model_name: str = model.name if model else self.model.name if self.model.name != model_name: logging.info(f"Overriding Anthropic model from {self.model.name} to {model_name}") cache_hits = [self._llm_cache_get(p, llm_kwargs, model=model_name) for p in prompts] calls = [] for p, ch, i in zip(prompts, cache_hits, range(len(prompts))): if ch is not None: continue kwargs = get_generate_kwargs(p, llm_kwargs) kwargs["model"] = model kwargs["max_tokens"] = kwargs.get("max_tokens", 1024) mparams = MessageCreateParamsNonStreaming(**kwargs) # type: ignore rq = Request(custom_id=str(i), params=mparams) calls.append(rq) starttime = datetime.now() batch = self._client.messages.batches.create(requests=calls) while batch.processing_status == "in_progress": time.sleep(BATCH_POLL_INTERVAL) batch = self._client.messages.batches.retrieve(batch.id) results = self._client.messages.batches.results(batch.id) for rs, call in zip(results, calls): if rs.result.type != "succeeded": raise ValueError(f"Call failed: {rs}") id = int(rs.custom_id) in_kwargs = get_generate_kwargs(prompts[id], llm_kwargs) ret = self._metadata_from_response(model_name, in_kwargs, rs.result.message, starttime) cache_hits[id] = ret self._llm_cache_set(prompts[id], llm_kwargs, ret, model=model_name) return [ch["output"] for ch in cache_hits]