cv / embedder.py
saherPervaiz's picture
Update embedder.py
d2ef00f verified
raw
history blame
544 Bytes
# embedder.py
from transformers import AutoTokenizer, AutoModel
import torch
# Use a model with PyTorch weights available
MODEL_NAME = "thenlper/gte-small"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME)
def get_embeddings(texts):
inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
with torch.no_grad():
model_output = model(**inputs)
# Mean Pooling
embeddings = model_output.last_hidden_state.mean(dim=1)
return embeddings.numpy()