thecollabagepatch commited on
Commit
a8d3e47
·
1 Parent(s): 53cce5a

the next 8 bars not the last 8 bars

Browse files
Files changed (2) hide show
  1. app.py +45 -24
  2. jam_worker.py +90 -47
app.py CHANGED
@@ -323,25 +323,46 @@ def jam_start(
323
  return {"session_id": sid}
324
 
325
  @app.get("/jam/next")
326
- def jam_next(session_id: str, since: int = 0):
 
 
 
 
327
  with jam_lock:
328
  worker = jam_registry.get(session_id)
329
  if worker is None or not worker.is_alive():
330
  raise HTTPException(status_code=404, detail="Session not found")
331
 
332
- # drain outbox entries with index > since
333
- items = []
334
- with worker._lock:
335
- for ch in worker.outbox:
336
- if ch.index > since:
337
- items.append({"index": ch.index, "audio_base64": ch.audio_base64, "metadata": ch.metadata})
338
- # optional: truncate old items to keep memory bounded
339
- if len(worker.outbox) > 32:
340
- worker.outbox = worker.outbox[-16:]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341
 
342
- if not items:
343
- return Response(status_code=204) # nothing yet
344
- return {"chunks": items}
345
 
346
  @app.post("/jam/stop")
347
  def jam_stop(session_id: str = Body(..., embed=True)):
@@ -384,26 +405,26 @@ def jam_status(session_id: str):
384
 
385
  # Snapshot safely
386
  with worker._lock:
387
- last_index = int(worker.idx)
 
388
  queued = len(worker.outbox)
 
389
  p = worker.params
390
  spb = p.beats_per_bar * (60.0 / p.bpm)
391
  chunk_secs = p.bars_per_chunk * spb
392
- target_sr = p.target_sr
393
- bars_per_chunk = p.bars_per_chunk
394
- beats_per_bar = p.beats_per_bar
395
- bpm = p.bpm
396
 
397
  return {
398
  "running": running,
399
- "last_index": last_index, # last finished chunk index (0 if none yet)
400
- "queued_chunks": queued, # how many not-yet-fetched chunks are in the outbox
401
- "bpm": bpm,
402
- "beats_per_bar": beats_per_bar,
403
- "bars_per_chunk": bars_per_chunk,
 
 
404
  "seconds_per_bar": spb,
405
  "chunk_duration_seconds": chunk_secs,
406
- "target_sample_rate": target_sr,
407
  "last_chunk_started_at": worker.last_chunk_started_at,
408
  "last_chunk_completed_at": worker.last_chunk_completed_at,
409
  }
 
323
  return {"session_id": sid}
324
 
325
  @app.get("/jam/next")
326
+ def jam_next(session_id: str):
327
+ """
328
+ Get the next sequential chunk in the jam session.
329
+ This ensures chunks are delivered in order without gaps.
330
+ """
331
  with jam_lock:
332
  worker = jam_registry.get(session_id)
333
  if worker is None or not worker.is_alive():
334
  raise HTTPException(status_code=404, detail="Session not found")
335
 
336
+ # Get the next sequential chunk (this blocks until ready)
337
+ chunk = worker.get_next_chunk()
338
+
339
+ if chunk is None:
340
+ raise HTTPException(status_code=408, detail="Chunk not ready within timeout")
341
+
342
+ return {
343
+ "chunk": {
344
+ "index": chunk.index,
345
+ "audio_base64": chunk.audio_base64,
346
+ "metadata": chunk.metadata
347
+ }
348
+ }
349
+
350
+ @app.post("/jam/consume")
351
+ def jam_consume(session_id: str = Form(...), chunk_index: int = Form(...)):
352
+ """
353
+ Mark a chunk as consumed by the frontend.
354
+ This helps the worker manage its buffer and generation flow.
355
+ """
356
+ with jam_lock:
357
+ worker = jam_registry.get(session_id)
358
+ if worker is None or not worker.is_alive():
359
+ raise HTTPException(status_code=404, detail="Session not found")
360
+
361
+ worker.mark_chunk_consumed(chunk_index)
362
+
363
+ return {"consumed": chunk_index}
364
+
365
 
 
 
 
366
 
367
  @app.post("/jam/stop")
368
  def jam_stop(session_id: str = Body(..., embed=True)):
 
405
 
406
  # Snapshot safely
407
  with worker._lock:
408
+ last_generated = int(worker.idx)
409
+ last_delivered = int(worker._last_delivered_index)
410
  queued = len(worker.outbox)
411
+ buffer_ahead = last_generated - last_delivered
412
  p = worker.params
413
  spb = p.beats_per_bar * (60.0 / p.bpm)
414
  chunk_secs = p.bars_per_chunk * spb
 
 
 
 
415
 
416
  return {
417
  "running": running,
418
+ "last_generated_index": last_generated, # Last chunk that finished generating
419
+ "last_delivered_index": last_delivered, # Last chunk sent to frontend
420
+ "buffer_ahead": buffer_ahead, # How many chunks ahead we are
421
+ "queued_chunks": queued, # Total chunks in outbox
422
+ "bpm": p.bpm,
423
+ "beats_per_bar": p.beats_per_bar,
424
+ "bars_per_chunk": p.bars_per_chunk,
425
  "seconds_per_bar": spb,
426
  "chunk_duration_seconds": chunk_secs,
427
+ "target_sample_rate": p.target_sr,
428
  "last_chunk_started_at": worker.last_chunk_started_at,
429
  "last_chunk_completed_at": worker.last_chunk_completed_at,
430
  }
jam_worker.py CHANGED
@@ -1,17 +1,14 @@
1
- # jam_worker.py
2
  import threading, time, base64, io, uuid
3
  from dataclasses import dataclass, field
4
  import numpy as np
5
  import soundfile as sf
6
 
7
- # Pull in your helpers from app.py or refactor them into a shared utils module.
8
  from utils import (
9
  match_loudness_to_reference, stitch_generated, hard_trim_seconds,
10
  apply_micro_fades, make_bar_aligned_context, take_bar_aligned_tail,
11
  resample_and_snap, wav_bytes_base64
12
  )
13
- from scipy.signal import resample_poly
14
- from math import gcd
15
 
16
  @dataclass
17
  class JamParams:
@@ -22,8 +19,8 @@ class JamParams:
22
  loudness_mode: str = "auto"
23
  headroom_db: float = 1.0
24
  style_vec: np.ndarray | None = None
25
- ref_loop: any = None # au.Waveform at model SR for 1st-chunk loudness
26
- combined_loop: any = None # NEW: Full combined audio for context setup
27
  guidance_weight: float = 1.1
28
  temperature: float = 1.1
29
  topk: int = 40
@@ -39,15 +36,20 @@ class JamWorker(threading.Thread):
39
  super().__init__(daemon=True)
40
  self.mrt = mrt
41
  self.params = params
42
- # Initialize fresh state
43
  self.state = mrt.init_state()
44
 
45
- # CRITICAL: Set up fresh context from the new combined audio
46
  if params.combined_loop is not None:
47
  self._setup_context_from_combined_loop()
 
48
  self.idx = 0
49
  self.outbox: list[JamChunk] = []
50
  self._stop_event = threading.Event()
 
 
 
 
 
 
51
  self.last_chunk_started_at = None
52
  self.last_chunk_completed_at = None
53
  self._lock = threading.Lock()
@@ -55,14 +57,11 @@ class JamWorker(threading.Thread):
55
  def _setup_context_from_combined_loop(self):
56
  """Set up MRT context tokens from the combined loop audio"""
57
  try:
58
- # Import the utility functions (same as used in main generation)
59
  from utils import make_bar_aligned_context, take_bar_aligned_tail
60
 
61
- # Extract context from combined loop (same logic as generate_loop_continuation_with_mrt)
62
  codec_fps = float(self.mrt.codec.frame_rate)
63
  ctx_seconds = float(self.mrt.config.context_length_frames) / codec_fps
64
 
65
- # Take tail portion for context (matches main generation)
66
  loop_for_context = take_bar_aligned_tail(
67
  self.params.combined_loop,
68
  self.params.bpm,
@@ -70,11 +69,9 @@ class JamWorker(threading.Thread):
70
  ctx_seconds
71
  )
72
 
73
- # Encode to tokens
74
  tokens_full = self.mrt.codec.encode(loop_for_context).astype(np.int32)
75
  tokens = tokens_full[:, :self.mrt.config.decoder_codec_rvq_depth]
76
 
77
- # Create bar-aligned context
78
  context_tokens = make_bar_aligned_context(
79
  tokens,
80
  bpm=self.params.bpm,
@@ -83,30 +80,58 @@ class JamWorker(threading.Thread):
83
  beats_per_bar=self.params.beats_per_bar
84
  )
85
 
86
- # Set context on state - this is the key fix!
87
  self.state.context_tokens = context_tokens
88
-
89
  print(f"✅ JamWorker: Set up fresh context from combined loop")
90
- print(f" Context shape: {context_tokens.shape if context_tokens is not None else None}")
91
 
92
  except Exception as e:
93
  print(f"❌ Failed to setup context from combined loop: {e}")
94
- # Continue without context rather than crashing
95
 
96
  def stop(self):
97
  self._stop_event.set()
98
 
99
- def update_style(self, new_style_vec: np.ndarray | None):
100
- with self._lock:
101
- if new_style_vec is not None:
102
- self.params.style_vec = new_style_vec
103
-
104
  def update_knobs(self, *, guidance_weight=None, temperature=None, topk=None):
105
  with self._lock:
106
  if guidance_weight is not None: self.params.guidance_weight = float(guidance_weight)
107
  if temperature is not None: self.params.temperature = float(temperature)
108
  if topk is not None: self.params.topk = int(topk)
109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  def _seconds_per_bar(self) -> float:
111
  return self.params.beats_per_bar * (60.0 / self.params.bpm)
112
 
@@ -131,58 +156,76 @@ class JamWorker(threading.Thread):
131
  return b64, meta
132
 
133
  def run(self):
 
134
  spb = self._seconds_per_bar()
135
  chunk_secs = self.params.bars_per_chunk * spb
136
  xfade = self.mrt.config.crossfade_length
137
 
138
- # Prime: set initial context on state (caller should have done this; safe to re-set here)
139
- # NOTE: We assume caller passed a style_vec computed from tail/whole/blend.
140
  while not self._stop_event.is_set():
141
- # honor live knob updates atomically
 
 
 
 
 
 
 
142
  with self._lock:
143
  style_vec = self.params.style_vec
144
- # Temporarily override MRT knobs (thread-local overrides)
145
  self.mrt.guidance_weight = self.params.guidance_weight
146
  self.mrt.temperature = self.params.temperature
147
  self.mrt.topk = self.params.topk
 
148
 
149
- # 1) generate enough model chunks to cover chunk_secs
 
 
150
  need = chunk_secs
151
  chunks = []
152
  self.last_chunk_started_at = time.time()
 
153
  while need > 0 and not self._stop_event.is_set():
154
  wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_vec)
