thecollabagepatch commited on
Commit
53cce5a
·
1 Parent(s): 1889c0a

refresh context

Browse files
Files changed (2) hide show
  1. app.py +14 -4
  2. jam_worker.py +51 -2
app.py CHANGED
@@ -131,6 +131,8 @@ def generate_loop_continuation_with_mrt(
131
 
132
  return out, loud_stats
133
 
 
 
134
  # ----------------------------
135
  # FastAPI app with lazy, thread-safe model init
136
  # ----------------------------
@@ -298,10 +300,18 @@ def jam_start(
298
  target_sr = int(target_sample_rate or input_sr)
299
 
300
  params = JamParams(
301
- bpm=bpm, beats_per_bar=beats_per_bar, bars_per_chunk=bars_per_chunk,
302
- target_sr=target_sr, loudness_mode=loudness_mode, headroom_db=loudness_headroom_db,
303
- style_vec=style_vec, ref_loop=loop_tail,
304
- guidance_weight=guidance_weight, temperature=temperature, topk=topk
 
 
 
 
 
 
 
 
305
  )
306
 
307
  worker = JamWorker(mrt, params)
 
131
 
132
  return out, loud_stats
133
 
134
+
135
+
136
  # ----------------------------
137
  # FastAPI app with lazy, thread-safe model init
138
  # ----------------------------
 
300
  target_sr = int(target_sample_rate or input_sr)
301
 
302
  params = JamParams(
303
+ bpm=bpm,
304
+ beats_per_bar=beats_per_bar,
305
+ bars_per_chunk=bars_per_chunk,
306
+ target_sr=target_sr,
307
+ loudness_mode=loudness_mode,
308
+ headroom_db=loudness_headroom_db,
309
+ style_vec=style_vec,
310
+ ref_loop=loop_tail, # For loudness matching
311
+ combined_loop=loop, # NEW: Full loop for context setup
312
+ guidance_weight=guidance_weight,
313
+ temperature=temperature,
314
+ topk=topk
315
  )
316
 
317
  worker = JamWorker(mrt, params)
jam_worker.py CHANGED
@@ -17,12 +17,13 @@ from math import gcd
17
  class JamParams:
18
  bpm: float
19
  beats_per_bar: int
20
- bars_per_chunk: int # 4 or 8
21
  target_sr: int
22
  loudness_mode: str = "auto"
23
  headroom_db: float = 1.0
24
- style_vec: np.ndarray | None = None # combined_style vector
25
  ref_loop: any = None # au.Waveform at model SR for 1st-chunk loudness
 
26
  guidance_weight: float = 1.1
27
  temperature: float = 1.1
28
  topk: int = 40
@@ -38,7 +39,12 @@ class JamWorker(threading.Thread):
38
  super().__init__(daemon=True)
39
  self.mrt = mrt
40
  self.params = params
 
41
  self.state = mrt.init_state()
 
 
 
 
42
  self.idx = 0
43
  self.outbox: list[JamChunk] = []
44
  self._stop_event = threading.Event()
@@ -46,6 +52,47 @@ class JamWorker(threading.Thread):
46
  self.last_chunk_completed_at = None
47
  self._lock = threading.Lock()
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  def stop(self):
50
  self._stop_event.set()
51
 
@@ -137,3 +184,5 @@ class JamWorker(threading.Thread):
137
  self.last_chunk_completed_at = time.time()
138
 
139
  # optional: cleanup here if needed
 
 
 
17
  class JamParams:
18
  bpm: float
19
  beats_per_bar: int
20
+ bars_per_chunk: int
21
  target_sr: int
22
  loudness_mode: str = "auto"
23
  headroom_db: float = 1.0
24
+ style_vec: np.ndarray | None = None
25
  ref_loop: any = None # au.Waveform at model SR for 1st-chunk loudness
26
+ combined_loop: any = None # NEW: Full combined audio for context setup
27
  guidance_weight: float = 1.1
28
  temperature: float = 1.1
29
  topk: int = 40
 
39
  super().__init__(daemon=True)
40
  self.mrt = mrt
41
  self.params = params
42
+ # Initialize fresh state
43
  self.state = mrt.init_state()
44
+
45
+ # CRITICAL: Set up fresh context from the new combined audio
46
+ if params.combined_loop is not None:
47
+ self._setup_context_from_combined_loop()
48
  self.idx = 0
49
  self.outbox: list[JamChunk] = []
50
  self._stop_event = threading.Event()
 
52
  self.last_chunk_completed_at = None
53
  self._lock = threading.Lock()
54
 
55
+ def _setup_context_from_combined_loop(self):
56
+ """Set up MRT context tokens from the combined loop audio"""
57
+ try:
58
+ # Import the utility functions (same as used in main generation)
59
+ from utils import make_bar_aligned_context, take_bar_aligned_tail
60
+
61
+ # Extract context from combined loop (same logic as generate_loop_continuation_with_mrt)
62
+ codec_fps = float(self.mrt.codec.frame_rate)
63
+ ctx_seconds = float(self.mrt.config.context_length_frames) / codec_fps
64
+
65
+ # Take tail portion for context (matches main generation)
66
+ loop_for_context = take_bar_aligned_tail(
67
+ self.params.combined_loop,
68
+ self.params.bpm,
69
+ self.params.beats_per_bar,
70
+ ctx_seconds
71
+ )
72
+
73
+ # Encode to tokens
74
+ tokens_full = self.mrt.codec.encode(loop_for_context).astype(np.int32)
75
+ tokens = tokens_full[:, :self.mrt.config.decoder_codec_rvq_depth]
76
+
77
+ # Create bar-aligned context
78
+ context_tokens = make_bar_aligned_context(
79
+ tokens,
80
+ bpm=self.params.bpm,
81
+ fps=int(self.mrt.codec.frame_rate),
82
+ ctx_frames=self.mrt.config.context_length_frames,
83
+ beats_per_bar=self.params.beats_per_bar
84
+ )
85
+
86
+ # Set context on state - this is the key fix!
87
+ self.state.context_tokens = context_tokens
88
+
89
+ print(f"✅ JamWorker: Set up fresh context from combined loop")
90
+ print(f" Context shape: {context_tokens.shape if context_tokens is not None else None}")
91
+
92
+ except Exception as e:
93
+ print(f"❌ Failed to setup context from combined loop: {e}")
94
+ # Continue without context rather than crashing
95
+
96
  def stop(self):
97
  self._stop_event.set()
98
 
 
184
  self.last_chunk_completed_at = time.time()
185
 
186
  # optional: cleanup here if needed
187
+
188
+