MatteoFasulo commited on
Commit
1132a05
·
verified ·
1 Parent(s): 79352ba

Update with examples and cleaner logic

Browse files
Files changed (1) hide show
  1. app.py +127 -68
app.py CHANGED
@@ -2,80 +2,132 @@ import gradio as gr
2
  import torch
3
  from transformers import DebertaV2Model, DebertaV2Config, AutoTokenizer, PreTrainedModel
4
  from transformers.models.deberta.modeling_deberta import ContextPooler
5
- from transformers import pipeline
6
  import torch.nn as nn
7
 
8
- # -- Model definitions
9
- BASE_MODEL = "microsoft/mdeberta-v3-base"
10
- SENT_SUBJ_MODEL = "MatteoFasulo/mdeberta-v3-base-subjectivity-sentiment-multilingual-no-arabic"
11
- SUBJ_ONLY_MODEL = "MatteoFasulo/mdeberta-v3-base-subjectivity-multilingual-no-arabic"
12
-
13
- # -- Custom model builder
14
- from functools import partial
15
-
16
- def build_custom_model(sentiment_dim=0):
17
- class CustomModel(PreTrainedModel):
18
- config_class = DebertaV2Config
19
- def __init__(self, config, *args, **kwargs):
20
- super().__init__(config, *args, **kwargs)
21
- self.deberta = DebertaV2Model(config)
22
- self.pooler = ContextPooler(config)
23
- self.dropout = nn.Dropout(0.1)
24
- hidden_dim = self.pooler.output_dim + sentiment_dim
25
- self.classifier = nn.Linear(hidden_dim, config.num_labels)
26
- def forward(self, input_ids, attention_mask=None, **sent_kwargs):
27
- x = self.deberta(input_ids=input_ids, attention_mask=attention_mask)[0]
28
- pooled = self.pooler(x)
29
- if sentiment_dim:
30
- sent_feats = torch.stack((sent_kwargs['positive'], sent_kwargs['neutral'], sent_kwargs['negative']), dim=1)
31
- pooled = torch.cat((pooled, sent_feats), dim=1)
32
- return self.classifier(self.dropout(pooled))
33
- return CustomModel
34
-
35
- # -- Load models and tokenizer
36
- tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
37
-
38
- # sentiment+subjectivity
39
- cfg1 = DebertaV2Config.from_pretrained(SENT_SUBJ_MODEL, num_labels=2, id2label={0:'OBJ',1:'SUBJ'}, label2id={'OBJ':0,'SUBJ':1})
40
- Model1Cls = build_custom_model(sentiment_dim=3)
41
- model1 = Model1Cls.from_pretrained(SENT_SUBJ_MODEL, config=cfg1, ignore_mismatched_sizes=True)
42
-
43
- # subjectivity-only
44
- cfg2 = DebertaV2Config.from_pretrained(SUBJ_ONLY_MODEL, num_labels=2, id2label={0:'OBJ',1:'SUBJ'}, label2id={'OBJ':0,'SUBJ':1})
45
- Model2Cls = build_custom_model(sentiment_dim=0)
46
- model2 = Model2Cls.from_pretrained(SUBJ_ONLY_MODEL, config=cfg2)
47
-
48
- # sentiment pipeline
49
- sentiment_pipe = pipeline("sentiment-analysis", model="cardiffnlp/twitter-xlm-roberta-base-sentiment", tokenizer="cardiffnlp/twitter-xlm-roberta-base-sentiment", top_k=None)
50
-
51
- def get_sentiment_scores(text):
52
- out = sentiment_pipe(text)[0]
53
- return {list(d.keys())[0]: list(d.values())[0] for d in out}
54
-
55
- # -- Prediction logic
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  def analyze(text):
 
 
 
 
 
 
 
 
57
  # Tokenize
58
- inputs = tokenizer(text, truncation=True, padding=True, max_length=256, return_tensors='pt')
59
- # Sentiment
60
- scores = get_sentiment_scores(text)
61
- pos, neu, neg = scores['positive'], scores['neutral'], scores['negative']
62
- # Model1
63
- logits1 = model1(input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, positive=torch.tensor([pos]), neutral=torch.tensor([neu]), negative=torch.tensor([neg]))
64
- p1 = torch.softmax(logits1, dim=1)[0]
65
- # Model2
66
- logits2 = model2(input_ids=inputs.input_ids, attention_mask=inputs.attention_mask)
67
- p2 = torch.softmax(logits2, dim=1)[0]
68
- # Build results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  return {
70
- 'Positive': f"{pos:.2%}", 'Neutral': f"{neu:.2%}", 'Negative': f"{neg:.2%}",
71
  'Sent-Subj OBJ': f"{p1[0]:.2%}", 'Sent-Subj SUBJ': f"{p1[1]:.2%}",
