|
import json |
|
import numpy as np |
|
from tokenizers import Tokenizer |
|
import onnxruntime as ort |
|
from huggingface_hub import hf_hub_download |
|
import gradio as gr |
|
|
|
|
|
class ONNXInferencePipeline: |
|
def __init__(self, repo_id): |
|
self.onnx_path = hf_hub_download(repo_id=repo_id, filename="mindmeter.onnx") |
|
self.tokenizer_path = hf_hub_download(repo_id=repo_id, filename="train_bpe_tokenizer.json") |
|
self.config_path = hf_hub_download(repo_id=repo_id, filename="hyperparameters.json") |
|
|
|
|
|
|
|
with open(self.config_path, "r") as f: |
|
self.config = json.load(f) |
|
|
|
|
|
self.tokenizer = Tokenizer.from_file(self.tokenizer_path) |
|
|
|
self.max_len = self.config["MAX_LEN"] |
|
|
|
|
|
self.session = ort.InferenceSession(self.onnx_path) |
|
|
|
self.providers = ['CPUExecutionProvider'] |
|
if 'CUDAExecutionProvider' in ort.get_available_providers(): |
|
self.providers = ['CUDAExecutionProvider'] |
|
self.session.set_providers(self.providers) |
|
|
|
def preprocess(self, text): |
|
encoding = self.tokenizer.encode(text) |
|
|
|
ids = encoding.ids[:self.max_len] |
|
|
|
padding = [0] * (self.max_len - len(ids)) |
|
return np.array(ids + padding, dtype=np.int64).reshape(1, -1) |
|
|
|
def predict(self, text): |
|
""" |
|
Given an input text string, run inference and return only the granular stress level. |
|
The model outputs: |
|
0 -> "Not Stressed" |
|
1 -> "Stressed" |
|
When the model predicts "Stressed", a confidence-based thresholding is applied: |
|
- confidence < 0.40: "Not Stressed" (fallback) |
|
- 0.40 ≤ confidence < 0.65: "Low Stress" |
|
- 0.65 ≤ confidence < 0.90: "Moderate Stress" |
|
- 0.90 ≤ confidence: "High Stress" |
|
""" |
|
input_array = self.preprocess(text) |
|
|
|
outputs = self.session.run(None, {"input": input_array}) |
|
logits = outputs[0] |
|
|
|
exp_logits = np.exp(logits) |
|
probabilities = exp_logits / np.sum(exp_logits, axis=1, keepdims=True) |
|
predicted_class = int(np.argmax(probabilities)) |
|
class_labels = ["Not Stressed", "Stressed"] |
|
predicted_label = class_labels[predicted_class] |
|
confidence = float(probabilities[0][predicted_class]) |
|
|
|
if predicted_label == "Stressed": |
|
|
|
stress_confidence = confidence |
|
if stress_confidence < 0.40: |
|
stress_level = "Not Stressed" |
|
elif 0.40 <= stress_confidence < 0.65: |
|
stress_level = "Low Stress" |
|
elif 0.65 <= stress_confidence < 0.90: |
|
stress_level = "Moderate Stress" |
|
else: |
|
stress_level = "High Stress" |
|
else: |
|
stress_level = "Not Stressed" |
|
|
|
return {"stress_level": stress_level} |
|
|
|
|
|
if __name__ == "__main__": |
|
pipeline = ONNXInferencePipeline(repo_id="iimran/MindMeter") |
|
|
|
text1 = "Yay! what a happy life" |
|
text2 = "I’ve missed another loan payment, and I don’t know how I’m going to catch up. The pressure is unbearable." |
|
text3 = "I am upset about how badly life is trating me these days, its shit." |
|
|
|
result1 = pipeline.predict(text1) |
|
result2 = pipeline.predict(text2) |
|
result3 = pipeline.predict(text3) |
|
|
|
print(f"Prediction for text 1: {result1}") |
|
print(f"Prediction for text 2: {result2}") |
|
print(f"Prediction for text 3: {result3}") |
|
|
|
def gradio_predict(text): |
|
result = pipeline.predict(text) |
|
return result["stress_level"] |
|
|
|
iface = gr.Interface( |
|
fn=gradio_predict, |
|
inputs=gr.Textbox(lines=7, placeholder="Enter your text here..."), |
|
outputs="text", |
|
title="MindMeter – Granular Stress Level Detection Agent", |
|
description=( |
|
"MindMeter is designed to swiftly assess the stress levels expressed in text communications, making it an invaluable tool for local councils, especially when addressing financial hardship cases. By analyzing the tone and wording in emails or chat messages, MindMeter categorizes the expressed sentiment into one of four levels—Not Stressed, Low Stress, Moderate Stress, or High Stress. This allows council representatives to quickly identify residents who might be under significant stress due to financial challenges. In turn, councils can prioritize outreach and tailor support services to address urgent concerns effectively, ensuring that vulnerable community members receive the timely assistance they need." |
|
" This agent will identify Risk from Text communication - Next Agent will read the financial reports/bank statements/loan statements and cross verify the finacnail hardship." |
|
), |
|
examples=[ |
|
"Yay! what a happy life", |
|
"I’ve missed another loan payment, and I don’t know how I’m going to catch up. The pressure is unbearable.", |
|
"I am upset about how badly life is trating me these days, its shit and i wanna end it" |
|
] |
|
) |
|
|
|
|
|
iface.launch() |
|
|