Spaces:
Running
Running
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import os | |
# --- ๋ชจ๋ธ ๋ก๋ --- | |
# ๋ชจ๋ธ ๊ฒฝ๋ก ์ค์ (Hugging Face ๋ชจ๋ธ ID) | |
model_id = "microsoft/bitnet-b1.58-2B-4T" | |
# ๋ชจ๋ธ ๋ก๋ ์ ๊ฒฝ๊ณ ๋ฉ์์ง๋ฅผ ์ต์ํํ๊ธฐ ์ํด ๋ก๊น ๋ ๋ฒจ ์ค์ | |
os.environ["TRANSFORMERS_VERBOSITY"] = "error" | |
# AutoModelForCausalLM๊ณผ AutoTokenizer๋ฅผ ๋ก๋ํฉ๋๋ค. | |
# BitNet ๋ชจ๋ธ์ trust_remote_code=True๊ฐ ํ์ํฉ๋๋ค. | |
# GitHub ํน์ ๋ธ๋์น์์ ์ค์นํ transformers๋ฅผ ์ฌ์ฉํฉ๋๋ค. | |
try: | |
print(f"๋ชจ๋ธ ๋ก๋ฉ ์ค: {model_id}...") | |
# GPU๊ฐ ์ฌ์ฉ ๊ฐ๋ฅํ๋ฉด bf16 ์ฌ์ฉ | |
if torch.cuda.is_available(): | |
# torch_dtype์ ๋ช ์์ ์ผ๋ก ์ค์ ํ์ฌ ๋ก๋ ์ค๋ฅ ๋ฐฉ์ง ์๋ | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
torch_dtype=torch.bfloat16, | |
trust_remote_code=True | |
).to("cuda") # GPU๋ก ๋ชจ๋ธ ์ด๋ | |
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) | |
print("GPU๋ฅผ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ ๋ก๋ ์๋ฃ.") | |
else: | |
# CPU ์ฌ์ฉ ์ torch_dtype ์๋ต ๋๋ float32 | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
trust_remote_code=True | |
) | |
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) | |
print("CPU๋ฅผ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ ๋ก๋ ์๋ฃ. ์ฑ๋ฅ์ด ๋๋ฆด ์ ์์ต๋๋ค.") | |
except Exception as e: | |
print(f"๋ชจ๋ธ ๋ก๋ ์ค ์ค๋ฅ ๋ฐ์: {e}") | |
tokenizer = None | |
model = None | |
print("๋ชจ๋ธ ๋ก๋์ ์คํจํ์ต๋๋ค. ์ ํ๋ฆฌ์ผ์ด์ ์ด ์ ๋๋ก ๋์ํ์ง ์์ ์ ์์ต๋๋ค.") | |
# --- ํ ์คํธ ์์ฑ ํจ์ --- | |
def generate_text(prompt, max_length=100, temperature=0.7): | |
if model is None or tokenizer is None: | |
return "๋ชจ๋ธ ๋ก๋์ ์คํจํ์ฌ ํ ์คํธ ์์ฑ์ ํ ์ ์์ต๋๋ค." | |
try: | |
# ํ๋กฌํํธ ํ ํฐํ | |
inputs = tokenizer(prompt, return_tensors="pt") | |
# GPU ์ฌ์ฉ ๊ฐ๋ฅ ์ GPU๋ก ์ ๋ ฅ ์ด๋ | |
if torch.cuda.is_available(): | |
inputs = {k: v.to("cuda") for k, v in inputs.items()} | |
# ํ ์คํธ ์์ฑ | |
# LLaMA 3 ํ ํฌ๋์ด์ ๋ฅผ ์ฌ์ฉํ๋ฏ๋ก chat template ์ ์ฉ ๊ฐ๋ฅ (์ ํ ์ฌํญ) | |
# ๋ฉ์์ง ํ์์ ์ฌ์ฉํ์ง ์๊ณ ์ง์ ํ๋กฌํํธ ์ ๋ ฅ ์ ์๋ ์ฝ๋ ์ฌ์ฉ | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=max_length, | |
temperature=temperature, | |
do_sample=True, # ์ํ๋ง ํ์ฑํ | |
pad_token_id=tokenizer.eos_token_id # ํจ๋ฉ ํ ํฐ ID ์ค์ (ํ์์) | |
) | |
# ์์ฑ๋ ํ ์คํธ ๋์ฝ๋ฉ | |
# ์ ๋ ฅ ํ๋กฌํํธ ๋ถ๋ถ์ ์ ์ธํ๊ณ ์์ฑ๋ ๋ถ๋ถ๋ง ๋์ฝ๋ฉ | |
generated_text = tokenizer.decode(outputs[0][inputs['input_ids'].shape[-1]:], skip_special_tokens=True) | |
return generated_text | |
except Exception as e: | |
return f"ํ ์คํธ ์์ฑ ์ค ์ค๋ฅ ๋ฐ์: {e}" | |
# --- Gradio ์ธํฐํ์ด์ค ์ค์ --- | |
if model is not None and tokenizer is not None: | |
interface = gr.Interface( | |
fn=generate_text, | |
inputs=[ | |
gr.Textbox(lines=2, placeholder="ํ ์คํธ๋ฅผ ์ ๋ ฅํ์ธ์...", label="์ ๋ ฅ ํ๋กฌํํธ"), | |
gr.Slider(minimum=10, maximum=500, value=100, label="์ต๋ ์์ฑ ๊ธธ์ด"), | |
gr.Slider(minimum=0.1, maximum=1.0, value=0.7, label="Temperature (์ฐฝ์์ฑ)") | |
], | |
outputs=gr.Textbox(label="์์ฑ๋ ํ ์คํธ"), | |
title="BitNet b1.58-2B-4T ํ ์คํธ ์์ฑ ๋ฐ๋ชจ", | |
description="BitNet b1.58-2B-4T ๋ชจ๋ธ์ ์ฌ์ฉํ์ฌ ํ ์คํธ๋ฅผ ์์ฑํฉ๋๋ค." | |
) | |
# Gradio ์ฑ ์คํ | |
# Hugging Face Spaces์์๋ share=True๊ฐ ์๋์ผ๋ก ์ค์ ๋ฉ๋๋ค. | |
interface.launch() | |
else: | |
print("๋ชจ๋ธ ๋ก๋ ์คํจ๋ก ์ธํด Gradio ์ธํฐํ์ด์ค๋ฅผ ์คํํ ์ ์์ต๋๋ค.") |