Sid26Roy commited on
Commit
5a789fb
·
verified ·
1 Parent(s): cd622b9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -0
app.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import BertTokenizer, BertForSequenceClassification
4
+ import re
5
+ import nltk
6
+ from nltk.tokenize import word_tokenize
7
+ from nltk.corpus import stopwords
8
+ from nltk.stem import WordNetLemmatizer
9
+
10
+ # NLTK setup
11
+ nltk.download('punkt')
12
+ nltk.download('stopwords')
13
+ nltk.download('wordnet')
14
+
15
+ stop_words = set(stopwords.words('english'))
16
+ lemmatizer = WordNetLemmatizer()
17
+
18
+ # Load model & tokenizer
19
+ model_dir = "./model"
20
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+
22
+ tokenizer = BertTokenizer.from_pretrained(model_dir)
23
+ model = BertForSequenceClassification.from_pretrained(model_dir).to(device)
24
+ model.eval()
25
+
26
+ MBTI_CLASSES = [
27
+ "ISTJ", "ISFJ", "INFJ", "INTJ",
28
+ "ISTP", "ISFP", "INFP", "INTP",
29
+ "ESTP", "ESFP", "ENFP", "ENTP",
30
+ "ESTJ", "ESFJ", "ENFJ", "ENTJ"
31
+ ]
32
+
33
+ def preprocess(text):
34
+ text = text.lower()
35
+ text = re.sub(r"http\S+|www.\S+", "", text)
36
+ text = re.sub(r"[^a-zA-Z\s]", "", text)
37
+ text = text.replace("|||", " ")
38
+ tokens = word_tokenize(text)
39
+ tokens = [lemmatizer.lemmatize(word) for word in tokens if word not in stop_words]
40
+ return " ".join(tokens)
41
+
42
+ def predict_mbti(passage):
43
+ if not passage.strip():
44
+ return "Please enter your text."
45
+
46
+ cleaned = preprocess(passage)
47
+ inputs = tokenizer(cleaned, return_tensors="pt", padding="max_length", truncation=True, max_length=512).to(device)
48
+
49
+ with torch.no_grad():
50
+ outputs = model(**inputs)
51
+ pred = torch.argmax(outputs.logits, dim=1).item()
52
+
53
+ return MBTI_CLASSES[pred]
54
+
55
+ # Gradio Interface
56
+ demo = gr.Interface(
57
+ fn=predict_mbti,
58
+ inputs=gr.Textbox(lines=10, label="Combined Response Passage", placeholder="Paste all question-answer text here..."),
59
+ outputs=gr.Textbox(label="Predicted MBTI Type"),
60
+ title="🔮 MBTI Personality Classifier API",
61
+ description="Pass a single combined text input (like 'Q1 A1 Q2 A2...') to get back the MBTI type."
62
+ )
63
+
64
+ if __name__ == "__main__":
65
+ demo.launch()