File size: 5,544 Bytes
cd609af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
# jam_worker.py
import threading, time, base64, io, uuid
from dataclasses import dataclass, field
import numpy as np
import soundfile as sf

# Pull in your helpers from app.py or refactor them into a shared utils module.
from utils import (
    match_loudness_to_reference, stitch_generated, hard_trim_seconds,
    apply_micro_fades, make_bar_aligned_context, take_bar_aligned_tail,
    resample_and_snap, wav_bytes_base64
)
from scipy.signal import resample_poly
from math import gcd

@dataclass
class JamParams:
    bpm: float
    beats_per_bar: int
    bars_per_chunk: int           # 4 or 8
    target_sr: int
    loudness_mode: str = "auto"
    headroom_db: float = 1.0
    style_vec: np.ndarray | None = None   # combined_style vector
    ref_loop: any = None                  # au.Waveform at model SR for 1st-chunk loudness
    guidance_weight: float = 1.1
    temperature: float = 1.1
    topk: int = 40

@dataclass
class JamChunk:
    index: int
    audio_base64: str
    metadata: dict

class JamWorker(threading.Thread):
    def __init__(self, mrt, params: JamParams):
        super().__init__(daemon=True)
        self.mrt = mrt
        self.params = params
        self.state = mrt.init_state()
        self.idx = 0
        self.outbox: list[JamChunk] = []
        self._stop = threading.Event()
        self.last_chunk_started_at = None
        self.last_chunk_completed_at = None
        self._lock = threading.Lock()

    def stop(self):
        self._stop.set()

    def update_style(self, new_style_vec: np.ndarray | None):
        with self._lock:
            if new_style_vec is not None:
                self.params.style_vec = new_style_vec

    def update_knobs(self, *, guidance_weight=None, temperature=None, topk=None):
        with self._lock:
            if guidance_weight is not None: self.params.guidance_weight = float(guidance_weight)
            if temperature is not None:     self.params.temperature     = float(temperature)
            if topk is not None:            self.params.topk            = int(topk)

    def _seconds_per_bar(self) -> float:
        return self.params.beats_per_bar * (60.0 / self.params.bpm)

    def _snap_and_encode(self, y, seconds, target_sr, bars):
        cur_sr = int(self.mrt.sample_rate)
        x = y.samples if y.samples.ndim == 2 else y.samples[:, None]
        x = resample_and_snap(x, cur_sr=cur_sr, target_sr=target_sr, seconds=seconds)
        b64, total_samples, channels = wav_bytes_base64(x, target_sr)
        meta = {
            "bpm": int(round(self.params.bpm)),
            "bars": int(bars),
            "beats_per_bar": int(self.params.beats_per_bar),
            "sample_rate": int(target_sr),
            "channels": channels,
            "total_samples": total_samples,
            "seconds_per_bar": self._seconds_per_bar(),
            "loop_duration_seconds": bars * self._seconds_per_bar(),
            "guidance_weight": self.params.guidance_weight,
            "temperature": self.params.temperature,
            "topk": self.params.topk,
        }
        return b64, meta

    def run(self):
        spb = self._seconds_per_bar()
        chunk_secs = self.params.bars_per_chunk * spb
        xfade = self.mrt.config.crossfade_length

        # Prime: set initial context on state (caller should have done this; safe to re-set here)
        # NOTE: We assume caller passed a style_vec computed from tail/whole/blend.
        while not self._stop.is_set():
            # honor live knob updates atomically
            with self._lock:
                style_vec = self.params.style_vec
                # Temporarily override MRT knobs (thread-local overrides)
                self.mrt.guidance_weight = self.params.guidance_weight
                self.mrt.temperature = self.params.temperature
                self.mrt.topk = self.params.topk

            # 1) generate enough model chunks to cover chunk_secs
            need = chunk_secs
            chunks = []
            self.last_chunk_started_at = time.time()
            while need > 0 and not self._stop.is_set():
                wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_vec)
                chunks.append(wav)
                # model chunk length (seconds) at model SR
                need -= (wav.samples.shape[0] / float(self.mrt.sample_rate))

            if self._stop.is_set():
                break

            # 2) stitch and trim to exact seconds at model SR
            y = stitch_generated(chunks, self.mrt.sample_rate, xfade).as_stereo()
            y = hard_trim_seconds(y, chunk_secs)

            # 3) post-process
            if self.idx == 0 and self.params.ref_loop is not None:
                y, _ = match_loudness_to_reference(self.params.ref_loop, y,
                                                   method=self.params.loudness_mode,
                                                   headroom_db=self.params.headroom_db)
            else:
                apply_micro_fades(y, 3)

            # 4) resample + snap + b64
            b64, meta = self._snap_and_encode(y, seconds=chunk_secs,
                                              target_sr=self.params.target_sr,
                                              bars=self.params.bars_per_chunk)

            # 5) enqueue
            with self._lock:
                self.idx += 1
                self.outbox.append(JamChunk(index=self.idx, audio_base64=b64, metadata=meta))

                self.last_chunk_completed_at = time.time()

        # optional: cleanup here if needed