thecollabagepatch commited on
Commit
c4dc2c2
Β·
1 Parent(s): 1b98b73

fixing flam

Browse files
Files changed (1) hide show
  1. jam_worker.py +77 -52
jam_worker.py CHANGED
@@ -355,49 +355,74 @@ class JamWorker(threading.Thread):
355
  chunk_secs = self.params.bars_per_chunk * spb
356
  xfade = float(self.mrt.config.crossfade_length) # seconds
357
 
358
- # local fallback stitcher that *keeps* the first head if utils.stitch_generated
359
- # doesn't yet support drop_first_pre_roll
360
- def _stitch_keep_head(chunks, sr: int, xfade_s: float):
361
- from magenta_rt import audio as au
362
- import numpy as _np
363
- if not chunks:
364
- raise ValueError("no chunks to stitch")
365
- xfade_n = int(round(max(0.0, xfade_s) * sr))
366
- # Fast-path: no crossfade
367
- if xfade_n <= 0:
368
- out = _np.concatenate([c.samples for c in chunks], axis=0)
369
- return au.Waveform(out, sr)
370
- # build equal-power curves
371
- t = _np.linspace(0, _np.pi / 2, xfade_n, endpoint=False, dtype=_np.float32)
372
- eq_in, eq_out = _np.sin(t)[:, None], _np.cos(t)[:, None]
373
-
374
- first = chunks[0].samples
375
- if first.shape[0] < xfade_n:
376
- raise ValueError("chunk shorter than crossfade prefix")
377
- out = first.copy() # πŸ‘ˆ keep the head for live seam
378
-
379
- for i in range(1, len(chunks)):
380
- cur = chunks[i].samples
381
- if cur.shape[0] < xfade_n:
382
- # too short to crossfade; just butt-join
383
- out = _np.concatenate([out, cur], axis=0)
384
- continue
385
- head, tail = cur[:xfade_n], cur[xfade_n:]
386
- mixed = out[-xfade_n:] * eq_out + head * eq_in
387
- out = _np.concatenate([out[:-xfade_n], mixed, tail], axis=0)
388
- return au.Waveform(out, sr)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
389
 
390
  print("πŸš€ JamWorker started with flow control...")
391
 
392
  while not self._stop_event.is_set():
393
  # Don’t get too far ahead of the consumer
394
  if not self._should_generate_next_chunk():
395
- # We're ahead enough, wait a bit for frontend to catch up
396
- # (kept short so stop() stays responsive)
397
  time.sleep(0.5)
398
  continue
399
 
400
- # Snapshot knobs + compute index atomically
401
  with self._lock:
402
  style_vec = self.params.style_vec
403
  self.mrt.guidance_weight = float(self.params.guidance_weight)
@@ -409,12 +434,10 @@ class JamWorker(threading.Thread):
409
  self.last_chunk_started_at = time.time()
410
 
411
  # ---- Generate enough model sub-chunks to yield *audible* chunk_secs ----
412
- # Count the first chunk at full length L, and each subsequent at (L - xfade)
413
  assembled = 0.0
414
  chunks = []
415
-
416
  while assembled < chunk_secs and not self._stop_event.is_set():
417
- # generate_chunk returns (au.Waveform, new_state)
418
  wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_vec)
419
  chunks.append(wav)
420
  L = wav.samples.shape[0] / float(self.mrt.sample_rate)
@@ -423,27 +446,30 @@ class JamWorker(threading.Thread):
423
  if self._stop_event.is_set():
424
  break
425
 
426
- # ---- Stitch and trim at model SR (keep first head for seamless handoff) ----
427
- try:
428
- # Preferred path if you've added the new param in utils.stitch_generated
429
- y = stitch_generated(chunks, self.mrt.sample_rate, xfade, drop_first_pre_roll=False).as_stereo()
430
- except TypeError:
431
- # Backward-compatible: local stitcher that keeps the head
432
- y = _stitch_keep_head(chunks, int(self.mrt.sample_rate), xfade).as_stereo()
433
-
434
- # Hard trim to the exact musical duration (still at model SR)
435
  y = hard_trim_seconds(y, chunk_secs)
