File size: 4,341 Bytes
f5776d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import math
from typing import Any, Optional
import backoff

from dsp.modules.lm import LM

try:
    import cohere
    cohere_api_error = cohere.CohereAPIError
except ImportError:
    cohere_api_error = Exception
    # print("Not loading Cohere because it is not installed.")


def backoff_hdlr(details):
    """Handler from https://pypi.org/project/backoff/"""
    print(
        "Backing off {wait:0.1f} seconds after {tries} tries "
        "calling function {target} with kwargs "
        "{kwargs}".format(**details)
    )


def giveup_hdlr(details):
    """wrapper function that decides when to give up on retry"""
    if "rate limits" in details.message:
        return False
    return True


class Cohere(LM):
    """Wrapper around Cohere's API.

    Currently supported models include `command`, `command-nightly`, `command-light`, `command-light-nightly`.
    """

    def __init__(
        self,
        model: str = "command-nightly",
        api_key: Optional[str] = None,
        stop_sequences: list[str] = [],
        **kwargs
    ):
        """
        Parameters
        ----------
        model : str
            Which pre-trained model from Cohere to use?
            Choices are [`command`, `command-nightly`, `command-light`, `command-light-nightly`]
        api_key : str
            The API key for Cohere.
            It can be obtained from https://dashboard.cohere.ai/register.
        stop_sequences : list of str
            Additional stop tokens to end generation.
        **kwargs: dict
            Additional arguments to pass to the API provider.
        """
        super().__init__(model)
        self.co = cohere.Client(api_key)
        self.provider = "cohere"
        self.kwargs = {
            "model": model,
            "temperature": 0.0,
            "max_tokens": 150,
            "p": 1,
            "frequency_penalty": 0,
            "presence_penalty": 0,
            "num_generations": 1,
            **kwargs
        }
        self.stop_sequences = stop_sequences
        self.max_num_generations = 5

        self.history: list[dict[str, Any]] = []

    def basic_request(self, prompt: str, **kwargs):
        raw_kwargs = kwargs
        kwargs = {
            **self.kwargs,
            "stop_sequences": self.stop_sequences,
            "prompt": prompt,
            **kwargs,
        }
        response = self.co.generate(**kwargs)

        history = {
            "prompt": prompt,
            "response": response,
            "kwargs": kwargs,
            "raw_kwargs": raw_kwargs,
        }
        self.history.append(history)

        return response

    @backoff.on_exception(
        backoff.expo,
        (cohere_api_error),
        max_time=1000,
        on_backoff=backoff_hdlr,
        giveup=giveup_hdlr,
    )
    def request(self, prompt: str, **kwargs):
        """Handles retrieval of completions from Cohere whilst handling API errors"""
        return self.basic_request(prompt, **kwargs)

    def __call__(
        self,
        prompt: str,
        only_completed: bool = True,
        return_sorted: bool = False,
        **kwargs
    ):
        assert only_completed, "for now"
        assert return_sorted is False, "for now"

        # Cohere uses 'num_generations' whereas dsp.generate() uses 'n'
        n = kwargs.pop("n", 1)

        # Cohere can generate upto self.max_num_generations completions at a time
        choices = []
        num_iters = math.ceil(n / self.max_num_generations)
        remainder = n % self.max_num_generations
        for i in range(num_iters):
            if i == (num_iters - 1):
                kwargs["num_generations"] = (
                    remainder if remainder != 0 else self.max_num_generations
                )
            else:
                kwargs["num_generations"] = self.max_num_generations
            response = self.request(prompt, **kwargs)
            choices.extend(response.generations)
        completions = [c.text for c in choices]

        if return_sorted and kwargs.get("num_generations", 1) > 1:
            scored_completions = []

            for c in choices:
                scored_completions.append((c.likelihood, c.text))

            scored_completions = sorted(scored_completions, reverse=True)
            completions = [c for _, c in scored_completions]

        return completions