EureCA / dsp /modules /cohere.py
tonneli's picture
Delete history
f5776d3
import math
from typing import Any, Optional
import backoff
from dsp.modules.lm import LM
try:
import cohere
cohere_api_error = cohere.CohereAPIError
except ImportError:
cohere_api_error = Exception
# print("Not loading Cohere because it is not installed.")
def backoff_hdlr(details):
"""Handler from https://pypi.org/project/backoff/"""
print(
"Backing off {wait:0.1f} seconds after {tries} tries "
"calling function {target} with kwargs "
"{kwargs}".format(**details)
)
def giveup_hdlr(details):
"""wrapper function that decides when to give up on retry"""
if "rate limits" in details.message:
return False
return True
class Cohere(LM):
"""Wrapper around Cohere's API.
Currently supported models include `command`, `command-nightly`, `command-light`, `command-light-nightly`.
"""
def __init__(
self,
model: str = "command-nightly",
api_key: Optional[str] = None,
stop_sequences: list[str] = [],
**kwargs
):
"""
Parameters
----------
model : str
Which pre-trained model from Cohere to use?
Choices are [`command`, `command-nightly`, `command-light`, `command-light-nightly`]
api_key : str
The API key for Cohere.
It can be obtained from https://dashboard.cohere.ai/register.
stop_sequences : list of str
Additional stop tokens to end generation.
**kwargs: dict
Additional arguments to pass to the API provider.
"""
super().__init__(model)
self.co = cohere.Client(api_key)
self.provider = "cohere"
self.kwargs = {
"model": model,
"temperature": 0.0,
"max_tokens": 150,
"p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
"num_generations": 1,
**kwargs
}
self.stop_sequences = stop_sequences
self.max_num_generations = 5
self.history: list[dict[str, Any]] = []
def basic_request(self, prompt: str, **kwargs):
raw_kwargs = kwargs
kwargs = {
**self.kwargs,
"stop_sequences": self.stop_sequences,
"prompt": prompt,
**kwargs,
}
response = self.co.generate(**kwargs)
history = {
"prompt": prompt,
"response": response,
"kwargs": kwargs,
"raw_kwargs": raw_kwargs,
}
self.history.append(history)
return response
@backoff.on_exception(
backoff.expo,
(cohere_api_error),
max_time=1000,
on_backoff=backoff_hdlr,
giveup=giveup_hdlr,
)
def request(self, prompt: str, **kwargs):
"""Handles retrieval of completions from Cohere whilst handling API errors"""
return self.basic_request(prompt, **kwargs)
def __call__(
self,
prompt: str,
only_completed: bool = True,
return_sorted: bool = False,
**kwargs
):
assert only_completed, "for now"
assert return_sorted is False, "for now"
# Cohere uses 'num_generations' whereas dsp.generate() uses 'n'
n = kwargs.pop("n", 1)
# Cohere can generate upto self.max_num_generations completions at a time
choices = []
num_iters = math.ceil(n / self.max_num_generations)
remainder = n % self.max_num_generations
for i in range(num_iters):
if i == (num_iters - 1):
kwargs["num_generations"] = (
remainder if remainder != 0 else self.max_num_generations
)
else:
kwargs["num_generations"] = self.max_num_generations
response = self.request(prompt, **kwargs)
choices.extend(response.generations)
completions = [c.text for c in choices]
if return_sorted and kwargs.get("num_generations", 1) > 1:
scored_completions = []
for c in choices:
scored_completions.append((c.likelihood, c.text))
scored_completions = sorted(scored_completions, reverse=True)
completions = [c for _, c in scored_completions]
return completions