File size: 3,063 Bytes
ed14d2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import pandas as pd
import numpy as np
import faiss
import pickle
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from llama_cpp import Llama
import gradio as gr

# =========================
# STEP 1: 載入 Hugging Face Dataset
# =========================
dataset = load_dataset("pcreem/37", split="train")
df = dataset.to_pandas()
df.columns = df.columns.str.strip()  # 清理欄位空白

def make_passage(row):
    return f"""藥品名稱:{row['中文品名']}
英文品名:{row['英文品名']}
主成分:{row['主成分略述']}
劑型:{row['劑型']}
適應症:{row['適應症']}
用法用量:{row['用法用量']}
申請商:{row['申請商名稱']}
製造商:{row['製造商名稱']}
製造廠地址:{row['製造廠廠址']}
包裝:{row['包裝']}
有效日期:{row['有效日期']}
許可證字號:{row['許可證字號']}"""

df["retrieval_passage"] = df.apply(make_passage, axis=1)
passages = df["retrieval_passage"].tolist()

# =========================
# STEP 2: 建立 FAISS 檢索
# =========================
embedding_model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2')
embeddings = embedding_model.encode(passages, show_progress_bar=True)
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(np.array(embeddings).astype("float32"))

# =========================
# STEP 3: 載入 Llama 模型
# =========================
from huggingface_hub import hf_hub_download

model_path = hf_hub_download(
    repo_id="chienweichang/Llama-3-Taiwan-8B-Instruct-GGUF",
    filename="llama-3-taiwan-8B-instruct-q5_1.gguf"
)

llm = Llama(
    model_path=model_path,
    n_gpu_layers=35,
    n_ctx=2048,
    seed=42,
    verbose=False,
)

# =========================
# STEP 4: 定義查詢函式
# =========================
def rag_qa(query, k=3):
    query_embedding = embedding_model.encode([query])
    D, I = index.search(np.array(query_embedding).astype("float32"), k=k)
    top_passages = [passages[idx] for idx in I[0]]

    context = "\n\n---\n\n".join(top_passages)
    system_prompt = "你是一位專業藥師,根據以下藥品資料,回答使用者的問題,請用簡潔中文說明並避免虛構資訊。\n"
    user_prompt = f"{system_prompt}\n以下是參考資料:\n\n{context}\n\n使用者問題:{query}"
    chat_prompt = f"<|start_header_id|>user<|end_header_id|>\n{user_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"

    output = llm(chat_prompt, max_tokens=512, temperature=0.7, top_p=0.9, stop=["<|eot_id|>"])
    answer = output["choices"][0]["text"]
    return answer.strip()

# =========================
# STEP 5: Gradio 介面
# =========================
gr.Interface(
    fn=rag_qa,
    inputs=gr.Textbox(label="請輸入問題", placeholder="例如:感冒藥有什麼選擇?"),
    outputs=gr.Textbox(label="藥師回答"),
    title="台灣藥品問答系統",
    description="輸入藥品相關問題,我會根據台灣合法藥品資料庫回答你!"
).launch()