File size: 1,947 Bytes
01fd84a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import gradio as gr
from transformers import pipeline, RobertaTokenizer, RobertaForQuestionAnswering
import torch

# Load the model and tokenizer
model_name = "AventIQ-AI/roberta-chatbot"
tokenizer = RobertaTokenizer.from_pretrained(model_name)
model = RobertaForQuestionAnswering.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Initialize the question-answering pipeline
qa_pipeline = pipeline("question-answering", model=model, tokenizer=tokenizer, device=0 if torch.cuda.is_available() else -1)

# Define the function for the Gradio interface
def roberta_chatbot(context, question):
    if not context or not question:
        return "Please provide both context and a question."
    
    # Get the model's answer
    result = qa_pipeline(question=question, context=context)
    answer = result.get('answer', 'Sorry, I could not find an answer.')
    return answer

# Create the Gradio interface
iface = gr.Interface(
    fn=roberta_chatbot,
    inputs=[
        gr.Textbox(label="πŸ“„ Context", placeholder="Enter the context here...", lines=5),
        gr.Textbox(label="❓ Question", placeholder="Enter your question here...", lines=2)
    ],
    outputs=gr.Textbox(label="πŸ€– Answer"),
    title="🧠 RoBERTa-Powered Chatbot",
    description="Provide a context and ask a question. The RoBERTa-based chatbot will find the answer based on the given context.",
    examples=[
        ["Flight AI101 departs from New York at 10:00 AM and arrives in San Francisco at 1:30 PM. The flight duration is 5 hours and 30 minutes.", "What is the duration of Flight AI101?"],
        ["The Great Wall of China was built over several centuries to protect China's northern borders.", "Why was the Great Wall of China built?"]
    ],
    theme="compact",
    allow_flagging="never"
)

if __name__ == "__main__":
    iface.launch()