File size: 6,786 Bytes
f5776d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
from dsp.modules.lm import LM
from typing import Any, Literal, Optional

import os, multiprocessing, datetime, hashlib
import requests

import json


def post_request_metadata(model_name, prompt):
    """Creates a serialized request object for the Ollama API."""
    timestamp = datetime.datetime.now().timestamp()
    id_string = str(timestamp) + model_name + prompt
    hashlib.sha1().update(id_string.encode("utf-8"))
    id_hash = hashlib.sha1().hexdigest()
    return {"id": f"chatcmpl-{id_hash}", "object": "chat.completion", "created": int(timestamp), "model": model_name}


class OllamaLocal(LM):
    """Wrapper around a locally hosted Ollama model (API: https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values and https://github.com/jmorganca/ollama/blob/main/docs/api.md#generate-a-completion).
    Returns dictionary info in the OpenAI API style (https://platform.openai.com/docs/api-reference/chat/object).

    Args:
        model (str, optional): Name of Ollama model. Defaults to "llama2".
        model_type (Literal["chat", "text"], optional): The type of model that was specified. Mainly to decide the optimal prompting strategy. Defaults to "text".
        base_url (str):  Protocol, host name, and port to the served ollama model. Defaults to "http://localhost:11434" as in ollama docs.
        timeout_s (float): Timeout period (in seconds) for the post request to llm.
        **kwargs: Additional arguments to pass to the API.
    """

    def __init__(
        self,
        model: str = "llama2",
        model_type: Literal["chat", "text"] = "text",
        base_url: str = "http://localhost:11434",
        timeout_s: float = 15,
        temperature: float = 0.0,
        max_tokens: int = 150,
        top_p: int = 1,
        top_k: int = 20,
        frequency_penalty: float = 0,
        presence_penalty: float = 0,
        n: int = 1,
        num_ctx: int = 1024,
        **kwargs,
    ):
        super().__init__(model)

        self.provider = "ollama"
        self.model_type = model_type
        self.base_url = base_url
        self.model_name = model
        self.timeout_s = timeout_s

        self.kwargs = {
            "temperature": temperature,
            "max_tokens": max_tokens,
            "top_p": top_p,
            "top_k": top_k,
            "frequency_penalty": frequency_penalty,
            "presence_penalty": presence_penalty,
            "n": n,
            "num_ctx": num_ctx,
            **kwargs,
        }

        # Ollama uses num_predict instead of max_tokens
        self.kwargs["num_predict"] = self.kwargs["max_tokens"]

        self.history: list[dict[str, Any]] = []
        self.version = kwargs["version"] if "version" in kwargs else ""

        # Ollama occasionally does not send `prompt_eval_count` in response body.
        # https://github.com/stanfordnlp/dspy/issues/293
        self._prev_prompt_eval_count = 0

    def basic_request(self, prompt: str, **kwargs):
        raw_kwargs = kwargs

        kwargs = {**self.kwargs, **kwargs}

        request_info = post_request_metadata(self.model_name, prompt)
        request_info["choices"] = []
        settings_dict = {
            "model": self.model_name,
            "options": {k: v for k, v in kwargs.items() if k not in ["n", "max_tokens"]},
            "stream": False,
        }
        if self.model_type == "chat":
            settings_dict["messages"] = [{"role": "user", "content": prompt}]
        else:
            settings_dict["prompt"] = prompt

        urlstr = f"{self.base_url}/api/chat" if self.model_type == "chat" else f"{self.base_url}/api/generate"
        tot_eval_tokens = 0
        for i in range(kwargs["n"]):
            response = requests.post(urlstr, json=settings_dict, timeout=self.timeout_s)

            # Check if the request was successful (HTTP status code 200)
            if response.status_code != 200:
                # If the request was not successful, print an error message
                print(f"Error: CODE {response.status_code} - {response.text}")

            response_json = response.json()

            text = (
                response_json.get("message").get("content")
                if self.model_type == "chat"
                else response_json.get("response")
            )
            request_info["choices"].append(
                {
                    "index": i,
                    "message": {
                        "role": "assistant",
                        "content": "".join(text),
                    },
                    "finish_reason": "stop",
                }
            )
            tot_eval_tokens += response_json.get("eval_count")
        request_info["additional_kwargs"] = {k: v for k, v in response_json.items() if k not in ["response"]}

        request_info["usage"] = {
            "prompt_tokens": response_json.get("prompt_eval_count", self._prev_prompt_eval_count),
            "completion_tokens": tot_eval_tokens,
            "total_tokens": response_json.get("prompt_eval_count", self._prev_prompt_eval_count) + tot_eval_tokens,
        }

        history = {
            "prompt": prompt,
            "response": request_info,
            "kwargs": kwargs,
            "raw_kwargs": raw_kwargs,
        }
        self.history.append(history)

        return request_info

    def request(self, prompt: str, **kwargs):
        """Wrapper for requesting completions from the Ollama model."""
        if "model_type" in kwargs:
            del kwargs["model_type"]

        return self.basic_request(prompt, **kwargs)

    def _get_choice_text(self, choice: dict[str, Any]) -> str:
        return choice["message"]["content"]

    def __call__(
        self,
        prompt: str,
        only_completed: bool = True,
        return_sorted: bool = False,
        **kwargs,
    ) -> list[dict[str, Any]]:
        """Retrieves completions from Ollama.

        Args:
            prompt (str): prompt to send to Ollama
            only_completed (bool, optional): return only completed responses and ignores completion due to length. Defaults to True.
            return_sorted (bool, optional): sort the completion choices using the returned probabilities. Defaults to False.

        Returns:
            list[dict[str, Any]]: list of completion choices
        """

        assert only_completed, "for now"
        assert return_sorted is False, "for now"

        response = self.request(prompt, **kwargs)

        choices = response["choices"]

        completed_choices = [c for c in choices if c["finish_reason"] != "length"]

        if only_completed and len(completed_choices):
            choices = completed_choices

        completions = [self._get_choice_text(c) for c in choices]

        return completions