File size: 544 Bytes
cb19513
 
d2ef00f
 
60fdc18
d2ef00f
 
 
 
 
60fdc18
f91e02f
d2ef00f
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# embedder.py

from transformers import AutoTokenizer, AutoModel
import torch

# Use a model with PyTorch weights available
MODEL_NAME = "thenlper/gte-small"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME)

def get_embeddings(texts):
    inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
    with torch.no_grad():
        model_output = model(**inputs)

    # Mean Pooling
    embeddings = model_output.last_hidden_state.mean(dim=1)
    return embeddings.numpy()