Pudding48 commited on
Commit
b7371e3
·
verified ·
1 Parent(s): 3202c16

Update qabot.py

Browse files
Files changed (1) hide show
  1. qabot.py +31 -40
qabot.py CHANGED
@@ -5,70 +5,61 @@ from langchain.chains import RetrievalQA
5
  from langchain_community.embeddings import GPT4AllEmbeddings
6
  from langchain_community.vectorstores import FAISS
7
 
8
- # from huggingface_hub import hf_hub_download
9
  # !pip install llama-cpp-python
10
 
11
- from llama_cpp import Llama
12
 
13
- model_file = Llama.from_pretrained(
14
- repo_id="Pudding48/TinyLLamaTest",
15
- filename="tinyllama-1.1b-chat-v1.0.Q8_0.gguf",
16
- )
17
-
18
- # model_file = hf_hub_download(
19
- # repo_id="Pudding48/TinyLlamaTest", # Replace with your model repo
20
- # filename="tinyllama-1.1b-chat-v1.0.Q8_0.gguf",
21
- # cache_dir="model" # Will be created in the Space's environment
22
  # )
23
 
24
- # Cấu hình
25
- #model_file = "model/tinyllama-1.1b-chat-v1.0.Q8_0.gguf"
 
 
 
 
 
26
  vector_dp_path = "vectorstores/db_faiss"
27
 
28
- # Load LLM
29
  def load_llm(model_file):
30
- llm = CTransformers(
31
  model=model_file,
32
  model_type="llama",
33
  temperature=0.01,
34
  config={'gpu_layers': 0},
35
- max_new_tokens=128,
36
  context_length=512
37
  )
38
- return llm
39
 
40
- # Tạo prompt template
41
  def creat_prompt(template):
42
- prompt = PromptTemplate(template=template, input_variables=["context","question"])
43
- return prompt
44
 
45
- # Tạo pipeline chain (thay cho LLMChain)
46
  def create_qa_chain(prompt, llm, db):
47
- llm_chain = RetrievalQA.from_chain_type(
48
- llm = llm,
49
- chain_type = "stuff",
50
- retriever =db.as_retriever(search_kwargs = {"k":1}),
51
- return_source_documents = False,
52
- chain_type_kwargs={'prompt':prompt}
53
  )
54
- return llm_chain
55
 
 
56
  def read_vector_db():
57
- embedding_model = GPT4AllEmbeddings(model_file = "model/all-minilm-l6-v2-q4_0.gguf")
58
- db = FAISS.load_local(vector_dp_path, embedding_model,allow_dangerous_deserialization=True)
59
- return db
60
 
 
61
  db = read_vector_db()
62
  llm = load_llm(model_file)
63
- # Mẫu prompt
64
  template = """<|im_start|>system\nSử dụng thông tin sau đây để trả lời câu hỏi. Nếu bạn không biết câu trả lời, hãy nói không biết, đừng cố tạo ra câu trả lời\n
65
- {context}<|im_end|>\n<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant"""
66
 
67
- # Khởi tạo các thành phần
68
  prompt = creat_prompt(template)
69
- llm_chain =create_qa_chain(prompt, llm, db)
70
-
71
- # Chạy thử chain
72
- question = "Khoa công nghệ thông tin thành lập năm nào ?"
73
- response = llm_chain.invoke({"query": question})
74
- print(response)
 
5
  from langchain_community.embeddings import GPT4AllEmbeddings
6
  from langchain_community.vectorstores import FAISS
7
 
8
+ from huggingface_hub import hf_hub_download
9
  # !pip install llama-cpp-python
10
 
11
+ # from llama_cpp import Llama
12
 
13
+ # model_file = Llama.from_pretrained(
14
+ # repo_id="Pudding48/TinyLLamaTest",
15
+ # filename="tinyllama-1.1b-chat-v1.0.Q8_0.gguf",
 
 
 
 
 
 
16
  # )
17
 
18
+ model_file = hf_hub_download(
19
+ repo_id="Pudding48/TinyLlamaTest", # 🟢 This must be a model repo, not a Space
20
+ filename="tinyllama-1.1b-chat-v1.0.Q8_0.gguf",
21
+ cache_dir="model"
22
+ )
23
+
24
+ # Vector store location
25
  vector_dp_path = "vectorstores/db_faiss"
26
 
27
+ # Load LLM with CTransformers
28
  def load_llm(model_file):
29
+ return CTransformers(
30
  model=model_file,
31
  model_type="llama",
32
  temperature=0.01,
33
  config={'gpu_layers': 0},
34
+ max_new_tokens=128,
35
  context_length=512
36
  )
 
37
 
38
+ # Create the prompt
39
  def creat_prompt(template):
40
+ return PromptTemplate(template=template, input_variables=["context", "question"])
 
41
 
42
+ # Create QA pipeline
43
  def create_qa_chain(prompt, llm, db):
44
+ return RetrievalQA.from_chain_type(
45
+ llm=llm,
46
+ chain_type="stuff",
47
+ retriever=db.as_retriever(search_kwargs={"k": 1}),
48
+ return_source_documents=False,
49
+ chain_type_kwargs={'prompt': prompt}
50
  )
 
51
 
52
+ # Load vector DB
53
  def read_vector_db():
54
+ embedding_model = GPT4AllEmbeddings(model_file=model_file)
55
+ return FAISS.load_local(vector_dp_path, embedding_model, allow_dangerous_deserialization=True)
 
56
 
57
+ # Build everything
58
  db = read_vector_db()
59
  llm = load_llm(model_file)
60
+
61
  template = """<|im_start|>system\nSử dụng thông tin sau đây để trả lời câu hỏi. Nếu bạn không biết câu trả lời, hãy nói không biết, đừng cố tạo ra câu trả lời\n
62
+ {context}<|im_end|>\n<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant"""
63
 
 
64
  prompt = creat_prompt(template)
65
+ llm_chain = create_qa_chain(prompt, llm, db)