Staticaliza commited on
Commit
4dbefe2
Β·
verified Β·
1 Parent(s): c8264c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -39
app.py CHANGED
@@ -1,52 +1,72 @@
1
- import os, shlex, subprocess, torch, numpy as np, spaces, gradio as gr, torchaudio
 
2
  from zonos.model import Zonos
3
  from zonos.conditioning import make_cond_dict, supported_language_codes
4
 
5
- subprocess.run(shlex.split("pip install flash-attn --no-build-isolation"), env=os.environ | {"FLASH_ATTENTION_SKIP_CUDA_BUILD":"TRUE"}, check=True)
6
- subprocess.run(shlex.split("pip install https://github.com/state-spaces/mamba/releases/download/v2.2.4/mamba_ssm-2.2.4+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"), check=True)
7
- subprocess.run(shlex.split("pip install https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.5.0.post8/causal_conv1d-1.5.0.post8+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"), check=True)
 
 
 
 
 
 
 
 
 
 
8
 
9
- os.environ["TORCH_COMPILE_DISABLE"]="1"
10
- os.environ["TORCHINDUCTOR_DISABLE"]="1"
 
11
  import torch._dynamo; torch._dynamo.disable()
12
 
13
- device="cuda"
14
- MODEL_NAMES=["Zyphra/Zonos-v0.1-transformer","Zyphra/Zonos-v0.1-hybrid"]
15
- MODELS={n:Zonos.from_pretrained(n,device=device).eval() for n in MODEL_NAMES}
16
 
17
- def _speaker_embed(model,audio):
 
 
 
 
 
 
 
18
  if audio is None: return None
19
- sr,wav=audio
20
- if wav.dtype.kind in "iu": wav=wav.astype(np.float32)/np.iinfo(wav.dtype).max
21
- wav=torch.from_numpy(wav).unsqueeze(0)
22
- return model.make_speaker_embedding(wav,sr)
23
 
24
  @spaces.GPU
25
- def tts(model_choice,text,language,speaker_audio,
26
- e1,e2,e3,e4,e5,e6,e7,e8,
27
- speaking_rate,pitch_std):
28
- m=MODELS[model_choice]
29
- speaker=_speaker_embed(m,speaker_audio)
30
- emotion=[e1,e2,e3,e4,e5,e6,e7,e8]
31
- cond=make_cond_dict(text=text,language=language,speaker=speaker,
32
- emotion=emotion,speaking_rate=float(speaking_rate),
33
- pitch_std=float(pitch_std),device=device)
34
  with torch.no_grad():
35
- wav=m.autoencoder.decode(m.generate(m.prepare_conditioning(cond)))[0].cpu()
36
- return (m.autoencoder.sampling_rate,wav.numpy())
 
 
 
37
 
38
- langs=supported_language_codes
39
  with gr.Blocks() as demo:
40
- mc=gr.Dropdown(MODEL_NAMES,value=MODEL_NAMES[0],label="model")
41
- txt=gr.Textbox(label="text")
42
- lang=gr.Dropdown(langs,value="en-us",label="language")
43
- spk=gr.Audio(type="numpy",label="speaker ref")
44
- emos=[gr.Slider(0,1,0.3 if i==0 else 0.0,0.05,label=l) for i,l in enumerate(
45
- ["happiness","sadness","disgust","fear","surprise","anger","other","neutral"])]
46
- rate=gr.Slider(0,40,15,1,label="speaking_rate")
47
- pitch=gr.Slider(0,400,20,1,label="pitch_std")
48
- out=gr.Audio(label="output")
49
- gr.Button("generate").click(fn=tts,
50
- inputs=[mc,txt,lang,spk,*emos,rate,pitch],
51
- outputs=out)
52
- if __name__=="__main__": demo.launch()
 
1
+ # app.py
2
+ import os, shlex, subprocess, torch, numpy as np, gradio as gr, torchaudio, spaces
3
  from zonos.model import Zonos
