thecollabagepatch commited on
Commit
d41a575
·
1 Parent(s): 9a1b4dc

oh boy bar boundaries is hard

Browse files
Files changed (2) hide show
  1. jam_worker.py +3 -3
  2. utils.py +90 -21
jam_worker.py CHANGED
@@ -80,7 +80,7 @@ class JamWorker(threading.Thread):
80
  context_tokens = make_bar_aligned_context(
81
  tokens,
82
  bpm=self.params.bpm,
83
- fps=int(self.mrt.codec.frame_rate),
84
  ctx_frames=self.mrt.config.context_length_frames,
85
  beats_per_bar=self.params.beats_per_bar
86
  )
@@ -213,7 +213,7 @@ class JamWorker(threading.Thread):
213
  tokens_full = self.mrt.codec.encode(tail).astype(np.int32)
214
  tokens = tokens_full[:, :self.mrt.config.decoder_codec_rvq_depth]
215
  context_tokens = make_bar_aligned_context(tokens,
216
- bpm=self.params.bpm, fps=int(self.mrt.codec.frame_rate),
217
  ctx_frames=self.mrt.config.context_length_frames,
218
  beats_per_bar=self.params.beats_per_bar
219
  )
@@ -242,7 +242,7 @@ class JamWorker(threading.Thread):
242
  ctx = make_bar_aligned_context(
243
  tokens,
244
  bpm=self.params.bpm,
245
- fps=int(self.mrt.codec.frame_rate),
246
  ctx_frames=self.mrt.config.context_length_frames,
247
  beats_per_bar=self.params.beats_per_bar
248
  )
 
80
  context_tokens = make_bar_aligned_context(
81
  tokens,
82
  bpm=self.params.bpm,
83
+ fps=float(self.mrt.codec.frame_rate), # keep fractional fps
84
  ctx_frames=self.mrt.config.context_length_frames,
85
  beats_per_bar=self.params.beats_per_bar
86
  )
 
213
  tokens_full = self.mrt.codec.encode(tail).astype(np.int32)
214
  tokens = tokens_full[:, :self.mrt.config.decoder_codec_rvq_depth]
215
  context_tokens = make_bar_aligned_context(tokens,
216
+ bpm=self.params.bpm, fps=float(self.mrt.codec.frame_rate),
217
  ctx_frames=self.mrt.config.context_length_frames,
218
  beats_per_bar=self.params.beats_per_bar
219
  )
 
242
  ctx = make_bar_aligned_context(
243
  tokens,
244
  bpm=self.params.bpm,
245
+ fps=float(self.mrt.codec.frame_rate), # keep fractional fps
246
  ctx_frames=self.mrt.config.context_length_frames,
247
  beats_per_bar=self.params.beats_per_bar
248
  )
utils.py CHANGED
@@ -109,30 +109,99 @@ def apply_micro_fades(wav: au.Waveform, ms: int = 5) -> None:
109
 
110
 
111
  # ---------- Token context helpers ----------
