lamhieu's picture
perf: improvement requests handling with async mode
716eebd
raw
history blame
17.8 kB
from __future__ import annotations
import asyncio
import logging
from enum import Enum
from typing import List, Union, Dict, Optional, NamedTuple, Any
from dataclasses import dataclass
from pathlib import Path
from io import BytesIO
from hashlib import md5
from cachetools import LRUCache
import httpx
import numpy as np
import torch
from PIL import Image
from sentence_transformers import SentenceTransformer
from transformers import AutoProcessor, AutoModel
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
class TextModelType(str, Enum):
"""
Enumeration of supported text models.
"""
MULTILINGUAL_E5_SMALL = "multilingual-e5-small"
MULTILINGUAL_E5_BASE = "multilingual-e5-base"
MULTILINGUAL_E5_LARGE = "multilingual-e5-large"
SNOWFLAKE_ARCTIC_EMBED_L_V2 = "snowflake-arctic-embed-l-v2.0"
PARAPHRASE_MULTILINGUAL_MINILM_L12_V2 = "paraphrase-multilingual-MiniLM-L12-v2"
PARAPHRASE_MULTILINGUAL_MPNET_BASE_V2 = "paraphrase-multilingual-mpnet-base-v2"
BGE_M3 = "bge-m3"
GTE_MULTILINGUAL_BASE = "gte-multilingual-base"
class ImageModelType(str, Enum):
"""
Enumeration of supported image models.
"""
SIGLIP_BASE_PATCH16_256_MULTILINGUAL = "siglip-base-patch16-256-multilingual"
class ModelInfo(NamedTuple):
"""
Container mapping a model type to its model identifier and optional ONNX file.
"""
model_id: str
onnx_file: Optional[str] = None
@dataclass
class ModelConfig:
"""
Configuration for text and image models.
"""
text_model_type: TextModelType = TextModelType.MULTILINGUAL_E5_SMALL
image_model_type: ImageModelType = (
ImageModelType.SIGLIP_BASE_PATCH16_256_MULTILINGUAL
)
logit_scale: float = 4.60517 # Example scale used in cross-modal similarity
@property
def text_model_info(self) -> ModelInfo:
"""
Return model information for the configured text model.
"""
text_configs = {
TextModelType.MULTILINGUAL_E5_SMALL: ModelInfo(
model_id="Xenova/multilingual-e5-small",
onnx_file="onnx/model_quantized.onnx",
),
TextModelType.MULTILINGUAL_E5_BASE: ModelInfo(
model_id="Xenova/multilingual-e5-base",
onnx_file="onnx/model_quantized.onnx",
),
TextModelType.MULTILINGUAL_E5_LARGE: ModelInfo(
model_id="Xenova/multilingual-e5-large",
onnx_file="onnx/model_quantized.onnx",
),
TextModelType.SNOWFLAKE_ARCTIC_EMBED_L_V2: ModelInfo(
model_id="Snowflake/snowflake-arctic-embed-l-v2.0",
onnx_file="onnx/model_quantized.onnx",
),
TextModelType.PARAPHRASE_MULTILINGUAL_MINILM_L12_V2: ModelInfo(
model_id="Xenova/paraphrase-multilingual-MiniLM-L12-v2",
onnx_file="onnx/model_quantized.onnx",
),
TextModelType.PARAPHRASE_MULTILINGUAL_MPNET_BASE_V2: ModelInfo(
model_id="Xenova/paraphrase-multilingual-mpnet-base-v2",
onnx_file="onnx/model_quantized.onnx",
),
TextModelType.BGE_M3: ModelInfo(
model_id="Xenova/bge-m3",
onnx_file="onnx/model_quantized.onnx",
),
TextModelType.GTE_MULTILINGUAL_BASE: ModelInfo(
model_id="onnx-community/gte-multilingual-base",
onnx_file="onnx/model_quantized.onnx",
),
}
return text_configs[self.text_model_type]
@property
def image_model_info(self) -> ModelInfo:
"""
Return model information for the configured image model.
"""
image_configs = {
ImageModelType.SIGLIP_BASE_PATCH16_256_MULTILINGUAL: ModelInfo(
model_id="google/siglip-base-patch16-256-multilingual"
),
}
return image_configs[self.image_model_type]
class ModelKind(str, Enum):
"""
Indicates the type of model: text or image.
"""
TEXT = "text"
IMAGE = "image"
def detect_model_kind(model_id: str) -> ModelKind:
"""
Detect whether the model identifier corresponds to a text or image model.
Raises:
ValueError: If the model identifier is unrecognized.
"""
if model_id in [m.value for m in TextModelType]:
return ModelKind.TEXT
elif model_id in [m.value for m in ImageModelType]:
return ModelKind.IMAGE
else:
raise ValueError(
f"Unrecognized model ID: {model_id}.\n"
f"Valid text: {[m.value for m in TextModelType]}\n"
f"Valid image: {[m.value for m in ImageModelType]}"
)
class EmbeddingsService:
"""
Service for generating text/image embeddings and performing similarity ranking.
Asynchronous methods are used to maximize throughput and avoid blocking the event loop.
"""
def __init__(self, config: Optional[ModelConfig] = None):
"""
Initialize the service by setting up model caches, device configuration,
and asynchronous HTTP client.
"""
self.lru_cache = LRUCache(maxsize=10_000)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.config = config or ModelConfig()
# Dictionaries to hold preloaded models.
self.text_models: Dict[TextModelType, SentenceTransformer] = {}
self.image_models: Dict[ImageModelType, AutoModel] = {}
self.image_processors: Dict[ImageModelType, AutoProcessor] = {}
# Create a persistent asynchronous HTTP client.
self.async_http_client = httpx.AsyncClient(timeout=10)
# Preload all models.
self._load_all_models()
def _load_all_models(self) -> None:
"""
Pre-load all text and image models to minimize latency at request time.
"""
try:
# Preload text models.
for t_model_type in TextModelType:
info = ModelConfig(text_model_type=t_model_type).text_model_info
logger.info("Loading text model: %s", info.model_id)
if info.onnx_file:
logger.info("Using ONNX file: %s", info.onnx_file)
self.text_models[t_model_type] = SentenceTransformer(
info.model_id,
device=self.device,
backend="onnx",
model_kwargs={
"provider": "CPUExecutionProvider",
"file_name": info.onnx_file,
},
trust_remote_code=True,
)
else:
self.text_models[t_model_type] = SentenceTransformer(
info.model_id,
device=self.device,
trust_remote_code=True,
)
# Preload image models.
for i_model_type in ImageModelType:
model_id = ModelConfig(
image_model_type=i_model_type
).image_model_info.model_id
logger.info("Loading image model: %s", model_id)
model = AutoModel.from_pretrained(model_id).to(self.device)
model.eval() # Set the model to evaluation mode.
processor = AutoProcessor.from_pretrained(model_id)
self.image_models[i_model_type] = model
self.image_processors[i_model_type] = processor
logger.info("All models loaded successfully.")
except Exception as e:
msg = f"Error loading models: {str(e)}"
logger.error(msg)
raise RuntimeError(msg) from e
@staticmethod
def _validate_text_list(input_text: Union[str, List[str]]) -> List[str]:
"""
Validate and convert text input into a non-empty list of strings.
Raises:
ValueError: If the input is invalid.
"""
if isinstance(input_text, str):
if not input_text.strip():
raise ValueError("Text input cannot be empty.")
return [input_text]
if not isinstance(input_text, list) or not all(
isinstance(x, str) for x in input_text
):
raise ValueError("Text input must be a string or a list of strings.")
if len(input_text) == 0:
raise ValueError("Text input list cannot be empty.")
return input_text
@staticmethod
def _validate_image_list(input_images: Union[str, List[str]]) -> List[str]:
"""
Validate and convert image input into a non-empty list of image paths/URLs.
Raises:
ValueError: If the input is invalid.
"""
if isinstance(input_images, str):
if not input_images.strip():
raise ValueError("Image input cannot be empty.")
return [input_images]
if not isinstance(input_images, list) or not all(
isinstance(x, str) for x in input_images
):
raise ValueError("Image input must be a string or a list of strings.")
if len(input_images) == 0:
raise ValueError("Image input list cannot be empty.")
return input_images
async def _fetch_image(self, path_or_url: str) -> Image.Image:
"""
Asynchronously fetch an image from a URL or load from a local path.
Args:
path_or_url: The URL or file path of the image.
Returns:
A PIL Image in RGB mode.
Raises:
ValueError: If image fetching or processing fails.
"""
try:
if path_or_url.startswith("http"):
# Asynchronously fetch the image bytes.
response = await self.async_http_client.get(path_or_url)
response.raise_for_status()
# Offload the blocking I/O (PIL image opening) to a thread.
img = await asyncio.to_thread(Image.open, BytesIO(response.content))
else:
# Offload file I/O to a thread.
img = await asyncio.to_thread(Image.open, Path(path_or_url))
return img.convert("RGB")
except Exception as e:
raise ValueError(f"Error fetching image '{path_or_url}': {str(e)}") from e
async def _process_image(self, path_or_url: str) -> Dict[str, torch.Tensor]:
"""
Asynchronously load and process a single image.
Args:
path_or_url: The image URL or local path.
Returns:
A dictionary of processed tensors ready for model input.
Raises:
ValueError: If image processing fails.
"""
img = await self._fetch_image(path_or_url)
processor = self.image_processors[self.config.image_model_type]
# Note: Processor may perform CPU-intensive work; if needed, offload to thread.
processed_data = processor(images=img, return_tensors="pt").to(self.device)
return processed_data
def _generate_text_embeddings(
self,
model_id: TextModelType,
texts: List[str],
) -> np.ndarray:
"""
Generate text embeddings using the SentenceTransformer model.
Single-text requests are cached using an LRU cache.
Returns:
A NumPy array of text embeddings.
Raises:
RuntimeError: If text embedding generation fails.
"""
try:
if len(texts) == 1:
single_text = texts[0]
key = md5(f"{model_id}:{single_text}".encode("utf-8")).hexdigest()[:8]
if key in self.lru_cache:
return self.lru_cache[key]
model = self.text_models[model_id]
emb = model.encode([single_text])
self.lru_cache[key] = emb
return emb
model = self.text_models[model_id]
return model.encode(texts)
except Exception as e:
raise RuntimeError(
f"Error generating text embeddings with model '{model_id}': {e}"
) from e
async def _async_generate_image_embeddings(
self,
model_id: ImageModelType,
images: List[str],
) -> np.ndarray:
"""
Asynchronously generate image embeddings.
This method concurrently processes multiple images and offloads
the blocking model inference to a separate thread.
Returns:
A NumPy array of image embeddings.
Raises:
RuntimeError: If image embedding generation fails.
"""
try:
# Concurrently process all images.
processed_tensors = await asyncio.gather(
*[self._process_image(img_path) for img_path in images]
)
# Assume all processed outputs have the same keys.
keys = processed_tensors[0].keys()
combined = {
k: torch.cat([pt[k] for pt in processed_tensors], dim=0) for k in keys
}
def infer():
with torch.no_grad():
embeddings = self.image_models[model_id].get_image_features(
**combined
)
return embeddings.cpu().numpy()
return await asyncio.to_thread(infer)
except Exception as e:
raise RuntimeError(
f"Error generating image embeddings with model '{model_id}': {e}"
) from e
async def generate_embeddings(
self,
model: str,
inputs: Union[str, List[str]],
) -> np.ndarray:
"""
Asynchronously generate embeddings for text or image inputs based on model type.
Args:
model: The model identifier.
inputs: The text or image input(s).
Returns:
A NumPy array of embeddings.
"""
modality = detect_model_kind(model)
if modality == ModelKind.TEXT:
text_model_id = TextModelType(model)
text_list = self._validate_text_list(inputs)
return await asyncio.to_thread(
self._generate_text_embeddings, text_model_id, text_list
)
elif modality == ModelKind.IMAGE:
image_model_id = ImageModelType(model)
image_list = self._validate_image_list(inputs)
return await self._async_generate_image_embeddings(
image_model_id, image_list
)
async def rank(
self,
model: str,
queries: Union[str, List[str]],
candidates: Union[str, List[str]],
) -> Dict[str, Any]:
"""
Asynchronously rank candidate texts/images against the provided queries.
Embeddings for queries and candidates are generated concurrently.
Returns:
A dictionary containing probabilities, cosine similarities, and usage statistics.
"""
modality = detect_model_kind(model)
if modality == ModelKind.TEXT:
model_enum = TextModelType(model)
else:
model_enum = ImageModelType(model)
# Concurrently generate embeddings.
query_task = asyncio.create_task(self.generate_embeddings(model, queries))
candidate_task = asyncio.create_task(
self.generate_embeddings(model, candidates)
)
query_embeds, candidate_embeds = await asyncio.gather(
query_task, candidate_task
)
# Compute cosine similarity.
sim_matrix = self.cosine_similarity(query_embeds, candidate_embeds)
scaled = np.exp(self.config.logit_scale) * sim_matrix
probs = self.softmax(scaled)
if modality == ModelKind.TEXT:
query_tokens = self.estimate_tokens(queries)
candidate_tokens = self.estimate_tokens(candidates)
total_tokens = query_tokens + candidate_tokens
else:
total_tokens = 0
usage = {
"prompt_tokens": total_tokens,
"total_tokens": total_tokens,
}
return {
"probabilities": probs.tolist(),
"cosine_similarities": sim_matrix.tolist(),
"usage": usage,
}
def estimate_tokens(self, input_data: Union[str, List[str]]) -> int:
"""
Estimate the token count for the given text input using the SentenceTransformer tokenizer.
Returns:
The total number of tokens.
"""
texts = self._validate_text_list(input_data)
model = self.text_models[self.config.text_model_type]
tokenized = model.tokenize(texts)
return sum(len(ids) for ids in tokenized["input_ids"])
@staticmethod
def softmax(scores: np.ndarray) -> np.ndarray:
"""
Compute the softmax over the last dimension of the input array.
Returns:
The softmax probabilities.
"""
exps = np.exp(scores - np.max(scores, axis=-1, keepdims=True))
return exps / np.sum(exps, axis=-1, keepdims=True)
@staticmethod
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> np.ndarray:
"""
Compute the pairwise cosine similarity between all rows of arrays a and b.
Returns:
A (N x M) matrix of cosine similarities.
"""
a_norm = a / (np.linalg.norm(a, axis=1, keepdims=True) + 1e-9)
b_norm = b / (np.linalg.norm(b, axis=1, keepdims=True) + 1e-9)
return np.dot(a_norm, b_norm.T)
async def close(self) -> None:
"""
Close the asynchronous HTTP client.
"""
await self.async_http_client.aclose()