QHL067 commited on
Commit
09d905b
·
1 Parent(s): f19fa67
Files changed (1) hide show
  1. app.py +403 -245
app.py CHANGED
@@ -1,249 +1,402 @@
1
- import gradio as gr
2
 
3
- from absl import flags
4
- from absl import app
5
- from ml_collections import config_flags
6
- import os
7
 
8
- import spaces #[uncomment to use ZeroGPU]
9
- import torch
10
 
11
 
12
- import os
13
- import random
14
 
15
- import numpy as np
16
- import torch
17
- import torch.nn.functional as F
18
- from torchvision.utils import save_image
19
- from huggingface_hub import hf_hub_download
20
 
21
- from absl import logging
22
- import ml_collections
23
 
24
- from diffusion.flow_matching import ODEEulerFlowMatchingSolver
25
- import utils
26
- import libs.autoencoder
27
- from libs.clip import FrozenCLIPEmbedder
28
- from configs import t2i_512px_clip_dimr
29
 
30
 
31
- def unpreprocess(x: torch.Tensor) -> torch.Tensor:
32
- x = 0.5 * (x + 1.0)
33
- x.clamp_(0.0, 1.0)
34
- return x
35
 
36
- def cosine_similarity_torch(latent1: torch.Tensor, latent2: torch.Tensor) -> torch.Tensor:
37
- latent1_flat = latent1.view(-1)
38
- latent2_flat = latent2.view(-1)
39
- cosine_similarity = F.cosine_similarity(
40
- latent1_flat.unsqueeze(0), latent2_flat.unsqueeze(0), dim=1
41
- )
42
- return cosine_similarity
43
-
44
- def kl_divergence(latent1: torch.Tensor, latent2: torch.Tensor) -> torch.Tensor:
45
- latent1_prob = F.softmax(latent1, dim=-1)
46
- latent2_prob = F.softmax(latent2, dim=-1)
47
- latent1_log_prob = torch.log(latent1_prob)
48
- kl_div = F.kl_div(latent1_log_prob, latent2_prob, reduction="batchmean")
49
- return kl_div
50
-
51
- def batch_decode(_z: torch.Tensor, decode, batch_size: int = 10) -> torch.Tensor:
52
- num_samples = _z.size(0)
53
- decoded_batches = []
54
-
55
- for i in range(0, num_samples, batch_size):
56
- batch = _z[i : i + batch_size]
57
- decoded_batch = decode(batch)
58
- decoded_batches.append(decoded_batch)
59
-
60
- return torch.cat(decoded_batches, dim=0)
61
-
62
- def get_caption(llm: str, text_model, prompt_dict: dict, batch_size: int):
63
- if batch_size == 3:
64
- # Only addition or only subtraction mode.
65
- assert len(prompt_dict) == 2, "Expected 2 prompts for batch_size 3."
66
- batch_prompts = list(prompt_dict.values()) + [" "]
67
- elif batch_size == 4:
68
- # Addition and subtraction mode.
69
- assert len(prompt_dict) == 3, "Expected 3 prompts for batch_size 4."
70
- batch_prompts = list(prompt_dict.values()) + [" "]
71
- elif batch_size >= 5:
72
- # Linear interpolation mode.
73
- assert len(prompt_dict) == 2, "Expected 2 prompts for linear interpolation."
74
- batch_prompts = [prompt_dict["prompt_1"]] + [" "] * (batch_size - 2) + [prompt_dict["prompt_2"]]
75
- else:
76
- raise ValueError(f"Unsupported batch_size: {batch_size}")
77
-
78
- if llm == "clip":
79
- latent, latent_and_others = text_model.encode(batch_prompts)
80
- context = latent_and_others["token_embedding"].detach()
81
- elif llm == "t5":
82
- latent, latent_and_others = text_model.get_text_embeddings(batch_prompts)
83
- context = (latent_and_others["token_embedding"] * 10.0).detach()
84
- else:
85
- raise NotImplementedError(f"Language model {llm} not supported.")
86
-
87
- token_mask = latent_and_others["token_mask"].detach()
88
- tokens = latent_and_others["tokens"].detach()
89
- captions = batch_prompts
90
-
91
- return context, token_mask, tokens, captions
92
-
93
- # Load configuration and initialize models.
94
- config_dict = t2i_512px_clip_dimr.get_config()
95
- config = ml_collections.ConfigDict(config_dict)
96
-
97
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
98
- logging.info(f"Using device: {device}")
99
-
100
- # Freeze configuration.
101
- config = ml_collections.FrozenConfigDict(config)
102
-
103
- torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
104
- MAX_SEED = np.iinfo(np.int32).max
105
- MAX_IMAGE_SIZE = 1024 # Currently not used.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
- # Load the main diffusion model.
108
- repo_id = "QHL067/CrossFlow"
109
- filename = "pretrained_models/t2i_512px_clip_dimr.pth"
110
- checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename)
111
- nnet = utils.get_nnet(**config.nnet)
112
- nnet = nnet.to(device)
113
- state_dict = torch.load(checkpoint_path, map_location=device)
114
- nnet.load_state_dict(state_dict)
115
- nnet.eval()
116
 
