Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig | |
from accelerate import init_empty_weights, load_checkpoint_and_dispatch | |
import torch | |
from threading import Thread | |
# Model and device configuration | |
phi4_model_path = "Compumacy/OpenBioLLm-70B" | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
# === INITIALIZE EMPTY WEIGHTS === | |
init_empty_weights() | |
# === CONFIGURE 4-BIT QUANTIZATION === | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_compute_dtype=torch.float16, | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_quant_type="nf4" | |
) | |
# === LOAD MODEL WITH QUANTIZATION === | |
model = AutoModelForCausalLM.from_pretrained( | |
phi4_model_path, | |
quantization_config=bnb_config, | |
torch_dtype=torch.float16, | |
device_map="auto" | |
) | |
tokenizer = AutoTokenizer.from_pretrained(phi4_model_path) | |
# === OFFLOAD TO CPU/DISK === | |
model = load_checkpoint_and_dispatch( | |
model, | |
phi4_model_path, | |
device_map="auto", | |
offload_folder="offload", | |
offload_state_dict=True, | |
max_memory={**{i: "12GB" for i in range(torch.cuda.device_count())}, "cpu": "30GB"} | |
) | |
# Enable gradient checkpointing if ever fine-tuning | |
model.gradient_checkpointing_enable() | |
# Optionally compile for PyTorch >= 2.0 | |
try: | |
model = torch.compile(model) | |
except Exception: | |
pass | |
# === RESPONSE GENERATOR === | |
def generate_response(user_message, max_tokens, temperature, top_k, top_p, repetition_penalty, history_state): | |
if not user_message.strip(): | |
return history_state, history_state | |
# Prompt setup | |
system_message = ( | |
"Your role as an assistant involves thoroughly exploring questions through a systematic thinking process..." | |
) | |
start_tag, sep_tag, end_tag = "<|im_start|>", "<|im_sep|>", "<|im_end|>" | |
prompt = f"{start_tag}system{sep_tag}{system_message}{end_tag}" | |
for msg in history_state: | |
tag = msg["role"] | |
content = msg["content"] | |
prompt += f"{start_tag}{tag}{sep_tag}{content}{end_tag}" | |
prompt += f"{start_tag}user{sep_tag}{user_message}{end_tag}{start_tag}assistant{sep_tag}" | |
inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
# Streaming setup | |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True) | |
generation_kwargs = { | |
"input_ids": inputs.input_ids, | |
"attention_mask": inputs.attention_mask, | |
"max_new_tokens": int(max_tokens), | |
"do_sample": True, | |
"temperature": temperature, | |
"top_k": int(top_k), | |
"top_p": top_p, | |
"repetition_penalty": repetition_penalty, | |
"streamer": streamer | |
} | |
# Run generation in thread | |
Thread(target=model.generate, kwargs=generation_kwargs).start() | |
assistant_response = "" | |
new_history = history_state + [ | |
{"role": "user", "content": user_message}, | |
{"role": "assistant", "content": ""} | |
] | |
# Stream tokens | |
for token in streamer: | |
clean = token.replace(start_tag, "").replace(sep_tag, "").replace(end_tag, "") | |
assistant_response += clean | |
new_history[-1]["content"] = assistant_response | |
yield new_history, new_history | |
yield new_history, new_history | |
# === EXAMPLE MESSAGES === | |
example_messages = { | |
"Math reasoning": "If a rectangular prism has a length of 6 cm...", | |
"Logic puzzle": "Four people (Alex, Blake, Casey, ...)", | |
"Physics problem": "A ball is thrown upward with an initial velocity..." | |
} | |
# === GRADIO APP === | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown(""" | |
# Phi-4 Chat | |
Try the example problems below to see how the model breaks down complex reasoning. | |
""" ) | |
history_state = gr.State([]) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
gr.Markdown("### Settings") | |
max_tokens_slider = gr.Slider(64, 32768, step=1024, value=2048, label="Max Tokens") | |
with gr.Accordion("Advanced Settings", open=False): | |
temperature_slider = gr.Slider(0.1, 2.0, value=0.8, label="Temperature") | |
top_k_slider = gr.Slider(1, 100, step=1, value=50, label="Top-k") | |
top_p_slider = gr.Slider(0.1, 1.0, value=0.95, label="Top-p") | |
repetition_penalty_slider = gr.Slider(1.0, 2.0, value=1.0, label="Repetition Penalty") | |
with gr.Column(scale=4): | |
chatbot = gr.Chatbot(label="Chat", type="messages") | |
with gr.Row(): | |
user_input = gr.Textbox(placeholder="Type your message...", scale=3) | |
submit_button = gr.Button("Send", variant="primary", scale=1) | |
clear_button = gr.Button("Clear", scale=1) | |
gr.Markdown("**Try these examples:**") | |
with gr.Row(): | |
for name in example_messages: | |
btn = gr.Button(name) | |
btn.click(fn=lambda n=name: gr.update(value=example_messages[n]), inputs=None, outputs=user_input) | |
submit_button.click( | |
fn=generate_response, | |
inputs=[user_input, max_tokens_slider, temperature_slider, top_k_slider, top_p_slider, repetition_penalty_slider, history_state], | |
outputs=[chatbot, history_state] | |
).then(lambda: gr.update(value=""), None, user_input) | |
clear_button.click(lambda: ([], []), None, [chatbot, history_state]) | |
demo.launch(ssr_mode=False) | |