File size: 7,739 Bytes
cd609af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53cce5a
cd609af
 
 
53cce5a
cd609af
53cce5a
cd609af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53cce5a
cd609af
53cce5a
 
 
 
cd609af
 
1889c0a
cd609af
 
 
 
53cce5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cd609af
1889c0a
cd609af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1889c0a
cd609af
 
 
 
 
 
 
 
 
 
 
 
1889c0a
cd609af
 
 
 
 
1889c0a
cd609af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53cce5a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
# jam_worker.py
import threading, time, base64, io, uuid
from dataclasses import dataclass, field
import numpy as np
import soundfile as sf

# Pull in your helpers from app.py or refactor them into a shared utils module.
from utils import (
    match_loudness_to_reference, stitch_generated, hard_trim_seconds,
    apply_micro_fades, make_bar_aligned_context, take_bar_aligned_tail,
    resample_and_snap, wav_bytes_base64
)
from scipy.signal import resample_poly
from math import gcd

@dataclass
class JamParams:
    bpm: float
    beats_per_bar: int
    bars_per_chunk: int
    target_sr: int
    loudness_mode: str = "auto"
    headroom_db: float = 1.0
    style_vec: np.ndarray | None = None
    ref_loop: any = None                  # au.Waveform at model SR for 1st-chunk loudness
    combined_loop: any = None             # NEW: Full combined audio for context setup
    guidance_weight: float = 1.1
    temperature: float = 1.1
    topk: int = 40

@dataclass
class JamChunk:
    index: int
    audio_base64: str
    metadata: dict

class JamWorker(threading.Thread):
    def __init__(self, mrt, params: JamParams):
        super().__init__(daemon=True)
        self.mrt = mrt
        self.params = params
        # Initialize fresh state
        self.state = mrt.init_state()
        
        # CRITICAL: Set up fresh context from the new combined audio
        if params.combined_loop is not None:
            self._setup_context_from_combined_loop()
        self.idx = 0
        self.outbox: list[JamChunk] = []
        self._stop_event = threading.Event()
        self.last_chunk_started_at = None
        self.last_chunk_completed_at = None
        self._lock = threading.Lock()

    def _setup_context_from_combined_loop(self):
        """Set up MRT context tokens from the combined loop audio"""
        try:
            # Import the utility functions (same as used in main generation)
            from utils import make_bar_aligned_context, take_bar_aligned_tail
            
            # Extract context from combined loop (same logic as generate_loop_continuation_with_mrt)
            codec_fps = float(self.mrt.codec.frame_rate)
            ctx_seconds = float(self.mrt.config.context_length_frames) / codec_fps
            
            # Take tail portion for context (matches main generation)
            loop_for_context = take_bar_aligned_tail(
                self.params.combined_loop, 
                self.params.bpm, 
                self.params.beats_per_bar, 
                ctx_seconds
            )
            
            # Encode to tokens
            tokens_full = self.mrt.codec.encode(loop_for_context).astype(np.int32)
            tokens = tokens_full[:, :self.mrt.config.decoder_codec_rvq_depth]
            
            # Create bar-aligned context
            context_tokens = make_bar_aligned_context(
                tokens, 
                bpm=self.params.bpm, 
                fps=int(self.mrt.codec.frame_rate),
                ctx_frames=self.mrt.config.context_length_frames, 
                beats_per_bar=self.params.beats_per_bar
            )
            
            # Set context on state - this is the key fix!
            self.state.context_tokens = context_tokens
            
            print(f"✅ JamWorker: Set up fresh context from combined loop")
            print(f"   Context shape: {context_tokens.shape if context_tokens is not None else None}")
            
        except Exception as e:
            print(f"❌ Failed to setup context from combined loop: {e}")
            # Continue without context rather than crashing

    def stop(self):
        self._stop_event.set()

    def update_style(self, new_style_vec: np.ndarray | None):
        with self._lock:
            if new_style_vec is not None:
                self.params.style_vec = new_style_vec

    def update_knobs(self, *, guidance_weight=None, temperature=None, topk=None):
        with self._lock:
            if guidance_weight is not None: self.params.guidance_weight = float(guidance_weight)
            if temperature is not None:     self.params.temperature     = float(temperature)
            if topk is not None:            self.params.topk            = int(topk)

    def _seconds_per_bar(self) -> float:
        return self.params.beats_per_bar * (60.0 / self.params.bpm)

    def _snap_and_encode(self, y, seconds, target_sr, bars):
        cur_sr = int(self.mrt.sample_rate)
        x = y.samples if y.samples.ndim == 2 else y.samples[:, None]
        x = resample_and_snap(x, cur_sr=cur_sr, target_sr=target_sr, seconds=seconds)
        b64, total_samples, channels = wav_bytes_base64(x, target_sr)
        meta = {
            "bpm": int(round(self.params.bpm)),
            "bars": int(bars),
            "beats_per_bar": int(self.params.beats_per_bar),
            "sample_rate": int(target_sr),
            "channels": channels,
            "total_samples": total_samples,
            "seconds_per_bar": self._seconds_per_bar(),
            "loop_duration_seconds": bars * self._seconds_per_bar(),
            "guidance_weight": self.params.guidance_weight,
            "temperature": self.params.temperature,
            "topk": self.params.topk,
        }
        return b64, meta

    def run(self):
        spb = self._seconds_per_bar()
        chunk_secs = self.params.bars_per_chunk * spb
        xfade = self.mrt.config.crossfade_length

        # Prime: set initial context on state (caller should have done this; safe to re-set here)
        # NOTE: We assume caller passed a style_vec computed from tail/whole/blend.
        while not self._stop_event.is_set():
            # honor live knob updates atomically
            with self._lock:
                style_vec = self.params.style_vec
                # Temporarily override MRT knobs (thread-local overrides)
                self.mrt.guidance_weight = self.params.guidance_weight
                self.mrt.temperature = self.params.temperature
                self.mrt.topk = self.params.topk

            # 1) generate enough model chunks to cover chunk_secs
            need = chunk_secs
            chunks = []
            self.last_chunk_started_at = time.time()
            while need > 0 and not self._stop_event.is_set():
                wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_vec)
                chunks.append(wav)
                # model chunk length (seconds) at model SR
                need -= (wav.samples.shape[0] / float(self.mrt.sample_rate))

            if self._stop_event.is_set():
                break

            # 2) stitch and trim to exact seconds at model SR
            y = stitch_generated(chunks, self.mrt.sample_rate, xfade).as_stereo()
            y = hard_trim_seconds(y, chunk_secs)

            # 3) post-process
            if self.idx == 0 and self.params.ref_loop is not None:
                y, _ = match_loudness_to_reference(self.params.ref_loop, y,
                                                   method=self.params.loudness_mode,
                                                   headroom_db=self.params.headroom_db)
            else:
                apply_micro_fades(y, 3)

            # 4) resample + snap + b64
            b64, meta = self._snap_and_encode(y, seconds=chunk_secs,
                                              target_sr=self.params.target_sr,
                                              bars=self.params.bars_per_chunk)

            # 5) enqueue
            with self._lock:
                self.idx += 1
                self.outbox.append(JamChunk(index=self.idx, audio_base64=b64, metadata=meta))

                self.last_chunk_completed_at = time.time()

        # optional: cleanup here if needed