Spaces:
Runtime error
Runtime error
Alexander McKinney
commited on
Commit
·
92ba1f6
1
Parent(s):
3d237d0
updates for CPU only diffusion
Browse files
app.py
CHANGED
|
@@ -15,6 +15,7 @@ from transformers.models.detr.feature_extraction_detr import rgb_to_id
|
|
| 15 |
from diffusers import StableDiffusionInpaintPipeline
|
| 16 |
|
| 17 |
auth_token = os.environ.get("READ_TOKEN")
|
|
|
|
| 18 |
|
| 19 |
torch.inference_mode()
|
| 20 |
torch.no_grad()
|
|
@@ -32,7 +33,7 @@ def load_diffusion_pipeline(model_name: str = 'runwayml/stable-diffusion-inpaint
|
|
| 32 |
return StableDiffusionInpaintPipeline.from_pretrained(
|
| 33 |
model_name,
|
| 34 |
revision='fp16',
|
| 35 |
-
torch_dtype=torch.float16,
|
| 36 |
use_auth_token=auth_token
|
| 37 |
)
|
| 38 |
|
|
@@ -60,7 +61,7 @@ def clean_mask(mask, max_kernel: int = 23, min_kernel: int = 5):
|
|
| 60 |
feature_extractor, segmentation_model, segmentation_cfg = load_segmentation_models()
|
| 61 |
pipe = load_diffusion_pipeline()
|
| 62 |
|
| 63 |
-
device = get_device()
|
| 64 |
pipe = pipe.to(device)
|
| 65 |
|
| 66 |
# Callback function that runs segmentation and updates CheckboxGroup
|
|
@@ -161,7 +162,9 @@ demo = gr.Blocks(css=open('app.css').read())
|
|
| 161 |
|
| 162 |
with demo:
|
| 163 |
gr.HTML(open('app_header.html').read())
|
| 164 |
-
|
|
|
|
|
|
|
| 165 |
|
| 166 |
# Input image control
|
| 167 |
input_image = gr.Image(value="example.png", type='pil', label="Input Image")
|
|
|
|
| 15 |
from diffusers import StableDiffusionInpaintPipeline
|
| 16 |
|
| 17 |
auth_token = os.environ.get("READ_TOKEN")
|
| 18 |
+
try_cuda = True
|
| 19 |
|
| 20 |
torch.inference_mode()
|
| 21 |
torch.no_grad()
|
|
|
|
| 33 |
return StableDiffusionInpaintPipeline.from_pretrained(
|
| 34 |
model_name,
|
| 35 |
revision='fp16',
|
| 36 |
+
torch_dtype=torch.float16 if try_cuda and torch.cuda.is_available() else torch.float32,
|
| 37 |
use_auth_token=auth_token
|
| 38 |
)
|
| 39 |
|
|
|
|
| 61 |
feature_extractor, segmentation_model, segmentation_cfg = load_segmentation_models()
|
| 62 |
pipe = load_diffusion_pipeline()
|
| 63 |
|
| 64 |
+
device = get_device(try_cuda=try_cuda)
|
| 65 |
pipe = pipe.to(device)
|
| 66 |
|
| 67 |
# Callback function that runs segmentation and updates CheckboxGroup
|
|
|
|
| 162 |
|
| 163 |
with demo:
|
| 164 |
gr.HTML(open('app_header.html').read())
|
| 165 |
+
|
| 166 |
+
if not try_cuda or not torch.cuda.is_available():
|
| 167 |
+
gr.HTML('<div class="alert alert-warning" role="alert" style="color:red"><b>Warning: GPU not available! Diffusion will be slow.</b></div>')
|
| 168 |
|
| 169 |
# Input image control
|
| 170 |
input_image = gr.Image(value="example.png", type='pil', label="Input Image")
|