Kokoro-API-3 / app.py
yaron123's picture
commit
a8e1c8c
raw
history blame
10.1 kB
# built-in
from inspect import signature
import os
import subprocess
import logging
import re
import random
from string import ascii_letters, digits, punctuation
import requests
import sys
import warnings
import time
import asyncio
from functools import partial
# external
import spaces
import torch
import gradio as gr
from numpy import asarray as array
from lxml.html import fromstring
from diffusers.utils import export_to_video, load_image
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file, save_file
from diffusers import FluxPipeline, CogVideoXImageToVideoPipeline
from PIL import Image, ImageDraw, ImageFont
# logging
warnings.filterwarnings("ignore")
root = logging.getLogger()
root.setLevel(logging.WARN)
handler = logging.StreamHandler(sys.stderr)
handler.setLevel(logging.WARN)
formatter = logging.Formatter('\n >>> [%(levelname)s] %(asctime)s %(name)s: %(message)s\n')
handler.setFormatter(formatter)
root.addHandler(handler)
# constant data
if torch.cuda.is_available():
device = "cuda"
dtype = torch.bfloat16
else:
device = "cpu"
dtype = torch.bfloat16
base = "black-forest-labs/FLUX.1-schnell"
# variable data
additional_image = None
# precision data
seq=512
fps=18
width=768
height=768
image_steps=8
video_steps=15
accu=6.5
# ui data
css="".join(["""
input, input::placeholder {
text-align: center !important;
}
*, *::placeholder {
font-family: Suez One !important;
}
h1,h2,h3,h4,h5,h6 {
width: 100%;
text-align: center;
}
footer {
display: none !important;
}
#col-container {
margin: 0 auto;
}
.image-container {
aspect-ratio: """,str(width),"/",str(height),""" !important;
}
.dropdown-arrow {
display: none !important;
}
*:has(>.btn) {
display: flex;
justify-content: space-evenly;
align-items: center;
}
.btn {
display: flex;
}
"""])
js="""
function custom(){
document.querySelector("div#prompt input").addEventListener("keydown",function(e){
e.target.setAttribute("last_value",e.target.value);
});
document.querySelector("div#prompt input").addEventListener("input",function(e){
if( e.target.value.toString().match(/[^ a-zA-Z,]|( |,){2,}/gsm) ){
e.target.value = e.target.getAttribute("last_value");
e.target.removeAttribute("last_value");
}
});
document.querySelector("div#prompt2 input").addEventListener("keydown",function(e){
e.target.setAttribute("last_value",e.target.value);
});
document.querySelector("div#prompt2 input").addEventListener("input",function(e){
if( e.target.value.toString().match(/[^ a-zA-Z,]|( |,){2,}/gsm) ){
e.target.value = e.target.getAttribute("last_value");
e.target.removeAttribute("last_value");
}
});
}
"""
# torch pipes
image_pipe = FluxPipeline.from_pretrained(base, torch_dtype=dtype).to(device)
video_pipe = CogVideoXImageToVideoPipeline.from_pretrained(
"THUDM/CogVideoX-5b-I2V",
torch_dtype=dtype
).to(device)
video_pipe.vae.enable_tiling()
video_pipe.vae.enable_slicing()
video_pipe.enable_model_cpu_offload()
# functionality
def run(cmd):
return str(subprocess.run(cmd, shell=True, capture_output=True, env=None).stdout)
def xpath_finder(str,pattern):
try:
return ""+fromstring(str).xpath(pattern)[0].text_content().lower().strip()
except:
return ""
def translate(text,lang):
if text == None or lang == None:
return ""
text = re.sub(f'[{punctuation}]', '', re.sub('[ ]+', ' ', text)).lower().strip()
lang = re.sub(f'[{punctuation}]', '', re.sub('[ ]+', ' ', lang)).lower().strip()
if text == "" or lang == "":
return ""
if len(text) > 38:
raise Exception("Translation Error: Too long text!")
user_agents = [
'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/126.0.0.0 Safari/537.36',
'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/126.0.0.0 Safari/537.36',
'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/126.0.0.0 Safari/537.36',
'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/16.1 Safari/605.1.15',
'Mozilla/5.0 (Macintosh; Intel Mac OS X 13_1) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/16.1 Safari/605.1.15'
]
text = text.strip()
query_text = f'Please translate {text}, into {lang}'
url = f'https://www.google.com/search?q={query_text}'
print(url)
content = str(requests.get(
url = url,
headers = {
'User-Agent': random.choice(user_agents)
}
).content)
translated = text
src_lang = xpath_finder(content,'//*[@class="source-language"]')
trgt_lang = xpath_finder(content,'//*[@class="target-language"]')
src_text = xpath_finder(content,'//*[@id="tw-source-text"]/*')
trgt_text = xpath_finder(content,'//*[@id="tw-target-text"]/*')
if trgt_lang == lang:
translated = trgt_text
ret = re.sub(f'[{punctuation}]', '', re.sub('[ ]+', ' ', translated)).lower().strip()
print(ret)
return ret
def generate_random_string(length):
characters = str(ascii_letters + digits)
return ''.join(random.choice(characters) for _ in range(length))
@spaces.GPU(duration=80)
def pipe_generate(img,p1,p2,time,title):
global pipe
if img is None:
img = image_pipe(
prompt=p1,
negative_prompt=p2,
height=height,
width=width,
guidance_scale=accu,
num_images_per_prompt=1,
num_inference_steps=image_steps,
max_sequence_length=seq,
generator=torch.Generator(device).manual_seed(int(str(random.random()).split(".")[1]))
).images[0]
additional_image = True
if title != "":
draw = ImageDraw.Draw(img)
textheight=72
rows = 1
font = ImageFont.truetype(r"Alef-Bold.ttf", textheight)
textwidth = draw.textlength(title,font)
x = (width - textwidth) // 2
y = (height - (textheight * rows // 2)) // 2
draw.text((x, y), title, (255,255,255), font=font)
additional_image = img if additional_image else None
if time == 0.0:
return img
return video_pipe(
prompt=p1,
negative_prompt=p2.replace("textual content, ",""),
image=img,
num_inference_steps=video_steps,
guidance_scale=accu,
num_videos_per_prompt=1,
num_frames=(fps*time),
generator=torch.Generator(device).manual_seed(int(str(random.random()).split(".")[1]))
).frames[0]
def handle_generate(*_inp):
additional_image = None
inp = list(_inp)
#inp[1] = translate(inp[1],"english")
#inp[2] = translate(inp[2],"english")
if len(inp[2]) >= 2:
inp[2] = "," + inp[2].strip(",").strip(" ")
inp[2] = f"textual content,unrealistic content,divined creatures,unrealistic creatures,creatures out of this world,demon,angel,cgi quality,anime quality,cartoon quality,drawing quality,cropped photo,cropped content,worst quality,low quality,duplicating elements,weird,non-standard human body,non-standard object structure,blur,wrong body anatomy,too big, too small,text,written content{inp[2]}"
if len(inp[1]) >= 2:
inp[1] = "," + inp[1].strip(",").strip(" ")
inp[1] = f'looks real,feels real,similar to real photographs,dark vivid colors,looks beautiful and pretty,look genuine and authentic,reasonable logic,natural,masterpiece,highly detailed{inp[1]}'
print(f"""
Positive: {inp[1]}
Negative: {inp[2]}
""")
pipe_out = pipe_generate(*inp)
name = generate_random_string(12) + ( ".png" if inp[3] == 0.0 else ".mp4" )
if inp[3] == 0.0:
pipe_out.save(name)
else:
export_to_video(pipe_out,name,fps=fps)
if inp[3] == 0.0:
return name, None
else:
return additional_image, name
def ui():
global result
with gr.Blocks(theme=gr.themes.Citrus(),css=css,js=js) as demo:
gr.Markdown(f"""
# Photo Motion - PNG/MP4 Generator
""")
with gr.Row():
title = gr.Textbox(
placeholder="Logo title",
container=False,
max_lines=1
)
prompt = gr.Textbox(
elem_id="prompt",
placeholder="Included keywords",
container=False,
max_lines=1
)
with gr.Row():
prompt2 = gr.Textbox(
elem_id="prompt2",
placeholder="Excluded keywords",
container=False,
max_lines=1
)
with gr.Row():
time = gr.Slider(
minimum=0.0,
maximum=3.0,
value=0.0,
step=1.0,
label="Duration (0s = PNG)"
)
with gr.Row(elem_id="col-container"):
with gr.Column():
img = gr.Image(label="Upload photo",show_label=True,container=False,type="pil")
with gr.Column():
res_img = gr.Image(interactive=False,container=False,elem_classes="image-container", label="Result", show_label=True, type='filepath', show_share_button=False)
with gr.Column():
res_vid = gr.Video(interactive=False,container=False,elem_classes="image-container", label="Result", show_label=True, show_share_button=False)
with gr.Row():
run_button = gr.Button("Start!",elem_classes="btn",scale=0)
gr.on(
triggers=[
run_button.click,
prompt.submit,
prompt2.submit
],
fn=handle_generate,
inputs=[img,prompt,prompt2,time,title],
outputs=[res_img,res_vid]
)
demo.queue().launch()
# entry
if __name__ == "__main__":
os.chdir(os.path.abspath(os.path.dirname(__file__)))
ui()
# end