Phat K Tran commited on
Commit
f7864bb
·
1 Parent(s): b14c4f8

refactor: improve result collection in BatchProcessor for parallel processing

Browse files
Files changed (1) hide show
  1. batch_sample.py +20 -15
batch_sample.py CHANGED
@@ -86,9 +86,10 @@ class BatchProcessor:
86
  ]
87
  )
88
 
89
- content_tensors = []
90
- style_tensors = []
91
- content_pil_images = []
 
92
 
93
  # Process in parallel using ThreadPoolExecutor for I/O operations
94
  with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
@@ -107,37 +108,41 @@ class BatchProcessor:
107
  content_input,
108
  content_inference_transforms,
109
  )
110
- content_futures.append(future)
111
 
112
  # Submit style processing tasks
113
  style_futures = []
114
- for style_input in style_inputs:
115
  future = executor.submit(
116
  self._process_style_image, style_input, style_inference_transforms
117
  )
118
- style_futures.append(future)
119
 
120
- # Collect results
121
- for future in as_completed(content_futures):
122
  try:
123
  content_tensor, content_pil = future.result()
124
  if content_tensor is not None:
125
- content_tensors.append(content_tensor)
126
- content_pil_images.append(content_pil)
127
  except Exception as e:
128
- print(f"Error processing content: {e}")
129
  continue
130
 
131
- for future in as_completed(style_futures):
132
  try:
133
  style_tensor = future.result()
134
  if style_tensor is not None:
135
- style_tensors.append(style_tensor)
136
  except Exception as e:
137
- print(f"Error processing style: {e}")
138
  continue
139
 
140
- # Stack tensors into batches
 
 
 
 
141
  if content_tensors and style_tensors:
142
  content_batch = torch.stack(content_tensors)
143
  style_batch = torch.stack(style_tensors)
 
86
  ]
87
  )
88
 
89
+ # Initialize ordered lists for results
90
+ content_tensors = [None] * batch_size
91
+ style_tensors = [None] * batch_size
92
+ content_pil_images = [None] * batch_size
93
 
94
  # Process in parallel using ThreadPoolExecutor for I/O operations
95
  with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
 
108
  content_input,
109
  content_inference_transforms,
110
  )
111
+ content_futures.append((i, future))
112
 
113
  # Submit style processing tasks
114
  style_futures = []
115
+ for i, style_input in enumerate(style_inputs):
116
  future = executor.submit(
117
  self._process_style_image, style_input, style_inference_transforms
118
  )
119
+ style_futures.append((i, future))
120
 
121
+ # Collect results in order
122
+ for i, future in content_futures:
123
  try:
124
  content_tensor, content_pil = future.result()
125
  if content_tensor is not None:
126
+ content_tensors[i] = content_tensor
127
+ content_pil_images[i] = content_pil
128
  except Exception as e:
129
+ print(f"Error processing content at index {i}: {e}")
130
  continue
131
 
132
+ for i, future in style_futures:
133
  try:
134
  style_tensor = future.result()
135
  if style_tensor is not None:
136
+ style_tensors[i] = style_tensor
137
  except Exception as e:
138
+ print(f"Error processing style at index {i}: {e}")
139
  continue
140
 
141
+ # Filter out None values and stack tensors
142
+ content_tensors = [t for t in content_tensors if t is not None]
143
+ style_tensors = [t for t in style_tensors if t is not None]
144
+ content_pil_images = [img for img in content_pil_images if img is not None]
145
+
146
  if content_tensors and style_tensors:
147
  content_batch = torch.stack(content_tensors)
148
  style_batch = torch.stack(style_tensors)