Spaces:
Sleeping
Sleeping
import requests | |
from bs4 import BeautifulSoup | |
from abc import ABC, abstractmethod | |
from pathlib import Path | |
from langdetect import detect as get_language | |
from typing import List, Optional, Union | |
from collections import namedtuple | |
from inspect import signature | |
import os | |
import subprocess | |
import logging | |
import re | |
import random | |
from string import ascii_letters, digits, punctuation | |
import requests | |
import sys | |
import warnings | |
import time | |
import asyncio | |
import math | |
from pathlib import Path | |
from functools import partial | |
from dataclasses import dataclass | |
from typing import Any | |
import pillow_heif | |
import spaces | |
import numpy as np | |
import numpy.typing as npt | |
import torch | |
from torch import nn | |
import gradio as gr | |
from lxml.html import fromstring | |
from huggingface_hub import hf_hub_download | |
from safetensors.torch import load_file, save_file | |
from diffusers import FluxPipeline | |
from PIL import Image, ImageDraw, ImageFont | |
from transformers import pipeline, T5ForConditionalGeneration, T5Tokenizer | |
from refiners.fluxion.utils import manual_seed | |
from refiners.foundationals.latent_diffusion import Solver, solvers | |
from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_upscaler import ( | |
MultiUpscaler, | |
UpscalerCheckpoints, | |
) | |
from datetime import datetime | |
model = T5ForConditionalGeneration.from_pretrained("t5-large") | |
tokenizer = T5Tokenizer.from_pretrained("t5-large") | |
def log(msg): | |
print(f'{datetime.now().time()} {msg}') | |
Tile = tuple[int, int, Image.Image] | |
Tiles = list[tuple[int, int, list[Tile]]] | |
def conv_block(in_nc: int, out_nc: int) -> nn.Sequential: | |
return nn.Sequential( | |
nn.Conv2d(in_nc, out_nc, kernel_size=3, padding=1), | |
nn.LeakyReLU(negative_slope=0.2, inplace=True), | |
) | |
class ResidualDenseBlock_5C(nn.Module): | |
""" | |
Residual Dense Block | |
The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18) | |
Modified options that can be used: | |
- "Partial Convolution based Padding" arXiv:1811.11718 | |
- "Spectral normalization" arXiv:1802.05957 | |
- "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C. | |
{Rakotonirina} and A. {Rasoanaivo} | |
""" | |
def __init__(self, nf: int = 64, gc: int = 32) -> None: | |
super().__init__() # type: ignore[reportUnknownMemberType] | |
self.conv1 = conv_block(nf, gc) | |
self.conv2 = conv_block(nf + gc, gc) | |
self.conv3 = conv_block(nf + 2 * gc, gc) | |
self.conv4 = conv_block(nf + 3 * gc, gc) | |
# Wrapped in Sequential because of key in state dict. | |
self.conv5 = nn.Sequential(nn.Conv2d(nf + 4 * gc, nf, kernel_size=3, padding=1)) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
x1 = self.conv1(x) | |
x2 = self.conv2(torch.cat((x, x1), 1)) | |
x3 = self.conv3(torch.cat((x, x1, x2), 1)) | |
x4 = self.conv4(torch.cat((x, x1, x2, x3), 1)) | |
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) | |
return x5 * 0.2 + x | |
class RRDB(nn.Module): | |
""" | |
Residual in Residual Dense Block | |
(ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks) | |
""" | |
def __init__(self, nf: int) -> None: | |
super().__init__() # type: ignore[reportUnknownMemberType] | |
self.RDB1 = ResidualDenseBlock_5C(nf) | |
self.RDB2 = ResidualDenseBlock_5C(nf) | |
self.RDB3 = ResidualDenseBlock_5C(nf) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
out = self.RDB1(x) | |
out = self.RDB2(out) | |
out = self.RDB3(out) | |
return out * 0.2 + x | |
class Upsample2x(nn.Module): | |
"""Upsample 2x.""" | |
def __init__(self) -> None: | |
super().__init__() # type: ignore[reportUnknownMemberType] | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
return nn.functional.interpolate(x, scale_factor=2.0) # type: ignore | |
class ShortcutBlock(nn.Module): | |
"""Elementwise sum the output of a submodule to its input""" | |
def __init__(self, submodule: nn.Module) -> None: | |
super().__init__() # type: ignore[reportUnknownMemberType] | |
self.sub = submodule | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
return x + self.sub(x) | |
class RRDBNet(nn.Module): | |
def __init__(self, in_nc: int, out_nc: int, nf: int, nb: int) -> None: | |
super().__init__() # type: ignore[reportUnknownMemberType] | |
assert in_nc % 4 != 0 # in_nc is 3 | |
self.model = nn.Sequential( | |
nn.Conv2d(in_nc, nf, kernel_size=3, padding=1), | |
ShortcutBlock( | |
nn.Sequential( | |
*(RRDB(nf) for _ in range(nb)), | |
nn.Conv2d(nf, nf, kernel_size=3, padding=1), | |
) | |
), | |
Upsample2x(), | |
nn.Conv2d(nf, nf, kernel_size=3, padding=1), | |
nn.LeakyReLU(negative_slope=0.2, inplace=True), | |
Upsample2x(), | |
nn.Conv2d(nf, nf, kernel_size=3, padding=1), | |
nn.LeakyReLU(negative_slope=0.2, inplace=True), | |
nn.Conv2d(nf, nf, kernel_size=3, padding=1), | |
nn.LeakyReLU(negative_slope=0.2, inplace=True), | |
nn.Conv2d(nf, out_nc, kernel_size=3, padding=1), | |
) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
return self.model(x) | |
def infer_params(state_dict: dict[str, torch.Tensor]) -> tuple[int, int, int, int, int]: | |
# this code is adapted from https://github.com/victorca25/iNNfer | |
scale2x = 0 | |
scalemin = 6 | |
n_uplayer = 0 | |
out_nc = 0 | |
nb = 0 | |
for block in list(state_dict): | |
parts = block.split(".") | |
n_parts = len(parts) | |
if n_parts == 5 and parts[2] == "sub": | |
nb = int(parts[3]) | |
elif n_parts == 3: | |
part_num = int(parts[1]) | |
if part_num > scalemin and parts[0] == "model" and parts[2] == "weight": | |
scale2x += 1 | |
if part_num > n_uplayer: | |
n_uplayer = part_num | |
out_nc = state_dict[block].shape[0] | |
assert "conv1x1" not in block # no ESRGANPlus | |
nf = state_dict["model.0.weight"].shape[0] | |
in_nc = state_dict["model.0.weight"].shape[1] | |
scale = 2**scale2x | |
assert out_nc > 0 | |
assert nb > 0 | |
return in_nc, out_nc, nf, nb, scale # 3, 3, 64, 23, 4 | |
# https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/images.py#L64 | |
Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"]) | |
# adapted from https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/images.py#L67 | |
def split_grid(image: Image.Image, tile_w: int = 512, tile_h: int = 512, overlap: int = 64) -> Grid: | |
w = image.width | |
h = image.height | |
non_overlap_width = tile_w - overlap | |
non_overlap_height = tile_h - overlap | |
cols = max(1, math.ceil((w - overlap) / non_overlap_width)) | |
rows = max(1, math.ceil((h - overlap) / non_overlap_height)) | |
dx = (w - tile_w) / (cols - 1) if cols > 1 else 0 | |
dy = (h - tile_h) / (rows - 1) if rows > 1 else 0 | |
grid = Grid([], tile_w, tile_h, w, h, overlap) | |
for row in range(rows): | |
row_images: list[Tile] = [] | |
y1 = max(min(int(row * dy), h - tile_h), 0) | |
y2 = min(y1 + tile_h, h) | |
for col in range(cols): | |
x1 = max(min(int(col * dx), w - tile_w), 0) | |
x2 = min(x1 + tile_w, w) | |
tile = image.crop((x1, y1, x2, y2)) | |
row_images.append((x1, tile_w, tile)) | |
grid.tiles.append((y1, tile_h, row_images)) | |
return grid | |
# https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/images.py#L104 | |
def combine_grid(grid: Grid): | |
def make_mask_image(r: npt.NDArray[np.float32]) -> Image.Image: | |
r = r * 255 / grid.overlap | |
return Image.fromarray(r.astype(np.uint8), "L") | |
mask_w = make_mask_image( | |
np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0) | |
) | |
mask_h = make_mask_image( | |
np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1) | |
) | |
combined_image = Image.new("RGB", (grid.image_w, grid.image_h)) | |
for y, h, row in grid.tiles: | |
combined_row = Image.new("RGB", (grid.image_w, h)) | |
for x, w, tile in row: | |
if x == 0: | |
combined_row.paste(tile, (0, 0)) | |
continue | |
combined_row.paste(tile.crop((0, 0, grid.overlap, h)), (x, 0), mask=mask_w) | |
combined_row.paste(tile.crop((grid.overlap, 0, w, h)), (x + grid.overlap, 0)) | |
if y == 0: | |
combined_image.paste(combined_row, (0, 0)) | |
continue | |
combined_image.paste( | |
combined_row.crop((0, 0, combined_row.width, grid.overlap)), | |
(0, y), | |
mask=mask_h, | |
) | |
combined_image.paste( | |
combined_row.crop((0, grid.overlap, combined_row.width, h)), | |
(0, y + grid.overlap), | |
) | |
return combined_image | |
class UpscalerESRGAN: | |
def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype): | |
self.model_path = model_path | |
self.device = device | |
self.model = self.load_model(model_path) | |
self.to(device, dtype) | |
def __call__(self, img: Image.Image) -> Image.Image: | |
return self.upscale_without_tiling(img) | |
def to(self, device: torch.device, dtype: torch.dtype): | |
self.device = device | |
self.dtype = dtype | |
self.model.to(device=device, dtype=dtype) | |
def load_model(self, path: Path) -> RRDBNet: | |
filename = path | |
state_dict: dict[str, torch.Tensor] = torch.load(filename, weights_only=True, map_location=self.device) # type: ignore | |
in_nc, out_nc, nf, nb, upscale = infer_params(state_dict) | |
assert upscale == 4, "Only 4x upscaling is supported" | |
model = RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb) | |
model.load_state_dict(state_dict) | |
model.eval() | |
return model | |
def upscale_without_tiling(self, img: Image.Image) -> Image.Image: | |
img_np = np.array(img) | |
img_np = img_np[:, :, ::-1] | |
img_np = np.ascontiguousarray(np.transpose(img_np, (2, 0, 1))) / 255 | |
img_t = torch.from_numpy(img_np).float() # type: ignore | |
img_t = img_t.unsqueeze(0).to(device=self.device, dtype=self.dtype) | |
with torch.no_grad(): | |
output = self.model(img_t) | |
output = output.squeeze().float().cpu().clamp_(0, 1).numpy() | |
output = 255.0 * np.moveaxis(output, 0, 2) | |
output = output.astype(np.uint8) | |
output = output[:, :, ::-1] | |
return Image.fromarray(output, "RGB") | |
# https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/esrgan_model.py#L208 | |
def upscale_with_tiling(self, img: Image.Image) -> Image.Image: | |
img = img.convert("RGB") | |
grid = split_grid(img) | |
newtiles: Tiles = [] | |
scale_factor: int = 1 | |
for y, h, row in grid.tiles: | |
newrow: list[Tile] = [] | |
for tiledata in row: | |
x, w, tile = tiledata | |
output = self.upscale_without_tiling(tile) | |
scale_factor = output.width // tile.width | |
newrow.append((x * scale_factor, w * scale_factor, output)) | |
newtiles.append((y * scale_factor, h * scale_factor, newrow)) | |
newgrid = Grid( | |
newtiles, | |
grid.tile_w * scale_factor, | |
grid.tile_h * scale_factor, | |
grid.image_w * scale_factor, | |
grid.image_h * scale_factor, | |
grid.overlap * scale_factor, | |
) | |
output = combine_grid(newgrid) | |
return output | |
class ESRGANUpscalerCheckpoints(UpscalerCheckpoints): | |
esrgan: Path | |
class ESRGANUpscaler(MultiUpscaler): | |
def __init__( | |
self, | |
checkpoints: ESRGANUpscalerCheckpoints, | |
device: torch.device, | |
dtype: torch.dtype, | |
) -> None: | |
super().__init__(checkpoints=checkpoints, device=device, dtype=dtype) | |
self.esrgan = UpscalerESRGAN(checkpoints.esrgan, device=self.device, dtype=self.dtype) | |
def to(self, device: torch.device, dtype: torch.dtype): | |
self.esrgan.to(device=device, dtype=dtype) | |
self.sd = self.sd.to(device=device, dtype=dtype) | |
self.device = device | |
self.dtype = dtype | |
def pre_upscale(self, image: Image.Image, upscale_factor: float, **_: Any) -> Image.Image: | |
image = self.esrgan.upscale_with_tiling(image) | |
return super().pre_upscale(image=image, upscale_factor=upscale_factor / 4) | |
pillow_heif.register_heif_opener() | |
pillow_heif.register_avif_opener() | |
CHECKPOINTS = ESRGANUpscalerCheckpoints( | |
unet=Path( | |
hf_hub_download( | |
repo_id="refiners/juggernaut.reborn.sd1_5.unet", | |
filename="model.safetensors", | |
revision="347d14c3c782c4959cc4d1bb1e336d19f7dda4d2", | |
) | |
), | |
clip_text_encoder=Path( | |
hf_hub_download( | |
repo_id="refiners/juggernaut.reborn.sd1_5.text_encoder", | |
filename="model.safetensors", | |
revision="744ad6a5c0437ec02ad826df9f6ede102bb27481", | |
) | |
), | |
lda=Path( | |
hf_hub_download( | |
repo_id="refiners/juggernaut.reborn.sd1_5.autoencoder", | |
filename="model.safetensors", | |
revision="3c1aae3fc3e03e4a2b7e0fa42b62ebb64f1a4c19", | |
) | |
), | |
controlnet_tile=Path( | |
hf_hub_download( | |
repo_id="refiners/controlnet.sd1_5.tile", | |
filename="model.safetensors", | |
revision="48ced6ff8bfa873a8976fa467c3629a240643387", | |
) | |
), | |
esrgan=Path( | |
hf_hub_download( | |
repo_id="philz1337x/upscaler", | |
filename="4x-UltraSharp.pth", | |
revision="011deacac8270114eb7d2eeff4fe6fa9a837be70", | |
) | |
), | |
negative_embedding=Path( | |
hf_hub_download( | |
repo_id="philz1337x/embeddings", | |
filename="JuggernautNegative-neg.pt", | |
revision="203caa7e9cc2bc225031a4021f6ab1ded283454a", | |
) | |
), | |
negative_embedding_key="string_to_param.*", | |
loras={ | |
"more_details": Path( | |
hf_hub_download( | |
repo_id="philz1337x/loras", | |
filename="more_details.safetensors", | |
revision="a3802c0280c0d00c2ab18d37454a8744c44e474e", | |
) | |
), | |
"sdxl_render": Path( | |
hf_hub_download( | |
repo_id="philz1337x/loras", | |
filename="SDXLrender_v2.0.safetensors", | |
revision="a3802c0280c0d00c2ab18d37454a8744c44e474e", | |
) | |
) | |
} | |
) | |
device = DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 | |
enhancer = ESRGANUpscaler(checkpoints=CHECKPOINTS, device=DEVICE, dtype=DTYPE) | |
# logging | |
warnings.filterwarnings("ignore") | |
root = logging.getLogger() | |
root.setLevel(logging.WARN) | |
handler = logging.StreamHandler(sys.stderr) | |
handler.setLevel(logging.WARN) | |
formatter = logging.Formatter('\n >>> [%(levelname)s] %(asctime)s %(name)s: %(message)s\n') | |
handler.setFormatter(formatter) | |
root.addHandler(handler) | |
# constant data | |
base = "black-forest-labs/FLUX.1-schnell" | |
# precision data | |
seq=256 | |
width=768 | |
height=768 | |
image_steps=8 | |
img_accu=0 | |
# ui data | |
css="".join([""" | |
input, input::placeholder { | |
text-align: center !important; | |
} | |
*, *::placeholder { | |
font-family: Suez One !important; | |
} | |
h1,h2,h3,h4,h5,h6 { | |
width: 100%; | |
text-align: center; | |
} | |
footer { | |
display: none !important; | |
} | |
.image-container { | |
aspect-ratio: """,str(width),"/",str(height),""" !important; | |
} | |
.dropdown-arrow { | |
display: none !important; | |
} | |
*:has(>.btn) { | |
display: flex; | |
justify-content: space-evenly; | |
align-items: center; | |
} | |
.btn { | |
display: flex; | |
} | |
"""]) | |
js=""" | |
function custom(){ | |
document.querySelector("div#prompt input").addEventListener("keydown",function(e){ | |
e.target.setAttribute("last_value",e.target.value); | |
}); | |
document.querySelector("div#prompt input").addEventListener("input",function(e){ | |
if( e.target.value.toString().match(/[^ a-zA-Z,]|( |,){2,}/gsm) ){ | |
e.target.value = e.target.getAttribute("last_value"); | |
e.target.removeAttribute("last_value"); | |
} | |
}); | |
document.querySelector("div#prompt2 input").addEventListener("keydown",function(e){ | |
e.target.setAttribute("last_value",e.target.value); | |
}); | |
document.querySelector("div#prompt2 input").addEventListener("input",function(e){ | |
if( e.target.value.toString().match(/[^ a-zA-Z,]|( |,){2,}/gsm) ){ | |
e.target.value = e.target.getAttribute("last_value"); | |
e.target.removeAttribute("last_value"); | |
} | |
}); | |
} | |
""" | |
# torch pipes | |
image_pipe = FluxPipeline.from_pretrained(base, torch_dtype=torch.bfloat16).to(device) | |
image_pipe.enable_model_cpu_offload() | |
image_pipe.enable_vae_slicing() | |
image_pipe.enable_vae_tiling() | |
# functionality | |
def upscaler( | |
input_image: Image.Image, | |
prompt: str = "Photorealistic, Hyperrealistic, Realistic Photography, High-Quality Photography, Natural.", | |
negative_prompt: str = "Distorted, Discontinuous, Blurry, Doll-Like, Overly-Plastic, Low-Quality, Painted, Smoothed, Artificial, Phony, Gaudy, Digital Effects.", | |
seed: int = int(str(random.random()).split(".")[1]), | |
upscale_factor: int = 2, | |
controlnet_scale: float = 0.6, | |
controlnet_decay: float = 1.0, | |
condition_scale: int = 6, | |
tile_width: int = 112, | |
tile_height: int = 144, | |
denoise_strength: float = 0.35, | |
num_inference_steps: int = 30, | |
solver: str = "DDIM", | |
) -> Image.Image: | |
log(f'CALL upscaler') | |
manual_seed(seed) | |
solver_type: type[Solver] = getattr(solvers, solver) | |
log(f'DBG upscaler 1') | |
enhanced_image = enhancer.upscale( | |
image=input_image, | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
upscale_factor=upscale_factor, | |
controlnet_scale=controlnet_scale, | |
controlnet_scale_decay=controlnet_decay, | |
condition_scale=condition_scale, | |
tile_size=(tile_height, tile_width), | |
denoise_strength=denoise_strength, | |
num_inference_steps=num_inference_steps, | |
loras_scale={"more_details": 0.5, "sdxl_render": 1.0}, | |
solver_type=solver_type, | |
) | |
log(f'RET upscaler') | |
return enhanced_image | |
def get_tensor_length(tensor): | |
nums = list(tensor.size()) | |
ret = 1 | |
for num in nums: | |
ret = ret * num | |
return ret | |
def summarize( | |
text, max_len=20, min_len=10 | |
): | |
log(f'CALL summarize') | |
words = text.split() | |
if len(words) < 5: | |
print("Summarization Error: Text is too short, 5 words minimum!") | |
return text | |
prefix = "summarize: " | |
ret = "" | |
for index in range(math.ceil( len(words) / 500 )): | |
chunk = " ".join(words[ index*500:(index+1)*500 ]) | |
inputs = tokenizer.encode( prefix + chunk, return_tensors="pt", truncation=False, add_special_tokens=True) | |
while get_tensor_length(inputs) > max_len: | |
inputs = model.generate( | |
inputs, | |
length_penalty=2.0, | |
num_beams=4, | |
early_stopping=True, | |
max_length=max( get_tensor_length(inputs) // 4 , max_len ), | |
min_length=min_len | |
) | |
toks = tokenizer.decode(inputs[0], skip_special_tokens=True) | |
ret = ret + ("" if ret == "" else " ") + toks | |
inputs = tokenizer.encode( prefix + ret, return_tensors="pt", truncation=False) | |
gen = model.generate( | |
inputs, | |
length_penalty=1.0, | |
num_beams=4, | |
early_stopping=True, | |
max_length=max_len, | |
min_length=min_len | |
) | |
summary = tokenizer.decode(gen[0], skip_special_tokens=True) | |
log(f'RET summarize with summary as {summary}') | |
return summary | |
def generate_random_string(length): | |
characters = str(ascii_letters + digits) | |
return ''.join(random.choice(characters) for _ in range(length)) | |
def pipe_generate_image(p1,p2): | |
log(f'CALL pipe_generate') | |
imgs = image_pipe( | |
prompt=p1, | |
negative_prompt=p2, | |
height=height, | |
width=width, | |
guidance_scale=img_accu, | |
num_images_per_prompt=1, | |
num_inference_steps=image_steps, | |
max_sequence_length=seq, | |
generator=torch.Generator(device).manual_seed(int(str(random.random()).split(".")[1])) | |
).images | |
log(f'RET pipe_generate') | |
return imgs | |
def add_song_cover_text(img,artist,song,height,width): | |
draw = ImageDraw.Draw(img,mode="RGBA") | |
rows = 1 | |
labels_distance = 1/3 | |
textheight=min(math.ceil( width / 10 ), math.ceil( height / 5 )) | |
font = ImageFont.truetype(r"Alef-Bold.ttf", textheight) | |
textwidth = draw.textlength(song,font) | |
x = math.ceil((width - textwidth) / 2) | |
y = height - (textheight * rows / 2) - (height / 2) | |
y = math.ceil(y - (height / 2 * labels_distance)) | |
draw.text((x, y), song, (255,255,255,85), font=font, spacing=2, stroke_width=math.ceil(textheight/20), stroke_fill=(0,0,0,170)) | |
textheight=min(math.ceil( width / 10 ), math.ceil( height / 5 )) | |
font = ImageFont.truetype(r"Alef-Bold.ttf", textheight) | |
textwidth = draw.textlength(artist,font) | |
x = math.ceil((width - textwidth) / 2) | |
y = height - (textheight * rows / 2) - (height / 2) | |
y = math.ceil(y + (height / 2 * labels_distance)) | |
draw.text((x, y), artist, (0,0,0,85), font=font, spacing=2, stroke_width=math.ceil(textheight/20), stroke_fill=(255,255,255,170)) | |
return img | |
def all_pipes(pos,neg,artist,song): | |
imgs = pipe_generate_image(pos,neg) | |
for i in range(len(imgs)): | |
imgs[i] = upscaler(imgs[i]) | |
return imgs | |
google_translate_endpoint = "https://translate.google.com/m" | |
language_codes = { | |
"afrikaans": "af", | |
"albanian": "sq", | |
"amharic": "am", | |
"arabic": "ar", | |
"armenian": "hy", | |
"assamese": "as", | |
"aymara": "ay", | |
"azerbaijani": "az", | |
"bambara": "bm", | |
"basque": "eu", | |
"belarusian": "be", | |
"bengali": "bn", | |
"bhojpuri": "bho", | |
"bosnian": "bs", | |
"bulgarian": "bg", | |
"catalan": "ca", | |
"cebuano": "ceb", | |
"chichewa": "ny", | |
"chinese (simplified)": "zh-CN", | |
"chinese (traditional)": "zh-TW", | |
"corsican": "co", | |
"croatian": "hr", | |
"czech": "cs", | |
"danish": "da", | |
"dhivehi": "dv", | |
"dogri": "doi", | |
"dutch": "nl", | |
"english": "en", | |
"esperanto": "eo", | |
"estonian": "et", | |
"ewe": "ee", | |
"filipino": "tl", | |
"finnish": "fi", | |
"french": "fr", | |
"frisian": "fy", | |
"galician": "gl", | |
"georgian": "ka", | |
"german": "de", | |
"greek": "el", | |
"guarani": "gn", | |
"gujarati": "gu", | |
"haitian creole": "ht", | |
"hausa": "ha", | |
"hawaiian": "haw", | |
"hebrew": "iw", | |
"hindi": "hi", | |
"hmong": "hmn", | |
"hungarian": "hu", | |
"icelandic": "is", | |
"igbo": "ig", | |
"ilocano": "ilo", | |
"indonesian": "id", | |
"irish": "ga", | |
"italian": "it", | |
"japanese": "ja", | |
"javanese": "jw", | |
"kannada": "kn", | |
"kazakh": "kk", | |
"khmer": "km", | |
"kinyarwanda": "rw", | |
"konkani": "gom", | |
"korean": "ko", | |
"krio": "kri", | |
"kurdish (kurmanji)": "ku", | |
"kurdish (sorani)": "ckb", | |
"kyrgyz": "ky", | |
"lao": "lo", | |
"latin": "la", | |
"latvian": "lv", | |
"lingala": "ln", | |
"lithuanian": "lt", | |
"luganda": "lg", | |
"luxembourgish": "lb", | |
"macedonian": "mk", | |
"maithili": "mai", | |
"malagasy": "mg", | |
"malay": "ms", | |
"malayalam": "ml", | |
"maltese": "mt", | |
"maori": "mi", | |
"marathi": "mr", | |
"meiteilon (manipuri)": "mni-Mtei", | |
"mizo": "lus", | |
"mongolian": "mn", | |
"myanmar": "my", | |
"nepali": "ne", | |
"norwegian": "no", | |
"odia (oriya)": "or", | |
"oromo": "om", | |
"pashto": "ps", | |
"persian": "fa", | |
"polish": "pl", | |
"portuguese": "pt", | |
"punjabi": "pa", | |
"quechua": "qu", | |
"romanian": "ro", | |
"russian": "ru", | |
"samoan": "sm", | |
"sanskrit": "sa", | |
"scots gaelic": "gd", | |
"sepedi": "nso", | |
"serbian": "sr", | |
"sesotho": "st", | |
"shona": "sn", | |
"sindhi": "sd", | |
"sinhala": "si", | |
"slovak": "sk", | |
"slovenian": "sl", | |
"somali": "so", | |
"spanish": "es", | |
"sundanese": "su", | |
"swahili": "sw", | |
"swedish": "sv", | |
"tajik": "tg", | |
"tamil": "ta", | |
"tatar": "tt", | |
"telugu": "te", | |
"thai": "th", | |
"tigrinya": "ti", | |
"tsonga": "ts", | |
"turkish": "tr", | |
"turkmen": "tk", | |
"twi": "ak", | |
"ukrainian": "uk", | |
"urdu": "ur", | |
"uyghur": "ug", | |
"uzbek": "uz", | |
"vietnamese": "vi", | |
"welsh": "cy", | |
"xhosa": "xh", | |
"yiddish": "yi", | |
"yoruba": "yo", | |
"zulu": "zu", | |
} | |
class BaseTranslator(ABC): | |
""" | |
Abstract class that serve as a base translator for other different translators | |
""" | |
def __init__( | |
self, | |
base_url: str = None, | |
languages: dict = language_codes, | |
source: str = "auto", | |
target: str = "en", | |
payload_key: Optional[str] = None, | |
element_tag: Optional[str] = None, | |
element_query: Optional[dict] = None, | |
**url_params, | |
): | |
""" | |
@param source: source language to translate from | |
@param target: target language to translate to | |
""" | |
self._base_url = base_url | |
self._languages = languages | |
self._supported_languages = list(self._languages.keys()) | |
if not source: | |
raise InvalidSourceOrTargetLanguage(source) | |
if not target: | |
raise InvalidSourceOrTargetLanguage(target) | |
self._source, self._target = self._map_language_to_code(source, target) | |
self._url_params = url_params | |
self._element_tag = element_tag | |
self._element_query = element_query | |
self.payload_key = payload_key | |
super().__init__() | |
def source(self): | |
return self._source | |
def source(self, lang): | |
self._source = lang | |
def target(self): | |
return self._target | |
def target(self, lang): | |
self._target = lang | |
def _type(self): | |
return self.__class__.__name__ | |
def _map_language_to_code(self, *languages): | |
""" | |
map language to its corresponding code (abbreviation) if the language was passed | |
by its full name by the user | |
@param languages: list of languages | |
@return: mapped value of the language or raise an exception if the language is | |
not supported | |
""" | |
for language in languages: | |
if language in self._languages.values() or language == "auto": | |
yield language | |
elif language in self._languages.keys(): | |
yield self._languages[language] | |
else: | |
raise LanguageNotSupportedException( | |
language, | |
message=f"No support for the provided language.\n" | |
f"Please select on of the supported languages:\n" | |
f"{self._languages}", | |
) | |
def _same_source_target(self) -> bool: | |
return self._source == self._target | |
def get_supported_languages( | |
self, as_dict: bool = False, **kwargs | |
) -> Union[list, dict]: | |
""" | |
return the supported languages by the Google translator | |
@param as_dict: if True, the languages will be returned as a dictionary | |
mapping languages to their abbreviations | |
@return: list or dict | |
""" | |
return self._supported_languages if not as_dict else self._languages | |
def is_language_supported(self, language: str, **kwargs) -> bool: | |
""" | |
check if the language is supported by the translator | |
@param language: a string for 1 language | |
@return: bool or raise an Exception | |
""" | |
if ( | |
language == "auto" | |
or language in self._languages.keys() | |
or language in self._languages.values() | |
): | |
return True | |
else: | |
return False | |
def translate(self, text: str, **kwargs) -> str: | |
""" | |
translate a text using a translator under the hood and return | |
the translated text | |
@param text: text to translate | |
@param kwargs: additional arguments | |
@return: str | |
""" | |
return NotImplemented("You need to implement the translate method!") | |
def _read_docx(self, f: str): | |
import docx2txt | |
return docx2txt.process(f) | |
def _read_pdf(self, f: str): | |
import pypdf | |
reader = pypdf.PdfReader(f) | |
page = reader.pages[0] | |
return page.extract_text() | |
def _translate_file(self, path: str, **kwargs) -> str: | |
""" | |
translate directly from file | |
@param path: path to the target file | |
@type path: str | |
@param kwargs: additional args | |
@return: str | |
""" | |
if not isinstance(path, Path): | |
path = Path(path) | |
if not path.exists(): | |
print("Path to the file is wrong!") | |
exit(1) | |
ext = path.suffix | |
if ext == ".docx": | |
text = self._read_docx(f=str(path)) | |
elif ext == ".pdf": | |
text = self._read_pdf(f=str(path)) | |
else: | |
with open(path, "r", encoding="utf-8") as f: | |
text = f.read().strip() | |
return self.translate(text) | |
def _translate_batch(self, batch: List[str], **kwargs) -> List[str]: | |
""" | |
translate a list of texts | |
@param batch: list of texts you want to translate | |
@return: list of translations | |
""" | |
if not batch: | |
raise Exception("Enter your text list that you want to translate") | |
arr = [] | |
for i, text in enumerate(batch): | |
translated = self.translate(text, **kwargs) | |
arr.append(translated) | |
return arr | |
class GoogleTranslator(BaseTranslator): | |
""" | |
class that wraps functions, which use Google Translate under the hood to translate text(s) | |
""" | |
def __init__( | |
self, | |
source: str = "auto", | |
target: str = "en", | |
proxies: Optional[dict] = None, | |
**kwargs | |
): | |
""" | |
@param source: source language to translate from | |
@param target: target language to translate to | |
""" | |
self.proxies = proxies | |
super().__init__( | |
base_url=google_translate_endpoint, | |
source=source, | |
target=target, | |
element_tag="div", | |
element_query={"class": "t0"}, | |
payload_key="q", # key of text in the url | |
**kwargs | |
) | |
self._alt_element_query = {"class": "result-container"} | |
def translate(self, text: str, **kwargs) -> str: | |
""" | |
function to translate a text | |
@param text: desired text to translate | |
@return: str: translated text | |
""" | |
if is_input_valid(text, max_chars=1000): | |
text = text.strip() | |
if self._same_source_target() or is_empty(text): | |
return text | |
self._url_params["tl"] = self._target | |
self._url_params["sl"] = self._source | |
if self.payload_key: | |
self._url_params[self.payload_key] = text | |
response = requests.get( | |
self._base_url, params=self._url_params, proxies=self.proxies | |
) | |
if response.status_code == 429: | |
raise TooManyRequests() | |
if request_failed(status_code=response.status_code): | |
raise RequestError() | |
soup = BeautifulSoup(response.text, "html.parser") | |
element = soup.find(self._element_tag, self._element_query) | |
response.close() | |
if not element: | |
element = soup.find(self._element_tag, self._alt_element_query) | |
if not element: | |
raise TranslationNotFound(text) | |
if element.get_text(strip=True) == text.strip(): | |
to_translate_alpha = "".join( | |
ch for ch in text.strip() if ch.isalnum() | |
) | |
translated_alpha = "".join( | |
ch for ch in element.get_text(strip=True) if ch.isalnum() | |
) | |
if ( | |
to_translate_alpha | |
and translated_alpha | |
and to_translate_alpha == translated_alpha | |
): | |
self._url_params["tl"] = self._target | |
if "hl" not in self._url_params: | |
return text.strip() | |
del self._url_params["hl"] | |
return self.translate(text) | |
else: | |
return element.get_text(strip=True) | |
def translate_file(self, path: str, **kwargs) -> str: | |
""" | |
translate directly from file | |
@param path: path to the target file | |
@type path: str | |
@param kwargs: additional args | |
@return: str | |
""" | |
return self._translate_file(path, **kwargs) | |
def translate_batch(self, batch: List[str], **kwargs) -> List[str]: | |
""" | |
translate a list of texts | |
@param batch: list of texts you want to translate | |
@return: list of translations | |
""" | |
return self._translate_batch(batch, **kwargs) | |
def translate(txt,to_lang="en",from_lang="auto"): | |
log(f'CALL translate') | |
translator = GoogleTranslator(from_lang=from_lang,to_lang=to_lang) | |
translation = "" | |
if len(txt) > 1000: | |
words = txt.split() | |
while len(words) > 0: | |
chunk = "" | |
while len(words) > 0 and len(chunk) < 1000: | |
chunk = chunk + " " + words[0] | |
words = words[1:] | |
if len(chunk) > 1000: | |
_words = chunk.split() | |
words = [_words[-1], *words] | |
chunk = " ".join(_words[:-1]) | |
translation = translation + " " + translator.translate(chunk) | |
else: | |
translation = translator.translate(txt) | |
translation = translation.strip() | |
log(f'RET translate with translation as {translation}') | |
return translation | |
def handle_generation(artist,song,genre,lyrics): | |
log(f'CALL handle_generate') | |
pos_artist = re.sub("([ \t\n]){1,}", " ", artist).upper().strip() | |
pos_song = re.sub("([ \t\n]){1,}", " ", song).lower().strip() | |
pos_song = ' '.join(word[0].upper() + word[1:] for word in pos_song.split()) | |
pos_genre = re.sub(f'[{punctuation}]', '', re.sub("([ \t\n]){1,}", " ", genre)).lower().strip() | |
pos_genre = ' '.join(word[0].upper() + word[1:] for word in pos_genre.split()) | |
pos_lyrics = re.sub(f'[{punctuation}]', '', re.sub("([ \t\n]){1,}", " ", lyrics)).lower().strip() | |
pos_lyrics_sum = pos_lyrics if pos_lyrics == "" else summarize(translate(pos_lyrics)) | |
neg = f"Sexuality, Humanity, Textual, Labeled, Distorted, Discontinuous, Blurry, Doll-Like, Overly Plastic, Low-Quality, Painted, Smoothed, Artificial, Phony, Gaudy, Digital Effects." | |
q = "\"" | |
pos = f'HQ Hyper-realistic { translate(pos_genre) } song "{ translate(pos_song) }"{ pos_lyrics_sum if pos_lyrics_sum == "" else ": " + pos_lyrics_sum }.' | |
print(f""" | |
Positive: {pos} | |
Negative: {neg} | |
""") | |
imgs = all_pipes(pos,neg,pos_artist,pos_song) | |
index = 1 | |
names = [] | |
for img in imgs: | |
scaled_by = 2 | |
labeled_img = add_song_cover_text(img,artist,song,height*scaled_by,width*scaled_by) | |
name = f'{artist} - {song} ({index}).png' | |
labeled_img.save(name) | |
names.append(name) | |
index = index + 1 | |
# return names | |
return names[0] | |
# entry | |
if __name__ == "__main__": | |
with gr.Blocks(theme=gr.themes.Citrus(),css=css) as demo: | |
gr.Markdown(f""" | |
# Song Cover Image Generator | |
""") | |
with gr.Row(): | |
with gr.Column(scale=4): | |
artist = gr.Textbox( | |
placeholder="Artist name", | |
value="", | |
container=False, | |
max_lines=1 | |
) | |
song = gr.Textbox( | |
placeholder="Song name", | |
value="", | |
container=False, | |
max_lines=1 | |
) | |
genre = gr.Textbox( | |
placeholder="Genre", | |
value="", | |
container=False, | |
max_lines=1 | |
) | |
lyrics = gr.Textbox( | |
placeholder="Lyrics", | |
value="", | |
container=False, | |
max_lines=1 | |
) | |
with gr.Column(): | |
cover = gr.Image(interactive=False,container=False,elem_classes="image-container", label="Result", show_label=True, type='filepath', show_share_button=False) | |
run = gr.Button("Generate",elem_classes="btn") | |
run.click( | |
fn=handle_generation, | |
inputs=[artist,song,genre,lyrics], | |
outputs=[cover] | |
) | |
demo.queue().launch() | |