File size: 3,007 Bytes
f5776d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
"""Clarifai LM integration"""
from typing import Any, Optional
from dsp.modules.lm import LM
class ClarifaiLLM(LM):
"""Integration to call models hosted in clarifai platform.
Args:
model (str, optional): Clarifai URL of the model. Defaults to "Mistral-7B-Instruct".
api_key (Optional[str], optional): CLARIFAI_PAT token. Defaults to None.
**kwargs: Additional arguments to pass to the API provider.
Example:
import dspy
dspy.configure(lm=dspy.Clarifai(model=MODEL_URL,
api_key=CLARIFAI_PAT,
inference_params={"max_tokens":100,'temperature':0.6}))
"""
def __init__(
self,
model: str = "https://clarifai.com/mistralai/completion/models/mistral-7B-Instruct",
api_key: Optional[str] = None,
**kwargs,
):
super().__init__(model)
try:
from clarifai.client.model import Model
except ImportError as err:
raise ImportError("ClarifaiLLM requires `pip install clarifai`.") from err
self.provider = "clarifai"
self.pat = api_key
self._model = Model(url=model, pat=api_key)
self.kwargs = {"n": 1, **kwargs}
self.history: list[dict[str, Any]] = []
self.kwargs["temperature"] = (
self.kwargs["inference_params"]["temperature"]
if "inference_params" in self.kwargs
and "temperature" in self.kwargs["inference_params"]
else 0.0
)
self.kwargs["max_tokens"] = (
self.kwargs["inference_params"]["max_tokens"]
if "inference_params" in self.kwargs
and "max_tokens" in self.kwargs["inference_params"]
else 150
)
def basic_request(self, prompt, **kwargs):
params = (
self.kwargs["inference_params"] if "inference_params" in self.kwargs else {}
)
response = (
self._model.predict_by_bytes(
input_bytes=prompt.encode(encoding="utf-8"),
input_type="text",
inference_params=params,
)
.outputs[0]
.data.text.raw
)
kwargs = {**self.kwargs, **kwargs}
history = {
"prompt": prompt,
"response": response,
"kwargs": kwargs,
}
self.history.append(history)
return response
def request(self, prompt: str, **kwargs):
return self.basic_request(prompt, **kwargs)
def __call__(
self,
prompt: str,
only_completed: bool = True,
return_sorted: bool = False,
**kwargs,
):
assert only_completed, "for now"
assert return_sorted is False, "for now"
n = kwargs.pop("n", 1)
completions = []
for i in range(n):
response = self.request(prompt, **kwargs)
completions.append(response)
return completions
|