from langdetect import detect as get_language from collections import namedtuple from inspect import signature import os import subprocess import logging import re import random from string import ascii_letters, digits, punctuation import requests import sys import warnings import time import asyncio import math from pathlib import Path from functools import partial from dataclasses import dataclass from typing import Any import pillow_heif import spaces import numpy as np import numpy.typing as npt import torch from torch import nn import gradio as gr from lxml.html import fromstring from huggingface_hub import hf_hub_download from safetensors.torch import load_file, save_file from diffusers import FluxPipeline from PIL import Image, ImageDraw, ImageFont from transformers import pipeline, T5ForConditionalGeneration, T5Tokenizer from refiners.fluxion.utils import manual_seed from refiners.foundationals.latent_diffusion import Solver, solvers from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_upscaler import ( MultiUpscaler, UpscalerCheckpoints, ) from datetime import datetime model = T5ForConditionalGeneration.from_pretrained("t5-large") tokenizer = T5Tokenizer.from_pretrained("t5-large") def log(msg): print(f'{datetime.now().time()} {msg}') Tile = tuple[int, int, Image.Image] Tiles = list[tuple[int, int, list[Tile]]] def conv_block(in_nc: int, out_nc: int) -> nn.Sequential: return nn.Sequential( nn.Conv2d(in_nc, out_nc, kernel_size=3, padding=1), nn.LeakyReLU(negative_slope=0.2, inplace=True), ) class ResidualDenseBlock_5C(nn.Module): """ Residual Dense Block The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18) Modified options that can be used: - "Partial Convolution based Padding" arXiv:1811.11718 - "Spectral normalization" arXiv:1802.05957 - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C. {Rakotonirina} and A. {Rasoanaivo} """ def __init__(self, nf: int = 64, gc: int = 32) -> None: super().__init__() # type: ignore[reportUnknownMemberType] self.conv1 = conv_block(nf, gc) self.conv2 = conv_block(nf + gc, gc) self.conv3 = conv_block(nf + 2 * gc, gc) self.conv4 = conv_block(nf + 3 * gc, gc) # Wrapped in Sequential because of key in state dict. self.conv5 = nn.Sequential(nn.Conv2d(nf + 4 * gc, nf, kernel_size=3, padding=1)) def forward(self, x: torch.Tensor) -> torch.Tensor: x1 = self.conv1(x) x2 = self.conv2(torch.cat((x, x1), 1)) x3 = self.conv3(torch.cat((x, x1, x2), 1)) x4 = self.conv4(torch.cat((x, x1, x2, x3), 1)) x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) return x5 * 0.2 + x class RRDB(nn.Module): """ Residual in Residual Dense Block (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks) """ def __init__(self, nf: int) -> None: super().__init__() # type: ignore[reportUnknownMemberType] self.RDB1 = ResidualDenseBlock_5C(nf) self.RDB2 = ResidualDenseBlock_5C(nf) self.RDB3 = ResidualDenseBlock_5C(nf) def forward(self, x: torch.Tensor) -> torch.Tensor: out = self.RDB1(x) out = self.RDB2(out) out = self.RDB3(out) return out * 0.2 + x class Upsample2x(nn.Module): """Upsample 2x.""" def __init__(self) -> None: super().__init__() # type: ignore[reportUnknownMemberType] def forward(self, x: torch.Tensor) -> torch.Tensor: return nn.functional.interpolate(x, scale_factor=2.0) # type: ignore class ShortcutBlock(nn.Module): """Elementwise sum the output of a submodule to its input""" def __init__(self, submodule: nn.Module) -> None: super().__init__() # type: ignore[reportUnknownMemberType] self.sub = submodule def forward(self, x: torch.Tensor) -> torch.Tensor: return x + self.sub(x) class RRDBNet(nn.Module): def __init__(self, in_nc: int, out_nc: int, nf: int, nb: int) -> None: super().__init__() # type: ignore[reportUnknownMemberType] assert in_nc % 4 != 0 # in_nc is 3 self.model = nn.Sequential( nn.Conv2d(in_nc, nf, kernel_size=3, padding=1), ShortcutBlock( nn.Sequential( *(RRDB(nf) for _ in range(nb)), nn.Conv2d(nf, nf, kernel_size=3, padding=1), ) ), Upsample2x(), nn.Conv2d(nf, nf, kernel_size=3, padding=1), nn.LeakyReLU(negative_slope=0.2, inplace=True), Upsample2x(), nn.Conv2d(nf, nf, kernel_size=3, padding=1), nn.LeakyReLU(negative_slope=0.2, inplace=True), nn.Conv2d(nf, nf, kernel_size=3, padding=1), nn.LeakyReLU(negative_slope=0.2, inplace=True), nn.Conv2d(nf, out_nc, kernel_size=3, padding=1), ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.model(x) def infer_params(state_dict: dict[str, torch.Tensor]) -> tuple[int, int, int, int, int]: # this code is adapted from https://github.com/victorca25/iNNfer scale2x = 0 scalemin = 6 n_uplayer = 0 out_nc = 0 nb = 0 for block in list(state_dict): parts = block.split(".") n_parts = len(parts) if n_parts == 5 and parts[2] == "sub": nb = int(parts[3]) elif n_parts == 3: part_num = int(parts[1]) if part_num > scalemin and parts[0] == "model" and parts[2] == "weight": scale2x += 1 if part_num > n_uplayer: n_uplayer = part_num out_nc = state_dict[block].shape[0] assert "conv1x1" not in block # no ESRGANPlus nf = state_dict["model.0.weight"].shape[0] in_nc = state_dict["model.0.weight"].shape[1] scale = 2**scale2x assert out_nc > 0 assert nb > 0 return in_nc, out_nc, nf, nb, scale # 3, 3, 64, 23, 4 # https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/images.py#L64 Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"]) # adapted from https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/images.py#L67 def split_grid(image: Image.Image, tile_w: int = 512, tile_h: int = 512, overlap: int = 64) -> Grid: w = image.width h = image.height non_overlap_width = tile_w - overlap non_overlap_height = tile_h - overlap cols = max(1, math.ceil((w - overlap) / non_overlap_width)) rows = max(1, math.ceil((h - overlap) / non_overlap_height)) dx = (w - tile_w) / (cols - 1) if cols > 1 else 0 dy = (h - tile_h) / (rows - 1) if rows > 1 else 0 grid = Grid([], tile_w, tile_h, w, h, overlap) for row in range(rows): row_images: list[Tile] = [] y1 = max(min(int(row * dy), h - tile_h), 0) y2 = min(y1 + tile_h, h) for col in range(cols): x1 = max(min(int(col * dx), w - tile_w), 0) x2 = min(x1 + tile_w, w) tile = image.crop((x1, y1, x2, y2)) row_images.append((x1, tile_w, tile)) grid.tiles.append((y1, tile_h, row_images)) return grid # https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/images.py#L104 def combine_grid(grid: Grid): def make_mask_image(r: npt.NDArray[np.float32]) -> Image.Image: r = r * 255 / grid.overlap return Image.fromarray(r.astype(np.uint8), "L") mask_w = make_mask_image( np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0) ) mask_h = make_mask_image( np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1) ) combined_image = Image.new("RGB", (grid.image_w, grid.image_h)) for y, h, row in grid.tiles: combined_row = Image.new("RGB", (grid.image_w, h)) for x, w, tile in row: if x == 0: combined_row.paste(tile, (0, 0)) continue combined_row.paste(tile.crop((0, 0, grid.overlap, h)), (x, 0), mask=mask_w) combined_row.paste(tile.crop((grid.overlap, 0, w, h)), (x + grid.overlap, 0)) if y == 0: combined_image.paste(combined_row, (0, 0)) continue combined_image.paste( combined_row.crop((0, 0, combined_row.width, grid.overlap)), (0, y), mask=mask_h, ) combined_image.paste( combined_row.crop((0, grid.overlap, combined_row.width, h)), (0, y + grid.overlap), ) return combined_image class UpscalerESRGAN: def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype): self.model_path = model_path self.device = device self.model = self.load_model(model_path) self.to(device, dtype) def __call__(self, img: Image.Image) -> Image.Image: return self.upscale_without_tiling(img) def to(self, device: torch.device, dtype: torch.dtype): self.device = device self.dtype = dtype self.model.to(device=device, dtype=dtype) def load_model(self, path: Path) -> RRDBNet: filename = path state_dict: dict[str, torch.Tensor] = torch.load(filename, weights_only=True, map_location=self.device) # type: ignore in_nc, out_nc, nf, nb, upscale = infer_params(state_dict) assert upscale == 4, "Only 4x upscaling is supported" model = RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb) model.load_state_dict(state_dict) model.eval() return model def upscale_without_tiling(self, img: Image.Image) -> Image.Image: img_np = np.array(img) img_np = img_np[:, :, ::-1] img_np = np.ascontiguousarray(np.transpose(img_np, (2, 0, 1))) / 255 img_t = torch.from_numpy(img_np).float() # type: ignore img_t = img_t.unsqueeze(0).to(device=self.device, dtype=self.dtype) with torch.no_grad(): output = self.model(img_t) output = output.squeeze().float().cpu().clamp_(0, 1).numpy() output = 255.0 * np.moveaxis(output, 0, 2) output = output.astype(np.uint8) output = output[:, :, ::-1] return Image.fromarray(output, "RGB") # https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/esrgan_model.py#L208 def upscale_with_tiling(self, img: Image.Image) -> Image.Image: img = img.convert("RGB") grid = split_grid(img) newtiles: Tiles = [] scale_factor: int = 1 for y, h, row in grid.tiles: newrow: list[Tile] = [] for tiledata in row: x, w, tile = tiledata output = self.upscale_without_tiling(tile) scale_factor = output.width // tile.width newrow.append((x * scale_factor, w * scale_factor, output)) newtiles.append((y * scale_factor, h * scale_factor, newrow)) newgrid = Grid( newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor, ) output = combine_grid(newgrid) return output @dataclass(kw_only=True) class ESRGANUpscalerCheckpoints(UpscalerCheckpoints): esrgan: Path class ESRGANUpscaler(MultiUpscaler): def __init__( self, checkpoints: ESRGANUpscalerCheckpoints, device: torch.device, dtype: torch.dtype, ) -> None: super().__init__(checkpoints=checkpoints, device=device, dtype=dtype) self.esrgan = UpscalerESRGAN(checkpoints.esrgan, device=self.device, dtype=self.dtype) def to(self, device: torch.device, dtype: torch.dtype): self.esrgan.to(device=device, dtype=dtype) self.sd = self.sd.to(device=device, dtype=dtype) self.device = device self.dtype = dtype def pre_upscale(self, image: Image.Image, upscale_factor: float, **_: Any) -> Image.Image: image = self.esrgan.upscale_with_tiling(image) return super().pre_upscale(image=image, upscale_factor=upscale_factor / 4) pillow_heif.register_heif_opener() pillow_heif.register_avif_opener() CHECKPOINTS = ESRGANUpscalerCheckpoints( unet=Path( hf_hub_download( repo_id="refiners/juggernaut.reborn.sd1_5.unet", filename="model.safetensors", revision="347d14c3c782c4959cc4d1bb1e336d19f7dda4d2", ) ), clip_text_encoder=Path( hf_hub_download( repo_id="refiners/juggernaut.reborn.sd1_5.text_encoder", filename="model.safetensors", revision="744ad6a5c0437ec02ad826df9f6ede102bb27481", ) ), lda=Path( hf_hub_download( repo_id="refiners/juggernaut.reborn.sd1_5.autoencoder", filename="model.safetensors", revision="3c1aae3fc3e03e4a2b7e0fa42b62ebb64f1a4c19", ) ), controlnet_tile=Path( hf_hub_download( repo_id="refiners/controlnet.sd1_5.tile", filename="model.safetensors", revision="48ced6ff8bfa873a8976fa467c3629a240643387", ) ), esrgan=Path( hf_hub_download( repo_id="philz1337x/upscaler", filename="4x-UltraSharp.pth", revision="011deacac8270114eb7d2eeff4fe6fa9a837be70", ) ), negative_embedding=Path( hf_hub_download( repo_id="philz1337x/embeddings", filename="JuggernautNegative-neg.pt", revision="203caa7e9cc2bc225031a4021f6ab1ded283454a", ) ), negative_embedding_key="string_to_param.*", loras={ "more_details": Path( hf_hub_download( repo_id="philz1337x/loras", filename="more_details.safetensors", revision="a3802c0280c0d00c2ab18d37454a8744c44e474e", ) ), "sdxl_render": Path( hf_hub_download( repo_id="philz1337x/loras", filename="SDXLrender_v2.0.safetensors", revision="a3802c0280c0d00c2ab18d37454a8744c44e474e", ) ) } ) device = DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 enhancer = ESRGANUpscaler(checkpoints=CHECKPOINTS, device=DEVICE, dtype=DTYPE) # logging warnings.filterwarnings("ignore") root = logging.getLogger() root.setLevel(logging.WARN) handler = logging.StreamHandler(sys.stderr) handler.setLevel(logging.WARN) formatter = logging.Formatter('\n >>> [%(levelname)s] %(asctime)s %(name)s: %(message)s\n') handler.setFormatter(formatter) root.addHandler(handler) # constant data base = "black-forest-labs/FLUX.1-schnell" # precision data seq=256 width=2048 height=2048 image_steps=8 img_accu=0 # ui data css="".join([""" input, input::placeholder { text-align: center !important; } *, *::placeholder { font-family: Suez One !important; } h1,h2,h3,h4,h5,h6 { width: 100%; text-align: center; } footer { display: none !important; } .image-container { aspect-ratio: """,str(width),"/",str(height),""" !important; } .dropdown-arrow { display: none !important; } *:has(>.btn) { display: flex; justify-content: space-evenly; align-items: center; } .btn { display: flex; } """]) js=""" function custom(){ document.querySelector("div#prompt input").addEventListener("keydown",function(e){ e.target.setAttribute("last_value",e.target.value); }); document.querySelector("div#prompt input").addEventListener("input",function(e){ if( e.target.value.toString().match(/[^ a-zA-Z,]|( |,){2,}/gsm) ){ e.target.value = e.target.getAttribute("last_value"); e.target.removeAttribute("last_value"); } }); document.querySelector("div#prompt2 input").addEventListener("keydown",function(e){ e.target.setAttribute("last_value",e.target.value); }); document.querySelector("div#prompt2 input").addEventListener("input",function(e){ if( e.target.value.toString().match(/[^ a-zA-Z,]|( |,){2,}/gsm) ){ e.target.value = e.target.getAttribute("last_value"); e.target.removeAttribute("last_value"); } }); } """ # torch pipes image_pipe = FluxPipeline.from_pretrained(base, torch_dtype=torch.bfloat16).to(device) image_pipe.enable_model_cpu_offload() image_pipe.enable_vae_slicing() image_pipe.enable_vae_tiling() # functionality def upscaler( input_image: Image.Image, prompt: str = "Photorealistic, Hyperrealistic, Realistic Photography, High-Quality Photography, Natural.", negative_prompt: str = "Distorted, Discontinuous, Blurry, Doll-Like, Overly-Plastic, Low-Quality, Painted, Smoothed, Artificial, Phony, Gaudy, Digital Effects.", seed: int = int(str(random.random()).split(".")[1]), upscale_factor: int = 2, controlnet_scale: float = 0.6, controlnet_decay: float = 1.0, condition_scale: int = 6, tile_width: int = 112, tile_height: int = 144, denoise_strength: float = 0.35, num_inference_steps: int = 30, solver: str = "DDIM", ) -> Image.Image: log(f'CALL upscaler') manual_seed(seed) solver_type: type[Solver] = getattr(solvers, solver) log(f'DBG upscaler 1') enhanced_image = enhancer.upscale( image=input_image, prompt=prompt, negative_prompt=negative_prompt, upscale_factor=upscale_factor, controlnet_scale=controlnet_scale, controlnet_scale_decay=controlnet_decay, condition_scale=condition_scale, tile_size=(tile_height, tile_width), denoise_strength=denoise_strength, num_inference_steps=num_inference_steps, loras_scale={"more_details": 0.5, "sdxl_render": 1.0}, solver_type=solver_type, ) log(f'RET upscaler') return enhanced_image def get_tensor_length(tensor): nums = list(tensor.size()) ret = 1 for num in nums: ret = ret * num return ret def summarize( text, max_len=20, min_len=10 ): log(f'CALL summarize') words = text.split() if get_tensor_length(words) < 5: print("Summarization Error: Text is too short, 5 words minimum!") return text prefix = "summarize: " ret = "" for index in math.ceil( len(words) / 512 ): chunk = " ".join(words[ index*512:(index+1)*512 ]) inputs = tokenizer.encode( prefix + chunk, return_tensors="pt", max_length=float('inf'), truncation=False) while get_tensor_length(inputs) > max_len: inputs = model.generate( inputs, length_penalty=2.0, num_beams=4, early_stopping=True, max_length=max( get_tensor_length(inputs) // 4 , max_len ), min_length=min_len ) toks = tokenizer.decode(inputs[0], skip_special_tokens=True) ret = ret + ("" if ret == "" else " ") + toks inputs = tokenizer.encode( prefix + ret, return_tensors="pt", max_length=float('inf'), truncation=False) gen = model.generate( inputs, length_penalty=1.0, num_beams=4, early_stopping=True, max_length=max_len, min_length=min_len ) summary = tokenizer.decode(gen[0], skip_special_tokens=True) log(f'RET summarize with summary as {summary}') return summary def generate_random_string(length): characters = str(ascii_letters + digits) return ''.join(random.choice(characters) for _ in range(length)) def pipe_generate_image(p1,p2): log(f'CALL pipe_generate') imgs = image_pipe( prompt=p1, negative_prompt=p2, height=height, width=width, guidance_scale=img_accu, num_images_per_prompt=1, num_inference_steps=image_steps, max_sequence_length=seq, generator=torch.Generator(device).manual_seed(int(str(random.random()).split(".")[1])) ).images log(f'RET pipe_generate') return imgs def add_song_cover_text(img,artist,song,height,width): draw = ImageDraw.Draw(img,mode="RGBA") rows = 1 labels_distance = 1/3 textheight=min(math.ceil( width / 10 ), math.ceil( height / 5 )) font = ImageFont.truetype(r"Alef-Bold.ttf", textheight) textwidth = draw.textlength(song,font) x = math.ceil((width - textwidth) / 2) y = height - (textheight * rows / 2) - (height / 2) y = math.ceil(y - (height / 2 * labels_distance)) draw.text((x, y), song, (255,255,255,85), font=font, spacing=2, stroke_width=math.ceil(textheight/20), stroke_fill=(0,0,0,170)) textheight=min(math.ceil( width / 10 ), math.ceil( height / 5 )) font = ImageFont.truetype(r"Alef-Bold.ttf", textheight) textwidth = draw.textlength(artist,font) x = math.ceil((width - textwidth) / 2) y = height - (textheight * rows / 2) - (height / 2) y = math.ceil(y + (height / 2 * labels_distance)) draw.text((x, y), artist, (0,0,0,85), font=font, spacing=2, stroke_width=math.ceil(textheight/20), stroke_fill=(255,255,255,170)) return img def all_pipes(pos,neg,artist,song): imgs = pipe_generate_image(pos,neg) for i in range(len(imgs)): imgs[i] = upscaler(imgs[i]) return imgs language_codes = { "af": "Afrikaans", "ar": "Arabic", "bg": "Bulgarian", "bn": "Bengali", "ca": "Catalan", "cs": "Czech", "cy": "Welsh", "da": "Danish", "de": "German", "el": "Greek", "en": "English", "es": "Spanish", "et": "Estonian", "fa": "Persian (Farsi)", "fi": "Finnish", "fr": "French", "gu": "Gujarati", "he": "Hebrew", "hi": "Hindi", "hr": "Croatian", "hu": "Hungarian", "id": "Indonesian", "it": "Italian", "ja": "Japanese", "kn": "Kannada", "ko": "Korean", "lt": "Lithuanian", "lv": "Latvian", "mk": "Macedonian", "ml": "Malayalam", "mr": "Marathi", "ne": "Nepali", "nl": "Dutch", "no": "Norwegian", "pa": "Punjabi", "pl": "Polish", "pt": "Portuguese", "ro": "Romanian", "ru": "Russian", "sk": "Slovak", "sl": "Slovenian", "so": "Somali", "sq": "Albanian", "sv": "Swedish", "sw": "Swahili", "ta": "Tamil", "te": "Telugu", "th": "Thai", "tl": "Tagalog (Filipino)", "tr": "Turkish", "uk": "Ukrainian", "ur": "Urdu", "vi": "Vietnamese", "zh-cn": "Chinese (Simplified)", "zh-tw": "Chinese (Traditional)", } def translate(txt,to_lang="en",from_lang=False): log(f'CALL translate') if not from_lang: from_lang = get_language(txt) if(from_lang == to_lang): log(f'RET translate with txt as {txt}') return txt prefix = f"translate {language_codes[from_lang]} to {language_codes[to_lang]}: " words = txt.split() ret = "" for index in math.ceil( len(words) / 512 ): chunk = " ".join(words[ index*512:(index+1)*512 ]) inputs = tokenizer.encode(prefix+chunk, return_tensors="pt", max_length=float('inf'), truncation=False) gen = model.generate(chunk,input) toks = tokenizer.decode(gen[0], skip_special_tokens=True) ret = ret + ("" if ret == "" else " ") + toks log(f'RET translate with ret as {ret}') return ret @spaces.GPU(duration=300) def handle_generation(artist,song,genre,lyrics): log(f'CALL handle_generate') pos_artist = re.sub("([ \t\n]){1,}", " ", artist).upper().strip() pos_song = re.sub("([ \t\n]){1,}", " ", song).lower().strip() pos_song = ' '.join(word[0].upper() + word[1:] for word in pos_song.split()) pos_genre = re.sub(f'[{punctuation}]', '', re.sub("([ \t\n]){1,}", " ", genre)).lower().strip() pos_genre = ' '.join(word[0].upper() + word[1:] for word in pos_genre.split()) pos_lyrics = re.sub(f'[{punctuation}]', '', re.sub("([ \t\n]){1,}", " ", lyrics)).lower().strip() pos_lyrics_sum = pos_lyrics if pos_lyrics == "" else summarize(translate(pos_lyrics)) neg = f"Sexuality, Humanity, Textual, Labeled, Distorted, Discontinuous, Blurry, Doll-Like, Overly Plastic, Low-Quality, Painted, Smoothed, Artificial, Phony, Gaudy, Digital Effects." q = "\"" pos = f'HQ Hyper-realistic { translate(pos_genre) } song "{ translate(pos_song) }"{ pos_lyrics_sum if pos_lyrics_sum == "" else ": " + pos_lyrics_sum }.' print(f""" Positive: {pos} Negative: {neg} """) imgs = all_pipes(pos,neg,pos_artist,pos_song) index = 1 names = [] for img in imgs: scaled_by = 2 labeled_img = add_song_cover_text(img,artist,song,height*scaled_by,width*scaled_by) name = f'{artist} - {song} ({index}).png' labeled_img.save(name) names.append(name) index = index + 1 # return names return names[0] # entry if __name__ == "__main__": with gr.Blocks(theme=gr.themes.Citrus(),css=css) as demo: gr.Markdown(f""" # Song Cover Image Generator """) with gr.Row(): with gr.Column(scale=4): artist = gr.Textbox( placeholder="Artist name", value="", container=False, max_lines=1 ) song = gr.Textbox( placeholder="Song name", value="", container=False, max_lines=1 ) genre = gr.Textbox( placeholder="Genre", value="", container=False, max_lines=1 ) lyrics = gr.Textbox( placeholder="Lyrics", value="", container=False, max_lines=1 ) with gr.Column(): cover = gr.Image(interactive=False,container=False,elem_classes="image-container", label="Result", show_label=True, type='filepath', show_share_button=False) run = gr.Button("Generate",elem_classes="btn") run.click( fn=handle_generation, inputs=[artist,song,genre,lyrics], outputs=[cover] ) demo.queue().launch()