A newer version of the Gradio SDK is available:
5.42.0
title: Erav2s13
emoji: 🔥
colorFrom: yellow
colorTo: red
sdk: gradio
sdk_version: 4.27.0
app_file: app.py
pinned: false
license: mit
I have used hugging face repo as the repository and Gradio for building the UI.
Main.ipynb has the file which was run on colab notebook to build and train the model. The same structure and model files are used to infer in gradio app.
Custom ResNet Model
The custom_resnet.py file defines a custom ResNet (Residual Network) model using PyTorch Lightning. The model is designed for image classification tasks, specifically for the CIFAR-10 dataset.
Model Architecture The custom ResNet model consists of the following components:
Preparation Layer: The input data goes through a convolutional layer with 64 filters, followed by batch normalization, ReLU activation, and dropout.
Layer 1: The output from the preparation layer passes through another convolutional layer with 128 filters, max pooling, batch normalization, ReLU activation, and dropout. Additionally, a residual block is applied, which consists of two convolutional layers with 128 filters each, batch normalization, ReLU activation, and dropout. The output of the residual block is added to the output of the previous layer.
Layer 2: The output from Layer 1 goes through a convolutional layer with 256 filters, max pooling, batch normalization, ReLU activation, and dropout.
Layer 3: Similar to Layer 1, the output from Layer 2 passes through a convolutional layer with 512 filters, max pooling, batch normalization, ReLU activation, and dropout. Another residual block is applied, consisting of two convolutional layers with 512 filters each, batch normalization, ReLU activation, and dropout. The output of the residual block is added to the output of the previous layer.
Max Pooling: The output from Layer 3 goes through a max pooling layer with a kernel size of 4.
Fully Connected Layer: The output from the max pooling layer is flattened and passed through a fully connected layer with 10 output units, corresponding to the 10 classes in the CIFAR-10 dataset.
Softmax: Finally, the output from the fully connected layer goes through a log softmax activation function to obtain the predicted class probabilities.
Training and Evaluation The model is trained using the PyTorch Lightning framework, which provides a high-level interface for training, validation, and testing. The model uses the Adam optimizer with a learning rate determined by the PREFERRED_START_LR configuration variable. The learning rate is adjusted using the OneCycleLR scheduler.
The model computes the cross-entropy loss and accuracy during training, validation, and testing steps. The loss and accuracy values are logged and stored in the results dictionary for further analysis.
Misclassified Images During the testing phase, the model keeps track of misclassified images. It stores the misclassified images, their ground truth labels, and the predicted labels in the misclassified_image_data dictionary. This information can be used for error analysis and model improvement.
Hyperparameters The model uses some hyperparameters defined in the config module, such as PREFERRED_START_LR for the initial learning rate and PREFERRED_WEIGHT_DECAY for weight decay regularization.
Model Summary The detailed_model_summary function is provided to print a detailed summary of the model architecture, including the input size, kernel size, output size, number of parameters, and trainable status of each layer.
This custom ResNet model serves as a starting point for image classification tasks and can be further modified and enhanced based on specific requirements and datasets.
Lightning Dataset Module
The lightning_dataset.py file contains the CIFARDataModule class, which is a PyTorch Lightning LightningDataModule for the CIFAR-10 dataset. This class encapsulates the data preparation, splitting, and loading logic for the CIFAR-10 dataset.
Class: CIFARDataModule The CIFARDataModule class is responsible for preparing and providing data loaders for the CIFAR-10 dataset. It inherits from pl.LightningDataModule.
Parameters data_path: Path to the directory where the CIFAR-10 dataset will be downloaded or loaded from. batch_size: Batch size for the data loaders. seed: Random seed for reproducibility. val_split: Fraction of the training data to be used for validation (default: 0). num_workers: Number of worker processes for data loading (default: 0). Methods prepare_data: Downloads the CIFAR-10 dataset if it doesn't exist. setup: Defines the data transformations and creates the training, validation, and testing datasets based on the specified stage. train_dataloader: Returns the data loader for the training dataset. val_dataloader: Returns the data loader for the validation dataset. test_dataloader: Returns the data loader for the testing dataset. The CIFARDataModule class also includes utility methods:
_split_train_val: Splits the training dataset into training and validation subsets based on the specified validation split ratio. _init_fn: Initializes the random seed for each worker process to ensure reproducibility .