112
- def make_bar_aligned_context(tokens, bpm, fps=25, ctx_frames=250, beats_per_bar=4):
113
- frames_per_bar_f = (beats_per_bar * 60.0 / bpm) * fps
114
- frames_per_bar = int(round(frames_per_bar_f))
115
- if abs(frames_per_bar - frames_per_bar_f) > 1e-3:
116
- reps = int(np.ceil(ctx_frames / len(tokens)))
117
- return np.tile(tokens, (reps, 1))[-ctx_frames:]
118
- reps = int(np.ceil(ctx_frames / len(tokens)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  tiled = np.tile(tokens, (reps, 1))
120
- end = (len(tiled) // frames_per_bar) * frames_per_bar
121
- if end < ctx_frames:
122
- return tiled[-ctx_frames:]
123
- start = end - ctx_frames
124
- return tiled[start:end]
125
-
126
- def take_bar_aligned_tail(wav: au.Waveform, bpm: float, beats_per_bar: int, ctx_seconds: float, max_bars=None) -> au.Waveform:
127
- spb = (60.0 / bpm) * beats_per_bar
128
- bars_needed = max(1, int(round(ctx_seconds / spb)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  if max_bars is not None:
130
- bars_needed = min(bars_needed, max_bars)
131
- tail_seconds = bars_needed * spb
132
- n = int(round(tail_seconds * wav.sample_rate))
133
- if n >= wav.samples.shape[0]:
 
 
 
 
 
134
  return wav
135
- return au.Waveform(wav.samples[-n:], wav.sample_rate)
 
 
136
 
137
 
138
  # ---------- SR normalize + snap ----------
 
109
 
110
 
111
  # ---------- Token context helpers ----------
112
+ def make_bar_aligned_context(tokens, bpm, fps=25.0, ctx_frames=250, beats_per_bar=4):
113
+ """
114
+ Return a ctx_frames-long slice of `tokens` whose **end** lands on the nearest
115
+ whole-bar boundary in codec-frame space, even when frames_per_bar is fractional.
116
+
117
+ tokens: np.ndarray of shape (T, D) or (T,) where T = codec frames
118
+ bpm: float
119
+ fps: float (codec frames per second; keep this as float)
120
+ ctx_frames: int (length of context window in codec frames)
121
+ beats_per_bar: int
122
+ """
123
+
124
+
125
+ if tokens is None:
126
+ raise ValueError("tokens is None")
127
+ tokens = np.asarray(tokens)
128
+ if tokens.ndim == 1:
129
+ tokens = tokens[:, None] # promote to (T, 1) for uniform tiling
130
+
131
+ T = tokens.shape[0]
132
+ if T == 0:
133
+ return tokens
134
+
135
+ fps = float(fps)
136
+ frames_per_bar_f = (beats_per_bar * 60.0 / float(bpm)) * fps # float frames per bar
137
+
138
+ # Tile a little more than we need so we can always snap the END to a bar boundary
139
+ reps = int(np.ceil((ctx_frames + T) / float(T))) + 1
140
  tiled = np.tile(tokens, (reps, 1))
141
+ total = tiled.shape[0]
142
+
143
+ # How many whole bars fit?
144
+ k_bars = int(np.floor(total / frames_per_bar_f))
145
+ if k_bars <= 0:
146
+ # Fallback: just take the last ctx_frames
147
+ window = tiled[-ctx_frames:]
148
+ return window
149
+
150
+ # Snap END index to the nearest integer frame at a whole-bar boundary
151
+ end_idx = int(round(k_bars * frames_per_bar_f))
152
+ end_idx = min(max(end_idx, ctx_frames), total)
153
+ start_idx = end_idx - ctx_frames
154
+ if start_idx < 0:
155
+ start_idx = 0
156
+ end_idx = ctx_frames
157
+
158
+ window = tiled[start_idx:end_idx]
159
+
160
+ # Guard against rare off-by-one due to rounding
161
+ if window.shape[0] < ctx_frames:
162
+ pad = np.tile(tokens, (int(np.ceil((ctx_frames - window.shape[0]) / T)), 1))
163
+ window = np.vstack([window, pad])[:ctx_frames]
164
+ elif window.shape[0] > ctx_frames:
165
+ window = window[-ctx_frames:]
166
+
167
+ return window
168
+
169
+
170
+ def take_bar_aligned_tail(
171
+ wav: au.Waveform,
172
+ bpm: float,
173
+ beats_per_bar: int,
174
+ ctx_seconds: float,
175
+ max_bars=None
176
+ ) -> au.Waveform:
177
+ """
178
+ Take a tail whose length is an integer number of bars, with the END aligned
179
+ to a bar boundary. Uses ceil for bars_needed so we never under-fill the context.
180
+ """
181
+ import math
182
+
183
+ # seconds per bar
184
+ spb = (60.0 / float(bpm)) * float(beats_per_bar)
185
+
186
+ # Pick enough whole bars to cover ctx_seconds (avoid underfilling on round-down).
187
+ # The small epsilon avoids an extra bar due to FP jitter when ctx_seconds ~= k * spb.
188
+ eps = 1e-9
189
+ bars_needed = max(1, int(math.ceil((float(ctx_seconds) - eps) / spb)))
190
+
191
  if max_bars is not None:
192
+ bars_needed = min(bars_needed, int(max_bars))
193
+
194
+ # Convert bars -> samples (do rounding once at the end for stability)
195
+ samples_per_bar_f = spb * float(wav.sample_rate)
196
+ n = int(round(bars_needed * samples_per_bar_f))
197
+
198
+ total = int(wav.samples.shape[0])
199
+ if n >= total:
200
+ # Not enough audio to take that many bars—return as-is (current behavior).
201
  return wav
202
+
203
+ start = total - n
204
+ return au.Waveform(wav.samples[start:], wav.sample_rate)
205
 
206
 
207
  # ---------- SR normalize + snap ----------