shukdevdatta123 commited on
Commit
6e57278
·
verified ·
1 Parent(s): 8d2d415

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -17
app.py CHANGED
@@ -1,23 +1,31 @@
1
  import torch
2
  import gradio as gr
3
- from transformers import AutoModelForSequenceClassification, AutoTokenizer
 
4
 
5
- # Load model and tokenizer from Hugging Face Hub
6
- model_name = "shukdevdatta123/twitter-distilbert-base-uncased-sentiment-analysis-lora-text-classification"
7
- model = AutoModelForSequenceClassification.from_pretrained(model_name)
8
- tokenizer = AutoTokenizer.from_pretrained(model_name)
9
 
10
- # Define the label mapping (binary sentiment: 0 = Negative, 1 = Positive)
11
- id2label = {
12
- 0: "Negative",
13
- 1: "Positive"
14
- }
 
15
 
16
  # Set device
17
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
  model.to(device)
 
 
 
 
 
 
 
19
 
20
- # Define the prediction function
21
  def predict_sentiment(text):
22
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
23
 
@@ -26,21 +34,19 @@ def predict_sentiment(text):
26
 
27
  predicted_class = torch.argmax(logits, dim=1).item()
28
  label = id2label[predicted_class]
29
-
30
- # Optional: add confidence score
31
  probs = torch.nn.functional.softmax(logits, dim=1)
32
  confidence = probs[0][predicted_class].item()
33
 
34
  return f"{label} (Confidence: {confidence:.2f})"
35
 
36
- # Create Gradio Interface
37
  interface = gr.Interface(
38
  fn=predict_sentiment,
39
  inputs=gr.Textbox(lines=2, placeholder="Enter a sentence to analyze sentiment..."),
40
  outputs="text",
41
- title="Twitter Sentiment Classifier",
42
- description="This app uses a fine-tuned DistilBERT model with LoRA adapters to predict whether a tweet or sentence is Positive or Negative."
43
  )
44
 
45
- # Launch the app
46
  interface.launch()
 
1
  import torch
2
  import gradio as gr
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
+ from peft import PeftModel, PeftConfig
5
 
6
+ # Load PEFT adapter config
7
+ adapter_name = "shukdevdatta123/twitter-distilbert-base-uncased-sentiment-analysis-lora-text-classification"
8
+ config = PeftConfig.from_pretrained(adapter_name)
 
9
 
10
+ # Load base model and tokenizer
11
+ base_model = AutoModelForSequenceClassification.from_pretrained(config.base_model_name_or_path)
12
+ tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
13
+
14
+ # Load the LoRA adapter into the base model
15
+ model = PeftModel.from_pretrained(base_model, adapter_name)
16
 
17
  # Set device
18
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
  model.to(device)
20
+ model.eval()
21
+
22
+ # Label mapping
23
+ id2label = {
24
+ 0: "Negative",
25
+ 1: "Positive"
26
+ }
27
 
28
+ # Prediction function
29
  def predict_sentiment(text):
30
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
31
 
 
34
 
35
  predicted_class = torch.argmax(logits, dim=1).item()
36
  label = id2label[predicted_class]
37
+
 
38
  probs = torch.nn.functional.softmax(logits, dim=1)
39
  confidence = probs[0][predicted_class].item()
40
 
41
  return f"{label} (Confidence: {confidence:.2f})"
42
 
43
+ # Gradio UI
44
  interface = gr.Interface(
45
  fn=predict_sentiment,
46
  inputs=gr.Textbox(lines=2, placeholder="Enter a sentence to analyze sentiment..."),
47
  outputs="text",
48
+ title="Twitter Sentiment Classifier (LoRA + DistilBERT)",
49
+ description="This app uses a DistilBERT model with LoRA adapters to classify tweet sentiment as Positive or Negative."
50
  )
51
 
 
52
  interface.launch()