File size: 1,348 Bytes
bd5c18d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
import streamlit as st
from transformers import AutoModelWithHeads, AutoTokenizer
import torch
# Load pre-trained BERT model with adapter support
st.title("Adapter Transformers for Text Classification")
@st.cache_resource
def load_model():
model = AutoModelWithHeads.from_pretrained("bert-base-uncased")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
# Add and activate an adapter
adapter_name = "my_adapter"
model.add_adapter(adapter_name)
model.train_adapter(adapter_name)
model.set_active_adapters(adapter_name)
# Add a classification head (binary classification)
model.add_classification_head(adapter_name, num_labels=2)
return model, tokenizer
# Load the model
model, tokenizer = load_model()
# Streamlit input
input_text = st.text_input("Enter text for classification:", "Steve Jobs founded Apple")
if input_text:
# Tokenize the input
inputs = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True, max_length=512)
# Make the prediction
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class = logits.argmax(dim=-1).item()
# Display the prediction
if predicted_class == 0:
st.write("Prediction: Negative")
else:
st.write("Prediction: Positive")
|