XXXXRT666 commited on
Commit
7619997
·
1 Parent(s): 5cfeca6
AR/models/t2s_model_abc.py CHANGED
@@ -5,10 +5,10 @@ Modified From https://github.com/XXXXRT666/GPT-SoVITS
5
  from __future__ import annotations
6
 
7
  import os
 
8
  from abc import ABC, abstractmethod
9
  from contextlib import nullcontext
10
  from typing import Any, Dict, List, MutableSequence, Optional, Tuple, Type
11
- import time
12
 
13
  import torch
14
  import torch._inductor.config
@@ -30,138 +30,6 @@ class Sampler(nn.Module):
30
  super().__init__()
31
  self.batch_size = batch_size
32
 
33
- self.logits: Tensor
34
- self.samples: Tensor
35
- self.register_buffer("logits", torch.zeros((batch_size, vocab_size)), persistent=False)
36
- self.register_buffer("samples", torch.zeros((batch_size,), dtype=torch.int32), persistent=False)
37
-
38
- self.__CUDAGraph: Optional[CUDAGraph] = None
39
-
40
-
41
- def empty_cache(self):
42
- self.logits.zero_()
43
- self.__CUDAGraph = None
44
-
45
- @staticmethod
46
- def multinomial_sample_one_no_sync(probs_sort: Tensor): # Does multinomial sampling without a cuda synchronization
47
- q = torch.empty_like(probs_sort).exponential_(1)
48
- return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int32)
49
-
50
- @staticmethod
51
- def logits_to_probs(
52
- logits: Tensor,
53
- previous_tokens: Tensor,
54
- temperature: float,
55
- top_k: int,
56
- top_p: float,
57
- repetition_penalty: float,
58
- ):
59
- previous_tokens = previous_tokens.long()
60
- score = torch.gather(logits, dim=1, index=previous_tokens)
61
- score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty)
62
- logits.scatter_(dim=1, index=previous_tokens, src=score)
63
-
64
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
65
- cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
66
- sorted_indices_to_remove = cum_probs > top_p
67
- sorted_indices_to_remove[:, 0] = False # keep at least one option
68
- indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
69
- logits = logits.masked_fill(indices_to_remove, -float("Inf"))
70
-
71
- logits = logits / max(temperature, 1e-5)
72
-
73
- v, _ = torch.topk(logits, top_k)
74
- pivot = v[:, -1].unsqueeze(-1)
75
- logits = torch.where(logits < pivot, -float("Inf"), logits)
76
-
77
- probs = torch.nn.functional.softmax(logits, dim=-1)
78
- return probs
79
-
80
- @staticmethod
81
- def apply_repetition_penalty(logits: Tensor, previous_tokens: Tensor, repetition_penalty: float):
82
- previous_tokens = previous_tokens.long()
83
- score = torch.gather(logits, dim=1, index=previous_tokens)
84
- score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty)
85
- logits.scatter_(dim=1, index=previous_tokens, src=score)
86
- return logits
87
-
88
- @staticmethod
89
- def logits_to_probs_cuda_graph(
90
- logits: Tensor,
91
- temperature: float,
92
- top_k: int,
93
- top_p: float,
94
- ):
95
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
96
- cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
97
- sorted_indices_to_remove = cum_probs > top_p
98
- sorted_indices_to_remove[:, 0] = False # keep at least one option
99
- indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
100
- logits = logits.masked_fill(indices_to_remove, -float("Inf"))
101
-
102
- logits = logits / max(temperature, 1e-5)
103
-
104
- v, _ = torch.topk(logits, top_k)
105
- pivot = v[:, -1].unsqueeze(-1)
106
- logits = torch.where(logits < pivot, -float("Inf"), logits)
107
-
108
- probs = torch.nn.functional.softmax(logits, dim=-1)
109
- return probs
110
-
111
- def __sample(
112
- self,
113
- logits: Tensor,
114
- previous_tokens: Tensor,
115
- temperature: float,
116
- top_k: int,
117
- top_p: float,
118
- repetition_penalty: float,
119
- ) -> Tuple[torch.Tensor, torch.Tensor]:
120
- probs = self.logits_to_probs(
121
- logits=logits,
122
- previous_tokens=previous_tokens,
123
- temperature=temperature,
124
- top_k=top_k,
125
- top_p=top_p,
126
- repetition_penalty=repetition_penalty,
127
- )
128
- idx_next = self.multinomial_sample_one_no_sync(probs)
129
- return idx_next, probs
130
-
131
- def __sample_cuda_graph(
132
- self,
133
- logits: Tensor,
134
- temperature: float,
135
- top_k: int,
136
- top_p: float,
137
- ):
138
- probs = self.logits_to_probs_cuda_graph(
139
- logits=logits,
140
- temperature=temperature,
141
- top_k=top_k,
142
- top_p=top_p,
143
- )
144
- idx_next = self.multinomial_sample_one_no_sync(probs)
145
- return idx_next
146
-
147
- def capture(self, temperature: float, top_k: int, top_p: float):
148
- t1=time.perf_counter()
149
- s = torch.cuda.Stream()
150
- s.wait_stream(torch.cuda.current_stream())
151
-
152
- logits = self.logits
153
-
154
- with torch.cuda.stream(s): # type: ignore
155
- for _ in range(5):
156
- self.__sample_cuda_graph(logits, temperature, top_k, top_p)
157
- torch.cuda.current_stream().wait_stream(s)
158
-
159
- self.__CUDAGraph = torch.cuda.CUDAGraph()
160
- with torch.cuda.graph(self.__CUDAGraph):
161
- self.samples = self.__sample_cuda_graph(logits, temperature, top_k, top_p)
162
- torch.cuda.synchronize()
163
- print("Sample",time.perf_counter()-t1)
164
-
165
  # @torch.jit.script