155
  chunks.append(wav)
156
- # model chunk length (seconds) at model SR
157
  need -= (wav.samples.shape[0] / float(self.mrt.sample_rate))
158
 
159
  if self._stop_event.is_set():
160
  break
161
 
162
- # 2) stitch and trim to exact seconds at model SR
163
  y = stitch_generated(chunks, self.mrt.sample_rate, xfade).as_stereo()
164
  y = hard_trim_seconds(y, chunk_secs)
165
 
166
- # 3) post-process
167
- if self.idx == 0 and self.params.ref_loop is not None:
168
- y, _ = match_loudness_to_reference(self.params.ref_loop, y,
169
- method=self.params.loudness_mode,
170
- headroom_db=self.params.headroom_db)
 
 
171
  else:
172
  apply_micro_fades(y, 3)
173
 
174
- # 4) resample + snap + b64
175
- b64, meta = self._snap_and_encode(y, seconds=chunk_secs,
176
- target_sr=self.params.target_sr,
177
- bars=self.params.bars_per_chunk)
 
 
178
 
179
- # 5) enqueue
180
  with self._lock:
181
- self.idx += 1
182
- self.outbox.append(JamChunk(index=self.idx, audio_base64=b64, metadata=meta))
183
-
184
- self.last_chunk_completed_at = time.time()
185
-
186
- # optional: cleanup here if needed
187
-
188
-
 
 
 
 
 
