Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import time | |
from PIL import Image | |
from typing import List, Tuple, Optional, Union | |
from concurrent.futures import ThreadPoolExecutor, as_completed | |
from pathlib import Path | |
import torch | |
import torchvision.transforms as transforms | |
from accelerate.utils import set_seed | |
from src import ( | |
FontDiffuserDPMPipeline, | |
FontDiffuserModelDPM, | |
build_ddpm_scheduler, | |
build_unet, | |
build_content_encoder, | |
build_style_encoder, | |
) | |
from utils import ( | |
ttf2im, | |
load_ttf, | |
is_char_in_font, | |
save_args_to_yaml, | |
save_single_image, | |
save_image_with_content_style, | |
) | |
class BatchProcessor: | |
"""Handles batch processing logic for FontDiffuser""" | |
def __init__(self, args): | |
self.args = args | |
self.device = args.device | |
self.max_batch_size = getattr(args, "max_batch_size", 8) | |
self.num_workers = getattr(args, "num_workers", 4) | |
def batch_image_process( | |
self, | |
content_inputs: List[Union[str, Image.Image]], | |
style_inputs: List[Union[str, Image.Image]], | |
content_characters: Optional[List[str]] = None, | |
) -> Tuple[torch.Tensor, torch.Tensor, List[Optional[Image.Image]]]: | |
""" | |
Process multiple images in batch | |
Args: | |
content_inputs: List of content image paths or PIL Images | |
style_inputs: List of style image paths or PIL Images | |
content_characters: List of characters if using character input mode | |
Returns: | |
Tuple of (content_tensors, style_tensors, content_pil_images) | |
""" | |
batch_size = len(content_inputs) | |
assert len(style_inputs) == batch_size, ( | |
"Content and style inputs must have same length" | |
) | |
if content_characters: | |
assert len(content_characters) == batch_size, ( | |
"Content characters must match batch size" | |
) | |
# Transform setup | |
content_inference_transforms = transforms.Compose( | |
[ | |
transforms.Resize( | |
self.args.content_image_size, | |
interpolation=transforms.InterpolationMode.BILINEAR, | |
), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5], [0.5]), | |
] | |
) | |
style_inference_transforms = transforms.Compose( | |
[ | |
transforms.Resize( | |
self.args.style_image_size, | |
interpolation=transforms.InterpolationMode.BILINEAR, | |
), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5], [0.5]), | |
] | |
) | |
# Initialize ordered lists for results | |
content_tensors = [None] * batch_size | |
style_tensors = [None] * batch_size | |
content_pil_images = [None] * batch_size | |
# Process in parallel using ThreadPoolExecutor for I/O operations | |
with ThreadPoolExecutor(max_workers=self.num_workers) as executor: | |
# Submit content processing tasks | |
content_futures = [] | |
for i, content_input in enumerate(content_inputs): | |
if content_characters and i < len(content_characters): | |
future = executor.submit( | |
self._process_content_character, | |
content_characters[i], | |
content_inference_transforms, | |
) | |
else: | |
future = executor.submit( | |
self._process_content_image, | |
content_input, | |
content_inference_transforms, | |
) | |
content_futures.append((i, future)) | |
# Submit style processing tasks | |
style_futures = [] | |
for i, style_input in enumerate(style_inputs): | |
future = executor.submit( | |
self._process_style_image, style_input, style_inference_transforms | |
) | |
style_futures.append((i, future)) | |
# Collect results in order | |
for i, future in content_futures: | |
try: | |
content_tensor, content_pil = future.result() | |
if content_tensor is not None: | |
content_tensors[i] = content_tensor | |
content_pil_images[i] = content_pil | |
except Exception as e: | |
print(f"Error processing content at index {i}: {e}") | |
continue | |
for i, future in style_futures: | |
try: | |
style_tensor = future.result() | |
if style_tensor is not None: | |
style_tensors[i] = style_tensor | |
except Exception as e: | |
print(f"Error processing style at index {i}: {e}") | |
continue | |
# Filter out None values and stack tensors | |
content_tensors = [t for t in content_tensors if t is not None] | |
style_tensors = [t for t in style_tensors if t is not None] | |
content_pil_images = [img for img in content_pil_images if img is not None] | |
if content_tensors and style_tensors: | |
content_batch = torch.stack(content_tensors) | |
style_batch = torch.stack(style_tensors) | |
return content_batch, style_batch, content_pil_images | |
else: | |
return None, None, [] | |
def _process_content_character( | |
self, character: str, transform | |
) -> Tuple[Optional[torch.Tensor], Optional[Image.Image]]: | |
"""Process content character into tensor""" | |
if not is_char_in_font(font_path=self.args.ttf_path, char=character): | |
print(f"Character '{character}' not found in font") | |
return None, None | |
font = load_ttf(ttf_path=self.args.ttf_path) | |
content_image = ttf2im(font=font, char=character) | |
content_image_pil = content_image.copy() | |
content_tensor = transform(content_image) | |
return content_tensor, content_image_pil | |
def _process_content_image( | |
self, image_input: Union[str, Image.Image], transform | |
) -> Tuple[Optional[torch.Tensor], None]: | |
"""Process content image into tensor""" | |
try: | |
if isinstance(image_input, str): | |
content_image = Image.open(image_input).convert("RGB") | |
else: | |
content_image = image_input.convert("RGB") | |
content_tensor = transform(content_image) | |
return content_tensor, None | |
except Exception as e: | |
print(f"Error processing content image: {e}") | |
return None, None | |
def _process_style_image( | |
self, image_input: Union[str, Image.Image], transform | |
) -> Optional[torch.Tensor]: | |
"""Process style image into tensor""" | |
try: | |
if isinstance(image_input, str): | |
style_image = Image.open(image_input).convert("RGB") | |
else: | |
style_image = image_input.convert("RGB") | |
style_tensor = transform(style_image) | |
return style_tensor | |
except Exception as e: | |
print(f"Error processing style image: {e}") | |
return None | |
def arg_parse(): | |
from configs.fontdiffuser import get_parser | |
parser = get_parser() | |
parser.add_argument("--ckpt_dir", type=str, default=None) | |
parser.add_argument("--demo", action="store_true") | |
parser.add_argument( | |
"--controlnet", | |
type=bool, | |
default=False, | |
help="If in demo mode, the controlnet can be added.", | |
) | |
parser.add_argument("--character_input", action="store_true") | |
parser.add_argument("--content_character", type=str, default=None) | |
parser.add_argument("--content_image_path", type=str, default=None) | |
parser.add_argument("--style_image_path", type=str, default=None) | |
parser.add_argument("--save_image", action="store_true") | |
parser.add_argument( | |
"--save_image_dir", type=str, default=None, help="The saving directory." | |
) | |
parser.add_argument("--device", type=str, default="cuda:0") | |
parser.add_argument("--ttf_path", type=str, default="ttf/KaiXinSongA.ttf") | |
# Batch processing arguments | |
parser.add_argument( | |
"--batch_size", | |
type=int, | |
default=4, | |
help="Batch size for processing multiple images", | |
) | |
parser.add_argument( | |
"--max_batch_size", | |
type=int, | |
default=8, | |
help="Maximum batch size based on GPU memory", | |
) | |
parser.add_argument( | |
"--num_workers", | |
type=int, | |
default=4, | |
help="Number of workers for parallel image loading", | |
) | |
parser.add_argument( | |
"--batch_content_paths", | |
type=str, | |
nargs="+", | |
default=None, | |
help="List of content image paths for batch processing", | |
) | |
parser.add_argument( | |
"--batch_style_paths", | |
type=str, | |
nargs="+", | |
default=None, | |
help="List of style image paths for batch processing", | |
) | |
parser.add_argument( | |
"--batch_characters", | |
type=str, | |
nargs="+", | |
default=None, | |
help="List of characters for batch processing", | |
) | |
parser.add_argument( | |
"--adaptive_batch_size", | |
action="store_true", | |
help="Automatically adjust batch size based on GPU memory", | |
) | |
args = parser.parse_args() | |
style_image_size = args.style_image_size | |
content_image_size = args.content_image_size | |
args.style_image_size = (style_image_size, style_image_size) | |
args.content_image_size = (content_image_size, content_image_size) | |
return args | |
def get_optimal_batch_size(args) -> int: | |
"""Determine optimal batch size based on GPU memory""" | |
if not torch.cuda.is_available(): | |
return 1 | |
# Get GPU memory info | |
gpu_memory = torch.cuda.get_device_properties(args.device).total_memory / ( | |
1024**3 | |
) # GB | |
# Estimate batch size based on GPU memory (rough heuristic) | |
if gpu_memory >= 24: # RTX 4090, A100, etc. | |
optimal_batch = min(16, args.max_batch_size) | |
elif gpu_memory >= 12: # RTX 3080 Ti, RTX 4070 Ti, etc. | |
optimal_batch = min(8, args.max_batch_size) | |
elif gpu_memory >= 8: # RTX 3070, RTX 4060 Ti, etc. | |
optimal_batch = min(4, args.max_batch_size) | |
else: # Lower end GPUs | |
optimal_batch = min(2, args.max_batch_size) | |
return optimal_batch | |
def load_fontdiffuer_pipeline(args): | |
"""Load FontDiffuser pipeline (unchanged from original)""" | |
# Load the model state_dict | |
unet = build_unet(args=args) | |
unet.load_state_dict(torch.load(f"{args.ckpt_dir}/unet.pth")) | |
style_encoder = build_style_encoder(args=args) | |
style_encoder.load_state_dict(torch.load(f"{args.ckpt_dir}/style_encoder.pth")) | |
content_encoder = build_content_encoder(args=args) | |
content_encoder.load_state_dict(torch.load(f"{args.ckpt_dir}/content_encoder.pth")) | |
model = FontDiffuserModelDPM( | |
unet=unet, style_encoder=style_encoder, content_encoder=content_encoder | |
) | |
model.to(args.device) | |
print("Loaded the model state_dict successfully!") | |
# Load the training ddpm_scheduler. | |
train_scheduler = build_ddpm_scheduler(args=args) | |
print("Loaded training DDPM scheduler sucessfully!") | |
# Load the DPM_Solver to generate the sample. | |
pipe = FontDiffuserDPMPipeline( | |
model=model, | |
ddpm_train_scheduler=train_scheduler, | |
model_type=args.model_type, | |
guidance_type=args.guidance_type, | |
guidance_scale=args.guidance_scale, | |
) | |
print("Loaded dpm_solver pipeline sucessfully!") | |
return pipe | |
def batch_sampling( | |
args, | |
pipe, | |
content_inputs: List[Union[str, Image.Image]], | |
style_inputs: List[Union[str, Image.Image]], | |
content_characters: Optional[List[str]] = None, | |
) -> List[Image.Image]: | |
""" | |
Perform batch sampling with FontDiffuser | |
Args: | |
args: Arguments | |
pipe: FontDiffuser pipeline | |
content_inputs: List of content images/paths | |
style_inputs: List of style images/paths | |
content_characters: List of characters (if using character input) | |
Returns: | |
List of generated images | |
""" | |
if not args.demo: | |
os.makedirs(args.save_image_dir, exist_ok=True) | |
save_args_to_yaml( | |
args=args, output_file=f"{args.save_image_dir}/sampling_config.yaml" | |
) | |
if args.seed: | |
set_seed(seed=args.seed) | |
# Determine optimal batch size | |
if args.adaptive_batch_size: | |
optimal_batch_size = get_optimal_batch_size(args) | |
print(f"Using adaptive batch size: {optimal_batch_size}") | |
else: | |
optimal_batch_size = args.batch_size | |
batch_processor = BatchProcessor(args) | |
total_samples = len(content_inputs) | |
all_generated_images = [] | |
print(f"Processing {total_samples} samples in batches of {optimal_batch_size}") | |
# Process in batches | |
for batch_start in range(0, total_samples, optimal_batch_size): | |
batch_end = min(batch_start + optimal_batch_size, total_samples) | |
batch_content = content_inputs[batch_start:batch_end] | |
batch_style = style_inputs[batch_start:batch_end] | |
batch_chars = ( | |
content_characters[batch_start:batch_end] if content_characters else None | |
) | |
print( | |
f"Processing batch {batch_start // optimal_batch_size + 1}/{(total_samples + optimal_batch_size - 1) // optimal_batch_size}" | |
) | |
# Process batch | |
content_batch, style_batch, content_pil_images = ( | |
batch_processor.batch_image_process(batch_content, batch_style, batch_chars) | |
) | |
if content_batch is None or style_batch is None: | |
print("Skipping batch due to processing errors") | |
continue | |
current_batch_size = content_batch.shape[0] | |
with torch.no_grad(): | |
content_batch = content_batch.to(args.device) | |
style_batch = style_batch.to(args.device) | |
print(f"Generating {current_batch_size} images with DPM-Solver++...") | |
start_time = time.time() | |
try: | |
# Generate batch | |
images = pipe.generate( | |
content_images=content_batch, | |
style_images=style_batch, | |
batch_size=current_batch_size, | |
order=args.order, | |
num_inference_step=args.num_inference_steps, | |
content_encoder_downsample_size=args.content_encoder_downsample_size, | |
t_start=args.t_start, | |
t_end=args.t_end, | |
dm_size=args.content_image_size, | |
algorithm_type=args.algorithm_type, | |
skip_type=args.skip_type, | |
method=args.method, | |
correcting_x0_fn=args.correcting_x0_fn, | |
) | |
end_time = time.time() | |
print(f"Batch generation completed in {end_time - start_time:.2f}s") | |
# Save images if requested | |
if args.save_image: | |
save_batch_images( | |
args, | |
images, | |
content_pil_images, | |
batch_content, | |
batch_style, | |
batch_start, | |
) | |
all_generated_images.extend(images) | |
except RuntimeError as e: | |
if "out of memory" in str(e).lower(): | |
print( | |
f"GPU out of memory with batch size {current_batch_size}, trying smaller batch..." | |
) | |
torch.cuda.empty_cache() | |
# Retry with smaller batch | |
smaller_batch_size = max(1, current_batch_size // 2) | |
for sub_batch_start in range( | |
0, current_batch_size, smaller_batch_size | |
): | |
sub_batch_end = min( | |
sub_batch_start + smaller_batch_size, current_batch_size | |
) | |
sub_content = content_batch[sub_batch_start:sub_batch_end] | |
sub_style = style_batch[sub_batch_start:sub_batch_end] | |
sub_images = pipe.generate( | |
content_images=sub_content, | |
style_images=sub_style, | |
batch_size=sub_batch_end - sub_batch_start, | |
order=args.order, | |
num_inference_step=args.num_inference_steps, | |
content_encoder_downsample_size=args.content_encoder_downsample_size, | |
t_start=args.t_start, | |
t_end=args.t_end, | |
dm_size=args.content_image_size, | |
algorithm_type=args.algorithm_type, | |
skip_type=args.skip_type, | |
method=args.method, | |
correcting_x0_fn=args.correcting_x0_fn, | |
) | |
all_generated_images.extend(sub_images) | |
else: | |
print(f"Error during generation: {e}") | |
continue | |
# Clear GPU cache between batches | |
torch.cuda.empty_cache() | |
print(f"Batch processing completed! Generated {len(all_generated_images)} images.") | |
return all_generated_images | |
def save_batch_images( | |
args, images, content_pil_images, batch_content, batch_style, batch_offset | |
): | |
"""Save batch of generated images""" | |
for i, image in enumerate(images): | |
# Create unique filename for each image | |
image_idx = batch_offset + i | |
save_single_image( | |
save_dir=args.save_image_dir, image=image, suffix=f"_{image_idx:04d}" | |
) | |
# Save with content and style context if available | |
if args.character_input and i < len(content_pil_images): | |
save_image_with_content_style( | |
save_dir=args.save_image_dir, | |
image=image, | |
content_image_pil=content_pil_images[i], | |
content_image_path=None, | |
style_image_path=batch_style[i] | |
if isinstance(batch_style[i], str) | |
else None, | |
resolution=args.resolution, | |
suffix=f"_{image_idx:04d}", | |
) | |
elif not args.character_input: | |
save_image_with_content_style( | |
save_dir=args.save_image_dir, | |
image=image, | |
content_image_pil=None, | |
content_image_path=batch_content[i] | |
if isinstance(batch_content[i], str) | |
else None, | |
style_image_path=batch_style[i] | |
if isinstance(batch_style[i], str) | |
else None, | |
resolution=args.resolution, | |
suffix=f"_{image_idx:04d}", | |
) | |
def sampling(args, pipe, content_image=None, style_image=None): | |
"""Original single image sampling function (for backward compatibility)""" | |
if not args.demo: | |
os.makedirs(args.save_image_dir, exist_ok=True) | |
save_args_to_yaml( | |
args=args, output_file=f"{args.save_image_dir}/sampling_config.yaml" | |
) | |
if args.seed: | |
set_seed(seed=args.seed) | |
# Use single image processing | |
if args.character_input: | |
content_inputs = ( | |
[args.content_character] if hasattr(args, "content_character") else ["A"] | |
) | |
style_inputs = [style_image or args.style_image_path] | |
result = batch_sampling(args, pipe, [], style_inputs, content_inputs) | |
else: | |
content_inputs = [content_image or args.content_image_path] | |
style_inputs = [style_image or args.style_image_path] | |
result = batch_sampling(args, pipe, content_inputs, style_inputs) | |
return result[0] if result else None | |
# Additional utility functions for batch processing | |
def load_images_from_directory( | |
directory_path: str, extensions: List[str] = [".jpg", ".jpeg", ".png", ".bmp"] | |
) -> List[str]: | |
"""Load all image paths from a directory""" | |
directory = Path(directory_path) | |
image_paths = [] | |
for ext in extensions: | |
image_paths.extend(directory.glob(f"*{ext}")) | |
image_paths.extend(directory.glob(f"*{ext.upper()}")) | |
return [str(path) for path in sorted(image_paths)] | |
def create_batch_from_config( | |
config_file: str, | |
) -> Tuple[List[str], List[str], List[str]]: | |
"""Create batch inputs from configuration file""" | |
import json | |
with open(config_file, "r") as f: | |
config = json.load(f) | |
content_inputs = config.get("content_images", []) | |
style_inputs = config.get("style_images", []) | |
characters = config.get("characters", []) | |
return content_inputs, style_inputs, characters | |
if __name__ == "__main__": | |
args = arg_parse() | |
# Load fontdiffuser pipeline | |
pipe = load_fontdiffuer_pipeline(args=args) | |
# Check if batch processing is requested | |
if args.batch_content_paths or args.batch_style_paths or args.batch_characters: | |
# Batch processing mode | |
content_inputs = args.batch_content_paths or [] | |
style_inputs = args.batch_style_paths or [] | |
characters = args.batch_characters or None | |
if characters and args.character_input: | |
# Character-based batch processing | |
style_inputs = style_inputs or [args.style_image_path] * len(characters) | |
generated_images = batch_sampling(args, pipe, [], style_inputs, characters) | |
else: | |
# Image-based batch processing | |
if len(content_inputs) != len(style_inputs): | |
print("Error: Number of content and style images must match") | |
exit(1) | |
generated_images = batch_sampling(args, pipe, content_inputs, style_inputs) | |
print(f"Batch processing completed! Generated {len(generated_images)} images.") | |
else: | |
# Single image processing (original behavior) | |
out_image = sampling(args=args, pipe=pipe) | |