Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import AutoModelForQuestionAnswering, AutoTokenizer | |
import torch | |
import torch.nn.functional as F | |
# Load model and tokenizer | |
MODEL_NAME = "S-Dreamer/raft-qa-space" | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
model = AutoModelForQuestionAnswering.from_pretrained(MODEL_NAME) | |
def answer_question(context, question): | |
inputs = tokenizer( | |
question, context, return_tensors="pt", truncation=True, max_length=512, stride=128, return_overflowing_tokens=True | |
) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
start_probs = F.softmax(outputs.start_logits, dim=-1) | |
end_probs = F.softmax(outputs.end_logits, dim=-1) | |
start_idx = torch.argmax(start_probs) | |
end_idx = torch.argmax(end_probs) + 1 | |
answer = tokenizer.decode(inputs["input_ids"][0][start_idx:end_idx], skip_special_tokens=True) | |
return answer if answer.strip() else "No answer found." | |
# Define UI | |
with gr.Blocks() as demo: | |
gr.Markdown("# 🤖 RAFT: Retrieval-Augmented Fine-Tuning for QA") | |
gr.Markdown("Ask a question based on the provided context and see how RAFT improves response accuracy!") | |
with gr.Row(): | |
context_input = gr.Textbox(lines=5, label="Context", placeholder="Enter background text here...") | |
question_input = gr.Textbox(lines=2, label="Question", placeholder="What is the main idea?") | |
answer_output = gr.Textbox(label="Answer", interactive=False) | |
submit_btn = gr.Button("Generate Answer") | |
submit_btn.click(answer_question, inputs=[context_input, question_input], outputs=answer_output) | |
demo.launch() | |