csc525_retrieval_based_chatbot / quality_metrics.py
JoeArmani
updates through 4th iteration
300fe5d
raw
history blame
2.02 kB
import tensorflow_hub as hub
import spacy
from sklearn.metrics.pairwise import cosine_similarity
from typing import Dict
from pipeline_config import PipelineConfig
class QualityMetrics:
"""
Quality metrics focusing on semantic similarity and basic lexical stats.
"""
def __init__(self, config: PipelineConfig):
self.config = config
self.use_model = hub.load('https://tfhub.dev/google/universal-sentence-encoder/4')
self.nlp = spacy.load('en_core_web_md')
def compute_semantic_similarity(self, text1: str, text2: str) -> float:
embeddings = self.use_model([text1, text2])
emb1, emb2 = embeddings[0].numpy(), embeddings[1].numpy()
return cosine_similarity([emb1], [emb2])[0][0]
def compute_metrics(self, original: str, augmented: str) -> Dict[str, float]:
metrics = {}
# Semantic similarity
embeddings = self.use_model([original, augmented])
emb_orig, emb_aug = embeddings[0].numpy(), embeddings[1].numpy()
metrics['semantic_similarity'] = cosine_similarity([emb_orig], [emb_aug])[0][0]
# Lexical diversity & content preservation
doc_orig = self.nlp(original)
doc_aug = self.nlp(augmented)
aug_tokens = [token.text.lower() for token in doc_aug]
metrics['type_token_ratio'] = len(set(aug_tokens)) / max(len(aug_tokens), 1)
orig_content = {token.text.lower() for token in doc_orig if not token.is_stop}
aug_content = {token.text.lower() for token in doc_aug if not token.is_stop}
if len(orig_content) == 0:
metrics['content_preservation'] = 1.0 if len(aug_content) == 0 else 0.0
else:
metrics['content_preservation'] = len(orig_content.intersection(aug_content)) / len(orig_content)
# Length ratio
orig_words = len(original.split())
aug_words = len(augmented.split())
metrics['length_ratio'] = aug_words / max(orig_words, 1)
return metrics