Spaces:
Sleeping
Sleeping
app.py
Browse files
app.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ํ์ํ ๋ผ์ด๋ธ๋ฌ๋ฆฌ ์ํฌํธ
|
2 |
+
import os
|
3 |
+
import pandas as pd
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from torch.utils.data import Dataset, DataLoader # DataLoader์ Dataset์ ์ถ๋ก ์ ์ง์ ์ฌ์ฉ๋์ง ์์ง๋ง, ๋ชจ๋ธ ์ ์์ ํ์ํ ์ ์์ด ๋จ๊ฒจ๋
|
9 |
+
from transformers import LongformerForSequenceClassification, AutoTokenizer
|
10 |
+
import gradio as gr
|
11 |
+
|
12 |
+
# =======================================================
|
13 |
+
# 1. ์ ์ญ ์ค์ ๋ฐ ์์ ์ ์
|
14 |
+
# =======================================================
|
15 |
+
MODEL_NAME = 'kiddothe2b/longformer-mini-1024' # HuggingFace ๋ชจ๋ธ ์ด๋ฆ
|
16 |
+
MAX_LEN = 1024 # ๋ชจ๋ธ ์
๋ ฅ ์ต๋ ๊ธธ์ด
|
17 |
+
|
18 |
+
# GPU ์ฌ์ฉ ๊ฐ๋ฅ ์ฌ๋ถ ํ์ธ ๋ฐ ๋๋ฐ์ด์ค ์ค์
|
19 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
20 |
+
print(f"Using device: {device}")
|
21 |
+
|
22 |
+
# ํ ํฌ๋์ด์ ๋ก๋ (์ถ๋ก ์ ํ์)
|
23 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
24 |
+
|
25 |
+
# =======================================================
|
26 |
+
# 2. PyTorch ๋ฐ์ดํฐ์
์ ์ (ํ์ต ์ ์ฌ์ฉ๋์๋ ํด๋์ค. ์ถ๋ก ์ ์ง์ ๋ฐ์ดํฐ ๋ก๋๋ฅผ ๋ง๋ค์ง๋ ์์)
|
27 |
+
# =======================================================
|
28 |
+
# ์ด ํด๋์ค๋ ๋ชจ๋ธ์ด ํ์ต๋ ๋ ์ฌ์ฉ๋์๋ ๋ฐ์ดํฐ ๊ตฌ์กฐ๋ฅผ ์ ์ํฉ๋๋ค.
|
29 |
+
# ์ถ๋ก ์์๋ ๋จ์ผ ํ
์คํธ ์
๋ ฅ์ด ๋ค์ด์ค๋ฏ๋ก ์ง์ DataLoader๋ฅผ ๋ง๋ค ํ์๋ ์์ต๋๋ค.
|
30 |
+
# ํ์ง๋ง ๋ชจ๋ธ์ด ๊ธฐ๋ํ๋ ์
๋ ฅ ํํ๋ฅผ ๋ง์ถ๊ธฐ ์ํด encoding ๊ณผ์ ์ด ์ฌ์ฉ๋ฉ๋๋ค.
|
31 |
+
class DepressionDataset(Dataset):
|
32 |
+
def __init__(self, texts, labels, tokenizer, max_len):
|
33 |
+
self.texts = texts
|
34 |
+
self.labels = labels
|
35 |
+
self.tokenizer = tokenizer
|
36 |
+
self.max_len = max_len
|
37 |
+
|
38 |
+
def __len__(self):
|
39 |
+
return len(self.texts)
|
40 |
+
|
41 |
+
def __getitem__(self, item):
|
42 |
+
text = str(self.texts[item])
|
43 |
+
label = self.labels[item]
|
44 |
+
encoding = self.tokenizer.encode_plus(
|
45 |
+
text,
|
46 |
+
add_special_tokens=True,
|
47 |
+
max_length=self.max_len,
|
48 |
+
return_token_type_ids=False,
|
49 |
+
padding='max_length',
|
50 |
+
truncation=True,
|
51 |
+
return_attention_mask=True,
|
52 |
+
return_tensors='pt',
|
53 |
+
)
|
54 |
+
return {
|
55 |
+
'input_ids': encoding['input_ids'].flatten(),
|
56 |
+
'attention_mask': encoding['attention_mask'].flatten(),
|
57 |
+
'labels': torch.tensor(label, dtype=torch.long)
|
58 |
+
}
|
59 |
+
|
60 |
+
# =======================================================
|
61 |
+
# 3. ๋ชจ๋ธ ๋ก๋ฉ (ํ์ต๋ ๊ฐ์ค์น๋ฅผ ๋ก๋)
|
62 |
+
# =======================================================
|
63 |
+
print("\n--- Loading models for inference ---")
|
64 |
+
|
65 |
+
# ๋ชจ๋ธ ๊ฐ์ค์น๊ฐ ์ ์ฅ๋ ๋๋ ํ ๋ฆฌ (Hugging Face Spaces์ ์
๋ก๋ํ ๋ ์์ฑํ ํด๋)
|
66 |
+
save_dir = "saved_models"
|
67 |
+
|
68 |
+
# ๋ชจ๋ธ ํ์ผ ๊ฒฝ๋ก
|
69 |
+
p_model_path = os.path.join(save_dir, 'p_text_best_model.bin')
|
70 |
+
e_model_path = os.path.join(save_dir, 'e_text_best_model.bin')
|
71 |
+
|
72 |
+
# ๋ชจ๋ธ ๋ก๋ฉ ๋ฐ ํ๊ฐ ๋ชจ๋ ์ค์
|
73 |
+
p_model_for_inference = None
|
74 |
+
e_model_for_inference = None
|
75 |
+
|
76 |
+
try:
|
77 |
+
# ์ฐธ๊ฐ์ ๋ฐํ ๋ชจ๋ธ (P-model) ๋ก๋
|
78 |
+
if os.path.exists(p_model_path):
|
79 |
+
p_model_for_inference = LongformerForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2)
|
80 |
+
p_model_for_inference.load_state_dict(torch.load(p_model_path, map_location=device))
|
81 |
+
p_model_for_inference.to(device)
|
82 |
+
p_model_for_inference.eval() # ํ๊ฐ ๋ชจ๋ ์ค์
|
83 |
+
print(f"P-model loaded successfully from {p_model_path}")
|
84 |
+
else:
|
85 |
+
print(f"Warning: P-model file not found at {p_model_path}. Please ensure it's uploaded.")
|
86 |
+
|
87 |
+
# ์๋ฆฌ ๋ฐํ ๋ชจ๋ธ (E-model) ๋ก๋
|
88 |
+
if os.path.exists(e_model_path):
|
89 |
+
e_model_for_inference = LongformerForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2)
|
90 |
+
e_model_for_inference.load_state_dict(torch.load(e_model_path, map_location=device))
|
91 |
+
e_model_for_inference.to(device)
|
92 |
+
e_model_for_inference.eval() # ํ๊ฐ ๋ชจ๋ ์ค์
|
93 |
+
print(f"E-model loaded successfully from {e_model_path}")
|
94 |
+
else:
|
95 |
+
print(f"Warning: E-model file not found at {e_model_path}. Please ensure it's uploaded.")
|
96 |
+
|
97 |
+
except Exception as e:
|
98 |
+
print(f"Error loading models: {e}")
|
99 |
+
# ๋ชจ๋ธ ๋ก๋ฉ ์คํจ ์, UI๊ฐ ์คํ๋์ง ์๋๋ก ์ค์
|
100 |
+
p_model_for_inference = None
|
101 |
+
e_model_for_inference = None
|
102 |
+
|
103 |
+
# =======================================================
|
104 |
+
# 4. Gradio ์์ธก ํจ์ ์ ์
|
105 |
+
# =======================================================
|
106 |
+
def predict_depression(participant_text, ellie_text):
|
107 |
+
# ๋ชจ๋ธ์ด ์ ๋๋ก ๋ก๋๋์๋์ง ํ์ธ
|
108 |
+
if p_model_for_inference is None or e_model_for_inference is None:
|
109 |
+
return "**์ค๋ฅ:** ๋ชจ๋ธ์ด ๋ก๋๋์ง ์์์ต๋๋ค. ๊ด๋ฆฌ์์๊ฒ ๋ฌธ์ํ๊ฑฐ๋ ๋ชจ๋ธ ํ์ผ ์
๋ก๋ ์ฌ๋ถ๋ฅผ ํ์ธํด์ฃผ์ธ์."
|
110 |
+
|
111 |
+
# ์๋ฆฌ ๋ฐํ ์ ์ฒ๋ฆฌ (ํ์ต ์์ ๋์ผํ ๋ก์ง ์ ์ฉ)
|
112 |
+
e_text_words = ellie_text.split()
|
113 |
+
if len(e_text_words) > 0:
|
114 |
+
ellie_text_processed = " ".join(e_text_words[len(e_text_words) // 2:])
|
115 |
+
else:
|
116 |
+
ellie_text_processed = ""
|
117 |
+
|
118 |
+
# P-model ์์ธก
|
119 |
+
p_encoding = tokenizer.encode_plus(
|
120 |
+
participant_text,
|
121 |
+
add_special_tokens=True,
|
122 |
+
max_length=MAX_LEN,
|
123 |
+
return_token_type_ids=False,
|
124 |
+
padding='max_length',
|
125 |
+
truncation=True,
|
126 |
+
return_attention_mask=True,
|
127 |
+
return_tensors='pt',
|
128 |
+
)
|
129 |
+
p_input_ids = p_encoding['input_ids'].to(device)
|
130 |
+
p_attention_mask = p_encoding['attention_mask'].to(device)
|
131 |
+
|
132 |
+
with torch.no_grad(): # ์ถ๋ก ์์๋ ๊ทธ๋ผ๋์ธํธ ๊ณ์ฐ ๋ถํ์
|
133 |
+
p_outputs = p_model_for_inference(input_ids=p_input_ids, attention_mask=p_attention_mask)
|
134 |
+
p_probs = F.softmax(p_outputs.logits, dim=1).cpu().numpy().flatten()
|
135 |
+
p_pred_label = np.argmax(p_probs)
|
136 |
+
|
137 |
+
# E-model ์์ธก
|
138 |
+
e_encoding = tokenizer.encode_plus(
|
139 |
+
ellie_text_processed,
|
140 |
+
add_special_tokens=True,
|
141 |
+
max_length=MAX_LEN,
|
142 |
+
return_token_type_ids=False,
|
143 |
+
padding='max_length',
|
144 |
+
truncation=True,
|
145 |
+
return_attention_mask=True,
|
146 |
+
return_tensors='pt',
|
147 |
+
)
|
148 |
+
e_input_ids = e_encoding['input_ids'].to(device)
|
149 |
+
e_attention_mask = e_encoding['attention_mask'].to(device)
|
150 |
+
|
151 |
+
with torch.no_grad(): # ์ถ๋ก ์์๋ ๊ทธ๋ผ๋์ธํธ ๊ณ์ฐ ๋ถํ์
|
152 |
+
e_outputs = e_model_for_inference(input_ids=e_input_ids, attention_mask=e_attention_mask)
|
153 |
+
e_probs = F.softmax(e_outputs.logits, dim=1).cpu().numpy().flatten()
|
154 |
+
e_pred_label = np.argmax(e_probs)
|
155 |
+
|
156 |
+
# ์์๋ธ (OR ์ ๋ต): ๋ ์ค ํ๋๋ผ๋ ์ฐ์ธ์ฆ(1)์ผ๋ก ์์ธกํ๋ฉด ์ฐ์ธ์ฆ์ผ๋ก ๊ฐ์ฃผ
|
157 |
+
ensemble_pred_label = 1 if p_pred_label == 1 or e_pred_label == 1 else 0
|
158 |
+
|
159 |
+
labels = ['Control (๋น์ฐ์ธ)', 'Depressed (์ฐ์ธ)']
|
160 |
+
ensemble_result = labels[ensemble_pred_label]
|
161 |
+
p_model_result = labels[p_pred_label]
|
162 |
+
e_model_result = labels[e_pred_label]
|
163 |
+
|
164 |
+
return (f"**์ต์ข
์์๋ธ ์์ธก (OR ์ ๋ต): {ensemble_result}**\n\n"
|
165 |
+
f" - ์ฐธ๊ฐ์ ๋ชจ๋ธ (P-longBERT) ์์ธก: {p_model_result} (ํ๋ฅ : Control={p_probs[0]:.2f}, Depressed={p_probs[1]:.2f})\n"
|
166 |
+
f" - ์๋ฆฌ ๋ชจ๋ธ (E-longBERT) ์์ธก: {e_model_result} (ํ๋ฅ : Control={e_probs[0]:.2f}, Depressed={e_probs[1]:.2f})\n\n"
|
167 |
+
f"**์ฐธ๊ณ :**\n"
|
168 |
+
f"- ์์ธก์ ๊ฐ ๋ํ ๋ด์ฉ์๋ง ๊ธฐ๋ฐํ๋ฉฐ, ์ค์ ์ง๋จ์ ์ ๋ฌธ๊ฐ์ ์๋ดํด์ผ ํฉ๋๋ค.\n"
|
169 |
+
f"- GPU ํ๊ฒฝ์์๋ ์์ธก์ด ๋น ๋ฅด๊ฒ ์ํ๋ฉ๋๋ค."
|
170 |
+
)
|
171 |
+
|
172 |
+
# =======================================================
|
173 |
+
# 5. Gradio UI ์ธํฐํ์ด์ค ์์ฑ ๋ฐ ์คํ
|
174 |
+
# =======================================================
|
175 |
+
print("\n--- Setting up Gradio UI ---")
|
176 |
+
|
177 |
+
# ๋ชจ๋ธ์ด ์ฑ๊ณต์ ์ผ๋ก ๋ก๋๋์์ ๊ฒฝ์ฐ์๋ง Gradio UI๋ฅผ ์คํ
|
178 |
+
if p_model_for_inference is not None and e_model_for_inference is not None:
|
179 |
+
gr.Interface(
|
180 |
+
fn=predict_depression,
|
181 |
+
inputs=[
|
182 |
+
gr.Textbox(lines=10, label="์ฐธ๊ฐ์ ๋ฐํ ๋ด์ฉ (Participant's speech)", placeholder="์ฌ๊ธฐ์ ์ฐธ๊ฐ์์ ๋ฐํ ๋ด์ฉ์ ์
๋ ฅํ์ธ์..."),
|
183 |
+
gr.Textbox(lines=10, label="์๋ฆฌ ๋ฐํ ๋ด์ฉ (Ellie's speech)", placeholder="์ฌ๊ธฐ์ ์๋ฆฌ(๊ฐ์ ์์ด์ ํธ)์ ๋ฐํ ๋ด์ฉ์ ์
๋ ฅํ์ธ์... (์ ์ฒด ๋ด์ฉ ์ค ํ๋ฐ๋ถ๋ง ์ฌ์ฉ๋จ)")
|
184 |
+
],
|
185 |
+
outputs="markdown",
|
186 |
+
title="DAIC-WOZ ์ฐ์ธ์ฆ ๊ฐ์ง ์์๋ธ ๋ชจ๋ธ (GPU ๊ฐ์)",
|
187 |
+
description=f"""์ด ์ฑ์ DAIC-WOZ ๋ฐ์ดํฐ์
์ ๊ธฐ๋ฐ์ผ๋ก ์ฐธ๊ฐ์์ ๊ฐ์ ์์ด์ ํธ(์๋ฆฌ)์ ๋ํ ๋ด์ฉ์ ๋ถ์ํ์ฌ ์ฐ์ธ์ฆ ์ฌ๋ถ๋ฅผ ์์ธกํฉ๋๋ค.
|
188 |
+
P-longBERT (์ฐธ๊ฐ์ ๋ฐํ)์ E-longBERT (์๋ฆฌ ๋ฐํ) ๋ชจ๋ธ์ ์์๋ธ (OR ์ ๋ต) ๊ฒฐ๊ณผ๋ฅผ ์ ๊ณตํฉ๋๋ค.
|
189 |
+
**GPU ํ๊ฒฝ์์๋ ์์ธก์ด ๋น ๋ฅด๊ฒ ์ํ๋ฉ๋๋ค.**
|
190 |
+
**์ฐธ๊ณ :** ์ด๋ AI ๋ชจ๋ธ์ ์์ธก์ผ ๋ฟ์ด๋ฉฐ, **์ค์ ์ํ์ ์ง๋จ์ ๋ฐ๋์ ์ ๋ฌธ๊ฐ์ ์๋ดํด์ผ ํฉ๋๋ค.**
|
191 |
+
์ฌ์ฉ ์ค์ธ ๋๋ฐ์ด์ค: {device}
|
192 |
+
"""
|
193 |
+
).launch() # Hugging Face Spaces์์๋ share=True๊ฐ ํ์ ์์
|
194 |
+
else:
|
195 |
+
print("\nGradio UI could not be launched because models failed to load. Please check model files.")
|
196 |
+
|