# app.py # ===================== Imports ===================== from qwen_vl_utils import fetch_image from transformers import ( AutoModelForCausalLM, AutoProcessor, TextIteratorStreamer, ) import gradio as gr import librosa import torch import numpy as np import soundfile as sf from threading import Thread from copy import deepcopy import os # NEW: ZeroGPU requirement import spaces # ===================== Load Model (lazy move to GPU) ===================== model_path = "FreedomIntelligence/ShizhenGPT-7B-Omni" # 先在 CPU 上加载权重;等真正需要推理时再迁移到 GPU(由 @spaces.GPU 管理) model = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype=torch.bfloat16, # 如遇不支持,可改成 torch.float16 trust_remote_code=True, ) processor = AutoProcessor.from_pretrained( model_path, trust_remote_code=True ) model.eval() # 某些权重会把 chat_template 放 tokenizer 上;做个兼容 if hasattr(processor, "tokenizer") and hasattr(processor.tokenizer, "chat_template"): processor.chat_template = processor.tokenizer.chat_template # 标志位:仅在第一次推理时把模型迁移到 GPU _MODEL_ON_CUDA = False # ===================== Streaming Generation ===================== def generate_with_streaming(model, processor, text, images=None, audios=None, history=None): # Process images processed_images = None if images: text = "".join(["<|vision_start|><|image_pad|><|vision_end|>"] * len(images)) + text processed_images = [ fetch_image({"type": "image", "image": img, "max_pixels": 360 * 420}) for img in images if img is not None ] # Process audios processed_audios = None if audios: text = "".join(["<|audio_bos|><|AUDIO|><|audio_eos|>"] * len(audios)) + text processed_audios = [audio for audio in audios if audio is not None] # Build conversation history messages = [] if history: for user_msg, assistant_msg in history: messages.append({"role": "user", "content": user_msg}) if assistant_msg: messages.append({"role": "assistant", "content": assistant_msg}) # Clean multimodal tokens from previous history for m in messages: m["content"] = m["content"].replace("<|audio_bos|><|AUDIO|><|audio_eos|>", "").replace( "<|vision_start|><|image_pad|><|vision_end|>", "" ) # Add current user input messages.append({"role": "user", "content": text}) # Prepare model input templated = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) or "" input_data = processor( text=[templated], audios=processed_audios, images=processed_images, return_tensors="pt", padding=True, ) # Move tensors to the current model device for k, v in input_data.items(): if hasattr(v, "to"): input_data[k] = v.to(model.device) # Start streaming generation streamer = TextIteratorStreamer( processor.tokenizer, skip_special_tokens=True, skip_prompt=True ) generation_kwargs = dict( **input_data, streamer=streamer, max_new_tokens=1500, do_sample=True, temperature=0.2, top_p=0.8, ) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() # Yield generated tokens in stream for new_text in streamer: yield new_text # ===================== Audio Preprocessing ===================== def process_audio(audio): """Convert recorded/loaded audio into required sampling rate.""" if audio is None: return None try: sr, y = audio if y.ndim > 1: y = y[:, 0] # Keep the first channel only save_path = "./temp.wav" sf.write(save_path, y, sr) # 有些处理器没有 feature_extractor;做个兜底 target_sr = getattr( getattr(processor, "feature_extractor", None), "sampling_rate", 16000 ) y_resampled, _ = librosa.load(save_path, sr=target_sr, mono=True) return y_resampled except Exception as e: print(f"Error processing audio: {e}") return None # ===================== Prediction Function for Gradio (ZeroGPU) ===================== @spaces.GPU(duration=120) # 关键:让 ZeroGPU 能检测到 GPU 函数,并把一次调用的最长占用设长些 def predict(message, image, audio, chatbox): global _MODEL_ON_CUDA, model # 首次调用被装饰的函数时,ZeroGPU 才真正分配 GPU;此时再迁移模型 if not _MODEL_ON_CUDA: model.to("cuda") _MODEL_ON_CUDA = True chat_history = deepcopy(chatbox) processed_audio = [process_audio(audio)] if audio is not None else None processed_image = [image] if image is not None else None chatbox.append([message, ""]) response = "" # Streaming model response for chunk in generate_with_streaming(model, processor, message, processed_image, processed_audio, chat_history): response += chunk chatbox[-1][1] = response # 作为生成器返回,Gradio Chatbot 将实时刷新 yield chatbox print("\n=== Complete Model Response ===") print(response) print("============================\n", flush=True) return chatbox # ===================== CSS for UI ===================== css = """ .gradio-container { background-color: #f7f7f7; font-family: 'Arial', sans-serif; } .chat-message { padding: 15px; border-radius: 10px; margin-bottom: 10px; } .user-message { background-color: #e6f7ff; border-left: 5px solid #1890ff; } .bot-message { background-color: #f2f2f2; border-left: 5px solid #52c41a; } .title { text-align: center; color: #1890ff; font-size: 24px; margin-bottom: 20px; } """ # ===================== Gradio App ===================== with gr.Blocks(css=css) as demo: gr.HTML("

ShizhenGPT: Multimodal LLM for Traditional Chinese Medicine (Vision, Audio, Text)

") with gr.Row(): with gr.Column(scale=2): chatbot = gr.Chatbot(height=500) message = gr.Textbox(label="Your Question", placeholder="Type your question here...") with gr.Row(): submit_btn = gr.Button("Submit", variant="primary") clear_btn = gr.Button("Clear") with gr.Column(scale=1): image_input = gr.Image(type="filepath", label="Upload Image") audio_input = gr.Audio(type="numpy", label="Record or Upload Audio") # examples_list = [ # ["段某,男,49岁。平素性急,时而头晕,有高血压史。今日中午突然昏仆,不省人事,牙关紧闭,两手握固,肢体强痉,面赤身热,苔黄腻,脉弦滑而数。请你分析病情,并给出可能的医治方案。", None, None, []], # ["请分析图中所示舌象,并告诉我哪种舌苔表现与之相符?", "examples/images.JPG", None, []], # 确保该图片路径存在 # ] # gr.Examples( # examples=examples_list, # inputs=[message, image_input, audio_input, chatbot], # === FIX: 绑定真实 inputs # label="Examples", # cache_examples=False, # 可避免在某些环境下的缓存报错 # ) # Submit button: run prediction submit_btn.click( predict, inputs=[message, image_input, audio_input, chatbot], outputs=[chatbot], show_progress=True ).then( lambda: "", # Clear textbox after sending outputs=[message] ) # Clear button: reset inputs and chat history clear_btn.click( lambda: (None, None, None, []), outputs=[message, image_input, audio_input, chatbot] ) # ===================== Run App ===================== if __name__ == "__main__": # ZeroGPU 下建议限制并发,避免重复申请 GPU demo.queue().launch( server_name="0.0.0.0", server_port=7860, share=True )