Spaces:
Running
on
Zero
Running
on
Zero
# app.py | |
# ===================== Imports ===================== | |
from qwen_vl_utils import fetch_image | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoProcessor, | |
TextIteratorStreamer, | |
) | |
import gradio as gr | |
import librosa | |
import torch | |
import numpy as np | |
import soundfile as sf | |
from threading import Thread | |
from copy import deepcopy | |
import os | |
# NEW: ZeroGPU requirement | |
import spaces | |
# ===================== Load Model (lazy move to GPU) ===================== | |
model_path = "FreedomIntelligence/ShizhenGPT-7B-Omni" | |
# 先在 CPU 上加载权重;等真正需要推理时再迁移到 GPU(由 @spaces.GPU 管理) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_path, | |
torch_dtype=torch.bfloat16, # 如遇不支持,可改成 torch.float16 | |
trust_remote_code=True, | |
) | |
processor = AutoProcessor.from_pretrained( | |
model_path, | |
trust_remote_code=True | |
) | |
model.eval() | |
# 某些权重会把 chat_template 放 tokenizer 上;做个兼容 | |
if hasattr(processor, "tokenizer") and hasattr(processor.tokenizer, "chat_template"): | |
processor.chat_template = processor.tokenizer.chat_template | |
# 标志位:仅在第一次推理时把模型迁移到 GPU | |
_MODEL_ON_CUDA = False | |
# ===================== Streaming Generation ===================== | |
def generate_with_streaming(model, processor, text, images=None, audios=None, history=None): | |
# Process images | |
processed_images = None | |
if images: | |
text = "".join(["<|vision_start|><|image_pad|><|vision_end|>"] * len(images)) + text | |
processed_images = [ | |
fetch_image({"type": "image", "image": img, "max_pixels": 360 * 420}) | |
for img in images | |
if img is not None | |
] | |
# Process audios | |
processed_audios = None | |
if audios: | |
text = "".join(["<|audio_bos|><|AUDIO|><|audio_eos|>"] * len(audios)) + text | |
processed_audios = [audio for audio in audios if audio is not None] | |
# Build conversation history | |
messages = [] | |
if history: | |
for user_msg, assistant_msg in history: | |
messages.append({"role": "user", "content": user_msg}) | |
if assistant_msg: | |
messages.append({"role": "assistant", "content": assistant_msg}) | |
# Clean multimodal tokens from previous history | |
for m in messages: | |
m["content"] = m["content"].replace("<|audio_bos|><|AUDIO|><|audio_eos|>", "").replace( | |
"<|vision_start|><|image_pad|><|vision_end|>", "" | |
) | |
# Add current user input | |
messages.append({"role": "user", "content": text}) | |
# Prepare model input | |
templated = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) or "" | |
input_data = processor( | |
text=[templated], | |
audios=processed_audios, | |
images=processed_images, | |
return_tensors="pt", | |
padding=True, | |
) | |
# Move tensors to the current model device | |
for k, v in input_data.items(): | |
if hasattr(v, "to"): | |
input_data[k] = v.to(model.device) | |
# Start streaming generation | |
streamer = TextIteratorStreamer( | |
processor.tokenizer, skip_special_tokens=True, skip_prompt=True | |
) | |
generation_kwargs = dict( | |
**input_data, | |
streamer=streamer, | |
max_new_tokens=1500, | |
do_sample=True, | |
temperature=0.2, | |
top_p=0.8, | |
) | |
thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
thread.start() | |
# Yield generated tokens in stream | |
for new_text in streamer: | |
yield new_text | |
# ===================== Audio Preprocessing ===================== | |
def process_audio(audio): | |
"""Convert recorded/loaded audio into required sampling rate.""" | |
if audio is None: | |
return None | |
try: | |
sr, y = audio | |
if y.ndim > 1: | |
y = y[:, 0] # Keep the first channel only | |
save_path = "./temp.wav" | |
sf.write(save_path, y, sr) | |
# 有些处理器没有 feature_extractor;做个兜底 | |
target_sr = getattr( | |
getattr(processor, "feature_extractor", None), "sampling_rate", 16000 | |
) | |
y_resampled, _ = librosa.load(save_path, sr=target_sr, mono=True) | |
return y_resampled | |
except Exception as e: | |
print(f"Error processing audio: {e}") | |
return None | |
# ===================== Prediction Function for Gradio (ZeroGPU) ===================== | |
# 关键:让 ZeroGPU 能检测到 GPU 函数,并把一次调用的最长占用设长些 | |
def predict(message, image, audio, chatbox): | |
global _MODEL_ON_CUDA, model | |
# 首次调用被装饰的函数时,ZeroGPU 才真正分配 GPU;此时再迁移模型 | |
if not _MODEL_ON_CUDA: | |
model.to("cuda") | |
_MODEL_ON_CUDA = True | |
chat_history = deepcopy(chatbox) | |
processed_audio = [process_audio(audio)] if audio is not None else None | |
processed_image = [image] if image is not None else None | |
chatbox.append([message, ""]) | |
response = "" | |
# Streaming model response | |
for chunk in generate_with_streaming(model, processor, message, processed_image, processed_audio, chat_history): | |
response += chunk | |
chatbox[-1][1] = response | |
# 作为生成器返回,Gradio Chatbot 将实时刷新 | |
yield chatbox | |
print("\n=== Complete Model Response ===") | |
print(response) | |
print("============================\n", flush=True) | |
return chatbox | |
# ===================== CSS for UI ===================== | |
css = """ | |
.gradio-container { | |
background-color: #f7f7f7; | |
font-family: 'Arial', sans-serif; | |
} | |
.chat-message { | |
padding: 15px; | |
border-radius: 10px; | |
margin-bottom: 10px; | |
} | |
.user-message { | |
background-color: #e6f7ff; | |
border-left: 5px solid #1890ff; | |
} | |
.bot-message { | |
background-color: #f2f2f2; | |
border-left: 5px solid #52c41a; | |
} | |
.title { | |
text-align: center; | |
color: #1890ff; | |
font-size: 24px; | |
margin-bottom: 20px; | |
} | |
""" | |
# ===================== Gradio App ===================== | |
with gr.Blocks(css=css) as demo: | |
gr.HTML("<h1 class='title'>ShizhenGPT: Multimodal LLM for Traditional Chinese Medicine (Vision, Audio, Text)</h1>") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
chatbot = gr.Chatbot(height=500) | |
message = gr.Textbox(label="Your Question", placeholder="Type your question here...") | |
with gr.Row(): | |
submit_btn = gr.Button("Submit", variant="primary") | |
clear_btn = gr.Button("Clear") | |
with gr.Column(scale=1): | |
image_input = gr.Image(type="filepath", label="Upload Image") | |
audio_input = gr.Audio(type="numpy", label="Record or Upload Audio") | |
# Submit button: run prediction | |
submit_btn.click( | |
predict, | |
inputs=[message, image_input, audio_input, chatbot], | |
outputs=[chatbot], | |
show_progress=True | |
).then( | |
lambda: "", # Clear textbox after sending | |
outputs=[message] | |
) | |
# Clear button: reset inputs and chat history | |
clear_btn.click( | |
lambda: (None, None, None, []), | |
outputs=[message, image_input, audio_input, chatbot] | |
) | |
# ===================== Run App ===================== | |
if __name__ == "__main__": | |
# ZeroGPU 下建议限制并发,避免重复申请 GPU | |
demo.queue().launch( | |
server_name="0.0.0.0", server_port=7860, share=True | |
) | |