File size: 2,009 Bytes
5a789fb c3a7c62 5342693 46ea1cb 5a789fb c3a7c62 5342693 5a789fb c3a7c62 5a789fb c3a7c62 5a789fb c3a7c62 5342693 84efe55 c3a7c62 5a789fb c3a7c62 5a789fb 46ea1cb 5a789fb c3a7c62 5a789fb c3a7c62 5a789fb c3a7c62 5a789fb c3a7c62 5a789fb c3a7c62 5a789fb c3a7c62 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
import gradio as gr
import torch
import re
from transformers import BertTokenizer, BertForSequenceClassification
# import nltk
# from nltk.tokenize import word_tokenize
# from nltk.corpus import stopwords
# from nltk.stem import WordNetLemmatizer
# Download required NLTK data
# nltk.download("stopwords")
# nltk.download("punkt")
# nltk.download("wordnet")
# Load model and tokenizer
model_name = "./model"
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
# NLP tools
# stop_words = set(stopwords.words("english"))
# lemmatizer = WordNetLemmatizer()
# MBTI Labels
MBTI_CLASSES = [
"ISTJ", "ISFJ", "INFJ", "INTJ",
"ISTP", "ISFP", "INFP", "INTP",
"ESTP", "ESFP", "ENFP", "ENTP",
"ESTJ", "ESFJ", "ENFJ", "ENTJ"
]
# Preprocess text
def preprocess_text(text):
text = text.lower()
text = re.sub(r"http\S+|www.\S+", "", text)
text = re.sub(r"[^a-zA-Z\s]", "", text)
# tokens = word_tokenize(text)
# tokens = [lemmatizer.lemmatize(word) for word in tokens]
# return " ".join(tokens)
# Inference function
def predict_mbti(text):
cleaned = preprocess_text(text)
inputs = tokenizer(
cleaned,
max_length=512,
padding="max_length",
truncation=True,
return_tensors="pt"
).to(device)
with torch.no_grad():
outputs = model(**inputs)
pred_idx = torch.argmax(outputs.logits, dim=1).item()
return MBTI_CLASSES[pred_idx]
# Gradio interface
interface = gr.Interface(
fn=predict_mbti,
inputs=gr.Textbox(lines=12, label="Enter Combined Answers (Q1 A1 Q2 A2 ...)"),
outputs=gr.Textbox(label="Predicted MBTI Type"),
title="MBTI Personality Predictor (BERT)",
description="Paste your combined answers to get your MBTI personality type. Powered by Sid26Roy/mbti"
)
if __name__ == "__main__":
interface.launch()
|