DHEIVER commited on
Commit
1bfd20b
·
verified ·
1 Parent(s): 302b740

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -223
app.py CHANGED
@@ -12,273 +12,112 @@ import chromadb
12
  from unidecode import unidecode
13
  import re
14
 
15
- # Lista de modelos LLM disponíveis
16
  list_llm = [
17
- "mistralai/Mistral-7B-Instruct-v0.2",
18
- "mistralai/Mixtral-8x7B-Instruct-v0.1",
19
- "mistralai/Mistral-7B-Instruct-v0.1",
20
- "google/gemma-7b-it",
21
- "google/gemma-2b-it",
22
- "HuggingFaceH4/zephyr-7b-beta",
23
- "HuggingFaceH4/zephyr-7b-gemma-v0.1",
24
- "meta-llama/Llama-2-7b-chat-hf",
25
- "microsoft/phi-2",
26
- "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
27
- "mosaicml/mpt-7b-instruct",
28
- "tiiuae/falcon-7b-instruct",
29
- "google/flan-t5-xxl"
30
  ]
31
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
32
 
33
- # Função para carregar documentos PDF e dividir em chunks
34
- def load_doc(list_file_path, chunk_size, chunk_overlap):
35
  loaders = [PyPDFLoader(x) for x in list_file_path]
36
  pages = []
37
  for loader in loaders:
38
  pages.extend(loader.load())
39
- text_splitter = RecursiveCharacterTextSplitter(
40
- chunk_size=chunk_size,
41
- chunk_overlap=chunk_overlap
42
- )
43
- doc_splits = text_splitter.split_documents(pages)
44
- return doc_splits
45
 
46
- # Função para criar o banco de dados vetorial
47
- def create_db(splits, collection_name):
48
  embedding = HuggingFaceEmbeddings()
49
- # Usando PersistentClient para persistir o banco de dados
50
  new_client = chromadb.PersistentClient(path="./chroma_db")
51
- vectordb = Chroma.from_documents(
52
- documents=splits,
53
- embedding=embedding,
54
- client=new_client,
55
- collection_name=collection_name,
56
- )
57
- return vectordb
58
 
59
- # Função para inicializar a cadeia de QA com o modelo LLM
60
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
61
- progress(0.1, desc="Inicializando tokenizer da HF...")
62
- progress(0.5, desc="Inicializando Hub da HF...")
63
- if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
64
- llm = HuggingFaceEndpoint(
65
- repo_id=llm_model,
66
- temperature=temperature,
67
- max_new_tokens=max_tokens,
68
- top_k=top_k,
69
- load_in_8bit=True,
70
- )
71
- elif llm_model in ["HuggingFaceH4/zephyr-7b-gemma-v0.1", "mosaicml/mpt-7b-instruct"]:
72
- raise gr.Error("O modelo LLM é muito grande para ser carregado automaticamente no endpoint de inferência gratuito")
73
- elif llm_model == "microsoft/phi-2":
74
- llm = HuggingFaceEndpoint(
75
- repo_id=llm_model,
76
- temperature=temperature,
77
- max_new_tokens=max_tokens,
78
- top_k=top_k,
79
- trust_remote_code=True,
80
- torch_dtype="auto",
81
- )
82
- elif llm_model == "TinyLlama/TinyLlama-1.1B-Chat-v1.0":
83
- llm = HuggingFaceEndpoint(
84
- repo_id=llm_model,
85
- temperature=temperature,
86
- max_new_tokens=250,
87
- top_k=top_k,
88
- )
89
- elif llm_model == "meta-llama/Llama-2-7b-chat-hf":
90
- raise gr.Error("O modelo Llama-2-7b-chat-hf requer uma assinatura Pro...")
91
- else:
92
- llm = HuggingFaceEndpoint(
93
- repo_id=llm_model,
94
- temperature=temperature,
95
- max_new_tokens=max_tokens,
96
- top_k=top_k,
97
- )
98
-
99
- progress(0.75, desc="Definindo memória de buffer...")
100
- memory = ConversationBufferMemory(
101
- memory_key="chat_history",
102
- output_key='answer',
103
- return_messages=True
104
  )
 
 
105
  retriever = vector_db.as_retriever()
