DHEIVER commited on
Commit
36fef6e
·
verified ·
1 Parent(s): 4f1f3fd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -49
app.py CHANGED
@@ -1,13 +1,11 @@
1
  import os
2
- from typing import Optional, Tuple, Dict
3
  import gradio as gr
4
  from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
6
  from langchain_community.embeddings import HuggingFaceEmbeddings
7
  from langchain_community.vectorstores import FAISS
8
- from langchain_community.llms import HuggingFacePipeline
9
- from langchain.chains import RetrievalQA
10
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
11
  import torch
12
  import tempfile
13
  import time
@@ -19,6 +17,7 @@ DOCS_DIR = "documents"
19
 
20
  class RAGSystem:
21
  def __init__(self):
 
22
  self.tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL)
23
  self.model = AutoModelForSeq2SeqLM.from_pretrained(
24
  LLM_MODEL,
@@ -26,21 +25,13 @@ class RAGSystem:
26
  torch_dtype=torch.float32
27
  )
28
 
29
- pipe = pipeline(
30
- "text2text-generation",
31
- model=self.model,
32
- tokenizer=self.tokenizer,
33
- max_length=512,
34
- temperature=0.3, # Respostas mais precisas
35
- top_p=0.9, # Diversidade controlada
36
- repetition_penalty=1.2 # Evita repetições
37
- )
38
-
39
- self.llm = HuggingFacePipeline(pipeline=pipe)
40
  self.embeddings = HuggingFaceEmbeddings(
41
  model_name=EMBEDDING_MODEL,
42
  model_kwargs={'device': 'cpu'}
43
  )
 
 
44
  self.base_db = self.load_base_knowledge()
45
 
46
  def load_base_knowledge(self) -> Optional[FAISS]:
@@ -49,6 +40,7 @@ class RAGSystem:
49
  os.makedirs(DOCS_DIR)
50
  return None
51
 
 
52
  loader = DirectoryLoader(
53
  DOCS_DIR,
54
  glob="**/*.pdf",
@@ -59,6 +51,7 @@ class RAGSystem:
59
  if not documents:
60
  return None
61
 
 
62
  text_splitter = RecursiveCharacterTextSplitter(
63
  chunk_size=500,
64
  chunk_overlap=100,
@@ -67,6 +60,7 @@ class RAGSystem:
67
  )
68
  texts = text_splitter.split_documents(documents)
69
 
 
70
  return FAISS.from_documents(texts, self.embeddings)
71
 
72
  except Exception as e:
@@ -75,10 +69,12 @@ class RAGSystem:
75
 
76
  def process_pdf(self, file_content: bytes) -> Optional[FAISS]:
77
  try:
 
78
  with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file:
79
  tmp_file.write(file_content)
80
  tmp_path = tmp_file.name
81
 
 
82
  loader = PyPDFLoader(tmp_path)
83
  documents = loader.load()
84
  os.unlink(tmp_path)
@@ -86,6 +82,7 @@ class RAGSystem:
86
  if not documents:
87
  return None
88
 
 
89
  text_splitter = RecursiveCharacterTextSplitter(
90
  chunk_size=500,
91
  chunk_overlap=100,
@@ -94,8 +91,10 @@ class RAGSystem:
94
  )
95
  texts = text_splitter.split_documents(documents)
96
 
 
97
  db = FAISS.from_documents(texts, self.embeddings)
98
 
 
99
  if self.base_db is not None:
100
  db.merge_from(self.base_db)
101
 
@@ -105,26 +104,6 @@ class RAGSystem:
105
  print(f"Erro ao processar PDF: {str(e)}")
106
  return None
107
 
108
- def format_response(self, raw_response: str, source_type: str, context_found: bool) -> str:
109
- """Formata a resposta para um formato padronizado e claro"""
110
- if not context_found:
111
- return "🔍 Não foram encontradas informações suficientes nos documentos para responder esta pergunta."
112
-
113
- prefix = ""
114
- if source_type == "pdf":
115
- prefix = "📄 [Resposta baseada no PDF enviado]\n\n"
116
- elif source_type == "base":
117
- prefix = "📚 [Resposta baseada na base de documentos]\n\n"
118
- elif source_type == "both":
119
- prefix = "📚📄 [Resposta baseada em ambas as fontes]\n\n"
120
-
121
- # Limpa e formata a resposta
122
- response = raw_response.strip()
123
- if not response:
124
- return "🔍 Não foi possível gerar uma resposta adequada com as informações disponíveis."
125
-
126
- return f"{prefix}{response}"
127
-
128
  def generate_response(self, file_obj, query: str, progress=gr.Progress()) -> Tuple[str, str, str]:
129
  """Retorna (resposta, status, tempo_decorrido)"""
130
  if not query.strip():
@@ -152,16 +131,14 @@ class RAGSystem:
152
 
153
  progress(0.4, desc="Buscando informações relevantes...")
154
 
155
- # Configuração do RAG
156
  retriever = db.as_retriever(
157
  search_kwargs={
158
- "k": 6, # Aumenta o número de trechos recuperados
159
  "fetch_k": 10,
160
  "score_threshold": 0.5 # Limiar de relevância
161
  }
162
  )
163
-
164
- # Recupera o contexto
165
  context_docs = retriever.get_relevant_documents(query)
166
 
167
  # Verifica se o contexto é relevante
@@ -173,7 +150,7 @@ class RAGSystem:
173
 
174
  progress(0.6, desc="Gerando resposta...")
175
 
176
- # Prompt mais estruturado
177
  prompt = f"""Instruções:
178
  1. Analise cuidadosamente o contexto fornecido.
179
  2. Responda à seguinte pergunta em português de forma clara e direta: {query}
@@ -187,20 +164,21 @@ class RAGSystem:
187
 
188
  Pergunta: {query}"""
189
 
190
- # Gera resposta
191
- result = self.llm(prompt)
192
-
193
- # Formata a resposta
194
- formatted_response = self.format_response(
195
- result,
196
- source_type,
197
- context_found=True
198
  )
 
199
 
200
  elapsed_time = f"{time.time() - start_time:.1f}s"
201
  progress(1.0, desc="Concluído!")
202
 
203
- return formatted_response, "✅ Sucesso", elapsed_time
204
 
205
  except Exception as e:
206
  elapsed_time = f"{time.time() - start_time:.1f}s"
 
1
  import os
2
+ from typing import Optional, Tuple
3
  import gradio as gr
4
  from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
6
  from langchain_community.embeddings import HuggingFaceEmbeddings
7
  from langchain_community.vectorstores import FAISS
8
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
 
 
9
  import torch
10
  import tempfile
11
  import time
 
17
 
18
  class RAGSystem:
19
  def __init__(self):
20
+ # Carrega o modelo e o tokenizador
21
  self.tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL)
