thecollabagepatch commited on
Commit
184daaa
Β·
1 Parent(s): a36f465

fixing jam seamlessness

Browse files
Files changed (1) hide show
  1. jam_worker.py +95 -59
jam_worker.py CHANGED
@@ -3,6 +3,7 @@ import threading, time, base64, io, uuid
3
  from dataclasses import dataclass, field
4
  import numpy as np
5
  import soundfile as sf
 
6
 
7
  from utils import (
8
  match_loudness_to_reference, stitch_generated, hard_trim_seconds,
@@ -155,77 +156,112 @@ class JamWorker(threading.Thread):
155
  }
156
  return b64, meta
157
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  def run(self):
159
- """Main worker loop - generate chunks continuously but don't get too far ahead"""
 
160
  spb = self._seconds_per_bar()
161
- chunk_secs = self.params.bars_per_chunk * spb
 
162
  xfade = self.mrt.config.crossfade_length
163
 
164
- print("πŸš€ JamWorker started with flow control...")
165
-
 
 
 
 
166
  while not self._stop_event.is_set():
167
- # Check if we should generate the next chunk
168
- if not self._should_generate_next_chunk():
169
- # We're ahead enough, wait a bit for frontend to catch up
170
- print(f"⏸️ Buffer full, waiting for consumption...")
171
- time.sleep(0.5)
172
- continue
173
-
174
- # Generate the next chunk
175
  with self._lock:
 
 
 
176
  style_vec = self.params.style_vec
177
  self.mrt.guidance_weight = self.params.guidance_weight
178
- self.mrt.temperature = self.params.temperature
179
- self.mrt.topk = self.params.topk
180
- next_idx = self.idx + 1
181
 
182
- print(f"🎹 Generating chunk {next_idx}...")
183
-
184
- # Generate enough model chunks to cover chunk_secs
185
- need = chunk_secs
186
- chunks = []
187
  self.last_chunk_started_at = time.time()
188
-
189
- while need > 0 and not self._stop_event.is_set():
190
- wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_vec)
191
- chunks.append(wav)
192
- need -= (wav.samples.shape[0] / float(self.mrt.sample_rate))
193
-
194
- if self._stop_event.is_set():
195
- break
196
-
197
- # Stitch and trim to exact seconds at model SR
198
- y = stitch_generated(chunks, self.mrt.sample_rate, xfade).as_stereo()
199
- y = hard_trim_seconds(y, chunk_secs)
200
-
201
- # Post-process
202
- if next_idx == 1 and self.params.ref_loop is not None:
203
- y, _ = match_loudness_to_reference(
204
- self.params.ref_loop, y,
205
- method=self.params.loudness_mode,
206
- headroom_db=self.params.headroom_db
 
 
 
 
 
 
 
 
 
 
 
207
  )
208
- else:
209
- apply_micro_fades(y, 3)
210
 
211
- # Resample + snap + b64
212
- b64, meta = self._snap_and_encode(
213
- y, seconds=chunk_secs,
214
- target_sr=self.params.target_sr,
215
- bars=self.params.bars_per_chunk
216
- )
217
 
218
- # Store the completed chunk
219
- with self._lock:
220
- self.idx = next_idx
221
- self.outbox.append(JamChunk(index=next_idx, audio_base64=b64, metadata=meta))
222
-
223
- # Keep outbox bounded (remove old chunks)
224
- if len(self.outbox) > 10:
225
- # Remove chunks that are way behind the delivery point
226
- self.outbox = [ch for ch in self.outbox if ch.index > self._last_delivered_index - 5]
227
 
228
- self.last_chunk_completed_at = time.time()
229
- print(f"βœ… Completed chunk {next_idx}")
 
 
 
230
 
231
- print("πŸ›‘ JamWorker stopped")
 
3
  from dataclasses import dataclass, field
4
  import numpy as np
5
  import soundfile as sf
6
+ from magenta_rt import audio as au
7
 
8
  from utils import (
9
  match_loudness_to_reference, stitch_generated, hard_trim_seconds,
 
156
  }
157
  return b64, meta
158
 
