Update modelA.py
Browse files
modelA.py
CHANGED
@@ -146,6 +146,74 @@ class rotary(nn.Module):
|
|
146 |
self.bias = nn.Parameter(torch.zeros(max_ctx, dims // 2))
|
147 |
self.theta = nn.Parameter(torch.tensor(theta, device=device, dtype=dtype), requires_grad=True)
|
148 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
def theta_freqs(self, theta):
|
150 |
freq = (theta / 220.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), self.dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
|
151 |
freqs = nn.Parameter(torch.tensor(freq, device=device, dtype=dtype), requires_grad=True)
|
|
|
146 |
self.bias = nn.Parameter(torch.zeros(max_ctx, dims // 2))
|
147 |
self.theta = nn.Parameter(torch.tensor(theta, device=device, dtype=dtype), requires_grad=True)
|
148 |
|
149 |
+
|
150 |
+
# def theta_freqs(self, theta):
|
151 |
+
# freq = (theta / 220.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), self.dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
|
152 |
+
# freqs = nn.Parameter(torch.tensor(freq, device=device, dtype=dtype), requires_grad=True)
|
153 |
+
# return freqs
|
154 |
+
|
155 |
+
# def mel_geodesic_rotary(f0, theta):
|
156 |
+
# mel_f0 = 1127.0 * torch.log(1.0 + f0 / 700.0)
|
157 |
+
# fisher_info = torch.var(mel_f0) + 1e-8
|
158 |
+
# adaptive_theta = theta * torch.sqrt(fisher_info)
|
159 |
+
# freqs = self.theta_freqs(adaptive_theta)
|
160 |
+
# return freqs
|
161 |
+
|
162 |
+
# def compute_pitch_fisher_info(f0, window_size=10):
|
163 |
+
# if f0.dim() == 1:
|
164 |
+
# f0 = f0.unsqueeze(0)
|
165 |
+
# mel_f0 = 1127.0 * torch.log(1.0 + f0 / 700.0)
|
166 |
+
# fisher_info = torch.nn.functional.avg_pool1d(
|
167 |
+
# mel_f0.unsqueeze(0),
|
168 |
+
# kernel_size=window_size,
|
169 |
+
# stride=1,
|
170 |
+
# padding=window_size//2
|
171 |
+
# ).squeeze(0)
|
172 |
+
# fisher_info = (fisher_info - fisher_info.min()) / (fisher_info.max() - fisher_info.min() + 1e-8)
|
173 |
+
# return fisher_info
|
174 |
+
|
175 |
+
# def compute_advanced_fisher_info(f0, window_size=10):
|
176 |
+
# mel_f0 = 1127.0 * torch.log(1.0 + f0 / 700.0)
|
177 |
+
# local_mean = torch.nn.functional.avg_pool1d(
|
178 |
+
# mel_f0.unsqueeze(0), window_size, 1, window_size//2
|
179 |
+
# ).squeeze(0)
|
180 |
+
|
181 |
+
# local_var = torch.nn.functional.avg_pool1d(
|
182 |
+
# (mel_f0 - local_mean).pow(2).unsqueeze(0),
|
183 |
+
# window_size, 1, window_size//2
|
184 |
+
# ).squeeze(0)
|
185 |
+
|
186 |
+
# fisher_info = 1.0 / (local_var + 1e-8)
|
187 |
+
# return fisher_info
|
188 |
+
|
189 |
+
# def test_fisher_info(self, f0):
|
190 |
+
# """Test Fisher information computation.""" # fisher_info = self.compute_pitch_fisher_info(f0)
|
191 |
+
|
192 |
+
# print(f"f0 range: {f0.min():.1f} - {f0.max():.1f}")
|
193 |
+
# print(f"Fisher info range: {fisher_info.min():.3f} - {fisher_info.max():.3f}")
|
194 |
+
# print(f"Fisher info mean: {fisher_info.mean():.3f}")
|
195 |
+
|
196 |
+
# # Visualize: high Fisher info = meaningful pitch changes
|
197 |
+
# return fisher_info
|
198 |
+
|
199 |
+
# def forward(self, x=None, enc=None, layer=None, feature_type="audio"):
|
200 |
+
|
201 |
+
# if f0 is not None:
|
202 |
+
# # Compute Fisher information
|
203 |
+
# fisher_info = self.compute_pitch_fisher_info(f0)
|
204 |
+
|
205 |
+
# # Use Fisher info to weight pitch influence
|
206 |
+
# f0_weighted = f0 * fisher_info
|
207 |
+
|
208 |
+
# # Apply to both theta and radius
|
209 |
+
# f0_mean = f0_weighted.mean()
|
210 |
+
# theta = f0_mean + self.theta
|
211 |
+
|
212 |
+
# if self.radii:
|
213 |
+
# radius = f0_weighted.to(device, dtype)
|
214 |
+
|
215 |
+
|
216 |
+
|
217 |
def theta_freqs(self, theta):
|
218 |
freq = (theta / 220.0) * 700 * (torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)), self.dim // 2, device=device, dtype=dtype) / 2595) - 1) / 1000
|
219 |
freqs = nn.Parameter(torch.tensor(freq, device=device, dtype=dtype), requires_grad=True)
|