72
  'TextOnly OBJ': f"{p2[0]:.2%}", 'TextOnly SUBJ': f"{p2[1]:.2%}"
73
  }
74
 
75
- # -- Build Gradio Dashboard with Blocks
76
- theme = gr.themes.Soft()
77
-
78
- with gr.Blocks(theme=theme, css="""
79
  #result_table td { padding: 8px; font-size: 1rem; }
80
  #header { text-align: center; font-size: 2rem; font-weight: bold; margin-bottom: 10px; }
81
  """) as demo:
@@ -90,9 +142,16 @@ with gr.Blocks(theme=theme, css="""
90
  table = gr.Dataframe(headers=["Metric", "Value"], datatype=["str","str"], interactive=False, elem_id="result_table")
91
  with gr.TabItem("About ℹ️"):
92
  gr.Markdown("This dashboard uses two DeBERTa-based models (with and without sentiment integration) to detect subjectivity, alongside sentiment scores from an XLM-RoBERTa model.")
93
- gr.Markdown("**Threshold** for subjective classification is adjustable in code (default: 0.65). Feel free to fork and customize! 🚀")
 
 
 
 
 
 
 
 
94
  # Link inputs to outputs
95
  btn.click(fn=analyze, inputs=txt, outputs=[chart, table])
96
 
97
- # -- Launch
98
- demo.queue().launch(server_name="0.0.0.0", share=True)
 
2
  import torch
3
  from transformers import DebertaV2Model, DebertaV2Config, AutoTokenizer, PreTrainedModel
4
  from transformers.models.deberta.modeling_deberta import ContextPooler
5
+ from transformers import pipeline, AutoModelForSequenceClassification
6
  import torch.nn as nn
7
 
8
+ # Define the model and tokenizer
9
+ model_card = "microsoft/mdeberta-v3-base"
10
+ subjectivity_only_model = "MatteoFasulo/mdeberta-v3-base-subjectivity-multilingual-no-arabic"
11
+ sentiment_model = "MatteoFasulo/mdeberta-v3-base-subjectivity-sentiment-multilingual-no-arabic"
12
+
13
+ # Define some examples for the Gradio interface (cached to run on-the-fly)
14
+ examples = [
15
+ ['Example1'],
16
+ ['Example2'],
17
+ ['Example3'],
18
+ ]
19
+
20
+ # Custom model class for combining sentiment analysis with subjectivity detection
21
+ class CustomModel(PreTrainedModel):
22
+ config_class = DebertaV2Config
23
+
24
+ def __init__(self, config, sentiment_dim=3, num_labels=2, *args, **kwargs):
25
+ super().__init__(config, *args, **kwargs)
26
+ self.deberta = DebertaV2Model(config)
27
+ self.pooler = ContextPooler(config)
28
+ output_dim = self.pooler.output_dim
29
+ self.dropout = nn.Dropout(0.1)
30
+
31
+ self.classifier = nn.Linear(output_dim + sentiment_dim, num_labels)
32
+
33
+ def forward(self, input_ids, positive, neutral, negative, token_type_ids=None, attention_mask=None, labels=None):
34
+ outputs = self.deberta(input_ids=input_ids, attention_mask=attention_mask)
35
+
36
+ encoder_layer = outputs[0]
37
+ pooled_output = self.pooler(encoder_layer)
38
+
39
+ # Sentiment features as a single tensor
40
+ sentiment_features = torch.stack((positive, neutral, negative), dim=1) # Shape: (batch_size, 3)
41
+
42
+ # Combine CLS embedding with sentiment features
43
+ combined_features = torch.cat((pooled_output, sentiment_features), dim=1)
44
+
45
+ # Classification head
46
+ logits = self.classifier(self.dropout(combined_features))
47
+
48
+ return {'logits': logits}
49
+
50
+ # Load the pre-trained tokenizer
51
+ def load_tokenizer(model_name: str):
52
+ return AutoTokenizer.from_pretrained(model_name)
53
+
54
+ # Load the pre-trained model
55
+ def load_model(model_name: str):
56
+
57
+ if 'sentiment' in model_name:
58
+ config = DebertaV2Config.from_pretrained(
59
+ model_name,
60
+ num_labels=2,
61
+ id2label={0: 'OBJ', 1: 'SUBJ'},
62
+ label2id={'OBJ': 0, 'SUBJ': 1},
63
+ output_attentions=False,
64
+ output_hidden_states=False
65
+ )
66
+
67
+ model = CustomModel(config=config, sentiment_dim=3, num_labels=2).from_pretrained(model_name)
68
+
69
+ else:
70
+ model = AutoModelForSequenceClassification.from_pretrained(
71
+ model_name,
72
+ num_labels=2,
73
+ id2label={0: 'OBJ', 1: 'SUBJ'},
74
+ label2id={'OBJ': 0, 'SUBJ': 1},
75
+ output_attentions=False,
76
+ output_hidden_states=False
77
+ )
78
+
79
+ return model
80
+
81
+ # Get sentiment values using a pre-trained sentiment analysis model
82
+ def get_sentiment_values(text: str):
83
+ pipe = pipeline("sentiment-analysis", model="cardiffnlp/twitter-xlm-roberta-base-sentiment", tokenizer="cardiffnlp/twitter-xlm-roberta-base-sentiment", top_k=None)
84
+ sentiments = pipe(text)[0]
85
+ return {k:v for k,v in [(list(sentiment.values())[0], list(sentiment.values())[1]) for sentiment in sentiments]}
86
+
87
+ # Modify the predict_subjectivity function to return additional information
88
  def analyze(text):
89
+ # Extract sentiment values
90
+ sentiment_values = get_sentiment_values(text)
91
+
92
+ # Load the tokenizer and model
93
+ tokenizer = load_tokenizer(model_card)
94
+ sentiment_model = load_model(sentiment_model)
95
+ subjectivity_model = load_model(subjectivity_only_model)
96
+
97
  # Tokenize
98
+ inputs = tokenizer(text, padding=True, truncation=True, max_length=256, return_tensors='pt')
99
+
100
+ # Get the sentiment values
101
+ positive = sentiment_values['positive']
102
+ neutral = sentiment_values['neutral']
103
+ negative = sentiment_values['negative']
104
+ # Convert sentiment values to tensors
105
+ inputs['positive'] = torch.tensor(positive).unsqueeze(0)
106
+ inputs['neutral'] = torch.tensor(neutral).unsqueeze(0)
107
+ inputs['negative'] = torch.tensor(negative).unsqueeze(0)
108
+
109
+ # Get the sentiment model outputs
110
+ outputs1 = sentiment_model(**inputs)
111
+ logits1 = outputs1.get('logits')
112
+
113
+ # Calculate probabilities using softmax
114
+ p1 = torch.nn.functional.softmax(logits1, dim=1)[0]
115
+
116
+ # Get the subjectivity model outputs
117
+ outputs2 = subjectivity_model(**inputs)
118
+ logits2 = outputs2.get('logits')
119
+ # Calculate probabilities using softmax
120
+ p2 = torch.nn.functional.softmax(logits2, dim=1)[0]
121
+
122
+ # Format the output
123
  return {
124
+ 'Positive': f"{positive:.2%}", 'Neutral': f"{neutral:.2%}", 'Negative': f"{negative:.2%}",
125
  'Sent-Subj OBJ': f"{p1[0]:.2%}", 'Sent-Subj SUBJ': f"{p1[1]:.2%}",
126
  'TextOnly OBJ': f"{p2[0]:.2%}", 'TextOnly SUBJ': f"{p2[1]:.2%}"
127
  }
128
 
129
+ # Update the Gradio interface
130
+ with gr.Blocks(theme=gr.themes.Soft(), css="""
 
 
131
  #result_table td { padding: 8px; font-size: 1rem; }
132
  #header { text-align: center; font-size: 2rem; font-weight: bold; margin-bottom: 10px; }
133
  """) as demo:
 
142
  table = gr.Dataframe(headers=["Metric", "Value"], datatype=["str","str"], interactive=False, elem_id="result_table")
143
  with gr.TabItem("About ℹ️"):
144
  gr.Markdown("This dashboard uses two DeBERTa-based models (with and without sentiment integration) to detect subjectivity, alongside sentiment scores from an XLM-RoBERTa model.")
145
+ with gr.Row():
146
+ gr.Markdown("### Examples:")
147
+ gr.Examples(
148
+ examples=examples,
149
+ inputs=txt,
150
+ label="Examples",
151
+ elem_id="example_list",
152
+ cache_examples=True,
153
+ )
154
  # Link inputs to outputs
155
  btn.click(fn=analyze, inputs=txt, outputs=[chart, table])
156
 
157
+ demo.queue().launch(server_name="0.0.0.0", share=True)