preston-cell commited on
Commit
4c9d528
·
verified ·
1 Parent(s): 8a51ef5

Create generator.py

Browse files
Files changed (1) hide show
  1. generator.py +176 -0
generator.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Tuple
3
+
4
+ import torch
5
+ import torchaudio
6
+ from huggingface_hub import hf_hub_download
7
+ from models import Model
8
+ from moshi.models import loaders
9
+ from tokenizers.processors import TemplateProcessing
10
+ from transformers import AutoTokenizer
11
+ from watermarking import CSM_1B_GH_WATERMARK, load_watermarker, watermark
12
+
13
+
14
+ @dataclass
15
+ class Segment:
16
+ speaker: int
17
+ text: str
18
+ # (num_samples,), sample_rate = 24_000
19
+ audio: torch.Tensor
20
+
21
+
22
+ def load_llama3_tokenizer():
23
+ """
24
+ https://github.com/huggingface/transformers/issues/22794#issuecomment-2092623992
25
+ """
26
+ tokenizer_name = "meta-llama/Llama-3.2-1B"
27
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
28
+ bos = tokenizer.bos_token
29
+ eos = tokenizer.eos_token
30
+ tokenizer._tokenizer.post_processor = TemplateProcessing(
31
+ single=f"{bos}:0 $A:0 {eos}:0",
32
+ pair=f"{bos}:0 $A:0 {eos}:0 {bos}:1 $B:1 {eos}:1",
33
+ special_tokens=[(f"{bos}", tokenizer.bos_token_id), (f"{eos}", tokenizer.eos_token_id)],
34
+ )
35
+
36
+ return tokenizer
37
+
38
+
39
+ class Generator:
40
+ def __init__(
41
+ self,
42
+ model: Model,
43
+ ):
44
+ self._model = model
45
+ self._model.setup_caches(1)
46
+
47
+ self._text_tokenizer = load_llama3_tokenizer()
48
+
49
+ device = next(model.parameters()).device
50
+ mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME)
51
+ mimi = loaders.get_mimi(mimi_weight, device=device)
52
+ mimi.set_num_codebooks(32)
53
+ self._audio_tokenizer = mimi
54
+
55
+ self._watermarker = load_watermarker(device=device)
56
+
57
+ self.sample_rate = mimi.sample_rate
58
+ self.device = device
59
+
60
+ def _tokenize_text_segment(self, text: str, speaker: int) -> Tuple[torch.Tensor, torch.Tensor]:
61
+ frame_tokens = []
62
+ frame_masks = []
63
+
64
+ text_tokens = self._text_tokenizer.encode(f"[{speaker}]{text}")
65
+ text_frame = torch.zeros(len(text_tokens), 33).long()
66
+ text_frame_mask = torch.zeros(len(text_tokens), 33).bool()
67
+ text_frame[:, -1] = torch.tensor(text_tokens)
68
+ text_frame_mask[:, -1] = True
69
+
70
+ frame_tokens.append(text_frame.to(self.device))
71
+ frame_masks.append(text_frame_mask.to(self.device))
72
+
73
+ return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0)
74
+
75
+ def _tokenize_audio(self, audio: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
76
+ assert audio.ndim == 1, "Audio must be single channel"
77
+
78
+ frame_tokens = []
79
+ frame_masks = []
80
+
81
+ # (K, T)
82
+ audio = audio.to(self.device)
83
+ audio_tokens = self._audio_tokenizer.encode(audio.unsqueeze(0).unsqueeze(0))[0]
84
+ # add EOS frame
85
+ eos_frame = torch.zeros(audio_tokens.size(0), 1).to(self.device)
86
+ audio_tokens = torch.cat([audio_tokens, eos_frame], dim=1)
87
+
88
+ audio_frame = torch.zeros(audio_tokens.size(1), 33).long().to(self.device)
89
+ audio_frame_mask = torch.zeros(audio_tokens.size(1), 33).bool().to(self.device)
90
+ audio_frame[:, :-1] = audio_tokens.transpose(0, 1)
91
+ audio_frame_mask[:, :-1] = True
92
+
93
+ frame_tokens.append(audio_frame)
94
+ frame_masks.append(audio_frame_mask)
95
+
96
+ return torch.cat(frame_tokens, dim=0), torch.cat(frame_masks, dim=0)
97
+
98
+ def _tokenize_segment(self, segment: Segment) -> Tuple[torch.Tensor, torch.Tensor]:
99
+ """
100
+ Returns:
101
+ (seq_len, 33), (seq_len, 33)
102
+ """
103
+ text_tokens, text_masks = self._tokenize_text_segment(segment.text, segment.speaker)
104
+ audio_tokens, audio_masks = self._tokenize_audio(segment.audio)
105
+
106
+ return torch.cat([text_tokens, audio_tokens], dim=0), torch.cat([text_masks, audio_masks], dim=0)
107
+
108
+ @torch.inference_mode()
109
+ def generate(
110
+ self,
111
+ text: str,
112
+ speaker: int,
113
+ context: List[Segment],
114
+ max_audio_length_ms: float = 90_000,
115
+ temperature: float = 0.9,
116
+ topk: int = 50,
117
+ ) -> torch.Tensor:
118
+ self._model.reset_caches()
119
+
120
+ max_generation_len = int(max_audio_length_ms / 80)
121
+ tokens, tokens_mask = [], []
122
+ for segment in context:
123
+ segment_tokens, segment_tokens_mask = self._tokenize_segment(segment)
124
+ tokens.append(segment_tokens)
125
+ tokens_mask.append(segment_tokens_mask)
126
+
127
+ gen_segment_tokens, gen_segment_tokens_mask = self._tokenize_text_segment(text, speaker)
128
+ tokens.append(gen_segment_tokens)
129
+ tokens_mask.append(gen_segment_tokens_mask)
130
+
131
+ prompt_tokens = torch.cat(tokens, dim=0).long().to(self.device)
132
+ prompt_tokens_mask = torch.cat(tokens_mask, dim=0).bool().to(self.device)
133
+
134
+ samples = []
135
+ curr_tokens = prompt_tokens.unsqueeze(0)
136
+ curr_tokens_mask = prompt_tokens_mask.unsqueeze(0)
137
+ curr_pos = torch.arange(0, prompt_tokens.size(0)).unsqueeze(0).long().to(self.device)
138
+
139
+ max_seq_len = 2048
140
+ max_context_len = max_seq_len - max_generation_len
141
+ if curr_tokens.size(1) >= max_context_len:
142
+ raise ValueError(
143
+ f"Inputs too long, must be below max_seq_len - max_generation_len: {max_context_len}"
144
+ )
145
+
146
+ for _ in range(max_generation_len):
147
+ sample = self._model.generate_frame(curr_tokens, curr_tokens_mask, curr_pos, temperature, topk)
148
+ if torch.all(sample == 0):
149
+ break # eos
150
+
151
+ samples.append(sample)
152
+
153
+ curr_tokens = torch.cat([sample, torch.zeros(1, 1).long().to(self.device)], dim=1).unsqueeze(1)
154
+ curr_tokens_mask = torch.cat(
155
+ [torch.ones_like(sample).bool(), torch.zeros(1, 1).bool().to(self.device)], dim=1
156
+ ).unsqueeze(1)
157
+ curr_pos = curr_pos[:, -1:] + 1
158
+
159
+ audio = self._audio_tokenizer.decode(torch.stack(samples).permute(1, 2, 0)).squeeze(0).squeeze(0)
160
+
161
+ # This applies an imperceptible watermark to identify audio as AI-generated.
162
+ # Watermarking ensures transparency, dissuades misuse, and enables traceability.
163
+ # Please be a responsible AI citizen and keep the watermarking in place.
164
+ # If using CSM 1B in another application, use your own private key and keep it secret.
165
+ audio, wm_sample_rate = watermark(self._watermarker, audio, self.sample_rate, CSM_1B_GH_WATERMARK)
166
+ audio = torchaudio.functional.resample(audio, orig_freq=wm_sample_rate, new_freq=self.sample_rate)
167
+
168
+ return audio
169
+
170
+
171
+ def load_csm_1b(device: str = "cuda") -> Generator:
172
+ model = Model.from_pretrained("sesame/csm-1b")
173
+ model.to(device=device, dtype=torch.bfloat16)
174
+
175
+ generator = Generator(model)
176
+ return generator