117
- # Initialize text model.
118
- llm = "clip"
119
- clip = FrozenCLIPEmbedder()
120
- clip.eval()
121
- clip.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
- # Load autoencoder.
124
- autoencoder = libs.autoencoder.get_model(**config.autoencoder)
125
- autoencoder.to(device)
126
 
 
 
 
127
 
128
- @torch.cuda.amp.autocast()
129
- def encode(_batch: torch.Tensor) -> torch.Tensor:
130
- """Encode a batch of images using the autoencoder."""
131
- return autoencoder.encode(_batch)
132
 
 
 
 
 
133
 
134
- @torch.cuda.amp.autocast()
135
- def decode(_batch: torch.Tensor) -> torch.Tensor:
136
- """Decode a batch of latent vectors using the autoencoder."""
137
- return autoencoder.decode(_batch)
 
138
 
139
 
140
- @spaces.GPU #[uncomment to use ZeroGPU]
141
  def infer(
142
- prompt1,
143
- prompt2,
144
  seed,
145
  randomize_seed,
 
 
146
  guidance_scale,
147
  num_inference_steps,
148
- num_of_interpolation,
149
- save_gpu_memory=True,
150
  progress=gr.Progress(track_tqdm=True),
151
  ):
152
  if randomize_seed:
153
  seed = random.randint(0, MAX_SEED)
154
 
155
- torch.manual_seed(seed)
156
- if device.type == "cuda":
157
- torch.cuda.manual_seed_all(seed)
158
 
159
- # Only support interpolation in this implementation.
160
- prompt_dict = {"prompt_1": prompt1, "prompt_2": prompt2}
161
- for key, value in prompt_dict.items():
162
- assert value is not None, f"{key} must not be None."
163
- assert num_of_interpolation >= 5, "For linear interpolation, please sample at least five images."
 
 
 
 
164
 
165
- # Get text embeddings and tokens.
166
- _context, _token_mask, _token, _caption = get_caption(
167
- llm, clip, prompt_dict=prompt_dict, batch_size=num_of_interpolation
168
- )
169
 
