Spaces:
Sleeping
Sleeping
File size: 5,136 Bytes
bdb7ac2 9e3c2e9 bdb7ac2 9e3c2e9 bdb7ac2 9e3c2e9 bdb7ac2 9e3c2e9 bdb7ac2 d75aa56 bdb7ac2 9e3c2e9 bdb7ac2 9e3c2e9 d75aa56 bdb7ac2 d75aa56 bdb7ac2 9e3c2e9 d75aa56 bdb7ac2 9e3c2e9 bdb7ac2 9e3c2e9 bdb7ac2 d75aa56 bdb7ac2 9e3c2e9 d75aa56 bdb7ac2 9e3c2e9 d75aa56 9e3c2e9 d75aa56 9e3c2e9 bdb7ac2 d75aa56 bdb7ac2 9e3c2e9 d75aa56 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
from dataclasses import dataclass
from typing import List, Tuple
import torch
from huggingface_hub import hf_hub_download
from models import Model
from moshi.models import loaders
from tokenizers.processors import TemplateProcessing
from transformers import AutoTokenizer
@dataclass
class Segment:
speaker: int
text: str
audio: torch.Tensor
def load_llama3_tokenizer():
tokenizer_name = "meta-llama/Llama-3.2-1B"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
bos = tokenizer.bos_token
eos = tokenizer.eos_token
tokenizer._tokenizer.post_processor = TemplateProcessing(
single=f"{bos}:0 $A:0 {eos}:0",
pair=f"{bos}:0 $A:0 {eos}:0 {bos}:1 $B:1 {eos}:1",
special_tokens=[(f"{bos}", tokenizer.bos_token_id), (f"{eos}", tokenizer.eos_token_id)],
)
return tokenizer
class Generator:
def __init__(self, model: Model):
self._model = model
self._model.setup_caches(1)
self._text_tokenizer = load_llama3_tokenizer()
device = next(model.parameters()).device
mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME)
mimi = loaders.get_mimi(mimi_weight, device=device)
mimi.set_num_codebooks(32)
self._audio_tokenizer = mimi
self.sample_rate = mimi.sample_rate
self.device = device
def _tokenize_text_segment(self, text: str, speaker: int) -> Tuple[torch.Tensor, torch.Tensor]:
text_tokens = self._text_tokenizer.encode(f"[{speaker}]{text}")
text_frame = torch.zeros(len(text_tokens), 33).long()
text_frame_mask = torch.zeros(len(text_tokens), 33).bool()
text_frame[:, -1] = torch.tensor(text_tokens)
text_frame_mask[:, -1] = True
return text_frame.to(self.device), text_frame_mask.to(self.device)
def _tokenize_audio(self, audio: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
assert audio.ndim == 1
audio = audio.to(self.device)
audio_tokens = self._audio_tokenizer.encode(audio.unsqueeze(0).unsqueeze(0))[0]
eos_frame = torch.zeros(audio_tokens.size(0), 1).to(self.device)
audio_tokens = torch.cat([audio_tokens, eos_frame], dim=1)
audio_frame = torch.zeros(audio_tokens.size(1), 33).long().to(self.device)
audio_frame_mask = torch.zeros(audio_tokens.size(1), 33).bool().to(self.device)
audio_frame[:, :-1] = audio_tokens.transpose(0, 1)
audio_frame_mask[:, :-1] = True
return audio_frame, audio_frame_mask
def _tokenize_segment(self, segment: Segment) -> Tuple[torch.Tensor, torch.Tensor]:
text_tokens, text_masks = self._tokenize_text_segment(segment.text, segment.speaker)
audio_tokens, audio_masks = self._tokenize_audio(segment.audio)
return torch.cat([text_tokens, audio_tokens], dim=0), torch.cat([text_masks, audio_masks], dim=0)
@torch.inference_mode()
def generate(
self,
text: str,
speaker: int,
context: List[Segment],
max_audio_length_ms: float = 90_000,
temperature: float = 0.9,
topk: int = 50,
top_p: float = 1.0,
) -> torch.Tensor:
self._model.reset_caches()
max_generation_len = int(max_audio_length_ms / 80)
tokens, tokens_mask = [], []
for segment in context:
segment_tokens, segment_tokens_mask = self._tokenize_segment(segment)
tokens.append(segment_tokens)
tokens_mask.append(segment_tokens_mask)
gen_segment_tokens, gen_segment_tokens_mask = self._tokenize_text_segment(text, speaker)
tokens.append(gen_segment_tokens)
tokens_mask.append(gen_segment_tokens_mask)
prompt_tokens = torch.cat(tokens, dim=0).long().to(self.device)
prompt_tokens_mask = torch.cat(tokens_mask, dim=0).bool().to(self.device)
samples = []
curr_tokens = prompt_tokens.unsqueeze(0)
curr_tokens_mask = prompt_tokens_mask.unsqueeze(0)
curr_pos = torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device)
max_seq_len = 2048
max_context_len = max_seq_len - max_generation_len
if curr_tokens.size(1) >= max_context_len:
raise ValueError("Input too long")
for _ in range(max_generation_len):
sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk, top_p)
if torch.all(sample == 0):
break
samples.append(sample)
curr_tokens = torch.cat([sample, torch.zeros(1, 1).long().to(self.device)], dim=1).unsqueeze(1)
curr_tokens_mask = torch.cat(
[torch.ones_like(sample).bool(), torch.zeros(1, 1).bool().to(self.device)], dim=1
).unsqueeze(1)
curr_pos = curr_pos[:, -1:] + 1
return self._audio_tokenizer.decode(torch.stack(samples).permute(1, 2, 0)).squeeze(0).squeeze(0)
def load_csm_1b(device: str = "cuda") -> Generator:
model = Model.from_pretrained("xlr8harder/csm-1b-tahm-kench")
model.to(device=device, dtype=torch.bfloat16)
return Generator(model)
|