Spaces:
Sleeping
Sleeping
# ํ์ํ ๋ผ์ด๋ธ๋ฌ๋ฆฌ ์ํฌํธ | |
import os | |
import pandas as pd # pandas๋ ํ์ฌ ์ฝ๋์์๋ ์ง์ ์ฌ์ฉ๋์ง ์์ง๋ง, ๋ฐ์ดํฐ ์ฒ๋ฆฌ ๊ด๋ จ ์ ํธ๋ฆฌํฐ๋ก ๋จ๊ฒจ๋ ์ ์์ต๋๋ค. | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.utils.data import Dataset, DataLoader # DataLoader์ Dataset์ ์ถ๋ก ์ ์ง์ ์ฌ์ฉ๋์ง ์์ง๋ง, ๋ชจ๋ธ ์ ์์ ํ์ํ ์ ์์ด ๋จ๊ฒจ๋ | |
from transformers import LongformerForSequenceClassification, AutoTokenizer | |
import gradio as gr | |
# ======================================================= | |
# 1. ์ ์ญ ์ค์ ๋ฐ ์์ ์ ์ | |
# ======================================================= | |
MODEL_NAME = 'kiddothe2b/longformer-mini-1024' # HuggingFace ๋ชจ๋ธ ์ด๋ฆ | |
MAX_LEN = 1024 # ๋ชจ๋ธ ์ ๋ ฅ ์ต๋ ๊ธธ์ด | |
# GPU ์ฌ์ฉ ๊ฐ๋ฅ ์ฌ๋ถ ํ์ธ ๋ฐ ๋๋ฐ์ด์ค ์ค์ | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
print(f"Using device: {device}") | |
# ํ ํฌ๋์ด์ ๋ก๋ (์ถ๋ก ์ ํ์) | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
# ======================================================= | |
# 2. PyTorch ๋ฐ์ดํฐ์ ์ ์ (ํ์ต ์ ์ฌ์ฉ๋์๋ ํด๋์ค. ์ถ๋ก ์ ์ง์ ๋ฐ์ดํฐ ๋ก๋๋ฅผ ๋ง๋ค์ง๋ ์์) | |
# ======================================================= | |
# ์ด ํด๋์ค๋ ๋ชจ๋ธ์ด ํ์ต๋ ๋ ์ฌ์ฉ๋์๋ ๋ฐ์ดํฐ ๊ตฌ์กฐ๋ฅผ ์ ์ํฉ๋๋ค. | |
# ์ถ๋ก ์์๋ ๋จ์ผ ํ ์คํธ ์ ๋ ฅ์ด ๋ค์ด์ค๋ฏ๋ก ์ง์ DataLoader๋ฅผ ๋ง๋ค ํ์๋ ์์ต๋๋ค. | |
# ํ์ง๋ง ๋ชจ๋ธ์ด ๊ธฐ๋ํ๋ ์ ๋ ฅ ํํ๋ฅผ ๋ง์ถ๊ธฐ ์ํด encoding ๊ณผ์ ์ด ์ฌ์ฉ๋ฉ๋๋ค. | |
class DepressionDataset(Dataset): | |
def __init__(self, texts, labels, tokenizer, max_len): | |
self.texts = texts | |
self.labels = labels | |
self.tokenizer = tokenizer | |
self.max_len = max_len | |
def __len__(self): | |
return len(self.texts) | |
def __getitem__(self, item): | |
text = str(self.texts[item]) | |
label = self.labels[item] | |
encoding = self.tokenizer.encode_plus( | |
text, | |
add_special_tokens=True, | |
max_length=self.max_len, | |
return_token_type_ids=False, | |
padding='max_length', | |
truncation=True, | |
return_attention_mask=True, | |
return_tensors='pt', | |
) | |
return { | |
'input_ids': encoding['input_ids'].flatten(), | |
'attention_mask': encoding['attention_mask'].flatten(), | |
'labels': torch.tensor(label, dtype=torch.long) | |
} | |
# ======================================================= | |
# 3. ๋ชจ๋ธ ๋ก๋ฉ (ํ์ต๋ ๊ฐ์ค์น๋ฅผ ๋ก๋) | |
# ======================================================= | |
print("\n--- Loading models for inference ---") | |
# ๋ชจ๋ธ ํ์ผ ๊ฒฝ๋ก (saved_models ํด๋๊ฐ ์์ผ๋ฏ๋ก ๋ฃจํธ ๋๋ ํ ๋ฆฌ์ ์๋ค๊ณ ๊ฐ์ ) | |
# ์ด์ ์ ์๋ save_dir ๋ณ์๋ ์ด์ ํ์ ์์ต๋๋ค. | |
p_model_path = 'p_text_best_model.bin' # ํ์ผ๋ช ์ด ๋ฃจํธ์ ๋ฐ๋ก ์๋ค๊ณ ๊ฐ์ | |
e_model_path = 'e_text_best_model.bin' # ํ์ผ๋ช ์ด ๋ฃจํธ์ ๋ฐ๋ก ์๋ค๊ณ ๊ฐ์ | |
# ๋ชจ๋ธ ๋ก๋ฉ ๋ฐ ํ๊ฐ ๋ชจ๋ ์ค์ | |
p_model_for_inference = None | |
e_model_for_inference = None | |
try: | |
# ์ฐธ๊ฐ์ ๋ฐํ ๋ชจ๋ธ (P-model) ๋ก๋ | |
if os.path.exists(p_model_path): | |
p_model_for_inference = LongformerForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2) | |
p_model_for_inference.load_state_dict(torch.load(p_model_path, map_location=device)) | |
p_model_for_inference.to(device) | |
p_model_for_inference.eval() # ํ๊ฐ ๋ชจ๋ ์ค์ | |
print(f"P-model loaded successfully from {p_model_path}") | |
else: | |
print(f"Warning: P-model file not found at {p_model_path}. Please ensure it's uploaded to the root directory.") | |
# ์๋ฆฌ ๋ฐํ ๋ชจ๋ธ (E-model) ๋ก๋ | |
if os.path.exists(e_model_path): | |
e_model_for_inference = LongformerForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2) | |
e_model_for_inference.load_state_dict(torch.load(e_model_path, map_location=device)) | |
e_model_for_inference.to(device) | |
e_model_for_inference.eval() # ํ๊ฐ ๋ชจ๋ ์ค์ | |
print(f"E-model loaded successfully from {e_model_path}") | |
else: | |
print(f"Warning: E-model file not found at {e_model_path}. Please ensure it's uploaded to the root directory.") | |
except Exception as e: | |
print(f"Error loading models: {e}") | |
# ๋ชจ๋ธ ๋ก๋ฉ ์คํจ ์, UI๊ฐ ์คํ๋์ง ์๋๋ก ์ค์ | |
p_model_for_inference = None | |
e_model_for_inference = None | |
# ======================================================= | |
# 4. Gradio ์์ธก ํจ์ ์ ์ | |
# ======================================================= | |
def predict_depression(participant_text, ellie_text): | |
# ๋ชจ๋ธ์ด ์ ๋๋ก ๋ก๋๋์๋์ง ํ์ธ | |
if p_model_for_inference is None or e_model_for_inference is None: | |
return "**์ค๋ฅ:** ๋ชจ๋ธ์ด ๋ก๋๋์ง ์์์ต๋๋ค. ๊ด๋ฆฌ์์๊ฒ ๋ฌธ์ํ๊ฑฐ๋ ๋ชจ๋ธ ํ์ผ ์ ๋ก๋ ์ฌ๋ถ๋ฅผ ํ์ธํด์ฃผ์ธ์." | |
# ์๋ฆฌ ๋ฐํ ์ ์ฒ๋ฆฌ (ํ์ต ์์ ๋์ผํ ๋ก์ง ์ ์ฉ) | |
e_text_words = ellie_text.split() | |
if len(e_text_words) > 0: | |
ellie_text_processed = " ".join(e_text_words[len(e_text_words) // 2:]) | |
else: | |
ellie_text_processed = "" | |
# P-model ์์ธก | |
p_encoding = tokenizer.encode_plus( | |
participant_text, | |
add_special_tokens=True, | |
max_length=MAX_LEN, | |
return_token_type_ids=False, | |
padding='max_length', | |
truncation=True, | |
return_attention_mask=True, | |
return_tensors='pt', | |
) | |
p_input_ids = p_encoding['input_ids'].to(device) | |
p_attention_mask = p_encoding['attention_mask'].to(device) | |
with torch.no_grad(): # ์ถ๋ก ์์๋ ๊ทธ๋ผ๋์ธํธ ๊ณ์ฐ ๋ถํ์ | |
p_outputs = p_model_for_inference(input_ids=p_input_ids, attention_mask=p_attention_mask) | |
p_probs = F.softmax(p_outputs.logits, dim=1).cpu().numpy().flatten() | |
p_pred_label = np.argmax(p_probs) | |
# E-model ์์ธก | |
e_encoding = tokenizer.encode_plus( | |
ellie_text_processed, | |
add_special_tokens=True, | |
max_length=MAX_LEN, | |
return_token_type_ids=False, | |
padding='max_length', | |
truncation=True, | |
return_attention_mask=True, | |
return_tensors='pt', | |
) | |
e_input_ids = e_encoding['input_ids'].to(device) | |
e_attention_mask = e_encoding['attention_mask'].to(device) | |
with torch.no_grad(): # ์ถ๋ก ์์๋ ๊ทธ๋ผ๋์ธํธ ๊ณ์ฐ ๋ถํ์ | |
e_outputs = e_model_for_inference(input_ids=e_input_ids, attention_mask=e_attention_mask) | |
e_probs = F.softmax(e_outputs.logits, dim=1).cpu().numpy().flatten() | |
e_pred_label = np.argmax(e_probs) | |
# ์์๋ธ (OR ์ ๋ต): ๋ ์ค ํ๋๋ผ๋ ์ฐ์ธ์ฆ(1)์ผ๋ก ์์ธกํ๋ฉด ์ฐ์ธ์ฆ์ผ๋ก ๊ฐ์ฃผ | |
ensemble_pred_label = 1 if p_pred_label == 1 or e_pred_label == 1 else 0 | |
labels = ['Control (๋น์ฐ์ธ)', 'Depressed (์ฐ์ธ)'] | |
ensemble_result = labels[ensemble_pred_label] | |
p_model_result = labels[p_pred_label] | |
e_model_result = labels[e_pred_label] | |
return (f"**์ต์ข ์์๋ธ ์์ธก (OR ์ ๋ต): {ensemble_result}**\n\n" | |
f" - ์ฐธ๊ฐ์ ๋ชจ๋ธ (P-longBERT) ์์ธก: {p_model_result} (ํ๋ฅ : Control={p_probs[0]:.2f}, Depressed={p_probs[1]:.2f})\n" | |
f" - ์๋ฆฌ ๋ชจ๋ธ (E-longBERT) ์์ธก: {e_model_result} (ํ๋ฅ : Control={e_probs[0]:.2f}, Depressed={e_probs[1]:.2f})\n\n" | |
f"**์ฐธ๊ณ :**\n" | |
f"- ์์ธก์ ๊ฐ ๋ํ ๋ด์ฉ์๋ง ๊ธฐ๋ฐํ๋ฉฐ, ์ค์ ์ง๋จ์ ์ ๋ฌธ๊ฐ์ ์๋ดํด์ผ ํฉ๋๋ค.\n" | |
f"- GPU ํ๊ฒฝ์์๋ ์์ธก์ด ๋น ๋ฅด๊ฒ ์ํ๋ฉ๋๋ค." | |
) | |
# ======================================================= | |
# 5. Gradio UI ์ธํฐํ์ด์ค ์์ฑ ๋ฐ ์คํ | |
# ======================================================= | |
print("\n--- Setting up Gradio UI ---") | |
# ๋ชจ๋ธ์ด ์ฑ๊ณต์ ์ผ๋ก ๋ก๋๋์์ ๊ฒฝ์ฐ์๋ง Gradio UI๋ฅผ ์คํ | |
if p_model_for_inference is not None and e_model_for_inference is not None: | |
gr.Interface( | |
fn=predict_depression, | |
inputs=[ | |
gr.Textbox(lines=10, label="์ฐธ๊ฐ์ ๋ฐํ ๋ด์ฉ (Participant's speech)", placeholder="์ฌ๊ธฐ์ ์ฐธ๊ฐ์์ ๋ฐํ ๋ด์ฉ์ ์ ๋ ฅํ์ธ์..."), | |
gr.Textbox(lines=10, label="์๋ฆฌ ๋ฐํ ๋ด์ฉ (Ellie's speech)", placeholder="์ฌ๊ธฐ์ ์๋ฆฌ(๊ฐ์ ์์ด์ ํธ)์ ๋ฐํ ๋ด์ฉ์ ์ ๋ ฅํ์ธ์... (์ ์ฒด ๋ด์ฉ ์ค ํ๋ฐ๋ถ๋ง ์ฌ์ฉ๋จ)") | |
], | |
outputs="markdown", | |
title="DAIC-WOZ ์ฐ์ธ์ฆ ๊ฐ์ง ์์๋ธ ๋ชจ๋ธ (GPU ๊ฐ์)", | |
description=f"""์ด ์ฑ์ DAIC-WOZ ๋ฐ์ดํฐ์ ์ ๊ธฐ๋ฐ์ผ๋ก ์ฐธ๊ฐ์์ ๊ฐ์ ์์ด์ ํธ(์๋ฆฌ)์ ๋ํ ๋ด์ฉ์ ๋ถ์ํ์ฌ ์ฐ์ธ์ฆ ์ฌ๋ถ๋ฅผ ์์ธกํฉ๋๋ค. | |
P-longBERT (์ฐธ๊ฐ์ ๋ฐํ)์ E-longBERT (์๋ฆฌ ๋ฐํ) ๋ชจ๋ธ์ ์์๋ธ (OR ์ ๋ต) ๊ฒฐ๊ณผ๋ฅผ ์ ๊ณตํฉ๋๋ค. | |
**GPU ํ๊ฒฝ์์๋ ์์ธก์ด ๋น ๋ฅด๊ฒ ์ํ๋ฉ๋๋ค.** | |
**์ฐธ๊ณ :** ์ด๋ AI ๋ชจ๋ธ์ ์์ธก์ผ ๋ฟ์ด๋ฉฐ, **์ค์ ์ํ์ ์ง๋จ์ ๋ฐ๋์ ์ ๋ฌธ๊ฐ์ ์๋ดํด์ผ ํฉ๋๋ค.** | |
์ฌ์ฉ ์ค์ธ ๋๋ฐ์ด์ค: {device} | |
""" | |
).launch() # Hugging Face Spaces์์๋ share=True๊ฐ ํ์ ์์ | |
else: | |
print("\nGradio UI could not be launched because models failed to load. Please check model files.") | |