File size: 3,560 Bytes
7b539b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import sys
print(sys.path)
sys.path.append('/home/user/audio_ai/diffusers_harp/venv/src')

from pyharp import ModelCard, build_endpoint, save_and_return_filepath

from audiotools import AudioSignal
import scipy
import torch
import gradio as gr
from diffusers import AudioLDM2Pipeline
import subprocess as sp


#harp_deps = [
#"descript-audiotools"]
#
#try:
#    from pyharp import ModelCard, build_endpoint, save_and_return_filepath
#except ImportError:
#    print("Installing harp dependencies...")
#    sp.check_call(["pip", "install", *harp_deps])
#    sp.check_call(["pip", "install", "-e git+https://github.com/audacitorch/pyharp.git#egg=pyharp"])
#    sp.check_call(["pip", "install", "pydantic<2.0.0"])
#    from pyharp import ModelCard, build_endpoint, save_and_return_filepath

# Create a Model Card
card = ModelCard(
    name='Diffusers AudioLDM2 Style Transfer',
    description='AudioLDM2 style transfer, operates on region selected in track.',
    author='Team Audio',
    tags=['AudioLDM', 'Diffusers', 'Style Transfer']
)

# Load the model
repo_id = "cvssp/audioldm2"
pipe = AudioLDM2Pipeline.from_pretrained(repo_id, torch_dtype=torch.float16)
pipe = pipe.to("cuda")



def process_fn(input_audio_path, prompt, negative_prompt, seed, num_inference_steps, audio_length_in_s, num_waveforms_per_prompt):
    """
    This function defines the audio processing steps

    Args:
        input_audio_path (str): the audio filepath to be processed.

        <YOUR_KWARGS>: additional keyword arguments necessary for processing.
            NOTE: These should correspond to and match order of UI elements defined below.

    Returns:
        output_audio_path (str): the filepath of the processed audio.
    """

    sig = AudioSignal(input_audio_path)
    outfile = "./output.wav"

    #prompt = "The sound of a hammer hitting a wooden surface."
    #negative_prompt = "Low quality."

    # set the seed for generator
    generator = torch.Generator("cuda").manual_seed(int(seed))

    audio = pipe(
        prompt,
        negative_prompt=negative_prompt,
        num_inference_steps=num_inference_steps,
        audio_length_in_s=audio_length_in_s,
        num_waveforms_per_prompt=num_waveforms_per_prompt,
        generator=generator,
    ).audios

    scipy.io.wavfile.write(outfile, rate=16000, data=audio[0])
    return outfile


# Build the endpoint
with gr.Blocks() as webapp:
    # Define your Gradio interface
    inputs = [
        gr.Audio(
            label="Audio Input", 
            type="filepath"
        ), 
        gr.Text(
            label="Prompt", 
            interactive=True
        ),
        gr.Text(
            label="Negative Prompt", 
            interactive=True
        ),
        gr.Slider(
            label="seed",
            minimum="0",
            maximum="65535",
            value="0",
            step="1"
        ),
        gr.Slider(
            minimum=1, maximum=500, 
            step=1, value=1, 
            label="Inference Steps"
        ),
        gr.Slider(
            minimum=2.5, maximum=10.0, 
            step=2.5, value=2.5, 
            label="Duration"
        ),
        gr.Slider(
            minimum=1, maximum=10, 
            step=1, value=1, 
            label="Waveforms Per Prompt"
        ),
    ]
    
    # make an output audio widget
    output = gr.Audio(label="Audio Output", type="filepath")

    # Build the endpoint
    ctrls_data, ctrls_button, process_button, cancel_button = build_endpoint(inputs, output, process_fn, card)

#webapp.queue()
webapp.launch(share=True)