Update model.py
Browse files
model.py
CHANGED
@@ -130,38 +130,32 @@ def sinusoids(length, channels, max_timescale=10000):
|
|
130 |
|
131 |
class rotary(nn.Module):
|
132 |
_seen = set()
|
133 |
-
def __init__(self, dims, max_ctx=1500, theta=10000, learned_freq=
|
134 |
learned_radius=False, learned_theta=False, learned_pitch=False, debug: List[str] = []):
|
135 |
super().__init__()
|
136 |
-
|
137 |
-
|
138 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
139 |
self.device = device
|
140 |
-
dtype = torch.float32
|
141 |
self.dtype = dtype
|
142 |
self.debug = debug
|
143 |
self._counter = 0
|
144 |
-
|
|
|
145 |
self.max_ctx = max_ctx
|
146 |
self.variable_radius = variable_radius
|
147 |
|
148 |
-
self.inv_freq = nn.Parameter(
|
149 |
-
1.0 / (19000 ** (torch.arange(0, dims, 2, device=device, dtype=dtype) / dims)),
|
150 |
requires_grad=learned_freq)
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
self.
|
156 |
-
torch.tensor(2400.0), requires_grad=learned_theta)
|
157 |
-
|
158 |
-
self.pitch_scale = nn.Parameter(torch.tensor(1.0),
|
159 |
-
requires_grad=learned_pitch)
|
160 |
|
161 |
if variable_radius:
|
162 |
-
self.radius = nn.Parameter(
|
163 |
-
torch.ones(dims // 2),
|
164 |
-
requires_grad=learned_radius)
|
165 |
|
166 |
def get_pitch_bias(self, f0):
|
167 |
if f0 is None:
|
@@ -189,31 +183,38 @@ class rotary(nn.Module):
|
|
189 |
rotary.get_sim = get_sim
|
190 |
rotary.fwd_sim = fwd_sim
|
191 |
|
192 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
if isinstance(x, int):
|
194 |
t = torch.arange(x, device=self.device).float()
|
195 |
else:
|
196 |
t = x.float().to(self.inv_freq.device)
|
197 |
-
|
198 |
if f0 is not None:
|
199 |
-
|
200 |
f0_mean = f0.mean()
|
201 |
-
|
202 |
-
|
203 |
-
f0_theta = self.theta + perceptual_factor * (max_theta - min_theta)
|
204 |
-
inv_freq = 1.0 / (f0_theta ** (torch.arange(0, self.dims, 2, device=self.device) / self.dims))
|
205 |
else:
|
206 |
inv_freq = self.inv_freq
|
207 |
freqs = torch.einsum('i,j->ij', t, inv_freq)
|
208 |
-
|
209 |
freqs = freqs.float()
|
210 |
if self.variable_radius:
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
freqs = freqs.unsqueeze(0)
|
216 |
-
|
217 |
if "rotary" in self.debug:
|
218 |
if f0 is not None:
|
219 |
key = f"{self._counter}_{f0_theta:.2f}"
|
@@ -221,13 +222,12 @@ class rotary(nn.Module):
|
|
221 |
if not hasattr(self, '_prev_f0_theta'):
|
222 |
self._prev_f0_theta = f0_theta
|
223 |
print(f"Step {self._counter}: Using raw F0 as theta: {f0_theta:.2f} Hz")
|
224 |
-
elif abs(self._prev_f0_theta - f0_theta) >
|
225 |
print(f"Step {self._counter}: Using raw F0 as theta: {f0_theta:.2f} Hz")
|
226 |
self._prev_f0_theta = f0_theta
|
227 |
rotary._seen.add(key)
|
228 |
self._counter += 1
|
229 |
-
|
230 |
-
return freqs
|
231 |
|
232 |
@staticmethod
|
233 |
def apply_rotary(x, freqs):
|
@@ -240,13 +240,11 @@ class rotary(nn.Module):
|
|
240 |
x1 = x1 * freqs
|
241 |
x1 = torch.view_as_real(x1).flatten(-2)
|
242 |
return torch.cat([x1.type_as(x), x2], dim=-1)
|
243 |
-
|
244 |
else:
|
245 |
x1 = x[..., :freqs.shape[-1]*2]
|
246 |
x2 = x[..., freqs.shape[-1]*2:]
|
247 |
|
248 |
if x.ndim == 2:
|
249 |
-
|
250 |
x1 = x1.unsqueeze(0)
|
251 |
x1 = x1.float().reshape(*x1.shape[:-1], -1, 2).contiguous()
|
252 |
x1 = torch.view_as_complex(x1)
|
@@ -260,7 +258,7 @@ class rotary(nn.Module):
|
|
260 |
x1 = x1 * freqs
|
261 |
x1 = torch.view_as_real(x1).flatten(-2)
|
262 |
return torch.cat([x1.type_as(x), x2], dim=-1)
|
263 |
-
|
264 |
class SliceAttention(nn.Module):
|
265 |
def __init__(self, dims, heads, dropout=0.0):
|
266 |
super().__init__()
|
|
|
130 |
|
131 |
class rotary(nn.Module):
|
132 |
_seen = set()
|
133 |
+
def __init__(self, dims, max_ctx=1500, theta=10000, learned_freq=False, variable_radius=False,
|
134 |
learned_radius=False, learned_theta=False, learned_pitch=False, debug: List[str] = []):
|
135 |
super().__init__()
|
136 |
+
|
137 |
+
self.dims = dims
|
138 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
139 |
+
dtype = torch.float32
|
140 |
self.device = device
|
|
|
141 |
self.dtype = dtype
|
142 |
self.debug = debug
|
143 |
self._counter = 0
|
144 |
+
|
145 |
+
self.use_pbias = False
|
146 |
self.max_ctx = max_ctx
|
147 |
self.variable_radius = variable_radius
|
148 |
|
149 |
+
self.inv_freq = nn.Parameter(1.0 / (theta ** (torch.arange(0, dims, 2, device=device, dtype=dtype) / dims)),
|
|
|
150 |
requires_grad=learned_freq)
|
151 |
+
|
152 |
+
self.theta = nn.Parameter(torch.tensor(float(theta)),
|
153 |
+
requires_grad=learned_theta)
|
154 |
+
|
155 |
+
self.pitch_scale = nn.Parameter(torch.tensor(1.0), requires_grad=learned_pitch)
|
|
|
|
|
|
|
|
|
156 |
|
157 |
if variable_radius:
|
158 |
+
self.radius = nn.Parameter(torch.ones(dims // 2), requires_grad=learned_radius)
|
|
|
|
|
159 |
|
160 |
def get_pitch_bias(self, f0):
|
161 |
if f0 is None:
|
|
|
183 |
rotary.get_sim = get_sim
|
184 |
rotary.fwd_sim = fwd_sim
|
185 |
|
186 |
+
def align_f0_to_tokens(self, f0, token_length):
|
187 |
+
ratio = len(f0) / token_length
|
188 |
+
indices = [int(i * ratio) for i in range(token_length)]
|
189 |
+
indices = [min(i, len(f0) - 1) for i in indices]
|
190 |
+
return f0[indices]
|
191 |
+
|
192 |
+
def forward(self, x=None, f0=None, stage=None) -> Tensor:
|
193 |
if isinstance(x, int):
|
194 |
t = torch.arange(x, device=self.device).float()
|
195 |
else:
|
196 |
t = x.float().to(self.inv_freq.device)
|
|
|
197 |
if f0 is not None:
|
|
|
198 |
f0_mean = f0.mean()
|
199 |
+
f0_theta = f0_mean * (f0_mean / self.theta) * self.theta * self.pitch_scale
|
200 |
+
inv_freq = 1.0 / (f0_theta ** (torch.arange(0, self.dims, 2, device=self.device) / self.dims))
|
|
|
|
|
201 |
else:
|
202 |
inv_freq = self.inv_freq
|
203 |
freqs = torch.einsum('i,j->ij', t, inv_freq)
|
|
|
204 |
freqs = freqs.float()
|
205 |
if self.variable_radius:
|
206 |
+
|
207 |
+
if f0 is not None:
|
208 |
+
f0 = f0[0]
|
209 |
+
seq_len = x
|
210 |
+
f0 = torch.tensor(f0, device=x.device if isinstance(x, torch.Tensor) else device)
|
211 |
+
f0 = self.align_f0_to_tokens(f0, freqs.shape[-1])
|
212 |
+
radius = 1.0 / (f0 + 1)
|
213 |
+
freqs = torch.polar(radius, freqs)
|
214 |
+
else:
|
215 |
+
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
216 |
freqs = freqs.unsqueeze(0)
|
217 |
+
|
218 |
if "rotary" in self.debug:
|
219 |
if f0 is not None:
|
220 |
key = f"{self._counter}_{f0_theta:.2f}"
|
|
|
222 |
if not hasattr(self, '_prev_f0_theta'):
|
223 |
self._prev_f0_theta = f0_theta
|
224 |
print(f"Step {self._counter}: Using raw F0 as theta: {f0_theta:.2f} Hz")
|
225 |
+
elif abs(self._prev_f0_theta - f0_theta) > 0.0:
|
226 |
print(f"Step {self._counter}: Using raw F0 as theta: {f0_theta:.2f} Hz")
|
227 |
self._prev_f0_theta = f0_theta
|
228 |
rotary._seen.add(key)
|
229 |
self._counter += 1
|
230 |
+
return freqs
|
|
|
231 |
|
232 |
@staticmethod
|
233 |
def apply_rotary(x, freqs):
|
|
|
240 |
x1 = x1 * freqs
|
241 |
x1 = torch.view_as_real(x1).flatten(-2)
|
242 |
return torch.cat([x1.type_as(x), x2], dim=-1)
|
|
|
243 |
else:
|
244 |
x1 = x[..., :freqs.shape[-1]*2]
|
245 |
x2 = x[..., freqs.shape[-1]*2:]
|
246 |
|
247 |
if x.ndim == 2:
|
|
|
248 |
x1 = x1.unsqueeze(0)
|
249 |
x1 = x1.float().reshape(*x1.shape[:-1], -1, 2).contiguous()
|
250 |
x1 = torch.view_as_complex(x1)
|
|
|
258 |
x1 = x1 * freqs
|
259 |
x1 = torch.view_as_real(x1).flatten(-2)
|
260 |
return torch.cat([x1.type_as(x), x2], dim=-1)
|
261 |
+
|
262 |
class SliceAttention(nn.Module):
|
263 |
def __init__(self, dims, heads, dropout=0.0):
|
264 |
super().__init__()
|