amasood commited on
Commit
bd5c18d
·
verified ·
1 Parent(s): e3f3424

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -0
app.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoModelWithHeads, AutoTokenizer
3
+ import torch
4
+
5
+ # Load pre-trained BERT model with adapter support
6
+ st.title("Adapter Transformers for Text Classification")
7
+
8
+ @st.cache_resource
9
+ def load_model():
10
+ model = AutoModelWithHeads.from_pretrained("bert-base-uncased")
11
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
12
+
13
+ # Add and activate an adapter
14
+ adapter_name = "my_adapter"
15
+ model.add_adapter(adapter_name)
16
+ model.train_adapter(adapter_name)
17
+ model.set_active_adapters(adapter_name)
18
+
19
+ # Add a classification head (binary classification)
20
+ model.add_classification_head(adapter_name, num_labels=2)
21
+ return model, tokenizer
22
+
23
+ # Load the model
24
+ model, tokenizer = load_model()
25
+
26
+ # Streamlit input
27
+ input_text = st.text_input("Enter text for classification:", "Steve Jobs founded Apple")
28
+
29
+ if input_text:
30
+ # Tokenize the input
31
+ inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True, max_length=512)
32
+
33
+ # Make the prediction
34
+ with torch.no_grad():
35
+ outputs = model(**inputs)
36
+ logits = outputs.logits
37
+ predicted_class = logits.argmax(dim=-1).item()
38
+
39
+ # Display the prediction
40
+ if predicted_class == 0:
41
+ st.write("Prediction: Negative")
42
+ else:
43
+ st.write("Prediction: Positive")