Spaces:
Running
Running
File size: 3,863 Bytes
ec2757f 5d967c4 ec2757f 5d967c4 ec2757f 5d967c4 ec2757f 5d967c4 ec2757f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import requests
import time
import gradio as gr
import os
########################
## Loading the model: ##
#######################
api_key = os.environ.get("HF_API_KEY_INFERENCE")
API_URL = "https://api-inference.huggingface.co/models/chernandezc/distilbert-base-uncased-finetuned-items-multi-label-21122023-AUGMENTED" #Api endpoint.
headers = {"Authorization": f"Bearer {api_key}"} #This is a read only API key, do not share please :D.
#This is visible just to experiment, I cannot share environment variables.
def query(payload): #Function to use the API.
response = requests.post(API_URL, headers=headers, json=payload)
return response.json() #Return Json.
# ##########################################################
# Function to process the output and print classifications #
############################################################
def classify_output(item):
label_dict = {
'LABEL_0': 'Cognition',
'LABEL_1': 'Affect',
'LABEL_2': 'Self',
'LABEL_3': 'Motivation',
'LABEL_4': 'Attention',
'LABEL_5': 'Overt_Behavior',
'LABEL_6': 'Context'
}
output = query({ #Try to query the endpoint.
"inputs": item,
})
# If the model is loading, wait and try again.
while 'error' in output:
time.sleep(output["estimated_time"]) #Sleep the estimated time and try again.
output = query({
"inputs": item,
})
# Store classifications in a list
classifications = []
# Find the item with the highest score
highest_score_item = max(output[0], key=lambda x: x['score'])
for item in output[0]:
# Check if the score is greater than or equal to 0.5
if item['score'] >= 0.5:
# Append the category and score to the classifications list
classifications.append((label_dict[item['label']], item['score']))
# Construct and print the classification message
if classifications:
classification_str = ', '.join([f"{label} ({score:.2f})" for label, score in classifications])
output1 = classification_str
return output1
else:
output2 = f"No classifications with a score of 0.5 or higher were found. \n However, the highest probability was for: '{label_dict[highest_score_item['label']]}' ({round(item['score'],2)}) \n Use this classification with caution due to uncertainty"
return output2
#########################################
######## RUN GRADIO APP #################
#########################################
txtbx = gr.Textbox(value = 'I would like to feel better', label = 'Please enter your item:', container = 'True')
txtbxopt = gr.Textbox(label = 'The item you provided was classified as:', container = 'True')
demo = gr.Interface(fn=classify_output, inputs=txtbx, outputs=txtbxopt,
theme = gr.themes.Soft(primary_hue='orange'),
title = 'EEMM Item Classification Machine V 0.5',
description = 'This machine is a fine tuned version of DistillBERT. It classifies items in 7 EEMM dimensions following (ref). **Please note that the machine goes idle after a period of inactivity. If this occurs, waking it up may take around 20 seconds. Be patient ;)**',
article = 'Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum')
demo.launch(share=True)
|