FontDiffuser-Gradio / batch_sample.py
Phat K Tran
refactor: improve result collection in BatchProcessor for parallel processing
f7864bb
import os
import time
from PIL import Image
from typing import List, Tuple, Optional, Union
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
import torch
import torchvision.transforms as transforms
from accelerate.utils import set_seed
from src import (
FontDiffuserDPMPipeline,
FontDiffuserModelDPM,
build_ddpm_scheduler,
build_unet,
build_content_encoder,
build_style_encoder,
)
from utils import (
ttf2im,
load_ttf,
is_char_in_font,
save_args_to_yaml,
save_single_image,
save_image_with_content_style,
)
class BatchProcessor:
"""Handles batch processing logic for FontDiffuser"""
def __init__(self, args):
self.args = args
self.device = args.device
self.max_batch_size = getattr(args, "max_batch_size", 8)
self.num_workers = getattr(args, "num_workers", 4)
def batch_image_process(
self,
content_inputs: List[Union[str, Image.Image]],
style_inputs: List[Union[str, Image.Image]],
content_characters: Optional[List[str]] = None,
) -> Tuple[torch.Tensor, torch.Tensor, List[Optional[Image.Image]]]:
"""
Process multiple images in batch
Args:
content_inputs: List of content image paths or PIL Images
style_inputs: List of style image paths or PIL Images
content_characters: List of characters if using character input mode
Returns:
Tuple of (content_tensors, style_tensors, content_pil_images)
"""
batch_size = len(content_inputs)
assert len(style_inputs) == batch_size, (
"Content and style inputs must have same length"
)
if content_characters:
assert len(content_characters) == batch_size, (
"Content characters must match batch size"
)
# Transform setup
content_inference_transforms = transforms.Compose(
[
transforms.Resize(
self.args.content_image_size,
interpolation=transforms.InterpolationMode.BILINEAR,
),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
style_inference_transforms = transforms.Compose(
[
transforms.Resize(
self.args.style_image_size,
interpolation=transforms.InterpolationMode.BILINEAR,
),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
# Initialize ordered lists for results
content_tensors = [None] * batch_size
style_tensors = [None] * batch_size
content_pil_images = [None] * batch_size
# Process in parallel using ThreadPoolExecutor for I/O operations
with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
# Submit content processing tasks
content_futures = []
for i, content_input in enumerate(content_inputs):
if content_characters and i < len(content_characters):
future = executor.submit(
self._process_content_character,
content_characters[i],
content_inference_transforms,
)
else:
future = executor.submit(
self._process_content_image,
content_input,
content_inference_transforms,
)
content_futures.append((i, future))
# Submit style processing tasks
style_futures = []
for i, style_input in enumerate(style_inputs):
future = executor.submit(
self._process_style_image, style_input, style_inference_transforms
)
style_futures.append((i, future))
# Collect results in order
for i, future in content_futures:
try:
content_tensor, content_pil = future.result()
if content_tensor is not None:
content_tensors[i] = content_tensor
content_pil_images[i] = content_pil
except Exception as e:
print(f"Error processing content at index {i}: {e}")
continue
for i, future in style_futures:
try:
style_tensor = future.result()
if style_tensor is not None:
style_tensors[i] = style_tensor
except Exception as e:
print(f"Error processing style at index {i}: {e}")
continue
# Filter out None values and stack tensors
content_tensors = [t for t in content_tensors if t is not None]
style_tensors = [t for t in style_tensors if t is not None]
content_pil_images = [img for img in content_pil_images if img is not None]
if content_tensors and style_tensors:
content_batch = torch.stack(content_tensors)
style_batch = torch.stack(style_tensors)
return content_batch, style_batch, content_pil_images
else:
return None, None, []
def _process_content_character(
self, character: str, transform
) -> Tuple[Optional[torch.Tensor], Optional[Image.Image]]:
"""Process content character into tensor"""
if not is_char_in_font(font_path=self.args.ttf_path, char=character):
print(f"Character '{character}' not found in font")
return None, None
font = load_ttf(ttf_path=self.args.ttf_path)
content_image = ttf2im(font=font, char=character)
content_image_pil = content_image.copy()
content_tensor = transform(content_image)
return content_tensor, content_image_pil
def _process_content_image(
self, image_input: Union[str, Image.Image], transform
) -> Tuple[Optional[torch.Tensor], None]:
"""Process content image into tensor"""
try:
if isinstance(image_input, str):
content_image = Image.open(image_input).convert("RGB")
else:
content_image = image_input.convert("RGB")
content_tensor = transform(content_image)
return content_tensor, None
except Exception as e:
print(f"Error processing content image: {e}")
return None, None
def _process_style_image(
self, image_input: Union[str, Image.Image], transform
) -> Optional[torch.Tensor]:
"""Process style image into tensor"""
try:
if isinstance(image_input, str):
style_image = Image.open(image_input).convert("RGB")
else:
style_image = image_input.convert("RGB")
style_tensor = transform(style_image)
return style_tensor
except Exception as e:
print(f"Error processing style image: {e}")
return None
def arg_parse():
from configs.fontdiffuser import get_parser
parser = get_parser()
parser.add_argument("--ckpt_dir", type=str, default=None)
parser.add_argument("--demo", action="store_true")
parser.add_argument(
"--controlnet",
type=bool,
default=False,
help="If in demo mode, the controlnet can be added.",
)
parser.add_argument("--character_input", action="store_true")
parser.add_argument("--content_character", type=str, default=None)
parser.add_argument("--content_image_path", type=str, default=None)
parser.add_argument("--style_image_path", type=str, default=None)
parser.add_argument("--save_image", action="store_true")
parser.add_argument(
"--save_image_dir", type=str, default=None, help="The saving directory."
)
parser.add_argument("--device", type=str, default="cuda:0")
parser.add_argument("--ttf_path", type=str, default="ttf/KaiXinSongA.ttf")
# Batch processing arguments
parser.add_argument(
"--batch_size",
type=int,
default=4,
help="Batch size for processing multiple images",
)
parser.add_argument(
"--max_batch_size",
type=int,
default=8,
help="Maximum batch size based on GPU memory",
)
parser.add_argument(
"--num_workers",
type=int,
default=4,
help="Number of workers for parallel image loading",
)
parser.add_argument(
"--batch_content_paths",
type=str,
nargs="+",
default=None,
help="List of content image paths for batch processing",
)
parser.add_argument(
"--batch_style_paths",
type=str,
nargs="+",
default=None,
help="List of style image paths for batch processing",
)
parser.add_argument(
"--batch_characters",
type=str,
nargs="+",
default=None,
help="List of characters for batch processing",
)
parser.add_argument(
"--adaptive_batch_size",
action="store_true",
help="Automatically adjust batch size based on GPU memory",
)
args = parser.parse_args()
style_image_size = args.style_image_size
content_image_size = args.content_image_size
args.style_image_size = (style_image_size, style_image_size)
args.content_image_size = (content_image_size, content_image_size)
return args
def get_optimal_batch_size(args) -> int:
"""Determine optimal batch size based on GPU memory"""
if not torch.cuda.is_available():
return 1
# Get GPU memory info
gpu_memory = torch.cuda.get_device_properties(args.device).total_memory / (
1024**3
) # GB
# Estimate batch size based on GPU memory (rough heuristic)
if gpu_memory >= 24: # RTX 4090, A100, etc.
optimal_batch = min(16, args.max_batch_size)
elif gpu_memory >= 12: # RTX 3080 Ti, RTX 4070 Ti, etc.
optimal_batch = min(8, args.max_batch_size)
elif gpu_memory >= 8: # RTX 3070, RTX 4060 Ti, etc.
optimal_batch = min(4, args.max_batch_size)
else: # Lower end GPUs
optimal_batch = min(2, args.max_batch_size)
return optimal_batch
def load_fontdiffuer_pipeline(args):
"""Load FontDiffuser pipeline (unchanged from original)"""
# Load the model state_dict
unet = build_unet(args=args)
unet.load_state_dict(torch.load(f"{args.ckpt_dir}/unet.pth"))
style_encoder = build_style_encoder(args=args)
style_encoder.load_state_dict(torch.load(f"{args.ckpt_dir}/style_encoder.pth"))
content_encoder = build_content_encoder(args=args)
content_encoder.load_state_dict(torch.load(f"{args.ckpt_dir}/content_encoder.pth"))
model = FontDiffuserModelDPM(
unet=unet, style_encoder=style_encoder, content_encoder=content_encoder
)
model.to(args.device)
print("Loaded the model state_dict successfully!")
# Load the training ddpm_scheduler.
train_scheduler = build_ddpm_scheduler(args=args)
print("Loaded training DDPM scheduler sucessfully!")
# Load the DPM_Solver to generate the sample.
pipe = FontDiffuserDPMPipeline(
model=model,
ddpm_train_scheduler=train_scheduler,
model_type=args.model_type,
guidance_type=args.guidance_type,
guidance_scale=args.guidance_scale,
)
print("Loaded dpm_solver pipeline sucessfully!")
return pipe
def batch_sampling(
args,
pipe,
content_inputs: List[Union[str, Image.Image]],
style_inputs: List[Union[str, Image.Image]],
content_characters: Optional[List[str]] = None,
) -> List[Image.Image]:
"""
Perform batch sampling with FontDiffuser
Args:
args: Arguments
pipe: FontDiffuser pipeline
content_inputs: List of content images/paths
style_inputs: List of style images/paths
content_characters: List of characters (if using character input)
Returns:
List of generated images
"""
if not args.demo:
os.makedirs(args.save_image_dir, exist_ok=True)
save_args_to_yaml(
args=args, output_file=f"{args.save_image_dir}/sampling_config.yaml"
)
if args.seed:
set_seed(seed=args.seed)
# Determine optimal batch size
if args.adaptive_batch_size:
optimal_batch_size = get_optimal_batch_size(args)
print(f"Using adaptive batch size: {optimal_batch_size}")
else:
optimal_batch_size = args.batch_size
batch_processor = BatchProcessor(args)
total_samples = len(content_inputs)
all_generated_images = []
print(f"Processing {total_samples} samples in batches of {optimal_batch_size}")
# Process in batches
for batch_start in range(0, total_samples, optimal_batch_size):
batch_end = min(batch_start + optimal_batch_size, total_samples)
batch_content = content_inputs[batch_start:batch_end]
batch_style = style_inputs[batch_start:batch_end]
batch_chars = (
content_characters[batch_start:batch_end] if content_characters else None
)
print(
f"Processing batch {batch_start // optimal_batch_size + 1}/{(total_samples + optimal_batch_size - 1) // optimal_batch_size}"
)
# Process batch
content_batch, style_batch, content_pil_images = (
batch_processor.batch_image_process(batch_content, batch_style, batch_chars)
)
if content_batch is None or style_batch is None:
print("Skipping batch due to processing errors")
continue
current_batch_size = content_batch.shape[0]
with torch.no_grad():
content_batch = content_batch.to(args.device)
style_batch = style_batch.to(args.device)
print(f"Generating {current_batch_size} images with DPM-Solver++...")
start_time = time.time()
try:
# Generate batch
images = pipe.generate(
content_images=content_batch,
style_images=style_batch,
batch_size=current_batch_size,
order=args.order,
num_inference_step=args.num_inference_steps,
content_encoder_downsample_size=args.content_encoder_downsample_size,
t_start=args.t_start,
t_end=args.t_end,
dm_size=args.content_image_size,
algorithm_type=args.algorithm_type,
skip_type=args.skip_type,
method=args.method,
correcting_x0_fn=args.correcting_x0_fn,
)
end_time = time.time()
print(f"Batch generation completed in {end_time - start_time:.2f}s")
# Save images if requested
if args.save_image:
save_batch_images(
args,
images,
content_pil_images,
batch_content,
batch_style,
batch_start,
)
all_generated_images.extend(images)
except RuntimeError as e:
if "out of memory" in str(e).lower():
print(
f"GPU out of memory with batch size {current_batch_size}, trying smaller batch..."
)
torch.cuda.empty_cache()
# Retry with smaller batch
smaller_batch_size = max(1, current_batch_size // 2)
for sub_batch_start in range(
0, current_batch_size, smaller_batch_size
):
sub_batch_end = min(
sub_batch_start + smaller_batch_size, current_batch_size
)
sub_content = content_batch[sub_batch_start:sub_batch_end]
sub_style = style_batch[sub_batch_start:sub_batch_end]
sub_images = pipe.generate(
content_images=sub_content,
style_images=sub_style,
batch_size=sub_batch_end - sub_batch_start,
order=args.order,
num_inference_step=args.num_inference_steps,
content_encoder_downsample_size=args.content_encoder_downsample_size,
t_start=args.t_start,
t_end=args.t_end,
dm_size=args.content_image_size,
algorithm_type=args.algorithm_type,
skip_type=args.skip_type,
method=args.method,
correcting_x0_fn=args.correcting_x0_fn,
)
all_generated_images.extend(sub_images)
else:
print(f"Error during generation: {e}")
continue
# Clear GPU cache between batches
torch.cuda.empty_cache()
print(f"Batch processing completed! Generated {len(all_generated_images)} images.")
return all_generated_images
def save_batch_images(
args, images, content_pil_images, batch_content, batch_style, batch_offset
):
"""Save batch of generated images"""
for i, image in enumerate(images):
# Create unique filename for each image
image_idx = batch_offset + i
save_single_image(
save_dir=args.save_image_dir, image=image, suffix=f"_{image_idx:04d}"
)
# Save with content and style context if available
if args.character_input and i < len(content_pil_images):
save_image_with_content_style(
save_dir=args.save_image_dir,
image=image,
content_image_pil=content_pil_images[i],
content_image_path=None,
style_image_path=batch_style[i]
if isinstance(batch_style[i], str)
else None,
resolution=args.resolution,
suffix=f"_{image_idx:04d}",
)
elif not args.character_input:
save_image_with_content_style(
save_dir=args.save_image_dir,
image=image,
content_image_pil=None,
content_image_path=batch_content[i]
if isinstance(batch_content[i], str)
else None,
style_image_path=batch_style[i]
if isinstance(batch_style[i], str)
else None,
resolution=args.resolution,
suffix=f"_{image_idx:04d}",
)
def sampling(args, pipe, content_image=None, style_image=None):
"""Original single image sampling function (for backward compatibility)"""
if not args.demo:
os.makedirs(args.save_image_dir, exist_ok=True)
save_args_to_yaml(
args=args, output_file=f"{args.save_image_dir}/sampling_config.yaml"
)
if args.seed:
set_seed(seed=args.seed)
# Use single image processing
if args.character_input:
content_inputs = (
[args.content_character] if hasattr(args, "content_character") else ["A"]
)
style_inputs = [style_image or args.style_image_path]
result = batch_sampling(args, pipe, [], style_inputs, content_inputs)
else:
content_inputs = [content_image or args.content_image_path]
style_inputs = [style_image or args.style_image_path]
result = batch_sampling(args, pipe, content_inputs, style_inputs)
return result[0] if result else None
# Additional utility functions for batch processing
def load_images_from_directory(
directory_path: str, extensions: List[str] = [".jpg", ".jpeg", ".png", ".bmp"]
) -> List[str]:
"""Load all image paths from a directory"""
directory = Path(directory_path)
image_paths = []
for ext in extensions:
image_paths.extend(directory.glob(f"*{ext}"))
image_paths.extend(directory.glob(f"*{ext.upper()}"))
return [str(path) for path in sorted(image_paths)]
def create_batch_from_config(
config_file: str,
) -> Tuple[List[str], List[str], List[str]]:
"""Create batch inputs from configuration file"""
import json
with open(config_file, "r") as f:
config = json.load(f)
content_inputs = config.get("content_images", [])
style_inputs = config.get("style_images", [])
characters = config.get("characters", [])
return content_inputs, style_inputs, characters
if __name__ == "__main__":
args = arg_parse()
# Load fontdiffuser pipeline
pipe = load_fontdiffuer_pipeline(args=args)
# Check if batch processing is requested
if args.batch_content_paths or args.batch_style_paths or args.batch_characters:
# Batch processing mode
content_inputs = args.batch_content_paths or []
style_inputs = args.batch_style_paths or []
characters = args.batch_characters or None
if characters and args.character_input:
# Character-based batch processing
style_inputs = style_inputs or [args.style_image_path] * len(characters)
generated_images = batch_sampling(args, pipe, [], style_inputs, characters)
else:
# Image-based batch processing
if len(content_inputs) != len(style_inputs):
print("Error: Number of content and style images must match")
exit(1)
generated_images = batch_sampling(args, pipe, content_inputs, style_inputs)
print(f"Batch processing completed! Generated {len(generated_images)} images.")
else:
# Single image processing (original behavior)
out_image = sampling(args=args, pipe=pipe)