par-meta commited on
Commit
8f2cf88
·
unverified ·
1 Parent(s): ea1fc75

Let process start before yielding preloaded prefetch buffer, avoid needlessly losing buffer in edge cases (#75)

Browse files
bytelatent/data/iterators/multiprocess_iterator.py CHANGED
@@ -190,7 +190,11 @@ class MultiprocessIterator(StatefulIterator):
190
  logging.info(
191
  "Main thread: Emptying the batch_queue until batch.is_final=True is found."
192
  )
193
- self.prefetch_buffer = []
 
 
 
 
194
  final_batch_received = False
195
  while True:
196
  try:
@@ -261,12 +265,14 @@ class MultiprocessIterator(StatefulIterator):
261
  "Attempted to get approximate state, but queue was erroniously empty."
262
  )
263
  self.received_approximate_state_event.set()
 
 
 
 
264
  return MultiprocessIteratorState(
265
  base_iterator_state=base_iterator_state,
266
  n_batches_to_prefetch=self.n_batches_to_prefetch,
267
- serialized_prefetch_buffer=json.dumps(
268
- [b.to_python_dict() for b in self.prefetch_buffer]
269
- ),
270
  persist_type=self.persist_type,
271
  )
272
 
@@ -281,9 +287,12 @@ class MultiprocessIterator(StatefulIterator):
281
  "State will be invalid if shutdown was forced before state persisted."
282
  )
283
  if self.producer is None:
284
- serialized_prefetch_buffer = json.dumps(
285
- [b.to_python_dict() for b in self.prefetch_buffer]
286
- )
 
 
 
287
  return MultiprocessIteratorState(
288
  base_iterator_state=self.base_iterator.get_state(),
289
  n_batches_to_prefetch=self.n_batches_to_prefetch,
@@ -304,12 +313,6 @@ class MultiprocessIterator(StatefulIterator):
304
  "Iterator may be invalid if shutdown was forced before state persisted."
305
  )
306
  logging.info("Main thread: Creating MP iterator")
307
- # First yield from the stored prefetch buffer.
308
- if self.prefetch_buffer is not None:
309
- while len(self.prefetch_buffer) > 0:
310
- item = self.prefetch_buffer.pop(0)
311
- yield item
312
- self.prefetch_buffer = None
313
 
314
  assert (
315
  self.producer is None
@@ -349,6 +352,13 @@ class MultiprocessIterator(StatefulIterator):
349
  logger.info("Async dataloader started")
350
  self.producer.start()
351
 
 
 
 
 
 
 
 
352
  while True:
353
  if self.producer.exitcode is not None:
354
  raise RuntimeError(
 
190
  logging.info(
191
  "Main thread: Emptying the batch_queue until batch.is_final=True is found."
192
  )
193
+ if self.prefetch_buffer is not None and len(self.prefetch_buffer) > 0:
194
+ buffer = self.prefetch_buffer
195
+ else:
196
+ buffer = []
197
+ self.prefetch_buffer = buffer
198
  final_batch_received = False
199
  while True:
200
  try:
 
265
  "Attempted to get approximate state, but queue was erroniously empty."
266
  )
267
  self.received_approximate_state_event.set()
268
+ if self.prefetch_buffer is not None and len(self.prefetch_buffer) > 0:
269
+ buffer = [b.to_python_dict() for b in self.prefetch_buffer]
270
+ else:
271
+ buffer = []
272
  return MultiprocessIteratorState(
273
  base_iterator_state=base_iterator_state,
274
  n_batches_to_prefetch=self.n_batches_to_prefetch,
275
+ serialized_prefetch_buffer=json.dumps(buffer),
 
 
276
  persist_type=self.persist_type,
277
  )
278
 
 
287
  "State will be invalid if shutdown was forced before state persisted."
288
  )
289
  if self.producer is None:
290
+ if self.prefetch_buffer is not None and len(self.prefetch_buffer) > 0:
291
+ serialized_prefetch_buffer = json.dumps(
292
+ [b.to_python_dict() for b in self.prefetch_buffer]
293
+ )
294
+ else:
295
+ serialized_prefetch_buffer = json.dumps([])
296
  return MultiprocessIteratorState(
297
  base_iterator_state=self.base_iterator.get_state(),
298
  n_batches_to_prefetch=self.n_batches_to_prefetch,
 
313
  "Iterator may be invalid if shutdown was forced before state persisted."
314
  )
315
  logging.info("Main thread: Creating MP iterator")
 
 
 
 
 
 
316
 
317
  assert (
318
  self.producer is None
 
352
  logger.info("Async dataloader started")
353
  self.producer.start()
354
 
355
+ # First yield from the stored prefetch buffer.
356
+ if self.prefetch_buffer is not None:
357
+ while len(self.prefetch_buffer) > 0:
358
+ item = self.prefetch_buffer.pop(0)
359
+ yield item
360
+ self.prefetch_buffer = None
361
+
362
  while True:
363
  if self.producer.exitcode is not None:
364
  raise RuntimeError(