|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Any, Dict, Optional, Tuple, Union |
|
|
|
import torch |
|
from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput |
|
|
|
|
|
def cachecrossattnupblock2d_forward( |
|
self, |
|
hidden_states: torch.FloatTensor, |
|
res_hidden_states_0: torch.FloatTensor, |
|
res_hidden_states_1: torch.FloatTensor, |
|
res_hidden_states_2: torch.FloatTensor, |
|
temb: Optional[torch.FloatTensor] = None, |
|
encoder_hidden_states: Optional[torch.FloatTensor] = None, |
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
|
upsample_size: Optional[int] = None, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
encoder_attention_mask: Optional[torch.FloatTensor] = None, |
|
) -> torch.FloatTensor: |
|
res_hidden_states_tuple = (res_hidden_states_0, res_hidden_states_1, res_hidden_states_2) |
|
for resnet, attn in zip(self.resnets, self.attentions): |
|
|
|
res_hidden_states = res_hidden_states_tuple[-1] |
|
res_hidden_states_tuple = res_hidden_states_tuple[:-1] |
|
|
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) |
|
|
|
hidden_states = resnet(hidden_states, temb) |
|
hidden_states = attn( |
|
hidden_states, |
|
encoder_hidden_states=encoder_hidden_states, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
attention_mask=attention_mask, |
|
encoder_attention_mask=encoder_attention_mask, |
|
return_dict=False, |
|
)[0] |
|
|
|
if self.upsamplers is not None: |
|
for upsampler in self.upsamplers: |
|
hidden_states = upsampler(hidden_states, upsample_size) |
|
|
|
return hidden_states |
|
|
|
|
|
def cacheupblock2d_forward( |
|
self, |
|
hidden_states: torch.FloatTensor, |
|
res_hidden_states_0: torch.FloatTensor, |
|
res_hidden_states_1: torch.FloatTensor, |
|
res_hidden_states_2: torch.FloatTensor, |
|
temb: Optional[torch.FloatTensor] = None, |
|
upsample_size: Optional[int] = None, |
|
) -> torch.FloatTensor: |
|
res_hidden_states_tuple = (res_hidden_states_0, res_hidden_states_1, res_hidden_states_2) |
|
for resnet in self.resnets: |
|
|
|
res_hidden_states = res_hidden_states_tuple[-1] |
|
res_hidden_states_tuple = res_hidden_states_tuple[:-1] |
|
|
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) |
|
|
|
hidden_states = resnet(hidden_states, temb) |
|
|
|
if self.upsamplers is not None: |
|
for upsampler in self.upsamplers: |
|
hidden_states = upsampler(hidden_states, upsample_size) |
|
|
|
return hidden_states |
|
|
|
|
|
def cacheunet_forward( |
|
self, |
|
sample: torch.FloatTensor, |
|
timestep: Union[torch.Tensor, float, int], |
|
encoder_hidden_states: torch.Tensor, |
|
class_labels: Optional[torch.Tensor] = None, |
|
timestep_cond: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
|
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, |
|
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, |
|
mid_block_additional_residual: Optional[torch.Tensor] = None, |
|
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, |
|
encoder_attention_mask: Optional[torch.Tensor] = None, |
|
return_dict: bool = True, |
|
) -> Union[UNet2DConditionOutput, Tuple]: |
|
|
|
t_emb = self.get_time_embed(sample=sample, timestep=timestep) |
|
emb = self.time_embedding(t_emb, timestep_cond) |
|
aug_emb = None |
|
|
|
aug_emb = self.get_aug_embed( |
|
emb=emb, |
|
encoder_hidden_states=encoder_hidden_states, |
|
added_cond_kwargs=added_cond_kwargs, |
|
) |
|
|
|
emb = emb + aug_emb if aug_emb is not None else emb |
|
|
|
encoder_hidden_states = self.process_encoder_hidden_states( |
|
encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs |
|
) |
|
|
|
|
|
sample = self.conv_in(sample) |
|
|
|
if hasattr(self, "_export_precess_onnx") and self._export_precess_onnx: |
|
return ( |
|
sample, |
|
encoder_hidden_states, |
|
emb, |
|
) |
|
|
|
down_block_res_samples = (sample,) |
|
for i, downsample_block in enumerate(self.down_blocks): |
|
if ( |
|
hasattr(downsample_block, "has_cross_attention") |
|
and downsample_block.has_cross_attention |
|
): |
|
if hasattr(self, "use_trt_infer") and self.use_trt_infer: |
|
feed_dict = { |
|
"hidden_states": sample, |
|
"temb": emb, |
|
"encoder_hidden_states": encoder_hidden_states, |
|
} |
|
down_results = self.engines[f"down_blocks.{i}"](feed_dict, self.cuda_stream) |
|
sample = down_results["sample"] |
|
res_samples_0 = down_results["res_samples_0"] |
|
res_samples_1 = down_results["res_samples_1"] |
|
if "res_samples_2" in down_results.keys(): |
|
res_samples_2 = down_results["res_samples_2"] |
|
else: |
|
|
|
additional_residuals = {} |
|
|
|
sample, res_samples = downsample_block( |
|
hidden_states=sample, |
|
temb=emb, |
|
encoder_hidden_states=encoder_hidden_states, |
|
attention_mask=attention_mask, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
encoder_attention_mask=encoder_attention_mask, |
|
**additional_residuals, |
|
) |
|
else: |
|
if hasattr(self, "use_trt_infer") and self.use_trt_infer: |
|
feed_dict = {"hidden_states": sample, "temb": emb} |
|
down_results = self.engines[f"down_blocks.{i}"](feed_dict, self.cuda_stream) |
|
sample = down_results["sample"] |
|
res_samples_0 = down_results["res_samples_0"] |
|
res_samples_1 = down_results["res_samples_1"] |
|
if "res_samples_2" in down_results.keys(): |
|
res_samples_2 = down_results["res_samples_2"] |
|
else: |
|
sample, res_samples = downsample_block(hidden_states=sample, temb=emb) |
|
|
|
if hasattr(self, "use_trt_infer") and self.use_trt_infer: |
|
down_block_res_samples += ( |
|
res_samples_0, |
|
res_samples_1, |
|
) |
|
if "res_samples_2" in down_results.keys(): |
|
down_block_res_samples += (res_samples_2,) |
|
else: |
|
down_block_res_samples += res_samples |
|
|
|
if hasattr(self, "use_trt_infer") and self.use_trt_infer: |
|
feed_dict = { |
|
"hidden_states": sample, |
|
"temb": emb, |
|
"encoder_hidden_states": encoder_hidden_states, |
|
} |
|
mid_results = self.engines["mid_block"](feed_dict, self.cuda_stream) |
|
sample = mid_results["sample"] |
|
else: |
|
sample = self.mid_block( |
|
sample, |
|
emb, |
|
encoder_hidden_states=encoder_hidden_states, |
|
attention_mask=attention_mask, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
encoder_attention_mask=encoder_attention_mask, |
|
) |
|
|
|
|
|
for i, upsample_block in enumerate(self.up_blocks): |
|
res_samples = down_block_res_samples[-len(upsample_block.resnets) :] |
|
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] |
|
|
|
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: |
|
if hasattr(self, "use_trt_infer") and self.use_trt_infer: |
|
feed_dict = { |
|
"hidden_states": sample, |
|
"res_hidden_states_0": res_samples[0], |
|
"res_hidden_states_1": res_samples[1], |
|
"res_hidden_states_2": res_samples[2], |
|
"temb": emb, |
|
"encoder_hidden_states": encoder_hidden_states, |
|
} |
|
up_results = self.engines[f"up_blocks.{i}"](feed_dict, self.cuda_stream) |
|
sample = up_results["sample"] |
|
else: |
|
sample = upsample_block( |
|
hidden_states=sample, |
|
temb=emb, |
|
res_hidden_states_0=res_samples[0], |
|
res_hidden_states_1=res_samples[1], |
|
res_hidden_states_2=res_samples[2], |
|
encoder_hidden_states=encoder_hidden_states, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
attention_mask=attention_mask, |
|
encoder_attention_mask=encoder_attention_mask, |
|
) |
|
else: |
|
if hasattr(self, "use_trt_infer") and self.use_trt_infer: |
|
feed_dict = { |
|
"hidden_states": sample, |
|
"res_hidden_states_0": res_samples[0], |
|
"res_hidden_states_1": res_samples[1], |
|
"res_hidden_states_2": res_samples[2], |
|
"temb": emb, |
|
} |
|
up_results = self.engines[f"up_blocks.{i}"](feed_dict, self.cuda_stream) |
|
sample = up_results["sample"] |
|
else: |
|
sample = upsample_block( |
|
hidden_states=sample, |
|
temb=emb, |
|
res_hidden_states_0=res_samples[0], |
|
res_hidden_states_1=res_samples[1], |
|
res_hidden_states_2=res_samples[2], |
|
) |
|
|
|
|
|
if self.conv_norm_out: |
|
sample = self.conv_norm_out(sample) |
|
sample = self.conv_act(sample) |
|
sample = self.conv_out(sample) |
|
|
|
if not return_dict: |
|
return (sample,) |
|
|
|
return UNet2DConditionOutput(sample=sample) |
|
|