hsuwill000 commited on
Commit
3f48b5b
·
verified ·
1 Parent(s): a9a7c72

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -24
app.py CHANGED
@@ -1,30 +1,43 @@
1
  import gradio as gr
2
  from transformers import AutoTokenizer
3
  from optimum.intel import OVModelForCausalLM
4
- from fastapi import FastAPI
5
- import uvicorn
6
- from pydantic import BaseModel
7
  import warnings
8
- warnings.filterwarnings("ignore", category=DeprecationWarning, message="__array__ implementation doesn't accept a copy keyword")
9
 
 
10
 
11
- # 模型與標記器載入(你的原始代碼)
12
  model_id = "hsuwill000/DeepSeek-R1-Distill-Qwen-1.5B-openvino"
13
  print("Loading model...")
14
  model = OVModelForCausalLM.from_pretrained(model_id, device_map="auto")
15
  print("Loading tokenizer...")
16
  tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
17
 
18
- def respond(prompt, history):
19
- messages = [
20
- {"role": "system", "content": "Answer the question in English only."},
21
- {"role": "user", "content": prompt}
22
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  text = tokenizer.apply_chat_template(
24
  messages,
25
  tokenize=False,
26
  add_generation_prompt=True
27
  )
 
 
28
  model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
29
  generated_ids = model.generate(
30
  **model_inputs,
@@ -38,10 +51,22 @@ def respond(prompt, history):
38
  ]
39
  response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
40
  response = response.replace("<think>", "**THINK**").replace("</think>", "**THINK**").strip()
 
 
 
 
41
  return response
42
 
 
 
 
 
 
 
 
43
  with gr.Blocks() as demo:
44
  gr.Markdown("# DeepSeek-R1-Distill-Qwen-1.5B-openvino")
 
45
  with gr.Tabs():
46
  with gr.TabItem("聊天"):
47
  chat_if = gr.Interface(
@@ -49,24 +74,16 @@ with gr.Blocks() as demo:
49
  inputs=gr.Textbox(label="Prompt", placeholder="請輸入訊息..."),
50
  outputs=gr.Textbox(label="Response", interactive=False),
51
  api_name="hchat",
52
- title="DeepSeek-R1-Distill-Qwen-1.5B-openvino",
53
  description="回傳輸入內容的測試 API",
54
  )
55
-
56
 
57
- app = FastAPI()
58
- class Prompt(BaseModel):
59
- prompt: str
60
 
61
- def maxtest(prompt: str) -> str:
62
- # 在此处实现您的逻辑
63
- return f"您输入的内容是: {prompt}"
64
-
65
- @app.post("/maxtest")
66
- async def call_maxtest(prompt: Prompt):
67
- response = maxtest(prompt.prompt)
68
- return {"response": response}
69
 
70
  if __name__ == "__main__":
71
  print("Launching Gradio app...")
72
- demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer
3
  from optimum.intel import OVModelForCausalLM
 
 
 
4
  import warnings
 
5
 
6
+ warnings.filterwarnings("ignore", category=DeprecationWarning, message="__array__ implementation doesn't accept a copy keyword")
7
 
8
+ # 模型與標記器載入
9
  model_id = "hsuwill000/DeepSeek-R1-Distill-Qwen-1.5B-openvino"
10
  print("Loading model...")
11
  model = OVModelForCausalLM.from_pretrained(model_id, device_map="auto")
12
  print("Loading tokenizer...")
13
  tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
14
 
15
+ # 對話歷史記錄
16
+ history = []
17
+
18
+ # 回應函數
19
+ def respond(prompt):
20
+ global history # 使用全域變數存 history
21
+
22
+ # 轉換 history 為 messages 格式
23
+ messages = [{"role": "system", "content": "Answer the question in English only."}]
24
+
25
+ # 加入歷史對話
26
+ for user_text, assistant_text in history:
27
+ messages.append({"role": "user", "content": user_text})
28
+ messages.append({"role": "assistant", "content": assistant_text})
29
+
30
+ # 加入當前輸入
31
+ messages.append({"role": "user", "content": prompt})
32
+
33
+ # 轉換為 tokenizer 需要的格式
34
  text = tokenizer.apply_chat_template(
35
  messages,
36
  tokenize=False,
37
  add_generation_prompt=True
38
  )
39
+
40
+ # 進行模型推理
41
  model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
42
  generated_ids = model.generate(
43
  **model_inputs,
 
51
  ]
52
  response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
53
  response = response.replace("<think>", "**THINK**").replace("</think>", "**THINK**").strip()
54
+
55
+ # 更新 history
56
+ history.append((prompt, response))
57
+
58
  return response
59
 
60
+ # 清除歷史記錄
61
+ def clear_history():
62
+ global history
63
+ history = []
64
+ return "History cleared!"
65
+
66
+ # Gradio 介面
67
  with gr.Blocks() as demo:
68
  gr.Markdown("# DeepSeek-R1-Distill-Qwen-1.5B-openvino")
69
+
70
  with gr.Tabs():
71
  with gr.TabItem("聊天"):
72
  chat_if = gr.Interface(
 
74
  inputs=gr.Textbox(label="Prompt", placeholder="請輸入訊息..."),
75
  outputs=gr.Textbox(label="Response", interactive=False),
76
  api_name="hchat",
77
+ title="DeepSeek-R1-Distill-Qwen-1.5B-openvino(with history)",
78
  description="回傳輸入內容的測試 API",
79
  )
 
80
 
81
+ with gr.Row():
82
+ clear_button = gr.Button("🧹 Clear History")
 
83
 
84
+ # 點擊按鈕清除 history
85
+ clear_button.click(fn=clear_history, inputs=[], outputs=[])
 
 
 
 
 
 
86
 
87
  if __name__ == "__main__":
88
  print("Launching Gradio app...")
89
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=True)