xzx22 commited on
Commit
8e79170
·
verified ·
1 Parent(s): 131e42d

Upload 3 files

Browse files
Files changed (3) hide show
  1. README.md +121 -3
  2. modeling_adaptor.py +91 -0
  3. yuan-adaptors.pth +3 -0
README.md CHANGED
@@ -1,3 +1,121 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - zh
4
+ base_model: OpenSearch-AI/Ops-MoA-Yuan-embedding-1.0
5
+ model-index:
6
+ - name: Ops-MoA-Yuan-embedding-1.0
7
+ results:
8
+ - task:
9
+ type: Retrieval
10
+ dataset:
11
+ type: C-MTEB/CmedqaRetrieval
12
+ name: MTEB CmedqaRetrieval
13
+ config: default
14
+ split: dev
15
+ revision: cd540c506dae1cf9e9a59c3e06f42030d54e7301
16
+ metrics:
17
+ - type: ndcg_at_10
18
+ value: 51.46
19
+ - task:
20
+ type: Retrieval
21
+ dataset:
22
+ type: C-MTEB/CovidRetrieval
23
+ name: MTEB CovidRetrieval
24
+ config: default
25
+ split: dev
26
+ revision: 1271c7809071a13532e05f25fb53511ffce77117
27
+ metrics:
28
+ - type: ndcg_at_10
29
+ value: 93.2
30
+ - task:
31
+ type: Retrieval
32
+ dataset:
33
+ type: C-MTEB/DuRetrieval
34
+ name: MTEB DuRetrieval
35
+ config: default
36
+ split: dev
37
+ revision: a1a333e290fe30b10f3f56498e3a0d911a693ced
38
+ metrics:
39
+ - type: ndcg_at_10
40
+ value: 89.84
41
+ - task:
42
+ type: Retrieval
43
+ dataset:
44
+ type: C-MTEB/EcomRetrieval
45
+ name: MTEB EcomRetrieval
46
+ config: default
47
+ split: dev
48
+ revision: 687de13dc7294d6fd9be10c6945f9e8fec8166b9
49
+ metrics:
50
+ - type: ndcg_at_10
51
+ value: 71.08
52
+ - task:
53
+ type: Retrieval
54
+ dataset:
55
+ type: C-MTEB/MMarcoRetrieval
56
+ name: MTEB MMarcoRetrieval
57
+ config: default
58
+ split: dev
59
+ revision: 539bbde593d947e2a124ba72651aafc09eb33fc2
60
+ metrics:
61
+ - type: ndcg_at_10
62
+ value: 79.27
63
+ - task:
64
+ type: Retrieval
65
+ dataset:
66
+ type: C-MTEB/MedicalRetrieval
67
+ name: MTEB MedicalRetrieval
68
+ config: default
69
+ split: dev
70
+ revision: 2039188fb5800a9803ba5048df7b76e6fb151fc6
71
+ metrics:
72
+ - type: ndcg_at_10
73
+ value: 74.84
74
+ - task:
75
+ type: Retrieval
76
+ dataset:
77
+ type: C-MTEB/T2Retrieval
78
+ name: MTEB T2Retrieval
79
+ config: default
80
+ split: dev
81
+ revision: 8731a845f1bf500a4f111cf1070785c793d10e64
82
+ metrics:
83
+ - type: ndcg_at_10
84
+ value: 85.78
85
+ - task:
86
+ type: Retrieval
87
+ dataset:
88
+ type: C-MTEB/VideoRetrieval
89
+ name: MTEB VideoRetrieval
90
+ config: default
91
+ split: dev
92
+ revision: 58c2597a5943a2ba48f4668c3b90d796283c5639
93
+ metrics:
94
+ - type: ndcg_at_10
95
+ value: 79.51
96
+ pipeline_tag: feature-extraction
97
+ tags:
98
+ - mteb
99
+ - sentence-transformers
100
+ library_name: transformers
101
+ ---
102
+
103
+ ```python
104
+ import torch.nn as nn
105
+ from sentence_transformers import SentenceTransformer
106
+ from modeling_adaptor import MixtureOfAdaptors
107
+ class CustomSentenceTransformer(nn.Module):
108
+ def __init__(self, output_dim=1536):
109
+ super(CustomSentenceTransformer, self).__init__()
110
+ self.model = SentenceTransformer('IEITYuan/Yuan-embedding-1.0', trust_remote_code=True)
111
+ adaptor = MixtureOfAdaptors(5, 1792)
112
+ adaptor.load_state_dict(torch.load(f"yuan-adaptors.pth"))
113
+ self.model.add_module('adaptor', adaptor)
114
+ self.output_dim = output_dim
115
+
116
+ def encode(self, sentences, **kwargs):
117
+ embeddings = self.model.encode(sentences, **kwargs)
118
+ return embeddings[:, :self.output_dim]
119
+
120
+ model = CustomSentenceTransformer(output_dim=1536)
121
+ model.encode(['text'])
modeling_adaptor.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+ import torch.nn.functional as F
5
+ import joblib
6
+
7
+ class MoAGate(nn.Module):
8
+ def __init__(self, num_adaptors, hidden_dim):
9
+ super().__init__()
10
+ self.routing_vectors = nn.Parameter(
11
+ torch.empty(num_adaptors, hidden_dim, dtype=torch.float32),
12
+ requires_grad=False
13
+ )
14
+ def forward(self, hidden_states):
15
+ if self.routing_vectors.device == torch.device('cpu'):
16
+ self.routing_vectors = self.routing_vectors.to(hidden_states.device)
17
+ hidden_states = hidden_states.unsqueeze(1)
18
+ batch_size, seq_len, hidden_dim = hidden_states.shape
19
+
20
+ hidden_states = hidden_states.view(-1, hidden_dim)
21
+ distances = torch.cdist(hidden_states, self.routing_vectors)
22
+
23
+ _, cluster_indices = torch.min(distances, dim=1)
24
+ cluster_indices = cluster_indices.view(-1, 1)
25
+
26
+ topk_indices = cluster_indices
27
+ topk_indices = torch.zeros_like(topk_indices, device=hidden_states.device)
28
+ topk_weights = torch.ones_like(topk_indices, device=hidden_states.device)
29
+
30
+ return topk_indices, topk_weights
31
+
32
+ class LinearLayer(nn.Module):
33
+ def __init__(self, input_dim, output_dim):
34
+ super().__init__()
35
+ self.linear = nn.Linear(input_dim, output_dim)
36
+
37
+ def forward(self, x):
38
+ return self.linear(x)
39
+
40
+ class MixtureOfAdaptors(nn.Module):
41
+ def __init__(self, num_adaptors, hidden_dim):
42
+ super().__init__()
43
+ self.adaptors = nn.ModuleList([
44
+ LinearLayer(input_dim=hidden_dim, output_dim=hidden_dim)
45
+ for _ in range(num_adaptors)
46
+ ])
47
+ self.gate = MoAGate(num_adaptors, hidden_dim)
48
+
49
+ def forward(self, inputs):
50
+ if isinstance(inputs, dict):
51
+ hidden_states = inputs['sentence_embedding']
52
+ else:
53
+ hidden_states = inputs
54
+
55
+ residual = hidden_states
56
+ original_shape = hidden_states.shape
57
+ topk_indices, topk_weights = self.gate(hidden_states)
58
+
59
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
60
+ flat_topk_indices = topk_indices.view(-1)
61
+ output = self.moa_inference(hidden_states, flat_topk_indices, topk_weights.view(-1, 1)).view(*original_shape)
62
+
63
+ if isinstance(inputs, dict):
64
+ inputs['sentence_embedding'] = output
65
+ return inputs
66
+ return output
67
+
68
+ @torch.no_grad()
69
+ def moa_inference(self, x, flat_adaptor_indices, flat_adaptor_weights):
70
+ adaptor_cache = torch.zeros_like(x)
71
+ sorted_indices = flat_adaptor_indices.argsort()
72
+ tokens_per_adaptor = flat_adaptor_indices.bincount().cpu().numpy().cumsum(0)
73
+ token_indices = sorted_indices
74
+ for i, end_idx in enumerate(tokens_per_adaptor):
75
+ start_idx = 0 if i == 0 else tokens_per_adaptor[i-1]
76
+ if start_idx == end_idx:
77
+ continue
78
+ adaptor = self.adaptors[i]
79
+ adaptor_token_indices = token_indices[start_idx:end_idx]
80
+ adaptor_tokens = x[adaptor_token_indices]
81
+ adaptor_output = adaptor(adaptor_tokens)
82
+ adaptor_output.mul_(flat_adaptor_weights[sorted_indices[start_idx:end_idx]])
83
+ adaptor_cache.scatter_reduce_(
84
+ 0,
85
+ adaptor_token_indices.view(-1, 1).repeat(1, x.shape[-1]),
86
+ adaptor_output,
87
+ reduce='sum'
88
+ )
89
+ return adaptor_cache
90
+
91
+
yuan-adaptors.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0f6968776d9d3f223d89e115050f15ce0cd11a62d5dbaeb65cbfac16b0443901
3
+ size 64301391