Source code for sycamore.llms.gemini

import datetime
import logging
from typing import Any, Optional, Union
import os
import io

from google.api_core import retry, retry_async

from sycamore.llms.config import GeminiModel, GeminiModels, LLMModel
from sycamore.llms.llms import LLM, LLMMode
from sycamore.llms.prompts.prompts import RenderedPrompt
from sycamore.utils.cache import Cache
from sycamore.utils.import_utils import requires_modules

logger = logging.getLogger(__name__)

# Base URL for Helicone API, if configured using the SYCAMORE_HELICONE_API_KEY environment variable.
HELICONE_BASE_URL = "https://gateway.helicone.ai"


def gemini_deserializer(kwargs):
    return Gemini(**kwargs)


[docs] class Gemini(LLM): """This is an LLM implementation that uses the Google Gemini API to generate text. Args: model_name: The name of the Gemini model to use. cache: A cache object to use for caching results. """ @requires_modules("google.genai", extra="google-genai") def __init__( self, model_name: Union[GeminiModels, GeminiModel, str], default_mode: LLMMode = LLMMode.ASYNC, cache: Optional[Cache] = None, api_key: Optional[str] = None, default_llm_kwargs: Optional[dict[str, Any]] = None, disable_helicone: bool = True, ): from google.genai import Client from google.genai.types import HttpOptionsDict self.model_name = model_name # Is this supposed to a string? if isinstance(model_name, GeminiModels): self.model = model_name.value elif isinstance(model_name, GeminiModel): self.model = model_name elif isinstance(model_name, str): self.model = GeminiModels.from_name(model_name).value if self.model is None: raise ValueError(f"Invalid model name: {model_name}") else: raise TypeError("model_name must be an instance of str, GeminiAIModel, or GeminiAIModels") api_key = api_key if api_key else os.getenv("GEMINI_API_KEY") # Helicone implementation from https://docs.helicone.ai/integrations/gemini/api/python http_options: Optional[HttpOptionsDict] = None if not disable_helicone and "SYCAMORE_HELICONE_API_KEY" in os.environ: http_options = { "base_url": HELICONE_BASE_URL, "headers": { "helicone-auth": f"Bearer {os.environ['SYCAMORE_HELICONE_API_KEY']}", "helicone-target-url": "https://generativelanguage.googleapis.com", }, } if "SYCAMORE_HELICONE_TAG" in os.environ: assert http_options["headers"] is not None, "type checking, unreachable" http_options["headers"].update({"Helicone-Property-Tag": os.environ["SYCAMORE_HELICONE_TAG"]}) self._client = Client(api_key=api_key, http_options=http_options) super().__init__(self.model.name, default_mode, cache, default_llm_kwargs=default_llm_kwargs) def __reduce__(self): kwargs = { "model_name": self.model_name, "cache": self._cache, "default_mode": self._default_mode, "default_llm_kwargs": self._default_llm_kwargs, } return gemini_deserializer, (kwargs,)
[docs] def default_mode(self) -> LLMMode: if self._default_mode is not None: return self._default_mode return LLMMode.ASYNC
[docs] def is_chat_mode(self) -> bool: """Returns True if the LLM is in chat mode, False otherwise.""" return True
def get_generate_kwargs(self, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> dict: from google.genai import types config = { "temperature": 0, "candidate_count": 1, **(llm_kwargs or {}), } if prompt.response_format: config["response_mime_type"] = "application/json" config["response_schema"] = prompt.response_format content_list: list[types.Content] = [] for message in prompt.messages: if message.role == "system": config["system_instruction"] = message.content continue role = "model" if message.role == "assistant" else "user" content = types.Content(parts=[types.Part.from_text(text=message.content)], role=role) if message.images: for image in message.images: buffered = io.BytesIO() image.save(buffered, format="PNG") image_bytes = buffered.getvalue() assert content.parts is not None # mypy content.parts.append(types.Part.from_bytes(data=image_bytes, mime_type="image/png")) content_list.append(content) if thinking_budget := config.pop("thinking_budget", None): config["thinking_config"] = types.ThinkingConfig(thinking_budget=thinking_budget) if thinking_level := config.pop("thinking_level", None): if "thinking_config" in config: logger.warning(f"Thinking level {thinking_level} overrides thinking budget {thinking_budget}") config["thinking_config"] = types.ThinkingConfig(thinking_level=types.ThinkingLevel(thinking_level)) return { "config": types.GenerateContentConfig(**config), "content": content_list, } def _metadata_from_response(self, model: str, kwargs, response, starttime) -> dict: wall_latency = datetime.datetime.now() - starttime md = response.usage_metadata in_tokens = int(md.prompt_token_count) if md and md.prompt_token_count else 0 out_tokens = int(md.candidates_token_count) if md and md.candidates_token_count else 0 reason = response.candidates[0].finish_reason from google.genai.types import FinishReason if reason != FinishReason.STOP: logger.warning( f"Gemini model stopped for unexpected reason {reason}. Kwargs: {kwargs}. Full response:\n{response}" ) if response.candidates[0].content is None or response.candidates[0].content.parts is None: import json logger.debug(f"Gemini model returned no content: {json.dumps(response.model_dump(), indent=4)}") output = "" else: output = " ".join(part.text if part else "" for part in response.candidates[0].content.parts) ret = { "model": model, "output": output, "wall_latency": wall_latency, "in_tokens": in_tokens, "out_tokens": out_tokens, } self.add_llm_metadata(kwargs, output, wall_latency, in_tokens, out_tokens, model=model) return ret
[docs] def generate_metadata( self, *, prompt: RenderedPrompt, model: Optional[LLMModel] = None, llm_kwargs: Optional[dict] = None ) -> dict: assert model is None or isinstance( model, GeminiModel ), f"model must be a GeminiModel, got {type(model)} from {model=}" if model is not None and model != self.model: logger.warning(f"Generating response using {model=} instead of {self.model=}") model_name = model.name if model else self.model.name llm_kwargs = self._merge_llm_kwargs(llm_kwargs) ret = self._llm_cache_get(prompt, llm_kwargs, model=model_name) if isinstance(ret, dict): return ret assert ret is None kwargs = self.get_generate_kwargs(prompt, llm_kwargs) start = datetime.datetime.now() response = self.generate_content(model=model_name, contents=kwargs["content"], config=kwargs["config"]) ret = self._metadata_from_response(model_name, kwargs, response, start) self._llm_cache_set(prompt, llm_kwargs, ret, model=model_name) return ret
[docs] def generate( self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None, model: Optional[LLMModel] = None ) -> str: d = self.generate_metadata(prompt=prompt, model=model, llm_kwargs=llm_kwargs) return d["output"]
[docs] async def generate_async( self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None, model: Optional[LLMModel] = None ) -> str: llm_kwargs = self._merge_llm_kwargs(llm_kwargs) model_name = model.name if model else self.model.name if self.model.name != model_name: logger.info(f"Overriding Gemini model from {self.model.name} to {model_name}") ret = self._llm_cache_get(prompt, llm_kwargs, model=model_name) if isinstance(ret, dict): return ret["output"] assert ret is None kwargs = self.get_generate_kwargs(prompt, llm_kwargs) start = datetime.datetime.now() response = await self.generate_content_async( model=model_name, contents=kwargs["content"], config=kwargs["config"] ) ret = self._metadata_from_response(model_name, kwargs, response, start) self._llm_cache_set(prompt, llm_kwargs, ret, model=model_name) return ret["output"]
@retry.Retry( predicate=retry.if_transient_error, initial=1.0, maximum=60.0, multiplier=2.0, timeout=120.0, ) def generate_content(self, model, contents, config): return self._client.models.generate_content(model=model, contents=contents, config=config) @retry_async.AsyncRetry( predicate=retry.if_transient_error, initial=1.0, maximum=60.0, multiplier=2.0, timeout=120.0, ) async def generate_content_async(self, model: str, contents, config): return await self._client.aio.models.generate_content(model=model, contents=contents, config=config)