Spaces:
Runtime error
Runtime error
| # Copyright 2024 Vchitect/Latte | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License.# Modified from Latte | |
| # | |
| # This file is adapted from the Latte project. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| # -------------------------------------------------------- | |
| # References: | |
| # Latte: https://github.com/Vchitect/Latte | |
| # DiT: https://github.com/facebookresearch/DiT/tree/main | |
| # -------------------------------------------------------- | |
| import torch | |
| import torch.nn as nn | |
| import transformers | |
| from transformers import CLIPTextModel, CLIPTokenizer | |
| from opensora.registry import MODELS | |
| transformers.logging.set_verbosity_error() | |
| class AbstractEncoder(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| def encode(self, *args, **kwargs): | |
| raise NotImplementedError | |
| class FrozenCLIPEmbedder(AbstractEncoder): | |
| """Uses the CLIP transformer encoder for text (from Hugging Face)""" | |
| def __init__(self, path="openai/clip-vit-huge-patch14", device="cuda", max_length=77): | |
| super().__init__() | |
| self.tokenizer = CLIPTokenizer.from_pretrained(path) | |
| self.transformer = CLIPTextModel.from_pretrained(path) | |
| self.device = device | |
| self.max_length = max_length | |
| self._freeze() | |
| def _freeze(self): | |
| self.transformer = self.transformer.eval() | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| def forward(self, text): | |
| batch_encoding = self.tokenizer( | |
| text, | |
| truncation=True, | |
| max_length=self.max_length, | |
| return_length=True, | |
| return_overflowing_tokens=False, | |
| padding="max_length", | |
| return_tensors="pt", | |
| ) | |
| tokens = batch_encoding["input_ids"].to(self.device) | |
| outputs = self.transformer(input_ids=tokens) | |
| z = outputs.last_hidden_state | |
| pooled_z = outputs.pooler_output | |
| return z, pooled_z | |
| def encode(self, text): | |
| return self(text) | |
| class ClipEncoder: | |
| """ | |
| Embeds text prompt into vector representations. Also handles text dropout for classifier-free guidance. | |
| """ | |
| def __init__( | |
| self, | |
| from_pretrained, | |
| model_max_length=77, | |
| device="cuda", | |
| dtype=torch.float, | |
| ): | |
| super().__init__() | |
| assert from_pretrained is not None, "Please specify the path to the T5 model" | |
| self.text_encoder = FrozenCLIPEmbedder(path=from_pretrained, max_length=model_max_length).to(device, dtype) | |
| self.y_embedder = None | |
| self.model_max_length = model_max_length | |
| self.output_dim = self.text_encoder.transformer.config.hidden_size | |
| def encode(self, text): | |
| _, pooled_embeddings = self.text_encoder.encode(text) | |
| y = pooled_embeddings.unsqueeze(1).unsqueeze(1) | |
| return dict(y=y) | |
| def null(self, n): | |
| null_y = self.y_embedder.y_embedding[None].repeat(n, 1, 1)[:, None] | |
| return null_y | |
| def to(self, dtype): | |
| self.text_encoder = self.text_encoder.to(dtype) | |
| return self | |