File size: 2,015 Bytes
3190e1e
 
 
 
 
 
 
 
300fe5d
3190e1e
 
 
 
300fe5d
 
3190e1e
 
 
 
 
 
 
300fe5d
3190e1e
 
 
 
300fe5d
3190e1e
 
300fe5d
3190e1e
 
 
300fe5d
 
3190e1e
 
 
 
 
300fe5d
3190e1e
 
 
 
300fe5d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import tensorflow_hub as hub
import spacy
from sklearn.metrics.pairwise import cosine_similarity
from typing import Dict
from pipeline_config import PipelineConfig

class QualityMetrics:
    """
    Quality metrics focusing on semantic similarity and basic lexical stats.
    """
    def __init__(self, config: PipelineConfig):
        self.config = config
        self.use_model = hub.load('https://tfhub.dev/google/universal-sentence-encoder/4')
        self.nlp = spacy.load('en_core_web_md')

    def compute_semantic_similarity(self, text1: str, text2: str) -> float:
        embeddings = self.use_model([text1, text2])
        emb1, emb2 = embeddings[0].numpy(), embeddings[1].numpy()
        return cosine_similarity([emb1], [emb2])[0][0]

    def compute_metrics(self, original: str, augmented: str) -> Dict[str, float]:
        metrics = {}
        # Semantic similarity
        embeddings = self.use_model([original, augmented])
        emb_orig, emb_aug = embeddings[0].numpy(), embeddings[1].numpy()
        metrics['semantic_similarity'] = cosine_similarity([emb_orig], [emb_aug])[0][0]
        
        # Lexical diversity & content preservation
        doc_orig = self.nlp(original)
        doc_aug = self.nlp(augmented)

        aug_tokens = [token.text.lower() for token in doc_aug]
        metrics['type_token_ratio'] = len(set(aug_tokens)) / max(len(aug_tokens), 1)
        
        orig_content = {token.text.lower() for token in doc_orig if not token.is_stop}
        aug_content = {token.text.lower() for token in doc_aug if not token.is_stop}
        if len(orig_content) == 0:
            metrics['content_preservation'] = 1.0 if len(aug_content) == 0 else 0.0
        else:
            metrics['content_preservation'] = len(orig_content.intersection(aug_content)) / len(orig_content)
        
        # Length ratio
        orig_words = len(original.split())
        aug_words = len(augmented.split())
        metrics['length_ratio'] = aug_words / max(orig_words, 1)

        return metrics