File size: 7,077 Bytes
60aba0c
8e9a234
 
 
 
 
60aba0c
8e9a234
 
8d2f8c3
 
 
8e9a234
d8213a7
078bd75
 
d8213a7
8e9a234
 
 
 
 
 
 
 
 
 
 
 
078bd75
8e9a234
078bd75
 
d8213a7
078bd75
 
d8213a7
 
078bd75
 
60aba0c
d8213a7
 
2788c71
 
8e9a234
 
 
5b7c5f9
078bd75
 
 
 
 
 
 
 
 
 
 
 
 
 
8e9a234
2788c71
8e9a234
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ace8c2
 
 
 
 
8e9a234
 
 
 
2ace8c2
 
 
 
 
8e9a234
2ace8c2
 
 
8e9a234
2ace8c2
 
 
8e9a234
2ace8c2
 
 
 
 
8e9a234
2ace8c2
8e9a234
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a0de5e2
 
 
 
8e9a234
 
 
 
 
2ace8c2
8e9a234
 
 
 
2ace8c2
8e9a234
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ace8c2
 
3d2b6af
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import spaces
import gradio as gr
import torch
from huggingface_hub import hf_hub_download
from moshi.models import loaders, LMGen
import numpy as np



mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME)
moshi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MOSHI_NAME)


@spaces.GPU
def compute_codes(wav):
    """wav = torch.randn(1, 1, 24000 * 10)  # should be [B, C=1, T]"""
    mimi = loaders.get_mimi(mimi_weight)
    mimi.set_num_codebooks(8)  # up to 32 for mimi, but limited to 8 for moshi.

    with torch.no_grad():
        # Supports streaming too.
        frame_size = int(mimi.sample_rate / mimi.frame_rate)
        all_codes = []
        with mimi.streaming(batch_size=1):
            for offset in range(0, wav.shape[-1], frame_size):
                frame = wav[:, :, offset: offset + frame_size]
                codes = mimi.encode(frame)
                assert codes.shape[-1] == 1, codes.shape
                all_codes.append(codes)
    return all_codes


@spaces.GPU
def generate_reponse(all_codes):
    """wav = torch.randn(1, 1, 24000 * 10)  # should be [B, C=1, T]"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Set up Mimi
    mimi = loaders.get_mimi(mimi_weight, device='cpu')
    mimi.set_num_codebooks(8)  # up to 32 for mimi, but limited to 8 for moshi.
    mimi.to(device)

    # Set up Moshi/LM Gen
    moshi = loaders.get_moshi_lm(moshi_weight, device='cpu')
    moshi.to(device)  # Move to GPU after loading
    lm_gen = LMGen(moshi, temp=0.8, temp_text=0.7)  # this handles sampling params etc.

    out_wav_chunks = []


    # Now we will stream over both Moshi I/O, and decode on the fly with Mimi.
    with torch.no_grad(), lm_gen.streaming(1), mimi.streaming(1):
        for idx, code in enumerate(all_codes):
            # print("CODE: ", code.shape)
            tokens_out = lm_gen.step(code.to(device))
            # tokens_out is [B, 1 + 8, 1], with tokens_out[:, 1] representing the text token.
            if tokens_out is not None:
                wav_chunk = mimi.decode(tokens_out[:, 1:])
                out_wav_chunks.append(wav_chunk)
              
            print(idx, end='\r')



    return torch.cat(out_wav_chunks, dim=-1)

def convert2wav(audio):
    if audio is None:
        return None
    
    sr, data = audio
    
    # Convert to mono if stereo
    if len(data.shape) > 1:
        data = np.mean(data, axis=1)
    
    # Convert to torch tensor
    wav = torch.from_numpy(data).float()
    
    # Reshape to (1, 1, samples)
    wav = wav.unsqueeze(0).unsqueeze(0)
    
    # Resample to 24000 Hz if necessary
    if sr != 24000:
        wav = torch.nn.functional.interpolate(wav, size=24000 * 10, mode='linear', align_corners=False)
    
    # Ensure the tensor has the correct shape (1, 1, 24000 * 10)
    wav = wav[:, :, :24000 * 10]
    
    return wav



##########################################################################################################
##########################################################################################################

def process_audio(audio, instream):
    log_out = ""
    outwav = torch.randn(1, 1, 24000 * 10)
    stream = torch.randn(1, 1, 24000 * 10)


    print("Audio recieved")
    if audio is None:
        return gr.update(), instream
    
    try:
        if instream is None:
            instream = (24000, torch.randn(1, 1, 24000 * 10).squeeze().cpu().numpy())
        print("STREAM RECIEVED")
        stream = (audio[0], np.concatenate((instream[1], audio[1])))

        # Assuming instream[1] and audio[1] are valid inputs for convert2wav
        wav1 = convert2wav(instream)
        wav2 = convert2wav(audio)

        # Concatenate along the last dimension (time axis)
        combined_wav = torch.cat((wav1, wav2), dim=2)
        print("WAV COMBINED")

        mimi_codes = compute_codes(combined_wav)
        print("CODES COMPUTED")
        outwav = generate_reponse(mimi_codes)
    except Exception as e:
        return gr.update(value=None), (24000, outwav.squeeze().cpu().numpy()), stream, gr.update(visible=True,value=f"LOG: {e}")

    return gr.update(value=None), (24000, outwav.squeeze().cpu().numpy()), stream, gr.update(visible=False)


with gr.Blocks() as demo:
    gr.Markdown("# Moshi Demo")
    gr.Markdown(" ")
    gr.Markdown("-----------")

    gr.Markdown("### Model Description")
    gr.Markdown("""Moshi is a speech-text foundation model that casts spoken dialogue as speech-to-speech generation. Starting from a text language model backbone, Moshi generates speech as tokens from the residual quantizer of a neural audio codec, while modeling separately its own speech and that of the user into parallel streams. This allows for the removal of explicit speaker turns, and the modeling of arbitrary conversational dynamics.
