File size: 5,500 Bytes
39d0c1a
4b77634
 
 
a00a4e7
 
767aa72
3d58a26
 
39d0c1a
 
 
 
3d58a26
71ed3fb
39d0c1a
 
 
e2288b2
3d58a26
 
 
 
 
 
 
 
e2288b2
 
39d0c1a
e2288b2
4b77634
e2288b2
 
 
 
 
 
 
 
 
 
 
 
39d0c1a
e2288b2
 
 
 
 
 
39d0c1a
e2288b2
 
 
 
 
 
39d0c1a
e2288b2
 
39d0c1a
e2288b2
39d0c1a
 
 
e2288b2
39d0c1a
 
 
 
 
 
e2288b2
 
71ed3fb
e2288b2
 
 
 
39d0c1a
 
 
 
e2288b2
71ed3fb
39d0c1a
e2288b2
39d0c1a
71ed3fb
39d0c1a
e2288b2
 
 
39d0c1a
e2288b2
39d0c1a
 
e2288b2
39d0c1a
 
71ed3fb
 
39d0c1a
 
e2288b2
 
 
 
39d0c1a
e2288b2
3d58a26
39d0c1a
e2288b2
39d0c1a
 
 
e2288b2
39d0c1a
e2288b2
39d0c1a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e2288b2
 
 
 
 
 
 
39d0c1a
 
 
 
 
 
 
 
 
e2288b2
 
 
 
 
 
39d0c1a
e2288b2
 
 
 
 
 
 
 
71ed3fb
39d0c1a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import os, torch, torchaudio, gradio as gr

import spaces

from zonos.model import Zonos
from zonos.conditioning import make_cond_dict, supported_language_codes

os.environ["TORCH_COMPILE_DISABLE"] = "1"
os.environ["TORCHINDUCTOR_DISABLE"] = "1"
os.environ["TORCHDYNAMO_DISABLE"] = "1"
os.environ["TORCHDYNAMO_SUPPRESS_ERRORS"] = "True"
torch._dynamo.disable()
torch.compile = lambda f, *_, **__: f

device = "cuda"
MODEL_NAME = "Zyphra/Zonos-v0.1-transformer"
MODEL = Zonos.from_pretrained(MODEL_NAME, device=device).requires_grad_(False).eval()


def _patch_cuda_props():
    if torch.cuda.is_available():
        for i in range(torch.cuda.device_count()):
            p = torch.cuda.get_device_properties(i)
            if not hasattr(p, "regs_per_multiprocessor"):
                setattr(p, "regs_per_multiprocessor", 65536)
            if not hasattr(p, "max_threads_per_multi_processor"):
                setattr(p, "max_threads_per_multi_processor", 2048)


_patch_cuda_props()

@spaces.GPU
def generate_audio(
    text,
    language,
    speaker_audio,
    e1,
    e2,
    e3,
    e4,
    e5,
    e6,
    e7,
    e8,
    clarity,
    fmax,
    pitch_std,
    speaking_rate,
    dnsmos_ovrl,
    cfg_scale,
    min_p,
    steps,
    seed,
    randomize_seed,
    progress=gr.Progress(),
):
    if randomize_seed:
        seed = torch.randint(0, 2**32 - 1, (1,)).item()
    torch.manual_seed(int(seed))

    speaker_embedding = None
    if speaker_audio is not None:
        wav, sr = torchaudio.load(speaker_audio)
        speaker_embedding = (
            MODEL.make_speaker_embedding(wav, sr).to(device, dtype=torch.bfloat16)
        )

    emotion_tensor = torch.tensor(
        [e1, e2, e3, e4, e5, e6, e7, e8], device=device, dtype=torch.float32
    )
    vq_tensor = torch.tensor([clarity] * 8, device=device, dtype=torch.float32).unsqueeze(
        0
    )

    cond_dict = make_cond_dict(
        text=text,
        language=language,
        speaker=speaker_embedding,
        emotion=emotion_tensor,
        vqscore_8=vq_tensor,
        fmax=float(fmax),
        pitch_std=float(pitch_std),
        speaking_rate=float(speaking_rate),
        dnsmos_ovrl=float(dnsmos_ovrl),
        device=device,
    )
    conditioning = MODEL.prepare_conditioning(cond_dict)

    estimated_total_steps = int(steps)

    def cb(_, step, __):
        progress((step, estimated_total_steps))
        return True

    codes = MODEL.generate(
        prefix_conditioning=conditioning,
        max_new_tokens=int(steps),
        cfg_scale=float(cfg_scale),
        batch_size=1,
        sampling_params=dict(min_p=float(min_p)),
        callback=cb,
    )

    wav_out = MODEL.autoencoder.decode(codes).cpu().detach()
    sr_out = MODEL.autoencoder.sampling_rate
    if wav_out.dim() == 2 and wav_out.size(0) > 1:
        wav_out = wav_out[0:1, :]
    return (sr_out, wav_out.squeeze().numpy()), seed