1
+ # jam_worker.py - SIMPLE FIX VERSION
2
  import threading, time, base64, io, uuid
3
  from dataclasses import dataclass, field
4
  import numpy as np
5
  import soundfile as sf
6
 
 
7
  from utils import (
8
  match_loudness_to_reference, stitch_generated, hard_trim_seconds,
9
  apply_micro_fades, make_bar_aligned_context, take_bar_aligned_tail,
10
  resample_and_snap, wav_bytes_base64
11
  )
 
 
12
 
13
  @dataclass
14
  class JamParams:
 
19
  loudness_mode: str = "auto"
20
  headroom_db: float = 1.0
21
  style_vec: np.ndarray | None = None
22
+ ref_loop: any = None
23
+ combined_loop: any = None
24
  guidance_weight: float = 1.1
25
  temperature: float = 1.1
26
  topk: int = 40
 
36
  super().__init__(daemon=True)
37
  self.mrt = mrt
38
  self.params = params
 
39
  self.state = mrt.init_state()
40
 
 
41
  if params.combined_loop is not None:
42
  self._setup_context_from_combined_loop()
43
+
44
  self.idx = 0
45
  self.outbox: list[JamChunk] = []
46
  self._stop_event = threading.Event()
47
+
48
+ # NEW: Track delivery state
49
+ self._last_delivered_index = 0
50
+ self._max_buffer_ahead = 5 # Don't generate more than 3 chunks ahead
51
+
52
+ # Timing info
53
  self.last_chunk_started_at = None
