File size: 5,108 Bytes
fd1c028
092fcaa
 
 
fd1c028
 
 
092fcaa
 
 
fd1c028
 
092fcaa
 
 
 
 
 
 
 
 
fd1c028
092fcaa
 
 
 
 
 
 
 
 
 
d639c7d
fd1c028
 
 
 
 
 
 
 
092fcaa
fd1c028
 
 
092fcaa
 
 
d639c7d
 
 
 
 
 
092fcaa
d639c7d
092fcaa
 
d639c7d
 
092fcaa
 
 
d639c7d
 
 
092fcaa
fd1c028
 
d639c7d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
092fcaa
d639c7d
 
 
 
 
 
 
092fcaa
d639c7d
 
 
 
092fcaa
fd1c028
33739c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd1c028
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import gradio as gr
from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL
from diffusers.utils import load_image
from transformers import DPTImageProcessor, DPTForDepthEstimation
import torch
import mediapy
import sa_handler
import pipeline_calls



# init models

depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda")
feature_processor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")

controlnet = ControlNetModel.from_pretrained(
    "diffusers/controlnet-depth-sdxl-1.0",
    variant="fp16",
    use_safetensors=True,
    torch_dtype=torch.float16,
).to("cuda")
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16).to("cuda")
pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    controlnet=controlnet,
    vae=vae,
    variant="fp16",
    use_safetensors=True,
    torch_dtype=torch.float16,
).to("cuda")
pipeline.enable_model_cpu_offload()
pipeline.enable_vae_slicing()

sa_args = sa_handler.StyleAlignedArgs(share_group_norm=False,
                                      share_layer_norm=False,
                                      share_attention=True,
                                      adain_queries=True,
                                      adain_keys=True,
                                      adain_values=False,
                                     )
handler = sa_handler.Handler(pipeline)
handler.register(sa_args, )




# run ControlNet depth with StyleAligned
def style_aligned_controlnet(ref_style_prompt, depth_map, ref_image, img_generation_prompt):
    if depth_map == True:
        image = load_image(ref_image)
        depth_image = pipeline_calls.get_depth_map(image, feature_processor, depth_estimator)
    else:
        depth_image = load_image(ref_image).resize((1024, 1024))
    controlnet_conditioning_scale = 0.8
    num_images_per_prompt = 3 # adjust according to VRAM size
    latents = torch.randn(1 + num_images_per_prompt, 4, 128, 128).to(pipeline.unet.dtype)
    latents[1:] = torch.randn(num_images_per_prompt, 4, 128, 128).to(pipeline.unet.dtype)
    images = pipeline_calls.controlnet_call(pipeline, [ref_style_prompt, img_generation_prompt],
                                            image=depth_image,
                                            num_inference_steps=50,
                                            controlnet_conditioning_scale=controlnet_conditioning_scale,
                                            num_images_per_prompt=num_images_per_prompt,
                                            latents=latents)
    #mediapy.show_images([images[0], depth_image2] +  images[1:], titles=["reference", "depth"] + [f'result {i}' for i in range(1, len(images))])
    return [images[0], depth_image] +  images[1:], gr.Image(value=images[0], visible=True)


with gr.Blocks() as demo:
    
    with gr.Row():
      
      with gr.Column(variant='panel'):
        ref_style_prompt = gr.Textbox(
          label='Reference style prompt',
          info="Enter a Prompt to generate the reference image", placeholder='a poster in <style name> style'
        )
        depth_map = gr.Checkbox(label='Depth-map',)
        ref_style_image = gr.Image(visible=False, label='Reference style image')
      
      with gr.Column(variant='panel'): 
        ref_image = gr.Image(label="Upload the reference image", 
                             type='filepath' )
        img_generation_prompt = gr.Textbox(
            label='ControlNet Prompt',
            info="Enter a Prompt to generate images using ControlNet and Style-aligned", 
            )
    
    btn = gr.Button("Generate", size='sm')
    gallery = gr.Gallery(label="Style-Aligned ControlNet - Generated images", 
                           elem_id="gallery",
                           columns=5, 
                           rows=1, 
                           object_fit="contain", 
                           height="auto",
                          )
      
    btn.click(fn=style_aligned_controlnet, 
              inputs=[ref_style_prompt, depth_map, ref_image, img_generation_prompt], 
              outputs=[gallery, ref_style_image], 
              api_name="style_aligned_controlnet")



    gr.Examples(
      examples=[
        ['A poster in a papercut art style.', True, 'example_image/A.png', 'A village in a papercut art style.'],
        ['A poster in a papercut art style.', True, 'example_image/camel.jpg', 'A village in a papercut art style.'],
        ['A poster in a papercut art style.', False, 'example_image/train.jpg', 'A village in a papercut art style.'],
        ['A poster in a papercut art style.', False, 'example_image/sun.png', 'A village in a papercut art style.'],
        ['A poster in a papercut art style.', True, 'example_image/whale.png', 'A village in a papercut art style.'],
      ],
      inputs=[ref_style_prompt, depth_map, ref_image, img_generation_prompt], 
      outputs=[gallery, ref_style_image], 
      fn=style_aligned_controlnet,
      )

    
demo.launch()