File size: 3,855 Bytes
f498762
9d9cc80
 
 
f498762
9d9cc80
 
 
f498762
9d9cc80
 
f498762
9d9cc80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f498762
9d9cc80
 
 
 
 
f498762
 
9d9cc80
 
 
 
f498762
9d9cc80
 
 
 
 
 
f498762
9d9cc80
 
 
 
 
 
 
 
 
 
f498762
9d9cc80
 
 
f498762
9d9cc80
f498762
9d9cc80
 
f498762
9d9cc80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import os

# --- ๋ชจ๋ธ ๋กœ๋“œ ---
# ๋ชจ๋ธ ๊ฒฝ๋กœ ์„ค์ • (Hugging Face ๋ชจ๋ธ ID)
model_id = "microsoft/bitnet-b1.58-2B-4T"

# ๋ชจ๋ธ ๋กœ๋“œ ์‹œ ๊ฒฝ๊ณ  ๋ฉ”์‹œ์ง€๋ฅผ ์ตœ์†Œํ™”ํ•˜๊ธฐ ์œ„ํ•ด ๋กœ๊น… ๋ ˆ๋ฒจ ์„ค์ •
os.environ["TRANSFORMERS_VERBOSITY"] = "error"

# AutoModelForCausalLM๊ณผ AutoTokenizer๋ฅผ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.
# BitNet ๋ชจ๋ธ์€ trust_remote_code=True๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค.
# bf16์€ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰์„ ์ค„์ด๊ณ  ์†๋„๋ฅผ ํ–ฅ์ƒ์‹œํ‚ฌ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค (GPU ์ง€์› ์‹œ).
# CPU๋งŒ ์‚ฌ์šฉํ•˜๋Š” ๊ฒฝ์šฐ torch_dtype์„ ์ƒ๋žตํ•˜๊ฑฐ๋‚˜ torch.float32๋กœ ์„ค์ •ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
try:
    print(f"๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘: {model_id}...")
    tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
    # GPU๊ฐ€ ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•˜๋ฉด bf16 ์‚ฌ์šฉ
    if torch.cuda.is_available():
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            torch_dtype=torch.bfloat16,
            trust_remote_code=True
        ).to("cuda") # GPU๋กœ ๋ชจ๋ธ ์ด๋™
        print("GPU๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ชจ๋ธ ๋กœ๋“œ ์™„๋ฃŒ.")
    else:
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            trust_remote_code=True
        )
        print("CPU๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ชจ๋ธ ๋กœ๋“œ ์™„๋ฃŒ. ์„ฑ๋Šฅ์ด ๋А๋ฆด ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.")

except Exception as e:
    print(f"๋ชจ๋ธ ๋กœ๋“œ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
    tokenizer = None
    model = None
    print("๋ชจ๋ธ ๋กœ๋“œ์— ์‹คํŒจํ–ˆ์Šต๋‹ˆ๋‹ค. ์• ํ”Œ๋ฆฌ์ผ€์ด์…˜์ด ์ œ๋Œ€๋กœ ๋™์ž‘ํ•˜์ง€ ์•Š์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.")


# --- ํ…์ŠคํŠธ ์ƒ์„ฑ ํ•จ์ˆ˜ ---
def generate_text(prompt, max_length=100, temperature=0.7):
    if model is None or tokenizer is None:
        return "๋ชจ๋ธ ๋กœ๋“œ์— ์‹คํŒจํ•˜์—ฌ ํ…์ŠคํŠธ ์ƒ์„ฑ์„ ํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."

    try:
        # ํ”„๋กฌํ”„ํŠธ ํ† ํฐํ™”
        inputs = tokenizer(prompt, return_tensors="pt")
        # GPU ์‚ฌ์šฉ ๊ฐ€๋Šฅ ์‹œ GPU๋กœ ์ž…๋ ฅ ์ด๋™
        if torch.cuda.is_available():
            inputs = {k: v.to("cuda") for k, v in inputs.items()}

        # ํ…์ŠคํŠธ ์ƒ์„ฑ
        # LLaMA 3 ํ† ํฌ๋‚˜์ด์ €๋ฅผ ์‚ฌ์šฉํ•˜๋ฏ€๋กœ chat template ์ ์šฉ ๊ฐ€๋Šฅ (์„ ํƒ ์‚ฌํ•ญ)
        # ๋ฉ”์‹œ์ง€ ํ˜•์‹์„ ์‚ฌ์šฉํ•˜์ง€ ์•Š๊ณ  ์ง์ ‘ ํ”„๋กฌํ”„ํŠธ ์ž…๋ ฅ ์‹œ ์•„๋ž˜ ์ฝ”๋“œ ์‚ฌ์šฉ
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_length,
            temperature=temperature,
            do_sample=True, # ์ƒ˜ํ”Œ๋ง ํ™œ์„ฑํ™”
            pad_token_id=tokenizer.eos_token_id # ํŒจ๋”ฉ ํ† ํฐ ID ์„ค์ • (ํ•„์š”์‹œ)
        )

        # ์ƒ์„ฑ๋œ ํ…์ŠคํŠธ ๋””์ฝ”๋”ฉ
        # ์ž…๋ ฅ ํ”„๋กฌํ”„ํŠธ ๋ถ€๋ถ„์„ ์ œ์™ธํ•˜๊ณ  ์ƒ์„ฑ๋œ ๋ถ€๋ถ„๋งŒ ๋””์ฝ”๋”ฉ
        generated_text = tokenizer.decode(outputs[0][inputs['input_ids'].shape[-1]:], skip_special_tokens=True)

        return generated_text

    except Exception as e:
        return f"ํ…์ŠคํŠธ ์ƒ์„ฑ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}"

# --- Gradio ์ธํ„ฐํŽ˜์ด์Šค ์„ค์ • ---
if model is not None and tokenizer is not None:
    interface = gr.Interface(
        fn=generate_text,
        inputs=[
            gr.Textbox(lines=2, placeholder="ํ…์ŠคํŠธ๋ฅผ ์ž…๋ ฅํ•˜์„ธ์š”...", label="์ž…๋ ฅ ํ”„๋กฌํ”„ํŠธ"),
            gr.Slider(minimum=10, maximum=500, value=100, label="์ตœ๋Œ€ ์ƒ์„ฑ ๊ธธ์ด"),
            gr.Slider(minimum=0.1, maximum=1.0, value=0.7, label="Temperature (์ฐฝ์˜์„ฑ)")
        ],
        outputs=gr.Textbox(label="์ƒ์„ฑ๋œ ํ…์ŠคํŠธ"),
        title="BitNet b1.58-2B-4T ํ…์ŠคํŠธ ์ƒ์„ฑ ๋ฐ๋ชจ",
        description="BitNet b1.58-2B-4T ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜์—ฌ ํ…์ŠคํŠธ๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค."
    )

    # Gradio ์•ฑ ์‹คํ–‰
    # share=True๋ฅผ ํ•˜๋ฉด ์ž„์‹œ ๊ณต๊ฐœ ๋งํฌ๊ฐ€ ์ƒ์„ฑ๋ฉ๋‹ˆ๋‹ค.
    interface.launch(share=False)
else:
    print("๋ชจ๋ธ ๋กœ๋“œ ์‹คํŒจ๋กœ ์ธํ•ด Gradio ์ธํ„ฐํŽ˜์ด์Šค๋ฅผ ์‹คํ–‰ํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")