ramysaidagieb commited on
Commit
2d59ec1
·
verified ·
1 Parent(s): 704db9d

Update rag_pipeline.py

Browse files
Files changed (1) hide show
  1. rag_pipeline.py +7 -11
rag_pipeline.py CHANGED
@@ -14,10 +14,8 @@ def load_documents(pdf_dir):
14
  return docs
15
 
16
  def load_rag_chain():
17
- pdf_dir = Path("data")
18
- pdf_dir.mkdir(parents=True, exist_ok=True)
19
-
20
- raw_docs = load_documents(pdf_dir)
21
  splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
22
  pages = splitter.split_documents(raw_docs)
23
 
@@ -26,19 +24,17 @@ def load_rag_chain():
26
  model_kwargs={"device": "cpu"},
27
  )
28
 
29
- vectordb_dir = "chroma_db"
30
- vectordb = Chroma.from_documents(pages, embeddings, persist_directory=vectordb_dir)
31
  retriever = vectordb.as_retriever(search_type="mmr", search_kwargs={"k": 5})
32
 
33
  hf_pipeline = pipeline(
34
  "text2text-generation",
35
- model="csebuetnlp/mT5_small_finetuned_squad",
36
- tokenizer=AutoTokenizer.from_pretrained("csebuetnlp/mT5_small_finetuned_squad"),
37
  max_new_tokens=512,
38
  temperature=0.3,
39
- device=-1
40
  )
41
  llm = HuggingFacePipeline(pipeline=hf_pipeline)
42
 
43
- qa_chain = RetrievalQA.from_llm(llm=llm, retriever=retriever)
44
- return qa_chain
 
14
  return docs
15
 
16
  def load_rag_chain():
17
+ Path("data").mkdir(exist_ok=True)
18
+ raw_docs = load_documents("data")
 
 
19
  splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
20
  pages = splitter.split_documents(raw_docs)
21
 
 
24
  model_kwargs={"device": "cpu"},
25
  )
26
 
27
+ vectordb = Chroma.from_documents(pages, embeddings, persist_directory="chroma_db")
 
28
  retriever = vectordb.as_retriever(search_type="mmr", search_kwargs={"k": 5})
29
 
30
  hf_pipeline = pipeline(
31
  "text2text-generation",
32
+ model="ArabicNLP/mT5-base_ar",
33
+ tokenizer=AutoTokenizer.from_pretrained("ArabicNLP/mT5-base_ar"),
34
  max_new_tokens=512,
35
  temperature=0.3,
36
+ device=-1,
37
  )
38
  llm = HuggingFacePipeline(pipeline=hf_pipeline)
39
 
40
+ return RetrievalQA.from_llm(llm=llm, retriever=retriever)