Kokoro-API-5 / app.py
yaron123's picture
commit
ce2ea41
raw
history blame
7.16 kB
# built-in
from inspect import signature
import os
import subprocess
import logging
import re
import random
from string import ascii_letters, digits, punctuation
import requests
import sys
import warnings
import time
import asyncio
import math
from functools import partial
# external
import spaces
import torch
import gradio as gr
from lxml.html import fromstring
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file, save_file
from diffusers import FluxPipeline
from PIL import Image, ImageDraw, ImageFont
from transformers import PegasusForConditionalGeneration, PegasusTokenizerFast
# logging
warnings.filterwarnings("ignore")
root = logging.getLogger()
root.setLevel(logging.WARN)
handler = logging.StreamHandler(sys.stderr)
handler.setLevel(logging.WARN)
formatter = logging.Formatter('\n >>> [%(levelname)s] %(asctime)s %(name)s: %(message)s\n')
handler.setFormatter(formatter)
root.addHandler(handler)
# constant data
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
base = "black-forest-labs/FLUX.1-schnell"
pegasus_name = "google/pegasus-xsum"
# precision data
seq=512
width=2160
height=2160
image_steps=8
img_accu=0
# ui data
css="".join(["""
input, input::placeholder {
text-align: center !important;
}
*, *::placeholder {
font-family: Suez One !important;
}
h1,h2,h3,h4,h5,h6 {
width: 100%;
text-align: center;
}
footer {
display: none !important;
}
#col-container {
margin: 0 auto;
}
.image-container {
aspect-ratio: """,str(width),"/",str(height),""" !important;
}
.dropdown-arrow {
display: none !important;
}
*:has(>.btn) {
display: flex;
justify-content: space-evenly;
align-items: center;
}
.btn {
display: flex;
}
"""])
js="""
function custom(){
document.querySelector("div#prompt input").addEventListener("keydown",function(e){
e.target.setAttribute("last_value",e.target.value);
});
document.querySelector("div#prompt input").addEventListener("input",function(e){
if( e.target.value.toString().match(/[^ a-zA-Z,]|( |,){2,}/gsm) ){
e.target.value = e.target.getAttribute("last_value");
e.target.removeAttribute("last_value");
}
});
document.querySelector("div#prompt2 input").addEventListener("keydown",function(e){
e.target.setAttribute("last_value",e.target.value);
});
document.querySelector("div#prompt2 input").addEventListener("input",function(e){
if( e.target.value.toString().match(/[^ a-zA-Z,]|( |,){2,}/gsm) ){
e.target.value = e.target.getAttribute("last_value");
e.target.removeAttribute("last_value");
}
});
}
"""
# torch pipes
image_pipe = FluxPipeline.from_pretrained(base, torch_dtype=torch.bfloat16).to(device)
image_pipe.enable_model_cpu_offload()
# functionality
@spaces.GPU(duration=70)
def summarize_text(
text, max_length=30, num_beams=16, early_stopping=True,
pegasus_tokenizer = PegasusTokenizerFast.from_pretrained("google/pegasus-xsum"),
pegasus_model = PegasusForConditionalGeneration.from_pretrained("google/pegasus-xsum")
):
return pegasus_tokenizer.decode( pegasus_model.generate(
pegasus_tokenizer(text,return_tensors="pt").input_ids,
max_length=max_length,
num_beams=num_beams,
early_stopping=early_stopping
)[0], skip_special_tokens=True)
def generate_random_string(length):
characters = str(ascii_letters + digits)
return ''.join(random.choice(characters) for _ in range(length))
@spaces.GPU(duration=140)
def pipe_generate(p1,p2):
return image_pipe(
prompt=p1,
negative_prompt=p2,
height=height,
width=width,
guidance_scale=img_accu,
num_images_per_prompt=1,
num_inference_steps=image_steps,
max_sequence_length=seq,
generator=torch.Generator(device).manual_seed(int(str(random.random()).split(".")[1]))
).images[0]
def handle_generate(artist,song,genre,lyrics):
pos_artist = re.sub("([ \t\n]){1,}", " ", artist).strip()
pos_song = re.sub("([ \t\n]){1,}", " ", song).strip()
pos_song = ' '.join(word[0].upper() + word[1:] for word in pos_song.split())
pos_genre = re.sub(f'[{punctuation}]', '', re.sub("([ \t\n]){1,}", " ", genre)).upper().strip()
pos_lyrics = re.sub(f'[{punctuation}]', '', re.sub("([ \t\n]){1,}", " ", lyrics)).lower().strip()
pos_lyrics_sum = summarize_text(pos_lyrics)
neg = f"Textual Labeled Distorted Discontinuous Ugly Blurry"
pos = f'Realistic Natural Genuine Reasonable Detailed { pos_genre } GENRE { pos_song } "{ pos_lyrics_sum }"'
print(f"""
Positive: {pos}
Negative: {neg}
""")
img = pipe_generate(pos,neg)
draw = ImageDraw.Draw(img)
rows = 1
labes_distance = math.cail(1 / 3)
textheight=min(math.cail( width / 10 ), math.cail( height / 5 ))
font = ImageFont.truetype(r"Alef-Bold.ttf", textheight)
textwidth = draw.textlength(pos_song,font)
x = math.cail((width - textwidth) / 2)
y = math.cail((height - math.cail(textheight * rows / 2)) / 2)
y = y - math.cail(y / labes_distance)
draw.text((x, y), pos_song, (255,255,255), font=font)
textheight=min(math.cail( width / 12 ), math.cail( height / 6 ))
font = ImageFont.truetype(r"Alef-Bold.ttf", textheight)
textwidth = draw.textlength(pos_artist,font)
x = math.cail((width - textwidth) / 2)
y = math.cail((height - math.cail(textheight * rows / 2)) / 2)
y = y + math.cail(y / labes_distance)
draw.text((x, y), pos_artist, (255,255,255), font=font)
name = generate_random_string(12) + ".png"
img.save(name)
return name
# entry
if __name__ == "__main__":
with gr.Blocks(theme=gr.themes.Citrus(),css=css) as demo:
gr.Markdown(f"""
# Song Cover Image Generator
""")
with gr.Row():
with gr.Column():
artist = gr.Textbox(
placeholder="Artist name",
container=False,
max_lines=1
)
with gr.Column():
song = gr.Textbox(
placeholder="Song name",
container=False,
max_lines=1
)
with gr.Column():
genre = gr.Textbox(
placeholder="Genre",
container=False,
max_lines=1
)
with gr.Row():
lyrics = gr.Textbox(
placeholder="Lyrics (English)",
container=False,
max_lines=1
)
with gr.Row():
run = gr.Button("Generate",elem_classes="btn")
with gr.Row():
cover = gr.Image(interactive=False,container=False,elem_classes="image-container", label="Result", show_label=True, type='filepath', show_share_button=False)
run.click(
fn=handle_generate,
inputs=[artist,song,genre,lyrics],
outputs=[cover]
)
demo.queue().launch()
# end