hsuwill000 commited on
Commit
a12a90a
·
verified ·
1 Parent(s): 3f48b5b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -43
app.py CHANGED
@@ -1,60 +1,76 @@
1
  import gradio as gr
2
  from transformers import AutoTokenizer
3
  from optimum.intel import OVModelForCausalLM
 
 
 
4
  import warnings
5
 
6
  warnings.filterwarnings("ignore", category=DeprecationWarning, message="__array__ implementation doesn't accept a copy keyword")
7
 
8
- # 模型與標記器載入
9
  model_id = "hsuwill000/DeepSeek-R1-Distill-Qwen-1.5B-openvino"
10
  print("Loading model...")
11
  model = OVModelForCausalLM.from_pretrained(model_id, device_map="auto")
12
  print("Loading tokenizer...")
13
  tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  # 對話歷史記錄
16
  history = []
17
 
18
- # 回應函數
19
  def respond(prompt):
20
- global history # 使用全域變數存 history
21
-
22
- # 轉換 history 為 messages 格式
23
- messages = [{"role": "system", "content": "Answer the question in English only."}]
24
 
25
- # 加入歷史對話
26
- for user_text, assistant_text in history:
27
- messages.append({"role": "user", "content": user_text})
28
- messages.append({"role": "assistant", "content": assistant_text})
29
-
30
- # 加入當前輸入
31
- messages.append({"role": "user", "content": prompt})
32
 
33
- # 轉換為 tokenizer 需要的格式
34
- text = tokenizer.apply_chat_template(
35
- messages,
36
- tokenize=False,
37
- add_generation_prompt=True
38
- )
39
-
40
- # 進行模型推理
41
- model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
42
- generated_ids = model.generate(
43
- **model_inputs,
44
- max_new_tokens=4096,
45
- temperature=0.7,
46
- top_p=0.9,
47
- do_sample=True
48
- )
49
- generated_ids = [
50
- output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
51
- ]
52
- response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
53
- response = response.replace("<think>", "**THINK**").replace("</think>", "**THINK**").strip()
54
-
55
- # 更新 history
56
  history.append((prompt, response))
57
-
58
  return response
59
 
60
  # 清除歷史記錄
@@ -65,8 +81,8 @@ def clear_history():
65
 
66
  # Gradio 介面
67
  with gr.Blocks() as demo:
68
- gr.Markdown("# DeepSeek-R1-Distill-Qwen-1.5B-openvino")
69
-
70
  with gr.Tabs():
71
  with gr.TabItem("聊天"):
72
  chat_if = gr.Interface(
@@ -74,14 +90,13 @@ with gr.Blocks() as demo:
74
  inputs=gr.Textbox(label="Prompt", placeholder="請輸入訊息..."),
75
  outputs=gr.Textbox(label="Response", interactive=False),
76
  api_name="hchat",
77
- title="DeepSeek-R1-Distill-Qwen-1.5B-openvino(with history)",
78
- description="回傳輸入內容的測試 API",
79
  )
80
-
81
  with gr.Row():
82
  clear_button = gr.Button("🧹 Clear History")
83
 
84
- # 點擊按鈕清除 history
85
  clear_button.click(fn=clear_history, inputs=[], outputs=[])
86
 
87
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer
3
  from optimum.intel import OVModelForCausalLM
4
+ from sentence_transformers import SentenceTransformer
5
+ import faiss
6
+ import numpy as np
7
  import warnings
8
 
9
  warnings.filterwarnings("ignore", category=DeprecationWarning, message="__array__ implementation doesn't accept a copy keyword")
10
 
11
+ # 載入 OpenVINO 語言模型
12
  model_id = "hsuwill000/DeepSeek-R1-Distill-Qwen-1.5B-openvino"
13
  print("Loading model...")
14
  model = OVModelForCausalLM.from_pretrained(model_id, device_map="auto")
15
  print("Loading tokenizer...")
16
  tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
17
 
