Gemini
feat: add detailed logging
01d5a5d
from lpm_kernel.api.services.user_llm_config_service import UserLLMConfigService
from lpm_kernel.configs.config import Config
from typing import List, Union
import requests
import numpy as np
from lpm_kernel.configs.logging import get_train_process_logger
logger = get_train_process_logger()
import lpm_kernel.common.strategy.classification as classification
from sentence_transformers import SentenceTransformer
import json
class EmbeddingError(Exception):
"""Custom exception class for embedding-related errors"""
def __init__(self, message, original_error=None):
super().__init__(message)
self.original_error = original_error
class LLMClient:
"""LLM client utility class"""
def __init__(self):
self.config = Config.from_env()
self.user_llm_config_service = UserLLMConfigService()
self.embedding_max_text_length = int(self.config.get('EMBEDDING_MAX_TEXT_LENGTH', 8000))
# self.user_llm_config = self.user_llm_config_service.get_available_llm()
# self.chat_api_key = self.user_llm_config.chat_api_key
# self.chat_base_url = self.user_llm_config.chat_endpoint
# self.chat_model = self.user_llm_config.chat_model_name
# self.embedding_api_key = self.user_llm_config.embedding_api_key
# self.embedding_base_url = self.user_llm_config.embedding_endpoint
# self.embedding_model = self.user_llm_config.embedding_model_name
def get_embedding(self, texts: Union[str, List[str]]) -> np.ndarray:
"""Calculate text embedding
Args:
texts (str or list): Input text or list of texts
Returns:
numpy.ndarray: Embedding vector of the text
"""
# Ensure texts is in list format
if isinstance(texts, str):
texts = [texts]
# Split long texts into chunks using configured max length
chunked_texts = []
text_chunk_counts = [] # Keep track of how many chunks each text was split into
for text in texts:
if len(text) > self.embedding_max_text_length:
# Split into chunks of embedding_max_text_length
chunks = [text[i:i + self.embedding_max_text_length]
for i in range(0, len(text), self.embedding_max_text_length)]
chunked_texts.extend(chunks)
text_chunk_counts.append(len(chunks))
else:
chunked_texts.append(text)
text_chunk_counts.append(1)
user_llm_config = self.user_llm_config_service.get_available_llm()
if not user_llm_config:
raise EmbeddingError("No LLM configuration found")
try:
# Send request to embedding endpoint
embeddings_array = classification.strategy_classification(user_llm_config, chunked_texts)
# If we split any texts, we need to merge their embeddings back
if sum(text_chunk_counts) > len(texts):
final_embeddings = []
start_idx = 0
for chunk_count in text_chunk_counts:
if chunk_count > 1:
# Average embeddings for split text
chunk_embeddings = embeddings_array[start_idx:start_idx + chunk_count]
avg_embedding = np.mean(chunk_embeddings, axis=0)
final_embeddings.append(avg_embedding)
else:
final_embeddings.append(embeddings_array[start_idx])
start_idx += chunk_count
return np.array(final_embeddings)
return embeddings_array
except requests.exceptions.RequestException as e:
# Handle request errors
error_msg = f"Request error getting embeddings: {str(e)}"
logger.error(error_msg)
raise EmbeddingError(error_msg, e)
except json.JSONDecodeError as e:
# Handle JSON parsing errors
error_msg = f"Invalid JSON response from embedding API: {str(e)}"
logger.error(error_msg)
raise EmbeddingError(error_msg, e)
except (KeyError, IndexError, ValueError) as e:
# Handle response structure errors
error_msg = f"Invalid response structure from embedding API: {str(e)}"
logger.error(error_msg)
raise EmbeddingError(error_msg, e)
except Exception as e:
# Fallback for any other errors
error_msg = f"Unexpected error getting embeddings: {str(e)}"
logger.error(error_msg, exc_info=True)
raise EmbeddingError(error_msg, e)
@property
def chat_credentials(self):
"""Get LLM authentication information"""
return {"api_key": self.chat_api_key, "base_url": self.chat_base_url}