LucaTedeschini
feat: enhanced GUI
63daa8c
raw
history blame
7.84 kB
import gradio as gr
import torch
from transformers import DebertaV2Model, DebertaV2Config, AutoTokenizer, PreTrainedModel
from transformers.models.deberta.modeling_deberta import ContextPooler
from transformers import pipeline, AutoModelForSequenceClassification
import torch.nn as nn
# Define the model and tokenizer
model_card = "microsoft/mdeberta-v3-base"
subjectivity_only_model = "MatteoFasulo/mdeberta-v3-base-subjectivity-multilingual-no-arabic"
sentiment_model = "MatteoFasulo/mdeberta-v3-base-subjectivity-sentiment-multilingual-no-arabic"
# Define some examples for the Gradio interface (cached to run on-the-fly)
examples = [
["But then Trump came to power and sidelined the defense hawks, ushering in a dramatic shift in Republican sentiment toward America's allies and adversaries."],
["Boxing Day ambush & flagship attack Putin has long tried to downplay the true losses his army has faced in the Black Sea."],
]
class CustomModel(PreTrainedModel):
config_class = DebertaV2Config
def __init__(self, config, sentiment_dim=3, num_labels=2, *args, **kwargs):
super().__init__(config, *args, **kwargs)
self.deberta = DebertaV2Model(config)
self.pooler = ContextPooler(config)
output_dim = self.pooler.output_dim
self.dropout = nn.Dropout(0.1)
self.classifier = nn.Linear(output_dim + sentiment_dim, num_labels)
def forward(self, input_ids, positive, neutral, negative, token_type_ids=None, attention_mask=None, labels=None):
outputs = self.deberta(input_ids=input_ids, attention_mask=attention_mask)
encoder_layer = outputs[0]
pooled_output = self.pooler(encoder_layer)
sentiment_features = torch.stack((positive, neutral, negative), dim=1).to(pooled_output.dtype)
combined_features = torch.cat((pooled_output, sentiment_features), dim=1)
logits = self.classifier(self.dropout(combined_features))
return {'logits': logits}
def load_tokenizer(model_name: str):
return AutoTokenizer.from_pretrained(model_name)
load_model_cache = {}
def load_model(model_name: str):
if model_name not in load_model_cache:
print(f"Loading model: {model_name}")
if 'sentiment' in model_name:
config = DebertaV2Config.from_pretrained(
model_name, num_labels=2, id2label={0: 'OBJ', 1: 'SUBJ'}, label2id={'OBJ': 0, 'SUBJ': 1},
output_attentions=False, output_hidden_states=False
)
model_instance = CustomModel(config=config, sentiment_dim=3, num_labels=2).from_pretrained(model_name)
else:
model_instance = AutoModelForSequenceClassification.from_pretrained(
model_name, num_labels=2, id2label={0: 'OBJ', 1: 'SUBJ'}, label2id={'OBJ': 0, 'SUBJ': 1},
output_attentions=False, output_hidden_states=False
)
load_model_cache[model_name] = model_instance
return load_model_cache[model_name]
sentiment_pipeline_cache = None #
def get_sentiment_values(text: str):
global sentiment_pipeline_cache
if sentiment_pipeline_cache is None:
print("Loading sentiment pipeline...")
sentiment_pipeline_cache = pipeline(
"sentiment-analysis",
model="cardiffnlp/twitter-xlm-roberta-base-sentiment",
tokenizer="cardiffnlp/twitter-xlm-roberta-base-sentiment",
top_k=None
)
sentiments_output = sentiment_pipeline_cache(text)
if sentiments_output and isinstance(sentiments_output, list) and sentiments_output[0]:
sentiments = sentiments_output[0]
return {s['label'].lower(): s['score'] for s in sentiments}
return {}
def analyze(text):
if not text or not text.strip():
empty_data = [
["Positive", ""], ["Neutral", ""], ["Negative", ""],
["Sent-Subj OBJ", ""], ["Sent-Subj SUBJ", ""],
["TextOnly OBJ", ""], ["TextOnly SUBJ", ""]
]
return empty_data
sentiment_values = get_sentiment_values(text)
tokenizer = load_tokenizer(model_card)
model_with_sentiment = load_model(sentiment_model)
model_without_sentiment = load_model(subjectivity_only_model)
inputs_dict = tokenizer(text, padding=True, truncation=True, max_length=256, return_tensors='pt')
device = next(model_without_sentiment.parameters()).device
inputs_dict_on_device = {k: v.to(device) for k, v in inputs_dict.items()}
outputs_base = model_without_sentiment(**inputs_dict_on_device)
logits_base = outputs_base.get('logits')
prob_base = torch.nn.functional.softmax(logits_base, dim=1)[0]
positive = sentiment_values.get('positive', 0.0)
neutral = sentiment_values.get('neutral', 0.0)
negative = sentiment_values.get('negative', 0.0)
current_inputs_for_sentiment_model = inputs_dict_on_device.copy()
current_inputs_for_sentiment_model['positive'] = torch.tensor(positive, device=device).unsqueeze(0).float()
current_inputs_for_sentiment_model['neutral'] = torch.tensor(neutral, device=device).unsqueeze(0).float()
current_inputs_for_sentiment_model['negative'] = torch.tensor(negative, device=device).unsqueeze(0).float()
outputs_sentiment = model_with_sentiment(**current_inputs_for_sentiment_model)
logits_sentiment = outputs_sentiment.get('logits')
prob_sentiment = torch.nn.functional.softmax(logits_sentiment, dim=1)[0]
table_data = [
["Positive", f"{positive:.2%}"],
["Neutral", f"{neutral:.2%}"],
["Negative", f"{negative:.2%}"],
["Sent-Subj OBJ", f"{prob_sentiment[0]:.2%}"],
["Sent-Subj SUBJ", f"{prob_sentiment[1]:.2%}"],
["TextOnly OBJ", f"{prob_base[0]:.2%}"],
["TextOnly SUBJ", f"{prob_base[1]:.2%}"]
]
return table_data
def load_default_example_on_startup():
print("Loading default example on startup...")
if examples and examples[0] and isinstance(examples[0], list) and examples[0]:
default_text = examples[0][0]
default_analysis_results = analyze(default_text)
return default_text, default_analysis_results
print("Warning: No valid default example found. Loading empty.")
empty_text = ""
empty_results = analyze(empty_text)
return empty_text, empty_results
with gr.Blocks(theme=gr.themes.Ocean(), title="Subjectivity & Sentiment Dashboard") as demo:
gr.Markdown("# πŸš€ Subjectivity & Sentiment Analysis Dashboard πŸš€")
with gr.Column():
txt = gr.Textbox(
label="Enter text to analyze",
placeholder="Paste news sentence here...",
lines=2,
)
with gr.Row():
gr.Column(scale=1, min_width=0)
btn = gr.Button(
"Analyze πŸ”",
variant="primary",
size="md",
scale=0
)
with gr.Tabs():
with gr.TabItem("Raw Scores πŸ“‹"):
table = gr.Dataframe(
headers=["Metric", "Value"],
datatype=["str", "str"],
interactive=False
)
with gr.TabItem("About ℹ️"):
gr.Markdown(
"This dashboard uses two DeBERTa-based models (with and without sentiment integration) "
"to detect subjectivity, alongside sentiment scores from an XLM-RoBERTa model."
)
with gr.Row():
gr.Markdown("### Examples:")
gr.Examples(
examples=examples,
inputs=txt,
outputs=[table],
fn=analyze,
label="Click an example to analyze",
cache_examples=True,
)
btn.click(fn=analyze, inputs=txt, outputs=[table])
demo.load(
fn=load_default_example_on_startup,
inputs=None,
outputs=[txt, table]
)
demo.queue().launch(share=True)