Kokoro-API-2 / app.py
Yaron Koresh
Update app.py
89e3e87 verified
raw
history blame
49.9 kB
"""
Modified parts included from these sources:
- https://github.com/nidhaloff/deep-translator
- https://huggingface.co/spaces/ostris/Flex.1-alpha
"""
import urllib
import requests
from bs4 import BeautifulSoup
from abc import ABC, abstractmethod
from pathlib import Path
from langdetect import detect as get_language
from typing import Any, Dict, List, Optional, Union
from collections import namedtuple
from inspect import signature
import os
import subprocess
import logging
import re
import random
from string import ascii_letters, digits, punctuation
import requests
import sys
import warnings
import time
import asyncio
import math
from pathlib import Path
from functools import partial
from dataclasses import dataclass
from typing import Any
import pillow_heif
import spaces
import numpy as np
import numpy.typing as npt
import torch
from torch import nn
import gradio as gr
from lxml.html import fromstring
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file, save_file
from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL, FluxPipeline, FlowMatchEulerDiscreteScheduler
from PIL import Image, ImageDraw, ImageFont
from transformers import pipeline, T5ForConditionalGeneration, T5Tokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
from refiners.fluxion.utils import manual_seed
from refiners.foundationals.latent_diffusion import Solver, solvers
from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_upscaler import (
MultiUpscaler,
UpscalerCheckpoints,
)
from datetime import datetime
_HEIGHT_ = None
_WIDTH_ = None
working = False
model = T5ForConditionalGeneration.from_pretrained("t5-base")
tokenizer = T5Tokenizer.from_pretrained("t5-base")
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.16,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
# FLUX pipeline function
@torch.inference_mode()
def flux_pipe_call_that_returns_an_iterable_of_images(
self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 28,
timesteps: List[int] = None,
guidance_scale: float = 3.5,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
max_sequence_length: int = 512,
good_vae: Optional[Any] = None,
):
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
# 1. Check inputs
self.check_inputs(
prompt,
prompt_2,
height,
width,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
max_sequence_length=max_sequence_length,
)
self._guidance_scale = guidance_scale
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
# 2. Define call parameters
batch_size = 1 if isinstance(prompt, str) else len(prompt)
device = self._execution_device
# 3. Encode prompt
lora_scale = joint_attention_kwargs.get("scale", None) if joint_attention_kwargs is not None else None
prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(
prompt=prompt,
prompt_2=prompt_2,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
# 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 4
latents, latent_image_ids = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 5. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
image_seq_len = latents.shape[1]
mu = calculate_shift(
image_seq_len,
self.scheduler.config.base_image_seq_len,
self.scheduler.config.max_image_seq_len,
self.scheduler.config.base_shift,
self.scheduler.config.max_shift,
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
timesteps,
sigmas,
mu=mu,
)
self._num_timesteps = len(timesteps)
# Handle guidance
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32).expand(latents.shape[0]) if self.transformer.config.guidance_embeds else None
# 6. Denoising loop
for i, t in enumerate(timesteps):
if self.interrupt:
continue
timestep = t.expand(latents.shape[0]).to(latents.dtype)
noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=pooled_prompt_embeds,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
# Yield intermediate result
latents_for_image = self._unpack_latents(latents, height, width, self.vae_scale_factor)
latents_for_image = (latents_for_image / self.vae.config.scaling_factor) + self.vae.config.shift_factor
image = self.vae.decode(latents_for_image, return_dict=False)[0]
yield self.image_processor.postprocess(image, output_type=output_type)[0]
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
torch.cuda.empty_cache()
# Final image using good_vae
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
latents = (latents / good_vae.config.scaling_factor) + good_vae.config.shift_factor
image = good_vae.decode(latents, return_dict=False)[0]
self.maybe_free_model_hooks()
torch.cuda.empty_cache()
yield self.image_processor.postprocess(image, output_type=output_type)[0]
def log(msg):
print(f'{datetime.now().time()} {msg}')
Tile = tuple[int, int, Image.Image]
Tiles = list[tuple[int, int, list[Tile]]]
def conv_block(in_nc: int, out_nc: int) -> nn.Sequential:
return nn.Sequential(
nn.Conv2d(in_nc, out_nc, kernel_size=3, padding=1),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
)
class ResidualDenseBlock_5C(nn.Module):
"""
Residual Dense Block
The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
Modified options that can be used:
- "Partial Convolution based Padding" arXiv:1811.11718
- "Spectral normalization" arXiv:1802.05957
- "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
{Rakotonirina} and A. {Rasoanaivo}
"""
def __init__(self, nf: int = 64, gc: int = 32) -> None:
super().__init__() # type: ignore[reportUnknownMemberType]
self.conv1 = conv_block(nf, gc)
self.conv2 = conv_block(nf + gc, gc)
self.conv3 = conv_block(nf + 2 * gc, gc)
self.conv4 = conv_block(nf + 3 * gc, gc)
# Wrapped in Sequential because of key in state dict.
self.conv5 = nn.Sequential(nn.Conv2d(nf + 4 * gc, nf, kernel_size=3, padding=1))
def forward(self, x: torch.Tensor) -> torch.Tensor:
x1 = self.conv1(x)
x2 = self.conv2(torch.cat((x, x1), 1))
x3 = self.conv3(torch.cat((x, x1, x2), 1))
x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
return x5 * 0.2 + x
class RRDB(nn.Module):
"""
Residual in Residual Dense Block
(ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
"""
def __init__(self, nf: int) -> None:
super().__init__() # type: ignore[reportUnknownMemberType]
self.RDB1 = ResidualDenseBlock_5C(nf)
self.RDB2 = ResidualDenseBlock_5C(nf)
self.RDB3 = ResidualDenseBlock_5C(nf)
def forward(self, x: torch.Tensor) -> torch.Tensor:
out = self.RDB1(x)
out = self.RDB2(out)
out = self.RDB3(out)
return out * 0.2 + x
class Upsample2x(nn.Module):
"""Upsample 2x."""
def __init__(self) -> None:
super().__init__() # type: ignore[reportUnknownMemberType]
def forward(self, x: torch.Tensor) -> torch.Tensor:
return nn.functional.interpolate(x, scale_factor=2.0) # type: ignore
class ShortcutBlock(nn.Module):
"""Elementwise sum the output of a submodule to its input"""
def __init__(self, submodule: nn.Module) -> None:
super().__init__() # type: ignore[reportUnknownMemberType]
self.sub = submodule
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.sub(x)
class RRDBNet(nn.Module):
def __init__(self, in_nc: int, out_nc: int, nf: int, nb: int) -> None:
super().__init__() # type: ignore[reportUnknownMemberType]
assert in_nc % 4 != 0 # in_nc is 3
self.model = nn.Sequential(
nn.Conv2d(in_nc, nf, kernel_size=3, padding=1),
ShortcutBlock(
nn.Sequential(
*(RRDB(nf) for _ in range(nb)),
nn.Conv2d(nf, nf, kernel_size=3, padding=1),
)
),
Upsample2x(),
nn.Conv2d(nf, nf, kernel_size=3, padding=1),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
Upsample2x(),
nn.Conv2d(nf, nf, kernel_size=3, padding=1),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(nf, nf, kernel_size=3, padding=1),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(nf, out_nc, kernel_size=3, padding=1),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.model(x)
def infer_params(state_dict: dict[str, torch.Tensor]) -> tuple[int, int, int, int, int]:
# this code is adapted from https://github.com/victorca25/iNNfer
scale2x = 0
scalemin = 6
n_uplayer = 0
out_nc = 0
nb = 0
for block in list(state_dict):
parts = block.split(".")
n_parts = len(parts)
if n_parts == 5 and parts[2] == "sub":
nb = int(parts[3])
elif n_parts == 3:
part_num = int(parts[1])
if part_num > scalemin and parts[0] == "model" and parts[2] == "weight":
scale2x += 1
if part_num > n_uplayer:
n_uplayer = part_num
out_nc = state_dict[block].shape[0]
assert "conv1x1" not in block # no ESRGANPlus
nf = state_dict["model.0.weight"].shape[0]
in_nc = state_dict["model.0.weight"].shape[1]
scale = 2**scale2x
assert out_nc > 0
assert nb > 0
return in_nc, out_nc, nf, nb, scale # 3, 3, 64, 23, 4
# https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/images.py#L64
Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])
# adapted from https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/images.py#L67
def split_grid(image: Image.Image, tile_w: int = 512, tile_h: int = 512, overlap: int = 64) -> Grid:
w = image.width
h = image.height
non_overlap_width = tile_w - overlap
non_overlap_height = tile_h - overlap
cols = max(1, math.ceil((w - overlap) / non_overlap_width))
rows = max(1, math.ceil((h - overlap) / non_overlap_height))
dx = (w - tile_w) / (cols - 1) if cols > 1 else 0
dy = (h - tile_h) / (rows - 1) if rows > 1 else 0
grid = Grid([], tile_w, tile_h, w, h, overlap)
for row in range(rows):
row_images: list[Tile] = []
y1 = max(min(int(row * dy), h - tile_h), 0)
y2 = min(y1 + tile_h, h)
for col in range(cols):
x1 = max(min(int(col * dx), w - tile_w), 0)
x2 = min(x1 + tile_w, w)
tile = image.crop((x1, y1, x2, y2))
row_images.append((x1, tile_w, tile))
grid.tiles.append((y1, tile_h, row_images))
return grid
# https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/images.py#L104
def combine_grid(grid: Grid):
def make_mask_image(r: npt.NDArray[np.float32]) -> Image.Image:
r = r * 255 / grid.overlap
return Image.fromarray(r.astype(np.uint8), "L")
mask_w = make_mask_image(
np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0)
)
mask_h = make_mask_image(
np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1)
)
combined_image = Image.new("RGB", (grid.image_w, grid.image_h))
for y, h, row in grid.tiles:
combined_row = Image.new("RGB", (grid.image_w, h))
for x, w, tile in row:
if x == 0:
combined_row.paste(tile, (0, 0))
continue
combined_row.paste(tile.crop((0, 0, grid.overlap, h)), (x, 0), mask=mask_w)
combined_row.paste(tile.crop((grid.overlap, 0, w, h)), (x + grid.overlap, 0))
if y == 0:
combined_image.paste(combined_row, (0, 0))
continue
combined_image.paste(
combined_row.crop((0, 0, combined_row.width, grid.overlap)),
(0, y),
mask=mask_h,
)
combined_image.paste(
combined_row.crop((0, grid.overlap, combined_row.width, h)),
(0, y + grid.overlap),
)
return combined_image
class UpscalerESRGAN:
def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype):
self.model_path = model_path
self.device = device
self.model = self.load_model(model_path)
self.to(device, dtype)
def __call__(self, img: Image.Image) -> Image.Image:
return self.upscale_without_tiling(img)
def to(self, device: torch.device, dtype: torch.dtype):
self.device = device
self.dtype = dtype
self.model.to(device=device, dtype=dtype)
def load_model(self, path: Path) -> RRDBNet:
filename = path
state_dict: dict[str, torch.Tensor] = torch.load(filename, weights_only=True, map_location=self.device) # type: ignore
in_nc, out_nc, nf, nb, upscale = infer_params(state_dict)
assert upscale == 4, "Only 4x upscaling is supported"
model = RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb)
model.load_state_dict(state_dict)
model.eval()
return model
def upscale_without_tiling(self, img: Image.Image) -> Image.Image:
img_np = np.array(img)
img_np = img_np[:, :, ::-1]
img_np = np.ascontiguousarray(np.transpose(img_np, (2, 0, 1))) / 255
img_t = torch.from_numpy(img_np).float() # type: ignore
img_t = img_t.unsqueeze(0).to(device=self.device, dtype=self.dtype)
with torch.no_grad():
output = self.model(img_t)
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
output = 255.0 * np.moveaxis(output, 0, 2)
output = output.astype(np.uint8)
output = output[:, :, ::-1]
return Image.fromarray(output, "RGB")
# https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/esrgan_model.py#L208
def upscale_with_tiling(self, img: Image.Image) -> Image.Image:
img = img.convert("RGB")
grid = split_grid(img)
newtiles: Tiles = []
scale_factor: int = 1
for y, h, row in grid.tiles:
newrow: list[Tile] = []
for tiledata in row:
x, w, tile = tiledata
output = self.upscale_without_tiling(tile)
scale_factor = output.width // tile.width
newrow.append((x * scale_factor, w * scale_factor, output))
newtiles.append((y * scale_factor, h * scale_factor, newrow))
newgrid = Grid(
newtiles,
grid.tile_w * scale_factor,
grid.tile_h * scale_factor,
grid.image_w * scale_factor,
grid.image_h * scale_factor,
grid.overlap * scale_factor,
)
output = combine_grid(newgrid)
return output
@dataclass(kw_only=True)
class ESRGANUpscalerCheckpoints(UpscalerCheckpoints):
esrgan: Path
class ESRGANUpscaler(MultiUpscaler):
def __init__(
self,
checkpoints: ESRGANUpscalerCheckpoints,
device: torch.device,
dtype: torch.dtype,
) -> None:
super().__init__(checkpoints=checkpoints, device=device, dtype=dtype)
self.esrgan = UpscalerESRGAN(checkpoints.esrgan, device=self.device, dtype=self.dtype)
def to(self, device: torch.device, dtype: torch.dtype):
self.esrgan.to(device=device, dtype=dtype)
self.sd = self.sd.to(device=device, dtype=dtype)
self.device = device
self.dtype = dtype
def pre_upscale(self, image: Image.Image, upscale_factor: float, **_: Any) -> Image.Image:
image = self.esrgan.upscale_with_tiling(image)
return super().pre_upscale(image=image, upscale_factor=upscale_factor / 4)
pillow_heif.register_heif_opener()
pillow_heif.register_avif_opener()
CHECKPOINTS = ESRGANUpscalerCheckpoints(
unet=Path(
hf_hub_download(
repo_id="refiners/juggernaut.reborn.sd1_5.unet",
filename="model.safetensors",
revision="347d14c3c782c4959cc4d1bb1e336d19f7dda4d2",
)
),
clip_text_encoder=Path(
hf_hub_download(
repo_id="refiners/juggernaut.reborn.sd1_5.text_encoder",
filename="model.safetensors",
revision="744ad6a5c0437ec02ad826df9f6ede102bb27481",
)
),
lda=Path(
hf_hub_download(
repo_id="refiners/juggernaut.reborn.sd1_5.autoencoder",
filename="model.safetensors",
revision="3c1aae3fc3e03e4a2b7e0fa42b62ebb64f1a4c19",
)
),
controlnet_tile=Path(
hf_hub_download(
repo_id="refiners/controlnet.sd1_5.tile",
filename="model.safetensors",
revision="48ced6ff8bfa873a8976fa467c3629a240643387",
)
),
esrgan=Path(
hf_hub_download(
repo_id="philz1337x/upscaler",
filename="4x-UltraSharp.pth",
revision="011deacac8270114eb7d2eeff4fe6fa9a837be70",
)
),
negative_embedding=Path(
hf_hub_download(
repo_id="philz1337x/embeddings",
filename="JuggernautNegative-neg.pt",
revision="203caa7e9cc2bc225031a4021f6ab1ded283454a",
)
),
negative_embedding_key="string_to_param.*",
loras={
"more_details": Path(
hf_hub_download(
repo_id="philz1337x/loras",
filename="more_details.safetensors",
revision="a3802c0280c0d00c2ab18d37454a8744c44e474e",
)
),
"sdxl_render": Path(
hf_hub_download(
repo_id="philz1337x/loras",
filename="SDXLrender_v2.0.safetensors",
revision="a3802c0280c0d00c2ab18d37454a8744c44e474e",
)
)
}
)
device = DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DTYPE = dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
enhancer = ESRGANUpscaler(checkpoints=CHECKPOINTS, device="cpu", dtype=DTYPE)
# logging
warnings.filterwarnings("ignore")
root = logging.getLogger()
root.setLevel(logging.WARN)
handler = logging.StreamHandler(sys.stderr)
handler.setLevel(logging.WARN)
formatter = logging.Formatter('\n >>> [%(levelname)s] %(asctime)s %(name)s: %(message)s\n')
handler.setFormatter(formatter)
root.addHandler(handler)
# constant data
MAX_SEED = np.iinfo(np.int32).max
# precision data
seq=512
image_steps=25
img_accu=3.5
# ui data
css="".join(["""
input, textarea, input::placeholder, textarea::placeholder {
text-align: center !important;
}
*, *::placeholder {
font-family: Suez One !important;
}
h1,h2,h3,h4,h5,h6 {
width: 100%;
text-align: center;
}
footer {
display: none !important;
}
.image-container {
aspect-ratio: 1/1 !important;
}
.dropdown-arrow {
display: none !important;
}
*:has(>.btn) {
display: flex;
justify-content: space-evenly;
align-items: center;
}
.btn {
display: flex;
}
"""])
js="""
function custom(){
document.querySelector("div#prompt input").addEventListener("keydown",function(e){
e.target.setAttribute("last_value",e.target.value);
});
document.querySelector("div#prompt input").addEventListener("input",function(e){
if( e.target.value.toString().match(/[^ a-zA-Z,]|( |,){2,}/gsm) ){
e.target.value = e.target.getAttribute("last_value");
e.target.removeAttribute("last_value");
}
});
document.querySelector("div#prompt2 input").addEventListener("keydown",function(e){
e.target.setAttribute("last_value",e.target.value);
});
document.querySelector("div#prompt2 input").addEventListener("input",function(e){
if( e.target.value.toString().match(/[^ a-zA-Z,]|( |,){2,}/gsm) ){
e.target.value = e.target.getAttribute("last_value");
e.target.removeAttribute("last_value");
}
});
}
"""
# torch pipes
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
#good_vae = AutoencoderKL.from_pretrained("ostris/Flex.1-alpha", subfolder="vae", torch_dtype=dtype).to(device)
image_pipe = DiffusionPipeline.from_pretrained("ostris/Flex.1-alpha", torch_dtype=dtype, vae=taef1).to(device)
#image_pipe.enable_model_cpu_offload()
torch.cuda.empty_cache()
#image_pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(image_pipe)
# functionality
def upscaler(
input_image: Image.Image,
prompt: str = "Hyper realistic photography, Natural visual content.",
negative_prompt: str = "Distorted, Discontinuous, Blurry, Doll-Like, Overly-Plastic, Low-Quality, Painted, Smoothed, Artificial, Phony, Gaudy, Digital Effects.",
seed: int = random.randint(0, MAX_SEED),
upscale_factor: int = 2,
controlnet_scale: float = 0.6,
controlnet_decay: float = 1.0,
condition_scale: int = 6,
tile_width: int = 112,
tile_height: int = 144,
denoise_strength: float = 0.35,
num_inference_steps: int = 15,
solver: str = "DDIM",
) -> Image.Image:
log(f'CALL upscaler')
if not working:
working = True
manual_seed(seed)
solver_type: type[Solver] = getattr(solvers, solver)
log(f'DBG upscaler 1')
enhanced_image = enhancer.upscale(
image=input_image,
prompt=prompt,
negative_prompt=negative_prompt,
upscale_factor=upscale_factor,
controlnet_scale=controlnet_scale,
controlnet_scale_decay=controlnet_decay,
condition_scale=condition_scale,
tile_size=(tile_height, tile_width),
denoise_strength=denoise_strength,
num_inference_steps=num_inference_steps,
loras_scale={"more_details": 0.5, "sdxl_render": 1.0},
solver_type=solver_type,
)
_HEIGHT_ = _HEIGHT_ * upscale_factor
_WIDTH_ = _WIDTH_ * upscale_factor
log(f'RET upscaler')
working = False
return enhanced_image
def get_tensor_length(tensor):
nums = list(tensor.size())
ret = 1
for num in nums:
ret = ret * num
return ret
def _summarize(text):
log(f'CALL _summarize')
prefix = "summarize: "
toks = tokenizer.encode( prefix + text, return_tensors="pt", truncation=False)
gen = model.generate(
toks,
length_penalty=0.9,
num_beams=8,
early_stopping=True,
max_length=512
)
ret = tokenizer.decode(gen[0], skip_special_tokens=True)
log(f'RET _summarize with ret as {ret}')
return ret
def summarize(text, max_len=400):
log(f'CALL summarize')
words = text.split()
words_length = len(words)
if words_length >= 510:
while words_length >= 510:
words = text.split()
sum = _summarize(
" ".join(words[0:510])
) + " ".join(words[510:])
if summ == text:
return text
text = summ
words_length = len(text.split())
while len(text) > max_len:
summ = _summarize(text)
if summ == text:
return text
text = summ
log(f'RET summarize with text as {text}')
return text
def generate_random_string(length):
characters = str(ascii_letters + digits)
return ''.join(random.choice(characters) for _ in range(length))
def add_song_cover_text(img,top_title=None,bottom_title=none):
if not working:
working = True
h = _HEIGHT_
w = _WIDTH_
draw = ImageDraw.Draw(img,mode="RGBA")
labels_distance = 1/3
if top_title:
rows = len(song.split("\n"))
textheight=min(math.ceil( w / 10 ), math.ceil( h / 5 ))
font = ImageFont.truetype(r"Alef-Bold.ttf", textheight)
textwidth = draw.textlength(top_title,font)
x = math.ceil((w - textwidth) / 2)
y = h - (textheight * rows / 2) - (h / 2)
y = math.ceil(y - (h / 2 * labels_distance))
draw.text((x, y), top_title, (255,255,255), font=font, spacing=2, stroke_width=math.ceil(textheight/20), stroke_fill=(0,0,0))
if bottom_title:
rows = len(artist.split("\n"))
textheight=min(math.ceil( w / 10 ), math.ceil( h / 5 ))
font = ImageFont.truetype(r"Alef-Bold.ttf", textheight)
textwidth = draw.textlength(bottom_title,font)
x = math.ceil((w - textwidth) / 2)
y = h - (textheight * rows / 2) - (h / 2)
y = math.ceil(y + (h / 2 * labels_distance))
draw.text((x, y), bottom_title, (0,0,0), font=font, spacing=2, stroke_width=math.ceil(textheight/20), stroke_fill=(255,255,255))
working = False
return img
google_translate_endpoint = "https://translate.google.com/m"
language_codes = {
"afrikaans": "af",
"albanian": "sq",
"amharic": "am",
"arabic": "ar",
"armenian": "hy",
"assamese": "as",
"aymara": "ay",
"azerbaijani": "az",
"bambara": "bm",
"basque": "eu",
"belarusian": "be",
"bengali": "bn",
"bhojpuri": "bho",
"bosnian": "bs",
"bulgarian": "bg",
"catalan": "ca",
"cebuano": "ceb",
"chichewa": "ny",
"chinese (simplified)": "zh-CN",
"chinese (traditional)": "zh-TW",
"corsican": "co",
"croatian": "hr",
"czech": "cs",
"danish": "da",
"dhivehi": "dv",
"dogri": "doi",
"dutch": "nl",
"english": "en",
"esperanto": "eo",
"estonian": "et",
"ewe": "ee",
"filipino": "tl",
"finnish": "fi",
"french": "fr",
"frisian": "fy",
"galician": "gl",
"georgian": "ka",
"german": "de",
"greek": "el",
"guarani": "gn",
"gujarati": "gu",
"haitian creole": "ht",
"hausa": "ha",
"hawaiian": "haw",
"hebrew": "iw",
"hindi": "hi",
"hmong": "hmn",
"hungarian": "hu",
"icelandic": "is",
"igbo": "ig",
"ilocano": "ilo",
"indonesian": "id",
"irish": "ga",
"italian": "it",
"japanese": "ja",
"javanese": "jw",
"kannada": "kn",
"kazakh": "kk",
"khmer": "km",
"kinyarwanda": "rw",
"konkani": "gom",
"korean": "ko",
"krio": "kri",
"kurdish (kurmanji)": "ku",
"kurdish (sorani)": "ckb",
"kyrgyz": "ky",
"lao": "lo",
"latin": "la",
"latvian": "lv",
"lingala": "ln",
"lithuanian": "lt",
"luganda": "lg",
"luxembourgish": "lb",
"macedonian": "mk",
"maithili": "mai",
"malagasy": "mg",
"malay": "ms",
"malayalam": "ml",
"maltese": "mt",
"maori": "mi",
"marathi": "mr",
"meiteilon (manipuri)": "mni-Mtei",
"mizo": "lus",
"mongolian": "mn",
"myanmar": "my",
"nepali": "ne",
"norwegian": "no",
"odia (oriya)": "or",
"oromo": "om",
"pashto": "ps",
"persian": "fa",
"polish": "pl",
"portuguese": "pt",
"punjabi": "pa",
"quechua": "qu",
"romanian": "ro",
"russian": "ru",
"samoan": "sm",
"sanskrit": "sa",
"scots gaelic": "gd",
"sepedi": "nso",
"serbian": "sr",
"sesotho": "st",
"shona": "sn",
"sindhi": "sd",
"sinhala": "si",
"slovak": "sk",
"slovenian": "sl",
"somali": "so",
"spanish": "es",
"sundanese": "su",
"swahili": "sw",
"swedish": "sv",
"tajik": "tg",
"tamil": "ta",
"tatar": "tt",
"telugu": "te",
"thai": "th",
"tigrinya": "ti",
"tsonga": "ts",
"turkish": "tr",
"turkmen": "tk",
"twi": "ak",
"ukrainian": "uk",
"urdu": "ur",
"uyghur": "ug",
"uzbek": "uz",
"vietnamese": "vi",
"welsh": "cy",
"xhosa": "xh",
"yiddish": "yi",
"yoruba": "yo",
"zulu": "zu",
}
class BaseError(Exception):
"""
base error structure class
"""
def __init__(self, val, message):
"""
@param val: actual value
@param message: message shown to the user
"""
self.val = val
self.message = message
super().__init__()
def __str__(self):
return "{} --> {}".format(self.val, self.message)
class LanguageNotSupportedException(BaseError):
"""
exception thrown if the user uses a language
that is not supported by the deep_translator
"""
def __init__(
self, val, message="There is no support for the chosen language"
):
super().__init__(val, message)
class NotValidPayload(BaseError):
"""
exception thrown if the user enters an invalid payload
"""
def __init__(
self,
val,
message="text must be a valid text with maximum 5000 character,"
"otherwise it cannot be translated",
):
super(NotValidPayload, self).__init__(val, message)
class InvalidSourceOrTargetLanguage(BaseError):
"""
exception thrown if the user enters an invalid payload
"""
def __init__(self, val, message="Invalid source or target language!"):
super(InvalidSourceOrTargetLanguage, self).__init__(val, message)
class TranslationNotFound(BaseError):
"""
exception thrown if no translation was found for the text provided by the user
"""
def __init__(
self,
val,
message="No translation was found using the current translator. Try another translator?",
):
super(TranslationNotFound, self).__init__(val, message)
class ElementNotFoundInGetRequest(BaseError):
"""
exception thrown if the html element was not found in the body parsed by beautifulsoup
"""
def __init__(
self, val, message="Required element was not found in the API response"
):
super(ElementNotFoundInGetRequest, self).__init__(val, message)
class NotValidLength(BaseError):
"""
exception thrown if the provided text exceed the length limit of the translator
"""
def __init__(self, val, min_chars, max_chars):
message = f"Text length need to be between {min_chars} and {max_chars} characters"
super(NotValidLength, self).__init__(val, message)
class RequestError(Exception):
"""
exception thrown if an error occurred during the request call, e.g a connection problem.
"""
def __init__(
self,
message="Request exception can happen due to an api connection error. "
"Please check your connection and try again",
):
self.message = message
def __str__(self):
return self.message
class TooManyRequests(Exception):
"""
exception thrown if an error occurred during the request call, e.g a connection problem.
"""
def __init__(
self,
message="Server Error: You made too many requests to the server."
"According to google, you are allowed to make 5 requests per second"
"and up to 200k requests per day. You can wait and try again later or"
"you can try the translate_batch function",
):
self.message = message
def __str__(self):
return self.message
class ServerException(Exception):
"""
Default YandexTranslate exception from the official website
"""
errors = {
400: "ERR_BAD_REQUEST",
401: "ERR_KEY_INVALID",
402: "ERR_KEY_BLOCKED",
403: "ERR_DAILY_REQ_LIMIT_EXCEEDED",
404: "ERR_DAILY_CHAR_LIMIT_EXCEEDED",
413: "ERR_TEXT_TOO_LONG",
429: "ERR_TOO_MANY_REQUESTS",
422: "ERR_UNPROCESSABLE_TEXT",
500: "ERR_INTERNAL_SERVER_ERROR",
501: "ERR_LANG_NOT_SUPPORTED",
503: "ERR_SERVICE_NOT_AVAIBLE",
}
def __init__(self, status_code, *args):
message = self.errors.get(status_code, "API server error")
super(ServerException, self).__init__(message, *args)
def is_empty(text: str) -> bool:
return text == ""
def request_failed(status_code: int) -> bool:
"""Check if a request has failed or not.
A request is considered successfull if the status code is in the 2** range.
Args:
status_code (int): status code of the request
Returns:
bool: indicates request failure
"""
if status_code > 299 or status_code < 200:
return True
return False
def is_input_valid(
text: str, min_chars: int = 0, max_chars: Optional[int] = None
) -> bool:
"""
validate the target text to translate
@param min_chars: min characters
@param max_chars: max characters
@param text: text to translate
@return: bool
"""
if not isinstance(text, str):
raise NotValidPayload(text)
if max_chars and (not min_chars <= len(text) < max_chars):
raise NotValidLength(text, min_chars, max_chars)
return True
class BaseTranslator(ABC):
"""
Abstract class that serve as a base translator for other different translators
"""
def __init__(
self,
base_url: str = None,
languages: dict = language_codes,
source: str = "auto",
target: str = "en",
payload_key: Optional[str] = None,
element_tag: Optional[str] = None,
element_query: Optional[dict] = None,
**url_params,
):
"""
@param source: source language to translate from
@param target: target language to translate to
"""
self._base_url = base_url
self._languages = languages
self._supported_languages = list(self._languages.keys())
if not source:
raise InvalidSourceOrTargetLanguage(source)
if not target:
raise InvalidSourceOrTargetLanguage(target)
self._source, self._target = self._map_language_to_code(source, target)
self._url_params = url_params
self._element_tag = element_tag
self._element_query = element_query
self.payload_key = payload_key
super().__init__()
@property
def source(self):
return self._source
@source.setter
def source(self, lang):
self._source = lang
@property
def target(self):
return self._target
@target.setter
def target(self, lang):
self._target = lang
def _type(self):
return self.__class__.__name__
def _map_language_to_code(self, *languages):
"""
map language to its corresponding code (abbreviation) if the language was passed
by its full name by the user
@param languages: list of languages
@return: mapped value of the language or raise an exception if the language is
not supported
"""
for language in languages:
if language in self._languages.values() or language == "auto":
yield language
elif language in self._languages.keys():
yield self._languages[language]
else:
raise LanguageNotSupportedException(
language,
message=f"No support for the provided language.\n"
f"Please select on of the supported languages:\n"
f"{self._languages}",
)
def _same_source_target(self) -> bool:
return self._source == self._target
def get_supported_languages(
self, as_dict: bool = False, **kwargs
) -> Union[list, dict]:
"""
return the supported languages by the Google translator
@param as_dict: if True, the languages will be returned as a dictionary
mapping languages to their abbreviations
@return: list or dict
"""
return self._supported_languages if not as_dict else self._languages
def is_language_supported(self, language: str, **kwargs) -> bool:
"""
check if the language is supported by the translator
@param language: a string for 1 language
@return: bool or raise an Exception
"""
if (
language == "auto"
or language in self._languages.keys()
or language in self._languages.values()
):
return True
else:
return False
@abstractmethod
def translate(self, text: str, **kwargs) -> str:
"""
translate a text using a translator under the hood and return
the translated text
@param text: text to translate
@param kwargs: additional arguments
@return: str
"""
return NotImplemented("You need to implement the translate method!")
def _read_docx(self, f: str):
import docx2txt
return docx2txt.process(f)
def _read_pdf(self, f: str):
import pypdf
reader = pypdf.PdfReader(f)
page = reader.pages[0]
return page.extract_text()
def _translate_file(self, path: str, **kwargs) -> str:
"""
translate directly from file
@param path: path to the target file
@type path: str
@param kwargs: additional args
@return: str
"""
if not isinstance(path, Path):
path = Path(path)
if not path.exists():
print("Path to the file is wrong!")
exit(1)
ext = path.suffix
if ext == ".docx":
text = self._read_docx(f=str(path))
elif ext == ".pdf":
text = self._read_pdf(f=str(path))
else:
with open(path, "r", encoding="utf-8") as f:
text = f.read().strip()
return self.translate(text)
def _translate_batch(self, batch: List[str], **kwargs) -> List[str]:
"""
translate a list of texts
@param batch: list of texts you want to translate
@return: list of translations
"""
if not batch:
raise Exception("Enter your text list that you want to translate")
arr = []
for i, text in enumerate(batch):
translated = self.translate(text, **kwargs)
arr.append(translated)
return arr
class GoogleTranslator(BaseTranslator):
"""
class that wraps functions, which use Google Translate under the hood to translate text(s)
"""
def __init__(
self,
source: str = "auto",
target: str = "en",
proxies: Optional[dict] = None,
**kwargs
):
"""
@param source: source language to translate from
@param target: target language to translate to
"""
self.proxies = proxies
super().__init__(
base_url=google_translate_endpoint,
source=source,
target=target,
element_tag="div",
element_query={"class": "t0"},
payload_key="q", # key of text in the url
**kwargs
)
self._alt_element_query = {"class": "result-container"}
def translate(self, text: str, **kwargs) -> str:
"""
function to translate a text
@param text: desired text to translate
@return: str: translated text
"""
if is_input_valid(text, max_chars=1000):
text = text.strip()
if self._same_source_target() or is_empty(text):
return text
self._url_params["tl"] = self._target
self._url_params["sl"] = self._source
if self.payload_key:
self._url_params[self.payload_key] = text
response = requests.get(
self._base_url, params=self._url_params, proxies=self.proxies
)
if response.status_code == 429:
raise TooManyRequests()
if request_failed(status_code=response.status_code):
raise RequestError()
soup = BeautifulSoup(response.text, "html.parser")
element = soup.find(self._element_tag, self._element_query)
response.close()
if not element:
element = soup.find(self._element_tag, self._alt_element_query)
if not element:
raise TranslationNotFound(text)
if element.get_text(strip=True) == text.strip():
to_translate_alpha = "".join(
ch for ch in text.strip() if ch.isalnum()
)
translated_alpha = "".join(
ch for ch in element.get_text(strip=True) if ch.isalnum()
)
if (
to_translate_alpha
and translated_alpha
and to_translate_alpha == translated_alpha
):
self._url_params["tl"] = self._target
if "hl" not in self._url_params:
return text.strip()
del self._url_params["hl"]
return self.translate(text)
else:
return element.get_text(strip=True)
def translate_file(self, path: str, **kwargs) -> str:
"""
translate directly from file
@param path: path to the target file
@type path: str
@param kwargs: additional args
@return: str
"""
return self._translate_file(path, **kwargs)
def translate_batch(self, batch: List[str], **kwargs) -> List[str]:
"""
translate a list of texts
@param batch: list of texts you want to translate
@return: list of translations
"""
return self._translate_batch(batch, **kwargs)
def translate(txt,to_lang="en",from_lang="auto"):
log(f'CALL translate')
if len(txt) == 0:
print("Translated text is empty. Skipping translation...")
return txt.strip().lower()
if from_lang == to_lang or get_language(txt) == to_lang:
print("Same languages. Skipping translation...")
return txt.strip().lower()
translator = GoogleTranslator(from_lang=from_lang,to_lang=to_lang)
translation = ""
if len(txt) > 1000:
words = txt.split()
while len(words) > 0:
chunk = ""
while len(words) > 0 and len(chunk) < 1000:
chunk = chunk + " " + words[0]
words = words[1:]
if len(chunk) > 1000:
_words = chunk.split()
words = [_words[-1], *words]
chunk = " ".join(_words[:-1])
translation = translation + " " + translator.translate(chunk)
else:
translation = translator.translate(txt)
translation = translation.strip()
log(f'RET translate with translation as {translation}')
return translation.lower()
@spaces.GPU(duration=300)
def handle_generation(h,w,d):
log(f'CALL handle_generate')
if not working:
working = True
d = re.sub(r",( ){1,}",". ",d)
d_lines = re.split(r"([\n]){1,}", d)
for line_index in range(len(d_lines)):
d_lines[line_index] = d_lines[line_index].strip()
if re.sub(r'[\.]$', '', d_lines[line_index]) == d_lines[line_index]:
d_lines[line_index] = d_lines[line_index].strip() + "."
d = " ".join(d_lines)
pos_d = re.sub(r"([ \t]){1,}", " ", d).lower().strip()
pos_d = pos_d if pos_d == "" else summarize(translate(pos_d))
pos_d = re.sub(r"([ \t]){1,}", " ", pos_d).lower().strip()
neg = f"Textual, Text, Distorted, Fake, Discontinuous, Blurry, Doll-Like, Overly Plastic, Low Quality, Paint, Smoothed, Artificial, Phony, Gaudy, Digital Effects."
q = "\""
pos = f'HQ Hyper-realistic professional photograph{ pos_d if pos_d == "" else ": " + pos_d }.'
print(f"""
Positive: {pos}
Negative: {neg}
""")
img = image_pipe(
prompt=pos,
negative_prompt=neg,
height=h,
width=w,
output_type="pil",
guidance_scale=img_accu,
num_images_per_prompt=1,
num_inference_steps=image_steps,
max_sequence_length=seq,
generator=torch.Generator(device).manual_seed(random.randint(0, MAX_SEED))
)
working = False
_HEIGHT_ = h
_WIDTH_ = w
return img
# entry
if __name__ == "__main__":
with gr.Blocks(theme=gr.themes.Citrus(),css=css) as demo:
gr.Markdown(f"""
# Text-to-Image generator
""")
gr.Markdown(f"""
### Realistic. Upscalable. Multilingual.
""")
with gr.Row():
with gr.Column(scale=3):
height = gr.Slider(
label="Height (px)",
minimum=512,
maximum=2048,
step=16,
value=1024,
)
width = gr.Slider(
label="Width (px)",
minimum=512,
maximum=2048,
step=16,
value=1024,
)
run = gr.Button("Generate",elem_classes="btn")
data = gr.Textbox(
placeholder="Input data",
value="",
container=False,
max_lines=100
)
with gr.Column():
cover = gr.Image(interactive=False,container=False,elem_classes="image-container", label="Result", show_label=True, type='pil', show_share_button=False)
with gr.Column():
upscale_now = gr.Button("Upscale",elem_classes="btn")
with gr.Column():
top = gr.Textbox(
placeholder="Top title",
value="",
container=False,
max_lines=1
)
bottom = gr.Textbox(
placeholder="Bottom title",
value="",
container=False,
max_lines=1
)
add_titles = gr.Button("Add title(s)",elem_classes="btn")
gr.on(
triggers=[run.click],
fn=handle_generation,
inputs=[height,width,data],
outputs=[cover]
)
upscale_now.click(
fn=upscaler,
inputs=[cover],
outputs=[cover]
)
add_titles.click(
fn=add_song_cover_text,
inputs=[cover,top,bottom],
outputs=[cover]
)
demo.queue().launch()