436
 
 
 
 
 
 
 
 
 
 
 
 
 
437
  # ---- Post-processing ----
438
  if next_idx == 1 and self.params.ref_loop is not None:
439
- # match loudness to the provided reference on the very first audible chunk
440
  y, _ = match_loudness_to_reference(
441
  self.params.ref_loop, y,
442
  method=self.params.loudness_mode,
443
  headroom_db=self.params.headroom_db
444
  )
445
  else:
446
- # light micro-fades to guard against clicks
447
  apply_micro_fades(y, 3)
448
 
449
  # ---- Resample + bar-snap + encode ----
@@ -453,14 +479,12 @@ class JamWorker(threading.Thread):
453
  target_sr=self.params.target_sr,
454
  bars=self.params.bars_per_chunk
455
  )
456
- # small hint for the client if you want UI butter between chunks
457
- meta["xfade_seconds"] = xfade
458
 
459
- # ---- Publish the completed chunk ----
460
  with self._lock:
461
  self.idx = next_idx
462
  self.outbox.append(JamChunk(index=next_idx, audio_base64=b64, metadata=meta))
463
- # Keep outbox bounded (trim far-behind entries)
464
  if len(self.outbox) > 10:
465
  cutoff = self._last_delivered_index - 5
466
  self.outbox = [ch for ch in self.outbox if ch.index > cutoff]
@@ -469,3 +493,4 @@ class JamWorker(threading.Thread):
469
  print(f"βœ… Completed chunk {next_idx}")
470
 
471
  print("πŸ›‘ JamWorker stopped")
 
 
355
  chunk_secs = self.params.bars_per_chunk * spb
356
  xfade = float(self.mrt.config.crossfade_length) # seconds
357
 
358
+ # ---- tiny helper: mono + simple envelope ----
359
+ def _mono_env(x: np.ndarray, sr: int, win_ms: float = 20.0) -> np.ndarray:
360
+ if x.ndim == 2:
361
+ x = x.mean(axis=1)
362
+ x = np.abs(x).astype(np.float32)
363
+ w = max(1, int(round(win_ms * 1e-3 * sr)))
364
+ if w == 1:
365
+ return x
366
+ kern = np.ones(w, dtype=np.float32) / float(w)
367
+ # moving average (same length)
368
+ return np.convolve(x, kern, mode="same")
369
+
370
+ # ---- estimate how late the first downbeat is (<= max_ms) ----
371
+ def _estimate_first_offset_samples(ref_loop_wav, gen_wav, sr: int, max_ms: int = 120) -> int:
372
+ try:
373
+ # resample ref to model SR if needed
374
+ ref = ref_loop_wav
375
+ if ref.sample_rate != sr:
376
+ ref = ref.resample(sr)
377
+ # last 1 bar of the reference (what the model just "heard")
378
+ n_bar = int(round(spb * sr))
379
+ ref_tail = ref.samples[-n_bar:, :] if ref.samples.shape[0] >= n_bar else ref.samples
380
+ # first 2 bars of the generated chunk (search window)
381
+ gen_head = gen_wav.samples[: int(2 * n_bar), :]
382
+ if ref_tail.size == 0 or gen_head.size == 0:
383
+ return 0
384
+
385
+ # envelopes
386
+ e_ref = _mono_env(ref_tail, sr)
387
+ e_gen = _mono_env(gen_head, sr)
388
+
389
+ max_lag = int(round((max_ms / 1000.0) * sr))
390
+ # ensure the window is long enough
391
+ seg = min(len(e_ref), len(e_gen))
392
+ e_ref = e_ref[-seg:]
393
+ e_gen = e_gen[: seg + max_lag] # allow positive lag (gen late)
394
+
395
+ if len(e_gen) < seg:
396
+ return 0
397
+
398
+ # brute-force short-range correlation (gen late => positive lag)
399
+ best_lag = 0
400
+ best_score = -1e9
401
+ for lag in range(0, max_lag + 1):
402
+ a = e_ref
403
+ b = e_gen[lag : lag + seg]
404
+ if len(b) != seg:
405
+ break
406
+ # normalized dot to be robust-ish
407
+ denom = (np.linalg.norm(a) * np.linalg.norm(b)) or 1.0
408
+ score = float(np.dot(a, b) / denom)
409
+ if score > best_score:
410
+ best_score = score
411
+ best_lag = lag
412
+ return int(best_lag)
413
+ except Exception:
414
+ return 0
415
 
