Sin2pi commited on
Commit
3ec2e71
·
verified ·
1 Parent(s): 64f14a0

Update modelA.py

Browse files
Files changed (1) hide show
  1. modelA.py +68 -0
modelA.py CHANGED
@@ -146,6 +146,74 @@ class rotary(nn.Module):
146
  self.bias = nn.Parameter(torch.zeros(max_ctx, dims // 2))
147
  self.theta = nn.Parameter(torch.tensor(theta, device=device, dtype=dtype), requires_grad=True)
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  def theta_freqs(self, theta):
150
  freq = (theta / 220.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), self.dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
151
  freqs = nn.Parameter(torch.tensor(freq, device=device, dtype=dtype), requires_grad=True)
 
146
  self.bias = nn.Parameter(torch.zeros(max_ctx, dims // 2))
147
  self.theta = nn.Parameter(torch.tensor(theta, device=device, dtype=dtype), requires_grad=True)
148
 
149
+
150
+ # def theta_freqs(self, theta):
151
+ # freq = (theta / 220.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), self.dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
152
+ # freqs = nn.Parameter(torch.tensor(freq, device=device, dtype=dtype), requires_grad=True)
153
+ # return freqs
154
+
155
+ # def mel_geodesic_rotary(f0, theta):
156
+ # mel_f0 = 1127.0 * torch.log(1.0 + f0 / 700.0)
157
+ # fisher_info = torch.var(mel_f0) + 1e-8
158
+ # adaptive_theta = theta * torch.sqrt(fisher_info)
159
+ # freqs = self.theta_freqs(adaptive_theta)
160
+ # return freqs
161
+
162
+ # def compute_pitch_fisher_info(f0, window_size=10):
163
+ # if f0.dim() == 1:
164
+ # f0 = f0.unsqueeze(0)
165
+ # mel_f0 = 1127.0 * torch.log(1.0 + f0 / 700.0)
166
+ # fisher_info = torch.nn.functional.avg_pool1d(
167
+ # mel_f0.unsqueeze(0),
168
+ # kernel_size=window_size,
169
+ # stride=1,
170
+ # padding=window_size//2
171
+ # ).squeeze(0)
172
+ # fisher_info = (fisher_info - fisher_info.min()) / (fisher_info.max() - fisher_info.min() + 1e-8)
173
+ # return fisher_info
174
+
175
+ # def compute_advanced_fisher_info(f0, window_size=10):
176
+ # mel_f0 = 1127.0 * torch.log(1.0 + f0 / 700.0)
177
+ # local_mean = torch.nn.functional.avg_pool1d(
178
+ # mel_f0.unsqueeze(0), window_size, 1, window_size//2
179
+ # ).squeeze(0)
180
+
181
+ # local_var = torch.nn.functional.avg_pool1d(
182
+ # (mel_f0 - local_mean).pow(2).unsqueeze(0),
183
+ # window_size, 1, window_size//2
184
+ # ).squeeze(0)
185
+
186
+ # fisher_info = 1.0 / (local_var + 1e-8)
187
+ # return fisher_info
188
+
189
+ # def test_fisher_info(self, f0):
190
+ # """Test Fisher information computation.""" # fisher_info = self.compute_pitch_fisher_info(f0)
191
+
192
+ # print(f"f0 range: {f0.min():.1f} - {f0.max():.1f}")
193
+ # print(f"Fisher info range: {fisher_info.min():.3f} - {fisher_info.max():.3f}")
194
+ # print(f"Fisher info mean: {fisher_info.mean():.3f}")
195
+
196
+ # # Visualize: high Fisher info = meaningful pitch changes
197
+ # return fisher_info
198
+
199
+ # def forward(self, x=None, enc=None, layer=None, feature_type="audio"):
200
+
201
+ # if f0 is not None:
202
+ # # Compute Fisher information
203
+ # fisher_info = self.compute_pitch_fisher_info(f0)
204
+
205
+ # # Use Fisher info to weight pitch influence
206
+ # f0_weighted = f0 * fisher_info
207
+
208
+ # # Apply to both theta and radius
209
+ # f0_mean = f0_weighted.mean()
210
+ # theta = f0_mean + self.theta
211
+
212
+ # if self.radii:
213
+ # radius = f0_weighted.to(device, dtype)
214
+
215
+
216
+
217
  def theta_freqs(self, theta):
218
  freq = (theta / 220.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), self.dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
219
  freqs = nn.Parameter(torch.tensor(freq, device=device, dtype=dtype), requires_grad=True)