Spaces:
Running
Running
| """ | |
| Modified from https://github.com/philz1337x/clarity-upscaler | |
| which is a copy of https://github.com/AUTOMATIC1111/stable-diffusion-webui | |
| which is a copy of https://github.com/victorca25/iNNfer | |
| which is a copy of https://github.com/xinntao/ESRGAN | |
| """ | |
| import math | |
| from pathlib import Path | |
| from typing import NamedTuple | |
| import numpy as np | |
| import numpy.typing as npt | |
| import torch | |
| import torch.nn as nn | |
| from PIL import Image | |
| def conv_block(in_nc: int, out_nc: int) -> nn.Sequential: | |
| return nn.Sequential( | |
| nn.Conv2d(in_nc, out_nc, kernel_size=3, padding=1), | |
| nn.LeakyReLU(negative_slope=0.2, inplace=True), | |
| ) | |
| class ResidualDenseBlock_5C(nn.Module): | |
| """ | |
| Residual Dense Block | |
| The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18) | |
| Modified options that can be used: | |
| - "Partial Convolution based Padding" arXiv:1811.11718 | |
| - "Spectral normalization" arXiv:1802.05957 | |
| - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C. | |
| {Rakotonirina} and A. {Rasoanaivo} | |
| """ | |
| def __init__(self, nf: int = 64, gc: int = 32) -> None: | |
| super().__init__() # type: ignore[reportUnknownMemberType] | |
| self.conv1 = conv_block(nf, gc) | |
| self.conv2 = conv_block(nf + gc, gc) | |
| self.conv3 = conv_block(nf + 2 * gc, gc) | |
| self.conv4 = conv_block(nf + 3 * gc, gc) | |
| # Wrapped in Sequential because of key in state dict. | |
| self.conv5 = nn.Sequential(nn.Conv2d(nf + 4 * gc, nf, kernel_size=3, padding=1)) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x1 = self.conv1(x) | |
| x2 = self.conv2(torch.cat((x, x1), 1)) | |
| x3 = self.conv3(torch.cat((x, x1, x2), 1)) | |
| x4 = self.conv4(torch.cat((x, x1, x2, x3), 1)) | |
| x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) | |
| return x5 * 0.2 + x | |
| class RRDB(nn.Module): | |
| """ | |
| Residual in Residual Dense Block | |
| (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks) | |
| """ | |
| def __init__(self, nf: int) -> None: | |
| super().__init__() # type: ignore[reportUnknownMemberType] | |
| self.RDB1 = ResidualDenseBlock_5C(nf) | |
| self.RDB2 = ResidualDenseBlock_5C(nf) | |
| self.RDB3 = ResidualDenseBlock_5C(nf) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| out = self.RDB1(x) | |
| out = self.RDB2(out) | |
| out = self.RDB3(out) | |
| return out * 0.2 + x | |
| class Upsample2x(nn.Module): | |
| """Upsample 2x.""" | |
| def __init__(self) -> None: | |
| super().__init__() # type: ignore[reportUnknownMemberType] | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return nn.functional.interpolate(x, scale_factor=2.0) # type: ignore | |
| class ShortcutBlock(nn.Module): | |
| """Elementwise sum the output of a submodule to its input""" | |
| def __init__(self, submodule: nn.Module) -> None: | |
| super().__init__() # type: ignore[reportUnknownMemberType] | |
| self.sub = submodule | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return x + self.sub(x) | |
| class RRDBNet(nn.Module): | |
| def __init__(self, in_nc: int, out_nc: int, nf: int, nb: int) -> None: | |
| super().__init__() # type: ignore[reportUnknownMemberType] | |
| assert in_nc % 4 != 0 # in_nc is 3 | |
| self.model = nn.Sequential( | |
| nn.Conv2d(in_nc, nf, kernel_size=3, padding=1), | |
| ShortcutBlock( | |
| nn.Sequential( | |
| *(RRDB(nf) for _ in range(nb)), | |
| nn.Conv2d(nf, nf, kernel_size=3, padding=1), | |
| ) | |
| ), | |
| Upsample2x(), | |
| nn.Conv2d(nf, nf, kernel_size=3, padding=1), | |
| nn.LeakyReLU(negative_slope=0.2, inplace=True), | |
| Upsample2x(), | |
| nn.Conv2d(nf, nf, kernel_size=3, padding=1), | |
| nn.LeakyReLU(negative_slope=0.2, inplace=True), | |
| nn.Conv2d(nf, nf, kernel_size=3, padding=1), | |
| nn.LeakyReLU(negative_slope=0.2, inplace=True), | |
| nn.Conv2d(nf, out_nc, kernel_size=3, padding=1), | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.model(x) | |
| def infer_params(state_dict: dict[str, torch.Tensor]) -> tuple[int, int, int, int, int]: | |
| # this code is adapted from https://github.com/victorca25/iNNfer | |
| scale2x = 0 | |
| scalemin = 6 | |
| n_uplayer = 0 | |
| out_nc = 0 | |
| nb = 0 | |
| for block in list(state_dict): | |
| parts = block.split(".") | |
| n_parts = len(parts) | |
| if n_parts == 5 and parts[2] == "sub": | |
| nb = int(parts[3]) | |
| elif n_parts == 3: | |
| part_num = int(parts[1]) | |
| if part_num > scalemin and parts[0] == "model" and parts[2] == "weight": | |
| scale2x += 1 | |
| if part_num > n_uplayer: | |
| n_uplayer = part_num | |
| out_nc = state_dict[block].shape[0] | |
| assert "conv1x1" not in block # no ESRGANPlus | |
| nf = state_dict["model.0.weight"].shape[0] | |
| in_nc = state_dict["model.0.weight"].shape[1] | |
| scale = 2**scale2x | |
| assert out_nc > 0 | |
| assert nb > 0 | |
| return in_nc, out_nc, nf, nb, scale # 3, 3, 64, 23, 4 | |
| Tile = tuple[int, int, Image.Image] | |
| Tiles = list[tuple[int, int, list[Tile]]] | |
| # https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/images.py#L64 | |
| class Grid(NamedTuple): | |
| tiles: Tiles | |
| tile_w: int | |
| tile_h: int | |
| image_w: int | |
| image_h: int | |
| overlap: int | |
| # adapted from https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/images.py#L67 | |
| def split_grid(image: Image.Image, tile_w: int = 512, tile_h: int = 512, overlap: int = 64) -> Grid: | |
| w = image.width | |
| h = image.height | |
| non_overlap_width = tile_w - overlap | |
| non_overlap_height = tile_h - overlap | |
| cols = max(1, math.ceil((w - overlap) / non_overlap_width)) | |
| rows = max(1, math.ceil((h - overlap) / non_overlap_height)) | |
| dx = (w - tile_w) / (cols - 1) if cols > 1 else 0 | |
| dy = (h - tile_h) / (rows - 1) if rows > 1 else 0 | |
| grid = Grid([], tile_w, tile_h, w, h, overlap) | |
| for row in range(rows): | |
| row_images: list[Tile] = [] | |
| y1 = max(min(int(row * dy), h - tile_h), 0) | |
| y2 = min(y1 + tile_h, h) | |
| for col in range(cols): | |
| x1 = max(min(int(col * dx), w - tile_w), 0) | |
| x2 = min(x1 + tile_w, w) | |
| tile = image.crop((x1, y1, x2, y2)) | |
| row_images.append((x1, tile_w, tile)) | |
| grid.tiles.append((y1, tile_h, row_images)) | |
| return grid | |
| # https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/images.py#L104 | |
| def combine_grid(grid: Grid): | |
| def make_mask_image(r: npt.NDArray[np.float32]) -> Image.Image: | |
| r = r * 255 / grid.overlap | |
| return Image.fromarray(r.astype(np.uint8), "L") | |
| mask_w = make_mask_image( | |
| np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0) | |
| ) | |
| mask_h = make_mask_image( | |
| np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1) | |
| ) | |
| combined_image = Image.new("RGB", (grid.image_w, grid.image_h)) | |
| for y, h, row in grid.tiles: | |
| combined_row = Image.new("RGB", (grid.image_w, h)) | |
| for x, w, tile in row: | |
| if x == 0: | |
| combined_row.paste(tile, (0, 0)) | |
| continue | |
| combined_row.paste(tile.crop((0, 0, grid.overlap, h)), (x, 0), mask=mask_w) | |
| combined_row.paste(tile.crop((grid.overlap, 0, w, h)), (x + grid.overlap, 0)) | |
| if y == 0: | |
| combined_image.paste(combined_row, (0, 0)) | |
| continue | |
| combined_image.paste( | |
| combined_row.crop((0, 0, combined_row.width, grid.overlap)), | |
| (0, y), | |
| mask=mask_h, | |
| ) | |
| combined_image.paste( | |
| combined_row.crop((0, grid.overlap, combined_row.width, h)), | |
| (0, y + grid.overlap), | |
| ) | |
| return combined_image | |
| class UpscalerESRGAN: | |
| def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype): | |
| self.model_path = model_path | |
| self.device = device | |
| self.model = self.load_model(model_path) | |
| self.to(device, dtype) | |
| def __call__(self, img: Image.Image) -> Image.Image: | |
| return self.upscale_without_tiling(img) | |
| def to(self, device: torch.device, dtype: torch.dtype): | |
| self.device = device | |
| self.dtype = dtype | |
| self.model.to(device=device, dtype=dtype) | |
| def load_model(self, path: Path) -> RRDBNet: | |
| filename = path | |
| state_dict: dict[str, torch.Tensor] = torch.load(filename, weights_only=True, map_location=self.device) # type: ignore | |
| in_nc, out_nc, nf, nb, upscale = infer_params(state_dict) | |
| assert upscale == 4, "Only 4x upscaling is supported" | |
| model = RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb) | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| return model | |
| def upscale_without_tiling(self, img: Image.Image) -> Image.Image: | |
| img_np = np.array(img) | |
| img_np = img_np[:, :, ::-1] | |
| img_np = np.ascontiguousarray(np.transpose(img_np, (2, 0, 1))) / 255 | |
| img_t = torch.from_numpy(img_np).float() # type: ignore | |
| img_t = img_t.unsqueeze(0).to(device=self.device, dtype=self.dtype) | |
| with torch.no_grad(): | |
| output = self.model(img_t) | |
| output = output.squeeze().float().cpu().clamp_(0, 1).numpy() | |
| output = 255.0 * np.moveaxis(output, 0, 2) | |
| output = output.astype(np.uint8) | |
| output = output[:, :, ::-1] | |
| return Image.fromarray(output, "RGB") | |
| # https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/esrgan_model.py#L208 | |
| def upscale_with_tiling(self, img: Image.Image) -> Image.Image: | |
| img = img.convert("RGB") | |
| grid = split_grid(img) | |
| newtiles: Tiles = [] | |
| scale_factor: int = 1 | |
| for y, h, row in grid.tiles: | |
| newrow: list[Tile] = [] | |
| for tiledata in row: | |
| x, w, tile = tiledata | |
| output = self.upscale_without_tiling(tile) | |
| scale_factor = output.width // tile.width | |
| newrow.append((x * scale_factor, w * scale_factor, output)) | |
| newtiles.append((y * scale_factor, h * scale_factor, newrow)) | |
| newgrid = Grid( | |
| newtiles, | |
| grid.tile_w * scale_factor, | |
| grid.tile_h * scale_factor, | |
| grid.image_w * scale_factor, | |
| grid.image_h * scale_factor, | |
| grid.overlap * scale_factor, | |
| ) | |
| output = combine_grid(newgrid) | |
| return output | |