|
import gradio as gr |
|
import onnxruntime as ort |
|
import numpy as np |
|
from tokenizers import Tokenizer |
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
MAX_LEN = 256 |
|
|
|
|
|
DESCRIPTION_TEXT = ( |
|
"I am raising this case to report a severe and ongoing issue of vermin infestation, " |
|
"specifically rats and mice, in my residential area. The problem appears to be directly linked " |
|
) |
|
|
|
|
|
status_choices = ["Assess & Assign", "Generate Letter", "Site Inspection"] |
|
category_choices = ["Litter and Nuisance"] |
|
request_reason_choices = ["Nuisance"] |
|
request_sub_reason_choices = ["Animals"] |
|
additional_reason_choices = ["Vermin, Rats and Mice", "Dog", "Cat", "Horse"] |
|
notification_method_choices = ["No Notification", "Email", "Phone"] |
|
inspection_performed_choices = ["Yes", "No"] |
|
letter_sent_choices = ["Yes", "No"] |
|
|
|
|
|
onnx_model_path = hf_hub_download( |
|
repo_id="iimran/Case-Next-Best-Action-Classifier", filename="nba.onnx" |
|
) |
|
tokenizer_path = hf_hub_download( |
|
repo_id="iimran/Case-Next-Best-Action-Classifier", filename="train_bpe_tokenizer.json" |
|
) |
|
|
|
|
|
tokenizer = Tokenizer.from_file(tokenizer_path) |
|
session = ort.InferenceSession(onnx_model_path) |
|
input_name = session.get_inputs()[0].name |
|
output_name = session.get_outputs()[0].name |
|
|
|
def predict_action(status, category, request_reason, request_sub_reason, |
|
additional_reason, notification_method, inspection_performed, letter_sent): |
|
|
|
fields = [ |
|
status, |
|
category, |
|
request_reason, |
|
request_sub_reason, |
|
additional_reason, |
|
notification_method, |
|
DESCRIPTION_TEXT, |
|
inspection_performed, |
|
letter_sent |
|
] |
|
sample_text = " ".join(fields) |
|
|
|
|
|
encoding = tokenizer.encode(sample_text) |
|
ids = encoding.ids[:MAX_LEN] |
|
padding = [0] * (MAX_LEN - len(ids)) |
|
input_ids = np.array([ids + padding], dtype=np.int64) |
|
|
|
|
|
outputs = session.run([output_name], {input_name: input_ids}) |
|
predicted_class = np.argmax(outputs[0], axis=1)[0] |
|
|
|
|
|
label_names = [ |
|
"Assign Case Officer", |
|
"Generate Letter and Send By Post", |
|
"Generate Letter and Send Email", |
|
"Generate Letter and Send SMS", |
|
"Schedule Inspection", |
|
"Send Feedback Survey" |
|
] |
|
predicted_label = label_names[predicted_class] |
|
return predicted_label |
|
|
|
|
|
demo = gr.Interface( |
|
fn=predict_action, |
|
inputs=[ |
|
gr.Dropdown(choices=status_choices, label="Status"), |
|
gr.Dropdown(choices=category_choices, label="Category"), |
|
gr.Dropdown(choices=request_reason_choices, label="Request Reason"), |
|
gr.Dropdown(choices=request_sub_reason_choices, label="Request Sub Reason"), |
|
gr.Dropdown(choices=additional_reason_choices, label="Additional Reason"), |
|
gr.Dropdown(choices=notification_method_choices, label="Notification Method"), |
|
gr.Dropdown(choices=inspection_performed_choices, label="Inspection Performed", value="No"), |
|
gr.Dropdown(choices=letter_sent_choices, label="Letter Sent", value="No") |
|
|
|
], |
|
outputs=gr.Textbox(label="Predicted Action"), |
|
title="Council - Case Next Best Action Predictor", |
|
description="Select values from the dropdowns. The description field is fixed." |
|
) |
|
|
|
demo.launch() |
|
|