Spaces:
Running
Running
from diffusers import ( | |
CogVideoXTransformer3DModel, | |
CogView4Transformer2DModel, | |
FluxTransformer2DModel, | |
WanTransformer3DModel, | |
) | |
from finetrainers._metadata import CPInput, CPOutput, ParamId, TransformerMetadata, TransformerRegistry | |
from finetrainers.logging import get_logger | |
logger = get_logger() | |
def register_transformer_metadata(): | |
# CogVideoX | |
TransformerRegistry.register( | |
model_class=CogVideoXTransformer3DModel, | |
metadata=TransformerMetadata( | |
cp_plan={ | |
"": { | |
ParamId("image_rotary_emb", 5): [CPInput(0, 2), CPInput(0, 2)], | |
}, | |
"transformer_blocks.0": { | |
ParamId("hidden_states", 0): CPInput(1, 3), | |
ParamId("encoder_hidden_states", 1): CPInput(1, 3), | |
}, | |
"proj_out": [CPOutput(1, 3)], | |
} | |
), | |
) | |
# CogView4 | |
TransformerRegistry.register( | |
model_class=CogView4Transformer2DModel, | |
metadata=TransformerMetadata( | |
cp_plan={ | |
"patch_embed": { | |
ParamId(index=0): CPInput(1, 3, split_output=True), | |
ParamId(index=1): CPInput(1, 3, split_output=True), | |
}, | |
"rope": { | |
ParamId(index=0): CPInput(0, 2, split_output=True), | |
ParamId(index=1): CPInput(0, 2, split_output=True), | |
}, | |
"proj_out": [CPOutput(1, 3)], | |
} | |
), | |
) | |
# Flux | |
TransformerRegistry.register( | |
model_class=FluxTransformer2DModel, | |
metadata=TransformerMetadata( | |
cp_plan={ | |
"": { | |
ParamId("hidden_states", 0): CPInput(1, 3), | |
ParamId("encoder_hidden_states", 1): CPInput(1, 3), | |
ParamId("img_ids", 4): CPInput(0, 2), | |
ParamId("txt_ids", 5): CPInput(0, 2), | |
}, | |
"proj_out": [CPOutput(1, 3)], | |
} | |
), | |
) | |
# Wan2.1 | |
TransformerRegistry.register( | |
model_class=WanTransformer3DModel, | |
metadata=TransformerMetadata( | |
cp_plan={ | |
"rope": { | |
ParamId(index=0): CPInput(2, 4, split_output=True), | |
}, | |
"blocks.*": { | |
ParamId("encoder_hidden_states", 1): CPInput(1, 3), | |
}, | |
"blocks.0": { | |
ParamId("hidden_states", 0): CPInput(1, 3), | |
}, | |
"proj_out": [CPOutput(1, 3)], | |
} | |
), | |
) | |
logger.debug("Metadata for transformer registered") | |