MatteoFasulo's picture
Fix missing comma
b3b327d
raw
history blame
3.82 kB
import gradio as gr
import torch
from transformers import DebertaV2Model, DebertaV2Config, AutoTokenizer, PreTrainedModel, ContextPooler
from transformers import pipeline
import torch.nn as nn
# Define the model and tokenizer
model_card = "microsoft/mdeberta-v3-base"
finetuned_model = "MatteoFasulo/mdeberta-v3-base-subjectivity-sentiment-multilingual"
# Custom model class for combining sentiment analysis with subjectivity detection
class CustomModel(PreTrainedModel):
config_class = DebertaV2Config
def __init__(self, config, sentiment_dim=3, num_labels=2, *args, **kwargs):
super().__init__(config, *args, **kwargs)
self.deberta = DebertaV2Model(config)
self.pooler = ContextPooler(config)
output_dim = self.pooler.output_dim
self.dropout = nn.Dropout(0.1)
self.classifier = nn.Linear(output_dim + sentiment_dim, num_labels)
def forward(self, input_ids, positive, neutral, negative, attention_mask=None, labels=None):
outputs = self.deberta(input_ids=input_ids, attention_mask=attention_mask)
encoder_layer = outputs[0]
pooled_output = self.pooler(encoder_layer)
# Sentiment features as a single tensor
sentiment_features = torch.stack((positive, neutral, negative), dim=1) # Shape: (batch_size, 3)
# Combine CLS embedding with sentiment features
combined_features = torch.cat((pooled_output, sentiment_features), dim=1)
# Classification head
logits = self.classifier(self.dropout(combined_features))
return {'logits': logits}
# Load the pre-trained tokenizer
def load_tokenizer(model_name: str):
return AutoTokenizer.from_pretrained(model_name)
# Load the pre-trained model
def load_model(model_card: str, finetuned_model: str):
tokenizer = AutoTokenizer.from_pretrained(model_card)
config = DebertaV2Config.from_pretrained(
finetuned_model,
num_labels=2,
id2label={0: 'OBJ', 1: 'SUBJ'},
label2id={'OBJ': 0, 'SUBJ': 1},
output_attentions=False,
output_hidden_states=False
)
model = CustomModel(config=config, sentiment_dim=3, num_labels=2).from_pretrained(finetuned_model)
return model
# Get sentiment values using a pre-trained sentiment analysis model
def get_sentiment_values(text: str):
pipe = pipeline("sentiment-analysis", model="cardiffnlp/twitter-xlm-roberta-base-sentiment", tokenizer="cardiffnlp/twitter-xlm-roberta-base-sentiment", top_k=None)
sentiments = pipe(text)[0]
return {k:v for k,v in [(list(sentiment.values())[0], list(sentiment.values())[1]) for sentiment in sentiments]}
# Predict the subjectivity of a sentence
def predict_subjectivity(text):
sentiment_values = get_sentiment_values(text)
model = load_model(model_card, finetuned_model)
tokenizer = load_tokenizer(model_card)
inputs = tokenizer(text, padding=True, truncation=True, max_length=256, return_tensors='pt')
outputs = model(**inputs)
logits = outputs.get('logits')
predicted_class_idx = logits.argmax().item()
predicted_class = model.config.id2label[predicted_class_idx]
return predicted_class
# Create a Gradio interface
demo = gr.Interface(
fn=predict_subjectivity,
inputs=gr.Textbox(
label='Input sentence',
placeholder='Enter a sentence from a news article',
info='Paste a sentence from a news article to determine if it is subjective or objective.'
),
outputs=gr.Text(
label="Prediction",
info="Whether the sentence is subjective or objective."
),
title='Subjectivity Detection',
description='Detect if a sentence is subjective or objective using a pre-trained model.',
theme='huggingface',
)
# Launch the interface
demo.launch()