VisoLearn's picture
Update app.py
b5ca495 verified
raw
history blame
4.99 kB
import spaces
import gradio as gr
from transformers import AutoTokenizer, TextIteratorStreamer
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
import torch
from threading import Thread
# Model and device configuration
phi4_model_path = "Compumacy/OpenBioLLm-70B"
device = "cuda" if torch.cuda.is_available() else "cpu"
# === GPTQ 2-bit QUANTIZATION CONFIG ===
quantize_config = BaseQuantizeConfig(
load_in_4bit=False,
load_in_8bit=False,
quantization_bit=2,
compute_dtype=torch.float16,
use_double_quant=True,
quant_type="nf4"
)
# === LOAD GPTQ-QUANTIZED MODEL ===
model = AutoGPTQForCausalLM.from_quantized(
phi4_model_path,
quantize_config=quantize_config,
device_map="auto",
use_safetensors=True,
)
tokenizer = AutoTokenizer.from_pretrained(phi4_model_path)
# === OPTIONAL: TorchCompile for optimization (PyTorch >= 2.0) ===
try:
model = torch.compile(model)
except Exception:
pass
# === STREAMING RESPONSE GENERATOR ===
@spaces.GPU()
def generate_response(user_message, max_tokens, temperature, top_k, top_p, repetition_penalty, history_state):
if not user_message.strip():
return history_state, history_state
# System prompt prefix
system_message = (
"Your role as an assistant involves thoroughly exploring questions through a systematic thinking process..."
)
start_tag, sep_tag, end_tag = "<|im_start|>", "<|im_sep|>", "<|im_end|>"
# Build full prompt
prompt = f"{start_tag}system{sep_tag}{system_message}{end_tag}"
for msg in history_state:
prompt += f"{start_tag}{msg['role']}{sep_tag}{msg['content']}{end_tag}"
prompt += f"{start_tag}user{sep_tag}{user_message}{end_tag}{start_tag}assistant{sep_tag}"
# Tokenize and move to device
inputs = tokenizer(prompt, return_tensors="pt").to(device)
# Set up streamer
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
generation_kwargs = {
"input_ids": inputs.input_ids,
"attention_mask": inputs.attention_mask,
"max_new_tokens": int(max_tokens),
"do_sample": True,
"temperature": temperature,
"top_k": int(top_k),
"top_p": top_p,
"repetition_penalty": repetition_penalty,
"streamer": streamer
}
# Launch generation
Thread(target=model.generate, kwargs=generation_kwargs).start()
assistant_response = ""
new_history = history_state + [
{"role": "user", "content": user_message},
{"role": "assistant", "content": ""}
]
# Stream tokens back to Gradio
for token in streamer:
clean = token.replace(start_tag, "").replace(sep_tag, "").replace(end_tag, "")
assistant_response += clean
new_history[-1]["content"] = assistant_response
yield new_history, new_history
yield new_history, new_history
# === EXAMPLE MESSAGES ===
example_messages = {
"Math reasoning": "If a rectangular prism has a length of 6 cm...",
"Logic puzzle": "Four people (Alex, Blake, Casey, ...)",
"Physics problem": "A ball is thrown upward with an initial velocity..."
}
# === GRADIO APP ===
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# Phi-4 Chat with GPTQ Quant
Try the example problems below to see how the model breaks down complex reasoning.
""" )
history_state = gr.State([])
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Settings")
max_tokens_slider = gr.Slider(64, 32768, step=1024, value=2048, label="Max Tokens")
with gr.Accordion("Advanced Settings", open=False):
temperature_slider = gr.Slider(0.1, 2.0, value=0.8, label="Temperature")
top_k_slider = gr.Slider(1, 100, step=1, value=50, label="Top-k")
top_p_slider = gr.Slider(0.1, 1.0, value=0.95, label="Top-p")
repetition_penalty_slider = gr.Slider(1.0, 2.0, value=1.0, label="Repetition Penalty")
with gr.Column(scale=4):
chatbot = gr.Chatbot(label="Chat", type="messages")
with gr.Row():
user_input = gr.Textbox(placeholder="Type your message...", scale=3)
submit_button = gr.Button("Send", variant="primary", scale=1)
clear_button = gr.Button("Clear", scale=1)
gr.Markdown("**Try these examples:**")
with gr.Row():
for name, text in example_messages.items():
btn = gr.Button(name)
btn.click(fn=lambda t=text: gr.update(value=t), None, user_input)
submit_button.click(
fn=generate_response,
inputs=[user_input, max_tokens_slider, temperature_slider, top_k_slider, top_p_slider, repetition_penalty_slider, history_state],
outputs=[chatbot, history_state]
).then(lambda: gr.update(value=""), None, user_input)
clear_button.click(lambda: ([], []), None, [chatbot, history_state])
demo.launch(ssr_mode=False)