Spaces:
Running
on
Zero
Running
on
Zero
import re | |
import threading | |
import gc | |
import os | |
import torch | |
import time | |
import signal | |
import gradio as gr | |
import spaces | |
import transformers | |
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer | |
from huggingface_hub import login | |
# ๋ชจ๋ธ ๋ฉ๋ชจ๋ฆฌ ๊ด๋ฆฌ ๋ฐ ์ต์ ํ๋ฅผ ์ํ ์ค์ | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
MAX_GPU_MEMORY = 80 * 1024 * 1024 * 1024 # 80GB A100 ๊ธฐ์ค | |
# ์ฌ์ฉ ๊ฐ๋ฅํ ๋ชจ๋ธ ๋ชฉ๋ก - ๋ ์์ ๋ชจ๋ธ๋ถํฐ ์์ํ๋๋ก ๋ณ๊ฒฝ | |
available_models = { | |
"google/gemma-2b": "Google Gemma (2B)", # ๋ ์์ ๋ชจ๋ธ์ ๊ธฐ๋ณธ์ผ๋ก ์ค์ | |
"mistralai/Mistral-7B-Instruct-v0.2": "Mistral 7B Instruct v0.2", | |
"mistralai/Mistral-Small-3.1-24B-Base-2503": "Mistral Small 3.1 (24B)", | |
"google/gemma-3-27b-it": "Google Gemma 3 (27B)", | |
"Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen 2.5 Coder (32B)", | |
"open-r1/OlympicCoder-32B": "Olympic Coder (32B)" | |
} | |
# ๊ธฐ๋ณธ ๋ชจ๋ธ - ๊ฐ์ฅ ์์ ๋ชจ๋ธ๋ก ์ค์ | |
DEFAULT_MODEL_KEY = list(available_models.keys())[0] | |
DEFAULT_MODEL_VALUE = available_models[DEFAULT_MODEL_KEY] | |
# ๋ชจ๋ธ ๋ก๋์ ์ฌ์ฉ๋๋ ์ ์ญ ๋ณ์ | |
pipe = None | |
current_model_name = None | |
loading_in_progress = False | |
# Hugging Face ํ ํฐ์ผ๋ก ๋ก๊ทธ์ธ ์๋ | |
try: | |
hf_token = os.getenv("HF_TOKEN") | |
if hf_token: | |
login(token=hf_token) | |
print("Hugging Face์ ์ฑ๊ณต์ ์ผ๋ก ๋ก๊ทธ์ธํ์ต๋๋ค.") | |
else: | |
print("๊ฒฝ๊ณ : HF_TOKEN ํ๊ฒฝ ๋ณ์๊ฐ ์ค์ ๋์ง ์์์ต๋๋ค.") | |
except Exception as e: | |
print(f"Hugging Face ๋ก๊ทธ์ธ ์๋ฌ: {str(e)}") | |
# ์ต์ข ๋ต๋ณ์ ๊ฐ์งํ๊ธฐ ์ํ ๋ง์ปค | |
ANSWER_MARKER = "**๋ต๋ณ**" | |
# ๋จ๊ณ๋ณ ์ถ๋ก ์ ์์ํ๋ ๋ฌธ์ฅ๋ค | |
rethink_prepends = [ | |
"์, ์ด์ ๋ค์์ ํ์ ํด์ผ ํฉ๋๋ค ", | |
"์ ์๊ฐ์๋ ", | |
"์ ์๋ง์, ์ ์๊ฐ์๋ ", | |
"๋ค์ ์ฌํญ์ด ๋ง๋์ง ํ์ธํด ๋ณด๊ฒ ์ต๋๋ค ", | |
"๋ํ ๊ธฐ์ตํด์ผ ํ ๊ฒ์ ", | |
"๋ ๋ค๋ฅธ ์ฃผ๋ชฉํ ์ ์ ", | |
"๊ทธ๋ฆฌ๊ณ ์ ๋ ๋ค์๊ณผ ๊ฐ์ ์ฌ์ค๋ ๊ธฐ์ตํฉ๋๋ค ", | |
"์ด์ ์ถฉ๋ถํ ์ดํดํ๋ค๊ณ ์๊ฐํฉ๋๋ค ", | |
"์ง๊ธ๊น์ง์ ์ ๋ณด๋ฅผ ๋ฐํ์ผ๋ก, ์๋ ์ง๋ฌธ์ ์ฌ์ฉ๋ ์ธ์ด๋ก ๋ต๋ณํ๊ฒ ์ต๋๋ค:" | |
"\n{question}\n" | |
f"\n{ANSWER_MARKER}\n", | |
] | |
# ์์ ํ์ ๋ฌธ์ ํด๊ฒฐ์ ์ํ ์ค์ | |
latex_delimiters = [ | |
{"left": "$$", "right": "$$", "display": True}, | |
{"left": "$", "right": "$", "display": False}, | |
] | |
# ๋ชจ๋ธ ํฌ๊ธฐ ๊ธฐ๋ฐ ๊ตฌ์ฑ - ๋ชจ๋ธ ํฌ๊ธฐ์ ๋ฐ๋ฅธ ์ต์ ์ค์ ์ ์ | |
MODEL_CONFIG = { | |
"small": { # <10B | |
"max_memory": {0: "10GiB"}, | |
"offload": False, | |
"quantization": None | |
}, | |
"medium": { # 10B-30B | |
"max_memory": {0: "30GiB"}, | |
"offload": False, | |
"quantization": None | |
}, | |
"large": { # >30B | |
"max_memory": {0: "60GiB"}, | |
"offload": True, | |
"quantization": None | |
} | |
} | |
def get_model_size_category(model_name): | |
"""๋ชจ๋ธ ํฌ๊ธฐ ์นดํ ๊ณ ๋ฆฌ ๊ฒฐ์ """ | |
if "2B" in model_name or "3B" in model_name or "7B" in model_name or "8B" in model_name: | |
return "small" | |
elif "15B" in model_name or "24B" in model_name or "27B" in model_name: | |
return "medium" | |
elif "32B" in model_name or "70B" in model_name: | |
return "large" | |
else: | |
# ๊ธฐ๋ณธ๊ฐ์ผ๋ก small ๋ฐํ (์์ ์ ์ํด) | |
return "small" | |
def clear_gpu_memory(): | |
"""GPU ๋ฉ๋ชจ๋ฆฌ ์ ๋ฆฌ""" | |
global pipe | |
if pipe is not None: | |
del pipe | |
pipe = None | |
# CUDA ์บ์ ์ ๋ฆฌ | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
torch.cuda.synchronize() | |
def reformat_math(text): | |
"""Gradio ๊ตฌ๋ฌธ(Katex)์ ์ฌ์ฉํ๋๋ก MathJax ๊ตฌ๋ถ ๊ธฐํธ ์์ .""" | |
text = re.sub(r"\\\[\s*(.*?)\s*\\\]", r"$$\1$$", text, flags=re.DOTALL) | |
text = re.sub(r"\\\(\s*(.*?)\s*\\\)", r"$\1$", text, flags=re.DOTALL) | |
return text | |
def user_input(message, history: list): | |
"""์ฌ์ฉ์ ์ ๋ ฅ์ ํ์คํ ๋ฆฌ์ ์ถ๊ฐํ๊ณ ์ ๋ ฅ ํ ์คํธ ์์ ๋น์ฐ๊ธฐ""" | |
return "", history + [ | |
gr.ChatMessage(role="user", content=message.replace(ANSWER_MARKER, "")) | |
] | |
def rebuild_messages(history: list): | |
"""์ค๊ฐ ์๊ฐ ๊ณผ์ ์์ด ๋ชจ๋ธ์ด ์ฌ์ฉํ ํ์คํ ๋ฆฌ์์ ๋ฉ์์ง ์ฌ๊ตฌ์ฑ""" | |
messages = [] | |
for h in history: | |
if isinstance(h, dict) and not h.get("metadata", {}).get("title", False): | |
messages.append(h) | |
elif ( | |
isinstance(h, gr.ChatMessage) | |
and h.metadata.get("title") | |
and isinstance(h.content, str) | |
): | |
messages.append({"role": h.role, "content": h.content}) | |
return messages | |
def load_model(model_names, status_callback=None): | |
"""์ ํ๋ ๋ชจ๋ธ ์ด๋ฆ์ ๋ฐ๋ผ ๋ชจ๋ธ ๋ก๋ (A100์ ์ต์ ํ๋ ์ค์ ์ฌ์ฉ)""" | |
global pipe, current_model_name, loading_in_progress | |
# ์ด๋ฏธ ๋ก๋ฉ ์ค์ธ ๊ฒฝ์ฐ | |
if loading_in_progress: | |
return "๋ค๋ฅธ ๋ชจ๋ธ์ด ์ด๋ฏธ ๋ก๋ ์ค์ ๋๋ค. ์ ์ ๊ธฐ๋ค๋ ค์ฃผ์ธ์." | |
loading_in_progress = True | |
try: | |
# ๊ธฐ์กด ๋ชจ๋ธ ์ ๋ฆฌ | |
clear_gpu_memory() | |
# ๋ชจ๋ธ์ด ์ ํ๋์ง ์์์ ๊ฒฝ์ฐ ๊ธฐ๋ณธ๊ฐ ์ง์ | |
if not model_names: | |
model_name = DEFAULT_MODEL_KEY | |
else: | |
# ์ฒซ ๋ฒ์งธ ์ ํ๋ ๋ชจ๋ธ ์ฌ์ฉ | |
model_name = model_names[0] | |
# ๋ชจ๋ธ ํฌ๊ธฐ ์นดํ ๊ณ ๋ฆฌ ํ์ธ | |
size_category = get_model_size_category(model_name) | |
config = MODEL_CONFIG[size_category] | |
# ๋ก๋ฉ ์ํ ์ ๋ฐ์ดํธ | |
if status_callback: | |
status_callback(f"๋ชจ๋ธ '{model_name}' ๋ก๋ ์ค... (ํฌ๊ธฐ: {size_category})") | |
# ๋ชจ๋ธ ๋ก๋ (ํฌ๊ธฐ์ ๋ฐ๋ผ ์ต์ ํ๋ ์ค์ ์ ์ฉ) | |
# HF_TOKEN ํ๊ฒฝ ๋ณ์ ํ์ธ | |
hf_token = os.getenv("HF_TOKEN") | |
# ๊ณตํต ๋งค๊ฐ๋ณ์ | |
common_params = { | |
"token": hf_token, # ์ ๊ทผ ์ ํ ๋ชจ๋ธ์ ์ํ ํ ํฐ | |
"trust_remote_code": True, | |
} | |
# BitsAndBytes ์ฌ์ฉ ์ฌ๋ถ ํ์ธ | |
try: | |
import bitsandbytes | |
has_bitsandbytes = True | |
except ImportError: | |
has_bitsandbytes = False | |
if status_callback: | |
status_callback(f"BitsAndBytes ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค. ์์ํ ์์ด ๋ก๋ํฉ๋๋ค.") | |
# ์๊ฐ ์ ํ ์ค์ (๋ชจ๋ธ ํฌ๊ธฐ์ ๋ฐ๋ผ ๋ค๋ฅด๊ฒ) | |
if size_category == "small": | |
load_timeout = 180 # 3๋ถ | |
elif size_category == "medium": | |
load_timeout = 300 # 5๋ถ | |
else: | |
load_timeout = 600 # 10๋ถ | |
# ๋ก๋ฉ ์์ ์๊ฐ | |
start_time = time.time() | |
# ์์ํ ์ค์ ์ด ํ์ํ๊ณ BitsAndBytes๋ฅผ ์ฌ์ฉํ ์ ์๋ ๊ฒฝ์ฐ | |
if config["quantization"] and has_bitsandbytes: | |
# ์์ํ ์ ์ฉ | |
from transformers import BitsAndBytesConfig | |
quantization_config = BitsAndBytesConfig( | |
load_in_4bit=config["quantization"] == "4bit", | |
bnb_4bit_compute_dtype=DTYPE | |
) | |
if status_callback: | |
status_callback(f"๋ชจ๋ธ '{model_name}' ๋ก๋ ์ค... (์์ํ ์ ์ฉ)") | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
device_map="auto", | |
max_memory=config["max_memory"], | |
torch_dtype=DTYPE, | |
quantization_config=quantization_config, | |
offload_folder="offload" if config["offload"] else None, | |
**common_params | |
) | |
tokenizer = AutoTokenizer.from_pretrained(model_name, **common_params) | |
pipe = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
torch_dtype=DTYPE, | |
device_map="auto" | |
) | |
else: | |
# ์์ํ ์์ด ๋ก๋ | |
if status_callback: | |
status_callback(f"๋ชจ๋ธ '{model_name}' ๋ก๋ ์ค... (ํ์ค ๋ฐฉ์)") | |
pipe = pipeline( | |
"text-generation", | |
model=model_name, | |
device_map="auto", | |
torch_dtype=DTYPE, | |
**common_params | |
) | |
# ์๊ฐ ์ ํ ์ด๊ณผ ํ์ธ | |
elapsed_time = time.time() - start_time | |
if elapsed_time > load_timeout: | |
clear_gpu_memory() | |
loading_in_progress = False | |
return f"๋ชจ๋ธ ๋ก๋ ์๊ฐ ์ด๊ณผ: {load_timeout}์ด๊ฐ ์ง๋ฌ์ต๋๋ค. ๋ค์ ์๋ํ์ธ์." | |
current_model_name = model_name | |
loading_in_progress = False | |
return f"๋ชจ๋ธ '{model_name}'์ด(๊ฐ) ์ฑ๊ณต์ ์ผ๋ก ๋ก๋๋์์ต๋๋ค. (์ต์ ํ: {size_category}, ์์์๊ฐ: {elapsed_time:.1f}์ด)" | |
except Exception as e: | |
loading_in_progress = False | |
return f"๋ชจ๋ธ ๋ก๋ ์คํจ: {str(e)}" | |
def bot( | |
history: list, | |
max_num_tokens: int, | |
final_num_tokens: int, | |
do_sample: bool, | |
temperature: float, | |
): | |
"""๋ชจ๋ธ์ด ์ง๋ฌธ์ ๋ต๋ณํ๋๋ก ํ๊ธฐ""" | |
global pipe, current_model_name | |
# ๋ชจ๋ธ์ด ๋ก๋๋์ง ์์๋ค๋ฉด ์ค๋ฅ ๋ฉ์์ง ํ์ | |
if pipe is None: | |
history.append( | |
gr.ChatMessage( | |
role="assistant", | |
content="๋ชจ๋ธ์ด ๋ก๋๋์ง ์์์ต๋๋ค. ํ๋ ์ด์์ ๋ชจ๋ธ์ ์ ํํ๊ณ '๋ชจ๋ธ ๋ก๋' ๋ฒํผ์ ํด๋ฆญํด ์ฃผ์ธ์.", | |
) | |
) | |
yield history | |
return | |
try: | |
# ํ ํฐ ๊ธธ์ด ์๋ ์กฐ์ (๋ชจ๋ธ ํฌ๊ธฐ์ ๋ฐ๋ผ) | |
size_category = get_model_size_category(current_model_name) | |
# ๋ํ ๋ชจ๋ธ์ ํ ํฐ ์๋ฅผ ์ค์ฌ ๋ฉ๋ชจ๋ฆฌ ํจ์จ์ฑ ํฅ์ | |
if size_category == "large": | |
max_num_tokens = min(max_num_tokens, 1000) | |
final_num_tokens = min(final_num_tokens, 1500) | |
# ๋์ค์ ์ค๋ ๋์์ ํ ํฐ์ ์คํธ๋ฆผ์ผ๋ก ๊ฐ์ ธ์ค๊ธฐ ์ํจ | |
streamer = transformers.TextIteratorStreamer( | |
pipe.tokenizer, | |
skip_special_tokens=True, | |
skip_prompt=True, | |
) | |
# ํ์ํ ๊ฒฝ์ฐ ์ถ๋ก ์ ์ง๋ฌธ์ ๋ค์ ์ฝ์ ํ๊ธฐ ์ํจ | |
question = history[-1]["content"] | |
# ๋ณด์กฐ์ ๋ฉ์์ง ์ค๋น | |
history.append( | |
gr.ChatMessage( | |
role="assistant", | |
content=str(""), | |
metadata={"title": "๐ง ์๊ฐ ์ค...", "status": "pending"}, | |
) | |
) | |
# ํ์ฌ ์ฑํ ์ ํ์๋ ์ถ๋ก ๊ณผ์ | |
messages = rebuild_messages(history) | |
# ํ์์์ ์ค์ | |
class TimeoutError(Exception): | |
pass | |
def timeout_handler(signum, frame): | |
raise TimeoutError("์์ฒญ ์ฒ๋ฆฌ ์๊ฐ์ด ์ด๊ณผ๋์์ต๋๋ค.") | |
# ๊ฐ ๋จ๊ณ๋ง๋ค ์ต๋ 120์ด ํ์์์ ์ค์ | |
timeout_seconds = 120 | |
for i, prepend in enumerate(rethink_prepends): | |
if i > 0: | |
messages[-1]["content"] += "\n\n" | |
messages[-1]["content"] += prepend.format(question=question) | |
num_tokens = int( | |
max_num_tokens if ANSWER_MARKER not in prepend else final_num_tokens | |
) | |
# ์ค๋ ๋์์ ๋ชจ๋ธ ์คํ | |
t = threading.Thread( | |
target=pipe, | |
args=(messages,), | |
kwargs=dict( | |
max_new_tokens=num_tokens, | |
streamer=streamer, | |
do_sample=do_sample, | |
temperature=temperature, | |
# ๋ฉ๋ชจ๋ฆฌ ํจ์จ์ฑ์ ์ํ ์ถ๊ฐ ํ๋ผ๋ฏธํฐ | |
repetition_penalty=1.2, # ๋ฐ๋ณต ๋ฐฉ์ง | |
use_cache=True, # KV ์บ์ ์ฌ์ฉ | |
), | |
) | |
t.daemon = True # ๋ฐ๋ชฌ ์ค๋ ๋๋ก ์ค์ ํ์ฌ ๋ฉ์ธ ์ค๋ ๋๊ฐ ์ข ๋ฃ๋๋ฉด ํจ๊ป ์ข ๋ฃ | |
t.start() | |
# ์ ๋ด์ฉ์ผ๋ก ํ์คํ ๋ฆฌ ์ฌ๊ตฌ์ฑ | |
history[-1].content += prepend.format(question=question) | |
if ANSWER_MARKER in prepend: | |
history[-1].metadata = {"title": "๐ญ ์ฌ๊ณ ๊ณผ์ ", "status": "done"} | |
# ์๊ฐ ์ข ๋ฃ, ์ด์ ๋ต๋ณ์ ๋๋ค (์ค๊ฐ ๋จ๊ณ์ ๋ํ ๋ฉํ๋ฐ์ดํฐ ์์) | |
history.append(gr.ChatMessage(role="assistant", content="")) | |
# ํ์์์ ์ค์ (Unix ์์คํ ์์๋ง ์๋) | |
try: | |
if hasattr(signal, 'SIGALRM'): | |
signal.signal(signal.SIGALRM, timeout_handler) | |
signal.alarm(timeout_seconds) | |
# ํ ํฐ ์คํธ๋ฆฌ๋ฐ | |
token_count = 0 | |
for token in streamer: | |
history[-1].content += token | |
history[-1].content = reformat_math(history[-1].content) | |
token_count += 1 | |
# 10๊ฐ ํ ํฐ๋ง๋ค yield (UI ์๋ต์ฑ ํฅ์) | |
if token_count % 10 == 0: | |
yield history | |
# ๋จ์ ๋ด์ฉ yield | |
yield history | |
# ํ์์์ ํด์ | |
if hasattr(signal, 'SIGALRM'): | |
signal.alarm(0) | |
except TimeoutError: | |
if hasattr(signal, 'SIGALRM'): | |
signal.alarm(0) | |
history[-1].content += "\n\nโ ๏ธ ์๋ต ์์ฑ ์๊ฐ์ด ์ด๊ณผ๋์์ต๋๋ค. ๋ค์ ๋จ๊ณ๋ก ์งํํฉ๋๋ค." | |
yield history | |
continue | |
# ์ต๋ 30์ด ๋๊ธฐ ํ ๋ค์ ๋จ๊ณ๋ก ์งํ | |
join_start_time = time.time() | |
while t.is_alive() and (time.time() - join_start_time) < 30: | |
t.join(1) # 1์ด๋ง๋ค ํ์ธ | |
# ์ค๋ ๋๊ฐ ์ฌ์ ํ ์คํ ์ค์ด๋ฉด ๊ฐ์ ์งํ | |
if t.is_alive(): | |
history[-1].content += "\n\nโ ๏ธ ์๋ต ์์ฑ์ด ์์๋ณด๋ค ์ค๋ ๊ฑธ๋ฆฝ๋๋ค. ๋ค์ ๋จ๊ณ๋ก ์งํํฉ๋๋ค." | |
yield history | |
# ๋ํ ๋ชจ๋ธ์ธ ๊ฒฝ์ฐ ๊ฐ ๋จ๊ณ ํ ๋ถ๋ถ์ ๋ฉ๋ชจ๋ฆฌ ์ ๋ฆฌ | |
if size_category == "large" and torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
except Exception as e: | |
# ์ค๋ฅ ๋ฐ์์ ์ฌ์ฉ์์๊ฒ ์๋ฆผ | |
import traceback | |
error_msg = f"\n\nโ ๏ธ ์ฒ๋ฆฌ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค: {str(e)}\n{traceback.format_exc()}" | |
if len(history) > 0 and isinstance(history[-1], gr.ChatMessage) and history[-1].role == "assistant": | |
history[-1].content += error_msg | |
else: | |
history.append(gr.ChatMessage(role="assistant", content=error_msg)) | |
yield history | |
yield history | |
# ์ฌ์ฉ ๊ฐ๋ฅํ GPU ์ ๋ณด ํ์ ํจ์ | |
def get_gpu_info(): | |
if not torch.cuda.is_available(): | |
return "GPU๋ฅผ ์ฌ์ฉํ ์ ์์ต๋๋ค." | |
gpu_info = [] | |
for i in range(torch.cuda.device_count()): | |
gpu_name = torch.cuda.get_device_name(i) | |
total_memory = torch.cuda.get_device_properties(i).total_memory / 1024**3 | |
gpu_info.append(f"GPU {i}: {gpu_name} ({total_memory:.1f} GB)") | |
return "\n".join(gpu_info) | |
# ์๋ ๋ชจ๋ธ ๋ก๋ ํจ์ (์ํ ์ ๋ฐ์ดํธ ํฌํจ) | |
def auto_load_model(): | |
# ์ฒซ ๋ฒ์งธ ๋ชจ๋ธ ์๋ ๋ก๋ | |
model_key = DEFAULT_MODEL_KEY | |
try: | |
# ์งํ ์ํ ํ์๋ฅผ ์ํ ๋น ๊ฒฐ๊ณผ ๋ฐํ | |
return "์์ ๋ชจ๋ธ ์๋ ๋ก๋ ์ค... ์ ์ ๊ธฐ๋ค๋ ค์ฃผ์ธ์." | |
except Exception as e: | |
return f"์๋ ๋ชจ๋ธ ๋ก๋ ์คํจ: {str(e)}" | |
# ์ค์ ๋ชจ๋ธ ๋ก๋ ํจ์ (๋น๋๊ธฐ) | |
def load_model_async(model_status): | |
# ๋น๋๊ธฐ ํจ์๋ก ๋ชจ๋ธ ๋ก๋ (์ค์ ๋ก๋๋ ๋ฐฑ๊ทธ๋ผ์ด๋์์ ์ํ) | |
model_key = DEFAULT_MODEL_KEY | |
def update_status(status): | |
model_status.update(value=status) | |
# ๋ณ๋ ์ค๋ ๋์์ ๋ก๋ | |
def load_in_thread(): | |
try: | |
result = load_model([model_key], update_status) | |
model_status.update(value=result) | |
except Exception as e: | |
model_status.update(value=f"๋ชจ๋ธ ๋ก๋ ์คํจ: {str(e)}") | |
threading.Thread(target=load_in_thread, daemon=True).start() | |
return "๋ชจ๋ธ ๋ก๋ ์ค๋น ์ค... ์๋์ผ๋ก ์งํ๋ฉ๋๋ค." | |
# Gradio ์ธํฐํ์ด์ค | |
with gr.Blocks(fill_height=True, title="ThinkFlow - Step-by-step Reasoning Service") as demo: | |
# ์๋จ์ ํ์ดํ๊ณผ ์ค๋ช ์ถ๊ฐ | |
gr.Markdown(""" | |
# ThinkFlow | |
## A thought amplification service that implants step-by-step reasoning abilities into LLMs without model modification | |
""") | |
with gr.Row(scale=1): | |
with gr.Column(scale=5): | |
# ์ฑํ ์ธํฐํ์ด์ค | |
chatbot = gr.Chatbot( | |
scale=1, | |
type="messages", | |
latex_delimiters=latex_delimiters, | |
height=600, | |
) | |
msg = gr.Textbox( | |
submit_btn=True, | |
label="", | |
show_label=False, | |
placeholder="์ฌ๊ธฐ์ ์ง๋ฌธ์ ์ ๋ ฅํ์ธ์.", | |
autofocus=True, | |
) | |
with gr.Column(scale=1): | |
# ํ๋์จ์ด ์ ๋ณด ํ์ | |
gpu_info = gr.Markdown(f"**์ฌ์ฉ ๊ฐ๋ฅํ ํ๋์จ์ด:**\n{get_gpu_info()}") | |
# ๋ชจ๋ธ ์ ํ ์น์ ์ถ๊ฐ | |
gr.Markdown("""## ๋ชจ๋ธ ์ ํ""") | |
model_selector = gr.Radio( | |
choices=list(available_models.values()), | |
value=DEFAULT_MODEL_VALUE, | |
label="์ฌ์ฉํ LLM ๋ชจ๋ธ ์ ํ", | |
) | |
# ๋ชจ๋ธ ๋ก๋ ๋ฒํผ | |
load_model_btn = gr.Button("๋ชจ๋ธ ๋ก๋", variant="primary") | |
model_status = gr.Textbox(label="๋ชจ๋ธ ์ํ", interactive=False) | |
# ๋ฉ๋ชจ๋ฆฌ ์ ๋ฆฌ ๋ฒํผ | |
clear_memory_btn = gr.Button("GPU ๋ฉ๋ชจ๋ฆฌ ์ ๋ฆฌ", variant="secondary") | |
gr.Markdown("""## ๋งค๊ฐ๋ณ์ ์กฐ์ """) | |
with gr.Accordion("๊ณ ๊ธ ์ค์ ", open=False): | |
num_tokens = gr.Slider( | |
50, | |
2000, | |
1000, | |
step=50, | |
label="์ถ๋ก ๋จ๊ณ๋น ์ต๋ ํ ํฐ ์", | |
interactive=True, | |
) | |
final_num_tokens = gr.Slider( | |
50, | |
3000, | |
1500, | |
step=50, | |
label="์ต์ข ๋ต๋ณ์ ์ต๋ ํ ํฐ ์", | |
interactive=True, | |
) | |
do_sample = gr.Checkbox(True, label="์ํ๋ง ์ฌ์ฉ") | |
temperature = gr.Slider(0.1, 1.0, 0.7, step=0.1, label="์จ๋") | |
# ์์ ์ ์๋์ผ๋ก ์ด๊ธฐํ | |
demo.load(auto_load_model, [], [model_status]) | |
# ์์ ํ ๋น๋๊ธฐ์ ์ผ๋ก ๋ชจ๋ธ ๋ก๋ (์ด๊ธฐ ํ๋ฉด ํ์ ์ง์ฐ ๋ฐฉ์ง) | |
demo.load(lambda x: load_model_async(x), [model_status], [], _js="() => {}") | |
# ์ ํ๋ ๋ชจ๋ธ ๋ก๋ ์ด๋ฒคํธ ์ฐ๊ฒฐ | |
def get_model_names(selected_model): | |
# ํ์ ์ด๋ฆ์์ ์๋ ๋ชจ๋ธ ์ด๋ฆ์ผ๋ก ๋ณํ | |
inverse_map = {v: k for k, v in available_models.items()} | |
return [inverse_map[selected_model]] if selected_model else [] | |
load_model_btn.click( | |
lambda selected: load_model(get_model_names(selected)), | |
inputs=[model_selector], | |
outputs=[model_status] | |
) | |
# GPU ๋ฉ๋ชจ๋ฆฌ ์ ๋ฆฌ ์ด๋ฒคํธ ์ฐ๊ฒฐ | |
clear_memory_btn.click( | |
lambda: (clear_gpu_memory(), "GPU ๋ฉ๋ชจ๋ฆฌ๊ฐ ์ ๋ฆฌ๋์์ต๋๋ค."), | |
inputs=[], | |
outputs=[model_status] | |
) | |
# ์ฌ์ฉ์๊ฐ ๋ฉ์์ง๋ฅผ ์ ์ถํ๋ฉด ๋ด์ด ์๋ตํฉ๋๋ค | |
msg.submit( | |
user_input, | |
[msg, chatbot], # ์ ๋ ฅ | |
[msg, chatbot], # ์ถ๋ ฅ | |
).then( | |
bot, | |
[ | |
chatbot, | |
num_tokens, | |
final_num_tokens, | |
do_sample, | |
temperature, | |
], # ์ค์ ๋ก๋ "history" ์ ๋ ฅ | |
chatbot, # ์ถ๋ ฅ์์ ์ ํ์คํ ๋ฆฌ ์ ์ฅ | |
) | |
if __name__ == "__main__": | |
# ๋๋ฒ๊น ์ ๋ณด ์ถ๋ ฅ | |
print(f"GPU ์ฌ์ฉ ๊ฐ๋ฅ: {torch.cuda.is_available()}") | |
if torch.cuda.is_available(): | |
print(f"์ฌ์ฉ ๊ฐ๋ฅํ GPU ๊ฐ์: {torch.cuda.device_count()}") | |
print(f"ํ์ฌ GPU: {torch.cuda.current_device()}") | |
print(f"GPU ์ด๋ฆ: {torch.cuda.get_device_name(0)}") | |
# HF_TOKEN ํ๊ฒฝ ๋ณ์ ํ์ธ | |
hf_token = os.getenv("HF_TOKEN") | |
if hf_token: | |
print("HF_TOKEN ํ๊ฒฝ ๋ณ์๊ฐ ์ค์ ๋์ด ์์ต๋๋ค.") | |
else: | |
print("๊ฒฝ๊ณ : HF_TOKEN ํ๊ฒฝ ๋ณ์๊ฐ ์ค์ ๋์ง ์์์ต๋๋ค. ์ ํ๋ ๋ชจ๋ธ์ ์ ๊ทผํ ์ ์์ต๋๋ค.") | |
# ํ ์ฌ์ฉ ๋ฐ ์ฑ ์คํ | |
demo.queue(max_size=10).launch() |