Spaces:
Running
Running
```python | |
import torch | |
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler | |
# Model selection and device placement | |
model_id = "black-forest-labs/FLUX.1-schnell" | |
device = "mps" if torch.backends.mps.is_available() else "cpu" # Prioritize MPS | |
# Optimization parameters | |
torch.backends.mps.graph_mode = False # Disable graph mode for MPS for better debugging and potential performance in some cases | |
pipe = StableDiffusionPipeline.from_pretrained( | |
model_id, | |
torch_dtype=torch.float16, # Use float16 for better performance on MPS | |
scheduler=DPMSolverMultistepScheduler.from_config(model_id, algorithm_type="dpmsolver"), # Optimized scheduler for speed | |
) | |
pipe.to(device) | |
# Memory optimization | |
pipe.enable_attention_slicing() # Enable attention slicing for memory efficiency | |
pipe.enable_vae_slicing() # Enable VAE slicing for memory efficiency | |
# Speed optimization with torch.compile | |
pipe = torch.compile(pipe) # Compile the pipeline for better performance | |
# Inference parameters | |
prompt = "A cat holding a sign that says hello world" | |
height = 768 | |
width = 1360 | |
num_inference_steps = 4 | |
# Inference | |
image = pipe(prompt, height=height, width=width, num_inference_steps=num_inference_steps).images[0] | |
# Save the image (optional) | |
image.save("cat_with_sign.png") | |
``` | |