NinaMwangi commited on
Commit
e89aaa7
·
verified ·
1 Parent(s): e39318b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -40
app.py CHANGED
@@ -2,50 +2,49 @@ import gradio as gr
2
  from datasets import load_dataset
3
  from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM
4
 
5
-
6
- # Load model and tokenizer from your Hugging Face model repo
7
  model_name = "NinaMwangi/T5_finbot"
8
  tokenizer = AutoTokenizer.from_pretrained(model_name)
9
  model = TFAutoModelForSeq2SeqLM.from_pretrained(model_name)
10
 
11
- # Load dataset ===
12
  dataset = load_dataset("virattt/financial-qa-10K")["train"]
13
 
14
- # Function to retrieve matching context
15
  def get_context_for_question(question):
16
  for item in dataset:
17
  if item["question"].strip().lower() == question.strip().lower():
18
  return item["context"]
19
  return "No relevant context found."
20
 
21
- # Define the prediction function (inference)
22
  def generate_answer(question, chat_history):
23
- context = get_context_for_question(question)
24
- prompt = f"Q: {question} Context: {context} A:"
25
-
26
-
27
- inputs = tokenizer(
28
- prompt,
29
- return_tensors="tf",
30
- padding="max_length",
31
- truncation=True,
32
- max_length=256
33
- )
34
-
35
- outputs = model.generate(
36
- **inputs,
37
- max_new_tokens=64,
38
- num_beams=4,
39
- early_stopping=True
40
- )
41
-
42
- #return tokenizer.decode(outputs[0], skip_special_tokens=True)
43
-
44
- answer = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
45
- chat_history.append([question, answer])
46
- return "", chat_history
 
 
 
 
 
47
 
48
- # === Gradio UI ===
49
  with gr.Blocks(theme=gr.themes.Base()) as interface:
50
  gr.Markdown(
51
  """
@@ -65,25 +64,27 @@ with gr.Blocks(theme=gr.themes.Base()) as interface:
65
  submit_btn = gr.Button("Send")
66
 
67
  clear_btn = gr.Button("Clear Chat")
68
-
69
- # Chat state
70
  state = gr.State([])
71
 
72
- # Bind functionality
73
  submit_btn.click(
74
- generate_answer,
75
- inputs=[question_box, state],
76
- outputs=[question_box, chatbot],
77
- )
 
 
 
 
 
 
78
 
79
  clear_btn.click(
80
  lambda: ("", [], []),
81
  inputs=[],
82
  outputs=[question_box, chatbot, state],
83
- )
84
-
85
 
86
- # === Launch the app ===
87
  interface.launch(share=True)
88
 
89
 
 
 
2
  from datasets import load_dataset
3
  from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM
4
 
 
 
5
  model_name = "NinaMwangi/T5_finbot"
6
  tokenizer = AutoTokenizer.from_pretrained(model_name)
7
  model = TFAutoModelForSeq2SeqLM.from_pretrained(model_name)
8
 
 
9
  dataset = load_dataset("virattt/financial-qa-10K")["train"]
10
 
 
11
  def get_context_for_question(question):
12
  for item in dataset:
13
  if item["question"].strip().lower() == question.strip().lower():
14
  return item["context"]
15
  return "No relevant context found."
16
 
 
17
  def generate_answer(question, chat_history):
18
+ try:
19
+ if chat_history is None:
20
+ chat_history = []
21
+ context = get_context_for_question(question)
22
+ prompt = f"Q: {question} Context: {context} A:"
23
+
24
+ inputs = tokenizer(
25
+ prompt,
26
+ return_tensors="tf",
27
+ padding="max_length",
28
+ truncation=True,
29
+ max_length=256
30
+ )
31
+
32
+ outputs = model.generate(
33
+ **inputs,
34
+ max_new_tokens=64,
35
+ num_beams=4,
36
+ early_stopping=True
37
+ )
38
+
39
+ answer = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
40
+ chat_history.append([question, answer])
41
+ return "", chat_history
42
+ except Exception as e:
43
+ if chat_history is None:
44
+ chat_history = []
45
+ chat_history.append([question, f"Error: {str(e)}"])
46
+ return "", chat_history
47
 
 
48
  with gr.Blocks(theme=gr.themes.Base()) as interface:
49
  gr.Markdown(
50
  """
 
64
  submit_btn = gr.Button("Send")
65
 
66
  clear_btn = gr.Button("Clear Chat")
 
 
67
  state = gr.State([])
68
 
 
69
  submit_btn.click(
70
+ generate_answer,
71
+ inputs=[question_box, state],
72
+ outputs=[question_box, chatbot],
73
+ )
74
+
75
+ question_box.submit(
76
+ generate_answer,
77
+ inputs=[question_box, state],
78
+ outputs=[question_box, chatbot]
79
+ )
80
 
81
  clear_btn.click(
82
  lambda: ("", [], []),
83
  inputs=[],
84
  outputs=[question_box, chatbot, state],
85
+ )
 
86
 
 
87
  interface.launch(share=True)
88
 
89
 
90
+