Kokoro-API-2 / app.py
yaron123's picture
commit
5c223cd
raw
history blame
22.3 kB
# built-in
from inspect import signature
import os
import subprocess
import logging
import re
import random
from string import ascii_letters, digits, punctuation
import requests
import sys
import warnings
import time
import asyncio
import math
from pathlib import Path
from functools import partial
from dataclasses import dataclass
from typing import Any
import pillow_heif
import spaces
import numpy as np
import numpy.typing as npt
import torch
import gradio as gr
from lxml.html import fromstring
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file, save_file
from diffusers import FluxPipeline
from PIL import Image, ImageDraw, ImageFont
from transformers import PegasusForConditionalGeneration, PegasusTokenizerFast
from refiners.fluxion.utils import manual_seed
from refiners.foundationals.latent_diffusion import Solver, solvers
from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_upscaler import (
MultiUpscaler,
UpscalerCheckpoints,
)
Tile = tuple[int, int, Image.Image]
Tiles = list[tuple[int, int, list[Tile]]]
def conv_block(in_nc: int, out_nc: int) -> nn.Sequential:
return nn.Sequential(
nn.Conv2d(in_nc, out_nc, kernel_size=3, padding=1),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
)
class ResidualDenseBlock_5C(nn.Module):
"""
Residual Dense Block
The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
Modified options that can be used:
- "Partial Convolution based Padding" arXiv:1811.11718
- "Spectral normalization" arXiv:1802.05957
- "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
{Rakotonirina} and A. {Rasoanaivo}
"""
def __init__(self, nf: int = 64, gc: int = 32) -> None:
super().__init__() # type: ignore[reportUnknownMemberType]
self.conv1 = conv_block(nf, gc)
self.conv2 = conv_block(nf + gc, gc)
self.conv3 = conv_block(nf + 2 * gc, gc)
self.conv4 = conv_block(nf + 3 * gc, gc)
# Wrapped in Sequential because of key in state dict.
self.conv5 = nn.Sequential(nn.Conv2d(nf + 4 * gc, nf, kernel_size=3, padding=1))
def forward(self, x: torch.Tensor) -> torch.Tensor:
x1 = self.conv1(x)
x2 = self.conv2(torch.cat((x, x1), 1))
x3 = self.conv3(torch.cat((x, x1, x2), 1))
x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
return x5 * 0.2 + x
class RRDB(nn.Module):
"""
Residual in Residual Dense Block
(ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
"""
def __init__(self, nf: int) -> None:
super().__init__() # type: ignore[reportUnknownMemberType]
self.RDB1 = ResidualDenseBlock_5C(nf)
self.RDB2 = ResidualDenseBlock_5C(nf)
self.RDB3 = ResidualDenseBlock_5C(nf)
def forward(self, x: torch.Tensor) -> torch.Tensor:
out = self.RDB1(x)
out = self.RDB2(out)
out = self.RDB3(out)
return out * 0.2 + x
class Upsample2x(nn.Module):
"""Upsample 2x."""
def __init__(self) -> None:
super().__init__() # type: ignore[reportUnknownMemberType]
def forward(self, x: torch.Tensor) -> torch.Tensor:
return nn.functional.interpolate(x, scale_factor=2.0) # type: ignore
class ShortcutBlock(nn.Module):
"""Elementwise sum the output of a submodule to its input"""
def __init__(self, submodule: nn.Module) -> None:
super().__init__() # type: ignore[reportUnknownMemberType]
self.sub = submodule
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.sub(x)
class RRDBNet(nn.Module):
def __init__(self, in_nc: int, out_nc: int, nf: int, nb: int) -> None:
super().__init__() # type: ignore[reportUnknownMemberType]
assert in_nc % 4 != 0 # in_nc is 3
self.model = nn.Sequential(
nn.Conv2d(in_nc, nf, kernel_size=3, padding=1),
ShortcutBlock(
nn.Sequential(
*(RRDB(nf) for _ in range(nb)),
nn.Conv2d(nf, nf, kernel_size=3, padding=1),
)
),
Upsample2x(),
nn.Conv2d(nf, nf, kernel_size=3, padding=1),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
Upsample2x(),
nn.Conv2d(nf, nf, kernel_size=3, padding=1),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(nf, nf, kernel_size=3, padding=1),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(nf, out_nc, kernel_size=3, padding=1),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.model(x)
def infer_params(state_dict: dict[str, torch.Tensor]) -> tuple[int, int, int, int, int]:
# this code is adapted from https://github.com/victorca25/iNNfer
scale2x = 0
scalemin = 6
n_uplayer = 0
out_nc = 0
nb = 0
for block in list(state_dict):
parts = block.split(".")
n_parts = len(parts)
if n_parts == 5 and parts[2] == "sub":
nb = int(parts[3])
elif n_parts == 3:
part_num = int(parts[1])
if part_num > scalemin and parts[0] == "model" and parts[2] == "weight":
scale2x += 1
if part_num > n_uplayer:
n_uplayer = part_num
out_nc = state_dict[block].shape[0]
assert "conv1x1" not in block # no ESRGANPlus
nf = state_dict["model.0.weight"].shape[0]
in_nc = state_dict["model.0.weight"].shape[1]
scale = 2**scale2x
assert out_nc > 0
assert nb > 0
return in_nc, out_nc, nf, nb, scale # 3, 3, 64, 23, 4
# https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/images.py#L64
class Grid(NamedTuple):
tiles: Tiles
tile_w: int
tile_h: int
image_w: int
image_h: int
overlap: int
# adapted from https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/images.py#L67
def split_grid(image: Image.Image, tile_w: int = 512, tile_h: int = 512, overlap: int = 64) -> Grid:
w = image.width
h = image.height
non_overlap_width = tile_w - overlap
non_overlap_height = tile_h - overlap
cols = max(1, math.ceil((w - overlap) / non_overlap_width))
rows = max(1, math.ceil((h - overlap) / non_overlap_height))
dx = (w - tile_w) / (cols - 1) if cols > 1 else 0
dy = (h - tile_h) / (rows - 1) if rows > 1 else 0
grid = Grid([], tile_w, tile_h, w, h, overlap)
for row in range(rows):
row_images: list[Tile] = []
y1 = max(min(int(row * dy), h - tile_h), 0)
y2 = min(y1 + tile_h, h)
for col in range(cols):
x1 = max(min(int(col * dx), w - tile_w), 0)
x2 = min(x1 + tile_w, w)
tile = image.crop((x1, y1, x2, y2))
row_images.append((x1, tile_w, tile))
grid.tiles.append((y1, tile_h, row_images))
return grid
# https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/images.py#L104
def combine_grid(grid: Grid):
def make_mask_image(r: npt.NDArray[np.float32]) -> Image.Image:
r = r * 255 / grid.overlap
return Image.fromarray(r.astype(np.uint8), "L")
mask_w = make_mask_image(
np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0)
)
mask_h = make_mask_image(
np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1)
)
combined_image = Image.new("RGB", (grid.image_w, grid.image_h))
for y, h, row in grid.tiles:
combined_row = Image.new("RGB", (grid.image_w, h))
for x, w, tile in row:
if x == 0:
combined_row.paste(tile, (0, 0))
continue
combined_row.paste(tile.crop((0, 0, grid.overlap, h)), (x, 0), mask=mask_w)
combined_row.paste(tile.crop((grid.overlap, 0, w, h)), (x + grid.overlap, 0))
if y == 0:
combined_image.paste(combined_row, (0, 0))
continue
combined_image.paste(
combined_row.crop((0, 0, combined_row.width, grid.overlap)),
(0, y),
mask=mask_h,
)
combined_image.paste(
combined_row.crop((0, grid.overlap, combined_row.width, h)),
(0, y + grid.overlap),
)
return combined_image
class UpscalerESRGAN:
def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype):
self.model_path = model_path
self.device = device
self.model = self.load_model(model_path)
self.to(device, dtype)
def __call__(self, img: Image.Image) -> Image.Image:
return self.upscale_without_tiling(img)
def to(self, device: torch.device, dtype: torch.dtype):
self.device = device
self.dtype = dtype
self.model.to(device=device, dtype=dtype)
def load_model(self, path: Path) -> RRDBNet:
filename = path
state_dict: dict[str, torch.Tensor] = torch.load(filename, weights_only=True, map_location=self.device) # type: ignore
in_nc, out_nc, nf, nb, upscale = infer_params(state_dict)
assert upscale == 4, "Only 4x upscaling is supported"
model = RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb)
model.load_state_dict(state_dict)
model.eval()
return model
def upscale_without_tiling(self, img: Image.Image) -> Image.Image:
img_np = np.array(img)
img_np = img_np[:, :, ::-1]
img_np = np.ascontiguousarray(np.transpose(img_np, (2, 0, 1))) / 255
img_t = torch.from_numpy(img_np).float() # type: ignore
img_t = img_t.unsqueeze(0).to(device=self.device, dtype=self.dtype)
with torch.no_grad():
output = self.model(img_t)
output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
output = 255.0 * np.moveaxis(output, 0, 2)
output = output.astype(np.uint8)
output = output[:, :, ::-1]
return Image.fromarray(output, "RGB")
# https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/esrgan_model.py#L208
def upscale_with_tiling(self, img: Image.Image) -> Image.Image:
img = img.convert("RGB")
grid = split_grid(img)
newtiles: Tiles = []
scale_factor: int = 1
for y, h, row in grid.tiles:
newrow: list[Tile] = []
for tiledata in row:
x, w, tile = tiledata
output = self.upscale_without_tiling(tile)
scale_factor = output.width // tile.width
newrow.append((x * scale_factor, w * scale_factor, output))
newtiles.append((y * scale_factor, h * scale_factor, newrow))
newgrid = Grid(
newtiles,
grid.tile_w * scale_factor,
grid.tile_h * scale_factor,
grid.image_w * scale_factor,
grid.image_h * scale_factor,
grid.overlap * scale_factor,
)
output = combine_grid(newgrid)
return output
@dataclass(kw_only=True)
class ESRGANUpscalerCheckpoints(UpscalerCheckpoints):
esrgan: Path
class ESRGANUpscaler(MultiUpscaler):
def __init__(
self,
checkpoints: ESRGANUpscalerCheckpoints,
device: torch.device,
dtype: torch.dtype,
) -> None:
super().__init__(checkpoints=checkpoints, device=device, dtype=dtype)
self.esrgan = UpscalerESRGAN(checkpoints.esrgan, device=self.device, dtype=self.dtype)
def to(self, device: torch.device, dtype: torch.dtype):
self.esrgan.to(device=device, dtype=dtype)
self.sd = self.sd.to(device=device, dtype=dtype)
self.device = device
self.dtype = dtype
def pre_upscale(self, image: Image.Image, upscale_factor: float, **_: Any) -> Image.Image:
image = self.esrgan.upscale_with_tiling(image)
return super().pre_upscale(image=image, upscale_factor=upscale_factor / 4)
pillow_heif.register_heif_opener()
pillow_heif.register_avif_opener()
CHECKPOINTS = ESRGANUpscalerCheckpoints(
unet=Path(
hf_hub_download(
repo_id="refiners/juggernaut.reborn.sd1_5.unet",
filename="model.safetensors",
revision="347d14c3c782c4959cc4d1bb1e336d19f7dda4d2",
)
),
clip_text_encoder=Path(
hf_hub_download(
repo_id="refiners/juggernaut.reborn.sd1_5.text_encoder",
filename="model.safetensors",
revision="744ad6a5c0437ec02ad826df9f6ede102bb27481",
)
),
lda=Path(
hf_hub_download(
repo_id="refiners/juggernaut.reborn.sd1_5.autoencoder",
filename="model.safetensors",
revision="3c1aae3fc3e03e4a2b7e0fa42b62ebb64f1a4c19",
)
),
controlnet_tile=Path(
hf_hub_download(
repo_id="refiners/controlnet.sd1_5.tile",
filename="model.safetensors",
revision="48ced6ff8bfa873a8976fa467c3629a240643387",
)
),
esrgan=Path(
hf_hub_download(
repo_id="philz1337x/upscaler",
filename="4x-UltraSharp.pth",
revision="011deacac8270114eb7d2eeff4fe6fa9a837be70",
)
),
negative_embedding=Path(
hf_hub_download(
repo_id="philz1337x/embeddings",
filename="JuggernautNegative-neg.pt",
revision="203caa7e9cc2bc225031a4021f6ab1ded283454a",
)
),
negative_embedding_key="string_to_param.*",
loras={
"more_details": Path(
hf_hub_download(
repo_id="philz1337x/loras",
filename="more_details.safetensors",
revision="a3802c0280c0d00c2ab18d37454a8744c44e474e",
)
),
"sdxl_render": Path(
hf_hub_download(
repo_id="philz1337x/loras",
filename="SDXLrender_v2.0.safetensors",
revision="a3802c0280c0d00c2ab18d37454a8744c44e474e",
)
)
}
)
# initialize the enhancer, on the cpu
DEVICE_CPU = torch.device("cpu")
DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
enhancer = ESRGANUpscaler(checkpoints=CHECKPOINTS, device=DEVICE_CPU, dtype=DTYPE)
device = DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
enhancer.to(device=DEVICE, dtype=DTYPE)
# logging
warnings.filterwarnings("ignore")
root = logging.getLogger()
root.setLevel(logging.WARN)
handler = logging.StreamHandler(sys.stderr)
handler.setLevel(logging.WARN)
formatter = logging.Formatter('\n >>> [%(levelname)s] %(asctime)s %(name)s: %(message)s\n')
handler.setFormatter(formatter)
root.addHandler(handler)
# constant data
base = "black-forest-labs/FLUX.1-schnell"
pegasus_name = "google/pegasus-xsum"
# precision data
seq=512
width=1024
height=1024
image_steps=8
img_accu=0
# ui data
css="".join(["""
input, input::placeholder {
text-align: center !important;
}
*, *::placeholder {
font-family: Suez One !important;
}
h1,h2,h3,h4,h5,h6 {
width: 100%;
text-align: center;
}
footer {
display: none !important;
}
#col-container {
margin: 0 auto;
}
.image-container {
aspect-ratio: """,str(width),"/",str(height),""" !important;
}
.dropdown-arrow {
display: none !important;
}
*:has(>.btn) {
display: flex;
justify-content: space-evenly;
align-items: center;
}
.btn {
display: flex;
}
"""])
js="""
function custom(){
document.querySelector("div#prompt input").addEventListener("keydown",function(e){
e.target.setAttribute("last_value",e.target.value);
});
document.querySelector("div#prompt input").addEventListener("input",function(e){
if( e.target.value.toString().match(/[^ a-zA-Z,]|( |,){2,}/gsm) ){
e.target.value = e.target.getAttribute("last_value");
e.target.removeAttribute("last_value");
}
});
document.querySelector("div#prompt2 input").addEventListener("keydown",function(e){
e.target.setAttribute("last_value",e.target.value);
});
document.querySelector("div#prompt2 input").addEventListener("input",function(e){
if( e.target.value.toString().match(/[^ a-zA-Z,]|( |,){2,}/gsm) ){
e.target.value = e.target.getAttribute("last_value");
e.target.removeAttribute("last_value");
}
});
}
"""
# torch pipes
image_pipe = FluxPipeline.from_pretrained(base, torch_dtype=torch.bfloat16).to(device)
image_pipe.enable_model_cpu_offload()
# functionality
@spaces.GPU(duration=180)
def upscaler(
input_image: Image.Image,
prompt: str = "masterpiece, best quality, highres",
negative_prompt: str = "worst quality, low quality, normal quality",
seed: int = 42,
upscale_factor: int = 8,
controlnet_scale: float = 0.6,
controlnet_decay: float = 1.0,
condition_scale: int = 6,
tile_width: int = 112,
tile_height: int = 144,
denoise_strength: float = 0.35,
num_inference_steps: int = 18,
solver: str = "DDIM",
) -> Image.Image:
manual_seed(seed)
solver_type: type[Solver] = getattr(solvers, solver)
enhanced_image = enhancer.upscale(
image=input_image,
prompt=prompt,
negative_prompt=negative_prompt,
upscale_factor=upscale_factor,
controlnet_scale=controlnet_scale,
controlnet_scale_decay=controlnet_decay,
condition_scale=condition_scale,
tile_size=(tile_height, tile_width),
denoise_strength=denoise_strength,
num_inference_steps=num_inference_steps,
loras_scale={"more_details": 0.5, "sdxl_render": 1.0},
solver_type=solver_type,
)
return enhanced_image
@spaces.GPU(duration=180)
def summarize_text(
text, max_length=30, num_beams=16, early_stopping=True,
pegasus_tokenizer = PegasusTokenizerFast.from_pretrained("google/pegasus-xsum"),
pegasus_model = PegasusForConditionalGeneration.from_pretrained("google/pegasus-xsum")
):
return pegasus_tokenizer.decode( pegasus_model.generate(
pegasus_tokenizer(text,return_tensors="pt").input_ids,
max_length=max_length,
num_beams=num_beams,
early_stopping=early_stopping
)[0], skip_special_tokens=True)
def generate_random_string(length):
characters = str(ascii_letters + digits)
return ''.join(random.choice(characters) for _ in range(length))
@spaces.GPU(duration=180)
def pipe_generate(p1,p2):
return image_pipe(
prompt=p1,
negative_prompt=p2,
height=height,
width=width,
guidance_scale=img_accu,
num_images_per_prompt=1,
num_inference_steps=image_steps,
max_sequence_length=seq,
generator=torch.Generator(device).manual_seed(int(str(random.random()).split(".")[1]))
).images[0]
def handle_generate(artist,song,genre,lyrics):
pos_artist = re.sub("([ \t\n]){1,}", " ", artist).strip()
pos_song = re.sub("([ \t\n]){1,}", " ", song).strip()
pos_song = ' '.join(word[0].upper() + word[1:] for word in pos_song.split())
pos_genre = re.sub(f'[{punctuation}]', '', re.sub("([ \t\n]){1,}", " ", genre)).upper().strip()
pos_lyrics = re.sub(f'[{punctuation}]', '', re.sub("([ \t\n]){1,}", " ", lyrics)).lower().strip()
pos_lyrics_sum = summarize_text(pos_lyrics)
neg = f"Textual Labeled Distorted Discontinuous Ugly Blurry Low-Quality Worst-Quality Low-Resolution Painted"
pos = f'Realistic Vivid Genuine Reasonable Detailed 4K { pos_genre } GENRE { pos_song }: "{ pos_lyrics_sum }"'
print(f"""
Positive: {pos}
Negative: {neg}
""")
img = pipe_generate(pos,neg)
draw = ImageDraw.Draw(img)
rows = 1
labels_distance = math.ceil(1 / 3)
textheight=min(math.ceil( width / 10 ), math.ceil( height / 5 ))
font = ImageFont.truetype(r"Alef-Bold.ttf", textheight)
textwidth = draw.textlength(pos_song,font)
x = math.ceil((width - textwidth) / 2)
y = height - math.ceil(textheight * rows / 2)
y = y - math.ceil(y / labels_distance)
draw.text((x, y), pos_song, (255,255,255), font=font, spacing=2, stroke_width=4, stroke_fill=(0,0,0))
textheight=min(math.ceil( width / 12 ), math.ceil( height / 6 ))
font = ImageFont.truetype(r"Alef-Bold.ttf", textheight)
textwidth = draw.textlength(pos_artist,font)
x = math.ceil((width - textwidth) / 2)
y = height - math.ceil(textheight * rows / 2)
y = y + math.ceil(y / labels_distance)
draw.text((x, y), pos_artist, (0,0,0), font=font, spacing=6, stroke_width=8, stroke_fill=(255,255,255))
enhanced_img = upscaler(img)
name = generate_random_string(12) + ".png"
enhanced_img.save(name)
return name
# entry
if __name__ == "__main__":
with gr.Blocks(theme=gr.themes.Citrus(),css=css) as demo:
gr.Markdown(f"""
# Song Cover Image Generator
""")
with gr.Column():
with gr.Row():
artist = gr.Textbox(
placeholder="Artist name",
container=False,
max_lines=1
)
song = gr.Textbox(
placeholder="Song name",
container=False,
max_lines=1
)
genre = gr.Textbox(
placeholder="Genre",
container=False,
max_lines=1
)
lyrics = gr.Textbox(
placeholder="Lyrics (English)",
container=False,
max_lines=1
)
with gr.Column():
cover = gr.Image(interactive=False,container=False,elem_classes="image-container", label="Result", show_label=True, type='filepath', show_share_button=False)
run = gr.Button("Generate",elem_classes="btn")
run.click(
fn=handle_generate,
inputs=[artist,song,genre,lyrics],
outputs=[cover]
)
demo.queue().launch()