File size: 4,990 Bytes
b4ed6e4
fa7e3c5
b5ca495
 
fa7e3c5
 
 
580e705
b2b704c
b5ca495
 
 
 
 
 
 
 
 
 
bb20016
 
b5ca495
 
54dd705
b5ca495
54dd705
b5ca495
54dd705
fa7e3c5
b5ca495
580e705
b5ca495
580e705
 
 
 
 
b5ca495
b4ed6e4
6d70605
fa7e3c5
 
580e705
b5ca495
580e705
 
 
 
b5ca495
 
6d70605
580e705
b5ca495
6d70605
fa7e3c5
b5ca495
fa7e3c5
 
b5ca495
fa7e3c5
 
580e705
 
fa7e3c5
1794ce2
580e705
fa7e3c5
580e705
fa7e3c5
580e705
fa7e3c5
 
b5ca495
580e705
fa7e3c5
 
 
 
 
 
580e705
b5ca495
580e705
 
 
 
fa7e3c5
 
 
 
580e705
fa7e3c5
580e705
 
 
fa7e3c5
 
580e705
fa7e3c5
580e705
b5ca495
580e705
 
fa7e3c5
580e705
fa7e3c5
 
 
580e705
fa7e3c5
580e705
 
 
 
fa7e3c5
 
 
580e705
fa7e3c5
 
 
 
b5ca495
580e705
b5ca495
fa7e3c5
 
 
6d70605
fa7e3c5
580e705
fa7e3c5
580e705
fa7e3c5
580e705
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import spaces
import gradio as gr
from transformers import AutoTokenizer, TextIteratorStreamer
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
import torch
from threading import Thread

# Model and device configuration
phi4_model_path = "Compumacy/OpenBioLLm-70B"
device = "cuda" if torch.cuda.is_available() else "cpu"

# === GPTQ 2-bit QUANTIZATION CONFIG ===
quantize_config = BaseQuantizeConfig(
    load_in_4bit=False,
    load_in_8bit=False,
    quantization_bit=2,
    compute_dtype=torch.float16,
    use_double_quant=True,
    quant_type="nf4"
)

# === LOAD GPTQ-QUANTIZED MODEL ===
model = AutoGPTQForCausalLM.from_quantized(
    phi4_model_path,
    quantize_config=quantize_config,
    device_map="auto",
    use_safetensors=True,
)

tokenizer = AutoTokenizer.from_pretrained(phi4_model_path)

# === OPTIONAL: TorchCompile for optimization (PyTorch >= 2.0) ===
try:
    model = torch.compile(model)
except Exception:
    pass

# === STREAMING RESPONSE GENERATOR ===
@spaces.GPU()
def generate_response(user_message, max_tokens, temperature, top_k, top_p, repetition_penalty, history_state):
    if not user_message.strip():
        return history_state, history_state

    # System prompt prefix
    system_message = (
        "Your role as an assistant involves thoroughly exploring questions through a systematic thinking process..."
    )
    start_tag, sep_tag, end_tag = "<|im_start|>", "<|im_sep|>", "<|im_end|>"

    # Build full prompt
    prompt = f"{start_tag}system{sep_tag}{system_message}{end_tag}"
    for msg in history_state:
        prompt += f"{start_tag}{msg['role']}{sep_tag}{msg['content']}{end_tag}"
    prompt += f"{start_tag}user{sep_tag}{user_message}{end_tag}{start_tag}assistant{sep_tag}"

    # Tokenize and move to device
    inputs = tokenizer(prompt, return_tensors="pt").to(device)

    # Set up streamer
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
    generation_kwargs = {
        "input_ids": inputs.input_ids,
        "attention_mask": inputs.attention_mask,
        "max_new_tokens": int(max_tokens),
        "do_sample": True,
        "temperature": temperature,
        "top_k": int(top_k),
        "top_p": top_p,
        "repetition_penalty": repetition_penalty,
        "streamer": streamer
    }

    # Launch generation
    Thread(target=model.generate, kwargs=generation_kwargs).start()

    assistant_response = ""
    new_history = history_state + [
        {"role": "user", "content": user_message},
        {"role": "assistant", "content": ""}
    ]

    # Stream tokens back to Gradio
    for token in streamer:
        clean = token.replace(start_tag, "").replace(sep_tag, "").replace(end_tag, "")
        assistant_response += clean
        new_history[-1]["content"] = assistant_response
        yield new_history, new_history

    yield new_history, new_history

# === EXAMPLE MESSAGES ===
example_messages = {
    "Math reasoning": "If a rectangular prism has a length of 6 cm...",
    "Logic puzzle": "Four people (Alex, Blake, Casey, ...)",
    "Physics problem": "A ball is thrown upward with an initial velocity..."
}

# === GRADIO APP ===
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("""
    # Phi-4 Chat with GPTQ Quant
    Try the example problems below to see how the model breaks down complex reasoning.
    """ )

    history_state = gr.State([])
    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown("### Settings")
            max_tokens_slider = gr.Slider(64, 32768, step=1024, value=2048, label="Max Tokens")
            with gr.Accordion("Advanced Settings", open=False):
                temperature_slider = gr.Slider(0.1, 2.0, value=0.8, label="Temperature")
                top_k_slider = gr.Slider(1, 100, step=1, value=50, label="Top-k")
                top_p_slider = gr.Slider(0.1, 1.0, value=0.95, label="Top-p")
                repetition_penalty_slider = gr.Slider(1.0, 2.0, value=1.0, label="Repetition Penalty")
        with gr.Column(scale=4):
            chatbot = gr.Chatbot(label="Chat", type="messages")
            with gr.Row():
                user_input = gr.Textbox(placeholder="Type your message...", scale=3)
                submit_button = gr.Button("Send", variant="primary", scale=1)
                clear_button = gr.Button("Clear", scale=1)
            gr.Markdown("**Try these examples:**")
            with gr.Row():
                for name, text in example_messages.items():
                    btn = gr.Button(name)
                    btn.click(fn=lambda t=text: gr.update(value=t), None, user_input)

    submit_button.click(
        fn=generate_response,
        inputs=[user_input, max_tokens_slider, temperature_slider, top_k_slider, top_p_slider, repetition_penalty_slider, history_state],
        outputs=[chatbot, history_state]
    ).then(lambda: gr.update(value=""), None, user_input)

    clear_button.click(lambda: ([], []), None, [chatbot, history_state])

demo.launch(ssr_mode=False)