File size: 8,235 Bytes
2bbf6b0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
import torch
from pipeline_flux_regional import RegionalFluxPipeline, RegionalFluxAttnProcessor2_0
from pipeline_flux_controlnet_regional import RegionalFluxControlNetPipeline
from diffusers import FluxControlNetModel, FluxMultiControlNetModel
if __name__ == "__main__":
model_path = "black-forest-labs/FLUX.1-dev"
use_lora = False
use_controlnet = False
if use_controlnet: # takes up more gpu memory
# READ https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro for detailed usage tutorial
controlnet_model_union = 'Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro'
controlnet_union = FluxControlNetModel.from_pretrained(controlnet_model_union, torch_dtype=torch.bfloat16)
controlnet = FluxMultiControlNetModel([controlnet_union])
pipeline = RegionalFluxControlNetPipeline.from_pretrained(model_path, controlnet=controlnet, torch_dtype=torch.bfloat16).to("cuda")
else:
pipeline = RegionalFluxPipeline.from_pretrained(model_path, torch_dtype=torch.bfloat16).to("cuda")
if use_lora:
# READ https://huggingface.co/Shakker-Labs/FLUX.1-dev-LoRA-Children-Simple-Sketch for detailed usage tutorial
pipeline.load_lora_weights("Shakker-Labs/FLUX.1-dev-LoRA-Children-Simple-Sketch", weight_name="FLUX-dev-lora-children-simple-sketch.safetensors")
attn_procs = {}
for name in pipeline.transformer.attn_processors.keys():
if 'transformer_blocks' in name and name.endswith("attn.processor"):
attn_procs[name] = RegionalFluxAttnProcessor2_0()
else:
attn_procs[name] = pipeline.transformer.attn_processors[name]
pipeline.transformer.set_attn_processor(attn_procs)
## generation settings
# example regional prompt and mask pairs
image_width = 1280
image_height = 768
num_samples = 1
num_inference_steps = 24
guidance_scale = 3.5
seed = 124
base_prompt = "An ancient woman stands solemnly holding a blazing torch, while a fierce battle rages in the background, capturing both strength and tragedy in a historical war scene."
background_prompt = "a photo"
regional_prompt_mask_pairs = {
"0": {
"description": "A dignified woman in ancient robes stands in the foreground, her face illuminated by the torch she holds high. Her expression is one of determination and sorrow, her clothing and appearance reflecting the historical period. The torch casts dramatic shadows across her features, its flames dancing vibrantly against the darkness.",
"mask": [128, 128, 640, 768]
}
}
# region control settings
mask_inject_steps = 10
double_inject_blocks_interval = 1
single_inject_blocks_interval = 1
base_ratio = 0.3
# example input with controlnet enabled
# image_width = 1280
# image_height = 968
# num_samples = 1
# num_inference_steps = 24
# guidance_scale = 3.5
# seed = 124
# base_prompt = "Three high-performance sports cars, red, blue, and yellow, are racing side by side on a city street"
# background_prompt = "city street" # needed if regional masks don't cover the whole image
# regional_prompt_mask_pairs = {
# "0": {
# "description": "A sleek red sports car in the lead position, with aggressive aerodynamic styling and gleaming paint that catches the light. The car appears to be moving at high speed with motion blur effects.",
# "mask": [0, 0, 426, 968]
# },
# "1": {
# "description": "A powerful blue sports car in the middle position, neck-and-neck with its competitors. Its metallic paint shimmers as it races forward, with visible speed lines and dynamic movement.",
# "mask": [426, 0, 853, 968]
# },
# "2": {
# "description": "A striking yellow sports car in the third position, its bold color standing out against the street. The car's aggressive stance and aerodynamic profile emphasize its racing performance.",
# "mask": [853, 0, 1280, 968]
# }
# }
# ## controlnet settings
# if use_controlnet:
# control_image = [Image.open("./assets/condition_depth.png")]
# control_mode = [2] # (2) stands for depth control
# controlnet_conditioning_scale = [0.7]
## region control settings
# mask_inject_steps = 10
# double_inject_blocks_interval = 1 # 1 for full blocks
# single_inject_blocks_interval = 2 # 1 for full blocks
# base_ratio = 0.2
# example input with lora enabled
# image_width = 1280
# image_height = 1280
# num_samples = 1
# num_inference_steps = 24
# guidance_scale = 3.5
# seed = 124
# base_prompt = "Sketched style: A cute dinosaur playfully blowing tiny fire puffs over a cartoon city in a cheerful scene."
# background_prompt = "white background"
# regional_prompt_mask_pairs = {
# "0": {
# "description": "Sketched style: dinosaur with round eyes and a mischievous smile, puffing small flames over the city.",
# "mask": [0, 0, 640, 1280]
# },
# "1": {
# "description": "Sketched style: city with colorful buildings and tiny flames gently floating above, adding a playful touch.",
# "mask": [640, 0, 1280, 1280]
# }
# }
# ## lora settings
# if use_lora:
# pipeline.fuse_lora(lora_scale=1.5)
# ## region control settings
# mask_inject_steps = 10
# double_inject_blocks_interval = 1 # 18 for full blocks
# single_inject_blocks_interval = 1 # 39 for full blocks
# base_ratio = 0.1
## prepare regional prompts and masks
# ensure image width and height are divisible by the vae scale factor
image_width = (image_width // pipeline.vae_scale_factor) * pipeline.vae_scale_factor
image_height = (image_height // pipeline.vae_scale_factor) * pipeline.vae_scale_factor
regional_prompts = []
regional_masks = []
background_mask = torch.ones((image_height, image_width))
for region_idx, region in regional_prompt_mask_pairs.items():
description = region['description']
mask = region['mask']
x1, y1, x2, y2 = mask
mask = torch.zeros((image_height, image_width))
mask[y1:y2, x1:x2] = 1.0
background_mask -= mask
regional_prompts.append(description)
regional_masks.append(mask)
# if regional masks don't cover the whole image, append background prompt and mask
if background_mask.sum() > 0:
regional_prompts.append(background_prompt)
regional_masks.append(background_mask)
# setup regional kwargs that pass to the pipeline
joint_attention_kwargs = {
'regional_prompts': regional_prompts,
'regional_masks': regional_masks,
'double_inject_blocks_interval': double_inject_blocks_interval,
'single_inject_blocks_interval': single_inject_blocks_interval,
'base_ratio': base_ratio,
}
# generate images
if use_controlnet:
images = pipeline(
prompt=base_prompt,
num_samples=num_samples,
width=image_width, height=image_height,
mask_inject_steps=mask_inject_steps,
control_image=control_image,
control_mode=control_mode,
controlnet_conditioning_scale=controlnet_conditioning_scale,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=torch.Generator("cuda").manual_seed(seed),
joint_attention_kwargs=joint_attention_kwargs,
).images
else:
images = pipeline(
prompt=base_prompt,
num_samples=num_samples,
width=image_width, height=image_height,
mask_inject_steps=mask_inject_steps,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=torch.Generator("cuda").manual_seed(seed),
joint_attention_kwargs=joint_attention_kwargs,
).images
for idx, image in enumerate(images):
image.save(f"output_{idx}.jpg") |