Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import math | |
| from typing import * | |
| from torch.autograd import Function | |
| from .feature_extractor import EnhancedFeatureExtractor | |
| from .fasterkan_layers import FasterKANLayer | |
| class FasterKAN(nn.Module): | |
| def __init__( | |
| self, | |
| layers_hidden: List[int], | |
| grid_min: float = -1.2, | |
| grid_max: float = 1.2, | |
| num_grids: int = 8, | |
| exponent: int = 2, | |
| inv_denominator: float = 0.5, | |
| train_grid: bool = False, | |
| train_inv_denominator: bool = False, | |
| #use_base_update: bool = True, | |
| base_activation = None, | |
| spline_weight_init_scale: float = 1.0, | |
| ) -> None: | |
| super().__init__() | |
| self.layers = nn.ModuleList([ | |
| FasterKANLayer( | |
| in_dim, out_dim, | |
| grid_min=grid_min, | |
| grid_max=grid_max, | |
| num_grids=num_grids, | |
| exponent = exponent, | |
| inv_denominator = inv_denominator, | |
| train_grid = train_grid , | |
| train_inv_denominator = train_inv_denominator, | |
| #use_base_update=use_base_update, | |
| base_activation=base_activation, | |
| spline_weight_init_scale=spline_weight_init_scale, | |
| ) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:]) | |
| ]) | |
| #print(f"FasterKAN layers_hidden[1:] shape: ", len(layers_hidden[1:])) | |
| #print(f"FasterKAN layers_hidden[:-1] shape: ", len(layers_hidden[:-1])) | |
| #print("FasterKAN zip shape: \n", *[(in_dim, out_dim) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:])]) | |
| #print(f"FasterKAN self.faster_kan_layers shape: \n", len(self.layers)) | |
| #print(f"FasterKAN self.faster_kan_layers: \n", self.layers) | |
| def forward(self, x): | |
| for layer in self.layers: | |
| #print("FasterKAN layer: \n", layer) | |
| #print(f"FasterKAN x shape: {x.shape}") | |
| x = layer(x) | |
| return x | |
| class FasterKANvolver(nn.Module): | |
| def __init__( | |
| self, | |
| layers_hidden: List[int], | |
| grid_min: float = -1.2, | |
| grid_max: float = 0.2, | |
| num_grids: int = 8, | |
| exponent: int = 2, | |
| inv_denominator: float = 0.5, | |
| train_grid: bool = False, | |
| train_inv_denominator: bool = False, | |
| #use_base_update: bool = True, | |
| base_activation = None, | |
| spline_weight_init_scale: float = 1.0, | |
| view = [-1, 1, 28, 28], | |
| ) -> None: | |
| super(FasterKANvolver, self).__init__() | |
| self.view = view | |
| # Feature extractor with Convolutional layers | |
| self.feature_extractor = EnhancedFeatureExtractor(colors = view[1]) | |
| """ | |
| nn.Sequential( | |
| nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1), # 1 input channel (grayscale), 16 output channels | |
| nn.ReLU(), | |
| nn.MaxPool2d(2, 2), | |
| nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1), | |
| nn.ReLU(), | |
| nn.MaxPool2d(2, 2) | |
| ) | |
| """ | |
| # Calculate the flattened feature size after convolutional layers | |
| flat_features = 256 # XX channels, image size reduced to YxY | |
| # Update layers_hidden with the correct input size from conv layers | |
| layers_hidden = [flat_features] + layers_hidden | |
| #print(f"FasterKANvolver layers_hidden shape: \n", layers_hidden) | |
| #print(f"FasterKANvolver layers_hidden[1:] shape: ", len(layers_hidden[1:])) | |
| #print(f"FasterKANvolver layers_hidden[:-1] shape: ", len(layers_hidden[:-1])) | |
| #print("FasterKANvolver zip shape: \n", *[(in_dim, out_dim) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:])]) | |
| # Define the FasterKAN layers | |
| self.faster_kan_layers = nn.ModuleList([ | |
| FasterKANLayer( | |
| in_dim, out_dim, | |
| grid_min=grid_min, | |
| grid_max=grid_max, | |
| num_grids=num_grids, | |
| exponent=exponent, | |
| inv_denominator = 0.5, | |
| train_grid = False, | |
| train_inv_denominator = False, | |
| #use_base_update=use_base_update, | |
| base_activation=base_activation, | |
| spline_weight_init_scale=spline_weight_init_scale, | |
| ) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:]) | |
| ]) | |
| #print(f"FasterKANvolver self.faster_kan_layers shape: \n", len(self.faster_kan_layers)) | |
| #print(f"FasterKANvolver self.faster_kan_layers: \n", self.faster_kan_layers) | |
| def forward(self, x): | |
| # Reshape input from [batch_size, 784] to [batch_size, 1, 28, 28] for MNIST [batch_size, 1, 32, 32] for C | |
| #print(f"FasterKAN x view shape: {x.shape}") | |
| # Handle different input shapes based on the length of view | |
| x = x.view(self.view[0], self.view[1], self.view[2], self.view[3]) | |
| #print(f"FasterKAN x view shape: {x.shape}") | |
| # Apply convolutional layers | |
| #print(f"FasterKAN x view shape: {x.shape}") | |
| x = self.feature_extractor(x) | |
| #print(f"FasterKAN x after feature_extractor shape: {x.shape}") | |
| x = x.view(x.size(0), -1) # Flatten the output from the conv layers | |
| #rint(f"FasterKAN x shape: {x.shape}") | |
| # Pass through FasterKAN layers | |
| for layer in self.faster_kan_layers: | |
| #print("FasterKAN layer: \n", layer) | |
| #print(f"FasterKAN x shape: {x.shape}") | |
| x = layer(x) | |
| #print(f"FasterKAN x shape: {x.shape}") | |
| return x |