def build_interface():
    with gr.Blocks() as demo:
        gr.Markdown("# ✨ zonos tts generator ✨")

        text = gr.Textbox(label="text", value="hello, world!", lines=4, max_length=500)
        language = gr.Dropdown(
            choices=supported_language_codes, value="en-us", label="language"
        )
        speaker_audio = gr.Audio(label="voice reference", type="filepath")

        clarity_slider = gr.Slider(0.5, 0.8, 0.78, 0.01, label="clarity")
        steps_slider = gr.Slider(1, 3000, 300, 1, label="steps")

        dnsmos_slider = gr.Slider(1.0, 5.0, 4.0, 0.1, label="quality")
        fmax_slider = gr.Slider(0, 24000, 24000, 1, label="fmax")
        pitch_std_slider = gr.Slider(0.0, 300.0, 45.0, 1, label="pitch std")
        speaking_rate_slider = gr.Slider(5.0, 30.0, 15.0, 0.5, label="rate")

        cfg_scale_slider = gr.Slider(1.0, 5.0, 2.0, 0.1, label="guidance")
        min_p_slider = gr.Slider(0.0, 1.0, 0.15, 0.01, label="min p")

        with gr.Row():
            e1 = gr.Slider(0.0, 1.0, 0.0, 0.05, label="happy")
            e2 = gr.Slider(0.0, 1.0, 0.0, 0.05, label="sad")
            e3 = gr.Slider(0.0, 1.0, 0.0, 0.05, label="disgust")
            e4 = gr.Slider(0.0, 1.0, 0.0, 0.05, label="fear")

        with gr.Row():
            e5 = gr.Slider(0.0, 1.0, 0.0, 0.05, label="surprise")
            e6 = gr.Slider(0.0, 1.0, 0.0, 0.05, label="anger")
            e7 = gr.Slider(0.0, 1.0, 0.0, 0.05, label="other")
            e8 = gr.Slider(0.0, 1.0, 0.0, 0.05, label="neutral")

        seed_number = gr.Number(label="seed", value=420, precision=0)
        randomize_seed_toggle = gr.Checkbox(label="randomize seed", value=True)

        generate_button = gr.Button("generate")
        output_audio = gr.Audio(label="output", type="numpy", autoplay=True)

        generate_button.click(
            fn=generate_audio,
            inputs=[
                text,
                language,
                speaker_audio,
                e1,
                e2,
                e3,
                e4,
                e5,
                e6,
                e7,
                e8,
                clarity_slider,
                fmax_slider,
                pitch_std_slider,
                speaking_rate_slider,
                dnsmos_slider,
                cfg_scale_slider,
                min_p_slider,
                steps_slider,
                seed_number,
                randomize_seed_toggle,
            ],
            outputs=[output_audio, seed_number],
        )
    return demo


if __name__ == "__main__":
    build_interface().launch()