File size: 1,963 Bytes
2c8d1ab b06ae2b 2c8d1ab 7b5edcc 2c8d1ab 5cd7fd5 2c8d1ab 7b5edcc 705c0ca 0de4c24 b82cda0 705c0ca 0de4c24 b06ae2b 0de4c24 b06ae2b b82cda0 0de4c24 2c8d1ab 5cd7fd5 2c8d1ab 705c0ca 2c8d1ab b06ae2b 705c0ca 3eecdbe 2c8d1ab 705c0ca 2c8d1ab 705c0ca 2c8d1ab b06ae2b 2c8d1ab b06ae2b 2c8d1ab 705c0ca b06ae2b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
import torch
import os
MODEL_NAMES = {
"DeepSeek-R1-Distill-Qwen-7B": "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
"DeepSeek-R1-Distill-Llama-8B": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
}
HF_TOKEN = os.getenv("HF_TOKEN")
def load_model(model_path):
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, token=HF_TOKEN)
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True, token=HF_TOKEN)
if hasattr(config, "quantization_config"):
del config.quantization_config # 刪除量化配置,避免使用 FP8
model = AutoModelForCausalLM.from_pretrained(
model_path,
config=config,
trust_remote_code=True,
token=HF_TOKEN,
torch_dtype=torch.float16,
device_map="auto",
)
return model, tokenizer
current_model, current_tokenizer = load_model("deepseek-ai/DeepSeek-R1-Distill-Llama-8B")
def chat(message, history, model_name):
global current_model, current_tokenizer
if model_name != current_model:
current_model, current_tokenizer = load_model(MODEL_NAMES[model_name])
device = "cuda" if torch.cuda.is_available() else "cpu"
inputs = current_tokenizer(message, return_tensors="pt").to(device)
outputs = current_model.generate(**inputs, max_length=1024)
response = current_tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
with gr.Blocks() as app:
gr.Markdown("## Chatbot with DeepSeek Models")
with gr.Row():
chat_interface = gr.ChatInterface(
chat,
type="messages",
flagging_mode="manual",
save_history=True,
)
model_selector = gr.Dropdown(
choices=list(MODEL_NAMES.keys()),
value="DeepSeek-R1-Distill-Llama-8B",
label="Select Model",
)
app.launch() |