Spaces:
Running
Running
Update agent.py
Browse files
agent.py
CHANGED
@@ -41,7 +41,7 @@ from io import StringIO
|
|
41 |
|
42 |
from transformers import BertTokenizer, BertModel
|
43 |
import torch
|
44 |
-
|
45 |
|
46 |
|
47 |
load_dotenv()
|
@@ -361,28 +361,75 @@ class BERTEmbeddings(Embeddings):
|
|
361 |
|
362 |
# Example usage of BERTEmbedding with LangChain
|
363 |
|
364 |
-
embedding_model = BERTEmbeddings(model_name="bert-base-uncased")
|
365 |
|
366 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
367 |
docs = [
|
368 |
-
Document(page_content="Mercedes Sosa released many albums between 2000 and 2009."),
|
369 |
-
Document(page_content="She was a prominent Argentine folk singer."),
|
370 |
-
Document(page_content="Her album 'Al Despertar' was released in 1998."),
|
371 |
-
Document(page_content="She continued releasing music well into the 2000s.")
|
372 |
]
|
373 |
-
# Get the embeddings for the documents
|
374 |
-
vector_store = FAISS.from_documents(docs, embedding_model)
|
375 |
|
376 |
-
# Now, you can use the embeddings with FAISS or other retrieval systems
|
377 |
-
# For example, with FAISS:
|
378 |
|
379 |
-
#
|
|
|
|
|
380 |
vector_store = FAISS.from_documents(docs, embedding_model)
|
381 |
vector_store.save_local("faiss_index")
|
382 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
383 |
|
384 |
# -----------------------------
|
385 |
-
#
|
386 |
# -----------------------------
|
387 |
retriever = vector_store.as_retriever()
|
388 |
|
|
|
41 |
|
42 |
from transformers import BertTokenizer, BertModel
|
43 |
import torch
|
44 |
+
import torch.nn.functional as F
|
45 |
|
46 |
|
47 |
load_dotenv()
|
|
|
361 |
|
362 |
# Example usage of BERTEmbedding with LangChain
|
363 |
|
|
|
364 |
|
365 |
+
# -----------------------------
|
366 |
+
# 1. Define Custom BERT Embedding Model
|
367 |
+
# -----------------------------
|
368 |
+
class BERTEmbeddings(Embeddings):
|
369 |
+
def __init__(self, model_name='bert-base-uncased'):
|
370 |
+
self.tokenizer = BertTokenizer.from_pretrained(model_name)
|
371 |
+
self.model = BertModel.from_pretrained(model_name)
|
372 |
+
self.model.eval() # Set model to eval mode
|
373 |
+
|
374 |
+
def embed_documents(self, texts):
|
375 |
+
inputs = self.tokenizer(texts, return_tensors='pt', padding=True, truncation=True)
|
376 |
+
with torch.no_grad():
|
377 |
+
outputs = self.model(**inputs)
|
378 |
+
embeddings = outputs.last_hidden_state.mean(dim=1)
|
379 |
+
embeddings = F.normalize(embeddings, p=2, dim=1) # Normalize for cosine similarity
|
380 |
+
return embeddings.cpu().numpy()
|
381 |
+
|
382 |
+
def embed_query(self, text):
|
383 |
+
return self.embed_documents([text])[0]
|
384 |
+
|
385 |
+
|
386 |
+
# -----------------------------
|
387 |
+
# 2. Initialize Embedding Model
|
388 |
+
# -----------------------------
|
389 |
+
embedding_model = BERTEmbeddings()
|
390 |
+
|
391 |
+
|
392 |
+
# -----------------------------
|
393 |
+
# 3. Prepare Documents
|
394 |
+
# -----------------------------
|
395 |
docs = [
|
396 |
+
Document(page_content="Mercedes Sosa released many albums between 2000 and 2009.", metadata={"id": 1}),
|
397 |
+
Document(page_content="She was a prominent Argentine folk singer.", metadata={"id": 2}),
|
398 |
+
Document(page_content="Her album 'Al Despertar' was released in 1998.", metadata={"id": 3}),
|
399 |
+
Document(page_content="She continued releasing music well into the 2000s.", metadata={"id": 4}),
|
400 |
]
|
|
|
|
|
401 |
|
|
|
|
|
402 |
|
403 |
+
# -----------------------------
|
404 |
+
# 4. Create FAISS Vector Store
|
405 |
+
# -----------------------------
|
406 |
vector_store = FAISS.from_documents(docs, embedding_model)
|
407 |
vector_store.save_local("faiss_index")
|
408 |
|
409 |
+
# -----------------------------
|
410 |
+
# 5. Query & Filter Results (optional preview)
|
411 |
+
# -----------------------------
|
412 |
+
query = "How many albums did Mercedes Sosa release between 2000 and 2009?"
|
413 |
+
results = vector_store.similarity_search_with_score(query, k=5)
|
414 |
+
threshold = 0.75
|
415 |
+
filtered = [doc for doc, score in results if score < threshold]
|
416 |
+
|
417 |
+
|
418 |
+
print("\n📊 Retrieved Documents with Similarity Scores:")
|
419 |
+
filtered = []
|
420 |
+
for doc, score in results:
|
421 |
+
print(f"🔢 Score: {score:.4f}")
|
422 |
+
print(f"📄 Content: {doc.page_content}")
|
423 |
+
if score < threshold:
|
424 |
+
filtered.append(doc)
|
425 |
+
print("✅ Accepted")
|
426 |
+
else:
|
427 |
+
print("❌ Rejected")
|
428 |
+
print("-" * 80)
|
429 |
+
|
430 |
|
431 |
# -----------------------------
|
432 |
+
# 6. Create LangChain Retriever Tool
|
433 |
# -----------------------------
|
434 |
retriever = vector_store.as_retriever()
|
435 |
|