18
+ # 載入向量模型 (用來將文本轉換為向量)
19
+ encoder = SentenceTransformer("all-MiniLM-L6-v2")
20
+
21
+ # FAQ 知識庫 (問題 + 回答)
22
+ faq_data = [
23
+ ("What is FAISS?", "FAISS is a library for efficient similarity search and clustering of dense vectors."),
24
+ ("How does FAISS work?", "FAISS uses indexing structures to quickly retrieve the nearest neighbors of a query vector."),
25
+ ("Can FAISS run on GPU?", "Yes, FAISS supports GPU acceleration for faster computation."),
26
+ ("What is OpenVINO?", "OpenVINO is an inference engine optimized for Intel hardware."),
27
+ ("How to fine-tune a model?", "Fine-tuning involves training a model on a specific dataset to adapt it to a particular task."),
28
+ ("What is the best way to optimize inference speed?", "Using quantization and model distillation can significantly improve inference speed.")
29
+ ]
30
+
31
+ # 轉換 FAQ 問題為向量
32
+ faq_questions = [q for q, _ in faq_data]
33
+ faq_answers = [a for _, a in faq_data]
34
+ faq_vectors = np.array(encoder.encode(faq_questions)).astype("float32")
35
+
36
+ # 建立 FAISS 索引
37
+ d = faq_vectors.shape[1] # 向量維度
38
+ index = faiss.IndexFlatL2(d)
39
+ index.add(faq_vectors)
40
+
41
  # 對話歷史記錄
42
  history = []
43
 
44
+ # 查詢函數 (先檢索 FAQ,無匹配則交給模型)
45
  def respond(prompt):
46
+ global history
 
 
 
47
 
48
+ # 將輸入轉換為向量,並用 FAISS 查詢
49
+ query_vector = np.array(encoder.encode([prompt])).astype("float32")
50
+ D, I = index.search(query_vector, 1) # 找最相近的 FAQ
 
 
 
 
51
 
52
+ if D[0][0] < 1.0: # 設定相似度閾值 (數值越低代表越相似) (5.0太大 啥問題都會丟給FAISS)
53
+ response = faq_answers[I[0][0]] # 直接回應 FAQ 答案
54
+ else:
55
+ # 若 FAQ 沒有匹配,則使用語言模型
56
+ messages = [{"role": "system", "content": "Answer the question in English only."}]
57
+ for user_text, assistant_text in history:
58
+ messages.append({"role": "user", "content": user_text})
59
+ messages.append({"role": "assistant", "content": assistant_text})
60
+ messages.append({"role": "user", "content": prompt})
61
+
62
+ text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
63
+ model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
64
+ generated_ids = model.generate(
65
+ **model_inputs,
66
+ max_new_tokens=512,
67
+ temperature=0.7,
68
+ top_p=0.9,
69
+ do_sample=True
70
+ )
71
+ response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
72
+
 
 
73
  history.append((prompt, response))
 
74
  return response
75
 
76
  # 清除歷史記錄
 
81
 
82
  # Gradio 介面
83
  with gr.Blocks() as demo:
84
+ gr.Markdown("# DeepSeek-R1-Distill-Qwen-1.5B-openvino with history,FAISS ")
85
+
86
  with gr.Tabs():
87
  with gr.TabItem("聊天"):
88
  chat_if = gr.Interface(
 
90
  inputs=gr.Textbox(label="Prompt", placeholder="請輸入訊息..."),
91
  outputs=gr.Textbox(label="Response", interactive=False),
92
  api_name="hchat",
93
+ title="DeepSeek-R1 with FAISS FAQ",
94
+ description="This chatbot first searches an FAQ database using FAISS, then responds using a language model if no match is found."
95
  )
96
+
97
  with gr.Row():
98
  clear_button = gr.Button("🧹 Clear History")
99
 
 
100
  clear_button.click(fn=clear_history, inputs=[], outputs=[])
101
 
102
  if __name__ == "__main__":