File size: 2,785 Bytes
f5776d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import math
from typing import Any, Optional
import backoff
from dsp.modules.lm import LM
try:
import google.generativeai as genai
except ImportError:
google_api_error = Exception
print("Not loading Google because it is not installed.")
def backoff_hdlr(details):
"""Handler from https://pypi.org/project/backoff/"""
print(
"Backing off {wait:0.1f} seconds after {tries} tries "
"calling function {target} with kwargs "
"{kwargs}".format(**details)
)
def giveup_hdlr(details):
"""wrapper function that decides when to give up on retry"""
if "rate limits" in details.message:
return False
return True
class Google(LM):
"""Wrapper around Google's API.
Currently supported models include `gemini-pro-1.0`.
"""
def __init__(
self,
model: str = "gemini-pro-1.0",
api_key: Optional[str] = None,
**kwargs
):
"""
Parameters
----------
model : str
Which pre-trained model from Google to use?
Choices are [`gemini-pro-1.0`]
api_key : str
The API key for Google.
It can be obtained from https://cloud.google.com/generative-ai-studio
**kwargs: dict
Additional arguments to pass to the API provider.
"""
super().__init__(model)
self.google = genai.configure(api_key=self.api_key)
self.provider = "google"
self.kwargs = {
"model_name": model,
"temperature": 0.0 if "temperature" not in kwargs else kwargs["temperature"],
"max_output_tokens": 2048,
"top_p": 1,
"top_k": 1,
**kwargs
}
self.history: list[dict[str, Any]] = []
def basic_request(self, prompt: str, **kwargs):
raw_kwargs = kwargs
kwargs = {
**self.kwargs,
"prompt": prompt,
**kwargs,
}
response = self.co.generate(**kwargs)
history = {
"prompt": prompt,
"response": response,
"kwargs": kwargs,
"raw_kwargs": raw_kwargs,
}
self.history.append(history)
return response
@backoff.on_exception(
backoff.expo,
(google_api_error),
max_time=1000,
on_backoff=backoff_hdlr,
giveup=giveup_hdlr,
)
def request(self, prompt: str, **kwargs):
"""Handles retrieval of completions from Google whilst handling API errors"""
return self.basic_request(prompt, **kwargs)
def __call__(
self,
prompt: str,
only_completed: bool = True,
return_sorted: bool = False,
**kwargs
):
return self.request(prompt, **kwargs)
|