import torch import tensorflow as tf import tensorflow_hub as hub from transformers import GPT2TokenizerFast, GPT2LMHeadModel import language_tool_python from rouge_score import rouge_scorer import spacy from sklearn.metrics.pairwise import cosine_similarity import numpy as np from typing import Dict from pipeline_config import PipelineConfig class QualityMetrics: """ Measure augmented text quality """ def __init__(self, config: PipelineConfig): self.config = config # Semantic similarity self.use_model = hub.load('https://tfhub.dev/google/universal-sentence-encoder/4') # Fluency metrics self.tokenizer = GPT2TokenizerFast.from_pretrained('gpt2') self.model = GPT2LMHeadModel.from_pretrained('gpt2') self.model.eval() # Grammar self.language_tool = language_tool_python.LanguageTool('en-US') # Lexical similarity self.rouge = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True) # Diversity self.nlp = spacy.load('en_core_web_sm') def compute_perplexity(self, text): try: encodings = self.tokenizer(text, return_tensors='pt') input_ids = encodings['input_ids'] with torch.no_grad(): outputs = self.model(input_ids, labels=input_ids) loss = outputs.loss perplexity = torch.exp(loss) return perplexity.item() except Exception as e: print(f"Error computing perplexity for text '{text}': {e}") return float('inf') # High perplexity value == poor quality def compute_semantic_similarity(self, text1: str, text2: str) -> float: """ Compute semantic similarity between two texts using the Universal Sentence Encoder. Args: text1 (str): First text text2 (str): Second text Returns: float: Cosine similarity score between the two texts (0-1) """ embeddings = self.use_model([text1, text2]) emb1, emb2 = embeddings[0].numpy(), embeddings[1].numpy() return cosine_similarity([emb1], [emb2])[0][0] def compute_metrics(self, original: str, augmented: str) -> Dict[str, float]: """ Compute quality metrics """ metrics = {} # 1. Semantic Preservation embeddings = self.use_model([original, augmented]) emb_orig, emb_aug = embeddings[0].numpy(), embeddings[1].numpy() metrics['semantic_similarity'] = cosine_similarity([emb_orig], [emb_aug])[0][0] # 2. Fluency & Naturalness metrics['perplexity'] = self.compute_perplexity(augmented) metrics['grammar_errors'] = len(self.language_tool.check(augmented)) # 3. Lexical Diversity doc_orig = self.nlp(original) doc_aug = self.nlp(augmented) # Type-token ratio with safety check aug_tokens = [token.text.lower() for token in doc_aug] metrics['type_token_ratio'] = len(set(aug_tokens)) / max(len(aug_tokens), 1) # Content word overlap with safety checks orig_content = set([token.text.lower() for token in doc_orig if not token.is_stop]) aug_content = set([token.text.lower() for token in doc_aug if not token.is_stop]) # Safety check for empty content sets if len(orig_content) == 0: metrics['content_preservation'] = 1.0 if len(aug_content) == 0 else 0.0 else: metrics['content_preservation'] = len(orig_content.intersection(aug_content)) / len(orig_content) # 4. Structural Preservation rouge_scores = self.rouge.score(original, augmented) metrics['rouge1_f1'] = rouge_scores['rouge1'].fmeasure metrics['rouge2_f1'] = rouge_scores['rouge2'].fmeasure metrics['rougeL_f1'] = rouge_scores['rougeL'].fmeasure # 5. Length Preservation with safety check orig_words = len(original.split()) aug_words = len(augmented.split()) metrics['length_ratio'] = aug_words / max(orig_words, 1) return metrics def meets_quality_threshold(self, metrics: Dict[str, float]) -> bool: """ Enhanced quality threshold checking """ # Core quality checks basic_quality = ( metrics['perplexity'] <= self.config.perplexity_threshold and metrics['semantic_similarity'] >= self.config.semantic_similarity_threshold and metrics['grammar_errors'] <= self.config.grammar_error_threshold ) # Length preservation check length_ok = 0.6 <= metrics['length_ratio'] <= 1.4 # Diversity check diversity_ok = metrics['type_token_ratio'] >= 0.4 # Content preservation check content_ok = metrics['content_preservation'] >= 0.6 return all([basic_quality, length_ok, diversity_ok, content_ok])