File size: 1,963 Bytes
2c8d1ab
b06ae2b
2c8d1ab
7b5edcc
2c8d1ab
 
5cd7fd5
 
2c8d1ab
 
7b5edcc
 
705c0ca
0de4c24
b82cda0
 
 
 
705c0ca
0de4c24
b06ae2b
0de4c24
 
b06ae2b
b82cda0
0de4c24
2c8d1ab
 
5cd7fd5
2c8d1ab
 
 
705c0ca
2c8d1ab
b06ae2b
705c0ca
3eecdbe
 
2c8d1ab
 
705c0ca
2c8d1ab
 
 
 
705c0ca
2c8d1ab
b06ae2b
 
 
 
 
 
2c8d1ab
b06ae2b
 
 
2c8d1ab
705c0ca
b06ae2b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
import torch
import os

MODEL_NAMES = {
    "DeepSeek-R1-Distill-Qwen-7B": "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
    "DeepSeek-R1-Distill-Llama-8B": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
}

HF_TOKEN = os.getenv("HF_TOKEN")

def load_model(model_path):
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, token=HF_TOKEN)
    config = AutoConfig.from_pretrained(model_path, trust_remote_code=True, token=HF_TOKEN)
    if hasattr(config, "quantization_config"):
        del config.quantization_config  # 刪除量化配置,避免使用 FP8

    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        config=config,
        trust_remote_code=True,
        token=HF_TOKEN,
        torch_dtype=torch.float16,
        device_map="auto",
    )
    return model, tokenizer

current_model, current_tokenizer = load_model("deepseek-ai/DeepSeek-R1-Distill-Llama-8B")

def chat(message, history, model_name):
    global current_model, current_tokenizer

    if model_name != current_model:
        current_model, current_tokenizer = load_model(MODEL_NAMES[model_name])

    device = "cuda" if torch.cuda.is_available() else "cpu"
    inputs = current_tokenizer(message, return_tensors="pt").to(device)
    outputs = current_model.generate(**inputs, max_length=1024)
    response = current_tokenizer.decode(outputs[0], skip_special_tokens=True)

    return response

with gr.Blocks() as app:
    gr.Markdown("## Chatbot with DeepSeek Models")

    with gr.Row():
        chat_interface = gr.ChatInterface(
            chat,
            type="messages",
            flagging_mode="manual",
            save_history=True,
        )
        model_selector = gr.Dropdown(
            choices=list(MODEL_NAMES.keys()),
            value="DeepSeek-R1-Distill-Llama-8B",
            label="Select Model",
        )


app.launch()