|
import gradio as gr |
|
import torch |
|
import re |
|
from transformers import BertTokenizer, BertForSequenceClassification |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_name = "./model" |
|
tokenizer = BertTokenizer.from_pretrained(model_name) |
|
model = BertForSequenceClassification.from_pretrained(model_name) |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model.to(device) |
|
model.eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
MBTI_CLASSES = [ |
|
"ISTJ", "ISFJ", "INFJ", "INTJ", |
|
"ISTP", "ISFP", "INFP", "INTP", |
|
"ESTP", "ESFP", "ENFP", "ENTP", |
|
"ESTJ", "ESFJ", "ENFJ", "ENTJ" |
|
] |
|
|
|
|
|
def preprocess_text(text): |
|
text = text.lower() |
|
text = re.sub(r"http\S+|www.\S+", "", text) |
|
text = re.sub(r"[^a-zA-Z\s]", "", text) |
|
|
|
|
|
|
|
|
|
|
|
def predict_mbti(text): |
|
cleaned = preprocess_text(text) |
|
inputs = tokenizer( |
|
cleaned, |
|
max_length=512, |
|
padding="max_length", |
|
truncation=True, |
|
return_tensors="pt" |
|
).to(device) |
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
pred_idx = torch.argmax(outputs.logits, dim=1).item() |
|
return MBTI_CLASSES[pred_idx] |
|
|
|
|
|
interface = gr.Interface( |
|
fn=predict_mbti, |
|
inputs=gr.Textbox(lines=12, label="Enter Combined Answers (Q1 A1 Q2 A2 ...)"), |
|
outputs=gr.Textbox(label="Predicted MBTI Type"), |
|
title="MBTI Personality Predictor (BERT)", |
|
description="Paste your combined answers to get your MBTI personality type. Powered by Sid26Roy/mbti" |
|
) |
|
|
|
if __name__ == "__main__": |
|
interface.launch() |
|
|