File size: 6,786 Bytes
f5776d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
from dsp.modules.lm import LM
from typing import Any, Literal, Optional
import os, multiprocessing, datetime, hashlib
import requests
import json
def post_request_metadata(model_name, prompt):
"""Creates a serialized request object for the Ollama API."""
timestamp = datetime.datetime.now().timestamp()
id_string = str(timestamp) + model_name + prompt
hashlib.sha1().update(id_string.encode("utf-8"))
id_hash = hashlib.sha1().hexdigest()
return {"id": f"chatcmpl-{id_hash}", "object": "chat.completion", "created": int(timestamp), "model": model_name}
class OllamaLocal(LM):
"""Wrapper around a locally hosted Ollama model (API: https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values and https://github.com/jmorganca/ollama/blob/main/docs/api.md#generate-a-completion).
Returns dictionary info in the OpenAI API style (https://platform.openai.com/docs/api-reference/chat/object).
Args:
model (str, optional): Name of Ollama model. Defaults to "llama2".
model_type (Literal["chat", "text"], optional): The type of model that was specified. Mainly to decide the optimal prompting strategy. Defaults to "text".
base_url (str): Protocol, host name, and port to the served ollama model. Defaults to "http://localhost:11434" as in ollama docs.
timeout_s (float): Timeout period (in seconds) for the post request to llm.
**kwargs: Additional arguments to pass to the API.
"""
def __init__(
self,
model: str = "llama2",
model_type: Literal["chat", "text"] = "text",
base_url: str = "http://localhost:11434",
timeout_s: float = 15,
temperature: float = 0.0,
max_tokens: int = 150,
top_p: int = 1,
top_k: int = 20,
frequency_penalty: float = 0,
presence_penalty: float = 0,
n: int = 1,
num_ctx: int = 1024,
**kwargs,
):
super().__init__(model)
self.provider = "ollama"
self.model_type = model_type
self.base_url = base_url
self.model_name = model
self.timeout_s = timeout_s
self.kwargs = {
"temperature": temperature,
"max_tokens": max_tokens,
"top_p": top_p,
"top_k": top_k,
"frequency_penalty": frequency_penalty,
"presence_penalty": presence_penalty,
"n": n,
"num_ctx": num_ctx,
**kwargs,
}
# Ollama uses num_predict instead of max_tokens
self.kwargs["num_predict"] = self.kwargs["max_tokens"]
self.history: list[dict[str, Any]] = []
self.version = kwargs["version"] if "version" in kwargs else ""
# Ollama occasionally does not send `prompt_eval_count` in response body.
# https://github.com/stanfordnlp/dspy/issues/293
self._prev_prompt_eval_count = 0
def basic_request(self, prompt: str, **kwargs):
raw_kwargs = kwargs
kwargs = {**self.kwargs, **kwargs}
request_info = post_request_metadata(self.model_name, prompt)
request_info["choices"] = []
settings_dict = {
"model": self.model_name,
"options": {k: v for k, v in kwargs.items() if k not in ["n", "max_tokens"]},
"stream": False,
}
if self.model_type == "chat":
settings_dict["messages"] = [{"role": "user", "content": prompt}]
else:
settings_dict["prompt"] = prompt
urlstr = f"{self.base_url}/api/chat" if self.model_type == "chat" else f"{self.base_url}/api/generate"
tot_eval_tokens = 0
for i in range(kwargs["n"]):
response = requests.post(urlstr, json=settings_dict, timeout=self.timeout_s)
# Check if the request was successful (HTTP status code 200)
if response.status_code != 200:
# If the request was not successful, print an error message
print(f"Error: CODE {response.status_code} - {response.text}")
response_json = response.json()
text = (
response_json.get("message").get("content")
if self.model_type == "chat"
else response_json.get("response")
)
request_info["choices"].append(
{
"index": i,
"message": {
"role": "assistant",
"content": "".join(text),
},
"finish_reason": "stop",
}
)
tot_eval_tokens += response_json.get("eval_count")
request_info["additional_kwargs"] = {k: v for k, v in response_json.items() if k not in ["response"]}
request_info["usage"] = {
"prompt_tokens": response_json.get("prompt_eval_count", self._prev_prompt_eval_count),
"completion_tokens": tot_eval_tokens,
"total_tokens": response_json.get("prompt_eval_count", self._prev_prompt_eval_count) + tot_eval_tokens,
}
history = {
"prompt": prompt,
"response": request_info,
"kwargs": kwargs,
"raw_kwargs": raw_kwargs,
}
self.history.append(history)
return request_info
def request(self, prompt: str, **kwargs):
"""Wrapper for requesting completions from the Ollama model."""
if "model_type" in kwargs:
del kwargs["model_type"]
return self.basic_request(prompt, **kwargs)
def _get_choice_text(self, choice: dict[str, Any]) -> str:
return choice["message"]["content"]
def __call__(
self,
prompt: str,
only_completed: bool = True,
return_sorted: bool = False,
**kwargs,
) -> list[dict[str, Any]]:
"""Retrieves completions from Ollama.
Args:
prompt (str): prompt to send to Ollama
only_completed (bool, optional): return only completed responses and ignores completion due to length. Defaults to True.
return_sorted (bool, optional): sort the completion choices using the returned probabilities. Defaults to False.
Returns:
list[dict[str, Any]]: list of completion choices
"""
assert only_completed, "for now"
assert return_sorted is False, "for now"
response = self.request(prompt, **kwargs)
choices = response["choices"]
completed_choices = [c for c in choices if c["finish_reason"] != "length"]
if only_completed and len(completed_choices):
choices = completed_choices
completions = [self._get_choice_text(c) for c in choices]
return completions
|