Spaces:
Running
Running
Commit
·
53cce5a
1
Parent(s):
1889c0a
refresh context
Browse files- app.py +14 -4
- jam_worker.py +51 -2
app.py
CHANGED
@@ -131,6 +131,8 @@ def generate_loop_continuation_with_mrt(
|
|
131 |
|
132 |
return out, loud_stats
|
133 |
|
|
|
|
|
134 |
# ----------------------------
|
135 |
# FastAPI app with lazy, thread-safe model init
|
136 |
# ----------------------------
|
@@ -298,10 +300,18 @@ def jam_start(
|
|
298 |
target_sr = int(target_sample_rate or input_sr)
|
299 |
|
300 |
params = JamParams(
|
301 |
-
bpm=bpm,
|
302 |
-
|
303 |
-
|
304 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
305 |
)
|
306 |
|
307 |
worker = JamWorker(mrt, params)
|
|
|
131 |
|
132 |
return out, loud_stats
|
133 |
|
134 |
+
|
135 |
+
|
136 |
# ----------------------------
|
137 |
# FastAPI app with lazy, thread-safe model init
|
138 |
# ----------------------------
|
|
|
300 |
target_sr = int(target_sample_rate or input_sr)
|
301 |
|
302 |
params = JamParams(
|
303 |
+
bpm=bpm,
|
304 |
+
beats_per_bar=beats_per_bar,
|
305 |
+
bars_per_chunk=bars_per_chunk,
|
306 |
+
target_sr=target_sr,
|
307 |
+
loudness_mode=loudness_mode,
|
308 |
+
headroom_db=loudness_headroom_db,
|
309 |
+
style_vec=style_vec,
|
310 |
+
ref_loop=loop_tail, # For loudness matching
|
311 |
+
combined_loop=loop, # NEW: Full loop for context setup
|
312 |
+
guidance_weight=guidance_weight,
|
313 |
+
temperature=temperature,
|
314 |
+
topk=topk
|
315 |
)
|
316 |
|
317 |
worker = JamWorker(mrt, params)
|
jam_worker.py
CHANGED
@@ -17,12 +17,13 @@ from math import gcd
|
|
17 |
class JamParams:
|
18 |
bpm: float
|
19 |
beats_per_bar: int
|
20 |
-
bars_per_chunk: int
|
21 |
target_sr: int
|
22 |
loudness_mode: str = "auto"
|
23 |
headroom_db: float = 1.0
|
24 |
-
style_vec: np.ndarray | None = None
|
25 |
ref_loop: any = None # au.Waveform at model SR for 1st-chunk loudness
|
|
|
26 |
guidance_weight: float = 1.1
|
27 |
temperature: float = 1.1
|
28 |
topk: int = 40
|
@@ -38,7 +39,12 @@ class JamWorker(threading.Thread):
|
|
38 |
super().__init__(daemon=True)
|
39 |
self.mrt = mrt
|
40 |
self.params = params
|
|
|
41 |
self.state = mrt.init_state()
|
|
|
|
|
|
|
|
|
42 |
self.idx = 0
|
43 |
self.outbox: list[JamChunk] = []
|
44 |
self._stop_event = threading.Event()
|
@@ -46,6 +52,47 @@ class JamWorker(threading.Thread):
|
|
46 |
self.last_chunk_completed_at = None
|
47 |
self._lock = threading.Lock()
|
48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
def stop(self):
|
50 |
self._stop_event.set()
|
51 |
|
@@ -137,3 +184,5 @@ class JamWorker(threading.Thread):
|
|
137 |
self.last_chunk_completed_at = time.time()
|
138 |
|
139 |
# optional: cleanup here if needed
|
|
|
|
|
|
17 |
class JamParams:
|
18 |
bpm: float
|
19 |
beats_per_bar: int
|
20 |
+
bars_per_chunk: int
|
21 |
target_sr: int
|
22 |
loudness_mode: str = "auto"
|
23 |
headroom_db: float = 1.0
|
24 |
+
style_vec: np.ndarray | None = None
|
25 |
ref_loop: any = None # au.Waveform at model SR for 1st-chunk loudness
|
26 |
+
combined_loop: any = None # NEW: Full combined audio for context setup
|
27 |
guidance_weight: float = 1.1
|
28 |
temperature: float = 1.1
|
29 |
topk: int = 40
|
|
|
39 |
super().__init__(daemon=True)
|
40 |
self.mrt = mrt
|
41 |
self.params = params
|
42 |
+
# Initialize fresh state
|
43 |
self.state = mrt.init_state()
|
44 |
+
|
45 |
+
# CRITICAL: Set up fresh context from the new combined audio
|
46 |
+
if params.combined_loop is not None:
|
47 |
+
self._setup_context_from_combined_loop()
|
48 |
self.idx = 0
|
49 |
self.outbox: list[JamChunk] = []
|
50 |
self._stop_event = threading.Event()
|
|
|
52 |
self.last_chunk_completed_at = None
|
53 |
self._lock = threading.Lock()
|
54 |
|
55 |
+
def _setup_context_from_combined_loop(self):
|
56 |
+
"""Set up MRT context tokens from the combined loop audio"""
|
57 |
+
try:
|
58 |
+
# Import the utility functions (same as used in main generation)
|
59 |
+
from utils import make_bar_aligned_context, take_bar_aligned_tail
|
60 |
+
|
61 |
+
# Extract context from combined loop (same logic as generate_loop_continuation_with_mrt)
|
62 |
+
codec_fps = float(self.mrt.codec.frame_rate)
|
63 |
+
ctx_seconds = float(self.mrt.config.context_length_frames) / codec_fps
|
64 |
+
|
65 |
+
# Take tail portion for context (matches main generation)
|
66 |
+
loop_for_context = take_bar_aligned_tail(
|
67 |
+
self.params.combined_loop,
|
68 |
+
self.params.bpm,
|
69 |
+
self.params.beats_per_bar,
|
70 |
+
ctx_seconds
|
71 |
+
)
|
72 |
+
|
73 |
+
# Encode to tokens
|
74 |
+
tokens_full = self.mrt.codec.encode(loop_for_context).astype(np.int32)
|
75 |
+
tokens = tokens_full[:, :self.mrt.config.decoder_codec_rvq_depth]
|
76 |
+
|
77 |
+
# Create bar-aligned context
|
78 |
+
context_tokens = make_bar_aligned_context(
|
79 |
+
tokens,
|
80 |
+
bpm=self.params.bpm,
|
81 |
+
fps=int(self.mrt.codec.frame_rate),
|
82 |
+
ctx_frames=self.mrt.config.context_length_frames,
|
83 |
+
beats_per_bar=self.params.beats_per_bar
|
84 |
+
)
|
85 |
+
|
86 |
+
# Set context on state - this is the key fix!
|
87 |
+
self.state.context_tokens = context_tokens
|
88 |
+
|
89 |
+
print(f"✅ JamWorker: Set up fresh context from combined loop")
|
90 |
+
print(f" Context shape: {context_tokens.shape if context_tokens is not None else None}")
|
91 |
+
|
92 |
+
except Exception as e:
|
93 |
+
print(f"❌ Failed to setup context from combined loop: {e}")
|
94 |
+
# Continue without context rather than crashing
|
95 |
+
|
96 |
def stop(self):
|
97 |
self._stop_event.set()
|
98 |
|
|
|
184 |
self.last_chunk_completed_at = time.time()
|
185 |
|
186 |
# optional: cleanup here if needed
|
187 |
+
|
188 |
+
|