thecollabagepatch commited on
Commit
241e975
·
1 Parent(s): f8b3793

use tail end of longer contexts

Browse files
Files changed (1) hide show
  1. app.py +33 -2
app.py CHANGED
@@ -141,6 +141,23 @@ def apply_micro_fades(wav: au.Waveform, ms: int = 5) -> None:
141
  wav.samples[:n] *= env
142
  wav.samples[-n:] *= env[::-1]
143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  # ----------------------------
145
  # Main generation (single combined style vector)
146
  # ----------------------------
@@ -156,9 +173,23 @@ def generate_loop_continuation_with_mrt(
156
  loudness_mode: str = "auto", # "auto"|"lufs"|"rms"|"none"
157
  loudness_headroom_db: float = 1.0, # for the peak guard
158
  ):
159
- # Load loop & encode
160
  loop = au.Waveform.from_file(input_wav_path).resample(mrt.sample_rate).as_stereo()
161
- tokens_full = mrt.codec.encode(loop).astype(np.int32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  tokens = tokens_full[:, :mrt.config.decoder_codec_rvq_depth]
163
 
164
  # Context
 
141
  wav.samples[:n] *= env
142
  wav.samples[-n:] *= env[::-1]
143
 
144
+ def take_bar_aligned_tail(wav: au.Waveform,
145
+ bpm: float,
146
+ beats_per_bar: int,
147
+ ctx_seconds: float) -> au.Waveform:
148
+ """
149
+ Return the LAST N bars whose duration is as close as possible to ctx_seconds,
150
+ anchored to the end of `wav`, and bar-aligned.
151
+ """
152
+ spb = (60.0 / bpm) * beats_per_bar # seconds per bar
153
+ bars_needed = max(1, int(round(ctx_seconds / spb)))
154
+ tail_seconds = bars_needed * spb # exact multiple of bars
155
+ n = int(round(tail_seconds * wav.sample_rate))
156
+ if n >= wav.samples.shape[0]:
157
+ # Input shorter than desired tail: keep whole thing (your existing behavior will tile)
158
+ return wav
159
+ return au.Waveform(wav.samples[-n:], wav.sample_rate)
160
+
161
  # ----------------------------
162
  # Main generation (single combined style vector)
163
  # ----------------------------
 
173
  loudness_mode: str = "auto", # "auto"|"lufs"|"rms"|"none"
174
  loudness_headroom_db: float = 1.0, # for the peak guard
175
  ):
176
+ # Load loop & put into model SR/channels
177
  loop = au.Waveform.from_file(input_wav_path).resample(mrt.sample_rate).as_stereo()
178
+
179
+ # Compute the model's desired context seconds (e.g., 250 frames / 25 fps = 10s)
180
+ codec_fps = float(mrt.codec.frame_rate)
181
+ ctx_seconds = float(mrt.config.context_length_frames) / codec_fps # typically 10.0s
182
+
183
+ # ✅ NEW: take bar-aligned TAIL for context, if input is long enough
184
+ loop_for_context = take_bar_aligned_tail(
185
+ wav=loop,
186
+ bpm=bpm,
187
+ beats_per_bar=beats_per_bar,
188
+ ctx_seconds=ctx_seconds
189
+ )
190
+
191
+ # Encode ONLY the tail (so we condition on recent audio)
192
+ tokens_full = mrt.codec.encode(loop_for_context).astype(np.int32)
193
  tokens = tokens_full[:, :mrt.config.decoder_codec_rvq_depth]
194
 
195
  # Context