Commit
·
d41a575
1
Parent(s):
9a1b4dc
oh boy bar boundaries is hard
Browse files- jam_worker.py +3 -3
- utils.py +90 -21
jam_worker.py
CHANGED
@@ -80,7 +80,7 @@ class JamWorker(threading.Thread):
|
|
80 |
context_tokens = make_bar_aligned_context(
|
81 |
tokens,
|
82 |
bpm=self.params.bpm,
|
83 |
-
fps=
|
84 |
ctx_frames=self.mrt.config.context_length_frames,
|
85 |
beats_per_bar=self.params.beats_per_bar
|
86 |
)
|
@@ -213,7 +213,7 @@ class JamWorker(threading.Thread):
|
|
213 |
tokens_full = self.mrt.codec.encode(tail).astype(np.int32)
|
214 |
tokens = tokens_full[:, :self.mrt.config.decoder_codec_rvq_depth]
|
215 |
context_tokens = make_bar_aligned_context(tokens,
|
216 |
-
bpm=self.params.bpm, fps=
|
217 |
ctx_frames=self.mrt.config.context_length_frames,
|
218 |
beats_per_bar=self.params.beats_per_bar
|
219 |
)
|
@@ -242,7 +242,7 @@ class JamWorker(threading.Thread):
|
|
242 |
ctx = make_bar_aligned_context(
|
243 |
tokens,
|
244 |
bpm=self.params.bpm,
|
245 |
-
fps=
|
246 |
ctx_frames=self.mrt.config.context_length_frames,
|
247 |
beats_per_bar=self.params.beats_per_bar
|
248 |
)
|
|
|
80 |
context_tokens = make_bar_aligned_context(
|
81 |
tokens,
|
82 |
bpm=self.params.bpm,
|
83 |
+
fps=float(self.mrt.codec.frame_rate), # keep fractional fps
|
84 |
ctx_frames=self.mrt.config.context_length_frames,
|
85 |
beats_per_bar=self.params.beats_per_bar
|
86 |
)
|
|
|
213 |
tokens_full = self.mrt.codec.encode(tail).astype(np.int32)
|
214 |
tokens = tokens_full[:, :self.mrt.config.decoder_codec_rvq_depth]
|
215 |
context_tokens = make_bar_aligned_context(tokens,
|
216 |
+
bpm=self.params.bpm, fps=float(self.mrt.codec.frame_rate),
|
217 |
ctx_frames=self.mrt.config.context_length_frames,
|
218 |
beats_per_bar=self.params.beats_per_bar
|
219 |
)
|
|
|
242 |
ctx = make_bar_aligned_context(
|
243 |
tokens,
|
244 |
bpm=self.params.bpm,
|
245 |
+
fps=float(self.mrt.codec.frame_rate), # keep fractional fps
|
246 |
ctx_frames=self.mrt.config.context_length_frames,
|
247 |
beats_per_bar=self.params.beats_per_bar
|
248 |
)
|
utils.py
CHANGED
@@ -109,30 +109,99 @@ def apply_micro_fades(wav: au.Waveform, ms: int = 5) -> None:
|
|
109 |
|
110 |
|
111 |
# ---------- Token context helpers ----------
|
112 |
-
def make_bar_aligned_context(tokens, bpm, fps=25, ctx_frames=250, beats_per_bar=4):
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
tiled = np.tile(tokens, (reps, 1))
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
if max_bars is not None:
|
130 |
-
bars_needed = min(bars_needed, max_bars)
|
131 |
-
|
132 |
-
|
133 |
-
|
|
|
|
|
|
|
|
|
|
|
134 |
return wav
|
135 |
-
|
|
|
|
|
136 |
|
137 |
|
138 |
# ---------- SR normalize + snap ----------
|
|
|
109 |
|
110 |
|
111 |
# ---------- Token context helpers ----------
|
112 |
+
def make_bar_aligned_context(tokens, bpm, fps=25.0, ctx_frames=250, beats_per_bar=4):
|
113 |
+
"""
|
114 |
+
Return a ctx_frames-long slice of `tokens` whose **end** lands on the nearest
|
115 |
+
whole-bar boundary in codec-frame space, even when frames_per_bar is fractional.
|
116 |
+
|
117 |
+
tokens: np.ndarray of shape (T, D) or (T,) where T = codec frames
|
118 |
+
bpm: float
|
119 |
+
fps: float (codec frames per second; keep this as float)
|
120 |
+
ctx_frames: int (length of context window in codec frames)
|
121 |
+
beats_per_bar: int
|
122 |
+
"""
|
123 |
+
|
124 |
+
|
125 |
+
if tokens is None:
|
126 |
+
raise ValueError("tokens is None")
|
127 |
+
tokens = np.asarray(tokens)
|
128 |
+
if tokens.ndim == 1:
|
129 |
+
tokens = tokens[:, None] # promote to (T, 1) for uniform tiling
|
130 |
+
|
131 |
+
T = tokens.shape[0]
|
132 |
+
if T == 0:
|
133 |
+
return tokens
|
134 |
+
|
135 |
+
fps = float(fps)
|
136 |
+
frames_per_bar_f = (beats_per_bar * 60.0 / float(bpm)) * fps # float frames per bar
|
137 |
+
|
138 |
+
# Tile a little more than we need so we can always snap the END to a bar boundary
|
139 |
+
reps = int(np.ceil((ctx_frames + T) / float(T))) + 1
|
140 |
tiled = np.tile(tokens, (reps, 1))
|
141 |
+
total = tiled.shape[0]
|
142 |
+
|
143 |
+
# How many whole bars fit?
|
144 |
+
k_bars = int(np.floor(total / frames_per_bar_f))
|
145 |
+
if k_bars <= 0:
|
146 |
+
# Fallback: just take the last ctx_frames
|
147 |
+
window = tiled[-ctx_frames:]
|
148 |
+
return window
|
149 |
+
|
150 |
+
# Snap END index to the nearest integer frame at a whole-bar boundary
|
151 |
+
end_idx = int(round(k_bars * frames_per_bar_f))
|
152 |
+
end_idx = min(max(end_idx, ctx_frames), total)
|
153 |
+
start_idx = end_idx - ctx_frames
|
154 |
+
if start_idx < 0:
|
155 |
+
start_idx = 0
|
156 |
+
end_idx = ctx_frames
|
157 |
+
|
158 |
+
window = tiled[start_idx:end_idx]
|
159 |
+
|
160 |
+
# Guard against rare off-by-one due to rounding
|
161 |
+
if window.shape[0] < ctx_frames:
|
162 |
+
pad = np.tile(tokens, (int(np.ceil((ctx_frames - window.shape[0]) / T)), 1))
|
163 |
+
window = np.vstack([window, pad])[:ctx_frames]
|
164 |
+
elif window.shape[0] > ctx_frames:
|
165 |
+
window = window[-ctx_frames:]
|
166 |
+
|
167 |
+
return window
|
168 |
+
|
169 |
+
|
170 |
+
def take_bar_aligned_tail(
|
171 |
+
wav: au.Waveform,
|
172 |
+
bpm: float,
|
173 |
+
beats_per_bar: int,
|
174 |
+
ctx_seconds: float,
|
175 |
+
max_bars=None
|
176 |
+
) -> au.Waveform:
|
177 |
+
"""
|
178 |
+
Take a tail whose length is an integer number of bars, with the END aligned
|
179 |
+
to a bar boundary. Uses ceil for bars_needed so we never under-fill the context.
|
180 |
+
"""
|
181 |
+
import math
|
182 |
+
|
183 |
+
# seconds per bar
|
184 |
+
spb = (60.0 / float(bpm)) * float(beats_per_bar)
|
185 |
+
|
186 |
+
# Pick enough whole bars to cover ctx_seconds (avoid underfilling on round-down).
|
187 |
+
# The small epsilon avoids an extra bar due to FP jitter when ctx_seconds ~= k * spb.
|
188 |
+
eps = 1e-9
|
189 |
+
bars_needed = max(1, int(math.ceil((float(ctx_seconds) - eps) / spb)))
|
190 |
+
|
191 |
if max_bars is not None:
|
192 |
+
bars_needed = min(bars_needed, int(max_bars))
|
193 |
+
|
194 |
+
# Convert bars -> samples (do rounding once at the end for stability)
|
195 |
+
samples_per_bar_f = spb * float(wav.sample_rate)
|
196 |
+
n = int(round(bars_needed * samples_per_bar_f))
|
197 |
+
|
198 |
+
total = int(wav.samples.shape[0])
|
199 |
+
if n >= total:
|
200 |
+
# Not enough audio to take that many bars—return as-is (current behavior).
|
201 |
return wav
|
202 |
+
|
203 |
+
start = total - n
|
204 |
+
return au.Waveform(wav.samples[start:], wav.sample_rate)
|
205 |
|
206 |
|
207 |
# ---------- SR normalize + snap ----------
|