Spaces:
Sleeping
Sleeping
import sys | |
import os | |
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) | |
import numpy as np | |
from transformers import pipeline | |
from typing import List | |
from utils.config import load_config | |
class EntailmentAnalyzer: | |
# def __init__(self, config_path: str): | |
def __init__(self, config): | |
""" | |
Initialize the EntailmentAnalyzer with the config file path. | |
Args: | |
config_path: The path to the configuration file. | |
""" | |
# self.config = load_config(config_path)['PECCAVI_TEXT']['Entailment'] | |
self.config = config | |
self.entailment_pipeline = pipeline(task=self.config['task'], model=self.config['model']) | |
def check_entailment(self, premise: str, hypothesis: str) -> float: | |
""" | |
Check entailment between the premise and hypothesis. | |
Args: | |
premise: The premise sentence. | |
hypothesis: The hypothesis sentence. | |
Returns: | |
float: The entailment score. | |
""" | |
results = self.entailment_pipeline(f"{premise} [SEP] {hypothesis}", top_k=None) | |
entailment_score = next(item['score'] for item in results if item['label'] == 'entailment') | |
return entailment_score | |
def analyze_entailment(self, original_sentence: str, paraphrased_sentences: List[str], threshold: float) -> tuple: | |
""" | |
Analyze entailment scores for paraphrased sentences. If no selected sentences are found, | |
lower the threshold and rerun the analysis. | |
Args: | |
original_sentence: The original sentence. | |
paraphrased_sentences: List of paraphrased sentences. | |
threshold: Minimum score to select a sentence. | |
Returns: | |
tuple: A dictionary of all scores, selected sentences, and discarded sentences. | |
""" | |
all_sentences = {} | |
selected_sentences = {} | |
discarded_sentences = {} | |
# Loop to reduce threshold if no sentences are selected | |
while not selected_sentences: | |
for paraphrased_sentence in paraphrased_sentences: | |
entailment_score = self.check_entailment(original_sentence, paraphrased_sentence) | |
all_sentences[paraphrased_sentence] = entailment_score | |
if entailment_score >= threshold: | |
selected_sentences[paraphrased_sentence] = entailment_score | |
else: | |
discarded_sentences[paraphrased_sentence] = entailment_score | |
# If no sentences are selected, lower the threshold | |
if not selected_sentences: | |
print(f"No selected sentences found. Lowering the threshold by 0.1 (from {threshold} to {threshold - 0.1}).") | |
threshold -= 0.1 | |
if threshold <= 0: | |
print("Threshold has reached 0. No sentences meet the criteria.") | |
break | |
return all_sentences, selected_sentences, discarded_sentences | |
if __name__ == "__main__": | |
config_path = os.path.join(os.path.dirname(__file__), '..', 'config', 'config.yaml') | |
config_path = '/home/ashhar21137/text_wm/scratch/utils/config/config.yaml' | |
config = load_config(config_path) | |
entailment_analyzer = EntailmentAnalyzer(config['PECCAVI_TEXT']['Entailment']) | |
all_sentences, selected_sentences, discarded_sentences = entailment_analyzer.analyze_entailment( | |
"The weather is nice today", | |
[ | |
"The climate is pleasant today", | |
"It's a good day weather-wise", | |
"Today, the weather is terrible", | |
"What a beautiful day it is", | |
"The sky is clear and the weather is perfect", | |
"It's pouring rain outside today", | |
"The weather isn't bad today", | |
"A lovely day for outdoor activities" | |
], | |
0.7 | |
) | |
print("----------------------- All Sentences -----------------------") | |
print(all_sentences) | |
print("----------------------- Discarded Sentences -----------------------") | |
print(discarded_sentences) | |
print("----------------------- Selected Sentences -----------------------") | |
print(selected_sentences) | |