Cristóbal Hernández commited on
Commit
ec2757f
·
1 Parent(s): a77a997

Add application file

Browse files
Files changed (1) hide show
  1. app.py +72 -0
app.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import time
3
+ import gradio as gr
4
+ import os
5
+
6
+ ########################
7
+ ## Loading the model: ##
8
+ #######################
9
+ api_key = os.environ.get("HF_API_KEY_INFERENCE")
10
+ API_URL = "https://api-inference.huggingface.co/models/chernandezc/distilbert-base-uncased-finetuned-items-multi-label-21122023-AUGMENTED" #Api endpoint.
11
+ headers = {"Authorization": f"Bearer {api_key}"} #This is a read only API key, do not share please :D.
12
+ #This is visible just to experiment, I cannot share environment variables.
13
+
14
+ def query(payload): #Function to use the API.
15
+ response = requests.post(API_URL, headers=headers, json=payload)
16
+ return response.json() #Return Json.
17
+
18
+ # ##########################################################
19
+ # Function to process the output and print classifications #
20
+ ############################################################
21
+ def classify_output(item):
22
+ label_dict = {
23
+ 'LABEL_0': 'Cognition',
24
+ 'LABEL_1': 'Affect',
25
+ 'LABEL_2': 'Self',
26
+ 'LABEL_3': 'Motivation',
27
+ 'LABEL_4': 'Attention',
28
+ 'LABEL_5': 'Overt_Behavior',
29
+ 'LABEL_6': 'Context'
30
+ }
31
+
32
+ output = query({ #Try to query the endpoint.
33
+ "inputs": item,
34
+ })
35
+
36
+ # If the model is loading, wait and try again.
37
+ while 'error' in output:
38
+ time.sleep(output["estimated_time"]) #Sleep the estimated time and try again.
39
+ output = query({
40
+ "inputs": item,
41
+ })
42
+
43
+ # Store classifications in a list
44
+ classifications = []
45
+
46
+ # Find the item with the highest score
47
+ highest_score_item = max(output[0], key=lambda x: x['score'])
48
+
49
+ for item in output[0]:
50
+ # Check if the score is greater than or equal to 0.5
51
+ if item['score'] >= 0.5:
52
+ # Append the category and score to the classifications list
53
+ classifications.append((label_dict[item['label']], item['score']))
54
+
55
+ # Construct and print the classification message
56
+ if classifications:
57
+ classification_str = ', '.join([f"'{label}' ({score:.2f})" for label, score in classifications])
58
+
59
+ output1 = f"The item you provided is classified as: {classification_str}"
60
+ return output1
61
+ else:
62
+ output2 = f"No classifications with a score of 0.5 or higher were found. \n However, the highest probability was for: '{label_dict[highest_score_item['label']]}' ({round(item['score'],2)}) \n Use this classification with caution due to uncertainty"
63
+ return output2
64
+
65
+ #########################################
66
+ ######## RUN GRADIO APP #################
67
+ #########################################
68
+
69
+ demo = gr.Interface(fn=classify_output, inputs="text", outputs="text")
70
+ demo.launch()
71
+
72
+