naveenus commited on
Commit
42f0920
·
verified ·
1 Parent(s): 90db17e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -32
app.py CHANGED
@@ -1,42 +1,40 @@
1
  # app.py
2
- import json
3
- import gradio as gr
4
  from transformers import pipeline
5
 
6
- # 1) Load base labels from JSON
7
- with open("labels.json", "r") as f:
8
- base_labels = json.load(f)
 
 
9
 
10
- # 2) Prepare default textbox value
11
- default_label_str = ", ".join(base_labels)
 
12
 
13
- # 3) Initialize zero-shot classifier
14
- classifier = pipeline(
15
- task="zero-shot-classification",
16
- model="facebook/bart-large-mnli"
17
- )
18
-
19
- # 4) Interface function that merges runtime labels
20
- def tag_question(question: str, labels_str: str):
21
- # Split & clean the user-supplied string
22
- labels = [lbl.strip() for lbl in labels_str.split(",") if lbl.strip()]
23
- # Zero-shot classify
24
- out = classifier(question, candidate_labels=labels)
25
- # Return top-3 labels with scores
26
- return {lbl: round(score,3) for lbl, score in zip(out["labels"], out["scores"])}
27
 
28
- # 5) Build the Gradio UI
29
  iface = gr.Interface(
30
- fn=tag_question,
31
- inputs=[
32
- gr.Textbox(lines=3, label="Question"),
33
- gr.Textbox(lines=2, label="Candidate Labels (comma-separated)",
34
- value=default_label_str)
35
- ],
36
- outputs=gr.Label(num_top_classes=3),
37
- title="Hybrid Zero-Shot Question Tagger",
38
- description="Loaded labels from `labels.json`, editable at runtime."
39
  )
40
 
41
- if __name__ == "__main__":
42
  iface.launch()
 
1
  # app.py
2
+ import json, gradio as gr
 
3
  from transformers import pipeline
4
 
5
+ # 1) Load taxonomies
6
+ with open("coarse_labels.json") as f:
7
+ coarse_labels = json.load(f) # :contentReference[oaicite:4]{index=4}
8
+ with open("fine_labels.json") as f:
9
+ fine_map = json.load(f) # :contentReference[oaicite:5]{index=5}
10
 
11
+ # 2) Init classifier
12
+ classifier = pipeline("zero-shot-classification",
13
+ model="facebook/bart-large-mnli")
14
 
15
+ # 3) Tagging fn
16
+ def hierarchical_tag(question):
17
+ # Stage 1: pick coarse subject
18
+ coarse_out = classifier(question, candidate_labels=coarse_labels)
19
+ chosen = coarse_out["labels"][0]
20
+ # Stage 2: fine-grained tags within that subject
21
+ fine_labels = fine_map.get(chosen, [])
22
+ fine_out = classifier(question, candidate_labels=fine_labels)
23
+ # Return both
24
+ return {
25
+ "Subject": chosen,
26
+ **{lbl: round(score,3)
27
+ for lbl, score in zip(fine_out["labels"], fine_out["scores"])}
28
+ }
29
 
30
+ # 4) Build UI
31
  iface = gr.Interface(
32
+ fn=hierarchical_tag,
33
+ inputs=gr.Textbox(lines=3, label="Enter your question"),
34
+ outputs=gr.JSON(label="Hierarchical Tags"),
35
+ title="Two-Stage Zero-Shot Question Tagger",
36
+ description="Stage 1: classify subject; Stage 2: classify topic within subject."
 
 
 
 
37
  )
38
 
39
+ if __name__=="__main__":
40
  iface.launch()