import math from typing import Any, Optional import backoff from dsp.modules.lm import LM try: import google.generativeai as genai except ImportError: google_api_error = Exception print("Not loading Google because it is not installed.") def backoff_hdlr(details): """Handler from https://pypi.org/project/backoff/""" print( "Backing off {wait:0.1f} seconds after {tries} tries " "calling function {target} with kwargs " "{kwargs}".format(**details) ) def giveup_hdlr(details): """wrapper function that decides when to give up on retry""" if "rate limits" in details.message: return False return True class Google(LM): """Wrapper around Google's API. Currently supported models include `gemini-pro-1.0`. """ def __init__( self, model: str = "gemini-pro-1.0", api_key: Optional[str] = None, **kwargs ): """ Parameters ---------- model : str Which pre-trained model from Google to use? Choices are [`gemini-pro-1.0`] api_key : str The API key for Google. It can be obtained from https://cloud.google.com/generative-ai-studio **kwargs: dict Additional arguments to pass to the API provider. """ super().__init__(model) self.google = genai.configure(api_key=self.api_key) self.provider = "google" self.kwargs = { "model_name": model, "temperature": 0.0 if "temperature" not in kwargs else kwargs["temperature"], "max_output_tokens": 2048, "top_p": 1, "top_k": 1, **kwargs } self.history: list[dict[str, Any]] = [] def basic_request(self, prompt: str, **kwargs): raw_kwargs = kwargs kwargs = { **self.kwargs, "prompt": prompt, **kwargs, } response = self.co.generate(**kwargs) history = { "prompt": prompt, "response": response, "kwargs": kwargs, "raw_kwargs": raw_kwargs, } self.history.append(history) return response @backoff.on_exception( backoff.expo, (google_api_error), max_time=1000, on_backoff=backoff_hdlr, giveup=giveup_hdlr, ) def request(self, prompt: str, **kwargs): """Handles retrieval of completions from Google whilst handling API errors""" return self.basic_request(prompt, **kwargs) def __call__( self, prompt: str, only_completed: bool = True, return_sorted: bool = False, **kwargs ): return self.request(prompt, **kwargs)