Spaces:
Sleeping
Sleeping
from lpm_kernel.api.services.user_llm_config_service import UserLLMConfigService | |
from lpm_kernel.configs.config import Config | |
from typing import List, Union | |
import requests | |
import numpy as np | |
from lpm_kernel.configs.logging import get_train_process_logger | |
logger = get_train_process_logger() | |
import lpm_kernel.common.strategy.classification as classification | |
from sentence_transformers import SentenceTransformer | |
import json | |
class EmbeddingError(Exception): | |
"""Custom exception class for embedding-related errors""" | |
def __init__(self, message, original_error=None): | |
super().__init__(message) | |
self.original_error = original_error | |
class LLMClient: | |
"""LLM client utility class""" | |
def __init__(self): | |
self.config = Config.from_env() | |
self.user_llm_config_service = UserLLMConfigService() | |
self.embedding_max_text_length = int(self.config.get('EMBEDDING_MAX_TEXT_LENGTH', 8000)) | |
# self.user_llm_config = self.user_llm_config_service.get_available_llm() | |
# self.chat_api_key = self.user_llm_config.chat_api_key | |
# self.chat_base_url = self.user_llm_config.chat_endpoint | |
# self.chat_model = self.user_llm_config.chat_model_name | |
# self.embedding_api_key = self.user_llm_config.embedding_api_key | |
# self.embedding_base_url = self.user_llm_config.embedding_endpoint | |
# self.embedding_model = self.user_llm_config.embedding_model_name | |
def get_embedding(self, texts: Union[str, List[str]]) -> np.ndarray: | |
"""Calculate text embedding | |
Args: | |
texts (str or list): Input text or list of texts | |
Returns: | |
numpy.ndarray: Embedding vector of the text | |
""" | |
# Ensure texts is in list format | |
if isinstance(texts, str): | |
texts = [texts] | |
# Split long texts into chunks using configured max length | |
chunked_texts = [] | |
text_chunk_counts = [] # Keep track of how many chunks each text was split into | |
for text in texts: | |
if len(text) > self.embedding_max_text_length: | |
# Split into chunks of embedding_max_text_length | |
chunks = [text[i:i + self.embedding_max_text_length] | |
for i in range(0, len(text), self.embedding_max_text_length)] | |
chunked_texts.extend(chunks) | |
text_chunk_counts.append(len(chunks)) | |
else: | |
chunked_texts.append(text) | |
text_chunk_counts.append(1) | |
user_llm_config = self.user_llm_config_service.get_available_llm() | |
if not user_llm_config: | |
raise EmbeddingError("No LLM configuration found") | |
try: | |
# Send request to embedding endpoint | |
embeddings_array = classification.strategy_classification(user_llm_config, chunked_texts) | |
# If we split any texts, we need to merge their embeddings back | |
if sum(text_chunk_counts) > len(texts): | |
final_embeddings = [] | |
start_idx = 0 | |
for chunk_count in text_chunk_counts: | |
if chunk_count > 1: | |
# Average embeddings for split text | |
chunk_embeddings = embeddings_array[start_idx:start_idx + chunk_count] | |
avg_embedding = np.mean(chunk_embeddings, axis=0) | |
final_embeddings.append(avg_embedding) | |
else: | |
final_embeddings.append(embeddings_array[start_idx]) | |
start_idx += chunk_count | |
return np.array(final_embeddings) | |
return embeddings_array | |
except requests.exceptions.RequestException as e: | |
# Handle request errors | |
error_msg = f"Request error getting embeddings: {str(e)}" | |
logger.error(error_msg) | |
raise EmbeddingError(error_msg, e) | |
except json.JSONDecodeError as e: | |
# Handle JSON parsing errors | |
error_msg = f"Invalid JSON response from embedding API: {str(e)}" | |
logger.error(error_msg) | |
raise EmbeddingError(error_msg, e) | |
except (KeyError, IndexError, ValueError) as e: | |
# Handle response structure errors | |
error_msg = f"Invalid response structure from embedding API: {str(e)}" | |
logger.error(error_msg) | |
raise EmbeddingError(error_msg, e) | |
except Exception as e: | |
# Fallback for any other errors | |
error_msg = f"Unexpected error getting embeddings: {str(e)}" | |
logger.error(error_msg, exc_info=True) | |
raise EmbeddingError(error_msg, e) | |
def chat_credentials(self): | |
"""Get LLM authentication information""" | |
return {"api_key": self.chat_api_key, "base_url": self.chat_base_url} | |