Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -7,6 +7,8 @@ is_shared_ui = True if "fffiloni/sdxl-control-loras" in os.environ['SPACE_ID'] e
|
|
| 7 |
hf_token = os.environ.get("HF_TOKEN")
|
| 8 |
login(token=hf_token)
|
| 9 |
|
|
|
|
|
|
|
| 10 |
from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL
|
| 11 |
from diffusers.utils import load_image
|
| 12 |
from PIL import Image
|
|
@@ -30,7 +32,7 @@ pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
|
|
| 30 |
use_safetensors=True
|
| 31 |
)
|
| 32 |
|
| 33 |
-
pipe.to(
|
| 34 |
|
| 35 |
|
| 36 |
|
|
@@ -60,7 +62,7 @@ def resize_image(input_path, output_path, target_height):
|
|
| 60 |
def infer(use_custom_model, model_name, custom_lora_weight, image_in, prompt, negative_prompt, preprocessor, controlnet_conditioning_scale, guidance_scale, inf_steps, seed, progress=gr.Progress(track_tqdm=True)):
|
| 61 |
prompt = prompt
|
| 62 |
negative_prompt = negative_prompt
|
| 63 |
-
generator = torch.Generator(device=
|
| 64 |
|
| 65 |
if image_in == None:
|
| 66 |
raise gr.Error("You forgot to upload a source image.")
|
|
|
|
| 7 |
hf_token = os.environ.get("HF_TOKEN")
|
| 8 |
login(token=hf_token)
|
| 9 |
|
| 10 |
+
device="cuda" if torch.cuda.is_available() else "cpu"
|
| 11 |
+
|
| 12 |
from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL
|
| 13 |
from diffusers.utils import load_image
|
| 14 |
from PIL import Image
|
|
|
|
| 32 |
use_safetensors=True
|
| 33 |
)
|
| 34 |
|
| 35 |
+
pipe.to(device)
|
| 36 |
|
| 37 |
|
| 38 |
|
|
|
|
| 62 |
def infer(use_custom_model, model_name, custom_lora_weight, image_in, prompt, negative_prompt, preprocessor, controlnet_conditioning_scale, guidance_scale, inf_steps, seed, progress=gr.Progress(track_tqdm=True)):
|
| 63 |
prompt = prompt
|
| 64 |
negative_prompt = negative_prompt
|
| 65 |
+
generator = torch.Generator(device=device).manual_seed(seed)
|
| 66 |
|
| 67 |
if image_in == None:
|
| 68 |
raise gr.Error("You forgot to upload a source image.")
|