22
  self.model = AutoModelForSeq2SeqLM.from_pretrained(
23
  LLM_MODEL,
 
25
  torch_dtype=torch.float32
26
  )
27
 
28
+ # Configurações de embedding
 
 
 
 
 
 
 
 
 
 
29
  self.embeddings = HuggingFaceEmbeddings(
30
  model_name=EMBEDDING_MODEL,
31
  model_kwargs={'device': 'cpu'}
32
  )
33
+
34
+ # Carrega a base de conhecimento
35
  self.base_db = self.load_base_knowledge()
36
 
37
  def load_base_knowledge(self) -> Optional[FAISS]:
 
40
  os.makedirs(DOCS_DIR)
41
  return None
42
 
43
+ # Carrega documentos da pasta
44
  loader = DirectoryLoader(
45
  DOCS_DIR,
46
  glob="**/*.pdf",
 
51
  if not documents:
52
  return None
53
 
54
+ # Divide os documentos em trechos menores
55
  text_splitter = RecursiveCharacterTextSplitter(
56
  chunk_size=500,
57
  chunk_overlap=100,
 
60
  )
61
  texts = text_splitter.split_documents(documents)
62
 
63
+ # Cria o banco de dados de embeddings
64
  return FAISS.from_documents(texts, self.embeddings)
65
 
66
  except Exception as e:
 
69
 
70
  def process_pdf(self, file_content: bytes) -> Optional[FAISS]:
71
  try:
72
+ # Salva o PDF temporariamente
73
  with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp_file:
74
  tmp_file.write(file_content)
75
  tmp_path = tmp_file.name
76
 
77
+ # Carrega o PDF
78
  loader = PyPDFLoader(tmp_path)
79
  documents = loader.load()
80
  os.unlink(tmp_path)
 
82
  if not documents:
83
  return None
84
 
85
+ # Divide o PDF em trechos menores
86
  text_splitter = RecursiveCharacterTextSplitter(
87
  chunk_size=500,
88
  chunk_overlap=100,
 
91
  )
92
  texts = text_splitter.split_documents(documents)
93
 
94
+ # Cria o banco de dados de embeddings
95
  db = FAISS.from_documents(texts, self.embeddings)
96
 
97
+ # Combina com a base de conhecimento existente, se houver
98
  if self.base_db is not None:
99
  db.merge_from(self.base_db)
100
 
 
104
  print(f"Erro ao processar PDF: {str(e)}")
105
  return None
106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  def generate_response(self, file_obj, query: str, progress=gr.Progress()) -> Tuple[str, str, str]:
108
  """Retorna (resposta, status, tempo_decorrido)"""
109
  if not query.strip():
 
131
 
132
  progress(0.4, desc="Buscando informações relevantes...")
133
 
134
+ # Recupera os trechos relevantes
135
  retriever = db.as_retriever(
136
  search_kwargs={
137
+ "k": 6, # Número de trechos recuperados
138
  "fetch_k": 10,
139
  "score_threshold": 0.5 # Limiar de relevância
140
  }
141
  )
 
 
142
  context_docs = retriever.get_relevant_documents(query)
143
 
144
  # Verifica se o contexto é relevante
 
150
 
151
  progress(0.6, desc="Gerando resposta...")
152
 
153
+ # Cria o prompt
154
  prompt = f"""Instruções:
155
  1. Analise cuidadosamente o contexto fornecido.
156
  2. Responda à seguinte pergunta em português de forma clara e direta: {query}
 
164
 
165
  Pergunta: {query}"""
166
 
167
+ # Gera a resposta usando o modelo diretamente
168
+ inputs = self.tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
169
+ outputs = self.model.generate(
170
+ inputs["input_ids"],
171
+ max_length=512,
172
+ temperature=0.3,
173
+ top_p=0.9,
174
+ repetition_penalty=1.2
175
  )
176
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
177
 
178
  elapsed_time = f"{time.time() - start_time:.1f}s"
179
  progress(1.0, desc="Concluído!")
180
 
181
+ return response, "✅ Sucesso", elapsed_time
182
 
183
  except Exception as e:
184
  elapsed_time = f"{time.time() - start_time:.1f}s"