import random from typing import List, Union, Optional, Tuple import torch from PIL import Image import spaces import gradio as gr from sample import (arg_parse, sampling, load_fontdiffuer_pipeline) from batch_sample import batch_sampling @spaces.GPU() def run_fontdiffuer(source_image, character, reference_image, sampling_step, guidance_scale, batch_size): args.character_input = False if source_image is not None else True args.content_character = character args.sampling_step = sampling_step args.guidance_scale = guidance_scale args.batch_size = batch_size args.seed = random.randint(0, 10000) out_image = sampling( args=args, pipe=pipe, content_image=source_image, style_image=reference_image) if out_image is not None: out_image.format = 'PNG' return out_image def _normalize_batch_inputs(source_images, characters, reference_images) -> Tuple[List, List, List, int]: """ Normalize different input types to consistent lists Returns: Tuple of (content_inputs, style_inputs, char_inputs, total_samples) """ content_inputs = [] style_inputs = [] char_inputs = [] # Handle character mode if source_images is None: if isinstance(characters, str): char_inputs = [characters] elif isinstance(characters, list): char_inputs = characters else: return [], [], [], 0 # Replicate reference images to match character count if isinstance(reference_images, Image.Image): style_inputs = [reference_images] * len(char_inputs) elif isinstance(reference_images, list): if len(reference_images) == 1: style_inputs = reference_images * len(char_inputs) elif len(reference_images) == len(char_inputs): style_inputs = reference_images else: # Cycle through reference images if counts don't match style_inputs = [reference_images[i % len(reference_images)] for i in range(len(char_inputs))] total_samples = len(char_inputs) # Handle image mode else: if isinstance(source_images, Image.Image): content_inputs = [source_images] elif isinstance(source_images, list): # Handle Gradio Gallery format: list of tuples (image, caption) content_inputs = [] for item in source_images: if isinstance(item, tuple) and len(item) >= 1: # Extract the image from tuple (image, caption) content_inputs.append(item[0]) elif isinstance(item, Image.Image): # Direct image content_inputs.append(item) else: return [], [], [], 0 # Handle reference images if isinstance(reference_images, Image.Image): style_inputs = [reference_images] * len(content_inputs) elif isinstance(reference_images, list): if len(reference_images) == 1: style_inputs = reference_images * len(content_inputs) elif len(reference_images) == len(content_inputs): style_inputs = reference_images else: # Cycle through reference images if counts don't match style_inputs = [reference_images[i % len(reference_images)] for i in range(len(content_inputs))] total_samples = len(content_inputs) return content_inputs, style_inputs, char_inputs, total_samples @spaces.GPU() def run_fontdiffuer_batch(source_images: Union[List[Image.Image], Image.Image, None], # characters: Union[List[str], str, None], # reference_images: Union[List[Image.Image], Image.Image], reference_image: Image.Image, sampling_step: int = 50, guidance_scale: float = 7.5, batch_size: int = 4, seed: Optional[int] = None) -> List[Image.Image]: """ Run FontDiffuser in batch mode Args: source_images: Single image, list of images, or None (for character mode) characters: Single character, list of characters, or None (for image mode) reference_images: Single style image or list of style images sampling_step: Number of sampling steps guidance_scale: Guidance scale for diffusion batch_size: Batch size for processing seed: Random seed (if None, generates random seed) Returns: List of generated images """ args.adaptive_batch_size = True characters = None reference_images = [reference_image] # Normalize inputs to lists content_inputs, style_inputs, char_inputs, total_samples = _normalize_batch_inputs( source_images, characters, reference_images ) if total_samples == 0: return [] # Set up arguments args.character_input = source_images is None args.sampling_step = sampling_step args.guidance_scale = guidance_scale args.batch_size = min(batch_size, total_samples) # Don't exceed available samples args.seed = seed if seed is not None else random.randint(0, 10000) print(f"Processing {total_samples} samples with batch size {args.batch_size}") # Use the enhanced batch_sampling function if args.character_input: # Character-based generation generated_images = batch_sampling( args=args, pipe=pipe, content_inputs=content_inputs, # Empty for character mode style_inputs=style_inputs, content_characters=char_inputs ) else: # Image-based generation generated_images = batch_sampling( args=args, pipe=pipe, content_inputs=content_inputs, style_inputs=style_inputs, content_characters=None ) # Set format for all output images for img in generated_images: img.format = 'PNG' return generated_images if __name__ == '__main__': args = arg_parse() args.demo = True args.ckpt_dir = 'ckpt' args.ttf_path = 'ttf/KaiXinSongA.ttf' args.device = 'cuda' args.max_batch_size = 64 args.num_workers = 64 args.adaptive_batch_size = True # load fontdiffuer pipeline pipe = load_fontdiffuer_pipeline(args=args) with gr.Blocks() as demo: with gr.Row(): with gr.Column(scale=1): gr.HTML("""