Update app.py
Browse files
app.py
CHANGED
|
@@ -10,8 +10,8 @@ import PIL
|
|
| 10 |
base = "stabilityai/stable-diffusion-xl-base-1.0"
|
| 11 |
repo = "tianweiy/DMD2"
|
| 12 |
checkpoints = {
|
| 13 |
-
"1-Step" : ["
|
| 14 |
-
"4-Step" : ["
|
| 15 |
}
|
| 16 |
loaded = None
|
| 17 |
|
|
@@ -37,7 +37,7 @@ def generate_image(prompt, ckpt):
|
|
| 37 |
num_inference_steps = checkpoints[ckpt][1]
|
| 38 |
|
| 39 |
if loaded != num_inference_steps:
|
| 40 |
-
unet.load_state_dict(torch.load(hf_hub_download(repo,
|
| 41 |
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample" if num_inference_steps==1 else "epsilon")
|
| 42 |
loaded = num_inference_steps
|
| 43 |
|
|
@@ -51,7 +51,7 @@ def generate_image(prompt, ckpt):
|
|
| 51 |
|
| 52 |
with gr.Blocks(css=CSS) as demo:
|
| 53 |
gr.HTML("<h1><center>Adobe DMD2🦖</center></h1>")
|
| 54 |
-
gr.HTML("<p><center><a href='https://huggingface.co/tianweiy/DMD2'>
|
| 55 |
with gr.Group():
|
| 56 |
with gr.Row():
|
| 57 |
prompt = gr.Textbox(label='Enter your prompt (English)', scale=8)
|
|
|
|
| 10 |
base = "stabilityai/stable-diffusion-xl-base-1.0"
|
| 11 |
repo = "tianweiy/DMD2"
|
| 12 |
checkpoints = {
|
| 13 |
+
"1-Step" : ["dmd2_sdxl_1step_unet_fp16.bin", 1],
|
| 14 |
+
"4-Step" : ["dmd2_sdxl_4step_unet_fp16.bin", 4],
|
| 15 |
}
|
| 16 |
loaded = None
|
| 17 |
|
|
|
|
| 37 |
num_inference_steps = checkpoints[ckpt][1]
|
| 38 |
|
| 39 |
if loaded != num_inference_steps:
|
| 40 |
+
unet.load_state_dict(torch.load(hf_hub_download(repo, checkpoint)), map_location="cuda")
|
| 41 |
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample" if num_inference_steps==1 else "epsilon")
|
| 42 |
loaded = num_inference_steps
|
| 43 |
|
|
|
|
| 51 |
|
| 52 |
with gr.Blocks(css=CSS) as demo:
|
| 53 |
gr.HTML("<h1><center>Adobe DMD2🦖</center></h1>")
|
| 54 |
+
gr.HTML("<p><center><a href='https://huggingface.co/tianweiy/DMD2'>DMD2</a> text-to-image generation</center></p>")
|
| 55 |
with gr.Group():
|
| 56 |
with gr.Row():
|
| 57 |
prompt = gr.Textbox(label='Enter your prompt (English)', scale=8)
|