Sin2pi commited on
Commit
034db85
Β·
verified Β·
1 Parent(s): 0f56c7e

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +824 -11
README.md CHANGED
@@ -20,7 +20,6 @@ tags:
20
  - new
21
 
22
  ---
23
- ----
24
  NLPASR multimodal modal with f0 modulated relative positional embeddings.
25
  For researchtesting.
26
 
@@ -32,15 +31,13 @@ Questions:
32
  -Can we incorporate acoustic information directly into positional encodings?
33
 
34
  -Does pitch-conditioning improve speech recognition?
35
-
36
-
37
  ---
38
 
39
 
40
 
41
  <img width="780" alt="cc5" src="https:github.comuser-attachmentsassets106ebe75-f1db-4f85-bdae-818b114fedd2" >
42
 
43
- This plot illustrates the pattern similiarity of pitch waveform and spectrogram. librispeech - clean.
44
 
45
  To explore the relationship between pitch and rotary embeddings, the model implements three complementary pitch based enhancements:
46
 
@@ -74,15 +71,11 @@ if f0 is not None:
74
  else:
75
  theta = self.theta
76
 
77
- ## In text, theta=10,000 sets the base frequency for positional encoding, ensuring a wide range of periodicities for long sequences. I'm not sure if the specific number 10k was experimentally derived.
78
- ## For audio, especially speech, the relevant periodicities are determined by the pitch f0 neighborhood or f0 per frame might be more meaningful.
79
 
80
  freqs = theta.unsqueeze-1 220.0 * 700 *
81
  torch.pow10, torch.linspace0, 2595 * torch.log10torch.tensor1 + 8000700,
82
  self.dim 2, device=theta.device, dtype=theta.dtype 2595 - 1 1000
83
 
84
- ## This seems to give better results compared to the standard freqs = 1. theta torch.arange0, dim, 2[:dim 2].float dim.
85
- ## I thought a mel-scale version might be more perceptually meaningful for audio.. ie. using mel-scale to create a perceptually-relevant distance metric instead of Euclidean distance.
86
 
87
  t = torch.arangectx, device=device, dtype=dtype
88
  freqs = t[:, None] * freqs # dont repeat or use some other method here
@@ -245,19 +238,19 @@ The Complex Frequency Result:
245
  [Freqs] torch.Size[454, 64] 2.17+1.17j
246
 
247
 
 
248
  Magnitude: sqrt2.17Β² + 1.17Β² β‰ˆ 2.5
249
  Phase: atan21.17, 2.17 β‰ˆ 0.49 radians
250
 
251
  Variable radius: Each frame has different magnitude
252
 
253
 
 
254
  Silence frames: radius β‰ˆ 0 β†’ freqs β‰ˆ 0
255
  Voiced frames: radius β‰ˆ 200-300 β†’ freqs β‰ˆ 2-3
256
 
257
  Variable attention: Important frames get more attention
258
 
259
-
260
-
261
  Silence: No acoustic prominence β†’ low radius
262
  Speech: High acoustic prominence β†’ high radius
263
  Transitions: Natural pitch changes
@@ -288,6 +281,11 @@ Approximation methods like using cossin projections or fixed rotation matrices t
288
  ```
289
  This approach respects both the rotation phase and the scaling radius for each tokenhead, so the rotary embedding is applied when the radius varies.
290
 
 
 
 
 
 
291
 
292
  ----
293
 
@@ -295,4 +293,819 @@ This model sometimes uses :
295
 
296
  https:github.comsine2piMaxfactor
297
 
298
- MaxFactor is a custom PyTorch optimizer with adaptive learning rates and specialized handling for matrix parameters. I wrote it for the model in the asr_model repository. I needed something that performs well and has a light memory foot print since I do everything from my laptop.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  - new
21
 
22
  ---
 
23
  NLPASR multimodal modal with f0 modulated relative positional embeddings.
24
  For researchtesting.
25
 
 
31
  -Can we incorporate acoustic information directly into positional encodings?
32
 
33
  -Does pitch-conditioning improve speech recognition?
34
+
 
35
  ---
36
 
37
 
38
 
39
  <img width="780" alt="cc5" src="https:github.comuser-attachmentsassets106ebe75-f1db-4f85-bdae-818b114fedd2" >
40
 
 
41
 
42
  To explore the relationship between pitch and rotary embeddings, the model implements three complementary pitch based enhancements:
43
 
 
71
  else:
72
  theta = self.theta
73
 
 
 
74
 
75
  freqs = theta.unsqueeze-1 220.0 * 700 *
76
  torch.pow10, torch.linspace0, 2595 * torch.log10torch.tensor1 + 8000700,
77
  self.dim 2, device=theta.device, dtype=theta.dtype 2595 - 1 1000
78
 
 
 
79
 
80
  t = torch.arangectx, device=device, dtype=dtype
81
  freqs = t[:, None] * freqs # dont repeat or use some other method here
 
238
  [Freqs] torch.Size[454, 64] 2.17+1.17j
239
 
240
 
241
+
242
  Magnitude: sqrt2.17Β² + 1.17Β² β‰ˆ 2.5
243
  Phase: atan21.17, 2.17 β‰ˆ 0.49 radians
244
 
245
  Variable radius: Each frame has different magnitude
246
 
247
 
248
+
249
  Silence frames: radius β‰ˆ 0 β†’ freqs β‰ˆ 0
250
  Voiced frames: radius β‰ˆ 200-300 β†’ freqs β‰ˆ 2-3
251
 
252
  Variable attention: Important frames get more attention
253
 
 
 
254
  Silence: No acoustic prominence β†’ low radius
255
  Speech: High acoustic prominence β†’ high radius
256
  Transitions: Natural pitch changes
 
281
  ```
