Kokoro-API-4 / app.py
Yaron Koresh
Update app.py
209a22d verified
raw
history blame
6.25 kB
import gradio as gr
import os
import re
#from tempfile import NamedTemporaryFile
import numpy as np
import spaces
import random
import string
from diffusers import StableDiffusion3Pipeline
import torch
from pathos.multiprocessing import ProcessingPool as ProcessPoolExecutor
import requests
from lxml.html import fromstring
pool = ProcessPoolExecutor(4)
pool.__enter__()
#model_id = "runwayml/stable-diffusion-v1-5"
model_id = "stabilityai/stable-diffusion-3-medium-diffusers"
device = "cuda" if torch.cuda.is_available() else "cpu"
if torch.cuda.is_available():
torch.cuda.max_memory_allocated(device=device)
pipe = StableDiffusion3Pipeline.from_pretrained(model_id, torch_dtype=torch.float16, variant="fp16", use_safetensors=True, token=os.getenv('hf_token'))
pipe = pipe.to(device)
else:
pipe = StableDiffusion3Pipeline.from_pretrained(model_id, use_safetensors=True, token=os.getenv('hf_token'))
pipe = pipe.to(device)
def translate(text,lang):
text = re.sub(f'[{string.punctuation}]', '', re.sub('[\s+]', ' ', text)).lower().strip()
lang = re.sub(f'[{string.punctuation}]', '', re.sub('[\s+]', ' ', lang)).lower().strip()
user_agents = [
'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/126.0.0.0 Safari/537.36',
'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/126.0.0.0 Safari/537.36',
'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/126.0.0.0 Safari/537.36',
'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/126.0.0.0 Safari/537.36',
'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/126.0.0.0 Safari/537.36',
'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/16.1 Safari/605.1.15',
'Mozilla/5.0 (Macintosh; Intel Mac OS X 13_1) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/16.1 Safari/605.1.15'
]
url = f'https://www.google.com/search?q=translate to {lang}: {text}'
print(url)
resp = requests.get(
url = url,
headers = {
'User-Agent': random.choice(user_agents)
}
)
print(resp)
content = resp.content
html = fromstring(content)
rslt = html.xpath('//pre[@aria-label="Translated text"]/span')
translated = text
try:
t = rslt[0].text.strip()
translated = t
except:
print(f'"{text}" is already in {lang}!')
ret = re.sub(f'[{string.punctuation}]', '', re.sub('[\s+]', ' ', translated)).lower().strip()
print(ret)
return ret
def generate_random_string(length):
characters = string.ascii_letters + string.digits
return ''.join(random.choice(characters) for _ in range(length))
@spaces.GPU
def Piper(_do):
return pipe(
_do,
height=480,
width=480,
negative_prompt="",
num_inference_steps=100,
guidance_scale=10
)
def infer(prompt1,prompt2,prompt3,prompt4):
name = generate_random_string(12)+".png"
if prompt4 == None:
prompt5 = ""
else:
prompt5 = " or " + " or ".join([translate(v,"english") for v in prompt4])
_dont = f'Show TEXT or LOGO{prompt5}.'
if prompt4 == None:
prompt6 = "text and logos"
else:
prompt6 = "text and logos and " + " and ".join([translate(v,"english") for v in prompt4])
if prompt1 == None:
prompt1 = "element"
else:
prompt1 = ", and ".join([ (translate(v,"english").upper()) for v in prompt1 ])
if prompt2 == None:
prompt2 = "elements"
else:
prompt2 = ", and ".join([ (translate(v,"english").upper()) for v in prompt2 ])
if prompt3 == None:
prompt3 = "event"
else:
prompt3 = " ".join([translate(v,"english").upper() for v in prompt3]) + " event"
_do = f'Show {prompt3}; Include {prompt2}; Focus on {prompt1}; Remove {prompt6}.'
print(_do)
image = Piper(_do).images[0].save(name)
return name
css="""
#col-container {
margin: 0 auto;
max-width: 13cm;
}
#image-container {
aspect-ratio: 1 / 1;
}
.dropdown-arrow {
display: none !important;
}
"""
if torch.cuda.is_available():
power_device = "GPU"
else:
power_device = "CPU"
with gr.Blocks(theme=gr.themes.Soft(),css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(f"""
# Image Generator
Currently running on {power_device}.
""")
with gr.Column():
with gr.Row():
prompt1 = gr.Dropdown(
multiselect=True,
allow_custom_value=True,
max_choices=1,
label="Foreground Elements",
show_label=True,
container=True
)
with gr.Row():
prompt2 = gr.Dropdown(
multiselect=True,
allow_custom_value=True,
max_choices=3,
label="Background Elements",
show_label=True,
container=True
)
with gr.Row():
prompt3 = gr.Dropdown(
multiselect=True,
allow_custom_value=True,
max_choices=1,
label="Background Events",
show_label=True,
container=True
)
with gr.Row():
prompt4 = gr.Dropdown(
multiselect=True,
allow_custom_value=True,
max_choices=4,
label="Forbidden Elements/Events",
show_label=True,
container=True
)
with gr.Row():
run_button = gr.Button("Run")
result = gr.Image(elem_id="image-container", label="Result", show_label=False, type='filepath')
run_button.click(
fn = infer,
inputs = [prompt1,prompt2,prompt3,prompt4],
outputs = [result]
)
demo.queue().launch()