xzx22 commited on
Commit
cb7e414
·
verified ·
1 Parent(s): 84d8057

Upload 3 files

Browse files
Files changed (3) hide show
  1. README.md +120 -3
  2. conan-adaptors.pth +3 -0
  3. modeling_adaptor.py +89 -0
README.md CHANGED
@@ -1,3 +1,120 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - zh
4
+ base_model: OpenSearch-AI/Ops-MoA-Conan-embedding-v1
5
+ model-index:
6
+ - name: Ops-MoA-Conan-embedding-v1
7
+ results:
8
+ - task:
9
+ type: Retrieval
10
+ dataset:
11
+ type: C-MTEB/CmedqaRetrieval
12
+ name: MTEB CmedqaRetrieval
13
+ config: default
14
+ split: dev
15
+ revision: cd540c506dae1cf9e9a59c3e06f42030d54e7301
16
+ metrics:
17
+ - type: ndcg_at_10
18
+ value: 48.21
19
+ - task:
20
+ type: Retrieval
21
+ dataset:
22
+ type: C-MTEB/CovidRetrieval
23
+ name: MTEB CovidRetrieval
24
+ config: default
25
+ split: dev
26
+ revision: 1271c7809071a13532e05f25fb53511ffce77117
27
+ metrics:
28
+ - type: ndcg_at_10
29
+ value: 92.66
30
+ - task:
31
+ type: Retrieval
32
+ dataset:
33
+ type: C-MTEB/DuRetrieval
34
+ name: MTEB DuRetrieval
35
+ config: default
36
+ split: dev
37
+ revision: a1a333e290fe30b10f3f56498e3a0d911a693ced
38
+ metrics:
39
+ - type: ndcg_at_10
40
+ value: 89.23
41
+ - task:
42
+ type: Retrieval
43
+ dataset:
44
+ type: C-MTEB/EcomRetrieval
45
+ name: MTEB EcomRetrieval
46
+ config: default
47
+ split: dev
48
+ revision: 687de13dc7294d6fd9be10c6945f9e8fec8166b9
49
+ metrics:
50
+ - type: ndcg_at_10
51
+ value: 70.93
52
+ - task:
53
+ type: Retrieval
54
+ dataset:
55
+ type: C-MTEB/MMarcoRetrieval
56
+ name: MTEB MMarcoRetrieval
57
+ config: default
58
+ split: dev
59
+ revision: 539bbde593d947e2a124ba72651aafc09eb33fc2
60
+ metrics:
61
+ - type: ndcg_at_10
62
+ value: 82.35
63
+ - task:
64
+ type: Retrieval
65
+ dataset:
66
+ type: C-MTEB/MedicalRetrieval
67
+ name: MTEB MedicalRetrieval
68
+ config: default
69
+ split: dev
70
+ revision: 2039188fb5800a9803ba5048df7b76e6fb151fc6
71
+ metrics:
72
+ - type: ndcg_at_10
73
+ value: 68.27
74
+ - task:
75
+ type: Retrieval
76
+ dataset:
77
+ type: C-MTEB/T2Retrieval
78
+ name: MTEB T2Retrieval
79
+ config: default
80
+ split: dev
81
+ revision: 8731a845f1bf500a4f111cf1070785c793d10e64
82
+ metrics:
83
+ - type: ndcg_at_10
84
+ value: 83.51
85
+ - task:
86
+ type: Retrieval
87
+ dataset:
88
+ type: C-MTEB/VideoRetrieval
89
+ name: MTEB VideoRetrieval
90
+ config: default
91
+ split: dev
92
+ revision: 58c2597a5943a2ba48f4668c3b90d796283c5639
93
+ metrics:
94
+ - type: ndcg_at_10
95
+ value: 80.64
96
+ pipeline_tag: feature-extraction
97
+ tags:
98
+ - mteb
99
+ - sentence-transformers
100
+ library_name: transformers
101
+ ---
102
+ ```python
103
+ import torch.nn as nn
104
+ from sentence_transformers import SentenceTransformer
105
+ from modeling_adaptor import MixtureOfAdaptors
106
+ class CustomSentenceTransformer(nn.Module):
107
+ def __init__(self, output_dim=1536):
108
+ super(CustomSentenceTransformer, self).__init__()
109
+ self.model = SentenceTransformer('TencentBAC/Conan-embedding-v1', trust_remote_code=True)
110
+ adaptor = MixtureOfAdaptors(5, 1792)
111
+ adaptor.load_state_dict(torch.load(f"conan-adaptors.pth"))
112
+ self.model.add_module('adaptor', adaptor)
113
+ self.output_dim = output_dim
114
+
115
+ def encode(self, sentences, **kwargs):
116
+ embeddings = self.model.encode(sentences, **kwargs)
117
+ return embeddings[:, :self.output_dim]
118
+
119
+ model = CustomSentenceTransformer(output_dim=1536)
120
+ model.encode(['text'])
conan-adaptors.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b47a90b0cfe78adbc7a13bb4397f993722fac7c12c6f3a2bc0fdde9ee788a2ec
3
+ size 64301406
modeling_adaptor.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+ import torch.nn.functional as F
5
+ import joblib
6
+
7
+ class MoAGate(nn.Module):
8
+ def __init__(self, num_adaptors, hidden_dim):
9
+ super().__init__()
10
+ self.routing_vectors = nn.Parameter(
11
+ torch.empty(num_adaptors, hidden_dim, dtype=torch.float32),
12
+ requires_grad=False
13
+ )
14
+ def forward(self, hidden_states):
15
+ hidden_states = hidden_states.unsqueeze(1)
16
+ batch_size, seq_len, hidden_dim = hidden_states.shape
17
+
18
+ hidden_states = hidden_states.view(-1, hidden_dim)
19
+ distances = torch.cdist(hidden_states, self.routing_vectors)
20
+
21
+ _, cluster_indices = torch.min(distances, dim=1)
22
+ cluster_indices = cluster_indices.view(-1, 1)
23
+
24
+ topk_indices = cluster_indices
25
+ topk_indices = torch.zeros_like(topk_indices, device=hidden_states.device)
26
+ topk_weights = torch.ones_like(topk_indices, device=hidden_states.device)
27
+
28
+ return topk_indices, topk_weights
29
+
30
+ class LinearLayer(nn.Module):
31
+ def __init__(self, input_dim, output_dim):
32
+ super().__init__()
33
+ self.linear = nn.Linear(input_dim, output_dim)
34
+
35
+ def forward(self, x):
36
+ return self.linear(x)
37
+
38
+ class MixtureOfAdaptors(nn.Module):
39
+ def __init__(self, num_adaptors, hidden_dim):
40
+ super().__init__()
41
+ self.adaptors = nn.ModuleList([
42
+ LinearLayer(input_dim=hidden_dim, output_dim=hidden_dim)
43
+ for _ in range(num_adaptors)
44
+ ])
45
+ self.gate = MoAGate(num_adaptors, hidden_dim)
46
+
47
+ def forward(self, inputs):
48
+ if isinstance(inputs, dict):
49
+ hidden_states = inputs['sentence_embedding']
50
+ else:
51
+ hidden_states = inputs
52
+
53
+ residual = hidden_states
54
+ original_shape = hidden_states.shape
55
+ topk_indices, topk_weights = self.gate(hidden_states)
56
+
57
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
58
+ flat_topk_indices = topk_indices.view(-1)
59
+ output = self.moa_inference(hidden_states, flat_topk_indices, topk_weights.view(-1, 1)).view(*original_shape)
60
+
61
+ if isinstance(inputs, dict):
62
+ inputs['sentence_embedding'] = output
63
+ return inputs
64
+ return output
65
+
66
+ @torch.no_grad()
67
+ def moa_inference(self, x, flat_adaptor_indices, flat_adaptor_weights):
68
+ adaptor_cache = torch.zeros_like(x)
69
+ sorted_indices = flat_adaptor_indices.argsort()
70
+ tokens_per_adaptor = flat_adaptor_indices.bincount().cpu().numpy().cumsum(0)
71
+ token_indices = sorted_indices
72
+ for i, end_idx in enumerate(tokens_per_adaptor):
73
+ start_idx = 0 if i == 0 else tokens_per_adaptor[i-1]
74
+ if start_idx == end_idx:
75
+ continue
76
+ adaptor = self.adaptors[i]
77
+ adaptor_token_indices = token_indices[start_idx:end_idx]
78
+ adaptor_tokens = x[adaptor_token_indices]
79
+ adaptor_output = adaptor(adaptor_tokens)
80
+ adaptor_output.mul_(flat_adaptor_weights[sorted_indices[start_idx:end_idx]])
81
+ adaptor_cache.scatter_reduce_(
82
+ 0,
83
+ adaptor_token_indices.view(-1, 1).repeat(1, x.shape[-1]),
84
+ adaptor_output,
85
+ reduce='sum'
86
+ )
87
+ return adaptor_cache
88
+
89
+