|
import tensorflow_hub as hub |
|
import spacy |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
from typing import Dict |
|
from pipeline_config import PipelineConfig |
|
|
|
class QualityMetrics: |
|
""" |
|
Quality metrics focusing on semantic similarity and basic lexical stats. |
|
""" |
|
def __init__(self, config: PipelineConfig): |
|
self.config = config |
|
self.use_model = hub.load('https://tfhub.dev/google/universal-sentence-encoder/4') |
|
self.nlp = spacy.load('en_core_web_md') |
|
|
|
def compute_semantic_similarity(self, text1: str, text2: str) -> float: |
|
embeddings = self.use_model([text1, text2]) |
|
emb1, emb2 = embeddings[0].numpy(), embeddings[1].numpy() |
|
return cosine_similarity([emb1], [emb2])[0][0] |
|
|
|
def compute_metrics(self, original: str, augmented: str) -> Dict[str, float]: |
|
metrics = {} |
|
|
|
embeddings = self.use_model([original, augmented]) |
|
emb_orig, emb_aug = embeddings[0].numpy(), embeddings[1].numpy() |
|
metrics['semantic_similarity'] = cosine_similarity([emb_orig], [emb_aug])[0][0] |
|
|
|
|
|
doc_orig = self.nlp(original) |
|
doc_aug = self.nlp(augmented) |
|
|
|
aug_tokens = [token.text.lower() for token in doc_aug] |
|
metrics['type_token_ratio'] = len(set(aug_tokens)) / max(len(aug_tokens), 1) |
|
|
|
orig_content = {token.text.lower() for token in doc_orig if not token.is_stop} |
|
aug_content = {token.text.lower() for token in doc_aug if not token.is_stop} |
|
if len(orig_content) == 0: |
|
metrics['content_preservation'] = 1.0 if len(aug_content) == 0 else 0.0 |
|
else: |
|
metrics['content_preservation'] = len(orig_content.intersection(aug_content)) / len(orig_content) |
|
|
|
|
|
orig_words = len(original.split()) |
|
aug_words = len(augmented.split()) |
|
metrics['length_ratio'] = aug_words / max(orig_words, 1) |
|
|
|
return metrics |