import spaces import gradio as gr from transformers import AutoTokenizer, TextIteratorStreamer from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig import torch from threading import Thread # Model and device configuration phi4_model_path = "Compumacy/OpenBioLLm-70B" device = "cuda" if torch.cuda.is_available() else "cpu" # === GPTQ 2-bit QUANTIZATION CONFIG === quantize_config = BaseQuantizeConfig( load_in_4bit=False, load_in_8bit=False, quantization_bit=2, compute_dtype=torch.float16, use_double_quant=True, quant_type="nf4" ) # === LOAD GPTQ-QUANTIZED MODEL === model = AutoGPTQForCausalLM.from_quantized( phi4_model_path, quantize_config=quantize_config, device_map="auto", use_safetensors=True, ) tokenizer = AutoTokenizer.from_pretrained(phi4_model_path) # === OPTIONAL: TorchCompile for optimization (PyTorch >= 2.0) === try: model = torch.compile(model) except Exception: pass # === STREAMING RESPONSE GENERATOR === @spaces.GPU() def generate_response(user_message, max_tokens, temperature, top_k, top_p, repetition_penalty, history_state): if not user_message.strip(): return history_state, history_state # System prompt prefix system_message = ( "Your role as an assistant involves thoroughly exploring questions through a systematic thinking process..." ) start_tag, sep_tag, end_tag = "<|im_start|>", "<|im_sep|>", "<|im_end|>" # Build full prompt prompt = f"{start_tag}system{sep_tag}{system_message}{end_tag}" for msg in history_state: prompt += f"{start_tag}{msg['role']}{sep_tag}{msg['content']}{end_tag}" prompt += f"{start_tag}user{sep_tag}{user_message}{end_tag}{start_tag}assistant{sep_tag}" # Tokenize and move to device inputs = tokenizer(prompt, return_tensors="pt").to(device) # Set up streamer streamer = TextIteratorStreamer(tokenizer, skip_prompt=True) generation_kwargs = { "input_ids": inputs.input_ids, "attention_mask": inputs.attention_mask, "max_new_tokens": int(max_tokens), "do_sample": True, "temperature": temperature, "top_k": int(top_k), "top_p": top_p, "repetition_penalty": repetition_penalty, "streamer": streamer } # Launch generation Thread(target=model.generate, kwargs=generation_kwargs).start() assistant_response = "" new_history = history_state + [ {"role": "user", "content": user_message}, {"role": "assistant", "content": ""} ] # Stream tokens back to Gradio for token in streamer: clean = token.replace(start_tag, "").replace(sep_tag, "").replace(end_tag, "") assistant_response += clean new_history[-1]["content"] = assistant_response yield new_history, new_history yield new_history, new_history # === EXAMPLE MESSAGES === example_messages = { "Math reasoning": "If a rectangular prism has a length of 6 cm...", "Logic puzzle": "Four people (Alex, Blake, Casey, ...)", "Physics problem": "A ball is thrown upward with an initial velocity..." } # === GRADIO APP === with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown(""" # Phi-4 Chat with GPTQ Quant Try the example problems below to see how the model breaks down complex reasoning. """ ) history_state = gr.State([]) with gr.Row(): with gr.Column(scale=1): gr.Markdown("### Settings") max_tokens_slider = gr.Slider(64, 32768, step=1024, value=2048, label="Max Tokens") with gr.Accordion("Advanced Settings", open=False): temperature_slider = gr.Slider(0.1, 2.0, value=0.8, label="Temperature") top_k_slider = gr.Slider(1, 100, step=1, value=50, label="Top-k") top_p_slider = gr.Slider(0.1, 1.0, value=0.95, label="Top-p") repetition_penalty_slider = gr.Slider(1.0, 2.0, value=1.0, label="Repetition Penalty") with gr.Column(scale=4): chatbot = gr.Chatbot(label="Chat", type="messages") with gr.Row(): user_input = gr.Textbox(placeholder="Type your message...", scale=3) submit_button = gr.Button("Send", variant="primary", scale=1) clear_button = gr.Button("Clear", scale=1) gr.Markdown("**Try these examples:**") with gr.Row(): for name, text in example_messages.items(): btn = gr.Button(name) btn.click(fn=lambda t=text: gr.update(value=t), inputs=None, outputs=user_input) submit_button.click( fn=generate_response, inputs=[user_input, max_tokens_slider, temperature_slider, top_k_slider, top_p_slider, repetition_penalty_slider, history_state], outputs=[chatbot, history_state] ).then(lambda: gr.update(value=""), None, user_input) clear_button.click(lambda: ([], []), None, [chatbot, history_state]) demo.launch(ssr_mode=False)