DHEIVER commited on
Commit
7e5d1ad
·
verified ·
1 Parent(s): 791cd44

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -187
app.py CHANGED
@@ -1,188 +1,6 @@
 
1
  import gradio as gr
2
- import os
3
- from langchain_community.document_loaders import PyPDFLoader
4
- from langchain.text_splitter import RecursiveCharacterTextSplitter
5
- from langchain_community.vectorstores import Chroma
6
- from langchain.chains import ConversationalRetrievalChain
7
- from langchain_community.embeddings import HuggingFaceEmbeddings
8
- from langchain_community.llms import HuggingFaceEndpoint
9
- from langchain.memory import ConversationBufferMemory
10
- from pathlib import Path
11
- import chromadb
12
- from unidecode import unidecode
13
- import re
14
-
15
- # Lista de modelos LLM disponíveis
16
- list_llm = [
17
- "mistralai/Mistral-7B-Instruct-v0.2",
18
- "mistralai/Mixtral-8x7B-Instruct-v0.1",
19
- "mistralai/Mistral-7B-Instruct-v0.1",
20
- "google/gemma-7b-it",
21
- "google/gemma-2b-it",
22
- "HuggingFaceH4/zephyr-7b-beta",
23
- "HuggingFaceH4/zephyr-7b-gemma-v0.1",
24
- "meta-llama/Llama-2-7b-chat-hf",
25
- "microsoft/phi-2",
26
- "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
27
- "mosaicml/mpt-7b-instruct",
28
- "tiiuae/falcon-7b-instruct",
29
- "google/flan-t5-xxl"
30
- ]
31
- list_llm_simple = [os.path.basename(llm) for llm in list_llm]
32
-
33
- # Função para carregar documentos PDF e dividir em chunks
34
- def load_doc(list_file_path, chunk_size, chunk_overlap):
35
- loaders = [PyPDFLoader(x) for x in list_file_path]
36
- pages = []
37
- for loader in loaders:
38
- pages.extend(loader.load())
39
- text_splitter = RecursiveCharacterTextSplitter(
40
- chunk_size=chunk_size,
41
- chunk_overlap=chunk_overlap
42
- )
43
- doc_splits = text_splitter.split_documents(pages)
44
- return doc_splits
45
-
46
- # Função para criar o banco de dados vetorial
47
- def create_db(splits, collection_name):
48
- embedding = HuggingFaceEmbeddings()
49
- # Usando PersistentClient para persistir o banco de dados
50
- new_client = chromadb.PersistentClient(path="./chroma_db")
51
- vectordb = Chroma.from_documents(
52
- documents=splits,
53
- embedding=embedding,
54
- client=new_client,
55
- collection_name=collection_name,
56
- )
57
- return vectordb
58
-
59
- # Função para inicializar a cadeia de QA com o modelo LLM
60
- def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
61
- progress(0.1, desc="Inicializando tokenizer da HF...")
62
- progress(0.5, desc="Inicializando Hub da HF...")
63
- if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
64
- llm = HuggingFaceEndpoint(
65
- repo_id=llm_model,
66
- temperature=temperature,
67
- max_new_tokens=max_tokens,
68
- top_k=top_k,
69
- load_in_8bit=True,
70
- )
71
- elif llm_model in ["HuggingFaceH4/zephyr-7b-gemma-v0.1", "mosaicml/mpt-7b-instruct"]:
72
- raise gr.Error("O modelo LLM é muito grande para ser carregado automaticamente no endpoint de inferência gratuito")
73
- elif llm_model == "microsoft/phi-2":
74
- llm = HuggingFaceEndpoint(
75
- repo_id=llm_model,
76
- temperature=temperature,
77
- max_new_tokens=max_tokens,
78
- top_k=top_k,
79
- trust_remote_code=True,
80
- torch_dtype="auto",
81
- )
82
- elif llm_model == "TinyLlama/TinyLlama-1.1B-Chat-v1.0":
83
- llm = HuggingFaceEndpoint(
84
- repo_id=llm_model,
85
- temperature=temperature,
86
- max_new_tokens=250,
87
- top_k=top_k,
88
- )
89
- elif llm_model == "meta-llama/Llama-2-7b-chat-hf":
90
- raise gr.Error("O modelo Llama-2-7b-chat-hf requer uma assinatura Pro...")
91
- else:
92
- llm = HuggingFaceEndpoint(
93
- repo_id=llm_model,
94
- temperature=temperature,
95
- max_new_tokens=max_tokens,
96
- top_k=top_k,
97
- )
98
-
99
- progress(0.75, desc="Definindo memória de buffer...")
100
- memory = ConversationBufferMemory(
101
- memory_key="chat_history",
102
- output_key='answer',
103
- return_messages=True
104
- )
105
- retriever = vector_db.as_retriever()
106
- progress(0.8, desc="Definindo cadeia de recuperação...")
107
- qa_chain = ConversationalRetrievalChain.from_llm(
108
- llm,
109
- retriever=retriever,
110
- chain_type="stuff",
111
- memory=memory,
112
- return_source_documents=True,
113
- verbose=False,
114
- )
115
- progress(0.9, desc="Concluído!")
116
- return qa_chain
117
-
118
- # Função para gerar um nome de coleção válido
119
- def create_collection_name(filepath):
120
- collection_name = Path(filepath).stem
121
- collection_name = collection_name.replace(" ", "-")
122
- collection_name = unidecode(collection_name)
123
- collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name)
124
- collection_name = collection_name[:50]
125
- if len(collection_name) < 3:
126
- collection_name = collection_name + 'xyz'
127
- if not collection_name[0].isalnum():
128
- collection_name = 'A' + collection_name[1:]
129
- if not collection_name[-1].isalnum():
130
- collection_name = collection_name[:-1] + 'Z'
131
- print('Caminho do arquivo: ', filepath)
132
- print('Nome da coleção: ', collection_name)
133
- return collection_name
134
-
135
- # Função para inicializar o banco de dados
136
- def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()):
137
- list_file_path = [x.name for x in list_file_obj if x is not None]
138
- progress(0.1, desc="Criando nome da coleção...")
139
- collection_name = create_collection_name(list_file_path[0])
140
- progress(0.25, desc="Carregando documento...")
141
- doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
142
- progress(0.5, desc="Gerando banco de dados vetorial...")
143
- vector_db = create_db(doc_splits, collection_name)
144
- progress(0.9, desc="Concluído!")
145
- return vector_db, collection_name, "Completo!"
146
-
147
- # Função para inicializar o modelo LLM
148
- def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
149
- llm_name = list_llm[llm_option]
150
- print("Nome do LLM: ", llm_name)
151
- qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
152
- return qa_chain, "Completo!"
153
-
154
- # Função para formatar o histórico de conversa
155
- def format_chat_history(message, chat_history):
156
- formatted_chat_history = []
157
- for user_message, bot_message in chat_history:
158
- formatted_chat_history.append(f"Usuário: {user_message}")
159
- formatted_chat_history.append(f"Assistente: {bot_message}")
160
- return formatted_chat_history
161
-
162
- # Função para realizar a conversa com o chatbot
163
- def conversation(qa_chain, message, history):
164
- formatted_chat_history = format_chat_history(message, history)
165
- response = qa_chain({"question": message, "chat_history": formatted_chat_history})
166
- response_answer = response["answer"]
167
- if response_answer.find("Resposta útil:") != -1:
168
- response_answer = response_answer.split("Resposta útil:")[-1]
169
- response_sources = response["source_documents"]
170
- response_source1 = response_sources[0].page_content.strip()
171
- response_source2 = response_sources[1].page_content.strip()
172
- response_source3 = response_sources[2].page_content.strip()
173
- response_source1_page = response_sources[0].metadata["page"] + 1
174
- response_source2_page = response_sources[1].metadata["page"] + 1
175
- response_source3_page = response_sources[2].metadata["page"] + 1
176
- new_history = history + [(message, response_answer)]
177
- return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
178
-
179
- # Função para carregar arquivos
180
- def upload_file(file_obj):
181
- list_file_path = []
182
- for idx, file in enumerate(file_obj):
183
- file_path = file_obj.name
184
- list_file_path.append(file_path)
185
- return list_file_path
186
 
