Amit Kumar commited on
Commit
cfde818
·
1 Parent(s): 5cc802c

initial commit

Browse files
Files changed (4) hide show
  1. .gitattributes +1 -0
  2. BERT_sentiment_analysis.pth +3 -0
  3. app.py +65 -0
  4. requirements.txt +5 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ BERT_sentiment_analysis.pth filter=lfs diff=lfs merge=lfs -text
BERT_sentiment_analysis.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7336efc4b2f17e373689bf5cef49d7a9c8eee0288c98a7179be8fdacc0297316
3
+ size 267884656
app.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### 1. Imports and class names setup ###
2
+ import gradio as gr
3
+ import os
4
+ import torch
5
+
6
+ from transformers import DistilBertTokenizerFast
7
+ from timeit import default_timer as timer
8
+
9
+ # Setup class names
10
+ class_names = ["Positive", "Negative",]
11
+
12
+ ### 2. Load the model ###
13
+
14
+ model = torch.load(f="bert_sentiment_classifier/BERT_sentiment_analysis.pth",
15
+ map_location=torch.device("cpu")) # load to CPU)
16
+
17
+ ### 3. Predict function ###
18
+
19
+ # Create predict function
20
+ def predict(text: str):
21
+ """Transforms and performs a prediction on img and returns prediction and time taken.
22
+ """
23
+ # Start the timer
24
+ start_time = timer()
25
+
26
+ tokenizer = DistilBertTokenizerFast.from_pretrained(
27
+ 'distilbert-base-uncased'
28
+ )
29
+
30
+ input = tokenizer(text, return_tensors="pt").to(DEVICE)
31
+
32
+ model.eval()
33
+ with torch.inference_mode():
34
+
35
+ logits = model(**input).logits
36
+ predicted_class_id = logits.argmax().item()
37
+
38
+ if predicted_class_id == 1:
39
+ result = "Positive 😊"
40
+ else:
41
+ result = "Negative 🙁"
42
+
43
+ # Calculate the prediction time
44
+ pred_time = round(timer() - start_time, 5)
45
+
46
+ # Return the prediction dictionary and prediction time
47
+ return result
48
+
49
+ ### 4. Gradio app ###
50
+
51
+ # Create title, description and article strings
52
+ title = "Sentiment Classifier"
53
+ description = "A Sentiment Classifier trained by fine-tuning [DistilBert](https://huggingface.co/docs/transformers/v4.42.0/en/model_doc/distilbert#transformers.DistilBertForSequenceClassification) Transformer model using hugging face [transformers](https://huggingface.co/docs/transformers/en/index) library."
54
+ article = "The model classifies sentiment of an input text (whether the text shows a positive or negative sentiment)."
55
+
56
+ #Create the Gradio demo
57
+ demo = gr.Interface(fn=predict, # mapping function from input to output
58
+ inputs=[gr.Textbox(label="Input")],
59
+ outputs=[gr.Label(label="Prediction")],
60
+ title=title,
61
+ description=description,
62
+ article=article)
63
+
64
+ # Launch the demo!
65
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ gradio
3
+ pandas
4
+ transformers
5
+ torchtext