|
import math |
|
from typing import Any, Optional |
|
import backoff |
|
|
|
from dsp.modules.lm import LM |
|
|
|
try: |
|
import cohere |
|
cohere_api_error = cohere.CohereAPIError |
|
except ImportError: |
|
cohere_api_error = Exception |
|
|
|
|
|
|
|
def backoff_hdlr(details): |
|
"""Handler from https://pypi.org/project/backoff/""" |
|
print( |
|
"Backing off {wait:0.1f} seconds after {tries} tries " |
|
"calling function {target} with kwargs " |
|
"{kwargs}".format(**details) |
|
) |
|
|
|
|
|
def giveup_hdlr(details): |
|
"""wrapper function that decides when to give up on retry""" |
|
if "rate limits" in details.message: |
|
return False |
|
return True |
|
|
|
|
|
class Cohere(LM): |
|
"""Wrapper around Cohere's API. |
|
|
|
Currently supported models include `command`, `command-nightly`, `command-light`, `command-light-nightly`. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
model: str = "command-nightly", |
|
api_key: Optional[str] = None, |
|
stop_sequences: list[str] = [], |
|
**kwargs |
|
): |
|
""" |
|
Parameters |
|
---------- |
|
model : str |
|
Which pre-trained model from Cohere to use? |
|
Choices are [`command`, `command-nightly`, `command-light`, `command-light-nightly`] |
|
api_key : str |
|
The API key for Cohere. |
|
It can be obtained from https://dashboard.cohere.ai/register. |
|
stop_sequences : list of str |
|
Additional stop tokens to end generation. |
|
**kwargs: dict |
|
Additional arguments to pass to the API provider. |
|
""" |
|
super().__init__(model) |
|
self.co = cohere.Client(api_key) |
|
self.provider = "cohere" |
|
self.kwargs = { |
|
"model": model, |
|
"temperature": 0.0, |
|
"max_tokens": 150, |
|
"p": 1, |
|
"frequency_penalty": 0, |
|
"presence_penalty": 0, |
|
"num_generations": 1, |
|
**kwargs |
|
} |
|
self.stop_sequences = stop_sequences |
|
self.max_num_generations = 5 |
|
|
|
self.history: list[dict[str, Any]] = [] |
|
|
|
def basic_request(self, prompt: str, **kwargs): |
|
raw_kwargs = kwargs |
|
kwargs = { |
|
**self.kwargs, |
|
"stop_sequences": self.stop_sequences, |
|
"prompt": prompt, |
|
**kwargs, |
|
} |
|
response = self.co.generate(**kwargs) |
|
|
|
history = { |
|
"prompt": prompt, |
|
"response": response, |
|
"kwargs": kwargs, |
|
"raw_kwargs": raw_kwargs, |
|
} |
|
self.history.append(history) |
|
|
|
return response |
|
|
|
@backoff.on_exception( |
|
backoff.expo, |
|
(cohere_api_error), |
|
max_time=1000, |
|
on_backoff=backoff_hdlr, |
|
giveup=giveup_hdlr, |
|
) |
|
def request(self, prompt: str, **kwargs): |
|
"""Handles retrieval of completions from Cohere whilst handling API errors""" |
|
return self.basic_request(prompt, **kwargs) |
|
|
|
def __call__( |
|
self, |
|
prompt: str, |
|
only_completed: bool = True, |
|
return_sorted: bool = False, |
|
**kwargs |
|
): |
|
assert only_completed, "for now" |
|
assert return_sorted is False, "for now" |
|
|
|
|
|
n = kwargs.pop("n", 1) |
|
|
|
|
|
choices = [] |
|
num_iters = math.ceil(n / self.max_num_generations) |
|
remainder = n % self.max_num_generations |
|
for i in range(num_iters): |
|
if i == (num_iters - 1): |
|
kwargs["num_generations"] = ( |
|
remainder if remainder != 0 else self.max_num_generations |
|
) |
|
else: |
|
kwargs["num_generations"] = self.max_num_generations |
|
response = self.request(prompt, **kwargs) |
|
choices.extend(response.generations) |
|
completions = [c.text for c in choices] |
|
|
|
if return_sorted and kwargs.get("num_generations", 1) > 1: |
|
scored_completions = [] |
|
|
|
for c in choices: |
|
scored_completions.append((c.likelihood, c.text)) |
|
|
|
scored_completions = sorted(scored_completions, reverse=True) |
|
completions = [c for _, c in scored_completions] |
|
|
|
return completions |
|
|