|
""" |
|
A generalized AWS LLM. |
|
""" |
|
|
|
from __future__ import annotations |
|
|
|
from abc import abstractmethod |
|
import logging |
|
from typing import Any, Literal |
|
import json |
|
from dsp.modules.lm import LM |
|
|
|
|
|
|
|
CHARS2TOKENS: int = 4 |
|
|
|
|
|
class AWSLM(LM): |
|
""" |
|
This class adds support for an AWS model |
|
""" |
|
|
|
def __init__( |
|
self, |
|
model: str, |
|
region_name: str, |
|
service_name: str, |
|
max_new_tokens: int, |
|
truncate_long_prompts: bool = False, |
|
input_output_ratio: int = 3, |
|
batch_n: bool = True, |
|
) -> None: |
|
"""_summary_ |
|
|
|
Args: |
|
|
|
service_name (str): Used in context of invoking the boto3 API. |
|
region_name (str, optional): The AWS region where this LM is hosted. |
|
model (str, optional): An LM name, e.g., a bedrock name or an AWS endpoint. |
|
max_new_tokens (int, optional): The maximum number of tokens to be sampled from the LM. |
|
input_output_ratio (int, optional): The rough size of the number of input tokens to output tokens in the worst case. Defaults to 3. |
|
temperature (float, optional): _description_. Defaults to 0.0. |
|
truncate_long_prompts (bool, optional): If True, remove extremely long inputs to context. Defaults to False. |
|
batch_n (bool, False): If False, call the LM N times rather than batching. Not all AWS models support the n parameter. |
|
""" |
|
super().__init__(model=model) |
|
|
|
|
|
self.kwargs["max_tokens"] = max_new_tokens * input_output_ratio |
|
self._max_new_tokens: int = max_new_tokens |
|
self._model_name: str = model |
|
self._truncate_long_prompt_prompts: bool = truncate_long_prompts |
|
self._batch_n: bool = batch_n |
|
|
|
import boto3 |
|
|
|
self.predictor = boto3.client(service_name, region_name=region_name) |
|
|
|
@abstractmethod |
|
def _create_body(self, prompt: str, **kwargs): |
|
pass |
|
|
|
def _sanitize_kwargs(self, query_kwargs: dict[str, Any]) -> dict[str, Any]: |
|
"""Ensure that input kwargs can be used by Bedrock or Sagemaker.""" |
|
base_args: dict[str, Any] = {"temperature": self.kwargs["temperature"]} |
|
|
|
for k, v in base_args.items(): |
|
if k not in query_kwargs: |
|
query_kwargs[k] = v |
|
if query_kwargs["temperature"] > 1.0: |
|
query_kwargs["temperature"] = 0.99 |
|
if query_kwargs["temperature"] < 0.01: |
|
query_kwargs["temperature"] = 0.01 |
|
|
|
return query_kwargs |
|
|
|
@abstractmethod |
|
def _call_model(self, body: str) -> str | list[str]: |
|
"""Call model, get generated input without the formatted prompt""" |
|
pass |
|
|
|
@abstractmethod |
|
def _extract_input_parameters( |
|
self, body: dict[Any, Any] |
|
) -> dict[str, str | float | int]: |
|
pass |
|
|
|
def _simple_api_call(self, formatted_prompt: str, **kwargs) -> str | list[str]: |
|
body = self._create_body(formatted_prompt, **kwargs) |
|
json_body = json.dumps(body) |
|
llm_out: str | list[str] = self._call_model(json_body) |
|
if isinstance(llm_out, str): |
|
llm_out = llm_out.replace(formatted_prompt, "") |
|
else: |
|
llm_out = [generated.replace(formatted_prompt, "") for generated in llm_out] |
|
self.history.append( |
|
{"prompt": formatted_prompt, "response": llm_out, "kwargs": body} |
|
) |
|
return llm_out |
|
|
|
def basic_request(self, prompt, **kwargs) -> str | list[str]: |
|
"""Query the endpoint.""" |
|
|
|
|
|
formatted_prompt: str |
|
if self._truncate_long_prompt_prompts: |
|
truncated_prompt: str = self._truncate_prompt(prompt) |
|
formatted_prompt = self._format_prompt(truncated_prompt) |
|
else: |
|
formatted_prompt = self._format_prompt((prompt)) |
|
|
|
llm_out: str | list[str] |
|
if "n" in kwargs.keys(): |
|
if self._batch_n: |
|
llm_out = self._simple_api_call( |
|
formatted_prompt=formatted_prompt, **kwargs |
|
) |
|
else: |
|
del kwargs["n"] |
|
llm_out = [] |
|
for _ in range(0, kwargs["n"]): |
|
generated: str | list[str] = self._simple_api_call( |
|
formatted_prompt=formatted_prompt, **kwargs |
|
) |
|
if isinstance(generated, str): |
|
llm_out.append(generated) |
|
else: |
|
raise TypeError("Error, list type was returned from LM call") |
|
else: |
|
llm_out = self._simple_api_call(formatted_prompt=formatted_prompt, **kwargs) |
|
|
|
return llm_out |
|
|
|
def _estimate_tokens(self, text: str) -> int: |
|
return len(text) * CHARS2TOKENS |
|
|
|
@abstractmethod |
|
def _format_prompt(self, raw_prompt: str) -> str: |
|
pass |
|
|
|
def _truncate_prompt( |
|
self, |
|
input_text: str, |
|
remove_beginning_or_ending: Literal["beginning", "ending"] = "beginning", |
|
max_input_tokens: int = 2500, |
|
) -> str: |
|
"""Reformat inputs such that they do not overflow context size limitation.""" |
|
token_count = self._estimate_tokens(input_text) |
|
if token_count > self.kwargs["max_tokens"]: |
|
logging.info("Excessive prompt found in llm input") |
|
logging.info("Truncating texts to avoid error") |
|
max_chars: int = CHARS2TOKENS * max_input_tokens |
|
truncated_text: str |
|
if remove_beginning_or_ending == "ending": |
|
truncated_text = input_text[0:max_chars] |
|
else: |
|
truncated_text = input_text[-max_chars:] |
|
return truncated_text |
|
return input_text |
|
|
|
def __call__( |
|
self, |
|
prompt: str, |
|
only_completed: bool = True, |
|
return_sorted: bool = False, |
|
**kwargs, |
|
) -> list[str]: |
|
""" |
|
Query the AWS LLM. |
|
|
|
There is only support for only_completed=True and return_sorted=False |
|
right now. |
|
""" |
|
if not only_completed: |
|
raise NotImplementedError("Error, only_completed not yet supported!") |
|
if return_sorted: |
|
raise NotImplementedError("Error, return_sorted not yet supported!") |
|
generated = self.basic_request(prompt, **kwargs) |
|
return [generated] |
|
|