File size: 5,090 Bytes
3190e1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import torch
import tensorflow as tf
import tensorflow_hub as hub
from transformers import GPT2TokenizerFast, GPT2LMHeadModel
import language_tool_python
from rouge_score import rouge_scorer
import spacy
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
from typing import Dict
from pipeline_config import PipelineConfig

class QualityMetrics:
    """
    Measure augmented text quality
    """
    def __init__(self, config: PipelineConfig):
        self.config = config
        
        # Semantic similarity
        self.use_model = hub.load('https://tfhub.dev/google/universal-sentence-encoder/4')
        
        # Fluency metrics
        self.tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
        self.model = GPT2LMHeadModel.from_pretrained('gpt2')
        self.model.eval()
        
        # Grammar
        self.language_tool = language_tool_python.LanguageTool('en-US')
        
        # Lexical similarity
        self.rouge = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
        
        # Diversity
        self.nlp = spacy.load('en_core_web_sm')
        
    def compute_perplexity(self, text):
        try:
            encodings = self.tokenizer(text, return_tensors='pt')
            input_ids = encodings['input_ids']
            with torch.no_grad():
                outputs = self.model(input_ids, labels=input_ids)
                loss = outputs.loss
                perplexity = torch.exp(loss)
                return perplexity.item()
        except Exception as e:
            print(f"Error computing perplexity for text '{text}': {e}")
            return float('inf')  # High perplexity value == poor quality
        
    def compute_semantic_similarity(self, text1: str, text2: str) -> float:
        """
        Compute semantic similarity between two texts using the Universal Sentence Encoder.
        Args:
            text1 (str): First text
            text2 (str): Second text
        Returns:
            float: Cosine similarity score between the two texts (0-1)
        """
        embeddings = self.use_model([text1, text2])
        emb1, emb2 = embeddings[0].numpy(), embeddings[1].numpy()
        return cosine_similarity([emb1], [emb2])[0][0]

    def compute_metrics(self, original: str, augmented: str) -> Dict[str, float]:
        """
        Compute quality metrics
        """
        metrics = {}
        
        # 1. Semantic Preservation
        embeddings = self.use_model([original, augmented])
        emb_orig, emb_aug = embeddings[0].numpy(), embeddings[1].numpy()
        metrics['semantic_similarity'] = cosine_similarity([emb_orig], [emb_aug])[0][0]
        
        # 2. Fluency & Naturalness
        metrics['perplexity'] = self.compute_perplexity(augmented)
        metrics['grammar_errors'] = len(self.language_tool.check(augmented))
        
        # 3. Lexical Diversity
        doc_orig = self.nlp(original)
        doc_aug = self.nlp(augmented)
        
        # Type-token ratio with safety check
        aug_tokens = [token.text.lower() for token in doc_aug]
        metrics['type_token_ratio'] = len(set(aug_tokens)) / max(len(aug_tokens), 1)
        
        # Content word overlap with safety checks
        orig_content = set([token.text.lower() for token in doc_orig if not token.is_stop])
        aug_content = set([token.text.lower() for token in doc_aug if not token.is_stop])
        
        # Safety check for empty content sets
        if len(orig_content) == 0:
            metrics['content_preservation'] = 1.0 if len(aug_content) == 0 else 0.0
        else:
            metrics['content_preservation'] = len(orig_content.intersection(aug_content)) / len(orig_content)
        
        # 4. Structural Preservation
        rouge_scores = self.rouge.score(original, augmented)
        metrics['rouge1_f1'] = rouge_scores['rouge1'].fmeasure
        metrics['rouge2_f1'] = rouge_scores['rouge2'].fmeasure
        metrics['rougeL_f1'] = rouge_scores['rougeL'].fmeasure
        
        # 5. Length Preservation with safety check
        orig_words = len(original.split())
        aug_words = len(augmented.split())
        metrics['length_ratio'] = aug_words / max(orig_words, 1)
        
        return metrics

    def meets_quality_threshold(self, metrics: Dict[str, float]) -> bool:
        """
        Enhanced quality threshold checking
        """
        # Core quality checks
        basic_quality = (
            metrics['perplexity'] <= self.config.perplexity_threshold and
            metrics['semantic_similarity'] >= self.config.semantic_similarity_threshold and
            metrics['grammar_errors'] <= self.config.grammar_error_threshold
        )
        
        # Length preservation check
        length_ok = 0.6 <= metrics['length_ratio'] <= 1.4
        
        # Diversity check
        diversity_ok = metrics['type_token_ratio'] >= 0.4
        
        # Content preservation check
        content_ok = metrics['content_preservation'] >= 0.6
        
        return all([basic_quality, length_ok, diversity_ok, content_ok])