henok3878 commited on
Commit
f5121bb
·
1 Parent(s): 0f4a5c1

support only single batch for now

Browse files
Files changed (1) hide show
  1. main.py +3 -10
main.py CHANGED
@@ -212,20 +212,13 @@ def generate_strokes(
212
  prime=primingData
213
  )
214
 
215
- if len(stroke_tensors) == 1 and stroke_tensors[0].dim() == 2:
 
216
  all_strokes_tensor = stroke_tensors[0]
217
  stroke_offsets = all_strokes_tensor.cpu().numpy().tolist()
218
  else:
219
  stroke_offsets = []
220
- for stroke_tensor in stroke_tensors:
221
- if stroke_tensor.dim() == 2:
222
- stroke_data = stroke_tensor.squeeze(0).cpu().numpy().tolist()
223
- else:
224
- stroke_data = stroke_tensor.cpu().numpy().tolist()
225
-
226
- if len(stroke_data) == 3:
227
- stroke_offsets.append(stroke_data)
228
-
229
  return stroke_offsets
230
 
231
  except Exception as e:
 
212
  prime=primingData
213
  )
214
 
215
+ # batch_size is 1
216
+ if len(stroke_tensors) == 1:
217
  all_strokes_tensor = stroke_tensors[0]
218
  stroke_offsets = all_strokes_tensor.cpu().numpy().tolist()
219
  else:
220
  stroke_offsets = []
221
+ logger.warning(f"Expected single batch, but got {len(stroke_tensors)}")
 
 
 
 
 
 
 
 
222
  return stroke_offsets
223
 
224
  except Exception as e: