thecollabagepatch commited on
Commit
bf8ae4c
·
1 Parent(s): eb5d99b

reseed splice context fix

Browse files
Files changed (1) hide show
  1. jam_worker.py +62 -9
jam_worker.py CHANGED
@@ -271,35 +271,88 @@ class JamWorker(threading.Thread):
271
  depth = original_tokens.shape[1]
272
  frames_per_bar = self._frames_per_bar()
273
 
274
- # 1) Anchor tail
275
- # Use floor, not round, to avoid grabbing an extra bar.
276
  anchor = self._bar_aligned_tail(original_tokens, math.floor(anchor_bars))
277
 
278
- # 2) Fill remainder with recent (in whole bars when possible)
279
  a = anchor.shape[0]
280
  remain = max(ctx_frames - a, 0)
 
 
 
281
  if remain > 0:
282
  bars_fit = remain // frames_per_bar
283
  if bars_fit >= 1:
284
  want_recent_frames = int(bars_fit * frames_per_bar)
285
- recent = recent_tokens[-want_recent_frames:] if recent_tokens.shape[0] > want_recent_frames else recent_tokens
 
286
  else:
287
- recent = recent_tokens[-remain:] if recent_tokens.shape[0] > remain else recent_tokens
 
 
 
 
 
288
  else:
289
- recent = recent_tokens[:0]
 
290
 
291
- out = np.concatenate([anchor, recent], axis=0) if (anchor.size or recent.size) else recent_tokens[-ctx_frames:]
292
  if out.shape[0] > ctx_frames:
293
  out = out[-ctx_frames:]
294
 
295
- # --- NEW: force total length to a whole number of bars
296
- max_bar_aligned = (out.shape[0] // frames_per_bar) * frames_per_bar
 
 
 
297
  if max_bar_aligned > 0 and out.shape[0] != max_bar_aligned:
298
  out = out[-max_bar_aligned:]
299
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
  if out.shape[1] != depth:
301
  out = out[:, :depth]
302
  return out
 
303
 
304
  def _realign_emit_pointer_to_bar(self, sr_model: int):
305
  """Advance _next_emit_start to the next bar boundary in model-sample space."""
 
271
  depth = original_tokens.shape[1]
272
  frames_per_bar = self._frames_per_bar()
273
 
274
+ # 1) Anchor tail (whole bars)
 
275
  anchor = self._bar_aligned_tail(original_tokens, math.floor(anchor_bars))
276
 
277
+ # 2) Fill remainder with recent (prefer whole bars)
278
  a = anchor.shape[0]
279
  remain = max(ctx_frames - a, 0)
280
+
281
+ recent = recent_tokens[:0]
282
+ used_recent = 0 # frames taken from the END of recent_tokens
283
  if remain > 0:
284
  bars_fit = remain // frames_per_bar
285
  if bars_fit >= 1:
286
  want_recent_frames = int(bars_fit * frames_per_bar)
287
+ used_recent = min(want_recent_frames, recent_tokens.shape[0])
288
+ recent = recent_tokens[-used_recent:] if used_recent > 0 else recent_tokens[:0]
289
  else:
290
+ used_recent = min(remain, recent_tokens.shape[0])
291
+ recent = recent_tokens[-used_recent:] if used_recent > 0 else recent_tokens[:0]
292
+
293
+ # 3) Concat in order [anchor, recent]
294
+ if anchor.size or recent.size:
295
+ out = np.concatenate([anchor, recent], axis=0)
296
  else:
297
+ # fallback: just take the last ctx window from recent
298
+ out = recent_tokens[-ctx_frames:]
299
 
300
+ # 4) Trim if we overshot
301
  if out.shape[0] > ctx_frames:
302
  out = out[-ctx_frames:]
303
 
304
+ # 5) Snap the **END** to the nearest LOWER bar boundary
305
+ if frames_per_bar > 0:
306
+ max_bar_aligned = (out.shape[0] // frames_per_bar) * frames_per_bar
307
+ else:
308
+ max_bar_aligned = out.shape[0]
309
  if max_bar_aligned > 0 and out.shape[0] != max_bar_aligned:
310
  out = out[-max_bar_aligned:]
311
 
312
+ # 6) Left-fill to reach ctx_frames **without moving the END**
313
+ deficit = ctx_frames - out.shape[0]
314
+ if deficit > 0:
315
+ left_parts = []
316
+
317
+ # Prefer frames immediately BEFORE the region we used from 'recent_tokens'
318
+ if used_recent < recent_tokens.shape[0]:
319
+ take = min(deficit, recent_tokens.shape[0] - used_recent)
320
+ if used_recent > 0:
321
+ left_parts.append(recent_tokens[-(used_recent + take) : -used_recent])
322
+ else:
323
+ left_parts.append(recent_tokens[-take:])
324
+
325
+ # Then take frames immediately BEFORE the 'anchor' in original_tokens
326
+ if sum(p.shape[0] for p in left_parts) < deficit and anchor.shape[0] > 0:
327
+ need = deficit - sum(p.shape[0] for p in left_parts)
328
+ a_len = anchor.shape[0]
329
+ avail = max(original_tokens.shape[0] - a_len, 0)
330
+ take2 = min(need, avail)
331
+ if take2 > 0:
332
+ left_parts.append(original_tokens[-(a_len + take2) : -a_len])
333
+
334
+ # Still short? tile from what's available
335
+ have = sum(p.shape[0] for p in left_parts)
336
+ if have < deficit:
337
+ base = out if out.shape[0] > 0 else (recent_tokens if recent_tokens.shape[0] > 0 else original_tokens)
338
+ reps = int(np.ceil((deficit - have) / max(1, base.shape[0])))
339
+ left_parts.append(np.tile(base, (reps, 1))[: (deficit - have)])
340
+
341
+ left = np.concatenate(left_parts, axis=0)
342
+ out = np.concatenate([left[-deficit:], out], axis=0)
343
+
344
+ # 7) Final guard to exact length
345
+ if out.shape[0] > ctx_frames:
346
+ out = out[-ctx_frames:]
347
+ elif out.shape[0] < ctx_frames:
348
+ reps = int(np.ceil(ctx_frames / max(1, out.shape[0])))
349
+ out = np.tile(out, (reps, 1))[-ctx_frames:]
350
+
351
+ # 8) Depth guard
352
  if out.shape[1] != depth:
353
  out = out[:, :depth]
354
  return out
355
+
356
 
357
  def _realign_emit_pointer_to_bar(self, sr_model: int):
358
  """Advance _next_emit_start to the next bar boundary in model-sample space."""