Update app.py
Browse files
app.py
CHANGED
@@ -8,6 +8,7 @@ import gradio as gr
|
|
8 |
|
9 |
class ONNXInferencePipeline:
|
10 |
def __init__(self, repo_id):
|
|
|
11 |
self.onnx_path = hf_hub_download(repo_id=repo_id, filename="mindmeter.onnx")
|
12 |
self.tokenizer_path = hf_hub_download(repo_id=repo_id, filename="train_bpe_tokenizer.json")
|
13 |
self.config_path = hf_hub_download(repo_id=repo_id, filename="hyperparameters.json")
|
@@ -19,7 +20,7 @@ class ONNXInferencePipeline:
|
|
19 |
|
20 |
# Initialize the tokenizer from file.
|
21 |
self.tokenizer = Tokenizer.from_file(self.tokenizer_path)
|
22 |
-
# Use the maximum sequence length from the
|
23 |
self.max_len = self.config["MAX_LEN"]
|
24 |
|
25 |
# Initialize the ONNX runtime session.
|
@@ -31,57 +32,46 @@ class ONNXInferencePipeline:
|
|
31 |
self.session.set_providers(self.providers)
|
32 |
|
33 |
def preprocess(self, text):
|
|
|
|
|
|
|
34 |
encoding = self.tokenizer.encode(text)
|
35 |
-
# Truncate to self.max_len tokens
|
36 |
ids = encoding.ids[:self.max_len]
|
37 |
-
# Pad with zeros if necessary
|
38 |
padding = [0] * (self.max_len - len(ids))
|
39 |
return np.array(ids + padding, dtype=np.int64).reshape(1, -1)
|
40 |
|
41 |
def predict(self, text):
|
42 |
"""
|
43 |
-
Given an input text string, run inference and return
|
44 |
The model outputs:
|
45 |
0 -> "Not Stressed"
|
46 |
1 -> "Stressed"
|
47 |
-
|
48 |
-
- confidence < 0.40: "Not Stressed" (fallback)
|
49 |
-
- 0.40 ≤ confidence < 0.65: "Low Stress"
|
50 |
-
- 0.65 ≤ confidence < 0.90: "Moderate Stress"
|
51 |
-
- 0.90 ≤ confidence: "High Stress"
|
52 |
"""
|
53 |
input_array = self.preprocess(text)
|
54 |
-
|
55 |
outputs = self.session.run(None, {"input": input_array})
|
56 |
logits = outputs[0]
|
57 |
|
|
|
58 |
exp_logits = np.exp(logits)
|
59 |
probabilities = exp_logits / np.sum(exp_logits, axis=1, keepdims=True)
|
60 |
predicted_class = int(np.argmax(probabilities))
|
|
|
|
|
61 |
class_labels = ["Not Stressed", "Stressed"]
|
62 |
predicted_label = class_labels[predicted_class]
|
63 |
-
confidence = float(probabilities[0][predicted_class])
|
64 |
-
|
65 |
-
if predicted_label == "Stressed":
|
66 |
-
# Use the confidence of the "Stressed" class
|
67 |
-
stress_confidence = confidence
|
68 |
-
if stress_confidence < 0.40:
|
69 |
-
stress_level = "Not Stressed" # Fallback (unlikely)
|
70 |
-
elif 0.40 <= stress_confidence < 0.60:
|
71 |
-
stress_level = "Low Stress"
|
72 |
-
elif 0.60 <= stress_confidence < 0.95:
|
73 |
-
stress_level = "Moderate Stress"
|
74 |
-
else: # 0.95 ≤ stress_confidence ≤ 1.00
|
75 |
-
stress_level = "High Stress"
|
76 |
-
else:
|
77 |
-
stress_level = "Not Stressed"
|
78 |
|
79 |
-
|
|
|
80 |
|
81 |
|
82 |
if __name__ == "__main__":
|
|
|
83 |
pipeline = ONNXInferencePipeline(repo_id="iimran/MindMeter")
|
84 |
|
|
|
85 |
text1 = "Yay! what a happy life"
|
86 |
text2 = "I’ve missed another loan payment, and I don’t know how I’m going to catch up. The pressure is unbearable."
|
87 |
text3 = "I am upset about how badly life is trating me these days, its shit."
|
@@ -94,23 +84,25 @@ if __name__ == "__main__":
|
|
94 |
print(f"Prediction for text 2: {result2}")
|
95 |
print(f"Prediction for text 3: {result3}")
|
96 |
|
|
|
97 |
def gradio_predict(text):
|
98 |
result = pipeline.predict(text)
|
99 |
return result["stress_level"]
|
100 |
|
|
|
101 |
iface = gr.Interface(
|
102 |
fn=gradio_predict,
|
103 |
inputs=gr.Textbox(lines=7, placeholder="Enter your text here..."),
|
104 |
outputs="text",
|
105 |
-
title="MindMeter –
|
106 |
description=(
|
107 |
-
"MindMeter is designed to swiftly assess the stress levels expressed in text communications, making it an invaluable tool for local councils, especially when addressing financial hardship cases. By analyzing the tone and wording in emails or chat messages, MindMeter categorizes the expressed sentiment into
|
108 |
" This agent will identify Risk from Text communication - Next Agent will read the financial reports/bank statements/loan statements and cross verify the finacnail hardship."
|
109 |
),
|
110 |
examples=[
|
111 |
"Yay! what a happy life",
|
112 |
"I’ve missed another loan payment, and I don’t know how I’m going to catch up. The pressure is unbearable.",
|
113 |
-
"I am upset about how badly life is trating me these days, its shit
|
114 |
]
|
115 |
)
|
116 |
|
|
|
8 |
|
9 |
class ONNXInferencePipeline:
|
10 |
def __init__(self, repo_id):
|
11 |
+
# Download the model files from Hugging Face Hub.
|
12 |
self.onnx_path = hf_hub_download(repo_id=repo_id, filename="mindmeter.onnx")
|
13 |
self.tokenizer_path = hf_hub_download(repo_id=repo_id, filename="train_bpe_tokenizer.json")
|
14 |
self.config_path = hf_hub_download(repo_id=repo_id, filename="hyperparameters.json")
|
|
|
20 |
|
21 |
# Initialize the tokenizer from file.
|
22 |
self.tokenizer = Tokenizer.from_file(self.tokenizer_path)
|
23 |
+
# Use the maximum sequence length from the configuration.
|
24 |
self.max_len = self.config["MAX_LEN"]
|
25 |
|
26 |
# Initialize the ONNX runtime session.
|
|
|
32 |
self.session.set_providers(self.providers)
|
33 |
|
34 |
def preprocess(self, text):
|
35 |
+
"""
|
36 |
+
Tokenize the input text, truncate or pad it to max_len, and return a numpy array.
|
37 |
+
"""
|
38 |
encoding = self.tokenizer.encode(text)
|
39 |
+
# Truncate to self.max_len tokens.
|
40 |
ids = encoding.ids[:self.max_len]
|
41 |
+
# Pad with zeros if necessary.
|
42 |
padding = [0] * (self.max_len - len(ids))
|
43 |
return np.array(ids + padding, dtype=np.int64).reshape(1, -1)
|
44 |
|
45 |
def predict(self, text):
|
46 |
"""
|
47 |
+
Given an input text string, run inference and return the predicted stress label.
|
48 |
The model outputs:
|
49 |
0 -> "Not Stressed"
|
50 |
1 -> "Stressed"
|
51 |
+
This function returns one of these two labels.
|
|
|
|
|
|
|
|
|
52 |
"""
|
53 |
input_array = self.preprocess(text)
|
|
|
54 |
outputs = self.session.run(None, {"input": input_array})
|
55 |
logits = outputs[0]
|
56 |
|
57 |
+
# Compute softmax probabilities from the logits.
|
58 |
exp_logits = np.exp(logits)
|
59 |
probabilities = exp_logits / np.sum(exp_logits, axis=1, keepdims=True)
|
60 |
predicted_class = int(np.argmax(probabilities))
|
61 |
+
|
62 |
+
# Map the predicted index to a label.
|
63 |
class_labels = ["Not Stressed", "Stressed"]
|
64 |
predicted_label = class_labels[predicted_class]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
+
# Return only the predicted stress label.
|
67 |
+
return {"stress_level": predicted_label}
|
68 |
|
69 |
|
70 |
if __name__ == "__main__":
|
71 |
+
# Initialize the pipeline with the Hugging Face repository ID.
|
72 |
pipeline = ONNXInferencePipeline(repo_id="iimran/MindMeter")
|
73 |
|
74 |
+
# Example input texts for local testing.
|
75 |
text1 = "Yay! what a happy life"
|
76 |
text2 = "I’ve missed another loan payment, and I don’t know how I’m going to catch up. The pressure is unbearable."
|
77 |
text3 = "I am upset about how badly life is trating me these days, its shit."
|
|
|
84 |
print(f"Prediction for text 2: {result2}")
|
85 |
print(f"Prediction for text 3: {result3}")
|
86 |
|
87 |
+
# Define a function for Gradio to use.
|
88 |
def gradio_predict(text):
|
89 |
result = pipeline.predict(text)
|
90 |
return result["stress_level"]
|
91 |
|
92 |
+
# Create the Gradio interface.
|
93 |
iface = gr.Interface(
|
94 |
fn=gradio_predict,
|
95 |
inputs=gr.Textbox(lines=7, placeholder="Enter your text here..."),
|
96 |
outputs="text",
|
97 |
+
title="MindMeter – Stress Detection Agent",
|
98 |
description=(
|
99 |
+
"MindMeter is designed to swiftly assess the stress levels expressed in text communications, making it an invaluable tool for local councils, especially when addressing financial hardship cases. By analyzing the tone and wording in emails or chat messages, MindMeter categorizes the expressed sentiment into 'Stressed' or 'Not Stressed' outputs. This allows council representatives to quickly identify residents who might be under significant stress due to financial challenges. In turn, councils can prioritize outreach and tailor support services to address urgent concerns effectively, ensuring that vulnerable community members receive the timely assistance they need."
|
100 |
" This agent will identify Risk from Text communication - Next Agent will read the financial reports/bank statements/loan statements and cross verify the finacnail hardship."
|
101 |
),
|
102 |
examples=[
|
103 |
"Yay! what a happy life",
|
104 |
"I’ve missed another loan payment, and I don’t know how I’m going to catch up. The pressure is unbearable.",
|
105 |
+
"I am upset about how badly life is trating me these days, its shit."
|
106 |
]
|
107 |
)
|
108 |
|