ThinkFlow-llama / app.py
openfree's picture
Update app.py
5969407 verified
raw
history blame
19.2 kB
import re
import threading
import gc
import os
import torch
import gradio as gr
import spaces
import transformers
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import login
# ๋ชจ๋ธ ๋ฉ”๋ชจ๋ฆฌ ๊ด€๋ฆฌ ๋ฐ ์ตœ์ ํ™”๋ฅผ ์œ„ํ•œ ์„ค์ •
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
MAX_GPU_MEMORY = 80 * 1024 * 1024 * 1024 # 80GB A100 ๊ธฐ์ค€ (์‹ค์ œ ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ๋ฉ”๋ชจ๋ฆฌ๋Š” ์ด๋ณด๋‹ค ์ ์Œ)
# ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ๋ชจ๋ธ ๋ชฉ๋ก - A100์—์„œ ํšจ์œจ์ ์œผ๋กœ ์‹คํ–‰ ๊ฐ€๋Šฅํ•œ ๋ชจ๋ธ๋กœ ํ•„ํ„ฐ๋ง
available_models = {
"mistralai/Mistral-Small-3.1-24B-Base-2503": "Mistral Small 3.1 (24B)",
"bartowski/mistralai_Mistral-Small-3.1-24B-Instruct-2503-GGUF": "Mistral Small 3.1 GGUF (24B)",
"google/gemma-3-27b-it": "Google Gemma 3 (27B)",
"Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen 2.5 Coder (32B)",
"open-r1/OlympicCoder-32B": "Olympic Coder (32B)"
}
# ๊ธฐ๋ณธ ๋ชจ๋ธ - available_models์˜ ์ฒซ ๋ฒˆ์งธ ๋ชจ๋ธ
DEFAULT_MODEL_KEY = list(available_models.keys())[0]
DEFAULT_MODEL_VALUE = available_models[DEFAULT_MODEL_KEY]
# ๋ชจ๋ธ ๋กœ๋“œ์— ์‚ฌ์šฉ๋˜๋Š” ์ „์—ญ ๋ณ€์ˆ˜
pipe = None
current_model_name = None
# Hugging Face ํ† ํฐ์œผ๋กœ ๋กœ๊ทธ์ธ ์‹œ๋„
try:
hf_token = os.getenv("HF_TOKEN")
if hf_token:
login(token=hf_token)
print("Hugging Face์— ์„ฑ๊ณต์ ์œผ๋กœ ๋กœ๊ทธ์ธํ–ˆ์Šต๋‹ˆ๋‹ค.")
else:
print("๊ฒฝ๊ณ : HF_TOKEN ํ™˜๊ฒฝ ๋ณ€์ˆ˜๊ฐ€ ์„ค์ •๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค.")
except Exception as e:
print(f"Hugging Face ๋กœ๊ทธ์ธ ์—๋Ÿฌ: {str(e)}")
# ์ตœ์ข… ๋‹ต๋ณ€์„ ๊ฐ์ง€ํ•˜๊ธฐ ์œ„ํ•œ ๋งˆ์ปค
ANSWER_MARKER = "**๋‹ต๋ณ€**"
# ๋‹จ๊ณ„๋ณ„ ์ถ”๋ก ์„ ์‹œ์ž‘ํ•˜๋Š” ๋ฌธ์žฅ๋“ค
rethink_prepends = [
"์ž, ์ด์ œ ๋‹ค์Œ์„ ํŒŒ์•…ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค ",
"์ œ ์ƒ๊ฐ์—๋Š” ",
"์ž ์‹œ๋งŒ์š”, ์ œ ์ƒ๊ฐ์—๋Š” ",
"๋‹ค์Œ ์‚ฌํ•ญ์ด ๋งž๋Š”์ง€ ํ™•์ธํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค ",
"๋˜ํ•œ ๊ธฐ์–ตํ•ด์•ผ ํ•  ๊ฒƒ์€ ",
"๋˜ ๋‹ค๋ฅธ ์ฃผ๋ชฉํ•  ์ ์€ ",
"๊ทธ๋ฆฌ๊ณ  ์ €๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์€ ์‚ฌ์‹ค๋„ ๊ธฐ์–ตํ•ฉ๋‹ˆ๋‹ค ",
"์ด์ œ ์ถฉ๋ถ„ํžˆ ์ดํ•ดํ–ˆ๋‹ค๊ณ  ์ƒ๊ฐํ•ฉ๋‹ˆ๋‹ค ",
"์ง€๊ธˆ๊นŒ์ง€์˜ ์ •๋ณด๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ, ์›๋ž˜ ์งˆ๋ฌธ์— ์‚ฌ์šฉ๋œ ์–ธ์–ด๋กœ ๋‹ต๋ณ€ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค:"
"\n{question}\n"
f"\n{ANSWER_MARKER}\n",
]
# ์ˆ˜์‹ ํ‘œ์‹œ ๋ฌธ์ œ ํ•ด๊ฒฐ์„ ์œ„ํ•œ ์„ค์ •
latex_delimiters = [
{"left": "$$", "right": "$$", "display": True},
{"left": "$", "right": "$", "display": False},
]
# ๋ชจ๋ธ ํฌ๊ธฐ ๊ธฐ๋ฐ˜ ๊ตฌ์„ฑ - ๋ชจ๋ธ ํฌ๊ธฐ์— ๋”ฐ๋ฅธ ์ตœ์  ์„ค์ • ์ •์˜
MODEL_CONFIG = {
"small": { # <10B
"max_memory": {0: "20GiB"},
"offload": False,
"quantization": None
},
"medium": { # 10B-30B
"max_memory": {0: "40GiB"},
"offload": False,
"quantization": None # BitsAndBytes ๋ฌธ์ œ๋กœ ์–‘์žํ™” ๋น„ํ™œ์„ฑํ™”
},
"large": { # >30B
"max_memory": {0: "70GiB"},
"offload": True,
"quantization": None # BitsAndBytes ๋ฌธ์ œ๋กœ ์–‘์žํ™” ๋น„ํ™œ์„ฑํ™”
}
}
def get_model_size_category(model_name):
"""๋ชจ๋ธ ํฌ๊ธฐ ์นดํ…Œ๊ณ ๋ฆฌ ๊ฒฐ์ •"""
if "3B" in model_name or "8B" in model_name:
return "small"
elif "24B" in model_name or "27B" in model_name:
return "medium"
elif "32B" in model_name or "70B" in model_name:
return "large"
else:
# ๊ธฐ๋ณธ๊ฐ’์œผ๋กœ medium ๋ฐ˜ํ™˜
return "medium"
def clear_gpu_memory():
"""GPU ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ"""
global pipe
if pipe is not None:
del pipe
pipe = None
# CUDA ์บ์‹œ ์ •๋ฆฌ
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
def reformat_math(text):
"""Gradio ๊ตฌ๋ฌธ(Katex)์„ ์‚ฌ์šฉํ•˜๋„๋ก MathJax ๊ตฌ๋ถ„ ๊ธฐํ˜ธ ์ˆ˜์ •."""
text = re.sub(r"\\\[\s*(.*?)\s*\\\]", r"$$\1$$", text, flags=re.DOTALL)
text = re.sub(r"\\\(\s*(.*?)\s*\\\)", r"$\1$", text, flags=re.DOTALL)
return text
def user_input(message, history: list):
"""์‚ฌ์šฉ์ž ์ž…๋ ฅ์„ ํžˆ์Šคํ† ๋ฆฌ์— ์ถ”๊ฐ€ํ•˜๊ณ  ์ž…๋ ฅ ํ…์ŠคํŠธ ์ƒ์ž ๋น„์šฐ๊ธฐ"""
return "", history + [
gr.ChatMessage(role="user", content=message.replace(ANSWER_MARKER, ""))
]
def rebuild_messages(history: list):
"""์ค‘๊ฐ„ ์ƒ๊ฐ ๊ณผ์ • ์—†์ด ๋ชจ๋ธ์ด ์‚ฌ์šฉํ•  ํžˆ์Šคํ† ๋ฆฌ์—์„œ ๋ฉ”์‹œ์ง€ ์žฌ๊ตฌ์„ฑ"""
messages = []
for h in history:
if isinstance(h, dict) and not h.get("metadata", {}).get("title", False):
messages.append(h)
elif (
isinstance(h, gr.ChatMessage)
and h.metadata.get("title")
and isinstance(h.content, str)
):
messages.append({"role": h.role, "content": h.content})
return messages
def load_model(model_names):
"""์„ ํƒ๋œ ๋ชจ๋ธ ์ด๋ฆ„์— ๋”ฐ๋ผ ๋ชจ๋ธ ๋กœ๋“œ (A100์— ์ตœ์ ํ™”๋œ ์„ค์ • ์‚ฌ์šฉ)"""
global pipe, current_model_name
# ๊ธฐ์กด ๋ชจ๋ธ ์ •๋ฆฌ
clear_gpu_memory()
# ๋ชจ๋ธ์ด ์„ ํƒ๋˜์ง€ ์•Š์•˜์„ ๊ฒฝ์šฐ ๊ธฐ๋ณธ๊ฐ’ ์ง€์ •
if not model_names:
model_name = DEFAULT_MODEL_KEY # ์ฒซ ๋ฒˆ์งธ ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ๋ชจ๋ธ์„ ๊ธฐ๋ณธ๊ฐ’์œผ๋กœ ์‚ฌ์šฉ
else:
# ์ฒซ ๋ฒˆ์งธ ์„ ํƒ๋œ ๋ชจ๋ธ ์‚ฌ์šฉ
model_name = model_names[0]
# ๋ชจ๋ธ ํฌ๊ธฐ ์นดํ…Œ๊ณ ๋ฆฌ ํ™•์ธ
size_category = get_model_size_category(model_name)
config = MODEL_CONFIG[size_category]
# ๋ชจ๋ธ ๋กœ๋“œ (ํฌ๊ธฐ์— ๋”ฐ๋ผ ์ตœ์ ํ™”๋œ ์„ค์ • ์ ์šฉ)
try:
# HF_TOKEN ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ํ™•์ธ
hf_token = os.getenv("HF_TOKEN")
# ๊ณตํ†ต ๋งค๊ฐœ๋ณ€์ˆ˜
common_params = {
"token": hf_token, # ์ ‘๊ทผ ์ œํ•œ ๋ชจ๋ธ์„ ์œ„ํ•œ ํ† ํฐ
"trust_remote_code": True,
}
# BitsAndBytes ์‚ฌ์šฉ ๊ฐ€๋Šฅ ์—ฌ๋ถ€ ํ™•์ธ
try:
import bitsandbytes
has_bitsandbytes = True
print("BitsAndBytes ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ ๋กœ๋“œ ์„ฑ๊ณต")
except ImportError:
has_bitsandbytes = False
print("BitsAndBytes ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. ์–‘์žํ™” ์—†์ด ๋ชจ๋ธ์„ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.")
# ์–‘์žํ™” ์„ค์ •์ด ํ•„์š”ํ•˜๊ณ  BitsAndBytes๋ฅผ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋Š” ๊ฒฝ์šฐ
if config["quantization"] and has_bitsandbytes:
# ์–‘์žํ™” ์ ์šฉ
from transformers import BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(
load_in_4bit=config["quantization"] == "4bit",
bnb_4bit_compute_dtype=DTYPE
)
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
max_memory=config["max_memory"],
torch_dtype=DTYPE,
quantization_config=quantization_config,
offload_folder="offload" if config["offload"] else None,
**common_params
)
tokenizer = AutoTokenizer.from_pretrained(model_name, **common_params)
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
torch_dtype=DTYPE,
device_map="auto"
)
else:
# ์–‘์žํ™” ์—†์ด ๋กœ๋“œ
pipe = pipeline(
"text-generation",
model=model_name,
device_map="auto",
torch_dtype=DTYPE,
**common_params
)
current_model_name = model_name
return f"๋ชจ๋ธ '{model_name}'์ด(๊ฐ€) ์„ฑ๊ณต์ ์œผ๋กœ ๋กœ๋“œ๋˜์—ˆ์Šต๋‹ˆ๋‹ค. (์ตœ์ ํ™”: {size_category} ์นดํ…Œ๊ณ ๋ฆฌ)"
except Exception as e:
return f"๋ชจ๋ธ ๋กœ๋“œ ์‹คํŒจ: {str(e)}"
@spaces.GPU
def bot(
history: list,
max_num_tokens: int,
final_num_tokens: int,
do_sample: bool,
temperature: float,
):
"""๋ชจ๋ธ์ด ์งˆ๋ฌธ์— ๋‹ต๋ณ€ํ•˜๋„๋ก ํ•˜๊ธฐ"""
global pipe, current_model_name
# ๋ชจ๋ธ์ด ๋กœ๋“œ๋˜์ง€ ์•Š์•˜๋‹ค๋ฉด ์˜ค๋ฅ˜ ๋ฉ”์‹œ์ง€ ํ‘œ์‹œ
if pipe is None:
history.append(
gr.ChatMessage(
role="assistant",
content="๋ชจ๋ธ์ด ๋กœ๋“œ๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค. ํ•˜๋‚˜ ์ด์ƒ์˜ ๋ชจ๋ธ์„ ์„ ํƒํ•˜๊ณ  '๋ชจ๋ธ ๋กœ๋“œ' ๋ฒ„ํŠผ์„ ํด๋ฆญํ•ด ์ฃผ์„ธ์š”.",
)
)
yield history
return
try:
# ํ† ํฐ ๊ธธ์ด ์ž๋™ ์กฐ์ • (๋ชจ๋ธ ํฌ๊ธฐ์— ๋”ฐ๋ผ)
size_category = get_model_size_category(current_model_name)
# ๋Œ€ํ˜• ๋ชจ๋ธ์€ ํ† ํฐ ์ˆ˜๋ฅผ ์ค„์—ฌ ๋ฉ”๋ชจ๋ฆฌ ํšจ์œจ์„ฑ ํ–ฅ์ƒ
if size_category == "large":
max_num_tokens = min(max_num_tokens, 1000)
final_num_tokens = min(final_num_tokens, 1500)
# ๋‚˜์ค‘์— ์Šค๋ ˆ๋“œ์—์„œ ํ† ํฐ์„ ์ŠคํŠธ๋ฆผ์œผ๋กœ ๊ฐ€์ ธ์˜ค๊ธฐ ์œ„ํ•จ
streamer = transformers.TextIteratorStreamer(
pipe.tokenizer,
skip_special_tokens=True,
skip_prompt=True,
)
# ํ•„์š”ํ•œ ๊ฒฝ์šฐ ์ถ”๋ก ์— ์งˆ๋ฌธ์„ ๋‹ค์‹œ ์‚ฝ์ž…ํ•˜๊ธฐ ์œ„ํ•จ
question = history[-1]["content"]
# ๋ณด์กฐ์ž ๋ฉ”์‹œ์ง€ ์ค€๋น„
history.append(
gr.ChatMessage(
role="assistant",
content=str(""),
metadata={"title": "๐Ÿง  ์ƒ๊ฐ ์ค‘...", "status": "pending"},
)
)
# ํ˜„์žฌ ์ฑ„ํŒ…์— ํ‘œ์‹œ๋  ์ถ”๋ก  ๊ณผ์ •
messages = rebuild_messages(history)
# ํƒ€์ž„์•„์›ƒ ์„ค์ •
import signal
class TimeoutError(Exception):
pass
def timeout_handler(signum, frame):
raise TimeoutError("์š”์ฒญ ์ฒ˜๋ฆฌ ์‹œ๊ฐ„์ด ์ดˆ๊ณผ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.")
# ๊ฐ ๋‹จ๊ณ„๋งˆ๋‹ค ์ตœ๋Œ€ 120์ดˆ ํƒ€์ž„์•„์›ƒ ์„ค์ •
timeout_seconds = 120
for i, prepend in enumerate(rethink_prepends):
if i > 0:
messages[-1]["content"] += "\n\n"
messages[-1]["content"] += prepend.format(question=question)
num_tokens = int(
max_num_tokens if ANSWER_MARKER not in prepend else final_num_tokens
)
# ์Šค๋ ˆ๋“œ์—์„œ ๋ชจ๋ธ ์‹คํ–‰
t = threading.Thread(
target=pipe,
args=(messages,),
kwargs=dict(
max_new_tokens=num_tokens,
streamer=streamer,
do_sample=do_sample,
temperature=temperature,
# ๋ฉ”๋ชจ๋ฆฌ ํšจ์œจ์„ฑ์„ ์œ„ํ•œ ์ถ”๊ฐ€ ํŒŒ๋ผ๋ฏธํ„ฐ
repetition_penalty=1.2, # ๋ฐ˜๋ณต ๋ฐฉ์ง€
use_cache=True, # KV ์บ์‹œ ์‚ฌ์šฉ
),
)
t.daemon = True # ๋ฐ๋ชฌ ์Šค๋ ˆ๋“œ๋กœ ์„ค์ •ํ•˜์—ฌ ๋ฉ”์ธ ์Šค๋ ˆ๋“œ๊ฐ€ ์ข…๋ฃŒ๋˜๋ฉด ํ•จ๊ป˜ ์ข…๋ฃŒ
t.start()
# ์ƒˆ ๋‚ด์šฉ์œผ๋กœ ํžˆ์Šคํ† ๋ฆฌ ์žฌ๊ตฌ์„ฑ
history[-1].content += prepend.format(question=question)
if ANSWER_MARKER in prepend:
history[-1].metadata = {"title": "๐Ÿ’ญ ์‚ฌ๊ณ  ๊ณผ์ •", "status": "done"}
# ์ƒ๊ฐ ์ข…๋ฃŒ, ์ด์ œ ๋‹ต๋ณ€์ž…๋‹ˆ๋‹ค (์ค‘๊ฐ„ ๋‹จ๊ณ„์— ๋Œ€ํ•œ ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ ์—†์Œ)
history.append(gr.ChatMessage(role="assistant", content=""))
# ํƒ€์ž„์•„์›ƒ ์„ค์ • (Unix ์‹œ์Šคํ…œ์—์„œ๋งŒ ์ž‘๋™)
try:
if hasattr(signal, 'SIGALRM'):
signal.signal(signal.SIGALRM, timeout_handler)
signal.alarm(timeout_seconds)
# ํ† ํฐ ์ŠคํŠธ๋ฆฌ๋ฐ
token_count = 0
for token in streamer:
history[-1].content += token
history[-1].content = reformat_math(history[-1].content)
token_count += 1
# 10๊ฐœ ํ† ํฐ๋งˆ๋‹ค yield (UI ์‘๋‹ต์„ฑ ํ–ฅ์ƒ)
if token_count % 10 == 0:
yield history
# ๋‚จ์€ ๋‚ด์šฉ yield
yield history
# ํƒ€์ž„์•„์›ƒ ํ•ด์ œ
if hasattr(signal, 'SIGALRM'):
signal.alarm(0)
except TimeoutError:
if hasattr(signal, 'SIGALRM'):
signal.alarm(0)
history[-1].content += "\n\nโš ๏ธ ์‘๋‹ต ์ƒ์„ฑ ์‹œ๊ฐ„์ด ์ดˆ๊ณผ๋˜์—ˆ์Šต๋‹ˆ๋‹ค. ๋‹ค์Œ ๋‹จ๊ณ„๋กœ ์ง„ํ–‰ํ•ฉ๋‹ˆ๋‹ค."
yield history
continue
# ์ตœ๋Œ€ 30์ดˆ ๋Œ€๊ธฐ ํ›„ ๋‹ค์Œ ๋‹จ๊ณ„๋กœ ์ง„ํ–‰
import time
join_start_time = time.time()
while t.is_alive() and (time.time() - join_start_time) < 30:
t.join(1) # 1์ดˆ๋งˆ๋‹ค ํ™•์ธ
# ์Šค๋ ˆ๋“œ๊ฐ€ ์—ฌ์ „ํžˆ ์‹คํ–‰ ์ค‘์ด๋ฉด ๊ฐ•์ œ ์ง„ํ–‰
if t.is_alive():
history[-1].content += "\n\nโš ๏ธ ์‘๋‹ต ์ƒ์„ฑ์ด ์˜ˆ์ƒ๋ณด๋‹ค ์˜ค๋ž˜ ๊ฑธ๋ฆฝ๋‹ˆ๋‹ค. ๋‹ค์Œ ๋‹จ๊ณ„๋กœ ์ง„ํ–‰ํ•ฉ๋‹ˆ๋‹ค."
yield history
# ๋Œ€ํ˜• ๋ชจ๋ธ์ธ ๊ฒฝ์šฐ ๊ฐ ๋‹จ๊ณ„ ํ›„ ๋ถ€๋ถ„์  ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
if size_category == "large" and torch.cuda.is_available():
torch.cuda.empty_cache()
except Exception as e:
# ์˜ค๋ฅ˜ ๋ฐœ์ƒ์‹œ ์‚ฌ์šฉ์ž์—๊ฒŒ ์•Œ๋ฆผ
import traceback
error_msg = f"\n\nโš ๏ธ ์ฒ˜๋ฆฌ ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {str(e)}\n{traceback.format_exc()}"
if len(history) > 0 and isinstance(history[-1], gr.ChatMessage) and history[-1].role == "assistant":
history[-1].content += error_msg
else:
history.append(gr.ChatMessage(role="assistant", content=error_msg))
yield history
yield history
# ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ GPU ์ •๋ณด ํ‘œ์‹œ ํ•จ์ˆ˜
def get_gpu_info():
if not torch.cuda.is_available():
return "GPU๋ฅผ ์‚ฌ์šฉํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."
gpu_info = []
for i in range(torch.cuda.device_count()):
gpu_name = torch.cuda.get_device_name(i)
total_memory = torch.cuda.get_device_properties(i).total_memory / 1024**3
gpu_info.append(f"GPU {i}: {gpu_name} ({total_memory:.1f} GB)")
return "\n".join(gpu_info)
# Gradio ์ธํ„ฐํŽ˜์ด์Šค
with gr.Blocks(fill_height=True, title="ThinkFlow - Step-by-step Reasoning Service") as demo:
# ์ƒ๋‹จ์— ํƒ€์ดํ‹€๊ณผ ์„ค๋ช… ์ถ”๊ฐ€
gr.Markdown("""
# ThinkFlow
## A thought amplification service that implants step-by-step reasoning abilities into LLMs without model modification
""")
with gr.Row(scale=1):
with gr.Column(scale=5):
# ์ฑ„ํŒ… ์ธํ„ฐํŽ˜์ด์Šค
chatbot = gr.Chatbot(
scale=1,
type="messages",
latex_delimiters=latex_delimiters,
height=600,
)
msg = gr.Textbox(
submit_btn=True,
label="",
show_label=False,
placeholder="์—ฌ๊ธฐ์— ์งˆ๋ฌธ์„ ์ž…๋ ฅํ•˜์„ธ์š”.",
autofocus=True,
)
with gr.Column(scale=1):
# ํ•˜๋“œ์›จ์–ด ์ •๋ณด ํ‘œ์‹œ
gpu_info = gr.Markdown(f"**์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ํ•˜๋“œ์›จ์–ด:**\n{get_gpu_info()}")
# ๋ชจ๋ธ ์„ ํƒ ์„น์…˜ ์ถ”๊ฐ€
gr.Markdown("""## ๋ชจ๋ธ ์„ ํƒ""")
model_selector = gr.Radio(
choices=list(available_models.values()),
value=DEFAULT_MODEL_VALUE, # ์˜ฌ๋ฐ”๋ฅธ ๊ธฐ๋ณธ ๋ชจ๋ธ ์„ค์ •
label="์‚ฌ์šฉํ•  LLM ๋ชจ๋ธ ์„ ํƒ",
)
# ๋ชจ๋ธ ๋กœ๋“œ ๋ฒ„ํŠผ
load_model_btn = gr.Button("๋ชจ๋ธ ๋กœ๋“œ", variant="primary")
model_status = gr.Textbox(label="๋ชจ๋ธ ์ƒํƒœ", interactive=False)
# ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ ๋ฒ„ํŠผ
clear_memory_btn = gr.Button("GPU ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ", variant="secondary")
gr.Markdown("""## ๋งค๊ฐœ๋ณ€์ˆ˜ ์กฐ์ •""")
with gr.Accordion("๊ณ ๊ธ‰ ์„ค์ •", open=False):
num_tokens = gr.Slider(
50,
2000,
1000, # ๊ธฐ๋ณธ๊ฐ’ ์ถ•์†Œ
step=50,
label="์ถ”๋ก  ๋‹จ๊ณ„๋‹น ์ตœ๋Œ€ ํ† ํฐ ์ˆ˜",
interactive=True,
)
final_num_tokens = gr.Slider(
50,
3000,
1500, # ๊ธฐ๋ณธ๊ฐ’ ์ถ•์†Œ
step=50,
label="์ตœ์ข… ๋‹ต๋ณ€์˜ ์ตœ๋Œ€ ํ† ํฐ ์ˆ˜",
interactive=True,
)
do_sample = gr.Checkbox(True, label="์ƒ˜ํ”Œ๋ง ์‚ฌ์šฉ")
temperature = gr.Slider(0.1, 1.0, 0.7, step=0.1, label="์˜จ๋„")
# ์ž๋™ ๋ชจ๋ธ ๋กœ๋“œ ๊ธฐ๋Šฅ ์ถ”๊ฐ€
def auto_load_model():
# ์ฒซ ๋ฒˆ์งธ ๋ชจ๋ธ ์ž๋™ ๋กœ๋“œ
model_key = DEFAULT_MODEL_KEY
try:
result = load_model([model_key])
return result
except Exception as e:
return f"์ž๋™ ๋ชจ๋ธ ๋กœ๋“œ ์‹คํŒจ: {str(e)}"
# ์‹œ์ž‘ ์‹œ ์ž๋™์œผ๋กœ ๋ชจ๋ธ ๋กœ๋“œ (์ŠคํŽ˜์ด์Šค๊ฐ€ ์‹œ์ž‘๋  ๋•Œ)
demo.load(auto_load_model, [], [model_status])
# ์„ ํƒ๋œ ๋ชจ๋ธ ๋กœ๋“œ ์ด๋ฒคํŠธ ์—ฐ๊ฒฐ
def get_model_names(selected_model):
# ํ‘œ์‹œ ์ด๋ฆ„์—์„œ ์›๋ž˜ ๋ชจ๋ธ ์ด๋ฆ„์œผ๋กœ ๋ณ€ํ™˜
inverse_map = {v: k for k, v in available_models.items()}
return [inverse_map[selected_model]] if selected_model else []
load_model_btn.click(
lambda selected: load_model(get_model_names(selected)),
inputs=[model_selector],
outputs=[model_status]
)
# GPU ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ ์ด๋ฒคํŠธ ์—ฐ๊ฒฐ
clear_memory_btn.click(
lambda: (clear_gpu_memory(), "GPU ๋ฉ”๋ชจ๋ฆฌ๊ฐ€ ์ •๋ฆฌ๋˜์—ˆ์Šต๋‹ˆ๋‹ค."),
inputs=[],
outputs=[model_status]
)
# ์‚ฌ์šฉ์ž๊ฐ€ ๋ฉ”์‹œ์ง€๋ฅผ ์ œ์ถœํ•˜๋ฉด ๋ด‡์ด ์‘๋‹ตํ•ฉ๋‹ˆ๋‹ค
msg.submit(
user_input,
[msg, chatbot], # ์ž…๋ ฅ
[msg, chatbot], # ์ถœ๋ ฅ
).then(
bot,
[
chatbot,
num_tokens,
final_num_tokens,
do_sample,
temperature,
], # ์‹ค์ œ๋กœ๋Š” "history" ์ž…๋ ฅ
chatbot, # ์ถœ๋ ฅ์—์„œ ์ƒˆ ํžˆ์Šคํ† ๋ฆฌ ์ €์žฅ
)
if __name__ == "__main__":
# ๋””๋ฒ„๊น… ์ •๋ณด ์ถœ๋ ฅ
print(f"GPU ์‚ฌ์šฉ ๊ฐ€๋Šฅ: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ GPU ๊ฐœ์ˆ˜: {torch.cuda.device_count()}")
print(f"ํ˜„์žฌ GPU: {torch.cuda.current_device()}")
print(f"GPU ์ด๋ฆ„: {torch.cuda.get_device_name(0)}")
# HF_TOKEN ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ํ™•์ธ
hf_token = os.getenv("HF_TOKEN")
if hf_token:
print("HF_TOKEN ํ™˜๊ฒฝ ๋ณ€์ˆ˜๊ฐ€ ์„ค์ •๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค.")
else:
print("๊ฒฝ๊ณ : HF_TOKEN ํ™˜๊ฒฝ ๋ณ€์ˆ˜๊ฐ€ ์„ค์ •๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค. ์ œํ•œ๋œ ๋ชจ๋ธ์— ์ ‘๊ทผํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
# ํ ์‚ฌ์šฉ ๋ฐ ์•ฑ ์‹คํ–‰
demo.queue(max_size=10).launch()