csc525_retrieval_based_chatbot / quality_metrics.py
JoeArmani
Initial commit
3190e1e
raw
history blame
5.09 kB
import torch
import tensorflow as tf
import tensorflow_hub as hub
from transformers import GPT2TokenizerFast, GPT2LMHeadModel
import language_tool_python
from rouge_score import rouge_scorer
import spacy
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
from typing import Dict
from pipeline_config import PipelineConfig
class QualityMetrics:
"""
Measure augmented text quality
"""
def __init__(self, config: PipelineConfig):
self.config = config
# Semantic similarity
self.use_model = hub.load('https://tfhub.dev/google/universal-sentence-encoder/4')
# Fluency metrics
self.tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
self.model = GPT2LMHeadModel.from_pretrained('gpt2')
self.model.eval()
# Grammar
self.language_tool = language_tool_python.LanguageTool('en-US')
# Lexical similarity
self.rouge = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
# Diversity
self.nlp = spacy.load('en_core_web_sm')
def compute_perplexity(self, text):
try:
encodings = self.tokenizer(text, return_tensors='pt')
input_ids = encodings['input_ids']
with torch.no_grad():
outputs = self.model(input_ids, labels=input_ids)
loss = outputs.loss
perplexity = torch.exp(loss)
return perplexity.item()
except Exception as e:
print(f"Error computing perplexity for text '{text}': {e}")
return float('inf') # High perplexity value == poor quality
def compute_semantic_similarity(self, text1: str, text2: str) -> float:
"""
Compute semantic similarity between two texts using the Universal Sentence Encoder.
Args:
text1 (str): First text
text2 (str): Second text
Returns:
float: Cosine similarity score between the two texts (0-1)
"""
embeddings = self.use_model([text1, text2])
emb1, emb2 = embeddings[0].numpy(), embeddings[1].numpy()
return cosine_similarity([emb1], [emb2])[0][0]
def compute_metrics(self, original: str, augmented: str) -> Dict[str, float]:
"""
Compute quality metrics
"""
metrics = {}
# 1. Semantic Preservation
embeddings = self.use_model([original, augmented])
emb_orig, emb_aug = embeddings[0].numpy(), embeddings[1].numpy()
metrics['semantic_similarity'] = cosine_similarity([emb_orig], [emb_aug])[0][0]
# 2. Fluency & Naturalness
metrics['perplexity'] = self.compute_perplexity(augmented)
metrics['grammar_errors'] = len(self.language_tool.check(augmented))
# 3. Lexical Diversity
doc_orig = self.nlp(original)
doc_aug = self.nlp(augmented)
# Type-token ratio with safety check
aug_tokens = [token.text.lower() for token in doc_aug]
metrics['type_token_ratio'] = len(set(aug_tokens)) / max(len(aug_tokens), 1)
# Content word overlap with safety checks
orig_content = set([token.text.lower() for token in doc_orig if not token.is_stop])
aug_content = set([token.text.lower() for token in doc_aug if not token.is_stop])
# Safety check for empty content sets
if len(orig_content) == 0:
metrics['content_preservation'] = 1.0 if len(aug_content) == 0 else 0.0
else:
metrics['content_preservation'] = len(orig_content.intersection(aug_content)) / len(orig_content)
# 4. Structural Preservation
rouge_scores = self.rouge.score(original, augmented)
metrics['rouge1_f1'] = rouge_scores['rouge1'].fmeasure
metrics['rouge2_f1'] = rouge_scores['rouge2'].fmeasure
metrics['rougeL_f1'] = rouge_scores['rougeL'].fmeasure
# 5. Length Preservation with safety check
orig_words = len(original.split())
aug_words = len(augmented.split())
metrics['length_ratio'] = aug_words / max(orig_words, 1)
return metrics
def meets_quality_threshold(self, metrics: Dict[str, float]) -> bool:
"""
Enhanced quality threshold checking
"""
# Core quality checks
basic_quality = (
metrics['perplexity'] <= self.config.perplexity_threshold and
metrics['semantic_similarity'] >= self.config.semantic_similarity_threshold and
metrics['grammar_errors'] <= self.config.grammar_error_threshold
)
# Length preservation check
length_ok = 0.6 <= metrics['length_ratio'] <= 1.4
# Diversity check
diversity_ok = metrics['type_token_ratio'] >= 0.4
# Content preservation check
content_ok = metrics['content_preservation'] >= 0.6
return all([basic_quality, length_ok, diversity_ok, content_ok])