kkandull commited on
Commit
c32e434
ยท
verified ยท
1 Parent(s): 2bb7523
Files changed (1) hide show
  1. app.py +196 -0
app.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ํ•„์š”ํ•œ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ ์ž„ํฌํŠธ
2
+ import os
3
+ import pandas as pd
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torch.utils.data import Dataset, DataLoader # DataLoader์™€ Dataset์€ ์ถ”๋ก  ์‹œ ์ง์ ‘ ์‚ฌ์šฉ๋˜์ง„ ์•Š์ง€๋งŒ, ๋ชจ๋ธ ์ •์˜์— ํ•„์š”ํ•  ์ˆ˜ ์žˆ์–ด ๋‚จ๊ฒจ๋‘ 
9
+ from transformers import LongformerForSequenceClassification, AutoTokenizer
10
+ import gradio as gr
11
+
12
+ # =======================================================
13
+ # 1. ์ „์—ญ ์„ค์ • ๋ฐ ์ƒ์ˆ˜ ์ •์˜
14
+ # =======================================================
15
+ MODEL_NAME = 'kiddothe2b/longformer-mini-1024' # HuggingFace ๋ชจ๋ธ ์ด๋ฆ„
16
+ MAX_LEN = 1024 # ๋ชจ๋ธ ์ž…๋ ฅ ์ตœ๋Œ€ ๊ธธ์ด
17
+
18
+ # GPU ์‚ฌ์šฉ ๊ฐ€๋Šฅ ์—ฌ๋ถ€ ํ™•์ธ ๋ฐ ๋””๋ฐ”์ด์Šค ์„ค์ •
19
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
20
+ print(f"Using device: {device}")
21
+
22
+ # ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ (์ถ”๋ก  ์‹œ ํ•„์š”)
23
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
24
+
25
+ # =======================================================
26
+ # 2. PyTorch ๋ฐ์ดํ„ฐ์…‹ ์ •์˜ (ํ•™์Šต ์‹œ ์‚ฌ์šฉ๋˜์—ˆ๋˜ ํด๋ž˜์Šค. ์ถ”๋ก  ์‹œ ์ง์ ‘ ๋ฐ์ดํ„ฐ ๋กœ๋”๋ฅผ ๋งŒ๋“ค์ง€๋Š” ์•Š์Œ)
27
+ # =======================================================
28
+ # ์ด ํด๋ž˜์Šค๋Š” ๋ชจ๋ธ์ด ํ•™์Šต๋  ๋•Œ ์‚ฌ์šฉ๋˜์—ˆ๋˜ ๋ฐ์ดํ„ฐ ๊ตฌ์กฐ๋ฅผ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค.
29
+ # ์ถ”๋ก  ์‹œ์—๋Š” ๋‹จ์ผ ํ…์ŠคํŠธ ์ž…๋ ฅ์ด ๋“ค์–ด์˜ค๋ฏ€๋กœ ์ง์ ‘ DataLoader๋ฅผ ๋งŒ๋“ค ํ•„์š”๋Š” ์—†์Šต๋‹ˆ๋‹ค.
30
+ # ํ•˜์ง€๋งŒ ๋ชจ๋ธ์ด ๊ธฐ๋Œ€ํ•˜๋Š” ์ž…๋ ฅ ํ˜•ํƒœ๋ฅผ ๋งž์ถ”๊ธฐ ์œ„ํ•ด encoding ๊ณผ์ •์ด ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค.
31
+ class DepressionDataset(Dataset):
32
+ def __init__(self, texts, labels, tokenizer, max_len):
33
+ self.texts = texts
34
+ self.labels = labels
35
+ self.tokenizer = tokenizer
36
+ self.max_len = max_len
37
+
38
+ def __len__(self):
39
+ return len(self.texts)
40
+
41
+ def __getitem__(self, item):
42
+ text = str(self.texts[item])
43
+ label = self.labels[item]
44
+ encoding = self.tokenizer.encode_plus(
45
+ text,
46
+ add_special_tokens=True,
47
+ max_length=self.max_len,
48
+ return_token_type_ids=False,
49
+ padding='max_length',
50
+ truncation=True,
51
+ return_attention_mask=True,
52
+ return_tensors='pt',
53
+ )
54
+ return {
55
+ 'input_ids': encoding['input_ids'].flatten(),
56
+ 'attention_mask': encoding['attention_mask'].flatten(),
57
+ 'labels': torch.tensor(label, dtype=torch.long)
58
+ }
59
+
60
+ # =======================================================
61
+ # 3. ๋ชจ๋ธ ๋กœ๋”ฉ (ํ•™์Šต๋œ ๊ฐ€์ค‘์น˜๋ฅผ ๋กœ๋“œ)
62
+ # =======================================================
63
+ print("\n--- Loading models for inference ---")
64
+
65
+ # ๋ชจ๋ธ ๊ฐ€์ค‘์น˜๊ฐ€ ์ €์žฅ๋œ ๋””๋ ‰ํ† ๋ฆฌ (Hugging Face Spaces์— ์—…๋กœ๋“œํ•  ๋•Œ ์ƒ์„ฑํ•  ํด๋”)
66
+ save_dir = "saved_models"
67
+
68
+ # ๋ชจ๋ธ ํŒŒ์ผ ๊ฒฝ๋กœ
69
+ p_model_path = os.path.join(save_dir, 'p_text_best_model.bin')
70
+ e_model_path = os.path.join(save_dir, 'e_text_best_model.bin')
71
+
72
+ # ๋ชจ๋ธ ๋กœ๋”ฉ ๋ฐ ํ‰๊ฐ€ ๋ชจ๋“œ ์„ค์ •
73
+ p_model_for_inference = None
74
+ e_model_for_inference = None
75
+
76
+ try:
77
+ # ์ฐธ๊ฐ€์ž ๋ฐœํ™” ๋ชจ๋ธ (P-model) ๋กœ๋“œ
78
+ if os.path.exists(p_model_path):
79
+ p_model_for_inference = LongformerForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2)
80
+ p_model_for_inference.load_state_dict(torch.load(p_model_path, map_location=device))
81
+ p_model_for_inference.to(device)
82
+ p_model_for_inference.eval() # ํ‰๊ฐ€ ๋ชจ๋“œ ์„ค์ •
83
+ print(f"P-model loaded successfully from {p_model_path}")
84
+ else:
85
+ print(f"Warning: P-model file not found at {p_model_path}. Please ensure it's uploaded.")
86
+
87
+ # ์—˜๋ฆฌ ๋ฐœํ™” ๋ชจ๋ธ (E-model) ๋กœ๋“œ
88
+ if os.path.exists(e_model_path):
89
+ e_model_for_inference = LongformerForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2)
90
+ e_model_for_inference.load_state_dict(torch.load(e_model_path, map_location=device))
91
+ e_model_for_inference.to(device)
92
+ e_model_for_inference.eval() # ํ‰๊ฐ€ ๋ชจ๋“œ ์„ค์ •
93
+ print(f"E-model loaded successfully from {e_model_path}")
94
+ else:
95
+ print(f"Warning: E-model file not found at {e_model_path}. Please ensure it's uploaded.")
96
+
97
+ except Exception as e:
98
+ print(f"Error loading models: {e}")
99
+ # ๋ชจ๋ธ ๋กœ๋”ฉ ์‹คํŒจ ์‹œ, UI๊ฐ€ ์‹คํ–‰๋˜์ง€ ์•Š๋„๋ก ์„ค์ •
100
+ p_model_for_inference = None
101
+ e_model_for_inference = None
102
+
103
+ # =======================================================
104
+ # 4. Gradio ์˜ˆ์ธก ํ•จ์ˆ˜ ์ •์˜
105
+ # =======================================================
106
+ def predict_depression(participant_text, ellie_text):
107
+ # ๋ชจ๋ธ์ด ์ œ๋Œ€๋กœ ๋กœ๋“œ๋˜์—ˆ๋Š”์ง€ ํ™•์ธ
108
+ if p_model_for_inference is None or e_model_for_inference is None:
109
+ return "**์˜ค๋ฅ˜:** ๋ชจ๋ธ์ด ๋กœ๋“œ๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค. ๊ด€๋ฆฌ์ž์—๊ฒŒ ๋ฌธ์˜ํ•˜๊ฑฐ๋‚˜ ๋ชจ๋ธ ํŒŒ์ผ ์—…๋กœ๋“œ ์—ฌ๋ถ€๋ฅผ ํ™•์ธํ•ด์ฃผ์„ธ์š”."
110
+
111
+ # ์—˜๋ฆฌ ๋ฐœํ™” ์ „์ฒ˜๋ฆฌ (ํ•™์Šต ์‹œ์™€ ๋™์ผํ•œ ๋กœ์ง ์ ์šฉ)
112
+ e_text_words = ellie_text.split()
113
+ if len(e_text_words) > 0:
114
+ ellie_text_processed = " ".join(e_text_words[len(e_text_words) // 2:])
115
+ else:
116
+ ellie_text_processed = ""
117
+
118
+ # P-model ์˜ˆ์ธก
119
+ p_encoding = tokenizer.encode_plus(
120
+ participant_text,
121
+ add_special_tokens=True,
122
+ max_length=MAX_LEN,
123
+ return_token_type_ids=False,
124
+ padding='max_length',
125
+ truncation=True,
126
+ return_attention_mask=True,
127
+ return_tensors='pt',
128
+ )
129
+ p_input_ids = p_encoding['input_ids'].to(device)
130
+ p_attention_mask = p_encoding['attention_mask'].to(device)
131
+
132
+ with torch.no_grad(): # ์ถ”๋ก  ์‹œ์—๋Š” ๊ทธ๋ผ๋””์–ธํŠธ ๊ณ„์‚ฐ ๋ถˆํ•„์š”
133
+ p_outputs = p_model_for_inference(input_ids=p_input_ids, attention_mask=p_attention_mask)
134
+ p_probs = F.softmax(p_outputs.logits, dim=1).cpu().numpy().flatten()
135
+ p_pred_label = np.argmax(p_probs)
136
+
137
+ # E-model ์˜ˆ์ธก
138
+ e_encoding = tokenizer.encode_plus(
139
+ ellie_text_processed,
140
+ add_special_tokens=True,
141
+ max_length=MAX_LEN,
142
+ return_token_type_ids=False,
143
+ padding='max_length',
144
+ truncation=True,
145
+ return_attention_mask=True,
146
+ return_tensors='pt',
147
+ )
148
+ e_input_ids = e_encoding['input_ids'].to(device)
149
+ e_attention_mask = e_encoding['attention_mask'].to(device)
150
+
151
+ with torch.no_grad(): # ์ถ”๋ก  ์‹œ์—๋Š” ๊ทธ๋ผ๋””์–ธํŠธ ๊ณ„์‚ฐ ๋ถˆํ•„์š”
152
+ e_outputs = e_model_for_inference(input_ids=e_input_ids, attention_mask=e_attention_mask)
153
+ e_probs = F.softmax(e_outputs.logits, dim=1).cpu().numpy().flatten()
154
+ e_pred_label = np.argmax(e_probs)
155
+
156
+ # ์•™์ƒ๋ธ” (OR ์ „๋žต): ๋‘˜ ์ค‘ ํ•˜๋‚˜๋ผ๋„ ์šฐ์šธ์ฆ(1)์œผ๋กœ ์˜ˆ์ธกํ•˜๋ฉด ์šฐ์šธ์ฆ์œผ๋กœ ๊ฐ„์ฃผ
157
+ ensemble_pred_label = 1 if p_pred_label == 1 or e_pred_label == 1 else 0
158
+
159
+ labels = ['Control (๋น„์šฐ์šธ)', 'Depressed (์šฐ์šธ)']
160
+ ensemble_result = labels[ensemble_pred_label]
161
+ p_model_result = labels[p_pred_label]
162
+ e_model_result = labels[e_pred_label]
163
+
164
+ return (f"**์ตœ์ข… ์•™์ƒ๋ธ” ์˜ˆ์ธก (OR ์ „๋žต): {ensemble_result}**\n\n"
165
+ f" - ์ฐธ๊ฐ€์ž ๋ชจ๋ธ (P-longBERT) ์˜ˆ์ธก: {p_model_result} (ํ™•๋ฅ : Control={p_probs[0]:.2f}, Depressed={p_probs[1]:.2f})\n"
166
+ f" - ์—˜๋ฆฌ ๋ชจ๋ธ (E-longBERT) ์˜ˆ์ธก: {e_model_result} (ํ™•๋ฅ : Control={e_probs[0]:.2f}, Depressed={e_probs[1]:.2f})\n\n"
167
+ f"**์ฐธ๊ณ :**\n"
168
+ f"- ์˜ˆ์ธก์€ ๊ฐ ๋Œ€ํ™” ๋‚ด์šฉ์—๋งŒ ๊ธฐ๋ฐ˜ํ•˜๋ฉฐ, ์‹ค์ œ ์ง„๋‹จ์€ ์ „๋ฌธ๊ฐ€์™€ ์ƒ๋‹ดํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.\n"
169
+ f"- GPU ํ™˜๊ฒฝ์—์„œ๋Š” ์˜ˆ์ธก์ด ๋น ๋ฅด๊ฒŒ ์ˆ˜ํ–‰๋ฉ๋‹ˆ๋‹ค."
170
+ )
171
+
172
+ # =======================================================
173
+ # 5. Gradio UI ์ธํ„ฐํŽ˜์ด์Šค ์ƒ์„ฑ ๋ฐ ์‹คํ–‰
174
+ # =======================================================
175
+ print("\n--- Setting up Gradio UI ---")
176
+
177
+ # ๋ชจ๋ธ์ด ์„ฑ๊ณต์ ์œผ๋กœ ๋กœ๋“œ๋˜์—ˆ์„ ๊ฒฝ์šฐ์—๋งŒ Gradio UI๋ฅผ ์‹คํ–‰
178
+ if p_model_for_inference is not None and e_model_for_inference is not None:
179
+ gr.Interface(
180
+ fn=predict_depression,
181
+ inputs=[
182
+ gr.Textbox(lines=10, label="์ฐธ๊ฐ€์ž ๋ฐœํ™” ๋‚ด์šฉ (Participant's speech)", placeholder="์—ฌ๊ธฐ์— ์ฐธ๊ฐ€์ž์˜ ๋ฐœํ™” ๋‚ด์šฉ์„ ์ž…๋ ฅํ•˜์„ธ์š”..."),
183
+ gr.Textbox(lines=10, label="์—˜๋ฆฌ ๋ฐœํ™” ๋‚ด์šฉ (Ellie's speech)", placeholder="์—ฌ๊ธฐ์— ์—˜๋ฆฌ(๊ฐ€์ƒ ์—์ด์ „ํŠธ)์˜ ๋ฐœํ™” ๋‚ด์šฉ์„ ์ž…๋ ฅํ•˜์„ธ์š”... (์ „์ฒด ๋‚ด์šฉ ์ค‘ ํ›„๋ฐ˜๋ถ€๋งŒ ์‚ฌ์šฉ๋จ)")
184
+ ],
185
+ outputs="markdown",
186
+ title="DAIC-WOZ ์šฐ์šธ์ฆ ๊ฐ์ง€ ์•™์ƒ๋ธ” ๋ชจ๋ธ (GPU ๊ฐ€์†)",
187
+ description=f"""์ด ์•ฑ์€ DAIC-WOZ ๋ฐ์ดํ„ฐ์…‹์„ ๊ธฐ๋ฐ˜์œผ๋กœ ์ฐธ๊ฐ€์ž์™€ ๊ฐ€์ƒ ์—์ด์ „ํŠธ(์—˜๋ฆฌ)์˜ ๋Œ€ํ™” ๋‚ด์šฉ์„ ๋ถ„์„ํ•˜์—ฌ ์šฐ์šธ์ฆ ์—ฌ๋ถ€๋ฅผ ์˜ˆ์ธกํ•ฉ๋‹ˆ๋‹ค.
188
+ P-longBERT (์ฐธ๊ฐ€์ž ๋ฐœํ™”)์™€ E-longBERT (์—˜๋ฆฌ ๋ฐœํ™”) ๋ชจ๋ธ์˜ ์•™์ƒ๋ธ” (OR ์ „๋žต) ๊ฒฐ๊ณผ๋ฅผ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค.
189
+ **GPU ํ™˜๊ฒฝ์—์„œ๋Š” ์˜ˆ์ธก์ด ๋น ๋ฅด๊ฒŒ ์ˆ˜ํ–‰๋ฉ๋‹ˆ๋‹ค.**
190
+ **์ฐธ๊ณ :** ์ด๋Š” AI ๋ชจ๋ธ์˜ ์˜ˆ์ธก์ผ ๋ฟ์ด๋ฉฐ, **์‹ค์ œ ์˜ํ•™์  ์ง„๋‹จ์€ ๋ฐ˜๋“œ์‹œ ์ „๋ฌธ๊ฐ€์™€ ์ƒ๋‹ดํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.**
191
+ ์‚ฌ์šฉ ์ค‘์ธ ๋””๋ฐ”์ด์Šค: {device}
192
+ """
193
+ ).launch() # Hugging Face Spaces์—์„œ๋Š” share=True๊ฐ€ ํ•„์š” ์—†์Œ
194
+ else:
195
+ print("\nGradio UI could not be launched because models failed to load. Please check model files.")
196
+