import os import time from PIL import Image from typing import List, Tuple, Optional, Union from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path import torch import torchvision.transforms as transforms from accelerate.utils import set_seed from src import ( FontDiffuserDPMPipeline, FontDiffuserModelDPM, build_ddpm_scheduler, build_unet, build_content_encoder, build_style_encoder, ) from utils import ( ttf2im, load_ttf, is_char_in_font, save_args_to_yaml, save_single_image, save_image_with_content_style, ) class BatchProcessor: """Handles batch processing logic for FontDiffuser""" def __init__(self, args): self.args = args self.device = args.device self.max_batch_size = getattr(args, "max_batch_size", 8) self.num_workers = getattr(args, "num_workers", 4) def batch_image_process( self, content_inputs: List[Union[str, Image.Image]], style_inputs: List[Union[str, Image.Image]], content_characters: Optional[List[str]] = None, ) -> Tuple[torch.Tensor, torch.Tensor, List[Optional[Image.Image]]]: """ Process multiple images in batch Args: content_inputs: List of content image paths or PIL Images style_inputs: List of style image paths or PIL Images content_characters: List of characters if using character input mode Returns: Tuple of (content_tensors, style_tensors, content_pil_images) """ batch_size = len(content_inputs) assert len(style_inputs) == batch_size, ( "Content and style inputs must have same length" ) if content_characters: assert len(content_characters) == batch_size, ( "Content characters must match batch size" ) # Transform setup content_inference_transforms = transforms.Compose( [ transforms.Resize( self.args.content_image_size, interpolation=transforms.InterpolationMode.BILINEAR, ), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] ) style_inference_transforms = transforms.Compose( [ transforms.Resize( self.args.style_image_size, interpolation=transforms.InterpolationMode.BILINEAR, ), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] ) # Initialize ordered lists for results content_tensors = [None] * batch_size style_tensors = [None] * batch_size content_pil_images = [None] * batch_size # Process in parallel using ThreadPoolExecutor for I/O operations with ThreadPoolExecutor(max_workers=self.num_workers) as executor: # Submit content processing tasks content_futures = [] for i, content_input in enumerate(content_inputs): if content_characters and i < len(content_characters): future = executor.submit( self._process_content_character, content_characters[i], content_inference_transforms, ) else: future = executor.submit( self._process_content_image, content_input, content_inference_transforms, ) content_futures.append((i, future)) # Submit style processing tasks style_futures = [] for i, style_input in enumerate(style_inputs): future = executor.submit( self._process_style_image, style_input, style_inference_transforms ) style_futures.append((i, future)) # Collect results in order for i, future in content_futures: try: content_tensor, content_pil = future.result() if content_tensor is not None: content_tensors[i] = content_tensor content_pil_images[i] = content_pil except Exception as e: print(f"Error processing content at index {i}: {e}") continue for i, future in style_futures: try: style_tensor = future.result() if style_tensor is not None: style_tensors[i] = style_tensor except Exception as e: print(f"Error processing style at index {i}: {e}") continue # Filter out None values and stack tensors content_tensors = [t for t in content_tensors if t is not None] style_tensors = [t for t in style_tensors if t is not None] content_pil_images = [img for img in content_pil_images if img is not None] if content_tensors and style_tensors: content_batch = torch.stack(content_tensors) style_batch = torch.stack(style_tensors) return content_batch, style_batch, content_pil_images else: return None, None, [] def _process_content_character( self, character: str, transform ) -> Tuple[Optional[torch.Tensor], Optional[Image.Image]]: """Process content character into tensor""" if not is_char_in_font(font_path=self.args.ttf_path, char=character): print(f"Character '{character}' not found in font") return None, None font = load_ttf(ttf_path=self.args.ttf_path) content_image = ttf2im(font=font, char=character) content_image_pil = content_image.copy() content_tensor = transform(content_image) return content_tensor, content_image_pil def _process_content_image( self, image_input: Union[str, Image.Image], transform ) -> Tuple[Optional[torch.Tensor], None]: """Process content image into tensor""" try: if isinstance(image_input, str): content_image = Image.open(image_input).convert("RGB") else: content_image = image_input.convert("RGB") content_tensor = transform(content_image) return content_tensor, None except Exception as e: print(f"Error processing content image: {e}") return None, None def _process_style_image( self, image_input: Union[str, Image.Image], transform ) -> Optional[torch.Tensor]: """Process style image into tensor""" try: if isinstance(image_input, str): style_image = Image.open(image_input).convert("RGB") else: style_image = image_input.convert("RGB") style_tensor = transform(style_image) return style_tensor except Exception as e: print(f"Error processing style image: {e}") return None def arg_parse(): from configs.fontdiffuser import get_parser parser = get_parser() parser.add_argument("--ckpt_dir", type=str, default=None) parser.add_argument("--demo", action="store_true") parser.add_argument( "--controlnet", type=bool, default=False, help="If in demo mode, the controlnet can be added.", ) parser.add_argument("--character_input", action="store_true") parser.add_argument("--content_character", type=str, default=None) parser.add_argument("--content_image_path", type=str, default=None) parser.add_argument("--style_image_path", type=str, default=None) parser.add_argument("--save_image", action="store_true") parser.add_argument( "--save_image_dir", type=str, default=None, help="The saving directory." ) parser.add_argument("--device", type=str, default="cuda:0") parser.add_argument("--ttf_path", type=str, default="ttf/KaiXinSongA.ttf") # Batch processing arguments parser.add_argument( "--batch_size", type=int, default=4, help="Batch size for processing multiple images", ) parser.add_argument( "--max_batch_size", type=int, default=8, help="Maximum batch size based on GPU memory", ) parser.add_argument( "--num_workers", type=int, default=4, help="Number of workers for parallel image loading", ) parser.add_argument( "--batch_content_paths", type=str, nargs="+", default=None, help="List of content image paths for batch processing", ) parser.add_argument( "--batch_style_paths", type=str, nargs="+", default=None, help="List of style image paths for batch processing", ) parser.add_argument( "--batch_characters", type=str, nargs="+", default=None, help="List of characters for batch processing", ) parser.add_argument( "--adaptive_batch_size", action="store_true", help="Automatically adjust batch size based on GPU memory", ) args = parser.parse_args() style_image_size = args.style_image_size content_image_size = args.content_image_size args.style_image_size = (style_image_size, style_image_size) args.content_image_size = (content_image_size, content_image_size) return args def get_optimal_batch_size(args) -> int: """Determine optimal batch size based on GPU memory""" if not torch.cuda.is_available(): return 1 # Get GPU memory info gpu_memory = torch.cuda.get_device_properties(args.device).total_memory / ( 1024**3 ) # GB # Estimate batch size based on GPU memory (rough heuristic) if gpu_memory >= 24: # RTX 4090, A100, etc. optimal_batch = min(16, args.max_batch_size) elif gpu_memory >= 12: # RTX 3080 Ti, RTX 4070 Ti, etc. optimal_batch = min(8, args.max_batch_size) elif gpu_memory >= 8: # RTX 3070, RTX 4060 Ti, etc. optimal_batch = min(4, args.max_batch_size) else: # Lower end GPUs optimal_batch = min(2, args.max_batch_size) return optimal_batch def load_fontdiffuer_pipeline(args): """Load FontDiffuser pipeline (unchanged from original)""" # Load the model state_dict unet = build_unet(args=args) unet.load_state_dict(torch.load(f"{args.ckpt_dir}/unet.pth")) style_encoder = build_style_encoder(args=args) style_encoder.load_state_dict(torch.load(f"{args.ckpt_dir}/style_encoder.pth")) content_encoder = build_content_encoder(args=args) content_encoder.load_state_dict(torch.load(f"{args.ckpt_dir}/content_encoder.pth")) model = FontDiffuserModelDPM( unet=unet, style_encoder=style_encoder, content_encoder=content_encoder ) model.to(args.device) print("Loaded the model state_dict successfully!") # Load the training ddpm_scheduler. train_scheduler = build_ddpm_scheduler(args=args) print("Loaded training DDPM scheduler sucessfully!") # Load the DPM_Solver to generate the sample. pipe = FontDiffuserDPMPipeline( model=model, ddpm_train_scheduler=train_scheduler, model_type=args.model_type, guidance_type=args.guidance_type, guidance_scale=args.guidance_scale, ) print("Loaded dpm_solver pipeline sucessfully!") return pipe def batch_sampling( args, pipe, content_inputs: List[Union[str, Image.Image]], style_inputs: List[Union[str, Image.Image]], content_characters: Optional[List[str]] = None, ) -> List[Image.Image]: """ Perform batch sampling with FontDiffuser Args: args: Arguments pipe: FontDiffuser pipeline content_inputs: List of content images/paths style_inputs: List of style images/paths content_characters: List of characters (if using character input) Returns: List of generated images """ if not args.demo: os.makedirs(args.save_image_dir, exist_ok=True) save_args_to_yaml( args=args, output_file=f"{args.save_image_dir}/sampling_config.yaml" ) if args.seed: set_seed(seed=args.seed) # Determine optimal batch size if args.adaptive_batch_size: optimal_batch_size = get_optimal_batch_size(args) print(f"Using adaptive batch size: {optimal_batch_size}") else: optimal_batch_size = args.batch_size batch_processor = BatchProcessor(args) total_samples = len(content_inputs) all_generated_images = [] print(f"Processing {total_samples} samples in batches of {optimal_batch_size}") # Process in batches for batch_start in range(0, total_samples, optimal_batch_size): batch_end = min(batch_start + optimal_batch_size, total_samples) batch_content = content_inputs[batch_start:batch_end] batch_style = style_inputs[batch_start:batch_end] batch_chars = ( content_characters[batch_start:batch_end] if content_characters else None ) print( f"Processing batch {batch_start // optimal_batch_size + 1}/{(total_samples + optimal_batch_size - 1) // optimal_batch_size}" ) # Process batch content_batch, style_batch, content_pil_images = ( batch_processor.batch_image_process(batch_content, batch_style, batch_chars) ) if content_batch is None or style_batch is None: print("Skipping batch due to processing errors") continue current_batch_size = content_batch.shape[0] with torch.no_grad(): content_batch = content_batch.to(args.device) style_batch = style_batch.to(args.device) print(f"Generating {current_batch_size} images with DPM-Solver++...") start_time = time.time() try: # Generate batch images = pipe.generate( content_images=content_batch, style_images=style_batch, batch_size=current_batch_size, order=args.order, num_inference_step=args.num_inference_steps, content_encoder_downsample_size=args.content_encoder_downsample_size, t_start=args.t_start, t_end=args.t_end, dm_size=args.content_image_size, algorithm_type=args.algorithm_type, skip_type=args.skip_type, method=args.method, correcting_x0_fn=args.correcting_x0_fn, ) end_time = time.time() print(f"Batch generation completed in {end_time - start_time:.2f}s") # Save images if requested if args.save_image: save_batch_images( args, images, content_pil_images, batch_content, batch_style, batch_start, ) all_generated_images.extend(images) except RuntimeError as e: if "out of memory" in str(e).lower(): print( f"GPU out of memory with batch size {current_batch_size}, trying smaller batch..." ) torch.cuda.empty_cache() # Retry with smaller batch smaller_batch_size = max(1, current_batch_size // 2) for sub_batch_start in range( 0, current_batch_size, smaller_batch_size ): sub_batch_end = min( sub_batch_start + smaller_batch_size, current_batch_size ) sub_content = content_batch[sub_batch_start:sub_batch_end] sub_style = style_batch[sub_batch_start:sub_batch_end] sub_images = pipe.generate( content_images=sub_content, style_images=sub_style, batch_size=sub_batch_end - sub_batch_start, order=args.order, num_inference_step=args.num_inference_steps, content_encoder_downsample_size=args.content_encoder_downsample_size, t_start=args.t_start, t_end=args.t_end, dm_size=args.content_image_size, algorithm_type=args.algorithm_type, skip_type=args.skip_type, method=args.method, correcting_x0_fn=args.correcting_x0_fn, ) all_generated_images.extend(sub_images) else: print(f"Error during generation: {e}") continue # Clear GPU cache between batches torch.cuda.empty_cache() print(f"Batch processing completed! Generated {len(all_generated_images)} images.") return all_generated_images def save_batch_images( args, images, content_pil_images, batch_content, batch_style, batch_offset ): """Save batch of generated images""" for i, image in enumerate(images): # Create unique filename for each image image_idx = batch_offset + i save_single_image( save_dir=args.save_image_dir, image=image, suffix=f"_{image_idx:04d}" ) # Save with content and style context if available if args.character_input and i < len(content_pil_images): save_image_with_content_style( save_dir=args.save_image_dir, image=image, content_image_pil=content_pil_images[i], content_image_path=None, style_image_path=batch_style[i] if isinstance(batch_style[i], str) else None, resolution=args.resolution, suffix=f"_{image_idx:04d}", ) elif not args.character_input: save_image_with_content_style( save_dir=args.save_image_dir, image=image, content_image_pil=None, content_image_path=batch_content[i] if isinstance(batch_content[i], str) else None, style_image_path=batch_style[i] if isinstance(batch_style[i], str) else None, resolution=args.resolution, suffix=f"_{image_idx:04d}", ) def sampling(args, pipe, content_image=None, style_image=None): """Original single image sampling function (for backward compatibility)""" if not args.demo: os.makedirs(args.save_image_dir, exist_ok=True) save_args_to_yaml( args=args, output_file=f"{args.save_image_dir}/sampling_config.yaml" ) if args.seed: set_seed(seed=args.seed) # Use single image processing if args.character_input: content_inputs = ( [args.content_character] if hasattr(args, "content_character") else ["A"] ) style_inputs = [style_image or args.style_image_path] result = batch_sampling(args, pipe, [], style_inputs, content_inputs) else: content_inputs = [content_image or args.content_image_path] style_inputs = [style_image or args.style_image_path] result = batch_sampling(args, pipe, content_inputs, style_inputs) return result[0] if result else None # Additional utility functions for batch processing def load_images_from_directory( directory_path: str, extensions: List[str] = [".jpg", ".jpeg", ".png", ".bmp"] ) -> List[str]: """Load all image paths from a directory""" directory = Path(directory_path) image_paths = [] for ext in extensions: image_paths.extend(directory.glob(f"*{ext}")) image_paths.extend(directory.glob(f"*{ext.upper()}")) return [str(path) for path in sorted(image_paths)] def create_batch_from_config( config_file: str, ) -> Tuple[List[str], List[str], List[str]]: """Create batch inputs from configuration file""" import json with open(config_file, "r") as f: config = json.load(f) content_inputs = config.get("content_images", []) style_inputs = config.get("style_images", []) characters = config.get("characters", []) return content_inputs, style_inputs, characters if __name__ == "__main__": args = arg_parse() # Load fontdiffuser pipeline pipe = load_fontdiffuer_pipeline(args=args) # Check if batch processing is requested if args.batch_content_paths or args.batch_style_paths or args.batch_characters: # Batch processing mode content_inputs = args.batch_content_paths or [] style_inputs = args.batch_style_paths or [] characters = args.batch_characters or None if characters and args.character_input: # Character-based batch processing style_inputs = style_inputs or [args.style_image_path] * len(characters) generated_images = batch_sampling(args, pipe, [], style_inputs, characters) else: # Image-based batch processing if len(content_inputs) != len(style_inputs): print("Error: Number of content and style images must match") exit(1) generated_images = batch_sampling(args, pipe, content_inputs, style_inputs) print(f"Batch processing completed! Generated {len(generated_images)} images.") else: # Single image processing (original behavior) out_image = sampling(args=args, pipe=pipe)