File size: 5,379 Bytes
b4ed6e4
fa7e3c5
580e705
 
fa7e3c5
 
 
580e705
b2b704c
fa7e3c5
 
580e705
 
bb20016
580e705
 
bb20016
 
 
 
 
 
580e705
 
 
 
 
 
 
 
 
 
 
 
 
54dd705
 
580e705
 
 
54dd705
fa7e3c5
580e705
 
 
 
 
 
 
 
 
 
b4ed6e4
6d70605
fa7e3c5
 
580e705
 
 
 
 
 
6d70605
580e705
 
 
 
6d70605
fa7e3c5
 
 
580e705
fa7e3c5
 
580e705
 
fa7e3c5
1794ce2
580e705
fa7e3c5
580e705
fa7e3c5
580e705
fa7e3c5
 
580e705
 
fa7e3c5
 
 
 
 
 
580e705
 
 
 
 
 
fa7e3c5
 
 
 
580e705
fa7e3c5
580e705
 
 
fa7e3c5
 
580e705
fa7e3c5
580e705
 
 
 
fa7e3c5
580e705
fa7e3c5
 
 
580e705
fa7e3c5
580e705
 
 
 
fa7e3c5
 
 
580e705
fa7e3c5
 
 
 
580e705
 
 
fa7e3c5
 
 
6d70605
fa7e3c5
580e705
fa7e3c5
580e705
fa7e3c5
580e705
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import spaces
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
import torch
from threading import Thread

# Model and device configuration
phi4_model_path = "Compumacy/OpenBioLLm-70B"
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# === INITIALIZE EMPTY WEIGHTS ===
init_empty_weights()

# === CONFIGURE 4-BIT QUANTIZATION ===
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4"
)

# === LOAD MODEL WITH QUANTIZATION ===
model = AutoModelForCausalLM.from_pretrained(
    phi4_model_path,
    quantization_config=bnb_config,
    torch_dtype=torch.float16,
    device_map="auto"
)

tokenizer = AutoTokenizer.from_pretrained(phi4_model_path)

# === OFFLOAD TO CPU/DISK ===
model = load_checkpoint_and_dispatch(
    model,
    phi4_model_path,
    device_map="auto",
    offload_folder="offload",
    offload_state_dict=True,
    max_memory={**{i: "12GB" for i in range(torch.cuda.device_count())}, "cpu": "30GB"}
)

# Enable gradient checkpointing if ever fine-tuning
model.gradient_checkpointing_enable()

# Optionally compile for PyTorch >= 2.0
try:
    model = torch.compile(model)
except Exception:
    pass

# === RESPONSE GENERATOR ===
@spaces.GPU()
def generate_response(user_message, max_tokens, temperature, top_k, top_p, repetition_penalty, history_state):
    if not user_message.strip():
        return history_state, history_state

    # Prompt setup
    system_message = (
        "Your role as an assistant involves thoroughly exploring questions through a systematic thinking process..."
    )
    start_tag, sep_tag, end_tag = "<|im_start|>", "<|im_sep|>", "<|im_end|>"
    prompt = f"{start_tag}system{sep_tag}{system_message}{end_tag}"
    for msg in history_state:
        tag = msg["role"]
        content = msg["content"]
        prompt += f"{start_tag}{tag}{sep_tag}{content}{end_tag}"
    prompt += f"{start_tag}user{sep_tag}{user_message}{end_tag}{start_tag}assistant{sep_tag}"

    inputs = tokenizer(prompt, return_tensors="pt").to(device)

    # Streaming setup
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
    generation_kwargs = {
        "input_ids": inputs.input_ids,
        "attention_mask": inputs.attention_mask,
        "max_new_tokens": int(max_tokens),
        "do_sample": True,
        "temperature": temperature,
        "top_k": int(top_k),
        "top_p": top_p,
        "repetition_penalty": repetition_penalty,
        "streamer": streamer
    }

    # Run generation in thread
    Thread(target=model.generate, kwargs=generation_kwargs).start()

    assistant_response = ""
    new_history = history_state + [
        {"role": "user", "content": user_message},
        {"role": "assistant", "content": ""}
    ]

    # Stream tokens
    for token in streamer:
        clean = token.replace(start_tag, "").replace(sep_tag, "").replace(end_tag, "")
        assistant_response += clean
        new_history[-1]["content"] = assistant_response
        yield new_history, new_history

    yield new_history, new_history

# === EXAMPLE MESSAGES ===
example_messages = {
    "Math reasoning": "If a rectangular prism has a length of 6 cm...",
    "Logic puzzle": "Four people (Alex, Blake, Casey, ...)",
    "Physics problem": "A ball is thrown upward with an initial velocity..."
}

# === GRADIO APP ===
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("""
    # Phi-4 Chat
    Try the example problems below to see how the model breaks down complex reasoning.
    """ )

    history_state = gr.State([])
    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown("### Settings")
            max_tokens_slider = gr.Slider(64, 32768, step=1024, value=2048, label="Max Tokens")
            with gr.Accordion("Advanced Settings", open=False):
                temperature_slider = gr.Slider(0.1, 2.0, value=0.8, label="Temperature")
                top_k_slider = gr.Slider(1, 100, step=1, value=50, label="Top-k")
                top_p_slider = gr.Slider(0.1, 1.0, value=0.95, label="Top-p")
                repetition_penalty_slider = gr.Slider(1.0, 2.0, value=1.0, label="Repetition Penalty")
        with gr.Column(scale=4):
            chatbot = gr.Chatbot(label="Chat", type="messages")
            with gr.Row():
                user_input = gr.Textbox(placeholder="Type your message...", scale=3)
                submit_button = gr.Button("Send", variant="primary", scale=1)
                clear_button = gr.Button("Clear", scale=1)
            gr.Markdown("**Try these examples:**")
            with gr.Row():
                for name in example_messages:
                    btn = gr.Button(name)
                    btn.click(fn=lambda n=name: gr.update(value=example_messages[n]), inputs=None, outputs=user_input)

    submit_button.click(
        fn=generate_response,
        inputs=[user_input, max_tokens_slider, temperature_slider, top_k_slider, top_p_slider, repetition_penalty_slider, history_state],
        outputs=[chatbot, history_state]
    ).then(lambda: gr.update(value=""), None, user_input)

    clear_button.click(lambda: ([], []), None, [chatbot, history_state])

demo.launch(ssr_mode=False)