File size: 6,548 Bytes
f5776d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
"""
A generalized AWS LLM.
"""
from __future__ import annotations
from abc import abstractmethod
import logging
from typing import Any, Literal
import json
from dsp.modules.lm import LM
# Heuristic translating number of chars to tokens
# ~4 chars = 1 token
CHARS2TOKENS: int = 4
class AWSLM(LM):
"""
This class adds support for an AWS model
"""
def __init__(
self,
model: str,
region_name: str,
service_name: str,
max_new_tokens: int,
truncate_long_prompts: bool = False,
input_output_ratio: int = 3,
batch_n: bool = True,
) -> None:
"""_summary_
Args:
service_name (str): Used in context of invoking the boto3 API.
region_name (str, optional): The AWS region where this LM is hosted.
model (str, optional): An LM name, e.g., a bedrock name or an AWS endpoint.
max_new_tokens (int, optional): The maximum number of tokens to be sampled from the LM.
input_output_ratio (int, optional): The rough size of the number of input tokens to output tokens in the worst case. Defaults to 3.
temperature (float, optional): _description_. Defaults to 0.0.
truncate_long_prompts (bool, optional): If True, remove extremely long inputs to context. Defaults to False.
batch_n (bool, False): If False, call the LM N times rather than batching. Not all AWS models support the n parameter.
"""
super().__init__(model=model)
# AWS doesn't have an equivalent of max_tokens so let's clarify
# that the expected input is going to be about 2x as long as the output
self.kwargs["max_tokens"] = max_new_tokens * input_output_ratio
self._max_new_tokens: int = max_new_tokens
self._model_name: str = model
self._truncate_long_prompt_prompts: bool = truncate_long_prompts
self._batch_n: bool = batch_n
import boto3
self.predictor = boto3.client(service_name, region_name=region_name)
@abstractmethod
def _create_body(self, prompt: str, **kwargs):
pass
def _sanitize_kwargs(self, query_kwargs: dict[str, Any]) -> dict[str, Any]:
"""Ensure that input kwargs can be used by Bedrock or Sagemaker."""
base_args: dict[str, Any] = {"temperature": self.kwargs["temperature"]}
for k, v in base_args.items():
if k not in query_kwargs:
query_kwargs[k] = v
if query_kwargs["temperature"] > 1.0:
query_kwargs["temperature"] = 0.99
if query_kwargs["temperature"] < 0.01:
query_kwargs["temperature"] = 0.01
return query_kwargs
@abstractmethod
def _call_model(self, body: str) -> str | list[str]:
"""Call model, get generated input without the formatted prompt"""
pass
@abstractmethod
def _extract_input_parameters(
self, body: dict[Any, Any]
) -> dict[str, str | float | int]:
pass
def _simple_api_call(self, formatted_prompt: str, **kwargs) -> str | list[str]:
body = self._create_body(formatted_prompt, **kwargs)
json_body = json.dumps(body)
llm_out: str | list[str] = self._call_model(json_body)
if isinstance(llm_out, str):
llm_out = llm_out.replace(formatted_prompt, "")
else:
llm_out = [generated.replace(formatted_prompt, "") for generated in llm_out]
self.history.append(
{"prompt": formatted_prompt, "response": llm_out, "kwargs": body}
)
return llm_out
def basic_request(self, prompt, **kwargs) -> str | list[str]:
"""Query the endpoint."""
# Remove any texts that are too long
formatted_prompt: str
if self._truncate_long_prompt_prompts:
truncated_prompt: str = self._truncate_prompt(prompt)
formatted_prompt = self._format_prompt(truncated_prompt)
else:
formatted_prompt = self._format_prompt((prompt))
llm_out: str | list[str]
if "n" in kwargs.keys():
if self._batch_n:
llm_out = self._simple_api_call(
formatted_prompt=formatted_prompt, **kwargs
)
else:
del kwargs["n"]
llm_out = []
for _ in range(0, kwargs["n"]):
generated: str | list[str] = self._simple_api_call(
formatted_prompt=formatted_prompt, **kwargs
)
if isinstance(generated, str):
llm_out.append(generated)
else:
raise TypeError("Error, list type was returned from LM call")
else:
llm_out = self._simple_api_call(formatted_prompt=formatted_prompt, **kwargs)
return llm_out
def _estimate_tokens(self, text: str) -> int:
return len(text) * CHARS2TOKENS
@abstractmethod
def _format_prompt(self, raw_prompt: str) -> str:
pass
def _truncate_prompt(
self,
input_text: str,
remove_beginning_or_ending: Literal["beginning", "ending"] = "beginning",
max_input_tokens: int = 2500,
) -> str:
"""Reformat inputs such that they do not overflow context size limitation."""
token_count = self._estimate_tokens(input_text)
if token_count > self.kwargs["max_tokens"]:
logging.info("Excessive prompt found in llm input")
logging.info("Truncating texts to avoid error")
max_chars: int = CHARS2TOKENS * max_input_tokens
truncated_text: str
if remove_beginning_or_ending == "ending":
truncated_text = input_text[0:max_chars]
else:
truncated_text = input_text[-max_chars:]
return truncated_text
return input_text
def __call__(
self,
prompt: str,
only_completed: bool = True,
return_sorted: bool = False,
**kwargs,
) -> list[str]:
"""
Query the AWS LLM.
There is only support for only_completed=True and return_sorted=False
right now.
"""
if not only_completed:
raise NotImplementedError("Error, only_completed not yet supported!")
if return_sorted:
raise NotImplementedError("Error, return_sorted not yet supported!")
generated = self.basic_request(prompt, **kwargs)
return [generated]
|