Spaces:
Sleeping
Sleeping
""" | |
Modified parts included from these sources: | |
- https://github.com/nidhaloff/deep-translator | |
- https://huggingface.co/spaces/ostris/Flex.1-alpha | |
""" | |
import urllib | |
import requests | |
from bs4 import BeautifulSoup | |
from abc import ABC, abstractmethod | |
from pathlib import Path | |
from langdetect import detect as get_language | |
from typing import Any, Dict, List, Optional, Union | |
from collections import namedtuple | |
from inspect import signature | |
import os | |
import subprocess | |
import logging | |
import re | |
import random | |
from string import ascii_letters, digits, punctuation | |
import requests | |
import sys | |
import warnings | |
import time | |
import asyncio | |
import math | |
from pathlib import Path | |
from functools import partial | |
from dataclasses import dataclass | |
from typing import Any | |
import pillow_heif | |
import spaces | |
import numpy as np | |
import numpy.typing as npt | |
import torch | |
from torch import nn | |
import gradio as gr | |
from lxml.html import fromstring | |
from huggingface_hub import hf_hub_download | |
from safetensors.torch import load_file, save_file | |
from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL, FluxPipeline, FlowMatchEulerDiscreteScheduler | |
from PIL import Image, ImageDraw, ImageFont | |
from transformers import pipeline, T5ForConditionalGeneration, T5Tokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel | |
from refiners.fluxion.utils import manual_seed | |
from refiners.foundationals.latent_diffusion import Solver, solvers | |
from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_upscaler import ( | |
MultiUpscaler, | |
UpscalerCheckpoints, | |
) | |
from datetime import datetime | |
_HEIGHT_ = None | |
_WIDTH_ = None | |
working = False | |
model = T5ForConditionalGeneration.from_pretrained("t5-base") | |
tokenizer = T5Tokenizer.from_pretrained("t5-base") | |
def calculate_shift( | |
image_seq_len, | |
base_seq_len: int = 256, | |
max_seq_len: int = 4096, | |
base_shift: float = 0.5, | |
max_shift: float = 1.16, | |
): | |
m = (max_shift - base_shift) / (max_seq_len - base_seq_len) | |
b = base_shift - m * base_seq_len | |
mu = image_seq_len * m + b | |
return mu | |
def retrieve_timesteps( | |
scheduler, | |
num_inference_steps: Optional[int] = None, | |
device: Optional[Union[str, torch.device]] = None, | |
timesteps: Optional[List[int]] = None, | |
sigmas: Optional[List[float]] = None, | |
**kwargs, | |
): | |
if timesteps is not None and sigmas is not None: | |
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") | |
if timesteps is not None: | |
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) | |
timesteps = scheduler.timesteps | |
num_inference_steps = len(timesteps) | |
elif sigmas is not None: | |
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) | |
timesteps = scheduler.timesteps | |
num_inference_steps = len(timesteps) | |
else: | |
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) | |
timesteps = scheduler.timesteps | |
return timesteps, num_inference_steps | |
# FLUX pipeline function | |
def flux_pipe_call_that_returns_an_iterable_of_images( | |
self, | |
prompt: Union[str, List[str]] = None, | |
prompt_2: Optional[Union[str, List[str]]] = None, | |
height: Optional[int] = None, | |
width: Optional[int] = None, | |
num_inference_steps: int = 28, | |
timesteps: List[int] = None, | |
guidance_scale: float = 3.5, | |
num_images_per_prompt: Optional[int] = 1, | |
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
latents: Optional[torch.FloatTensor] = None, | |
prompt_embeds: Optional[torch.FloatTensor] = None, | |
pooled_prompt_embeds: Optional[torch.FloatTensor] = None, | |
output_type: Optional[str] = "pil", | |
return_dict: bool = True, | |
joint_attention_kwargs: Optional[Dict[str, Any]] = None, | |
max_sequence_length: int = 512, | |
good_vae: Optional[Any] = None, | |
): | |
height = height or self.default_sample_size * self.vae_scale_factor | |
width = width or self.default_sample_size * self.vae_scale_factor | |
# 1. Check inputs | |
self.check_inputs( | |
prompt, | |
prompt_2, | |
height, | |
width, | |
prompt_embeds=prompt_embeds, | |
pooled_prompt_embeds=pooled_prompt_embeds, | |
max_sequence_length=max_sequence_length, | |
) | |
self._guidance_scale = guidance_scale | |
self._joint_attention_kwargs = joint_attention_kwargs | |
self._interrupt = False | |
# 2. Define call parameters | |
batch_size = 1 if isinstance(prompt, str) else len(prompt) | |
device = self._execution_device | |
# 3. Encode prompt | |
lora_scale = joint_attention_kwargs.get("scale", None) if joint_attention_kwargs is not None else None | |
prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt( | |
prompt=prompt, | |
prompt_2=prompt_2, | |
prompt_embeds=prompt_embeds, | |
pooled_prompt_embeds=pooled_prompt_embeds, | |
device=device, | |
num_images_per_prompt=num_images_per_prompt, | |
max_sequence_length=max_sequence_length, | |
lora_scale=lora_scale, | |
) | |
# 4. Prepare latent variables | |
num_channels_latents = self.transformer.config.in_channels // 4 | |
latents, latent_image_ids = self.prepare_latents( | |
batch_size * num_images_per_prompt, | |
num_channels_latents, | |
height, | |
width, | |
prompt_embeds.dtype, | |
device, | |
generator, | |
latents, | |
) | |
# 5. Prepare timesteps | |
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) | |
image_seq_len = latents.shape[1] | |
mu = calculate_shift( | |
image_seq_len, | |
self.scheduler.config.base_image_seq_len, | |
self.scheduler.config.max_image_seq_len, | |
self.scheduler.config.base_shift, | |
self.scheduler.config.max_shift, | |
) | |
timesteps, num_inference_steps = retrieve_timesteps( | |
self.scheduler, | |
num_inference_steps, | |
device, | |
timesteps, | |
sigmas, | |
mu=mu, | |
) | |
self._num_timesteps = len(timesteps) | |
# Handle guidance | |
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32).expand(latents.shape[0]) if self.transformer.config.guidance_embeds else None | |
# 6. Denoising loop | |
for i, t in enumerate(timesteps): | |
if self.interrupt: | |
continue | |
timestep = t.expand(latents.shape[0]).to(latents.dtype) | |
noise_pred = self.transformer( | |
hidden_states=latents, | |
timestep=timestep / 1000, | |
guidance=guidance, | |
pooled_projections=pooled_prompt_embeds, | |
encoder_hidden_states=prompt_embeds, | |
txt_ids=text_ids, | |
img_ids=latent_image_ids, | |
joint_attention_kwargs=self.joint_attention_kwargs, | |
return_dict=False, | |
)[0] | |
# Yield intermediate result | |
latents_for_image = self._unpack_latents(latents, height, width, self.vae_scale_factor) | |
latents_for_image = (latents_for_image / self.vae.config.scaling_factor) + self.vae.config.shift_factor | |
image = self.vae.decode(latents_for_image, return_dict=False)[0] | |
yield self.image_processor.postprocess(image, output_type=output_type)[0] | |
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] | |
torch.cuda.empty_cache() | |
# Final image using good_vae | |
latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) | |
latents = (latents / good_vae.config.scaling_factor) + good_vae.config.shift_factor | |
image = good_vae.decode(latents, return_dict=False)[0] | |
self.maybe_free_model_hooks() | |
torch.cuda.empty_cache() | |
yield self.image_processor.postprocess(image, output_type=output_type)[0] | |
def log(msg): | |
print(f'{datetime.now().time()} {msg}') | |
Tile = tuple[int, int, Image.Image] | |
Tiles = list[tuple[int, int, list[Tile]]] | |
def conv_block(in_nc: int, out_nc: int) -> nn.Sequential: | |
return nn.Sequential( | |
nn.Conv2d(in_nc, out_nc, kernel_size=3, padding=1), | |
nn.LeakyReLU(negative_slope=0.2, inplace=True), | |
) | |
class ResidualDenseBlock_5C(nn.Module): | |
""" | |
Residual Dense Block | |
The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18) | |
Modified options that can be used: | |
- "Partial Convolution based Padding" arXiv:1811.11718 | |
- "Spectral normalization" arXiv:1802.05957 | |
- "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C. | |
{Rakotonirina} and A. {Rasoanaivo} | |
""" | |
def __init__(self, nf: int = 64, gc: int = 32) -> None: | |
super().__init__() # type: ignore[reportUnknownMemberType] | |
self.conv1 = conv_block(nf, gc) | |
self.conv2 = conv_block(nf + gc, gc) | |
self.conv3 = conv_block(nf + 2 * gc, gc) | |
self.conv4 = conv_block(nf + 3 * gc, gc) | |
# Wrapped in Sequential because of key in state dict. | |
self.conv5 = nn.Sequential(nn.Conv2d(nf + 4 * gc, nf, kernel_size=3, padding=1)) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
x1 = self.conv1(x) | |
x2 = self.conv2(torch.cat((x, x1), 1)) | |
x3 = self.conv3(torch.cat((x, x1, x2), 1)) | |
x4 = self.conv4(torch.cat((x, x1, x2, x3), 1)) | |
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) | |
return x5 * 0.2 + x | |
class RRDB(nn.Module): | |
""" | |
Residual in Residual Dense Block | |
(ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks) | |
""" | |
def __init__(self, nf: int) -> None: | |
super().__init__() # type: ignore[reportUnknownMemberType] | |
self.RDB1 = ResidualDenseBlock_5C(nf) | |
self.RDB2 = ResidualDenseBlock_5C(nf) | |
self.RDB3 = ResidualDenseBlock_5C(nf) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
out = self.RDB1(x) | |
out = self.RDB2(out) | |
out = self.RDB3(out) | |
return out * 0.2 + x | |
class Upsample2x(nn.Module): | |
"""Upsample 2x.""" | |
def __init__(self) -> None: | |
super().__init__() # type: ignore[reportUnknownMemberType] | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
return nn.functional.interpolate(x, scale_factor=2.0) # type: ignore | |
class ShortcutBlock(nn.Module): | |
"""Elementwise sum the output of a submodule to its input""" | |
def __init__(self, submodule: nn.Module) -> None: | |
super().__init__() # type: ignore[reportUnknownMemberType] | |
self.sub = submodule | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
return x + self.sub(x) | |
class RRDBNet(nn.Module): | |
def __init__(self, in_nc: int, out_nc: int, nf: int, nb: int) -> None: | |
super().__init__() # type: ignore[reportUnknownMemberType] | |
assert in_nc % 4 != 0 # in_nc is 3 | |
self.model = nn.Sequential( | |
nn.Conv2d(in_nc, nf, kernel_size=3, padding=1), | |
ShortcutBlock( | |
nn.Sequential( | |
*(RRDB(nf) for _ in range(nb)), | |
nn.Conv2d(nf, nf, kernel_size=3, padding=1), | |
) | |
), | |
Upsample2x(), | |
nn.Conv2d(nf, nf, kernel_size=3, padding=1), | |
nn.LeakyReLU(negative_slope=0.2, inplace=True), | |
Upsample2x(), | |
nn.Conv2d(nf, nf, kernel_size=3, padding=1), | |
nn.LeakyReLU(negative_slope=0.2, inplace=True), | |
nn.Conv2d(nf, nf, kernel_size=3, padding=1), | |
nn.LeakyReLU(negative_slope=0.2, inplace=True), | |
nn.Conv2d(nf, out_nc, kernel_size=3, padding=1), | |
) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
return self.model(x) | |
def infer_params(state_dict: dict[str, torch.Tensor]) -> tuple[int, int, int, int, int]: | |
# this code is adapted from https://github.com/victorca25/iNNfer | |
scale2x = 0 | |
scalemin = 6 | |
n_uplayer = 0 | |
out_nc = 0 | |
nb = 0 | |
for block in list(state_dict): | |
parts = block.split(".") | |
n_parts = len(parts) | |
if n_parts == 5 and parts[2] == "sub": | |
nb = int(parts[3]) | |
elif n_parts == 3: | |
part_num = int(parts[1]) | |
if part_num > scalemin and parts[0] == "model" and parts[2] == "weight": | |
scale2x += 1 | |
if part_num > n_uplayer: | |
n_uplayer = part_num | |
out_nc = state_dict[block].shape[0] | |
assert "conv1x1" not in block # no ESRGANPlus | |
nf = state_dict["model.0.weight"].shape[0] | |
in_nc = state_dict["model.0.weight"].shape[1] | |
scale = 2**scale2x | |
assert out_nc > 0 | |
assert nb > 0 | |
return in_nc, out_nc, nf, nb, scale # 3, 3, 64, 23, 4 | |
# https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/images.py#L64 | |
Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"]) | |
# adapted from https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/images.py#L67 | |
def split_grid(image: Image.Image, tile_w: int = 512, tile_h: int = 512, overlap: int = 64) -> Grid: | |
w = image.width | |
h = image.height | |
non_overlap_width = tile_w - overlap | |
non_overlap_height = tile_h - overlap | |
cols = max(1, math.ceil((w - overlap) / non_overlap_width)) | |
rows = max(1, math.ceil((h - overlap) / non_overlap_height)) | |
dx = (w - tile_w) / (cols - 1) if cols > 1 else 0 | |
dy = (h - tile_h) / (rows - 1) if rows > 1 else 0 | |
grid = Grid([], tile_w, tile_h, w, h, overlap) | |
for row in range(rows): | |
row_images: list[Tile] = [] | |
y1 = max(min(int(row * dy), h - tile_h), 0) | |
y2 = min(y1 + tile_h, h) | |
for col in range(cols): | |
x1 = max(min(int(col * dx), w - tile_w), 0) | |
x2 = min(x1 + tile_w, w) | |
tile = image.crop((x1, y1, x2, y2)) | |
row_images.append((x1, tile_w, tile)) | |
grid.tiles.append((y1, tile_h, row_images)) | |
return grid | |
# https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/images.py#L104 | |
def combine_grid(grid: Grid): | |
def make_mask_image(r: npt.NDArray[np.float32]) -> Image.Image: | |
r = r * 255 / grid.overlap | |
return Image.fromarray(r.astype(np.uint8), "L") | |
mask_w = make_mask_image( | |
np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0) | |
) | |
mask_h = make_mask_image( | |
np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1) | |
) | |
combined_image = Image.new("RGB", (grid.image_w, grid.image_h)) | |
for y, h, row in grid.tiles: | |
combined_row = Image.new("RGB", (grid.image_w, h)) | |
for x, w, tile in row: | |
if x == 0: | |
combined_row.paste(tile, (0, 0)) | |
continue | |
combined_row.paste(tile.crop((0, 0, grid.overlap, h)), (x, 0), mask=mask_w) | |
combined_row.paste(tile.crop((grid.overlap, 0, w, h)), (x + grid.overlap, 0)) | |
if y == 0: | |
combined_image.paste(combined_row, (0, 0)) | |
continue | |
combined_image.paste( | |
combined_row.crop((0, 0, combined_row.width, grid.overlap)), | |
(0, y), | |
mask=mask_h, | |
) | |
combined_image.paste( | |
combined_row.crop((0, grid.overlap, combined_row.width, h)), | |
(0, y + grid.overlap), | |
) | |
return combined_image | |
class UpscalerESRGAN: | |
def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype): | |
self.model_path = model_path | |
self.device = device | |
self.model = self.load_model(model_path) | |
self.to(device, dtype) | |
def __call__(self, img: Image.Image) -> Image.Image: | |
return self.upscale_without_tiling(img) | |
def to(self, device: torch.device, dtype: torch.dtype): | |
self.device = device | |
self.dtype = dtype | |
self.model.to(device=device, dtype=dtype) | |
def load_model(self, path: Path) -> RRDBNet: | |
filename = path | |
state_dict: dict[str, torch.Tensor] = torch.load(filename, weights_only=True, map_location=self.device) # type: ignore | |
in_nc, out_nc, nf, nb, upscale = infer_params(state_dict) | |
assert upscale == 4, "Only 4x upscaling is supported" | |
model = RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb) | |
model.load_state_dict(state_dict) | |
model.eval() | |
return model | |
def upscale_without_tiling(self, img: Image.Image) -> Image.Image: | |
img_np = np.array(img) | |
img_np = img_np[:, :, ::-1] | |
img_np = np.ascontiguousarray(np.transpose(img_np, (2, 0, 1))) / 255 | |
img_t = torch.from_numpy(img_np).float() # type: ignore | |
img_t = img_t.unsqueeze(0).to(device=self.device, dtype=self.dtype) | |
with torch.no_grad(): | |
output = self.model(img_t) | |
output = output.squeeze().float().cpu().clamp_(0, 1).numpy() | |
output = 255.0 * np.moveaxis(output, 0, 2) | |
output = output.astype(np.uint8) | |
output = output[:, :, ::-1] | |
return Image.fromarray(output, "RGB") | |
# https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/esrgan_model.py#L208 | |
def upscale_with_tiling(self, img: Image.Image) -> Image.Image: | |
img = img.convert("RGB") | |
grid = split_grid(img) | |
newtiles: Tiles = [] | |
scale_factor: int = 1 | |
for y, h, row in grid.tiles: | |
newrow: list[Tile] = [] | |
for tiledata in row: | |
x, w, tile = tiledata | |
output = self.upscale_without_tiling(tile) | |
scale_factor = output.width // tile.width | |
newrow.append((x * scale_factor, w * scale_factor, output)) | |
newtiles.append((y * scale_factor, h * scale_factor, newrow)) | |
newgrid = Grid( | |
newtiles, | |
grid.tile_w * scale_factor, | |
grid.tile_h * scale_factor, | |
grid.image_w * scale_factor, | |
grid.image_h * scale_factor, | |
grid.overlap * scale_factor, | |
) | |
output = combine_grid(newgrid) | |
return output | |
class ESRGANUpscalerCheckpoints(UpscalerCheckpoints): | |
esrgan: Path | |
class ESRGANUpscaler(MultiUpscaler): | |
def __init__( | |
self, | |
checkpoints: ESRGANUpscalerCheckpoints, | |
device: torch.device, | |
dtype: torch.dtype, | |
) -> None: | |
super().__init__(checkpoints=checkpoints, device=device, dtype=dtype) | |
self.esrgan = UpscalerESRGAN(checkpoints.esrgan, device=self.device, dtype=self.dtype) | |
def to(self, device: torch.device, dtype: torch.dtype): | |
self.esrgan.to(device=device, dtype=dtype) | |
self.sd = self.sd.to(device=device, dtype=dtype) | |
self.device = device | |
self.dtype = dtype | |
def pre_upscale(self, image: Image.Image, upscale_factor: float, **_: Any) -> Image.Image: | |
image = self.esrgan.upscale_with_tiling(image) | |
return super().pre_upscale(image=image, upscale_factor=upscale_factor / 4) | |
pillow_heif.register_heif_opener() | |
pillow_heif.register_avif_opener() | |
CHECKPOINTS = ESRGANUpscalerCheckpoints( | |
unet=Path( | |
hf_hub_download( | |
repo_id="refiners/juggernaut.reborn.sd1_5.unet", | |
filename="model.safetensors", | |
revision="347d14c3c782c4959cc4d1bb1e336d19f7dda4d2", | |
) | |
), | |
clip_text_encoder=Path( | |
hf_hub_download( | |
repo_id="refiners/juggernaut.reborn.sd1_5.text_encoder", | |
filename="model.safetensors", | |
revision="744ad6a5c0437ec02ad826df9f6ede102bb27481", | |
) | |
), | |
lda=Path( | |
hf_hub_download( | |
repo_id="refiners/juggernaut.reborn.sd1_5.autoencoder", | |
filename="model.safetensors", | |
revision="3c1aae3fc3e03e4a2b7e0fa42b62ebb64f1a4c19", | |
) | |
), | |
controlnet_tile=Path( | |
hf_hub_download( | |
repo_id="refiners/controlnet.sd1_5.tile", | |
filename="model.safetensors", | |
revision="48ced6ff8bfa873a8976fa467c3629a240643387", | |
) | |
), | |
esrgan=Path( | |
hf_hub_download( | |
repo_id="philz1337x/upscaler", | |
filename="4x-UltraSharp.pth", | |
revision="011deacac8270114eb7d2eeff4fe6fa9a837be70", | |
) | |
), | |
negative_embedding=Path( | |
hf_hub_download( | |
repo_id="philz1337x/embeddings", | |
filename="JuggernautNegative-neg.pt", | |
revision="203caa7e9cc2bc225031a4021f6ab1ded283454a", | |
) | |
), | |
negative_embedding_key="string_to_param.*", | |
loras={ | |
"more_details": Path( | |
hf_hub_download( | |
repo_id="philz1337x/loras", | |
filename="more_details.safetensors", | |
revision="a3802c0280c0d00c2ab18d37454a8744c44e474e", | |
) | |
), | |
"sdxl_render": Path( | |
hf_hub_download( | |
repo_id="philz1337x/loras", | |
filename="SDXLrender_v2.0.safetensors", | |
revision="a3802c0280c0d00c2ab18d37454a8744c44e474e", | |
) | |
) | |
} | |
) | |
device = DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
DTYPE = dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 | |
enhancer = ESRGANUpscaler(checkpoints=CHECKPOINTS, device="cpu", dtype=DTYPE) | |
# logging | |
warnings.filterwarnings("ignore") | |
root = logging.getLogger() | |
root.setLevel(logging.WARN) | |
handler = logging.StreamHandler(sys.stderr) | |
handler.setLevel(logging.WARN) | |
formatter = logging.Formatter('\n >>> [%(levelname)s] %(asctime)s %(name)s: %(message)s\n') | |
handler.setFormatter(formatter) | |
root.addHandler(handler) | |
# constant data | |
MAX_SEED = np.iinfo(np.int32).max | |
# precision data | |
seq=512 | |
image_steps=25 | |
img_accu=3.5 | |
# ui data | |
css="".join([""" | |
input, textarea, input::placeholder, textarea::placeholder { | |
text-align: center !important; | |
} | |
*, *::placeholder { | |
font-family: Suez One !important; | |
} | |
h1,h2,h3,h4,h5,h6 { | |
width: 100%; | |
text-align: center; | |
} | |
footer { | |
display: none !important; | |
} | |
.image-container { | |
aspect-ratio: 16/9 !important; | |
} | |
.dropdown-arrow { | |
display: none !important; | |
} | |
*:has(>.btn) { | |
display: flex; | |
justify-content: space-evenly; | |
align-items: center; | |
} | |
.btn { | |
display: flex; | |
} | |
"""]) | |
js=""" | |
function custom(){ | |
document.querySelector("div#prompt input").addEventListener("keydown",function(e){ | |
e.target.setAttribute("last_value",e.target.value); | |
}); | |
document.querySelector("div#prompt input").addEventListener("input",function(e){ | |
if( e.target.value.toString().match(/[^ a-zA-Z,]|( |,){2,}/gsm) ){ | |
e.target.value = e.target.getAttribute("last_value"); | |
e.target.removeAttribute("last_value"); | |
} | |
}); | |
document.querySelector("div#prompt2 input").addEventListener("keydown",function(e){ | |
e.target.setAttribute("last_value",e.target.value); | |
}); | |
document.querySelector("div#prompt2 input").addEventListener("input",function(e){ | |
if( e.target.value.toString().match(/[^ a-zA-Z,]|( |,){2,}/gsm) ){ | |
e.target.value = e.target.getAttribute("last_value"); | |
e.target.removeAttribute("last_value"); | |
} | |
}); | |
} | |
""" | |
# torch pipes | |
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device) | |
#good_vae = AutoencoderKL.from_pretrained("ostris/Flex.1-alpha", subfolder="vae", torch_dtype=dtype).to(device) | |
image_pipe = DiffusionPipeline.from_pretrained("ostris/Flex.1-alpha", torch_dtype=dtype, vae=taef1).to(device) | |
#image_pipe.enable_model_cpu_offload() | |
torch.cuda.empty_cache() | |
#image_pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(image_pipe) | |
# functionality | |
def upscaler( | |
input_image: Image.Image, | |
prompt: str = "Hyper realistic photography, Natural visual content.", | |
negative_prompt: str = "Distorted, Discontinuous, Blurry, Doll-Like, Overly-Plastic, Low-Quality, Painted, Smoothed, Artificial, Phony, Gaudy, Digital Effects.", | |
seed: int = random.randint(0, MAX_SEED), | |
upscale_factor: int = 2, | |
controlnet_scale: float = 0.6, | |
controlnet_decay: float = 1.0, | |
condition_scale: int = 6, | |
tile_width: int = 112, | |
tile_height: int = 144, | |
denoise_strength: float = 0.35, | |
num_inference_steps: int = 15, | |
solver: str = "DDIM", | |
) -> Image.Image: | |
log(f'CALL upscaler') | |
if not working: | |
working = True | |
manual_seed(seed) | |
solver_type: type[Solver] = getattr(solvers, solver) | |
log(f'DBG upscaler 1') | |
enhanced_image = enhancer.upscale( | |
image=input_image, | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
upscale_factor=upscale_factor, | |
controlnet_scale=controlnet_scale, | |
controlnet_scale_decay=controlnet_decay, | |
condition_scale=condition_scale, | |
tile_size=(tile_height, tile_width), | |
denoise_strength=denoise_strength, | |
num_inference_steps=num_inference_steps, | |
loras_scale={"more_details": 0.5, "sdxl_render": 1.0}, | |
solver_type=solver_type, | |
) | |
_HEIGHT_ = _HEIGHT_ * upscale_factor | |
_WIDTH_ = _WIDTH_ * upscale_factor | |
log(f'RET upscaler') | |
working = False | |
return enhanced_image | |
def get_tensor_length(tensor): | |
nums = list(tensor.size()) | |
ret = 1 | |
for num in nums: | |
ret = ret * num | |
return ret | |
def _summarize(text): | |
log(f'CALL _summarize') | |
prefix = "summarize: " | |
toks = tokenizer.encode( prefix + text, return_tensors="pt", truncation=False) | |
gen = model.generate( | |
toks, | |
length_penalty=0.9, | |
num_beams=8, | |
early_stopping=True, | |
max_length=512 | |
) | |
ret = tokenizer.decode(gen[0], skip_special_tokens=True) | |
log(f'RET _summarize with ret as {ret}') | |
return ret | |
def summarize(text, max_len=400): | |
log(f'CALL summarize') | |
words = text.split() | |
words_length = len(words) | |
if words_length >= 510: | |
while words_length >= 510: | |
words = text.split() | |
sum = _summarize( | |
" ".join(words[0:510]) | |
) + " ".join(words[510:]) | |
if summ == text: | |
return text | |
text = summ | |
words_length = len(text.split()) | |
while len(text) > max_len: | |
summ = _summarize(text) | |
if summ == text: | |
return text | |
text = summ | |
log(f'RET summarize with text as {text}') | |
return text | |
def generate_random_string(length): | |
characters = str(ascii_letters + digits) | |
return ''.join(random.choice(characters) for _ in range(length)) | |
def add_song_cover_text(img,top_title=None,bottom_title=None): | |
if not working: | |
working = True | |
h = _HEIGHT_ | |
w = _WIDTH_ | |
draw = ImageDraw.Draw(img,mode="RGBA") | |
labels_distance = 1/3 | |
if top_title: | |
rows = len(song.split("\n")) | |
textheight=min(math.ceil( w / 10 ), math.ceil( h / 5 )) | |
font = ImageFont.truetype(r"Alef-Bold.ttf", textheight) | |
textwidth = draw.textlength(top_title,font) | |
x = math.ceil((w - textwidth) / 2) | |
y = h - (textheight * rows / 2) - (h / 2) | |
y = math.ceil(y - (h / 2 * labels_distance)) | |
draw.text((x, y), top_title, (255,255,255), font=font, spacing=2, stroke_width=math.ceil(textheight/20), stroke_fill=(0,0,0)) | |
if bottom_title: | |
rows = len(artist.split("\n")) | |
textheight=min(math.ceil( w / 10 ), math.ceil( h / 5 )) | |
font = ImageFont.truetype(r"Alef-Bold.ttf", textheight) | |
textwidth = draw.textlength(bottom_title,font) | |
x = math.ceil((w - textwidth) / 2) | |
y = h - (textheight * rows / 2) - (h / 2) | |
y = math.ceil(y + (h / 2 * labels_distance)) | |
draw.text((x, y), bottom_title, (0,0,0), font=font, spacing=2, stroke_width=math.ceil(textheight/20), stroke_fill=(255,255,255)) | |
working = False | |
return img | |
google_translate_endpoint = "https://translate.google.com/m" | |
language_codes = { | |
"afrikaans": "af", | |
"albanian": "sq", | |
"amharic": "am", | |
"arabic": "ar", | |
"armenian": "hy", | |
"assamese": "as", | |
"aymara": "ay", | |
"azerbaijani": "az", | |
"bambara": "bm", | |
"basque": "eu", | |
"belarusian": "be", | |
"bengali": "bn", | |
"bhojpuri": "bho", | |
"bosnian": "bs", | |
"bulgarian": "bg", | |
"catalan": "ca", | |
"cebuano": "ceb", | |
"chichewa": "ny", | |
"chinese (simplified)": "zh-CN", | |
"chinese (traditional)": "zh-TW", | |
"corsican": "co", | |
"croatian": "hr", | |
"czech": "cs", | |
"danish": "da", | |
"dhivehi": "dv", | |
"dogri": "doi", | |
"dutch": "nl", | |
"english": "en", | |
"esperanto": "eo", | |
"estonian": "et", | |
"ewe": "ee", | |
"filipino": "tl", | |
"finnish": "fi", | |
"french": "fr", | |
"frisian": "fy", | |
"galician": "gl", | |
"georgian": "ka", | |
"german": "de", | |
"greek": "el", | |
"guarani": "gn", | |
"gujarati": "gu", | |
"haitian creole": "ht", | |
"hausa": "ha", | |
"hawaiian": "haw", | |
"hebrew": "iw", | |
"hindi": "hi", | |
"hmong": "hmn", | |
"hungarian": "hu", | |
"icelandic": "is", | |
"igbo": "ig", | |
"ilocano": "ilo", | |
"indonesian": "id", | |
"irish": "ga", | |
"italian": "it", | |
"japanese": "ja", | |
"javanese": "jw", | |
"kannada": "kn", | |
"kazakh": "kk", | |
"khmer": "km", | |
"kinyarwanda": "rw", | |
"konkani": "gom", | |
"korean": "ko", | |
"krio": "kri", | |
"kurdish (kurmanji)": "ku", | |
"kurdish (sorani)": "ckb", | |
"kyrgyz": "ky", | |
"lao": "lo", | |
"latin": "la", | |
"latvian": "lv", | |
"lingala": "ln", | |
"lithuanian": "lt", | |
"luganda": "lg", | |
"luxembourgish": "lb", | |
"macedonian": "mk", | |
"maithili": "mai", | |
"malagasy": "mg", | |
"malay": "ms", | |
"malayalam": "ml", | |
"maltese": "mt", | |
"maori": "mi", | |
"marathi": "mr", | |
"meiteilon (manipuri)": "mni-Mtei", | |
"mizo": "lus", | |
"mongolian": "mn", | |
"myanmar": "my", | |
"nepali": "ne", | |
"norwegian": "no", | |
"odia (oriya)": "or", | |
"oromo": "om", | |
"pashto": "ps", | |
"persian": "fa", | |
"polish": "pl", | |
"portuguese": "pt", | |
"punjabi": "pa", | |
"quechua": "qu", | |
"romanian": "ro", | |
"russian": "ru", | |
"samoan": "sm", | |
"sanskrit": "sa", | |
"scots gaelic": "gd", | |
"sepedi": "nso", | |
"serbian": "sr", | |
"sesotho": "st", | |
"shona": "sn", | |
"sindhi": "sd", | |
"sinhala": "si", | |
"slovak": "sk", | |
"slovenian": "sl", | |
"somali": "so", | |
"spanish": "es", | |
"sundanese": "su", | |
"swahili": "sw", | |
"swedish": "sv", | |
"tajik": "tg", | |
"tamil": "ta", | |
"tatar": "tt", | |
"telugu": "te", | |
"thai": "th", | |
"tigrinya": "ti", | |
"tsonga": "ts", | |
"turkish": "tr", | |
"turkmen": "tk", | |
"twi": "ak", | |
"ukrainian": "uk", | |
"urdu": "ur", | |
"uyghur": "ug", | |
"uzbek": "uz", | |
"vietnamese": "vi", | |
"welsh": "cy", | |
"xhosa": "xh", | |
"yiddish": "yi", | |
"yoruba": "yo", | |
"zulu": "zu", | |
} | |
class BaseError(Exception): | |
""" | |
base error structure class | |
""" | |
def __init__(self, val, message): | |
""" | |
@param val: actual value | |
@param message: message shown to the user | |
""" | |
self.val = val | |
self.message = message | |
super().__init__() | |
def __str__(self): | |
return "{} --> {}".format(self.val, self.message) | |
class LanguageNotSupportedException(BaseError): | |
""" | |
exception thrown if the user uses a language | |
that is not supported by the deep_translator | |
""" | |
def __init__( | |
self, val, message="There is no support for the chosen language" | |
): | |
super().__init__(val, message) | |
class NotValidPayload(BaseError): | |
""" | |
exception thrown if the user enters an invalid payload | |
""" | |
def __init__( | |
self, | |
val, | |
message="text must be a valid text with maximum 5000 character," | |
"otherwise it cannot be translated", | |
): | |
super(NotValidPayload, self).__init__(val, message) | |
class InvalidSourceOrTargetLanguage(BaseError): | |
""" | |
exception thrown if the user enters an invalid payload | |
""" | |
def __init__(self, val, message="Invalid source or target language!"): | |
super(InvalidSourceOrTargetLanguage, self).__init__(val, message) | |
class TranslationNotFound(BaseError): | |
""" | |
exception thrown if no translation was found for the text provided by the user | |
""" | |
def __init__( | |
self, | |
val, | |
message="No translation was found using the current translator. Try another translator?", | |
): | |
super(TranslationNotFound, self).__init__(val, message) | |
class ElementNotFoundInGetRequest(BaseError): | |
""" | |
exception thrown if the html element was not found in the body parsed by beautifulsoup | |
""" | |
def __init__( | |
self, val, message="Required element was not found in the API response" | |
): | |
super(ElementNotFoundInGetRequest, self).__init__(val, message) | |
class NotValidLength(BaseError): | |
""" | |
exception thrown if the provided text exceed the length limit of the translator | |
""" | |
def __init__(self, val, min_chars, max_chars): | |
message = f"Text length need to be between {min_chars} and {max_chars} characters" | |
super(NotValidLength, self).__init__(val, message) | |
class RequestError(Exception): | |
""" | |
exception thrown if an error occurred during the request call, e.g a connection problem. | |
""" | |
def __init__( | |
self, | |
message="Request exception can happen due to an api connection error. " | |
"Please check your connection and try again", | |
): | |
self.message = message | |
def __str__(self): | |
return self.message | |
class TooManyRequests(Exception): | |
""" | |
exception thrown if an error occurred during the request call, e.g a connection problem. | |
""" | |
def __init__( | |
self, | |
message="Server Error: You made too many requests to the server." | |
"According to google, you are allowed to make 5 requests per second" | |
"and up to 200k requests per day. You can wait and try again later or" | |
"you can try the translate_batch function", | |
): | |
self.message = message | |
def __str__(self): | |
return self.message | |
class ServerException(Exception): | |
""" | |
Default YandexTranslate exception from the official website | |
""" | |
errors = { | |
400: "ERR_BAD_REQUEST", | |
401: "ERR_KEY_INVALID", | |
402: "ERR_KEY_BLOCKED", | |
403: "ERR_DAILY_REQ_LIMIT_EXCEEDED", | |
404: "ERR_DAILY_CHAR_LIMIT_EXCEEDED", | |
413: "ERR_TEXT_TOO_LONG", | |
429: "ERR_TOO_MANY_REQUESTS", | |
422: "ERR_UNPROCESSABLE_TEXT", | |
500: "ERR_INTERNAL_SERVER_ERROR", | |
501: "ERR_LANG_NOT_SUPPORTED", | |
503: "ERR_SERVICE_NOT_AVAIBLE", | |
} | |
def __init__(self, status_code, *args): | |
message = self.errors.get(status_code, "API server error") | |
super(ServerException, self).__init__(message, *args) | |
def is_empty(text: str) -> bool: | |
return text == "" | |
def request_failed(status_code: int) -> bool: | |
"""Check if a request has failed or not. | |
A request is considered successfull if the status code is in the 2** range. | |
Args: | |
status_code (int): status code of the request | |
Returns: | |
bool: indicates request failure | |
""" | |
if status_code > 299 or status_code < 200: | |
return True | |
return False | |
def is_input_valid( | |
text: str, min_chars: int = 0, max_chars: Optional[int] = None | |
) -> bool: | |
""" | |
validate the target text to translate | |
@param min_chars: min characters | |
@param max_chars: max characters | |
@param text: text to translate | |
@return: bool | |
""" | |
if not isinstance(text, str): | |
raise NotValidPayload(text) | |
if max_chars and (not min_chars <= len(text) < max_chars): | |
raise NotValidLength(text, min_chars, max_chars) | |
return True | |
class BaseTranslator(ABC): | |
""" | |
Abstract class that serve as a base translator for other different translators | |
""" | |
def __init__( | |
self, | |
base_url: str = None, | |
languages: dict = language_codes, | |
source: str = "auto", | |
target: str = "en", | |
payload_key: Optional[str] = None, | |
element_tag: Optional[str] = None, | |
element_query: Optional[dict] = None, | |
**url_params, | |
): | |
""" | |
@param source: source language to translate from | |
@param target: target language to translate to | |
""" | |
self._base_url = base_url | |
self._languages = languages | |
self._supported_languages = list(self._languages.keys()) | |
if not source: | |
raise InvalidSourceOrTargetLanguage(source) | |
if not target: | |
raise InvalidSourceOrTargetLanguage(target) | |
self._source, self._target = self._map_language_to_code(source, target) | |
self._url_params = url_params | |
self._element_tag = element_tag | |
self._element_query = element_query | |
self.payload_key = payload_key | |
super().__init__() | |
def source(self): | |
return self._source | |
def source(self, lang): | |
self._source = lang | |
def target(self): | |
return self._target | |
def target(self, lang): | |
self._target = lang | |
def _type(self): | |
return self.__class__.__name__ | |
def _map_language_to_code(self, *languages): | |
""" | |
map language to its corresponding code (abbreviation) if the language was passed | |
by its full name by the user | |
@param languages: list of languages | |
@return: mapped value of the language or raise an exception if the language is | |
not supported | |
""" | |
for language in languages: | |
if language in self._languages.values() or language == "auto": | |
yield language | |
elif language in self._languages.keys(): | |
yield self._languages[language] | |
else: | |
raise LanguageNotSupportedException( | |
language, | |
message=f"No support for the provided language.\n" | |
f"Please select on of the supported languages:\n" | |
f"{self._languages}", | |
) | |
def _same_source_target(self) -> bool: | |
return self._source == self._target | |
def get_supported_languages( | |
self, as_dict: bool = False, **kwargs | |
) -> Union[list, dict]: | |
""" | |
return the supported languages by the Google translator | |
@param as_dict: if True, the languages will be returned as a dictionary | |
mapping languages to their abbreviations | |
@return: list or dict | |
""" | |
return self._supported_languages if not as_dict else self._languages | |
def is_language_supported(self, language: str, **kwargs) -> bool: | |
""" | |
check if the language is supported by the translator | |
@param language: a string for 1 language | |
@return: bool or raise an Exception | |
""" | |
if ( | |
language == "auto" | |
or language in self._languages.keys() | |
or language in self._languages.values() | |
): | |
return True | |
else: | |
return False | |
def translate(self, text: str, **kwargs) -> str: | |
""" | |
translate a text using a translator under the hood and return | |
the translated text | |
@param text: text to translate | |
@param kwargs: additional arguments | |
@return: str | |
""" | |
return NotImplemented("You need to implement the translate method!") | |
def _read_docx(self, f: str): | |
import docx2txt | |
return docx2txt.process(f) | |
def _read_pdf(self, f: str): | |
import pypdf | |
reader = pypdf.PdfReader(f) | |
page = reader.pages[0] | |
return page.extract_text() | |
def _translate_file(self, path: str, **kwargs) -> str: | |
""" | |
translate directly from file | |
@param path: path to the target file | |
@type path: str | |
@param kwargs: additional args | |
@return: str | |
""" | |
if not isinstance(path, Path): | |
path = Path(path) | |
if not path.exists(): | |
print("Path to the file is wrong!") | |
exit(1) | |
ext = path.suffix | |
if ext == ".docx": | |
text = self._read_docx(f=str(path)) | |
elif ext == ".pdf": | |
text = self._read_pdf(f=str(path)) | |
else: | |
with open(path, "r", encoding="utf-8") as f: | |
text = f.read().strip() | |
return self.translate(text) | |
def _translate_batch(self, batch: List[str], **kwargs) -> List[str]: | |
""" | |
translate a list of texts | |
@param batch: list of texts you want to translate | |
@return: list of translations | |
""" | |
if not batch: | |
raise Exception("Enter your text list that you want to translate") | |
arr = [] | |
for i, text in enumerate(batch): | |
translated = self.translate(text, **kwargs) | |
arr.append(translated) | |
return arr | |
class GoogleTranslator(BaseTranslator): | |
""" | |
class that wraps functions, which use Google Translate under the hood to translate text(s) | |
""" | |
def __init__( | |
self, | |
source: str = "auto", | |
target: str = "en", | |
proxies: Optional[dict] = None, | |
**kwargs | |
): | |
""" | |
@param source: source language to translate from | |
@param target: target language to translate to | |
""" | |
self.proxies = proxies | |
super().__init__( | |
base_url=google_translate_endpoint, | |
source=source, | |
target=target, | |
element_tag="div", | |
element_query={"class": "t0"}, | |
payload_key="q", # key of text in the url | |
**kwargs | |
) | |
self._alt_element_query = {"class": "result-container"} | |
def translate(self, text: str, **kwargs) -> str: | |
""" | |
function to translate a text | |
@param text: desired text to translate | |
@return: str: translated text | |
""" | |
if is_input_valid(text, max_chars=1000): | |
text = text.strip() | |
if self._same_source_target() or is_empty(text): | |
return text | |
self._url_params["tl"] = self._target | |
self._url_params["sl"] = self._source | |
if self.payload_key: | |
self._url_params[self.payload_key] = text | |
response = requests.get( | |
self._base_url, params=self._url_params, proxies=self.proxies | |
) | |
if response.status_code == 429: | |
raise TooManyRequests() | |
if request_failed(status_code=response.status_code): | |
raise RequestError() | |
soup = BeautifulSoup(response.text, "html.parser") | |
element = soup.find(self._element_tag, self._element_query) | |
response.close() | |
if not element: | |
element = soup.find(self._element_tag, self._alt_element_query) | |
if not element: | |
raise TranslationNotFound(text) | |
if element.get_text(strip=True) == text.strip(): | |
to_translate_alpha = "".join( | |
ch for ch in text.strip() if ch.isalnum() | |
) | |
translated_alpha = "".join( | |
ch for ch in element.get_text(strip=True) if ch.isalnum() | |
) | |
if ( | |
to_translate_alpha | |
and translated_alpha | |
and to_translate_alpha == translated_alpha | |
): | |
self._url_params["tl"] = self._target | |
if "hl" not in self._url_params: | |
return text.strip() | |
del self._url_params["hl"] | |
return self.translate(text) | |
else: | |
return element.get_text(strip=True) | |
def translate_file(self, path: str, **kwargs) -> str: | |
""" | |
translate directly from file | |
@param path: path to the target file | |
@type path: str | |
@param kwargs: additional args | |
@return: str | |
""" | |
return self._translate_file(path, **kwargs) | |
def translate_batch(self, batch: List[str], **kwargs) -> List[str]: | |
""" | |
translate a list of texts | |
@param batch: list of texts you want to translate | |
@return: list of translations | |
""" | |
return self._translate_batch(batch, **kwargs) | |
def translate(txt,to_lang="en",from_lang="auto"): | |
log(f'CALL translate') | |
if len(txt) == 0: | |
print("Translated text is empty. Skipping translation...") | |
return txt.strip().lower() | |
if from_lang == to_lang or get_language(txt) == to_lang: | |
print("Same languages. Skipping translation...") | |
return txt.strip().lower() | |
translator = GoogleTranslator(from_lang=from_lang,to_lang=to_lang) | |
translation = "" | |
if len(txt) > 1000: | |
words = txt.split() | |
while len(words) > 0: | |
chunk = "" | |
while len(words) > 0 and len(chunk) < 1000: | |
chunk = chunk + " " + words[0] | |
words = words[1:] | |
if len(chunk) > 1000: | |
_words = chunk.split() | |
words = [_words[-1], *words] | |
chunk = " ".join(_words[:-1]) | |
translation = translation + " " + translator.translate(chunk) | |
else: | |
translation = translator.translate(txt) | |
translation = translation.strip() | |
log(f'RET translate with translation as {translation}') | |
return translation.lower() | |
def handle_generation(h,w,d): | |
log(f'CALL handle_generate') | |
if not working: | |
working = True | |
d = re.sub(r",( ){1,}",". ",d) | |
d_lines = re.split(r"([\n]){1,}", d) | |
for line_index in range(len(d_lines)): | |
d_lines[line_index] = d_lines[line_index].strip() | |
if re.sub(r'[\.]$', '', d_lines[line_index]) == d_lines[line_index]: | |
d_lines[line_index] = d_lines[line_index].strip() + "." | |
d = " ".join(d_lines) | |
pos_d = re.sub(r"([ \t]){1,}", " ", d).lower().strip() | |
pos_d = pos_d if pos_d == "" else summarize(translate(pos_d)) | |
pos_d = re.sub(r"([ \t]){1,}", " ", pos_d).lower().strip() | |
neg = f"Textual, Text, Distorted, Fake, Discontinuous, Blurry, Doll-Like, Overly Plastic, Low Quality, Paint, Smoothed, Artificial, Phony, Gaudy, Digital Effects." | |
q = "\"" | |
pos = f'HQ Hyper-realistic professional photograph{ pos_d if pos_d == "" else ": " + pos_d }.' | |
print(f""" | |
Positive: {pos} | |
Negative: {neg} | |
""") | |
img = image_pipe( | |
prompt=pos, | |
negative_prompt=neg, | |
height=h, | |
width=w, | |
output_type="pil", | |
guidance_scale=img_accu, | |
num_images_per_prompt=1, | |
num_inference_steps=image_steps, | |
max_sequence_length=seq, | |
generator=torch.Generator(device).manual_seed(random.randint(0, MAX_SEED)) | |
) | |
working = False | |
_HEIGHT_ = h | |
_WIDTH_ = w | |
return img | |
# entry | |
if __name__ == "__main__": | |
with gr.Blocks(theme=gr.themes.Citrus(),css=css) as demo: | |
gr.Markdown(f""" | |
# Text-to-Image generator | |
""") | |
gr.Markdown(f""" | |
### Realistic. Upscalable. Multilingual. | |
""") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
height = gr.Slider( | |
label="Height (px)", | |
minimum=512, | |
maximum=2048, | |
step=16, | |
value=1024, | |
) | |
width = gr.Slider( | |
label="Width (px)", | |
minimum=512, | |
maximum=2048, | |
step=16, | |
value=1024, | |
) | |
run = gr.Button("Generate",elem_classes="btn") | |
top = gr.Textbox( | |
placeholder="Top title", | |
value="", | |
container=False, | |
max_lines=1 | |
) | |
bottom = gr.Textbox( | |
placeholder="Bottom title", | |
value="", | |
container=False, | |
max_lines=1 | |
) | |
add_titles = gr.Button("Add title(s)",elem_classes="btn") | |
data = gr.Textbox( | |
placeholder="Input data", | |
value="", | |
container=False, | |
max_lines=100 | |
) | |
with gr.Column(): | |
cover = gr.Image(interactive=False,container=False,elem_classes="image-container", label="Result", show_label=True, type='pil', show_share_button=False) | |
upscale_now = gr.Button("Upscale",elem_classes="btn") | |
gr.on( | |
triggers=[run.click], | |
fn=handle_generation, | |
inputs=[height,width,data], | |
outputs=[cover] | |
) | |
upscale_now.click( | |
fn=upscaler, | |
inputs=[cover], | |
outputs=[cover] | |
) | |
add_titles.click( | |
fn=add_song_cover_text, | |
inputs=[cover,top,bottom], | |
outputs=[cover] | |
) | |
demo.queue().launch() | |