Spaces:
Runtime error
Runtime error
| import math | |
| import torch | |
| import torch.nn as nn | |
| from diffusers import ModelMixin | |
| from diffusers.configuration_utils import (ConfigMixin, | |
| register_to_config) | |
| class FontDiffuserModel(ModelMixin, ConfigMixin): | |
| """Forward function for FontDiffuer with content encoder \ | |
| style encoder and unet. | |
| """ | |
| def __init__( | |
| self, | |
| unet, | |
| style_encoder, | |
| content_encoder, | |
| ): | |
| super().__init__() | |
| self.unet = unet | |
| self.style_encoder = style_encoder | |
| self.content_encoder = content_encoder | |
| def forward( | |
| self, | |
| x_t, | |
| timesteps, | |
| style_images, | |
| content_images, | |
| content_encoder_downsample_size, | |
| ): | |
| style_img_feature, _, _ = self.style_encoder(style_images) | |
| batch_size, channel, height, width = style_img_feature.shape | |
| style_hidden_states = style_img_feature.permute(0, 2, 3, 1).reshape(batch_size, height*width, channel) | |
| # Get the content feature | |
| content_img_feature, content_residual_features = self.content_encoder(content_images) | |
| content_residual_features.append(content_img_feature) | |
| # Get the content feature from reference image | |
| style_content_feature, style_content_res_features = self.content_encoder(style_images) | |
| style_content_res_features.append(style_content_feature) | |
| input_hidden_states = [style_img_feature, content_residual_features, \ | |
| style_hidden_states, style_content_res_features] | |
| out = self.unet( | |
| x_t, | |
| timesteps, | |
| encoder_hidden_states=input_hidden_states, | |
| content_encoder_downsample_size=content_encoder_downsample_size, | |
| ) | |
| noise_pred = out[0] | |
| offset_out_sum = out[1] | |
| return noise_pred, offset_out_sum | |
| class FontDiffuserModelDPM(ModelMixin, ConfigMixin): | |
| """DPM Forward function for FontDiffuer with content encoder \ | |
| style encoder and unet. | |
| """ | |
| def __init__( | |
| self, | |
| unet, | |
| style_encoder, | |
| content_encoder, | |
| ): | |
| super().__init__() | |
| self.unet = unet | |
| self.style_encoder = style_encoder | |
| self.content_encoder = content_encoder | |
| def forward( | |
| self, | |
| x_t, | |
| timesteps, | |
| cond, | |
| content_encoder_downsample_size, | |
| version, | |
| ): | |
| content_images = cond[0] | |
| style_images = cond[1] | |
| style_img_feature, _, style_residual_features = self.style_encoder(style_images) | |
| batch_size, channel, height, width = style_img_feature.shape | |
| style_hidden_states = style_img_feature.permute(0, 2, 3, 1).reshape(batch_size, height*width, channel) | |
| # Get content feature | |
| content_img_feture, content_residual_features = self.content_encoder(content_images) | |
| content_residual_features.append(content_img_feture) | |
| # Get the content feature from reference image | |
| style_content_feature, style_content_res_features = self.content_encoder(style_images) | |
| style_content_res_features.append(style_content_feature) | |
| input_hidden_states = [style_img_feature, content_residual_features, style_hidden_states, style_content_res_features] | |
| out = self.unet( | |
| x_t, | |
| timesteps, | |
| encoder_hidden_states=input_hidden_states, | |
| content_encoder_downsample_size=content_encoder_downsample_size, | |
| ) | |
| noise_pred = out[0] | |
| return noise_pred | |