106
- progress(0.8, desc="Definindo cadeia de recuperação...")
107
- qa_chain = ConversationalRetrievalChain.from_llm(
108
- llm,
109
- retriever=retriever,
110
- chain_type="stuff",
111
- memory=memory,
112
- return_source_documents=True,
113
- verbose=False,
114
- )
115
  progress(0.9, desc="Concluído!")
116
  return qa_chain
117
 
118
  # Função para gerar um nome de coleção válido
119
  def create_collection_name(filepath):
120
  collection_name = Path(filepath).stem
121
- collection_name = collection_name.replace(" ", "-")
122
- collection_name = unidecode(collection_name)
123
- collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name)
124
- collection_name = collection_name[:50]
125
- if len(collection_name) < 3:
126
- collection_name = collection_name + 'xyz'
127
- if not collection_name[0].isalnum():
128
- collection_name = 'A' + collection_name[1:]
129
- if not collection_name[-1].isalnum():
130
- collection_name = collection_name[:-1] + 'Z'
131
- print('Caminho do arquivo: ', filepath)
132
- print('Nome da coleção: ', collection_name)
133
- return collection_name
134
 
135
- # Função para inicializar o banco de dados
136
- def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()):
137
  list_file_path = [x.name for x in list_file_obj if x is not None]
138
  progress(0.1, desc="Criando nome da coleção...")
139
  collection_name = create_collection_name(list_file_path[0])
140
- progress(0.25, desc="Carregando documento...")
141
- doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
142
  progress(0.5, desc="Gerando banco de dados vetorial...")
143
- vector_db = create_db(doc_splits, collection_name)
144
- progress(0.9, desc="Concluído!")
145
- return vector_db, collection_name, "Completo!"
146
-
147
- # Função para inicializar o modelo LLM
148
- def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
149
  llm_name = list_llm[llm_option]
150
- print("Nome do LLM: ", llm_name)
151
  qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
152
- return qa_chain, "Completo!"
153
-
154
- # Função para formatar o histórico de conversa
155
- def format_chat_history(message, chat_history):
156
- formatted_chat_history = []
157
- for user_message, bot_message in chat_history:
158
- formatted_chat_history.append(f"Usuário: {user_message}")
159
- formatted_chat_history.append(f"Assistente: {bot_message}")
160
- return formatted_chat_history
161
 
162
- # Função para realizar a conversa com o chatbot
163
  def conversation(qa_chain, message, history):
164
- formatted_chat_history = format_chat_history(message, history)
165
  response = qa_chain({"question": message, "chat_history": formatted_chat_history})
166
- response_answer = response["answer"]
167
- if response_answer.find("Resposta útil:") != -1:
168
- response_answer = response_answer.split("Resposta útil:")[-1]
169
- response_sources = response["source_documents"]
170
- response_source1 = response_sources[0].page_content.strip()
171
- response_source2 = response_sources[1].page_content.strip()
172
- response_source3 = response_sources[2].page_content.strip()
173
- response_source1_page = response_sources[0].metadata["page"] + 1
174
- response_source2_page = response_sources[1].metadata["page"] + 1
175
- response_source3_page = response_sources[2].metadata["page"] + 1
176
  new_history = history + [(message, response_answer)]
177
- return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
178
 
179
- # Função para carregar arquivos
180
  def upload_file(file_obj):
181
- list_file_path = []
182
- for idx, file in enumerate(file_obj):
183
- file_path = file_obj.name
184
- list_file_path.append(file_path)
185
- return list_file_path
186
 
 
187
  def demo():
188
  with gr.Blocks(theme="base") as demo:
189
- vector_db = gr.State()
190
- qa_chain = gr.State()
191
- collection_name = gr.State()
192
-
193
- gr.Markdown(
194
- """<center><h2>Chatbot baseado em PDF</center></h2>
195
- <h3>Faça qualquer pergunta sobre seus documentos PDF</h3>""")
196
- gr.Markdown(
197
- """<b>Nota:</b> Este assistente de IA, utilizando Langchain e LLMs de código aberto, realiza geração aumentada por recuperação (RAG) a partir de seus documentos PDF. \
198
- A interface do usuário mostra explicitamente várias etapas para ajudar a entender o fluxo de trabalho do RAG.
199
- Este chatbot leva em consideração perguntas anteriores ao gerar respostas (via memória conversacional), e inclui referências documentais para maior clareza.<br>
200
- <br><b>Aviso:</b> Este espaço usa a CPU básica gratuita do Hugging Face. Algumas etapas e modelos LLM utilizados abaixo (pontos finais de inferência gratuitos) podem levar algum tempo para gerar uma resposta.
201
- """)
202
-
203
  with gr.Tab("Etapa 1 - Carregar PDF"):