282
  This approach respects both the rotation phase and the scaling radius for each tokenhead, so the rotary embedding is applied when the radius varies.
283
 
284
+ <img width="780" alt="cc4" src="https:github.comuser-attachmentsassets165a3f18-659a-4e2e-a154-a3456b667bae" >
285
+
286
+
287
+ ----
288
+ [https:huggingface.coSin2piEcho17tensorboard?params=scalars](https://huggingface.co/Sin2pi/Echo3/tensorboard?params=scalars)
289
 
290
  ----
291
 
 
293
 
294
  https:github.comsine2piMaxfactor
295
 
296
+ MaxFactor is a custom PyTorch optimizer with adaptive learning rates and specialized handling for matrix parameters.
297
+
298
+ ** this model deviates in a lot of ways from standard transformer models.
299
+
300
+
301
+ ```python
302
+ import os
303
+ import math
304
+ import warnings
305
+ import logging
306
+ from itertools import chain
307
+ import torch
308
+ import torch.nn.functional as F
309
+ from torch import nn, Tensor
310
+ from tensordict import TensorDict
311
+ from typing import Optional, Dict, Union, List, Tuple
312
+ import numpy as np
313
+ from functools import partial
314
+ from datetime import datetime
315
+ from tensordict import TensorDict
316
+ from transformers.trainer_seq2seq import Seq2SeqTrainer
317
+ from transformers.training_args_seq2seq import Seq2SeqTrainingArguments
318
+ from echoutils import *
319
+
320
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
321
+ dtype = torch.float32
322
+ warnings.filterwarnings("ignore")
323
+ logging.basicConfig(level=logging.ERROR)
324
+
325
+ class rotary(nn.Module):
326
+ def __init__(self, dims, head, max_ctx=1500, radii=False, debug: List[str] = [], use_pbias=False, axial=False, spec_shape=None):
327
+
328
+ super(rotary, self).__init__()
329
+ self.use_pbias = use_pbias
330
+ self.dims = dims
331
+ self.head = head
332
+ self.head_dim = dims // head
333
+ self.radii = radii
334
+ self.debug = debug
335
+ self.counter = 0
336
+ self.last_theta = None
337
+ self.axial = axial
338
+
339
+ self.bias = nn.Parameter(torch.zeros(max_ctx, dims // 2), requires_grad=True if use_pbias else False)
340
+ theta = (torch.tensor(10000, device=device, dtype=dtype))
341
+ self.theta = nn.Parameter(theta, requires_grad=True)
342
+ self.theta_values = []
343
+
344
+ if axial and spec_shape is not None:
345
+ time_frames, freq_bins = spec_shape
346
+ self.time_frames = time_frames
347
+ self.freq_bins = freq_bins
348
+
349
+ time_theta = 50.0
350
+ time_freqs = 1.0 / (time_theta ** (torch.arange(0, dims, 4)[:(dims // 4)].float() / dims))
351
+ self.register_buffer('time_freqs', time_freqs)
352
+
353
+ freq_theta = 100.0
354
+ freq_freqs = 1.0 / (freq_theta ** (torch.arange(0, dims, 4)[:(dims // 4)].float() / dims))
355
+ self.register_buffer('freq_freqs', freq_freqs)
356
+
357
+ def pitch_bias(self, f0):
358
+ if f0 is None:
359
+ return None
360
+ f0_flat = f0.squeeze().float()
361
+ f0_norm = (f0_flat - f0_flat.mean()) / (f0_flat.std() + 1e-8)
362
+ f0_sim = torch.exp(-torch.cdist(f0_norm.unsqueeze(1),
363
+ f0_norm.unsqueeze(1)))
364
+ return f0_sim.unsqueeze(0).unsqueeze(0)
365
+
366
+ def theta_freqs(self, theta):
367
+ if theta.dim() == 0:
368
+ theta = theta.unsqueeze(0)
369
+ freq = (theta.unsqueeze(-1) / 220.0) * 700 * (
370
+ torch.pow(10, torch.linspace(0, 2595 * torch.log10(torch.tensor(1 + 8000/700)),
371
+ self.head_dim // 2, device=theta.device, dtype=theta.dtype) / 2595) - 1) / 1000
372
+ return freq
373
+
374
+ def _apply_radii(self, freqs, f0, ctx):
375
+ if self.radii and f0 is not None:
376
+ radius = f0.to(device, dtype)
377
+ L = radius.shape[0]
378
+ if L != ctx:
379
+ F = L / ctx
380
+ idx = torch.arange(ctx, device=f0.device)
381
+ idx = (idx * F).long().clamp(0, L - 1)
382
+ radius = radius[idx]
383
+ return torch.polar(radius.unsqueeze(-1), freqs), radius
384
+ else:
385
+ return torch.polar(radius.unsqueeze(-1), freqs), radius
386
+ else:
387
+ return torch.polar(torch.ones_like(freqs), freqs), None
388
+
389
+ def check_f0(self, f0, f0t, ctx):
390
+ if f0 is not None and f0.shape[1] == ctx:
391
+ return f0
392
+ elif f0t is not None and f0t.shape[1] == ctx:
393
+ return f0t
394
+ else:
395
+ return None
396
+
397
+ def axial_freqs(self, ctx):
398
+ if not self.axial:
399
+ return None
400
+ time_frames = self.time_frames
401
+ freq_bins = self.freq_bins
402
+
403
+ t = torch.arange(ctx, device=device, dtype=dtype)
404
+ t_x = (t % time_frames).float()
405
+ t_y = torch.div(t, time_frames, rounding_mode='floor').float()
406
+ freqs_x = torch.outer(t_x, self.time_freqs)
407
+ freqs_y = torch.outer(t_y, self.freq_freqs)
408
+ freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
409
+ freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
410
+ return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)
411
+
412
+ def forward(self, x=None, en=None, f=None, layer=None) -> Tensor:
413
+ ctx=x
414
+ f0 = en.get("f0") if en is not None else None
415
+ f0t = en.get("f0t") if en is not None else None
416
+
417
+ f0 = self.check_f0(f0, f0t, ctx)
418
+ if f0 is not None:
419
+ if f0.dim() == 2:
420
+ f0 = f0.squeeze(0)
421
+ theta = f0 + self.theta
422
+ else:
423
+ theta = self.theta
424
+ freqs = self.theta_freqs(theta)
425
+ t = torch.arange(ctx, device=device, dtype=dtype)
426
+ freqs = t[:, None] * freqs
427
+ freqs, radius = self._apply_radii(freqs, f0, ctx)
428
+
429
+ if self.axial and f == "spectrogram":
430
+ freqs_2d = self.axial_freqs(ctx)
431
+ if freqs_2d is not None:
432
+ return freqs_2d.unsqueeze(0)
433
+
434
+ if "radius" in self.debug and self.counter == 10:
435
+ print(f" [{layer}] [Radius] {radius.shape if radius is not None else None} {radius.mean() if radius is not None else None} [Theta] {theta.mean() if theta is not None else None} [f0] {f0.shape if f0 is not None else None} [Freqs] {freqs.shape} {freqs.mean():.2f} [ctx] {ctx}")
436
+ self.counter += 1
437
+ return freqs.unsqueeze(0)
438
+
439
+ @staticmethod
440
+ def apply_rotary(x, freqs):
441
+ x1 = x[..., :freqs.shape[-1]*2]
442
+ x2 = x[..., freqs.shape[-1]*2:]
443
+ orig_shape = x1.shape
444
+ if x1.ndim == 2:
445
+ x1 = x1.unsqueeze(0)
446
+ x1 = x1.float().reshape(*x1.shape[:-1], -1, 2).contiguous()
447
+ x1 = torch.view_as_complex(x1) * freqs
448
+ x1 = torch.view_as_real(x1).flatten(-2)
449
+ x1 = x1.view(orig_shape)
450
+ return torch.cat([x1.type_as(x), x2], dim=-1)
451
+
452
+ class MultiheadA(nn.Module):
453
+
454
+ rbf = False
455
+ def __init__(self, dims: int, head: int, rotary_emb: bool = True,
456
+ zero_val: float = 1e-7, minz: float = 1e-8, maxz: float = 1e-6, debug: List[str] = [], optim_attn=False, use_pbias=False):
457
+ super(MultiheadA, self).__init__()
458
+
459
+ self.dims = dims
460
+ self.head = head
461
+ self.head_dim = dims // head
462
+ self.debug = debug
463
+ self.counter = 0
464
+ self.use_pbias = use_pbias
465
+
466
+ self.q = nn.Linear(dims, dims).to(device, dtype)
467
+ self.k = nn.Linear(dims, dims, bias=False).to(device, dtype)
468
+ self.v = nn.Linear(dims, dims).to(device, dtype)
469
+ self.o = nn.Linear(dims, dims).to(device, dtype)
470
+
471
+ self.pad_token = 0
472
+ self.rotary_emb = rotary_emb
473
+ self.minz = minz
474
+ self.maxz = maxz
475
+ self.zero_val = zero_val
476
+ self.optim_attn = optim_attn
477
+ self.fzero = nn.Parameter(torch.tensor(zero_val, device=device, dtype=dtype), requires_grad=False)
478
+
479
+ if rotary_emb:
480
+ self.rope = rotary(
481
+ dims=dims,
482
+ head=head,
483
+ debug=debug,
484
+ radii=False,
485
+ )
486
+ else:
487
+ self.rope = None
488
+
489
+ def cos_sim(self, q: Tensor, k: Tensor, v: Tensor, mask) -> Tensor:
490
+ q_norm = torch.nn.functional.normalize(q, dim=-1, eps=1e-12)
491
+ k_norm = torch.nn.functional.normalize(k, dim=-1, eps=1e-12)
492
+ qk_cosine = torch.matmul(q_norm, k_norm.transpose(-1, -2))
493
+ qk_cosine = qk_cosine + mask
494
+ weights = F.softmax(qk_cosine, dim=-1)
495
+ out = torch.matmul(weights, v)
496
+ return out
497
+
498
+ def rbf_scores(self, q, k, rbf_sigma=1.0, rbf_ratio=0.0):
499
+ scale = (self.dims // self.head) ** -0.25
500
+ dot_scores = torch.matmul(q, k.transpose(-1, -2)) * scale
501
+ if rbf_ratio <= 0.0:
502
+ return dot_scores
503
+ q_norm = q.pow(2).sum(dim=-1, keepdim=True)
504
+ k_norm = k.pow(2).sum(dim=-1, keepdim=True)
505
+ qk = torch.matmul(q, k.transpose(-1, -2))
506
+ dist_sq = q_norm + k_norm.transpose(-1, -2) - 2 * qk
507
+ rbf_scores = torch.exp(-dist_sq / (2 * rbf_sigma**2))
508
+ return (1 - rbf_ratio) * dot_scores + rbf_ratio * rbf_scores
509
+
510
+ def forward(self, x: Tensor, xa = None, mask = None, en= None, layer = None, f=None) -> tuple:
511
+
512
+ x = x.to(device, dtype)
513
+ if xa is not None:
514
+ xa = xa.to(device, dtype)
515
+ scale = (self.dims // self.head) ** -0.25
516
+
517
+ z = default(xa, x).to(device, dtype)
518
+ q = self.q(x)
519
+ k = self.k(z)
520
+ v = self.v(z)
521
+
522
+ if self.rotary_emb:
523
+ q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
524
+ k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
525
+ v = v.view(*v.shape[:2], self.head, -1).permute(0, 2, 1, 3)
526
+ q2 = q.shape[2]
527
+ k2 = k.shape[2]
528
+
529
+ q = self.rope.apply_rotary(q, (self.rope(x=q2, en=en, f=f, layer=layer)))
530
+ k = self.rope.apply_rotary(k, (self.rope(x=k2, en=en, f=f, layer=layer)))
531
+ else:
532
+ q = q.view(*q.shape[:2], self.head, -1).permute(0, 2, 1, 3)
533
+ k = k.view(*k.shape[:2], self.head, -1).permute(0, 2, 1, 3)
534
+ v = v.view(*v.shape[:2], self.head, -1).permute(0, 2, 1, 3)
535
+
536
+ qk = (q * scale) @ (k * scale).transpose(-1, -2)
537
+
538
+ if self.rbf:
539
+ qk = self.rbf_scores(q * scale, k * scale, rbf_sigma=1.0, rbf_ratio=0.3)
540
+ if self.use_pbias:
541
+ pbias = self.rope.pitch_bias(f0 = en.get("f0", None) if en is not None else None)
542
+ if pbias is not None:
543
+ qk = qk + pbias[:,:,:q2,:q2]
544
+
545
+ token_ids = k[:, :, :, 0]
546
+ zscale = torch.ones_like(token_ids)
547
+ fzero = torch.clamp(F.softplus(self.fzero), self.minz, self.maxz)
548
+ zscale[token_ids.float() == self.pad_token] = fzero
549
+
550
+ if mask is not None:
551
+ if mask.dim() == 4:
552
+ mask = mask[0, 0]
553
+ mask = mask[:q2, :k2] if xa is not None else mask[:q2, :q2]
554
+ qk = qk + mask * zscale.unsqueeze(-2).expand(qk.shape)
555
+
556
+ qk = qk * zscale.unsqueeze(-2)
557
+ w = F.softmax(qk, dim=-1).to(q.dtype)
558
+ wv = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
559
+
560
+ if "multihead" in self.debug and self.counter % 100 == 0:
561
+ print(f"MHA: q={q.shape}, k={k.shape}, v={v.shape} - {qk.shape}, wv shape: {wv.shape}")
562
+ self.counter += 1
563
+ return self.o(wv), qk
564
+
565
+ @staticmethod
566
+ def split(X: Tensor) -> (Tensor, Tensor):
567
+ half_dim = X.shape[-1] // 2
568
+ return X[..., :half_dim], X[..., half_dim:]
569
+
570
+ class t_gate(nn.Module):
571
+ def __init__(self, dims, num_types=4, enabled=True):
572
+ super().__init__()
573
+ self.enabled = enabled
574
+ self.gate_projections = nn.ModuleList([
575
+ nn.Sequential(Linear(dims, 1), nn.Sigmoid())
576
+ for _ in range(num_types)])
577
+ self.type_classifier = nn.Sequential(
578
+ Linear(dims, num_types),
579
+ nn.Softmax(dim=-1))
580
+ def forward(self, x):
581
+ if not self.enabled:
582
+ return None
583
+ type_probs = self.type_classifier(x)
584
+ gates = torch.stack([gate(x) for gate in self.gate_projections], dim=-1)
585
+ comb_gate = torch.sum(gates * type_probs.unsqueeze(2), dim=-1)
586
+ return comb_gate
587
+
588
+ class m_gate(nn.Module):
589
+ def __init__(self, dims, mem_size=64, enabled=True):
590
+ super().__init__()
591
+ self.enabled = enabled
592
+ if enabled:
593
+ self.m_key = nn.Parameter(torch.randn(mem_size, dims))
594
+ self.m_val = nn.Parameter(torch.randn(mem_size, 1))
595
+ self.gate_proj = nn.Sequential(Linear(dims, dims//2), nn.SiLU(), Linear(dims//2, 1))
596
+
597
+ def forward(self, x):
598
+ if not self.enabled:
599
+ return None
600
+ d_gate = torch.sigmoid(self.gate_proj(x))
601
+ attention = torch.matmul(x, self.m_key.transpose(0, 1))
602
+ attention = F.softmax(attention / math.sqrt(x.shape[-1]), dim=-1)
603
+ m_gate = torch.matmul(attention, self.m_val)
604
+ m_gate = torch.sigmoid(m_gate)
605
+ return 0.5 * (d_gate + m_gate)
606
+
607
+ class c_gate(nn.Module):
608
+ def __init__(self, dims, enabled=True):
609
+ super().__init__()
610
+ self.enabled = enabled
611
+ if enabled:
612
+ self.s_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
613
+ self.w_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
614
+ self.p_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
615
+ self.e_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
616
+ self.ph_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
617
+ self.integ = Linear(dims*5, dims)
618
+
619
+ def forward(self, x, features):
620
+ if not self.enabled:
621
+ return None
622
+ s_feat = features.get("spectrogram", x)
623
+ w_feat = features.get("waveform", x)
624
+ p_feat = features.get("pitch", x)
625
+ e_feat = features.get("envelope", x)
626
+ ph_feat = features.get("phase", x)
627
+ s = self.s_gate(x) * s_feat
628
+ w = self.w_gate(x) * w_feat
629
+ p = self.p_gate(x) * p_feat
630
+ e = self.e_gate(x) * e_feat
631
+ ph = self.ph_gate(x) * ph_feat
632
+ comb = torch.cat([s, w, p, e, ph], dim=-1)
633
+ return self.integ(comb)
634
+
635
+ class mlp_gate(nn.Module):
636
+ def __init__(self, dims, head, enabled=True, one_shot=True):
637
+ super().__init__()
638
+ self.enabled = enabled
639
+ if enabled:
640
+ self.gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
641
+
642
+ def forward(self, x, xa=None, f=None):
643
+ if not self.enabled:
644
+ return None
645
+ return self.gate(x)
646
+
647
+ class Residual(nn.Module):
648
+ _seen = set()
649
+ def __init__(self, ctx, dims, head, act, debug: List[str] = [],
650
+ tgate=True, mgate=False, cgate=False, mem_size=512, features=None, one_shot=False):
651
+ super().__init__()
652
+
653
+ self.dims = dims
654
+ self.head = head
655
+ self.ctx = ctx
656
+ self.head_dim = dims // head
657
+ self.features = features
658
+ self.debug = debug
659
+ self.counter = 0
660
+ self.dropout = 0.01
661
+ self.one_shot = one_shot
662
+
663
+ self.blend = nn.Parameter(torch.tensor(0.5))
664
+ act_fn = get_activation(act)
665
+ self.attn = MultiheadA(dims, head, rotary_emb=True, debug=debug)
666
+ self.curiosity = curiosity(dims, head)
667
+
668
+ if not any([tgate, mgate, cgate]):
669
+ self.mlp_gate = nn.Sequential(Linear(dims, 1), nn.Sigmoid())
670
+ else:
671
+ self.mlp_gate = None
672
+
673
+ mlp = dims * 4
674
+ self.mlp = nn.Sequential(Linear(dims, mlp), act_fn, Linear(mlp, dims))
675
+
676
+ self.t_gate = t_gate(dims=dims, num_types=4*2, enabled=tgate)
677
+ self.m_gate = m_gate(dims=dims, mem_size=mem_size, enabled=mgate)
678
+ self.c_gate = c_gate(dims=dims, enabled=cgate)
679
+ self.mlp_gate = mlp_gate(dims=dims, head=head, enabled=not any([tgate, mgate, cgate]), one_shot=True)
680
+
681
+ self.lna = RMSNorm(dims)
682
+ self.lnb = RMSNorm(dims)
683
+ self.lnc = RMSNorm(dims)
684
+
685
+ def forward(self, x, xa=None, mask=None, en=None, layer=None, f=None) -> Tensor:
686
+
687
+ b = torch.sigmoid(self.blend)
688
+ ax = x + self.attn(self.lna(x), xa=xa, mask=mask, en=en, layer=layer, f=f)[0]
689
+ bx = b * ax + (1 - b) * x
690
+ cx = self.lnb(bx)
691
+ dx = self.mlp(cx)
692
+ ex = self.t_gate(cx) if not None else self.default(self.m_gate(cx), self.mlp_gate(cx))
693
+ fx = x + ex + dx
694
+ gx = self.lnc(fx)
695
+ return gx
696
+
697
+ class OneShot(nn.Module):
698
+ def __init__(self, dims: int, head: int, scale: float = 0.3):
699
+ super().__init__()
700
+ self.head = head
701
+ self.hdim = dims // head
702
+ self.scale = scale
703
+ self.q_proj = Linear(dims, dims)
704
+ self.k_proj = Linear(dims, dims)
705
+
706
+ def forward(self, x: Tensor, guide: Tensor, f=None) -> Tensor | None:
707
+ B, Q, _ = x.shape
708
+ K = guide.size(1)
709
+ q = self.q_proj(x ).view(B, Q, self.head, self.hdim).transpose(1,2)
710
+ k = self.k_proj(guide).view(B, K, self.head, self.hdim).transpose(1,2)
711
+ bias = (q @ k.transpose(-1, -2)) * self.scale / math.sqrt(self.hdim)
712
+ return bias
713
+
714
+ class curiosity(nn.Module):
715
+ def __init__(self, d, h, bias=True):
716
+ super().__init__()
717
+ self.h = h
718
+ self.dh = d // h
719
+ self.qkv = nn.Linear(d, d * 3, bias=bias)
720
+ self.qkv_aux = nn.Linear(d, d * 3, bias=bias)
721
+ self.o = nn.Linear(d, d, bias=bias)
722
+ self.g = nn.Parameter(torch.zeros(h))
723
+
724
+ def split(self, x):
725
+ b, t, _ = x.shape
726
+ return x.view(b, t, self.h, self.dh).transpose(1, 2)
727
+
728
+ def merge(self, x):
729
+ b, h, t, dh = x.shape
730
+ return x.transpose(1, 2).contiguous().view(b, t, h * dh)
731
+
732
+ def forward(self, x, xa, mask=None):
733
+ q, k, v = self.qkv(x).chunk(3, -1)
734
+ qa, ka, va = self.qkv_aux(xa).chunk(3, -1)
735
+ q, k, v = map(self.split, (q, k, v))
736
+ qa, ka, va = map(self.split, (qa, ka, va))
737
+ dots = (q @ k.transpose(-2, -1)) / self.dh**0.5
738
+ dots_aux = (q @ ka.transpose(-2, -1)) / self.dh**0.5
739
+ if mask is not None: dots = dots.masked_fill(mask, -9e15)
740
+ p = dots.softmax(-1)
741
+ pa = dots_aux.softmax(-1)
742
+ h_main = p @ v
743
+ h_aux = pa @ va
744
+ g = torch.sigmoid(self.g).view(1, -1, 1, 1)
745
+ out = self.merge(h_main * (1 - g) + h_aux * g)
746
+ return self.o(out)
747
+
748
+ class PositionalEncoding(nn.Module):
749
+ def __init__(self, dims, ctx):
750
+ super(PositionalEncoding, self).__init__()
751
+ self.dims = dims
752
+ self.ctx = ctx
753
+ self.pe = self.get_positional_encoding(max_ctx=ctx)
754
+
755
+ def get_positional_encoding(self, max_ctx):
756
+ pe = torch.zeros(max_ctx, self.dims)
757
+ position = torch.arange(0, max_ctx, dtype=torch.float32).unsqueeze(1)
758
+ div_term = torch.exp(
759
+ torch.arange(0, self.dims, 2, dtype=torch.float32)
760
+ * (-math.log(10000.0) / self.dims)
761
+ )
762
+ pe[:, 0::2] = torch.sin(position * div_term)
763
+ pe[:, 1::2] = torch.cos(position * div_term)
764
+ pe = pe.unsqueeze(0)
765
+ return pe.to(device)
766
+
767
+ def forward(self, x):
768
+ ctx = x.size(1)
769
+ pe = self.pe[:, :ctx, :]
770
+ x = x * math.sqrt(self.dims)
771
+ x = x + pe
772
+ return x
773
+
774
+ class FEncoder(nn.Module):
775
+ def __init__(self, mels, dims, head, layer, kernel_size, act, stride=1, use_rope=False, spec_shape=None, debug=[]):
776
+ super().__init__()
777
+
778
+ self.head = head
779
+ self.head_dim = dims // head
780
+ self.dropout = 0.01
781
+ self.use_rope = use_rope
782
+ self.dims = dims
783
+ self.debug = debug
784
+ act_fn = get_activation(act)
785
+ self.attend_pitch = False
786
+
787
+ if self.attend_pitch:
788
+ self.q, self.k, self.v, self.o, self.scale = qkv_init(dims, head)
789
+ self.mlp = nn.Sequential(
790
+ nn.Linear(dims, dims),
791
+ nn.ReLU(),
792
+ nn.Linear(dims, dims),
793
+ )
794
+ else:
795
+ self.q, self.k, self.v, self.o, self.scale = None, None, None, None, None
796
+ self.mlp = None
797
+
798
+ self.encoder = nn.Sequential(
799
+ Conv1d(mels, dims, kernel_size=3, stride=1, padding=1), act_fn,
800
+ Conv1d(dims, dims, kernel_size=3, stride=1, padding=1), act_fn,
801
+ Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
802
+
803
+ if use_rope:
804
+ if spec_shape is not None:
805
+ self.rope = rotary(dims=dims, head=head, radii=False, debug=[], use_pbias=False, axial=False, spec_shape=spec_shape)
806
+ else:
807
+ self.rope = None
808
+ self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
809
+ self.norm = RMSNorm(dims)
810
+
811
+ def apply_rope_to_features(self, x, en=None, f=None, layer="audio"):
812
+ batch, ctx, dims = x.shape
813
+ x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
814
+ freqs = self.rope(ctx, en=en, f=f, layer=layer)
815
+ x = self.rope.apply_rotary(x, freqs)
816
+ x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
817
+
818
+ return x
819
+
820
+ def forward(self, x: Tensor, en=None, f=None, layer = None):
821
+ x = self.encoder(x).permute(0, 2, 1)
822
+ if self.use_rope:
823
+ x = self.apply_rope_to_features(x, en=en, f=f, layer=layer)
824
+ else:
825
+ x = x + self.positional(x.shape[1], x.shape[-1], 10000).to(device, dtype)
826
+
827
+ if self.mlp is not None:
828
+ x = self.mlp(x)
829
+
830
+ if self.attend_pitch:
831
+ xa = en["input_ids"]
832
+ if xa is not None:
833
+ q, k, v = create_qkv(self.q, self.k, self.v, x=xa, xa=x, head=self.head)
834
+ out, _ = calculate_attention(q, k, v, mask=None, temperature=1.0, is_causal=True)
835
+ out = self.o(out)
836
+ x = x + out
837
+
838
+ x = nn.functional.dropout(x, p=self.dropout, training=self.training)
839
+ x = self.norm(x)
840
+ return x
841
+
842
+ class WEncoder(nn.Module):
843
+ def __init__(self, input_dims, dims, head, layer, kernel_size, act, use_rope=False, debug=[], spec_shape=None):
844
+ super().__init__()
845
+
846
+ self.head = head
847
+ self.head_dim = dims // head
848
+ self.dropout = 0.01
849
+ self.use_rope = use_rope
850
+ self.dims = dims
851
+ self.debug = debug
852
+ act_fn = get_activation(act)
853
+ self.target_length = None
854
+ self.encoder = nn.Sequential(
855
+ Conv1d(input_dims, dims//4, kernel_size=15, stride=4, padding=7), act_fn,
856
+ Conv1d(dims//4, dims//2, kernel_size=7, stride=2, padding=3), act_fn,
857
+ Conv1d(dims//2, dims, kernel_size=5, stride=2, padding=2), act_fn)
858
+
859
+ if use_rope:
860
+ if spec_shape is not None:
861
+ self.rope = rotary(dims=dims, head=head, radii=False, debug=[], use_pbias=False, axial=False, spec_shape=spec_shape)
862
+ else:
863
+ self.rope = None
864
+ self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
865
+ self.norm = RMSNorm(dims)
866
+
867
+ def apply_rope_to_features(self, x, en=None, f=None, layer="audio"):
868
+ batch, ctx, dims = x.shape
869
+ x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
870
+ freqs = self.rope(ctx, en=en, f=f, layer=layer)
871
+ x = self.rope.apply_rotary(x, freqs)
872
+ x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
873
+ return x
874
+
875
+ def forward(self, x: Tensor, en= None, f=None, layer = None):
876
+ x = self.encoder(x).permute(0, 2, 1)
877
+ if self.target_length and x.shape[1] != self.target_length:
878
+ x = F.adaptive_avg_pool1d(x.transpose(1, 2), self.target_length).transpose(1, 2)
879
+ if self.use_rope:
880
+ x = self.apply_rope_to_features(x, en=en, f=f, layer=layer)
881
+ else:
882
+ x = x + self.positional(x.shape[1], x.shape[-1], 10000).to(device, dtype)
883
+ x = nn.functional.dropout(x, p=self.dropout, training=self.training)
884
+
885
+ x = self.ln(x)
886
+ print(f"X: {x.shape} {f}") if "encoder" in self.debug else None
887
+ return self.norm(x)
888
+
889
+ class PEncoder(nn.Module):
890
+ def __init__(self, input_dims, dims, head, layer, kernel_size, act, use_rope=True, debug=[], one_shot=False, spec_shape=None):
891
+ super().__init__()
892
+
893
+ self.head = head
894
+ self.head_dim = dims // head
895
+ self.dims = dims
896
+ self.dropout = 0.01
897
+ self.use_rope = use_rope
898
+ self.debug = debug
899
+ act_fn = get_activation(act)
900
+
901
+ self.encoder = nn.Sequential(
902
+ Conv1d(input_dims, dims, kernel_size=7, stride=1, padding=3), act_fn,
903
+ Conv1d(dims, dims, kernel_size=5, stride=1, padding=2), act_fn,
904
+ Conv1d(dims, dims, kernel_size=3, stride=1, padding=1, groups=dims), act_fn)
905
+
906
+ if use_rope:
907
+ self.rope = rotary(dims=dims, head=head, radii=False, debug=[], use_pbias=False, axial=False, spec_shape=spec_shape)
908
+ else:
909
+ self.rope = None
910
+ self.positional = lambda length, dims, max_tscale: sinusoids(length, dims, max_tscale)
911
+
912
+ self.norm = RMSNorm(dims)
913
+
914
+ def rope_to_feature(self, x, en=None, f="pitch", layer="PEncoder"):
915
+ batch, ctx, dims = x.shape
916
+ x = x.view(batch, ctx, self.head, self.head_dim).permute(0, 2, 1, 3)
917
+ freqs = self.rope(ctx, en=en, f=f, layer=layer)
918
+ x = self.rope.apply_rotary(x, freqs)
919
+ x = x.permute(0, 2, 1, 3).contiguous().view(batch, ctx, dims)
920
+ return x
921
+
922
+ def forward(self, x: Tensor, en= None, f="pitch", layer="PEncoder"):
923
+
924
+ if x.dim() == 2:
925
+ x = x.unsqueeze(0)
926
+
927
+ x = self.encoder(x).permute(0, 2, 1)
928
+ if self.use_rope:
929
+ x = self.rope_to_feature(x, en=en, f=f, layer=layer)
930
+ else:
931
+ x = x + self.positional(x.shape[1], x.shape[-1], 10000).to(device, dtype)
932
+ x = nn.functional.dropout(x, p=self.dropout, training=self.training)
933
+ x = self.norm(x)
934
+ print(f"X: {x.shape} {f}") if "PEncoder" in self.debug else None
935
+ return x
936
+
937
+ class theBridge(nn.Module):
938
+ def __init__(self, vocab: int, mels: int, ctx: int, dims: int, head: int, layer: int,
939
+ debug: List[str], features: List[str], act: str = "gelu"):
940
+ super(theBridge, self).__init__()
941
+
942
+ tgate = True
943
+ mgate = False
944
+ cgate = False
945
+
946
+ self.debug = debug
947
+ self.counter = 0
948
+ self.dropout = 0.01
949
+ self.features = features
950
+ self.do_blend = "no_blend" not in self.debug
951
+ self.sequential = "sequential" in self.debug
952
+ self.layer = layer
953
+
954
+ self.token = nn.Embedding(vocab, dims, device=device, dtype=dtype)
955
+ self.positional = nn.Parameter(torch.empty(ctx, dims, device=device, dtype=dtype), requires_grad=True)
956
+ self.blend = nn.Parameter(torch.tensor(0.5, device=device, dtype=dtype), requires_grad=True)
957
+ self.norm = RMSNorm(dims)
958
+ self.sinusoid_pos = lambda length, dims, max_tscale: sinusoids(length, dims, 10000)
959
+ self.rotary = rotary(dims=dims, head=head, debug=debug, radii=False)
960
+
961
+ with torch.no_grad():
962
+ self.token.weight[0].zero_()
963
+
964
+ act_fn = get_activation(act)
965
+ if features == ["spectrogram", "waveform", "pitch"]:
966
+ cgate=True
967
+ else:
968
+ cgate = False
969
+
970
+ self.blockA = nn.ModuleDict()
971
+ self.blockA["waveform"] = nn.ModuleList(
972
+ [WEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=11, act=act_fn)] +
973
+ [Residual(ctx=ctx, dims=dims, head=head, act=act_fn, tgate=tgate, mgate=mgate, cgate=cgate, debug=debug, features=features)
974
+ for _ in range(layer)] if "waveform" in features else None)
975
+
976
+ for feature_type in ["spectrogram", "aperiodic", "harmonic"]:
977
+ if feature_type in features:
978
+ self.blockA[feature_type] = nn.ModuleList(
979
+ [FEncoder(mels=mels, dims=dims, head=head, layer=layer, kernel_size=3, act=act_fn)] +
980
+ [Residual(ctx=ctx, dims=dims, head=head, act=act_fn, tgate=tgate, mgate=mgate, cgate=cgate, debug=debug, features=features) for _ in range(layer)] if feature_type in features else None)
981
+ else:
982
+ self.blockA[feature_type] = None
983
+
984
+ for feature_type in ["pitch", "phase"]:
985
+ if feature_type in features:
986
+ self.blockA[feature_type] = nn.ModuleList(
987
+ [PEncoder(input_dims=1, dims=dims, head=head, layer=layer, kernel_size=9, act=act_fn)] +
988
+ [Residual(ctx=ctx, dims=dims, head=head, act=act_fn, tgate=tgate, mgate=mgate, cgate=cgate, debug=debug, features=features) for _ in range(layer)] if feature_type in features else None)
989
+ else:
990
+ self.blockA[feature_type] = None
991
+
992
+ self.blockB = nn.ModuleList([
993
+ Residual(ctx=ctx, dims=dims, head=head, act=act_fn, tgate=tgate, mgate=mgate, cgate=cgate, debug=debug, features=features)
994
+ for _ in range(layer)])
995
+
996
+ self.modal = nn.ModuleList([
997
+ Residual(ctx=ctx, dims=dims, head=head, act=act_fn, tgate=tgate, mgate=mgate, cgate=cgate, debug=debug, features=features)
998
+ for _ in range(layer)])
999
+
1000
+ mask = torch.tril(torch.ones(ctx, ctx), diagonal=0)
1001
+ self.register_buffer("mask", mask, persistent=False)
1002
+
1003
+ self.norm = RMSNorm(dims)
1004
+
1005
+ def forward(self, x, xa, en, f, sequential=False) -> Tensor:
1006
+ mask = self.mask[:x.shape[1], :x.shape[1]]
1007
+ x = self.token(x.long()) + self.positional[:x.shape[1]]
1008
+
1009
+ out = {}
1010
+ out["input_ids"] = x
1011
+ out.update(en)
1012
+
1013
+ for b in chain(self.blockA[f] or []):
1014
+ xa = b(x=xa, en=out, f=f, layer="en")
1015
+
1016
+ for b in chain(self.blockB or []):
1017
+ x = b(x=x, xa=None, mask=mask, en=out, f=f, layer="dec")
1018
+ y = b(x, xa=xa, mask=None, en=out, f=f, layer="cross")
1019
+ if sequential:
1020
+ x = y
1021
+ else:
1022
+ a = torch.sigmoid(self.blend)
1023
+ x = a * y + (1 - a) * x
1024
+ for b in self.modal:
1025
+ xc = b(x=torch.cat([x, xa], dim=1), xa=None, mask=None, en=out, f=f, layer="modal")
1026
+ xm = b(x=xc[:, :x.shape[1]], xa=xc[:, x.shape[1]:], mask=None, en=out, f=f, layer="modal")
1027
+ if sequential:
1028
+ x = xm
1029
+ else:
1030
+ a = torch.sigmoid(self.blend)
1031
+ x = a * x + (1 - a) * xm
1032
+
1033
+ if self.counter < 1 and "encoder" in self.debug:
1034
+ shapes = {k: v.shape for k, v in en.items()}
1035
+ print(f"Step {self.counter}: mode: {list(en.keys()) }: shapes: {shapes}")
1036
+ self.counter += 1
1037
+
1038
+ x = self.norm(x)
1039
+ x = x @ torch.transpose(self.token.weight.to(dtype), 0, 1).float()
1040
+
1041
+ return x
1042
+
1043
+ class Echo(nn.Module):
1044
+ def __init__(self, param: Dimensions):
1045
+ super().__init__()
1046
+ self.param = param
1047
+
1048
+ self.processor = theBridge(
1049
+ vocab=param.vocab,
1050
+ mels=param.mels,
1051
+ ctx=param.ctx,
1052
+ dims=param.dims,
1053
+ head=param.head,
1054
+ layer=param.layer,
1055
+ features=param.features,
1056
+ act=param.act,
1057
+ debug=param.debug,
1058
+ )
1059
+
1060
+ def forward(self,
1061
+ labels=None,
1062
+ input_ids=None,
1063
+ waveform: Optional[torch.Tensor]=None,
1064
+ spectrogram: Optional[torch.Tensor]=None,
1065
+ pitch: Optional[torch.Tensor]=None,
1066
+ f0: Optional[torch.Tensor]=None,
1067
+ f0t: Optional[torch.Tensor]=None,
1068
+ harmonic: Optional[torch.Tensor]=None,
1069
+ aperiodic: Optional[torch.Tensor]=None,
1070
+ phase: Optional[torch.Tensor]=None,
1071
+ ) -> Dict[str, Optional[torch.Tensor]]:
1072
+
1073
+ en= TensorDict(batch_size=[1], device=self.device, dtype=self.dtype)
1074
+
1075
+ en= {}
1076
+ if f0 is not None:
1077
+ en["f0"] = f0
1078
+ if f0t is not None:
1079
+ en["f0t"] = f0t
1080
+ if harmonic is not None:
1081
+ en["harmonic"] = harmonic
1082
+ if aperiodic is not None:
1083
+ en["aperiodic"] = aperiodic
1084
+ if phase is not None:
1085
+ en["phase"] = phase
1086
+ if pitch is not None:
1087
+ en["pitch"] = pitch
1088
+ if waveform is not None:
1089
+ en["waveform"] = waveform
1090
+ if spectrogram is not None:
1091
+ en["spectrogram"] = spectrogram
1092
+
1093
+ x = input_ids
1094
+ for f, xa in en.items():
1095
+
1096
+ logits = self.processor(x, xa, en, f)
1097
+
1098
+ loss = None
1099
+ if labels is not None:
1100
+ loss = F.cross_entropy(
1101
+ logits.view(-1, logits.shape[-1]), labels.view(-1), ignore_index=0)
1102
+
1103
+ return {"logits": logits, "loss": loss}
1104
+
1105
+ @property
1106
+ def device(self):
1107
+ return next(self.parameters()).device
1108
+ @property
1109
+ def dtype(self):
1110
+ return next(self.parameters()).dtype
1111
+ ```