4
  from zonos.conditioning import make_cond_dict, supported_language_codes
5
 
6
+ # ── optional perf wheels (safe to ignore if they fail) ───────────────────────────
7
+ cmds = [
8
+ "pip install flash-attn --no-build-isolation",
9
+ "pip install https://github.com/state-spaces/mamba/releases/download/v2.2.4/"
10
+ "mamba_ssm-2.2.4+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl",
11
+ "pip install https://github.com/Dao-AILab/causal-conv1d/releases/download/"
12
+ "v1.5.0.post8/causal_conv1d-1.5.0.post8+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl",
13
+ ]
14
+ for c in cmds:
15
+ try:
16
+ subprocess.run(shlex.split(c), check=True)
17
+ except subprocess.CalledProcessError:
18
+ print("wheel skipped:", c.split()[2 if c.startswith('pip') else -1])
19
 
20
+ # ── disable torch.compile: zerogpu lacks full cuda props ────────────────────────
21
+ os.environ["TORCH_COMPILE_DISABLE"] = "1"
22
+ os.environ["TORCHINDUCTOR_DISABLE"] = "1"
23
  import torch._dynamo; torch._dynamo.disable()
24
 
25
+ device = "cuda" # zerogpu maps this transparently
26
+ MODEL_NAME = "Zyphra/Zonos-v0.1-transformer" # hybrid commented out for now
 
27
 
28
+ _cached_model: Zonos | None = None
29
+ def get_model() -> Zonos:
30
+ global _cached_model
31
+ if _cached_model is None:
32
+ _cached_model = Zonos.from_pretrained(MODEL_NAME, device=device).eval()
33
+ return _cached_model
34
+
35
+ def _speaker_embed(audio):
36
  if audio is None: return None
37
+ sr, wav = audio
38
+ if wav.dtype.kind in "iu": wav = wav.astype(np.float32) / np.iinfo(wav.dtype).max
39
+ wav = torch.from_numpy(wav).unsqueeze(0)
40
+ return get_model().make_speaker_embedding(wav, sr)
41
 
42
  @spaces.GPU
43
+ def tts(text, language, speaker_audio,
44
+ e1,e2,e3,e4,e5,e6,e7,e8, speaking_rate, pitch_std):
45
+ m = get_model()
46
+ speaker = _speaker_embed(speaker_audio)
47
+ emotion = [e1,e2,e3,e4,e5,e6,e7,e8]
48
+ cond = make_cond_dict(
49
+ text=text, language=language, speaker=speaker, emotion=emotion,
50
+ speaking_rate=float(speaking_rate), pitch_std=float(pitch_std), device=device
51
+ )
52
  with torch.no_grad():
53
+ codes = m.generate(m.prepare_conditioning(cond))
54
+ wav = m.autoencoder.decode(codes)[0].cpu()
55
+ return (m.autoencoder.sampling_rate, wav.numpy())
56
+
57
+ langs = supported_language_codes # from the library itself
58
 
 
59
  with gr.Blocks() as demo:
60
+ txt = gr.Textbox(label="text")
61
+ lng = gr.Dropdown(langs, value="en-us", label="language")
62
+ aud = gr.Audio(type="numpy", label="speaker ref (optional)")
63
+ emos = [gr.Slider(0,1,0.3 if i==0 else 0.0,0.05,label=l)
64
+ for i,l in enumerate(["happiness","sadness","disgust","fear",
65
+ "surprise","anger","other","neutral"])]
66
+ rate = gr.Slider(0,40,15,1,label="speaking_rate")
67
+ pitch= gr.Slider(0,400,20,1,label="pitch_std")
68
+ out = gr.Audio(label="output")
69
+ gr.Button("generate").click(tts,[txt,lng,aud,*emos,rate,pitch],out)
70
+
71
+ if __name__ == "__main__":
72
+ demo.launch()