204
- with gr.Row():
205
- document = gr.Files(height=100, file_count="multiple", file_types=["pdf"], interactive=True, label="Carregue seus documentos PDF (único ou múltiplos)")
206
- # upload_btn = gr.UploadButton("Carregando documento...", height=100, file_count="multiple", file_types=["pdf"], scale=1)
207
-
208
  with gr.Tab("Etapa 2 - Processar documento"):
209
- with gr.Row():
210
- db_btn = gr.Radio(["ChromaDB"], label="Tipo de banco de dados vetorial", value = "ChromaDB", type="index", info="Escolha o banco de dados vetorial")
211
- with gr.Accordion("Opções avançadas - Divisor de texto do documento", open=False):
212
- with gr.Row():
213
- slider_chunk_size = gr.Slider(minimum = 100, maximum = 1000, value=600, step=20, label="Tamanho do bloco", info="Tamanho do bloco", interactive=True)
214
- with gr.Row():
215
- slider_chunk_overlap = gr.Slider(minimum = 10, maximum = 200, value=40, step=10, label="Sobreposição do bloco", info="Sobreposição do bloco", interactive=True)
216
- with gr.Row():
217
- db_progress = gr.Textbox(label="Inicialização do banco de dados vetorial", value="Nenhum")
218
- with gr.Row():
219
- db_btn = gr.Button("Gerar banco de dados vetorial")
220
-
221
  with gr.Tab("Etapa 3 - Inicializar cadeia de QA"):
222
- with gr.Row():
223
- llm_btn = gr.Radio(list_llm_simple, \
224
- label="Modelos LLM", value = list_llm_simple[0], type="index", info="Escolha seu modelo LLM")
225
- with gr.Accordion("Opções avançadas - Modelo LLM", open=False):
226
- with gr.Row():
227
- slider_temperature = gr.Slider(minimum = 0.01, maximum = 1.0, value=0.7, step=0.1, label="Temperatura", info="Temperatura do modelo", interactive=True)
228
- with gr.Row():
229
- slider_maxtokens = gr.Slider(minimum = 224, maximum = 4096, value=1024, step=32, label="Máximo de Tokens", info="Máximo de tokens do modelo", interactive=True)
230
- with gr.Row():
231
- slider_topk = gr.Slider(minimum = 1, maximum = 10, value=3, step=1, label="Amostras top-k", info="Amostras top-k do modelo", interactive=True)
232
- with gr.Row():
233
- llm_progress = gr.Textbox(value="Nenhum",label="Inicialização da cadeia QA")
234
- with gr.Row():
235
- qachain_btn = gr.Button("Inicializar cadeia de Pergunta e Resposta")
236
 
237
  with gr.Tab("Etapa 4 - Chatbot"):
238
  chatbot = gr.Chatbot(height=300)
239
- with gr.Accordion("Avançado - Referências do documento", open=False):
240
- with gr.Row():
241
- doc_source1 = gr.Textbox(label="Referência 1", lines=2, container=True, scale=20)
242
- source1_page = gr.Number(label="Página", scale=1)
243
- with gr.Row():
244
- doc_source2 = gr.Textbox(label="Referência 2", lines=2, container=True, scale=20)
245
- source2_page = gr.Number(label="Página", scale=1)
246
- with gr.Row():
247
- doc_source3 = gr.Textbox(label="Referência 3", lines=2, container=True, scale=20)
248
- source3_page = gr.Number(label="Página", scale=1)
249
- with gr.Row():
250
- msg = gr.Textbox(placeholder="Digite a mensagem (exemplo: 'Sobre o que é este documento?')", container=True)
251
- with gr.Row():
252
- submit_btn = gr.Button("Enviar mensagem")
253
- clear_btn = gr.ClearButton([msg, chatbot], value="Limpar conversa")
254
 
