Spaces:
Sleeping
Sleeping
File size: 3,383 Bytes
d5fc19a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import streamlit as st
import torch
import pandas as pd
from transformers import DistilBertTokenizer, DistilBertConfig, DistilBertModel
from .torch_primitives import PaperClassifierV1, PaperClassifierDatasetV1
@st.cache_resource
def load_everything():
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
# DistilBertTokenizer.from_pretrained('distilbert-base-uncased') doesn't work from my laptop, but we don't need
# that checkpoint anymore so we will use this class instead.
class EmptyPaperClassifier(PaperClassifierV1):
def __init__(self, n_classes):
super(PaperClassifierV1, self).__init__()
self.backbone = DistilBertModel(DistilBertConfig())
self.head = torch.nn.Linear(in_features=self.backbone.config.hidden_size, out_features=n_classes)
model = EmptyPaperClassifier(n_classes=len(PaperClassifierDatasetV1.MAJORS))
model.load_state_dict(torch.load('best_model.pt', map_location=device))
model.to(device)
model.eval()
return model, tokenizer, device
def classify_paper(title, abstract, model, tokenizer, device):
if abstract.strip() == "":
inputs = tokenizer(
title,
padding=True,
truncation=True,
max_length=512,
return_tensors='pt'
)
else:
inputs = tokenizer(
[title],
[abstract],
padding=True,
truncation=True,
max_length=512,
return_tensors='pt'
)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
probabilities = torch.sigmoid(outputs).cpu().numpy()[0]
return pd.DataFrame({
'Category': PaperClassifierDatasetV1.MAJORS,
'Probability': probabilities
}).sort_values('Probability', ascending=False)
def main(threshold: float = 0.5):
st.set_page_config(page_title="ArXiv Paper Classifier", page_icon="🦈")
st.title("ArXiv Paper Classifier")
model, tokenizer, device = load_everything()
col1, col2 = st.columns([1, 1])
with col1:
title = st.text_area("Title", height=200, placeholder="Enter paper title here...", )
with col2:
abstract = st.text_area("Abstract (optional)", height=200, placeholder="Enter paper abstract here...")
if st.button("Classify", type='primary', use_container_width=True):
if not title:
st.error("Please enter a paper title")
else:
with st.spinner('In progress...'):
results = classify_paper(title, abstract, model, tokenizer, device)
st.subheader("Results")
predicted = results[results['Probability'] > threshold]['Category'].tolist()
results['Probability'] = results['Probability'].apply(lambda x: f"{x:.2%}")
if len(predicted) == 0:
st.info("Hmm, I am not sure about this one.")
else:
st.success(f"Predicted categories: {', '.join(predicted)}")
with st.expander("Show details"):
st.dataframe(results, use_container_width=True, hide_index=True)
st.caption("All categories with their confidence scores")
if __name__ == "__main__":
main()
|