Flourish commited on
Commit
c7db14f
·
verified ·
1 Parent(s): 09fac77

Upload 5 files

Browse files
Files changed (5) hide show
  1. README.md +6 -5
  2. app.py +83 -0
  3. gitattributes +35 -0
  4. pipeline.py +189 -0
  5. requirements.txt +22 -0
README.md CHANGED
@@ -1,13 +1,14 @@
1
  ---
2
  title: CHATS
3
- emoji: 🚀
4
- colorFrom: blue
5
- colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 5.31.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: CHATS
3
+ emoji: 🖼
4
+ colorFrom: purple
5
+ colorTo: red
6
  sdk: gradio
7
+ sdk_version: 5.25.2
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
+ short_description: The demo for CHATS-SDXL text-to-image generation model
12
  ---
13
 
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+
4
+ from pipeline import ChatsSDXLPipeline
5
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
6
+ from transformers import CLIPFeatureExtractor
7
+ from diffusers.utils import logging
8
+ from PIL import Image
9
+
10
+ logging.set_verbosity_error()
11
+
12
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
13
+
14
+ feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32")
15
+ safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
16
+
17
+ # Load CHATS-SDXL pipeline
18
+ pipe = ChatsSDXLPipeline.from_pretrained(
19
+ "AIDC-AI/CHATS",
20
+ safety_checker=safety_checker,
21
+ feature_extractor=feature_extractor,
22
+ torch_dtype=torch.float16
23
+ )
24
+ pipe.to(DEVICE)
25
+
26
+ def generate(prompt, steps=50, guidance_scale=7.5, height=768, width=512):
27
+ output = pipe(
28
+ prompt=prompt,
29
+ num_inference_steps=steps,
30
+ guidance_scale=guidance_scale,
31
+ height=height,
32
+ width=width,
33
+ seed=0
34
+ )
35
+ image = output['images'][0]
36
+ image = Image.fromarray(image)
37
+ return image
38
+
39
+ with gr.Blocks(title="🔥 CHATS-SDXL Demo") as demo:
40
+ gr.Markdown(
41
+ "## CHATS-SDXL Text-to-Image Demo\n\n"
42
+ "Enter your prompt and click **Generate Image**. All NSFW content will be automatically filtered."
43
+ )
44
+ with gr.Row():
45
+ prompt_input = gr.Textbox(
46
+ label="Prompt",
47
+ placeholder="Enter your description here...",
48
+ lines=2,
49
+ )
50
+ with gr.Row():
51
+ steps_slider = gr.Slider(
52
+ minimum=1, maximum=100, value=50, step=1,
53
+ label="Inference Steps"
54
+ )
55
+ scale_slider = gr.Slider(
56
+ minimum=1.0, maximum=14.0, value=5.0, step=0.1,
57
+ label="Guidance Scale"
58
+ )
59
+ with gr.Row():
60
+ height_slider = gr.Slider(
61
+ minimum=64, maximum=2048, value=1024, step=64,
62
+ label="Image Height"
63
+ )
64
+ width_slider = gr.Slider(
65
+ minimum=64, maximum=2048, value=1024, step=64,
66
+ label="Image Width"
67
+ )
68
+ generate_button = gr.Button("Generate Image")
69
+ gallery = gr.Gallery(
70
+ label="Generated Images",
71
+ show_label=False,
72
+ columns=2,
73
+ elem_id="gallery"
74
+ )
75
+
76
+ generate_button.click(
77
+ fn=generate,
78
+ inputs=[prompt_input, steps_slider, scale_slider, height_slider, width_slider],
79
+ outputs=[gallery],
80
+ )
81
+
82
+ if __name__ == "__main__":
83
+ demo.launch()
gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
pipeline.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright (C) 2025 AIDC-AI
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from typing import Optional, Union, List, Dict, Any
18
+
19
+ import math
20
+ import os
21
+ import torch
22
+ import torch.nn as nn
23
+ from diffusers import DiffusionPipeline, EulerDiscreteScheduler, SchedulerMixin
24
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
25
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
26
+ from diffusers.utils import logging
27
+ from PIL import Image
28
+
29
+ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, CLIPFeatureExtractor
30
+ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
31
+
32
+ def get_noise(
33
+ num_samples: int,
34
+ channel: int,
35
+ height: int,
36
+ width: int,
37
+ device: torch.device,
38
+ dtype: torch.dtype,
39
+ seed: int,
40
+ ):
41
+ return torch.randn(
42
+ num_samples,
43
+ channel,
44
+ # allow for packing
45
+ 2 * math.ceil(height / 16),
46
+ 2 * math.ceil(width / 16),
47
+ device=device,
48
+ dtype=dtype,
49
+ generator=torch.Generator(device=device).manual_seed(seed),
50
+ )
51
+
52
+ class ChatsSDXLPipeline(DiffusionPipeline, ConfigMixin):
53
+
54
+ @register_to_config
55
+ def __init__(
56
+ self,
57
+ unet_win: nn.Module,
58
+ unet_lose: nn.Module,
59
+ text_encoder: CLIPTextModel,
60
+ text_encoder_two: CLIPTextModelWithProjection,
61
+ tokenizer: CLIPTokenizer,
62
+ tokenizer_two: CLIPTokenizer,
63
+ vae: AutoencoderKL,
64
+ scheduler: SchedulerMixin,
65
+ safety_checker: StableDiffusionSafetyChecker,
66
+ feature_extractor: CLIPFeatureExtractor
67
+ ):
68
+ super().__init__()
69
+
70
+ self.register_modules(
71
+ unet_win=unet_win,
72
+ unet_lose=unet_lose,
73
+ text_encoder=text_encoder,
74
+ text_encoder_two=text_encoder_two,
75
+ tokenizer=tokenizer,
76
+ tokenizer_two=tokenizer_two,
77
+ vae=vae,
78
+ scheduler=scheduler,
79
+ safety_checker=safety_checker,
80
+ feature_extractor=feature_extractor
81
+ )
82
+
83
+
84
+ @classmethod
85
+ def from_pretrained(
86
+ cls,
87
+ pretrained_model_name_or_path: Union[str, os.PathLike],
88
+ **kwargs,
89
+ ) -> "ChatsSDXLPipeline":
90
+
91
+ return super().from_pretrained(pretrained_model_name_or_path, **kwargs)
92
+
93
+ def save_pretrained(self, save_directory: Union[str, os.PathLike]):
94
+ super().save_pretrained(save_directory)
95
+
96
+ @torch.no_grad()
97
+ def encode_text(self, tokenizers, text_encoders, prompt):
98
+ prompt_embeds_list = []
99
+
100
+ with torch.no_grad():
101
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
102
+ text_inputs = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt",)
103
+ text_input_ids = text_inputs.input_ids
104
+ prompt_embeds = text_encoder(text_input_ids.to(self.unet_win.device), output_hidden_states=True)
105
+ pooled_prompt_embeds = prompt_embeds[0]
106
+ prompt_embeds = prompt_embeds.hidden_states[-2]
107
+ prompt_embeds_list.append(prompt_embeds)
108
+
109
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
110
+ prompt_embeds = prompt_embeds.to(dtype=text_encoders[-1].dtype, device=text_encoders[-1].device)
111
+
112
+ return prompt_embeds, pooled_prompt_embeds
113
+
114
+ @torch.no_grad()
115
+ def __call__(
116
+ self,
117
+ prompt: Union[str, List[str]],
118
+ num_inference_steps: int = 50,
119
+ guidance_scale: float = 7.5,
120
+ latents: torch.FloatTensor = None,
121
+ height: int = 1024,
122
+ width: int = 1024,
123
+ seed: int = 0,
124
+ alpha: float=0.5
125
+ ):
126
+ if isinstance(prompt, str):
127
+ prompt = [prompt]
128
+
129
+ device = self.unet_win.device
130
+
131
+ tokenizers = [self.tokenizer, self.tokenizer_two]
132
+ text_encoders = [self.text_encoder, self.text_encoder_two]
133
+
134
+ prompt_embeds, pooled_prompt_embeds = self.encode_text(tokenizers, text_encoders, prompt)
135
+ negative_prompt_embeds, negative_pooled_prompt_embeds = self.encode_text(tokenizers, text_encoders, "")
136
+
137
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
138
+ timesteps = self.scheduler.timesteps
139
+
140
+ bs = len(prompt)
141
+ channel = self.vae.config.latent_channels
142
+ height = 16 * (height // 16)
143
+ width = 16 * (width // 16)
144
+
145
+ # prepare input
146
+ latents = get_noise(
147
+ bs,
148
+ channel,
149
+ height,
150
+ width,
151
+ device=device,
152
+ dtype=self.unet_win.dtype,
153
+ seed=seed,
154
+ )
155
+ latents = latents * self.scheduler.init_noise_sigma
156
+
157
+ add_time_ids = torch.tensor([height, width, 0, 0, height, width], dtype=latents.dtype, device=device)[None, :].repeat(latents.size(0), 1)
158
+
159
+ for i, t in enumerate(timesteps):
160
+ latent_model_input = self.scheduler.scale_model_input(latents, t)
161
+
162
+ added_cond_kwargs_win = {"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids}
163
+ added_cond_kwargs_lose = {"text_embeds": pooled_prompt_embeds * (-alpha) + negative_pooled_prompt_embeds * (1. + alpha), "time_ids": add_time_ids}
164
+
165
+ pred_win = self.unet_win(latent_model_input, t, encoder_hidden_states=prompt_embeds, added_cond_kwargs=added_cond_kwargs_win, return_dict=False)[0]
166
+ pred_lose = self.unet_lose(latent_model_input, t, encoder_hidden_states=prompt_embeds * (-alpha) + negative_prompt_embeds * (1. + alpha), added_cond_kwargs=added_cond_kwargs_lose, return_dict=False)[0]
167
+
168
+ noise_pred = pred_win + guidance_scale * (pred_win - pred_lose)
169
+ latents = self.scheduler.step(noise_pred, t, latents, generator=None, return_dict=False)[0]
170
+
171
+ x = latents.float()
172
+
173
+ with torch.no_grad():
174
+ with torch.autocast(device_type=device.type, dtype=torch.float32):
175
+ if hasattr(self.vae.config, 'scaling_factor') and self.vae.config.scaling_factor is not None:
176
+ x = x / self.vae.config.scaling_factor
177
+ if hasattr(self.vae.config, 'shift_factor') and self.vae.config.shift_factor is not None:
178
+ x = x + self.vae.config.shift_factor
179
+ x = self.vae.decode(x, return_dict=False)[0]
180
+
181
+ # bring into PIL format and save
182
+ x = (x / 2 + 0.5).clamp(0, 1)
183
+ x = x.cpu().permute(0, 2, 3, 1).float().numpy()
184
+ images = (x * 255).round().astype("uint8")
185
+
186
+ clip_input = self.feature_extractor(images=images, return_tensors="pt").to(self.device)
187
+ filtered_images, has_nsfw_flags = self.safety_checker(images=images, clip_input=clip_input.pixel_values)
188
+
189
+ return {"images": filtered_images, "nsfw_flags": has_nsfw_flags}
requirements.txt ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ transformers==4.44.2
2
+ accelerate==0.31.0
3
+ deepspeed==0.14.5
4
+ numpy==1.24.3
5
+ diffusers
6
+ datasets
7
+ requests
8
+ fastapi
9
+ scipy
10
+ pandas
11
+ xformers
12
+ ftfy
13
+ Jinja2
14
+ bitsandbytes
15
+ safetensors
16
+ pyyaml
17
+ pillow==10.3.0
18
+ gradio
19
+ --extra-index-url https://download.pytorch.org/whl/cu124
20
+ torch==2.4.1
21
+ torchvision==0.19.1
22
+ torchaudio==2.4.1