Spaces:
Runtime error
Runtime error
import json | |
import numpy as np | |
from app.core.embedding import generate_embedding | |
from app.core.db_setup import conversations_collection | |
from app.core.logging_setup import logger | |
# Removed DB_PATH as it is no longer needed | |
# Update find_similar_conversations to use MongoDB | |
def find_similar_conversations(query_message): | |
"""Find similar conversations based on embedding similarity.""" | |
query_embedding = generate_embedding(query_message) | |
if query_embedding is None: | |
return [] | |
try: | |
conversations = list(conversations_collection.find({}, {"message": 1, "embedding": 1, "_id": 0})) | |
except Exception as e: | |
logger.error(f"Unexpected error: {e}") | |
return [] | |
similarities = [] | |
for conversation in conversations: | |
message = conversation["message"] | |
embedding = json.loads(conversation["embedding"]) | |
similarity = np.dot(query_embedding, embedding) / (np.linalg.norm(query_embedding) * np.linalg.norm(embedding)) | |
similarities.append((message, similarity)) | |
similarities.sort(key=lambda x: x[1], reverse=True) | |
return [msg for msg, _ in similarities[:5]] | |
# Update get_similar_conversations to use MongoDB | |
def get_similar_conversations(user_input, top_n=3): | |
"""Retrieve past relevant messages using semantic similarity.""" | |
user_embedding = generate_embedding(user_input) | |
if user_embedding is None: | |
return [] | |
try: | |
conversations = list(conversations_collection.find({}, {"message": 1, "embedding": 1, "_id": 0})) | |
except Exception as e: | |
logger.error(f"Unexpected error: {e}") | |
return [] | |
similarities = [] | |
for conversation in conversations: | |
message = conversation["message"] | |
embedding = json.loads(conversation["embedding"]) | |
# Ensure user_embedding is a numpy array | |
user_embedding = np.array(user_embedding) | |
# Ensure embedding is a numpy array | |
embedding = np.array(embedding) | |
# Validate embedding dimensions | |
if user_embedding.shape[0] != len(embedding): | |
logger.warning(f"Skipping embedding due to dimension mismatch: user_embedding {user_embedding.shape[0]} vs embedding {len(embedding)}") | |
continue | |
# Adjust embedding dimensions to match user_embedding | |
if len(embedding) < user_embedding.shape[0]: | |
embedding = np.pad(embedding, (0, user_embedding.shape[0] - len(embedding)), mode='constant') | |
elif len(embedding) > user_embedding.shape[0]: | |
embedding = embedding[:user_embedding.shape[0]] | |
similarity = np.dot(user_embedding, embedding) / (np.linalg.norm(user_embedding) * np.linalg.norm(embedding)) | |
similarities.append((message, similarity)) | |
similarities.sort(key=lambda x: x[1], reverse=True) | |
return [msg for msg, _ in similarities[:top_n]] |