Spaces:
Running
Running
Update agent.py
Browse files
agent.py
CHANGED
@@ -324,22 +324,21 @@ for task in tasks:
|
|
324 |
|
325 |
class BERTEmbeddings(Embeddings):
|
326 |
def __init__(self, model_name='bert-base-uncased'):
|
327 |
-
# Load the pre-trained BERT model and tokenizer
|
328 |
self.tokenizer = BertTokenizer.from_pretrained(model_name)
|
329 |
self.model = BertModel.from_pretrained(model_name)
|
330 |
-
|
331 |
-
|
332 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
333 |
inputs = self.tokenizer(texts, return_tensors='pt', padding=True, truncation=True)
|
334 |
-
|
335 |
-
# Get the BERT embeddings (we use the last hidden state)
|
336 |
with torch.no_grad():
|
337 |
outputs = self.model(**inputs)
|
338 |
-
|
339 |
-
# Use the mean of the last layer hidden states as the embedding
|
340 |
-
embeddings = outputs.last_hidden_state.mean(dim=1) # Shape: (batch_size, hidden_dim)
|
341 |
-
|
342 |
-
# Return the embeddings as a list of numpy arrays
|
343 |
return embeddings.cpu().numpy().tolist()
|
344 |
|
345 |
# Example usage of BERTEmbedding with LangChain
|
|
|
324 |
|
325 |
class BERTEmbeddings(Embeddings):
|
326 |
def __init__(self, model_name='bert-base-uncased'):
|
|
|
327 |
self.tokenizer = BertTokenizer.from_pretrained(model_name)
|
328 |
self.model = BertModel.from_pretrained(model_name)
|
329 |
+
self.model.eval() # Set to evaluation mode
|
330 |
+
|
331 |
+
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
332 |
+
return self._embed(texts)
|
333 |
+
|
334 |
+
def embed_query(self, text: str) -> List[float]:
|
335 |
+
return self._embed([text])[0]
|
336 |
+
|
337 |
+
def _embed(self, texts: List[str]) -> List[List[float]]:
|
338 |
inputs = self.tokenizer(texts, return_tensors='pt', padding=True, truncation=True)
|
|
|
|
|
339 |
with torch.no_grad():
|
340 |
outputs = self.model(**inputs)
|
341 |
+
embeddings = outputs.last_hidden_state.mean(dim=1) # Mean pooling
|
|
|
|
|
|
|
|
|
342 |
return embeddings.cpu().numpy().tolist()
|
343 |
|
344 |
# Example usage of BERTEmbedding with LangChain
|