saherPervaiz commited on
Commit
d2ef00f
·
verified ·
1 Parent(s): a69a9de

Update embedder.py

Browse files
Files changed (1) hide show
  1. embedder.py +14 -3
embedder.py CHANGED
@@ -1,8 +1,19 @@
1
  # embedder.py
2
 
3
- from sentence_transformers import SentenceTransformer
 
4
 
5
- model = SentenceTransformer("all-MiniLM-L6-v2")
 
 
 
 
6
 
7
  def get_embeddings(texts):
8
- return model.encode(texts)
 
 
 
 
 
 
 
1
  # embedder.py
2
 
3
+ from transformers import AutoTokenizer, AutoModel
4
+ import torch
5
 
6
+ # Use a model with PyTorch weights available
7
+ MODEL_NAME = "thenlper/gte-small"
8
+
9
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
10
+ model = AutoModel.from_pretrained(MODEL_NAME)
11
 
12
  def get_embeddings(texts):
13
+ inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
14
+ with torch.no_grad():
15
+ model_output = model(**inputs)
16
+
17
+ # Mean Pooling
18
+ embeddings = model_output.last_hidden_state.mean(dim=1)
19
+ return embeddings.numpy()