saherPervaiz commited on
Commit
bc8d8e3
·
verified ·
1 Parent(s): 8fa1c54

Update embedder.py

Browse files
Files changed (1) hide show
  1. embedder.py +39 -9
embedder.py CHANGED
@@ -1,7 +1,6 @@
1
- # embedder.py
2
-
3
  from transformers import AutoTokenizer, AutoModel
4
  import torch
 
5
 
6
  # Use a model with PyTorch weights available
7
  MODEL_NAME = "thenlper/gte-small"
@@ -9,11 +8,42 @@ MODEL_NAME = "thenlper/gte-small"
9
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
10
  model = AutoModel.from_pretrained(MODEL_NAME)
11
 
12
- def get_embeddings(texts):
13
- inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
14
- with torch.no_grad():
15
- model_output = model(**inputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- # Mean Pooling
18
- embeddings = model_output.last_hidden_state.mean(dim=1)
19
- return embeddings.numpy()
 
 
 
1
  from transformers import AutoTokenizer, AutoModel
2
  import torch
3
+ import numpy as np
4
 
5
  # Use a model with PyTorch weights available
6
  MODEL_NAME = "thenlper/gte-small"
 
8
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
9
  model = AutoModel.from_pretrained(MODEL_NAME)
10
 
11
+ def get_embeddings(texts, max_length=512):
12
+ """
13
+ Generate embeddings for long text by chunking and averaging.
14
+
15
+ Args:
16
+ texts (str or list): One or multiple texts to embed.
17
+ max_length (int): Maximum tokens per chunk (default is 512).
18
+
19
+ Returns:
20
+ np.ndarray: Averaged embeddings.
21
+ """
22
+ if isinstance(texts, str):
23
+ texts = [texts]
24
+
25
+ final_embeddings = []
26
+
27
+ for text in texts:
28
+ # Tokenize and split into chunks
29
+ tokens = tokenizer.tokenize(text)
30
+ chunks = [tokens[i:i + max_length] for i in range(0, len(tokens), max_length)]
31
+
32
+ chunk_embeddings = []
33
+
34
+ for chunk in chunks:
35
+ input_ids = tokenizer.convert_tokens_to_ids(chunk)
36
+ input_ids = torch.tensor([input_ids])
37
+ with torch.no_grad():
38
+ output = model(input_ids=input_ids)
39
+ embedding = output.last_hidden_state.mean(dim=1) # Mean pooling
40
+ chunk_embeddings.append(embedding)
41
+
42
+ # Average embeddings of all chunks
43
+ if chunk_embeddings:
44
+ avg_embedding = torch.stack(chunk_embeddings).mean(dim=0)
45
+ final_embeddings.append(avg_embedding.squeeze(0).numpy())
46
+ else:
47
+ final_embeddings.append(np.zeros(model.config.hidden_size))
48
 
49
+ return np.array(final_embeddings)