thecollabagepatch commited on
Commit
af5a7b2
·
1 Parent(s): 9fb3c06

update styles endpoint

Browse files
Files changed (1) hide show
  1. app.py +48 -34
app.py CHANGED
@@ -17,6 +17,7 @@ from jam_worker import JamWorker, JamParams, JamChunk
17
  import uuid, threading
18
 
19
  import gradio as gr
 
20
 
21
  def create_documentation_interface():
22
  """Create a Gradio interface for documentation and transparency"""
@@ -581,47 +582,60 @@ def jam_stop(session_id: str = Body(..., embed=True)):
581
  jam_registry.pop(session_id, None)
582
  return {"stopped": True}
583
 
584
- @app.post("/jam/update")
585
- def jam_update(session_id: str = Form(...),
586
- guidance_weight: float | None = Form(None),
587
- temperature: float | None = Form(None),
588
- topk: int | None = Form(None)):
589
- with jam_lock:
590
- worker = jam_registry.get(session_id)
591
- if worker is None or not worker.is_alive():
592
- raise HTTPException(status_code=404, detail="Session not found")
593
- worker.update_knobs(guidance_weight=guidance_weight, temperature=temperature, topk=topk)
594
- return {"ok": True}
595
 
596
- @app.post("/jam/update_styles")
597
- def jam_update_styles(session_id: str = Form(...),
598
- styles: str = Form(""),
599
- style_weights: str = Form(""),
600
- loop_weight: float = Form(1.0),
601
- use_current_mix_as_style: bool = Form(False)):
602
  with jam_lock:
603
  worker = jam_registry.get(session_id)
604
  if worker is None or not worker.is_alive():
605
  raise HTTPException(status_code=404, detail="Session not found")
606
 
607
- embeds, weights = [], []
608
- # Optionally re-embed from current combined loop
609
- if use_current_mix_as_style and worker.params.combined_loop is not None:
610
- embeds.append(worker.mrt.embed_style(worker.params.combined_loop))
611
- weights.append(float(loop_weight))
612
-
613
- extra = [s for s in (styles.split(",") if styles else []) if s.strip()]
614
- sw = [float(x) for x in style_weights.split(",")] if style_weights else []
615
- for i, s in enumerate(extra):
616
- embeds.append(worker.mrt.embed_style(s.strip()))
617
- weights.append(sw[i] if i < len(sw) else 1.0)
618
-
619
- wsum = sum(weights) or 1.0
620
- weights = [w/wsum for w in weights]
621
- style_vec = np.sum([w*e for w,e in zip(weights, embeds)], axis=0).astype(np.float32)
622
 
623
- with worker._lock:
624
- worker.params.style_vec = style_vec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
625
 
626
  return {"ok": True}
627
 
 
17
  import uuid, threading
18
 
19
  import gradio as gr
20
+ from typing import Optional
21
 
22
  def create_documentation_interface():
23
  """Create a Gradio interface for documentation and transparency"""
 
582
  jam_registry.pop(session_id, None)
583
  return {"stopped": True}
584
 
585
+ @app.post("/jam/update") # consolidated
586
+ def jam_update(
587
+ session_id: str = Form(...),
588
+
589
+ # knobs (all optional)
590
+ guidance_weight: Optional[float] = Form(None),
591
+ temperature: Optional[float] = Form(None),
592
+ topk: Optional[int] = Form(None),
 
 
 
593
 
594
+ # styles (all optional)
595
+ styles: str = Form(""),
596
+ style_weights: str = Form(""),
597
+ loop_weight: Optional[float] = Form(None), # None means "don’t change"
598
+ use_current_mix_as_style: bool = Form(False),
599
+ ):
600
  with jam_lock:
601
  worker = jam_registry.get(session_id)
602
  if worker is None or not worker.is_alive():
603
  raise HTTPException(status_code=404, detail="Session not found")
604
 
605
+ # --- 1) Apply knob updates (atomic under lock)
606
+ if any(v is not None for v in (guidance_weight, temperature, topk)):
607
+ worker.update_knobs(
608
+ guidance_weight=guidance_weight,
609
+ temperature=temperature,
610
+ topk=topk
611
+ )
 
 
 
 
 
 
 
 
612
 
613
+ # --- 2) Apply style updates only if requested
614
+ wants_style_update = use_current_mix_as_style or (styles.strip() != "")
615
+ if wants_style_update:
616
+ embeds, weights = [], []
617
+
618
+ # optional: include current mix as a style component
619
+ if use_current_mix_as_style and worker.params.combined_loop is not None:
620
+ lw = 1.0 if loop_weight is None else float(loop_weight)
621
+ embeds.append(worker.mrt.embed_style(worker.params.combined_loop))
622
+ weights.append(lw)
623
+
624
+ # extra text styles
625
+ extra = [s for s in (styles.split(",") if styles else []) if s.strip()]
626
+ sw = [float(x) for x in style_weights.split(",")] if style_weights else []
627
+ for i, s in enumerate(extra):
628
+ embeds.append(worker.mrt.embed_style(s.strip()))
629
+ weights.append(sw[i] if i < len(sw) else 1.0)
630
+
631
+ if embeds: # only swap if we actually built something
632
+ wsum = sum(weights) or 1.0
633
+ weights = [w / wsum for w in weights]
634
+ style_vec = np.sum([w * e for w, e in zip(weights, embeds)], axis=0).astype(np.float32)
635
+
636
+ # install atomically
637
+ with worker._lock:
638
+ worker.params.style_vec = style_vec
639
 
640
  return {"ok": True}
641