ElizabethSrgh commited on
Commit
dbf0b81
·
verified ·
1 Parent(s): 1ebcb15

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -16
app.py CHANGED
@@ -3,21 +3,21 @@ import torch.nn as nn
3
  from transformers import AutoTokenizer, AutoModel
4
  import gradio as gr
5
 
6
- # Model multitask (Topik & Sentimen)
7
  class MultiTaskModel(nn.Module):
8
  def __init__(self, base_model_name, num_topic_classes, num_sentiment_classes):
9
  super(MultiTaskModel, self).__init__()
10
  self.encoder = AutoModel.from_pretrained(base_model_name)
11
  hidden_size = self.encoder.config.hidden_size
12
- self.topic_classifier = nn.Linear(hidden_size, num_topic_classes)
13
- self.sentiment_classifier = nn.Linear(hidden_size, num_sentiment_classes)
14
 
15
  def forward(self, input_ids, attention_mask):
16
  outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
17
  pooled_output = outputs.last_hidden_state[:, 0]
18
- topic_logits = self.topic_classifier(pooled_output)
19
- sentiment_logits = self.sentiment_classifier(pooled_output)
20
- return topic_logits, sentiment_logits
21
 
22
  # Load tokenizer & model
23
  tokenizer = AutoTokenizer.from_pretrained("tokenizer")
@@ -26,21 +26,19 @@ model.load_state_dict(torch.load("model.pt", map_location=torch.device("cpu")))
26
  model.eval()
27
 
28
  # Label mapping
29
- topic_labels = ["Produk", "Layanan", "Pengiriman", "Lainnya"]
30
- sentiment_labels = ["Negatif", "Netral", "Positif"]
31
 
32
- # Fungsi klasifikasi
33
  def klasifikasi(text):
34
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
35
  with torch.no_grad():
36
- topic_logits, sentiment_logits = model(**inputs)
37
- topic_probs = torch.softmax(topic_logits, dim=-1).squeeze()
38
- sentiment_probs = torch.softmax(sentiment_logits, dim=-1).squeeze()
39
 
40
- topic_result = {label: float(prob) for label, prob in zip(topic_labels, topic_probs)}
41
- sentiment_result = {label: float(prob) for label, prob in zip(sentiment_labels, sentiment_probs)}
42
- return {"Topik": topic_result, "Sentimen": sentiment_result}
43
 
44
- # Gradio UI
45
  demo = gr.Interface(fn=klasifikasi, inputs="text", outputs="json", title="Klasifikasi Topik dan Sentimen Pelanggan")
46
  demo.launch()
 
3
  from transformers import AutoTokenizer, AutoModel
4
  import gradio as gr
5
 
6
+ # Model multitask (Topik & Sentimen) dengan nama layer sesuai model.pt
7
  class MultiTaskModel(nn.Module):
8
  def __init__(self, base_model_name, num_topic_classes, num_sentiment_classes):
9
  super(MultiTaskModel, self).__init__()
10
  self.encoder = AutoModel.from_pretrained(base_model_name)
11
  hidden_size = self.encoder.config.hidden_size
12
+ self.topik_classifier = nn.Linear(hidden_size, num_topic_classes)
13
+ self.sentimen_classifier = nn.Linear(hidden_size, num_sentiment_classes)
14
 
15
  def forward(self, input_ids, attention_mask):
16
  outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
17
  pooled_output = outputs.last_hidden_state[:, 0]
18
+ topik_logits = self.topik_classifier(pooled_output)
19
+ sentimen_logits = self.sentimen_classifier(pooled_output)
20
+ return topik_logits, sentimen_logits
21
 
22
  # Load tokenizer & model
23
  tokenizer = AutoTokenizer.from_pretrained("tokenizer")
 
26
  model.eval()
27
 
28
  # Label mapping
29
+ topik_labels = ["Produk", "Layanan", "Pengiriman", "Lainnya"]
30
+ sentimen_labels = ["Negatif", "Netral", "Positif"]
31
 
 
32
  def klasifikasi(text):
33
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
34
  with torch.no_grad():
35
+ topik_logits, sentimen_logits = model(**inputs)
36
+ topik_probs = torch.softmax(topik_logits, dim=-1).squeeze()
37
+ sentimen_probs = torch.softmax(sentimen_logits, dim=-1).squeeze()
38
 
39
+ topik_result = {label: float(prob) for label, prob in zip(topik_labels, topik_probs)}
40
+ sentimen_result = {label: float(prob) for label, prob in zip(sentimen_labels, sentimen_probs)}
41
+ return {"Topik": topik_result, "Sentimen": sentimen_result}
42
 
 
43
  demo = gr.Interface(fn=klasifikasi, inputs="text", outputs="json", title="Klasifikasi Topik dan Sentimen Pelanggan")
44
  demo.launch()