Kokoro-API-1 / app.py
yaron123's picture
commit
f3a00bf
raw
history blame
6.56 kB
# built-in
from inspect import signature
import os
import subprocess
import logging
import re
import random
from string import ascii_letters, digits, punctuation
import requests
import sys
import warnings
import time
import asyncio
from functools import partial
# external
import spaces
import torch
import gradio as gr
from numpy import asarray as array
from lxml.html import fromstring
from diffusers.utils import export_to_video, load_image
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file, save_file
from diffusers import FluxPipeline, CogVideoXImageToVideoPipeline
from PIL import Image, ImageDraw, ImageFont
# logging
warnings.filterwarnings("ignore")
root = logging.getLogger()
root.setLevel(logging.WARN)
handler = logging.StreamHandler(sys.stderr)
handler.setLevel(logging.WARN)
formatter = logging.Formatter('\n >>> [%(levelname)s] %(asctime)s %(name)s: %(message)s\n')
handler.setFormatter(formatter)
root.addHandler(handler)
# constant data
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
base = "black-forest-labs/FLUX.1-schnell"
# variable data
# precision data
seq=512
width=4320
height=4320
image_steps=8
img_accu=0
# ui data
css="".join(["""
input, input::placeholder {
text-align: center !important;
}
*, *::placeholder {
font-family: Suez One !important;
}
h1,h2,h3,h4,h5,h6 {
width: 100%;
text-align: center;
}
footer {
display: none !important;
}
#col-container {
margin: 0 auto;
}
.image-container {
aspect-ratio: """,str(width),"/",str(height),""" !important;
}
.dropdown-arrow {
display: none !important;
}
*:has(>.btn) {
display: flex;
justify-content: space-evenly;
align-items: center;
}
.btn {
display: flex;
}
"""])
js="""
function custom(){
document.querySelector("div#prompt input").addEventListener("keydown",function(e){
e.target.setAttribute("last_value",e.target.value);
});
document.querySelector("div#prompt input").addEventListener("input",function(e){
if( e.target.value.toString().match(/[^ a-zA-Z,]|( |,){2,}/gsm) ){
e.target.value = e.target.getAttribute("last_value");
e.target.removeAttribute("last_value");
}
});
document.querySelector("div#prompt2 input").addEventListener("keydown",function(e){
e.target.setAttribute("last_value",e.target.value);
});
document.querySelector("div#prompt2 input").addEventListener("input",function(e){
if( e.target.value.toString().match(/[^ a-zA-Z,]|( |,){2,}/gsm) ){
e.target.value = e.target.getAttribute("last_value");
e.target.removeAttribute("last_value");
}
});
}
"""
# torch pipes
image_pipe = FluxPipeline.from_pretrained(base, torch_dtype=torch.bfloat16).to(device)
image_pipe.enable_model_cpu_offload()
# functionality
def generate_random_string(length):
characters = str(ascii_letters + digits)
return ''.join(random.choice(characters) for _ in range(length))
@spaces.GPU()
def pipe_generate(p1,p2):
return image_pipe(
prompt=p1,
negative_prompt=p2,
height=height,
width=width,
guidance_scale=img_accu,
num_images_per_prompt=1,
num_inference_steps=image_steps,
max_sequence_length=seq,
generator=torch.Generator(device).manual_seed(int(str(random.random()).split(".")[1]))
).images[0]
def handle_generate(artist,song,genre,lyrics):
pos_artist = re.sub("([ \t\n]){1,}", " ", artist).strip()
pos_song = re.sub("([ \t\n]){1,}", " ", song).strip()
pos_song = ' '.join(word[0].upper() + word[1:] for word in pos_song.split())
pos_genre = re.sub(f'[{punctuation}]', '', re.sub("([ \t\n]){1,}", " ", genre)).upper().strip()
pos_lyrics = re.sub(f'[{punctuation}]', '', re.sub("([ \t\n]){1,}", " ", genre)).lower().strip()
neg = f"Textual Labeled Distorted Discontinuous Ugly Blurry"
pos = f'Realistic Natural Genuine Reasonable Detailed { pos_genre } GENRE SONG COVER FOR { pos_song }: "{ pos_lyrics }"'
print(f"""
Positive: {inp[1]}
Negative: {inp[2]}
""")
img = pipe_generate(pos,neg)
draw = ImageDraw.Draw(img)
rows = 1
labes_distance = 1 // 3
textheight=min(( width // 10 ), ( height // 5 ))
font = ImageFont.truetype(r"Alef-Bold.ttf", textheight)
textwidth = draw.textlength(pos_song,font)
x = (width - textwidth) // 2
y = (height - (textheight * rows // 2)) // 2
y = y - (y // labes_distance)
draw.text((x, y), pos_song, (255,255,255), font=font)
textheight=min(( width // 12 ), ( height // 6 ))
font = ImageFont.truetype(r"Alef-Bold.ttf", textheight)
textwidth = draw.textlength(pos_artist,font)
x = (width - textwidth) // 2
y = (height - (textheight * rows // 2)) // 2
y = y + (y // labes_distance)
draw.text((x, y), pos_artist, (255,255,255), font=font)
name = generate_random_string(12) + ".png"
img.save(name)
return name
def ui():
with gr.Blocks(theme=gr.themes.Citrus(),css=css,js=js) as demo:
gr.Markdown(f"""
# Song Cover Image Generator
""")
with gr.Row():
with gr.Column():
artist = gr.Textbox(
placeholder="Artist name",
container=False,
max_lines=1
)
with gr.Column():
song = gr.Textbox(
placeholder="Song name",
container=False,
max_lines=1
)
with gr.Column():
genre = gr.Textbox(
placeholder="Genre",
container=False,
max_lines=1
)
with gr.Row():
lyrics = gr.Textbox(
placeholder="Lyrics (English)",
container=False,
max_lines=1
)
with gr.Row():
cover = gr.Image(interactive=False,container=False,elem_classes="image-container", label="Result", show_label=True, type='filepath', show_share_button=False)
with gr.Row():
run = gr.Button("Generate",elem_classes="btn")
gr.on(
triggers=[
run.click
],
fn=handle_generate,
inputs=[artist,song,genre,lyrics],
outputs=[cover]
)
demo.queue().launch()
# entry
if __name__ == "__main__":
ui()
# end