Spaces:
Runtime error
Runtime error
import json | |
from typing import Any, Dict, Optional | |
from loguru import logger | |
try: | |
import boto3 | |
except ModuleNotFoundError: | |
logger.warning("Couldn't load AWS or SageMaker imports. Run 'poetry install --with aws' to support AWS.") | |
from llm_engineering.domain.inference import Inference | |
from llm_engineering.settings import settings | |
class LLMInferenceSagemakerEndpoint(Inference): | |
""" | |
Class for performing inference using a SageMaker endpoint for LLM schemas. | |
""" | |
def __init__( | |
self, | |
endpoint_name: str, | |
default_payload: Optional[Dict[str, Any]] = None, | |
inference_component_name: Optional[str] = None, | |
) -> None: | |
super().__init__() | |
self.client = boto3.client( | |
"sagemaker-runtime", | |
region_name=settings.AWS_REGION, | |
aws_access_key_id=settings.AWS_ACCESS_KEY, | |
aws_secret_access_key=settings.AWS_SECRET_KEY, | |
) | |
self.endpoint_name = endpoint_name | |
self.payload = default_payload if default_payload else self._default_payload() | |
self.inference_component_name = inference_component_name | |
def _default_payload(self) -> Dict[str, Any]: | |
""" | |
Generates the default payload for the inference request. | |
Returns: | |
dict: The default payload. | |
""" | |
return { | |
"inputs": "How is the weather?", | |
"parameters": { | |
"max_new_tokens": settings.MAX_NEW_TOKENS_INFERENCE, | |
"top_p": settings.TOP_P_INFERENCE, | |
"temperature": settings.TEMPERATURE_INFERENCE, | |
"return_full_text": False, | |
}, | |
} | |
def set_payload(self, inputs: str, parameters: Optional[Dict[str, Any]] = None) -> None: | |
""" | |
Sets the payload for the inference request. | |
Args: | |
inputs (str): The input text for the inference. | |
parameters (dict, optional): Additional parameters for the inference. Defaults to None. | |
""" | |
self.payload["inputs"] = inputs | |
if parameters: | |
self.payload["parameters"].update(parameters) | |
def inference(self) -> Dict[str, Any]: | |
""" | |
Performs the inference request using the SageMaker endpoint. | |
Returns: | |
dict: The response from the inference request. | |
Raises: | |
Exception: If an error occurs during the inference request. | |
""" | |
try: | |
logger.info("Inference request sent.") | |
invoke_args = { | |
"EndpointName": self.endpoint_name, | |
"ContentType": "application/json", | |
"Body": json.dumps(self.payload), | |
} | |
if self.inference_component_name not in ["None", None]: | |
invoke_args["InferenceComponentName"] = self.inference_component_name | |
response = self.client.invoke_endpoint(**invoke_args) | |
response_body = response["Body"].read().decode("utf8") | |
return json.loads(response_body) | |
except Exception: | |
logger.exception("SageMaker inference failed.") | |
raise | |