VisoLearn's picture
Update app.py
b4ed6e4 verified
raw
history blame
5.38 kB
import spaces
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
import torch
from threading import Thread
# Model and device configuration
phi4_model_path = "Compumacy/OpenBioLLm-70B"
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# === INITIALIZE EMPTY WEIGHTS ===
init_empty_weights()
# === CONFIGURE 4-BIT QUANTIZATION ===
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
# === LOAD MODEL WITH QUANTIZATION ===
model = AutoModelForCausalLM.from_pretrained(
phi4_model_path,
quantization_config=bnb_config,
torch_dtype=torch.float16,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(phi4_model_path)
# === OFFLOAD TO CPU/DISK ===
model = load_checkpoint_and_dispatch(
model,
phi4_model_path,
device_map="auto",
offload_folder="offload",
offload_state_dict=True,
max_memory={**{i: "12GB" for i in range(torch.cuda.device_count())}, "cpu": "30GB"}
)
# Enable gradient checkpointing if ever fine-tuning
model.gradient_checkpointing_enable()
# Optionally compile for PyTorch >= 2.0
try:
model = torch.compile(model)
except Exception:
pass
# === RESPONSE GENERATOR ===
@spaces.GPU()
def generate_response(user_message, max_tokens, temperature, top_k, top_p, repetition_penalty, history_state):
if not user_message.strip():
return history_state, history_state
# Prompt setup
system_message = (
"Your role as an assistant involves thoroughly exploring questions through a systematic thinking process..."
)
start_tag, sep_tag, end_tag = "<|im_start|>", "<|im_sep|>", "<|im_end|>"
prompt = f"{start_tag}system{sep_tag}{system_message}{end_tag}"
for msg in history_state:
tag = msg["role"]
content = msg["content"]
prompt += f"{start_tag}{tag}{sep_tag}{content}{end_tag}"
prompt += f"{start_tag}user{sep_tag}{user_message}{end_tag}{start_tag}assistant{sep_tag}"
inputs = tokenizer(prompt, return_tensors="pt").to(device)
# Streaming setup
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
generation_kwargs = {
"input_ids": inputs.input_ids,
"attention_mask": inputs.attention_mask,
"max_new_tokens": int(max_tokens),
"do_sample": True,
"temperature": temperature,
"top_k": int(top_k),
"top_p": top_p,
"repetition_penalty": repetition_penalty,
"streamer": streamer
}
# Run generation in thread
Thread(target=model.generate, kwargs=generation_kwargs).start()
assistant_response = ""
new_history = history_state + [
{"role": "user", "content": user_message},
{"role": "assistant", "content": ""}
]
# Stream tokens
for token in streamer:
clean = token.replace(start_tag, "").replace(sep_tag, "").replace(end_tag, "")
assistant_response += clean
new_history[-1]["content"] = assistant_response
yield new_history, new_history
yield new_history, new_history
# === EXAMPLE MESSAGES ===
example_messages = {
"Math reasoning": "If a rectangular prism has a length of 6 cm...",
"Logic puzzle": "Four people (Alex, Blake, Casey, ...)",
"Physics problem": "A ball is thrown upward with an initial velocity..."
}
# === GRADIO APP ===
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# Phi-4 Chat
Try the example problems below to see how the model breaks down complex reasoning.
""" )
history_state = gr.State([])
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Settings")
max_tokens_slider = gr.Slider(64, 32768, step=1024, value=2048, label="Max Tokens")
with gr.Accordion("Advanced Settings", open=False):
temperature_slider = gr.Slider(0.1, 2.0, value=0.8, label="Temperature")
top_k_slider = gr.Slider(1, 100, step=1, value=50, label="Top-k")
top_p_slider = gr.Slider(0.1, 1.0, value=0.95, label="Top-p")
repetition_penalty_slider = gr.Slider(1.0, 2.0, value=1.0, label="Repetition Penalty")
with gr.Column(scale=4):
chatbot = gr.Chatbot(label="Chat", type="messages")
with gr.Row():
user_input = gr.Textbox(placeholder="Type your message...", scale=3)
submit_button = gr.Button("Send", variant="primary", scale=1)
clear_button = gr.Button("Clear", scale=1)
gr.Markdown("**Try these examples:**")
with gr.Row():
for name in example_messages:
btn = gr.Button(name)
btn.click(fn=lambda n=name: gr.update(value=example_messages[n]), inputs=None, outputs=user_input)
submit_button.click(
fn=generate_response,
inputs=[user_input, max_tokens_slider, temperature_slider, top_k_slider, top_p_slider, repetition_penalty_slider, history_state],
outputs=[chatbot, history_state]
).then(lambda: gr.update(value=""), None, user_input)
clear_button.click(lambda: ([], []), None, [chatbot, history_state])
demo.launch(ssr_mode=False)