thecollabagepatch commited on
Commit
5709926
·
1 Parent(s): 02fcba6

vibe coded spaceship

Browse files
Files changed (1) hide show
  1. jam_worker.py +35 -28
jam_worker.py CHANGED
@@ -79,6 +79,9 @@ class JamWorker(threading.Thread):
79
  self.mrt = mrt
80
  self.params = params
81
 
 
 
 
82
  # generation state
83
  self.state = self.mrt.init_state()
84
  self.mrt.guidance_weight = float(self.params.guidance_weight)
@@ -115,7 +118,7 @@ class JamWorker(threading.Thread):
115
  self._cv = threading.Condition()
116
 
117
  # control flags
118
- self._stop = threading.Event()
119
  self._max_buffer_ahead = 5
120
 
121
  # reseed queue (install at next safe point)
@@ -128,7 +131,7 @@ class JamWorker(threading.Thread):
128
  # ---------- lifecycle ----------
129
 
130
  def stop(self):
131
- self._stop.set()
132
 
133
  # FastAPI reads this to block until the next sequential chunk is ready
134
  def get_next_chunk(self, timeout: float = 30.0) -> Optional[JamChunk]:
@@ -154,16 +157,17 @@ class JamWorker(threading.Thread):
154
  self._outbox.pop(k, None)
155
 
156
  def update_knobs(self, *, guidance_weight=None, temperature=None, topk=None):
157
- if guidance_weight is not None:
158
- self.params.guidance_weight = float(guidance_weight)
159
- if temperature is not None:
160
- self.params.temperature = float(temperature)
161
- if topk is not None:
162
- self.params.topk = int(topk)
163
- # push into mrt (thread-safe enough for our use)
164
- self.mrt.guidance_weight = float(self.params.guidance_weight)
165
- self.mrt.temperature = float(self.params.temperature)
166
- self.mrt.topk = int(self.params.topk)
 
167
 
168
  # ---------- context / reseed ----------
169
 
@@ -242,16 +246,18 @@ class JamWorker(threading.Thread):
242
  def reseed_from_waveform(self, wav: au.Waveform):
243
  """Immediate reseed: replace context from provided wave (bar-locked, exact length)."""
244
  context_tokens = self._encode_exact_context_tokens(wav)
245
- s = self.mrt.init_state()
246
- s.context_tokens = context_tokens
247
- self.state = s
248
- self._model_stream = None # drop model-domain continuity so next chunk starts cleanly
249
- self._original_context_tokens = np.copy(context_tokens)
 
250
 
251
  def reseed_splice(self, recent_wav: au.Waveform, anchor_bars: float):
252
  """Queue a splice reseed to be applied right after the next emitted loop."""
253
  new_ctx = self._encode_exact_context_tokens(recent_wav)
254
- self._pending_reseed = {"ctx": new_ctx}
 
255
 
256
 
257
  def reseed_from_waveform(self, wav: au.Waveform):
@@ -366,21 +372,19 @@ class JamWorker(threading.Thread):
366
  self.idx += 1
367
 
368
  # If a reseed is queued, install it *right after* we finish a chunk
369
- if self._pending_reseed is not None:
370
- new_state = self.mrt.init_state()
371
- new_state.context_tokens = self._pending_reseed["ctx"]
372
- self.state = new_state
373
- self._model_stream = None # drop model-domain continuity so next chunk starts clean
374
- self._pending_reseed = None
 
375
 
376
  # ---------- main loop ----------
377
 
378
  def run(self):
379
- # set style vector if present
380
- style_vec = self._style_vec
381
-
382
  # generate until stopped
383
- while not self._stop.is_set():
384
  # throttle generation if we are far ahead
385
  if not self._should_generate_next_chunk():
386
  # still try to emit if spool already has enough
@@ -389,6 +393,9 @@ class JamWorker(threading.Thread):
389
  continue
390
 
391
  # generate next model chunk
 
 
 
392
  wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_vec)
393
  # append and spool
394
  self._append_model_chunk_and_spool(wav)
 
79
  self.mrt = mrt
80
  self.params = params
81
 
82
+ # external callers (FastAPI endpoints) use this for atomic updates
83
+ self._lock = threading.RLock()
84
+
85
  # generation state
86
  self.state = self.mrt.init_state()
87
  self.mrt.guidance_weight = float(self.params.guidance_weight)
 
118
  self._cv = threading.Condition()
119
 
120
  # control flags
121
+ self._stop_event = threading.Event()
122
  self._max_buffer_ahead = 5
123
 
124
  # reseed queue (install at next safe point)
 
131
  # ---------- lifecycle ----------
132
 
133
  def stop(self):
134
+ self._stop_event.set()
135
 
136
  # FastAPI reads this to block until the next sequential chunk is ready
137
  def get_next_chunk(self, timeout: float = 30.0) -> Optional[JamChunk]:
 
157
  self._outbox.pop(k, None)
158
 
159
  def update_knobs(self, *, guidance_weight=None, temperature=None, topk=None):
160
+ with self._lock:
161
+ if guidance_weight is not None:
162
+ self.params.guidance_weight = float(guidance_weight)
163
+ if temperature is not None:
164
+ self.params.temperature = float(temperature)
165
+ if topk is not None:
166
+ self.params.topk = int(topk)
167
+ # push into mrt
168
+ self.mrt.guidance_weight = float(self.params.guidance_weight)
169
+ self.mrt.temperature = float(self.params.temperature)
170
+ self.mrt.topk = int(self.params.topk)
171
 
172
  # ---------- context / reseed ----------
173
 
 
246
  def reseed_from_waveform(self, wav: au.Waveform):
247
  """Immediate reseed: replace context from provided wave (bar-locked, exact length)."""
248
  context_tokens = self._encode_exact_context_tokens(wav)
249
+ with self._lock:
250
+ s = self.mrt.init_state()
251
+ s.context_tokens = context_tokens
252
+ self.state = s
253
+ self._model_stream = None # drop model-domain continuity so next chunk starts cleanly
254
+ self._original_context_tokens = np.copy(context_tokens)
255
 
256
  def reseed_splice(self, recent_wav: au.Waveform, anchor_bars: float):
257
  """Queue a splice reseed to be applied right after the next emitted loop."""
258
  new_ctx = self._encode_exact_context_tokens(recent_wav)
259
+ with self._lock:
260
+ self._pending_reseed = {"ctx": new_ctx}
261
 
262
 
263
  def reseed_from_waveform(self, wav: au.Waveform):
 
372
  self.idx += 1
373
 
374
  # If a reseed is queued, install it *right after* we finish a chunk
375
+ with self._lock:
376
+ if self._pending_reseed is not None:
377
+ new_state = self.mrt.init_state()
378
+ new_state.context_tokens = self._pending_reseed["ctx"]
379
+ self.state = new_state
380
+ self._model_stream = None
381
+ self._pending_reseed = None
382
 
383
  # ---------- main loop ----------
384
 
385
  def run(self):
 
 
 
386
  # generate until stopped
387
+ while not self._stop_event.is_set():
388
  # throttle generation if we are far ahead
389
  if not self._should_generate_next_chunk():
390
  # still try to emit if spool already has enough
 
393
  continue
394
 
395
  # generate next model chunk
396
+ # snapshot current style vector under lock for this step
397
+ with self._lock:
398
+ style_vec = self._style_vec
399
  wav, self.state = self.mrt.generate_chunk(state=self.state, style=style_vec)
400
  # append and spool
401
  self._append_model_chunk_and_spool(wav)