cocktailpeanut commited on
Commit
a2efe11
·
1 Parent(s): 542ea3b
Files changed (3) hide show
  1. app.py +8 -2
  2. merge.py +3 -0
  3. requirements.txt +1 -1
app.py CHANGED
@@ -16,9 +16,15 @@ h1 {
16
  display: block;
17
  }
18
  """
 
 
 
 
 
 
19
 
20
  # Pipeline
21
- pipe = diffusers.StableDiffusionPipeline.from_pretrained("Lykon/DreamShaper").to("cuda", torch.float16)
22
  pipe.scheduler = diffusers.EulerDiscreteScheduler.from_config(pipe.scheduler.config)
23
  pipe.safety_checker = None
24
 
@@ -159,4 +165,4 @@ with gr.Blocks(css=css) as demo:
159
  gen.click(generate_merged, inputs=[prompt, seed, steps, height_width, negative_prompt,
160
  guidance_scale, method], outputs=[output_image, output_result])
161
 
162
- demo.launch(share=True)
 
16
  display: block;
17
  }
18
  """
19
+ if torch.cuda.is_available():
20
+ device = "cuda"
21
+ elif torch.backends.mps.is_available():
22
+ device = "mps"
23
+ else:
24
+ device = "cpu"
25
 
26
  # Pipeline
27
+ pipe = diffusers.StableDiffusionPipeline.from_pretrained("Lykon/DreamShaper").to(device, torch.float16)
28
  pipe.scheduler = diffusers.EulerDiscreteScheduler.from_config(pipe.scheduler.config)
29
  pipe.safety_checker = None
30
 
 
165
  gen.click(generate_merged, inputs=[prompt, seed, steps, height_width, negative_prompt,
166
  guidance_scale, method], outputs=[output_image, output_result])
167
 
168
+ demo.launch(share=True)
merge.py CHANGED
@@ -26,10 +26,13 @@ def init_generator(device: torch.device, fallback: torch.Generator = None):
26
  """
27
  Forks the current default random generator given device.
28
  """
 
29
  if device.type == "cpu":
30
  return torch.Generator(device="cpu").set_state(torch.get_rng_state())
31
  elif device.type == "cuda":
32
  return torch.Generator(device=device).set_state(torch.cuda.get_rng_state())
 
 
33
  else:
34
  if fallback is None:
35
  return init_generator(torch.device("cpu"))
 
26
  """
27
  Forks the current default random generator given device.
28
  """
29
+ print(f"init_generator device = {device}")
30
  if device.type == "cpu":
31
  return torch.Generator(device="cpu").set_state(torch.get_rng_state())
32
  elif device.type == "cuda":
33
  return torch.Generator(device=device).set_state(torch.cuda.get_rng_state())
34
+ elif device.type == "cuda":
35
+ return torch.Generator(device=device).set_state(torch.mps.get_rng_state())
36
  else:
37
  if fallback is None:
38
  return init_generator(torch.device("cpu"))
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
  diffusers
2
  transformers
3
  accelerate
4
- xformers
 
1
  diffusers
2
  transformers
3
  accelerate
4
+ #xformers