File size: 8,235 Bytes
2bbf6b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import torch
from pipeline_flux_regional import RegionalFluxPipeline, RegionalFluxAttnProcessor2_0
from pipeline_flux_controlnet_regional import RegionalFluxControlNetPipeline
from diffusers import FluxControlNetModel, FluxMultiControlNetModel

if __name__ == "__main__":
    
    model_path = "black-forest-labs/FLUX.1-dev"
    
    use_lora = False
    use_controlnet = False

    if use_controlnet: # takes up more gpu memory
        # READ https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro for detailed usage tutorial
        controlnet_model_union = 'Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro'
        controlnet_union = FluxControlNetModel.from_pretrained(controlnet_model_union, torch_dtype=torch.bfloat16)
        controlnet = FluxMultiControlNetModel([controlnet_union])
        pipeline = RegionalFluxControlNetPipeline.from_pretrained(model_path, controlnet=controlnet, torch_dtype=torch.bfloat16).to("cuda")
    else:
        pipeline = RegionalFluxPipeline.from_pretrained(model_path, torch_dtype=torch.bfloat16).to("cuda")
    
    if use_lora:
        # READ https://huggingface.co/Shakker-Labs/FLUX.1-dev-LoRA-Children-Simple-Sketch for detailed usage tutorial
        pipeline.load_lora_weights("Shakker-Labs/FLUX.1-dev-LoRA-Children-Simple-Sketch", weight_name="FLUX-dev-lora-children-simple-sketch.safetensors")
    
    attn_procs = {}
    for name in pipeline.transformer.attn_processors.keys():
        if 'transformer_blocks' in name and name.endswith("attn.processor"):
            attn_procs[name] = RegionalFluxAttnProcessor2_0()
        else:
            attn_procs[name] = pipeline.transformer.attn_processors[name]
    pipeline.transformer.set_attn_processor(attn_procs)

    ## generation settings
    
    # example regional prompt and mask pairs
    image_width = 1280
    image_height = 768
    num_samples = 1
    num_inference_steps = 24
    guidance_scale = 3.5
    seed = 124
    base_prompt = "An ancient woman stands solemnly holding a blazing torch, while a fierce battle rages in the background, capturing both strength and tragedy in a historical war scene."
    background_prompt = "a photo"
    regional_prompt_mask_pairs = {
        "0": {
            "description": "A dignified woman in ancient robes stands in the foreground, her face illuminated by the torch she holds high. Her expression is one of determination and sorrow, her clothing and appearance reflecting the historical period. The torch casts dramatic shadows across her features, its flames dancing vibrantly against the darkness.",
            "mask": [128, 128, 640, 768]
        }
    }
    # region control settings
    mask_inject_steps = 10
    double_inject_blocks_interval = 1
    single_inject_blocks_interval = 1
    base_ratio = 0.3

    # example input with controlnet enabled
    # image_width = 1280
    # image_height = 968
    # num_samples = 1
    # num_inference_steps = 24
    # guidance_scale = 3.5
    # seed = 124
    # base_prompt = "Three high-performance sports cars, red, blue, and yellow, are racing side by side on a city street"
    # background_prompt = "city street" # needed if regional masks don't cover the whole image
    # regional_prompt_mask_pairs = {
    #         "0": {
    #             "description": "A sleek red sports car in the lead position, with aggressive aerodynamic styling and gleaming paint that catches the light. The car appears to be moving at high speed with motion blur effects.",
    #             "mask": [0, 0, 426, 968]
    #         },
    #         "1": {
    #             "description": "A powerful blue sports car in the middle position, neck-and-neck with its competitors. Its metallic paint shimmers as it races forward, with visible speed lines and dynamic movement.",
    #             "mask": [426, 0, 853, 968]
    #         },
    #         "2": {
    #             "description": "A striking yellow sports car in the third position, its bold color standing out against the street. The car's aggressive stance and aerodynamic profile emphasize its racing performance.",
    #             "mask": [853, 0, 1280, 968]
    #         }
    # }

    # ## controlnet settings
    # if use_controlnet:
    #     control_image = [Image.open("./assets/condition_depth.png")]
    #     control_mode = [2] # (2) stands for depth control
    #     controlnet_conditioning_scale = [0.7]
    ## region control settings
    # mask_inject_steps = 10
    # double_inject_blocks_interval = 1 # 1 for full blocks
    # single_inject_blocks_interval = 2 # 1 for full blocks
    # base_ratio = 0.2


    # example input with lora enabled
    # image_width = 1280
    # image_height = 1280
    # num_samples = 1
    # num_inference_steps = 24
    # guidance_scale = 3.5
    # seed = 124
    # base_prompt = "Sketched style: A cute dinosaur playfully blowing tiny fire puffs over a cartoon city in a cheerful scene."
    # background_prompt = "white background"
    # regional_prompt_mask_pairs = {
    #     "0": {
    #         "description": "Sketched style: dinosaur with round eyes and a mischievous smile, puffing small flames over the city.",
    #         "mask": [0, 0, 640, 1280]
    #     },
    #     "1": {
    #         "description": "Sketched style: city with colorful buildings and tiny flames gently floating above, adding a playful touch.", 
    #         "mask": [640, 0, 1280, 1280]
    #     }
    # }
    # ## lora settings
    # if use_lora:
    #     pipeline.fuse_lora(lora_scale=1.5)
    # ## region control settings
    # mask_inject_steps = 10
    # double_inject_blocks_interval = 1 # 18 for full blocks
    # single_inject_blocks_interval = 1 # 39 for full blocks
    # base_ratio = 0.1

    ## prepare regional prompts and masks
    # ensure image width and height are divisible by the vae scale factor
    image_width = (image_width // pipeline.vae_scale_factor) * pipeline.vae_scale_factor
    image_height = (image_height // pipeline.vae_scale_factor) * pipeline.vae_scale_factor

    regional_prompts = []
    regional_masks = []
    background_mask = torch.ones((image_height, image_width))

    for region_idx, region in regional_prompt_mask_pairs.items():
        description = region['description']
        mask = region['mask']
        x1, y1, x2, y2 = mask

        mask = torch.zeros((image_height, image_width))
        mask[y1:y2, x1:x2] = 1.0

        background_mask -= mask

        regional_prompts.append(description)
        regional_masks.append(mask)
            
    # if regional masks don't cover the whole image, append background prompt and mask
    if background_mask.sum() > 0:
        regional_prompts.append(background_prompt)
        regional_masks.append(background_mask)

    # setup regional kwargs that pass to the pipeline
    joint_attention_kwargs = {
        'regional_prompts': regional_prompts,
        'regional_masks': regional_masks,
        'double_inject_blocks_interval': double_inject_blocks_interval,
        'single_inject_blocks_interval': single_inject_blocks_interval,
        'base_ratio': base_ratio,
    }
    # generate images
    if use_controlnet:
        images = pipeline(
            prompt=base_prompt,
            num_samples=num_samples,
            width=image_width, height=image_height,
            mask_inject_steps=mask_inject_steps,
            control_image=control_image,
            control_mode=control_mode,
            controlnet_conditioning_scale=controlnet_conditioning_scale,
            guidance_scale=guidance_scale,
            num_inference_steps=num_inference_steps,
            generator=torch.Generator("cuda").manual_seed(seed),
            joint_attention_kwargs=joint_attention_kwargs,
        ).images
    else:
        images = pipeline(
            prompt=base_prompt,
            num_samples=num_samples,
            width=image_width, height=image_height,
            mask_inject_steps=mask_inject_steps,
            guidance_scale=guidance_scale,
            num_inference_steps=num_inference_steps,
            generator=torch.Generator("cuda").manual_seed(seed),
            joint_attention_kwargs=joint_attention_kwargs,
        ).images

    for idx, image in enumerate(images):
        image.save(f"output_{idx}.jpg")