wangrongsheng's picture
Update app.py
6c43b5d verified
# app.py
# ===================== Imports =====================
from qwen_vl_utils import fetch_image
from transformers import (
AutoModelForCausalLM,
AutoProcessor,
TextIteratorStreamer,
)
import gradio as gr
import librosa
import torch
import numpy as np
import soundfile as sf
from threading import Thread
from copy import deepcopy
import os
# NEW: ZeroGPU requirement
import spaces
# ===================== Load Model (lazy move to GPU) =====================
model_path = "FreedomIntelligence/ShizhenGPT-7B-Omni"
# 先在 CPU 上加载权重;等真正需要推理时再迁移到 GPU(由 @spaces.GPU 管理)
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16, # 如遇不支持,可改成 torch.float16
trust_remote_code=True,
)
processor = AutoProcessor.from_pretrained(
model_path,
trust_remote_code=True
)
model.eval()
# 某些权重会把 chat_template 放 tokenizer 上;做个兼容
if hasattr(processor, "tokenizer") and hasattr(processor.tokenizer, "chat_template"):
processor.chat_template = processor.tokenizer.chat_template
# 标志位:仅在第一次推理时把模型迁移到 GPU
_MODEL_ON_CUDA = False
# ===================== Streaming Generation =====================
def generate_with_streaming(model, processor, text, images=None, audios=None, history=None):
# Process images
processed_images = None
if images:
text = "".join(["<|vision_start|><|image_pad|><|vision_end|>"] * len(images)) + text
processed_images = [
fetch_image({"type": "image", "image": img, "max_pixels": 360 * 420})
for img in images
if img is not None
]
# Process audios
processed_audios = None
if audios:
text = "".join(["<|audio_bos|><|AUDIO|><|audio_eos|>"] * len(audios)) + text
processed_audios = [audio for audio in audios if audio is not None]
# Build conversation history
messages = []
if history:
for user_msg, assistant_msg in history:
messages.append({"role": "user", "content": user_msg})
if assistant_msg:
messages.append({"role": "assistant", "content": assistant_msg})
# Clean multimodal tokens from previous history
for m in messages:
m["content"] = m["content"].replace("<|audio_bos|><|AUDIO|><|audio_eos|>", "").replace(
"<|vision_start|><|image_pad|><|vision_end|>", ""
)
# Add current user input
messages.append({"role": "user", "content": text})
# Prepare model input
templated = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) or ""
input_data = processor(
text=[templated],
audios=processed_audios,
images=processed_images,
return_tensors="pt",
padding=True,
)
# Move tensors to the current model device
for k, v in input_data.items():
if hasattr(v, "to"):
input_data[k] = v.to(model.device)
# Start streaming generation
streamer = TextIteratorStreamer(
processor.tokenizer, skip_special_tokens=True, skip_prompt=True
)
generation_kwargs = dict(
**input_data,
streamer=streamer,
max_new_tokens=1500,
do_sample=True,
temperature=0.2,
top_p=0.8,
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# Yield generated tokens in stream
for new_text in streamer:
yield new_text
# ===================== Audio Preprocessing =====================
def process_audio(audio):
"""Convert recorded/loaded audio into required sampling rate."""
if audio is None:
return None
try:
sr, y = audio
if y.ndim > 1:
y = y[:, 0] # Keep the first channel only
save_path = "./temp.wav"
sf.write(save_path, y, sr)
# 有些处理器没有 feature_extractor;做个兜底
target_sr = getattr(
getattr(processor, "feature_extractor", None), "sampling_rate", 16000
)
y_resampled, _ = librosa.load(save_path, sr=target_sr, mono=True)
return y_resampled
except Exception as e:
print(f"Error processing audio: {e}")
return None
# ===================== Prediction Function for Gradio (ZeroGPU) =====================
@spaces.GPU(duration=120) # 关键:让 ZeroGPU 能检测到 GPU 函数,并把一次调用的最长占用设长些
def predict(message, image, audio, chatbox):
global _MODEL_ON_CUDA, model
# 首次调用被装饰的函数时,ZeroGPU 才真正分配 GPU;此时再迁移模型
if not _MODEL_ON_CUDA:
model.to("cuda")
_MODEL_ON_CUDA = True
chat_history = deepcopy(chatbox)
processed_audio = [process_audio(audio)] if audio is not None else None
processed_image = [image] if image is not None else None
chatbox.append([message, ""])
response = ""
# Streaming model response
for chunk in generate_with_streaming(model, processor, message, processed_image, processed_audio, chat_history):
response += chunk
chatbox[-1][1] = response
# 作为生成器返回,Gradio Chatbot 将实时刷新
yield chatbox
print("\n=== Complete Model Response ===")
print(response)
print("============================\n", flush=True)
return chatbox
# ===================== CSS for UI =====================
css = """
.gradio-container {
background-color: #f7f7f7;
font-family: 'Arial', sans-serif;
}
.chat-message {
padding: 15px;
border-radius: 10px;
margin-bottom: 10px;
}
.user-message {
background-color: #e6f7ff;
border-left: 5px solid #1890ff;
}
.bot-message {
background-color: #f2f2f2;
border-left: 5px solid #52c41a;
}
.title {
text-align: center;
color: #1890ff;
font-size: 24px;
margin-bottom: 20px;
}
"""
# ===================== Gradio App =====================
with gr.Blocks(css=css) as demo:
gr.HTML("<h1 class='title'>ShizhenGPT: Multimodal LLM for Traditional Chinese Medicine (Vision, Audio, Text)</h1>")
with gr.Row():
with gr.Column(scale=2):
chatbot = gr.Chatbot(height=500)
message = gr.Textbox(label="Your Question", placeholder="Type your question here...")
with gr.Row():
submit_btn = gr.Button("Submit", variant="primary")
clear_btn = gr.Button("Clear")
with gr.Column(scale=1):
image_input = gr.Image(type="filepath", label="Upload Image")
audio_input = gr.Audio(type="numpy", label="Record or Upload Audio")
# Submit button: run prediction
submit_btn.click(
predict,
inputs=[message, image_input, audio_input, chatbot],
outputs=[chatbot],
show_progress=True
).then(
lambda: "", # Clear textbox after sending
outputs=[message]
)
# Clear button: reset inputs and chat history
clear_btn.click(
lambda: (None, None, None, []),
outputs=[message, image_input, audio_input, chatbot]
)
# ===================== Run App =====================
if __name__ == "__main__":
# ZeroGPU 下建议限制并发,避免重复申请 GPU
demo.queue().launch(
server_name="0.0.0.0", server_port=7860, share=True
)