NinaMwangi commited on
Commit
6b5d959
·
verified ·
1 Parent(s): 1070f39

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -40
app.py CHANGED
@@ -7,20 +7,17 @@ model_name = "NinaMwangi/T5_finbot"
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  model = TFAutoModelForSeq2SeqLM.from_pretrained(model_name)
9
 
10
- # Load dataset for context retrieval
11
  dataset = load_dataset("virattt/financial-qa-10K")["train"]
12
 
13
- # Global chat history
14
- chat_history = []
15
-
16
- # Context lookup
17
  def get_context_for_question(question):
18
  for item in dataset:
19
  if item["question"].strip().lower() == question.strip().lower():
20
  return item["context"]
21
  return "No relevant context found."
22
 
23
- # Inference function
24
  def generate_answer(question):
25
  context = get_context_for_question(question)
26
  prompt = f"Q: {question} Context: {context} A:"
@@ -41,37 +38,15 @@ def generate_answer(question):
41
  )
42
 
43
  answer = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
44
- chat_history.append((question, answer))
45
- return "", chat_history
46
-
47
- # Clear history function
48
- def clear_chat():
49
- global chat_history
50
- chat_history = []
51
- return chat_history
52
-
53
- # Gradio UI
54
- with gr.Blocks(theme=gr.themes.Base()) as interface:
55
- gr.Markdown(
56
- """
57
- # 💬 Finance QA Chatbot
58
- Ask a finance-related question and get an accurate, concise response.
59
- Built using a fine-tuned T5 Transformer on financial Q&A data.
60
- """
61
- )
62
-
63
- chatbot = gr.Chatbot(label="Finance Chatbot", height=400, bubble_full_width=False)
64
-
65
- with gr.Row():
66
- with gr.Column(scale=8):
67
- question_box = gr.Textbox(placeholder="Ask a finance question...", show_label=False, lines=2)
68
- with gr.Column(scale=1):
69
- submit_btn = gr.Button("Send")
70
-
71
- clear_btn = gr.Button("Clear Chat")
72
-
73
- submit_btn.click(fn=generate_answer, inputs=question_box, outputs=[question_box, chatbot])
74
- clear_btn.click(fn=clear_chat, outputs=chatbot)
75
-
76
- # Launch app
77
- interface.launch()
 
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  model = TFAutoModelForSeq2SeqLM.from_pretrained(model_name)
9
 
10
+ # Load dataset
11
  dataset = load_dataset("virattt/financial-qa-10K")["train"]
12
 
13
+ # Function to retrieve context
 
 
 
14
  def get_context_for_question(question):
15
  for item in dataset:
16
  if item["question"].strip().lower() == question.strip().lower():
17
  return item["context"]
18
  return "No relevant context found."
19
 
20
+ # Predict function
21
  def generate_answer(question):
22
  context = get_context_for_question(question)
23
  prompt = f"Q: {question} Context: {context} A:"
 
38
  )
39
 
40
  answer = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
41
+ return answer
42
+
43
+ # Interface
44
+ interface = gr.Interface(
45
+ fn=generate_answer,
46
+ inputs=gr.Textbox(lines=2, placeholder="Ask a finance question..."),
47
+ outputs="text",
48
+ title="Finance QA Chatbot",
49
+ description="Built using a fine-tuned T5 Transformer. Ask a finance-related question and get an accurate, concise answer."
50
+ )
51
+
52
+ interface.launch()