ollieollie commited on
Commit
99cc645
·
1 Parent(s): c30668f

add alignment

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. orator/src/orator/__pycache__/__init__.cpython-311.pyc +0 -0
  2. orator/src/orator/__pycache__/tts.cpython-311.pyc +0 -0
  3. orator/src/orator/models/bigvgan/__pycache__/activations.cpython-311.pyc +0 -0
  4. orator/src/orator/models/bigvgan/__pycache__/bigvgan.cpython-311.pyc +0 -0
  5. orator/src/orator/models/bigvgan/activations.py +120 -0
  6. orator/src/orator/models/bigvgan/alias_free_torch/__init__.py +6 -0
  7. orator/src/orator/models/bigvgan/alias_free_torch/__pycache__/__init__.cpython-311.pyc +0 -0
  8. orator/src/orator/models/bigvgan/alias_free_torch/__pycache__/act.cpython-311.pyc +0 -0
  9. orator/src/orator/models/bigvgan/alias_free_torch/__pycache__/filter.cpython-311.pyc +0 -0
  10. orator/src/orator/models/bigvgan/alias_free_torch/__pycache__/resample.cpython-311.pyc +0 -0
  11. orator/src/orator/models/bigvgan/alias_free_torch/act.py +28 -0
  12. orator/src/orator/models/bigvgan/alias_free_torch/filter.py +95 -0
  13. orator/src/orator/models/bigvgan/alias_free_torch/resample.py +55 -0
  14. orator/src/orator/models/bigvgan/bigvgan.py +212 -0
  15. orator/src/orator/models/s3gen/__pycache__/__init__.cpython-311.pyc +0 -0
  16. orator/src/orator/models/s3gen/__pycache__/const.cpython-311.pyc +0 -0
  17. orator/src/orator/models/s3gen/__pycache__/decoder.cpython-311.pyc +0 -0
  18. orator/src/orator/models/s3gen/__pycache__/f0_predictor.cpython-311.pyc +0 -0
  19. orator/src/orator/models/s3gen/__pycache__/flow.cpython-311.pyc +0 -0
  20. orator/src/orator/models/s3gen/__pycache__/flow_matching.cpython-311.pyc +0 -0
  21. orator/src/orator/models/s3gen/__pycache__/hifigan.cpython-311.pyc +0 -0
  22. orator/src/orator/models/s3gen/__pycache__/s3gen.cpython-311.pyc +0 -0
  23. orator/src/orator/models/s3gen/__pycache__/xvector.cpython-311.pyc +0 -0
  24. orator/src/orator/models/s3gen/matcha/__pycache__/decoder.cpython-311.pyc +0 -0
  25. orator/src/orator/models/s3gen/matcha/__pycache__/flow_matching.cpython-311.pyc +0 -0
  26. orator/src/orator/models/s3gen/matcha/__pycache__/transformer.cpython-311.pyc +0 -0
  27. orator/src/orator/models/s3gen/transformer/__pycache__/__init__.cpython-311.pyc +0 -0
  28. orator/src/orator/models/s3gen/transformer/__pycache__/activation.cpython-311.pyc +0 -0
  29. orator/src/orator/models/s3gen/transformer/__pycache__/attention.cpython-311.pyc +0 -0
  30. orator/src/orator/models/s3gen/transformer/__pycache__/convolution.cpython-311.pyc +0 -0
  31. orator/src/orator/models/s3gen/transformer/__pycache__/embedding.cpython-311.pyc +0 -0
  32. orator/src/orator/models/s3gen/transformer/__pycache__/encoder_layer.cpython-311.pyc +0 -0
  33. orator/src/orator/models/s3gen/transformer/__pycache__/positionwise_feed_forward.cpython-311.pyc +0 -0
  34. orator/src/orator/models/s3gen/transformer/__pycache__/subsampling.cpython-311.pyc +0 -0
  35. orator/src/orator/models/s3gen/transformer/__pycache__/upsample_encoder.cpython-311.pyc +0 -0
  36. orator/src/orator/models/s3gen/utils/__pycache__/class_utils.cpython-311.pyc +0 -0
  37. orator/src/orator/models/s3gen/utils/__pycache__/mask.cpython-311.pyc +0 -0
  38. orator/src/orator/models/s3gen/utils/__pycache__/mel.cpython-311.pyc +0 -0
  39. orator/src/orator/models/s3tokenizer/__pycache__/__init__.cpython-311.pyc +0 -0
  40. orator/src/orator/models/s3tokenizer/__pycache__/s3tokenizer.cpython-311.pyc +0 -0
  41. orator/src/orator/models/t3/__pycache__/__init__.cpython-311.pyc +0 -0
  42. orator/src/orator/models/t3/__pycache__/llama_configs.cpython-311.pyc +0 -0
  43. orator/src/orator/models/t3/__pycache__/t3.cpython-311.pyc +0 -0
  44. orator/src/orator/models/t3/inference/__pycache__/t3_hf_backend.cpython-311.pyc +0 -0
  45. orator/src/orator/models/t3/inference/alignment_stream_analyzer.py +154 -0
  46. orator/src/orator/models/t3/inference/t3_hf_backend.py +6 -6
  47. orator/src/orator/models/t3/modules/__pycache__/cond_enc.cpython-311.pyc +0 -0
  48. orator/src/orator/models/t3/modules/__pycache__/learned_pos_emb.cpython-311.pyc +0 -0
  49. orator/src/orator/models/t3/modules/__pycache__/perceiver.cpython-311.pyc +0 -0
  50. orator/src/orator/models/t3/modules/__pycache__/t3_config.cpython-311.pyc +0 -0
orator/src/orator/__pycache__/__init__.cpython-311.pyc CHANGED
Binary files a/orator/src/orator/__pycache__/__init__.cpython-311.pyc and b/orator/src/orator/__pycache__/__init__.cpython-311.pyc differ
 
