Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -102,97 +102,68 @@ uploaded_documents = load_documents()
|
|
102 |
from langchain.vectorstores import FAISS
|
103 |
import faiss
|
104 |
|
105 |
-
def
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
|
|
116 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
for file in files:
|
118 |
logging.info(f"Processing file: {file.name}")
|
119 |
try:
|
120 |
-
|
121 |
-
if not
|
122 |
logging.warning(f"No chunks loaded from {file.name}")
|
123 |
continue
|
124 |
-
logging.info(f"Loaded {len(
|
125 |
-
|
126 |
-
for chunk in data:
|
127 |
-
if chunk.page_content not in seen_contents:
|
128 |
-
all_data.append(chunk)
|
129 |
-
seen_contents.add(chunk.page_content)
|
130 |
-
else:
|
131 |
-
logging.warning(f"Duplicate content detected in {file.name}, skipping...")
|
132 |
-
|
133 |
-
if not any(doc["name"] == file.name for doc in uploaded_documents):
|
134 |
-
uploaded_documents.append({"name": file.name, "selected": True})
|
135 |
-
logging.info(f"Added new document to uploaded_documents: {file.name}")
|
136 |
-
else:
|
137 |
-
logging.info(f"Document already exists in uploaded_documents: {file.name}")
|
138 |
except Exception as e:
|
139 |
logging.error(f"Error processing file {file.name}: {str(e)}")
|
140 |
|
141 |
-
if not
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
try:
|
146 |
-
|
147 |
-
|
148 |
-
database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
|
149 |
-
initial_size = database.index.ntotal
|
150 |
-
database.add_documents(all_data)
|
151 |
-
final_size = database.index.ntotal
|
152 |
-
logging.info(f"FAISS database updated. Initial size: {initial_size}, Final size: {final_size}")
|
153 |
-
else:
|
154 |
-
logging.info("Creating new FAISS database")
|
155 |
-
database = FAISS.from_documents(all_data, embed)
|
156 |
-
logging.info(f"New FAISS database created with {database.index.ntotal} vectors")
|
157 |
-
|
158 |
-
database.save_local("faiss_database")
|
159 |
-
logging.info("FAISS database saved")
|
160 |
-
|
161 |
-
# Check the database after updating
|
162 |
-
check_faiss_database()
|
163 |
-
|
164 |
-
# Analyze document similarity
|
165 |
-
analyze_document_similarity()
|
166 |
-
|
167 |
except Exception as e:
|
168 |
logging.error(f"Error updating FAISS database: {str(e)}")
|
169 |
-
return f"Error updating vector store: {str(e)}"
|
170 |
-
|
171 |
-
save_documents(uploaded_documents)
|
172 |
-
logging.info(f"Updated documents saved. Total documents: {len(uploaded_documents)}")
|
173 |
-
|
174 |
-
return f"Vector store updated successfully. Processed {len(all_data)} chunks from {len(files)} files using {parser}.", display_documents()
|
175 |
-
|
176 |
-
from sklearn.metrics.pairwise import cosine_similarity
|
177 |
-
|
178 |
-
def analyze_document_similarity():
|
179 |
-
embed = get_embeddings()
|
180 |
-
database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
|
181 |
-
|
182 |
-
docs = list(database.docstore.docs.values())
|
183 |
-
embeddings = [database.embedding_function(doc.page_content) for doc in docs]
|
184 |
-
|
185 |
-
similarity_matrix = cosine_similarity(embeddings)
|
186 |
-
|
187 |
-
for i in range(len(docs)):
|
188 |
-
for j in range(i+1, len(docs)):
|
189 |
-
similarity = similarity_matrix[i][j]
|
190 |
-
logging.info(f"Similarity between {docs[i].metadata['source']} and {docs[j].metadata['source']}: {similarity}")
|
191 |
-
if similarity > 0.9: # Adjust this threshold as needed
|
192 |
-
logging.warning(f"High similarity detected between {docs[i].metadata['source']} and {docs[j].metadata['source']}")
|
193 |
-
|
194 |
-
# Call this after updating the vector store
|
195 |
-
analyze_document_similarity()
|
196 |
|
197 |
def delete_documents(selected_docs):
|
198 |
global uploaded_documents
|
@@ -522,17 +493,17 @@ def get_response_from_pdf(query, model, selected_docs, num_calls=3, temperature=
|
|
522 |
return
|
523 |
|
524 |
try:
|
525 |
-
retriever = database.as_retriever(search_kwargs={"k":
|
526 |
logging.info(f"Retrieving relevant documents for query: {query}")
|
527 |
-
|
528 |
-
logging.info(f"Number of relevant documents retrieved: {len(
|
529 |
|
530 |
-
|
531 |
-
|
532 |
-
logging.info(f"
|
533 |
-
|
534 |
-
# Filter relevant_docs based on selected documents
|
535 |
-
filtered_docs = [doc for doc in
|
536 |
logging.info(f"Number of filtered documents: {len(filtered_docs)}")
|
537 |
|
538 |
if not filtered_docs:
|
@@ -541,47 +512,37 @@ def get_response_from_pdf(query, model, selected_docs, num_calls=3, temperature=
|
|
541 |
return
|
542 |
|
543 |
for i, doc in enumerate(filtered_docs):
|
544 |
-
logging.info(f"
|
545 |
-
logging.info(f"
|
546 |
|
547 |
-
context_str = "\n
|
548 |
logging.info(f"Total context length: {len(context_str)}")
|
549 |
|
550 |
-
prompt = f"""You are analyzing multiple financial documents. The following documents have been selected: {', '.join(selected_docs)}
|
551 |
-
|
552 |
-
Using the following context from the selected PDF documents:
|
553 |
-
|
554 |
-
{context_str}
|
555 |
-
|
556 |
-
Please provide a detailed and complete response that answers the following user question, making sure to consider information from all selected documents: '{query}'
|
557 |
-
|
558 |
-
If the information is not found in the provided context, please state that clearly."""
|
559 |
-
|
560 |
if model == "@cf/meta/llama-3.1-8b-instruct":
|
561 |
logging.info("Using Cloudflare API")
|
562 |
-
for response in get_response_from_cloudflare(prompt=
|
563 |
yield response
|
564 |
else:
|
565 |
logging.info("Using Hugging Face API")
|
|
|
|
|
|
|
|
|
566 |
client = InferenceClient(model, token=huggingface_token)
|
567 |
|
568 |
response = ""
|
569 |
for i in range(num_calls):
|
570 |
logging.info(f"API call {i+1}/{num_calls}")
|
571 |
-
|
572 |
-
|
573 |
-
|
574 |
-
|
575 |
-
|
576 |
-
|
577 |
-
|
578 |
-
|
579 |
-
|
580 |
-
|
581 |
-
yield response # Yield partial response
|
582 |
-
except Exception as e:
|
583 |
-
logging.error(f"Error in API call {i+1}: {str(e)}")
|
584 |
-
yield f"Error in API call {i+1}: {str(e)}. Attempting next call..."
|
585 |
|
586 |
logging.info("Finished generating response")
|
587 |
|
|
|
102 |
from langchain.vectorstores import FAISS
|
103 |
import faiss
|
104 |
|
105 |
+
def add_documents_to_faiss(documents: List[Document], embeddings):
|
106 |
+
logging.info(f"Adding {len(documents)} documents to FAISS database")
|
107 |
+
if os.path.exists("faiss_database"):
|
108 |
+
db = FAISS.load_local("faiss_database", embeddings, allow_dangerous_deserialization=True)
|
109 |
+
logging.info(f"Loaded existing FAISS database with {db.index.ntotal} vectors")
|
110 |
+
initial_size = db.index.ntotal
|
111 |
+
db.add_documents(documents)
|
112 |
+
final_size = db.index.ntotal
|
113 |
+
logging.info(f"FAISS database updated. Initial size: {initial_size}, Final size: {final_size}")
|
114 |
+
else:
|
115 |
+
db = FAISS.from_documents(documents, embeddings)
|
116 |
+
logging.info(f"Created new FAISS database with {db.index.ntotal} vectors")
|
117 |
|
118 |
+
db.save_local("faiss_database")
|
119 |
+
logging.info("FAISS database saved")
|
120 |
+
return db
|
121 |
+
|
122 |
+
def get_relevant_documents(query: str, selected_docs: List[str], embeddings) -> List[Document]:
|
123 |
+
if not os.path.exists("faiss_database"):
|
124 |
+
logging.warning("No FAISS database found")
|
125 |
+
return []
|
126 |
+
|
127 |
+
db = FAISS.load_local("faiss_database", embeddings, allow_dangerous_deserialization=True)
|
128 |
+
logging.info(f"Loaded FAISS database with {db.index.ntotal} vectors")
|
129 |
+
|
130 |
+
# Retrieve documents without filtering first
|
131 |
+
all_docs = db.similarity_search(query, k=20) # Increase k to ensure we get enough documents
|
132 |
+
logging.info(f"Retrieved {len(all_docs)} documents from FAISS")
|
133 |
+
|
134 |
+
# Log all retrieved documents
|
135 |
+
for i, doc in enumerate(all_docs):
|
136 |
+
logging.info(f"Retrieved document {i+1} source: {doc.metadata['source']}")
|
137 |
+
|
138 |
+
# Filter documents based on selected_docs
|
139 |
+
filtered_docs = [doc for doc in all_docs if doc.metadata["source"] in selected_docs]
|
140 |
+
logging.info(f"Filtered to {len(filtered_docs)} documents based on selection")
|
141 |
+
|
142 |
+
return filtered_docs
|
143 |
+
|
144 |
+
def update_vectors(files: List[NamedTemporaryFile], parser: str, embeddings) -> str:
|
145 |
+
all_documents = []
|
146 |
for file in files:
|
147 |
logging.info(f"Processing file: {file.name}")
|
148 |
try:
|
149 |
+
documents = load_document(file, parser)
|
150 |
+
if not documents:
|
151 |
logging.warning(f"No chunks loaded from {file.name}")
|
152 |
continue
|
153 |
+
logging.info(f"Loaded {len(documents)} chunks from {file.name}")
|
154 |
+
all_documents.extend(documents)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
except Exception as e:
|
156 |
logging.error(f"Error processing file {file.name}: {str(e)}")
|
157 |
|
158 |
+
if not all_documents:
|
159 |
+
return "No valid data could be extracted from the uploaded files."
|
160 |
+
|
|
|
161 |
try:
|
162 |
+
db = add_documents_to_faiss(all_documents, embeddings)
|
163 |
+
return f"Vector store updated successfully. Added {len(all_documents)} chunks from {len(files)} files."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
except Exception as e:
|
165 |
logging.error(f"Error updating FAISS database: {str(e)}")
|
166 |
+
return f"Error updating vector store: {str(e)}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
|
168 |
def delete_documents(selected_docs):
|
169 |
global uploaded_documents
|
|
|
493 |
return
|
494 |
|
495 |
try:
|
496 |
+
retriever = database.as_retriever(search_kwargs={"k": 20}) # Increase k to retrieve more documents initially
|
497 |
logging.info(f"Retrieving relevant documents for query: {query}")
|
498 |
+
all_relevant_docs = retriever.get_relevant_documents(query)
|
499 |
+
logging.info(f"Number of relevant documents retrieved: {len(all_relevant_docs)}")
|
500 |
|
501 |
+
# Log all retrieved documents before filtering
|
502 |
+
for i, doc in enumerate(all_relevant_docs):
|
503 |
+
logging.info(f"Retrieved document {i+1} source: {doc.metadata['source']}")
|
504 |
+
|
505 |
+
# Filter relevant_docs based on selected documents
|
506 |
+
filtered_docs = [doc for doc in all_relevant_docs if doc.metadata["source"] in selected_docs]
|
507 |
logging.info(f"Number of filtered documents: {len(filtered_docs)}")
|
508 |
|
509 |
if not filtered_docs:
|
|
|
512 |
return
|
513 |
|
514 |
for i, doc in enumerate(filtered_docs):
|
515 |
+
logging.info(f"Document {i+1} source: {doc.metadata['source']}")
|
516 |
+
logging.info(f"Document {i+1} content preview: {doc.page_content[:100]}...")
|
517 |
|
518 |
+
context_str = "\n".join([doc.page_content for doc in filtered_docs])
|
519 |
logging.info(f"Total context length: {len(context_str)}")
|
520 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
521 |
if model == "@cf/meta/llama-3.1-8b-instruct":
|
522 |
logging.info("Using Cloudflare API")
|
523 |
+
for response in get_response_from_cloudflare(prompt="", context=context_str, query=query, num_calls=num_calls, temperature=temperature, search_type="pdf"):
|
524 |
yield response
|
525 |
else:
|
526 |
logging.info("Using Hugging Face API")
|
527 |
+
prompt = f"""Using the following context from the PDF documents:
|
528 |
+
{context_str}
|
529 |
+
Write a detailed and complete response that answers the following user question: '{query}'"""
|
530 |
+
|
531 |
client = InferenceClient(model, token=huggingface_token)
|
532 |
|
533 |
response = ""
|
534 |
for i in range(num_calls):
|
535 |
logging.info(f"API call {i+1}/{num_calls}")
|
536 |
+
for message in client.chat_completion(
|
537 |
+
messages=[{"role": "user", "content": prompt}],
|
538 |
+
max_tokens=10000,
|
539 |
+
temperature=temperature,
|
540 |
+
stream=True,
|
541 |
+
):
|
542 |
+
if message.choices and message.choices[0].delta and message.choices[0].delta.content:
|
543 |
+
chunk = message.choices[0].delta.content
|
544 |
+
response += chunk
|
545 |
+
yield response # Yield partial response
|
|
|
|
|
|
|
|
|
546 |
|
547 |
logging.info("Finished generating response")
|
548 |
|