naveenus commited on
Commit
cd09eef
Β·
verified Β·
1 Parent(s): 8f108e4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -15
app.py CHANGED
@@ -2,6 +2,11 @@ import json, time, csv, os
2
  import gradio as gr
3
  from transformers import pipeline
4
 
 
 
 
 
 
5
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
6
  # Load taxonomies
7
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
@@ -38,28 +43,38 @@ for fn, hdr in [
38
  # Inference functions
39
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
40
  def run_stage1(question, model_name):
 
 
41
  start = time.time()
42
- clf = pipeline("zero-shot-classification", model=model_name)
43
  out = clf(question, candidate_labels=coarse_labels)
44
  labels, scores = out["labels"][:3], out["scores"][:3]
45
  duration = round(time.time() - start, 3)
46
 
47
- # Build the label mapping
48
- label_dict = {lbl: round(score,3) for lbl,score in zip(labels, scores)}
49
- # Prepare Radio update
50
  radio_update = gr.update(choices=labels, value=labels[0])
 
 
 
51
 
52
- return label_dict, radio_update, f"⏱ {duration}s"
53
 
54
  def run_stage2(question, model_name, subject):
55
- """Return top3 fine topics + duration."""
56
- start = time.time()
57
- clf = pipeline("zero-shot-classification", model=model_name)
58
  fine_labels = fine_map.get(subject, [])
 
 
 
 
 
 
59
  out = clf(question, candidate_labels=fine_labels)
60
  labels, scores = out["labels"][:3], out["scores"][:3]
61
- duration = round(time.time()-start,3)
62
- # Log combined run
 
63
  with open(LOG_FILE, "a", newline="") as f:
64
  csv.writer(f).writerow([
65
  time.strftime("%Y-%m-%d %H:%M:%S"),
@@ -69,7 +84,11 @@ def run_stage2(question, model_name, subject):
69
  ";".join(labels),
70
  duration
71
  ])
72
- return {lbl: round(score,3) for lbl,score in zip(labels, scores)}, f"⏱ {duration}s"
 
 
 
 
73
 
74
  def submit_feedback(question, subject_fb, topic_fb):
75
  with open(FEEDBACK_FILE, "a", newline="") as f:
@@ -92,10 +111,10 @@ with gr.Blocks() as demo:
92
  model_input = gr.Dropdown(choices=MODEL_CHOICES, value=MODEL_CHOICES[0], label="Choose model")
93
  go_button = gr.Button("Run Stage 1")
94
 
95
- subject_out = gr.Label(num_top_classes=3, label="Top-3 Subjects")
96
- subj_radio = gr.Radio(choices=[], label="Select Subject for Stage 2")
97
- stage1_time = gr.Textbox(label="Stage 1 Time")
98
-
99
  go_button.click(
100
  fn=run_stage1,
101
  inputs=[question_input, model_input],
 
2
  import gradio as gr
3
  from transformers import pipeline
4
 
5
+ PIPELINES = {
6
+ name: pipeline("zero-shot-classification", model=name)
7
+ for name in MODEL_CHOICES
8
+ }
9
+
10
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
11
  # Load taxonomies
12
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
 
43
  # Inference functions
44
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
45
  def run_stage1(question, model_name):
46
+ if not question or not question.strip():
47
+ return {}, gr.update(choices=[]), ""
48
  start = time.time()
49
+ clf = PIPELINES[model_name]
50
  out = clf(question, candidate_labels=coarse_labels)
51
  labels, scores = out["labels"][:3], out["scores"][:3]
52
  duration = round(time.time() - start, 3)
53
 
54
+ # Prepare outputs
55
+ subject_dict = {lbl: round(score,3) for lbl,score in zip(labels, scores)}
 
56
  radio_update = gr.update(choices=labels, value=labels[0])
57
+ time_str = f"⏱ {duration}s"
58
+
59
+ return subject_dict, radio_update, time_str
60
 
 
61
 
62
  def run_stage2(question, model_name, subject):
63
+ # 1) Validate inputs
64
+ if not question or not question.strip():
65
+ return {}, "No question provided", ""
66
  fine_labels = fine_map.get(subject, [])
67
+ if not fine_labels:
68
+ return {}, f"No topics found for '{subject}'", ""
69
+
70
+ # 2) Inference (fast, using preloaded pipeline)
71
+ start = time.time()
72
+ clf = PIPELINES[model_name]
73
  out = clf(question, candidate_labels=fine_labels)
74
  labels, scores = out["labels"][:3], out["scores"][:3]
75
+ duration = round(time.time() - start, 3)
76
+
77
+ # 3) Logging
78
  with open(LOG_FILE, "a", newline="") as f:
79
  csv.writer(f).writerow([
80
  time.strftime("%Y-%m-%d %H:%M:%S"),
 
84
  ";".join(labels),
85
  duration
86
  ])
87
+
88
+ # 4) Return topics + time
89
+ topic_dict = {lbl: round(score,3) for lbl,score in zip(labels, scores)}
90
+ return topic_dict, f"⏱ {duration}s"
91
+
92
 
93
  def submit_feedback(question, subject_fb, topic_fb):
94
  with open(FEEDBACK_FILE, "a", newline="") as f:
 
111
  model_input = gr.Dropdown(choices=MODEL_CHOICES, value=MODEL_CHOICES[0], label="Choose model")
112
  go_button = gr.Button("Run Stage 1")
113
 
114
+ subject_out = gr.Label(num_top_classes=3, label="Top-3 Subjects")
115
+ subj_radio = gr.Radio(choices=[], label="Select Subject for Stage 2")
116
+ stage1_time = gr.Textbox(label="Stage 1 Time")
117
+
118
  go_button.click(
119
  fn=run_stage1,
120
  inputs=[question_input, model_input],