ramysaidagieb commited on
Commit
9f0b7c7
·
verified ·
1 Parent(s): ad6ee04

Update rag_pipeline.py

Browse files
Files changed (1) hide show
  1. rag_pipeline.py +10 -5
rag_pipeline.py CHANGED
@@ -1,6 +1,6 @@
1
  from pathlib import Path
2
  from langchain.chains import RetrievalQA
3
- from transformers import pipeline, AutoTokenizer
4
  from langchain_community.vectorstores import Chroma
5
  from langchain_community.document_loaders import PyMuPDFLoader
6
  from langchain.text_splitter import RecursiveCharacterTextSplitter
@@ -14,8 +14,10 @@ def load_documents(pdf_dir):
14
  return docs
15
 
16
  def load_rag_chain():
17
- Path("data").mkdir(exist_ok=True)
18
- raw_docs = load_documents("data")
 
 
19
  splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
20
  pages = splitter.split_documents(raw_docs)
21
 
@@ -24,13 +26,16 @@ def load_rag_chain():
24
  model_kwargs={"device": "cpu"},
25
  )
26
 
27
- vectordb = Chroma.from_documents(pages, embeddings, persist_directory="chroma_db")
 
28
  retriever = vectordb.as_retriever(search_type="mmr", search_kwargs={"k": 5})
29
 
 
 
30
  hf_pipeline = pipeline(
31
  "text2text-generation",
32
  model="ArabicNLP/mT5-base_ar",
33
- tokenizer=AutoTokenizer.from_pretrained("ArabicNLP/mT5-base_ar"),
34
  max_new_tokens=512,
35
  temperature=0.3,
36
  device=-1,
 
1
  from pathlib import Path
2
  from langchain.chains import RetrievalQA
3
+ from transformers import pipeline, T5Tokenizer
4
  from langchain_community.vectorstores import Chroma
5
  from langchain_community.document_loaders import PyMuPDFLoader
6
  from langchain.text_splitter import RecursiveCharacterTextSplitter
 
14
  return docs
15
 
16
  def load_rag_chain():
17
+ pdf_dir = Path("data")
18
+ pdf_dir.mkdir(parents=True, exist_ok=True)
19
+
20
+ raw_docs = load_documents(pdf_dir)
21
  splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
22
  pages = splitter.split_documents(raw_docs)
23
 
 
26
  model_kwargs={"device": "cpu"},
27
  )
28
 
29
+ vectordb_dir = "chroma_db"
30
+ vectordb = Chroma.from_documents(pages, embeddings, persist_directory=vectordb_dir)
31
  retriever = vectordb.as_retriever(search_type="mmr", search_kwargs={"k": 5})
32
 
33
+ # ✅ Use slow tokenizer explicitly
34
+ tokenizer = T5Tokenizer.from_pretrained("ArabicNLP/mT5-base_ar", use_fast=False)
35
  hf_pipeline = pipeline(
36
  "text2text-generation",
37
  model="ArabicNLP/mT5-base_ar",
38
+ tokenizer=tokenizer,
39
  max_new_tokens=512,
40
  temperature=0.3,
41
  device=-1,