iimran commited on
Commit
b02596a
·
verified ·
1 Parent(s): 80613bb

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -0
app.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import onnxruntime as ort
3
+ import numpy as np
4
+ from tokenizers import Tokenizer
5
+ from huggingface_hub import hf_hub_download
6
+
7
+ # Configuration parameters
8
+ MAX_LEN = 256
9
+
10
+ # Hardcoded description value
11
+ DESCRIPTION_TEXT = (
12
+ "I am raising this case to report a severe and ongoing issue of vermin infestation, "
13
+ "specifically rats and mice, in my residential area. The problem appears to be directly linked "
14
+ )
15
+
16
+ # Define possible choices for each field
17
+ status_choices = ["Assess & Assign", "Generate Letter", "Site Inspection"]
18
+ category_choices = ["Litter and Nuisance"]
19
+ request_reason_choices = ["Nuisance"]
20
+ request_sub_reason_choices = ["Animals"]
21
+ additional_reason_choices = ["Vermin, Rats and Mice", "Dog", "Cat", "Horse"]
22
+ notification_method_choices = ["No Notification", "Email", "Phone"]
23
+ inspection_performed_choices = ["Yes", "No"]
24
+ letter_sent_choices = ["Yes", "No"]
25
+
26
+ # Download the ONNX model and tokenizer from Hugging Face Hub
27
+ onnx_model_path = hf_hub_download(
28
+ repo_id="iimran/Case-Next-Best-Action-Classifier", filename="moodlens.onnx"
29
+ )
30
+ tokenizer_path = hf_hub_download(
31
+ repo_id="iimran/Case-Next-Best-Action-Classifier", filename="train_bpe_tokenizer.json"
32
+ )
33
+
34
+ # Load the tokenizer and ONNX model once outside the function for efficiency.
35
+ tokenizer = Tokenizer.from_file(tokenizer_path)
36
+ session = ort.InferenceSession(onnx_model_path)
37
+ input_name = session.get_inputs()[0].name
38
+ output_name = session.get_outputs()[0].name
39
+
40
+ def predict_action(status, category, request_reason, request_sub_reason,
41
+ additional_reason, notification_method, inspection_performed, letter_sent):
42
+ # Combine fields into one input string.
43
+ fields = [
44
+ status,
45
+ category,
46
+ request_reason,
47
+ request_sub_reason,
48
+ additional_reason,
49
+ notification_method,
50
+ DESCRIPTION_TEXT,
51
+ inspection_performed,
52
+ letter_sent
53
+ ]
54
+ sample_text = " ".join(fields)
55
+
56
+ # Tokenize and pad the input
57
+ encoding = tokenizer.encode(sample_text)
58
+ ids = encoding.ids[:MAX_LEN]
59
+ padding = [0] * (MAX_LEN - len(ids))
60
+ input_ids = np.array([ids + padding], dtype=np.int64)
61
+
62
+ # Run inference
63
+ outputs = session.run([output_name], {input_name: input_ids})
64
+ predicted_class = np.argmax(outputs[0], axis=1)[0]
65
+
66
+ # Map predicted index to the actual action labels
67
+ label_names = [
68
+ "Assign Case Officer",
69
+ "Generate Letter and Send By Post",
70
+ "Generate Letter and Send Email",
71
+ "Generate Letter and Send SMS",
72
+ "Schedule Inspection",
73
+ "Send Feedback Survey"
74
+ ]
75
+ predicted_label = label_names[predicted_class]
76
+ return predicted_label
77
+
78
+ # Create the Gradio Interface using the updated API.
79
+ demo = gr.Interface(
80
+ fn=predict_action,
81
+ inputs=[
82
+ gr.Dropdown(choices=status_choices, label="Status"),
83
+ gr.Dropdown(choices=category_choices, label="Category"),
84
+ gr.Dropdown(choices=request_reason_choices, label="Request Reason"),
85
+ gr.Dropdown(choices=request_sub_reason_choices, label="Request Sub Reason"),
86
+ gr.Dropdown(choices=additional_reason_choices, label="Additional Reason"),
87
+ gr.Dropdown(choices=notification_method_choices, label="Notification Method"),
88
+ gr.Dropdown(choices=inspection_performed_choices, label="Inspection Performed"),
89
+ gr.Dropdown(choices=letter_sent_choices, label="Letter Sent")
90
+ ],
91
+ outputs=gr.Textbox(label="Predicted Action"),
92
+ title="MoodleLens Action Predictor",
93
+ description="Select values from the dropdowns. The description field is fixed."
94
+ )
95
+
96
+ demo.launch()