Staticaliza commited on
Commit
a00a4e7
·
verified ·
1 Parent(s): 767aa72

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -34
app.py CHANGED
@@ -1,58 +1,57 @@
1
- import os, shlex, subprocess, torch
 
 
2
 
3
- # extra wheels (safe to skip if they fail)
4
  for cmd, env in [
5
- ("pip install flash-attn --no-build-isolation", {"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}),
6
  ("pip install https://github.com/state-spaces/mamba/releases/download/v2.2.4/mamba_ssm-2.2.4+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl", {}),
7
  ("pip install https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.5.0.post8/causal_conv1d-1.5.0.post8+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl", {}),
8
  ]:
9
- try: subprocess.run(shlex.split(cmd), env=os.environ | env, check=True)
10
  except subprocess.CalledProcessError: pass
11
 
12
- # hard-nuke torch.compile everywhere
13
- os.environ["TORCH_COMPILE_DISABLE"]="1"
14
- os.environ["TORCHINDUCTOR_DISABLE"]="1"
15
  torch._dynamo.disable()
16
- torch.compile=lambda fn,*a,**k:fn
17
-
18
- import torchaudio, gradio as gr, spaces, numpy as np
19
- from zonos.model import Zonos
20
- from zonos.conditioning import make_cond_dict, supported_language_codes
21
 
22
  device="cuda"
23
- MODEL_NAMES=["Zyphra/Zonos-v0.1-transformer","Zyphra/Zonos-v0.1-hybrid"]
24
- MODELS={n:Zonos.from_pretrained(n,device=device).eval() for n in MODEL_NAMES}
25
 
26
- def _spk(model,aud):
27
  if aud is None: return None
28
  sr,wav=aud
29
  if wav.dtype.kind in "iu": wav=wav.astype(np.float32)/np.iinfo(wav.dtype).max
30
  return model.make_speaker_embedding(torch.from_numpy(wav).unsqueeze(0),sr)
31
 
32
- @spaces.GPU(duration=120)
33
- def tts(m,text,lang,speaker,
34
- h,sad,disg,fear,sur,ang,oth,neu,
35
- speak,pitch):
36
- model=MODELS[m]
37
- emotion=[h,sad,disg,fear,sur,ang,oth,neu]
38
- cond=make_cond_dict(text=text,language=lang,speaker=_spk(model,speaker),
39
- emotion=emotion,speaking_rate=float(speak),
40
- pitch_std=float(pitch),device=device)
 
 
 
 
41
  with torch.no_grad():
42
- codes=model.generate(model.prepare_conditioning(cond))
43
- wav=model.autoencoder.decode(codes)[0].cpu()
44
- return (model.autoencoder.sampling_rate,wav.numpy())
45
 
46
  langs=supported_language_codes
47
  with gr.Blocks() as demo:
48
- mc=gr.Dropdown(MODEL_NAMES,value=MODEL_NAMES[0],label="model")
49
  txt=gr.Textbox(label="text")
50
  lng=gr.Dropdown(langs,value="en-us",label="language")
51
- spk=gr.Audio(type="numpy",label="speaker ref")
52
- emos=[gr.Slider(0,1,0.3 if i==0 else 0.0,0.05,label=l) for i,l in
53
- enumerate(["happiness","sad","disgust","fear","surprise","anger","other","neutral"])]
54
- rate=gr.Slider(0,40,15,1,label="speaking_rate")
55
- pit=gr.Slider(0,400,20,1,label="pitch_std")
 
56
  out=gr.Audio(label="output")
57
- gr.Button("generate").click(tts,[mc,txt,lng,spk,*emos,rate,pit],out)
58
  if __name__=="__main__": demo.launch()
 
1
+ import os, shlex, subprocess, torch, numpy as np, gradio as gr, torchaudio, spaces
2
+ from zonos.model import Zonos
3
+ from zonos.conditioning import make_cond_dict, supported_language_codes
4
 
5
+ # optional speed-up wheels, silently skip on failure
6
  for cmd, env in [
7
+ ("pip install flash-attn --no-build-isolation", {"FLASH_ATTENTION_SKIP_CUDA_BUILD":"TRUE"}),
8
  ("pip install https://github.com/state-spaces/mamba/releases/download/v2.2.4/mamba_ssm-2.2.4+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl", {}),
9
  ("pip install https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.5.0.post8/causal_conv1d-1.5.0.post8+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl", {}),
10
  ]:
11
+ try: subprocess.run(shlex.split(cmd), env=os.environ|env, check=True)
12
  except subprocess.CalledProcessError: pass
13
 
14
+ os.environ["TORCH_COMPILE_DISABLE"]=os.environ["TORCHINDUCTOR_DISABLE"]="1"
 
 
15
  torch._dynamo.disable()
16
+ torch.compile=lambda f,*a,**k:f
 
 
 
 
17
 
18
  device="cuda"
19
+ model=Zonos.from_pretrained("Zyphra/Zonos-v0.1-transformer",device=device).eval()
 
20
 
21
+ def _spk(aud):
22
  if aud is None: return None
23
  sr,wav=aud
24
  if wav.dtype.kind in "iu": wav=wav.astype(np.float32)/np.iinfo(wav.dtype).max
25
  return model.make_speaker_embedding(torch.from_numpy(wav).unsqueeze(0),sr)
26
 
27
+ @spaces.GPU
28
+ def tts(text,lang,speaker,vq,fmax,pitch,rate,dnsmos):
29
+ cond=make_cond_dict(
30
+ text=text,
31
+ language=lang,
32
+ speaker=_spk(speaker),
33
+ vqscore_8=torch.tensor([vq]*8,device=device).unsqueeze(0),
34
+ fmax=float(fmax),
35
+ pitch_std=float(pitch),
36
+ speaking_rate=float(rate),
37
+ dnsmos_ovrl=float(dnsmos),
38
+ device=device,
39
+ )
40
  with torch.no_grad():
41
+ wav=model.autoencoder.decode(model.generate(model.prepare_conditioning(cond)))[0].cpu()
42
+ out=(wav.clip(-1,1)*32767).short().numpy() # int16 fix
43
+ return (model.autoencoder.sampling_rate,out)
44
 
45
  langs=supported_language_codes
46
  with gr.Blocks() as demo:
 
47
  txt=gr.Textbox(label="text")
48
  lng=gr.Dropdown(langs,value="en-us",label="language")
49
+ spk=gr.Audio(type="numpy",label="speaker ref (optional)")
50
+ vq =gr.Slider(0.5,0.9,0.78,0.01,label="clarity (vq)")
51
+ fmx=gr.Slider(8000,24000,24000,100,label="fmax hz")
52
+ pit=gr.Slider(0,300,20,1,label="pitch std")
53
+ rte=gr.Slider(5,30,15,0.5,label="speaking rate")
54
+ dns=gr.Slider(1,5,4,0.1,label="quality target")
55
  out=gr.Audio(label="output")
56
+ gr.Button("generate").click(tts,[txt,lng,spk,vq,fmx,pit,rte,dns],out)
57
  if __name__=="__main__": demo.launch()