File size: 1,930 Bytes
4d871c7 76cb536 0801ebc 4d871c7 0801ebc 6453441 7160766 8b4afb4 7160766 155b74f 4d871c7 7160766 8b4afb4 a7464e5 7160766 155b74f 8b4afb4 7160766 8b4afb4 a7464e5 155b74f 7160766 8b4afb4 7160766 246dff9 76da388 7160766 155b74f 0801ebc 7fae2e6 0801ebc ec76eef 7fae2e6 3486524 65205db 7fae2e6 0801ebc 3486524 0745678 4d871c7 e5d3a7a 724fc9d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
import gradio as gr
from transformers import AutoTokenizer
from optimum.intel import OVModelForCausalLM
# 模型與標記器載入(你的原始代碼)
model_id = "hsuwill000/DeepSeek-R1-Distill-Qwen-1.5B-openvino"
print("Loading model...")
model = OVModelForCausalLM.from_pretrained(model_id, device_map="auto")
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
def respond(prompt, history):
messages = [
{"role": "system", "content": "使用中文。"},
{"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
generated_ids = model.generate(
**model_inputs,
max_new_tokens=4096,
temperature=0.7,
top_p=0.9,
do_sample=True
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
response = response.replace("<think>", "**THINK**").replace("</think>", "**THINK**").strip()
return response
with gr.Blocks() as demo:
gr.Markdown("# DeepSeek-R1-Distill-Qwen-1.5B-openvino")
with gr.Tabs():
with gr.TabItem("聊天"):
chat_if = gr.Interface(
fn=respond,
inputs=gr.Textbox(label="Prompt", placeholder="請輸入訊息..."),
outputs=gr.Textbox(label="Response", interactive=False),
api_name="/hchat",
title="DeepSeek-R1-Distill-Qwen-1.5B-openvino",
description="回傳輸入內容的測試 API",
)
if __name__ == "__main__":
print("Launching Gradio app...")
demo.launch(server_name="0.0.0.0", server_port=7860)
|