File size: 6,473 Bytes
a4b33d8 b3b327d 138ec98 1132a05 b3b327d a4b33d8 1132a05 33a1f42 1132a05 33a1f42 1132a05 33a1f42 1132a05 33a1f42 1132a05 33a1f42 1132a05 13f2506 1132a05 33a1f42 1132a05 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import gradio as gr
import torch
from transformers import DebertaV2Model, DebertaV2Config, AutoTokenizer, PreTrainedModel
from transformers.models.deberta.modeling_deberta import ContextPooler
from transformers import pipeline, AutoModelForSequenceClassification
import torch.nn as nn
# Define the model and tokenizer
model_card = "microsoft/mdeberta-v3-base"
subjectivity_only_model = "MatteoFasulo/mdeberta-v3-base-subjectivity-multilingual-no-arabic"
sentiment_model = "MatteoFasulo/mdeberta-v3-base-subjectivity-sentiment-multilingual-no-arabic"
# Define some examples for the Gradio interface (cached to run on-the-fly)
examples = [
['Example1'],
['Example2'],
['Example3'],
]
# Custom model class for combining sentiment analysis with subjectivity detection
class CustomModel(PreTrainedModel):
config_class = DebertaV2Config
def __init__(self, config, sentiment_dim=3, num_labels=2, *args, **kwargs):
super().__init__(config, *args, **kwargs)
self.deberta = DebertaV2Model(config)
self.pooler = ContextPooler(config)
output_dim = self.pooler.output_dim
self.dropout = nn.Dropout(0.1)
self.classifier = nn.Linear(output_dim + sentiment_dim, num_labels)
def forward(self, input_ids, positive, neutral, negative, token_type_ids=None, attention_mask=None, labels=None):
outputs = self.deberta(input_ids=input_ids, attention_mask=attention_mask)
encoder_layer = outputs[0]
pooled_output = self.pooler(encoder_layer)
# Sentiment features as a single tensor
sentiment_features = torch.stack((positive, neutral, negative), dim=1) # Shape: (batch_size, 3)
# Combine CLS embedding with sentiment features
combined_features = torch.cat((pooled_output, sentiment_features), dim=1)
# Classification head
logits = self.classifier(self.dropout(combined_features))
return {'logits': logits}
# Load the pre-trained tokenizer
def load_tokenizer(model_name: str):
return AutoTokenizer.from_pretrained(model_name)
# Load the pre-trained model
def load_model(model_name: str):
if 'sentiment' in model_name:
config = DebertaV2Config.from_pretrained(
model_name,
num_labels=2,
id2label={0: 'OBJ', 1: 'SUBJ'},
label2id={'OBJ': 0, 'SUBJ': 1},
output_attentions=False,
output_hidden_states=False
)
model = CustomModel(config=config, sentiment_dim=3, num_labels=2).from_pretrained(model_name)
else:
model = AutoModelForSequenceClassification.from_pretrained(
model_name,
num_labels=2,
id2label={0: 'OBJ', 1: 'SUBJ'},
label2id={'OBJ': 0, 'SUBJ': 1},
output_attentions=False,
output_hidden_states=False
)
return model
# Get sentiment values using a pre-trained sentiment analysis model
def get_sentiment_values(text: str):
pipe = pipeline("sentiment-analysis", model="cardiffnlp/twitter-xlm-roberta-base-sentiment", tokenizer="cardiffnlp/twitter-xlm-roberta-base-sentiment", top_k=None)
sentiments = pipe(text)[0]
return {k:v for k,v in [(list(sentiment.values())[0], list(sentiment.values())[1]) for sentiment in sentiments]}
# Modify the predict_subjectivity function to return additional information
def analyze(text):
# Extract sentiment values
sentiment_values = get_sentiment_values(text)
# Load the tokenizer and model
tokenizer = load_tokenizer(model_card)
sentiment_model = load_model(sentiment_model)
subjectivity_model = load_model(subjectivity_only_model)
# Tokenize
inputs = tokenizer(text, padding=True, truncation=True, max_length=256, return_tensors='pt')
# Get the sentiment values
positive = sentiment_values['positive']
neutral = sentiment_values['neutral']
negative = sentiment_values['negative']
# Convert sentiment values to tensors
inputs['positive'] = torch.tensor(positive).unsqueeze(0)
inputs['neutral'] = torch.tensor(neutral).unsqueeze(0)
inputs['negative'] = torch.tensor(negative).unsqueeze(0)
# Get the sentiment model outputs
outputs1 = sentiment_model(**inputs)
logits1 = outputs1.get('logits')
# Calculate probabilities using softmax
p1 = torch.nn.functional.softmax(logits1, dim=1)[0]
# Get the subjectivity model outputs
outputs2 = subjectivity_model(**inputs)
logits2 = outputs2.get('logits')
# Calculate probabilities using softmax
p2 = torch.nn.functional.softmax(logits2, dim=1)[0]
# Format the output
return {
'Positive': f"{positive:.2%}", 'Neutral': f"{neutral:.2%}", 'Negative': f"{negative:.2%}",
'Sent-Subj OBJ': f"{p1[0]:.2%}", 'Sent-Subj SUBJ': f"{p1[1]:.2%}",
'TextOnly OBJ': f"{p2[0]:.2%}", 'TextOnly SUBJ': f"{p2[1]:.2%}"
}
# Update the Gradio interface
with gr.Blocks(theme=gr.themes.Soft(), css="""
#result_table td { padding: 8px; font-size: 1rem; }
#header { text-align: center; font-size: 2rem; font-weight: bold; margin-bottom: 10px; }
""") as demo:
gr.Markdown("<div id='header'>π Advanced Subjectivity & Sentiment Dashboard π</div>")
with gr.Row():
txt = gr.Textbox(label="Enter text to analyze", placeholder="Paste news sentence here...", lines=2)
btn = gr.Button("Analyze π", variant="primary")
with gr.Tabs():
with gr.TabItem("Overview π"):
chart = gr.BarPlot(x="category", y="value", label="Results", elem_id="result_chart")
with gr.TabItem("Raw Scores π"):
table = gr.Dataframe(headers=["Metric", "Value"], datatype=["str","str"], interactive=False, elem_id="result_table")
with gr.TabItem("About βΉοΈ"):
gr.Markdown("This dashboard uses two DeBERTa-based models (with and without sentiment integration) to detect subjectivity, alongside sentiment scores from an XLM-RoBERTa model.")
with gr.Row():
gr.Markdown("### Examples:")
gr.Examples(
examples=examples,
inputs=txt,
outputs=[chart, table],
fn=analyze,
label="Examples",
elem_id="example_list",
cache_examples=True,
)
# Link inputs to outputs
btn.click(fn=analyze, inputs=txt, outputs=[chart, table])
demo.queue().launch(server_name="0.0.0.0", share=True) |