416
  print("πŸš€ JamWorker started with flow control...")
417
 
418
  while not self._stop_event.is_set():
419
  # Don’t get too far ahead of the consumer
420
  if not self._should_generate_next_chunk():
421
+ print("⏸️ Buffer full, waiting for consumption...")
 
422
  time.sleep(0.5)
423
  continue
424
 
425
+ # Snapshot knobs + compute index
426
  with self._lock:
427
  style_vec = self.params.style_vec
428
  self.mrt.guidance_weight = float(self.params.guidance_weight)
 
434
  self.last_chunk_started_at = time.time()
435
 
436
  # ---- Generate enough model sub-chunks to yield *audible* chunk_secs ----
437
+ # First sub-chunk contributes full L; subsequent contribute (L - xfade)
438
  assembled = 0.0
439
  chunks = []
 
440
  while assembled < chunk_secs and not self._stop_event.is_set():
 
441
  wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_vec)
442
  chunks.append(wav)
443
  L = wav.samples.shape[0] / float(self.mrt.sample_rate)
 
446
  if self._stop_event.is_set():
447
  break
448
 
449
+ # ---- Stitch (utils drops the very first model pre-roll) & trim at model SR ----
450
+ y = stitch_generated(chunks, self.mrt.sample_rate, xfade).as_stereo()
 
 
 
 
 
 
 
451
  y = hard_trim_seconds(y, chunk_secs)
452
 
453
+ # ---- ONE-TIME: grid-align the very first jam chunk to kill the flam ----
454
+ if next_idx == 1 and self.params.combined_loop is not None:
455
+ offset = _estimate_first_offset_samples(
456
+ self.params.combined_loop, y, int(self.mrt.sample_rate), max_ms=120
457
+ )
458
+ if offset > 0:
459
+ # Trim the head by the detected offset; we'll snap length later
460
+ y.samples = y.samples[offset:, :]
461
+ print(f"🎯 First-chunk offset compensation: -{offset/self.mrt.sample_rate:.3f}s")
462
+ # hard trim again (defensive), remaining length exactness happens in _snap_and_encode
463
+ y = hard_trim_seconds(y, chunk_secs)
464
+
465
  # ---- Post-processing ----
466
  if next_idx == 1 and self.params.ref_loop is not None:
 
467
  y, _ = match_loudness_to_reference(
468
  self.params.ref_loop, y,
469
  method=self.params.loudness_mode,
470
  headroom_db=self.params.headroom_db
471
  )
472
  else:
 
473
  apply_micro_fades(y, 3)
474
 
475
  # ---- Resample + bar-snap + encode ----
 
479
  target_sr=self.params.target_sr,
480
  bars=self.params.bars_per_chunk
481
  )
482
+ meta["xfade_seconds"] = xfade # tiny hint for client if you want butter at chunk joins
 
483
 
484
+ # ---- Publish ----
485
  with self._lock:
486
  self.idx = next_idx
487
  self.outbox.append(JamChunk(index=next_idx, audio_base64=b64, metadata=meta))
 
488
  if len(self.outbox) > 10:
489
  cutoff = self._last_delivered_index - 5
490
  self.outbox = [ch for ch in self.outbox if ch.index > cutoff]
 
493
  print(f"βœ… Completed chunk {next_idx}")
494
 
495
  print("πŸ›‘ JamWorker stopped")
496
+