naveenus commited on
Commit
e5b7bb1
·
verified ·
1 Parent(s): 1ffab5b

Added an option to override default tags

Browse files
Files changed (1) hide show
  1. app.py +26 -12
app.py CHANGED
@@ -1,27 +1,41 @@
 
 
1
  import gradio as gr
2
  from transformers import pipeline
3
 
4
- # 1) Initialize zero-shot classifier
 
 
 
 
 
 
 
5
  classifier = pipeline(
6
  task="zero-shot-classification",
7
  model="facebook/bart-large-mnli"
8
  )
9
 
10
- # 2) Define candidate labels
11
- LABELS = ["linear algebra", "calculus", "probability", "geometry"]
12
-
13
- # 3) Gradio interface function
14
- def tag_question(question):
15
- result = classifier(question, candidate_labels=LABELS)
16
- return {lbl: round(score, 3) for lbl, score in zip(result["labels"], result["scores"])}
 
17
 
18
- # 4) Build UI
19
  iface = gr.Interface(
20
  fn=tag_question,
21
- inputs=gr.Textbox(lines=3, placeholder="Enter your MCQ here..."),
 
 
 
 
22
  outputs=gr.Label(num_top_classes=3),
23
- title="Zero-Shot Question Tagger",
24
- description="Classify questions into math topics without any training data."
25
  )
26
 
27
  if __name__ == "__main__":
 
1
+ # app.py
2
+ import json
3
  import gradio as gr
4
  from transformers import pipeline
5
 
6
+ # 1) Load base labels from JSON
7
+ with open("labels.json", "r") as f:
8
+ base_labels = json.load(f)
9
+
10
+ # 2) Prepare default textbox value
11
+ default_label_str = ", ".join(base_labels)
12
+
13
+ # 3) Initialize zero-shot classifier
14
  classifier = pipeline(
15
  task="zero-shot-classification",
16
  model="facebook/bart-large-mnli"
17
  )
18
 
19
+ # 4) Interface function that merges runtime labels
20
+ def tag_question(question: str, labels_str: str):
21
+ # Split & clean the user-supplied string
22
+ labels = [lbl.strip() for lbl in labels_str.split(",") if lbl.strip()]
23
+ # Zero-shot classify
24
+ out = classifier(question, candidate_labels=labels)
25
+ # Return top-3 labels with scores
26
+ return {lbl: round(score,3) for lbl, score in zip(out["labels"], out["scores"])}
27
 
28
+ # 5) Build the Gradio UI
29
  iface = gr.Interface(
30
  fn=tag_question,
31
+ inputs=[
32
+ gr.Textbox(lines=3, label="Question"),
33
+ gr.Textbox(lines=2, label="Candidate Labels (comma-separated)",
34
+ value=default_label_str)
35
+ ],
36
  outputs=gr.Label(num_top_classes=3),
37
+ title="Hybrid Zero-Shot Question Tagger",
38
+ description="Loaded labels from `labels.json`, editable at runtime."
39
  )
40
 
41
  if __name__ == "__main__":