170
- with torch.no_grad():
171
- _z_gaussian = torch.randn(num_of_interpolation, *config.z_shape, device=device)
172
- _z_x0, _mu, _log_var = nnet(
173
- _context, text_encoder=True, shape=_z_gaussian.shape, mask=_token_mask
174
- )
175
- _z_init = _z_x0.reshape(_z_gaussian.shape)
176
-
177
- # Prepare the initial latent representations based on the number of interpolations.
178
- if num_of_interpolation == 3:
179
- # Addition or subtraction mode.
180
- if config.prompt_a is not None:
181
- assert config.prompt_s is None, "Only one of prompt_a or prompt_s should be provided."
182
- z_init_temp = _z_init[0] + _z_init[1]
183
- elif config.prompt_s is not None:
184
- assert config.prompt_a is None, "Only one of prompt_a or prompt_s should be provided."
185
- z_init_temp = _z_init[0] - _z_init[1]
186
- else:
187
- raise NotImplementedError("Either prompt_a or prompt_s must be provided for 3-sample mode.")
188
- mean = z_init_temp.mean()
189
- std = z_init_temp.std()
190
- _z_init[2] = (z_init_temp - mean) / std
191
-
192
- elif num_of_interpolation == 4:
193
- z_init_temp = _z_init[0] + _z_init[1] - _z_init[2]
194
- mean = z_init_temp.mean()
195
- std = z_init_temp.std()
196
- _z_init[3] = (z_init_temp - mean) / std
197
-
198
- elif num_of_interpolation >= 5:
199
- tensor_a = _z_init[0]
200
- tensor_b = _z_init[-1]
201
- num_interpolations = num_of_interpolation - 2
202
- interpolations = [
203
- tensor_a + (tensor_b - tensor_a) * (i / (num_interpolations + 1))
204
- for i in range(1, num_interpolations + 1)
205
- ]
206
- _z_init = torch.stack([tensor_a] + interpolations + [tensor_b], dim=0)
207
-
208
- else:
209
- raise ValueError("Unsupported number of interpolations.")
210
-
211
- assert guidance_scale > 1, "Guidance scale must be greater than 1."
212
-
213
- has_null_indicator = hasattr(config.nnet.model_args, "cfg_indicator")
214
- ode_solver = ODEEulerFlowMatchingSolver(
215
- nnet,
216
- bdv_model_fn=None,
217
- step_size_type="step_in_dsigma",
218
- guidance_scale=guidance_scale,
219
- )
220
- _z, _ = ode_solver.sample(
221
- x_T=_z_init,
222
- batch_size=num_of_interpolation,
223
- sample_steps=num_inference_steps,
224
- unconditional_guidance_scale=guidance_scale,
225
- has_null_indicator=has_null_indicator,
226
- )
227
-
228
- if save_gpu_memory:
229
- image_unprocessed = batch_decode(_z, decode)
230
- else:
231
- image_unprocessed = decode(_z)
232
-
233
- samples = unpreprocess(image_unprocessed).contiguous()[0]
234
-
235
- # return samples, seed
236
- return seed
237
 
238
 
239
- # examples = [
240
- # "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
241
- # "An astronaut riding a green horse",
242
- # "A delicious ceviche cheesecake slice",
243
- # ]
244
-
245
  examples = [
246
- ["A dog cooking dinner in the kitchen", "An orange cat wearing sunglasses on a ship"],
 
 
247
  ]
248
 
249
  css = """
@@ -255,33 +408,29 @@ css = """
255
 
256
  with gr.Blocks(css=css) as demo:
257
  with gr.Column(elem_id="col-container"):
258
- gr.Markdown(" # CrossFlow")
259
- gr.Markdown(" CrossFlow directly transforms text representations into images for text-to-image generation, enabling interpolation in the input text latent space.")
260
 
261
  with gr.Row():
262
- prompt1 = gr.Text(
263
- label="Prompt_1",
264
  show_label=False,
265
  max_lines=1,
266
- placeholder="Enter your prompt for the first image",
267
- container=False,
268
- )
269
-
270
- with gr.Row():
271
- prompt2 = gr.Text(
272
- label="Prompt_2",
273
- show_label=False,
274
- max_lines=1,
275
- placeholder="Enter your prompt for the second image",
276
  container=False,
277
  )
278
 
279
- with gr.Row():
280
  run_button = gr.Button("Run", scale=0, variant="primary")
281
 
282
  result = gr.Image(label="Result", show_label=False)
283
 
284
  with gr.Accordion("Advanced Settings", open=False):
 
 
 
 
 
 
 
