File size: 1,749 Bytes
bf2494e
 
 
aad2b57
bf2494e
 
 
aad2b57
 
bf2494e
 
 
aad2b57
 
 
 
 
 
 
 
bf2494e
 
 
 
 
aad2b57
 
bf2494e
 
 
aad2b57
bf2494e
 
 
 
 
aad2b57
bf2494e
 
 
 
aad2b57
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import streamlit as st
import numpy as np
import torch
from transformers import DistilBertTokenizer, DistilBertForQuestionAnswering

@st.cache(allow_output_mutation=True)
def load_model():
    model = DistilBertForQuestionAnswering.from_pretrained("distilbert-base-uncased-distilled-squad")
    tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased-distilled-squad")
    return model, tokenizer

def get_answer(question, text, tokenizer, model):
    inputs = tokenizer(question, text, return_tensors="pt", truncation=True, padding=True)
    with torch.no_grad():
        outputs = model(**inputs)
    start = torch.argmax(outputs.start_logits)
    end = torch.argmax(outputs.end_logits) + 1
    ans_tokens = inputs.input_ids[0][start:end]
    answer = tokenizer.decode(ans_tokens, skip_special_tokens=True)
    return answer

def main():
    st.set_page_config(page_title="Question Answering Tool", page_icon=":mag_right:")

    st.write("# Question Answering Tool \n"
             "This tool will help you find answers to your questions about the text you provide. \n"
             "Please enter your question and the text you want to search in the boxes below.")
    model, tokenizer = load_model()

    with st.form("qa_form"):
        text = st.text_area("Enter your text here")
        question = st.text_input("Enter your question here")
        
        if st.form_submit_button("Submit"):
            data_load_state = st.text('Let me think about that...')
            answer = get_answer(question, text, tokenizer, model)
            if answer.strip() == "":
                data_load_state.text("Sorry but I don't know the answer to that question")
            else:
                data_load_state.text(answer)

main()