166
  def sample(
167
  self,
@@ -172,7 +40,6 @@ class Sampler(nn.Module):
172
  top_p: float,
173
  repetition_penalty: float,
174
  ) -> Tensor:
175
-
176
  previous_tokens = previous_tokens.long()
177
  score = torch.gather(logits, dim=1, index=previous_tokens)
178
  score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty)
@@ -198,7 +65,6 @@ class Sampler(nn.Module):
198
  return idx_next
199
 
200
 
201
-
202
  class KVCacheABC(ABC, nn.Module):
203
  def __init__(self, *args, **kwds) -> None:
204
  super().__init__()
 
5
  from __future__ import annotations
6
 
7
  import os
8
+ import time
9
  from abc import ABC, abstractmethod
10
  from contextlib import nullcontext
11
  from typing import Any, Dict, List, MutableSequence, Optional, Tuple, Type
 
12
 
13
  import torch
14
  import torch._inductor.config
 
30
  super().__init__()
31
  self.batch_size = batch_size
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  # @torch.jit.script
34
  def sample(
35
  self,
 
40
  top_p: float,
41
  repetition_penalty: float,
42
  ) -> Tensor:
 
43
  previous_tokens = previous_tokens.long()
44
  score = torch.gather(logits, dim=1, index=previous_tokens)
45
  score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty)
 
65
  return idx_next
66
 
67
 
 
68
  class KVCacheABC(ABC, nn.Module):
69
  def __init__(self, *args, **kwds) -> None:
70
  super().__init__()
AR/models/t2s_model_flash_attn.py CHANGED
@@ -5,10 +5,10 @@ Modified From https://github.com/XXXXRT666/GPT-SoVITS
5
  import os
6
  import time
7
  import traceback
8
- from typing import Dict, List, Tuple,Optional
9
- import gradio as gr
10
 
11
  import flash_attn # type: ignore
 
12
  import torch
13
  import torch.nn as nn
