Spaces:
Sleeping
Sleeping
# imports | |
import json | |
import time | |
import gradio as gr | |
from transformers import AutoTokenizer, AutoModel | |
# pytorch library | |
import torch | |
import torch.nn.functional as f | |
from roles_list import roles | |
# Load the model from the specified directory | |
embed_store = {} | |
model = 'sentence-transformers/all-MiniLM-L12-v2' | |
sbert_model = AutoModel.from_pretrained(model) | |
sbert_tokenizer = AutoTokenizer.from_pretrained(model) | |
for role in roles: | |
encoding = sbert_tokenizer(role, # the texts to be tokenized | |
max_length=10, | |
padding="max_length", | |
return_tensors='pt' # return the tensors (not lists) | |
) | |
with torch.no_grad(): | |
# get the model embeddings | |
embed = sbert_model(**encoding) | |
embed = embed.pooler_output | |
embed_store[role] = f.normalize(embed, p=2, dim=1) | |
print("Model is ready for inference") | |
def get_role_from_sbert(title): | |
start_time = time.time() | |
encoding = sbert_tokenizer(title, | |
max_length=10, | |
padding="max_length", | |
return_tensors='pt' | |
) | |
# Run the model prediction on the input data | |
with torch.no_grad(): | |
# get the model embeddings | |
embed = sbert_model(**encoding) | |
embed = embed.pooler_output | |
store_cos = {} | |
for role in embed_store: | |
cos_sim = torch.nn.functional.cosine_similarity(f.normalize(embed, p=2, dim=1), embed_store[role]) | |
store_cos[role] = round(cos_sim.item(), 3) | |
# Get the top 3 items with the highest cosine similarity | |
top_3_keys_values = sorted(store_cos.items(), key=lambda item: item[1], reverse=True) | |
job_scores_str = '\n'.join([f"{job}: {score}" for job, score in top_3_keys_values]) | |
end_time = time.time() | |
execution_time = end_time - start_time | |
# Convert to dictionary if needed or keep as list of tuples | |
return job_scores_str + f" \nExecution time: {str(execution_time)}" | |
demo = gr.Interface(fn=get_role_from_sbert, | |
inputs=gr.Textbox(label="Job Title"), | |
outputs=gr.Textbox(label="Role"), | |
title="HackerRank Role Classifier") | |
gr.close_all() | |
demo.launch() |