Almaatla commited on
Commit
8c7bb55
·
1 Parent(s): 6bdbd0f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +258 -0
app.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import langchain
2
+ from langchain.embeddings import SentenceTransformerEmbeddings
3
+ from langchain.chains.question_answering import load_qa_chain
4
+ from langchain.document_loaders import UnstructuredPDFLoader,UnstructuredWordDocumentLoader
5
+ from langchain.indexes import VectorstoreIndexCreator
6
+ from langchain.vectorstores import FAISS
7
+ from langchain import HuggingFaceHub
8
+ from langchain import PromptTemplate
9
+ from langchain.chat_models import ChatOpenAI
10
+ from zipfile import ZipFile
11
+ import gradio as gr
12
+ import openpyxl
13
+ import os
14
+ import shutil
15
+ from langchain.schema import Document
16
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
17
+ import tiktoken
18
+ import secrets
19
+
20
+ tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo")
21
+
22
+ # create the length function
23
+ def tiktoken_len(text):
24
+ tokens = tokenizer.encode(
25
+ text,
26
+ disallowed_special=()
27
+ )
28
+ return len(tokens)
29
+
30
+ text_splitter = RecursiveCharacterTextSplitter(
31
+ chunk_size=400,
32
+ chunk_overlap=20,
33
+ length_function=tiktoken_len,
34
+ separators=["\n\n", "\n", " ", ""]
35
+ )
36
+
37
+ embeddings = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
38
+ foo = Document(page_content='foo is fou!',metadata={"source":'foo source'})
39
+
40
+ def reset_database(ui_session_id):
41
+ session_id = f"PDFAISS-{ui_session_id}"
42
+ if 'drive' in session_id:
43
+ print("RESET DATABASE: session_id contains 'drive' !!")
44
+ return None
45
+
46
+ try:
47
+ shutil.rmtree(session_id)
48
+ except:
49
+ print(f'no {session_id} directory present')
50
+
51
+ try:
52
+ os.remove(f"{session_id}.zip")
53
+ except:
54
+ print("no {session_id}.zip present")
55
+
56
+ return None
57
+
58
+ def is_duplicate(split_docs,db):
59
+ epsilon=0.0
60
+ print(f"DUPLICATE: Treating: {split_docs[0].metadata['source'].split('/')[-1]}")
61
+ for i in range(min(3,len(split_docs))):
62
+ query = split_docs[i].page_content
63
+ docs = db.similarity_search_with_score(query,k=1)
64
+ _ , score = docs[0]
65
+ epsilon += score
66
+ print(f"DUPLICATE: epsilon: {epsilon}")
67
+ return epsilon < 0.05
68
+
69
+ def merge_split_docs_to_db(split_docs,db,progress,progress_step=0.1):
70
+ progress(progress_step,desc="merging docs")
71
+ if len(split_docs)==0:
72
+ print("MERGE to db: NO docs!!")
73
+ return
74
+
75
+ filename = split_docs[0].metadata['source']
76
+ if is_duplicate(split_docs,db):
77
+ print(f"MERGE: Document is duplicated: {filename}")
78
+ return
79
+ print(f"MERGE: number of split docs: {len(split_docs)}")
80
+ batch = 20
81
+ for i in range(0, len(split_docs), batch):
82
+ progress(i/len(split_docs),desc=f"added {i} chunks of {len(split_docs)} chunks")
83
+ db1 = FAISS.from_documents(split_docs[i:i+batch], embeddings)
84
+ db.merge_from(db1)
85
+ return db
86
+
87
+ def merge_pdf_to_db(filename,db,progress,progress_step=0.1):
88
+ progress_step+=0.05
89
+ progress(progress_step,'unpacking pdf')
90
+ doc = UnstructuredPDFLoader(filename).load()
91
+ doc[0].metadata['source'] = filename.split('/')[-1]
92
+ split_docs = text_splitter.split_documents(doc)
93
+ progress_step+=0.3
94
+ progress(progress_step,'docx unpacked')
95
+ return merge_split_docs_to_db(split_docs,db,progress,progress_step)
96
+
97
+ def merge_docx_to_db(filename,db,progress,progress_step=0.1):
98
+ progress_step+=0.05
99
+ progress(progress_step,'unpacking docx')
100
+ doc = UnstructuredWordDocumentLoader(filename).load()
101
+ doc[0].metadata['source'] = filename.split('/')[-1]
102
+ split_docs = text_splitter.split_documents(doc)
103
+ progress_step+=0.3
104
+ progress(progress_step,'docx unpacked')
105
+ return merge_split_docs_to_db(split_docs,db,progress,progress_step)
106
+
107
+ def merge_txt_to_db(filename,db,progress,progress_step=0.1):
108
+ progress_step+=0.05
109
+ progress(progress_step,'unpacking txt')
110
+ with open(filename) as f:
111
+ docs = text_splitter.split_text(f.read())
112
+ split_docs = [Document(page_content=doc,metadata={'source':filename.split('/')[-1]}) for doc in docs]
113
+ progress_step+=0.3
114
+ progress(progress_step,'txt unpacked')
115
+ return merge_split_docs_to_db(split_docs,db,progress,progress_step)
116
+
117
+ def unpack_zip_file(filename,db,progress):
118
+ with ZipFile(filename, 'r') as zipObj:
119
+ contents = zipObj.namelist()
120
+ print(f"unpack zip: contents: {contents}")
121
+ tmp_directory = filename.split('/')[-1].split('.')[-2]
122
+ shutil.unpack_archive(filename, tmp_directory)
123
+
124
+ if 'index.faiss' in [item.lower() for item in contents]:
125
+ db2 = FAISS.load_local(tmp_directory, embeddings)
126
+ db.merge_from(db2)
127
+ return db
128
+
129
+ for file in contents:
130
+ if file.lower().endswith('.docx'):
131
+ db = merge_docx_to_db(f"{tmp_directory}/{file}",db,progress)
132
+ if file.lower().endswith('.pdf'):
133
+ db = merge_pdf_to_db(f"{tmp_directory}/{file}",db,progress)
134
+ if file.lower().endswith('.txt'):
135
+ db = merge_txt_to_db(f"{tmp_directory}/{file}",db,progress)
136
+ return db
137
+
138
+ def add_files_to_zip(session_id):
139
+ zip_file_name = f"{session_id}.zip"
140
+ with ZipFile(zip_file_name, "w") as zipObj:
141
+ for root, dirs, files in os.walk(session_id):
142
+ for file_name in files:
143
+ file_path = os.path.join(root, file_name)
144
+ arcname = os.path.relpath(file_path, session_id)
145
+ zipObj.write(file_path, arcname)
146
+
147
+ #### UI Functions ####
148
+
149
+ def embed_files(files,ui_session_id,progress=gr.Progress(),progress_step=0.05):
150
+ progress(progress_step,desc="Starting...")
151
+ split_docs=[]
152
+ if len(ui_session_id)==0:
153
+ ui_session_id = secrets.token_urlsafe(16)
154
+ session_id = f"PDFAISS-{ui_session_id}"
155
+
156
+ try:
157
+ db = FAISS.load_local(session_id,embeddings)
158
+ except:
159
+ print(f"SESSION: {session_id} database does not exist, create a FAISS db")
160
+ db = FAISS.from_documents([foo], embeddings)
161
+ db.save_local(session_id)
162
+ print(f"SESSION: {session_id} database created")
163
+
164
+ print("EMBEDDED, before embeddeding: ",session_id,len(db.index_to_docstore_id))
165
+ for file_id,file in enumerate(files):
166
+ file_type = file.name.split('.')[-1].lower()
167
+ source = file.name.split('/')[-1]
168
+ print(f"current file: {source}")
169
+ progress(file_id/len(files),desc=f"Treating {source}")
170
+
171
+ if file_type == 'pdf':
172
+ db = merge_pdf_to_db(file.name,db,progress)
173
+ db.save_local(session_id)
174
+
175
+ if file_type == 'txt':
176
+ db = merge_txt_to_db(file.name,db,progress)
177
+ db.save_local(session_id)
178
+
179
+ if file_type == 'docx':
180
+ db = merge_docx_to_db(file.name,db,progress)
181
+ db.save_local(session_id)
182
+
183
+ if file_type == 'zip':
184
+ db = unpack_zip_file(file.name,db,progress)
185
+ db.save_local(session_id)
186
+
187
+ ### move file to store ###
188
+ progress(progress_step, desc = 'moving file to store')
189
+ directory_path = f"{session_id}/store/"
190
+ if not os.path.exists(directory_path):
191
+ os.makedirs(directory_path)
192
+ shutil.move(file.name, directory_path)
193
+
194
+ ### load the updated db and zip it ###
195
+ progress(progress_step, desc = 'loading db')
196
+ db = FAISS.load_local(session_id,embeddings)
197
+ print("EMBEDDED, after embeddeding: ",session_id,len(db.index_to_docstore_id))
198
+ progress(progress_step, desc = 'zipping db for download')
199
+ add_files_to_zip(session_id)
200
+ print(f"EMBEDDED: db zipped")
201
+ progress(progress_step, desc = 'db zipped')
202
+ return f"{session_id}.zip",ui_session_id
203
+
204
+ def display_docs(docs):
205
+ output_str = ''
206
+ for i, doc in enumerate(docs):
207
+ source = doc.metadata['source'].split('/')[-1]
208
+ output_str += f"Ref: {i+1}\n{repr(doc.page_content)}\nSource: {source}\n\n"
209
+ return output_str
210
+
211
+ def ask_gpt(query, apikey,history,ui_session_id):
212
+ session_id = f"PDFAISS-{ui_session_id}"
213
+ try:
214
+ db = FAISS.load_local(session_id,embeddings)
215
+ print("ASKGPT after loading",session_id,len(db.index_to_docstore_id))
216
+ except:
217
+ print(f"SESSION: {session_id} database does not exist")
218
+ return f"SESSION: {session_id} database does not exist","",""
219
+
220
+ docs = db.similarity_search(query)
221
+ history += f"[query]\n{query}\n[answer]\n"
222
+ if(apikey==""):
223
+ history += f"None\n[references]\n{display_docs(docs)}\n\n"
224
+ return "No answer from GPT", display_docs(docs),history
225
+ else:
226
+ llm = ChatOpenAI(temperature=0, model_name = 'gpt-3.5-turbo', openai_api_key=apikey)
227
+ chain = load_qa_chain(llm, chain_type="stuff")
228
+ answer = chain.run(input_documents=docs, question=query, verbose=True)
229
+ history += f"{answer}\n[references]\n{display_docs(docs)}\n\n"
230
+ return answer,display_docs(docs),history
231
+
232
+ with gr.Blocks() as demo:
233
+ gr.Markdown("Upload your documents and question them.")
234
+ with gr.Tab("Upload PDF & TXT"):
235
+ tb_session_id = gr.Textbox(label='session id')
236
+ docs_input = gr.File(file_count="multiple", file_types=[".txt", ".pdf",".zip",".docx"])
237
+ db_output = gr.outputs.File(label="Download zipped database")
238
+ btn_generate_db = gr.Button("Generate database")
239
+ btn_reset_db = gr.Button("Reset database")
240
+
241
+ with gr.Tab("Ask PDF"):
242
+ with gr.Column():
243
+ api_key = gr.Textbox(placeholder="Leave blank if you don't have any", label="OpenAI API Key",type='password')
244
+ query_input = gr.Textbox(placeholder="Type your question", label="Question")
245
+ btn_askGPT = gr.Button("Answer")
246
+ answer_output = gr.Textbox(label='GPT 3.5 answer')
247
+ answer_output.style(show_copy_button=True)
248
+ sources = gr.Textbox(label='Sources')
249
+ sources.style(show_copy_button=True)
250
+ history = gr.Textbox(label='History')
251
+ history.style(show_copy_button=True)
252
+
253
+ btn_generate_db.click(embed_files, inputs=[docs_input,tb_session_id], outputs=[db_output,tb_session_id])
254
+ btn_reset_db.click(reset_database,inputs=[tb_session_id],outputs=[db_output])
255
+ btn_askGPT.click(ask_gpt, inputs=[query_input,api_key,history,tb_session_id], outputs=[answer_output,sources,history])
256
+
257
+ demo.queue(concurrency_count=10)
258
+ demo.launch(debug=False,share=True)