54
  self.last_chunk_completed_at = None
55
  self._lock = threading.Lock()
 
57
  def _setup_context_from_combined_loop(self):
58
  """Set up MRT context tokens from the combined loop audio"""
59
  try:
 
60
  from utils import make_bar_aligned_context, take_bar_aligned_tail
61
 
 
62
  codec_fps = float(self.mrt.codec.frame_rate)
63
  ctx_seconds = float(self.mrt.config.context_length_frames) / codec_fps
64
 
 
65
  loop_for_context = take_bar_aligned_tail(
66
  self.params.combined_loop,
67
  self.params.bpm,
 
69
  ctx_seconds
70
  )
71
 
 
72
  tokens_full = self.mrt.codec.encode(loop_for_context).astype(np.int32)
73
  tokens = tokens_full[:, :self.mrt.config.decoder_codec_rvq_depth]
74
 
 
75
  context_tokens = make_bar_aligned_context(
76
  tokens,
77
  bpm=self.params.bpm,
 
80
  beats_per_bar=self.params.beats_per_bar
81
  )
82
 
 
83
  self.state.context_tokens = context_tokens
 
84
  print(f"✅ JamWorker: Set up fresh context from combined loop")
 
85
 
86
  except Exception as e:
87
  print(f"❌ Failed to setup context from combined loop: {e}")
 
88
 
89
  def stop(self):
90
  self._stop_event.set()
91
 
 
 
 
 
 
92
  def update_knobs(self, *, guidance_weight=None, temperature=None, topk=None):
93
  with self._lock:
94
  if guidance_weight is not None: self.params.guidance_weight = float(guidance_weight)
95
  if temperature is not None: self.params.temperature = float(temperature)
96
  if topk is not None: self.params.topk = int(topk)
97
 
