Spaces:
Paused
Paused
File size: 8,002 Bytes
60aba0c 8e9a234 43e6c08 60aba0c 8e9a234 43e6c08 8d2f8c3 8e9a234 d8213a7 078bd75 d8213a7 8e9a234 43e6c08 8e9a234 dd7a48c 078bd75 8e9a234 078bd75 d8213a7 078bd75 d8213a7 078bd75 60aba0c d8213a7 2788c71 8e9a234 5b7c5f9 078bd75 43e6c08 078bd75 8e9a234 2788c71 8e9a234 43e6c08 8e9a234 2ace8c2 dd7a48c 2ace8c2 8e9a234 43e6c08 8e9a234 2ace8c2 43e6c08 2ace8c2 8e9a234 2ace8c2 43e6c08 2ace8c2 8e9a234 2ace8c2 43e6c08 2ace8c2 8e9a234 43e6c08 2ace8c2 43e6c08 2ace8c2 43e6c08 2ace8c2 43e6c08 8e9a234 2ace8c2 8e9a234 a0de5e2 8e9a234 2ace8c2 8e9a234 2ace8c2 8e9a234 2ace8c2 3d2b6af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 |
import spaces
import gradio as gr
import torch
from huggingface_hub import hf_hub_download
from moshi.models import loaders, LMGen
import numpy as np
from tqdm import tqdm
MAX_LENGTH = 24000 * 5 # For example, 30 seconds of audio at 24kHz
mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME)
moshi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MOSHI_NAME)
@spaces.GPU
def compute_codes(wav):
"""wav = torch.randn(1, 1, 24000 * 10) # should be [B, C=1, T]"""
mimi = loaders.get_mimi(mimi_weight)
mimi.set_num_codebooks(8) # up to 32 for mimi, but limited to 8 for moshi.
with torch.no_grad():
# Supports streaming too.
frame_size = int(mimi.sample_rate / mimi.frame_rate)
all_codes = []
with mimi.streaming(batch_size=1):
for offset in tqdm(range(0, wav.shape[-1], frame_size), desc="computing Codes"):
frame = wav[:, :, offset: offset + frame_size]
codes = mimi.encode(frame)
if codes.shape[-1] == 1:
all_codes.append(codes)
else:
print(f"Warning: Empty codes for frame at offset {offset}")
return all_codes
@spaces.GPU
def generate_reponse(all_codes):
"""wav = torch.randn(1, 1, 24000 * 10) # should be [B, C=1, T]"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Set up Mimi
mimi = loaders.get_mimi(mimi_weight, device='cpu')
mimi.set_num_codebooks(8) # up to 32 for mimi, but limited to 8 for moshi.
mimi.to(device)
# Set up Moshi/LM Gen
moshi = loaders.get_moshi_lm(moshi_weight, device='cpu')
moshi.to(device) # Move to GPU after loading
lm_gen = LMGen(moshi, temp=0.8, temp_text=0.7) # this handles sampling params etc.
out_wav_chunks = []
# Now we will stream over both Moshi I/O, and decode on the fly with Mimi.
with torch.no_grad(), lm_gen.streaming(1), mimi.streaming(1):
for idx, code in tqdm(enumerate(all_codes), desc="generate tokens"):
# print("CODE: ", code.shape)
tokens_out = lm_gen.step(code.to(device))
# tokens_out is [B, 1 + 8, 1], with tokens_out[:, 1] representing the text token.
if tokens_out is not None:
wav_chunk = mimi.decode(tokens_out[:, 1:])
out_wav_chunks.append(wav_chunk)
print(idx, end='\r')
return torch.cat(out_wav_chunks, dim=-1)
def convert2wav(audio):
if audio is None:
return None
sr, data = audio
# Convert to mono if stereo
if len(data.shape) > 1:
data = np.mean(data, axis=1)
# Convert to torch tensor
wav = torch.from_numpy(data).float()
# Reshape to (1, 1, samples)
wav = wav.unsqueeze(0).unsqueeze(0)
# Resample to 24000 Hz if necessary
if sr != 24000:
wav = torch.nn.functional.interpolate(wav, size=24000 * 10, mode='linear', align_corners=False)
# Ensure the tensor has the correct shape (1, 1, 24000 * 10)
wav = wav[:, :, :24000 * 10]
return wav
def truncate_audio(wav, max_length):
if wav.shape[2] > max_length:
return wav[:, :, -max_length:]
return wav
##########################################################################################################
##########################################################################################################
def process_audio(audio, instream):
log_out = ""
outwav = torch.randn(1, 1, 24000 * 2)
stream = torch.randn(1, 1, 24000 * 2)
print("Audio recieved")
if audio is None:
return gr.update(), (24000, outwav.squeeze().cpu().numpy()), instream, gr.update(visible=True,value=f"Audio is None")
try:
if instream is None:
instream = (24000, torch.randn(1, 1, 24000 * 10).squeeze().cpu().numpy())
print("1. COMBINE AUDIO WITH PREVIOUS CONVERSATION TO STORE")
stream = (audio[0], np.concatenate((instream[1], audio[1])))
# Assuming instream[1] and audio[1] are valid inputs for convert2wav
print("2. CONVERT AUDIO TO WAV")
wav1 = convert2wav(instream)
wav2 = convert2wav(audio)
# Concatenate along the last dimension (time axis)
print("3. COMBINE AUDIOS TO A SINGLE STREAM")
combined_wav = torch.cat((wav1, wav2), dim=2)
# Truncate Audio to a defined length to recude computational efforts
print("4. TRUNCATE AUDIO LENGTH TO GIVEN DURATION")
combined_wav = truncate_audio(combined_wav, MAX_LENGTH)
# Preprocessing, convert the audio into the processable codes/tokens
print("5. COMPUTE CODES")
mimi_codes = compute_codes(combined_wav)
# Generation of the Model's reponse
print("6. GENRATE TOKENS")
outwav = generate_reponse(mimi_codes)
except Exception as e:
return gr.update(value=None), (24000, outwav.squeeze().cpu().numpy()), stream, gr.update(visible=True,value=f"LOG: \n{e}")
return gr.update(value=None), (24000, outwav.squeeze().cpu().numpy()), stream, gr.update(visible=False)
with gr.Blocks() as demo:
gr.Markdown("# Moshi Demo")
gr.Markdown(" ")
gr.Markdown("-----------")
gr.Markdown("### Model Description")
gr.Markdown("""Moshi is a speech-text foundation model that casts spoken dialogue as speech-to-speech generation. Starting from a text language model backbone, Moshi generates speech as tokens from the residual quantizer of a neural audio codec, while modeling separately its own speech and that of the user into parallel streams. This allows for the removal of explicit speaker turns, and the modeling of arbitrary conversational dynamics.
Moshi also predicts time-aligned text tokens as a prefix to audio tokens. This “Inner
Monologue” method significantly improves the linguistic quality of generated speech and provides streaming speech recognition and text-to-speech. As a result, Moshi is the first real-time full-duplex spoken large language model, with a theoretical latency of 160ms, 200ms in practice.
""")
gr.Markdown("""
- **Developed by:** Kyutai
- **Model type:** Multimodal speech-text foundation model
- **Language(s) (NLP):** English
- **License:** CC-BY""")
gr.Markdown("### Model Sources ")
gr.Markdown("""
- **Repository:** [repo](https://github.com/kyutai-labs/moshi)
- **Paper:** [paper](http://kyutai.org/Moshi.pdf)
- **Demo:** [demo](https://moshi.chat/) """)
gr.Markdown("""
🚨
The Model will produce a lot of silence, because it is actually meant to stream the input and output.
I will try to create a demo which works with the streaming.""")
input_audio = gr.Audio(sources="microphone", label="Input Audio")
output_audio = gr.Audio(label="Processed Audio", streaming=True, autoplay=True)
stream = gr.State()
log_out = gr.Textbox("Log", visible=False)
input_audio.stop_recording(
fn=process_audio,
inputs=[input_audio, stream],
outputs=[input_audio, output_audio, stream, log_out]
)
with gr.Row():
with gr.Accordion("📙 Citation", open=False):
gr.Textbox(
value="""@techreport{kyutai2024moshi,
author = {Alexandre D\'efossez and Laurent Mazar\'e and Manu Orsini and Am\'elie Royer and Patrick P\'erez and Herv\'e J\'egou and Edouard Grave and Neil Zeghidour},
title = {Moshi: a speech-text foundation model for real-time dialogue},
institution = {Kyutai},
year={2024},
month={September},
url={http://kyutai.org/Moshi.pdf},
}
""", lines=7,
label="Copy the BibTeX snippet to cite this source",
elem_id="citation-button",
show_copy_button=True,
)
demo.launch(debug=True) |