yaron123 commited on
Commit
5c223cd
·
1 Parent(s): 12302ad
Files changed (2) hide show
  1. app.py +456 -34
  2. requirements.txt +4 -2
app.py CHANGED
@@ -14,11 +14,14 @@ import warnings
14
  import time
15
  import asyncio
16
  import math
 
17
  from functools import partial
18
-
19
- # external
20
-
21
  import spaces
 
 
22
  import torch
23
  import gradio as gr
24
  from lxml.html import fromstring
@@ -27,6 +30,396 @@ from safetensors.torch import load_file, save_file
27
  from diffusers import FluxPipeline
28
  from PIL import Image, ImageDraw, ImageFont
29
  from transformers import PegasusForConditionalGeneration, PegasusTokenizerFast
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  # logging
32
 
@@ -41,19 +434,14 @@ root.addHandler(handler)
41
 
42
  # constant data
43
 
44
- if torch.cuda.is_available():
45
- device = "cuda"
46
- else:
47
- device = "cpu"
48
-
49
  base = "black-forest-labs/FLUX.1-schnell"
50
  pegasus_name = "google/pegasus-xsum"
51
 
52
  # precision data
53
 
54
  seq=512
55
- width=2160
56
- height=2160
57
  image_steps=8
58
  img_accu=0
59
 
@@ -123,7 +511,44 @@ image_pipe.enable_model_cpu_offload()
123
 
124
  # functionality
125
 
