|  | import time | 
					
						
						|  |  | 
					
						
						|  | import jax | 
					
						
						|  | import jax.numpy as jnp | 
					
						
						|  | import numpy as np | 
					
						
						|  | from flax.jax_utils import replicate | 
					
						
						|  | from jax import pmap | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | from jax.experimental.compilation_cache import compilation_cache as cc | 
					
						
						|  |  | 
					
						
						|  | from diffusers import FlaxStableDiffusionXLPipeline | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | cc.initialize_cache("/tmp/sdxl_cache") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | NUM_DEVICES = jax.device_count() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained( | 
					
						
						|  | "stabilityai/stable-diffusion-xl-base-1.0", revision="refs/pr/95", split_head_dim=True | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | scheduler_state = params.pop("scheduler") | 
					
						
						|  | params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params) | 
					
						
						|  | params["scheduler"] = scheduler_state | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | default_prompt = "a colorful photo of a castle in the middle of a forest with trees and bushes, by Ismail Inceoglu, shadows, high contrast, dynamic shading, hdr, detailed vegetation, digital painting, digital drawing, detailed painting, a detailed digital painting, gothic art, featured on deviantart" | 
					
						
						|  | default_neg_prompt = "fog, grainy, purple" | 
					
						
						|  | default_seed = 33 | 
					
						
						|  | default_guidance_scale = 5.0 | 
					
						
						|  | default_num_steps = 25 | 
					
						
						|  | width = 1024 | 
					
						
						|  | height = 1024 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def tokenize_prompt(prompt, neg_prompt): | 
					
						
						|  | prompt_ids = pipeline.prepare_inputs(prompt) | 
					
						
						|  | neg_prompt_ids = pipeline.prepare_inputs(neg_prompt) | 
					
						
						|  | return prompt_ids, neg_prompt_ids | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | p_params = replicate(params) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def replicate_all(prompt_ids, neg_prompt_ids, seed): | 
					
						
						|  | p_prompt_ids = replicate(prompt_ids) | 
					
						
						|  | p_neg_prompt_ids = replicate(neg_prompt_ids) | 
					
						
						|  | rng = jax.random.PRNGKey(seed) | 
					
						
						|  | rng = jax.random.split(rng, NUM_DEVICES) | 
					
						
						|  | return p_prompt_ids, p_neg_prompt_ids, rng | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def aot_compile( | 
					
						
						|  | prompt=default_prompt, | 
					
						
						|  | negative_prompt=default_neg_prompt, | 
					
						
						|  | seed=default_seed, | 
					
						
						|  | guidance_scale=default_guidance_scale, | 
					
						
						|  | num_inference_steps=default_num_steps, | 
					
						
						|  | ): | 
					
						
						|  | prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt) | 
					
						
						|  | prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed) | 
					
						
						|  | g = jnp.array([guidance_scale] * prompt_ids.shape[0], dtype=jnp.float32) | 
					
						
						|  | g = g[:, None] | 
					
						
						|  |  | 
					
						
						|  | return ( | 
					
						
						|  | pmap(pipeline._generate, static_broadcasted_argnums=[3, 4, 5, 9]) | 
					
						
						|  | .lower( | 
					
						
						|  | prompt_ids, | 
					
						
						|  | p_params, | 
					
						
						|  | rng, | 
					
						
						|  | num_inference_steps, | 
					
						
						|  | height, | 
					
						
						|  | width, | 
					
						
						|  | g, | 
					
						
						|  | None, | 
					
						
						|  | neg_prompt_ids, | 
					
						
						|  | False, | 
					
						
						|  | ) | 
					
						
						|  | .compile() | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | start = time.time() | 
					
						
						|  | print("Compiling ...") | 
					
						
						|  | p_generate = aot_compile() | 
					
						
						|  | print(f"Compiled in {time.time() - start}") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def generate(prompt, negative_prompt, seed=default_seed, guidance_scale=default_guidance_scale): | 
					
						
						|  | prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt) | 
					
						
						|  | prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed) | 
					
						
						|  | g = jnp.array([guidance_scale] * prompt_ids.shape[0], dtype=jnp.float32) | 
					
						
						|  | g = g[:, None] | 
					
						
						|  | images = p_generate(prompt_ids, p_params, rng, g, None, neg_prompt_ids) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:]) | 
					
						
						|  | return pipeline.numpy_to_pil(np.array(images)) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | start = time.time() | 
					
						
						|  | prompt = "photo of a rhino dressed suit and tie sitting at a table in a bar with a bar stools, award winning photography, Elke vogelsang" | 
					
						
						|  | neg_prompt = "cartoon, illustration, animation. face. male, female" | 
					
						
						|  | images = generate(prompt, neg_prompt) | 
					
						
						|  | print(f"First inference in {time.time() - start}") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | start = time.time() | 
					
						
						|  | prompt = "photo of a rhino dressed suit and tie sitting at a table in a bar with a bar stools, award winning photography, Elke vogelsang" | 
					
						
						|  | neg_prompt = "cartoon, illustration, animation. face. male, female" | 
					
						
						|  | images = generate(prompt, neg_prompt) | 
					
						
						|  | print(f"Inference in {time.time() - start}") | 
					
						
						|  |  | 
					
						
						|  | for i, image in enumerate(images): | 
					
						
						|  | image.save(f"castle_{i}.png") | 
					
						
						|  |  |