Moshi also predicts time-aligned text tokens as a prefix to audio tokens. This “Inner
Monologue” method significantly improves the linguistic quality of generated speech and provides streaming speech recognition and text-to-speech. As a result, Moshi is the first real-time full-duplex spoken large language model, with a theoretical latency of 160ms, 200ms in practice. 
""")
    gr.Markdown("""
                - **Developed by:**  Kyutai
                - **Model type:** Multimodal speech-text foundation model
                - **Language(s) (NLP):** English
                - **License:** CC-BY""")

    gr.Markdown("### Model Sources ")
    gr.Markdown("""
                - **Repository:** [repo](https://github.com/kyutai-labs/moshi)
                - **Paper:** [paper](http://kyutai.org/Moshi.pdf)
                - **Demo:** [demo](https://moshi.chat/) """)


    gr.Markdown("""
                🚨 
                The Model will produce a lot of silence, because it is actually meant to stream the input and output.
                I will try to create a demo which works with the streaming.""")

    input_audio = gr.Audio(sources="microphone", label="Input Audio")
    output_audio = gr.Audio(label="Processed Audio", streaming=True, autoplay=True)
    stream = gr.State()

    log_out = gr.Textbox("Log", visible=False)

    input_audio.stop_recording(
        fn=process_audio,
        inputs=[input_audio, stream],
        outputs=[input_audio, output_audio, stream, log_out]
    )

    with gr.Row():
        with gr.Accordion("📙 Citation", open=False):
            gr.Textbox(
                value="""@techreport{kyutai2024moshi,
    author = {Alexandre D\'efossez and Laurent Mazar\'e and Manu Orsini and Am\'elie Royer and Patrick P\'erez and Herv\'e J\'egou and Edouard Grave and Neil Zeghidour},
    title = {Moshi: a speech-text foundation model for real-time dialogue},
    institution = {Kyutai},
    year={2024},
    month={September},
    url={http://kyutai.org/Moshi.pdf},
}
""", lines=7,
                label="Copy the BibTeX snippet to cite this source",
                elem_id="citation-button",
                show_copy_button=True,
            )

    
demo.launch(debug=True)