orator/src/orator/__pycache__/tts.cpython-311.pyc CHANGED
Binary files a/orator/src/orator/__pycache__/tts.cpython-311.pyc and b/orator/src/orator/__pycache__/tts.cpython-311.pyc differ
 
orator/src/orator/models/bigvgan/__pycache__/activations.cpython-311.pyc ADDED
Binary file (6.09 kB). View file
 
orator/src/orator/models/bigvgan/__pycache__/bigvgan.cpython-311.pyc ADDED
Binary file (13.3 kB). View file
 
orator/src/orator/models/bigvgan/activations.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch
5
+ from torch import nn, sin, pow
6
+ from torch.nn import Parameter
7
+
8
+
9
+ class Snake(nn.Module):
10
+ '''
11
+ Implementation of a sine-based periodic activation function
12
+ Shape:
13
+ - Input: (B, C, T)
14
+ - Output: (B, C, T), same shape as the input
15
+ Parameters:
16
+ - alpha - trainable parameter
17
+ References:
18
+ - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
19
+ https://arxiv.org/abs/2006.08195
20
+ Examples:
21
+ >>> a1 = snake(256)
22
+ >>> x = torch.randn(256)
23
+ >>> x = a1(x)
24
+ '''
25
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
26
+ '''
27
+ Initialization.
28
+ INPUT:
29
+ - in_features: shape of the input
30
+ - alpha: trainable parameter
31
+ alpha is initialized to 1 by default, higher values = higher-frequency.
32
+ alpha will be trained along with the rest of your model.
33
+ '''
34
+ super(Snake, self).__init__()
35
+ self.in_features = in_features
36
+
37
+ # initialize alpha
38
+ self.alpha_logscale = alpha_logscale
39
+ if self.alpha_logscale: # log scale alphas initialized to zeros
40
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
41
+ else: # linear scale alphas initialized to ones
42
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
43
+
44
+ self.alpha.requires_grad = alpha_trainable
45
+
46
+ self.no_div_by_zero = 0.000000001
47
+
48
+ def forward(self, x):
49
+ '''
50
+ Forward pass of the function.
51
+ Applies the function to the input elementwise.
52
+ Snake ∶= x + 1/a * sin^2 (xa)
53
+ '''
54
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
55
+ if self.alpha_logscale:
56
+ alpha = torch.exp(alpha)
57
+ x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
58
+
59
+ return x
60
+
61
+
62
+ class SnakeBeta(nn.Module):
63
+ '''
64
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
65
+ Shape:
66
+ - Input: (B, C, T)
67
+ - Output: (B, C, T), same shape as the input
68
+ Parameters:
69
+ - alpha - trainable parameter that controls frequency
70
+ - beta - trainable parameter that controls magnitude
71
+ References:
72
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
73
+ https://arxiv.org/abs/2006.08195
74
+ Examples:
75
+ >>> a1 = snakebeta(256)
76
+ >>> x = torch.randn(256)
77
+ >>> x = a1(x)
78
+ '''
79
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
80
+ '''
81
+ Initialization.
82
+ INPUT:
83
+ - in_features: shape of the input
84
+ - alpha - trainable parameter that controls frequency
85
+ - beta - trainable parameter that controls magnitude
86
+ alpha is initialized to 1 by default, higher values = higher-frequency.
87
+ beta is initialized to 1 by default, higher values = higher-magnitude.
88
+ alpha will be trained along with the rest of your model.
89
+ '''
90
+ super(SnakeBeta, self).__init__()
91
+ self.in_features = in_features
92
+
93
+ # initialize alpha
94
+ self.alpha_logscale = alpha_logscale
95
+ if self.alpha_logscale: # log scale alphas initialized to zeros
96
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
97
+ self.beta = Parameter(torch.zeros(in_features) * alpha)
98
+ else: # linear scale alphas initialized to ones
99
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
100
+ self.beta = Parameter(torch.ones(in_features) * alpha)
101
+
102
+ self.alpha.requires_grad = alpha_trainable
103
+ self.beta.requires_grad = alpha_trainable
104
+
105
+ self.no_div_by_zero = 0.000000001
106
+
107
+ def forward(self, x):
108
+ '''
109
+ Forward pass of the function.
110
+ Applies the function to the input elementwise.
111
+ SnakeBeta ∶= x + 1/b * sin^2 (xa)
112
+ '''
113
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
114
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
115
+ if self.alpha_logscale:
116
+ alpha = torch.exp(alpha)
117
+ beta = torch.exp(beta)
118
+ x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
119
+
120
+ return x
orator/src/orator/models/bigvgan/alias_free_torch/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ from .filter import *
5
+ from .resample import *
6
+ from .act import *
orator/src/orator/models/bigvgan/alias_free_torch/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (281 Bytes). View file
 
orator/src/orator/models/bigvgan/alias_free_torch/__pycache__/act.cpython-311.pyc ADDED
Binary file (1.67 kB). View file
 
orator/src/orator/models/bigvgan/alias_free_torch/__pycache__/filter.cpython-311.pyc ADDED
Binary file (4.51 kB). View file
 
orator/src/orator/models/bigvgan/alias_free_torch/__pycache__/resample.cpython-311.pyc ADDED
Binary file (3.43 kB). View file
 
