File size: 2,931 Bytes
54f351c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# modified from https://github.com/FlagOpen/FlagEmbedding/blob/master/FlagEmbedding/BGE_M3/modeling.py

import os
import torch
from torch import nn, Tensor
from transformers import AutoModel, AutoConfig
from huggingface_hub import snapshot_download
from typing import Dict


class BGEM3InferenceModel(nn.Module):
    def __init__(
        self,
        model_name: str = "BAAI/bge-m3",
        colbert_dim: int = -1,
    ) -> None:
        super().__init__()

        model_name = snapshot_download(
            repo_id=model_name,
            allow_patterns=[
                "model.safetensors",
                "colbert_linear.pt",
                "sparse_linear.pt",
                "config.json",
            ],
        )

        self.config = AutoConfig.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)
        self.colbert_linear = torch.nn.Linear(
            in_features=self.model.config.hidden_size,
            out_features=(
                self.model.config.hidden_size if colbert_dim == -1 else colbert_dim
            ),
        )
        self.sparse_linear = torch.nn.Linear(
            in_features=self.model.config.hidden_size, out_features=1
        )
        colbert_state_dict = torch.load(
            os.path.join(model_name, "colbert_linear.pt"), map_location="cpu"
        )
        sparse_state_dict = torch.load(
            os.path.join(model_name, "sparse_linear.pt"), map_location="cpu"
        )
        self.colbert_linear.load_state_dict(colbert_state_dict)
        self.sparse_linear.load_state_dict(sparse_state_dict)

    def dense_embedding(self, last_hidden_state: Tensor) -> Tensor:
        return last_hidden_state[:, 0]

    def sparse_embedding(self, last_hidden_state: Tensor) -> Tensor:
        with torch.no_grad():
            return torch.relu(self.sparse_linear(last_hidden_state))

    def colbert_embedding(
        self, last_hidden_state: Tensor, attention_mask: Tensor
    ) -> Tensor:
        with torch.no_grad():
            colbert_vecs = self.colbert_linear(last_hidden_state[:, 1:])
        colbert_vecs = colbert_vecs * attention_mask[:, 1:][:, :, None].float()
        return colbert_vecs

    def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Dict[str, Tensor]:
        with torch.no_grad():
            last_hidden_state = self.model(
                input_ids=input_ids, attention_mask=attention_mask, return_dict=True
            ).last_hidden_state

        output = {}
        dense_vecs = self.dense_embedding(last_hidden_state)
        output["dense_vecs"] = torch.nn.functional.normalize(dense_vecs, dim=-1)

        sparse_vecs = self.sparse_embedding(last_hidden_state)
        output["sparse_vecs"] = sparse_vecs

        colbert_vecs = self.colbert_embedding(last_hidden_state, attention_mask)
        output["colbert_vecs"] = torch.nn.functional.normalize(colbert_vecs, dim=-1)

        return output