255
- # Eventos de pré-processamento
256
- #upload_btn.upload(upload_file, inputs=[upload_btn], outputs=[document])
257
- db_btn.click(initialize_database, \
258
- inputs=[document, slider_chunk_size, slider_chunk_overlap], \
259
- outputs=[vector_db, collection_name, db_progress])
260
- qachain_btn.click(initialize_LLM, \
261
- inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], \
262
- outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0,"",0], \
263
- inputs=None, \
264
- outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
265
- queue=False)
266
-
267
- # Eventos do Chatbot
268
- msg.submit(conversation, \
269
- inputs=[qa_chain, msg, chatbot], \
270
- outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
271
- queue=False)
272
- submit_btn.click(conversation, \
273
- inputs=[qa_chain, msg, chatbot], \
274
- outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
275
- queue=False)
276
- clear_btn.click(lambda:[None,"",0,"",0,"",0], \
277
- inputs=None, \
278
- outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
279
- queue=False)
280
- demo.queue().launch(debug=True)
281
-
282
 
 
283
  if __name__ == "__main__":
284
  demo()
 
12
  from unidecode import unidecode
13
  import re
14
 
15
+ # Modelos LLM disponíveis
16
  list_llm = [
17
+ "mistralai/Mistral-7B-Instruct-v0.2", "mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mistral-7B-Instruct-v0.1",
18
+ "google/gemma-7b-it", "google/gemma-2b-it", "HuggingFaceH4/zephyr-7b-beta", "HuggingFaceH4/zephyr-7b-gemma-v0.1",
19
+ "meta-llama/Llama-2-7b-chat-hf", "microsoft/phi-2", "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "mosaicml/mpt-7b-instruct",
20
+ "tiiuae/falcon-7b-instruct", "google/flan-t5-xxl"
 
 
 
 
 
 
 
 
 
21
  ]
22
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
23
 
24
+ # Função de carregamento e divisão de documentos
25
+ def load_and_split_documents(list_file_path, chunk_size, chunk_overlap):
26
  loaders = [PyPDFLoader(x) for x in list_file_path]
27
  pages = []
28
  for loader in loaders:
29
  pages.extend(loader.load())
30
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
31
+ return text_splitter.split_documents(pages)
 
 
 
 
32
 
33
+ # Função para criar banco de dados vetorial com ChromaDB
34
+ def create_vector_db(splits, collection_name):
35
  embedding = HuggingFaceEmbeddings()
 
36
  new_client = chromadb.PersistentClient(path="./chroma_db")
37
+ return Chroma.from_documents(documents=splits, embedding=embedding, client=new_client, collection_name=collection_name)
 
 
 
 
 
 
38
 
39
+ # Função para inicializar a cadeia de QA
40
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
41
+ progress(0.1, desc="Inicializando tokenizer e Hub...")
42
+ llm = HuggingFaceEndpoint(
43
+ repo_id=llm_model, temperature=temperature, max_new_tokens=max_tokens, top_k=top_k, load_in_8bit=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  )
45
+ progress(0.5, desc="Definindo memória de buffer e cadeia de recuperação...")
46
+ memory = ConversationBufferMemory(memory_key="chat_history", output_key='answer', return_messages=True)
47
  retriever = vector_db.as_retriever()
48
+ qa_chain = ConversationalRetrievalChain.from_llm(llm, retriever=retriever, chain_type="stuff", memory=memory, return_source_documents=True)
 
 
 
 
 
 
 
 
49
  progress(0.9, desc="Concluído!")
50
  return qa_chain
51
 
52
  # Função para gerar um nome de coleção válido
53
  def create_collection_name(filepath):
54
  collection_name = Path(filepath).stem
55
+ collection_name = unidecode(collection_name.replace(" ", "-"))
56
+ return re.sub('[^A-Za-z0-9]+', '-', collection_name)[:50]
 
 
 
 
 
 
 
 
 
 
 
57
 
58
+ # Função para inicializar o banco de dados e o modelo LLM
59
+ def initialize_database_and_llm(list_file_obj, chunk_size, chunk_overlap, llm_option, llm_temperature, max_tokens, top_k, progress=gr.Progress()):
60
  list_file_path = [x.name for x in list_file_obj if x is not None]
61
  progress(0.1, desc="Criando nome da coleção...")
62
  collection_name = create_collection_name(list_file_path[0])
63
+ progress(0.25, desc="Carregando e dividindo documentos...")
64
+ doc_splits = load_and_split_documents(list_file_path, chunk_size, chunk_overlap)
65
  progress(0.5, desc="Gerando banco de dados vetorial...")
