Spaces:
Runtime error
Runtime error
| # Copyright 2024 Vchitect/Latte | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License.# Modified from Latte | |
| # | |
| # | |
| # This file is mofied from https://github.com/Vchitect/Latte/blob/main/models/latte.py | |
| # | |
| # With references to: | |
| # Latte: https://github.com/Vchitect/Latte | |
| # DiT: https://github.com/facebookresearch/DiT/tree/main | |
| import torch | |
| from einops import rearrange, repeat | |
| from opensora.acceleration.checkpoint import auto_grad_checkpoint | |
| from opensora.models.dit import DiT | |
| from opensora.registry import MODELS | |
| from opensora.utils.ckpt_utils import load_checkpoint | |
| class Latte(DiT): | |
| def forward(self, x, t, y): | |
| """ | |
| Forward pass of DiT. | |
| x: (B, C, T, H, W) tensor of inputs | |
| t: (B,) tensor of diffusion timesteps | |
| y: list of text | |
| """ | |
| # origin inputs should be float32, cast to specified dtype | |
| x = x.to(self.dtype) | |
| # embedding | |
| x = self.x_embedder(x) # (B, N, D) | |
| x = rearrange(x, "b (t s) d -> b t s d", t=self.num_temporal, s=self.num_spatial) | |
| x = x + self.pos_embed_spatial | |
| x = rearrange(x, "b t s d -> b (t s) d") | |
| t = self.t_embedder(t, dtype=x.dtype) # (N, D) | |
| y = self.y_embedder(y, self.training) # (N, D) | |
| if self.use_text_encoder: | |
| y = y.squeeze(1).squeeze(1) | |
| condition = t + y | |
| condition_spatial = repeat(condition, "b d -> (b t) d", t=self.num_temporal) | |
| condition_temporal = repeat(condition, "b d -> (b s) d", s=self.num_spatial) | |
| # blocks | |
| for i, block in enumerate(self.blocks): | |
| if i % 2 == 0: | |
| # spatial | |
| x = rearrange(x, "b (t s) d -> (b t) s d", t=self.num_temporal, s=self.num_spatial) | |
| c = condition_spatial | |
| else: | |
| # temporal | |
| x = rearrange(x, "b (t s) d -> (b s) t d", t=self.num_temporal, s=self.num_spatial) | |
| c = condition_temporal | |
| if i == 1: | |
| x = x + self.pos_embed_temporal | |
| x = auto_grad_checkpoint(block, x, c) # (B, N, D) | |
| if i % 2 == 0: | |
| x = rearrange(x, "(b t) s d -> b (t s) d", t=self.num_temporal, s=self.num_spatial) | |
| else: | |
| x = rearrange(x, "(b s) t d -> b (t s) d", t=self.num_temporal, s=self.num_spatial) | |
| # final process | |
| x = self.final_layer(x, condition) # (B, N, num_patches * out_channels) | |
| x = self.unpatchify(x) # (B, out_channels, T, H, W) | |
| # cast to float32 for better accuracy | |
| x = x.to(torch.float32) | |
| return x | |
| def Latte_XL_2(from_pretrained=None, **kwargs): | |
| model = Latte( | |
| depth=28, | |
| hidden_size=1152, | |
| patch_size=(1, 2, 2), | |
| num_heads=16, | |
| **kwargs, | |
| ) | |
| if from_pretrained is not None: | |
| load_checkpoint(model, from_pretrained) | |
| return model | |
| def Latte_XL_2x2(from_pretrained=None, **kwargs): | |
| model = Latte( | |
| depth=28, | |
| hidden_size=1152, | |
| patch_size=(2, 2, 2), | |
| num_heads=16, | |
| **kwargs, | |
| ) | |
| if from_pretrained is not None: | |
| load_checkpoint(model, from_pretrained) | |
| return model | |