File size: 2,692 Bytes
9fd1204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from diffusers import (
    CogVideoXTransformer3DModel,
    CogView4Transformer2DModel,
    FluxTransformer2DModel,
    WanTransformer3DModel,
)

from finetrainers._metadata import CPInput, CPOutput, ParamId, TransformerMetadata, TransformerRegistry
from finetrainers.logging import get_logger


logger = get_logger()


def register_transformer_metadata():
    # CogVideoX
    TransformerRegistry.register(
        model_class=CogVideoXTransformer3DModel,
        metadata=TransformerMetadata(
            cp_plan={
                "": {
                    ParamId("image_rotary_emb", 5): [CPInput(0, 2), CPInput(0, 2)],
                },
                "transformer_blocks.0": {
                    ParamId("hidden_states", 0): CPInput(1, 3),
                    ParamId("encoder_hidden_states", 1): CPInput(1, 3),
                },
                "proj_out": [CPOutput(1, 3)],
            }
        ),
    )

    # CogView4
    TransformerRegistry.register(
        model_class=CogView4Transformer2DModel,
        metadata=TransformerMetadata(
            cp_plan={
                "patch_embed": {
                    ParamId(index=0): CPInput(1, 3, split_output=True),
                    ParamId(index=1): CPInput(1, 3, split_output=True),
                },
                "rope": {
                    ParamId(index=0): CPInput(0, 2, split_output=True),
                    ParamId(index=1): CPInput(0, 2, split_output=True),
                },
                "proj_out": [CPOutput(1, 3)],
            }
        ),
    )

    # Flux
    TransformerRegistry.register(
        model_class=FluxTransformer2DModel,
        metadata=TransformerMetadata(
            cp_plan={
                "": {
                    ParamId("hidden_states", 0): CPInput(1, 3),
                    ParamId("encoder_hidden_states", 1): CPInput(1, 3),
                    ParamId("img_ids", 4): CPInput(0, 2),
                    ParamId("txt_ids", 5): CPInput(0, 2),
                },
                "proj_out": [CPOutput(1, 3)],
            }
        ),
    )

    # Wan2.1
    TransformerRegistry.register(
        model_class=WanTransformer3DModel,
        metadata=TransformerMetadata(
            cp_plan={
                "rope": {
                    ParamId(index=0): CPInput(2, 4, split_output=True),
                },
                "blocks.*": {
                    ParamId("encoder_hidden_states", 1): CPInput(1, 3),
                },
                "blocks.0": {
                    ParamId("hidden_states", 0): CPInput(1, 3),
                },
                "proj_out": [CPOutput(1, 3)],
            }
        ),
    )

    logger.debug("Metadata for transformer registered")