Spaces:
Paused
Paused
| """ | |
| Translates from OpenAI's `/v1/chat/completions` endpoint to Triton's `/generate` endpoint. | |
| """ | |
| import json | |
| from typing import Any, AsyncIterator, Dict, Iterator, List, Literal, Optional, Union | |
| from httpx import Headers, Response | |
| from litellm.constants import DEFAULT_MAX_TOKENS_FOR_TRITON | |
| from litellm.litellm_core_utils.prompt_templates.factory import prompt_factory | |
| from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator | |
| from litellm.llms.base_llm.chat.transformation import ( | |
| BaseConfig, | |
| BaseLLMException, | |
| LiteLLMLoggingObj, | |
| ) | |
| from litellm.types.llms.openai import AllMessageValues | |
| from litellm.types.utils import ( | |
| ChatCompletionToolCallChunk, | |
| ChatCompletionUsageBlock, | |
| Choices, | |
| GenericStreamingChunk, | |
| Message, | |
| ModelResponse, | |
| ) | |
| from ..common_utils import TritonError | |
| class TritonConfig(BaseConfig): | |
| """ | |
| Base class for Triton configurations. | |
| Handles routing between /infer and /generate triton completion llms | |
| """ | |
| def get_error_class( | |
| self, error_message: str, status_code: int, headers: Union[Dict, Headers] | |
| ) -> BaseLLMException: | |
| return TritonError( | |
| status_code=status_code, message=error_message, headers=headers | |
| ) | |
| def validate_environment( | |
| self, | |
| headers: Dict, | |
| model: str, | |
| messages: List[AllMessageValues], | |
| optional_params: Dict, | |
| litellm_params: dict, | |
| api_key: Optional[str] = None, | |
| api_base: Optional[str] = None, | |
| ) -> Dict: | |
| return {"Content-Type": "application/json"} | |
| def get_supported_openai_params(self, model: str) -> List: | |
| return ["max_tokens", "max_completion_tokens"] | |
| def map_openai_params( | |
| self, | |
| non_default_params: Dict, | |
| optional_params: Dict, | |
| model: str, | |
| drop_params: bool, | |
| ) -> Dict: | |
| for param, value in non_default_params.items(): | |
| if param == "max_tokens" or param == "max_completion_tokens": | |
| optional_params[param] = value | |
| return optional_params | |
| def get_complete_url( | |
| self, | |
| api_base: Optional[str], | |
| api_key: Optional[str], | |
| model: str, | |
| optional_params: dict, | |
| litellm_params: dict, | |
| stream: Optional[bool] = None, | |
| ) -> str: | |
| if api_base is None: | |
| raise ValueError("api_base is required") | |
| llm_type = self._get_triton_llm_type(api_base) | |
| if llm_type == "generate" and stream: | |
| return api_base + "_stream" | |
| return api_base | |
| def transform_response( | |
| self, | |
| model: str, | |
| raw_response: Response, | |
| model_response: ModelResponse, | |
| logging_obj: LiteLLMLoggingObj, | |
| request_data: Dict, | |
| messages: List[AllMessageValues], | |
| optional_params: Dict, | |
| litellm_params: Dict, | |
| encoding: Any, | |
| api_key: Optional[str] = None, | |
| json_mode: Optional[bool] = None, | |
| ) -> ModelResponse: | |
| api_base = litellm_params.get("api_base", "") | |
| llm_type = self._get_triton_llm_type(api_base) | |
| if llm_type == "generate": | |
| return TritonGenerateConfig().transform_response( | |
| model=model, | |
| raw_response=raw_response, | |
| model_response=model_response, | |
| logging_obj=logging_obj, | |
| request_data=request_data, | |
| messages=messages, | |
| optional_params=optional_params, | |
| litellm_params=litellm_params, | |
| encoding=encoding, | |
| api_key=api_key, | |
| json_mode=json_mode, | |
| ) | |
| elif llm_type == "infer": | |
| return TritonInferConfig().transform_response( | |
| model=model, | |
| raw_response=raw_response, | |
| model_response=model_response, | |
| logging_obj=logging_obj, | |
| request_data=request_data, | |
| messages=messages, | |
| optional_params=optional_params, | |
| litellm_params=litellm_params, | |
| encoding=encoding, | |
| api_key=api_key, | |
| json_mode=json_mode, | |
| ) | |
| return model_response | |
| def transform_request( | |
| self, | |
| model: str, | |
| messages: List[AllMessageValues], | |
| optional_params: dict, | |
| litellm_params: dict, | |
| headers: dict, | |
| ) -> dict: | |
| api_base = litellm_params.get("api_base", "") | |
| llm_type = self._get_triton_llm_type(api_base) | |
| if llm_type == "generate": | |
| return TritonGenerateConfig().transform_request( | |
| model=model, | |
| messages=messages, | |
| optional_params=optional_params, | |
| litellm_params=litellm_params, | |
| headers=headers, | |
| ) | |
| elif llm_type == "infer": | |
| return TritonInferConfig().transform_request( | |
| model=model, | |
| messages=messages, | |
| optional_params=optional_params, | |
| litellm_params=litellm_params, | |
| headers=headers, | |
| ) | |
| return {} | |
| def _get_triton_llm_type(self, api_base: str) -> Literal["generate", "infer"]: | |
| if api_base.endswith("/generate"): | |
| return "generate" | |
| elif api_base.endswith("/infer"): | |
| return "infer" | |
| else: | |
| raise ValueError(f"Invalid Triton API base: {api_base}") | |
| def get_model_response_iterator( | |
| self, | |
| streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse], | |
| sync_stream: bool, | |
| json_mode: Optional[bool] = False, | |
| ) -> Any: | |
| return TritonResponseIterator( | |
| streaming_response=streaming_response, | |
| sync_stream=sync_stream, | |
| json_mode=json_mode, | |
| ) | |
| class TritonGenerateConfig(TritonConfig): | |
| """ | |
| Transformations for triton /generate endpoint (This is a trtllm model) | |
| """ | |
| def transform_request( | |
| self, | |
| model: str, | |
| messages: List[AllMessageValues], | |
| optional_params: dict, | |
| litellm_params: dict, | |
| headers: dict, | |
| ) -> dict: | |
| inference_params = optional_params.copy() | |
| stream = inference_params.pop("stream", False) | |
| data_for_triton: Dict[str, Any] = { | |
| "text_input": prompt_factory(model=model, messages=messages), | |
| "parameters": { | |
| "max_tokens": int( | |
| optional_params.get("max_tokens", DEFAULT_MAX_TOKENS_FOR_TRITON) | |
| ), | |
| }, | |
| "stream": bool(stream), | |
| } | |
| data_for_triton["parameters"].update(inference_params) | |
| return data_for_triton | |
| def transform_response( | |
| self, | |
| model: str, | |
| raw_response: Response, | |
| model_response: ModelResponse, | |
| logging_obj: LiteLLMLoggingObj, | |
| request_data: Dict, | |
| messages: List[AllMessageValues], | |
| optional_params: Dict, | |
| litellm_params: Dict, | |
| encoding: Any, | |
| api_key: Optional[str] = None, | |
| json_mode: Optional[bool] = None, | |
| ) -> ModelResponse: | |
| try: | |
| raw_response_json = raw_response.json() | |
| except Exception: | |
| raise TritonError( | |
| message=raw_response.text, status_code=raw_response.status_code | |
| ) | |
| model_response.choices = [ | |
| Choices(index=0, message=Message(content=raw_response_json["text_output"])) | |
| ] | |
| return model_response | |
| class TritonInferConfig(TritonConfig): | |
| """ | |
| Transformations for triton /infer endpoint (his is an infer model with a custom model on triton) | |
| """ | |
| def transform_request( | |
| self, | |
| model: str, | |
| messages: List[AllMessageValues], | |
| optional_params: dict, | |
| litellm_params: dict, | |
| headers: dict, | |
| ) -> dict: | |
| text_input = messages[0].get("content", "") | |
| data_for_triton = { | |
| "inputs": [ | |
| { | |
| "name": "text_input", | |
| "shape": [1], | |
| "datatype": "BYTES", | |
| "data": [text_input], | |
| } | |
| ] | |
| } | |
| for k, v in optional_params.items(): | |
| if not (k == "stream" or k == "max_retries"): | |
| datatype = "INT32" if isinstance(v, int) else "BYTES" | |
| datatype = "FP32" if isinstance(v, float) else datatype | |
| data_for_triton["inputs"].append( | |
| {"name": k, "shape": [1], "datatype": datatype, "data": [v]} | |
| ) | |
| if "max_tokens" not in optional_params: | |
| data_for_triton["inputs"].append( | |
| { | |
| "name": "max_tokens", | |
| "shape": [1], | |
| "datatype": "INT32", | |
| "data": [20], | |
| } | |
| ) | |
| return data_for_triton | |
| def transform_response( | |
| self, | |
| model: str, | |
| raw_response: Response, | |
| model_response: ModelResponse, | |
| logging_obj: LiteLLMLoggingObj, | |
| request_data: Dict, | |
| messages: List[AllMessageValues], | |
| optional_params: Dict, | |
| litellm_params: Dict, | |
| encoding: Any, | |
| api_key: Optional[str] = None, | |
| json_mode: Optional[bool] = None, | |
| ) -> ModelResponse: | |
| try: | |
| raw_response_json = raw_response.json() | |
| except Exception: | |
| raise TritonError( | |
| message=raw_response.text, status_code=raw_response.status_code | |
| ) | |
| _triton_response_data = raw_response_json["outputs"][0]["data"] | |
| triton_response_data: Optional[str] = None | |
| if isinstance(_triton_response_data, list): | |
| triton_response_data = "".join(_triton_response_data) | |
| else: | |
| triton_response_data = _triton_response_data | |
| model_response.choices = [ | |
| Choices( | |
| index=0, | |
| message=Message(content=triton_response_data), | |
| ) | |
| ] | |
| return model_response | |
| class TritonResponseIterator(BaseModelResponseIterator): | |
| def chunk_parser(self, chunk: dict) -> GenericStreamingChunk: | |
| try: | |
| text = "" | |
| tool_use: Optional[ChatCompletionToolCallChunk] = None | |
| is_finished = False | |
| finish_reason = "" | |
| usage: Optional[ChatCompletionUsageBlock] = None | |
| provider_specific_fields = None | |
| index = int(chunk.get("index", 0)) | |
| # set values | |
| text = chunk.get("text_output", "") | |
| finish_reason = chunk.get("stop_reason", "") | |
| is_finished = chunk.get("is_finished", False) | |
| return GenericStreamingChunk( | |
| text=text, | |
| tool_use=tool_use, | |
| is_finished=is_finished, | |
| finish_reason=finish_reason, | |
| usage=usage, | |
| index=index, | |
| provider_specific_fields=provider_specific_fields, | |
| ) | |
| except json.JSONDecodeError: | |
| raise ValueError(f"Failed to decode JSON from chunk: {chunk}") | |