wt002 commited on
Commit
cf02c0e
·
verified ·
1 Parent(s): 7df3234

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +60 -13
agent.py CHANGED
@@ -41,7 +41,7 @@ from io import StringIO
41
 
42
  from transformers import BertTokenizer, BertModel
43
  import torch
44
-
45
 
46
 
47
  load_dotenv()
@@ -361,28 +361,75 @@ class BERTEmbeddings(Embeddings):
361
 
362
  # Example usage of BERTEmbedding with LangChain
363
 
364
- embedding_model = BERTEmbeddings(model_name="bert-base-uncased")
365
 
366
- # Sample text (replace with your own text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
367
  docs = [
368
- Document(page_content="Mercedes Sosa released many albums between 2000 and 2009."),
369
- Document(page_content="She was a prominent Argentine folk singer."),
370
- Document(page_content="Her album 'Al Despertar' was released in 1998."),
371
- Document(page_content="She continued releasing music well into the 2000s.")
372
  ]
373
- # Get the embeddings for the documents
374
- vector_store = FAISS.from_documents(docs, embedding_model)
375
 
376
- # Now, you can use the embeddings with FAISS or other retrieval systems
377
- # For example, with FAISS:
378
 
379
- # Assuming 'docs' contains your list of documents and 'embedding_model' is the model you created
 
 
380
  vector_store = FAISS.from_documents(docs, embedding_model)
381
  vector_store.save_local("faiss_index")
382
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
383
 
384
  # -----------------------------
385
- # Step 4: Create Retriever Tool
386
  # -----------------------------
387
  retriever = vector_store.as_retriever()
388
 
 
41
 
42
  from transformers import BertTokenizer, BertModel
43
  import torch
44
+ import torch.nn.functional as F
45
 
46
 
47
  load_dotenv()
 
361
 
362
  # Example usage of BERTEmbedding with LangChain
363
 
 
364
 
365
+ # -----------------------------
366
+ # 1. Define Custom BERT Embedding Model
367
+ # -----------------------------
368
+ class BERTEmbeddings(Embeddings):
369
+ def __init__(self, model_name='bert-base-uncased'):
370
+ self.tokenizer = BertTokenizer.from_pretrained(model_name)
371
+ self.model = BertModel.from_pretrained(model_name)
372
+ self.model.eval() # Set model to eval mode
373
+
374
+ def embed_documents(self, texts):
375
+ inputs = self.tokenizer(texts, return_tensors='pt', padding=True, truncation=True)
376
+ with torch.no_grad():
377
+ outputs = self.model(**inputs)
378
+ embeddings = outputs.last_hidden_state.mean(dim=1)
379
+ embeddings = F.normalize(embeddings, p=2, dim=1) # Normalize for cosine similarity
380
+ return embeddings.cpu().numpy()
381
+
382
+ def embed_query(self, text):
383
+ return self.embed_documents([text])[0]
384
+
385
+
386
+ # -----------------------------
387
+ # 2. Initialize Embedding Model
388
+ # -----------------------------
389
+ embedding_model = BERTEmbeddings()
390
+
391
+
392
+ # -----------------------------
393
+ # 3. Prepare Documents
394
+ # -----------------------------
395
  docs = [
396
+ Document(page_content="Mercedes Sosa released many albums between 2000 and 2009.", metadata={"id": 1}),
397
+ Document(page_content="She was a prominent Argentine folk singer.", metadata={"id": 2}),
398
+ Document(page_content="Her album 'Al Despertar' was released in 1998.", metadata={"id": 3}),
399
+ Document(page_content="She continued releasing music well into the 2000s.", metadata={"id": 4}),
400
  ]
 
 
401
 
 
 
402
 
403
+ # -----------------------------
404
+ # 4. Create FAISS Vector Store
405
+ # -----------------------------
406
  vector_store = FAISS.from_documents(docs, embedding_model)
407
  vector_store.save_local("faiss_index")
408
 
409
+ # -----------------------------
410
+ # 5. Query & Filter Results (optional preview)
411
+ # -----------------------------
412
+ query = "How many albums did Mercedes Sosa release between 2000 and 2009?"
413
+ results = vector_store.similarity_search_with_score(query, k=5)
414
+ threshold = 0.75
415
+ filtered = [doc for doc, score in results if score < threshold]
416
+
417
+
418
+ print("\n📊 Retrieved Documents with Similarity Scores:")
419
+ filtered = []
420
+ for doc, score in results:
421
+ print(f"🔢 Score: {score:.4f}")
422
+ print(f"📄 Content: {doc.page_content}")
423
+ if score < threshold:
424
+ filtered.append(doc)
425
+ print("✅ Accepted")
426
+ else:
427
+ print("❌ Rejected")
428
+ print("-" * 80)
429
+
430
 
431
  # -----------------------------
432
+ # 6. Create LangChain Retriever Tool
433
  # -----------------------------
434
  retriever = vector_store.as_retriever()
435