ramysaidagieb commited on
Commit
77f3883
·
verified ·
1 Parent(s): 3281db1

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +16 -0
  2. rag_pipeline.py +36 -0
  3. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from rag_pipeline import load_rag_chain
3
+
4
+ rag_chain = load_rag_chain()
5
+
6
+ def ask_question(query):
7
+ result = rag_chain.invoke(query)
8
+ return result['result']
9
+
10
+ iface = gr.Interface(fn=ask_question,
11
+ inputs=gr.Textbox(lines=3, label="Ask a Question"),
12
+ outputs="text",
13
+ title="Custom PDF RAG Chatbot")
14
+
15
+ if __name__ == "__main__":
16
+ iface.launch()
rag_pipeline.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from langchain.chains import RetrievalQA
3
+ from transformers import pipeline, AutoTokenizer
4
+ from langchain_community.vectorstores import Chroma
5
+ from langchain_community.document_loaders import DirectoryLoader
6
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
7
+ from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline
8
+
9
+ def load_rag_chain():
10
+ pdf_dir = Path("data")
11
+ loader = DirectoryLoader(str(pdf_dir), glob="*.pdf")
12
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
13
+ pages = loader.load_and_split(text_splitter=text_splitter)
14
+
15
+ embeddings = HuggingFaceEmbeddings(
16
+ model_name="sentence-transformers/all-MiniLM-L6-v2",
17
+ model_kwargs={"device": "cpu"},
18
+ )
19
+
20
+ vectordb_dir = "chroma_db"
21
+ vectordb = Chroma.from_documents(pages, embeddings, persist_directory=vectordb_dir)
22
+ retriever = vectordb.as_retriever(search_type="mmr", search_kwargs={"k": 5})
23
+
24
+ hf_pipeline = pipeline(
25
+ "text-generation",
26
+ model="mistralai/Mistral-7B-Instruct-v0.2",
27
+ tokenizer=AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2"),
28
+ max_new_tokens=512,
29
+ temperature=0.3,
30
+ return_full_text=True,
31
+ device=-1,
32
+ )
33
+ llm = HuggingFacePipeline(pipeline=hf_pipeline)
34
+
35
+ qa_chain = RetrievalQA.from_llm(llm=llm, retriever=retriever)
36
+ return qa_chain
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio
2
+ transformers
3
+ sentence-transformers
4
+ langchain
5
+ langchain-community
6
+ langchain-huggingface
7
+ chromadb