Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import gradio as gr | |
from transformers import AutoTokenizer, TextIteratorStreamer | |
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig | |
import torch | |
from threading import Thread | |
# Model and device configuration | |
phi4_model_path = "Compumacy/OpenBioLLm-70B" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# === GPTQ 2-bit QUANTIZATION CONFIG === | |
quantize_config = BaseQuantizeConfig( | |
load_in_4bit=False, | |
load_in_8bit=False, | |
quantization_bit=2, | |
compute_dtype=torch.float16, | |
use_double_quant=True, | |
quant_type="nf4" | |
) | |
# === LOAD GPTQ-QUANTIZED MODEL === | |
model = AutoGPTQForCausalLM.from_quantized( | |
phi4_model_path, | |
quantize_config=quantize_config, | |
device_map="auto", | |
use_safetensors=True, | |
) | |
tokenizer = AutoTokenizer.from_pretrained(phi4_model_path) | |
# === OPTIONAL: TorchCompile for optimization (PyTorch >= 2.0) === | |
try: | |
model = torch.compile(model) | |
except Exception: | |
pass | |
# === STREAMING RESPONSE GENERATOR === | |
def generate_response(user_message, max_tokens, temperature, top_k, top_p, repetition_penalty, history_state): | |
if not user_message.strip(): | |
return history_state, history_state | |
# System prompt prefix | |
system_message = ( | |
"Your role as an assistant involves thoroughly exploring questions through a systematic thinking process..." | |
) | |
start_tag, sep_tag, end_tag = "<|im_start|>", "<|im_sep|>", "<|im_end|>" | |
# Build full prompt | |
prompt = f"{start_tag}system{sep_tag}{system_message}{end_tag}" | |
for msg in history_state: | |
prompt += f"{start_tag}{msg['role']}{sep_tag}{msg['content']}{end_tag}" | |
prompt += f"{start_tag}user{sep_tag}{user_message}{end_tag}{start_tag}assistant{sep_tag}" | |
# Tokenize and move to device | |
inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
# Set up streamer | |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True) | |
generation_kwargs = { | |
"input_ids": inputs.input_ids, | |
"attention_mask": inputs.attention_mask, | |
"max_new_tokens": int(max_tokens), | |
"do_sample": True, | |
"temperature": temperature, | |
"top_k": int(top_k), | |
"top_p": top_p, | |
"repetition_penalty": repetition_penalty, | |
"streamer": streamer | |
} | |
# Launch generation | |
Thread(target=model.generate, kwargs=generation_kwargs).start() | |
assistant_response = "" | |
new_history = history_state + [ | |
{"role": "user", "content": user_message}, | |
{"role": "assistant", "content": ""} | |
] | |
# Stream tokens back to Gradio | |
for token in streamer: | |
clean = token.replace(start_tag, "").replace(sep_tag, "").replace(end_tag, "") | |
assistant_response += clean | |
new_history[-1]["content"] = assistant_response | |
yield new_history, new_history | |
yield new_history, new_history | |
# === EXAMPLE MESSAGES === | |
example_messages = { | |
"Math reasoning": "If a rectangular prism has a length of 6 cm...", | |
"Logic puzzle": "Four people (Alex, Blake, Casey, ...)", | |
"Physics problem": "A ball is thrown upward with an initial velocity..." | |
} | |
# === GRADIO APP === | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown(""" | |
# Phi-4 Chat with GPTQ Quant | |
Try the example problems below to see how the model breaks down complex reasoning. | |
""" ) | |
history_state = gr.State([]) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("### Settings") | |
max_tokens_slider = gr.Slider(64, 32768, step=1024, value=2048, label="Max Tokens") | |
with gr.Accordion("Advanced Settings", open=False): | |
temperature_slider = gr.Slider(0.1, 2.0, value=0.8, label="Temperature") | |
top_k_slider = gr.Slider(1, 100, step=1, value=50, label="Top-k") | |
top_p_slider = gr.Slider(0.1, 1.0, value=0.95, label="Top-p") | |
repetition_penalty_slider = gr.Slider(1.0, 2.0, value=1.0, label="Repetition Penalty") | |
with gr.Column(scale=4): | |
chatbot = gr.Chatbot(label="Chat", type="messages") | |
with gr.Row(): | |
user_input = gr.Textbox(placeholder="Type your message...", scale=3) | |
submit_button = gr.Button("Send", variant="primary", scale=1) | |
clear_button = gr.Button("Clear", scale=1) | |
gr.Markdown("**Try these examples:**") | |
with gr.Row(): | |
for name, text in example_messages.items(): | |
btn = gr.Button(name) | |
btn.click(fn=lambda t=text: gr.update(value=t), None, user_input) | |
submit_button.click( | |
fn=generate_response, | |
inputs=[user_input, max_tokens_slider, temperature_slider, top_k_slider, top_p_slider, repetition_penalty_slider, history_state], | |
outputs=[chatbot, history_state] | |
).then(lambda: gr.update(value=""), None, user_input) | |
clear_button.click(lambda: ([], []), None, [chatbot, history_state]) | |
demo.launch(ssr_mode=False) | |