chulanpro5 commited on
Commit
fa4c65b
·
1 Parent(s): 1dc498e

feat: batch_sampling

Browse files
Files changed (2) hide show
  1. app.py +153 -21
  2. batch_sample.py +604 -0
app.py CHANGED
@@ -1,13 +1,18 @@
1
  import random
 
 
 
2
  import spaces
3
  import gradio as gr
4
- from sample import (arg_parse,
5
  sampling,
6
  load_fontdiffuer_pipeline)
7
 
 
 
8
  @spaces.GPU()
9
- def run_fontdiffuer(source_image,
10
- character,
11
  reference_image,
12
  sampling_step,
13
  guidance_scale,
@@ -23,12 +28,139 @@ def run_fontdiffuer(source_image,
23
  pipe=pipe,
24
  content_image=source_image,
25
  style_image=reference_image)
26
-
27
  if out_image is not None:
28
  out_image.format = 'PNG'
29
-
30
  return out_image
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  if __name__ == '__main__':
34
  args = arg_parse()
@@ -49,18 +181,18 @@ if __name__ == '__main__':
49
  FontDiffuser
50
  </h1>
51
  <h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
52
- <a href="https://yeungchenwa.github.io/"">Zhenhua Yang</a>,
53
- <a href="https://scholar.google.com/citations?user=6zNgcjAAAAAJ&hl=zh-CN&oi=ao"">Dezhi Peng</a>,
54
- <a href="https://github.com/kyxscut"">Yuxin Kong</a>,
55
- <a href="https://github.com/ZZXF11"">Yuyi Zhang</a>,
56
- <a href="https://scholar.google.com/citations?user=IpmnLFcAAAAJ&hl=zh-CN&oi=ao"">Cong Yao</a>,
57
  <a href="http://www.dlvc-lab.net/lianwen/Index.html"">Lianwen Jin</a>†
58
  </h2>
59
  <h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
60
  <strong>South China University of Technology</strong>, Alibaba DAMO Academy
61
  </h2>
62
- <h3 style="font-weight: 450; font-size: 1rem; margin: 0rem">
63
- [<a href="https://arxiv.org/abs/2312.12142" style="color:blue;">arXiv</a>]
64
  [<a href="https://yeungchenwa.github.io/fontdiffuser-homepage/" style="color:green;">Homepage</a>]
65
  [<a href="https://github.com/yeungchenwa/FontDiffuser" style="color:green;">Github</a>]
66
  </h3>
@@ -83,12 +215,12 @@ if __name__ == '__main__':
83
  with gr.Row():
84
  fontdiffuer_output_image = gr.Image(height=200, label="FontDiffuser Output Image", image_mode='RGB', type='pil', format='png')
85
 
86
- sampling_step = gr.Slider(20, 50, value=20, step=10,
87
  label="Sampling Step", info="The sampling step by FontDiffuser.")
88
- guidance_scale = gr.Slider(1, 12, value=7.5, step=0.5,
89
- label="Scale of Classifier-free Guidance",
90
  info="The scale used for classifier-free guidance sampling")
91
- batch_size = gr.Slider(1, 4, value=1, step=1,
92
  label="Batch Size", info="The number of images to be sampled.")
93
 
94
  FontDiffuser = gr.Button('Run FontDiffuser')
@@ -101,7 +233,7 @@ if __name__ == '__main__':
101
  gr.Markdown("### In this mode, we provide both the source image and \
102
  the reference image for you to try our demo!")
103
  gr.Examples(
104
- examples=[['figures/source_imgs/source_灨.jpg', 'figures/ref_imgs/ref_籍.jpg'],
105
  ['figures/source_imgs/source_鑻.jpg', 'figures/ref_imgs/ref_鹰.jpg'],
106
  ['figures/source_imgs/source_鑫.jpg', 'figures/ref_imgs/ref_壤.jpg'],
107
  ['figures/source_imgs/source_釅.jpg', 'figures/ref_imgs/ref_雕.jpg']],
@@ -124,7 +256,7 @@ if __name__ == '__main__':
124
  you can upload your own source image or you choose the character above \
125
  to try our demo!")
126
  gr.Examples(
127
- examples=['figures/ref_imgs/ref_闡.jpg',
128
  'figures/ref_imgs/ref_雕.jpg',
129
  'figures/ref_imgs/ref_豄.jpg',
130
  'figures/ref_imgs/ref_馨.jpg',
@@ -145,11 +277,11 @@ if __name__ == '__main__':
145
  )
146
  FontDiffuser.click(
147
  fn=run_fontdiffuer,
148
- inputs=[source_image,
149
- character,
150
  reference_image,
151
  sampling_step,
152
  guidance_scale,
153
  batch_size],
154
  outputs=fontdiffuer_output_image)
155
- demo.launch(debug=True)
 
1
  import random
2
+ from typing import List, Union, Optional, Tuple
3
+ import torch
4
+ from PIL import Image
5
  import spaces
6
  import gradio as gr
7
+ from sample import (arg_parse,
8
  sampling,
9
  load_fontdiffuer_pipeline)
10
 
11
+ from batch_sample import batch_sampling
12
+
13
  @spaces.GPU()
14
+ def run_fontdiffuer(source_image,
15
+ character,
16
  reference_image,
17
  sampling_step,
18
  guidance_scale,
 
28
  pipe=pipe,
29
  content_image=source_image,
30
  style_image=reference_image)
31
+
32
  if out_image is not None:
33
  out_image.format = 'PNG'
34
+
35
  return out_image
36
 
37
+ def _normalize_batch_inputs(source_images, characters, reference_images) -> Tuple[List, List, List, int]:
38
+ """
39
+ Normalize different input types to consistent lists
40
+
41
+ Returns:
42
+ Tuple of (content_inputs, style_inputs, char_inputs, total_samples)
43
+ """
44
+ content_inputs = []
45
+ style_inputs = []
46
+ char_inputs = []
47
+
48
+ # Handle character mode
49
+ if source_images is None:
50
+ if isinstance(characters, str):
51
+ char_inputs = [characters]
52
+ elif isinstance(characters, list):
53
+ char_inputs = characters
54
+ else:
55
+ return [], [], [], 0
56
+
57
+ # Replicate reference images to match character count
58
+ if isinstance(reference_images, Image.Image):
59
+ style_inputs = [reference_images] * len(char_inputs)
60
+ elif isinstance(reference_images, list):
61
+ if len(reference_images) == 1:
62
+ style_inputs = reference_images * len(char_inputs)
63
+ elif len(reference_images) == len(char_inputs):
64
+ style_inputs = reference_images
65
+ else:
66
+ # Cycle through reference images if counts don't match
67
+ style_inputs = [reference_images[i % len(reference_images)] for i in range(len(char_inputs))]
68
+
69
+ total_samples = len(char_inputs)
70
+
71
+ # Handle image mode
72
+ else:
73
+ if isinstance(source_images, Image.Image):
74
+ content_inputs = [source_images]
75
+ elif isinstance(source_images, list):
76
+ content_inputs = source_images
77
+ else:
78
+ return [], [], [], 0
79
+
80
+ # Handle reference images
81
+ if isinstance(reference_images, Image.Image):
82
+ style_inputs = [reference_images] * len(content_inputs)
83
+ elif isinstance(reference_images, list):
84
+ if len(reference_images) == 1:
85
+ style_inputs = reference_images * len(content_inputs)
86
+ elif len(reference_images) == len(content_inputs):
87
+ style_inputs = reference_images
88
+ else:
89
+ # Cycle through reference images if counts don't match
90
+ style_inputs = [reference_images[i % len(reference_images)] for i in range(len(content_inputs))]
91
+
92
+ total_samples = len(content_inputs)
93
+
94
+ return content_inputs, style_inputs, char_inputs, total_samples
95
+
96
+
97
+ @spaces.GPU()
98
+ def run_fontdiffuer_batch(source_images: Union[List[Image.Image], Image.Image, None],
99
+ characters: Union[List[str], str, None],
100
+ reference_images: Union[List[Image.Image], Image.Image],
101
+ sampling_step: int = 50,
102
+ guidance_scale: float = 7.5,
103
+ batch_size: int = 4,
104
+ seed: Optional[int] = None) -> List[Image.Image]:
105
+ """
106
+ Run FontDiffuser in batch mode
107
+
108
+ Args:
109
+ source_images: Single image, list of images, or None (for character mode)
110
+ characters: Single character, list of characters, or None (for image mode)
111
+ reference_images: Single style image or list of style images
112
+ sampling_step: Number of sampling steps
113
+ guidance_scale: Guidance scale for diffusion
114
+ batch_size: Batch size for processing
115
+ seed: Random seed (if None, generates random seed)
116
+
117
+ Returns:
118
+ List of generated images
119
+ """
120
+
121
+ # Normalize inputs to lists
122
+ content_inputs, style_inputs, char_inputs, total_samples = _normalize_batch_inputs(
123
+ source_images, characters, reference_images
124
+ )
125
+
126
+ if total_samples == 0:
127
+ return []
128
+
129
+ # Set up arguments
130
+ args.character_input = source_images is None
131
+ args.sampling_step = sampling_step
132
+ args.guidance_scale = guidance_scale
133
+ args.batch_size = min(batch_size, total_samples) # Don't exceed available samples
134
+ args.seed = seed if seed is not None else random.randint(0, 10000)
135
+
136
+ print(f"Processing {total_samples} samples with batch size {args.batch_size}")
137
+
138
+ # Use the enhanced batch_sampling function
139
+ if args.character_input:
140
+ # Character-based generation
141
+ generated_images = batch_sampling(
142
+ args=args,
143
+ pipe=pipe,
144
+ content_inputs=content_inputs, # Empty for character mode
145
+ style_inputs=style_inputs,
146
+ content_characters=char_inputs
147
+ )
148
+ else:
149
+ # Image-based generation
150
+ generated_images = batch_sampling(
151
+ args=args,
152
+ pipe=pipe,
153
+ content_inputs=content_inputs,
154
+ style_inputs=style_inputs,
155
+ content_characters=None
156
+ )
157
+
158
+ # Set format for all output images
159
+ for img in generated_images:
160
+ img.format = 'PNG'
161
+
162
+ return generated_images
163
+
164
 
165
  if __name__ == '__main__':
166
  args = arg_parse()
 
181
  FontDiffuser
182
  </h1>
183
  <h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
184
+ <a href="https://yeungchenwa.github.io/"">Zhenhua Yang</a>,
185
+ <a href="https://scholar.google.com/citations?user=6zNgcjAAAAAJ&hl=zh-CN&oi=ao"">Dezhi Peng</a>,
186
+ <a href="https://github.com/kyxscut"">Yuxin Kong</a>,
187
+ <a href="https://github.com/ZZXF11"">Yuyi Zhang</a>,
188
+ <a href="https://scholar.google.com/citations?user=IpmnLFcAAAAJ&hl=zh-CN&oi=ao"">Cong Yao</a>,
189
  <a href="http://www.dlvc-lab.net/lianwen/Index.html"">Lianwen Jin</a>†
190
  </h2>
191
  <h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
192
  <strong>South China University of Technology</strong>, Alibaba DAMO Academy
193
  </h2>
194
+ <h3 style="font-weight: 450; font-size: 1rem; margin: 0rem">
195
+ [<a href="https://arxiv.org/abs/2312.12142" style="color:blue;">arXiv</a>]
196
  [<a href="https://yeungchenwa.github.io/fontdiffuser-homepage/" style="color:green;">Homepage</a>]
197
  [<a href="https://github.com/yeungchenwa/FontDiffuser" style="color:green;">Github</a>]
198
  </h3>
 
215
  with gr.Row():
216
  fontdiffuer_output_image = gr.Image(height=200, label="FontDiffuser Output Image", image_mode='RGB', type='pil', format='png')
217
 
218
+ sampling_step = gr.Slider(20, 50, value=20, step=10,
219
  label="Sampling Step", info="The sampling step by FontDiffuser.")
220
+ guidance_scale = gr.Slider(1, 12, value=7.5, step=0.5,
221
+ label="Scale of Classifier-free Guidance",
222
  info="The scale used for classifier-free guidance sampling")
223
+ batch_size = gr.Slider(1, 4, value=1, step=1,
224
  label="Batch Size", info="The number of images to be sampled.")
225
 
226
  FontDiffuser = gr.Button('Run FontDiffuser')
 
233
  gr.Markdown("### In this mode, we provide both the source image and \
234
  the reference image for you to try our demo!")
235
  gr.Examples(
236
+ examples=[['figures/source_imgs/source_灨.jpg', 'figures/ref_imgs/ref_籍.jpg'],
237
  ['figures/source_imgs/source_鑻.jpg', 'figures/ref_imgs/ref_鹰.jpg'],
238
  ['figures/source_imgs/source_鑫.jpg', 'figures/ref_imgs/ref_壤.jpg'],
239
  ['figures/source_imgs/source_釅.jpg', 'figures/ref_imgs/ref_雕.jpg']],
 
256
  you can upload your own source image or you choose the character above \
257
  to try our demo!")
258
  gr.Examples(
259
+ examples=['figures/ref_imgs/ref_闡.jpg',
260
  'figures/ref_imgs/ref_雕.jpg',
261
  'figures/ref_imgs/ref_豄.jpg',
262
  'figures/ref_imgs/ref_馨.jpg',
 
277
  )
278
  FontDiffuser.click(
279
  fn=run_fontdiffuer,
280
+ inputs=[source_image,
281
+ character,
282
  reference_image,
283
  sampling_step,
284
  guidance_scale,
285
  batch_size],
286
  outputs=fontdiffuer_output_image)
287
+ demo.launch(debug=True)
batch_sample.py ADDED
@@ -0,0 +1,604 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from PIL import Image
4
+ from typing import List, Tuple, Optional, Union
5
+ from concurrent.futures import ThreadPoolExecutor, as_completed
6
+ from pathlib import Path
7
+
8
+ import torch
9
+ import torchvision.transforms as transforms
10
+ from accelerate.utils import set_seed
11
+
12
+ from src import (
13
+ FontDiffuserDPMPipeline,
14
+ FontDiffuserModelDPM,
15
+ build_ddpm_scheduler,
16
+ build_unet,
17
+ build_content_encoder,
18
+ build_style_encoder,
19
+ )
20
+ from utils import (
21
+ ttf2im,
22
+ load_ttf,
23
+ is_char_in_font,
24
+ save_args_to_yaml,
25
+ save_single_image,
26
+ save_image_with_content_style,
27
+ )
28
+
29
+
30
+ class BatchProcessor:
31
+ """Handles batch processing logic for FontDiffuser"""
32
+
33
+ def __init__(self, args):
34
+ self.args = args
35
+ self.device = args.device
36
+ self.max_batch_size = getattr(args, "max_batch_size", 8)
37
+ self.num_workers = getattr(args, "num_workers", 4)
38
+
39
+ def batch_image_process(
40
+ self,
41
+ content_inputs: List[Union[str, Image.Image]],
42
+ style_inputs: List[Union[str, Image.Image]],
43
+ content_characters: Optional[List[str]] = None,
44
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[Optional[Image.Image]]]:
45
+ """
46
+ Process multiple images in batch
47
+
48
+ Args:
49
+ content_inputs: List of content image paths or PIL Images
50
+ style_inputs: List of style image paths or PIL Images
51
+ content_characters: List of characters if using character input mode
52
+
53
+ Returns:
54
+ Tuple of (content_tensors, style_tensors, content_pil_images)
55
+ """
56
+ batch_size = len(content_inputs)
57
+ assert len(style_inputs) == batch_size, (
58
+ "Content and style inputs must have same length"
59
+ )
60
+
61
+ if content_characters:
62
+ assert len(content_characters) == batch_size, (
63
+ "Content characters must match batch size"
64
+ )
65
+
66
+ # Transform setup
67
+ content_inference_transforms = transforms.Compose(
68
+ [
69
+ transforms.Resize(
70
+ self.args.content_image_size,
71
+ interpolation=transforms.InterpolationMode.BILINEAR,
72
+ ),
73
+ transforms.ToTensor(),
74
+ transforms.Normalize([0.5], [0.5]),
75
+ ]
76
+ )
77
+
78
+ style_inference_transforms = transforms.Compose(
79
+ [
80
+ transforms.Resize(
81
+ self.args.style_image_size,
82
+ interpolation=transforms.InterpolationMode.BILINEAR,
83
+ ),
84
+ transforms.ToTensor(),
85
+ transforms.Normalize([0.5], [0.5]),
86
+ ]
87
+ )
88
+
89
+ content_tensors = []
90
+ style_tensors = []
91
+ content_pil_images = []
92
+
93
+ # Process in parallel using ThreadPoolExecutor for I/O operations
94
+ with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
95
+ # Submit content processing tasks
96
+ content_futures = []
97
+ for i, content_input in enumerate(content_inputs):
98
+ if content_characters and i < len(content_characters):
99
+ future = executor.submit(
100
+ self._process_content_character,
101
+ content_characters[i],
102
+ content_inference_transforms,
103
+ )
104
+ else:
105
+ future = executor.submit(
106
+ self._process_content_image,
107
+ content_input,
108
+ content_inference_transforms,
109
+ )
110
+ content_futures.append(future)
111
+
112
+ # Submit style processing tasks
113
+ style_futures = []
114
+ for style_input in style_inputs:
115
+ future = executor.submit(
116
+ self._process_style_image, style_input, style_inference_transforms
117
+ )
118
+ style_futures.append(future)
119
+
120
+ # Collect results
121
+ for future in as_completed(content_futures):
122
+ try:
123
+ content_tensor, content_pil = future.result()
124
+ if content_tensor is not None:
125
+ content_tensors.append(content_tensor)
126
+ content_pil_images.append(content_pil)
127
+ except Exception as e:
128
+ print(f"Error processing content: {e}")
129
+ continue
130
+
131
+ for future in as_completed(style_futures):
132
+ try:
133
+ style_tensor = future.result()
134
+ if style_tensor is not None:
135
+ style_tensors.append(style_tensor)
136
+ except Exception as e:
137
+ print(f"Error processing style: {e}")
138
+ continue
139
+
140
+ # Stack tensors into batches
141
+ if content_tensors and style_tensors:
142
+ content_batch = torch.stack(content_tensors)
143
+ style_batch = torch.stack(style_tensors)
144
+ return content_batch, style_batch, content_pil_images
145
+ else:
146
+ return None, None, []
147
+
148
+ def _process_content_character(
149
+ self, character: str, transform
150
+ ) -> Tuple[Optional[torch.Tensor], Optional[Image.Image]]:
151
+ """Process content character into tensor"""
152
+ if not is_char_in_font(font_path=self.args.ttf_path, char=character):
153
+ print(f"Character '{character}' not found in font")
154
+ return None, None
155
+
156
+ font = load_ttf(ttf_path=self.args.ttf_path)
157
+ content_image = ttf2im(font=font, char=character)
158
+ content_image_pil = content_image.copy()
159
+ content_tensor = transform(content_image)
160
+
161
+ return content_tensor, content_image_pil
162
+
163
+ def _process_content_image(
164
+ self, image_input: Union[str, Image.Image], transform
165
+ ) -> Tuple[Optional[torch.Tensor], None]:
166
+ """Process content image into tensor"""
167
+ try:
168
+ if isinstance(image_input, str):
169
+ content_image = Image.open(image_input).convert("RGB")
170
+ else:
171
+ content_image = image_input.convert("RGB")
172
+
173
+ content_tensor = transform(content_image)
174
+ return content_tensor, None
175
+ except Exception as e:
176
+ print(f"Error processing content image: {e}")
177
+ return None, None
178
+
179
+ def _process_style_image(
180
+ self, image_input: Union[str, Image.Image], transform
181
+ ) -> Optional[torch.Tensor]:
182
+ """Process style image into tensor"""
183
+ try:
184
+ if isinstance(image_input, str):
185
+ style_image = Image.open(image_input).convert("RGB")
186
+ else:
187
+ style_image = image_input.convert("RGB")
188
+
189
+ style_tensor = transform(style_image)
190
+ return style_tensor
191
+ except Exception as e:
192
+ print(f"Error processing style image: {e}")
193
+ return None
194
+
195
+
196
+ def arg_parse():
197
+ from configs.fontdiffuser import get_parser
198
+
199
+ parser = get_parser()
200
+ parser.add_argument("--ckpt_dir", type=str, default=None)
201
+ parser.add_argument("--demo", action="store_true")
202
+ parser.add_argument(
203
+ "--controlnet",
204
+ type=bool,
205
+ default=False,
206
+ help="If in demo mode, the controlnet can be added.",
207
+ )
208
+ parser.add_argument("--character_input", action="store_true")
209
+ parser.add_argument("--content_character", type=str, default=None)
210
+ parser.add_argument("--content_image_path", type=str, default=None)
211
+ parser.add_argument("--style_image_path", type=str, default=None)
212
+ parser.add_argument("--save_image", action="store_true")
213
+ parser.add_argument(
214
+ "--save_image_dir", type=str, default=None, help="The saving directory."
215
+ )
216
+ parser.add_argument("--device", type=str, default="cuda:0")
217
+ parser.add_argument("--ttf_path", type=str, default="ttf/KaiXinSongA.ttf")
218
+
219
+ # Batch processing arguments
220
+ parser.add_argument(
221
+ "--batch_size",
222
+ type=int,
223
+ default=4,
224
+ help="Batch size for processing multiple images",
225
+ )
226
+ parser.add_argument(
227
+ "--max_batch_size",
228
+ type=int,
229
+ default=8,
230
+ help="Maximum batch size based on GPU memory",
231
+ )
232
+ parser.add_argument(
233
+ "--num_workers",
234
+ type=int,
235
+ default=4,
236
+ help="Number of workers for parallel image loading",
237
+ )
238
+ parser.add_argument(
239
+ "--batch_content_paths",
240
+ type=str,
241
+ nargs="+",
242
+ default=None,
243
+ help="List of content image paths for batch processing",
244
+ )
245
+ parser.add_argument(
246
+ "--batch_style_paths",
247
+ type=str,
248
+ nargs="+",
249
+ default=None,
250
+ help="List of style image paths for batch processing",
251
+ )
252
+ parser.add_argument(
253
+ "--batch_characters",
254
+ type=str,
255
+ nargs="+",
256
+ default=None,
257
+ help="List of characters for batch processing",
258
+ )
259
+ parser.add_argument(
260
+ "--adaptive_batch_size",
261
+ action="store_true",
262
+ help="Automatically adjust batch size based on GPU memory",
263
+ )
264
+
265
+ args = parser.parse_args()
266
+ style_image_size = args.style_image_size
267
+ content_image_size = args.content_image_size
268
+ args.style_image_size = (style_image_size, style_image_size)
269
+ args.content_image_size = (content_image_size, content_image_size)
270
+
271
+ return args
272
+
273
+
274
+ def get_optimal_batch_size(args) -> int:
275
+ """Determine optimal batch size based on GPU memory"""
276
+ if not torch.cuda.is_available():
277
+ return 1
278
+
279
+ # Get GPU memory info
280
+ gpu_memory = torch.cuda.get_device_properties(args.device).total_memory / (
281
+ 1024**3
282
+ ) # GB
283
+
284
+ # Estimate batch size based on GPU memory (rough heuristic)
285
+ if gpu_memory >= 24: # RTX 4090, A100, etc.
286
+ optimal_batch = min(16, args.max_batch_size)
287
+ elif gpu_memory >= 12: # RTX 3080 Ti, RTX 4070 Ti, etc.
288
+ optimal_batch = min(8, args.max_batch_size)
289
+ elif gpu_memory >= 8: # RTX 3070, RTX 4060 Ti, etc.
290
+ optimal_batch = min(4, args.max_batch_size)
291
+ else: # Lower end GPUs
292
+ optimal_batch = min(2, args.max_batch_size)
293
+
294
+ return optimal_batch
295
+
296
+
297
+ def load_fontdiffuer_pipeline(args):
298
+ """Load FontDiffuser pipeline (unchanged from original)"""
299
+ # Load the model state_dict
300
+ unet = build_unet(args=args)
301
+ unet.load_state_dict(torch.load(f"{args.ckpt_dir}/unet.pth"))
302
+ style_encoder = build_style_encoder(args=args)
303
+ style_encoder.load_state_dict(torch.load(f"{args.ckpt_dir}/style_encoder.pth"))
304
+ content_encoder = build_content_encoder(args=args)
305
+ content_encoder.load_state_dict(torch.load(f"{args.ckpt_dir}/content_encoder.pth"))
306
+ model = FontDiffuserModelDPM(
307
+ unet=unet, style_encoder=style_encoder, content_encoder=content_encoder
308
+ )
309
+ model.to(args.device)
310
+ print("Loaded the model state_dict successfully!")
311
+
312
+ # Load the training ddpm_scheduler.
313
+ train_scheduler = build_ddpm_scheduler(args=args)
314
+ print("Loaded training DDPM scheduler sucessfully!")
315
+
316
+ # Load the DPM_Solver to generate the sample.
317
+ pipe = FontDiffuserDPMPipeline(
318
+ model=model,
319
+ ddpm_train_scheduler=train_scheduler,
320
+ model_type=args.model_type,
321
+ guidance_type=args.guidance_type,
322
+ guidance_scale=args.guidance_scale,
323
+ )
324
+ print("Loaded dpm_solver pipeline sucessfully!")
325
+
326
+ return pipe
327
+
328
+
329
+ def batch_sampling(
330
+ args,
331
+ pipe,
332
+ content_inputs: List[Union[str, Image.Image]],
333
+ style_inputs: List[Union[str, Image.Image]],
334
+ content_characters: Optional[List[str]] = None,
335
+ ) -> List[Image.Image]:
336
+ """
337
+ Perform batch sampling with FontDiffuser
338
+
339
+ Args:
340
+ args: Arguments
341
+ pipe: FontDiffuser pipeline
342
+ content_inputs: List of content images/paths
343
+ style_inputs: List of style images/paths
344
+ content_characters: List of characters (if using character input)
345
+
346
+ Returns:
347
+ List of generated images
348
+ """
349
+ if not args.demo:
350
+ os.makedirs(args.save_image_dir, exist_ok=True)
351
+ save_args_to_yaml(
352
+ args=args, output_file=f"{args.save_image_dir}/sampling_config.yaml"
353
+ )
354
+
355
+ if args.seed:
356
+ set_seed(seed=args.seed)
357
+
358
+ # Determine optimal batch size
359
+ if args.adaptive_batch_size:
360
+ optimal_batch_size = get_optimal_batch_size(args)
361
+ print(f"Using adaptive batch size: {optimal_batch_size}")
362
+ else:
363
+ optimal_batch_size = args.batch_size
364
+
365
+ batch_processor = BatchProcessor(args)
366
+ total_samples = len(content_inputs)
367
+ all_generated_images = []
368
+
369
+ print(f"Processing {total_samples} samples in batches of {optimal_batch_size}")
370
+
371
+ # Process in batches
372
+ for batch_start in range(0, total_samples, optimal_batch_size):
373
+ batch_end = min(batch_start + optimal_batch_size, total_samples)
374
+ batch_content = content_inputs[batch_start:batch_end]
375
+ batch_style = style_inputs[batch_start:batch_end]
376
+ batch_chars = (
377
+ content_characters[batch_start:batch_end] if content_characters else None
378
+ )
379
+
380
+ print(
381
+ f"Processing batch {batch_start // optimal_batch_size + 1}/{(total_samples + optimal_batch_size - 1) // optimal_batch_size}"
382
+ )
383
+
384
+ # Process batch
385
+ content_batch, style_batch, content_pil_images = (
386
+ batch_processor.batch_image_process(batch_content, batch_style, batch_chars)
387
+ )
388
+
389
+ if content_batch is None or style_batch is None:
390
+ print("Skipping batch due to processing errors")
391
+ continue
392
+
393
+ current_batch_size = content_batch.shape[0]
394
+
395
+ with torch.no_grad():
396
+ content_batch = content_batch.to(args.device)
397
+ style_batch = style_batch.to(args.device)
398
+
399
+ print(f"Generating {current_batch_size} images with DPM-Solver++...")
400
+ start_time = time.time()
401
+
402
+ try:
403
+ # Generate batch
404
+ images = pipe.generate(
405
+ content_images=content_batch,
406
+ style_images=style_batch,
407
+ batch_size=current_batch_size,
408
+ order=args.order,
409
+ num_inference_step=args.num_inference_steps,
410
+ content_encoder_downsample_size=args.content_encoder_downsample_size,
411
+ t_start=args.t_start,
412
+ t_end=args.t_end,
413
+ dm_size=args.content_image_size,
414
+ algorithm_type=args.algorithm_type,
415
+ skip_type=args.skip_type,
416
+ method=args.method,
417
+ correcting_x0_fn=args.correcting_x0_fn,
418
+ )
419
+
420
+ end_time = time.time()
421
+ print(f"Batch generation completed in {end_time - start_time:.2f}s")
422
+
423
+ # Save images if requested
424
+ if args.save_image:
425
+ save_batch_images(
426
+ args,
427
+ images,
428
+ content_pil_images,
429
+ batch_content,
430
+ batch_style,
431
+ batch_start,
432
+ )
433
+
434
+ all_generated_images.extend(images)
435
+
436
+ except RuntimeError as e:
437
+ if "out of memory" in str(e).lower():
438
+ print(
439
+ f"GPU out of memory with batch size {current_batch_size}, trying smaller batch..."
440
+ )
441
+ torch.cuda.empty_cache()
442
+ # Retry with smaller batch
443
+ smaller_batch_size = max(1, current_batch_size // 2)
444
+ for sub_batch_start in range(
445
+ 0, current_batch_size, smaller_batch_size
446
+ ):
447
+ sub_batch_end = min(
448
+ sub_batch_start + smaller_batch_size, current_batch_size
449
+ )
450
+ sub_content = content_batch[sub_batch_start:sub_batch_end]
451
+ sub_style = style_batch[sub_batch_start:sub_batch_end]
452
+
453
+ sub_images = pipe.generate(
454
+ content_images=sub_content,
455
+ style_images=sub_style,
456
+ batch_size=sub_batch_end - sub_batch_start,
457
+ order=args.order,
458
+ num_inference_step=args.num_inference_steps,
459
+ content_encoder_downsample_size=args.content_encoder_downsample_size,
460
+ t_start=args.t_start,
461
+ t_end=args.t_end,
462
+ dm_size=args.content_image_size,
463
+ algorithm_type=args.algorithm_type,
464
+ skip_type=args.skip_type,
465
+ method=args.method,
466
+ correcting_x0_fn=args.correcting_x0_fn,
467
+ )
468
+ all_generated_images.extend(sub_images)
469
+ else:
470
+ print(f"Error during generation: {e}")
471
+ continue
472
+
473
+ # Clear GPU cache between batches
474
+ torch.cuda.empty_cache()
475
+
476
+ print(f"Batch processing completed! Generated {len(all_generated_images)} images.")
477
+ return all_generated_images
478
+
479
+
480
+ def save_batch_images(
481
+ args, images, content_pil_images, batch_content, batch_style, batch_offset
482
+ ):
483
+ """Save batch of generated images"""
484
+ for i, image in enumerate(images):
485
+ # Create unique filename for each image
486
+ image_idx = batch_offset + i
487
+ save_single_image(
488
+ save_dir=args.save_image_dir, image=image, suffix=f"_{image_idx:04d}"
489
+ )
490
+
491
+ # Save with content and style context if available
492
+ if args.character_input and i < len(content_pil_images):
493
+ save_image_with_content_style(
494
+ save_dir=args.save_image_dir,
495
+ image=image,
496
+ content_image_pil=content_pil_images[i],
497
+ content_image_path=None,
498
+ style_image_path=batch_style[i]
499
+ if isinstance(batch_style[i], str)
500
+ else None,
501
+ resolution=args.resolution,
502
+ suffix=f"_{image_idx:04d}",
503
+ )
504
+ elif not args.character_input:
505
+ save_image_with_content_style(
506
+ save_dir=args.save_image_dir,
507
+ image=image,
508
+ content_image_pil=None,
509
+ content_image_path=batch_content[i]
510
+ if isinstance(batch_content[i], str)
511
+ else None,
512
+ style_image_path=batch_style[i]
513
+ if isinstance(batch_style[i], str)
514
+ else None,
515
+ resolution=args.resolution,
516
+ suffix=f"_{image_idx:04d}",
517
+ )
518
+
519
+
520
+ def sampling(args, pipe, content_image=None, style_image=None):
521
+ """Original single image sampling function (for backward compatibility)"""
522
+ if not args.demo:
523
+ os.makedirs(args.save_image_dir, exist_ok=True)
524
+ save_args_to_yaml(
525
+ args=args, output_file=f"{args.save_image_dir}/sampling_config.yaml"
526
+ )
527
+
528
+ if args.seed:
529
+ set_seed(seed=args.seed)
530
+
531
+ # Use single image processing
532
+ if args.character_input:
533
+ content_inputs = (
534
+ [args.content_character] if hasattr(args, "content_character") else ["A"]
535
+ )
536
+ style_inputs = [style_image or args.style_image_path]
537
+ result = batch_sampling(args, pipe, [], style_inputs, content_inputs)
538
+ else:
539
+ content_inputs = [content_image or args.content_image_path]
540
+ style_inputs = [style_image or args.style_image_path]
541
+ result = batch_sampling(args, pipe, content_inputs, style_inputs)
542
+
543
+ return result[0] if result else None
544
+
545
+
546
+ # Additional utility functions for batch processing
547
+ def load_images_from_directory(
548
+ directory_path: str, extensions: List[str] = [".jpg", ".jpeg", ".png", ".bmp"]
549
+ ) -> List[str]:
550
+ """Load all image paths from a directory"""
551
+ directory = Path(directory_path)
552
+ image_paths = []
553
+
554
+ for ext in extensions:
555
+ image_paths.extend(directory.glob(f"*{ext}"))
556
+ image_paths.extend(directory.glob(f"*{ext.upper()}"))
557
+
558
+ return [str(path) for path in sorted(image_paths)]
559
+
560
+
561
+ def create_batch_from_config(
562
+ config_file: str,
563
+ ) -> Tuple[List[str], List[str], List[str]]:
564
+ """Create batch inputs from configuration file"""
565
+ import json
566
+
567
+ with open(config_file, "r") as f:
568
+ config = json.load(f)
569
+
570
+ content_inputs = config.get("content_images", [])
571
+ style_inputs = config.get("style_images", [])
572
+ characters = config.get("characters", [])
573
+
574
+ return content_inputs, style_inputs, characters
575
+
576
+
577
+ if __name__ == "__main__":
578
+ args = arg_parse()
579
+
580
+ # Load fontdiffuser pipeline
581
+ pipe = load_fontdiffuer_pipeline(args=args)
582
+
583
+ # Check if batch processing is requested
584
+ if args.batch_content_paths or args.batch_style_paths or args.batch_characters:
585
+ # Batch processing mode
586
+ content_inputs = args.batch_content_paths or []
587
+ style_inputs = args.batch_style_paths or []
588
+ characters = args.batch_characters or None
589
+
590
+ if characters and args.character_input:
591
+ # Character-based batch processing
592
+ style_inputs = style_inputs or [args.style_image_path] * len(characters)
593
+ generated_images = batch_sampling(args, pipe, [], style_inputs, characters)
594
+ else:
595
+ # Image-based batch processing
596
+ if len(content_inputs) != len(style_inputs):
597
+ print("Error: Number of content and style images must match")
598
+ exit(1)
599
+ generated_images = batch_sampling(args, pipe, content_inputs, style_inputs)
600
+
601
+ print(f"Batch processing completed! Generated {len(generated_images)} images.")
602
+ else:
603
+ # Single image processing (original behavior)
604
+ out_image = sampling(args=args, pipe=pipe)