{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "view-in-github" }, "source": [ "\"Open" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "id": "1JryaVhtBHij", "outputId": "4fac7fb6-787c-4a1b-f6ef-12ec48024619" }, "outputs": [], "source": [ "!git clone https://github.com/soumik12345/enhance-me -b mirnet\n", "!pip install wandb streamlit" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "G_c4VtXWHR5l" }, "outputs": [], "source": [ "import sys\n", "sys.path.append(\"./enhance-me\")\n", "\n", "from PIL import Image\n", "from enhance_me import commons\n", "from enhance_me.mirnet import MIRNet" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZpBHbYaMIqP_" }, "outputs": [], "source": [ "#@title MIRNet Train Configs\n", "\n", "experiment_name = 'lol_dataset_128' #@param {type:\"string\"}\n", "image_size = 128 #@param {type:\"integer\"}\n", "dataset_label = 'lol' #@param [\"lol\"]\n", "apply_random_horizontal_flip = True #@param {type:\"boolean\"}\n", "apply_random_vertical_flip = True #@param {type:\"boolean\"}\n", "apply_random_rotation = True #@param {type:\"boolean\"}\n", "wandb_api_key = '' #@param {type:\"string\"}\n", "val_split = 0.1 #@param {type:\"slider\", min:0.1, max:1.0, step:0.1}\n", "batch_size = 4 #@param {type:\"integer\"}\n", "num_recursive_residual_groups = 3 #@param {type:\"slider\", min:1, max:5, step:1}\n", "num_multi_scale_residual_blocks = 2 #@param {type:\"slider\", min:1, max:5, step:1}\n", "learning_rate = 1e-4 #@param {type:\"number\"}\n", "epsilon = 1e-3 #@param {type:\"number\"}\n", "epochs = 50 #@param {type:\"slider\", min:10, max:100, step:5}" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 124 }, "id": "IVRoedqBIMuH", "outputId": "388a806f-f41f-420c-9c03-01024decb2d3" }, "outputs": [], "source": [ "mirnet = MIRNet(\n", " experiment_name=experiment_name,\n", " image_size=image_size,\n", " dataset_label=dataset_label,\n", " val_split=val_split,\n", " batch_size=batch_size,\n", " apply_random_horizontal_flip=apply_random_horizontal_flip,\n", " apply_random_vertical_flip=apply_random_vertical_flip,\n", " apply_random_rotation=apply_random_rotation,\n", " wandb_api_key=None if wandb_api_key == '' else wandb_api_key\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "tsfKrBCsL_Bb" }, "outputs": [], "source": [ "mirnet.build_model(\n", " num_recursive_residual_groups=num_recursive_residual_groups,\n", " num_multi_scale_residual_blocks=num_multi_scale_residual_blocks,\n", " learning_rate=learning_rate,\n", " epsilon=epsilon\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "y3L9wlpkNziL", "outputId": "65e7ba4d-1607-4c14-d5d7-e55c4641ad0a" }, "outputs": [], "source": [ "history = mirnet.train(epochs=epochs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "background_save": true, "base_uri": "https://localhost:8080/", "height": 1000 }, "id": "daFKbgBkiyzc", "outputId": "38c3fc7a-8cef-4332-8efe-35103c75f1a3" }, "outputs": [], "source": [ "for index, low_image_file in enumerate(mirnet.test_low_images):\n", " original_image = Image.open(low_image_file)\n", " enhanced_image = mirnet.infer(original_image)\n", " ground_truth = Image.open(mirnet.test_enhanced_images[index])\n", " commons.plot_results(\n", " [original_image, ground_truth, ground_truth],\n", " [\"Original Image\", \"Ground Truth\", \"Enhanced Image\"],\n", " (18, 18)\n", " )" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dO-IbNQHkB3R" }, "outputs": [], "source": [] } ], "metadata": { "accelerator": "GPU", "colab": { "authorship_tag": "ABX9TyMwNbyaCs348ucM56hcLJop", "collapsed_sections": [], "include_colab_link": true, "machine_shape": "hm", "name": "enhance-me-train.ipynb", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }