|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from diffusers.models.transformers.transformer_sd3 import SD3Transformer2DModel |
|
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel |
|
|
|
sd3_common_transformer_block_config = { |
|
"dummy_input": { |
|
"hidden_states": (2, 4096, 1536), |
|
"encoder_hidden_states": (2, 333, 1536), |
|
"temb": (2, 1536), |
|
}, |
|
"output_names": ["encoder_hidden_states_out", "hidden_states_out"], |
|
"dynamic_axes": { |
|
"hidden_states": {0: "batch_size"}, |
|
"encoder_hidden_states": {0: "batch_size"}, |
|
"temb": {0: "steps"}, |
|
}, |
|
} |
|
|
|
ONNX_CONFIG = { |
|
UNet2DConditionModel: { |
|
"down_blocks.0": { |
|
"dummy_input": { |
|
"hidden_states": (2, 320, 128, 128), |
|
"temb": (2, 1280), |
|
}, |
|
"output_names": ["sample", "res_samples_0", "res_samples_1", "res_samples_2"], |
|
"dynamic_axes": { |
|
"hidden_states": {0: "batch_size"}, |
|
"temb": {0: "steps"}, |
|
}, |
|
}, |
|
"down_blocks.1": { |
|
"dummy_input": { |
|
"hidden_states": (2, 320, 64, 64), |
|
"temb": (2, 1280), |
|
"encoder_hidden_states": (2, 77, 2048), |
|
}, |
|
"output_names": ["sample", "res_samples_0", "res_samples_1", "res_samples_2"], |
|
"dynamic_axes": { |
|
"hidden_states": {0: "batch_size"}, |
|
"temb": {0: "steps"}, |
|
"encoder_hidden_states": {0: "batch_size"}, |
|
}, |
|
}, |
|
"down_blocks.2": { |
|
"dummy_input": { |
|
"hidden_states": (2, 640, 32, 32), |
|
"temb": (2, 1280), |
|
"encoder_hidden_states": (2, 77, 2048), |
|
}, |
|
"output_names": ["sample", "res_samples_0", "res_samples_1"], |
|
"dynamic_axes": { |
|
"hidden_states": {0: "batch_size"}, |
|
"temb": {0: "steps"}, |
|
"encoder_hidden_states": {0: "batch_size"}, |
|
}, |
|
}, |
|
"mid_block": { |
|
"dummy_input": { |
|
"hidden_states": (2, 1280, 32, 32), |
|
"temb": (2, 1280), |
|
"encoder_hidden_states": (2, 77, 2048), |
|
}, |
|
"output_names": ["sample"], |
|
"dynamic_axes": { |
|
"hidden_states": {0: "batch_size"}, |
|
"temb": {0: "steps"}, |
|
"encoder_hidden_states": {0: "batch_size"}, |
|
}, |
|
}, |
|
"up_blocks.0": { |
|
"dummy_input": { |
|
"hidden_states": (2, 1280, 32, 32), |
|
"res_hidden_states_0": (2, 640, 32, 32), |
|
"res_hidden_states_1": (2, 1280, 32, 32), |
|
"res_hidden_states_2": (2, 1280, 32, 32), |
|
"temb": (2, 1280), |
|
"encoder_hidden_states": (2, 77, 2048), |
|
}, |
|
"output_names": ["sample"], |
|
"dynamic_axes": { |
|
"hidden_states": {0: "batch_size"}, |
|
"temb": {0: "steps"}, |
|
"encoder_hidden_states": {0: "batch_size"}, |
|
"res_hidden_states_0": {0: "batch_size"}, |
|
"res_hidden_states_1": {0: "batch_size"}, |
|
"res_hidden_states_2": {0: "batch_size"}, |
|
}, |
|
}, |
|
"up_blocks.1": { |
|
"dummy_input": { |
|
"hidden_states": (2, 1280, 64, 64), |
|
"res_hidden_states_0": (2, 320, 64, 64), |
|
"res_hidden_states_1": (2, 640, 64, 64), |
|
"res_hidden_states_2": (2, 640, 64, 64), |
|
"temb": (2, 1280), |
|
"encoder_hidden_states": (2, 77, 2048), |
|
}, |
|
"output_names": ["sample"], |
|
"dynamic_axes": { |
|
"hidden_states": {0: "batch_size"}, |
|
"temb": {0: "steps"}, |
|
"encoder_hidden_states": {0: "batch_size"}, |
|
"res_hidden_states_0": {0: "batch_size"}, |
|
"res_hidden_states_1": {0: "batch_size"}, |
|
"res_hidden_states_2": {0: "batch_size"}, |
|
}, |
|
}, |
|
"up_blocks.2": { |
|
"dummy_input": { |
|
"hidden_states": (2, 640, 128, 128), |
|
"res_hidden_states_0": (2, 320, 128, 128), |
|
"res_hidden_states_1": (2, 320, 128, 128), |
|
"res_hidden_states_2": (2, 320, 128, 128), |
|
"temb": (2, 1280), |
|
}, |
|
"output_names": ["sample"], |
|
"dynamic_axes": { |
|
"hidden_states": {0: "batch_size"}, |
|
"temb": {0: "steps"}, |
|
"res_hidden_states_0": {0: "batch_size"}, |
|
"res_hidden_states_1": {0: "batch_size"}, |
|
"res_hidden_states_2": {0: "batch_size"}, |
|
}, |
|
}, |
|
}, |
|
SD3Transformer2DModel: { |
|
**{f"transformer_blocks.{i}": sd3_common_transformer_block_config for i in range(23)}, |
|
"transformer_blocks.23": { |
|
"dummy_input": { |
|
"hidden_states": (2, 4096, 1536), |
|
"encoder_hidden_states": (2, 333, 1536), |
|
"temb": (2, 1536), |
|
}, |
|
"output_names": ["hidden_states_out"], |
|
"dynamic_axes": { |
|
"hidden_states": {0: "batch_size"}, |
|
"encoder_hidden_states": {0: "batch_size"}, |
|
"temb": {0: "steps"}, |
|
}, |
|
}, |
|
}, |
|
} |
|
|