187
  def demo():
188
  with gr.Blocks(theme="base") as demo:
@@ -203,7 +21,6 @@ def demo():
203
  with gr.Tab("Etapa 1 - Carregar PDF"):
204
  with gr.Row():
205
  document = gr.Files(height=100, file_count="multiple", file_types=["pdf"], interactive=True, label="Carregue seus documentos PDF (único ou múltiplos)")
206
- # upload_btn = gr.UploadButton("Carregando documento...", height=100, file_count="multiple", file_types=["pdf"], scale=1)
207
 
208
  with gr.Tab("Etapa 2 - Processar documento"):
209
  with gr.Row():
@@ -253,7 +70,6 @@ def demo():
253
  clear_btn = gr.ClearButton([msg, chatbot], value="Limpar conversa")
254
 
255
  # Eventos de pré-processamento
256
- #upload_btn.upload(upload_file, inputs=[upload_btn], outputs=[document])
257
  db_btn.click(initialize_database, \
258
  inputs=[document, slider_chunk_size, slider_chunk_overlap], \
259
  outputs=[vector_db, collection_name, db_progress])
@@ -281,4 +97,4 @@ def demo():
281
 
282
 
283
  if __name__ == "__main__":
284
- demo()
 
1
+ # gradio_interface.py
2
  import gradio as gr
3
+ from rag_functions import *
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  def demo():
6
  with gr.Blocks(theme="base") as demo:
 
21
  with gr.Tab("Etapa 1 - Carregar PDF"):
22
  with gr.Row():
23
  document = gr.Files(height=100, file_count="multiple", file_types=["pdf"], interactive=True, label="Carregue seus documentos PDF (único ou múltiplos)")
 
24
 
25
  with gr.Tab("Etapa 2 - Processar documento"):
26
  with gr.Row():
 
70
  clear_btn = gr.ClearButton([msg, chatbot], value="Limpar conversa")
71
 
72
  # Eventos de pré-processamento
 
73
  db_btn.click(initialize_database, \
74
  inputs=[document, slider_chunk_size, slider_chunk_overlap], \
75
  outputs=[vector_db, collection_name, db_progress])
 
97
 
98
 
99
  if __name__ == "__main__":
100
+ demo()