98
+ def get_next_chunk(self) -> JamChunk | None:
99
+ """Get the next sequential chunk (blocks/waits if not ready)"""
100
+ target_index = self._last_delivered_index + 1
101
+
102
+ # Wait for the target chunk to be ready (with timeout)
103
+ max_wait = 30.0 # seconds
104
+ start_time = time.time()
105
+
106
+ while time.time() - start_time < max_wait and not self._stop_event.is_set():
107
+ with self._lock:
108
+ # Look for the exact chunk we need
109
+ for chunk in self.outbox:
110
+ if chunk.index == target_index:
111
+ self._last_delivered_index = target_index
112
+ print(f"📦 Delivered chunk {target_index}")
113
+ return chunk
114
+
115
+ # Not ready yet, wait a bit
116
+ time.sleep(0.1)
117
+
118
+ # Timeout or stopped
119
+ return None
120
+
121
+ def mark_chunk_consumed(self, chunk_index: int):
122
+ """Mark a chunk as consumed by the frontend"""
123
+ with self._lock:
124
+ self._last_delivered_index = max(self._last_delivered_index, chunk_index)
125
+ print(f"✅ Chunk {chunk_index} consumed")
126
+
127
+ def _should_generate_next_chunk(self) -> bool:
128
+ """Check if we should generate the next chunk (don't get too far ahead)"""
129
+ with self._lock:
130
+ # Don't generate if we're already too far ahead
131
+ if self.idx > self._last_delivered_index + self._max_buffer_ahead:
132
+ return False
133
+ return True
134
+
135
  def _seconds_per_bar(self) -> float:
136
  return self.params.beats_per_bar * (60.0 / self.params.bpm)
137
 
 
156
  return b64, meta
157
 
158
  def run(self):
159
+ """Main worker loop - generate chunks continuously but don't get too far ahead"""
160
  spb = self._seconds_per_bar()
161
  chunk_secs = self.params.bars_per_chunk * spb
162
  xfade = self.mrt.config.crossfade_length
163
 
164
+ print("🚀 JamWorker started with flow control...")
165
+
166
  while not self._stop_event.is_set():
167
+ # Check if we should generate the next chunk
168
+ if not self._should_generate_next_chunk():
169
+ # We're ahead enough, wait a bit for frontend to catch up
170
+ print(f"⏸️ Buffer full, waiting for consumption...")
171
+ time.sleep(0.5)
172
+ continue
173
+
174
+ # Generate the next chunk
175
  with self._lock:
176
  style_vec = self.params.style_vec
 
177
  self.mrt.guidance_weight = self.params.guidance_weight
178
  self.mrt.temperature = self.params.temperature
179
  self.mrt.topk = self.params.topk
180
+ next_idx = self.idx + 1
181
 
182
+ print(f"🎹 Generating chunk {next_idx}...")
183
+
184
+ # Generate enough model chunks to cover chunk_secs
185
  need = chunk_secs
186
  chunks = []
187
  self.last_chunk_started_at = time.time()
188
+
189
  while need > 0 and not self._stop_event.is_set():
190
  wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_vec)
191
  chunks.append(wav)
 
192
  need -= (wav.samples.shape[0] / float(self.mrt.sample_rate))
193
 
194
  if self._stop_event.is_set():
195
  break
196
 
197
+ # Stitch and trim to exact seconds at model SR
198
  y = stitch_generated(chunks, self.mrt.sample_rate, xfade).as_stereo()
199
  y = hard_trim_seconds(y, chunk_secs)
200
 
201
+ # Post-process
202
+ if next_idx == 1 and self.params.ref_loop is not None:
203
+ y, _ = match_loudness_to_reference(
204
+ self.params.ref_loop, y,
205
+ method=self.params.loudness_mode,
206
+ headroom_db=self.params.headroom_db
207
+ )
208
  else:
209
  apply_micro_fades(y, 3)
210
 
211
+ # Resample + snap + b64
212
+ b64, meta = self._snap_and_encode(
213
+ y, seconds=chunk_secs,
214
+ target_sr=self.params.target_sr,
215
+ bars=self.params.bars_per_chunk
216
+ )
217
 
218
+ # Store the completed chunk
219
  with self._lock:
220
+ self.idx = next_idx
221
+ self.outbox.append(JamChunk(index=next_idx, audio_base64=b64, metadata=meta))
222
+
223
+ # Keep outbox bounded (remove old chunks)
224
+ if len(self.outbox) > 10:
225
+ # Remove chunks that are way behind the delivery point
226
+ self.outbox = [ch for ch in self.outbox if ch.index > self._last_delivered_index - 5]
227
+
228
+ self.last_chunk_completed_at = time.time()
229
+ print(f"✅ Completed chunk {next_idx}")
230
+
231
+ print("🛑 JamWorker stopped")