Spaces:
Sleeping
Sleeping
| from langdetect import detect as get_language | |
| from collections import namedtuple | |
| from inspect import signature | |
| import os | |
| import subprocess | |
| import logging | |
| import re | |
| import random | |
| from string import ascii_letters, digits, punctuation | |
| import requests | |
| import sys | |
| import warnings | |
| import time | |
| import asyncio | |
| import math | |
| from pathlib import Path | |
| from functools import partial | |
| from dataclasses import dataclass | |
| from typing import Any | |
| import pillow_heif | |
| import spaces | |
| import numpy as np | |
| import numpy.typing as npt | |
| import torch | |
| from torch import nn | |
| import gradio as gr | |
| from lxml.html import fromstring | |
| from huggingface_hub import hf_hub_download | |
| from safetensors.torch import load_file, save_file | |
| from diffusers import FluxPipeline | |
| from PIL import Image, ImageDraw, ImageFont | |
| from transformers import pipeline, MT5ForConditionalGeneration, T5Tokenizer | |
| from refiners.fluxion.utils import manual_seed | |
| from refiners.foundationals.latent_diffusion import Solver, solvers | |
| from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_upscaler import ( | |
| MultiUpscaler, | |
| UpscalerCheckpoints, | |
| ) | |
| from datetime import datetime | |
| model = MT5ForConditionalGeneration.from_pretrained("google/mt5-xl") | |
| tokenizer = T5Tokenizer.from_pretrained("google/mt5-xl") | |
| def log(msg): | |
| print(f'{datetime.now().time()} {msg}') | |
| Tile = tuple[int, int, Image.Image] | |
| Tiles = list[tuple[int, int, list[Tile]]] | |
| def conv_block(in_nc: int, out_nc: int) -> nn.Sequential: | |
| return nn.Sequential( | |
| nn.Conv2d(in_nc, out_nc, kernel_size=3, padding=1), | |
| nn.LeakyReLU(negative_slope=0.2, inplace=True), | |
| ) | |
| class ResidualDenseBlock_5C(nn.Module): | |
| """ | |
| Residual Dense Block | |
| The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18) | |
| Modified options that can be used: | |
| - "Partial Convolution based Padding" arXiv:1811.11718 | |
| - "Spectral normalization" arXiv:1802.05957 | |
| - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C. | |
| {Rakotonirina} and A. {Rasoanaivo} | |
| """ | |
| def __init__(self, nf: int = 64, gc: int = 32) -> None: | |
| super().__init__() # type: ignore[reportUnknownMemberType] | |
| self.conv1 = conv_block(nf, gc) | |
| self.conv2 = conv_block(nf + gc, gc) | |
| self.conv3 = conv_block(nf + 2 * gc, gc) | |
| self.conv4 = conv_block(nf + 3 * gc, gc) | |
| # Wrapped in Sequential because of key in state dict. | |
| self.conv5 = nn.Sequential(nn.Conv2d(nf + 4 * gc, nf, kernel_size=3, padding=1)) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| x1 = self.conv1(x) | |
| x2 = self.conv2(torch.cat((x, x1), 1)) | |
| x3 = self.conv3(torch.cat((x, x1, x2), 1)) | |
| x4 = self.conv4(torch.cat((x, x1, x2, x3), 1)) | |
| x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) | |
| return x5 * 0.2 + x | |
| class RRDB(nn.Module): | |
| """ | |
| Residual in Residual Dense Block | |
| (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks) | |
| """ | |
| def __init__(self, nf: int) -> None: | |
| super().__init__() # type: ignore[reportUnknownMemberType] | |
| self.RDB1 = ResidualDenseBlock_5C(nf) | |
| self.RDB2 = ResidualDenseBlock_5C(nf) | |
| self.RDB3 = ResidualDenseBlock_5C(nf) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| out = self.RDB1(x) | |
| out = self.RDB2(out) | |
| out = self.RDB3(out) | |
| return out * 0.2 + x | |
| class Upsample2x(nn.Module): | |
| """Upsample 2x.""" | |
| def __init__(self) -> None: | |
| super().__init__() # type: ignore[reportUnknownMemberType] | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return nn.functional.interpolate(x, scale_factor=2.0) # type: ignore | |
| class ShortcutBlock(nn.Module): | |
| """Elementwise sum the output of a submodule to its input""" | |
| def __init__(self, submodule: nn.Module) -> None: | |
| super().__init__() # type: ignore[reportUnknownMemberType] | |
| self.sub = submodule | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return x + self.sub(x) | |
| class RRDBNet(nn.Module): | |
| def __init__(self, in_nc: int, out_nc: int, nf: int, nb: int) -> None: | |
| super().__init__() # type: ignore[reportUnknownMemberType] | |
| assert in_nc % 4 != 0 # in_nc is 3 | |
| self.model = nn.Sequential( | |
| nn.Conv2d(in_nc, nf, kernel_size=3, padding=1), | |
| ShortcutBlock( | |
| nn.Sequential( | |
| *(RRDB(nf) for _ in range(nb)), | |
| nn.Conv2d(nf, nf, kernel_size=3, padding=1), | |
| ) | |
| ), | |
| Upsample2x(), | |
| nn.Conv2d(nf, nf, kernel_size=3, padding=1), | |
| nn.LeakyReLU(negative_slope=0.2, inplace=True), | |
| Upsample2x(), | |
| nn.Conv2d(nf, nf, kernel_size=3, padding=1), | |
| nn.LeakyReLU(negative_slope=0.2, inplace=True), | |
| nn.Conv2d(nf, nf, kernel_size=3, padding=1), | |
| nn.LeakyReLU(negative_slope=0.2, inplace=True), | |
| nn.Conv2d(nf, out_nc, kernel_size=3, padding=1), | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.model(x) | |
| def infer_params(state_dict: dict[str, torch.Tensor]) -> tuple[int, int, int, int, int]: | |
| # this code is adapted from https://github.com/victorca25/iNNfer | |
| scale2x = 0 | |
| scalemin = 6 | |
| n_uplayer = 0 | |
| out_nc = 0 | |
| nb = 0 | |
| for block in list(state_dict): | |
| parts = block.split(".") | |
| n_parts = len(parts) | |
| if n_parts == 5 and parts[2] == "sub": | |
| nb = int(parts[3]) | |
| elif n_parts == 3: | |
| part_num = int(parts[1]) | |
| if part_num > scalemin and parts[0] == "model" and parts[2] == "weight": | |
| scale2x += 1 | |
| if part_num > n_uplayer: | |
| n_uplayer = part_num | |
| out_nc = state_dict[block].shape[0] | |
| assert "conv1x1" not in block # no ESRGANPlus | |
| nf = state_dict["model.0.weight"].shape[0] | |
| in_nc = state_dict["model.0.weight"].shape[1] | |
| scale = 2**scale2x | |
| assert out_nc > 0 | |
| assert nb > 0 | |
| return in_nc, out_nc, nf, nb, scale # 3, 3, 64, 23, 4 | |
| # https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/images.py#L64 | |
| Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"]) | |
| # adapted from https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/images.py#L67 | |
| def split_grid(image: Image.Image, tile_w: int = 512, tile_h: int = 512, overlap: int = 64) -> Grid: | |
| w = image.width | |
| h = image.height | |
| non_overlap_width = tile_w - overlap | |
| non_overlap_height = tile_h - overlap | |
| cols = max(1, math.ceil((w - overlap) / non_overlap_width)) | |
| rows = max(1, math.ceil((h - overlap) / non_overlap_height)) | |
| dx = (w - tile_w) / (cols - 1) if cols > 1 else 0 | |
| dy = (h - tile_h) / (rows - 1) if rows > 1 else 0 | |
| grid = Grid([], tile_w, tile_h, w, h, overlap) | |
| for row in range(rows): | |
| row_images: list[Tile] = [] | |
| y1 = max(min(int(row * dy), h - tile_h), 0) | |
| y2 = min(y1 + tile_h, h) | |
| for col in range(cols): | |
| x1 = max(min(int(col * dx), w - tile_w), 0) | |
| x2 = min(x1 + tile_w, w) | |
| tile = image.crop((x1, y1, x2, y2)) | |
| row_images.append((x1, tile_w, tile)) | |
| grid.tiles.append((y1, tile_h, row_images)) | |
| return grid | |
| # https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/images.py#L104 | |
| def combine_grid(grid: Grid): | |
| def make_mask_image(r: npt.NDArray[np.float32]) -> Image.Image: | |
| r = r * 255 / grid.overlap | |
| return Image.fromarray(r.astype(np.uint8), "L") | |
| mask_w = make_mask_image( | |
| np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0) | |
| ) | |
| mask_h = make_mask_image( | |
| np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1) | |
| ) | |
| combined_image = Image.new("RGB", (grid.image_w, grid.image_h)) | |
| for y, h, row in grid.tiles: | |
| combined_row = Image.new("RGB", (grid.image_w, h)) | |
| for x, w, tile in row: | |
| if x == 0: | |
| combined_row.paste(tile, (0, 0)) | |
| continue | |
| combined_row.paste(tile.crop((0, 0, grid.overlap, h)), (x, 0), mask=mask_w) | |
| combined_row.paste(tile.crop((grid.overlap, 0, w, h)), (x + grid.overlap, 0)) | |
| if y == 0: | |
| combined_image.paste(combined_row, (0, 0)) | |
| continue | |
| combined_image.paste( | |
| combined_row.crop((0, 0, combined_row.width, grid.overlap)), | |
| (0, y), | |
| mask=mask_h, | |
| ) | |
| combined_image.paste( | |
| combined_row.crop((0, grid.overlap, combined_row.width, h)), | |
| (0, y + grid.overlap), | |
| ) | |
| return combined_image | |
| class UpscalerESRGAN: | |
| def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype): | |
| self.model_path = model_path | |
| self.device = device | |
| self.model = self.load_model(model_path) | |
| self.to(device, dtype) | |
| def __call__(self, img: Image.Image) -> Image.Image: | |
| return self.upscale_without_tiling(img) | |
| def to(self, device: torch.device, dtype: torch.dtype): | |
| self.device = device | |
| self.dtype = dtype | |
| self.model.to(device=device, dtype=dtype) | |
| def load_model(self, path: Path) -> RRDBNet: | |
| filename = path | |
| state_dict: dict[str, torch.Tensor] = torch.load(filename, weights_only=True, map_location=self.device) # type: ignore | |
| in_nc, out_nc, nf, nb, upscale = infer_params(state_dict) | |
| assert upscale == 4, "Only 4x upscaling is supported" | |
| model = RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb) | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| return model | |
| def upscale_without_tiling(self, img: Image.Image) -> Image.Image: | |
| img_np = np.array(img) | |
| img_np = img_np[:, :, ::-1] | |
| img_np = np.ascontiguousarray(np.transpose(img_np, (2, 0, 1))) / 255 | |
| img_t = torch.from_numpy(img_np).float() # type: ignore | |
| img_t = img_t.unsqueeze(0).to(device=self.device, dtype=self.dtype) | |
| with torch.no_grad(): | |
| output = self.model(img_t) | |
| output = output.squeeze().float().cpu().clamp_(0, 1).numpy() | |
| output = 255.0 * np.moveaxis(output, 0, 2) | |
| output = output.astype(np.uint8) | |
| output = output[:, :, ::-1] | |
| return Image.fromarray(output, "RGB") | |
| # https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/esrgan_model.py#L208 | |
| def upscale_with_tiling(self, img: Image.Image) -> Image.Image: | |
| img = img.convert("RGB") | |
| grid = split_grid(img) | |
| newtiles: Tiles = [] | |
| scale_factor: int = 1 | |
| for y, h, row in grid.tiles: | |
| newrow: list[Tile] = [] | |
| for tiledata in row: | |
| x, w, tile = tiledata | |
| output = self.upscale_without_tiling(tile) | |
| scale_factor = output.width // tile.width | |
| newrow.append((x * scale_factor, w * scale_factor, output)) | |
| newtiles.append((y * scale_factor, h * scale_factor, newrow)) | |
| newgrid = Grid( | |
| newtiles, | |
| grid.tile_w * scale_factor, | |
| grid.tile_h * scale_factor, | |
| grid.image_w * scale_factor, | |
| grid.image_h * scale_factor, | |
| grid.overlap * scale_factor, | |
| ) | |
| output = combine_grid(newgrid) | |
| return output | |
| class ESRGANUpscalerCheckpoints(UpscalerCheckpoints): | |
| esrgan: Path | |
| class ESRGANUpscaler(MultiUpscaler): | |
| def __init__( | |
| self, | |
| checkpoints: ESRGANUpscalerCheckpoints, | |
| device: torch.device, | |
| dtype: torch.dtype, | |
| ) -> None: | |
| super().__init__(checkpoints=checkpoints, device=device, dtype=dtype) | |
| self.esrgan = UpscalerESRGAN(checkpoints.esrgan, device=self.device, dtype=self.dtype) | |
| def to(self, device: torch.device, dtype: torch.dtype): | |
| self.esrgan.to(device=device, dtype=dtype) | |
| self.sd = self.sd.to(device=device, dtype=dtype) | |
| self.device = device | |
| self.dtype = dtype | |
| def pre_upscale(self, image: Image.Image, upscale_factor: float, **_: Any) -> Image.Image: | |
| image = self.esrgan.upscale_with_tiling(image) | |
| return super().pre_upscale(image=image, upscale_factor=upscale_factor / 4) | |
| pillow_heif.register_heif_opener() | |
| pillow_heif.register_avif_opener() | |
| CHECKPOINTS = ESRGANUpscalerCheckpoints( | |
| unet=Path( | |
| hf_hub_download( | |
| repo_id="refiners/juggernaut.reborn.sd1_5.unet", | |
| filename="model.safetensors", | |
| revision="347d14c3c782c4959cc4d1bb1e336d19f7dda4d2", | |
| ) | |
| ), | |
| clip_text_encoder=Path( | |
| hf_hub_download( | |
| repo_id="refiners/juggernaut.reborn.sd1_5.text_encoder", | |
| filename="model.safetensors", | |
| revision="744ad6a5c0437ec02ad826df9f6ede102bb27481", | |
| ) | |
| ), | |
| lda=Path( | |
| hf_hub_download( | |
| repo_id="refiners/juggernaut.reborn.sd1_5.autoencoder", | |
| filename="model.safetensors", | |
| revision="3c1aae3fc3e03e4a2b7e0fa42b62ebb64f1a4c19", | |
| ) | |
| ), | |
| controlnet_tile=Path( | |
| hf_hub_download( | |
| repo_id="refiners/controlnet.sd1_5.tile", | |
| filename="model.safetensors", | |
| revision="48ced6ff8bfa873a8976fa467c3629a240643387", | |
| ) | |
| ), | |
| esrgan=Path( | |
| hf_hub_download( | |
| repo_id="philz1337x/upscaler", | |
| filename="4x-UltraSharp.pth", | |
| revision="011deacac8270114eb7d2eeff4fe6fa9a837be70", | |
| ) | |
| ), | |
| negative_embedding=Path( | |
| hf_hub_download( | |
| repo_id="philz1337x/embeddings", | |
| filename="JuggernautNegative-neg.pt", | |
| revision="203caa7e9cc2bc225031a4021f6ab1ded283454a", | |
| ) | |
| ), | |
| negative_embedding_key="string_to_param.*", | |
| loras={ | |
| "more_details": Path( | |
| hf_hub_download( | |
| repo_id="philz1337x/loras", | |
| filename="more_details.safetensors", | |
| revision="a3802c0280c0d00c2ab18d37454a8744c44e474e", | |
| ) | |
| ), | |
| "sdxl_render": Path( | |
| hf_hub_download( | |
| repo_id="philz1337x/loras", | |
| filename="SDXLrender_v2.0.safetensors", | |
| revision="a3802c0280c0d00c2ab18d37454a8744c44e474e", | |
| ) | |
| ) | |
| } | |
| ) | |
| device = DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 | |
| enhancer = ESRGANUpscaler(checkpoints=CHECKPOINTS, device=DEVICE, dtype=DTYPE) | |
| # logging | |
| warnings.filterwarnings("ignore") | |
| root = logging.getLogger() | |
| root.setLevel(logging.WARN) | |
| handler = logging.StreamHandler(sys.stderr) | |
| handler.setLevel(logging.WARN) | |
| formatter = logging.Formatter('\n >>> [%(levelname)s] %(asctime)s %(name)s: %(message)s\n') | |
| handler.setFormatter(formatter) | |
| root.addHandler(handler) | |
| # constant data | |
| base = "black-forest-labs/FLUX.1-schnell" | |
| # precision data | |
| seq=256 | |
| width=2048 | |
| height=2048 | |
| image_steps=8 | |
| img_accu=0 | |
| # ui data | |
| css="".join([""" | |
| input, input::placeholder { | |
| text-align: center !important; | |
| } | |
| *, *::placeholder { | |
| font-family: Suez One !important; | |
| } | |
| h1,h2,h3,h4,h5,h6 { | |
| width: 100%; | |
| text-align: center; | |
| } | |
| footer { | |
| display: none !important; | |
| } | |
| .image-container { | |
| aspect-ratio: """,str(width),"/",str(height),""" !important; | |
| } | |
| .dropdown-arrow { | |
| display: none !important; | |
| } | |
| *:has(>.btn) { | |
| display: flex; | |
| justify-content: space-evenly; | |
| align-items: center; | |
| } | |
| .btn { | |
| display: flex; | |
| } | |
| """]) | |
| js=""" | |
| function custom(){ | |
| document.querySelector("div#prompt input").addEventListener("keydown",function(e){ | |
| e.target.setAttribute("last_value",e.target.value); | |
| }); | |
| document.querySelector("div#prompt input").addEventListener("input",function(e){ | |
| if( e.target.value.toString().match(/[^ a-zA-Z,]|( |,){2,}/gsm) ){ | |
| e.target.value = e.target.getAttribute("last_value"); | |
| e.target.removeAttribute("last_value"); | |
| } | |
| }); | |
| document.querySelector("div#prompt2 input").addEventListener("keydown",function(e){ | |
| e.target.setAttribute("last_value",e.target.value); | |
| }); | |
| document.querySelector("div#prompt2 input").addEventListener("input",function(e){ | |
| if( e.target.value.toString().match(/[^ a-zA-Z,]|( |,){2,}/gsm) ){ | |
| e.target.value = e.target.getAttribute("last_value"); | |
| e.target.removeAttribute("last_value"); | |
| } | |
| }); | |
| } | |
| """ | |
| # torch pipes | |
| image_pipe = FluxPipeline.from_pretrained(base, torch_dtype=torch.bfloat16).to(device) | |
| image_pipe.enable_model_cpu_offload() | |
| image_pipe.enable_vae_slicing() | |
| image_pipe.enable_vae_tiling() | |
| # functionality | |
| def upscaler( | |
| input_image: Image.Image, | |
| prompt: str = "Photorealistic, Hyperrealistic, Realistic Photography, High-Quality Photography, Natural.", | |
| negative_prompt: str = "Distorted, Discontinuous, Blurry, Doll-Like, Overly-Plastic, Low-Quality, Painted, Smoothed, Artificial, Phony, Gaudy, Digital Effects.", | |
| seed: int = int(str(random.random()).split(".")[1]), | |
| upscale_factor: int = 2, | |
| controlnet_scale: float = 0.6, | |
| controlnet_decay: float = 1.0, | |
| condition_scale: int = 6, | |
| tile_width: int = 112, | |
| tile_height: int = 144, | |
| denoise_strength: float = 0.35, | |
| num_inference_steps: int = 30, | |
| solver: str = "DDIM", | |
| ) -> Image.Image: | |
| log(f'CALL upscaler') | |
| manual_seed(seed) | |
| solver_type: type[Solver] = getattr(solvers, solver) | |
| log(f'DBG upscaler 1') | |
| enhanced_image = enhancer.upscale( | |
| image=input_image, | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| upscale_factor=upscale_factor, | |
| controlnet_scale=controlnet_scale, | |
| controlnet_scale_decay=controlnet_decay, | |
| condition_scale=condition_scale, | |
| tile_size=(tile_height, tile_width), | |
| denoise_strength=denoise_strength, | |
| num_inference_steps=num_inference_steps, | |
| loras_scale={"more_details": 0.5, "sdxl_render": 1.0}, | |
| solver_type=solver_type, | |
| ) | |
| log(f'RET upscaler') | |
| return enhanced_image | |
| def get_tensor_length(tensor): | |
| nums = list(tensor.size()) | |
| ret = 1 | |
| for num in nums: | |
| ret = ret * num | |
| return ret | |
| def summarize( | |
| text, max_len=20, min_len=10 | |
| ): | |
| log(f'CALL summarize') | |
| words = text.split() | |
| if len(words) < 5: | |
| print("Summarization Error: Text is too short, 5 words minimum!") | |
| return text | |
| prefix = "summarize: " | |
| ret = "" | |
| for index in range(math.ceil( len(words) / 500 )): | |
| chunk = " ".join(words[ index*500:(index+1)*500 ]) | |
| inputs = tokenizer.encode( prefix + chunk, return_tensors="pt", truncation=False, add_special_tokens=True) | |
| while get_tensor_length(inputs) > max_len: | |
| inputs = model.generate( | |
| inputs, | |
| length_penalty=2.0, | |
| num_beams=4, | |
| early_stopping=True, | |
| max_length=max( get_tensor_length(inputs) // 4 , max_len ), | |
| min_length=min_len | |
| ) | |
| toks = tokenizer.decode(inputs[0], skip_special_tokens=True) | |
| ret = ret + ("" if ret == "" else " ") + toks | |
| inputs = tokenizer.encode( prefix + ret, return_tensors="pt", truncation=False) | |
| gen = model.generate( | |
| inputs, | |
| length_penalty=1.0, | |
| num_beams=4, | |
| early_stopping=True, | |
| max_length=max_len, | |
| min_length=min_len | |
| ) | |
| summary = tokenizer.decode(gen[0], skip_special_tokens=True) | |
| log(f'RET summarize with summary as {summary}') | |
| return summary | |
| def generate_random_string(length): | |
| characters = str(ascii_letters + digits) | |
| return ''.join(random.choice(characters) for _ in range(length)) | |
| def pipe_generate_image(p1,p2): | |
| log(f'CALL pipe_generate') | |
| imgs = image_pipe( | |
| prompt=p1, | |
| negative_prompt=p2, | |
| height=height, | |
| width=width, | |
| guidance_scale=img_accu, | |
| num_images_per_prompt=1, | |
| num_inference_steps=image_steps, | |
| max_sequence_length=seq, | |
| generator=torch.Generator(device).manual_seed(int(str(random.random()).split(".")[1])) | |
| ).images | |
| log(f'RET pipe_generate') | |
| return imgs | |
| def add_song_cover_text(img,artist,song,height,width): | |
| draw = ImageDraw.Draw(img,mode="RGBA") | |
| rows = 1 | |
| labels_distance = 1/3 | |
| textheight=min(math.ceil( width / 10 ), math.ceil( height / 5 )) | |
| font = ImageFont.truetype(r"Alef-Bold.ttf", textheight) | |
| textwidth = draw.textlength(song,font) | |
| x = math.ceil((width - textwidth) / 2) | |
| y = height - (textheight * rows / 2) - (height / 2) | |
| y = math.ceil(y - (height / 2 * labels_distance)) | |
| draw.text((x, y), song, (255,255,255,85), font=font, spacing=2, stroke_width=math.ceil(textheight/20), stroke_fill=(0,0,0,170)) | |
| textheight=min(math.ceil( width / 10 ), math.ceil( height / 5 )) | |
| font = ImageFont.truetype(r"Alef-Bold.ttf", textheight) | |
| textwidth = draw.textlength(artist,font) | |
| x = math.ceil((width - textwidth) / 2) | |
| y = height - (textheight * rows / 2) - (height / 2) | |
| y = math.ceil(y + (height / 2 * labels_distance)) | |
| draw.text((x, y), artist, (0,0,0,85), font=font, spacing=2, stroke_width=math.ceil(textheight/20), stroke_fill=(255,255,255,170)) | |
| return img | |
| def all_pipes(pos,neg,artist,song): | |
| imgs = pipe_generate_image(pos,neg) | |
| for i in range(len(imgs)): | |
| imgs[i] = upscaler(imgs[i]) | |
| return imgs | |
| language_codes = { | |
| "af": "Afrikaans", | |
| "ar": "Arabic", | |
| "bg": "Bulgarian", | |
| "bn": "Bengali", | |
| "ca": "Catalan", | |
| "cs": "Czech", | |
| "cy": "Welsh", | |
| "da": "Danish", | |
| "de": "German", | |
| "el": "Greek", | |
| "en": "English", | |
| "es": "Spanish", | |
| "et": "Estonian", | |
| "fa": "Persian (Farsi)", | |
| "fi": "Finnish", | |
| "fr": "French", | |
| "gu": "Gujarati", | |
| "he": "Hebrew", | |
| "hi": "Hindi", | |
| "hr": "Croatian", | |
| "hu": "Hungarian", | |
| "id": "Indonesian", | |
| "it": "Italian", | |
| "ja": "Japanese", | |
| "kn": "Kannada", | |
| "ko": "Korean", | |
| "lt": "Lithuanian", | |
| "lv": "Latvian", | |
| "mk": "Macedonian", | |
| "ml": "Malayalam", | |
| "mr": "Marathi", | |
| "ne": "Nepali", | |
| "nl": "Dutch", | |
| "no": "Norwegian", | |
| "pa": "Punjabi", | |
| "pl": "Polish", | |
| "pt": "Portuguese", | |
| "ro": "Romanian", | |
| "ru": "Russian", | |
| "sk": "Slovak", | |
| "sl": "Slovenian", | |
| "so": "Somali", | |
| "sq": "Albanian", | |
| "sv": "Swedish", | |
| "sw": "Swahili", | |
| "ta": "Tamil", | |
| "te": "Telugu", | |
| "th": "Thai", | |
| "tl": "Tagalog (Filipino)", | |
| "tr": "Turkish", | |
| "uk": "Ukrainian", | |
| "ur": "Urdu", | |
| "vi": "Vietnamese", | |
| "zh-cn": "Chinese (Simplified)", | |
| "zh-tw": "Chinese (Traditional)", | |
| } | |
| def translate(txt,to_lang="en",from_lang=False): | |
| log(f'CALL translate') | |
| if not from_lang: | |
| from_lang = get_language(txt) | |
| print(f"translating from {from_lang} to {to_lang}") | |
| if(from_lang == to_lang): | |
| log(f'RET translate with txt as {txt}') | |
| return txt | |
| prefix = f"translate {language_codes[from_lang]} to {language_codes[to_lang]}: " | |
| words = txt.split() | |
| ret = "" | |
| for index in range(math.ceil( len(words) / 500 )): | |
| chunk = " ".join(words[index*500:(index+1)*500]) | |
| log(f'DBG translate chunk is {chunk}') | |
| inputs = tokenizer.encode(prefix+chunk, return_tensors="pt", truncation=False, add_special_tokens=True) | |
| gen = model.generate(inputs,num_beams=3) | |
| toks = tokenizer.decode(gen[0], skip_special_tokens=True) | |
| ret = ret + ("" if ret == "" else " ") + toks | |
| log(f'RET translate with ret as {ret}') | |
| return ret | |
| def handle_generation(artist,song,genre,lyrics): | |
| log(f'CALL handle_generate') | |
| pos_artist = re.sub("([ \t\n]){1,}", " ", artist).upper().strip() | |
| pos_song = re.sub("([ \t\n]){1,}", " ", song).lower().strip() | |
| pos_song = ' '.join(word[0].upper() + word[1:] for word in pos_song.split()) | |
| pos_genre = re.sub(f'[{punctuation}]', '', re.sub("([ \t\n]){1,}", " ", genre)).lower().strip() | |
| pos_genre = ' '.join(word[0].upper() + word[1:] for word in pos_genre.split()) | |
| pos_lyrics = re.sub(f'[{punctuation}]', '', re.sub("([ \t\n]){1,}", " ", lyrics)).lower().strip() | |
| pos_lyrics_sum = pos_lyrics if pos_lyrics == "" else summarize(translate(pos_lyrics)) | |
| neg = f"Sexuality, Humanity, Textual, Labeled, Distorted, Discontinuous, Blurry, Doll-Like, Overly Plastic, Low-Quality, Painted, Smoothed, Artificial, Phony, Gaudy, Digital Effects." | |
| q = "\"" | |
| pos = f'HQ Hyper-realistic { translate(pos_genre) } song "{ translate(pos_song) }"{ pos_lyrics_sum if pos_lyrics_sum == "" else ": " + pos_lyrics_sum }.' | |
| print(f""" | |
| Positive: {pos} | |
| Negative: {neg} | |
| """) | |
| imgs = all_pipes(pos,neg,pos_artist,pos_song) | |
| index = 1 | |
| names = [] | |
| for img in imgs: | |
| scaled_by = 2 | |
| labeled_img = add_song_cover_text(img,artist,song,height*scaled_by,width*scaled_by) | |
| name = f'{artist} - {song} ({index}).png' | |
| labeled_img.save(name) | |
| names.append(name) | |
| index = index + 1 | |
| # return names | |
| return names[0] | |
| # entry | |
| if __name__ == "__main__": | |
| with gr.Blocks(theme=gr.themes.Citrus(),css=css) as demo: | |
| gr.Markdown(f""" | |
| # Song Cover Image Generator | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| artist = gr.Textbox( | |
| placeholder="Artist name", | |
| value="", | |
| container=False, | |
| max_lines=1 | |
| ) | |
| song = gr.Textbox( | |
| placeholder="Song name", | |
| value="", | |
| container=False, | |
| max_lines=1 | |
| ) | |
| genre = gr.Textbox( | |
| placeholder="Genre", | |
| value="", | |
| container=False, | |
| max_lines=1 | |
| ) | |
| lyrics = gr.Textbox( | |
| placeholder="Lyrics", | |
| value="", | |
| container=False, | |
| max_lines=1 | |
| ) | |
| with gr.Column(): | |
| cover = gr.Image(interactive=False,container=False,elem_classes="image-container", label="Result", show_label=True, type='filepath', show_share_button=False) | |
| run = gr.Button("Generate",elem_classes="btn") | |
| run.click( | |
| fn=handle_generation, | |
| inputs=[artist,song,genre,lyrics], | |
| outputs=[cover] | |
| ) | |
| demo.queue().launch() | |