ThinkFlow-llama / app.py
openfree's picture
Update app.py
ccc2ed2 verified
raw
history blame
21.4 kB
import re
import threading
import gc
import os
import torch
import time
import signal
import gradio as gr
import spaces
import transformers
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login
# ๋ชจ๋ธ ๋ฉ”๋ชจ๋ฆฌ ๊ด€๋ฆฌ ๋ฐ ์ตœ์ ํ™”๋ฅผ ์œ„ํ•œ ์„ค์ •
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
MAX_GPU_MEMORY = 80 * 1024 * 1024 * 1024 # 80GB A100 ๊ธฐ์ค€
# ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ๋ชจ๋ธ ๋ชฉ๋ก - ๋” ์ž‘์€ ๋ชจ๋ธ๋ถ€ํ„ฐ ์‹œ์ž‘ํ•˜๋„๋ก ๋ณ€๊ฒฝ
available_models = {
"google/gemma-2b": "Google Gemma (2B)", # ๋” ์ž‘์€ ๋ชจ๋ธ์„ ๊ธฐ๋ณธ์œผ๋กœ ์„ค์ •
"mistralai/Mistral-7B-Instruct-v0.2": "Mistral 7B Instruct v0.2",
"mistralai/Mistral-Small-3.1-24B-Base-2503": "Mistral Small 3.1 (24B)",
"google/gemma-3-27b-it": "Google Gemma 3 (27B)",
"Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen 2.5 Coder (32B)",
"open-r1/OlympicCoder-32B": "Olympic Coder (32B)"
}
# ๊ธฐ๋ณธ ๋ชจ๋ธ - ๊ฐ€์žฅ ์ž‘์€ ๋ชจ๋ธ๋กœ ์„ค์ •
DEFAULT_MODEL_KEY = list(available_models.keys())[0]
DEFAULT_MODEL_VALUE = available_models[DEFAULT_MODEL_KEY]
# ๋ชจ๋ธ ๋กœ๋“œ์— ์‚ฌ์šฉ๋˜๋Š” ์ „์—ญ ๋ณ€์ˆ˜
pipe = None
current_model_name = None
loading_in_progress = False
# Hugging Face ํ† ํฐ์œผ๋กœ ๋กœ๊ทธ์ธ ์‹œ๋„
try:
hf_token = os.getenv("HF_TOKEN")
if hf_token:
login(token=hf_token)
print("Hugging Face์— ์„ฑ๊ณต์ ์œผ๋กœ ๋กœ๊ทธ์ธํ–ˆ์Šต๋‹ˆ๋‹ค.")
else:
print("๊ฒฝ๊ณ : HF_TOKEN ํ™˜๊ฒฝ ๋ณ€์ˆ˜๊ฐ€ ์„ค์ •๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค.")
except Exception as e:
print(f"Hugging Face ๋กœ๊ทธ์ธ ์—๋Ÿฌ: {str(e)}")
# ์ตœ์ข… ๋‹ต๋ณ€์„ ๊ฐ์ง€ํ•˜๊ธฐ ์œ„ํ•œ ๋งˆ์ปค
ANSWER_MARKER = "**๋‹ต๋ณ€**"
# ๋‹จ๊ณ„๋ณ„ ์ถ”๋ก ์„ ์‹œ์ž‘ํ•˜๋Š” ๋ฌธ์žฅ๋“ค
rethink_prepends = [
"์ž, ์ด์ œ ๋‹ค์Œ์„ ํŒŒ์•…ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค ",
"์ œ ์ƒ๊ฐ์—๋Š” ",
"์ž ์‹œ๋งŒ์š”, ์ œ ์ƒ๊ฐ์—๋Š” ",
"๋‹ค์Œ ์‚ฌํ•ญ์ด ๋งž๋Š”์ง€ ํ™•์ธํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค ",
"๋˜ํ•œ ๊ธฐ์–ตํ•ด์•ผ ํ•  ๊ฒƒ์€ ",
"๋˜ ๋‹ค๋ฅธ ์ฃผ๋ชฉํ•  ์ ์€ ",
"๊ทธ๋ฆฌ๊ณ  ์ €๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์€ ์‚ฌ์‹ค๋„ ๊ธฐ์–ตํ•ฉ๋‹ˆ๋‹ค ",
"์ด์ œ ์ถฉ๋ถ„ํžˆ ์ดํ•ดํ–ˆ๋‹ค๊ณ  ์ƒ๊ฐํ•ฉ๋‹ˆ๋‹ค ",
"์ง€๊ธˆ๊นŒ์ง€์˜ ์ •๋ณด๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ, ์›๋ž˜ ์งˆ๋ฌธ์— ์‚ฌ์šฉ๋œ ์–ธ์–ด๋กœ ๋‹ต๋ณ€ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค:"
"\n{question}\n"
f"\n{ANSWER_MARKER}\n",
]
# ์ˆ˜์‹ ํ‘œ์‹œ ๋ฌธ์ œ ํ•ด๊ฒฐ์„ ์œ„ํ•œ ์„ค์ •
latex_delimiters = [
{"left": "$$", "right": "$$", "display": True},
{"left": "$", "right": "$", "display": False},
]
# ๋ชจ๋ธ ํฌ๊ธฐ ๊ธฐ๋ฐ˜ ๊ตฌ์„ฑ - ๋ชจ๋ธ ํฌ๊ธฐ์— ๋”ฐ๋ฅธ ์ตœ์  ์„ค์ • ์ •์˜
MODEL_CONFIG = {
"small": { # <10B
"max_memory": {0: "10GiB"},
"offload": False,
"quantization": None
},
"medium": { # 10B-30B
"max_memory": {0: "30GiB"},
"offload": False,
"quantization": None
},
"large": { # >30B
"max_memory": {0: "60GiB"},
"offload": True,
"quantization": None
}
}
def get_model_size_category(model_name):
"""๋ชจ๋ธ ํฌ๊ธฐ ์นดํ…Œ๊ณ ๋ฆฌ ๊ฒฐ์ •"""
if "2B" in model_name or "3B" in model_name or "7B" in model_name or "8B" in model_name:
return "small"
elif "15B" in model_name or "24B" in model_name or "27B" in model_name:
return "medium"
elif "32B" in model_name or "70B" in model_name:
return "large"
else:
# ๊ธฐ๋ณธ๊ฐ’์œผ๋กœ small ๋ฐ˜ํ™˜ (์•ˆ์ „์„ ์œ„ํ•ด)
return "small"
def clear_gpu_memory():
"""GPU ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ"""
global pipe
if pipe is not None:
del pipe
pipe = None
# CUDA ์บ์‹œ ์ •๋ฆฌ
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
def reformat_math(text):
"""Gradio ๊ตฌ๋ฌธ(Katex)์„ ์‚ฌ์šฉํ•˜๋„๋ก MathJax ๊ตฌ๋ถ„ ๊ธฐํ˜ธ ์ˆ˜์ •."""
text = re.sub(r"\\\[\s*(.*?)\s*\\\]", r"$$\1$$", text, flags=re.DOTALL)
text = re.sub(r"\\\(\s*(.*?)\s*\\\)", r"$\1$", text, flags=re.DOTALL)
return text
def user_input(message, history: list):
"""์‚ฌ์šฉ์ž ์ž…๋ ฅ์„ ํžˆ์Šคํ† ๋ฆฌ์— ์ถ”๊ฐ€ํ•˜๊ณ  ์ž…๋ ฅ ํ…์ŠคํŠธ ์ƒ์ž ๋น„์šฐ๊ธฐ"""
return "", history + [
gr.ChatMessage(role="user", content=message.replace(ANSWER_MARKER, ""))
]
def rebuild_messages(history: list):
"""์ค‘๊ฐ„ ์ƒ๊ฐ ๊ณผ์ • ์—†์ด ๋ชจ๋ธ์ด ์‚ฌ์šฉํ•  ํžˆ์Šคํ† ๋ฆฌ์—์„œ ๋ฉ”์‹œ์ง€ ์žฌ๊ตฌ์„ฑ"""
messages = []
for h in history:
if isinstance(h, dict) and not h.get("metadata", {}).get("title", False):
messages.append(h)
elif (
isinstance(h, gr.ChatMessage)
and h.metadata.get("title")
and isinstance(h.content, str)
):
messages.append({"role": h.role, "content": h.content})
return messages
def load_model(model_names, status_callback=None):
"""์„ ํƒ๋œ ๋ชจ๋ธ ์ด๋ฆ„์— ๋”ฐ๋ผ ๋ชจ๋ธ ๋กœ๋“œ (A100์— ์ตœ์ ํ™”๋œ ์„ค์ • ์‚ฌ์šฉ)"""
global pipe, current_model_name, loading_in_progress
# ์ด๋ฏธ ๋กœ๋”ฉ ์ค‘์ธ ๊ฒฝ์šฐ
if loading_in_progress:
return "๋‹ค๋ฅธ ๋ชจ๋ธ์ด ์ด๋ฏธ ๋กœ๋“œ ์ค‘์ž…๋‹ˆ๋‹ค. ์ž ์‹œ ๊ธฐ๋‹ค๋ ค์ฃผ์„ธ์š”."
loading_in_progress = True
try:
# ๊ธฐ์กด ๋ชจ๋ธ ์ •๋ฆฌ
clear_gpu_memory()
# ๋ชจ๋ธ์ด ์„ ํƒ๋˜์ง€ ์•Š์•˜์„ ๊ฒฝ์šฐ ๊ธฐ๋ณธ๊ฐ’ ์ง€์ •
if not model_names:
model_name = DEFAULT_MODEL_KEY
else:
# ์ฒซ ๋ฒˆ์งธ ์„ ํƒ๋œ ๋ชจ๋ธ ์‚ฌ์šฉ
model_name = model_names[0]
# ๋ชจ๋ธ ํฌ๊ธฐ ์นดํ…Œ๊ณ ๋ฆฌ ํ™•์ธ
size_category = get_model_size_category(model_name)
config = MODEL_CONFIG[size_category]
# ๋กœ๋”ฉ ์ƒํƒœ ์—…๋ฐ์ดํŠธ
if status_callback:
status_callback(f"๋ชจ๋ธ '{model_name}' ๋กœ๋“œ ์ค‘... (ํฌ๊ธฐ: {size_category})")
# ๋ชจ๋ธ ๋กœ๋“œ (ํฌ๊ธฐ์— ๋”ฐ๋ผ ์ตœ์ ํ™”๋œ ์„ค์ • ์ ์šฉ)
# HF_TOKEN ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ํ™•์ธ
hf_token = os.getenv("HF_TOKEN")
# ๊ณตํ†ต ๋งค๊ฐœ๋ณ€์ˆ˜
common_params = {
"token": hf_token, # ์ ‘๊ทผ ์ œํ•œ ๋ชจ๋ธ์„ ์œ„ํ•œ ํ† ํฐ
"trust_remote_code": True,
}
# BitsAndBytes ์‚ฌ์šฉ ์—ฌ๋ถ€ ํ™•์ธ
try:
import bitsandbytes
has_bitsandbytes = True
except ImportError:
has_bitsandbytes = False
if status_callback:
status_callback(f"BitsAndBytes ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. ์–‘์žํ™” ์—†์ด ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.")
# ์‹œ๊ฐ„ ์ œํ•œ ์„ค์ • (๋ชจ๋ธ ํฌ๊ธฐ์— ๋”ฐ๋ผ ๋‹ค๋ฅด๊ฒŒ)
if size_category == "small":
load_timeout = 180 # 3๋ถ„
elif size_category == "medium":
load_timeout = 300 # 5๋ถ„
else:
load_timeout = 600 # 10๋ถ„
# ๋กœ๋”ฉ ์‹œ์ž‘ ์‹œ๊ฐ„
start_time = time.time()
# ์–‘์žํ™” ์„ค์ •์ด ํ•„์š”ํ•˜๊ณ  BitsAndBytes๋ฅผ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋Š” ๊ฒฝ์šฐ
if config["quantization"] and has_bitsandbytes:
# ์–‘์žํ™” ์ ์šฉ
from transformers import BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(
load_in_4bit=config["quantization"] == "4bit",
bnb_4bit_compute_dtype=DTYPE
)
if status_callback:
status_callback(f"๋ชจ๋ธ '{model_name}' ๋กœ๋“œ ์ค‘... (์–‘์žํ™” ์ ์šฉ)")
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
max_memory=config["max_memory"],
torch_dtype=DTYPE,
quantization_config=quantization_config,
offload_folder="offload" if config["offload"] else None,
**common_params
)
tokenizer = AutoTokenizer.from_pretrained(model_name, **common_params)
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
torch_dtype=DTYPE,
device_map="auto"
)
else:
# ์–‘์žํ™” ์—†์ด ๋กœ๋“œ
if status_callback:
status_callback(f"๋ชจ๋ธ '{model_name}' ๋กœ๋“œ ์ค‘... (ํ‘œ์ค€ ๋ฐฉ์‹)")
pipe = pipeline(
"text-generation",
model=model_name,
device_map="auto",
torch_dtype=DTYPE,
**common_params
)
# ์‹œ๊ฐ„ ์ œํ•œ ์ดˆ๊ณผ ํ™•์ธ
elapsed_time = time.time() - start_time
if elapsed_time > load_timeout:
clear_gpu_memory()
loading_in_progress = False
return f"๋ชจ๋ธ ๋กœ๋“œ ์‹œ๊ฐ„ ์ดˆ๊ณผ: {load_timeout}์ดˆ๊ฐ€ ์ง€๋‚ฌ์Šต๋‹ˆ๋‹ค. ๋‹ค์‹œ ์‹œ๋„ํ•˜์„ธ์š”."
current_model_name = model_name
loading_in_progress = False
return f"๋ชจ๋ธ '{model_name}'์ด(๊ฐ€) ์„ฑ๊ณต์ ์œผ๋กœ ๋กœ๋“œ๋˜์—ˆ์Šต๋‹ˆ๋‹ค. (์ตœ์ ํ™”: {size_category}, ์†Œ์š”์‹œ๊ฐ„: {elapsed_time:.1f}์ดˆ)"
except Exception as e:
loading_in_progress = False
return f"๋ชจ๋ธ ๋กœ๋“œ ์‹คํŒจ: {str(e)}"
@spaces.GPU
def bot(
history: list,
max_num_tokens: int,
final_num_tokens: int,
do_sample: bool,
temperature: float,
):
"""๋ชจ๋ธ์ด ์งˆ๋ฌธ์— ๋‹ต๋ณ€ํ•˜๋„๋ก ํ•˜๊ธฐ"""
global pipe, current_model_name
# ๋ชจ๋ธ์ด ๋กœ๋“œ๋˜์ง€ ์•Š์•˜๋‹ค๋ฉด ์˜ค๋ฅ˜ ๋ฉ”์‹œ์ง€ ํ‘œ์‹œ
if pipe is None:
history.append(
gr.ChatMessage(
role="assistant",
content="๋ชจ๋ธ์ด ๋กœ๋“œ๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค. ํ•˜๋‚˜ ์ด์ƒ์˜ ๋ชจ๋ธ์„ ์„ ํƒํ•˜๊ณ  '๋ชจ๋ธ ๋กœ๋“œ' ๋ฒ„ํŠผ์„ ํด๋ฆญํ•ด ์ฃผ์„ธ์š”.",
)
)
yield history
return
try:
# ํ† ํฐ ๊ธธ์ด ์ž๋™ ์กฐ์ • (๋ชจ๋ธ ํฌ๊ธฐ์— ๋”ฐ๋ผ)
size_category = get_model_size_category(current_model_name)
# ๋Œ€ํ˜• ๋ชจ๋ธ์€ ํ† ํฐ ์ˆ˜๋ฅผ ์ค„์—ฌ ๋ฉ”๋ชจ๋ฆฌ ํšจ์œจ์„ฑ ํ–ฅ์ƒ
if size_category == "large":
max_num_tokens = min(max_num_tokens, 1000)
final_num_tokens = min(final_num_tokens, 1500)
# ๋‚˜์ค‘์— ์Šค๋ ˆ๋“œ์—์„œ ํ† ํฐ์„ ์ŠคํŠธ๋ฆผ์œผ๋กœ ๊ฐ€์ ธ์˜ค๊ธฐ ์œ„ํ•จ
streamer = transformers.TextIteratorStreamer(
pipe.tokenizer,
skip_special_tokens=True,
skip_prompt=True,
)
# ํ•„์š”ํ•œ ๊ฒฝ์šฐ ์ถ”๋ก ์— ์งˆ๋ฌธ์„ ๋‹ค์‹œ ์‚ฝ์ž…ํ•˜๊ธฐ ์œ„ํ•จ
question = history[-1]["content"]
# ๋ณด์กฐ์ž ๋ฉ”์‹œ์ง€ ์ค€๋น„
history.append(
gr.ChatMessage(
role="assistant",
content=str(""),
metadata={"title": "๐Ÿง  ์ƒ๊ฐ ์ค‘...", "status": "pending"},
)
)
# ํ˜„์žฌ ์ฑ„ํŒ…์— ํ‘œ์‹œ๋  ์ถ”๋ก  ๊ณผ์ •
messages = rebuild_messages(history)
# ํƒ€์ž„์•„์›ƒ ์„ค์ •
class TimeoutError(Exception):
pass
def timeout_handler(signum, frame):
raise TimeoutError("์š”์ฒญ ์ฒ˜๋ฆฌ ์‹œ๊ฐ„์ด ์ดˆ๊ณผ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.")
# ๊ฐ ๋‹จ๊ณ„๋งˆ๋‹ค ์ตœ๋Œ€ 120์ดˆ ํƒ€์ž„์•„์›ƒ ์„ค์ •
timeout_seconds = 120
for i, prepend in enumerate(rethink_prepends):
if i > 0:
messages[-1]["content"] += "\n\n"
messages[-1]["content"] += prepend.format(question=question)
num_tokens = int(
max_num_tokens if ANSWER_MARKER not in prepend else final_num_tokens
)
# ์Šค๋ ˆ๋“œ์—์„œ ๋ชจ๋ธ ์‹คํ–‰
t = threading.Thread(
target=pipe,
args=(messages,),
kwargs=dict(
max_new_tokens=num_tokens,
streamer=streamer,
do_sample=do_sample,
temperature=temperature,
# ๋ฉ”๋ชจ๋ฆฌ ํšจ์œจ์„ฑ์„ ์œ„ํ•œ ์ถ”๊ฐ€ ํŒŒ๋ผ๋ฏธํ„ฐ
repetition_penalty=1.2, # ๋ฐ˜๋ณต ๋ฐฉ์ง€
use_cache=True, # KV ์บ์‹œ ์‚ฌ์šฉ
),
)
t.daemon = True # ๋ฐ๋ชฌ ์Šค๋ ˆ๋“œ๋กœ ์„ค์ •ํ•˜์—ฌ ๋ฉ”์ธ ์Šค๋ ˆ๋“œ๊ฐ€ ์ข…๋ฃŒ๋˜๋ฉด ํ•จ๊ป˜ ์ข…๋ฃŒ
t.start()
# ์ƒˆ ๋‚ด์šฉ์œผ๋กœ ํžˆ์Šคํ† ๋ฆฌ ์žฌ๊ตฌ์„ฑ
history[-1].content += prepend.format(question=question)
if ANSWER_MARKER in prepend:
history[-1].metadata = {"title": "๐Ÿ’ญ ์‚ฌ๊ณ  ๊ณผ์ •", "status": "done"}
# ์ƒ๊ฐ ์ข…๋ฃŒ, ์ด์ œ ๋‹ต๋ณ€์ž…๋‹ˆ๋‹ค (์ค‘๊ฐ„ ๋‹จ๊ณ„์— ๋Œ€ํ•œ ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ ์—†์Œ)
history.append(gr.ChatMessage(role="assistant", content=""))
# ํƒ€์ž„์•„์›ƒ ์„ค์ • (Unix ์‹œ์Šคํ…œ์—์„œ๋งŒ ์ž‘๋™)
try:
if hasattr(signal, 'SIGALRM'):
signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(timeout_seconds)
# ํ† ํฐ ์ŠคํŠธ๋ฆฌ๋ฐ
token_count = 0
for token in streamer:
history[-1].content += token
history[-1].content = reformat_math(history[-1].content)
token_count += 1
# 10๊ฐœ ํ† ํฐ๋งˆ๋‹ค yield (UI ์‘๋‹ต์„ฑ ํ–ฅ์ƒ)
if token_count % 10 == 0:
yield history
# ๋‚จ์€ ๋‚ด์šฉ yield
yield history
# ํƒ€์ž„์•„์›ƒ ํ•ด์ œ
if hasattr(signal, 'SIGALRM'):
signal.alarm(0)
except TimeoutError:
if hasattr(signal, 'SIGALRM'):
signal.alarm(0)
history[-1].content += "\n\nโš ๏ธ ์‘๋‹ต ์ƒ์„ฑ ์‹œ๊ฐ„์ด ์ดˆ๊ณผ๋˜์—ˆ์Šต๋‹ˆ๋‹ค. ๋‹ค์Œ ๋‹จ๊ณ„๋กœ ์ง„ํ–‰ํ•ฉ๋‹ˆ๋‹ค."
yield history
continue
# ์ตœ๋Œ€ 30์ดˆ ๋Œ€๊ธฐ ํ›„ ๋‹ค์Œ ๋‹จ๊ณ„๋กœ ์ง„ํ–‰
join_start_time = time.time()
while t.is_alive() and (time.time() - join_start_time) < 30:
t.join(1) # 1์ดˆ๋งˆ๋‹ค ํ™•์ธ
# ์Šค๋ ˆ๋“œ๊ฐ€ ์—ฌ์ „ํžˆ ์‹คํ–‰ ์ค‘์ด๋ฉด ๊ฐ•์ œ ์ง„ํ–‰
if t.is_alive():
history[-1].content += "\n\nโš ๏ธ ์‘๋‹ต ์ƒ์„ฑ์ด ์˜ˆ์ƒ๋ณด๋‹ค ์˜ค๋ž˜ ๊ฑธ๋ฆฝ๋‹ˆ๋‹ค. ๋‹ค์Œ ๋‹จ๊ณ„๋กœ ์ง„ํ–‰ํ•ฉ๋‹ˆ๋‹ค."
yield history
# ๋Œ€ํ˜• ๋ชจ๋ธ์ธ ๊ฒฝ์šฐ ๊ฐ ๋‹จ๊ณ„ ํ›„ ๋ถ€๋ถ„์  ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
if size_category == "large" and torch.cuda.is_available():
torch.cuda.empty_cache()
except Exception as e:
# ์˜ค๋ฅ˜ ๋ฐœ์ƒ์‹œ ์‚ฌ์šฉ์ž์—๊ฒŒ ์•Œ๋ฆผ
import traceback
error_msg = f"\n\nโš ๏ธ ์ฒ˜๋ฆฌ ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {str(e)}\n{traceback.format_exc()}"
if len(history) > 0 and isinstance(history[-1], gr.ChatMessage) and history[-1].role == "assistant":
history[-1].content += error_msg
else:
history.append(gr.ChatMessage(role="assistant", content=error_msg))
yield history
yield history
# ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ GPU ์ •๋ณด ํ‘œ์‹œ ํ•จ์ˆ˜
def get_gpu_info():
if not torch.cuda.is_available():
return "GPU๋ฅผ ์‚ฌ์šฉํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."
gpu_info = []
for i in range(torch.cuda.device_count()):
gpu_name = torch.cuda.get_device_name(i)
total_memory = torch.cuda.get_device_properties(i).total_memory / 1024**3
gpu_info.append(f"GPU {i}: {gpu_name} ({total_memory:.1f} GB)")
return "\n".join(gpu_info)
# ์ž๋™ ๋ชจ๋ธ ๋กœ๋“œ ํ•จ์ˆ˜ (์ƒํƒœ ์—…๋ฐ์ดํŠธ ํฌํ•จ)
def auto_load_model():
# ์ฒซ ๋ฒˆ์งธ ๋ชจ๋ธ ์ž๋™ ๋กœ๋“œ
model_key = DEFAULT_MODEL_KEY
try:
# ์ง„ํ–‰ ์ƒํƒœ ํ‘œ์‹œ๋ฅผ ์œ„ํ•œ ๋นˆ ๊ฒฐ๊ณผ ๋ฐ˜ํ™˜
return "์ž‘์€ ๋ชจ๋ธ ์ž๋™ ๋กœ๋“œ ์ค‘... ์ž ์‹œ ๊ธฐ๋‹ค๋ ค์ฃผ์„ธ์š”."
except Exception as e:
return f"์ž๋™ ๋ชจ๋ธ ๋กœ๋“œ ์‹คํŒจ: {str(e)}"
# ์‹ค์ œ ๋ชจ๋ธ ๋กœ๋“œ ํ•จ์ˆ˜ (๋น„๋™๊ธฐ)
def load_model_async(model_status):
# ๋น„๋™๊ธฐ ํ•จ์ˆ˜๋กœ ๋ชจ๋ธ ๋กœ๋“œ (์‹ค์ œ ๋กœ๋“œ๋Š” ๋ฐฑ๊ทธ๋ผ์šด๋“œ์—์„œ ์ˆ˜ํ–‰)
model_key = DEFAULT_MODEL_KEY
def update_status(status):
model_status.update(value=status)
# ๋ณ„๋„ ์Šค๋ ˆ๋“œ์—์„œ ๋กœ๋“œ
def load_in_thread():
try:
result = load_model([model_key], update_status)
model_status.update(value=result)
except Exception as e:
model_status.update(value=f"๋ชจ๋ธ ๋กœ๋“œ ์‹คํŒจ: {str(e)}")
threading.Thread(target=load_in_thread, daemon=True).start()
return "๋ชจ๋ธ ๋กœ๋“œ ์ค€๋น„ ์ค‘... ์ž๋™์œผ๋กœ ์ง„ํ–‰๋ฉ๋‹ˆ๋‹ค."
# Gradio ์ธํ„ฐํŽ˜์ด์Šค
with gr.Blocks(fill_height=True, title="ThinkFlow - Step-by-step Reasoning Service") as demo:
# ์ƒ๋‹จ์— ํƒ€์ดํ‹€๊ณผ ์„ค๋ช… ์ถ”๊ฐ€
gr.Markdown("""
# ThinkFlow
## A thought amplification service that implants step-by-step reasoning abilities into LLMs without model modification
""")
with gr.Row(scale=1):
with gr.Column(scale=5):
# ์ฑ„ํŒ… ์ธํ„ฐํŽ˜์ด์Šค
chatbot = gr.Chatbot(
scale=1,
type="messages",
latex_delimiters=latex_delimiters,
height=600,
)
msg = gr.Textbox(
submit_btn=True,
label="",
show_label=False,
placeholder="์—ฌ๊ธฐ์— ์งˆ๋ฌธ์„ ์ž…๋ ฅํ•˜์„ธ์š”.",
autofocus=True,
)
with gr.Column(scale=1):
# ํ•˜๋“œ์›จ์–ด ์ •๋ณด ํ‘œ์‹œ
gpu_info = gr.Markdown(f"**์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ํ•˜๋“œ์›จ์–ด:**\n{get_gpu_info()}")
# ๋ชจ๋ธ ์„ ํƒ ์„น์…˜ ์ถ”๊ฐ€
gr.Markdown("""## ๋ชจ๋ธ ์„ ํƒ""")
model_selector = gr.Radio(
choices=list(available_models.values()),
value=DEFAULT_MODEL_VALUE,
label="์‚ฌ์šฉํ•  LLM ๋ชจ๋ธ ์„ ํƒ",
)
# ๋ชจ๋ธ ๋กœ๋“œ ๋ฒ„ํŠผ
load_model_btn = gr.Button("๋ชจ๋ธ ๋กœ๋“œ", variant="primary")
model_status = gr.Textbox(label="๋ชจ๋ธ ์ƒํƒœ", interactive=False)
# ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ ๋ฒ„ํŠผ
clear_memory_btn = gr.Button("GPU ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ", variant="secondary")
gr.Markdown("""## ๋งค๊ฐœ๋ณ€์ˆ˜ ์กฐ์ •""")
with gr.Accordion("๊ณ ๊ธ‰ ์„ค์ •", open=False):
num_tokens = gr.Slider(
50,
2000,
1000,
step=50,
label="์ถ”๋ก  ๋‹จ๊ณ„๋‹น ์ตœ๋Œ€ ํ† ํฐ ์ˆ˜",
interactive=True,
)
final_num_tokens = gr.Slider(
50,
3000,
1500,
step=50,
label="์ตœ์ข… ๋‹ต๋ณ€์˜ ์ตœ๋Œ€ ํ† ํฐ ์ˆ˜",
interactive=True,
)
do_sample = gr.Checkbox(True, label="์ƒ˜ํ”Œ๋ง ์‚ฌ์šฉ")
temperature = gr.Slider(0.1, 1.0, 0.7, step=0.1, label="์˜จ๋„")
# ์‹œ์ž‘ ์‹œ ์ž๋™์œผ๋กœ ์ดˆ๊ธฐํ™”
demo.load(auto_load_model, [], [model_status])
# ์‹œ์ž‘ ํ›„ ๋น„๋™๊ธฐ์ ์œผ๋กœ ๋ชจ๋ธ ๋กœ๋“œ (์ดˆ๊ธฐ ํ™”๋ฉด ํ‘œ์‹œ ์ง€์—ฐ ๋ฐฉ์ง€)
demo.load(lambda x: load_model_async(x), [model_status], [], _js="() => {}")
# ์„ ํƒ๋œ ๋ชจ๋ธ ๋กœ๋“œ ์ด๋ฒคํŠธ ์—ฐ๊ฒฐ
def get_model_names(selected_model):
# ํ‘œ์‹œ ์ด๋ฆ„์—์„œ ์›๋ž˜ ๋ชจ๋ธ ์ด๋ฆ„์œผ๋กœ ๋ณ€ํ™˜
inverse_map = {v: k for k, v in available_models.items()}
return [inverse_map[selected_model]] if selected_model else []
load_model_btn.click(
lambda selected: load_model(get_model_names(selected)),
inputs=[model_selector],
outputs=[model_status]
)
# GPU ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ ์ด๋ฒคํŠธ ์—ฐ๊ฒฐ
clear_memory_btn.click(
lambda: (clear_gpu_memory(), "GPU ๋ฉ”๋ชจ๋ฆฌ๊ฐ€ ์ •๋ฆฌ๋˜์—ˆ์Šต๋‹ˆ๋‹ค."),
inputs=[],
outputs=[model_status]
)
# ์‚ฌ์šฉ์ž๊ฐ€ ๋ฉ”์‹œ์ง€๋ฅผ ์ œ์ถœํ•˜๋ฉด ๋ด‡์ด ์‘๋‹ตํ•ฉ๋‹ˆ๋‹ค
msg.submit(
user_input,
[msg, chatbot], # ์ž…๋ ฅ
[msg, chatbot], # ์ถœ๋ ฅ
).then(
bot,
[
chatbot,
num_tokens,
final_num_tokens,
do_sample,
temperature,
], # ์‹ค์ œ๋กœ๋Š” "history" ์ž…๋ ฅ
chatbot, # ์ถœ๋ ฅ์—์„œ ์ƒˆ ํžˆ์Šคํ† ๋ฆฌ ์ €์žฅ
)
if __name__ == "__main__":
# ๋””๋ฒ„๊น… ์ •๋ณด ์ถœ๋ ฅ
print(f"GPU ์‚ฌ์šฉ ๊ฐ€๋Šฅ: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ GPU ๊ฐœ์ˆ˜: {torch.cuda.device_count()}")
print(f"ํ˜„์žฌ GPU: {torch.cuda.current_device()}")
print(f"GPU ์ด๋ฆ„: {torch.cuda.get_device_name(0)}")
# HF_TOKEN ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ํ™•์ธ
hf_token = os.getenv("HF_TOKEN")
if hf_token:
print("HF_TOKEN ํ™˜๊ฒฝ ๋ณ€์ˆ˜๊ฐ€ ์„ค์ •๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค.")
else:
print("๊ฒฝ๊ณ : HF_TOKEN ํ™˜๊ฒฝ ๋ณ€์ˆ˜๊ฐ€ ์„ค์ •๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค. ์ œํ•œ๋œ ๋ชจ๋ธ์— ์ ‘๊ทผํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
# ํ ์‚ฌ์šฉ ๋ฐ ์•ฑ ์‹คํ–‰
demo.queue(max_size=10).launch()