# built-in from inspect import signature import os import subprocess import logging import re import random from string import ascii_letters, digits, punctuation import requests import sys import warnings import time import asyncio import math from pathlib import Path from functools import partial from dataclasses import dataclass from typing import Any import pillow_heif import spaces import numpy as np import numpy.typing as npt import torch import gradio as gr from lxml.html import fromstring from huggingface_hub import hf_hub_download from safetensors.torch import load_file, save_file from diffusers import FluxPipeline from PIL import Image, ImageDraw, ImageFont from transformers import PegasusForConditionalGeneration, PegasusTokenizerFast from refiners.fluxion.utils import manual_seed from refiners.foundationals.latent_diffusion import Solver, solvers from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_upscaler import ( MultiUpscaler, UpscalerCheckpoints, ) Tile = tuple[int, int, Image.Image] Tiles = list[tuple[int, int, list[Tile]]] def conv_block(in_nc: int, out_nc: int) -> nn.Sequential: return nn.Sequential( nn.Conv2d(in_nc, out_nc, kernel_size=3, padding=1), nn.LeakyReLU(negative_slope=0.2, inplace=True), ) class ResidualDenseBlock_5C(nn.Module): """ Residual Dense Block The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18) Modified options that can be used: - "Partial Convolution based Padding" arXiv:1811.11718 - "Spectral normalization" arXiv:1802.05957 - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C. {Rakotonirina} and A. {Rasoanaivo} """ def __init__(self, nf: int = 64, gc: int = 32) -> None: super().__init__() # type: ignore[reportUnknownMemberType] self.conv1 = conv_block(nf, gc) self.conv2 = conv_block(nf + gc, gc) self.conv3 = conv_block(nf + 2 * gc, gc) self.conv4 = conv_block(nf + 3 * gc, gc) # Wrapped in Sequential because of key in state dict. self.conv5 = nn.Sequential(nn.Conv2d(nf + 4 * gc, nf, kernel_size=3, padding=1)) def forward(self, x: torch.Tensor) -> torch.Tensor: x1 = self.conv1(x) x2 = self.conv2(torch.cat((x, x1), 1)) x3 = self.conv3(torch.cat((x, x1, x2), 1)) x4 = self.conv4(torch.cat((x, x1, x2, x3), 1)) x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) return x5 * 0.2 + x class RRDB(nn.Module): """ Residual in Residual Dense Block (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks) """ def __init__(self, nf: int) -> None: super().__init__() # type: ignore[reportUnknownMemberType] self.RDB1 = ResidualDenseBlock_5C(nf) self.RDB2 = ResidualDenseBlock_5C(nf) self.RDB3 = ResidualDenseBlock_5C(nf) def forward(self, x: torch.Tensor) -> torch.Tensor: out = self.RDB1(x) out = self.RDB2(out) out = self.RDB3(out) return out * 0.2 + x class Upsample2x(nn.Module): """Upsample 2x.""" def __init__(self) -> None: super().__init__() # type: ignore[reportUnknownMemberType] def forward(self, x: torch.Tensor) -> torch.Tensor: return nn.functional.interpolate(x, scale_factor=2.0) # type: ignore class ShortcutBlock(nn.Module): """Elementwise sum the output of a submodule to its input""" def __init__(self, submodule: nn.Module) -> None: super().__init__() # type: ignore[reportUnknownMemberType] self.sub = submodule def forward(self, x: torch.Tensor) -> torch.Tensor: return x + self.sub(x) class RRDBNet(nn.Module): def __init__(self, in_nc: int, out_nc: int, nf: int, nb: int) -> None: super().__init__() # type: ignore[reportUnknownMemberType] assert in_nc % 4 != 0 # in_nc is 3 self.model = nn.Sequential( nn.Conv2d(in_nc, nf, kernel_size=3, padding=1), ShortcutBlock( nn.Sequential( *(RRDB(nf) for _ in range(nb)), nn.Conv2d(nf, nf, kernel_size=3, padding=1), ) ), Upsample2x(), nn.Conv2d(nf, nf, kernel_size=3, padding=1), nn.LeakyReLU(negative_slope=0.2, inplace=True), Upsample2x(), nn.Conv2d(nf, nf, kernel_size=3, padding=1), nn.LeakyReLU(negative_slope=0.2, inplace=True), nn.Conv2d(nf, nf, kernel_size=3, padding=1), nn.LeakyReLU(negative_slope=0.2, inplace=True), nn.Conv2d(nf, out_nc, kernel_size=3, padding=1), ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.model(x) def infer_params(state_dict: dict[str, torch.Tensor]) -> tuple[int, int, int, int, int]: # this code is adapted from https://github.com/victorca25/iNNfer scale2x = 0 scalemin = 6 n_uplayer = 0 out_nc = 0 nb = 0 for block in list(state_dict): parts = block.split(".") n_parts = len(parts) if n_parts == 5 and parts[2] == "sub": nb = int(parts[3]) elif n_parts == 3: part_num = int(parts[1]) if part_num > scalemin and parts[0] == "model" and parts[2] == "weight": scale2x += 1 if part_num > n_uplayer: n_uplayer = part_num out_nc = state_dict[block].shape[0] assert "conv1x1" not in block # no ESRGANPlus nf = state_dict["model.0.weight"].shape[0] in_nc = state_dict["model.0.weight"].shape[1] scale = 2**scale2x assert out_nc > 0 assert nb > 0 return in_nc, out_nc, nf, nb, scale # 3, 3, 64, 23, 4 # https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/images.py#L64 class Grid(NamedTuple): tiles: Tiles tile_w: int tile_h: int image_w: int image_h: int overlap: int # adapted from https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/images.py#L67 def split_grid(image: Image.Image, tile_w: int = 512, tile_h: int = 512, overlap: int = 64) -> Grid: w = image.width h = image.height non_overlap_width = tile_w - overlap non_overlap_height = tile_h - overlap cols = max(1, math.ceil((w - overlap) / non_overlap_width)) rows = max(1, math.ceil((h - overlap) / non_overlap_height)) dx = (w - tile_w) / (cols - 1) if cols > 1 else 0 dy = (h - tile_h) / (rows - 1) if rows > 1 else 0 grid = Grid([], tile_w, tile_h, w, h, overlap) for row in range(rows): row_images: list[Tile] = [] y1 = max(min(int(row * dy), h - tile_h), 0) y2 = min(y1 + tile_h, h) for col in range(cols): x1 = max(min(int(col * dx), w - tile_w), 0) x2 = min(x1 + tile_w, w) tile = image.crop((x1, y1, x2, y2)) row_images.append((x1, tile_w, tile)) grid.tiles.append((y1, tile_h, row_images)) return grid # https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/images.py#L104 def combine_grid(grid: Grid): def make_mask_image(r: npt.NDArray[np.float32]) -> Image.Image: r = r * 255 / grid.overlap return Image.fromarray(r.astype(np.uint8), "L") mask_w = make_mask_image( np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0) ) mask_h = make_mask_image( np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1) ) combined_image = Image.new("RGB", (grid.image_w, grid.image_h)) for y, h, row in grid.tiles: combined_row = Image.new("RGB", (grid.image_w, h)) for x, w, tile in row: if x == 0: combined_row.paste(tile, (0, 0)) continue combined_row.paste(tile.crop((0, 0, grid.overlap, h)), (x, 0), mask=mask_w) combined_row.paste(tile.crop((grid.overlap, 0, w, h)), (x + grid.overlap, 0)) if y == 0: combined_image.paste(combined_row, (0, 0)) continue combined_image.paste( combined_row.crop((0, 0, combined_row.width, grid.overlap)), (0, y), mask=mask_h, ) combined_image.paste( combined_row.crop((0, grid.overlap, combined_row.width, h)), (0, y + grid.overlap), ) return combined_image class UpscalerESRGAN: def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype): self.model_path = model_path self.device = device self.model = self.load_model(model_path) self.to(device, dtype) def __call__(self, img: Image.Image) -> Image.Image: return self.upscale_without_tiling(img) def to(self, device: torch.device, dtype: torch.dtype): self.device = device self.dtype = dtype self.model.to(device=device, dtype=dtype) def load_model(self, path: Path) -> RRDBNet: filename = path state_dict: dict[str, torch.Tensor] = torch.load(filename, weights_only=True, map_location=self.device) # type: ignore in_nc, out_nc, nf, nb, upscale = infer_params(state_dict) assert upscale == 4, "Only 4x upscaling is supported" model = RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb) model.load_state_dict(state_dict) model.eval() return model def upscale_without_tiling(self, img: Image.Image) -> Image.Image: img_np = np.array(img) img_np = img_np[:, :, ::-1] img_np = np.ascontiguousarray(np.transpose(img_np, (2, 0, 1))) / 255 img_t = torch.from_numpy(img_np).float() # type: ignore img_t = img_t.unsqueeze(0).to(device=self.device, dtype=self.dtype) with torch.no_grad(): output = self.model(img_t) output = output.squeeze().float().cpu().clamp_(0, 1).numpy() output = 255.0 * np.moveaxis(output, 0, 2) output = output.astype(np.uint8) output = output[:, :, ::-1] return Image.fromarray(output, "RGB") # https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/esrgan_model.py#L208 def upscale_with_tiling(self, img: Image.Image) -> Image.Image: img = img.convert("RGB") grid = split_grid(img) newtiles: Tiles = [] scale_factor: int = 1 for y, h, row in grid.tiles: newrow: list[Tile] = [] for tiledata in row: x, w, tile = tiledata output = self.upscale_without_tiling(tile) scale_factor = output.width // tile.width newrow.append((x * scale_factor, w * scale_factor, output)) newtiles.append((y * scale_factor, h * scale_factor, newrow)) newgrid = Grid( newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor, ) output = combine_grid(newgrid) return output @dataclass(kw_only=True) class ESRGANUpscalerCheckpoints(UpscalerCheckpoints): esrgan: Path class ESRGANUpscaler(MultiUpscaler): def __init__( self, checkpoints: ESRGANUpscalerCheckpoints, device: torch.device, dtype: torch.dtype, ) -> None: super().__init__(checkpoints=checkpoints, device=device, dtype=dtype) self.esrgan = UpscalerESRGAN(checkpoints.esrgan, device=self.device, dtype=self.dtype) def to(self, device: torch.device, dtype: torch.dtype): self.esrgan.to(device=device, dtype=dtype) self.sd = self.sd.to(device=device, dtype=dtype) self.device = device self.dtype = dtype def pre_upscale(self, image: Image.Image, upscale_factor: float, **_: Any) -> Image.Image: image = self.esrgan.upscale_with_tiling(image) return super().pre_upscale(image=image, upscale_factor=upscale_factor / 4) pillow_heif.register_heif_opener() pillow_heif.register_avif_opener() CHECKPOINTS = ESRGANUpscalerCheckpoints( unet=Path( hf_hub_download( repo_id="refiners/juggernaut.reborn.sd1_5.unet", filename="model.safetensors", revision="347d14c3c782c4959cc4d1bb1e336d19f7dda4d2", ) ), clip_text_encoder=Path( hf_hub_download( repo_id="refiners/juggernaut.reborn.sd1_5.text_encoder", filename="model.safetensors", revision="744ad6a5c0437ec02ad826df9f6ede102bb27481", ) ), lda=Path( hf_hub_download( repo_id="refiners/juggernaut.reborn.sd1_5.autoencoder", filename="model.safetensors", revision="3c1aae3fc3e03e4a2b7e0fa42b62ebb64f1a4c19", ) ), controlnet_tile=Path( hf_hub_download( repo_id="refiners/controlnet.sd1_5.tile", filename="model.safetensors", revision="48ced6ff8bfa873a8976fa467c3629a240643387", ) ), esrgan=Path( hf_hub_download( repo_id="philz1337x/upscaler", filename="4x-UltraSharp.pth", revision="011deacac8270114eb7d2eeff4fe6fa9a837be70", ) ), negative_embedding=Path( hf_hub_download( repo_id="philz1337x/embeddings", filename="JuggernautNegative-neg.pt", revision="203caa7e9cc2bc225031a4021f6ab1ded283454a", ) ), negative_embedding_key="string_to_param.*", loras={ "more_details": Path( hf_hub_download( repo_id="philz1337x/loras", filename="more_details.safetensors", revision="a3802c0280c0d00c2ab18d37454a8744c44e474e", ) ), "sdxl_render": Path( hf_hub_download( repo_id="philz1337x/loras", filename="SDXLrender_v2.0.safetensors", revision="a3802c0280c0d00c2ab18d37454a8744c44e474e", ) ) } ) # initialize the enhancer, on the cpu DEVICE_CPU = torch.device("cpu") DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 enhancer = ESRGANUpscaler(checkpoints=CHECKPOINTS, device=DEVICE_CPU, dtype=DTYPE) device = DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") enhancer.to(device=DEVICE, dtype=DTYPE) # logging warnings.filterwarnings("ignore") root = logging.getLogger() root.setLevel(logging.WARN) handler = logging.StreamHandler(sys.stderr) handler.setLevel(logging.WARN) formatter = logging.Formatter('\n >>> [%(levelname)s] %(asctime)s %(name)s: %(message)s\n') handler.setFormatter(formatter) root.addHandler(handler) # constant data base = "black-forest-labs/FLUX.1-schnell" pegasus_name = "google/pegasus-xsum" # precision data seq=512 width=1024 height=1024 image_steps=8 img_accu=0 # ui data css="".join([""" input, input::placeholder { text-align: center !important; } *, *::placeholder { font-family: Suez One !important; } h1,h2,h3,h4,h5,h6 { width: 100%; text-align: center; } footer { display: none !important; } #col-container { margin: 0 auto; } .image-container { aspect-ratio: """,str(width),"/",str(height),""" !important; } .dropdown-arrow { display: none !important; } *:has(>.btn) { display: flex; justify-content: space-evenly; align-items: center; } .btn { display: flex; } """]) js=""" function custom(){ document.querySelector("div#prompt input").addEventListener("keydown",function(e){ e.target.setAttribute("last_value",e.target.value); }); document.querySelector("div#prompt input").addEventListener("input",function(e){ if( e.target.value.toString().match(/[^ a-zA-Z,]|( |,){2,}/gsm) ){ e.target.value = e.target.getAttribute("last_value"); e.target.removeAttribute("last_value"); } }); document.querySelector("div#prompt2 input").addEventListener("keydown",function(e){ e.target.setAttribute("last_value",e.target.value); }); document.querySelector("div#prompt2 input").addEventListener("input",function(e){ if( e.target.value.toString().match(/[^ a-zA-Z,]|( |,){2,}/gsm) ){ e.target.value = e.target.getAttribute("last_value"); e.target.removeAttribute("last_value"); } }); } """ # torch pipes image_pipe = FluxPipeline.from_pretrained(base, torch_dtype=torch.bfloat16).to(device) image_pipe.enable_model_cpu_offload() # functionality @spaces.GPU(duration=180) def upscaler( input_image: Image.Image, prompt: str = "masterpiece, best quality, highres", negative_prompt: str = "worst quality, low quality, normal quality", seed: int = 42, upscale_factor: int = 8, controlnet_scale: float = 0.6, controlnet_decay: float = 1.0, condition_scale: int = 6, tile_width: int = 112, tile_height: int = 144, denoise_strength: float = 0.35, num_inference_steps: int = 18, solver: str = "DDIM", ) -> Image.Image: manual_seed(seed) solver_type: type[Solver] = getattr(solvers, solver) enhanced_image = enhancer.upscale( image=input_image, prompt=prompt, negative_prompt=negative_prompt, upscale_factor=upscale_factor, controlnet_scale=controlnet_scale, controlnet_scale_decay=controlnet_decay, condition_scale=condition_scale, tile_size=(tile_height, tile_width), denoise_strength=denoise_strength, num_inference_steps=num_inference_steps, loras_scale={"more_details": 0.5, "sdxl_render": 1.0}, solver_type=solver_type, ) return enhanced_image @spaces.GPU(duration=180) def summarize_text( text, max_length=30, num_beams=16, early_stopping=True, pegasus_tokenizer = PegasusTokenizerFast.from_pretrained("google/pegasus-xsum"), pegasus_model = PegasusForConditionalGeneration.from_pretrained("google/pegasus-xsum") ): return pegasus_tokenizer.decode( pegasus_model.generate( pegasus_tokenizer(text,return_tensors="pt").input_ids, max_length=max_length, num_beams=num_beams, early_stopping=early_stopping )[0], skip_special_tokens=True) def generate_random_string(length): characters = str(ascii_letters + digits) return ''.join(random.choice(characters) for _ in range(length)) @spaces.GPU(duration=180) def pipe_generate(p1,p2): return image_pipe( prompt=p1, negative_prompt=p2, height=height, width=width, guidance_scale=img_accu, num_images_per_prompt=1, num_inference_steps=image_steps, max_sequence_length=seq, generator=torch.Generator(device).manual_seed(int(str(random.random()).split(".")[1])) ).images[0] def handle_generate(artist,song,genre,lyrics): pos_artist = re.sub("([ \t\n]){1,}", " ", artist).strip() pos_song = re.sub("([ \t\n]){1,}", " ", song).strip() pos_song = ' '.join(word[0].upper() + word[1:] for word in pos_song.split()) pos_genre = re.sub(f'[{punctuation}]', '', re.sub("([ \t\n]){1,}", " ", genre)).upper().strip() pos_lyrics = re.sub(f'[{punctuation}]', '', re.sub("([ \t\n]){1,}", " ", lyrics)).lower().strip() pos_lyrics_sum = summarize_text(pos_lyrics) neg = f"Textual Labeled Distorted Discontinuous Ugly Blurry Low-Quality Worst-Quality Low-Resolution Painted" pos = f'Realistic Vivid Genuine Reasonable Detailed 4K { pos_genre } GENRE { pos_song }: "{ pos_lyrics_sum }"' print(f""" Positive: {pos} Negative: {neg} """) img = pipe_generate(pos,neg) draw = ImageDraw.Draw(img) rows = 1 labels_distance = math.ceil(1 / 3) textheight=min(math.ceil( width / 10 ), math.ceil( height / 5 )) font = ImageFont.truetype(r"Alef-Bold.ttf", textheight) textwidth = draw.textlength(pos_song,font) x = math.ceil((width - textwidth) / 2) y = height - math.ceil(textheight * rows / 2) y = y - math.ceil(y / labels_distance) draw.text((x, y), pos_song, (255,255,255), font=font, spacing=2, stroke_width=4, stroke_fill=(0,0,0)) textheight=min(math.ceil( width / 12 ), math.ceil( height / 6 )) font = ImageFont.truetype(r"Alef-Bold.ttf", textheight) textwidth = draw.textlength(pos_artist,font) x = math.ceil((width - textwidth) / 2) y = height - math.ceil(textheight * rows / 2) y = y + math.ceil(y / labels_distance) draw.text((x, y), pos_artist, (0,0,0), font=font, spacing=6, stroke_width=8, stroke_fill=(255,255,255)) enhanced_img = upscaler(img) name = generate_random_string(12) + ".png" enhanced_img.save(name) return name # entry if __name__ == "__main__": with gr.Blocks(theme=gr.themes.Citrus(),css=css) as demo: gr.Markdown(f""" # Song Cover Image Generator """) with gr.Column(): with gr.Row(): artist = gr.Textbox( placeholder="Artist name", container=False, max_lines=1 ) song = gr.Textbox( placeholder="Song name", container=False, max_lines=1 ) genre = gr.Textbox( placeholder="Genre", container=False, max_lines=1 ) lyrics = gr.Textbox( placeholder="Lyrics (English)", container=False, max_lines=1 ) with gr.Column(): cover = gr.Image(interactive=False,container=False,elem_classes="image-container", label="Result", show_label=True, type='filepath', show_share_button=False) run = gr.Button("Generate",elem_classes="btn") run.click( fn=handle_generate, inputs=[artist,song,genre,lyrics], outputs=[cover] ) demo.queue().launch()