File size: 3,957 Bytes
07268e3
 
1dd12ee
07268e3
 
 
a529429
4b02b47
759aea4
 
 
 
50398e9
51f6523
7dc6c22
759aea4
4b02b47
50398e9
4b02b47
50398e9
759aea4
 
50398e9
 
 
 
 
759aea4
50398e9
 
 
 
 
 
 
1dd12ee
759aea4
 
 
50398e9
a529429
 
 
 
 
50398e9
 
a529429
50398e9
 
a529429
 
 
 
 
 
50398e9
 
a529429
50398e9
 
 
 
 
 
 
 
 
 
 
 
 
 
1dd12ee
50398e9
759aea4
50398e9
1dd12ee
50398e9
 
 
 
 
 
759aea4
da55442
a529429
50398e9
 
759aea4
50398e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import huggingface_hub as hf_hub
import time
import openvino_genai as ov_genai
import numpy as np
import gradio as gr
import re
import threading

# 下載模型
model_ids = [
    "OpenVINO/Qwen3-0.6B-int4-ov",
    "OpenVINO/Qwen3-1.7B-int4-ov",
    #"OpenVINO/Qwen3-4B-int4-ov",  #不可用
    "OpenVINO/Qwen3-8B-int4-ov",
    "OpenVINO/Qwen3-14B-int4-ov",
]

model_name_to_full_id = {model_id.split("/")[-1]: model_id for model_id in model_ids}  # Create Dictionary

def download_model(model_id):
    model_path = model_id.split("/")[-1]  # Extract model name
    try:
        hf_hub.snapshot_download(model_id, local_dir=model_path, local_dir_use_symlinks=False)
        print(f"Successfully downloaded {model_id} to {model_path}")
        # 檢查模型檔案是否完整 (可以加入具體的檔案檢查)
        # 例如,檢查必須存在的檔案是否存在,或驗證檔案大小
        return True
    except Exception as e:
        print(f"Error downloading {model_id}: {e}")
        return False

# 下載所有模型
for model_id in model_ids:
    if not download_model(model_id):
        print(f"Failed to download {model_id}, skipping.")

# 建立推理管線 (Initialize with a default model first)
device = "CPU"
default_model_name = "Qwen3-0.6B-int4-ov"  # Choose a default model

# 全局变量,用于存储推理管线、分词器、Markdown 组件和累计文本
pipe = None
tokenizer = None
accumulated_text = ""

# 初始化 Markdown 组件
markdown_component = None  # 在全局範圍初始化

# 建立 Gradio 介面
model_choices = list(model_name_to_full_id.keys())

# 创建 streamer 函数 (保持原有架构)
def streamer(subword):
    global accumulated_text
    accumulated_text += subword
    print(subword, end='', flush=True)  # 保留打印到控制台
    return accumulated_text # 返回更新後的文字,Gradio會自動更新Markdown元件
    

# 模型載入函數
def load_model(model_name):
    global pipe, tokenizer
    model_path = model_name
    print(f"Loading model: {model_name}")
    try:
        pipe = ov_genai.LLMPipeline(model_path, device)
        tokenizer = pipe.get_tokenizer()
        tokenizer.set_chat_template(tokenizer.chat_template)  # 確保 chat template 已設定
        print(f"Model {model_name} loaded successfully.")
        return True
    except Exception as e:
        print(f"Error loading model {model_name}: {e}")
        return False

# 產生回應的函數
def generate_response(prompt, model_name):
    global pipe, tokenizer, accumulated_text

    # 如果模型尚未載入,或需要切換模型,則載入模型
    if pipe is None or pipe.model_name != model_name:
        if not load_model(model_name):
            return "模型載入失敗", "模型載入失敗", "模型載入失敗"
    
    accumulated_text = "" #重置累積文字

    try:
        generated = pipe.generate(prompt, streamer=streamer, max_new_tokens=100)
        tokenpersec = f'{generated.perf_metrics.get_throughput().mean:.2f}'
        return tokenpersec, accumulated_text
    except Exception as e:
        error_message = f"生成回應時發生錯誤:{e}"
        print(error_message)
        return "發生錯誤", "發生錯誤", error_message

with gr.Blocks() as demo:
    markdown_component = gr.Markdown(label="回应")  # 在Blocks內部初始化
    with gr.Row():
        prompt_textbox = gr.Textbox(lines=5, label="輸入提示 (Prompt)")
        model_dropdown = gr.Dropdown(choices=model_choices, value=default_model_name, label="選擇模型")
    with gr.Row():
        token_per_sec_textbox = gr.Textbox(label="tokens/sec")

    def process_input(prompt, model_name):
        tokens_sec, response = generate_response(prompt, model_name)
        return tokens_sec, response

    prompt_textbox.submit(
        fn=process_input,
        inputs=[prompt_textbox, model_dropdown],
        outputs=[token_per_sec_textbox, markdown_component]
    )

demo.launch()