126
- @spaces.GPU(duration=70)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  def summarize_text(
128
  text, max_length=30, num_beams=16, early_stopping=True,
129
  pegasus_tokenizer = PegasusTokenizerFast.from_pretrained("google/pegasus-xsum"),
@@ -140,7 +565,7 @@ def generate_random_string(length):
140
  characters = str(ascii_letters + digits)
141
  return ''.join(random.choice(characters) for _ in range(length))
142
 
143
- @spaces.GPU(duration=140)
144
  def pipe_generate(p1,p2):
145
  return image_pipe(
146
  prompt=p1,
@@ -162,8 +587,8 @@ def handle_generate(artist,song,genre,lyrics):
162
  pos_genre = re.sub(f'[{punctuation}]', '', re.sub("([ \t\n]){1,}", " ", genre)).upper().strip()
163
  pos_lyrics = re.sub(f'[{punctuation}]', '', re.sub("([ \t\n]){1,}", " ", lyrics)).lower().strip()
164
  pos_lyrics_sum = summarize_text(pos_lyrics)
165
- neg = f"Textual Labeled Distorted Discontinuous Ugly Blurry"
166
- pos = f'Realistic Natural Genuine Reasonable Detailed { pos_genre } GENRE { pos_song } "{ pos_lyrics_sum }"'
167
 
168
  print(f"""
169
  Positive: {pos}
@@ -176,26 +601,28 @@ def handle_generate(artist,song,genre,lyrics):
176
  draw = ImageDraw.Draw(img)
177
 
178
  rows = 1
179
- labes_distance = math.ceil(1 / 3)
180
 
181
  textheight=min(math.ceil( width / 10 ), math.ceil( height / 5 ))
182
  font = ImageFont.truetype(r"Alef-Bold.ttf", textheight)
183
  textwidth = draw.textlength(pos_song,font)
184
  x = math.ceil((width - textwidth) / 2)
185
- y = math.ceil((height - math.ceil(textheight * rows / 2)) / 2)
186
- y = y - math.ceil(y / labes_distance)
187
- draw.text((x, y), pos_song, (255,255,255), font=font)
188
 
189
  textheight=min(math.ceil( width / 12 ), math.ceil( height / 6 ))
190
  font = ImageFont.truetype(r"Alef-Bold.ttf", textheight)
191
  textwidth = draw.textlength(pos_artist,font)
192
  x = math.ceil((width - textwidth) / 2)
193
- y = math.ceil((height - math.ceil(textheight * rows / 2)) / 2)
194
- y = y + math.ceil(y / labes_distance)
195
- draw.text((x, y), pos_artist, (255,255,255), font=font)
 
 
196
 
197
  name = generate_random_string(12) + ".png"
198
- img.save(name)
199
  return name
200
 
201
  # entry
@@ -205,36 +632,33 @@ if __name__ == "__main__":
205
  gr.Markdown(f"""
206
  # Song Cover Image Generator
207
  """)
208
- with gr.Row():
209
- with gr.Column():
210
  artist = gr.Textbox(
211
  placeholder="Artist name",
212
  container=False,
213
  max_lines=1
214
  )
215
- with gr.Column():
216
  song = gr.Textbox(
217
  placeholder="Song name",
218
  container=False,
219
  max_lines=1
220
  )
221
- with gr.Column():
222
- genre = gr.Textbox(
223
  placeholder="Genre",
224
  container=False,
225
  max_lines=1
226
- )
227
- with gr.Row():
228
  lyrics = gr.Textbox(
229
  placeholder="Lyrics (English)",
230
  container=False,
231
  max_lines=1
232
  )
233
- with gr.Row():
234
- run = gr.Button("Generate",elem_classes="btn")
235
- with gr.Row():
236
  cover = gr.Image(interactive=False,container=False,elem_classes="image-container", label="Result", show_label=True, type='filepath', show_share_button=False)
237
 
 
 
238
  run.click(
239
  fn=handle_generate,
240
  inputs=[artist,song,genre,lyrics],
@@ -242,5 +666,3 @@ if __name__ == "__main__":
242
  )
243
 
244
  demo.queue().launch()
245
-
246
- # end
 
14
  import time
15
  import asyncio
16
  import math
17
+ from pathlib import Path
18
  from functools import partial
19
+ from dataclasses import dataclass
20
+ from typing import Any
21
+ import pillow_heif
22
  import spaces
23
+ import numpy as np
24
+ import numpy.typing as npt
25
  import torch
26
  import gradio as gr
27
  from lxml.html import fromstring
 
30
  from diffusers import FluxPipeline
31
  from PIL import Image, ImageDraw, ImageFont
32
  from transformers import PegasusForConditionalGeneration, PegasusTokenizerFast
33
+ from refiners.fluxion.utils import manual_seed
34
+ from refiners.foundationals.latent_diffusion import Solver, solvers
35
+ from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_upscaler import (
36
+ MultiUpscaler,
37
+ UpscalerCheckpoints,
38
+ )
39
+
40
+
41
+ Tile = tuple[int, int, Image.Image]
42
+ Tiles = list[tuple[int, int, list[Tile]]]
43
+
44
+ def conv_block(in_nc: int, out_nc: int) -> nn.Sequential:
45
+ return nn.Sequential(
46
+ nn.Conv2d(in_nc, out_nc, kernel_size=3, padding=1),
47
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
48
+ )
49
+
50
+
51
+ class ResidualDenseBlock_5C(nn.Module):
52
+ """
53
+ Residual Dense Block
54
+ The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
55
+ Modified options that can be used:
56
+ - "Partial Convolution based Padding" arXiv:1811.11718
57
+ - "Spectral normalization" arXiv:1802.05957
58
+ - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
59
+ {Rakotonirina} and A. {Rasoanaivo}
60
+ """
61
+
62
+ def __init__(self, nf: int = 64, gc: int = 32) -> None:
63
+ super().__init__() # type: ignore[reportUnknownMemberType]
64
+
65
+ self.conv1 = conv_block(nf, gc)
66
+ self.conv2 = conv_block(nf + gc, gc)
67
+ self.conv3 = conv_block(nf + 2 * gc, gc)
68
+ self.conv4 = conv_block(nf + 3 * gc, gc)
69
+ # Wrapped in Sequential because of key in state dict.
70
+ self.conv5 = nn.Sequential(nn.Conv2d(nf + 4 * gc, nf, kernel_size=3, padding=1))
71
+
72
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
73
+ x1 = self.conv1(x)
74
+ x2 = self.conv2(torch.cat((x, x1), 1))
75
+ x3 = self.conv3(torch.cat((x, x1, x2), 1))
76
+ x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
77
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
78
+ return x5 * 0.2 + x
79
+
80
+
81
+ class RRDB(nn.Module):
82
+ """
83
+ Residual in Residual Dense Block
84
+ (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
85
+ """
86
+
87
+ def __init__(self, nf: int) -> None:
88
+ super().__init__() # type: ignore[reportUnknownMemberType]
89
+ self.RDB1 = ResidualDenseBlock_5C(nf)
90
+ self.RDB2 = ResidualDenseBlock_5C(nf)
91
+ self.RDB3 = ResidualDenseBlock_5C(nf)
92
+
93
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
94
+ out = self.RDB1(x)
95
+ out = self.RDB2(out)
96
+ out = self.RDB3(out)
97
+ return out * 0.2 + x
98
+
99
+
100
+ class Upsample2x(nn.Module):
101
+ """Upsample 2x."""
102
+
103
+ def __init__(self) -> None:
104
+ super().__init__() # type: ignore[reportUnknownMemberType]
105
+
106
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
107
+ return nn.functional.interpolate(x, scale_factor=2.0) # type: ignore
108
+
109
+
110
+ class ShortcutBlock(nn.Module):
111
+ """Elementwise sum the output of a submodule to its input"""
112
+
113
+ def __init__(self, submodule: nn.Module) -> None:
114
+ super().__init__() # type: ignore[reportUnknownMemberType]
115
+ self.sub = submodule
116
+
117
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
118
+ return x + self.sub(x)
119
+
120
+
121
+ class RRDBNet(nn.Module):
122
+ def __init__(self, in_nc: int, out_nc: int, nf: int, nb: int) -> None:
123
+ super().__init__() # type: ignore[reportUnknownMemberType]
124
+ assert in_nc % 4 != 0 # in_nc is 3
125
+
126
+ self.model = nn.Sequential(
127
+ nn.Conv2d(in_nc, nf, kernel_size=3, padding=1),
128
+ ShortcutBlock(
129
+ nn.Sequential(
130
+ *(RRDB(nf) for _ in range(nb)),
131
+ nn.Conv2d(nf, nf, kernel_size=3, padding=1),
132
+ )
133
+ ),
134
+ Upsample2x(),
135
+ nn.Conv2d(nf, nf, kernel_size=3, padding=1),
136
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
137
+ Upsample2x(),
138
+ nn.Conv2d(nf, nf, kernel_size=3, padding=1),
139
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
140
+ nn.Conv2d(nf, nf, kernel_size=3, padding=1),
141
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
142
+ nn.Conv2d(nf, out_nc, kernel_size=3, padding=1),
143
+ )
144
+
145
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
146
+ return self.model(x)
147
+
148
+
149
+ def infer_params(state_dict: dict[str, torch.Tensor]) -> tuple[int, int, int, int, int]:
150
+ # this code is adapted from https://github.com/victorca25/iNNfer
151
+ scale2x = 0
152
+ scalemin = 6
153
+ n_uplayer = 0
154
+ out_nc = 0
155
+ nb = 0
156
+
157
+ for block in list(state_dict):
158
+ parts = block.split(".")
159
+ n_parts = len(parts)
160
+ if n_parts == 5 and parts[2] == "sub":
161
+ nb = int(parts[3])
162
+ elif n_parts == 3:
163
+ part_num = int(parts[1])
164
+ if part_num > scalemin and parts[0] == "model" and parts[2] == "weight":
165
+ scale2x += 1
166
+ if part_num > n_uplayer:
167
+ n_uplayer = part_num
168
+ out_nc = state_dict[block].shape[0]
169
+ assert "conv1x1" not in block # no ESRGANPlus
170
+
171
+ nf = state_dict["model.0.weight"].shape[0]
172
+ in_nc = state_dict["model.0.weight"].shape[1]
173
+ scale = 2**scale2x
174
+
175
+ assert out_nc > 0
176
+ assert nb > 0
177
+
178
+ return in_nc, out_nc, nf, nb, scale # 3, 3, 64, 23, 4
179
+
180
+ # https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/images.py#L64
181
+ class Grid(NamedTuple):
182
+ tiles: Tiles
183
+ tile_w: int
184
+ tile_h: int
185
+ image_w: int
186
+ image_h: int
187
+ overlap: int
188
+
189
+
190
+ # adapted from https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/images.py#L67
191
+ def split_grid(image: Image.Image, tile_w: int = 512, tile_h: int = 512, overlap: int = 64) -> Grid:
192
+ w = image.width
193
+ h = image.height
194
+
195
+ non_overlap_width = tile_w - overlap
196
+ non_overlap_height = tile_h - overlap
197
+
198
+ cols = max(1, math.ceil((w - overlap) / non_overlap_width))
199
+ rows = max(1, math.ceil((h - overlap) / non_overlap_height))
200
+
201
+ dx = (w - tile_w) / (cols - 1) if cols > 1 else 0
202
+ dy = (h - tile_h) / (rows - 1) if rows > 1 else 0
203
+
204
+ grid = Grid([], tile_w, tile_h, w, h, overlap)
205
+ for row in range(rows):
206
+ row_images: list[Tile] = []
207
+ y1 = max(min(int(row * dy), h - tile_h), 0)
208
+ y2 = min(y1 + tile_h, h)
209
+ for col in range(cols):
210
+ x1 = max(min(int(col * dx), w - tile_w), 0)
211
+ x2 = min(x1 + tile_w, w)
212
+ tile = image.crop((x1, y1, x2, y2))
213
+ row_images.append((x1, tile_w, tile))
214
+ grid.tiles.append((y1, tile_h, row_images))
215
+
216
+ return grid
217
+
218
+
219
+ # https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/images.py#L104
220
+ def combine_grid(grid: Grid):
221
+ def make_mask_image(r: npt.NDArray[np.float32]) -> Image.Image:
222
+ r = r * 255 / grid.overlap
223
+ return Image.fromarray(r.astype(np.uint8), "L")
224
+
225
+ mask_w = make_mask_image(
226
+ np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0)
227
+ )
228
+ mask_h = make_mask_image(
229
+ np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1)
230
+ )
231
+
232
+ combined_image = Image.new("RGB", (grid.image_w, grid.image_h))
233
+ for y, h, row in grid.tiles:
234
+ combined_row = Image.new("RGB", (grid.image_w, h))
235
+ for x, w, tile in row:
236
+ if x == 0:
237
+ combined_row.paste(tile, (0, 0))
238
+ continue
239
+
240
+ combined_row.paste(tile.crop((0, 0, grid.overlap, h)), (x, 0), mask=mask_w)
241
+ combined_row.paste(tile.crop((grid.overlap, 0, w, h)), (x + grid.overlap, 0))
242
+
243
+ if y == 0:
244
+ combined_image.paste(combined_row, (0, 0))
245
+ continue
246
+
247
+ combined_image.paste(
248
+ combined_row.crop((0, 0, combined_row.width, grid.overlap)),
249
+ (0, y),
250
+ mask=mask_h,
251
+ )
252
+ combined_image.paste(
253
+ combined_row.crop((0, grid.overlap, combined_row.width, h)),
254
+ (0, y + grid.overlap),
255
+ )
256
+
257
+ return combined_image
258
+
259
+
260
+ class UpscalerESRGAN:
261
+ def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype):
262
+ self.model_path = model_path
263
+ self.device = device
264
+ self.model = self.load_model(model_path)
265
+ self.to(device, dtype)
266
+
267
+ def __call__(self, img: Image.Image) -> Image.Image:
268
+ return self.upscale_without_tiling(img)
269
+
270
+ def to(self, device: torch.device, dtype: torch.dtype):
271
+ self.device = device
272
+ self.dtype = dtype
273
+ self.model.to(device=device, dtype=dtype)
274
+
275
+ def load_model(self, path: Path) -> RRDBNet:
276
+ filename = path
277
+ state_dict: dict[str, torch.Tensor] = torch.load(filename, weights_only=True, map_location=self.device) # type: ignore
278
+ in_nc, out_nc, nf, nb, upscale = infer_params(state_dict)
279
+ assert upscale == 4, "Only 4x upscaling is supported"
280
+ model = RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb)
281
+ model.load_state_dict(state_dict)
282
+ model.eval()
283
+
284
+ return model
285
+
286
+ def upscale_without_tiling(self, img: Image.Image) -> Image.Image:
287
+ img_np = np.array(img)
288
+ img_np = img_np[:, :, ::-1]
289
+ img_np = np.ascontiguousarray(np.transpose(img_np, (2, 0, 1))) / 255
290
+ img_t = torch.from_numpy(img_np).float() # type: ignore
291
+ img_t = img_t.unsqueeze(0).to(device=self.device, dtype=self.dtype)
292
+ with torch.no_grad():
293
+ output = self.model(img_t)
294
+ output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
295
+ output = 255.0 * np.moveaxis(output, 0, 2)
296
+ output = output.astype(np.uint8)
297
+ output = output[:, :, ::-1]
298
+ return Image.fromarray(output, "RGB")
299
+
300
+ # https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/esrgan_model.py#L208
301
+ def upscale_with_tiling(self, img: Image.Image) -> Image.Image:
302
+ img = img.convert("RGB")
303
+ grid = split_grid(img)
304
+ newtiles: Tiles = []
305
+ scale_factor: int = 1
306
+
307
+ for y, h, row in grid.tiles:
308
+ newrow: list[Tile] = []
309
+ for tiledata in row:
310
+ x, w, tile = tiledata
311
+ output = self.upscale_without_tiling(tile)
312
+ scale_factor = output.width // tile.width
313
+ newrow.append((x * scale_factor, w * scale_factor, output))
314
+ newtiles.append((y * scale_factor, h * scale_factor, newrow))
315
+
316
+ newgrid = Grid(
317
+ newtiles,
318
+ grid.tile_w * scale_factor,
319
+ grid.tile_h * scale_factor,
320
+ grid.image_w * scale_factor,
321
+ grid.image_h * scale_factor,
322
+ grid.overlap * scale_factor,
323
+ )
324
+ output = combine_grid(newgrid)
325
+ return output
326
+
327
+ @dataclass(kw_only=True)
328
+ class ESRGANUpscalerCheckpoints(UpscalerCheckpoints):
329
+ esrgan: Path
330
+
331
+ class ESRGANUpscaler(MultiUpscaler):
332
+ def __init__(
333
+ self,
334
+ checkpoints: ESRGANUpscalerCheckpoints,
335
+ device: torch.device,
336
+ dtype: torch.dtype,
337
+ ) -> None:
338
+ super().__init__(checkpoints=checkpoints, device=device, dtype=dtype)
339
+ self.esrgan = UpscalerESRGAN(checkpoints.esrgan, device=self.device, dtype=self.dtype)
340
+
341
+ def to(self, device: torch.device, dtype: torch.dtype):
342
+ self.esrgan.to(device=device, dtype=dtype)
343
+ self.sd = self.sd.to(device=device, dtype=dtype)
344
+ self.device = device
345
+ self.dtype = dtype
346
+
347
+ def pre_upscale(self, image: Image.Image, upscale_factor: float, **_: Any) -> Image.Image:
348
+ image = self.esrgan.upscale_with_tiling(image)
349
+ return super().pre_upscale(image=image, upscale_factor=upscale_factor / 4)
350
+
351
+ pillow_heif.register_heif_opener()
352
+ pillow_heif.register_avif_opener()
353
+
354
+ CHECKPOINTS = ESRGANUpscalerCheckpoints(
355
+ unet=Path(
356
+ hf_hub_download(
357
+ repo_id="refiners/juggernaut.reborn.sd1_5.unet",
358
+ filename="model.safetensors",
359
+ revision="347d14c3c782c4959cc4d1bb1e336d19f7dda4d2",
360
+ )
361
+ ),
362
+ clip_text_encoder=Path(
363
+ hf_hub_download(
364
+ repo_id="refiners/juggernaut.reborn.sd1_5.text_encoder",
365
+ filename="model.safetensors",
366
+ revision="744ad6a5c0437ec02ad826df9f6ede102bb27481",
367
+ )
368
+ ),
369
+ lda=Path(
370
+ hf_hub_download(
371
+ repo_id="refiners/juggernaut.reborn.sd1_5.autoencoder",
372
+ filename="model.safetensors",
373
+ revision="3c1aae3fc3e03e4a2b7e0fa42b62ebb64f1a4c19",
374
+ )
375
+ ),
376
+ controlnet_tile=Path(
377
+ hf_hub_download(
378
+ repo_id="refiners/controlnet.sd1_5.tile",
379
+ filename="model.safetensors",
380
+ revision="48ced6ff8bfa873a8976fa467c3629a240643387",
381
+ )
382
+ ),
383
+ esrgan=Path(
384
+ hf_hub_download(
385
+ repo_id="philz1337x/upscaler",
386
+ filename="4x-UltraSharp.pth",
387
+ revision="011deacac8270114eb7d2eeff4fe6fa9a837be70",
388
+ )
389
+ ),
390
+ negative_embedding=Path(
391
+ hf_hub_download(
392
+ repo_id="philz1337x/embeddings",
393
+ filename="JuggernautNegative-neg.pt",
394
+ revision="203caa7e9cc2bc225031a4021f6ab1ded283454a",
395
+ )
396
+ ),
397
+ negative_embedding_key="string_to_param.*",
398
+ loras={
399
+ "more_details": Path(
400
+ hf_hub_download(
401
+ repo_id="philz1337x/loras",
402
+ filename="more_details.safetensors",
403
+ revision="a3802c0280c0d00c2ab18d37454a8744c44e474e",
404
+ )
405
+ ),
406
+ "sdxl_render": Path(
407
+ hf_hub_download(
408
+ repo_id="philz1337x/loras",
409
+ filename="SDXLrender_v2.0.safetensors",
410
+ revision="a3802c0280c0d00c2ab18d37454a8744c44e474e",
411
+ )
412
+ )
413
+ }
414
+ )
415
+
416
+ # initialize the enhancer, on the cpu
417
+ DEVICE_CPU = torch.device("cpu")
418
+ DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
419
+ enhancer = ESRGANUpscaler(checkpoints=CHECKPOINTS, device=DEVICE_CPU, dtype=DTYPE)
420
+
421
+ device = DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
422
+ enhancer.to(device=DEVICE, dtype=DTYPE)
423
 
424
  # logging
425
 
 
434
 
435
  # constant data
436
 
 
 
 
 
 
437
  base = "black-forest-labs/FLUX.1-schnell"
438
  pegasus_name = "google/pegasus-xsum"
439
 
440
  # precision data
441
 
442
  seq=512
443
+ width=1024
444
+ height=1024
445
  image_steps=8
446
  img_accu=0
447
 
 
511
 
512
  # functionality
513
 
514
+ @spaces.GPU(duration=180)
515
+ def upscaler(
516
+ input_image: Image.Image,
517
+ prompt: str = "masterpiece, best quality, highres",
518
+ negative_prompt: str = "worst quality, low quality, normal quality",
519
+ seed: int = 42,
520
+ upscale_factor: int = 8,
521
+ controlnet_scale: float = 0.6,
522
+ controlnet_decay: float = 1.0,
523
+ condition_scale: int = 6,
524
+ tile_width: int = 112,
525
+ tile_height: int = 144,
526
+ denoise_strength: float = 0.35,
527
+ num_inference_steps: int = 18,
528
+ solver: str = "DDIM",
529
+ ) -> Image.Image:
530
+ manual_seed(seed)
531
+
532
+ solver_type: type[Solver] = getattr(solvers, solver)
533
+
534
+ enhanced_image = enhancer.upscale(
535
+ image=input_image,
536
+ prompt=prompt,
537
+ negative_prompt=negative_prompt,
538
+ upscale_factor=upscale_factor,
539
+ controlnet_scale=controlnet_scale,
540
+ controlnet_scale_decay=controlnet_decay,
541
+ condition_scale=condition_scale,
542
+ tile_size=(tile_height, tile_width),
543
+ denoise_strength=denoise_strength,
544
+ num_inference_steps=num_inference_steps,
545
+ loras_scale={"more_details": 0.5, "sdxl_render": 1.0},
546
+ solver_type=solver_type,
547
+ )
548
+
549
+ return enhanced_image
550
+
551
+ @spaces.GPU(duration=180)
552
  def summarize_text(
553
  text, max_length=30, num_beams=16, early_stopping=True,
554
  pegasus_tokenizer = PegasusTokenizerFast.from_pretrained("google/pegasus-xsum"),
 
565
  characters = str(ascii_letters + digits)
566
  return ''.join(random.choice(characters) for _ in range(length))
567
 
568
+ @spaces.GPU(duration=180)
569
  def pipe_generate(p1,p2):
570
  return image_pipe(
571
  prompt=p1,
 
587
  pos_genre = re.sub(f'[{punctuation}]', '', re.sub("([ \t\n]){1,}", " ", genre)).upper().strip()
588
  pos_lyrics = re.sub(f'[{punctuation}]', '', re.sub("([ \t\n]){1,}", " ", lyrics)).lower().strip()
589
  pos_lyrics_sum = summarize_text(pos_lyrics)
590
+ neg = f"Textual Labeled Distorted Discontinuous Ugly Blurry Low-Quality Worst-Quality Low-Resolution Painted"
591
+ pos = f'Realistic Vivid Genuine Reasonable Detailed 4K { pos_genre } GENRE { pos_song }: "{ pos_lyrics_sum }"'
592
 
593
  print(f"""
594
  Positive: {pos}
 
601
  draw = ImageDraw.Draw(img)
602
 
603
  rows = 1
604
+ labels_distance = math.ceil(1 / 3)
605
 
606
  textheight=min(math.ceil( width / 10 ), math.ceil( height / 5 ))
607
  font = ImageFont.truetype(r"Alef-Bold.ttf", textheight)
608
  textwidth = draw.textlength(pos_song,font)
609
  x = math.ceil((width - textwidth) / 2)
610
+ y = height - math.ceil(textheight * rows / 2)
611
+ y = y - math.ceil(y / labels_distance)
612
+ draw.text((x, y), pos_song, (255,255,255), font=font, spacing=2, stroke_width=4, stroke_fill=(0,0,0))
613
 
614
  textheight=min(math.ceil( width / 12 ), math.ceil( height / 6 ))
615
  font = ImageFont.truetype(r"Alef-Bold.ttf", textheight)
616
  textwidth = draw.textlength(pos_artist,font)
617
  x = math.ceil((width - textwidth) / 2)
618
+ y = height - math.ceil(textheight * rows / 2)
619
+ y = y + math.ceil(y / labels_distance)
620
+ draw.text((x, y), pos_artist, (0,0,0), font=font, spacing=6, stroke_width=8, stroke_fill=(255,255,255))
621
+
622
+ enhanced_img = upscaler(img)
623
 
624
  name = generate_random_string(12) + ".png"
625
+ enhanced_img.save(name)
626
  return name
627
 
628
  # entry
 
632
  gr.Markdown(f"""
633
  # Song Cover Image Generator
634
  """)
635
+ with gr.Column():
636
+ with gr.Row():
637
  artist = gr.Textbox(
638
  placeholder="Artist name",
639
  container=False,
640
  max_lines=1
641
  )
 
642
  song = gr.Textbox(
643
  placeholder="Song name",
644
  container=False,
645
  max_lines=1
646
  )
647
+ genre = gr.Textbox(
 
648
  placeholder="Genre",
649
  container=False,
650
  max_lines=1
651
+ )
 
652
  lyrics = gr.Textbox(
653
  placeholder="Lyrics (English)",
654
  container=False,
655
  max_lines=1
656
  )
657
+ with gr.Column():
 
 
658
  cover = gr.Image(interactive=False,container=False,elem_classes="image-container", label="Result", show_label=True, type='filepath', show_share_button=False)
659
 
660
+ run = gr.Button("Generate",elem_classes="btn")
661
+
662
  run.click(
663
  fn=handle_generate,
664
  inputs=[artist,song,genre,lyrics],
 
666
  )
667
 
668
  demo.queue().launch()
 
 
requirements.txt CHANGED
@@ -1,11 +1,13 @@
1
  lxml
2
- pillow
 
 
3
  opencv-python
4
  gradio==5.12.0
5
  accelerate
6
  safetensors
7
  huggingface-hub
8
- numpy
9
  torch
10
  torchaudio
11
  torchvision
 
1
  lxml
2
+ pillow>=10.4.0
3
+ git+https://github.com/finegrain-ai/refiners
4
+ pillow-heif>=0.18.0
5
  opencv-python
6
  gradio==5.12.0
7
  accelerate
8
  safetensors
9
  huggingface-hub
10
+ numpy<2.0.0
11
  torch
12
  torchaudio
13
  torchvision