14
  from tqdm import tqdm
@@ -54,7 +54,7 @@ class Attention(AttentionABC):
54
 
55
  attn: Tensor = flash_attn.flash_attn_with_kvcache(
56
  q, kv_cache.k_cache, kv_cache.v_cache, k, v, cache_seqlens=input_pos - 1
57
- ) # type: ignore
58
 
59
  attn = self.dropout.forward(attn)
60
 
@@ -219,10 +219,10 @@ class CUDAGraphRunner:
219
 
220
  self.decoder_path: os.PathLike
221
  self.decoder_model: T2SDecoderABC = decoder_model.to(self.device, self.dtype)
222
-
223
- self.graph: Optional[torch.cuda.CUDAGraph]= None
224
- self.xy_pos_ = torch.rand((1, 1, decoder_model.embedding_dim),device=device).to(dtype)
225
- self.xy_dec_ = torch.rand((1, 1, decoder_model.embedding_dim),device=device).to(dtype)
226
  self.kv_cache = decoder_model.init_cache(1)
227
  self.input_pos = torch.tensor([10]).int().cuda()
228
 
@@ -230,13 +230,13 @@ class CUDAGraphRunner:
230
  with self.device:
231
  for i in self.kv_cache:
232
  i.empty()
233
-
234
  decoder = self.decoder_model
235
  session = T2SSession(decoder, request, device=self.device, dtype=self.dtype)
236
  self.input_pos.copy_(session.input_pos)
237
-
238
  t1 = 0.0
239
- y = session.y
240
  bsz = y.size(0)
241
  torch_profiler = TorchProfiler(request.debug)
242
  with torch_profiler.profiler():
@@ -271,14 +271,14 @@ class CUDAGraphRunner:
271
  *args,
272
  **kwds,
273
  )
274
-
275
  decoder.post_forward(idx, session)
276
  logits = decoder.ar_predict_layer(xy_dec[:, -1])
277
  self.input_pos.add_(1)
278
 
279
  if idx == 0:
280
  logits[:, -1] = float("-inf")
281
-
282
  with torch_profiler.record("Sampling"):
283
  samples = session.sampler.sample(
284
  logits=logits,
@@ -291,22 +291,20 @@ class CUDAGraphRunner:
291
 
292
  session.y = torch.cat([session.y, samples], dim=1)
293
 
294
-
295
  with torch_profiler.record("EOS"):
296
  argmax_token = torch.argmax(logits, dim=-1)
297
  sample_token = samples.squeeze(1)
298
  EOS_mask = (argmax_token == decoder.EOS) | (sample_token == decoder.EOS)
299
-
300
  newly_done_mask = EOS_mask & (~session.completed)
301
  newly_done_indices = newly_done_mask.nonzero()
302
-
303
-
304
  if newly_done_indices.numel() > 0:
305
  session.y_results[newly_done_indices[0]] = session.y[
306
  newly_done_indices[0], session.y_len : -1
307
  ].squeeze(0)
308
  session.completed[newly_done_indices] = True
309
-
310
  if torch.all(session.completed).item():
311
  if session.y.size(1) == 0:
312
  session.y = torch.cat([session.y, torch.zeros_like(samples)], dim=1)
@@ -316,13 +314,15 @@ class CUDAGraphRunner:
316
  f"T2S Decoding EOS {session.prefill_len.tolist().__str__().strip('[]')} -> \n{[i.size(0) for i in session.y_results].__str__().strip('[]')}"
317
  )
318
  tqdm.write(f"Infer Speed: {(idx - 1) / (time.perf_counter() - t1):.2f} token/s")
319
- gr.Info(f"Infer Speed: {(idx - 1) / (time.perf_counter() - t1):.2f} token/s",duration=0.75)
 
 
320
  break
321
-
322
  if (
323
- (request.early_stop_num != -1
324
- and (session.y.size(1) - session.y_len) > request.early_stop_num )or idx ==1499
325
- ):
326
  for i in range(bsz):
