Spaces:
Running
Running
File size: 2,692 Bytes
9fd1204 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
from diffusers import (
CogVideoXTransformer3DModel,
CogView4Transformer2DModel,
FluxTransformer2DModel,
WanTransformer3DModel,
)
from finetrainers._metadata import CPInput, CPOutput, ParamId, TransformerMetadata, TransformerRegistry
from finetrainers.logging import get_logger
logger = get_logger()
def register_transformer_metadata():
# CogVideoX
TransformerRegistry.register(
model_class=CogVideoXTransformer3DModel,
metadata=TransformerMetadata(
cp_plan={
"": {
ParamId("image_rotary_emb", 5): [CPInput(0, 2), CPInput(0, 2)],
},
"transformer_blocks.0": {
ParamId("hidden_states", 0): CPInput(1, 3),
ParamId("encoder_hidden_states", 1): CPInput(1, 3),
},
"proj_out": [CPOutput(1, 3)],
}
),
)
# CogView4
TransformerRegistry.register(
model_class=CogView4Transformer2DModel,
metadata=TransformerMetadata(
cp_plan={
"patch_embed": {
ParamId(index=0): CPInput(1, 3, split_output=True),
ParamId(index=1): CPInput(1, 3, split_output=True),
},
"rope": {
ParamId(index=0): CPInput(0, 2, split_output=True),
ParamId(index=1): CPInput(0, 2, split_output=True),
},
"proj_out": [CPOutput(1, 3)],
}
),
)
# Flux
TransformerRegistry.register(
model_class=FluxTransformer2DModel,
metadata=TransformerMetadata(
cp_plan={
"": {
ParamId("hidden_states", 0): CPInput(1, 3),
ParamId("encoder_hidden_states", 1): CPInput(1, 3),
ParamId("img_ids", 4): CPInput(0, 2),
ParamId("txt_ids", 5): CPInput(0, 2),
},
"proj_out": [CPOutput(1, 3)],
}
),
)
# Wan2.1
TransformerRegistry.register(
model_class=WanTransformer3DModel,
metadata=TransformerMetadata(
cp_plan={
"rope": {
ParamId(index=0): CPInput(2, 4, split_output=True),
},
"blocks.*": {
ParamId("encoder_hidden_states", 1): CPInput(1, 3),
},
"blocks.0": {
ParamId("hidden_states", 0): CPInput(1, 3),
},
"proj_out": [CPOutput(1, 3)],
}
),
)
logger.debug("Metadata for transformer registered")
|