Spaces:
Running
Running
import requests | |
import time | |
import gradio as gr | |
import os | |
######################## | |
## Loading the model: ## | |
####################### | |
api_key = os.environ.get("HF_API_KEY_INFERENCE") | |
#7 dimensions: | |
API_URL1 = "https://api-inference.huggingface.co/models/chernandezc/EEMM_7_categories_WB" #Api endpoint. | |
headers = {"Authorization": f"Bearer {api_key}"} #Api Key, eqwual for both. | |
def query1(payload): #Function to use the API. | |
response = requests.post(API_URL1, headers=headers, json=payload) | |
return response.json() #Return Json. | |
#3 levels: | |
API_URL2 = "https://api-inference.huggingface.co/models/chernandezc/EEMM_3_dimensions_1201" #Api endpoint. | |
def query2(payload): #Function to use the API. | |
response = requests.post(API_URL2, headers=headers, json=payload) | |
return response.json() #Return Json. | |
# ########################################################## | |
# Function to process the output and print classifications # | |
############################################################ | |
def classify_output(item): | |
#Dictionary for dimensions. | |
label_dict1 = { | |
'LABEL_0': 'Cognition', | |
'LABEL_1': 'Affect', | |
'LABEL_2': 'Self', | |
'LABEL_3': 'Motivation', | |
'LABEL_4': 'Attention', | |
'LABEL_5': 'Overt_Behavior', | |
'LABEL_6': 'Context' | |
} | |
#Dictionary for levels. | |
label_dict2 = { | |
'LABEL_0': 'Social', | |
'LABEL_1': 'Psychological', | |
'LABEL_2': 'Physical' | |
} | |
output1 = query1({ #Try to query the endpoint. | |
"inputs": item, | |
}) | |
output2 = query2({ #Try to query the endpoint. | |
"inputs": item, | |
}) | |
# Initial minimal delay | |
min_delay = 1 # seconds | |
#If model is idle wait and try again. | |
while 'error' in output1 or 'error' in output2: | |
if 'error' in output1: | |
time.sleep(min(output1.get("estimated_time", min_delay), min_delay)) | |
output1 = query1({"inputs": item}) | |
if 'error' in output2: | |
time.sleep(min(output2.get("estimated_time", min_delay), min_delay)) | |
output2 = query2({"inputs": item}) | |
# Store classifications in a list | |
classifications1 = [] | |
classifications2 = [] | |
# Find the item with the highest score | |
highest_score_item1 = max(output1[0], key=lambda x: x['score']) | |
highest_score_item2 = max(output2[0], key=lambda x: x['score']) | |
for item in output1[0]: | |
# Check if the score is greater than or equal to 0.5 | |
if item['score'] >= 0.5: | |
# Append the category and score to the classifications list | |
classifications1.append((label_dict1[item['label']], item['score'])) | |
for item in output2[0]: | |
# Check if the score is greater than or equal to 0.5 | |
if item['score'] >= 0.5: | |
# Append the category and score to the classifications list | |
classifications2.append((label_dict2[item['label']], item['score'])) | |
# Construct and print the classification message | |
if (classifications1 and classifications2): | |
classification_str1 = ', '.join([f"{label} ({score:.2f})" for label, score in classifications1]) | |
classification_str2 = ', '.join([f"{label} ({score:.2f})" for label, score in classifications2]) | |
output_clas_and_lev = f"For dimensions: {classification_str1}\nFor levels: {classification_str2}" | |
return output_clas_and_lev | |
elif classifications1 and not classifications2: | |
classification_str1 = ', '.join([f"{label} ({score:.2f})" for label, score in classifications1]) | |
output_clas_no_lev = f"For dimensions: {classification_str1}\nFor levels: No classifications with a score of 0.5 or higher were found.\nHowever, the highest probability was for: '{label_dict2[highest_score_item2['label']]}' ({round(highest_score_item2['score'],2)})\n Use this classification with caution due to uncertainty" | |
return output_clas_no_lev | |
elif classifications2 and not classifications1: | |
classification_str2 = ', '.join([f"{label} ({score:.2f})" for label, score in classifications2]) | |
output_lev_no_clas = f"For levels: {classification_str2}\nFor dimensions: No classifications with a score of 0.5 or higher were found.\nHowever, the highest probability was for: '{label_dict1[highest_score_item1['label']]}' ({round(highest_score_item1['score'],2)}) \n Use this classification with caution due to uncertainty" | |
return output_lev_no_clas | |
else: | |
output_lev_no_no = f"No classification with a score of 0.5 or higher were found for both levels and dimensions\nThe highest probability for dimensions was: '{label_dict1[highest_score_item1['label']]}' ({round(highest_score_item1['score'],2)}\nThe highest probability for level was: '{label_dict2[highest_score_item2['label']]}' ({round(highest_score_item2['score'],2)} \n Use this classification with caution due to uncertainty" | |
return output_lev_no_no | |
######################################### | |
######## RUN GRADIO APP ################# | |
######################################### | |
txtbx = gr.Textbox(value = 'I would like to feel better', label = 'Please enter your item:', container = 'True') | |
txtbxopt = gr.Textbox(label = 'The item you provided was classified as:', container = 'True') | |
hf_writer = gr.HuggingFaceDatasetSaver(api_key, 'flagging_EEMM_V05') | |
demo = gr.Interface(fn=classify_output, inputs=txtbx, outputs=txtbxopt, | |
theme = gr.themes.Soft(primary_hue='orange'), | |
title = 'EEMM Item Classification Machine V 0.5', | |
description = 'This machine is a fine tuned version of DistillBERT. It classifies items in 7 EEMM dimensions following (ref). **Please note that the machine goes idle after a period of inactivity. If this occurs, waking it up may take around 20 seconds. Be patient ;)**', | |
article = 'Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum', | |
allow_flagging = 'manual', | |
flagging_options = ['Wrong category','Lacks a category'], | |
flagging_callback= hf_writer | |
) | |
demo.launch() | |