Spaces:
Running
Running
Commit
·
a8d3e47
1
Parent(s):
53cce5a
the next 8 bars not the last 8 bars
Browse files- app.py +45 -24
- jam_worker.py +90 -47
app.py
CHANGED
@@ -323,25 +323,46 @@ def jam_start(
|
|
323 |
return {"session_id": sid}
|
324 |
|
325 |
@app.get("/jam/next")
|
326 |
-
def jam_next(session_id: str
|
|
|
|
|
|
|
|
|
327 |
with jam_lock:
|
328 |
worker = jam_registry.get(session_id)
|
329 |
if worker is None or not worker.is_alive():
|
330 |
raise HTTPException(status_code=404, detail="Session not found")
|
331 |
|
332 |
-
#
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
341 |
|
342 |
-
if not items:
|
343 |
-
return Response(status_code=204) # nothing yet
|
344 |
-
return {"chunks": items}
|
345 |
|
346 |
@app.post("/jam/stop")
|
347 |
def jam_stop(session_id: str = Body(..., embed=True)):
|
@@ -384,26 +405,26 @@ def jam_status(session_id: str):
|
|
384 |
|
385 |
# Snapshot safely
|
386 |
with worker._lock:
|
387 |
-
|
|
|
388 |
queued = len(worker.outbox)
|
|
|
389 |
p = worker.params
|
390 |
spb = p.beats_per_bar * (60.0 / p.bpm)
|
391 |
chunk_secs = p.bars_per_chunk * spb
|
392 |
-
target_sr = p.target_sr
|
393 |
-
bars_per_chunk = p.bars_per_chunk
|
394 |
-
beats_per_bar = p.beats_per_bar
|
395 |
-
bpm = p.bpm
|
396 |
|
397 |
return {
|
398 |
"running": running,
|
399 |
-
"
|
400 |
-
"
|
401 |
-
"
|
402 |
-
"
|
403 |
-
"
|
|
|
|
|
404 |
"seconds_per_bar": spb,
|
405 |
"chunk_duration_seconds": chunk_secs,
|
406 |
-
"target_sample_rate": target_sr,
|
407 |
"last_chunk_started_at": worker.last_chunk_started_at,
|
408 |
"last_chunk_completed_at": worker.last_chunk_completed_at,
|
409 |
}
|
|
|
323 |
return {"session_id": sid}
|
324 |
|
325 |
@app.get("/jam/next")
|
326 |
+
def jam_next(session_id: str):
|
327 |
+
"""
|
328 |
+
Get the next sequential chunk in the jam session.
|
329 |
+
This ensures chunks are delivered in order without gaps.
|
330 |
+
"""
|
331 |
with jam_lock:
|
332 |
worker = jam_registry.get(session_id)
|
333 |
if worker is None or not worker.is_alive():
|
334 |
raise HTTPException(status_code=404, detail="Session not found")
|
335 |
|
336 |
+
# Get the next sequential chunk (this blocks until ready)
|
337 |
+
chunk = worker.get_next_chunk()
|
338 |
+
|
339 |
+
if chunk is None:
|
340 |
+
raise HTTPException(status_code=408, detail="Chunk not ready within timeout")
|
341 |
+
|
342 |
+
return {
|
343 |
+
"chunk": {
|
344 |
+
"index": chunk.index,
|
345 |
+
"audio_base64": chunk.audio_base64,
|
346 |
+
"metadata": chunk.metadata
|
347 |
+
}
|
348 |
+
}
|
349 |
+
|
350 |
+
@app.post("/jam/consume")
|
351 |
+
def jam_consume(session_id: str = Form(...), chunk_index: int = Form(...)):
|
352 |
+
"""
|
353 |
+
Mark a chunk as consumed by the frontend.
|
354 |
+
This helps the worker manage its buffer and generation flow.
|
355 |
+
"""
|
356 |
+
with jam_lock:
|
357 |
+
worker = jam_registry.get(session_id)
|
358 |
+
if worker is None or not worker.is_alive():
|
359 |
+
raise HTTPException(status_code=404, detail="Session not found")
|
360 |
+
|
361 |
+
worker.mark_chunk_consumed(chunk_index)
|
362 |
+
|
363 |
+
return {"consumed": chunk_index}
|
364 |
+
|
365 |
|
|
|
|
|
|
|
366 |
|
367 |
@app.post("/jam/stop")
|
368 |
def jam_stop(session_id: str = Body(..., embed=True)):
|
|
|
405 |
|
406 |
# Snapshot safely
|
407 |
with worker._lock:
|
408 |
+
last_generated = int(worker.idx)
|
409 |
+
last_delivered = int(worker._last_delivered_index)
|
410 |
queued = len(worker.outbox)
|
411 |
+
buffer_ahead = last_generated - last_delivered
|
412 |
p = worker.params
|
413 |
spb = p.beats_per_bar * (60.0 / p.bpm)
|
414 |
chunk_secs = p.bars_per_chunk * spb
|
|
|
|
|
|
|
|
|
415 |
|
416 |
return {
|
417 |
"running": running,
|
418 |
+
"last_generated_index": last_generated, # Last chunk that finished generating
|
419 |
+
"last_delivered_index": last_delivered, # Last chunk sent to frontend
|
420 |
+
"buffer_ahead": buffer_ahead, # How many chunks ahead we are
|
421 |
+
"queued_chunks": queued, # Total chunks in outbox
|
422 |
+
"bpm": p.bpm,
|
423 |
+
"beats_per_bar": p.beats_per_bar,
|
424 |
+
"bars_per_chunk": p.bars_per_chunk,
|
425 |
"seconds_per_bar": spb,
|
426 |
"chunk_duration_seconds": chunk_secs,
|
427 |
+
"target_sample_rate": p.target_sr,
|
428 |
"last_chunk_started_at": worker.last_chunk_started_at,
|
429 |
"last_chunk_completed_at": worker.last_chunk_completed_at,
|
430 |
}
|
jam_worker.py
CHANGED
@@ -1,17 +1,14 @@
|
|
1 |
-
# jam_worker.py
|
2 |
import threading, time, base64, io, uuid
|
3 |
from dataclasses import dataclass, field
|
4 |
import numpy as np
|
5 |
import soundfile as sf
|
6 |
|
7 |
-
# Pull in your helpers from app.py or refactor them into a shared utils module.
|
8 |
from utils import (
|
9 |
match_loudness_to_reference, stitch_generated, hard_trim_seconds,
|
10 |
apply_micro_fades, make_bar_aligned_context, take_bar_aligned_tail,
|
11 |
resample_and_snap, wav_bytes_base64
|
12 |
)
|
13 |
-
from scipy.signal import resample_poly
|
14 |
-
from math import gcd
|
15 |
|
16 |
@dataclass
|
17 |
class JamParams:
|
@@ -22,8 +19,8 @@ class JamParams:
|
|
22 |
loudness_mode: str = "auto"
|
23 |
headroom_db: float = 1.0
|
24 |
style_vec: np.ndarray | None = None
|
25 |
-
ref_loop: any = None
|
26 |
-
combined_loop: any = None
|
27 |
guidance_weight: float = 1.1
|
28 |
temperature: float = 1.1
|
29 |
topk: int = 40
|
@@ -39,15 +36,20 @@ class JamWorker(threading.Thread):
|
|
39 |
super().__init__(daemon=True)
|
40 |
self.mrt = mrt
|
41 |
self.params = params
|
42 |
-
# Initialize fresh state
|
43 |
self.state = mrt.init_state()
|
44 |
|
45 |
-
# CRITICAL: Set up fresh context from the new combined audio
|
46 |
if params.combined_loop is not None:
|
47 |
self._setup_context_from_combined_loop()
|
|
|
48 |
self.idx = 0
|
49 |
self.outbox: list[JamChunk] = []
|
50 |
self._stop_event = threading.Event()
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
self.last_chunk_started_at = None
|
52 |
self.last_chunk_completed_at = None
|
53 |
self._lock = threading.Lock()
|
@@ -55,14 +57,11 @@ class JamWorker(threading.Thread):
|
|
55 |
def _setup_context_from_combined_loop(self):
|
56 |
"""Set up MRT context tokens from the combined loop audio"""
|
57 |
try:
|
58 |
-
# Import the utility functions (same as used in main generation)
|
59 |
from utils import make_bar_aligned_context, take_bar_aligned_tail
|
60 |
|
61 |
-
# Extract context from combined loop (same logic as generate_loop_continuation_with_mrt)
|
62 |
codec_fps = float(self.mrt.codec.frame_rate)
|
63 |
ctx_seconds = float(self.mrt.config.context_length_frames) / codec_fps
|
64 |
|
65 |
-
# Take tail portion for context (matches main generation)
|
66 |
loop_for_context = take_bar_aligned_tail(
|
67 |
self.params.combined_loop,
|
68 |
self.params.bpm,
|
@@ -70,11 +69,9 @@ class JamWorker(threading.Thread):
|
|
70 |
ctx_seconds
|
71 |
)
|
72 |
|
73 |
-
# Encode to tokens
|
74 |
tokens_full = self.mrt.codec.encode(loop_for_context).astype(np.int32)
|
75 |
tokens = tokens_full[:, :self.mrt.config.decoder_codec_rvq_depth]
|
76 |
|
77 |
-
# Create bar-aligned context
|
78 |
context_tokens = make_bar_aligned_context(
|
79 |
tokens,
|
80 |
bpm=self.params.bpm,
|
@@ -83,30 +80,58 @@ class JamWorker(threading.Thread):
|
|
83 |
beats_per_bar=self.params.beats_per_bar
|
84 |
)
|
85 |
|
86 |
-
# Set context on state - this is the key fix!
|
87 |
self.state.context_tokens = context_tokens
|
88 |
-
|
89 |
print(f"✅ JamWorker: Set up fresh context from combined loop")
|
90 |
-
print(f" Context shape: {context_tokens.shape if context_tokens is not None else None}")
|
91 |
|
92 |
except Exception as e:
|
93 |
print(f"❌ Failed to setup context from combined loop: {e}")
|
94 |
-
# Continue without context rather than crashing
|
95 |
|
96 |
def stop(self):
|
97 |
self._stop_event.set()
|
98 |
|
99 |
-
def update_style(self, new_style_vec: np.ndarray | None):
|
100 |
-
with self._lock:
|
101 |
-
if new_style_vec is not None:
|
102 |
-
self.params.style_vec = new_style_vec
|
103 |
-
|
104 |
def update_knobs(self, *, guidance_weight=None, temperature=None, topk=None):
|
105 |
with self._lock:
|
106 |
if guidance_weight is not None: self.params.guidance_weight = float(guidance_weight)
|
107 |
if temperature is not None: self.params.temperature = float(temperature)
|
108 |
if topk is not None: self.params.topk = int(topk)
|
109 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
def _seconds_per_bar(self) -> float:
|
111 |
return self.params.beats_per_bar * (60.0 / self.params.bpm)
|
112 |
|
@@ -131,58 +156,76 @@ class JamWorker(threading.Thread):
|
|
131 |
return b64, meta
|
132 |
|
133 |
def run(self):
|
|
|
134 |
spb = self._seconds_per_bar()
|
135 |
chunk_secs = self.params.bars_per_chunk * spb
|
136 |
xfade = self.mrt.config.crossfade_length
|
137 |
|
138 |
-
|
139 |
-
|
140 |
while not self._stop_event.is_set():
|
141 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
with self._lock:
|
143 |
style_vec = self.params.style_vec
|
144 |
-
# Temporarily override MRT knobs (thread-local overrides)
|
145 |
self.mrt.guidance_weight = self.params.guidance_weight
|
146 |
self.mrt.temperature = self.params.temperature
|
147 |
self.mrt.topk = self.params.topk
|
|
|
148 |
|
149 |
-
|
|
|
|
|
150 |
need = chunk_secs
|
151 |
chunks = []
|
152 |
self.last_chunk_started_at = time.time()
|
|
|
153 |
while need > 0 and not self._stop_event.is_set():
|
154 |
wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_vec)
|
155 |
chunks.append(wav)
|
156 |
-
# model chunk length (seconds) at model SR
|
157 |
need -= (wav.samples.shape[0] / float(self.mrt.sample_rate))
|
158 |
|
159 |
if self._stop_event.is_set():
|
160 |
break
|
161 |
|
162 |
-
#
|
163 |
y = stitch_generated(chunks, self.mrt.sample_rate, xfade).as_stereo()
|
164 |
y = hard_trim_seconds(y, chunk_secs)
|
165 |
|
166 |
-
#
|
167 |
-
if
|
168 |
-
y, _ = match_loudness_to_reference(
|
169 |
-
|
170 |
-
|
|
|
|
|
171 |
else:
|
172 |
apply_micro_fades(y, 3)
|
173 |
|
174 |
-
#
|
175 |
-
b64, meta = self._snap_and_encode(
|
176 |
-
|
177 |
-
|
|
|
|
|
178 |
|
179 |
-
#
|
180 |
with self._lock:
|
181 |
-
self.idx
|
182 |
-
self.outbox.append(JamChunk(index=
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
|
|
|
|
|
|
|
|
|
1 |
+
# jam_worker.py - SIMPLE FIX VERSION
|
2 |
import threading, time, base64, io, uuid
|
3 |
from dataclasses import dataclass, field
|
4 |
import numpy as np
|
5 |
import soundfile as sf
|
6 |
|
|
|
7 |
from utils import (
|
8 |
match_loudness_to_reference, stitch_generated, hard_trim_seconds,
|
9 |
apply_micro_fades, make_bar_aligned_context, take_bar_aligned_tail,
|
10 |
resample_and_snap, wav_bytes_base64
|
11 |
)
|
|
|
|
|
12 |
|
13 |
@dataclass
|
14 |
class JamParams:
|
|
|
19 |
loudness_mode: str = "auto"
|
20 |
headroom_db: float = 1.0
|
21 |
style_vec: np.ndarray | None = None
|
22 |
+
ref_loop: any = None
|
23 |
+
combined_loop: any = None
|
24 |
guidance_weight: float = 1.1
|
25 |
temperature: float = 1.1
|
26 |
topk: int = 40
|
|
|
36 |
super().__init__(daemon=True)
|
37 |
self.mrt = mrt
|
38 |
self.params = params
|
|
|
39 |
self.state = mrt.init_state()
|
40 |
|
|
|
41 |
if params.combined_loop is not None:
|
42 |
self._setup_context_from_combined_loop()
|
43 |
+
|
44 |
self.idx = 0
|
45 |
self.outbox: list[JamChunk] = []
|
46 |
self._stop_event = threading.Event()
|
47 |
+
|
48 |
+
# NEW: Track delivery state
|
49 |
+
self._last_delivered_index = 0
|
50 |
+
self._max_buffer_ahead = 5 # Don't generate more than 3 chunks ahead
|
51 |
+
|
52 |
+
# Timing info
|
53 |
self.last_chunk_started_at = None
|
54 |
self.last_chunk_completed_at = None
|
55 |
self._lock = threading.Lock()
|
|
|
57 |
def _setup_context_from_combined_loop(self):
|
58 |
"""Set up MRT context tokens from the combined loop audio"""
|
59 |
try:
|
|
|
60 |
from utils import make_bar_aligned_context, take_bar_aligned_tail
|
61 |
|
|
|
62 |
codec_fps = float(self.mrt.codec.frame_rate)
|
63 |
ctx_seconds = float(self.mrt.config.context_length_frames) / codec_fps
|
64 |
|
|
|
65 |
loop_for_context = take_bar_aligned_tail(
|
66 |
self.params.combined_loop,
|
67 |
self.params.bpm,
|
|
|
69 |
ctx_seconds
|
70 |
)
|
71 |
|
|
|
72 |
tokens_full = self.mrt.codec.encode(loop_for_context).astype(np.int32)
|
73 |
tokens = tokens_full[:, :self.mrt.config.decoder_codec_rvq_depth]
|
74 |
|
|
|
75 |
context_tokens = make_bar_aligned_context(
|
76 |
tokens,
|
77 |
bpm=self.params.bpm,
|
|
|
80 |
beats_per_bar=self.params.beats_per_bar
|
81 |
)
|
82 |
|
|
|
83 |
self.state.context_tokens = context_tokens
|
|
|
84 |
print(f"✅ JamWorker: Set up fresh context from combined loop")
|
|
|
85 |
|
86 |
except Exception as e:
|
87 |
print(f"❌ Failed to setup context from combined loop: {e}")
|
|
|
88 |
|
89 |
def stop(self):
|
90 |
self._stop_event.set()
|
91 |
|
|
|
|
|
|
|
|
|
|
|
92 |
def update_knobs(self, *, guidance_weight=None, temperature=None, topk=None):
|
93 |
with self._lock:
|
94 |
if guidance_weight is not None: self.params.guidance_weight = float(guidance_weight)
|
95 |
if temperature is not None: self.params.temperature = float(temperature)
|
96 |
if topk is not None: self.params.topk = int(topk)
|
97 |
|
98 |
+
def get_next_chunk(self) -> JamChunk | None:
|
99 |
+
"""Get the next sequential chunk (blocks/waits if not ready)"""
|
100 |
+
target_index = self._last_delivered_index + 1
|
101 |
+
|
102 |
+
# Wait for the target chunk to be ready (with timeout)
|
103 |
+
max_wait = 30.0 # seconds
|
104 |
+
start_time = time.time()
|
105 |
+
|
106 |
+
while time.time() - start_time < max_wait and not self._stop_event.is_set():
|
107 |
+
with self._lock:
|
108 |
+
# Look for the exact chunk we need
|
109 |
+
for chunk in self.outbox:
|
110 |
+
if chunk.index == target_index:
|
111 |
+
self._last_delivered_index = target_index
|
112 |
+
print(f"📦 Delivered chunk {target_index}")
|
113 |
+
return chunk
|
114 |
+
|
115 |
+
# Not ready yet, wait a bit
|
116 |
+
time.sleep(0.1)
|
117 |
+
|
118 |
+
# Timeout or stopped
|
119 |
+
return None
|
120 |
+
|
121 |
+
def mark_chunk_consumed(self, chunk_index: int):
|
122 |
+
"""Mark a chunk as consumed by the frontend"""
|
123 |
+
with self._lock:
|
124 |
+
self._last_delivered_index = max(self._last_delivered_index, chunk_index)
|
125 |
+
print(f"✅ Chunk {chunk_index} consumed")
|
126 |
+
|
127 |
+
def _should_generate_next_chunk(self) -> bool:
|
128 |
+
"""Check if we should generate the next chunk (don't get too far ahead)"""
|
129 |
+
with self._lock:
|
130 |
+
# Don't generate if we're already too far ahead
|
131 |
+
if self.idx > self._last_delivered_index + self._max_buffer_ahead:
|
132 |
+
return False
|
133 |
+
return True
|
134 |
+
|
135 |
def _seconds_per_bar(self) -> float:
|
136 |
return self.params.beats_per_bar * (60.0 / self.params.bpm)
|
137 |
|
|
|
156 |
return b64, meta
|
157 |
|
158 |
def run(self):
|
159 |
+
"""Main worker loop - generate chunks continuously but don't get too far ahead"""
|
160 |
spb = self._seconds_per_bar()
|
161 |
chunk_secs = self.params.bars_per_chunk * spb
|
162 |
xfade = self.mrt.config.crossfade_length
|
163 |
|
164 |
+
print("🚀 JamWorker started with flow control...")
|
165 |
+
|
166 |
while not self._stop_event.is_set():
|
167 |
+
# Check if we should generate the next chunk
|
168 |
+
if not self._should_generate_next_chunk():
|
169 |
+
# We're ahead enough, wait a bit for frontend to catch up
|
170 |
+
print(f"⏸️ Buffer full, waiting for consumption...")
|
171 |
+
time.sleep(0.5)
|
172 |
+
continue
|
173 |
+
|
174 |
+
# Generate the next chunk
|
175 |
with self._lock:
|
176 |
style_vec = self.params.style_vec
|
|
|
177 |
self.mrt.guidance_weight = self.params.guidance_weight
|
178 |
self.mrt.temperature = self.params.temperature
|
179 |
self.mrt.topk = self.params.topk
|
180 |
+
next_idx = self.idx + 1
|
181 |
|
182 |
+
print(f"🎹 Generating chunk {next_idx}...")
|
183 |
+
|
184 |
+
# Generate enough model chunks to cover chunk_secs
|
185 |
need = chunk_secs
|
186 |
chunks = []
|
187 |
self.last_chunk_started_at = time.time()
|
188 |
+
|
189 |
while need > 0 and not self._stop_event.is_set():
|
190 |
wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_vec)
|
191 |
chunks.append(wav)
|
|
|
192 |
need -= (wav.samples.shape[0] / float(self.mrt.sample_rate))
|
193 |
|
194 |
if self._stop_event.is_set():
|
195 |
break
|
196 |
|
197 |
+
# Stitch and trim to exact seconds at model SR
|
198 |
y = stitch_generated(chunks, self.mrt.sample_rate, xfade).as_stereo()
|
199 |
y = hard_trim_seconds(y, chunk_secs)
|
200 |
|
201 |
+
# Post-process
|
202 |
+
if next_idx == 1 and self.params.ref_loop is not None:
|
203 |
+
y, _ = match_loudness_to_reference(
|
204 |
+
self.params.ref_loop, y,
|
205 |
+
method=self.params.loudness_mode,
|
206 |
+
headroom_db=self.params.headroom_db
|
207 |
+
)
|
208 |
else:
|
209 |
apply_micro_fades(y, 3)
|
210 |
|
211 |
+
# Resample + snap + b64
|
212 |
+
b64, meta = self._snap_and_encode(
|
213 |
+
y, seconds=chunk_secs,
|
214 |
+
target_sr=self.params.target_sr,
|
215 |
+
bars=self.params.bars_per_chunk
|
216 |
+
)
|
217 |
|
218 |
+
# Store the completed chunk
|
219 |
with self._lock:
|
220 |
+
self.idx = next_idx
|
221 |
+
self.outbox.append(JamChunk(index=next_idx, audio_base64=b64, metadata=meta))
|
222 |
+
|
223 |
+
# Keep outbox bounded (remove old chunks)
|
224 |
+
if len(self.outbox) > 10:
|
225 |
+
# Remove chunks that are way behind the delivery point
|
226 |
+
self.outbox = [ch for ch in self.outbox if ch.index > self._last_delivered_index - 5]
|
227 |
+
|
228 |
+
self.last_chunk_completed_at = time.time()
|
229 |
+
print(f"✅ Completed chunk {next_idx}")
|
230 |
+
|
231 |
+
print("🛑 JamWorker stopped")
|