File size: 1,298 Bytes
80a1334
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
```python
import torch
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler

# Model selection and device placement
model_id = "black-forest-labs/FLUX.1-schnell"
device = "mps" if torch.backends.mps.is_available() else "cpu"  # Prioritize MPS

# Optimization parameters
torch.backends.mps.graph_mode = False  # Disable graph mode for MPS for better debugging and potential performance in some cases
pipe = StableDiffusionPipeline.from_pretrained(
    model_id,
    torch_dtype=torch.float16,  # Use float16 for better performance on MPS
    scheduler=DPMSolverMultistepScheduler.from_config(model_id, algorithm_type="dpmsolver"), # Optimized scheduler for speed
)
pipe.to(device)

# Memory optimization
pipe.enable_attention_slicing()  # Enable attention slicing for memory efficiency
pipe.enable_vae_slicing() # Enable VAE slicing for memory efficiency

# Speed optimization with torch.compile
pipe = torch.compile(pipe)  # Compile the pipeline for better performance

# Inference parameters
prompt = "A cat holding a sign that says hello world"
height = 768
width = 1360
num_inference_steps = 4

# Inference
image = pipe(prompt, height=height, width=width, num_inference_steps=num_inference_steps).images[0]

# Save the image (optional)
image.save("cat_with_sign.png")

```