Spaces:
Sleeping
Sleeping
Update agent.py
Browse files
agent.py
CHANGED
|
@@ -41,7 +41,7 @@ from io import StringIO
|
|
| 41 |
|
| 42 |
from transformers import BertTokenizer, BertModel
|
| 43 |
import torch
|
| 44 |
-
|
| 45 |
|
| 46 |
|
| 47 |
load_dotenv()
|
|
@@ -361,28 +361,75 @@ class BERTEmbeddings(Embeddings):
|
|
| 361 |
|
| 362 |
# Example usage of BERTEmbedding with LangChain
|
| 363 |
|
| 364 |
-
embedding_model = BERTEmbeddings(model_name="bert-base-uncased")
|
| 365 |
|
| 366 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 367 |
docs = [
|
| 368 |
-
Document(page_content="Mercedes Sosa released many albums between 2000 and 2009."),
|
| 369 |
-
Document(page_content="She was a prominent Argentine folk singer."),
|
| 370 |
-
Document(page_content="Her album 'Al Despertar' was released in 1998."),
|
| 371 |
-
Document(page_content="She continued releasing music well into the 2000s.")
|
| 372 |
]
|
| 373 |
-
# Get the embeddings for the documents
|
| 374 |
-
vector_store = FAISS.from_documents(docs, embedding_model)
|
| 375 |
|
| 376 |
-
# Now, you can use the embeddings with FAISS or other retrieval systems
|
| 377 |
-
# For example, with FAISS:
|
| 378 |
|
| 379 |
-
#
|
|
|
|
|
|
|
| 380 |
vector_store = FAISS.from_documents(docs, embedding_model)
|
| 381 |
vector_store.save_local("faiss_index")
|
| 382 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 383 |
|
| 384 |
# -----------------------------
|
| 385 |
-
#
|
| 386 |
# -----------------------------
|
| 387 |
retriever = vector_store.as_retriever()
|
| 388 |
|
|
|
|
| 41 |
|
| 42 |
from transformers import BertTokenizer, BertModel
|
| 43 |
import torch
|
| 44 |
+
import torch.nn.functional as F
|
| 45 |
|
| 46 |
|
| 47 |
load_dotenv()
|
|
|
|
| 361 |
|
| 362 |
# Example usage of BERTEmbedding with LangChain
|
| 363 |
|
|
|
|
| 364 |
|
| 365 |
+
# -----------------------------
|
| 366 |
+
# 1. Define Custom BERT Embedding Model
|
| 367 |
+
# -----------------------------
|
| 368 |
+
class BERTEmbeddings(Embeddings):
|
| 369 |
+
def __init__(self, model_name='bert-base-uncased'):
|
| 370 |
+
self.tokenizer = BertTokenizer.from_pretrained(model_name)
|
| 371 |
+
self.model = BertModel.from_pretrained(model_name)
|
| 372 |
+
self.model.eval() # Set model to eval mode
|
| 373 |
+
|
| 374 |
+
def embed_documents(self, texts):
|
| 375 |
+
inputs = self.tokenizer(texts, return_tensors='pt', padding=True, truncation=True)
|
| 376 |
+
with torch.no_grad():
|
| 377 |
+
outputs = self.model(**inputs)
|
| 378 |
+
embeddings = outputs.last_hidden_state.mean(dim=1)
|
| 379 |
+
embeddings = F.normalize(embeddings, p=2, dim=1) # Normalize for cosine similarity
|
| 380 |
+
return embeddings.cpu().numpy()
|
| 381 |
+
|
| 382 |
+
def embed_query(self, text):
|
| 383 |
+
return self.embed_documents([text])[0]
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
# -----------------------------
|
| 387 |
+
# 2. Initialize Embedding Model
|
| 388 |
+
# -----------------------------
|
| 389 |
+
embedding_model = BERTEmbeddings()
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
# -----------------------------
|
| 393 |
+
# 3. Prepare Documents
|
| 394 |
+
# -----------------------------
|
| 395 |
docs = [
|
| 396 |
+
Document(page_content="Mercedes Sosa released many albums between 2000 and 2009.", metadata={"id": 1}),
|
| 397 |
+
Document(page_content="She was a prominent Argentine folk singer.", metadata={"id": 2}),
|
| 398 |
+
Document(page_content="Her album 'Al Despertar' was released in 1998.", metadata={"id": 3}),
|
| 399 |
+
Document(page_content="She continued releasing music well into the 2000s.", metadata={"id": 4}),
|
| 400 |
]
|
|
|
|
|
|
|
| 401 |
|
|
|
|
|
|
|
| 402 |
|
| 403 |
+
# -----------------------------
|
| 404 |
+
# 4. Create FAISS Vector Store
|
| 405 |
+
# -----------------------------
|
| 406 |
vector_store = FAISS.from_documents(docs, embedding_model)
|
| 407 |
vector_store.save_local("faiss_index")
|
| 408 |
|
| 409 |
+
# -----------------------------
|
| 410 |
+
# 5. Query & Filter Results (optional preview)
|
| 411 |
+
# -----------------------------
|
| 412 |
+
query = "How many albums did Mercedes Sosa release between 2000 and 2009?"
|
| 413 |
+
results = vector_store.similarity_search_with_score(query, k=5)
|
| 414 |
+
threshold = 0.75
|
| 415 |
+
filtered = [doc for doc, score in results if score < threshold]
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
print("\n📊 Retrieved Documents with Similarity Scores:")
|
| 419 |
+
filtered = []
|
| 420 |
+
for doc, score in results:
|
| 421 |
+
print(f"🔢 Score: {score:.4f}")
|
| 422 |
+
print(f"📄 Content: {doc.page_content}")
|
| 423 |
+
if score < threshold:
|
| 424 |
+
filtered.append(doc)
|
| 425 |
+
print("✅ Accepted")
|
| 426 |
+
else:
|
| 427 |
+
print("❌ Rejected")
|
| 428 |
+
print("-" * 80)
|
| 429 |
+
|
| 430 |
|
| 431 |
# -----------------------------
|
| 432 |
+
# 6. Create LangChain Retriever Tool
|
| 433 |
# -----------------------------
|
| 434 |
retriever = vector_store.as_retriever()
|
| 435 |
|