File size: 7,303 Bytes
c7db14f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
#!/usr/bin/env python
# coding=utf-8
# Copyright (C) 2025 AIDC-AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Union, List, Dict, Any
import math
import os
import torch
import torch.nn as nn
from diffusers import DiffusionPipeline, EulerDiscreteScheduler, SchedulerMixin
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils import logging
from PIL import Image
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, CLIPFeatureExtractor
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
def get_noise(
num_samples: int,
channel: int,
height: int,
width: int,
device: torch.device,
dtype: torch.dtype,
seed: int,
):
return torch.randn(
num_samples,
channel,
# allow for packing
2 * math.ceil(height / 16),
2 * math.ceil(width / 16),
device=device,
dtype=dtype,
generator=torch.Generator(device=device).manual_seed(seed),
)
class ChatsSDXLPipeline(DiffusionPipeline, ConfigMixin):
@register_to_config
def __init__(
self,
unet_win: nn.Module,
unet_lose: nn.Module,
text_encoder: CLIPTextModel,
text_encoder_two: CLIPTextModelWithProjection,
tokenizer: CLIPTokenizer,
tokenizer_two: CLIPTokenizer,
vae: AutoencoderKL,
scheduler: SchedulerMixin,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor
):
super().__init__()
self.register_modules(
unet_win=unet_win,
unet_lose=unet_lose,
text_encoder=text_encoder,
text_encoder_two=text_encoder_two,
tokenizer=tokenizer,
tokenizer_two=tokenizer_two,
vae=vae,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor
)
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Union[str, os.PathLike],
**kwargs,
) -> "ChatsSDXLPipeline":
return super().from_pretrained(pretrained_model_name_or_path, **kwargs)
def save_pretrained(self, save_directory: Union[str, os.PathLike]):
super().save_pretrained(save_directory)
@torch.no_grad()
def encode_text(self, tokenizers, text_encoders, prompt):
prompt_embeds_list = []
with torch.no_grad():
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
text_inputs = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt",)
text_input_ids = text_inputs.input_ids
prompt_embeds = text_encoder(text_input_ids.to(self.unet_win.device), output_hidden_states=True)
pooled_prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.hidden_states[-2]
prompt_embeds_list.append(prompt_embeds)
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
prompt_embeds = prompt_embeds.to(dtype=text_encoders[-1].dtype, device=text_encoders[-1].device)
return prompt_embeds, pooled_prompt_embeds
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
latents: torch.FloatTensor = None,
height: int = 1024,
width: int = 1024,
seed: int = 0,
alpha: float=0.5
):
if isinstance(prompt, str):
prompt = [prompt]
device = self.unet_win.device
tokenizers = [self.tokenizer, self.tokenizer_two]
text_encoders = [self.text_encoder, self.text_encoder_two]
prompt_embeds, pooled_prompt_embeds = self.encode_text(tokenizers, text_encoders, prompt)
negative_prompt_embeds, negative_pooled_prompt_embeds = self.encode_text(tokenizers, text_encoders, "")
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
bs = len(prompt)
channel = self.vae.config.latent_channels
height = 16 * (height // 16)
width = 16 * (width // 16)
# prepare input
latents = get_noise(
bs,
channel,
height,
width,
device=device,
dtype=self.unet_win.dtype,
seed=seed,
)
latents = latents * self.scheduler.init_noise_sigma
add_time_ids = torch.tensor([height, width, 0, 0, height, width], dtype=latents.dtype, device=device)[None, :].repeat(latents.size(0), 1)
for i, t in enumerate(timesteps):
latent_model_input = self.scheduler.scale_model_input(latents, t)
added_cond_kwargs_win = {"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids}
added_cond_kwargs_lose = {"text_embeds": pooled_prompt_embeds * (-alpha) + negative_pooled_prompt_embeds * (1. + alpha), "time_ids": add_time_ids}
pred_win = self.unet_win(latent_model_input, t, encoder_hidden_states=prompt_embeds, added_cond_kwargs=added_cond_kwargs_win, return_dict=False)[0]
pred_lose = self.unet_lose(latent_model_input, t, encoder_hidden_states=prompt_embeds * (-alpha) + negative_prompt_embeds * (1. + alpha), added_cond_kwargs=added_cond_kwargs_lose, return_dict=False)[0]
noise_pred = pred_win + guidance_scale * (pred_win - pred_lose)
latents = self.scheduler.step(noise_pred, t, latents, generator=None, return_dict=False)[0]
x = latents.float()
with torch.no_grad():
with torch.autocast(device_type=device.type, dtype=torch.float32):
if hasattr(self.vae.config, 'scaling_factor') and self.vae.config.scaling_factor is not None:
x = x / self.vae.config.scaling_factor
if hasattr(self.vae.config, 'shift_factor') and self.vae.config.shift_factor is not None:
x = x + self.vae.config.shift_factor
x = self.vae.decode(x, return_dict=False)[0]
# bring into PIL format and save
x = (x / 2 + 0.5).clamp(0, 1)
x = x.cpu().permute(0, 2, 3, 1).float().numpy()
images = (x * 255).round().astype("uint8")
clip_input = self.feature_extractor(images=images, return_tensors="pt").to(self.device)
filtered_images, has_nsfw_flags = self.safety_checker(images=images, clip_input=clip_input.pixel_values)
return {"images": filtered_images, "nsfw_flags": has_nsfw_flags}
|