Dorothy Oduor commited on
Commit
a65f9fb
·
1 Parent(s): 21312ac

file uploads

Browse files
Files changed (1) hide show
  1. app.py +59 -0
app.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
+ import torch
4
+
5
+
6
+ # Use a BERT model PRE-FINETUNED for sentiment analysis
7
+ # BERT fine-tuned on SST-2 dataset
8
+ model_name = "textattack/bert-base-uncased-SST-2"
9
+
10
+ # Load model and tokenizer
11
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
12
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
13
+
14
+ # Device configuration
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ model = model.to(device)
17
+
18
+
19
+ def analyze_sentiment(text):
20
+ # Tokenize input
21
+ inputs = tokenizer(text,
22
+ return_tensors="pt",
23
+ truncation=True,
24
+ padding=True,
25
+ max_length=512).to(device)
26
+
27
+ # Get predictions
28
+ with torch.no_grad():
29
+ outputs = model(**inputs)
30
+ probs = torch.softmax(outputs.logits, dim=1)
31
+
32
+ # Process results
33
+ labels = ["NEGATIVE", "POSITIVE"]
34
+ confidence, pred_class = torch.max(probs, dim=1)
35
+
36
+ return {
37
+ "text": text,
38
+ "prediction": labels[pred_class],
39
+ "confidence": confidence.item(),
40
+ "probabilities": dict(zip(labels, probs.tolist()[0]))
41
+ }
42
+
43
+
44
+ # Example usage
45
+ text = "Hugging Face is amazing!"
46
+ result = analyze_sentiment(text)
47
+
48
+ print(f"\nInput: {result['text']}")
49
+ print(f"Model: {model_name}")
50
+ print(f"Prediction: {result['prediction']} ({result['confidence']:.2%})")
51
+ print("Probabilities:")
52
+ for label, prob in result['probabilities'].items():
53
+ print(f" {label}: {prob:.4f}")
54
+
55
+ # Additional tokenization info
56
+ print("\nTokenization details:")
57
+ tokens = tokenizer.tokenize(text)
58
+ print(f"Tokens: {tokens}")
59
+ print(f"Token IDs: {tokenizer.encode(text)}")