Upload 3 files
Browse files- README.md +121 -3
- modeling_adaptor.py +91 -0
- yuan-adaptors.pth +3 -0
README.md
CHANGED
@@ -1,3 +1,121 @@
|
|
1 |
-
---
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
language:
|
3 |
+
- zh
|
4 |
+
base_model: OpenSearch-AI/Ops-MoA-Yuan-embedding-1.0
|
5 |
+
model-index:
|
6 |
+
- name: Ops-MoA-Yuan-embedding-1.0
|
7 |
+
results:
|
8 |
+
- task:
|
9 |
+
type: Retrieval
|
10 |
+
dataset:
|
11 |
+
type: C-MTEB/CmedqaRetrieval
|
12 |
+
name: MTEB CmedqaRetrieval
|
13 |
+
config: default
|
14 |
+
split: dev
|
15 |
+
revision: cd540c506dae1cf9e9a59c3e06f42030d54e7301
|
16 |
+
metrics:
|
17 |
+
- type: ndcg_at_10
|
18 |
+
value: 51.46
|
19 |
+
- task:
|
20 |
+
type: Retrieval
|
21 |
+
dataset:
|
22 |
+
type: C-MTEB/CovidRetrieval
|
23 |
+
name: MTEB CovidRetrieval
|
24 |
+
config: default
|
25 |
+
split: dev
|
26 |
+
revision: 1271c7809071a13532e05f25fb53511ffce77117
|
27 |
+
metrics:
|
28 |
+
- type: ndcg_at_10
|
29 |
+
value: 93.2
|
30 |
+
- task:
|
31 |
+
type: Retrieval
|
32 |
+
dataset:
|
33 |
+
type: C-MTEB/DuRetrieval
|
34 |
+
name: MTEB DuRetrieval
|
35 |
+
config: default
|
36 |
+
split: dev
|
37 |
+
revision: a1a333e290fe30b10f3f56498e3a0d911a693ced
|
38 |
+
metrics:
|
39 |
+
- type: ndcg_at_10
|
40 |
+
value: 89.84
|
41 |
+
- task:
|
42 |
+
type: Retrieval
|
43 |
+
dataset:
|
44 |
+
type: C-MTEB/EcomRetrieval
|
45 |
+
name: MTEB EcomRetrieval
|
46 |
+
config: default
|
47 |
+
split: dev
|
48 |
+
revision: 687de13dc7294d6fd9be10c6945f9e8fec8166b9
|
49 |
+
metrics:
|
50 |
+
- type: ndcg_at_10
|
51 |
+
value: 71.08
|
52 |
+
- task:
|
53 |
+
type: Retrieval
|
54 |
+
dataset:
|
55 |
+
type: C-MTEB/MMarcoRetrieval
|
56 |
+
name: MTEB MMarcoRetrieval
|
57 |
+
config: default
|
58 |
+
split: dev
|
59 |
+
revision: 539bbde593d947e2a124ba72651aafc09eb33fc2
|
60 |
+
metrics:
|
61 |
+
- type: ndcg_at_10
|
62 |
+
value: 79.27
|
63 |
+
- task:
|
64 |
+
type: Retrieval
|
65 |
+
dataset:
|
66 |
+
type: C-MTEB/MedicalRetrieval
|
67 |
+
name: MTEB MedicalRetrieval
|
68 |
+
config: default
|
69 |
+
split: dev
|
70 |
+
revision: 2039188fb5800a9803ba5048df7b76e6fb151fc6
|
71 |
+
metrics:
|
72 |
+
- type: ndcg_at_10
|
73 |
+
value: 74.84
|
74 |
+
- task:
|
75 |
+
type: Retrieval
|
76 |
+
dataset:
|
77 |
+
type: C-MTEB/T2Retrieval
|
78 |
+
name: MTEB T2Retrieval
|
79 |
+
config: default
|
80 |
+
split: dev
|
81 |
+
revision: 8731a845f1bf500a4f111cf1070785c793d10e64
|
82 |
+
metrics:
|
83 |
+
- type: ndcg_at_10
|
84 |
+
value: 85.78
|
85 |
+
- task:
|
86 |
+
type: Retrieval
|
87 |
+
dataset:
|
88 |
+
type: C-MTEB/VideoRetrieval
|
89 |
+
name: MTEB VideoRetrieval
|
90 |
+
config: default
|
91 |
+
split: dev
|
92 |
+
revision: 58c2597a5943a2ba48f4668c3b90d796283c5639
|
93 |
+
metrics:
|
94 |
+
- type: ndcg_at_10
|
95 |
+
value: 79.51
|
96 |
+
pipeline_tag: feature-extraction
|
97 |
+
tags:
|
98 |
+
- mteb
|
99 |
+
- sentence-transformers
|
100 |
+
library_name: transformers
|
101 |
+
---
|
102 |
+
|
103 |
+
```python
|
104 |
+
import torch.nn as nn
|
105 |
+
from sentence_transformers import SentenceTransformer
|
106 |
+
from modeling_adaptor import MixtureOfAdaptors
|
107 |
+
class CustomSentenceTransformer(nn.Module):
|
108 |
+
def __init__(self, output_dim=1536):
|
109 |
+
super(CustomSentenceTransformer, self).__init__()
|
110 |
+
self.model = SentenceTransformer('IEITYuan/Yuan-embedding-1.0', trust_remote_code=True)
|
111 |
+
adaptor = MixtureOfAdaptors(5, 1792)
|
112 |
+
adaptor.load_state_dict(torch.load(f"yuan-adaptors.pth"))
|
113 |
+
self.model.add_module('adaptor', adaptor)
|
114 |
+
self.output_dim = output_dim
|
115 |
+
|
116 |
+
def encode(self, sentences, **kwargs):
|
117 |
+
embeddings = self.model.encode(sentences, **kwargs)
|
118 |
+
return embeddings[:, :self.output_dim]
|
119 |
+
|
120 |
+
model = CustomSentenceTransformer(output_dim=1536)
|
121 |
+
model.encode(['text'])
|
modeling_adaptor.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import math
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import joblib
|
6 |
+
|
7 |
+
class MoAGate(nn.Module):
|
8 |
+
def __init__(self, num_adaptors, hidden_dim):
|
9 |
+
super().__init__()
|
10 |
+
self.routing_vectors = nn.Parameter(
|
11 |
+
torch.empty(num_adaptors, hidden_dim, dtype=torch.float32),
|
12 |
+
requires_grad=False
|
13 |
+
)
|
14 |
+
def forward(self, hidden_states):
|
15 |
+
if self.routing_vectors.device == torch.device('cpu'):
|
16 |
+
self.routing_vectors = self.routing_vectors.to(hidden_states.device)
|
17 |
+
hidden_states = hidden_states.unsqueeze(1)
|
18 |
+
batch_size, seq_len, hidden_dim = hidden_states.shape
|
19 |
+
|
20 |
+
hidden_states = hidden_states.view(-1, hidden_dim)
|
21 |
+
distances = torch.cdist(hidden_states, self.routing_vectors)
|
22 |
+
|
23 |
+
_, cluster_indices = torch.min(distances, dim=1)
|
24 |
+
cluster_indices = cluster_indices.view(-1, 1)
|
25 |
+
|
26 |
+
topk_indices = cluster_indices
|
27 |
+
topk_indices = torch.zeros_like(topk_indices, device=hidden_states.device)
|
28 |
+
topk_weights = torch.ones_like(topk_indices, device=hidden_states.device)
|
29 |
+
|
30 |
+
return topk_indices, topk_weights
|
31 |
+
|
32 |
+
class LinearLayer(nn.Module):
|
33 |
+
def __init__(self, input_dim, output_dim):
|
34 |
+
super().__init__()
|
35 |
+
self.linear = nn.Linear(input_dim, output_dim)
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
return self.linear(x)
|
39 |
+
|
40 |
+
class MixtureOfAdaptors(nn.Module):
|
41 |
+
def __init__(self, num_adaptors, hidden_dim):
|
42 |
+
super().__init__()
|
43 |
+
self.adaptors = nn.ModuleList([
|
44 |
+
LinearLayer(input_dim=hidden_dim, output_dim=hidden_dim)
|
45 |
+
for _ in range(num_adaptors)
|
46 |
+
])
|
47 |
+
self.gate = MoAGate(num_adaptors, hidden_dim)
|
48 |
+
|
49 |
+
def forward(self, inputs):
|
50 |
+
if isinstance(inputs, dict):
|
51 |
+
hidden_states = inputs['sentence_embedding']
|
52 |
+
else:
|
53 |
+
hidden_states = inputs
|
54 |
+
|
55 |
+
residual = hidden_states
|
56 |
+
original_shape = hidden_states.shape
|
57 |
+
topk_indices, topk_weights = self.gate(hidden_states)
|
58 |
+
|
59 |
+
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
60 |
+
flat_topk_indices = topk_indices.view(-1)
|
61 |
+
output = self.moa_inference(hidden_states, flat_topk_indices, topk_weights.view(-1, 1)).view(*original_shape)
|
62 |
+
|
63 |
+
if isinstance(inputs, dict):
|
64 |
+
inputs['sentence_embedding'] = output
|
65 |
+
return inputs
|
66 |
+
return output
|
67 |
+
|
68 |
+
@torch.no_grad()
|
69 |
+
def moa_inference(self, x, flat_adaptor_indices, flat_adaptor_weights):
|
70 |
+
adaptor_cache = torch.zeros_like(x)
|
71 |
+
sorted_indices = flat_adaptor_indices.argsort()
|
72 |
+
tokens_per_adaptor = flat_adaptor_indices.bincount().cpu().numpy().cumsum(0)
|
73 |
+
token_indices = sorted_indices
|
74 |
+
for i, end_idx in enumerate(tokens_per_adaptor):
|
75 |
+
start_idx = 0 if i == 0 else tokens_per_adaptor[i-1]
|
76 |
+
if start_idx == end_idx:
|
77 |
+
continue
|
78 |
+
adaptor = self.adaptors[i]
|
79 |
+
adaptor_token_indices = token_indices[start_idx:end_idx]
|
80 |
+
adaptor_tokens = x[adaptor_token_indices]
|
81 |
+
adaptor_output = adaptor(adaptor_tokens)
|
82 |
+
adaptor_output.mul_(flat_adaptor_weights[sorted_indices[start_idx:end_idx]])
|
83 |
+
adaptor_cache.scatter_reduce_(
|
84 |
+
0,
|
85 |
+
adaptor_token_indices.view(-1, 1).repeat(1, x.shape[-1]),
|
86 |
+
adaptor_output,
|
87 |
+
reduce='sum'
|
88 |
+
)
|
89 |
+
return adaptor_cache
|
90 |
+
|
91 |
+
|
yuan-adaptors.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0f6968776d9d3f223d89e115050f15ce0cd11a62d5dbaeb65cbfac16b0443901
|
3 |
+
size 64301391
|