Spaces:
Sleeping
Sleeping
IlayMalinyak
commited on
Commit
·
49ebc1f
1
Parent(s):
707b3a3
kan
Browse files- model/0.0_cache_data +0 -0
- model/0.0_config.yml +0 -0
- model/history.txt +2 -0
- tasks/audio.py +15 -10
- tasks/models/frugal_2025-01-21/CNNEncoder_frugal_2.json +0 -0
- tasks/models/frugal_2025-01-21/frugal_kan_2.pth +3 -0
- tasks/run.py +95 -0
- tasks/utils/config.yaml +18 -11
- tasks/utils/data.py +11 -5
- tasks/utils/kan/__init__.py +1 -0
- tasks/utils/kan/fasterkan.py +135 -0
- tasks/utils/kan/fasterkan_basis.py +112 -0
- tasks/utils/kan/fasterkan_layers.py +301 -0
- tasks/utils/kan/feature_extractor.py +112 -0
- tasks/utils/models.py +28 -1
- tasks/utils/train.py +12 -9
model/0.0_cache_data
ADDED
|
Binary file (840 Bytes). View file
|
|
|
model/0.0_config.yml
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
model/history.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
### Round 0 ###
|
| 2 |
+
init => 0.0
|
tasks/audio.py
CHANGED
|
@@ -10,7 +10,7 @@ from torch.utils.data import DataLoader
|
|
| 10 |
from .utils.evaluation import AudioEvaluationRequest
|
| 11 |
from .utils.emissions import tracker, clean_emissions_data, get_space_info
|
| 12 |
from .utils.data import FFTDataset
|
| 13 |
-
from .utils.models import DualEncoder
|
| 14 |
from .utils.train import Trainer
|
| 15 |
from .utils.data_utils import collate_fn, Container
|
| 16 |
import yaml
|
|
@@ -70,13 +70,14 @@ async def evaluate_audio(request: AudioEvaluationRequest):
|
|
| 70 |
model_args = Container(**yaml.safe_load(open(args_path, 'r'))['CNNEncoder'])
|
| 71 |
model_args_f = Container(**yaml.safe_load(open(args_path, 'r'))['CNNEncoder_f'])
|
| 72 |
conformer_args = Container(**yaml.safe_load(open(args_path, 'r'))['Conformer'])
|
|
|
|
| 73 |
|
| 74 |
test_dataset = FFTDataset(test_dataset)
|
| 75 |
test_dl = DataLoader(test_dataset, batch_size=data_args.batch_size, collate_fn=collate_fn)
|
| 76 |
|
| 77 |
-
model =
|
| 78 |
model = model.to(device)
|
| 79 |
-
state_dict = torch.load(
|
| 80 |
new_state_dict = OrderedDict()
|
| 81 |
for key, value in state_dict.items():
|
| 82 |
if key.startswith('module.'):
|
|
@@ -95,8 +96,12 @@ async def evaluate_audio(request: AudioEvaluationRequest):
|
|
| 95 |
accumulation_step=1, max_iter=np.inf,
|
| 96 |
exp_name=f"frugal_cnnencoder_inference")
|
| 97 |
predictions, true_labels, acc = trainer.predict(test_dl, device=device)
|
|
|
|
|
|
|
| 98 |
# Make random predictions (placeholder for actual model inference)
|
| 99 |
print("accuracy: ", acc)
|
|
|
|
|
|
|
| 100 |
|
| 101 |
#--------------------------------------------------------------------------------------------
|
| 102 |
# YOUR MODEL INFERENCE STOPS HERE
|
|
@@ -128,15 +133,15 @@ async def evaluate_audio(request: AudioEvaluationRequest):
|
|
| 128 |
|
| 129 |
return results
|
| 130 |
|
| 131 |
-
|
| 132 |
# with open("../logs//token.txt", "r") as f:
|
| 133 |
# api_key = f.read()
|
| 134 |
# login(api_key)
|
| 135 |
# # Create a sample request object
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
#
|
| 142 |
-
|
|
|
|
| 10 |
from .utils.evaluation import AudioEvaluationRequest
|
| 11 |
from .utils.emissions import tracker, clean_emissions_data, get_space_info
|
| 12 |
from .utils.data import FFTDataset
|
| 13 |
+
from .utils.models import DualEncoder, CNNKan
|
| 14 |
from .utils.train import Trainer
|
| 15 |
from .utils.data_utils import collate_fn, Container
|
| 16 |
import yaml
|
|
|
|
| 70 |
model_args = Container(**yaml.safe_load(open(args_path, 'r'))['CNNEncoder'])
|
| 71 |
model_args_f = Container(**yaml.safe_load(open(args_path, 'r'))['CNNEncoder_f'])
|
| 72 |
conformer_args = Container(**yaml.safe_load(open(args_path, 'r'))['Conformer'])
|
| 73 |
+
kan_args = Container(**yaml.safe_load(open(args_path, 'r'))['KAN'])
|
| 74 |
|
| 75 |
test_dataset = FFTDataset(test_dataset)
|
| 76 |
test_dl = DataLoader(test_dataset, batch_size=data_args.batch_size, collate_fn=collate_fn)
|
| 77 |
|
| 78 |
+
model = CNNKan(model_args, conformer_args, kan_args.get_dict())
|
| 79 |
model = model.to(device)
|
| 80 |
+
state_dict = torch.load(data_args.checkpoint_path)
|
| 81 |
new_state_dict = OrderedDict()
|
| 82 |
for key, value in state_dict.items():
|
| 83 |
if key.startswith('module.'):
|
|
|
|
| 96 |
accumulation_step=1, max_iter=np.inf,
|
| 97 |
exp_name=f"frugal_cnnencoder_inference")
|
| 98 |
predictions, true_labels, acc = trainer.predict(test_dl, device=device)
|
| 99 |
+
# true_labels = test_dataset["label"]
|
| 100 |
+
|
| 101 |
# Make random predictions (placeholder for actual model inference)
|
| 102 |
print("accuracy: ", acc)
|
| 103 |
+
print("predictions: ", len(predictions))
|
| 104 |
+
print("true_labels: ", len(true_labels))
|
| 105 |
|
| 106 |
#--------------------------------------------------------------------------------------------
|
| 107 |
# YOUR MODEL INFERENCE STOPS HERE
|
|
|
|
| 133 |
|
| 134 |
return results
|
| 135 |
|
| 136 |
+
if __name__ == "__main__":
|
| 137 |
# with open("../logs//token.txt", "r") as f:
|
| 138 |
# api_key = f.read()
|
| 139 |
# login(api_key)
|
| 140 |
# # Create a sample request object
|
| 141 |
+
sample_request = AudioEvaluationRequest(
|
| 142 |
+
dataset_name="rfcx/frugalai", # Replace with actual dataset name
|
| 143 |
+
test_size=0.2, # Example values
|
| 144 |
+
test_seed=42
|
| 145 |
+
)
|
| 146 |
#
|
| 147 |
+
asyncio.run(evaluate_audio(sample_request))
|
tasks/models/frugal_2025-01-21/CNNEncoder_frugal_2.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tasks/models/frugal_2025-01-21/frugal_kan_2.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:28e0188edab4879996cc960d2dc79641460b270af9c5ac7d3eacad1f5e96da39
|
| 3 |
+
size 1714830
|
tasks/run.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch.utils.data import DataLoader
|
| 2 |
+
from .utils.data import FFTDataset, SplitDataset
|
| 3 |
+
from datasets import load_dataset
|
| 4 |
+
from .utils.train import Trainer
|
| 5 |
+
from .utils.models import CNNKan, KanEncoder
|
| 6 |
+
from .utils.data_utils import *
|
| 7 |
+
from huggingface_hub import login
|
| 8 |
+
import yaml
|
| 9 |
+
import datetime
|
| 10 |
+
import json
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
# local_rank = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 14 |
+
current_date = datetime.date.today().strftime("%Y-%m-%d")
|
| 15 |
+
datetime_dir = f"frugal_{current_date}"
|
| 16 |
+
args_dir = 'tasks/utils/config.yaml'
|
| 17 |
+
data_args = Container(**yaml.safe_load(open(args_dir, 'r'))['Data'])
|
| 18 |
+
exp_num = data_args.exp_num
|
| 19 |
+
model_name = data_args.model_name
|
| 20 |
+
model_args = Container(**yaml.safe_load(open(args_dir, 'r'))['CNNEncoder'])
|
| 21 |
+
model_args_f = Container(**yaml.safe_load(open(args_dir, 'r'))['CNNEncoder_f'])
|
| 22 |
+
conformer_args = Container(**yaml.safe_load(open(args_dir, 'r'))['Conformer'])
|
| 23 |
+
kan_args = Container(**yaml.safe_load(open(args_dir, 'r'))['KAN'])
|
| 24 |
+
if not os.path.exists(f"{data_args.log_dir}/{datetime_dir}"):
|
| 25 |
+
os.makedirs(f"{data_args.log_dir}/{datetime_dir}")
|
| 26 |
+
|
| 27 |
+
with open("../logs//token.txt", "r") as f:
|
| 28 |
+
api_key = f.read()
|
| 29 |
+
|
| 30 |
+
# local_rank, world_size, gpus_per_node = setup()
|
| 31 |
+
local_rank = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 32 |
+
login(api_key)
|
| 33 |
+
dataset = load_dataset("rfcx/frugalai", streaming=True)
|
| 34 |
+
|
| 35 |
+
train_ds = SplitDataset(FFTDataset(dataset["train"]), is_train=True)
|
| 36 |
+
|
| 37 |
+
train_dl = DataLoader(train_ds, batch_size=data_args.batch_size, collate_fn=collate_fn)
|
| 38 |
+
|
| 39 |
+
val_ds = SplitDataset(FFTDataset(dataset["train"]), is_train=False)
|
| 40 |
+
|
| 41 |
+
val_dl = DataLoader(val_ds,batch_size=data_args.batch_size, collate_fn=collate_fn)
|
| 42 |
+
|
| 43 |
+
test_ds = FFTDataset(dataset["test"])
|
| 44 |
+
test_dl = DataLoader(test_ds,batch_size=data_args.batch_size, collate_fn=collate_fn)
|
| 45 |
+
|
| 46 |
+
# for i, batch in enumerate(train_dl):
|
| 47 |
+
# x, x_f, y = batch['audio']['array'], batch['audio']['fft'], batch['label']
|
| 48 |
+
# print(x.shape, x_f.shape, y.shape)
|
| 49 |
+
# if i > 10:
|
| 50 |
+
# break
|
| 51 |
+
# exit()
|
| 52 |
+
|
| 53 |
+
# model = DualEncoder(model_args, model_args_f, conformer_args)
|
| 54 |
+
# model = FasterKAN([18000,64,64,16,1])
|
| 55 |
+
model = CNNKan(model_args, conformer_args, kan_args.get_dict())
|
| 56 |
+
# model.kan.speed()
|
| 57 |
+
# model = KanEncoder(kan_args.get_dict())
|
| 58 |
+
model = model.to(local_rank)
|
| 59 |
+
# model = DDP(model, device_ids=[local_rank], output_device=local_rank)
|
| 60 |
+
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 61 |
+
print(f"Number of parameters: {num_params}")
|
| 62 |
+
|
| 63 |
+
loss_fn = torch.nn.BCEWithLogitsLoss()
|
| 64 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
| 65 |
+
total_steps = int(data_args.num_epochs) * 1000
|
| 66 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
|
| 67 |
+
T_max=total_steps,
|
| 68 |
+
eta_min=float((5e-4)/10))
|
| 69 |
+
|
| 70 |
+
# missing, unexpected = model.load_state_dict(torch.load(model_args.checkpoint_path))
|
| 71 |
+
# print(f"Missing keys: {missing}")
|
| 72 |
+
# print(f"Unexpected keys: {unexpected}")
|
| 73 |
+
|
| 74 |
+
trainer = Trainer(model=model, optimizer=optimizer,
|
| 75 |
+
criterion=loss_fn, output_dim=model_args.output_dim, scaler=None,
|
| 76 |
+
scheduler=None, train_dataloader=train_dl,
|
| 77 |
+
val_dataloader=val_dl, device=local_rank,
|
| 78 |
+
exp_num=datetime_dir, log_path=data_args.log_dir,
|
| 79 |
+
range_update=None,
|
| 80 |
+
accumulation_step=1, max_iter=np.inf,
|
| 81 |
+
exp_name=f"frugal_kan_{exp_num}")
|
| 82 |
+
fit_res = trainer.fit(num_epochs=100, device=local_rank,
|
| 83 |
+
early_stopping=10, only_p=False, best='loss', conf=True)
|
| 84 |
+
output_filename = f'{data_args.log_dir}/{datetime_dir}/{model_name}_frugal_{exp_num}.json'
|
| 85 |
+
with open(output_filename, "w") as f:
|
| 86 |
+
json.dump(fit_res, f, indent=2)
|
| 87 |
+
preds, acc = trainer.predict(test_dl, local_rank)
|
| 88 |
+
print(f"Accuracy: {acc}")
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
|
tasks/utils/config.yaml
CHANGED
|
@@ -1,34 +1,41 @@
|
|
| 1 |
Data:
|
| 2 |
# Basics
|
| 3 |
-
log_dir: '/
|
| 4 |
# Data
|
| 5 |
-
dataset: "
|
| 6 |
-
data_dir:
|
| 7 |
model_name: "CNNEncoder"
|
| 8 |
-
batch_size:
|
| 9 |
-
num_epochs:
|
| 10 |
exp_num: 2
|
| 11 |
max_len_spectra: 4096
|
| 12 |
max_days_lc: 270
|
| 13 |
lc_freq: 0.0208
|
| 14 |
create_umap: True
|
|
|
|
| 15 |
|
| 16 |
CNNEncoder:
|
| 17 |
# Model
|
| 18 |
-
in_channels:
|
| 19 |
num_layers: 4
|
| 20 |
stride: 1
|
| 21 |
-
encoder_dims: [32,64,128
|
| 22 |
kernel_size: 3
|
| 23 |
dropout_p: 0.3
|
| 24 |
output_dim: 2
|
| 25 |
beta: 1
|
| 26 |
-
load_checkpoint:
|
| 27 |
checkpoint_num: 1
|
| 28 |
activation: "silu"
|
| 29 |
sine_w0: 1.0
|
| 30 |
-
avg_output:
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
CNNEncoder_f:
|
| 34 |
# Model
|
|
@@ -50,7 +57,7 @@ CNNEncoder_f:
|
|
| 50 |
Conformer:
|
| 51 |
encoder: ["mhsa_pro", "conv"]
|
| 52 |
timeshift: false
|
| 53 |
-
num_layers:
|
| 54 |
encoder_dim: 128
|
| 55 |
num_heads: 8
|
| 56 |
kernel_size: 3
|
|
|
|
| 1 |
Data:
|
| 2 |
# Basics
|
| 3 |
+
log_dir: 'tasks/models'
|
| 4 |
# Data
|
| 5 |
+
dataset: "FFTDataset"
|
| 6 |
+
data_dir: None
|
| 7 |
model_name: "CNNEncoder"
|
| 8 |
+
batch_size: 32
|
| 9 |
+
num_epochs: 10
|
| 10 |
exp_num: 2
|
| 11 |
max_len_spectra: 4096
|
| 12 |
max_days_lc: 270
|
| 13 |
lc_freq: 0.0208
|
| 14 |
create_umap: True
|
| 15 |
+
checkpoint_path: 'tasks/models/frugal_2025-01-21/frugal_kan_2.pth'
|
| 16 |
|
| 17 |
CNNEncoder:
|
| 18 |
# Model
|
| 19 |
+
in_channels: 2
|
| 20 |
num_layers: 4
|
| 21 |
stride: 1
|
| 22 |
+
encoder_dims: [32,64,128]
|
| 23 |
kernel_size: 3
|
| 24 |
dropout_p: 0.3
|
| 25 |
output_dim: 2
|
| 26 |
beta: 1
|
| 27 |
+
load_checkpoint: False
|
| 28 |
checkpoint_num: 1
|
| 29 |
activation: "silu"
|
| 30 |
sine_w0: 1.0
|
| 31 |
+
avg_output: False
|
| 32 |
+
|
| 33 |
+
KAN:
|
| 34 |
+
layers_hidden: [1125,32,8,8,1]
|
| 35 |
+
grid_min: -1.2
|
| 36 |
+
grid_max: 1.2
|
| 37 |
+
num_grids: 8
|
| 38 |
+
exponent: 2
|
| 39 |
|
| 40 |
CNNEncoder_f:
|
| 41 |
# Model
|
|
|
|
| 57 |
Conformer:
|
| 58 |
encoder: ["mhsa_pro", "conv"]
|
| 59 |
timeshift: false
|
| 60 |
+
num_layers: 4
|
| 61 |
encoder_dim: 128
|
| 62 |
num_heads: 8
|
| 63 |
kernel_size: 3
|
tasks/utils/data.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
import torch
|
| 2 |
from torch.utils.data import IterableDataset
|
| 3 |
from torch.fft import fft
|
|
|
|
| 4 |
from itertools import tee
|
| 5 |
import random
|
| 6 |
import torchaudio.transforms as T
|
|
@@ -24,20 +25,25 @@ class SplitDataset(IterableDataset):
|
|
| 24 |
|
| 25 |
|
| 26 |
class FFTDataset(IterableDataset):
|
| 27 |
-
def __init__(self, original_dataset, orig_sample_rate=12000, target_sample_rate=
|
| 28 |
self.dataset = original_dataset
|
| 29 |
self.resampler = T.Resample(orig_freq=orig_sample_rate, new_freq=target_sample_rate)
|
|
|
|
| 30 |
|
| 31 |
def __iter__(self):
|
| 32 |
for item in self.dataset:
|
| 33 |
# Assuming your audio data is in item['audio']
|
| 34 |
# Modify this based on your actual data structure
|
| 35 |
audio_data = torch.tensor(item['audio']['array']).float()
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
# Update the item with FFT data
|
| 42 |
item['audio']['fft'] = fft_data
|
|
|
|
| 43 |
yield item
|
|
|
|
| 1 |
import torch
|
| 2 |
from torch.utils.data import IterableDataset
|
| 3 |
from torch.fft import fft
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
from itertools import tee
|
| 6 |
import random
|
| 7 |
import torchaudio.transforms as T
|
|
|
|
| 25 |
|
| 26 |
|
| 27 |
class FFTDataset(IterableDataset):
|
| 28 |
+
def __init__(self, original_dataset, max_len=72000, orig_sample_rate=12000, target_sample_rate=3000):
|
| 29 |
self.dataset = original_dataset
|
| 30 |
self.resampler = T.Resample(orig_freq=orig_sample_rate, new_freq=target_sample_rate)
|
| 31 |
+
self.max_len = max_len
|
| 32 |
|
| 33 |
def __iter__(self):
|
| 34 |
for item in self.dataset:
|
| 35 |
# Assuming your audio data is in item['audio']
|
| 36 |
# Modify this based on your actual data structure
|
| 37 |
audio_data = torch.tensor(item['audio']['array']).float()
|
| 38 |
+
# pad audio
|
| 39 |
+
# if len(audio_data) == 0:
|
| 40 |
+
# continue
|
| 41 |
+
pad_len = self.max_len - len(audio_data)
|
| 42 |
+
audio_data = F.pad(audio_data, (0, pad_len), mode='constant')
|
| 43 |
+
audio_data = self.resampler(audio_data)
|
| 44 |
+
fft_data = fft(audio_data)
|
| 45 |
|
| 46 |
# Update the item with FFT data
|
| 47 |
item['audio']['fft'] = fft_data
|
| 48 |
+
item['audio']['array'] = audio_data
|
| 49 |
yield item
|
tasks/utils/kan/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .fasterkan import FasterKAN, FasterKANLayer, FasterKANvolver
|
tasks/utils/kan/fasterkan.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import math
|
| 5 |
+
from typing import *
|
| 6 |
+
from torch.autograd import Function
|
| 7 |
+
from .feature_extractor import EnhancedFeatureExtractor
|
| 8 |
+
from .fasterkan_layers import FasterKANLayer
|
| 9 |
+
|
| 10 |
+
class FasterKAN(nn.Module):
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
layers_hidden: List[int],
|
| 14 |
+
grid_min: float = -1.2,
|
| 15 |
+
grid_max: float = 1.2,
|
| 16 |
+
num_grids: int = 8,
|
| 17 |
+
exponent: int = 2,
|
| 18 |
+
inv_denominator: float = 0.5,
|
| 19 |
+
train_grid: bool = False,
|
| 20 |
+
train_inv_denominator: bool = False,
|
| 21 |
+
#use_base_update: bool = True,
|
| 22 |
+
base_activation = None,
|
| 23 |
+
spline_weight_init_scale: float = 1.0,
|
| 24 |
+
) -> None:
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.layers = nn.ModuleList([
|
| 27 |
+
FasterKANLayer(
|
| 28 |
+
in_dim, out_dim,
|
| 29 |
+
grid_min=grid_min,
|
| 30 |
+
grid_max=grid_max,
|
| 31 |
+
num_grids=num_grids,
|
| 32 |
+
exponent = exponent,
|
| 33 |
+
inv_denominator = inv_denominator,
|
| 34 |
+
train_grid = train_grid ,
|
| 35 |
+
train_inv_denominator = train_inv_denominator,
|
| 36 |
+
#use_base_update=use_base_update,
|
| 37 |
+
base_activation=base_activation,
|
| 38 |
+
spline_weight_init_scale=spline_weight_init_scale,
|
| 39 |
+
) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:])
|
| 40 |
+
])
|
| 41 |
+
#print(f"FasterKAN layers_hidden[1:] shape: ", len(layers_hidden[1:]))
|
| 42 |
+
#print(f"FasterKAN layers_hidden[:-1] shape: ", len(layers_hidden[:-1]))
|
| 43 |
+
#print("FasterKAN zip shape: \n", *[(in_dim, out_dim) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:])])
|
| 44 |
+
|
| 45 |
+
#print(f"FasterKAN self.faster_kan_layers shape: \n", len(self.layers))
|
| 46 |
+
#print(f"FasterKAN self.faster_kan_layers: \n", self.layers)
|
| 47 |
+
|
| 48 |
+
def forward(self, x):
|
| 49 |
+
for layer in self.layers:
|
| 50 |
+
#print("FasterKAN layer: \n", layer)
|
| 51 |
+
#print(f"FasterKAN x shape: {x.shape}")
|
| 52 |
+
x = layer(x)
|
| 53 |
+
return x
|
| 54 |
+
|
| 55 |
+
class FasterKANvolver(nn.Module):
|
| 56 |
+
def __init__(
|
| 57 |
+
self,
|
| 58 |
+
layers_hidden: List[int],
|
| 59 |
+
grid_min: float = -1.2,
|
| 60 |
+
grid_max: float = 0.2,
|
| 61 |
+
num_grids: int = 8,
|
| 62 |
+
exponent: int = 2,
|
| 63 |
+
inv_denominator: float = 0.5,
|
| 64 |
+
train_grid: bool = False,
|
| 65 |
+
train_inv_denominator: bool = False,
|
| 66 |
+
#use_base_update: bool = True,
|
| 67 |
+
base_activation = None,
|
| 68 |
+
spline_weight_init_scale: float = 1.0,
|
| 69 |
+
view = [-1, 1, 28, 28],
|
| 70 |
+
) -> None:
|
| 71 |
+
super(FasterKANvolver, self).__init__()
|
| 72 |
+
|
| 73 |
+
self.view = view
|
| 74 |
+
# Feature extractor with Convolutional layers
|
| 75 |
+
self.feature_extractor = EnhancedFeatureExtractor(colors = view[1])
|
| 76 |
+
"""
|
| 77 |
+
nn.Sequential(
|
| 78 |
+
nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1), # 1 input channel (grayscale), 16 output channels
|
| 79 |
+
nn.ReLU(),
|
| 80 |
+
nn.MaxPool2d(2, 2),
|
| 81 |
+
nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
|
| 82 |
+
nn.ReLU(),
|
| 83 |
+
nn.MaxPool2d(2, 2)
|
| 84 |
+
)
|
| 85 |
+
"""
|
| 86 |
+
# Calculate the flattened feature size after convolutional layers
|
| 87 |
+
flat_features = 256 # XX channels, image size reduced to YxY
|
| 88 |
+
|
| 89 |
+
# Update layers_hidden with the correct input size from conv layers
|
| 90 |
+
layers_hidden = [flat_features] + layers_hidden
|
| 91 |
+
#print(f"FasterKANvolver layers_hidden shape: \n", layers_hidden)
|
| 92 |
+
#print(f"FasterKANvolver layers_hidden[1:] shape: ", len(layers_hidden[1:]))
|
| 93 |
+
#print(f"FasterKANvolver layers_hidden[:-1] shape: ", len(layers_hidden[:-1]))
|
| 94 |
+
#print("FasterKANvolver zip shape: \n", *[(in_dim, out_dim) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:])])
|
| 95 |
+
|
| 96 |
+
# Define the FasterKAN layers
|
| 97 |
+
self.faster_kan_layers = nn.ModuleList([
|
| 98 |
+
FasterKANLayer(
|
| 99 |
+
in_dim, out_dim,
|
| 100 |
+
grid_min=grid_min,
|
| 101 |
+
grid_max=grid_max,
|
| 102 |
+
num_grids=num_grids,
|
| 103 |
+
exponent=exponent,
|
| 104 |
+
inv_denominator = 0.5,
|
| 105 |
+
train_grid = False,
|
| 106 |
+
train_inv_denominator = False,
|
| 107 |
+
#use_base_update=use_base_update,
|
| 108 |
+
base_activation=base_activation,
|
| 109 |
+
spline_weight_init_scale=spline_weight_init_scale,
|
| 110 |
+
) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:])
|
| 111 |
+
])
|
| 112 |
+
#print(f"FasterKANvolver self.faster_kan_layers shape: \n", len(self.faster_kan_layers))
|
| 113 |
+
#print(f"FasterKANvolver self.faster_kan_layers: \n", self.faster_kan_layers)
|
| 114 |
+
|
| 115 |
+
def forward(self, x):
|
| 116 |
+
# Reshape input from [batch_size, 784] to [batch_size, 1, 28, 28] for MNIST [batch_size, 1, 32, 32] for C
|
| 117 |
+
#print(f"FasterKAN x view shape: {x.shape}")
|
| 118 |
+
# Handle different input shapes based on the length of view
|
| 119 |
+
x = x.view(self.view[0], self.view[1], self.view[2], self.view[3])
|
| 120 |
+
#print(f"FasterKAN x view shape: {x.shape}")
|
| 121 |
+
# Apply convolutional layers
|
| 122 |
+
#print(f"FasterKAN x view shape: {x.shape}")
|
| 123 |
+
x = self.feature_extractor(x)
|
| 124 |
+
#print(f"FasterKAN x after feature_extractor shape: {x.shape}")
|
| 125 |
+
x = x.view(x.size(0), -1) # Flatten the output from the conv layers
|
| 126 |
+
#rint(f"FasterKAN x shape: {x.shape}")
|
| 127 |
+
|
| 128 |
+
# Pass through FasterKAN layers
|
| 129 |
+
for layer in self.faster_kan_layers:
|
| 130 |
+
#print("FasterKAN layer: \n", layer)
|
| 131 |
+
#print(f"FasterKAN x shape: {x.shape}")
|
| 132 |
+
x = layer(x)
|
| 133 |
+
#print(f"FasterKAN x shape: {x.shape}")
|
| 134 |
+
|
| 135 |
+
return x
|
tasks/utils/kan/fasterkan_basis.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import math
|
| 5 |
+
from typing import *
|
| 6 |
+
from torch.autograd import Function
|
| 7 |
+
|
| 8 |
+
class RSWAFFunction(Function):
|
| 9 |
+
@staticmethod
|
| 10 |
+
def forward(ctx, input, grid, inv_denominator, train_grid, train_inv_denominator):
|
| 11 |
+
# Compute the forward pass
|
| 12 |
+
#print('\n')
|
| 13 |
+
#print(f"Forward pass - grid: {(grid[0].item(),grid[-1].item())}, inv_denominator: {inv_denominator.item()}")
|
| 14 |
+
|
| 15 |
+
#print(f"grid.shape: {grid.shape }")
|
| 16 |
+
#print(f"grid: {(grid[0],grid[-1]) }")
|
| 17 |
+
#print(f"inv_denominator.shape: {inv_denominator.shape }")
|
| 18 |
+
#print(f"inv_denominator: {inv_denominator }")
|
| 19 |
+
diff = (input[..., None] - grid)
|
| 20 |
+
diff_mul = diff.mul(inv_denominator)
|
| 21 |
+
tanh_diff = torch.tanh(diff)
|
| 22 |
+
tanh_diff_deriviative = -tanh_diff.mul(tanh_diff) + 1 # sech^2(x) = 1 - tanh^2(x)
|
| 23 |
+
|
| 24 |
+
# Save tensors for backward pass
|
| 25 |
+
ctx.save_for_backward(input, tanh_diff, tanh_diff_deriviative, diff, inv_denominator)
|
| 26 |
+
ctx.train_grid = train_grid
|
| 27 |
+
ctx.train_inv_denominator = train_inv_denominator
|
| 28 |
+
|
| 29 |
+
return tanh_diff_deriviative
|
| 30 |
+
|
| 31 |
+
##### SOS NOT SURE HOW grad_inv_denominator, grad_grid ARE CALCULATED CORRECTLY YET
|
| 32 |
+
##### MUST CHECK https://github.com/pytorch/pytorch/issues/74802
|
| 33 |
+
##### MUST CHECK https://www.changjiangcai.com/studynotes/2020-10-18-Custom-Function-Extending-PyTorch/
|
| 34 |
+
##### MUST CHECK https://pytorch.org/tutorials/intermediate/custom_function_double_backward_tutorial.html
|
| 35 |
+
##### MUST CHECK https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html
|
| 36 |
+
##### MUST CHECK https://gist.github.com/Hanrui-Wang/bf225dc0ccb91cdce160539c0acc853a
|
| 37 |
+
|
| 38 |
+
@staticmethod
|
| 39 |
+
def backward(ctx, grad_output):
|
| 40 |
+
# Retrieve saved tensors
|
| 41 |
+
input, tanh_diff, tanh_diff_deriviative, diff, inv_denominator = ctx.saved_tensors
|
| 42 |
+
grad_grid = None
|
| 43 |
+
grad_inv_denominator = None
|
| 44 |
+
|
| 45 |
+
#print(f"tanh_diff_deriviative shape: {tanh_diff_deriviative.shape }")
|
| 46 |
+
#print(f"tanh_diff shape: {tanh_diff.shape }")
|
| 47 |
+
#print(f"grad_output shape: {grad_output.shape }")
|
| 48 |
+
|
| 49 |
+
# Compute the backward pass for the input
|
| 50 |
+
grad_input = -2 * tanh_diff * tanh_diff_deriviative * grad_output
|
| 51 |
+
#print(f"Backward pass 1 - grad_input: {(grad_input.min().item(), grad_input.max().item())}")
|
| 52 |
+
#print(f"grad_input shape: {grad_input.shape }")
|
| 53 |
+
#print(f"grad_input.sum(dim=-1): {grad_input.sum(dim=-1).shape}")
|
| 54 |
+
grad_input = grad_input.sum(dim=-1).mul(inv_denominator)
|
| 55 |
+
#print(f"Backward pass 2 - grad_input: {(grad_input.min().item(), grad_input.max().item())}")
|
| 56 |
+
#print(f"grad_input: {grad_input}")
|
| 57 |
+
#print(f"grad_input shape: {grad_input.shape }")
|
| 58 |
+
|
| 59 |
+
# Compute the backward pass for grid
|
| 60 |
+
if ctx.train_grid:
|
| 61 |
+
#print('\n')
|
| 62 |
+
#print(f"grad_grid shape: {grad_grid.shape }")
|
| 63 |
+
grad_grid = -inv_denominator * grad_output.sum(dim=0).sum(dim=0)#-(inv_denominator * grad_output * tanh_diff_deriviative).sum(dim=0) #-inv_denominator * grad_output.sum(dim=0).sum(dim=0)
|
| 64 |
+
#print(f"Backward pass - grad_grid: {(grad_grid[0].item(),grad_grid[-1].item())}")
|
| 65 |
+
#print(f"grad_grid.shape: {grad_grid.shape }")
|
| 66 |
+
#print(f"grad_grid: {(grad_grid[0],grad_grid[-1]) }")
|
| 67 |
+
#print(f"inv_denominator shape: {inv_denominator.shape }")
|
| 68 |
+
#print(f"grad_grid shape: {grad_grid.shape }")
|
| 69 |
+
|
| 70 |
+
# Compute the backward pass for inv_denominator
|
| 71 |
+
if ctx.train_inv_denominator:
|
| 72 |
+
grad_inv_denominator = (grad_output* diff).sum() #(grad_output * diff * tanh_diff_deriviative).sum() #(grad_output* diff).sum()
|
| 73 |
+
#print(f"Backward pass - grad_inv_denominator: {grad_inv_denominator.item()}")
|
| 74 |
+
#print(f"diff shape: {diff.shape }")
|
| 75 |
+
|
| 76 |
+
#print(f"grad_inv_denominator shape: {grad_inv_denominator.shape }")
|
| 77 |
+
#print(f"grad_inv_denominator : {grad_inv_denominator }")
|
| 78 |
+
|
| 79 |
+
return grad_input, grad_grid, grad_inv_denominator, None, None # same number as tensors or parameters
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class ReflectionalSwitchFunction(nn.Module):
|
| 84 |
+
def __init__(
|
| 85 |
+
self,
|
| 86 |
+
grid_min: float = -1.2,
|
| 87 |
+
grid_max: float = 0.2,
|
| 88 |
+
num_grids: int = 8,
|
| 89 |
+
exponent: int = 2,
|
| 90 |
+
inv_denominator: float = 0.5,
|
| 91 |
+
train_grid: bool = False,
|
| 92 |
+
train_inv_denominator: bool = False,
|
| 93 |
+
):
|
| 94 |
+
super().__init__()
|
| 95 |
+
grid = torch.linspace(grid_min, grid_max, num_grids)
|
| 96 |
+
self.train_grid = torch.tensor(train_grid, dtype=torch.bool)
|
| 97 |
+
self.train_inv_denominator = torch.tensor(train_inv_denominator, dtype=torch.bool)
|
| 98 |
+
self.grid = torch.nn.Parameter(grid, requires_grad=train_grid)
|
| 99 |
+
#print(f"grid initial shape: {self.grid.shape }")
|
| 100 |
+
self.inv_denominator = torch.nn.Parameter(torch.tensor(inv_denominator, dtype=torch.float32), requires_grad=train_inv_denominator) # Cache the inverse of the denominator
|
| 101 |
+
|
| 102 |
+
def forward(self, x):
|
| 103 |
+
return RSWAFFunction.apply(x, self.grid, self.inv_denominator, self.train_grid, self.train_inv_denominator)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class SplineLinear(nn.Linear):
|
| 107 |
+
def __init__(self, in_features: int, out_features: int, init_scale: float = 0.1, **kw) -> None:
|
| 108 |
+
self.init_scale = init_scale
|
| 109 |
+
super().__init__(in_features, out_features, bias=False, **kw)
|
| 110 |
+
|
| 111 |
+
def reset_parameters(self) -> None:
|
| 112 |
+
nn.init.xavier_uniform_(self.weight) # Using Xavier Uniform initialization
|
tasks/utils/kan/fasterkan_layers.py
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import math
|
| 5 |
+
from typing import *
|
| 6 |
+
from torch.autograd import Function
|
| 7 |
+
from .fasterkan_basis import ReflectionalSwitchFunction, SplineLinear
|
| 8 |
+
|
| 9 |
+
class FasterKANLayer(nn.Module):
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
input_dim: int,
|
| 13 |
+
output_dim: int,
|
| 14 |
+
grid_min: float = -1.2,
|
| 15 |
+
grid_max: float = 0.2,
|
| 16 |
+
num_grids: int = 8,
|
| 17 |
+
exponent: int = 2,
|
| 18 |
+
inv_denominator: float = 0.5,
|
| 19 |
+
train_grid: bool = False,
|
| 20 |
+
train_inv_denominator: bool = False,
|
| 21 |
+
#use_base_update: bool = True,
|
| 22 |
+
base_activation = F.silu,
|
| 23 |
+
spline_weight_init_scale: float = 0.667,
|
| 24 |
+
) -> None:
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.layernorm = nn.LayerNorm(input_dim)
|
| 27 |
+
self.rbf = ReflectionalSwitchFunction(grid_min, grid_max, num_grids, exponent, inv_denominator, train_grid, train_inv_denominator)
|
| 28 |
+
self.spline_linear = SplineLinear(input_dim * num_grids, output_dim, spline_weight_init_scale)
|
| 29 |
+
#self.use_base_update = use_base_update
|
| 30 |
+
#if use_base_update:
|
| 31 |
+
# self.base_activation = base_activation
|
| 32 |
+
# self.base_linear = nn.Linear(input_dim, output_dim)
|
| 33 |
+
|
| 34 |
+
def forward(self, x):
|
| 35 |
+
#print("Shape before LayerNorm:", x.shape) # Debugging line to check the input shape
|
| 36 |
+
x = self.layernorm(x)
|
| 37 |
+
#print("Shape After LayerNorm:", x.shape)
|
| 38 |
+
spline_basis = self.rbf(x).view(x.shape[0], -1)
|
| 39 |
+
#print("spline_basis:", spline_basis.shape)
|
| 40 |
+
|
| 41 |
+
#print("-------------------------")
|
| 42 |
+
#ret = 0
|
| 43 |
+
ret = self.spline_linear(spline_basis)
|
| 44 |
+
#print("spline_basis.shape[:-2]:", spline_basis.shape[:-2])
|
| 45 |
+
#print("*spline_basis.shape[:-2]:", *spline_basis.shape[:-2])
|
| 46 |
+
#print("spline_basis.view(*spline_basis.shape[:-2], -1):", spline_basis.view(*spline_basis.shape[:-2], -1).shape)
|
| 47 |
+
#print("ret:", ret.shape)
|
| 48 |
+
#print("-------------------------")
|
| 49 |
+
#if self.use_base_update:
|
| 50 |
+
#base = self.base_linear(self.base_activation(x))
|
| 51 |
+
#print("self.base_activation(x):", self.base_activation(x).shape)
|
| 52 |
+
#print("base:", base.shape)
|
| 53 |
+
#print("@@@@@@@@@")
|
| 54 |
+
#ret += base
|
| 55 |
+
return ret
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
#spline_basis = spline_basis.reshape(x.shape[0], -1) # Reshape to [batch_size, input_dim * num_grids]
|
| 59 |
+
#print("spline_basis:", spline_basis.shape)
|
| 60 |
+
|
| 61 |
+
#spline_weight = self.spline_weight.view(-1, self.spline_weight.shape[0]) # Reshape to [input_dim * num_grids, output_dim]
|
| 62 |
+
#print("spline_weight:", spline_weight.shape)
|
| 63 |
+
|
| 64 |
+
#spline = torch.matmul(spline_basis, spline_weight) # Resulting shape: [batch_size, output_dim]
|
| 65 |
+
|
| 66 |
+
#print("-------------------------")
|
| 67 |
+
#print("Base shape:", base.shape)
|
| 68 |
+
#print("Spline shape:", spline.shape)
|
| 69 |
+
#print("@@@@@@@@@")
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class FasterKAN(nn.Module):
|
| 73 |
+
def __init__(
|
| 74 |
+
self,
|
| 75 |
+
layers_hidden: List[int],
|
| 76 |
+
grid_min: float = -1.2,
|
| 77 |
+
grid_max: float = 0.2,
|
| 78 |
+
num_grids: int = 8,
|
| 79 |
+
exponent: int = 2,
|
| 80 |
+
inv_denominator: float = 0.5,
|
| 81 |
+
train_grid: bool = False,
|
| 82 |
+
train_inv_denominator: bool = False,
|
| 83 |
+
#use_base_update: bool = True,
|
| 84 |
+
base_activation = None,
|
| 85 |
+
spline_weight_init_scale: float = 1.0,
|
| 86 |
+
) -> None:
|
| 87 |
+
super().__init__()
|
| 88 |
+
self.layers = nn.ModuleList([
|
| 89 |
+
FasterKANLayer(
|
| 90 |
+
in_dim, out_dim,
|
| 91 |
+
grid_min=grid_min,
|
| 92 |
+
grid_max=grid_max,
|
| 93 |
+
num_grids=num_grids,
|
| 94 |
+
exponent = exponent,
|
| 95 |
+
inv_denominator = inv_denominator,
|
| 96 |
+
train_grid = train_grid ,
|
| 97 |
+
train_inv_denominator = train_inv_denominator,
|
| 98 |
+
#use_base_update=use_base_update,
|
| 99 |
+
base_activation=base_activation,
|
| 100 |
+
spline_weight_init_scale=spline_weight_init_scale,
|
| 101 |
+
) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:])
|
| 102 |
+
])
|
| 103 |
+
#print(f"FasterKAN layers_hidden[1:] shape: ", len(layers_hidden[1:]))
|
| 104 |
+
#print(f"FasterKAN layers_hidden[:-1] shape: ", len(layers_hidden[:-1]))
|
| 105 |
+
#print("FasterKAN zip shape: \n", *[(in_dim, out_dim) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:])])
|
| 106 |
+
|
| 107 |
+
#print(f"FasterKAN self.faster_kan_layers shape: \n", len(self.layers))
|
| 108 |
+
#print(f"FasterKAN self.faster_kan_layers: \n", self.layers)
|
| 109 |
+
|
| 110 |
+
def forward(self, x):
|
| 111 |
+
for layer in self.layers:
|
| 112 |
+
#print("FasterKAN layer: \n", layer)
|
| 113 |
+
#print(f"FasterKAN x shape: {x.shape}")
|
| 114 |
+
x = layer(x)
|
| 115 |
+
return x
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class BasicResBlock(nn.Module):
|
| 120 |
+
def __init__(self, in_channels, out_channels, stride=1):
|
| 121 |
+
super(BasicResBlock, self).__init__()
|
| 122 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
|
| 123 |
+
self.bn1 = nn.BatchNorm2d(out_channels)
|
| 124 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
|
| 125 |
+
self.bn2 = nn.BatchNorm2d(out_channels)
|
| 126 |
+
|
| 127 |
+
self.downsample = nn.Sequential()
|
| 128 |
+
if stride != 1 or in_channels != out_channels:
|
| 129 |
+
self.downsample = nn.Sequential(
|
| 130 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
|
| 131 |
+
nn.BatchNorm2d(out_channels)
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
def forward(self, x):
|
| 135 |
+
identity = self.downsample(x)
|
| 136 |
+
|
| 137 |
+
out = F.relu(self.bn1(self.conv1(x)))
|
| 138 |
+
out = self.bn2(self.conv2(out))
|
| 139 |
+
out += identity
|
| 140 |
+
out = F.relu(out)
|
| 141 |
+
|
| 142 |
+
return out
|
| 143 |
+
|
| 144 |
+
class SEBlock(nn.Module):
|
| 145 |
+
def __init__(self, channel, reduction=16):
|
| 146 |
+
super(SEBlock, self).__init__()
|
| 147 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
| 148 |
+
self.fc = nn.Sequential(
|
| 149 |
+
nn.Linear(channel, channel // reduction, bias=False),
|
| 150 |
+
nn.ReLU(inplace=True),
|
| 151 |
+
nn.Linear(channel // reduction, channel, bias=False),
|
| 152 |
+
nn.Sigmoid()
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
def forward(self, x):
|
| 156 |
+
b, c, _, _ = x.size()
|
| 157 |
+
y = self.avg_pool(x).view(b, c)
|
| 158 |
+
y = self.fc(y).view(b, c, 1, 1)
|
| 159 |
+
return x * y.expand_as(x)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class DepthwiseSeparableConv(nn.Module):
|
| 163 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
|
| 164 |
+
super(DepthwiseSeparableConv, self).__init__()
|
| 165 |
+
self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size,
|
| 166 |
+
stride=stride, padding=padding, groups=in_channels)
|
| 167 |
+
self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1)
|
| 168 |
+
|
| 169 |
+
def forward(self, x):
|
| 170 |
+
x = self.depthwise(x)
|
| 171 |
+
x = self.pointwise(x)
|
| 172 |
+
return x
|
| 173 |
+
|
| 174 |
+
class SelfAttention(nn.Module):
|
| 175 |
+
def __init__(self, in_channels):
|
| 176 |
+
super(SelfAttention, self).__init__()
|
| 177 |
+
self.query_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
|
| 178 |
+
self.key_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
|
| 179 |
+
self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
| 180 |
+
self.gamma = nn.Parameter(torch.zeros(1))
|
| 181 |
+
|
| 182 |
+
def forward(self, x):
|
| 183 |
+
batch_size, C, width, height = x.size()
|
| 184 |
+
proj_query = self.query_conv(x).view(batch_size, -1, width * height).permute(0, 2, 1)
|
| 185 |
+
proj_key = self.key_conv(x).view(batch_size, -1, width * height)
|
| 186 |
+
energy = torch.bmm(proj_query, proj_key)
|
| 187 |
+
attention = F.softmax(energy, dim=-1)
|
| 188 |
+
proj_value = self.value_conv(x).view(batch_size, -1, width * height)
|
| 189 |
+
out = torch.bmm(proj_value, attention.permute(0, 2, 1))
|
| 190 |
+
out = out.view(batch_size, C, width, height)
|
| 191 |
+
out = self.gamma * out + x
|
| 192 |
+
return out
|
| 193 |
+
|
| 194 |
+
class EnhancedFeatureExtractor(nn.Module):
|
| 195 |
+
def __init__(self):
|
| 196 |
+
super(EnhancedFeatureExtractor, self).__init__()
|
| 197 |
+
self.initial_layers = nn.Sequential(
|
| 198 |
+
nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1), # Increased number of filters
|
| 199 |
+
nn.ReLU(),
|
| 200 |
+
nn.BatchNorm2d(32), # Added Batch Normalization
|
| 201 |
+
nn.MaxPool2d(2, 2),
|
| 202 |
+
nn.Dropout(0.25), # Added Dropout
|
| 203 |
+
BasicResBlock(32, 64),
|
| 204 |
+
SEBlock(64, reduction=16), # Squeeze-and-Excitation block
|
| 205 |
+
nn.MaxPool2d(2, 2),
|
| 206 |
+
nn.Dropout(0.25), # Added Dropout
|
| 207 |
+
DepthwiseSeparableConv(64, 128, kernel_size=3), # Increased number of filters
|
| 208 |
+
nn.ReLU(),
|
| 209 |
+
BasicResBlock(128, 256),
|
| 210 |
+
SEBlock(256, reduction=16),
|
| 211 |
+
nn.MaxPool2d(2, 2),
|
| 212 |
+
nn.Dropout(0.25), # Added Dropout
|
| 213 |
+
SelfAttention(256), # Added Self-Attention layer
|
| 214 |
+
)
|
| 215 |
+
self.global_avg_pool = nn.AdaptiveAvgPool2d(1) # Global Average Pooling
|
| 216 |
+
|
| 217 |
+
def forward(self, x):
|
| 218 |
+
x = self.initial_layers(x)
|
| 219 |
+
x = self.global_avg_pool(x)
|
| 220 |
+
x = x.view(x.size(0), -1) # Flatten the output for fully connected layers
|
| 221 |
+
return x
|
| 222 |
+
|
| 223 |
+
class FasterKANvolver(nn.Module):
|
| 224 |
+
def __init__(
|
| 225 |
+
self,
|
| 226 |
+
layers_hidden: List[int],
|
| 227 |
+
grid_min: float = -1.2,
|
| 228 |
+
grid_max: float = 0.2,
|
| 229 |
+
num_grids: int = 8,
|
| 230 |
+
exponent: int = 2,
|
| 231 |
+
inv_denominator: float = 0.5,
|
| 232 |
+
train_grid: bool = False,
|
| 233 |
+
train_inv_denominator: bool = False,
|
| 234 |
+
#use_base_update: bool = True,
|
| 235 |
+
base_activation = None,
|
| 236 |
+
spline_weight_init_scale: float = 1.0,
|
| 237 |
+
) -> None:
|
| 238 |
+
super(FasterKANvolver, self).__init__()
|
| 239 |
+
|
| 240 |
+
# Feature extractor with Convolutional layers
|
| 241 |
+
self.feature_extractor = EnhancedFeatureExtractor()
|
| 242 |
+
"""
|
| 243 |
+
nn.Sequential(
|
| 244 |
+
nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1), # 1 input channel (grayscale), 16 output channels
|
| 245 |
+
nn.ReLU(),
|
| 246 |
+
nn.MaxPool2d(2, 2),
|
| 247 |
+
nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
|
| 248 |
+
nn.ReLU(),
|
| 249 |
+
nn.MaxPool2d(2, 2)
|
| 250 |
+
)
|
| 251 |
+
"""
|
| 252 |
+
|
| 253 |
+
# Calculate the flattened feature size after convolutional layers
|
| 254 |
+
flat_features = 256 # XX channels, image size reduced to YxY
|
| 255 |
+
|
| 256 |
+
# Update layers_hidden with the correct input size from conv layers
|
| 257 |
+
layers_hidden = [flat_features] + layers_hidden
|
| 258 |
+
#print(f"FasterKANvolver layers_hidden shape: \n", layers_hidden)
|
| 259 |
+
#print(f"FasterKANvolver layers_hidden[1:] shape: ", len(layers_hidden[1:]))
|
| 260 |
+
#print(f"FasterKANvolver layers_hidden[:-1] shape: ", len(layers_hidden[:-1]))
|
| 261 |
+
#print("FasterKANvolver zip shape: \n", *[(in_dim, out_dim) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:])])
|
| 262 |
+
|
| 263 |
+
# Define the FasterKAN layers
|
| 264 |
+
self.faster_kan_layers = nn.ModuleList([
|
| 265 |
+
FasterKANLayer(
|
| 266 |
+
in_dim, out_dim,
|
| 267 |
+
grid_min=grid_min,
|
| 268 |
+
grid_max=grid_max,
|
| 269 |
+
num_grids=num_grids,
|
| 270 |
+
exponent=exponent,
|
| 271 |
+
inv_denominator = 0.5,
|
| 272 |
+
train_grid = False,
|
| 273 |
+
train_inv_denominator = False,
|
| 274 |
+
#use_base_update=use_base_update,
|
| 275 |
+
base_activation=base_activation,
|
| 276 |
+
spline_weight_init_scale=spline_weight_init_scale,
|
| 277 |
+
) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:])
|
| 278 |
+
])
|
| 279 |
+
#print(f"FasterKANvolver self.faster_kan_layers shape: \n", len(self.faster_kan_layers))
|
| 280 |
+
#print(f"FasterKANvolver self.faster_kan_layers: \n", self.faster_kan_layers)
|
| 281 |
+
|
| 282 |
+
def forward(self, x):
|
| 283 |
+
# Reshape input from [batch_size, 784] to [batch_size, 1, 28, 28] for MNIST [batch_size, 1, 32, 32] for C
|
| 284 |
+
#print(f"FasterKAN x view shape: {x.shape}")
|
| 285 |
+
x = x.view(-1, 3, 32,32)
|
| 286 |
+
#print(f"FasterKAN x view shape: {x.shape}")
|
| 287 |
+
# Apply convolutional layers
|
| 288 |
+
#print(f"FasterKAN x view shape: {x.shape}")
|
| 289 |
+
x = self.feature_extractor(x)
|
| 290 |
+
#print(f"FasterKAN x after feature_extractor shape: {x.shape}")
|
| 291 |
+
x = x.view(x.size(0), -1) # Flatten the output from the conv layers
|
| 292 |
+
#rint(f"FasterKAN x shape: {x.shape}")
|
| 293 |
+
|
| 294 |
+
# Pass through FasterKAN layers
|
| 295 |
+
for layer in self.faster_kan_layers:
|
| 296 |
+
#print("FasterKAN layer: \n", layer)
|
| 297 |
+
#print(f"FasterKAN x shape: {x.shape}")
|
| 298 |
+
x = layer(x)
|
| 299 |
+
#print(f"FasterKAN x shape: {x.shape}")
|
| 300 |
+
|
| 301 |
+
return x
|
tasks/utils/kan/feature_extractor.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import math
|
| 5 |
+
from typing import *
|
| 6 |
+
from torch.autograd import Function
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class BasicResBlock(nn.Module):
|
| 10 |
+
def __init__(self, in_channels, out_channels, stride=1):
|
| 11 |
+
super(BasicResBlock, self).__init__()
|
| 12 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
|
| 13 |
+
self.bn1 = nn.BatchNorm2d(out_channels)
|
| 14 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
|
| 15 |
+
self.bn2 = nn.BatchNorm2d(out_channels)
|
| 16 |
+
|
| 17 |
+
self.downsample = nn.Sequential()
|
| 18 |
+
if stride != 1 or in_channels != out_channels:
|
| 19 |
+
self.downsample = nn.Sequential(
|
| 20 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
|
| 21 |
+
nn.BatchNorm2d(out_channels)
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
def forward(self, x):
|
| 25 |
+
identity = self.downsample(x)
|
| 26 |
+
|
| 27 |
+
out = F.relu(self.bn1(self.conv1(x)))
|
| 28 |
+
out = self.bn2(self.conv2(out))
|
| 29 |
+
out += identity
|
| 30 |
+
out = F.relu(out)
|
| 31 |
+
|
| 32 |
+
return out
|
| 33 |
+
|
| 34 |
+
class SEBlock(nn.Module):
|
| 35 |
+
def __init__(self, channel, reduction=16):
|
| 36 |
+
super(SEBlock, self).__init__()
|
| 37 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
| 38 |
+
self.fc = nn.Sequential(
|
| 39 |
+
nn.Linear(channel, channel // reduction, bias=False),
|
| 40 |
+
nn.ReLU(inplace=True),
|
| 41 |
+
nn.Linear(channel // reduction, channel, bias=False),
|
| 42 |
+
nn.Sigmoid()
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
def forward(self, x):
|
| 46 |
+
b, c, _, _ = x.size()
|
| 47 |
+
y = self.avg_pool(x).view(b, c)
|
| 48 |
+
y = self.fc(y).view(b, c, 1, 1)
|
| 49 |
+
return x * y.expand_as(x)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class DepthwiseSeparableConv(nn.Module):
|
| 53 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
|
| 54 |
+
super(DepthwiseSeparableConv, self).__init__()
|
| 55 |
+
self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size,
|
| 56 |
+
stride=stride, padding=padding, groups=in_channels)
|
| 57 |
+
self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1)
|
| 58 |
+
|
| 59 |
+
def forward(self, x):
|
| 60 |
+
x = self.depthwise(x)
|
| 61 |
+
x = self.pointwise(x)
|
| 62 |
+
return x
|
| 63 |
+
|
| 64 |
+
class SelfAttention(nn.Module):
|
| 65 |
+
def __init__(self, in_channels):
|
| 66 |
+
super(SelfAttention, self).__init__()
|
| 67 |
+
self.query_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
|
| 68 |
+
self.key_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
|
| 69 |
+
self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
| 70 |
+
self.gamma = nn.Parameter(torch.zeros(1))
|
| 71 |
+
|
| 72 |
+
def forward(self, x):
|
| 73 |
+
batch_size, C, width, height = x.size()
|
| 74 |
+
proj_query = self.query_conv(x).view(batch_size, -1, width * height).permute(0, 2, 1)
|
| 75 |
+
proj_key = self.key_conv(x).view(batch_size, -1, width * height)
|
| 76 |
+
energy = torch.bmm(proj_query, proj_key)
|
| 77 |
+
attention = F.softmax(energy, dim=-1)
|
| 78 |
+
proj_value = self.value_conv(x).view(batch_size, -1, width * height)
|
| 79 |
+
out = torch.bmm(proj_value, attention.permute(0, 2, 1))
|
| 80 |
+
out = out.view(batch_size, C, width, height)
|
| 81 |
+
out = self.gamma * out + x
|
| 82 |
+
return out
|
| 83 |
+
|
| 84 |
+
class EnhancedFeatureExtractor(nn.Module):
|
| 85 |
+
def __init__(self,
|
| 86 |
+
colors = 3):
|
| 87 |
+
super(EnhancedFeatureExtractor, self).__init__()
|
| 88 |
+
self.initial_layers = nn.Sequential(
|
| 89 |
+
nn.Conv2d(colors, 32, kernel_size=3, stride=1, padding=1), # Increased number of filters
|
| 90 |
+
nn.ReLU(),
|
| 91 |
+
nn.BatchNorm2d(32), # Added Batch Normalization
|
| 92 |
+
nn.MaxPool2d(2, 2),
|
| 93 |
+
nn.Dropout(0.25), # Added Dropout
|
| 94 |
+
BasicResBlock(32, 64),
|
| 95 |
+
SEBlock(64, reduction=16), # Squeeze-and-Excitation block
|
| 96 |
+
nn.MaxPool2d(2, 2),
|
| 97 |
+
nn.Dropout(0.25), # Added Dropout
|
| 98 |
+
DepthwiseSeparableConv(64, 128, kernel_size=3), # Increased number of filters
|
| 99 |
+
nn.ReLU(),
|
| 100 |
+
BasicResBlock(128, 256),
|
| 101 |
+
SEBlock(256, reduction=16),
|
| 102 |
+
nn.MaxPool2d(2, 2),
|
| 103 |
+
nn.Dropout(0.25), # Added Dropout
|
| 104 |
+
SelfAttention(256), # Added Self-Attention layer
|
| 105 |
+
)
|
| 106 |
+
self.global_avg_pool = nn.AdaptiveAvgPool2d(1) # Global Average Pooling
|
| 107 |
+
|
| 108 |
+
def forward(self, x):
|
| 109 |
+
x = self.initial_layers(x)
|
| 110 |
+
x = self.global_avg_pool(x)
|
| 111 |
+
x = x.view(x.size(0), -1) # Flatten the output for fully connected layers
|
| 112 |
+
return x
|
tasks/utils/models.py
CHANGED
|
@@ -2,6 +2,8 @@ import torch
|
|
| 2 |
import torch.nn as nn
|
| 3 |
from .Modules.conformer import ConformerEncoder, ConformerDecoder
|
| 4 |
from .Modules.mhsa_pro import RotaryEmbedding, ContinuousRotaryEmbedding
|
|
|
|
|
|
|
| 5 |
|
| 6 |
class ConvBlock(nn.Module):
|
| 7 |
def __init__(self, args, num_layer) -> None:
|
|
@@ -111,4 +113,29 @@ class DualEncoder(nn.Module):
|
|
| 111 |
x1 = self.encoder_x(x)
|
| 112 |
x2, _ = self.encoder_f(x)
|
| 113 |
logits = torch.cat([x1, x2], dim=-1)
|
| 114 |
-
return self.regressor(logits).squeeze()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import torch.nn as nn
|
| 3 |
from .Modules.conformer import ConformerEncoder, ConformerDecoder
|
| 4 |
from .Modules.mhsa_pro import RotaryEmbedding, ContinuousRotaryEmbedding
|
| 5 |
+
from .kan.fasterkan import FasterKAN
|
| 6 |
+
from kan import KAN
|
| 7 |
|
| 8 |
class ConvBlock(nn.Module):
|
| 9 |
def __init__(self, args, num_layer) -> None:
|
|
|
|
| 113 |
x1 = self.encoder_x(x)
|
| 114 |
x2, _ = self.encoder_f(x)
|
| 115 |
logits = torch.cat([x1, x2], dim=-1)
|
| 116 |
+
return self.regressor(logits).squeeze()
|
| 117 |
+
|
| 118 |
+
class CNNKan(nn.Module):
|
| 119 |
+
def __init__(self, args, conformer_args, kan_args):
|
| 120 |
+
super().__init__()
|
| 121 |
+
self.backbone = CNNEncoder(args)
|
| 122 |
+
# self.kan = KAN(width=kan_args['layers_hidden'])
|
| 123 |
+
self.kan = FasterKAN(**kan_args)
|
| 124 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 125 |
+
x = self.backbone(x)
|
| 126 |
+
x = x.mean(dim=1)
|
| 127 |
+
return self.kan(x)
|
| 128 |
+
|
| 129 |
+
class KanEncoder(nn.Module):
|
| 130 |
+
def __init__(self, args):
|
| 131 |
+
super().__init__()
|
| 132 |
+
self.kan_x = FasterKAN(**args)
|
| 133 |
+
self.kan_f = FasterKAN(**args)
|
| 134 |
+
self.kan_out = FasterKAN(layers_hidden=[args['layers_hidden'][-1]*2, 8,8,1])
|
| 135 |
+
|
| 136 |
+
def forward(self, x: torch.Tensor, f: torch.Tensor) -> torch.Tensor:
|
| 137 |
+
x = self.kan_x(x)
|
| 138 |
+
f = self.kan_f(f)
|
| 139 |
+
out = torch.cat([x, f], dim=-1)
|
| 140 |
+
return self.kan_out(out)
|
| 141 |
+
|
tasks/utils/train.py
CHANGED
|
@@ -74,8 +74,8 @@ class Trainer(object):
|
|
| 74 |
lrs = []
|
| 75 |
# self.optim_params['lr_history'] = []
|
| 76 |
epochs_without_improvement = 0
|
| 77 |
-
main_proccess = (torch.distributed.is_initialized() and torch.distributed.get_rank() == 0) or self.device == 'cpu'
|
| 78 |
-
|
| 79 |
print(f"Starting training for {num_epochs} epochs")
|
| 80 |
print("is main process: ", main_proccess, flush=True)
|
| 81 |
global_time = time.time()
|
|
@@ -221,7 +221,8 @@ class Trainer(object):
|
|
| 221 |
x = x.to(device).float()
|
| 222 |
fft = fft.to(device).float()
|
| 223 |
y = y.to(device).float()
|
| 224 |
-
|
|
|
|
| 225 |
loss = self.criterion(y_pred, y)
|
| 226 |
loss.backward()
|
| 227 |
self.optimizer.step()
|
|
@@ -230,7 +231,7 @@ class Trainer(object):
|
|
| 230 |
# get predicted classes
|
| 231 |
probs = torch.sigmoid(y_pred)
|
| 232 |
cls_pred = (probs > 0.5).float()
|
| 233 |
-
acc = (cls_pred == y).sum()
|
| 234 |
return loss, acc, y
|
| 235 |
|
| 236 |
def eval_epoch(self, device, epoch):
|
|
@@ -257,10 +258,11 @@ class Trainer(object):
|
|
| 257 |
x, fft, y = batch['audio']['array'], batch['audio']['fft'], batch['label']
|
| 258 |
x = x.to(device).float()
|
| 259 |
fft = fft.to(device).float()
|
|
|
|
| 260 |
y = y.to(device).float()
|
| 261 |
with torch.no_grad():
|
| 262 |
-
y_pred = self.model(
|
| 263 |
-
loss = self.criterion(y_pred, y)
|
| 264 |
probs = torch.sigmoid(y_pred)
|
| 265 |
cls_pred = (probs > 0.5).float()
|
| 266 |
acc = (cls_pred == y).sum()
|
|
@@ -280,15 +282,16 @@ class Trainer(object):
|
|
| 280 |
x, fft, y = batch['audio']['array'], batch['audio']['fft'], batch['label']
|
| 281 |
x = x.to(device).float()
|
| 282 |
fft = fft.to(device).float()
|
|
|
|
| 283 |
y = y.to(device).float()
|
| 284 |
with torch.no_grad():
|
| 285 |
-
y_pred = self.model(
|
| 286 |
loss = self.criterion(y_pred, y)
|
| 287 |
probs = torch.sigmoid(y_pred)
|
| 288 |
cls_pred = (probs > 0.5).float()
|
| 289 |
acc = (cls_pred == y).sum()
|
| 290 |
-
predictions.
|
| 291 |
-
true_labels.
|
| 292 |
all_accs += acc
|
| 293 |
total += len(y)
|
| 294 |
pbar.set_description("acc: {:.4f}".format(acc))
|
|
|
|
| 74 |
lrs = []
|
| 75 |
# self.optim_params['lr_history'] = []
|
| 76 |
epochs_without_improvement = 0
|
| 77 |
+
# main_proccess = (torch.distributed.is_initialized() and torch.distributed.get_rank() == 0) or self.device == 'cpu'
|
| 78 |
+
main_proccess = True # change in a ddp setting
|
| 79 |
print(f"Starting training for {num_epochs} epochs")
|
| 80 |
print("is main process: ", main_proccess, flush=True)
|
| 81 |
global_time = time.time()
|
|
|
|
| 221 |
x = x.to(device).float()
|
| 222 |
fft = fft.to(device).float()
|
| 223 |
y = y.to(device).float()
|
| 224 |
+
x_fft = torch.cat((x.unsqueeze(dim=1), fft.unsqueeze(dim=1)), dim=1)
|
| 225 |
+
y_pred = self.model(x_fft).squeeze()
|
| 226 |
loss = self.criterion(y_pred, y)
|
| 227 |
loss.backward()
|
| 228 |
self.optimizer.step()
|
|
|
|
| 231 |
# get predicted classes
|
| 232 |
probs = torch.sigmoid(y_pred)
|
| 233 |
cls_pred = (probs > 0.5).float()
|
| 234 |
+
acc = (cls_pred == y).sum()
|
| 235 |
return loss, acc, y
|
| 236 |
|
| 237 |
def eval_epoch(self, device, epoch):
|
|
|
|
| 258 |
x, fft, y = batch['audio']['array'], batch['audio']['fft'], batch['label']
|
| 259 |
x = x.to(device).float()
|
| 260 |
fft = fft.to(device).float()
|
| 261 |
+
x_fft = torch.cat((x.unsqueeze(dim=1), fft.unsqueeze(dim=1)), dim=1)
|
| 262 |
y = y.to(device).float()
|
| 263 |
with torch.no_grad():
|
| 264 |
+
y_pred = self.model(x_fft).squeeze()
|
| 265 |
+
loss = self.criterion(y_pred.squeeze(), y)
|
| 266 |
probs = torch.sigmoid(y_pred)
|
| 267 |
cls_pred = (probs > 0.5).float()
|
| 268 |
acc = (cls_pred == y).sum()
|
|
|
|
| 282 |
x, fft, y = batch['audio']['array'], batch['audio']['fft'], batch['label']
|
| 283 |
x = x.to(device).float()
|
| 284 |
fft = fft.to(device).float()
|
| 285 |
+
x_fft = torch.cat((x.unsqueeze(dim=1), fft.unsqueeze(dim=1)), dim=1)
|
| 286 |
y = y.to(device).float()
|
| 287 |
with torch.no_grad():
|
| 288 |
+
y_pred = self.model(x_fft).squeeze()
|
| 289 |
loss = self.criterion(y_pred, y)
|
| 290 |
probs = torch.sigmoid(y_pred)
|
| 291 |
cls_pred = (probs > 0.5).float()
|
| 292 |
acc = (cls_pred == y).sum()
|
| 293 |
+
predictions.extend(cls_pred.cpu().numpy())
|
| 294 |
+
true_labels.extend(y.cpu().numpy())
|
| 295 |
all_accs += acc
|
| 296 |
total += len(y)
|
| 297 |
pbar.set_description("acc: {:.4f}".format(acc))
|