159
+ def _append_model_chunk_to_stream(self, wav):
160
+ """Incrementally append a model chunk with equal-power crossfade."""
161
+ xfade_s = float(self.mrt.config.crossfade_length)
162
+ sr = int(self.mrt.sample_rate)
163
+ xfade_n = int(round(xfade_s * sr))
164
+
165
+ s = wav.samples if wav.samples.ndim == 2 else wav.samples[:, None]
166
+
167
+ if getattr(self, "_stream", None) is None:
168
+ # First chunk: drop model pre-roll (xfade head)
169
+ if s.shape[0] > xfade_n:
170
+ self._stream = s[xfade_n:].astype(np.float32, copy=True)
171
+ else:
172
+ self._stream = np.zeros((0, s.shape[1]), dtype=np.float32)
173
+ self._next_emit_start = 0 # pointer into _stream (model SR samples)
174
+ return
175
+
176
+ # Crossfade last xfade_n samples of _stream with head of new s
177
+ if s.shape[0] <= xfade_n or self._stream.shape[0] < xfade_n:
178
+ # Degenerate safeguard
179
+ self._stream = np.concatenate([self._stream, s], axis=0)
180
+ return
181
+
182
+ tail = self._stream[-xfade_n:]
183
+ head = s[:xfade_n]
184
+
185
+ # Equal-power envelopes
186
+ t = np.linspace(0, np.pi/2, xfade_n, endpoint=False, dtype=np.float32)[:, None]
187
+ eq_in, eq_out = np.sin(t), np.cos(t)
188
+ mixed = tail * eq_out + head * eq_in
189
+
190
+ self._stream = np.concatenate([self._stream[:-xfade_n], mixed, s[xfade_n:]], axis=0)
191
+
192
  def run(self):
193
+ """Continuous stream + sliding 8-bar window emitter."""
194
+ sr_model = int(self.mrt.sample_rate)
195
  spb = self._seconds_per_bar()
196
+ chunk_secs = float(self.params.bars_per_chunk) * spb
197
+ chunk_n_model = int(round(chunk_secs * sr_model))
198
  xfade = self.mrt.config.crossfade_length
199
 
200
+ # Streaming state
201
+ self._stream = None # np.ndarray [S, C] at model SR
202
+ self._next_emit_start = 0 # sample pointer for next 8-bar cut
203
+
204
+ print("πŸš€ JamWorker (streaming) started...")
205
+
206
  while not self._stop_event.is_set():
207
+ # Flow control: don't get too far ahead of the consumer
 
 
 
 
 
 
 
208
  with self._lock:
209
+ if self.idx > self._last_delivered_index + self._max_buffer_ahead:
210
+ time.sleep(0.25)
211
+ continue
212
  style_vec = self.params.style_vec
213
  self.mrt.guidance_weight = self.params.guidance_weight
214
+ self.mrt.temperature = self.params.temperature
215
+ self.mrt.topk = self.params.topk
 
216
 
217
+ # Generate ONE model chunk and append to the continuous stream
 
 
 
 
218
  self.last_chunk_started_at = time.time()
219
+ wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_vec)
220
+ self._append_model_chunk_to_stream(wav)
221
+ self.last_chunk_completed_at = time.time()
222
+
223
+ # While we have at least one full 8-bar window available, emit it
224
+ while (getattr(self, "_stream", None) is not None and
225
+ self._stream.shape[0] - self._next_emit_start >= chunk_n_model and
226
+ not self._stop_event.is_set()):
227
+
228
+ seg = self._stream[self._next_emit_start:self._next_emit_start + chunk_n_model]
229
+
230
+ # Wrap as Waveform at model SR
231
+ y = au.Waveform(seg.astype(np.float32, copy=False), sr_model).as_stereo()
232
+
233
+ # Post-processing:
234
+ # - First emitted chunk: loudness-match to ref_loop
235
+ # - No micro-fades on mid-stream windows (they cause dips)
236
+ next_idx = self.idx + 1
237
+ if next_idx == 1 and self.params.ref_loop is not None:
238
+ y, _ = match_loudness_to_reference(
239
+ self.params.ref_loop, y,
240
+ method=self.params.loudness_mode,
241
+ headroom_db=self.params.headroom_db
242
+ )
243
+
244
+ # Resample + snap + encode exactly chunk_secs long
245
+ b64, meta = self._snap_and_encode(
246
+ y, seconds=chunk_secs,
247
+ target_sr=self.params.target_sr,
248
+ bars=self.params.bars_per_chunk
249
  )
 
 
250
 
251
+ with self._lock:
252
+ self.idx = next_idx
253
+ self.outbox.append(JamChunk(index=next_idx, audio_base64=b64, metadata=meta))
254
+ # Bound the outbox
255
+ if len(self.outbox) > 10:
256
+ self.outbox = [ch for ch in self.outbox if ch.index > self._last_delivered_index - 5]
257
 
258
+ # Advance window pointer to the next 8-bar slot
259
+ self._next_emit_start += chunk_n_model
 
 
 
 
 
 
 
260
 
261
+ # Trim old samples to keep memory bounded (keep a little guard)
262
+ keep_from = max(0, self._next_emit_start - chunk_n_model) # keep 1 extra window
263
+ if keep_from > 0:
264
+ self._stream = self._stream[keep_from:]
265
+ self._next_emit_start -= keep_from
266
 
267
+ print("πŸ›‘ JamWorker (streaming) stopped")