orator/src/orator/models/bigvgan/alias_free_torch/act.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch.nn as nn
5
+
6
+ from .resample import UpSample1d, DownSample1d
7
+
8
+
9
+ class Activation1d(nn.Module):
10
+ def __init__(self,
11
+ activation,
12
+ up_ratio: int = 2,
13
+ down_ratio: int = 2,
14
+ up_kernel_size: int = 12,
15
+ down_kernel_size: int = 12):
16
+ super().__init__()
17
+ self.up_ratio = up_ratio
18
+ self.down_ratio = down_ratio
19
+ self.act = activation
20
+ self.upsample = UpSample1d(up_ratio, up_kernel_size)
21
+ self.downsample = DownSample1d(down_ratio, down_kernel_size)
22
+
23
+ # x: [B, C, T]
24
+ def forward(self, x):
25
+ x = self.upsample(x)
26
+ x = self.act(x)
27
+ x = self.downsample(x)
28
+ return x
orator/src/orator/models/bigvgan/alias_free_torch/filter.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+
11
+ if 'sinc' in dir(torch):
12
+ sinc = torch.sinc
13
+ else:
14
+ # This code is adopted from adefossez's julius.core.sinc under the MIT License
15
+ # https://adefossez.github.io/julius/julius/core.html
16
+ # LICENSE is in incl_licenses directory.
17
+ def sinc(x: torch.Tensor):
18
+ """
19
+ Implementation of sinc, i.e. sin(pi * x) / (pi * x)
20
+ __Warning__: Different to julius.sinc, the input is multiplied by `pi`!
21
+ """
22
+ return torch.where(x == 0,
23
+ torch.tensor(1., device=x.device, dtype=x.dtype),
24
+ torch.sin(math.pi * x) / math.pi / x)
25
+
26
+
27
+ # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
28
+ # https://adefossez.github.io/julius/julius/lowpass.html
29
+ # LICENSE is in incl_licenses directory.
30
+ def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
31
+ even = (kernel_size % 2 == 0)
32
+ half_size = kernel_size // 2
33
+
34
+ #For kaiser window
35
+ delta_f = 4 * half_width
36
+ A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
37
+ if A > 50.:
38
+ beta = 0.1102 * (A - 8.7)
39
+ elif A >= 21.:
40
+ beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)
41
+ else:
42
+ beta = 0.
43
+ window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
44
+
45
+ # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
46
+ if even:
47
+ time = (torch.arange(-half_size, half_size) + 0.5)
48
+ else:
49
+ time = torch.arange(kernel_size) - half_size
50
+ if cutoff == 0:
51
+ filter_ = torch.zeros_like(time)
52
+ else:
53
+ filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
54
+ # Normalize filter to have sum = 1, otherwise we will have a small leakage
55
+ # of the constant component in the input signal.
56
+ filter_ /= filter_.sum()
57
+ filter = filter_.view(1, 1, kernel_size)
58
+
59
+ return filter
60
+
61
+
62
+ class LowPassFilter1d(nn.Module):
63
+ def __init__(self,
64
+ cutoff=0.5,
65
+ half_width=0.6,
66
+ stride: int = 1,
67
+ padding: bool = True,
68
+ padding_mode: str = 'replicate',
69
+ kernel_size: int = 12):
70
+ # kernel_size should be even number for stylegan3 setup,
71
+ # in this implementation, odd number is also possible.
72
+ super().__init__()
73
+ if cutoff < -0.:
74
+ raise ValueError("Minimum cutoff must be larger than zero.")
75
+ if cutoff > 0.5:
76
+ raise ValueError("A cutoff above 0.5 does not make sense.")
77
+ self.kernel_size = kernel_size
78
+ self.even = (kernel_size % 2 == 0)
79
+ self.pad_left = kernel_size // 2 - int(self.even)
80
+ self.pad_right = kernel_size // 2
81
+ self.stride = stride
82
+ self.padding = padding
83
+ self.padding_mode = padding_mode
84
+ filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
85
+ self.register_buffer("filter", filter)
86
+
87
+ #input [B, C, T]
88
+ def forward(self, x):
89
+ _, C, _ = x.shape
90
+
91
+ if self.padding:
92
+ x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
93
+ out = F.conv1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
94
+
95
+ return out
orator/src/orator/models/bigvgan/alias_free_torch/resample.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+
7
+ from .filter import LowPassFilter1d
8
+ from .filter import kaiser_sinc_filter1d
9
+
10
+
11
+ class UpSample1d(nn.Module):
12
+ def __init__(self, ratio=2, kernel_size=None):
13
+ super().__init__()
14
+ self.ratio = ratio
15
+ self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
16
+ self.stride = ratio
17
+ self.pad = self.kernel_size // ratio - 1
18
+ self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
19
+ self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
20
+ filter = kaiser_sinc_filter1d(
21
+ cutoff=0.5 / ratio,
22
+ half_width=0.6 / ratio,
23
+ kernel_size=self.kernel_size
24
+ )
25
+ self.register_buffer("filter", filter)
26
+
27
+ # x: [B, C, T]
28
+ def forward(self, x):
29
+ _, C, _ = x.shape
30
+
31
+ x = F.pad(x, (self.pad, self.pad), mode='replicate')
32
+ x = self.ratio * F.conv_transpose1d(
33
+ x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
34
+ )
35
+ x = x[..., self.pad_left:-self.pad_right]
36
+
37
+ return x
38
+
39
+
40
+ class DownSample1d(nn.Module):
41
+ def __init__(self, ratio=2, kernel_size=None):
42
+ super().__init__()
43
+ self.ratio = ratio
44
+ self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
45
+ self.lowpass = LowPassFilter1d(
46
+ cutoff=0.5 / ratio,
47
+ half_width=0.6 / ratio,
48
+ stride=ratio,
49
+ kernel_size=self.kernel_size
50
+ )
51
+
52
+ def forward(self, x):
53
+ xx = self.lowpass(x)
54
+
55
+ return xx
orator/src/orator/models/bigvgan/bigvgan.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+ # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
4
+ # LICENSE is in incl_licenses directory.
5
+
6
+ import logging
7
+ from torch.nn import Conv1d, ConvTranspose1d
8
+ from torch.nn.utils import weight_norm, remove_weight_norm
9
+ from torch.nn.utils.weight_norm import WeightNorm
10
+
11
+ from .activations import SnakeBeta
12
+ from .alias_free_torch import *
13
+
14
+
15
+
16
+ LRELU_SLOPE = 0.1
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ def get_padding(kernel_size, dilation=1):
22
+ return int((kernel_size*dilation - dilation)/2)
23
+
24
+
25
+ def init_weights(m, mean=0.0, std=0.01):
26
+ classname = m.__class__.__name__
27
+ if classname.find("Conv") != -1:
28
+ m.weight.data.normal_(mean, std)
29
+
30
+
31
+ class AMPBlock1(torch.nn.Module):
32
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
33
+ super(AMPBlock1, self).__init__()
34
+
35
+ self.convs1 = nn.ModuleList([
36
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
37
+ padding=get_padding(kernel_size, dilation[0]))),
38
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
39
+ padding=get_padding(kernel_size, dilation[1]))),
40
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
41
+ padding=get_padding(kernel_size, dilation[2])))
42
+ ])
43
+ self.convs1.apply(init_weights)
44
+
45
+ self.convs2 = nn.ModuleList([
46
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))),
47
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1))),
48
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, padding=get_padding(kernel_size, 1)))
49
+ ])
50
+ self.convs2.apply(init_weights)
51
+
52
+ self.num_layers = len(self.convs1) + len(self.convs2) # total number of conv layers
53
+
54
+ self.activations = nn.ModuleList([
55
+ Activation1d(activation=SnakeBeta(channels, alpha_logscale=True))
56
+ for _ in range(self.num_layers)
57
+ ])
58
+
59
+ def forward(self, x):
60
+ acts1, acts2 = self.activations[::2], self.activations[1::2]
61
+ for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2):
62
+ xt = a1(x)
63
+ xt = c1(xt)
64
+ xt = a2(xt)
65
+ xt = c2(xt)
66
+ x = xt + x
67
+
68
+ return x
69
+
70
+ def set_weight_norm(self, enabled: bool):
71
+ weight_norm_fn = weight_norm if enabled else remove_weight_norm
72
+ for l in self.convs1:
73
+ weight_norm_fn(l)
74
+ for l in self.convs2:
75
+ weight_norm_fn(l)
76
+
77
+
78
+ class BigVGAN(nn.Module):
79
+ # this is our main BigVGAN model. Applies anti-aliased periodic activation for resblocks.
80
+
81
+ # We've got a model in prod that has the wrong hparams for this. It's simpler to add this check than to
82
+ # redistribute the model.
83
+ ignore_state_dict_unexpected = ("cond_layer.*",)
84
+
85
+ def __init__(self):
86
+ super().__init__()
87
+
88
+ input_dims = 80
89
+
90
+ upsample_rates = [10, 8, 4, 2]
91
+ upsample_kernel_sizes = [x * 2 for x in upsample_rates]
92
+ upsample_initial_channel = 1024
93
+
94
+ resblock_kernel_sizes = [3, 7, 11]
95
+ resblock_dilation_sizes = [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
96
+ self.num_kernels = len(resblock_kernel_sizes)
97
+ self.num_upsamples = len(upsample_rates)
98
+
99
+ # pre conv
100
+ self.conv_pre = weight_norm(Conv1d(input_dims, upsample_initial_channel, 7, 1, padding=3))
101
+ self.cond_layer = None
102
+
103
+ # transposed conv-based upsamplers. does not apply anti-aliasing
104
+ self.ups = nn.ModuleList()
105
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
106
+ self.ups.append(nn.ModuleList([
107
+ weight_norm(ConvTranspose1d(upsample_initial_channel // (2 ** i),
108
+ upsample_initial_channel // (2 ** (i + 1)),
109
+ k, u, padding=(k - u) // 2))
110
+ ]))
111
+
112
+ # residual blocks using anti-aliased multi-periodicity composition modules (AMP)
113
+ self.resblocks = nn.ModuleList()
114
+ for i in range(len(self.ups)):
115
+ ch = upsample_initial_channel // (2 ** (i + 1))
116
+ for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
117
+ self.resblocks.append(AMPBlock1(ch, k, d))
118
+
119
+ # post conv
120
+ activation_post = SnakeBeta(ch, alpha_logscale=True)
121
+ self.activation_post = Activation1d(activation=activation_post)
122
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
123
+
124
+ # weight initialization
125
+ for i in range(len(self.ups)):
126
+ self.ups[i].apply(init_weights)
127
+ self.conv_post.apply(init_weights)
128
+
129
+ def forward(self, x) -> torch.Tensor:
130
+ """
131
+ Args
132
+ ----
133
+ x: torch.Tensor of shape [B, T, C]
134
+ """
135
+ with torch.inference_mode():
136
+
137
+ x = self.conv_pre(x)
138
+
139
+ for i in range(self.num_upsamples):
140
+ # upsampling
141
+ for i_up in range(len(self.ups[i])):
142
+ x = self.ups[i][i_up](x)
143
+ # AMP blocks
144
+ xs = None
145
+ for j in range(self.num_kernels):
146
+ if xs is None:
147
+ xs = self.resblocks[i * self.num_kernels + j](x)
148
+ else:
149
+ xs += self.resblocks[i * self.num_kernels + j](x)
150
+ x = xs / self.num_kernels
151
+
152
+ # post conv
153
+ x = self.activation_post(x)
154
+ x = self.conv_post(x)
155
+
156
+ # Bound the output to [-1, 1]
157
+ x = torch.tanh(x)
158
+
159
+ return x
160
+
161
+ @property
162
+ def weight_norm_enabled(self) -> bool:
163
+ return any(
164
+ isinstance(hook, WeightNorm) and hook.name == "weight"
165
+ for k, hook in self.conv_pre._forward_pre_hooks.items()
166
+ )
167
+
168
+ def set_weight_norm(self, enabled: bool):
169
+ """
170
+ N.B.: weight norm modifies the state dict, causing incompatibilities. Conventions:
171
+ - BigVGAN runs with weight norm for training, without for inference (done automatically by instantiate())
172
+ - All checkpoints are saved with weight norm (allows resuming training)
173
+ """
174
+ if enabled != self.weight_norm_enabled:
175
+ weight_norm_fn = weight_norm if enabled else remove_weight_norm
176
+ logger.debug(f"{'Applying' if enabled else 'Removing'} weight norm...")
177
+
178
+ for l in self.ups:
179
+ for l_i in l:
180
+ weight_norm_fn(l_i)
181
+ for l in self.resblocks:
182
+ l.set_weight_norm(enabled)
183
+ weight_norm_fn(self.conv_pre)
184
+ weight_norm_fn(self.conv_post)
185
+
186
+ def train_mode(self):
187
+ self.train()
188
+ self.set_weight_norm(enabled=True)
189
+
190
+ def inference_mode(self):
191
+ self.eval()
192
+ self.set_weight_norm(enabled=False)
193
+
194
+
195
+ if __name__ == '__main__':
196
+ import sys
197
+ import soundfile as sf
198
+ model = BigVGAN()
199
+
200
+ state_dict = torch.load("bigvgan32k.pt")
201
+ msg = model.load_state_dict(state_dict)
202
+ model.eval()
203
+ model.set_weight_norm(enabled=False)
204
+
205
+ print(msg)
206
+ mels = torch.load("mels.pt")
207
+ with torch.inference_mode():
208
+ y = model(mels.cpu())
209
+
210
+ for i, wav in enumerate(y):
211
+ wav = wav.view(-1).detach().numpy()
212
+ sf.write(f"bigvgan_test{i}.flac", wav, samplerate=32_000, format="FLAC")
orator/src/orator/models/s3gen/__pycache__/__init__.cpython-311.pyc CHANGED
Binary files a/orator/src/orator/models/s3gen/__pycache__/__init__.cpython-311.pyc and b/orator/src/orator/models/s3gen/__pycache__/__init__.cpython-311.pyc differ
 
orator/src/orator/models/s3gen/__pycache__/const.cpython-311.pyc CHANGED
Binary files a/orator/src/orator/models/s3gen/__pycache__/const.cpython-311.pyc and b/orator/src/orator/models/s3gen/__pycache__/const.cpython-311.pyc differ
 
orator/src/orator/models/s3gen/__pycache__/decoder.cpython-311.pyc CHANGED
Binary files a/orator/src/orator/models/s3gen/__pycache__/decoder.cpython-311.pyc and b/orator/src/orator/models/s3gen/__pycache__/decoder.cpython-311.pyc differ
 
orator/src/orator/models/s3gen/__pycache__/f0_predictor.cpython-311.pyc CHANGED
Binary files a/orator/src/orator/models/s3gen/__pycache__/f0_predictor.cpython-311.pyc and b/orator/src/orator/models/s3gen/__pycache__/f0_predictor.cpython-311.pyc differ
 
orator/src/orator/models/s3gen/__pycache__/flow.cpython-311.pyc CHANGED
Binary files a/orator/src/orator/models/s3gen/__pycache__/flow.cpython-311.pyc and b/orator/src/orator/models/s3gen/__pycache__/flow.cpython-311.pyc differ
 
orator/src/orator/models/s3gen/__pycache__/flow_matching.cpython-311.pyc CHANGED
Binary files a/orator/src/orator/models/s3gen/__pycache__/flow_matching.cpython-311.pyc and b/orator/src/orator/models/s3gen/__pycache__/flow_matching.cpython-311.pyc differ
 
orator/src/orator/models/s3gen/__pycache__/hifigan.cpython-311.pyc CHANGED
Binary files a/orator/src/orator/models/s3gen/__pycache__/hifigan.cpython-311.pyc and b/orator/src/orator/models/s3gen/__pycache__/hifigan.cpython-311.pyc differ
 
orator/src/orator/models/s3gen/__pycache__/s3gen.cpython-311.pyc CHANGED
Binary files a/orator/src/orator/models/s3gen/__pycache__/s3gen.cpython-311.pyc and b/orator/src/orator/models/s3gen/__pycache__/s3gen.cpython-311.pyc differ
 
orator/src/orator/models/s3gen/__pycache__/xvector.cpython-311.pyc CHANGED
Binary files a/orator/src/orator/models/s3gen/__pycache__/xvector.cpython-311.pyc and b/orator/src/orator/models/s3gen/__pycache__/xvector.cpython-311.pyc differ
 
orator/src/orator/models/s3gen/matcha/__pycache__/decoder.cpython-311.pyc CHANGED
Binary files a/orator/src/orator/models/s3gen/matcha/__pycache__/decoder.cpython-311.pyc and b/orator/src/orator/models/s3gen/matcha/__pycache__/decoder.cpython-311.pyc differ
 
orator/src/orator/models/s3gen/matcha/__pycache__/flow_matching.cpython-311.pyc CHANGED
Binary files a/orator/src/orator/models/s3gen/matcha/__pycache__/flow_matching.cpython-311.pyc and b/orator/src/orator/models/s3gen/matcha/__pycache__/flow_matching.cpython-311.pyc differ
 
orator/src/orator/models/s3gen/matcha/__pycache__/transformer.cpython-311.pyc CHANGED
Binary files a/orator/src/orator/models/s3gen/matcha/__pycache__/transformer.cpython-311.pyc and b/orator/src/orator/models/s3gen/matcha/__pycache__/transformer.cpython-311.pyc differ
 
orator/src/orator/models/s3gen/transformer/__pycache__/__init__.cpython-311.pyc CHANGED
Binary files a/orator/src/orator/models/s3gen/transformer/__pycache__/__init__.cpython-311.pyc and b/orator/src/orator/models/s3gen/transformer/__pycache__/__init__.cpython-311.pyc differ
 
orator/src/orator/models/s3gen/transformer/__pycache__/activation.cpython-311.pyc CHANGED
Binary files a/orator/src/orator/models/s3gen/transformer/__pycache__/activation.cpython-311.pyc and b/orator/src/orator/models/s3gen/transformer/__pycache__/activation.cpython-311.pyc differ
 
orator/src/orator/models/s3gen/transformer/__pycache__/attention.cpython-311.pyc CHANGED
Binary files a/orator/src/orator/models/s3gen/transformer/__pycache__/attention.cpython-311.pyc and b/orator/src/orator/models/s3gen/transformer/__pycache__/attention.cpython-311.pyc differ
 
orator/src/orator/models/s3gen/transformer/__pycache__/convolution.cpython-311.pyc CHANGED
Binary files a/orator/src/orator/models/s3gen/transformer/__pycache__/convolution.cpython-311.pyc and b/orator/src/orator/models/s3gen/transformer/__pycache__/convolution.cpython-311.pyc differ
 
orator/src/orator/models/s3gen/transformer/__pycache__/embedding.cpython-311.pyc CHANGED
Binary files a/orator/src/orator/models/s3gen/transformer/__pycache__/embedding.cpython-311.pyc and b/orator/src/orator/models/s3gen/transformer/__pycache__/embedding.cpython-311.pyc differ
 
orator/src/orator/models/s3gen/transformer/__pycache__/encoder_layer.cpython-311.pyc CHANGED
Binary files a/orator/src/orator/models/s3gen/transformer/__pycache__/encoder_layer.cpython-311.pyc and b/orator/src/orator/models/s3gen/transformer/__pycache__/encoder_layer.cpython-311.pyc differ
 
orator/src/orator/models/s3gen/transformer/__pycache__/positionwise_feed_forward.cpython-311.pyc CHANGED
Binary files a/orator/src/orator/models/s3gen/transformer/__pycache__/positionwise_feed_forward.cpython-311.pyc and b/orator/src/orator/models/s3gen/transformer/__pycache__/positionwise_feed_forward.cpython-311.pyc differ
 
orator/src/orator/models/s3gen/transformer/__pycache__/subsampling.cpython-311.pyc CHANGED
Binary files a/orator/src/orator/models/s3gen/transformer/__pycache__/subsampling.cpython-311.pyc and b/orator/src/orator/models/s3gen/transformer/__pycache__/subsampling.cpython-311.pyc differ
 
orator/src/orator/models/s3gen/transformer/__pycache__/upsample_encoder.cpython-311.pyc CHANGED
Binary files a/orator/src/orator/models/s3gen/transformer/__pycache__/upsample_encoder.cpython-311.pyc and b/orator/src/orator/models/s3gen/transformer/__pycache__/upsample_encoder.cpython-311.pyc differ
 
orator/src/orator/models/s3gen/utils/__pycache__/class_utils.cpython-311.pyc CHANGED
Binary files a/orator/src/orator/models/s3gen/utils/__pycache__/class_utils.cpython-311.pyc and b/orator/src/orator/models/s3gen/utils/__pycache__/class_utils.cpython-311.pyc differ
 
orator/src/orator/models/s3gen/utils/__pycache__/mask.cpython-311.pyc CHANGED
Binary files a/orator/src/orator/models/s3gen/utils/__pycache__/mask.cpython-311.pyc and b/orator/src/orator/models/s3gen/utils/__pycache__/mask.cpython-311.pyc differ
 
orator/src/orator/models/s3gen/utils/__pycache__/mel.cpython-311.pyc CHANGED
Binary files a/orator/src/orator/models/s3gen/utils/__pycache__/mel.cpython-311.pyc and b/orator/src/orator/models/s3gen/utils/__pycache__/mel.cpython-311.pyc differ
 
orator/src/orator/models/s3tokenizer/__pycache__/__init__.cpython-311.pyc CHANGED
Binary files a/orator/src/orator/models/s3tokenizer/__pycache__/__init__.cpython-311.pyc and b/orator/src/orator/models/s3tokenizer/__pycache__/__init__.cpython-311.pyc differ
 
orator/src/orator/models/s3tokenizer/__pycache__/s3tokenizer.cpython-311.pyc CHANGED
Binary files a/orator/src/orator/models/s3tokenizer/__pycache__/s3tokenizer.cpython-311.pyc and b/orator/src/orator/models/s3tokenizer/__pycache__/s3tokenizer.cpython-311.pyc differ
 
orator/src/orator/models/t3/__pycache__/__init__.cpython-311.pyc CHANGED
Binary files a/orator/src/orator/models/t3/__pycache__/__init__.cpython-311.pyc and b/orator/src/orator/models/t3/__pycache__/__init__.cpython-311.pyc differ
 
orator/src/orator/models/t3/__pycache__/llama_configs.cpython-311.pyc CHANGED
Binary files a/orator/src/orator/models/t3/__pycache__/llama_configs.cpython-311.pyc and b/orator/src/orator/models/t3/__pycache__/llama_configs.cpython-311.pyc differ
 
orator/src/orator/models/t3/__pycache__/t3.cpython-311.pyc CHANGED
Binary files a/orator/src/orator/models/t3/__pycache__/t3.cpython-311.pyc and b/orator/src/orator/models/t3/__pycache__/t3.cpython-311.pyc differ
 
orator/src/orator/models/t3/inference/__pycache__/t3_hf_backend.cpython-311.pyc CHANGED
Binary files a/orator/src/orator/models/t3/inference/__pycache__/t3_hf_backend.cpython-311.pyc and b/orator/src/orator/models/t3/inference/__pycache__/t3_hf_backend.cpython-311.pyc differ
 
orator/src/orator/models/t3/inference/alignment_stream_analyzer.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Resemble AI
2
+ # Author: John Meade, Jeremy Hsu
3
+ # MIT License
4
+ import logging
5
+ import torch
6
+ from dataclasses import dataclass
7
+ from types import MethodType
8
+
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ @dataclass
14
+ class AlignmentAnalysisResult:
15
+ # was this frame detected as being part of a noisy beginning chunk with potential hallucinations?
16
+ false_start: bool
17
+ # was this frame detected as being part of a long tail with potential hallucinations?
18
+ long_tail: bool
19
+ # was this frame detected as repeating existing text content?
20
+ repetition: bool
21
+ # was the alignment position of this frame too far from the previous frame?
22
+ discontinuity: bool
23
+ # has inference reached the end of the text tokens? eg, this remains false if inference stops early
24
+ complete: bool
25
+ # approximate position in the text token sequence. Can be used for generating online timestamps.
26
+ position: int
27
+
28
+
29
+ class AlignmentStreamAnalyzer:
30
+ def __init__(self, tfmr, queue, text_tokens_slice, alignment_layer_idx=9, eos_idx=0):
31
+ """
32
+ Some transformer TTS models implicitly solve text-speech alignment in one or more of their self-attention
33
+ activation maps. This module exploits this to perform online integrity checks which streaming.
34
+ A hook is injected into the specified attention layer, and heuristics are used to determine alignment
35
+ position, repetition, etc.
36
+
37
+ NOTE: currently requires no queues.
38
+ """
39
+ # self.queue = queue
40
+ self.text_tokens_slice = (i, j) = text_tokens_slice
41
+ self.eos_idx = eos_idx
42
+ self.alignment = torch.zeros(0, j-i)
43
+ # self.alignment_bin = torch.zeros(0, j-i)
44
+ self.curr_frame_pos = 0
45
+ self.text_position = 0
46
+
47
+ self.started = False
48
+ self.started_at = None
49
+
50
+ self.complete = False
51
+ self.completed_at = None
52
+
53
+ # Using `output_attentions=True` is incompatible with optimized attention kernels, so
54
+ # using it for all layers slows things down too much. We can apply it to just one layer
55
+ # by intercepting the kwargs and adding a forward hook (credit: jrm)
56
+ self.last_aligned_attn = None
57
+ self._add_attention_spy(tfmr, alignment_layer_idx)
58
+
59
+ def _add_attention_spy(self, tfmr, alignment_layer_idx):
60
+ """
61
+ Adds a forward hook to a specific attention layer to collect outputs.
62
+ Using `output_attentions=True` is incompatible with optimized attention kernels, so
63
+ using it for all layers slows things down too much.
64
+ (credit: jrm)
65
+ """
66
+
67
+ def attention_forward_hook(module, input, output):
68
+ """
69
+ See `LlamaAttention.forward`; the output is a 3-tuple: `attn_output, attn_weights, past_key_value`.
70
+ NOTE:
71
+ - When `output_attentions=True`, `LlamaSdpaAttention.forward` calls `LlamaAttention.forward`.
72
+ - `attn_output` has shape [B, H, T0, T0] for the 0th entry, and [B, H, 1, T0+i] for the rest i-th.
73
+ """
74
+ step_attention = output[1].cpu() # (B, 16, N, N)
75
+ self.last_aligned_attn = step_attention[0].mean(0) # (N, N)
76
+
77
+ target_layer = tfmr.layers[alignment_layer_idx].self_attn
78
+ hook_handle = target_layer.register_forward_hook(attention_forward_hook)
79
+
80
+ # Backup original forward
81
+ original_forward = target_layer.forward
82
+ def patched_forward(self, *args, **kwargs):
83
+ kwargs['output_attentions'] = True
84
+ return original_forward(*args, **kwargs)
85
+
86
+ # TODO: how to unpatch it?
87
+ target_layer.forward = MethodType(patched_forward, target_layer)
88
+
89
+ def step(self, logits):
90
+ """
91
+ Emits an AlignmentAnalysisResult into the output queue, and potentially modifies the logits to force an EOS.
92
+ """
93
+ # extract approximate alignment matrix chunk (1 frame at a time after the first chunk)
94
+ aligned_attn = self.last_aligned_attn # (N, N)
95
+ i, j = self.text_tokens_slice
96
+ if self.curr_frame_pos == 0:
97
+ # first chunk has conditioning info, text tokens, and BOS token
98
+ A_chunk = aligned_attn[j:, i:j].clone().cpu() # (T, S)
99
+ else:
100
+ # subsequent chunks have 1 frame due to KV-caching
101
+ A_chunk = aligned_attn[:, i:j].clone().cpu() # (1, S)
102
+
103
+ # TODO: monotonic masking; could have issue b/c spaces are often skipped.
104
+ A_chunk[:, self.curr_frame_pos + 1:] = 0
105
+
106
+
107
+ self.alignment = torch.cat((self.alignment, A_chunk), dim=0)
108
+
109
+ A = self.alignment
110
+ T, S = A.shape
111
+
112
+ # update position
113
+ cur_text_posn = A_chunk[-1].argmax()
114
+ discontinuity = not(-4 < cur_text_posn - self.text_position < 7) # NOTE: very lenient!
115
+ if not discontinuity:
116
+ self.text_position = cur_text_posn
117
+
118
+ # Hallucinations at the start of speech show up as activations at the bottom of the attention maps!
119
+ # To mitigate this, we just wait until there are no activations far off-diagonal in the last 2 tokens,
120
+ # and there are some strong activations in the first few tokens.
121
+ false_start = (not self.started) and (A[-2:, -2:].max() > 0.1 or A[:, :4].max() < 0.5)
122
+ self.started = not false_start
123
+ if self.started and self.started_at is None:
124
+ self.started_at = T
125
+
126
+ # Is generation likely complete?
127
+ self.complete = self.complete or self.text_position >= S - 3
128
+ if self.complete and self.completed_at is None:
129
+ self.completed_at = T
130
+
131
+ # NOTE: EOS rarely assigned activations, and second-last token is often punctuation, so use last 3 tokens.
132
+ # NOTE: due to the false-start behaviour, we need to make sure we skip activations for the first few tokens.
133
+ last_text_token_duration = A[15:, -3:].sum()
134
+
135
+ # Activations for the final token that last too long are likely hallucinations.
136
+ long_tail = self.complete and (A[self.completed_at:, -3:].sum(dim=0).max() >= 10) # 400ms
137
+
138
+ # If there are activations in previous tokens after generation has completed, assume this is a repetition error.
139
+ repetition = self.complete and (A[self.completed_at:, :-5].max(dim=1).values.sum() > 5)
140
+
141
+ # If a bad ending is detected, force emit EOS by modifying logits
142
+ # NOTE: this means logits may be inconsistent with latents!
143
+ if long_tail or repetition:
144
+ logger.warn(f"forcing EOS token, {long_tail=}, {repetition=}")
145
+ # (±2**15 is safe for all dtypes >= 16bit)
146
+ logits = -(2**15) * torch.ones_like(logits)
147
+ logits[..., self.eos_idx] = 2**15
148
+
149
+ # Suppress EoS to prevent early termination
150
+ if cur_text_posn < S - 3: # FIXME: arbitrary
151
+ logits[..., self.eos_idx] = -2**15
152
+
153
+ self.curr_frame_pos += 1
154
+ return logits
orator/src/orator/models/t3/inference/t3_hf_backend.py CHANGED
@@ -23,14 +23,14 @@ class T3HuggingfaceBackend(LlamaPreTrainedModel, GenerationMixin):
23
  speech_head,
24
  latents_queue=None,
25
  logits_queue=None,
 
26
  ):
27
  super().__init__(config)
28
  self.model = llama
29
  self.speech_enc = speech_enc
30
  self.speech_head = speech_head
31
- self.latents_queue = latents_queue
32
- self.logits_queue = logits_queue
33
  self._added_cond = False
 
34
 
35
  @torch.inference_mode()
36
  def prepare_inputs_for_generation(
@@ -101,12 +101,12 @@ class T3HuggingfaceBackend(LlamaPreTrainedModel, GenerationMixin):
101
  return_dict=True,
102
  )
103
  hidden_states = tfmr_out.hidden_states[-1] # (B, seq, dim)
104
- if self.latents_queue is not None:
105
- self.latents_queue.put(hidden_states)
106
 
107
  logits = self.speech_head(hidden_states)
108
- if self.logits_queue is not None:
109
- self.logits_queue.put(logits)
 
 
110
 
111
  return CausalLMOutputWithCrossAttentions(
112
  logits=logits,
 
23
  speech_head,
24
  latents_queue=None,
25
  logits_queue=None,
26
+ alignment_stream_analyzer: 'AlignmentStreamAnalyzer'=None,
27
  ):
28
  super().__init__(config)
29
  self.model = llama
30
  self.speech_enc = speech_enc
31
  self.speech_head = speech_head
 
 
32
  self._added_cond = False
33
+ self.alignment_stream_analyzer = alignment_stream_analyzer
34
 
35
  @torch.inference_mode()
36
  def prepare_inputs_for_generation(
 
101
  return_dict=True,
102
  )
103
  hidden_states = tfmr_out.hidden_states[-1] # (B, seq, dim)
 
 
104
 
105
  logits = self.speech_head(hidden_states)
106
+ assert inputs_embeds.size(0) == 1
107
+
108
+ # NOTE: hallucination handler may modify logits to force emit an EOS token
109
+ logits = self.alignment_stream_analyzer.step(logits)
110
 
111
  return CausalLMOutputWithCrossAttentions(
112
  logits=logits,
orator/src/orator/models/t3/modules/__pycache__/cond_enc.cpython-311.pyc CHANGED
Binary files a/orator/src/orator/models/t3/modules/__pycache__/cond_enc.cpython-311.pyc and b/orator/src/orator/models/t3/modules/__pycache__/cond_enc.cpython-311.pyc differ
 
orator/src/orator/models/t3/modules/__pycache__/learned_pos_emb.cpython-311.pyc CHANGED
Binary files a/orator/src/orator/models/t3/modules/__pycache__/learned_pos_emb.cpython-311.pyc and b/orator/src/orator/models/t3/modules/__pycache__/learned_pos_emb.cpython-311.pyc differ
 
orator/src/orator/models/t3/modules/__pycache__/perceiver.cpython-311.pyc CHANGED
Binary files a/orator/src/orator/models/t3/modules/__pycache__/perceiver.cpython-311.pyc and b/orator/src/orator/models/t3/modules/__pycache__/perceiver.cpython-311.pyc differ
 
orator/src/orator/models/t3/modules/__pycache__/t3_config.cpython-311.pyc CHANGED
Binary files a/orator/src/orator/models/t3/modules/__pycache__/t3_config.cpython-311.pyc and b/orator/src/orator/models/t3/modules/__pycache__/t3_config.cpython-311.pyc differ