oncall-guide-ai / tests /test_embedding_validation.py
YanBoChen
Remove obsolete embedding and index files; add comprehensive embedding test analysis and validation suite
775f8ea
raw
history blame
9.27 kB
"""
Test suite for validating embeddings and ANNOY functionality.
This module ensures the quality of embeddings and the correctness of ANNOY search.
"""
import numpy as np
import json
import logging
import os
from pathlib import Path
from typing import Tuple, List, Optional
from annoy import AnnoyIndex
from sentence_transformers import SentenceTransformer
class TestEmbeddingValidation:
def setup_class(self):
"""Initialize test environment with necessary data and models."""
# Setup logging
logging.basicConfig(
level=logging.DEBUG,
format='%(asctime)s - %(levelname)s - %(message)s',
filename='embedding_validation.log'
)
self.logger = logging.getLogger(__name__)
# Define base paths
self.project_root = Path(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
self.models_dir = self.project_root / "models"
self.embeddings_dir = self.models_dir / "embeddings"
self.indices_dir = self.models_dir / "indices" / "annoy"
self.logger.info(f"Project root: {self.project_root}")
self.logger.info(f"Models directory: {self.models_dir}")
self.logger.info(f"Embeddings directory: {self.embeddings_dir}")
try:
# Check directory existence
if not self.embeddings_dir.exists():
raise FileNotFoundError(f"Embeddings directory not found at: {self.embeddings_dir}")
if not self.indices_dir.exists():
raise FileNotFoundError(f"Indices directory not found at: {self.indices_dir}")
# Load embeddings
self.emergency_emb = np.load(self.embeddings_dir / "emergency_embeddings.npy")
self.treatment_emb = np.load(self.embeddings_dir / "treatment_embeddings.npy")
# Load chunks
with open(self.embeddings_dir / "emergency_chunks.json", 'r') as f:
self.emergency_chunks = json.load(f)
with open(self.embeddings_dir / "treatment_chunks.json", 'r') as f:
self.treatment_chunks = json.load(f)
# Initialize model
self.model = SentenceTransformer("NeuML/pubmedbert-base-embeddings")
self.logger.info("Test environment initialized successfully")
self.logger.info(f"Emergency embeddings shape: {self.emergency_emb.shape}")
self.logger.info(f"Treatment embeddings shape: {self.treatment_emb.shape}")
except FileNotFoundError as e:
self.logger.error(f"File not found: {e}")
raise
except Exception as e:
self.logger.error(f"Error during initialization: {e}")
raise
def _safe_search(
self,
index: AnnoyIndex,
query_vector: np.ndarray,
k: int = 5
) -> Tuple[Optional[List[int]], Optional[List[float]]]:
"""Safe search wrapper with error handling"""
try:
indices, distances = index.get_nns_by_vector(
query_vector, k, include_distances=True
)
self.logger.debug(f"Search successful: found {len(indices)} results")
return indices, distances
except Exception as e:
self.logger.error(f"Search failed: {str(e)}")
return None, None
def test_embedding_dimensions(self):
"""Test embedding dimensions and data quality."""
self.logger.info("\n=== Embedding Validation Report ===")
try:
# Basic dimension checks
assert self.emergency_emb.shape[1] == 768, "Emergency embedding dimension should be 768"
assert self.treatment_emb.shape[1] == 768, "Treatment embedding dimension should be 768"
# Count verification
assert len(self.emergency_chunks) == self.emergency_emb.shape[0], \
"Emergency chunks count mismatch"
assert len(self.treatment_chunks) == self.treatment_emb.shape[0], \
"Treatment chunks count mismatch"
# Data quality checks
for name, emb in [("Emergency", self.emergency_emb),
("Treatment", self.treatment_emb)]:
# Check for NaN and Inf
assert not np.isnan(emb).any(), f"{name} contains NaN values"
assert not np.isinf(emb).any(), f"{name} contains Inf values"
# Value distribution analysis
self.logger.info(f"\n{name} Embeddings Statistics:")
self.logger.info(f"- Range: {np.min(emb):.3f} to {np.max(emb):.3f}")
self.logger.info(f"- Mean: {np.mean(emb):.3f}")
self.logger.info(f"- Std: {np.std(emb):.3f}")
self.logger.info("\n✅ All embedding validations passed")
except AssertionError as e:
self.logger.error(f"Validation failed: {str(e)}")
raise
def test_multiple_known_item_search(self):
"""Test ANNOY search with multiple random samples."""
self.logger.info("\n=== Multiple Known-Item Search Test ===")
emergency_index = AnnoyIndex(768, 'angular')
emergency_index.load(str(self.indices_dir / "emergency_index.ann"))
# Test 20 random samples
test_indices = np.random.choice(
self.emergency_emb.shape[0],
size=20,
replace=False
)
success_count = 0
for test_idx in test_indices:
try:
test_emb = self.emergency_emb[test_idx]
indices, distances = self._safe_search(emergency_index, test_emb)
if indices is None:
continue
# Verify self-retrieval
assert indices[0] == test_idx, f"Self-retrieval failed for index {test_idx}"
assert distances[0] < 0.0001, f"Self-distance too large for index {test_idx}"
success_count += 1
except AssertionError as e:
self.logger.warning(f"Test failed for index {test_idx}: {str(e)}")
self.logger.info(f"\n✅ {success_count}/20 self-retrieval tests passed")
assert success_count >= 18, "Less than 90% of self-retrieval tests passed"
def test_balanced_cross_dataset_search(self):
"""Test search across both emergency and treatment datasets."""
self.logger.info("\n=== Balanced Cross-Dataset Search Test ===")
# Initialize indices
emergency_index = AnnoyIndex(768, 'angular')
treatment_index = AnnoyIndex(768, 'angular')
try:
emergency_index.load(str(self.indices_dir / "emergency_index.ann"))
treatment_index.load(str(self.indices_dir / "treatment_index.ann"))
# Test queries
test_queries = [
"What is the treatment protocol for acute myocardial infarction?",
"How to manage severe chest pain with difficulty breathing?",
"What are the emergency procedures for anaphylactic shock?"
]
for query in test_queries:
print(f"\n\n=== Query: {query} ===")
# Generate query vector
query_emb = self.model.encode([query])[0]
# Get top-5 results from each dataset
e_indices, e_distances = self._safe_search(emergency_index, query_emb, k=5)
t_indices, t_distances = self._safe_search(treatment_index, query_emb, k=5)
if None in [e_indices, e_distances, t_indices, t_distances]:
self.logger.error("Search failed for one or both datasets")
continue
# Print first sentence of each result
print("\nEmergency Dataset Results:")
for i, (idx, dist) in enumerate(zip(e_indices, e_distances), 1):
text = self.emergency_chunks[idx]['text']
first_sentence = text.split('.')[0] + '.'
print(f"\nE-{i} (distance: {dist:.3f}):")
print(first_sentence)
print("\nTreatment Dataset Results:")
for i, (idx, dist) in enumerate(zip(t_indices, t_distances), 1):
text = self.treatment_chunks[idx]['text']
first_sentence = text.split('.')[0] + '.'
print(f"\nT-{i} (distance: {dist:.3f}):")
print(first_sentence)
except Exception as e:
self.logger.error(f"Test failed: {str(e)}")
raise
else:
self.logger.info("\n✅ Cross-dataset search test completed")
if __name__ == "__main__":
# Manual test execution
test = TestEmbeddingValidation()
test.setup_class()
test.test_embedding_dimensions()
test.test_multiple_known_item_search()
test.test_balanced_cross_dataset_search()