File size: 1,961 Bytes
4d871c7 8b4afb4 4d871c7 8b4afb4 6453441 7160766 8b4afb4 7160766 7f4a0a3 4d871c7 7160766 8b4afb4 bfe0877 7160766 7764a77 8b4afb4 7160766 8b4afb4 7f4a0a3 7160766 8b4afb4 7160766 7764a77 49a4759 246dff9 7160766 76da388 7160766 8b4afb4 7160766 4d871c7 7160766 21f391c d4c56cf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
import gradio as gr
from huggingface_hub import InferenceClient
from optimum.intel import OVModelForCausalLM
from transformers import AutoTokenizer, pipeline
# 載入模型和標記器
model_id = "hsuwill000/DeepSeek-R1-Distill-Qwen-1.5B-openvino"
print("Loading model...")
model = OVModelForCausalLM.from_pretrained(model_id, device_map="auto")
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True,)
def respond(prompt, history):
# 構建聊天模板
messages = [
{"role": "system", "content": "用戶是繁體中文使用者. 包括think 回答限縮在1024token"},
{"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
print("Chat template text:", text)
# 將文本轉換為模型輸入
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
print("Model inputs:", model_inputs)
# 生成回應
generated_ids = model.generate(
**model_inputs,
max_new_tokens=2048
)
print("Generated IDs:", generated_ids)
# 解碼生成的 token IDs
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
print("Decoded response:", response)
# **去除 `<think>` 及其他無用內容**
response = response.replace("<think>", "**THINK**").replace("</think>", "**THINK**").strip()
# 返回回應
return response
# 設定 Gradio 的聊天界面
demo = gr.ChatInterface(
fn=respond,
title="Qwen2.5-0.5B-Instruct-openvino-4bit",
description="Qwen2.5-0.5B-Instruct-openvino-4bit"
)
if __name__ == "__main__":
print("Launching Gradio app...")
#demo.launch(server_name="0.0.0.0", server_port=7860)
demo.launch()
|