File size: 7,395 Bytes
10dbb25
 
56b4aa9
 
 
 
10dbb25
56b4aa9
 
 
 
 
 
 
 
10dbb25
 
 
 
56b4aa9
 
10dbb25
 
56b4aa9
10dbb25
56b4aa9
 
10dbb25
 
56b4aa9
 
 
 
 
 
 
10dbb25
 
 
 
 
 
56b4aa9
 
 
 
 
 
 
10dbb25
 
 
 
 
 
56b4aa9
 
 
 
10dbb25
56b4aa9
 
 
 
 
 
10dbb25
56b4aa9
10dbb25
56b4aa9
 
 
10dbb25
 
 
56b4aa9
 
10dbb25
56b4aa9
 
10dbb25
56b4aa9
10dbb25
56b4aa9
 
 
10dbb25
56b4aa9
 
10dbb25
56b4aa9
 
 
 
 
10dbb25
 
 
56b4aa9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10dbb25
 
 
 
 
56b4aa9
 
 
 
 
 
10dbb25
6c43b5d
56b4aa9
10dbb25
56b4aa9
10dbb25
 
 
 
 
 
 
 
56b4aa9
 
 
 
 
 
 
 
10dbb25
56b4aa9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10dbb25
f32d334
10dbb25
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
# app.py
# ===================== Imports =====================
from qwen_vl_utils import fetch_image
from transformers import (
    AutoModelForCausalLM,
    AutoProcessor,
    TextIteratorStreamer,
)
import gradio as gr
import librosa
import torch
import numpy as np
import soundfile as sf
from threading import Thread
from copy import deepcopy
import os

# NEW: ZeroGPU requirement
import spaces


# ===================== Load Model (lazy move to GPU) =====================
model_path = "FreedomIntelligence/ShizhenGPT-7B-Omni"

# 先在 CPU 上加载权重;等真正需要推理时再迁移到 GPU(由 @spaces.GPU 管理)
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.bfloat16,   # 如遇不支持,可改成 torch.float16
    trust_remote_code=True,
)
processor = AutoProcessor.from_pretrained(
    model_path,
    trust_remote_code=True
)

model.eval()
# 某些权重会把 chat_template 放 tokenizer 上;做个兼容
if hasattr(processor, "tokenizer") and hasattr(processor.tokenizer, "chat_template"):
    processor.chat_template = processor.tokenizer.chat_template

# 标志位:仅在第一次推理时把模型迁移到 GPU
_MODEL_ON_CUDA = False


# ===================== Streaming Generation =====================
def generate_with_streaming(model, processor, text, images=None, audios=None, history=None):
    # Process images
    processed_images = None
    if images:
        text = "".join(["<|vision_start|><|image_pad|><|vision_end|>"] * len(images)) + text
        processed_images = [
            fetch_image({"type": "image", "image": img, "max_pixels": 360 * 420})
            for img in images
            if img is not None
        ]

    # Process audios
    processed_audios = None
    if audios:
        text = "".join(["<|audio_bos|><|AUDIO|><|audio_eos|>"] * len(audios)) + text
        processed_audios = [audio for audio in audios if audio is not None]

    # Build conversation history
    messages = []
    if history:
        for user_msg, assistant_msg in history:
            messages.append({"role": "user", "content": user_msg})
            if assistant_msg:
                messages.append({"role": "assistant", "content": assistant_msg})

    # Clean multimodal tokens from previous history
    for m in messages:
        m["content"] = m["content"].replace("<|audio_bos|><|AUDIO|><|audio_eos|>", "").replace(
            "<|vision_start|><|image_pad|><|vision_end|>", ""
        )

    # Add current user input
    messages.append({"role": "user", "content": text})

    # Prepare model input
    templated = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) or ""
    input_data = processor(
        text=[templated],
        audios=processed_audios,
        images=processed_images,
        return_tensors="pt",
        padding=True,
    )

    # Move tensors to the current model device
    for k, v in input_data.items():
        if hasattr(v, "to"):
            input_data[k] = v.to(model.device)

    # Start streaming generation
    streamer = TextIteratorStreamer(
        processor.tokenizer, skip_special_tokens=True, skip_prompt=True
    )
    generation_kwargs = dict(
        **input_data,
        streamer=streamer,
        max_new_tokens=1500,
        do_sample=True,
        temperature=0.2,
        top_p=0.8,
    )

    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    # Yield generated tokens in stream
    for new_text in streamer:
        yield new_text


