File size: 2,009 Bytes
5a789fb
 
 
c3a7c62
5342693
46ea1cb
 
 
5a789fb
c3a7c62
5342693
 
 
5a789fb
c3a7c62
 
 
 
5a789fb
c3a7c62
5a789fb
 
c3a7c62
5342693
84efe55
c3a7c62
 
5a789fb
 
 
 
 
 
 
c3a7c62
 
5a789fb
 
 
46ea1cb
 
 
5a789fb
c3a7c62
 
 
 
 
 
 
 
 
 
5a789fb
 
 
c3a7c62
 
5a789fb
c3a7c62
 
5a789fb
c3a7c62
5a789fb
c3a7c62
 
5a789fb
 
 
c3a7c62
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import gradio as gr
import torch
import re
from transformers import BertTokenizer, BertForSequenceClassification
# import nltk
# from nltk.tokenize import word_tokenize
# from nltk.corpus import stopwords
# from nltk.stem import WordNetLemmatizer

# Download required NLTK data
# nltk.download("stopwords")
# nltk.download("punkt")
# nltk.download("wordnet")

# Load model and tokenizer
model_name = "./model"
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

# NLP tools
# stop_words = set(stopwords.words("english"))
# lemmatizer = WordNetLemmatizer()

# MBTI Labels
MBTI_CLASSES = [
    "ISTJ", "ISFJ", "INFJ", "INTJ",
    "ISTP", "ISFP", "INFP", "INTP",
    "ESTP", "ESFP", "ENFP", "ENTP",
    "ESTJ", "ESFJ", "ENFJ", "ENTJ"
]

# Preprocess text
def preprocess_text(text):
    text = text.lower()
    text = re.sub(r"http\S+|www.\S+", "", text)
    text = re.sub(r"[^a-zA-Z\s]", "", text)
    # tokens = word_tokenize(text)
    # tokens = [lemmatizer.lemmatize(word) for word in tokens]
    # return " ".join(tokens)

# Inference function
def predict_mbti(text):
    cleaned = preprocess_text(text)
    inputs = tokenizer(
        cleaned,
        max_length=512,
        padding="max_length",
        truncation=True,
        return_tensors="pt"
    ).to(device)

    with torch.no_grad():
        outputs = model(**inputs)
        pred_idx = torch.argmax(outputs.logits, dim=1).item()
        return MBTI_CLASSES[pred_idx]

# Gradio interface
interface = gr.Interface(
    fn=predict_mbti,
    inputs=gr.Textbox(lines=12, label="Enter Combined Answers (Q1 A1 Q2 A2 ...)"),
    outputs=gr.Textbox(label="Predicted MBTI Type"),
    title="MBTI Personality Predictor (BERT)",
    description="Paste your combined answers to get your MBTI personality type. Powered by Sid26Roy/mbti"
)

if __name__ == "__main__":
    interface.launch()