Spaces:
Sleeping
Sleeping
Update rag_pipeline.py
Browse files- rag_pipeline.py +10 -5
rag_pipeline.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
from pathlib import Path
|
2 |
from langchain.chains import RetrievalQA
|
3 |
-
from transformers import pipeline,
|
4 |
from langchain_community.vectorstores import Chroma
|
5 |
from langchain_community.document_loaders import PyMuPDFLoader
|
6 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
@@ -14,8 +14,10 @@ def load_documents(pdf_dir):
|
|
14 |
return docs
|
15 |
|
16 |
def load_rag_chain():
|
17 |
-
Path("data")
|
18 |
-
|
|
|
|
|
19 |
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
|
20 |
pages = splitter.split_documents(raw_docs)
|
21 |
|
@@ -24,13 +26,16 @@ def load_rag_chain():
|
|
24 |
model_kwargs={"device": "cpu"},
|
25 |
)
|
26 |
|
27 |
-
|
|
|
28 |
retriever = vectordb.as_retriever(search_type="mmr", search_kwargs={"k": 5})
|
29 |
|
|
|
|
|
30 |
hf_pipeline = pipeline(
|
31 |
"text2text-generation",
|
32 |
model="ArabicNLP/mT5-base_ar",
|
33 |
-
tokenizer=
|
34 |
max_new_tokens=512,
|
35 |
temperature=0.3,
|
36 |
device=-1,
|
|
|
1 |
from pathlib import Path
|
2 |
from langchain.chains import RetrievalQA
|
3 |
+
from transformers import pipeline, T5Tokenizer
|
4 |
from langchain_community.vectorstores import Chroma
|
5 |
from langchain_community.document_loaders import PyMuPDFLoader
|
6 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
|
14 |
return docs
|
15 |
|
16 |
def load_rag_chain():
|
17 |
+
pdf_dir = Path("data")
|
18 |
+
pdf_dir.mkdir(parents=True, exist_ok=True)
|
19 |
+
|
20 |
+
raw_docs = load_documents(pdf_dir)
|
21 |
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
|
22 |
pages = splitter.split_documents(raw_docs)
|
23 |
|
|
|
26 |
model_kwargs={"device": "cpu"},
|
27 |
)
|
28 |
|
29 |
+
vectordb_dir = "chroma_db"
|
30 |
+
vectordb = Chroma.from_documents(pages, embeddings, persist_directory=vectordb_dir)
|
31 |
retriever = vectordb.as_retriever(search_type="mmr", search_kwargs={"k": 5})
|
32 |
|
33 |
+
# ✅ Use slow tokenizer explicitly
|
34 |
+
tokenizer = T5Tokenizer.from_pretrained("ArabicNLP/mT5-base_ar", use_fast=False)
|
35 |
hf_pipeline = pipeline(
|
36 |
"text2text-generation",
|
37 |
model="ArabicNLP/mT5-base_ar",
|
38 |
+
tokenizer=tokenizer,
|
39 |
max_new_tokens=512,
|
40 |
temperature=0.3,
|
41 |
device=-1,
|