Finance-chatbot / app.py
NinaMwangi's picture
Update app.py
6b5d959 verified
import gradio as gr
from datasets import load_dataset
from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM
# Load model and tokenizer
model_name = "NinaMwangi/T5_finbot"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = TFAutoModelForSeq2SeqLM.from_pretrained(model_name)
# Load dataset
dataset = load_dataset("virattt/financial-qa-10K")["train"]
# Function to retrieve context
def get_context_for_question(question):
for item in dataset:
if item["question"].strip().lower() == question.strip().lower():
return item["context"]
return "No relevant context found."
# Predict function
def generate_answer(question):
context = get_context_for_question(question)
prompt = f"Q: {question} Context: {context} A:"
inputs = tokenizer(
prompt,
return_tensors="tf",
padding="max_length",
truncation=True,
max_length=256
)
outputs = model.generate(
**inputs,
max_new_tokens=64,
num_beams=4,
early_stopping=True
)
answer = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
return answer
# Interface
interface = gr.Interface(
fn=generate_answer,
inputs=gr.Textbox(lines=2, placeholder="Ask a finance question..."),
outputs="text",
title="Finance QA Chatbot",
description="Built using a fine-tuned T5 Transformer. Ask a finance-related question and get an accurate, concise answer."
)
interface.launch()