import datetime
import logging
from typing import Any, Optional, Union
import os
import io
from google.api_core import retry, retry_async
from sycamore.llms.config import GeminiModel, GeminiModels, LLMModel
from sycamore.llms.llms import LLM, LLMMode
from sycamore.llms.prompts.prompts import RenderedPrompt
from sycamore.utils.cache import Cache
from sycamore.utils.import_utils import requires_modules
logger = logging.getLogger(__name__)
# Base URL for Helicone API, if configured using the SYCAMORE_HELICONE_API_KEY environment variable.
HELICONE_BASE_URL = "https://gateway.helicone.ai"
def gemini_deserializer(kwargs):
return Gemini(**kwargs)
[docs]
class Gemini(LLM):
"""This is an LLM implementation that uses the Google Gemini API to generate text.
Args:
model_name: The name of the Gemini model to use.
cache: A cache object to use for caching results.
"""
@requires_modules("google.genai", extra="google-genai")
def __init__(
self,
model_name: Union[GeminiModels, GeminiModel, str],
default_mode: LLMMode = LLMMode.ASYNC,
cache: Optional[Cache] = None,
api_key: Optional[str] = None,
default_llm_kwargs: Optional[dict[str, Any]] = None,
disable_helicone: bool = True,
):
from google.genai import Client
from google.genai.types import HttpOptionsDict
self.model_name = model_name # Is this supposed to a string?
if isinstance(model_name, GeminiModels):
self.model = model_name.value
elif isinstance(model_name, GeminiModel):
self.model = model_name
elif isinstance(model_name, str):
self.model = GeminiModels.from_name(model_name).value
if self.model is None:
raise ValueError(f"Invalid model name: {model_name}")
else:
raise TypeError("model_name must be an instance of str, GeminiAIModel, or GeminiAIModels")
api_key = api_key if api_key else os.getenv("GEMINI_API_KEY")
# Helicone implementation from https://docs.helicone.ai/integrations/gemini/api/python
http_options: Optional[HttpOptionsDict] = None
if not disable_helicone and "SYCAMORE_HELICONE_API_KEY" in os.environ:
http_options = {
"base_url": HELICONE_BASE_URL,
"headers": {
"helicone-auth": f"Bearer {os.environ['SYCAMORE_HELICONE_API_KEY']}",
"helicone-target-url": "https://generativelanguage.googleapis.com",
},
}
if "SYCAMORE_HELICONE_TAG" in os.environ:
assert http_options["headers"] is not None, "type checking, unreachable"
http_options["headers"].update({"Helicone-Property-Tag": os.environ["SYCAMORE_HELICONE_TAG"]})
self._client = Client(api_key=api_key, http_options=http_options)
super().__init__(self.model.name, default_mode, cache, default_llm_kwargs=default_llm_kwargs)
def __reduce__(self):
kwargs = {
"model_name": self.model_name,
"cache": self._cache,
"default_mode": self._default_mode,
"default_llm_kwargs": self._default_llm_kwargs,
}
return gemini_deserializer, (kwargs,)
[docs]
def default_mode(self) -> LLMMode:
if self._default_mode is not None:
return self._default_mode
return LLMMode.ASYNC
[docs]
def is_chat_mode(self) -> bool:
"""Returns True if the LLM is in chat mode, False otherwise."""
return True
def get_generate_kwargs(self, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None) -> dict:
from google.genai import types
config = {
"temperature": 0,
"candidate_count": 1,
**(llm_kwargs or {}),
}
if prompt.response_format:
config["response_mime_type"] = "application/json"
config["response_schema"] = prompt.response_format
content_list: list[types.Content] = []
for message in prompt.messages:
if message.role == "system":
config["system_instruction"] = message.content
continue
role = "model" if message.role == "assistant" else "user"
content = types.Content(parts=[types.Part.from_text(text=message.content)], role=role)
if message.images:
for image in message.images:
buffered = io.BytesIO()
image.save(buffered, format="PNG")
image_bytes = buffered.getvalue()
assert content.parts is not None # mypy
content.parts.append(types.Part.from_bytes(data=image_bytes, mime_type="image/png"))
content_list.append(content)
if thinking_budget := config.pop("thinking_budget", None):
config["thinking_config"] = types.ThinkingConfig(thinking_budget=thinking_budget)
if thinking_level := config.pop("thinking_level", None):
if "thinking_config" in config:
logger.warning(f"Thinking level {thinking_level} overrides thinking budget {thinking_budget}")
config["thinking_config"] = types.ThinkingConfig(thinking_level=types.ThinkingLevel(thinking_level))
return {
"config": types.GenerateContentConfig(**config),
"content": content_list,
}
def _metadata_from_response(self, model: str, kwargs, response, starttime) -> dict:
wall_latency = datetime.datetime.now() - starttime
md = response.usage_metadata
in_tokens = int(md.prompt_token_count) if md and md.prompt_token_count else 0
out_tokens = int(md.candidates_token_count) if md and md.candidates_token_count else 0
reason = response.candidates[0].finish_reason
from google.genai.types import FinishReason
if reason != FinishReason.STOP:
logger.warning(
f"Gemini model stopped for unexpected reason {reason}. Kwargs: {kwargs}. Full response:\n{response}"
)
if response.candidates[0].content is None or response.candidates[0].content.parts is None:
import json
logger.debug(f"Gemini model returned no content: {json.dumps(response.model_dump(), indent=4)}")
output = ""
else:
output = " ".join(part.text if part else "" for part in response.candidates[0].content.parts)
ret = {
"model": model,
"output": output,
"wall_latency": wall_latency,
"in_tokens": in_tokens,
"out_tokens": out_tokens,
}
self.add_llm_metadata(kwargs, output, wall_latency, in_tokens, out_tokens, model=model)
return ret
[docs]
def generate(
self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None, model: Optional[LLMModel] = None
) -> str:
d = self.generate_metadata(prompt=prompt, model=model, llm_kwargs=llm_kwargs)
return d["output"]
[docs]
async def generate_async(
self, *, prompt: RenderedPrompt, llm_kwargs: Optional[dict] = None, model: Optional[LLMModel] = None
) -> str:
llm_kwargs = self._merge_llm_kwargs(llm_kwargs)
model_name = model.name if model else self.model.name
if self.model.name != model_name:
logger.info(f"Overriding Gemini model from {self.model.name} to {model_name}")
ret = self._llm_cache_get(prompt, llm_kwargs, model=model_name)
if isinstance(ret, dict):
return ret["output"]
assert ret is None
kwargs = self.get_generate_kwargs(prompt, llm_kwargs)
start = datetime.datetime.now()
response = await self.generate_content_async(
model=model_name, contents=kwargs["content"], config=kwargs["config"]
)
ret = self._metadata_from_response(model_name, kwargs, response, start)
self._llm_cache_set(prompt, llm_kwargs, ret, model=model_name)
return ret["output"]
@retry.Retry(
predicate=retry.if_transient_error,
initial=1.0,
maximum=60.0,
multiplier=2.0,
timeout=120.0,
)
def generate_content(self, model, contents, config):
return self._client.models.generate_content(model=model, contents=contents, config=config)
@retry_async.AsyncRetry(
predicate=retry.if_transient_error,
initial=1.0,
maximum=60.0,
multiplier=2.0,
timeout=120.0,
)
async def generate_content_async(self, model: str, contents, config):
return await self._client.aio.models.generate_content(model=model, contents=contents, config=config)