wt002 commited on
Commit
abd3036
·
verified ·
1 Parent(s): 212b89d

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +1 -32
agent.py CHANGED
@@ -29,7 +29,7 @@ import json
29
  from langchain_core.documents import Document
30
  from langchain_community.vectorstores import FAISS
31
  from langchain.vectorstores import FAISS
32
- from langchain.embeddings import BERTEmbeddings
33
  from langchain_community.embeddings import HuggingFaceEmbeddings
34
 
35
  from youtube_transcript_api import YouTubeTranscriptApi
@@ -343,37 +343,6 @@ for name in enabled_tool_names:
343
  tools.append(tool_map[name])
344
 
345
 
346
-
347
- # -------------------------------
348
- # Set up BERT Embeddings
349
- # -------------------------------
350
-
351
- # -----------------------------
352
- # Define Custom BERT Embedding Model
353
- # -----------------------------
354
- import torch
355
- import torch.nn.functional as F
356
- from transformers import BertTokenizer, BertModel
357
-
358
- class BERTEmbeddings:
359
- def __init__(self, model_name='bert-base-uncased'):
360
- self.tokenizer = BertTokenizer.from_pretrained(model_name)
361
- self.model = BertModel.from_pretrained(model_name)
362
- self.model.eval() # Set to evaluation mode
363
-
364
- def embed_documents(self, texts):
365
- inputs = self.tokenizer(texts, return_tensors='pt', padding=True, truncation=True)
366
- with torch.no_grad():
367
- outputs = self.model(**inputs)
368
- embeddings = outputs.last_hidden_state.mean(dim=1)
369
- embeddings = F.normalize(embeddings, p=2, dim=1) # Normalize for cosine similarity
370
- return embeddings.cpu().numpy()
371
-
372
- def embed_query(self, text):
373
- return self.embed_documents([text])[0]
374
-
375
-
376
-
377
  # -----------------------------
378
  # Create FAISS Vector Store
379
  # -----------------------------
 
29
  from langchain_core.documents import Document
30
  from langchain_community.vectorstores import FAISS
31
  from langchain.vectorstores import FAISS
32
+ #from langchain.embeddings import BERTEmbeddings
33
  from langchain_community.embeddings import HuggingFaceEmbeddings
34
 
35
  from youtube_transcript_api import YouTubeTranscriptApi
 
343
  tools.append(tool_map[name])
344
 
345
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
346
  # -----------------------------
347
  # Create FAISS Vector Store
348
  # -----------------------------