File size: 2,152 Bytes
2c8d1ab
e17536c
2c8d1ab
7b5edcc
2c8d1ab
 
 
5cd7fd5
 
2c8d1ab
 
7b5edcc
 
705c0ca
0de4c24
b82cda0
 
 
 
 
 
705c0ca
0de4c24
b82cda0
0de4c24
 
b82cda0
 
0de4c24
2c8d1ab
 
3eecdbe
e5fb3af
5cd7fd5
2c8d1ab
705c0ca
2c8d1ab
 
 
705c0ca
2c8d1ab
 
 
705c0ca
3eecdbe
 
2c8d1ab
 
705c0ca
2c8d1ab
 
705c0ca
2c8d1ab
 
705c0ca
2c8d1ab
 
 
8eac994
2c8d1ab
705c0ca
2c8d1ab
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig  
import torch
import os

# 預先定義 Hugging Face 模型
MODEL_NAMES = {
    "DeepSeek-R1-Distill-Qwen-7B": "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
    "DeepSeek-R1-Distill-Llama-8B": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
}

HF_TOKEN = os.getenv("HF_TOKEN")

def load_model(model_path):
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, token=HF_TOKEN)
    
    # 先載入 config,手動刪除量化設定,防止 FP8 問題
    config = AutoConfig.from_pretrained(model_path, trust_remote_code=True, token=HF_TOKEN)
    if hasattr(config, "quantization_config"):
        del config.quantization_config  # 刪除量化配置,避免使用 FP8

    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        config=config,  # 使用已移除量化的 config
        trust_remote_code=True,
        token=HF_TOKEN,
        torch_dtype=torch.float16,  # 強制 FP16,避免 FP8
        device_map="auto",
    )
    return model, tokenizer


# 預設載入 DeepSeek-R1
current_model, current_tokenizer = load_model("deepseek-ai/DeepSeek-R1-Distill-Llama-8B")


def chat(message, history, model_name):
    """處理聊天訊息"""
    global current_model, current_tokenizer

    # 若模型不同則切換
    if model_name != current_model:
        current_model, current_tokenizer = load_model(model_name)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    inputs = current_tokenizer(message, return_tensors="pt").to(device)
    outputs = current_model.generate(**inputs, max_length=1024)
    response = current_tokenizer.decode(outputs[0], skip_special_tokens=True)

    return response


with gr.Blocks() as app:
    gr.Markdown("## Chatbot with DeepSeek Models")

    with gr.Row():
        chat_interface = gr.ChatInterface(chat, streaming=True, save_history=True)
        model_selector = gr.Dropdown(
            choices=list(MODEL_NAMES.keys()), value="DeepSeek-R1-Distill-Llama-8B", label="Select Model"
        )

    chat_interface.append(model_selector)
    app.launch()