# ===================== Audio Preprocessing =====================
def process_audio(audio):
    """Convert recorded/loaded audio into required sampling rate."""
    if audio is None:
        return None
    try:
        sr, y = audio
        if y.ndim > 1:
            y = y[:, 0]  # Keep the first channel only

        save_path = "./temp.wav"
        sf.write(save_path, y, sr)

        # 有些处理器没有 feature_extractor;做个兜底
        target_sr = getattr(
            getattr(processor, "feature_extractor", None), "sampling_rate", 16000
        )
        y_resampled, _ = librosa.load(save_path, sr=target_sr, mono=True)
        return y_resampled
    except Exception as e:
        print(f"Error processing audio: {e}")
        return None


# ===================== Prediction Function for Gradio (ZeroGPU) =====================
@spaces.GPU(duration=120)  # 关键:让 ZeroGPU 能检测到 GPU 函数,并把一次调用的最长占用设长些
def predict(message, image, audio, chatbox):
    global _MODEL_ON_CUDA, model

    # 首次调用被装饰的函数时,ZeroGPU 才真正分配 GPU;此时再迁移模型
    if not _MODEL_ON_CUDA:
        model.to("cuda")
        _MODEL_ON_CUDA = True

    chat_history = deepcopy(chatbox)
    processed_audio = [process_audio(audio)] if audio is not None else None
    processed_image = [image] if image is not None else None

    chatbox.append([message, ""])
    response = ""

    # Streaming model response
    for chunk in generate_with_streaming(model, processor, message, processed_image, processed_audio, chat_history):
        response += chunk
        chatbox[-1][1] = response
        # 作为生成器返回,Gradio Chatbot 将实时刷新
        yield chatbox

    print("\n=== Complete Model Response ===")
    print(response)
    print("============================\n", flush=True)

    return chatbox


# ===================== CSS for UI =====================
css = """
.gradio-container {
    background-color: #f7f7f7;
    font-family: 'Arial', sans-serif;
}
.chat-message {
    padding: 15px;
    border-radius: 10px;
    margin-bottom: 10px;
}
.user-message {
    background-color: #e6f7ff;
    border-left: 5px solid #1890ff;
}
.bot-message {
    background-color: #f2f2f2;
    border-left: 5px solid #52c41a;
}
.title {
    text-align: center;
    color: #1890ff;
    font-size: 24px;
    margin-bottom: 20px;
}
"""


# ===================== Gradio App =====================
with gr.Blocks(css=css) as demo:
    gr.HTML("<h1 class='title'>ShizhenGPT: Multimodal LLM for Traditional Chinese Medicine (Vision, Audio, Text)</h1>")

    with gr.Row():
        with gr.Column(scale=2):
            chatbot = gr.Chatbot(height=500)
            message = gr.Textbox(label="Your Question", placeholder="Type your question here...")

            with gr.Row():
                submit_btn = gr.Button("Submit", variant="primary")
                clear_btn = gr.Button("Clear")

        with gr.Column(scale=1):
            image_input = gr.Image(type="filepath", label="Upload Image")
            audio_input = gr.Audio(type="numpy", label="Record or Upload Audio")

    # Submit button: run prediction
    submit_btn.click(
        predict,
        inputs=[message, image_input, audio_input, chatbot],
        outputs=[chatbot],
        show_progress=True
    ).then(
        lambda: "",  # Clear textbox after sending
        outputs=[message]
    )

    # Clear button: reset inputs and chat history
    clear_btn.click(
        lambda: (None, None, None, []),
        outputs=[message, image_input, audio_input, chatbot]
    )


# ===================== Run App =====================
if __name__ == "__main__":
    # ZeroGPU 下建议限制并发,避免重复申请 GPU
    demo.queue().launch(
        server_name="0.0.0.0", server_port=7860, share=True
    )