File size: 3,044 Bytes
07268e3
 
1dd12ee
07268e3
 
 
a529429
4b02b47
759aea4
 
 
 
77e29eb
51f6523
7dc6c22
d32e032
759aea4
4b02b47
a1d9077
4b02b47
759aea4
 
 
 
 
 
 
1dd12ee
759aea4
 
 
a529429
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1dd12ee
759aea4
d32e032
1dd12ee
a1d9077
1dd12ee
60685d1
d32e032
 
 
759aea4
da55442
a529429
 
d32e032
 
 
759aea4
d32e032
759aea4
d32e032
759aea4
a1d9077
759aea4
d32e032
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import huggingface_hub as hf_hub
import time
import openvino_genai as ov_genai
import numpy as np
import gradio as gr
import re
import threading

# 下載模型
model_ids = [
    "OpenVINO/Qwen3-0.6B-int4-ov",
    "OpenVINO/Qwen3-1.7B-int4-ov",
    #"OpenVINO/Qwen3-4B-int4-ov",#不可用
    "OpenVINO/Qwen3-8B-int4-ov",
    "OpenVINO/Qwen3-14B-int4-ov",
    
]

model_name_to_full_id = {model_id.split("/")[-1]: model_id for model_id in model_ids}  #Create Dictionary

for model_id in model_ids:
    model_path = model_id.split("/")[-1]  # Extract model name
    try:
      hf_hub.snapshot_download(model_id, local_dir=model_path, local_dir_use_symlinks=False)
      print(f"Successfully downloaded {model_id} to {model_path}") # Optional: Print confirmation
    except Exception as e:
      print(f"Error downloading {model_id}: {e}") # Handle download errors gracefully

# 建立推理管線 (Initialize with a default model first)
device = "CPU"
default_model_name = "Qwen3-0.6B-int4-ov"  # Choose a default model
# 全局变量,用于存储推理管线、分词器、Markdown 组件和累计文本
pipe = None
tokenizer = None
markdown_component = None  # 初始化
accumulated_text = ""


#  定义同步更新 Markdown 组件的函数
def update_markdown(text):
    global markdown_component
    if markdown_component:
        markdown_component.update(value=text)

# 创建 streamer 函数 (保持原有架构)
def streamer(subword):
    global accumulated_text
    accumulated_text += subword
    print(subword, end='', flush=True)  # 保留打印到控制台
    #  使用线程来异步更新 Markdown 组件
    threading.Thread(target=update_markdown, args=(accumulated_text,)).start() # 异步更新 UI
    return ov_genai.StreamingStatus.RUNNING


def generate_response(prompt, model_name):
    global pipe, tokenizer  # Access the global variables

    model_path = model_name

    print(f"Switching to model: {model_name}")
    pipe = ov_genai.LLMPipeline(model_path, device)
    tokenizer = pipe.get_tokenizer()
    tokenizer.set_chat_template(tokenizer.chat_template)

    try:
        #generated = pipe.generate([prompt], max_length=1024)
        generated = pipe.generate(prompt, streamer=streamer, max_new_tokens=100)
        tokenpersec=f'{generated.perf_metrics.get_throughput().mean:.2f}'

        return tokenpersec, generated
    except Exception as e:
        return "發生錯誤", "發生錯誤", f"生成回應時發生錯誤:{e}"
    

# 建立 Gradio 介面
model_choices = list(model_name_to_full_id.keys())

demo = gr.Interface(
    fn=generate_response,
    inputs=[
        gr.Textbox(lines=5, label="輸入提示 (Prompt)"),
        gr.Dropdown(choices=model_choices, value=default_model_name, label="選擇模型") # Added dropdown
    ],
    outputs=[
        gr.Textbox(label="tokens/sec"),
        gr.Textbox(label="回應"),
    ],
    title="Qwen3 Model Inference",
    description="基於 Qwen3 推理應用,支援思考過程分離與 GUI。"
)

if __name__ == "__main__":
    demo.launch()