327
  if not session.completed[i].item():
328
  session.y_results[i] = session.y[i, session.y_len :]
@@ -339,7 +339,7 @@ class CUDAGraphRunner:
339
 
340
  if idx == 51:
341
  torch_profiler.end()
342
-
343
  if idx % 100 == 0:
344
  match session.device.type:
345
  case "cuda":
@@ -360,7 +360,7 @@ class CUDAGraphRunner:
360
  torch.xpu.empty_cache()
361
  case "mtia":
362
  torch.mtia.empty_cache()
363
-
364
  torch_profiler.end()
365
  return session.y_results[: request.valid_length]
366
 
 
5
  import os
6
  import time
7
  import traceback
8
+ from typing import Dict, List, Optional, Tuple
 
9
 
10
  import flash_attn # type: ignore
11
+ import gradio as gr
12
  import torch
13
  import torch.nn as nn
14
  from tqdm import tqdm
 
54
 
55
  attn: Tensor = flash_attn.flash_attn_with_kvcache(
56
  q, kv_cache.k_cache, kv_cache.v_cache, k, v, cache_seqlens=input_pos - 1
57
+ ) # type: ignore
58
 
59
  attn = self.dropout.forward(attn)
60
 
 
219
 
220
  self.decoder_path: os.PathLike
221
  self.decoder_model: T2SDecoderABC = decoder_model.to(self.device, self.dtype)
222
+
223
+ self.graph: Optional[torch.cuda.CUDAGraph] = None
224
+ self.xy_pos_ = torch.rand((1, 1, decoder_model.embedding_dim), device=device).to(dtype)
225
+ self.xy_dec_ = torch.rand((1, 1, decoder_model.embedding_dim), device=device).to(dtype)
226
  self.kv_cache = decoder_model.init_cache(1)
227
  self.input_pos = torch.tensor([10]).int().cuda()
228
 
 
230
  with self.device:
231
  for i in self.kv_cache:
232
  i.empty()
233
+
234
  decoder = self.decoder_model
235
  session = T2SSession(decoder, request, device=self.device, dtype=self.dtype)
236
  self.input_pos.copy_(session.input_pos)
237
+
238
  t1 = 0.0
239
+ y = session.y
240
  bsz = y.size(0)
241
  torch_profiler = TorchProfiler(request.debug)
242
  with torch_profiler.profiler():
 
271
  *args,
272
  **kwds,
273
  )
274
+
275
  decoder.post_forward(idx, session)
276
  logits = decoder.ar_predict_layer(xy_dec[:, -1])
277
  self.input_pos.add_(1)
278
 
279
  if idx == 0:
280
  logits[:, -1] = float("-inf")
281
+
282
  with torch_profiler.record("Sampling"):
283
  samples = session.sampler.sample(
284
  logits=logits,
 
291
 
292
  session.y = torch.cat([session.y, samples], dim=1)
293
 
 
294
  with torch_profiler.record("EOS"):
295
  argmax_token = torch.argmax(logits, dim=-1)
296
  sample_token = samples.squeeze(1)
297
  EOS_mask = (argmax_token == decoder.EOS) | (sample_token == decoder.EOS)
298
+
299
  newly_done_mask = EOS_mask & (~session.completed)
300
  newly_done_indices = newly_done_mask.nonzero()
301
+
 
302
  if newly_done_indices.numel() > 0:
303
  session.y_results[newly_done_indices[0]] = session.y[
304
  newly_done_indices[0], session.y_len : -1
305
  ].squeeze(0)
306
  session.completed[newly_done_indices] = True
307
+
308
  if torch.all(session.completed).item():
309
  if session.y.size(1) == 0:
310
  session.y = torch.cat([session.y, torch.zeros_like(samples)], dim=1)
 
314
  f"T2S Decoding EOS {session.prefill_len.tolist().__str__().strip('[]')} -> \n{[i.size(0) for i in session.y_results].__str__().strip('[]')}"
315
  )
