Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import deepinv as dinv | |
| import torch | |
| import numpy as np | |
| import PIL.Image | |
| def pil_to_torch(image): | |
| image = np.array(image) | |
| image = image.transpose((2, 0, 1)) | |
| image = torch.tensor(image).float() / 255 | |
| image = image.unsqueeze(0) | |
| ref_size = 256 | |
| if image.shape[2] > image.shape[3]: | |
| size = (ref_size, ref_size * image.shape[3]//image.shape[2]) | |
| else: | |
| size = (ref_size * image.shape[2]//image.shape[3], ref_size) | |
| image = torch.nn.functional.interpolate(image, size=size, mode='bilinear') | |
| return image | |
| def torch_to_pil(image): | |
| image = image.squeeze(0).cpu().detach().numpy() | |
| image = image.transpose((1, 2, 0)) | |
| image = (np.clip(image, 0, 1) * 255).astype(np.uint8) | |
| image = PIL.Image.fromarray(image) | |
| return image | |
| def image_mod(image, noise_level, denoiser): | |
| image = pil_to_torch(image) | |
| if denoiser == 'DnCNN': | |
| den = dinv.models.DnCNN() | |
| sigma0 = 2/255 | |
| denoiser = lambda x, sigma: den(x*sigma0/sigma)*sigma/sigma0 | |
| elif denoiser == 'MedianFilter': | |
| denoiser = dinv.models.MedianFilter(kernel_size=5) | |
| elif denoiser == 'BM3D': | |
| denoiser = dinv.models.BM3D() | |
| elif denoiser == 'TV': | |
| denoiser = dinv.models.TVDenoiser() | |
| elif denoiser == 'TGV': | |
| denoiser = dinv.models.TGVDenoiser() | |
| elif denoiser == 'Wavelets': | |
| denoiser = dinv.models.WaveletPrior() | |
| elif denoiser == 'SwinIR': | |
| denoiser = dinv.models.SwinIR(img_size=256) | |
| elif denoiser == 'DRUNet': | |
| denoiser = dinv.models.DRUNet() | |
| else: | |
| raise ValueError("Invalid denoiser") | |
| noisy = image + torch.randn_like(image) * noise_level | |
| estimated = denoiser(noisy, noise_level) | |
| return torch_to_pil(noisy), torch_to_pil(estimated) | |
| input_image = gr.Image(label='Input Image') | |
| output_images = gr.Image(label='Denoised Image') | |
| noise_image = gr.Image(label='Noisy Image') | |
| input_image_output = gr.Image(label='Input Image') | |
| noise_levels = gr.Dropdown(choices=[0.1, 0.2, 0.3, 0.5, 1], value=0.1, label='Noise Level') | |
| denoiser = gr.Dropdown(choices=['DnCNN', 'DRUNet', 'SwinIR', 'BM3D', 'MedianFilter', 'TV', 'TGV', 'Wavelets'], value='DnCNN', label='Denoiser') | |
| demo = gr.Interface( | |
| image_mod, | |
| inputs=[input_image, noise_levels, denoiser], | |
| examples=[['https://deepinv.github.io/deepinv/_static/deepinv_logolarge.png', 0.1, 'DnCNN']], | |
| outputs=[noise_image, output_images], | |
| title="Image Denoising with DeepInverse", | |
| description="Denoise an image using a variety of denoisers and noise levels using the deepinverse library (https://deepinv.github.io/). We only include lightweight models like DnCNN and MedianFilter as this example is intended to be run on a CPU. We also automatically resize the input image to 256 pixels to reduce the computation time. For more advanced models, please run the code locally.", | |
| ) | |
| demo.launch() |