Update model.py
Browse files
model.py
CHANGED
@@ -133,29 +133,35 @@ class rotary(nn.Module):
|
|
133 |
def __init__(self, dims, max_ctx=1500, theta=10000, learned_freq=False, variable_radius=False,
|
134 |
learned_radius=False, learned_theta=False, learned_pitch=False, debug: List[str] = []):
|
135 |
super().__init__()
|
136 |
-
|
137 |
-
|
138 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
139 |
-
dtype = torch.float32
|
140 |
self.device = device
|
|
|
141 |
self.dtype = dtype
|
142 |
self.debug = debug
|
143 |
self._counter = 0
|
144 |
-
|
145 |
-
self.use_pbias = False
|
146 |
self.max_ctx = max_ctx
|
147 |
self.variable_radius = variable_radius
|
148 |
|
149 |
-
self.inv_freq = nn.Parameter(
|
|
|
150 |
requires_grad=learned_freq)
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
self.
|
|
|
|
|
|
|
|
|
156 |
|
157 |
if variable_radius:
|
158 |
-
self.radius = nn.Parameter(
|
|
|
|
|
159 |
|
160 |
def get_pitch_bias(self, f0):
|
161 |
if f0 is None:
|
@@ -194,27 +200,38 @@ class rotary(nn.Module):
|
|
194 |
t = torch.arange(x, device=self.device).float()
|
195 |
else:
|
196 |
t = x.float().to(self.inv_freq.device)
|
|
|
197 |
if f0 is not None:
|
198 |
f0_mean = f0.mean()
|
199 |
-
|
200 |
-
|
201 |
-
|
|
|
202 |
else:
|
203 |
inv_freq = self.inv_freq
|
204 |
freqs = torch.einsum('i,j->ij', t, inv_freq)
|
|
|
205 |
freqs = freqs.float()
|
206 |
if self.variable_radius:
|
207 |
-
if f0 is not None:
|
208 |
-
f0 = f0[0]
|
209 |
-
seq_len = x
|
210 |
-
f0 = torch.tensor(f0, device=device if isinstance(x, torch.Tensor) else device)
|
211 |
-
f0 = self.align_f0_to_tokens(f0, freqs.shape[-1])
|
212 |
-
radius = 1.0 / (f0 + 1)
|
213 |
-
freqs = torch.polar(radius, freqs)
|
214 |
-
else:
|
215 |
-
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
216 |
-
freqs = freqs.unsqueeze(0)
|
217 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
218 |
if "rotary" in self.debug:
|
219 |
if f0 is not None:
|
220 |
key = f"{self._counter}_{f0_theta:.2f}"
|
@@ -222,12 +239,13 @@ class rotary(nn.Module):
|
|
222 |
if not hasattr(self, '_prev_f0_theta'):
|
223 |
self._prev_f0_theta = f0_theta
|
224 |
print(f"Step {self._counter}: Using raw F0 as theta: {f0_theta:.2f} Hz")
|
225 |
-
elif abs(self._prev_f0_theta - f0_theta) >
|
226 |
print(f"Step {self._counter}: Using raw F0 as theta: {f0_theta:.2f} Hz")
|
227 |
self._prev_f0_theta = f0_theta
|
228 |
rotary._seen.add(key)
|
229 |
self._counter += 1
|
230 |
-
|
|
|
231 |
|
232 |
@staticmethod
|
233 |
def apply_rotary(x, freqs):
|
@@ -240,11 +258,13 @@ class rotary(nn.Module):
|
|
240 |
x1 = x1 * freqs
|
241 |
x1 = torch.view_as_real(x1).flatten(-2)
|
242 |
return torch.cat([x1.type_as(x), x2], dim=-1)
|
|
|
243 |
else:
|
244 |
x1 = x[..., :freqs.shape[-1]*2]
|
245 |
x2 = x[..., freqs.shape[-1]*2:]
|
246 |
|
247 |
if x.ndim == 2:
|
|
|
248 |
x1 = x1.unsqueeze(0)
|
249 |
x1 = x1.float().reshape(*x1.shape[:-1], -1, 2).contiguous()
|
250 |
x1 = torch.view_as_complex(x1)
|
|
|
133 |
def __init__(self, dims, max_ctx=1500, theta=10000, learned_freq=False, variable_radius=False,
|
134 |
learned_radius=False, learned_theta=False, learned_pitch=False, debug: List[str] = []):
|
135 |
super().__init__()
|
136 |
+
self.use_pbias = False
|
137 |
+
|
138 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
139 |
self.device = device
|
140 |
+
dtype = torch.float32
|
141 |
self.dtype = dtype
|
142 |
self.debug = debug
|
143 |
self._counter = 0
|
144 |
+
self.dims = dims
|
|
|
145 |
self.max_ctx = max_ctx
|
146 |
self.variable_radius = variable_radius
|
147 |
|
148 |
+
self.inv_freq = nn.Parameter(
|
149 |
+
1.0 / (10000 ** (torch.arange(0, dims, 2, device=device, dtype=dtype) / dims)),
|
150 |
requires_grad=learned_freq)
|
151 |
+
self.theta = nn.Parameter(
|
152 |
+
torch.tensor(float(theta)), requires_grad=learned_theta)
|
153 |
+
self.min_theta = nn.Parameter(
|
154 |
+
torch.tensor(600.0), requires_grad=learned_theta)
|
155 |
+
self.max_theta = nn.Parameter(
|
156 |
+
torch.tensor(2400.0), requires_grad=learned_theta)
|
157 |
+
|
158 |
+
self.pitch_scale = nn.Parameter(torch.tensor(1.0),
|
159 |
+
requires_grad=learned_pitch)
|
160 |
|
161 |
if variable_radius:
|
162 |
+
self.radius = nn.Parameter(
|
163 |
+
torch.ones(dims // 2),
|
164 |
+
requires_grad=learned_radius)
|
165 |
|
166 |
def get_pitch_bias(self, f0):
|
167 |
if f0 is None:
|
|
|
200 |
t = torch.arange(x, device=self.device).float()
|
201 |
else:
|
202 |
t = x.float().to(self.inv_freq.device)
|
203 |
+
|
204 |
if f0 is not None:
|
205 |
f0_mean = f0.mean()
|
206 |
+
f0_mean = torch.clamp(f0_mean, min=80.0, max=600.0)
|
207 |
+
perceptual_factor = torch.log(1 + f0_mean / 700.0) / torch.log(torch.tensor(1 + 300.0 / 700.0))
|
208 |
+
f0_theta = self.min_theta + perceptual_factor * (self.max_theta - self.min_theta)
|
209 |
+
inv_freq = 1.0 / (f0_theta ** (torch.arange(0, self.dims, 2, device=self.device) / self.dims))
|
210 |
else:
|
211 |
inv_freq = self.inv_freq
|
212 |
freqs = torch.einsum('i,j->ij', t, inv_freq)
|
213 |
+
|
214 |
freqs = freqs.float()
|
215 |
if self.variable_radius:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
|
217 |
+
# if f0 is not None:
|
218 |
+
# f0 = f0[0]
|
219 |
+
# seq_len = x
|
220 |
+
# f0 = self.align_f0_to_tokens(f0, freqs.shape[-1])
|
221 |
+
# radius = f0
|
222 |
+
|
223 |
+
# freqs = torch.polar(radius, freqs)
|
224 |
+
# else:
|
225 |
+
|
226 |
+
# freqs = torch.polar(torch.ones_like(freqs), freqs)
|
227 |
+
# freqs = freqs.unsqueeze(0)
|
228 |
+
|
229 |
+
radius = F.softplus(self.radius)
|
230 |
+
freqs = torch.polar(radius.unsqueeze(0).expand_as(freqs), freqs)
|
231 |
+
else:
|
232 |
+
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
233 |
+
freqs = freqs.unsqueeze(0)
|
234 |
+
|
235 |
if "rotary" in self.debug:
|
236 |
if f0 is not None:
|
237 |
key = f"{self._counter}_{f0_theta:.2f}"
|
|
|
239 |
if not hasattr(self, '_prev_f0_theta'):
|
240 |
self._prev_f0_theta = f0_theta
|
241 |
print(f"Step {self._counter}: Using raw F0 as theta: {f0_theta:.2f} Hz")
|
242 |
+
elif abs(self._prev_f0_theta - f0_theta) > 200.0:
|
243 |
print(f"Step {self._counter}: Using raw F0 as theta: {f0_theta:.2f} Hz")
|
244 |
self._prev_f0_theta = f0_theta
|
245 |
rotary._seen.add(key)
|
246 |
self._counter += 1
|
247 |
+
|
248 |
+
return freqs
|
249 |
|
250 |
@staticmethod
|
251 |
def apply_rotary(x, freqs):
|
|
|
258 |
x1 = x1 * freqs
|
259 |
x1 = torch.view_as_real(x1).flatten(-2)
|
260 |
return torch.cat([x1.type_as(x), x2], dim=-1)
|
261 |
+
|
262 |
else:
|
263 |
x1 = x[..., :freqs.shape[-1]*2]
|
264 |
x2 = x[..., freqs.shape[-1]*2:]
|
265 |
|
266 |
if x.ndim == 2:
|
267 |
+
|
268 |
x1 = x1.unsqueeze(0)
|
269 |
x1 = x1.float().reshape(*x1.shape[:-1], -1, 2).contiguous()
|
270 |
x1 = torch.view_as_complex(x1)
|