Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -29,8 +29,8 @@ PREVIEW_IMAGES = True
|
|
| 29 |
dtype = torch.bfloat16
|
| 30 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 31 |
if torch.cuda.is_available():
|
| 32 |
-
prior_pipeline = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=dtype)
|
| 33 |
-
decoder_pipeline = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=dtype)
|
| 34 |
|
| 35 |
if ENABLE_CPU_OFFLOAD:
|
| 36 |
prior_pipeline.enable_model_cpu_offload()
|
|
@@ -45,8 +45,8 @@ if torch.cuda.is_available():
|
|
| 45 |
|
| 46 |
if PREVIEW_IMAGES:
|
| 47 |
previewer = Previewer()
|
| 48 |
-
|
| 49 |
-
previewer.
|
| 50 |
def callback_prior(i, t, latents):
|
| 51 |
output = previewer(latents)
|
| 52 |
output = numpy_to_pil(output.clamp(0, 1).permute(0, 2, 3, 1).float().cpu().numpy())
|
|
@@ -82,9 +82,10 @@ def generate(
|
|
| 82 |
num_images_per_prompt: int = 2,
|
| 83 |
profile: gr.OAuthProfile | None = None,
|
| 84 |
) -> PIL.Image.Image:
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
|
|
|
| 88 |
generator = torch.Generator().manual_seed(seed)
|
| 89 |
prior_output = prior_pipeline(
|
| 90 |
prompt=prompt,
|
|
|
|
| 29 |
dtype = torch.bfloat16
|
| 30 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 31 |
if torch.cuda.is_available():
|
| 32 |
+
prior_pipeline = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=dtype)#.to(device)
|
| 33 |
+
decoder_pipeline = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=dtype)#.to(device)
|
| 34 |
|
| 35 |
if ENABLE_CPU_OFFLOAD:
|
| 36 |
prior_pipeline.enable_model_cpu_offload()
|
|
|
|
| 45 |
|
| 46 |
if PREVIEW_IMAGES:
|
| 47 |
previewer = Previewer()
|
| 48 |
+
previewer_state_dict = torch.load("previewer/previewer_v1_100k.pt", map_location=torch.device('cpu'))["state_dict"]
|
| 49 |
+
previewer.load_state_dict(previewer_state_dict)
|
| 50 |
def callback_prior(i, t, latents):
|
| 51 |
output = previewer(latents)
|
| 52 |
output = numpy_to_pil(output.clamp(0, 1).permute(0, 2, 3, 1).float().cpu().numpy())
|
|
|
|
| 82 |
num_images_per_prompt: int = 2,
|
| 83 |
profile: gr.OAuthProfile | None = None,
|
| 84 |
) -> PIL.Image.Image:
|
| 85 |
+
previewer.eval().requires_grad_(False).to(device).to(dtype)
|
| 86 |
+
prior_pipeline.to(device)
|
| 87 |
+
decoder_pipeline.to(device)
|
| 88 |
+
|
| 89 |
generator = torch.Generator().manual_seed(seed)
|
| 90 |
prior_output = prior_pipeline(
|
| 91 |
prompt=prompt,
|