Spaces:
Running
Running
from langdetect import detect as get_language | |
from collections import namedtuple | |
from inspect import signature | |
import os | |
import subprocess | |
import logging | |
import re | |
import random | |
from string import ascii_letters, digits, punctuation | |
import requests | |
import sys | |
import warnings | |
import time | |
import asyncio | |
import math | |
from pathlib import Path | |
from functools import partial | |
from dataclasses import dataclass | |
from typing import Any | |
import pillow_heif | |
import spaces | |
import numpy as np | |
import numpy.typing as npt | |
import torch | |
from torch import nn | |
import gradio as gr | |
from lxml.html import fromstring | |
from huggingface_hub import hf_hub_download | |
from safetensors.torch import load_file, save_file | |
from diffusers import FluxPipeline | |
from PIL import Image, ImageDraw, ImageFont | |
from transformers import pipeline, T5ForConditionalGeneration, T5Tokenizer | |
from refiners.fluxion.utils import manual_seed | |
from refiners.foundationals.latent_diffusion import Solver, solvers | |
from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_upscaler import ( | |
MultiUpscaler, | |
UpscalerCheckpoints, | |
) | |
from datetime import datetime | |
model = T5ForConditionalGeneration.from_pretrained("t5-large") | |
tokenizer = T5Tokenizer.from_pretrained("t5-large") | |
def log(msg): | |
print(f'{datetime.now().time()} {msg}') | |
Tile = tuple[int, int, Image.Image] | |
Tiles = list[tuple[int, int, list[Tile]]] | |
def conv_block(in_nc: int, out_nc: int) -> nn.Sequential: | |
return nn.Sequential( | |
nn.Conv2d(in_nc, out_nc, kernel_size=3, padding=1), | |
nn.LeakyReLU(negative_slope=0.2, inplace=True), | |
) | |
class ResidualDenseBlock_5C(nn.Module): | |
""" | |
Residual Dense Block | |
The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18) | |
Modified options that can be used: | |
- "Partial Convolution based Padding" arXiv:1811.11718 | |
- "Spectral normalization" arXiv:1802.05957 | |
- "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C. | |
{Rakotonirina} and A. {Rasoanaivo} | |
""" | |
def __init__(self, nf: int = 64, gc: int = 32) -> None: | |
super().__init__() # type: ignore[reportUnknownMemberType] | |
self.conv1 = conv_block(nf, gc) | |
self.conv2 = conv_block(nf + gc, gc) | |
self.conv3 = conv_block(nf + 2 * gc, gc) | |
self.conv4 = conv_block(nf + 3 * gc, gc) | |
# Wrapped in Sequential because of key in state dict. | |
self.conv5 = nn.Sequential(nn.Conv2d(nf + 4 * gc, nf, kernel_size=3, padding=1)) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
x1 = self.conv1(x) | |
x2 = self.conv2(torch.cat((x, x1), 1)) | |
x3 = self.conv3(torch.cat((x, x1, x2), 1)) | |
x4 = self.conv4(torch.cat((x, x1, x2, x3), 1)) | |
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) | |
return x5 * 0.2 + x | |
class RRDB(nn.Module): | |
""" | |
Residual in Residual Dense Block | |
(ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks) | |
""" | |
def __init__(self, nf: int) -> None: | |
super().__init__() # type: ignore[reportUnknownMemberType] | |
self.RDB1 = ResidualDenseBlock_5C(nf) | |
self.RDB2 = ResidualDenseBlock_5C(nf) | |
self.RDB3 = ResidualDenseBlock_5C(nf) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
out = self.RDB1(x) | |
out = self.RDB2(out) | |
out = self.RDB3(out) | |
return out * 0.2 + x | |
class Upsample2x(nn.Module): | |
"""Upsample 2x.""" | |
def __init__(self) -> None: | |
super().__init__() # type: ignore[reportUnknownMemberType] | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
return nn.functional.interpolate(x, scale_factor=2.0) # type: ignore | |
class ShortcutBlock(nn.Module): | |
"""Elementwise sum the output of a submodule to its input""" | |
def __init__(self, submodule: nn.Module) -> None: | |
super().__init__() # type: ignore[reportUnknownMemberType] | |
self.sub = submodule | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
return x + self.sub(x) | |
class RRDBNet(nn.Module): | |
def __init__(self, in_nc: int, out_nc: int, nf: int, nb: int) -> None: | |
super().__init__() # type: ignore[reportUnknownMemberType] | |
assert in_nc % 4 != 0 # in_nc is 3 | |
self.model = nn.Sequential( | |
nn.Conv2d(in_nc, nf, kernel_size=3, padding=1), | |
ShortcutBlock( | |
nn.Sequential( | |
*(RRDB(nf) for _ in range(nb)), | |
nn.Conv2d(nf, nf, kernel_size=3, padding=1), | |
) | |
), | |
Upsample2x(), | |
nn.Conv2d(nf, nf, kernel_size=3, padding=1), | |
nn.LeakyReLU(negative_slope=0.2, inplace=True), | |
Upsample2x(), | |
nn.Conv2d(nf, nf, kernel_size=3, padding=1), | |
nn.LeakyReLU(negative_slope=0.2, inplace=True), | |
nn.Conv2d(nf, nf, kernel_size=3, padding=1), | |
nn.LeakyReLU(negative_slope=0.2, inplace=True), | |
nn.Conv2d(nf, out_nc, kernel_size=3, padding=1), | |
) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
return self.model(x) | |
def infer_params(state_dict: dict[str, torch.Tensor]) -> tuple[int, int, int, int, int]: | |
# this code is adapted from https://github.com/victorca25/iNNfer | |
scale2x = 0 | |
scalemin = 6 | |
n_uplayer = 0 | |
out_nc = 0 | |
nb = 0 | |
for block in list(state_dict): | |
parts = block.split(".") | |
n_parts = len(parts) | |
if n_parts == 5 and parts[2] == "sub": | |
nb = int(parts[3]) | |
elif n_parts == 3: | |
part_num = int(parts[1]) | |
if part_num > scalemin and parts[0] == "model" and parts[2] == "weight": | |
scale2x += 1 | |
if part_num > n_uplayer: | |
n_uplayer = part_num | |
out_nc = state_dict[block].shape[0] | |
assert "conv1x1" not in block # no ESRGANPlus | |
nf = state_dict["model.0.weight"].shape[0] | |
in_nc = state_dict["model.0.weight"].shape[1] | |
scale = 2**scale2x | |
assert out_nc > 0 | |
assert nb > 0 | |
return in_nc, out_nc, nf, nb, scale # 3, 3, 64, 23, 4 | |
# https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/images.py#L64 | |
Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"]) | |
# adapted from https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/images.py#L67 | |
def split_grid(image: Image.Image, tile_w: int = 512, tile_h: int = 512, overlap: int = 64) -> Grid: | |
w = image.width | |
h = image.height | |
non_overlap_width = tile_w - overlap | |
non_overlap_height = tile_h - overlap | |
cols = max(1, math.ceil((w - overlap) / non_overlap_width)) | |
rows = max(1, math.ceil((h - overlap) / non_overlap_height)) | |
dx = (w - tile_w) / (cols - 1) if cols > 1 else 0 | |
dy = (h - tile_h) / (rows - 1) if rows > 1 else 0 | |
grid = Grid([], tile_w, tile_h, w, h, overlap) | |
for row in range(rows): | |
row_images: list[Tile] = [] | |
y1 = max(min(int(row * dy), h - tile_h), 0) | |
y2 = min(y1 + tile_h, h) | |
for col in range(cols): | |
x1 = max(min(int(col * dx), w - tile_w), 0) | |
x2 = min(x1 + tile_w, w) | |
tile = image.crop((x1, y1, x2, y2)) | |
row_images.append((x1, tile_w, tile)) | |
grid.tiles.append((y1, tile_h, row_images)) | |
return grid | |
# https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/images.py#L104 | |
def combine_grid(grid: Grid): | |
def make_mask_image(r: npt.NDArray[np.float32]) -> Image.Image: | |
r = r * 255 / grid.overlap | |
return Image.fromarray(r.astype(np.uint8), "L") | |
mask_w = make_mask_image( | |
np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0) | |
) | |
mask_h = make_mask_image( | |
np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1) | |
) | |
combined_image = Image.new("RGB", (grid.image_w, grid.image_h)) | |
for y, h, row in grid.tiles: | |
combined_row = Image.new("RGB", (grid.image_w, h)) | |
for x, w, tile in row: | |
if x == 0: | |
combined_row.paste(tile, (0, 0)) | |
continue | |
combined_row.paste(tile.crop((0, 0, grid.overlap, h)), (x, 0), mask=mask_w) | |
combined_row.paste(tile.crop((grid.overlap, 0, w, h)), (x + grid.overlap, 0)) | |
if y == 0: | |
combined_image.paste(combined_row, (0, 0)) | |
continue | |
combined_image.paste( | |
combined_row.crop((0, 0, combined_row.width, grid.overlap)), | |
(0, y), | |
mask=mask_h, | |
) | |
combined_image.paste( | |
combined_row.crop((0, grid.overlap, combined_row.width, h)), | |
(0, y + grid.overlap), | |
) | |
return combined_image | |
class UpscalerESRGAN: | |
def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype): | |
self.model_path = model_path | |
self.device = device | |
self.model = self.load_model(model_path) | |
self.to(device, dtype) | |
def __call__(self, img: Image.Image) -> Image.Image: | |
return self.upscale_without_tiling(img) | |
def to(self, device: torch.device, dtype: torch.dtype): | |
self.device = device | |
self.dtype = dtype | |
self.model.to(device=device, dtype=dtype) | |
def load_model(self, path: Path) -> RRDBNet: | |
filename = path | |
state_dict: dict[str, torch.Tensor] = torch.load(filename, weights_only=True, map_location=self.device) # type: ignore | |
in_nc, out_nc, nf, nb, upscale = infer_params(state_dict) | |
assert upscale == 4, "Only 4x upscaling is supported" | |
model = RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb) | |
model.load_state_dict(state_dict) | |
model.eval() | |
return model | |
def upscale_without_tiling(self, img: Image.Image) -> Image.Image: | |
img_np = np.array(img) | |
img_np = img_np[:, :, ::-1] | |
img_np = np.ascontiguousarray(np.transpose(img_np, (2, 0, 1))) / 255 | |
img_t = torch.from_numpy(img_np).float() # type: ignore | |
img_t = img_t.unsqueeze(0).to(device=self.device, dtype=self.dtype) | |
with torch.no_grad(): | |
output = self.model(img_t) | |
output = output.squeeze().float().cpu().clamp_(0, 1).numpy() | |
output = 255.0 * np.moveaxis(output, 0, 2) | |
output = output.astype(np.uint8) | |
output = output[:, :, ::-1] | |
return Image.fromarray(output, "RGB") | |
# https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/esrgan_model.py#L208 | |
def upscale_with_tiling(self, img: Image.Image) -> Image.Image: | |
img = img.convert("RGB") | |
grid = split_grid(img) | |
newtiles: Tiles = [] | |
scale_factor: int = 1 | |
for y, h, row in grid.tiles: | |
newrow: list[Tile] = [] | |
for tiledata in row: | |
x, w, tile = tiledata | |
output = self.upscale_without_tiling(tile) | |
scale_factor = output.width // tile.width | |
newrow.append((x * scale_factor, w * scale_factor, output)) | |
newtiles.append((y * scale_factor, h * scale_factor, newrow)) | |
newgrid = Grid( | |
newtiles, | |
grid.tile_w * scale_factor, | |
grid.tile_h * scale_factor, | |
grid.image_w * scale_factor, | |
grid.image_h * scale_factor, | |
grid.overlap * scale_factor, | |
) | |
output = combine_grid(newgrid) | |
return output | |
class ESRGANUpscalerCheckpoints(UpscalerCheckpoints): | |
esrgan: Path | |
class ESRGANUpscaler(MultiUpscaler): | |
def __init__( | |
self, | |
checkpoints: ESRGANUpscalerCheckpoints, | |
device: torch.device, | |
dtype: torch.dtype, | |
) -> None: | |
super().__init__(checkpoints=checkpoints, device=device, dtype=dtype) | |
self.esrgan = UpscalerESRGAN(checkpoints.esrgan, device=self.device, dtype=self.dtype) | |
def to(self, device: torch.device, dtype: torch.dtype): | |
self.esrgan.to(device=device, dtype=dtype) | |
self.sd = self.sd.to(device=device, dtype=dtype) | |
self.device = device | |
self.dtype = dtype | |
def pre_upscale(self, image: Image.Image, upscale_factor: float, **_: Any) -> Image.Image: | |
image = self.esrgan.upscale_with_tiling(image) | |
return super().pre_upscale(image=image, upscale_factor=upscale_factor / 4) | |
pillow_heif.register_heif_opener() | |
pillow_heif.register_avif_opener() | |
CHECKPOINTS = ESRGANUpscalerCheckpoints( | |
unet=Path( | |
hf_hub_download( | |
repo_id="refiners/juggernaut.reborn.sd1_5.unet", | |
filename="model.safetensors", | |
revision="347d14c3c782c4959cc4d1bb1e336d19f7dda4d2", | |
) | |
), | |
clip_text_encoder=Path( | |
hf_hub_download( | |
repo_id="refiners/juggernaut.reborn.sd1_5.text_encoder", | |
filename="model.safetensors", | |
revision="744ad6a5c0437ec02ad826df9f6ede102bb27481", | |
) | |
), | |
lda=Path( | |
hf_hub_download( | |
repo_id="refiners/juggernaut.reborn.sd1_5.autoencoder", | |
filename="model.safetensors", | |
revision="3c1aae3fc3e03e4a2b7e0fa42b62ebb64f1a4c19", | |
) | |
), | |
controlnet_tile=Path( | |
hf_hub_download( | |
repo_id="refiners/controlnet.sd1_5.tile", | |
filename="model.safetensors", | |
revision="48ced6ff8bfa873a8976fa467c3629a240643387", | |
) | |
), | |
esrgan=Path( | |
hf_hub_download( | |
repo_id="philz1337x/upscaler", | |
filename="4x-UltraSharp.pth", | |
revision="011deacac8270114eb7d2eeff4fe6fa9a837be70", | |
) | |
), | |
negative_embedding=Path( | |
hf_hub_download( | |
repo_id="philz1337x/embeddings", | |
filename="JuggernautNegative-neg.pt", | |
revision="203caa7e9cc2bc225031a4021f6ab1ded283454a", | |
) | |
), | |
negative_embedding_key="string_to_param.*", | |
loras={ | |
"more_details": Path( | |
hf_hub_download( | |
repo_id="philz1337x/loras", | |
filename="more_details.safetensors", | |
revision="a3802c0280c0d00c2ab18d37454a8744c44e474e", | |
) | |
), | |
"sdxl_render": Path( | |
hf_hub_download( | |
repo_id="philz1337x/loras", | |
filename="SDXLrender_v2.0.safetensors", | |
revision="a3802c0280c0d00c2ab18d37454a8744c44e474e", | |
) | |
) | |
} | |
) | |
device = DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 | |
enhancer = ESRGANUpscaler(checkpoints=CHECKPOINTS, device=DEVICE, dtype=DTYPE) | |
# logging | |
warnings.filterwarnings("ignore") | |
root = logging.getLogger() | |
root.setLevel(logging.WARN) | |
handler = logging.StreamHandler(sys.stderr) | |
handler.setLevel(logging.WARN) | |
formatter = logging.Formatter('\n >>> [%(levelname)s] %(asctime)s %(name)s: %(message)s\n') | |
handler.setFormatter(formatter) | |
root.addHandler(handler) | |
# constant data | |
base = "black-forest-labs/FLUX.1-schnell" | |
# precision data | |
seq=256 | |
width=1536 | |
height=1536 | |
image_steps=8 | |
img_accu=0 | |
# ui data | |
css="".join([""" | |
input, input::placeholder { | |
text-align: center !important; | |
} | |
*, *::placeholder { | |
font-family: Suez One !important; | |
} | |
h1,h2,h3,h4,h5,h6 { | |
width: 100%; | |
text-align: center; | |
} | |
footer { | |
display: none !important; | |
} | |
.image-container { | |
aspect-ratio: """,str(width),"/",str(height),""" !important; | |
} | |
.dropdown-arrow { | |
display: none !important; | |
} | |
*:has(>.btn) { | |
display: flex; | |
justify-content: space-evenly; | |
align-items: center; | |
} | |
.btn { | |
display: flex; | |
} | |
"""]) | |
js=""" | |
function custom(){ | |
document.querySelector("div#prompt input").addEventListener("keydown",function(e){ | |
e.target.setAttribute("last_value",e.target.value); | |
}); | |
document.querySelector("div#prompt input").addEventListener("input",function(e){ | |
if( e.target.value.toString().match(/[^ a-zA-Z,]|( |,){2,}/gsm) ){ | |
e.target.value = e.target.getAttribute("last_value"); | |
e.target.removeAttribute("last_value"); | |
} | |
}); | |
document.querySelector("div#prompt2 input").addEventListener("keydown",function(e){ | |
e.target.setAttribute("last_value",e.target.value); | |
}); | |
document.querySelector("div#prompt2 input").addEventListener("input",function(e){ | |
if( e.target.value.toString().match(/[^ a-zA-Z,]|( |,){2,}/gsm) ){ | |
e.target.value = e.target.getAttribute("last_value"); | |
e.target.removeAttribute("last_value"); | |
} | |
}); | |
} | |
""" | |
# torch pipes | |
image_pipe = FluxPipeline.from_pretrained(base, torch_dtype=torch.bfloat16).to(device) | |
image_pipe.enable_model_cpu_offload() | |
image_pipe.enable_vae_slicing() | |
image_pipe.enable_vae_tiling() | |
# functionality | |
def upscaler( | |
input_image: Image.Image, | |
prompt: str = "Photorealistic, Hyperrealistic, Realistic Photography, High-Quality Photography, Natural.", | |
negative_prompt: str = "Distorted, Discontinuous, Blurry, Doll-Like, Overly-Plastic, Low-Quality, Painted, Smoothed, Artificial, Phony, Gaudy, Digital Effects.", | |
seed: int = int(str(random.random()).split(".")[1]), | |
upscale_factor: int = 2, | |
controlnet_scale: float = 0.6, | |
controlnet_decay: float = 1.0, | |
condition_scale: int = 6, | |
tile_width: int = 112, | |
tile_height: int = 144, | |
denoise_strength: float = 0.35, | |
num_inference_steps: int = 30, | |
solver: str = "DDIM", | |
) -> Image.Image: | |
log(f'CALL upscaler') | |
manual_seed(seed) | |
solver_type: type[Solver] = getattr(solvers, solver) | |
log(f'DBG upscaler 1') | |
enhanced_image = enhancer.upscale( | |
image=input_image, | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
upscale_factor=upscale_factor, | |
controlnet_scale=controlnet_scale, | |
controlnet_scale_decay=controlnet_decay, | |
condition_scale=condition_scale, | |
tile_size=(tile_height, tile_width), | |
denoise_strength=denoise_strength, | |
num_inference_steps=num_inference_steps, | |
loras_scale={"more_details": 0.5, "sdxl_render": 1.0}, | |
solver_type=solver_type, | |
) | |
log(f'RET upscaler') | |
return enhanced_image | |
def get_tensor_length(tensor): | |
nums = list(tensor.size()) | |
ret = 1 | |
for num in nums: | |
ret = ret * num | |
return ret | |
def summarize( | |
text, max_len=20, min_len=10 | |
): | |
log(f'CALL summarize_text') | |
inputs = tokenizer.encode("summarize: " + text, return_tensors="pt", max_length=float('inf'), truncation=False) | |
i = 1 | |
while get_tensor_length(inputs) > max_len: | |
print(f'DBG summarize_text 1 {i}') | |
outputs = model.generate( | |
inputs[0][:512], | |
length_penalty=2.0, | |
num_beams=max(8,get_tensor_length(inputs)), | |
early_stopping=True, | |
max_length=max( get_tensor_length(inputs) // 4 , max_len ), | |
min_length=min_len | |
) | |
inputs = torch.tensor([[*list(outputs[0]), *list(inputs[0][512:])]]) | |
i = i + 1 | |
summary = tokenizer.decode(inputs[0]) | |
log(f'RET summarize_text with summary as {summary}') | |
return summary | |
def generate_random_string(length): | |
characters = str(ascii_letters + digits) | |
return ''.join(random.choice(characters) for _ in range(length)) | |
def pipe_generate_image(p1,p2): | |
log(f'CALL pipe_generate') | |
imgs = image_pipe( | |
prompt=p1, | |
negative_prompt=p2, | |
height=height, | |
width=width, | |
guidance_scale=img_accu, | |
num_images_per_prompt=1, | |
num_inference_steps=image_steps, | |
max_sequence_length=seq, | |
generator=torch.Generator(device).manual_seed(int(str(random.random()).split(".")[1])) | |
).images | |
log(f'RET pipe_generate') | |
return imgs | |
def add_song_cover_text(img,artist,song,height,width): | |
draw = ImageDraw.Draw(img,mode="RGBA") | |
rows = 1 | |
labels_distance = 1/3 | |
textheight=min(math.ceil( width / 10 ), math.ceil( height / 5 )) | |
font = ImageFont.truetype(r"Alef-Bold.ttf", textheight) | |
textwidth = draw.textlength(song,font) | |
x = math.ceil((width - textwidth) / 2) | |
y = height - (textheight * rows / 2) - (height / 2) | |
y = math.ceil(y - (height / 2 * labels_distance)) | |
draw.text((x, y), song, (255,255,255,85), font=font, spacing=2, stroke_width=math.ceil(textheight/20), stroke_fill=(0,0,0,170)) | |
textheight=min(math.ceil( width / 10 ), math.ceil( height / 5 )) | |
font = ImageFont.truetype(r"Alef-Bold.ttf", textheight) | |
textwidth = draw.textlength(artist,font) | |
x = math.ceil((width - textwidth) / 2) | |
y = height - (textheight * rows / 2) - (height / 2) | |
y = math.ceil(y + (height / 2 * labels_distance)) | |
draw.text((x, y), artist, (0,0,0,85), font=font, spacing=2, stroke_width=math.ceil(textheight/20), stroke_fill=(255,255,255,170)) | |
return img | |
def all_pipes(pos,neg,artist,song): | |
imgs = pipe_generate_image(pos,neg) | |
for i in range(len(imgs)): | |
imgs[i] = upscaler(imgs[i]) | |
return imgs | |
def translate(txt,to_lang="en",from_lang=False): | |
log(f'CALL translate') | |
if not from_lang: | |
from_lang = get_language(txt) | |
if(from_lang == to_lang): | |
log(f'RET translate with txt as {txt}') | |
return txt | |
inputs = tokenizer.encode(f"translate {from_lang} to {to_lang}: " + text, return_tensors="pt", max_length=float('inf'), truncation=False) | |
chunks_length = math.ceil(get_tensor_length(inputs) / 512): | |
ret = "" | |
for index in range(chunks_length): | |
ret = ret + ("" if ret == "" else " ") + tokenizer.decode( | |
model.generate( | |
inputs[0][ index*512:index*512+512 ] | |
)[0] | |
) | |
log(f'RET translate with ret as {ret}') | |
return ret | |
def handle_generation(artist,song,genre,lyrics): | |
log(f'CALL handle_generate') | |
pos_artist = re.sub("([ \t\n]){1,}", " ", artist).upper().strip() | |
pos_song = re.sub("([ \t\n]){1,}", " ", song).lower().strip() | |
pos_song = ' '.join(word[0].upper() + word[1:] for word in pos_song.split()) | |
pos_genre = re.sub(f'[{punctuation}]', '', re.sub("([ \t\n]){1,}", " ", genre)).lower().strip() | |
pos_genre = ' '.join(word[0].upper() + word[1:] for word in pos_genre.split()) | |
pos_lyrics = re.sub(f'[{punctuation}]', '', re.sub("([ \t\n]){1,}", " ", lyrics)).lower().strip() | |
pos_lyrics_sum = pos_lyrics if pos_lyrics == "" else summarize(pos_lyrics) | |
neg = f"Sexuality, Humanity, Textual, Labeled, Distorted, Discontinuous, Blurry, Doll-Like, Overly Plastic, Low-Quality, Painted, Smoothed, Artificial, Phony, Gaudy, Digital Effects." | |
q = "\"" | |
pos = f'HQ Hyper-realistic { translate(pos_genre) } song "{ translate(pos_song) }"{ pos_lyrics_sum if pos_lyrics_sum == "" else ": " + translate(pos_lyrics_sum) }.' | |
print(f""" | |
Positive: {pos} | |
Negative: {neg} | |
""") | |
imgs = all_pipes(pos,neg,pos_artist,pos_song) | |
index = 1 | |
names = [] | |
for img in imgs: | |
scaled_by = 2 | |
labeled_img = add_song_cover_text(img,artist,song,height*scaled_by,width*scaled_by) | |
name = f'{artist} - {song} ({index}).png' | |
labeled_img.save(name) | |
names.append(name) | |
index = index + 1 | |
# return names | |
return names[0] | |
# entry | |
if __name__ == "__main__": | |
with gr.Blocks(theme=gr.themes.Citrus(),css=css) as demo: | |
gr.Markdown(f""" | |
# Song Cover Image Generator | |
""") | |
with gr.Row(): | |
with gr.Column(scale=4): | |
artist = gr.Textbox( | |
placeholder="Artist name", | |
value="", | |
container=False, | |
max_lines=1 | |
) | |
song = gr.Textbox( | |
placeholder="Song name", | |
value="", | |
container=False, | |
max_lines=1 | |
) | |
genre = gr.Textbox( | |
placeholder="Genre", | |
value="", | |
container=False, | |
max_lines=1 | |
) | |
lyrics = gr.Textbox( | |
placeholder="Lyrics", | |
value="", | |
container=False, | |
max_lines=1 | |
) | |
with gr.Column(): | |
cover = gr.Image(interactive=False,container=False,elem_classes="image-container", label="Result", show_label=True, type='filepath', show_share_button=False) | |
run = gr.Button("Generate",elem_classes="btn") | |
run.click( | |
fn=handle_generation, | |
inputs=[artist,song,genre,lyrics], | |
outputs=[cover] | |
) | |
demo.queue().launch() | |