|
import torch |
|
from pipeline_flux_regional_pulid import RegionalFluxPipeline_PULID, RegionalFluxAttnProcessor2_0 |
|
|
|
if __name__ == "__main__": |
|
|
|
model_path = "black-forest-labs/FLUX.1-dev" |
|
|
|
pipeline = RegionalFluxPipeline_PULID.from_pretrained(model_path, torch_dtype=torch.bfloat16).to("cuda") |
|
|
|
attn_procs = {} |
|
for name in pipeline.transformer.attn_processors.keys(): |
|
if 'transformer_blocks' in name and name.endswith("attn.processor"): |
|
attn_procs[name] = RegionalFluxAttnProcessor2_0() |
|
else: |
|
attn_procs[name] = pipeline.transformer.attn_processors[name] |
|
pipeline.transformer.set_attn_processor(attn_procs) |
|
|
|
|
|
pipeline.load_pulid_models() |
|
pipeline.load_pretrain() |
|
|
|
|
|
|
|
|
|
image_width = 1280 |
|
image_height = 1280 |
|
num_samples = 1 |
|
num_inference_steps = 24 |
|
guidance_scale = 3.5 |
|
seed = 124 |
|
|
|
|
|
mask_inject_steps = 10 |
|
double_inject_blocks_interval = 1 |
|
single_inject_blocks_interval = 1 |
|
base_ratio = 0.2 |
|
|
|
|
|
base_prompt = "In a classroom during the afternoon, a man is practicing guitar by himself, with sunlight beautifully illuminating the room" |
|
background_prompt = "empty classroom" |
|
regional_prompt_mask_pairs = { |
|
"0": { |
|
"description": "A man in a blue shirt and jeans, playing guitar", |
|
"mask": [64, 320, 448, 1280] |
|
} |
|
} |
|
|
|
|
|
id_image_paths = ["./assets/lecun.jpeg"] |
|
id_weights = [1.0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
image_width = (image_width // pipeline.vae_scale_factor) * pipeline.vae_scale_factor |
|
image_height = (image_height // pipeline.vae_scale_factor) * pipeline.vae_scale_factor |
|
|
|
regional_prompts = [] |
|
regional_masks = [] |
|
background_mask = torch.ones((image_height, image_width)) |
|
|
|
for region_idx, region in regional_prompt_mask_pairs.items(): |
|
description = region['description'] |
|
mask = region['mask'] |
|
x1, y1, x2, y2 = mask |
|
|
|
mask = torch.zeros((image_height, image_width)) |
|
mask[y1:y2, x1:x2] = 1.0 |
|
|
|
background_mask -= mask |
|
|
|
regional_prompts.append(description) |
|
regional_masks.append(mask) |
|
|
|
|
|
if background_mask.sum() > 0: |
|
regional_prompts.append(background_prompt) |
|
regional_masks.append(background_mask) |
|
|
|
|
|
joint_attention_kwargs = { |
|
'regional_prompts': regional_prompts, |
|
'regional_masks': regional_masks, |
|
'double_inject_blocks_interval': double_inject_blocks_interval, |
|
'single_inject_blocks_interval': single_inject_blocks_interval, |
|
'base_ratio': base_ratio, |
|
'id_image_paths': id_image_paths, |
|
'id_weights': id_weights, |
|
'id_masks': regional_masks[:len(id_image_paths)], |
|
} |
|
|
|
images = pipeline( |
|
prompt=base_prompt, |
|
num_samples=num_samples, |
|
width=image_width, height=image_height, |
|
mask_inject_steps=mask_inject_steps, |
|
guidance_scale=guidance_scale, |
|
num_inference_steps=num_inference_steps, |
|
generator=torch.Generator("cuda").manual_seed(seed), |
|
joint_attention_kwargs=joint_attention_kwargs, |
|
).images |
|
|
|
for idx, image in enumerate(images): |
|
image.save(f"output_{idx}.jpg") |
|
|