|
import torch |
|
import tensorflow as tf |
|
import tensorflow_hub as hub |
|
from transformers import GPT2TokenizerFast, GPT2LMHeadModel |
|
import language_tool_python |
|
from rouge_score import rouge_scorer |
|
import spacy |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
import numpy as np |
|
from typing import Dict |
|
from pipeline_config import PipelineConfig |
|
|
|
class QualityMetrics: |
|
""" |
|
Measure augmented text quality |
|
""" |
|
def __init__(self, config: PipelineConfig): |
|
self.config = config |
|
|
|
|
|
self.use_model = hub.load('https://tfhub.dev/google/universal-sentence-encoder/4') |
|
|
|
|
|
self.tokenizer = GPT2TokenizerFast.from_pretrained('gpt2') |
|
self.model = GPT2LMHeadModel.from_pretrained('gpt2') |
|
self.model.eval() |
|
|
|
|
|
self.language_tool = language_tool_python.LanguageTool('en-US') |
|
|
|
|
|
self.rouge = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True) |
|
|
|
|
|
self.nlp = spacy.load('en_core_web_sm') |
|
|
|
def compute_perplexity(self, text): |
|
try: |
|
encodings = self.tokenizer(text, return_tensors='pt') |
|
input_ids = encodings['input_ids'] |
|
with torch.no_grad(): |
|
outputs = self.model(input_ids, labels=input_ids) |
|
loss = outputs.loss |
|
perplexity = torch.exp(loss) |
|
return perplexity.item() |
|
except Exception as e: |
|
print(f"Error computing perplexity for text '{text}': {e}") |
|
return float('inf') |
|
|
|
def compute_semantic_similarity(self, text1: str, text2: str) -> float: |
|
""" |
|
Compute semantic similarity between two texts using the Universal Sentence Encoder. |
|
Args: |
|
text1 (str): First text |
|
text2 (str): Second text |
|
Returns: |
|
float: Cosine similarity score between the two texts (0-1) |
|
""" |
|
embeddings = self.use_model([text1, text2]) |
|
emb1, emb2 = embeddings[0].numpy(), embeddings[1].numpy() |
|
return cosine_similarity([emb1], [emb2])[0][0] |
|
|
|
def compute_metrics(self, original: str, augmented: str) -> Dict[str, float]: |
|
""" |
|
Compute quality metrics |
|
""" |
|
metrics = {} |
|
|
|
|
|
embeddings = self.use_model([original, augmented]) |
|
emb_orig, emb_aug = embeddings[0].numpy(), embeddings[1].numpy() |
|
metrics['semantic_similarity'] = cosine_similarity([emb_orig], [emb_aug])[0][0] |
|
|
|
|
|
metrics['perplexity'] = self.compute_perplexity(augmented) |
|
metrics['grammar_errors'] = len(self.language_tool.check(augmented)) |
|
|
|
|
|
doc_orig = self.nlp(original) |
|
doc_aug = self.nlp(augmented) |
|
|
|
|
|
aug_tokens = [token.text.lower() for token in doc_aug] |
|
metrics['type_token_ratio'] = len(set(aug_tokens)) / max(len(aug_tokens), 1) |
|
|
|
|
|
orig_content = set([token.text.lower() for token in doc_orig if not token.is_stop]) |
|
aug_content = set([token.text.lower() for token in doc_aug if not token.is_stop]) |
|
|
|
|
|
if len(orig_content) == 0: |
|
metrics['content_preservation'] = 1.0 if len(aug_content) == 0 else 0.0 |
|
else: |
|
metrics['content_preservation'] = len(orig_content.intersection(aug_content)) / len(orig_content) |
|
|
|
|
|
rouge_scores = self.rouge.score(original, augmented) |
|
metrics['rouge1_f1'] = rouge_scores['rouge1'].fmeasure |
|
metrics['rouge2_f1'] = rouge_scores['rouge2'].fmeasure |
|
metrics['rougeL_f1'] = rouge_scores['rougeL'].fmeasure |
|
|
|
|
|
orig_words = len(original.split()) |
|
aug_words = len(augmented.split()) |
|
metrics['length_ratio'] = aug_words / max(orig_words, 1) |
|
|
|
return metrics |
|
|
|
def meets_quality_threshold(self, metrics: Dict[str, float]) -> bool: |
|
""" |
|
Enhanced quality threshold checking |
|
""" |
|
|
|
basic_quality = ( |
|
metrics['perplexity'] <= self.config.perplexity_threshold and |
|
metrics['semantic_similarity'] >= self.config.semantic_similarity_threshold and |
|
metrics['grammar_errors'] <= self.config.grammar_error_threshold |
|
) |
|
|
|
|
|
length_ok = 0.6 <= metrics['length_ratio'] <= 1.4 |
|
|
|
|
|
diversity_ok = metrics['type_token_ratio'] >= 0.4 |
|
|
|
|
|
content_ok = metrics['content_preservation'] >= 0.6 |
|
|
|
return all([basic_quality, length_ok, diversity_ok, content_ok]) |