bitnet / app.py
kimhyunwoo's picture
Update app.py
ece9655 verified
raw
history blame
3.95 kB
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import os
# --- ๋ชจ๋ธ ๋กœ๋“œ ---
# ๋ชจ๋ธ ๊ฒฝ๋กœ ์„ค์ • (Hugging Face ๋ชจ๋ธ ID)
model_id = "microsoft/bitnet-b1.58-2B-4T"
# ๋ชจ๋ธ ๋กœ๋“œ ์‹œ ๊ฒฝ๊ณ  ๋ฉ”์‹œ์ง€๋ฅผ ์ตœ์†Œํ™”ํ•˜๊ธฐ ์œ„ํ•ด ๋กœ๊น… ๋ ˆ๋ฒจ ์„ค์ •
os.environ["TRANSFORMERS_VERBOSITY"] = "error"
# AutoModelForCausalLM๊ณผ AutoTokenizer๋ฅผ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.
# BitNet ๋ชจ๋ธ์€ trust_remote_code=True๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.
# GitHub ํŠน์ • ๋ธŒ๋žœ์น˜์—์„œ ์„ค์น˜ํ•œ transformers๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
try:
print(f"๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘: {model_id}...")
# GPU๊ฐ€ ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•˜๋ฉด bf16 ์‚ฌ์šฉ
if torch.cuda.is_available():
# torch_dtype์„ ๋ช…์‹œ์ ์œผ๋กœ ์„ค์ •ํ•˜์—ฌ ๋กœ๋“œ ์˜ค๋ฅ˜ ๋ฐฉ์ง€ ์‹œ๋„
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
trust_remote_code=True
).to("cuda") # GPU๋กœ ๋ชจ๋ธ ์ด๋™
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
print("GPU๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ชจ๋ธ ๋กœ๋“œ ์™„๋ฃŒ.")
else:
# CPU ์‚ฌ์šฉ ์‹œ torch_dtype ์ƒ๋žต ๋˜๋Š” float32
model = AutoModelForCausalLM.from_pretrained(
model_id,
trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
print("CPU๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ชจ๋ธ ๋กœ๋“œ ์™„๋ฃŒ. ์„ฑ๋Šฅ์ด ๋А๋ฆด ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.")
except Exception as e:
print(f"๋ชจ๋ธ ๋กœ๋“œ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
tokenizer = None
model = None
print("๋ชจ๋ธ ๋กœ๋“œ์— ์‹คํŒจํ–ˆ์Šต๋‹ˆ๋‹ค. ์• ํ”Œ๋ฆฌ์ผ€์ด์…˜์ด ์ œ๋Œ€๋กœ ๋™์ž‘ํ•˜์ง€ ์•Š์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.")
# --- ํ…์ŠคํŠธ ์ƒ์„ฑ ํ•จ์ˆ˜ ---
def generate_text(prompt, max_length=100, temperature=0.7):
if model is None or tokenizer is None:
return "๋ชจ๋ธ ๋กœ๋“œ์— ์‹คํŒจํ•˜์—ฌ ํ…์ŠคํŠธ ์ƒ์„ฑ์„ ํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."
try:
# ํ”„๋กฌํ”„ํŠธ ํ† ํฐํ™”
inputs = tokenizer(prompt, return_tensors="pt")
# GPU ์‚ฌ์šฉ ๊ฐ€๋Šฅ ์‹œ GPU๋กœ ์ž…๋ ฅ ์ด๋™
if torch.cuda.is_available():
inputs = {k: v.to("cuda") for k, v in inputs.items()}
# ํ…์ŠคํŠธ ์ƒ์„ฑ
# LLaMA 3 ํ† ํฌ๋‚˜์ด์ €๋ฅผ ์‚ฌ์šฉํ•˜๋ฏ€๋กœ chat template ์ ์šฉ ๊ฐ€๋Šฅ (์„ ํƒ ์‚ฌํ•ญ)
# ๋ฉ”์‹œ์ง€ ํ˜•์‹์„ ์‚ฌ์šฉํ•˜์ง€ ์•Š๊ณ  ์ง์ ‘ ํ”„๋กฌํ”„ํŠธ ์ž…๋ ฅ ์‹œ ์•„๋ž˜ ์ฝ”๋“œ ์‚ฌ์šฉ
outputs = model.generate(
**inputs,
max_new_tokens=max_length,
temperature=temperature,
do_sample=True, # ์ƒ˜ํ”Œ๋ง ํ™œ์„ฑํ™”
pad_token_id=tokenizer.eos_token_id # ํŒจ๋”ฉ ํ† ํฐ ID ์„ค์ • (ํ•„์š”์‹œ)
)
# ์ƒ์„ฑ๋œ ํ…์ŠคํŠธ ๋””์ฝ”๋”ฉ
# ์ž…๋ ฅ ํ”„๋กฌํ”„ํŠธ ๋ถ€๋ถ„์„ ์ œ์™ธํ•˜๊ณ  ์ƒ์„ฑ๋œ ๋ถ€๋ถ„๋งŒ ๋””์ฝ”๋”ฉ
generated_text = tokenizer.decode(outputs[0][inputs['input_ids'].shape[-1]:], skip_special_tokens=True)
return generated_text
except Exception as e:
return f"ํ…์ŠคํŠธ ์ƒ์„ฑ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}"
# --- Gradio ์ธํ„ฐํŽ˜์ด์Šค ์„ค์ • ---
if model is not None and tokenizer is not None:
interface = gr.Interface(
fn=generate_text,
inputs=[
gr.Textbox(lines=2, placeholder="ํ…์ŠคํŠธ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”...", label="์ž…๋ ฅ ํ”„๋กฌํ”„ํŠธ"),
gr.Slider(minimum=10, maximum=500, value=100, label="์ตœ๋Œ€ ์ƒ์„ฑ ๊ธธ์ด"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.7, label="Temperature (์ฐฝ์˜์„ฑ)")
],
outputs=gr.Textbox(label="์ƒ์„ฑ๋œ ํ…์ŠคํŠธ"),
title="BitNet b1.58-2B-4T ํ…์ŠคํŠธ ์ƒ์„ฑ ๋ฐ๋ชจ",
description="BitNet b1.58-2B-4T ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜์—ฌ ํ…์ŠคํŠธ๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค."
)
# Gradio ์•ฑ ์‹คํ–‰
# Hugging Face Spaces์—์„œ๋Š” share=True๊ฐ€ ์ž๋™์œผ๋กœ ์„ค์ •๋ฉ๋‹ˆ๋‹ค.
interface.launch()
else:
print("๋ชจ๋ธ ๋กœ๋“œ ์‹คํŒจ๋กœ ์ธํ•ด Gradio ์ธํ„ฐํŽ˜์ด์Šค๋ฅผ ์‹คํ–‰ํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")