File size: 3,953 Bytes
f498762
9d9cc80
 
 
f498762
9d9cc80
 
 
f498762
9d9cc80
 
f498762
9d9cc80
 
ece9655
9d9cc80
 
 
 
ece9655
9d9cc80
 
 
 
 
ece9655
9d9cc80
 
ece9655
9d9cc80
 
 
 
ece9655
9d9cc80
f498762
9d9cc80
 
 
 
 
f498762
 
9d9cc80
 
 
 
f498762
9d9cc80
 
 
 
 
 
f498762
9d9cc80
 
 
 
 
 
 
 
 
 
f498762
9d9cc80
 
 
f498762
9d9cc80
f498762
9d9cc80
 
f498762
9d9cc80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ece9655
 
9d9cc80
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import os

# --- ๋ชจ๋ธ ๋กœ๋“œ ---
# ๋ชจ๋ธ ๊ฒฝ๋กœ ์„ค์ • (Hugging Face ๋ชจ๋ธ ID)
model_id = "microsoft/bitnet-b1.58-2B-4T"

# ๋ชจ๋ธ ๋กœ๋“œ ์‹œ ๊ฒฝ๊ณ  ๋ฉ”์‹œ์ง€๋ฅผ ์ตœ์†Œํ™”ํ•˜๊ธฐ ์œ„ํ•ด ๋กœ๊น… ๋ ˆ๋ฒจ ์„ค์ •
os.environ["TRANSFORMERS_VERBOSITY"] = "error"

# AutoModelForCausalLM๊ณผ AutoTokenizer๋ฅผ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.
# BitNet ๋ชจ๋ธ์€ trust_remote_code=True๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.
# GitHub ํŠน์ • ๋ธŒ๋žœ์น˜์—์„œ ์„ค์น˜ํ•œ transformers๋ฅผ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.
try:
    print(f"๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘: {model_id}...")
    # GPU๊ฐ€ ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•˜๋ฉด bf16 ์‚ฌ์šฉ
    if torch.cuda.is_available():
        # torch_dtype์„ ๋ช…์‹œ์ ์œผ๋กœ ์„ค์ •ํ•˜์—ฌ ๋กœ๋“œ ์˜ค๋ฅ˜ ๋ฐฉ์ง€ ์‹œ๋„
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            torch_dtype=torch.bfloat16,
            trust_remote_code=True
        ).to("cuda") # GPU๋กœ ๋ชจ๋ธ ์ด๋™
        tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
        print("GPU๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ชจ๋ธ ๋กœ๋“œ ์™„๋ฃŒ.")
    else:
        # CPU ์‚ฌ์šฉ ์‹œ torch_dtype ์ƒ๋žต ๋˜๋Š” float32
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            trust_remote_code=True
        )
        tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
        print("CPU๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ชจ๋ธ ๋กœ๋“œ ์™„๋ฃŒ. ์„ฑ๋Šฅ์ด ๋А๋ฆด ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.")

except Exception as e:
    print(f"๋ชจ๋ธ ๋กœ๋“œ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
    tokenizer = None
    model = None
    print("๋ชจ๋ธ ๋กœ๋“œ์— ์‹คํŒจํ–ˆ์Šต๋‹ˆ๋‹ค. ์• ํ”Œ๋ฆฌ์ผ€์ด์…˜์ด ์ œ๋Œ€๋กœ ๋™์ž‘ํ•˜์ง€ ์•Š์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.")


# --- ํ…์ŠคํŠธ ์ƒ์„ฑ ํ•จ์ˆ˜ ---
def generate_text(prompt, max_length=100, temperature=0.7):
    if model is None or tokenizer is None:
        return "๋ชจ๋ธ ๋กœ๋“œ์— ์‹คํŒจํ•˜์—ฌ ํ…์ŠคํŠธ ์ƒ์„ฑ์„ ํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."

    try:
        # ํ”„๋กฌํ”„ํŠธ ํ† ํฐํ™”
        inputs = tokenizer(prompt, return_tensors="pt")
        # GPU ์‚ฌ์šฉ ๊ฐ€๋Šฅ ์‹œ GPU๋กœ ์ž…๋ ฅ ์ด๋™
        if torch.cuda.is_available():
            inputs = {k: v.to("cuda") for k, v in inputs.items()}

        # ํ…์ŠคํŠธ ์ƒ์„ฑ
        # LLaMA 3 ํ† ํฌ๋‚˜์ด์ €๋ฅผ ์‚ฌ์šฉํ•˜๋ฏ€๋กœ chat template ์ ์šฉ ๊ฐ€๋Šฅ (์„ ํƒ ์‚ฌํ•ญ)
        # ๋ฉ”์‹œ์ง€ ํ˜•์‹์„ ์‚ฌ์šฉํ•˜์ง€ ์•Š๊ณ  ์ง์ ‘ ํ”„๋กฌํ”„ํŠธ ์ž…๋ ฅ ์‹œ ์•„๋ž˜ ์ฝ”๋“œ ์‚ฌ์šฉ
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_length,
            temperature=temperature,
            do_sample=True, # ์ƒ˜ํ”Œ๋ง ํ™œ์„ฑํ™”
            pad_token_id=tokenizer.eos_token_id # ํŒจ๋”ฉ ํ† ํฐ ID ์„ค์ • (ํ•„์š”์‹œ)
        )

        # ์ƒ์„ฑ๋œ ํ…์ŠคํŠธ ๋””์ฝ”๋”ฉ
        # ์ž…๋ ฅ ํ”„๋กฌํ”„ํŠธ ๋ถ€๋ถ„์„ ์ œ์™ธํ•˜๊ณ  ์ƒ์„ฑ๋œ ๋ถ€๋ถ„๋งŒ ๋””์ฝ”๋”ฉ
        generated_text = tokenizer.decode(outputs[0][inputs['input_ids'].shape[-1]:], skip_special_tokens=True)

        return generated_text

    except Exception as e:
        return f"ํ…์ŠคํŠธ ์ƒ์„ฑ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}"

# --- Gradio ์ธํ„ฐํŽ˜์ด์Šค ์„ค์ • ---
if model is not None and tokenizer is not None:
    interface = gr.Interface(
        fn=generate_text,
        inputs=[
            gr.Textbox(lines=2, placeholder="ํ…์ŠคํŠธ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”...", label="์ž…๋ ฅ ํ”„๋กฌํ”„ํŠธ"),
            gr.Slider(minimum=10, maximum=500, value=100, label="์ตœ๋Œ€ ์ƒ์„ฑ ๊ธธ์ด"),
            gr.Slider(minimum=0.1, maximum=1.0, value=0.7, label="Temperature (์ฐฝ์˜์„ฑ)")
        ],
        outputs=gr.Textbox(label="์ƒ์„ฑ๋œ ํ…์ŠคํŠธ"),
        title="BitNet b1.58-2B-4T ํ…์ŠคํŠธ ์ƒ์„ฑ ๋ฐ๋ชจ",
        description="BitNet b1.58-2B-4T ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜์—ฌ ํ…์ŠคํŠธ๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค."
    )

    # Gradio ์•ฑ ์‹คํ–‰
    # Hugging Face Spaces์—์„œ๋Š” share=True๊ฐ€ ์ž๋™์œผ๋กœ ์„ค์ •๋ฉ๋‹ˆ๋‹ค.
    interface.launch()
else:
    print("๋ชจ๋ธ ๋กœ๋“œ ์‹คํŒจ๋กœ ์ธํ•ด Gradio ์ธํ„ฐํŽ˜์ด์Šค๋ฅผ ์‹คํ–‰ํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")