Spaces:
Runtime error
Runtime error
Commit
·
e4296b4
1
Parent(s):
ad42ce4
Fix model combination
Browse files
app.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
from datasets import load_dataset
|
| 3 |
from transformers import DistilBertForSequenceClassification, DistilBertTokenizer
|
|
@@ -34,15 +35,20 @@ def predict(model_abstract, model_claims, tokenizer_abstract, tokenizer_claims,
|
|
| 34 |
|
| 35 |
abstract, claims = input['abstract'], input['claims']
|
| 36 |
|
| 37 |
-
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
with torch.no_grad():
|
| 41 |
-
outputs_abstract = model_abstract(
|
| 42 |
-
outputs_claims = model_claims(
|
| 43 |
|
| 44 |
combined_prob = (outputs_abstract.logits.softmax(dim=1) + outputs_claims.logits.softmax(dim=1)) / 2
|
| 45 |
-
label = torch.argmax(combined_prob,
|
| 46 |
|
| 47 |
return label, combined_prob
|
| 48 |
|
|
@@ -53,7 +59,7 @@ if __name__ == '__main__':
|
|
| 53 |
form = st.form('patent-prediction-form')
|
| 54 |
dropdown = [example['patent_number'] for example in dataset]
|
| 55 |
|
| 56 |
-
input_application = form.selectbox('Select a patent\'s application number',
|
| 57 |
submit = form.form_submit_button("Submit")
|
| 58 |
|
| 59 |
if submit:
|
|
@@ -62,6 +68,6 @@ if __name__ == '__main__':
|
|
| 62 |
label, prob = predict(model_abstract, model_claims, tokenizer_abstract, tokenizer_claims, input)
|
| 63 |
|
| 64 |
st.write(label)
|
| 65 |
-
st.write(
|
| 66 |
st.write(input['decision'])
|
| 67 |
|
|
|
|
| 1 |
+
import torch
|
| 2 |
import streamlit as st
|
| 3 |
from datasets import load_dataset
|
| 4 |
from transformers import DistilBertForSequenceClassification, DistilBertTokenizer
|
|
|
|
| 35 |
|
| 36 |
abstract, claims = input['abstract'], input['claims']
|
| 37 |
|
| 38 |
+
encoding_abstract = tokenizer_abstract(abstract, return_tensors='pt', truncation=True, padding='max_length')
|
| 39 |
+
encoding_claims = tokenizer_claims(claims, return_tensors='pt', truncation=True, padding='max_length')
|
| 40 |
+
|
| 41 |
+
input_abstract = encoding_abstract['input_ids'].to(device)
|
| 42 |
+
attention_mask_abstract = encoding_abstract['attention_mask'].to(device)
|
| 43 |
+
input_claims = encoding_claims['input_ids'].to(device)
|
| 44 |
+
attention_mask_claims = encoding_claims['attention_mask'].to(device)
|
| 45 |
|
| 46 |
with torch.no_grad():
|
| 47 |
+
outputs_abstract = model_abstract(input_ids=input_abstract, attention_mask=attention_mask_abstract)
|
| 48 |
+
outputs_claims = model_claims(input_ids=input_claims, attention_mask=attention_mask_claims)
|
| 49 |
|
| 50 |
combined_prob = (outputs_abstract.logits.softmax(dim=1) + outputs_claims.logits.softmax(dim=1)) / 2
|
| 51 |
+
label = torch.argmax(combined_prob, axis=1).flatten()
|
| 52 |
|
| 53 |
return label, combined_prob
|
| 54 |
|
|
|
|
| 59 |
form = st.form('patent-prediction-form')
|
| 60 |
dropdown = [example['patent_number'] for example in dataset]
|
| 61 |
|
| 62 |
+
input_application = form.selectbox('Select a patent\'s application number', dropdown)
|
| 63 |
submit = form.form_submit_button("Submit")
|
| 64 |
|
| 65 |
if submit:
|
|
|
|
| 68 |
label, prob = predict(model_abstract, model_claims, tokenizer_abstract, tokenizer_claims, input)
|
| 69 |
|
| 70 |
st.write(label)
|
| 71 |
+
st.write(prob)
|
| 72 |
st.write(input['decision'])
|
| 73 |
|