285
  seed = gr.Slider(
286
  label="Seed",
287
  minimum=0,
@@ -292,47 +441,56 @@ with gr.Blocks(css=css) as demo:
292
 
293
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
294
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295
  with gr.Row():
296
  guidance_scale = gr.Slider(
297
  label="Guidance scale",
298
  minimum=0.0,
299
  maximum=10.0,
300
  step=0.1,
301
- value=7.0, # Replace with defaults that work for your model
302
  )
303
- with gr.Row():
304
  num_inference_steps = gr.Slider(
305
  label="Number of inference steps",
306
  minimum=1,
307
  maximum=50,
308
  step=1,
309
- value=50, # Replace with defaults that work for your model
310
- )
311
- with gr.Row():
312
- num_of_interpolation = gr.Slider(
313
- label="Number of images for interpolation",
314
- minimum=5,
315
- maximum=50,
316
- step=1,
317
- value=10, # Replace with defaults that work for your model
318
  )
319
 
320
- gr.Examples(examples=examples, inputs=[prompt1, prompt2])
321
  gr.on(
322
- triggers=[run_button.click, prompt1.submit, prompt2.submit],
323
  fn=infer,
324
  inputs=[
325
- prompt1,
326
- prompt2,
327
  seed,
328
  randomize_seed,
 
 
329
  guidance_scale,
330
  num_inference_steps,
331
- num_of_interpolation,
332
  ],
333
- # outputs=[result, seed],
334
- outputs=[seed],
335
  )
336
 
337
  if __name__ == "__main__":
338
- demo.launch()
 
1
+ # import gradio as gr
2
 
3
+ # from absl import flags
4
+ # from absl import app
5
+ # from ml_collections import config_flags
6
+ # import os
7
 
8
+ # import spaces #[uncomment to use ZeroGPU]
9
+ # import torch
10
 
11
 
12
+ # import os
13
+ # import random
14
 
15
+ # import numpy as np
16
+ # import torch
17
+ # import torch.nn.functional as F
18
+ # from torchvision.utils import save_image
19
+ # from huggingface_hub import hf_hub_download
20
 
21
+ # from absl import logging
22
+ # import ml_collections
23
 
24
+ # from diffusion.flow_matching import ODEEulerFlowMatchingSolver
25
+ # import utils
26
+ # import libs.autoencoder
27
+ # from libs.clip import FrozenCLIPEmbedder
28
+ # from configs import t2i_512px_clip_dimr
29
 
30
 
31
+ # def unpreprocess(x: torch.Tensor) -> torch.Tensor:
32
+ # x = 0.5 * (x + 1.0)
33
+ # x.clamp_(0.0, 1.0)
34
+ # return x
35
 
36
+ # def cosine_similarity_torch(latent1: torch.Tensor, latent2: torch.Tensor) -> torch.Tensor:
37
+ # latent1_flat = latent1.view(-1)
38
+ # latent2_flat = latent2.view(-1)
39
+ # cosine_similarity = F.cosine_similarity(
40
+ # latent1_flat.unsqueeze(0), latent2_flat.unsqueeze(0), dim=1
41
+ # )
42
+ # return cosine_similarity
43
+
44
+ # def kl_divergence(latent1: torch.Tensor, latent2: torch.Tensor) -> torch.Tensor:
45
+ # latent1_prob = F.softmax(latent1, dim=-1)
46
+ # latent2_prob = F.softmax(latent2, dim=-1)
47
+ # latent1_log_prob = torch.log(latent1_prob)
48
+ # kl_div = F.kl_div(latent1_log_prob, latent2_prob, reduction="batchmean")
49
+ # return kl_div
50
+
51
+ # def batch_decode(_z: torch.Tensor, decode, batch_size: int = 10) -> torch.Tensor:
52
+ # num_samples = _z.size(0)
53
+ # decoded_batches = []
54
+
55
+ # for i in range(0, num_samples, batch_size):
56
+ # batch = _z[i : i + batch_size]
57
+ # decoded_batch = decode(batch)
58
+ # decoded_batches.append(decoded_batch)
59
+
60
+ # return torch.cat(decoded_batches, dim=0)
61
+
62
+ # def get_caption(llm: str, text_model, prompt_dict: dict, batch_size: int):
63
+ # if batch_size == 3:
64
+ # # Only addition or only subtraction mode.
65
+ # assert len(prompt_dict) == 2, "Expected 2 prompts for batch_size 3."
66
+ # batch_prompts = list(prompt_dict.values()) + [" "]
67
+ # elif batch_size == 4:
68
+ # # Addition and subtraction mode.
69
+ # assert len(prompt_dict) == 3, "Expected 3 prompts for batch_size 4."
70
+ # batch_prompts = list(prompt_dict.values()) + [" "]
71
+ # elif batch_size >= 5:
72
+ # # Linear interpolation mode.
73
+ # assert len(prompt_dict) == 2, "Expected 2 prompts for linear interpolation."
74
+ # batch_prompts = [prompt_dict["prompt_1"]] + [" "] * (batch_size - 2) + [prompt_dict["prompt_2"]]
75
+ # else:
76
+ # raise ValueError(f"Unsupported batch_size: {batch_size}")
77
+
78
+ # if llm == "clip":
79
+ # latent, latent_and_others = text_model.encode(batch_prompts)
80
+ # context = latent_and_others["token_embedding"].detach()
81
+ # elif llm == "t5":
82
+ # latent, latent_and_others = text_model.get_text_embeddings(batch_prompts)
83
+ # context = (latent_and_others["token_embedding"] * 10.0).detach()
84
+ # else:
85
+ # raise NotImplementedError(f"Language model {llm} not supported.")
86
+
87
+ # token_mask = latent_and_others["token_mask"].detach()
88
+ # tokens = latent_and_others["tokens"].detach()
89
+ # captions = batch_prompts
90
+
91
+ # return context, token_mask, tokens, captions
92
+
93
+ # # Load configuration and initialize models.
94
+ # config_dict = t2i_512px_clip_dimr.get_config()
95
+ # config = ml_collections.ConfigDict(config_dict)
96
+
97
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
98
+ # logging.info(f"Using device: {device}")
99
+
100
+ # # Freeze configuration.
101
+ # config = ml_collections.FrozenConfigDict(config)
102
+
103
+ # torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
104
+ # MAX_SEED = np.iinfo(np.int32).max
105
+ # MAX_IMAGE_SIZE = 1024 # Currently not used.
106
+
107
+ # # Load the main diffusion model.
108
+ # repo_id = "QHL067/CrossFlow"
109
+ # filename = "pretrained_models/t2i_512px_clip_dimr.pth"
110
+ # checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename)
111
+ # nnet = utils.get_nnet(**config.nnet)
112
+ # nnet = nnet.to(device)
113
+ # state_dict = torch.load(checkpoint_path, map_location=device)
114
+ # nnet.load_state_dict(state_dict)
115
+ # nnet.eval()
116
+
117
+ # # Initialize text model.
118
+ # llm = "clip"
119
+ # clip = FrozenCLIPEmbedder()
120
+ # clip.eval()
121
+ # clip.to(device)
122
+
123
+ # # Load autoencoder.
124
+ # autoencoder = libs.autoencoder.get_model(**config.autoencoder)
125
+ # autoencoder.to(device)
126
+
127
+
128
+ # @torch.cuda.amp.autocast()
129
+ # def encode(_batch: torch.Tensor) -> torch.Tensor:
130
+ # """Encode a batch of images using the autoencoder."""
131
+ # return autoencoder.encode(_batch)
132
+
133
+
134
+ # @torch.cuda.amp.autocast()
135
+ # def decode(_batch: torch.Tensor) -> torch.Tensor:
136
+ # """Decode a batch of latent vectors using the autoencoder."""
137
+ # return autoencoder.decode(_batch)
138
+
139
+
140
+ # @spaces.GPU #[uncomment to use ZeroGPU]
141
+ # def infer(
142
+ # prompt1,
143
+ # prompt2,
144
+ # seed,
145
+ # randomize_seed,
146
+ # guidance_scale,
147
+ # num_inference_steps,
148
+ # num_of_interpolation,
149
+ # save_gpu_memory=True,
150
+ # progress=gr.Progress(track_tqdm=True),
151
+ # ):
152
+ # if randomize_seed:
153
+ # seed = random.randint(0, MAX_SEED)
154
+
155
+ # torch.manual_seed(seed)
156
+ # if device.type == "cuda":
157
+ # torch.cuda.manual_seed_all(seed)
158
+
159
+ # # Only support interpolation in this implementation.
160
+ # prompt_dict = {"prompt_1": prompt1, "prompt_2": prompt2}
161
+ # for key, value in prompt_dict.items():
162
+ # assert value is not None, f"{key} must not be None."
163
+ # assert num_of_interpolation >= 5, "For linear interpolation, please sample at least five images."
164
+
165
+ # # Get text embeddings and tokens.
166
+ # _context, _token_mask, _token, _caption = get_caption(
167
+ # llm, clip, prompt_dict=prompt_dict, batch_size=num_of_interpolation
168
+ # )
169
+
170
+ # with torch.no_grad():
171
+ # _z_gaussian = torch.randn(num_of_interpolation, *config.z_shape, device=device)
172
+ # _z_x0, _mu, _log_var = nnet(
173
+ # _context, text_encoder=True, shape=_z_gaussian.shape, mask=_token_mask
174
+ # )
175
+ # _z_init = _z_x0.reshape(_z_gaussian.shape)
176
+
177
+ # # Prepare the initial latent representations based on the number of interpolations.
178
+ # if num_of_interpolation == 3:
179
+ # # Addition or subtraction mode.
180
+ # if config.prompt_a is not None:
181
+ # assert config.prompt_s is None, "Only one of prompt_a or prompt_s should be provided."
182
+ # z_init_temp = _z_init[0] + _z_init[1]
183
+ # elif config.prompt_s is not None:
184
+ # assert config.prompt_a is None, "Only one of prompt_a or prompt_s should be provided."
185
+ # z_init_temp = _z_init[0] - _z_init[1]
186
+ # else:
187
+ # raise NotImplementedError("Either prompt_a or prompt_s must be provided for 3-sample mode.")
188
+ # mean = z_init_temp.mean()
189
+ # std = z_init_temp.std()
190
+ # _z_init[2] = (z_init_temp - mean) / std
191
+
192
+ # elif num_of_interpolation == 4:
193
+ # z_init_temp = _z_init[0] + _z_init[1] - _z_init[2]
194
+ # mean = z_init_temp.mean()
195
+ # std = z_init_temp.std()
196
+ # _z_init[3] = (z_init_temp - mean) / std
197
+
198
+ # elif num_of_interpolation >= 5:
199
+ # tensor_a = _z_init[0]
200
+ # tensor_b = _z_init[-1]
201
+ # num_interpolations = num_of_interpolation - 2
202
+ # interpolations = [
203
+ # tensor_a + (tensor_b - tensor_a) * (i / (num_interpolations + 1))
204
+ # for i in range(1, num_interpolations + 1)
205
+ # ]
206
+ # _z_init = torch.stack([tensor_a] + interpolations + [tensor_b], dim=0)
207
+
208
+ # else:
209
+ # raise ValueError("Unsupported number of interpolations.")
210
+
211
+ # assert guidance_scale > 1, "Guidance scale must be greater than 1."
212
+
213
+ # has_null_indicator = hasattr(config.nnet.model_args, "cfg_indicator")
214
+ # ode_solver = ODEEulerFlowMatchingSolver(
215
+ # nnet,
216
+ # bdv_model_fn=None,
217
+ # step_size_type="step_in_dsigma",
218
+ # guidance_scale=guidance_scale,
219
+ # )
220
+ # _z, _ = ode_solver.sample(
221
+ # x_T=_z_init,
222
+ # batch_size=num_of_interpolation,
223
+ # sample_steps=num_inference_steps,
224
+ # unconditional_guidance_scale=guidance_scale,
225
+ # has_null_indicator=has_null_indicator,
226
+ # )
227
+
228
+ # if save_gpu_memory:
229
+ # image_unprocessed = batch_decode(_z, decode)
230
+ # else:
231
+ # image_unprocessed = decode(_z)
232
+
233
+ # samples = unpreprocess(image_unprocessed).contiguous()[0]
234
+
235
+ # # return samples, seed
236
+ # return seed
237
+
238
+
239
+ # # examples = [
240
+ # # "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
241
+ # # "An astronaut riding a green horse",
242
+ # # "A delicious ceviche cheesecake slice",
243
+ # # ]
244
 
245
+ # examples = [
246
+ # ["A dog cooking dinner in the kitchen", "An orange cat wearing sunglasses on a ship"],
247
+ # ]
 
 
 
 
 
 
248
 
249
+ # css = """
250
+ # #col-container {
251
+ # margin: 0 auto;
252
+ # max-width: 640px;
253
+ # }
254
+ # """
255
+
256
+ # with gr.Blocks(css=css) as demo:
257
+ # with gr.Column(elem_id="col-container"):
258
+ # gr.Markdown(" # CrossFlow")
259
+ # gr.Markdown(" CrossFlow directly transforms text representations into images for text-to-image generation, enabling interpolation in the input text latent space.")
260
+
261
+ # with gr.Row():
262
+ # prompt1 = gr.Text(
263
+ # label="Prompt_1",
264
+ # show_label=False,
265
+ # max_lines=1,
266
+ # placeholder="Enter your prompt for the first image",
267
+ # container=False,
268
+ # )
269
+
270
+ # with gr.Row():
271
+ # prompt2 = gr.Text(
272
+ # label="Prompt_2",
273
+ # show_label=False,
274
+ # max_lines=1,
275
+ # placeholder="Enter your prompt for the second image",
276
+ # container=False,
277
+ # )
278
+
279
+ # with gr.Row():
280
+ # run_button = gr.Button("Run", scale=0, variant="primary")
281
+
282
+ # result = gr.Image(label="Result", show_label=False)
283
+
284
+ # with gr.Accordion("Advanced Settings", open=False):
285
+ # seed = gr.Slider(
286
+ # label="Seed",
287
+ # minimum=0,
288
+ # maximum=MAX_SEED,
289
+ # step=1,
290
+ # value=0,
291
+ # )
292
+
293
+ # randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
294
+
295
+ # with gr.Row():
296
+ # guidance_scale = gr.Slider(
297
+ # label="Guidance scale",
298
+ # minimum=0.0,
299
+ # maximum=10.0,
300
+ # step=0.1,
301
+ # value=7.0, # Replace with defaults that work for your model
302
+ # )
303
+ # with gr.Row():
304
+ # num_inference_steps = gr.Slider(
305
+ # label="Number of inference steps",
306
+ # minimum=1,
307
+ # maximum=50,
308
+ # step=1,
309
+ # value=50, # Replace with defaults that work for your model
310
+ # )
311
+ # with gr.Row():
312
+ # num_of_interpolation = gr.Slider(
313
+ # label="Number of images for interpolation",
314
+ # minimum=5,
315
+ # maximum=50,
316
+ # step=1,
317
+ # value=10, # Replace with defaults that work for your model
318
+ # )
319
+
320
+ # gr.Examples(examples=examples, inputs=[prompt1, prompt2])
321
+ # gr.on(
322
+ # triggers=[run_button.click, prompt1.submit, prompt2.submit],
323
+ # fn=infer,
324
+ # inputs=[
325
+ # prompt1,
326
+ # prompt2,
327
+ # seed,
328
+ # randomize_seed,
329
+ # guidance_scale,
330
+ # num_inference_steps,
331
+ # num_of_interpolation,
332
+ # ],
333
+ # # outputs=[result, seed],
334
+ # outputs=[seed],
335
+ # )
336
+
337
+ # if __name__ == "__main__":
338
+ # demo.launch()
339
 
340
+ import gradio as gr
341
+ import numpy as np
342
+ import random
343
 
344
+ # import spaces #[uncomment to use ZeroGPU]
345
+ from diffusers import DiffusionPipeline
346
+ import torch
347
 
348
+ device = "cuda" if torch.cuda.is_available() else "cpu"
349
+ model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
 
 
350
 
351
+ if torch.cuda.is_available():
352
+ torch_dtype = torch.float16
353
+ else:
354
+ torch_dtype = torch.float32
355
 
356
+ pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
357
+ pipe = pipe.to(device)
358
+
359
+ MAX_SEED = np.iinfo(np.int32).max
360
+ MAX_IMAGE_SIZE = 1024
361
 
362
 
363
+ # @spaces.GPU #[uncomment to use ZeroGPU]
364
  def infer(
365
+ prompt,
366
+ negative_prompt,
367
  seed,
368
  randomize_seed,
369
+ width,
370
+ height,
371
  guidance_scale,
372
  num_inference_steps,
 
 
373
  progress=gr.Progress(track_tqdm=True),
374
  ):
375
  if randomize_seed:
376
  seed = random.randint(0, MAX_SEED)
377
 
378
+ generator = torch.Generator().manual_seed(seed)
 
 
379
 
380
+ image = pipe(
381
+ prompt=prompt,
382
+ negative_prompt=negative_prompt,
383
+ guidance_scale=guidance_scale,
384
+ num_inference_steps=num_inference_steps,
385
+ width=width,
386
+ height=height,
387
+ generator=generator,
388
+ ).images[0]
389
 
390
+ print('image.shape')
391
+ print(image.shape)
 
 
392
 
393
+ return image, seed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
394
 
395
 
 
 
 
 
 
 
396
  examples = [
397
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
398
+ "An astronaut riding a green horse",
399
+ "A delicious ceviche cheesecake slice",
400
  ]
401
 
402
  css = """
 
408
 
409
  with gr.Blocks(css=css) as demo:
410
  with gr.Column(elem_id="col-container"):
411
+ gr.Markdown(" # Text-to-Image Gradio Template")
 
412
 
413
  with gr.Row():
414
+ prompt = gr.Text(
415
+ label="Prompt",
416
  show_label=False,
417
  max_lines=1,
418
+ placeholder="Enter your prompt",
 
 
 
 
 
 
 
 
 
419
  container=False,
420
  )
421
 
 
422
  run_button = gr.Button("Run", scale=0, variant="primary")
423
 
424
  result = gr.Image(label="Result", show_label=False)
425
 
426
  with gr.Accordion("Advanced Settings", open=False):
427
+ negative_prompt = gr.Text(
428
+ label="Negative prompt",
429
+ max_lines=1,
430
+ placeholder="Enter a negative prompt",
431
+ visible=False,
432
+ )
433
+
434
  seed = gr.Slider(
435
  label="Seed",
436
  minimum=0,
 
441
 
442
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
443
 
444
+ with gr.Row():
445
+ width = gr.Slider(
446
+ label="Width",
447
+ minimum=256,
448
+ maximum=MAX_IMAGE_SIZE,
449
+ step=32,
450
+ value=1024, # Replace with defaults that work for your model
451
+ )
452
+
453
+ height = gr.Slider(
454
+ label="Height",
455
+ minimum=256,
456
+ maximum=MAX_IMAGE_SIZE,
457
+ step=32,
458
+ value=1024, # Replace with defaults that work for your model
459
+ )
460
+
461
  with gr.Row():
462
  guidance_scale = gr.Slider(
463
  label="Guidance scale",
464
  minimum=0.0,
465
  maximum=10.0,
466
  step=0.1,
467
+ value=0.0, # Replace with defaults that work for your model
468
  )
469
+
470
  num_inference_steps = gr.Slider(
471
  label="Number of inference steps",
472
  minimum=1,
473
  maximum=50,
474
  step=1,
475
+ value=2, # Replace with defaults that work for your model
 
 
 
 
 
 
 
 
476
  )
477
 
478
+ gr.Examples(examples=examples, inputs=[prompt])
479
  gr.on(
480
+ triggers=[run_button.click, prompt.submit],
481
  fn=infer,
482
  inputs=[
483
+ prompt,
484
+ negative_prompt,
485
  seed,
486
  randomize_seed,
487
+ width,
488
+ height,
489
  guidance_scale,
490
  num_inference_steps,
 
491
  ],
492
+ outputs=[result, seed],
 
493
  )
494
 
495
  if __name__ == "__main__":
496
+ demo.launch()