66
+ vector_db = create_vector_db(doc_splits, collection_name)
67
+ progress(0.75, desc="Inicializando modelo LLM...")
 
 
 
 
68
  llm_name = list_llm[llm_option]
 
69
  qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
70
+ progress(0.9, desc="Concluído!")
71
+ return vector_db, collection_name, qa_chain
 
 
 
 
 
 
 
72
 
73
+ # Função de interação com o chatbot
74
  def conversation(qa_chain, message, history):
75
+ formatted_chat_history = [f"Usuário: {user_message}\nAssistente: {bot_message}" for user_message, bot_message in history]
76
  response = qa_chain({"question": message, "chat_history": formatted_chat_history})
77
+ response_answer = response["answer"].split("Resposta útil:")[-1]
78
+ response_sources = [doc.page_content.strip() for doc in response["source_documents"]]
79
+ response_pages = [doc.metadata["page"] + 1 for doc in response["source_documents"]]
 
 
 
 
 
 
 
80
  new_history = history + [(message, response_answer)]
81
+ return qa_chain, gr.update(value=""), new_history, *response_sources, *response_pages
82
 
83
+ # Função de carregamento de arquivos
84
  def upload_file(file_obj):
85
+ return [file_obj.name for file_obj in file_obj if file_obj is not None]
 
 
 
 
86
 
87
+ # Interface Gradio
88
  def demo():
89
  with gr.Blocks(theme="base") as demo:
90
+ vector_db, qa_chain, collection_name = gr.State(), gr.State(), gr.State()
91
+ gr.Markdown("<center><h2>Chatbot baseado em PDF</center></h2><h3>Faça qualquer pergunta sobre seus documentos PDF</h3>")
92
+
 
 
 
 
 
 
 
 
 
 
 
93
  with gr.Tab("Etapa 1 - Carregar PDF"):
94
+ document = gr.Files(height=100, file_count="multiple", file_types=["pdf"])
95
+
 
 
96
  with gr.Tab("Etapa 2 - Processar documento"):
97
+ db_btn = gr.Button("Gerar banco de dados vetorial")
98
+ slider_chunk_size = gr.Slider(minimum=100, maximum=1000, value=600, step=20, label="Tamanho do bloco")
99
+ slider_chunk_overlap = gr.Slider(minimum=10, maximum=200, value=40, step=10, label="Sobreposição do bloco")
100
+ db_progress = gr.Textbox(label="Inicialização do banco de dados vetorial")
101
+
 
 
 
 
 
 
 
102
  with gr.Tab("Etapa 3 - Inicializar cadeia de QA"):
103
+ llm_btn = gr.Radio(list_llm_simple, label="Modelos LLM")
104
+ slider_temperature = gr.Slider(minimum=0.01, maximum=1.0, value=0.7, step=0.1, label="Temperatura")
105
+ slider_maxtokens = gr.Slider(minimum=224, maximum=4096, value=1024, step=32, label="Máximo de Tokens")
106
+ slider_topk = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Amostras top-k")
107
+ llm_progress = gr.Textbox(value="Nenhum", label="Inicialização da cadeia QA")
108
+ qachain_btn = gr.Button("Inicializar cadeia de Pergunta e Resposta")
 
 
 
 
 
 
 
 
109
 
110
  with gr.Tab("Etapa 4 - Chatbot"):
111
  chatbot = gr.Chatbot(height=300)
112
+ doc_source1, doc_source2, doc_source3 = gr.Textbox(label="Referência 1"), gr.Textbox(label="Referência 2"), gr.Textbox(label="Referência 3")
113
+ source1_page, source2_page, source3_page = gr.Number(label="Página 1"), gr.Number(label="Página 2"), gr.Number(label="Página 3")
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
+ # Implementação de lógica de interação de conversa
116
+ qachain_btn.click(initialize_database_and_llm, inputs=[document, slider_chunk_size, slider_chunk_overlap, llm_btn, slider_temperature, slider_maxtokens, slider_topk], outputs=[vector_db, collection_name, qa_chain])
117
+ chatbot.submit(conversation, inputs=[qa_chain, chatbot.input, chatbot.history], outputs=[qa_chain, gr.update(value=""), chatbot.history, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page])
118
+
119
+ demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
+ # Executando o app
122
  if __name__ == "__main__":
123
  demo()