Spaces:
Sleeping
Sleeping
import gradio as gr | |
from datasets import load_dataset | |
from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM | |
# Load model and tokenizer | |
model_name = "NinaMwangi/T5_finbot" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = TFAutoModelForSeq2SeqLM.from_pretrained(model_name) | |
# Load dataset | |
dataset = load_dataset("virattt/financial-qa-10K")["train"] | |
# Function to retrieve context | |
def get_context_for_question(question): | |
for item in dataset: | |
if item["question"].strip().lower() == question.strip().lower(): | |
return item["context"] | |
return "No relevant context found." | |
# Predict function | |
def generate_answer(question): | |
context = get_context_for_question(question) | |
prompt = f"Q: {question} Context: {context} A:" | |
inputs = tokenizer( | |
prompt, | |
return_tensors="tf", | |
padding="max_length", | |
truncation=True, | |
max_length=256 | |
) | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=64, | |
num_beams=4, | |
early_stopping=True | |
) | |
answer = tokenizer.decode(outputs[0], skip_special_tokens=True).strip() | |
return answer | |
# Interface | |
interface = gr.Interface( | |
fn=generate_answer, | |
inputs=gr.Textbox(lines=2, placeholder="Ask a finance question..."), | |
outputs="text", | |
title="Finance QA Chatbot", | |
description="Built using a fine-tuned T5 Transformer. Ask a finance-related question and get an accurate, concise answer." | |
) | |
interface.launch() | |