arina-hf-spaces-api / app /core /conversation_retrieval.py
adsurkasur's picture
clone from arina-hf-spaces
68964c2
import json
import numpy as np
from app.core.embedding import generate_embedding
from app.core.db_setup import conversations_collection
from app.core.logging_setup import logger
# Removed DB_PATH as it is no longer needed
# Update find_similar_conversations to use MongoDB
def find_similar_conversations(query_message):
"""Find similar conversations based on embedding similarity."""
query_embedding = generate_embedding(query_message)
if query_embedding is None:
return []
try:
conversations = list(conversations_collection.find({}, {"message": 1, "embedding": 1, "_id": 0}))
except Exception as e:
logger.error(f"Unexpected error: {e}")
return []
similarities = []
for conversation in conversations:
message = conversation["message"]
embedding = json.loads(conversation["embedding"])
similarity = np.dot(query_embedding, embedding) / (np.linalg.norm(query_embedding) * np.linalg.norm(embedding))
similarities.append((message, similarity))
similarities.sort(key=lambda x: x[1], reverse=True)
return [msg for msg, _ in similarities[:5]]
# Update get_similar_conversations to use MongoDB
def get_similar_conversations(user_input, top_n=3):
"""Retrieve past relevant messages using semantic similarity."""
user_embedding = generate_embedding(user_input)
if user_embedding is None:
return []
try:
conversations = list(conversations_collection.find({}, {"message": 1, "embedding": 1, "_id": 0}))
except Exception as e:
logger.error(f"Unexpected error: {e}")
return []
similarities = []
for conversation in conversations:
message = conversation["message"]
embedding = json.loads(conversation["embedding"])
# Ensure user_embedding is a numpy array
user_embedding = np.array(user_embedding)
# Ensure embedding is a numpy array
embedding = np.array(embedding)
# Validate embedding dimensions
if user_embedding.shape[0] != len(embedding):
logger.warning(f"Skipping embedding due to dimension mismatch: user_embedding {user_embedding.shape[0]} vs embedding {len(embedding)}")
continue
# Adjust embedding dimensions to match user_embedding
if len(embedding) < user_embedding.shape[0]:
embedding = np.pad(embedding, (0, user_embedding.shape[0] - len(embedding)), mode='constant')
elif len(embedding) > user_embedding.shape[0]:
embedding = embedding[:user_embedding.shape[0]]
similarity = np.dot(user_embedding, embedding) / (np.linalg.norm(user_embedding) * np.linalg.norm(embedding))
similarities.append((message, similarity))
similarities.sort(key=lambda x: x[1], reverse=True)
return [msg for msg, _ in similarities[:top_n]]