File size: 3,280 Bytes
613b160
 
da51d36
8eecde5
 
da51d36
f11da2a
da51d36
f11da2a
 
 
da51d36
f11da2a
 
 
 
 
da51d36
f11da2a
 
 
 
8eecde5
f11da2a
 
3da047b
 
 
 
f11da2a
 
 
 
 
 
da51d36
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import gradio as gr 
from transformers import AutoModelWithLMHead, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-question-generation-ap") 
model = AutoModelWithLMHead.from_pretrained("mrm8488/t5-base-finetuned-question-generation-ap")

def get_questions(input_data, max_length=64): context = input_data[0] answers = input_data[1:] generated_questions = []

for answer in answers:
    input_text = "answer: %s  context: %s </s>" % (answer, context)
    features = tokenizer([input_text], return_tensors='pt')

    output = model.generate(
        input_ids=features['input_ids'], 
        attention_mask=features['attention_mask'],
        max_length=max_length
    )

    question = tokenizer.decode(output[0])[16:-4]
    generated_questions.append(question)

return generated_questions

examples = [ ["Uzbekistan is a Central Asian nation and former Soviet republic. It's known for its mosques, mausoleums and other sites linked to the Silk Road, the ancient trade route between China and the Mediterranean. Samarkand, a major city on the route, contains a landmark of Islamic architecture.", "Silk Road"], ["The Great Barrier Reef is the world's largest coral reef system composed of over 2,900 individual reefs and 900 islands stretching for over 2,300 kilometers. The reef is located in the Coral Sea, off the coast of Australia's state of Queensland.", "Great Barrier Reef"] ]

def generate_questions(context, *answers): 
    input_data = [context] + list(answers) 
    questions = get_questions(input_data) 
    return questions

inputs = [ gr.Textbox(lines=3, placeholder="Enter context here", label="Input - Context"), gr.Textbox(lines=1, label="Input - Answer 1"), gr.Textbox(lines=1, label="Input - Answer 2"), gr.Textbox(lines=1, label="Input - Answer 3") ]

outputs = gr.Textbox(lines=5, label="Output - Generated Questions")

css = """ .output-markdown{display:none !important} .gr-button-primary { z-index: 14; height: 43px; width: 130px; left: 0px; top: 0px; padding: 0px; cursor: pointer !important; background: none rgb(17, 20, 45) !important; border: none !important; text-align: center !important; font-family: Poppins !important; font-size: 14px !important; font-weight: 500 !important; color: rgb(255, 255, 255) !important; line-height: 1 !important; border-radius: 12px !important; transition: box-shadow 200ms ease 0s, background 200ms ease 0s !important; box-shadow: none !important; } .gr-button-primary:hover{ z-index: 14; height: 43px; width: 130px; left: 0px; top: 0px; padding: 0px; cursor: pointer !important; background: none rgb(37, 56, 133) !important; border: none !important; text-align: center !important; font-family: Poppins !important; font-size: 14px !important; font-weight: 500 !important; color: rgb(255, 255, 255) !important; line-height: 1 !important; border-radius: 12px !important; transition: box-shadow 200ms ease 0s, background 200ms ease 0s !important; box-shadow: rgb(0 0 0 / 23%) 0px 1px 7px 0px !important; } .hover:bg-orange-50:hover { --tw-bg-opacity: 1 !important; background-color: rgb(229,225,255) !important; } """ demo = gr.Interface( fn=generate_questions, inputs=inputs, outputs=outputs, title="Question Generator | Data Science Dojo", examples=examples, css=css )

demo.launch()