```python import torch from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler # Model selection and device placement model_id = "black-forest-labs/FLUX.1-schnell" device = "mps" if torch.backends.mps.is_available() else "cpu" # Prioritize MPS # Optimization parameters torch.backends.mps.graph_mode = False # Disable graph mode for MPS for better debugging and potential performance in some cases pipe = StableDiffusionPipeline.from_pretrained( model_id, torch_dtype=torch.float16, # Use float16 for better performance on MPS scheduler=DPMSolverMultistepScheduler.from_config(model_id, algorithm_type="dpmsolver"), # Optimized scheduler for speed ) pipe.to(device) # Memory optimization pipe.enable_attention_slicing() # Enable attention slicing for memory efficiency pipe.enable_vae_slicing() # Enable VAE slicing for memory efficiency # Speed optimization with torch.compile pipe = torch.compile(pipe) # Compile the pipeline for better performance # Inference parameters prompt = "A cat holding a sign that says hello world" height = 768 width = 1360 num_inference_steps = 4 # Inference image = pipe(prompt, height=height, width=width, num_inference_steps=num_inference_steps).images[0] # Save the image (optional) image.save("cat_with_sign.png") ```