File size: 2,152 Bytes
2c8d1ab e17536c 2c8d1ab 7b5edcc 2c8d1ab 5cd7fd5 2c8d1ab 7b5edcc 705c0ca 0de4c24 b82cda0 705c0ca 0de4c24 b82cda0 0de4c24 b82cda0 0de4c24 2c8d1ab 3eecdbe e5fb3af 5cd7fd5 2c8d1ab 705c0ca 2c8d1ab 705c0ca 2c8d1ab 705c0ca 3eecdbe 2c8d1ab 705c0ca 2c8d1ab 705c0ca 2c8d1ab 705c0ca 2c8d1ab 8eac994 2c8d1ab 705c0ca 2c8d1ab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
import torch
import os
# 預先定義 Hugging Face 模型
MODEL_NAMES = {
"DeepSeek-R1-Distill-Qwen-7B": "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
"DeepSeek-R1-Distill-Llama-8B": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
}
HF_TOKEN = os.getenv("HF_TOKEN")
def load_model(model_path):
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, token=HF_TOKEN)
# 先載入 config,手動刪除量化設定,防止 FP8 問題
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True, token=HF_TOKEN)
if hasattr(config, "quantization_config"):
del config.quantization_config # 刪除量化配置,避免使用 FP8
model = AutoModelForCausalLM.from_pretrained(
model_path,
config=config, # 使用已移除量化的 config
trust_remote_code=True,
token=HF_TOKEN,
torch_dtype=torch.float16, # 強制 FP16,避免 FP8
device_map="auto",
)
return model, tokenizer
# 預設載入 DeepSeek-R1
current_model, current_tokenizer = load_model("deepseek-ai/DeepSeek-R1-Distill-Llama-8B")
def chat(message, history, model_name):
"""處理聊天訊息"""
global current_model, current_tokenizer
# 若模型不同則切換
if model_name != current_model:
current_model, current_tokenizer = load_model(model_name)
device = "cuda" if torch.cuda.is_available() else "cpu"
inputs = current_tokenizer(message, return_tensors="pt").to(device)
outputs = current_model.generate(**inputs, max_length=1024)
response = current_tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
with gr.Blocks() as app:
gr.Markdown("## Chatbot with DeepSeek Models")
with gr.Row():
chat_interface = gr.ChatInterface(chat, streaming=True, save_history=True)
model_selector = gr.Dropdown(
choices=list(MODEL_NAMES.keys()), value="DeepSeek-R1-Distill-Llama-8B", label="Select Model"
)
chat_interface.append(model_selector)
app.launch()
|