316
  tqdm.write(f"Infer Speed: {(idx - 1) / (time.perf_counter() - t1):.2f} token/s")
317
+ gr.Info(
318
+ f"Infer Speed: {(idx - 1) / (time.perf_counter() - t1):.2f} token/s", duration=0.75
319
+ )
320
  break
321
+
322
  if (
323
+ request.early_stop_num != -1
324
+ and (session.y.size(1) - session.y_len) > request.early_stop_num
325
+ ) or idx == 1499:
326
  for i in range(bsz):
327
  if not session.completed[i].item():
328
  session.y_results[i] = session.y[i, session.y_len :]
 
339
 
340
  if idx == 51:
341
  torch_profiler.end()
342
+
343
  if idx % 100 == 0:
344
  match session.device.type:
345
  case "cuda":
 
360
  torch.xpu.empty_cache()
361
  case "mtia":
362
  torch.mtia.empty_cache()
363
+
364
  torch_profiler.end()
365
  return session.y_results[: request.valid_length]
366
 
AR/models/utils.py DELETED
@@ -1,229 +0,0 @@
1
- # modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/utils.py
2
- # reference: https://github.com/lifeiteng/vall-e
3
- import torch
4
- import torch.nn.functional as F
5
- from typing import Tuple
6
-
7
- def sequence_mask(length, max_length=None):
8
- if max_length is None:
9
- max_length = length.max()
10
- x = torch.arange(max_length, dtype=length.dtype, device=length.device)
11
- return x.unsqueeze(0) < length.unsqueeze(1)
12
-
13
-
14
- def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
15
- """
16
- Args:
17
- lengths:
18
- A 1-D tensor containing sentence lengths.
19
- max_len:
20
- The length of masks.
21
- Returns:
22
- Return a 2-D bool tensor, where masked positions
23
- are filled with `True` and non-masked positions are
24
- filled with `False`.
25
-
26
- #>>> lengths = torch.tensor([1, 3, 2, 5])
27
- #>>> make_pad_mask(lengths)
28
- tensor([[False, True, True, True, True],
29
- [False, False, False, True, True],
30
- [False, False, True, True, True],
31
- [False, False, False, False, False]])
32
- """
33
- assert lengths.ndim == 1, lengths.ndim
34
- max_len = max(max_len, lengths.max())
35
- n = lengths.size(0)
36
- seq_range = torch.arange(0, max_len, device=lengths.device)
37
- expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len)
38
-
39
- return expaned_lengths >= lengths.unsqueeze(-1)
40
-
41
-
42
- # https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
43
- def top_k_top_p_filtering(
44
- logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
45
- ):
46
- """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
47
- Args:
48
- logits: logits distribution shape (batch size, vocabulary size)
49
- if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
50
- if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
51
- Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
52
- Make sure we keep at least min_tokens_to_keep per batch example in the output
53
- From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
54
- """
55
- if top_k > 0:
56
- top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
57
- # Remove all tokens with a probability less than the last token of the top-k
58
- indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
59
- logits[indices_to_remove] = filter_value
60
-
61
- if top_p < 1.0:
62
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
63
- cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
64
-
65
- # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
66
- sorted_indices_to_remove = cumulative_probs > top_p
67
- if min_tokens_to_keep > 1:
68
- # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
69
- sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
70
- # Shift the indices to the right to keep also the first token above the threshold
71
- sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
72
- sorted_indices_to_remove[..., 0] = 0
73
-
74
- # scatter sorted tensors to original indexing
75
- indices_to_remove = sorted_indices_to_remove.scatter(
76
- 1, sorted_indices, sorted_indices_to_remove
77
- )
78
- logits[indices_to_remove] = filter_value
79
- return logits
80
-
81
-
82
- def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
83
- # temperature: (`optional`) float
84
- # The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
85
- # top_k: (`optional`) int
86
- # The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
87
- # top_p: (`optional`) float
88
- # The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
89
-
90
- # Temperature (higher temperature => more likely to sample low probability tokens)
91
- if temperature != 1.0:
92
- logits = logits / temperature
93
- # Top-p/top-k filtering
94
- logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
95
- # Sample
96
- token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
97
- return token
98
-
99
-
100
- from typing import Optional, Tuple
101
-
102
-
103
- def multinomial_sample_one_no_sync(
104
- probs_sort,
105
- ): # Does multinomial sampling without a cuda synchronization
106
- q = torch.empty_like(probs_sort).exponential_(1)
107
- return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
108
-
109
-
110
- def logits_to_probs(
111
- logits,
112
- previous_tokens: Optional[torch.Tensor] = None,
113
- temperature: float = 1.0,
114
- top_k: Optional[int] = None,
115
- top_p: Optional[int] = None,
116
- repetition_penalty: float = 1.0,
117
- ):
118
- if previous_tokens is not None:
119
- previous_tokens = previous_tokens.squeeze()
120
- # print(logits.shape,previous_tokens.shape)
121
- # pdb.set_trace()
122
- if previous_tokens is not None and repetition_penalty != 1.0:
123
- previous_tokens = previous_tokens.long()
124
- score = torch.gather(logits, dim=0, index=previous_tokens)
125
- score = torch.where(
126
- score < 0, score * repetition_penalty, score / repetition_penalty
127
- )
128
- logits.scatter_(dim=0, index=previous_tokens, src=score)
129
-
130
- if top_p is not None and top_p < 1.0:
131
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
132
- cum_probs = torch.cumsum(
133
- torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
134
- )
135
- sorted_indices_to_remove = cum_probs > top_p
136
- sorted_indices_to_remove[0] = False # keep at least one option
137
- indices_to_remove = sorted_indices_to_remove.scatter(
138
- dim=0, index=sorted_indices, src=sorted_indices_to_remove
139
- )
140
- logits = logits.masked_fill(indices_to_remove, -float("Inf"))
141
-
142
- logits = logits / max(temperature, 1e-5)
143
-
144
- if top_k is not None:
145
- v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
146
- pivot = v.select(-1, -1).unsqueeze(-1)
147
- logits = torch.where(logits < pivot, -float("Inf"), logits)
148
-
149
- probs = torch.nn.functional.softmax(logits, dim=-1)
150
- return probs
151
-
152
-
153
- def sample(
154
- logits,
155
- previous_tokens: Optional[torch.Tensor] = None,
156
- **sampling_kwargs,
157
- ) -> Tuple[torch.Tensor, torch.Tensor]:
158
- probs = logits_to_probs(
159
- logits=logits, previous_tokens=previous_tokens, **sampling_kwargs
160
- )
161
- idx_next = multinomial_sample_one_no_sync(probs)
162
- return idx_next, probs
163
-
164
- def dpo_loss(policy_chosen_logps: torch.FloatTensor,
165
- policy_rejected_logps: torch.FloatTensor,
166
- reference_chosen_logps: torch.FloatTensor,
167
- reference_rejected_logps: torch.FloatTensor,
168
- beta: float,
169
- reference_free: bool = False) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
170
- pi_logratios = policy_chosen_logps - policy_rejected_logps
171
- ref_logratios = reference_chosen_logps - reference_rejected_logps
172
-
173
- if reference_free:
174
- ref_logratios = 0
175
-
176
- logits = pi_logratios - ref_logratios
177
-
178
- losses = -F.logsigmoid(beta * logits)
179
- chosen_rewards = beta * (policy_chosen_logps - reference_chosen_logps).detach()
180
- rejected_rewards = beta * (policy_rejected_logps - reference_rejected_logps).detach()
181
-
182
- return losses.mean(), chosen_rewards, rejected_rewards
183
-
184
- def get_batch_logps(logits_target: torch.FloatTensor, logits_reject: torch.FloatTensor, labels_target: torch.LongTensor, labels_reject: torch.LongTensor, average_log_prob: bool = False) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
185
-
186
- # dummy token; we'll ignore the losses on these tokens later
187
-
188
- per_token_logps_target = torch.gather(logits_target.log_softmax(-1), dim=2, index=labels_target.unsqueeze(2)).squeeze(2)
189
- per_token_logps_reject = torch.gather(logits_reject.log_softmax(-1), dim=2, index=labels_reject.unsqueeze(2)).squeeze(2)
190
-
191
- return per_token_logps_target.sum(-1), per_token_logps_reject.sum(-1)
192
-
193
- def make_reject_y(y_o, y_lens):
194
- def repeat_P(y):
195
- range_idx, _ = torch.randint(0, len(y), size=(2,)).sort()
196
- pre = y[:range_idx[0]]
197
- shf = y[range_idx[1]:]
198
- range_text = y[range_idx[0]:range_idx[1]]
199
- new_y = torch.cat([pre, range_text, range_text, shf])
200
- return new_y
201
- def lost_P(y):
202
- range_idx, _ = torch.randint(0, len(y), size=(2,)).sort()
203
- pre = y[:range_idx[0]]
204
- shf = y[range_idx[1]:]
205
- range_text = y[range_idx[0]:range_idx[1]]
206
- new_y = torch.cat([pre, shf])
207
- return new_y
208
- bs = len(y_lens)
209
- reject_y = []
210
- reject_y_lens = []
211
- for b in range(bs):
212
- process_item_idx = torch.randint(0, 1, size=(1, ))[0]
213
- if process_item_idx == 0:
214
- new_y = repeat_P(y_o[b])
215
- reject_y.append(new_y)
216
- reject_y_lens.append(len(new_y))
217
- elif process_item_idx==1:
218
- new_y = lost_P(y_o[b])
219
- reject_y.append(new_y)
220
- reject_y_lens.append(len(new_y))
221
- max_length = max(reject_y_lens)
222
- for b in range(bs):
223
- pad_length = max_length - reject_y_lens[b]
224
- reject_y[b] = torch.cat([reject_y[b], torch.zeros(pad_length, dtype=y_o.dtype, device=y_o.device)], dim=0)
225
-
226
- reject_y = torch.stack(reject_y, dim = 0)
227
- reject_y_lens = torch.tensor(reject_y_lens, device=y_lens.device)
228
-
229
- return reject_y, reject_y_lens
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
inference_webui.py CHANGED
@@ -52,13 +52,12 @@ import os
52
  import pdb
53
  import re
54
  import sys
 
55
 
56
  import LangSegment
57
  import spaces
58
  import torch
59
 
60
- import threading
61
-
62
  lock = threading.Lock()
63
 
64
  version = "v2" # os.environ.get("version","v2")
@@ -544,7 +543,7 @@ def get_tts_wav(
544
  if i_text in cache and if_freeze == True:
545
  pred_semantic = cache[i_text]
546
  else:
547
- with torch.no_grad(),lock:
548
  t2s_request = T2SRequest(
549
  [all_phoneme_ids.squeeze(0)],
550
  all_phoneme_len,
 
52
  import pdb
53
  import re
54
  import sys
55
+ import threading
56
 
57
  import LangSegment
58
  import spaces
59
  import torch
60
 
 
 
61
  lock = threading.Lock()
62
 
63
  version = "v2" # os.environ.get("version","v2")
 
543
  if i_text in cache and if_freeze == True:
544
  pred_semantic = cache[i_text]
545
  else:
546
+ with torch.no_grad(), lock:
547
  t2s_request = T2SRequest(
548
  [all_phoneme_ids.squeeze(0)],
549
  all_phoneme_len,