aiqcamp commited on
Commit
6ef1366
·
verified ·
1 Parent(s): 4fd76e4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +122 -1382
app.py CHANGED
@@ -1,1401 +1,141 @@
1
- import requests
2
- from bs4 import BeautifulSoup
3
- from abc import ABC, abstractmethod
4
- from pathlib import Path
5
- from langdetect import detect as get_language
6
- from typing import Any, Dict, List, Optional, Union
7
- from collections import namedtuple
8
- from inspect import signature
9
- import os
10
- import subprocess
11
- import logging
12
- import re
13
- import random
14
- from string import ascii_letters, digits, punctuation
15
- import requests
16
- import sys
17
- import warnings
18
- import time
19
- import math
20
- from pathlib import Path
21
- from dataclasses import dataclass
22
- from typing import Any
23
- import pillow_heif
24
- import spaces
25
- import numpy as np
26
- import numpy.typing as npt
27
- import torch
28
- from torch import nn
29
  import gradio as gr
30
- from lxml.html import fromstring
31
- from huggingface_hub import hf_hub_download
32
- from safetensors.torch import load_file, save_file
33
- from diffusers import DiffusionPipeline
34
- from PIL import Image, ImageDraw, ImageFont
35
- from transformers import pipeline, T5ForConditionalGeneration, T5Tokenizer
36
- from refiners.fluxion.utils import manual_seed
37
- from refiners.foundationals.latent_diffusion import Solver, solvers
38
- from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_upscaler import (
39
- MultiUpscaler,
40
- UpscalerCheckpoints,
41
- )
42
- from datetime import datetime
43
-
44
- model = T5ForConditionalGeneration.from_pretrained("t5-large")
45
- tokenizer = T5Tokenizer.from_pretrained("t5-large")
46
-
47
- def log(msg):
48
- print(f'{datetime.now().time()} {msg}')
49
-
50
- Tile = tuple[int, int, Image.Image]
51
- Tiles = list[tuple[int, int, list[Tile]]]
52
-
53
- def conv_block(in_nc: int, out_nc: int) -> nn.Sequential:
54
- return nn.Sequential(
55
- nn.Conv2d(in_nc, out_nc, kernel_size=3, padding=1),
56
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
57
- )
58
-
59
- class ResidualDenseBlock_5C(nn.Module):
60
- """
61
- Residual Dense Block
62
- The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
63
- Modified options that can be used:
64
- - "Partial Convolution based Padding" arXiv:1811.11718
65
- - "Spectral normalization" arXiv:1802.05957
66
- - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
67
- {Rakotonirina} and A. {Rasoanaivo}
68
- """
69
 
70
- def __init__(self, nf: int = 64, gc: int = 32) -> None:
71
- super().__init__() # type: ignore[reportUnknownMemberType]
72
 
73
- self.conv1 = conv_block(nf, gc)
74
- self.conv2 = conv_block(nf + gc, gc)
75
- self.conv3 = conv_block(nf + 2 * gc, gc)
76
- self.conv4 = conv_block(nf + 3 * gc, gc)
77
- # Wrapped in Sequential because of key in state dict.
78
- self.conv5 = nn.Sequential(nn.Conv2d(nf + 4 * gc, nf, kernel_size=3, padding=1))
79
 
80
- def forward(self, x: torch.Tensor) -> torch.Tensor:
81
- x1 = self.conv1(x)
82
- x2 = self.conv2(torch.cat((x, x1), 1))
83
- x3 = self.conv3(torch.cat((x, x1, x2), 1))
84
- x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
85
- x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
86
- return x5 * 0.2 + x
87
 
 
 
88
 
89
- class RRDB(nn.Module):
90
- """
91
- Residual in Residual Dense Block
92
- (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
 
 
 
 
 
93
  """
94
-
95
- def __init__(self, nf: int) -> None:
96
- super().__init__() # type: ignore[reportUnknownMemberType]
97
- self.RDB1 = ResidualDenseBlock_5C(nf)
98
- self.RDB2 = ResidualDenseBlock_5C(nf)
99
- self.RDB3 = ResidualDenseBlock_5C(nf)
100
-
101
- def forward(self, x: torch.Tensor) -> torch.Tensor:
102
- out = self.RDB1(x)
103
- out = self.RDB2(out)
104
- out = self.RDB3(out)
105
- return out * 0.2 + x
106
-
107
-
108
- class Upsample2x(nn.Module):
109
- """Upsample 2x."""
110
-
111
- def __init__(self) -> None:
112
- super().__init__() # type: ignore[reportUnknownMemberType]
113
-
114
- def forward(self, x: torch.Tensor) -> torch.Tensor:
115
- return nn.functional.interpolate(x, scale_factor=2.0) # type: ignore
116
-
117
-
118
- class ShortcutBlock(nn.Module):
119
- """Elementwise sum the output of a submodule to its input"""
120
-
121
- def __init__(self, submodule: nn.Module) -> None:
122
- super().__init__() # type: ignore[reportUnknownMemberType]
123
- self.sub = submodule
124
-
125
- def forward(self, x: torch.Tensor) -> torch.Tensor:
126
- return x + self.sub(x)
127
-
128
-
129
- class RRDBNet(nn.Module):
130
- def __init__(self, in_nc: int, out_nc: int, nf: int, nb: int) -> None:
131
- super().__init__() # type: ignore[reportUnknownMemberType]
132
- assert in_nc % 4 != 0 # in_nc is 3
133
-
134
- self.model = nn.Sequential(
135
- nn.Conv2d(in_nc, nf, kernel_size=3, padding=1),
136
- ShortcutBlock(
137
- nn.Sequential(
138
- *(RRDB(nf) for _ in range(nb)),
139
- nn.Conv2d(nf, nf, kernel_size=3, padding=1),
140
- )
141
- ),
142
- Upsample2x(),
143
- nn.Conv2d(nf, nf, kernel_size=3, padding=1),
144
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
145
- Upsample2x(),
146
- nn.Conv2d(nf, nf, kernel_size=3, padding=1),
147
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
148
- nn.Conv2d(nf, nf, kernel_size=3, padding=1),
149
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
150
- nn.Conv2d(nf, out_nc, kernel_size=3, padding=1),
151
- )
152
-
153
- def forward(self, x: torch.Tensor) -> torch.Tensor:
154
- return self.model(x)
155
-
156
-
157
- def infer_params(state_dict: dict[str, torch.Tensor]) -> tuple[int, int, int, int, int]:
158
- # this code is adapted from https://github.com/victorca25/iNNfer
159
- scale2x = 0
160
- scalemin = 6
161
- n_uplayer = 0
162
- out_nc = 0
163
- nb = 0
164
-
165
- for block in list(state_dict):
166
- parts = block.split(".")
167
- n_parts = len(parts)
168
- if n_parts == 5 and parts[2] == "sub":
169
- nb = int(parts[3])
170
- elif n_parts == 3:
171
- part_num = int(parts[1])
172
- if part_num > scalemin and parts[0] == "model" and parts[2] == "weight":
173
- scale2x += 1
174
- if part_num > n_uplayer:
175
- n_uplayer = part_num
176
- out_nc = state_dict[block].shape[0]
177
- assert "conv1x1" not in block # no ESRGANPlus
178
-
179
- nf = state_dict["model.0.weight"].shape[0]
180
- in_nc = state_dict["model.0.weight"].shape[1]
181
- scale = 2**scale2x
182
-
183
- assert out_nc > 0
184
- assert nb > 0
185
-
186
- return in_nc, out_nc, nf, nb, scale # 3, 3, 64, 23, 4
187
-
188
- # https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/images.py#L64
189
- Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])
190
-
191
- # adapted from https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/images.py#L67
192
- def split_grid(image: Image.Image, tile_w: int = 512, tile_h: int = 512, overlap: int = 64) -> Grid:
193
- w = image.width
194
- h = image.height
195
-
196
- non_overlap_width = tile_w - overlap
197
- non_overlap_height = tile_h - overlap
198
-
199
- cols = max(1, math.ceil((w - overlap) / non_overlap_width))
200
- rows = max(1, math.ceil((h - overlap) / non_overlap_height))
201
-
202
- dx = (w - tile_w) / (cols - 1) if cols > 1 else 0
203
- dy = (h - tile_h) / (rows - 1) if rows > 1 else 0
204
-
205
- grid = Grid([], tile_w, tile_h, w, h, overlap)
206
- for row in range(rows):
207
- row_images: list[Tile] = []
208
- y1 = max(min(int(row * dy), h - tile_h), 0)
209
- y2 = min(y1 + tile_h, h)
210
- for col in range(cols):
211
- x1 = max(min(int(col * dx), w - tile_w), 0)
212
- x2 = min(x1 + tile_w, w)
213
- tile = image.crop((x1, y1, x2, y2))
214
- row_images.append((x1, tile_w, tile))
215
- grid.tiles.append((y1, tile_h, row_images))
216
-
217
- return grid
218
-
219
-
220
- # https://github.com/philz1337x/clarity-upscaler/blob/e0cd797198d1e0e745400c04d8d1b98ae508c73b/modules/images.py#L104
221
- def combine_grid(grid: Grid):
222
- def make_mask_image(r: npt.NDArray[np.float32]) -> Image.Image:
223
- r = r * 255 / grid.overlap
224
- return Image.fromarray(r.astype(np.uint8), "L")
225
-
226
- mask_w = make_mask_image(
227
- np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0)
228
- )
229
- mask_h = make_mask_image(
230
- np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1)
231
- )
232
-
233
- combined_image = Image.new("RGB", (grid.image_w, grid.image_h))
234
- for y, h, row in grid.tiles:
235
- combined_row = Image.new("RGB", (grid.image_w, h))
236
- for x, w, tile in row:
237
- if x == 0:
238
- combined_row.paste(tile, (0, 0))
239
- continue
240
-
241
- combined_row.paste(tile.crop((0, 0, grid.overlap, h)), (x, 0), mask=mask_w)
242
- combined_row.paste(tile.crop((grid.overlap, 0, w, h)), (x + grid.overlap, 0))
243
-
244
- if y == 0:
245
- combined_image.paste(combined_row, (0, 0))
246
- continue
247
-
248
- combined_image.paste(
249
- combined_row.crop((0, 0, combined_row.width, grid.overlap)),
250
- (0, y),
251
- mask=mask_h,
252
- )
253
- combined_image.paste(
254
- combined_row.crop((0, grid.overlap, combined_row.width, h)),
255
- (0, y + grid.overlap),
256
- )
257
-
258
- return combined_image
259
-
260
-
261
- class UpscalerESRGAN:
262
- def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype):
263
- self.model_path = model_path
264
- self.device = device
265
- self.model = self.load_model(model_path)
266
- self.to(device, dtype)
267
-
268
- def __call__(self, img: Image.Image) -> Image.Image:
269
- return self.upscale_without_tiling(img)
270
-
271
- def to(self, device: torch.device, dtype: torch.dtype):
272
- self.device = device
273
- self.dtype = dtype
274
- self.model.to(device=device, dtype=dtype)
275
-
276
- def load_model(self, path: Path) -> RRDBNet:
277
- filename = path
278
- state_dict: dict[str, torch.Tensor] = torch.load(filename, weights_only=True, map_location=self.device) # type: ignore
279
- in_nc, out_nc, nf, nb, upscale = infer_params(state_dict)
280
- assert upscale == 4, "Only 4x upscaling is supported"
281
- model = RRDBNet(in_nc=in_nc, out_nc=out_nc, nf=nf, nb=nb)
282
- model.load_state_dict(state_dict)
283
- model.eval()
284
-
285
- return model
286
-
287
- def upscale_without_tiling(self, img: Image.Image) -> Image.Image:
288
- img_np = np.array(img)
289
- img_np = img_np[:, :, ::-1]
290
- img_np = np.ascontiguousarray(np.transpose(img_np, (2, 0, 1))) / 255
291
- img_t = torch.from_numpy(img_np).float() # type: ignore
292
- img_t = img_t.unsqueeze(0).to(device=self.device, dtype=self.dtype)
293
- with torch.no_grad():
294
- output = self.model(img_t)
295
- output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
296
- output = 255.0 * np.moveaxis(output, 0, 2)
297
- output = output.astype(np.uint8)
298
- output = output[:, :, ::-1]
299
- return Image.fromarray(output, "RGB")
300
-
301
- def upscale_with_tiling(self, img: Image.Image) -> Image.Image:
302
- img = img.convert("RGB")
303
- grid = split_grid(img)
304
- newtiles: Tiles = []
305
- scale_factor: int = 1
306
-
307
- for y, h, row in grid.tiles:
308
- newrow: list[Tile] = []
309
- for tiledata in row:
310
- x, w, tile = tiledata
311
- output = self.upscale_without_tiling(tile)
312
- scale_factor = output.width // tile.width
313
- newrow.append((x * scale_factor, w * scale_factor, output))
314
- newtiles.append((y * scale_factor, h * scale_factor, newrow))
315
-
316
- newgrid = Grid(
317
- newtiles,
318
- grid.tile_w * scale_factor,
319
- grid.tile_h * scale_factor,
320
- grid.image_w * scale_factor,
321
- grid.image_h * scale_factor,
322
- grid.overlap * scale_factor,
323
- )
324
- output = combine_grid(newgrid)
325
- return output
326
-
327
- @dataclass(kw_only=True)
328
- class ESRGANUpscalerCheckpoints(UpscalerCheckpoints):
329
- esrgan: Path
330
-
331
- class ESRGANUpscaler(MultiUpscaler):
332
- def __init__(
333
- self,
334
- checkpoints: ESRGANUpscalerCheckpoints,
335
- device: torch.device,
336
- dtype: torch.dtype,
337
- ) -> None:
338
- super().__init__(checkpoints=checkpoints, device=device, dtype=dtype)
339
- self.esrgan = UpscalerESRGAN(checkpoints.esrgan, device=self.device, dtype=self.dtype)
340
-
341
- def to(self, device: torch.device, dtype: torch.dtype):
342
- self.esrgan.to(device=device, dtype=dtype)
343
- self.sd = self.sd.to(device=device, dtype=dtype)
344
- self.device = device
345
- self.dtype = dtype
346
-
347
- def pre_upscale(self, image: Image.Image, upscale_factor: float, **_: Any) -> Image.Image:
348
- image = self.esrgan.upscale_with_tiling(image)
349
- return super().pre_upscale(image=image, upscale_factor=upscale_factor / 4)
350
-
351
- pillow_heif.register_heif_opener()
352
- pillow_heif.register_avif_opener()
353
-
354
- CHECKPOINTS = ESRGANUpscalerCheckpoints(
355
- unet=Path(
356
- hf_hub_download(
357
- repo_id="refiners/juggernaut.reborn.sd1_5.unet",
358
- filename="model.safetensors",
359
- revision="347d14c3c782c4959cc4d1bb1e336d19f7dda4d2",
360
- )
361
- ),
362
- clip_text_encoder=Path(
363
- hf_hub_download(
364
- repo_id="refiners/juggernaut.reborn.sd1_5.text_encoder",
365
- filename="model.safetensors",
366
- revision="744ad6a5c0437ec02ad826df9f6ede102bb27481",
367
- )
368
- ),
369
- lda=Path(
370
- hf_hub_download(
371
- repo_id="refiners/juggernaut.reborn.sd1_5.autoencoder",
372
- filename="model.safetensors",
373
- revision="3c1aae3fc3e03e4a2b7e0fa42b62ebb64f1a4c19",
374
- )
375
- ),
376
- controlnet_tile=Path(
377
- hf_hub_download(
378
- repo_id="refiners/controlnet.sd1_5.tile",
379
- filename="model.safetensors",
380
- revision="48ced6ff8bfa873a8976fa467c3629a240643387",
381
- )
382
- ),
383
- esrgan=Path(
384
- hf_hub_download(
385
- repo_id="philz1337x/upscaler",
386
- filename="4x-UltraSharp.pth",
387
- revision="011deacac8270114eb7d2eeff4fe6fa9a837be70",
388
- )
389
- ),
390
- negative_embedding=Path(
391
- hf_hub_download(
392
- repo_id="philz1337x/embeddings",
393
- filename="JuggernautNegative-neg.pt",
394
- revision="203caa7e9cc2bc225031a4021f6ab1ded283454a",
395
- )
396
- ),
397
- negative_embedding_key="string_to_param.*",
398
- loras={
399
- "more_details": Path(
400
- hf_hub_download(
401
- repo_id="philz1337x/loras",
402
- filename="more_details.safetensors",
403
- revision="a3802c0280c0d00c2ab18d37454a8744c44e474e",
404
- )
405
- ),
406
- "sdxl_render": Path(
407
- hf_hub_download(
408
- repo_id="philz1337x/loras",
409
- filename="SDXLrender_v2.0.safetensors",
410
- revision="a3802c0280c0d00c2ab18d37454a8744c44e474e",
411
- )
412
- )
413
- }
414
- )
415
-
416
- device = DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
417
- DTYPE = dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
418
-
419
- enhancer = ESRGANUpscaler(checkpoints=CHECKPOINTS, device=device, dtype=DTYPE)
420
-
421
- # logging
422
-
423
- warnings.filterwarnings("ignore")
424
- root = logging.getLogger()
425
- root.setLevel(logging.WARN)
426
- handler = logging.StreamHandler(sys.stderr)
427
- handler.setLevel(logging.WARN)
428
- formatter = logging.Formatter('\n >>> [%(levelname)s] %(asctime)s %(name)s: %(message)s\n')
429
- handler.setFormatter(formatter)
430
- root.addHandler(handler)
431
-
432
- # constant data
433
-
434
- MAX_SEED = np.iinfo(np.int32).max
435
-
436
- # precision data
437
-
438
- seq=512
439
- image_steps=40
440
- img_accu=6.5
441
-
442
- # ui data
443
-
444
- css="".join(["""
445
- input, textarea, input::placeholder, textarea::placeholder {
446
- text-align: center !important;
447
- }
448
- *, *::placeholder {
449
- font-family: Suez One !important;
450
- }
451
- h1,h2,h3,h4,h5,h6 {
452
- width: 100%;
453
- text-align: center;
454
- }
455
- footer {
456
- display: none !important;
457
- }
458
- .image-container {
459
- aspect-ratio: 1/1 !important;
460
- border: 2mm ridge black !important;
461
- }
462
- .dropdown-arrow {
463
- display: none !important;
464
- }
465
- *:has(>.btn) {
466
- display: flex;
467
- justify-content: space-evenly;
468
- align-items: center;
469
- }
470
- .btn {
471
- display: flex;
472
- }
473
-
474
- /* Added background gradient for a more colorful look */
475
- .gradio-container {
476
- background: linear-gradient(to right, #ffecd2, #fcb69f) !important;
477
- }
478
- """])
479
-
480
-
481
- # torch pipes
482
-
483
- image_pipe = DiffusionPipeline.from_pretrained("ostris/Flex.1-alpha", torch_dtype=dtype).to(device)
484
- image_pipe.enable_model_cpu_offload()
485
-
486
- torch.cuda.empty_cache()
487
-
488
- # functionality
489
-
490
- @spaces.GPU(duration=300)
491
- def hard_scaler(img):
492
- return upscaler(img)
493
-
494
- @spaces.GPU(duration=150)
495
- def easy_scaler(img):
496
- return upscaler(img)
497
-
498
- def handle_upscaler(img):
499
- w, h = img.size
500
- if w*h > 2 * (10 ** 6):
501
- return hard_scaler(img)
502
- return easy_scaler(img)
503
-
504
- def upscaler(
505
- input_image: Image.Image,
506
- prompt: str = "Accurate, Highly Detailed, Realistic, Best Quality, Hyper-Realistic, Super-Realistic, Natural, Reasonable, Logical.",
507
- negative_prompt: str = "Unreal, Exceptional, Irregular, Unusual, Blurry, Smoothed, Polished, Worst Quality, Worse Quality, Normal Quality, Painted, Movies Quality.",
508
- seed: int = random.randint(0, MAX_SEED),
509
- upscale_factor: int = 2,
510
- controlnet_scale: float = 0.6,
511
- controlnet_decay: float = 1.0,
512
- condition_scale: int = 6,
513
- tile_width: int = 112,
514
- tile_height: int = 144,
515
- denoise_strength: float = 0.35,
516
- num_inference_steps: int = 20,
517
- solver: str = "DDIM",
518
- ) -> Image.Image:
519
-
520
- log(f'CALL upscaler')
521
-
522
- manual_seed(seed)
523
- solver_type: type[Solver] = getattr(solvers, solver)
524
-
525
- log(f'DBG upscaler 1')
526
-
527
- enhanced_image = enhancer.upscale(
528
- image=input_image,
529
- prompt=prompt,
530
- negative_prompt=negative_prompt,
531
- upscale_factor=upscale_factor,
532
- controlnet_scale=controlnet_scale,
533
- controlnet_scale_decay=controlnet_decay,
534
- condition_scale=condition_scale,
535
- tile_size=(tile_height, tile_width),
536
- denoise_strength=denoise_strength,
537
- num_inference_steps=num_inference_steps,
538
- loras_scale={"more_details": 0.5, "sdxl_render": 1.0},
539
- solver_type=solver_type,
540
- )
541
-
542
- log(f'RET upscaler')
543
- return enhanced_image
544
-
545
- def get_tensor_length(tensor):
546
- nums = list(tensor.size())
547
- ret = 1
548
- for num in nums:
549
- ret *= num
550
- return ret
551
-
552
- def _summarize(text):
553
- log(f'CALL _summarize')
554
- prefix = "summarize: "
555
- toks = tokenizer.encode(prefix + text, return_tensors="pt", truncation=False)
556
- gen = model.generate(
557
- toks,
558
- length_penalty=0.1,
559
- num_beams=6,
560
- early_stopping=True,
561
- max_length=512
562
- )
563
- ret = tokenizer.decode(gen[0], skip_special_tokens=True)
564
- log(f'RET _summarize with ret as {ret}')
565
- return ret
566
-
567
- def summarize(text, max_words=100):
568
- log(f'CALL summarize')
569
 
570
- words = text.split()
571
- words_length = len(words)
572
-
573
- if words_length >= 510:
574
- while words_length >= 510:
575
- words = text.split()
576
- summ = _summarize(" ".join(words[0:510])) + " ".join(words[510:])
577
- if summ == text:
578
- return text
579
- text = summ
580
- words_length = len(text.split())
581
 
582
- while words_length > max_words:
583
- summ = _summarize(text)
584
- if summ == text:
585
- return text
586
- text = summ
587
- words_length = len(text.split())
588
 
589
- log(f'RET summarize with text as {text}')
590
- return text
591
-
592
- def generate_random_string(length):
593
- characters = str(ascii_letters + digits)
594
- return ''.join(random.choice(characters) for _ in range(length))
595
-
596
- def add_text_above_image(img,top_title=None,bottom_title=None):
597
- w, h = img.size
598
- draw = ImageDraw.Draw(img,mode="RGBA")
599
-
600
- labels_distance = 1/3
601
-
602
- if top_title:
603
- rows = len(top_title.split("\n"))
604
- textheight=min(math.ceil( w / 10 ), math.ceil( h / 5 ))
605
- font = ImageFont.truetype(r"Alef-Bold.ttf", textheight)
606
- textwidth = draw.textlength(top_title,font)
607
- x = math.ceil((w - textwidth) / 2)
608
- y = h - (textheight * rows / 2) - (h / 2)
609
- y = math.ceil(y - (h / 2 * labels_distance))
610
- draw.text(
611
- (x, y),
612
- top_title,
613
- (255,255,255),
614
- font=font,
615
- spacing=2,
616
- stroke_width=math.ceil(textheight/20),
617
- stroke_fill=(0,0,0)
618
- )
619
-
620
- if bottom_title:
621
- rows = len(bottom_title.split("\n"))
622
- textheight=min(math.ceil( w / 10 ), math.ceil( h / 5 ))
623
- font = ImageFont.truetype(r"Alef-Bold.ttf", textheight)
624
- textwidth = draw.textlength(bottom_title,font)
625
- x = math.ceil((w - textwidth) / 2)
626
- y = h - (textheight * rows / 2) - (h / 2)
627
- y = math.ceil(y + (h / 2 * labels_distance))
628
- draw.text(
629
- (x, y),
630
- bottom_title,
631
- (0,0,0),
632
- font=font,
633
- spacing=2,
634
- stroke_width=math.ceil(textheight/20),
635
- stroke_fill=(255,255,255)
636
- )
637
 
638
- return img
639
-
640
- # Modified parts from https://github.com/nidhaloff/deep-translator:
641
-
642
- google_translate_endpoint = "https://translate.google.com/m"
643
- language_codes = {
644
- "afrikaans": "af",
645
- "albanian": "sq",
646
- "amharic": "am",
647
- "arabic": "ar",
648
- "armenian": "hy",
649
- "assamese": "as",
650
- "aymara": "ay",
651
- "azerbaijani": "az",
652
- "bambara": "bm",
653
- "basque": "eu",
654
- "belarusian": "be",
655
- "bengali": "bn",
656
- "bhojpuri": "bho",
657
- "bosnian": "bs",
658
- "bulgarian": "bg",
659
- "catalan": "ca",
660
- "cebuano": "ceb",
661
- "chichewa": "ny",
662
- "chinese (simplified)": "zh-CN",
663
- "chinese (traditional)": "zh-TW",
664
- "corsican": "co",
665
- "croatian": "hr",
666
- "czech": "cs",
667
- "danish": "da",
668
- "dhivehi": "dv",
669
- "dogri": "doi",
670
- "dutch": "nl",
671
- "english": "en",
672
- "esperanto": "eo",
673
- "estonian": "et",
674
- "ewe": "ee",
675
- "filipino": "tl",
676
- "finnish": "fi",
677
- "french": "fr",
678
- "frisian": "fy",
679
- "galician": "gl",
680
- "georgian": "ka",
681
- "german": "de",
682
- "greek": "el",
683
- "guarani": "gn",
684
- "gujarati": "gu",
685
- "haitian creole": "ht",
686
- "hausa": "ha",
687
- "hawaiian": "haw",
688
- "hebrew": "iw",
689
- "hindi": "hi",
690
- "hmong": "hmn",
691
- "hungarian": "hu",
692
- "icelandic": "is",
693
- "igbo": "ig",
694
- "ilocano": "ilo",
695
- "indonesian": "id",
696
- "irish": "ga",
697
- "italian": "it",
698
- "japanese": "ja",
699
- "javanese": "jw",
700
- "kannada": "kn",
701
- "kazakh": "kk",
702
- "khmer": "km",
703
- "kinyarwanda": "rw",
704
- "konkani": "gom",
705
- "korean": "ko",
706
- "krio": "kri",
707
- "kurdish (kurmanji)": "ku",
708
- "kurdish (sorani)": "ckb",
709
- "kyrgyz": "ky",
710
- "lao": "lo",
711
- "latin": "la",
712
- "latvian": "lv",
713
- "lingala": "ln",
714
- "lithuanian": "lt",
715
- "luganda": "lg",
716
- "luxembourgish": "lb",
717
- "macedonian": "mk",
718
- "maithili": "mai",
719
- "malagasy": "mg",
720
- "malay": "ms",
721
- "malayalam": "ml",
722
- "maltese": "mt",
723
- "maori": "mi",
724
- "marathi": "mr",
725
- "meiteilon (manipuri)": "mni-Mtei",
726
- "mizo": "lus",
727
- "mongolian": "mn",
728
- "myanmar": "my",
729
- "nepali": "ne",
730
- "norwegian": "no",
731
- "odia (oriya)": "or",
732
- "oromo": "om",
733
- "pashto": "ps",
734
- "persian": "fa",
735
- "polish": "pl",
736
- "portuguese": "pt",
737
- "punjabi": "pa",
738
- "quechua": "qu",
739
- "romanian": "ro",
740
- "russian": "ru",
741
- "samoan": "sm",
742
- "sanskrit": "sa",
743
- "scots gaelic": "gd",
744
- "sepedi": "nso",
745
- "serbian": "sr",
746
- "sesotho": "st",
747
- "shona": "sn",
748
- "sindhi": "sd",
749
- "sinhala": "si",
750
- "slovak": "sk",
751
- "slovenian": "sl",
752
- "somali": "so",
753
- "spanish": "es",
754
- "sundanese": "su",
755
- "swahili": "sw",
756
- "swedish": "sv",
757
- "tajik": "tg",
758
- "tamil": "ta",
759
- "tatar": "tt",
760
- "telugu": "te",
761
- "thai": "th",
762
- "tigrinya": "ti",
763
- "tsonga": "ts",
764
- "turkish": "tr",
765
- "turkmen": "tk",
766
- "twi": "ak",
767
- "ukrainian": "uk",
768
- "urdu": "ur",
769
- "uyghur": "ug",
770
- "uzbek": "uz",
771
- "vietnamese": "vi",
772
- "welsh": "cy",
773
- "xhosa": "xh",
774
- "yiddish": "yi",
775
- "yoruba": "yo",
776
- "zulu": "zu",
777
- }
778
-
779
- class BaseError(Exception):
780
- """
781
- base error structure class
782
- """
783
-
784
- def __init__(self, val, message):
785
- self.val = val
786
- self.message = message
787
- super().__init__()
788
-
789
- def __str__(self):
790
- return "{} --> {}".format(self.val, self.message)
791
-
792
-
793
- class LanguageNotSupportedException(BaseError):
794
- """
795
- exception thrown if the user uses a language
796
- that is not supported by the deep_translator
797
- """
798
 
799
- def __init__(
800
- self, val, message="There is no support for the chosen language"
801
- ):
802
- super().__init__(val, message)
803
-
804
-
805
- class NotValidPayload(BaseError):
806
- """
807
- exception thrown if the user enters an invalid payload
808
- """
809
-
810
- def __init__(
811
- self,
812
- val,
813
- message="text must be a valid text with maximum 5000 character,"
814
- "otherwise it cannot be translated",
815
- ):
816
- super(NotValidPayload, self).__init__(val, message)
817
-
818
-
819
- class InvalidSourceOrTargetLanguage(BaseError):
820
- """
821
- exception thrown if the user enters an invalid payload
822
- """
823
-
824
- def __init__(self, val, message="Invalid source or target language!"):
825
- super(InvalidSourceOrTargetLanguage, self).__init__(val, message)
826
-
827
-
828
- class TranslationNotFound(BaseError):
829
- """
830
- exception thrown if no translation was found for the text provided by the user
831
- """
832
-
833
- def __init__(
834
- self,
835
- val,
836
- message="No translation was found using the current translator. Try another translator?",
837
- ):
838
- super(TranslationNotFound, self).__init__(val, message)
839
-
840
-
841
- class ElementNotFoundInGetRequest(BaseError):
842
- """
843
- exception thrown if the html element was not found in the body parsed by beautifulsoup
844
- """
845
-
846
- def __init__(
847
- self, val, message="Required element was not found in the API response"
848
- ):
849
- super(ElementNotFoundInGetRequest, self).__init__(val, message)
850
-
851
-
852
- class NotValidLength(BaseError):
853
- """
854
- exception thrown if the provided text exceed the length limit of the translator
855
- """
856
-
857
- def __init__(self, val, min_chars, max_chars):
858
- message = f"Text length need to be between {min_chars} and {max_chars} characters"
859
- super(NotValidLength, self).__init__(val, message)
860
-
861
-
862
- class RequestError(Exception):
863
- """
864
- exception thrown if an error occurred during the request call, e.g a connection problem.
865
- """
866
-
867
- def __init__(
868
- self,
869
- message="Request exception can happen due to an api connection error. "
870
- "Please check your connection and try again",
871
- ):
872
- self.message = message
873
-
874
- def __str__(self):
875
- return self.message
876
-
877
-
878
- class TooManyRequests(Exception):
879
- """
880
- exception thrown if an error occurred during the request call, e.g a connection problem.
881
- """
882
-
883
- def __init__(
884
- self,
885
- message="Server Error: You made too many requests to the server."
886
- "According to google, you are allowed to make 5 requests per second"
887
- "and up to 200k requests per day. You can wait and try again later or"
888
- "you can try the translate_batch function",
889
- ):
890
- self.message = message
891
-
892
- def __str__(self):
893
- return self.message
894
-
895
-
896
- class ServerException(Exception):
897
- """
898
- Default YandexTranslate exception from the official website
899
- """
900
-
901
- errors = {
902
- 400: "ERR_BAD_REQUEST",
903
- 401: "ERR_KEY_INVALID",
904
- 402: "ERR_KEY_BLOCKED",
905
- 403: "ERR_DAILY_REQ_LIMIT_EXCEEDED",
906
- 404: "ERR_DAILY_CHAR_LIMIT_EXCEEDED",
907
- 413: "ERR_TEXT_TOO_LONG",
908
- 429: "ERR_TOO_MANY_REQUESTS",
909
- 422: "ERR_UNPROCESSABLE_TEXT",
910
- 500: "ERR_INTERNAL_SERVER_ERROR",
911
- 501: "ERR_LANG_NOT_SUPPORTED",
912
- 503: "ERR_SERVICE_NOT_AVAIBLE",
913
- }
914
-
915
- def __init__(self, status_code, *args):
916
- message = self.errors.get(status_code, "API server error")
917
- super(ServerException, self).__init__(message, *args)
918
-
919
- def is_empty(text: str) -> bool:
920
- return text == ""
921
-
922
-
923
- def request_failed(status_code: int) -> bool:
924
- """Check if a request has failed or not.
925
- A request is considered successful if the status code is in the 2** range."""
926
- if status_code > 299 or status_code < 200:
927
- return True
928
- return False
929
-
930
-
931
- def is_input_valid(
932
- text: str, min_chars: int = 0, max_chars: Optional[int] = None
933
- ) -> bool:
934
- """
935
- validate the target text to translate
936
- @param min_chars: min characters
937
- @param max_chars: max characters
938
- @param text: text to translate
939
- @return: bool
940
- """
941
- if not isinstance(text, str):
942
- raise NotValidPayload(text)
943
- if max_chars and (not min_chars <= len(text) < max_chars):
944
- raise NotValidLength(text, min_chars, max_chars)
945
- return True
946
-
947
- class BaseTranslator(ABC):
948
- """
949
- Abstract class that serve as a base translator for other different translators
950
- """
951
-
952
- def __init__(
953
- self,
954
- base_url: str = None,
955
- languages: dict = language_codes,
956
- source: str = "auto",
957
- target: str = "en",
958
- payload_key: Optional[str] = None,
959
- element_tag: Optional[str] = None,
960
- element_query: Optional[dict] = None,
961
- **url_params,
962
- ):
963
- """
964
- @param source: source language to translate from
965
- @param target: target language to translate to
966
- """
967
- self._base_url = base_url
968
- self._languages = languages
969
- self._supported_languages = list(self._languages.keys())
970
- if not source:
971
- raise InvalidSourceOrTargetLanguage(source)
972
- if not target:
973
- raise InvalidSourceOrTargetLanguage(target)
974
-
975
- self._source, self._target = self._map_language_to_code(source, target)
976
- self._url_params = url_params
977
- self._element_tag = element_tag
978
- self._element_query = element_query
979
- self.payload_key = payload_key
980
- super().__init__()
981
-
982
- @property
983
- def source(self):
984
- return self._source
985
-
986
- @source.setter
987
- def source(self, lang):
988
- self._source = lang
989
-
990
- @property
991
- def target(self):
992
- return self._target
993
-
994
- @target.setter
995
- def target(self, lang):
996
- self._target = lang
997
-
998
- def _type(self):
999
- return self.__class__.__name__
1000
-
1001
- def _map_language_to_code(self, *languages):
1002
- """
1003
- map language to its corresponding code (abbreviation) if the language was passed
1004
- by its full name by the user
1005
- @param languages: list of languages
1006
- @return: mapped value of the language or raise an exception if the language is
1007
- not supported
1008
- """
1009
- for language in languages:
1010
- if language in self._languages.values() or language == "auto":
1011
- yield language
1012
- elif language in self._languages.keys():
1013
- yield self._languages[language]
1014
- else:
1015
- raise LanguageNotSupportedException(
1016
- language,
1017
- message=f"No support for the provided language.\n"
1018
- f"Please select on of the supported languages:\n"
1019
- f"{self._languages}",
1020
- )
1021
-
1022
- def _same_source_target(self) -> bool:
1023
- return self._source == self._target
1024
-
1025
- def get_supported_languages(
1026
- self, as_dict: bool = False, **kwargs
1027
- ) -> Union[list, dict]:
1028
- """
1029
- return the supported languages by the Google translator
1030
- @param as_dict: if True, the languages will be returned as a dictionary
1031
- mapping languages to their abbreviations
1032
- @return: list or dict
1033
- """
1034
- return self._supported_languages if not as_dict else self._languages
1035
-
1036
- def is_language_supported(self, language: str, **kwargs) -> bool:
1037
- """
1038
- check if the language is supported by the translator
1039
- @param language: a string for 1 language
1040
- @return: bool
1041
- """
1042
- if (
1043
- language == "auto"
1044
- or language in self._languages.keys()
1045
- or language in self._languages.values()
1046
- ):
1047
- return True
1048
- else:
1049
- return False
1050
-
1051
- @abstractmethod
1052
- def translate(self, text: str, **kwargs) -> str:
1053
- """
1054
- translate a text using a translator under the hood and return
1055
- the translated text
1056
- @param text: text to translate
1057
- @param kwargs: additional arguments
1058
- @return: str
1059
- """
1060
- return NotImplemented("You need to implement the translate method!")
1061
-
1062
- def _read_docx(self, f: str):
1063
- import docx2txt
1064
- return docx2txt.process(f)
1065
-
1066
- def _read_pdf(self, f: str):
1067
- import pypdf
1068
- reader = pypdf.PdfReader(f)
1069
- page = reader.pages[0]
1070
- return page.extract_text()
1071
-
1072
- def _translate_file(self, path: str, **kwargs) -> str:
1073
- """
1074
- translate directly from file
1075
- @param path: path to the target file
1076
- @type path: str
1077
- @param kwargs: additional args
1078
- @return: str
1079
- """
1080
- if not isinstance(path, Path):
1081
- path = Path(path)
1082
-
1083
- if not path.exists():
1084
- print("Path to the file is wrong!")
1085
- exit(1)
1086
-
1087
- ext = path.suffix
1088
-
1089
- if ext == ".docx":
1090
- text = self._read_docx(f=str(path))
1091
- elif ext == ".pdf":
1092
- text = self._read_pdf(f=str(path))
1093
- else:
1094
- with open(path, "r", encoding="utf-8") as f:
1095
- text = f.read().strip()
1096
-
1097
- return self.translate(text)
1098
-
1099
- def _translate_batch(self, batch: List[str], **kwargs) -> List[str]:
1100
- """
1101
- translate a list of texts
1102
- @param batch: list of texts you want to translate
1103
- @return: list of translations
1104
- """
1105
- if not batch:
1106
- raise Exception("Enter your text list that you want to translate")
1107
- arr = []
1108
- for i, text in enumerate(batch):
1109
- translated = self.translate(text, **kwargs)
1110
- arr.append(translated)
1111
- return arr
1112
-
1113
- class GoogleTranslator(BaseTranslator):
1114
- """
1115
- class that wraps functions, which use Google Translate under the hood to translate text(s)
1116
  """
1117
-
1118
- def __init__(
1119
- self,
1120
- source: str = "auto",
1121
- target: str = "en",
1122
- proxies: Optional[dict] = None,
1123
- **kwargs
1124
- ):
1125
- self.proxies = proxies
1126
- super().__init__(
1127
- base_url=google_translate_endpoint,
1128
- source=source,
1129
- target=target,
1130
- element_tag="div",
1131
- element_query={"class": "t0"},
1132
- payload_key="q",
1133
- **kwargs
1134
- )
1135
-
1136
- self._alt_element_query = {"class": "result-container"}
1137
-
1138
- def translate(self, text: str, **kwargs) -> str:
1139
- if is_input_valid(text, max_chars=1000):
1140
- text = text.strip()
1141
- if self._same_source_target() or is_empty(text):
1142
- return text
1143
- self._url_params["tl"] = self._target
1144
- self._url_params["sl"] = self._source
1145
-
1146
- if self.payload_key:
1147
- self._url_params[self.payload_key] = text
1148
-
1149
- response = requests.get(
1150
- self._base_url, params=self._url_params, proxies=self.proxies
1151
  )
1152
- if response.status_code == 429:
1153
- raise TooManyRequests()
1154
-
1155
- if request_failed(status_code=response.status_code):
1156
- raise RequestError()
1157
-
1158
- soup = BeautifulSoup(response.text, "html.parser")
1159
-
1160
- element = soup.find(self._element_tag, self._element_query)
1161
- response.close()
1162
-
1163
- if not element:
1164
- element = soup.find(self._element_tag, self._alt_element_query)
1165
- if not element:
1166
- raise TranslationNotFound(text)
1167
-
1168
- if element.get_text(strip=True) == text.strip():
1169
- to_translate_alpha = "".join(ch for ch in text.strip() if ch.isalnum())
1170
- translated_alpha = "".join(ch for ch in element.get_text(strip=True) if ch.isalnum())
1171
- if (
1172
- to_translate_alpha
1173
- and translated_alpha
1174
- and to_translate_alpha == translated_alpha
1175
- ):
1176
- self._url_params["tl"] = self._target
1177
- if "hl" not in self._url_params:
1178
- return text.strip()
1179
- del self._url_params["hl"]
1180
- return self.translate(text)
1181
- else:
1182
- return element.get_text(strip=True)
1183
-
1184
- def translate_file(self, path: str, **kwargs) -> str:
1185
- return self._translate_file(path, **kwargs)
1186
-
1187
- def translate_batch(self, batch: List[str], **kwargs) -> List[str]:
1188
- return self._translate_batch(batch, **kwargs)
1189
-
1190
-
1191
- def translate(txt,to_lang="en",from_lang="auto"):
1192
- log(f'CALL translate')
1193
- if len(txt) == 0:
1194
- print("Translated text is empty. Skipping translation...")
1195
- return txt.strip().lower()
1196
- if from_lang == to_lang or get_language(txt) == to_lang:
1197
- print("Same languages. Skipping translation...")
1198
- return txt.strip().lower()
1199
- translator = GoogleTranslator(from_lang=from_lang,to_lang=to_lang)
1200
- translation = ""
1201
- if len(txt) > 1000:
1202
- words = txt.split()
1203
- while len(words) > 0:
1204
- chunk = ""
1205
- while len(words) > 0 and len(chunk) < 1000:
1206
- chunk = chunk + " " + words[0]
1207
- words = words[1:]
1208
- if len(chunk) > 1000:
1209
- _words = chunk.split()
1210
- words = [_words[-1], *words]
1211
- chunk = " ".join(_words[:-1])
1212
- translation = translation + " " + translator.translate(chunk)
1213
- else:
1214
- translation = translator.translate(txt)
1215
- translation = translation.strip()
1216
- log(f'RET translate with translation as {translation}')
1217
- return translation.lower()
1218
-
1219
- def handle_generation(h,w,d):
1220
- log(f'CALL handle_generate')
1221
- difficulty_points = 0
1222
-
1223
- toks_len = get_tensor_length(tokenizer.encode(d, return_tensors="pt", truncation=False))
1224
- if toks_len > 500:
1225
- difficulty_points += 2
1226
- elif toks_len > 50:
1227
- difficulty_points += 1
1228
-
1229
- pxs = h*w
1230
- if pxs > 2 * (10 ** 6):
1231
- difficulty_points += 2
1232
- elif pxs > 1 * (10 ** 6):
1233
- difficulty_points += 1
1234
-
1235
- if difficulty_points < 2:
1236
- return easy_generation(h,w,d)
1237
- elif difficulty_points < 4:
1238
- return balanced_generation(h,w,d)
1239
- else:
1240
- return hard_generation(h,w,d)
1241
-
1242
- @spaces.GPU(duration=150)
1243
- def easy_generation(h,w,d):
1244
- return generation(h,w,d)
1245
-
1246
- @spaces.GPU(duration=210)
1247
- def balanced_generation(h,w,d):
1248
- return generation(h,w,d)
1249
-
1250
- @spaces.GPU(duration=270)
1251
- def hard_generation(h,w,d):
1252
- return generation(h,w,d)
1253
-
1254
- def generation(h,w,d):
1255
- if len(d) > 0:
1256
- d = re.sub(r",( ){1,}",". ",d)
1257
- d_lines = re.split(r"([\n]){1,}", d)
1258
 
1259
- for line_index in range(len(d_lines)):
1260
- d_lines[line_index] = d_lines[line_index].strip()
1261
- if d_lines[line_index] != "" and re.sub(r'[\.]$', '', d_lines[line_index]) == d_lines[line_index]:
1262
- d_lines[line_index] += "."
1263
- d = " ".join(d_lines)
1264
 
1265
- d = re.sub(r"([ \t]){1,}", " ", d).lower().strip()
1266
- d = d if d == "" else summarize(translate(d), max_words=50)
1267
- d = re.sub(r"([ \t]){1,}", " ", d)
1268
- d = re.sub(r"(\. \.)", ".", d)
1269
-
1270
- neg = f"Textual, Text, Signs, Labels, Titles, Unreal, Exceptional, Irregular, Unusual, Blurry, Smoothed, Polished, Worst Quality, Worse Quality, Painted, Movies Quality."
1271
- pos = f'Accurate, Detailed, Realistic.{ "" if d == "" else " " + d }'
1272
-
1273
- print(f"""
1274
- Positive: {pos}
1275
-
1276
- Negative: {neg}
1277
- """)
1278
-
1279
- img = image_pipe(
1280
- prompt=pos,
1281
- negative_prompt=neg,
1282
- height=h,
1283
- width=w,
1284
- output_type="pil",
1285
- guidance_scale=img_accu,
1286
- num_images_per_prompt=1,
1287
- num_inference_steps=image_steps,
1288
- max_sequence_length=seq,
1289
- generator=torch.Generator(device).manual_seed(random.randint(0, MAX_SEED))
1290
- ).images[0]
1291
- return img
1292
-
1293
- # entry
1294
-
1295
- if __name__ == "__main__":
1296
- # Changed the theme to a more colorful one and updated the title to English
1297
- with gr.Blocks(theme=gr.themes.Soft(primary_hue="lime"), css=css) as demo:
1298
- gr.Markdown(f"""
1299
- # Multilingual Images
1300
- """)
1301
- gr.Markdown(f"""
1302
- ### Realistic. Upscalable. Multilingual.
1303
- """)
1304
-
1305
- gr.HTML("""<a href="https://visitorbadge.io/status?path=https%3A%2F%2Faiqcamp-Multilingual-Images.hf.space">
1306
- <img src="https://api.visitorbadge.io/api/visitors?path=https%3A%2F%2Faiqcamp-Multilingual-Images.hf.space&countColor=%23263759" />
1307
- </a>""")
1308
-
1309
-
1310
- with gr.Row():
1311
- with gr.Column(scale=2):
1312
- height = gr.Slider(
1313
- label="Height (px)",
1314
- minimum=512,
1315
- maximum=1536,
1316
- step=16,
1317
- value=1024,
1318
- )
1319
- width = gr.Slider(
1320
- label="Width (px)",
1321
- minimum=512,
1322
- maximum=1536,
1323
- step=16,
1324
- value=1024,
1325
- )
1326
-
1327
- run = gr.Button("Generate", elem_classes="btn")
1328
-
1329
- top = gr.Textbox(
1330
- placeholder="Top Title",
1331
- value="",
1332
- container=False,
1333
- max_lines=1
1334
- )
1335
- bottom = gr.Textbox(
1336
- placeholder="Bottom Title",
1337
- value="",
1338
- container=False,
1339
- max_lines=1
1340
- )
1341
-
1342
- data = gr.Textbox(
1343
- placeholder="Enter your text/prompt (multiple languages allowed)",
1344
- value="",
1345
- container=False,
1346
- max_lines=100
1347
- )
1348
-
1349
- with gr.Column():
1350
- cover = gr.Image(
1351
- interactive=False,
1352
- container=False,
1353
- elem_classes="image-container",
1354
- label="Result",
1355
- show_label=True,
1356
- type='pil',
1357
- show_share_button=False
1358
- )
1359
- upscale_now = gr.Button("Upscale x2", elem_classes="btn")
1360
- add_titles = gr.Button("Add title(s)", elem_classes="btn")
1361
-
1362
- gr.Markdown("---")
1363
-
1364
- # Bottom row explanation or details in English
1365
- gr.Markdown("""
1366
- ## Features
1367
- 1. **Text Input**: You can input text in various languages; it will be automatically translated and summarized before generating an image.
1368
- 2. **Image Size Adjustment**: Use sliders to specify the width and height of the output image.
1369
- 3. **Overlay Text**: Easily add top/bottom titles to the generated image with a simple button click.
1370
- 4. **High-Quality Upscaling**: Increase the resolution with the "Upscale x2" feature.
1371
- 5. **Automatic GPU Resource Management**: The system automatically adjusts GPU usage time depending on input text length and image size.
1372
- ---
1373
- """)
1374
-
1375
  gr.Markdown("""
1376
- ### Usage Guide
1377
- 1. Set the desired image dimensions and text prompt, then click **Generate**.
1378
- 2. After viewing the generated image, you can select **Upscale x2** to improve its resolution.
1379
- 3. Use **Add title(s)** to place custom titles at the top or bottom of the image.
1380
- 4. You can test all available features in the interface above.
 
 
 
 
 
 
 
 
 
 
 
 
1381
  """)
1382
 
1383
- # Event wiring
1384
- gr.on(
1385
- triggers=[run.click],
1386
- fn=handle_generation,
1387
- inputs=[height, width, data],
1388
- outputs=[cover]
1389
- )
1390
- upscale_now.click(
1391
- fn=handle_upscaler,
1392
- inputs=[cover],
1393
- outputs=[cover]
1394
- )
1395
- add_titles.click(
1396
- fn=add_text_above_image,
1397
- inputs=[cover, top, bottom],
1398
- outputs=[cover]
1399
- )
1400
-
1401
- demo.queue().launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ import os
5
+ import io
6
+ import base64
7
+ from kokoro import KModel, KPipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ # Check if CUDA is available
10
+ CUDA_AVAILABLE = torch.cuda.is_available()
11
 
12
+ # Initialize the model
13
+ model = KModel().to('cuda' if CUDA_AVAILABLE else 'cpu').eval()
 
 
 
 
14
 
15
+ # Initialize pipelines for different language codes (using 'a' for English)
16
+ pipelines = {'a': KPipeline(lang_code='a', model=False)}
 
 
 
 
 
17
 
18
+ # Custom pronunciation for "kokoro"
19
+ pipelines['a'].g2p.lexicon.golds['kokoro'] = 'kˈOkəɹO'
20
 
21
+ def text_to_audio(text, speed=1.0):
22
+ """Convert text to audio using Kokoro model.
23
+
24
+ Args:
25
+ text: The text to convert to speech
26
+ speed: Speech speed multiplier (0.5-2.0, where 1.0 is normal speed)
27
+
28
+ Returns:
29
+ Audio data as a tuple of (sample_rate, audio_array)
30
  """
31
+ if not text:
32
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
+ pipeline = pipelines['a'] # Use English pipeline
35
+ voice = "af_heart" # Default voice (US English, female, Heart)
 
 
 
 
 
 
 
 
 
36
 
37
+ # Process the text
38
+ pack = pipeline.load_voice(voice)
39
+
40
+ for _, ps, _ in pipeline(text, voice, speed):
41
+ ref_s = pack[len(ps)-1]
 
42
 
43
+ # Generate audio
44
+ try:
45
+ audio = model(ps, ref_s, speed)
46
+ except Exception as e:
47
+ raise gr.Error(f"Error generating audio: {str(e)}")
48
+
49
+ # Return the audio with 24kHz sample rate
50
+ return 24000, audio.numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
+ def text_to_audio_b64(text, speed=1.0):
55
+ """Convert text to audio and return as base64 encoded WAV file.
56
+
57
+ Args:
58
+ text: The text to convert to speech
59
+ speed: Speech speed multiplier (0.5-2.0, where 1.0 is normal speed)
60
+
61
+ Returns:
62
+ Base64 encoded WAV file as a string
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  """
64
+ import soundfile as sf
65
+
66
+ result = text_to_audio(text, speed)
67
+ if result is None:
68
+ return None
69
+
70
+ sample_rate, audio_data = result
71
+
72
+ # Save to BytesIO object
73
+ wav_io = io.BytesIO()
74
+ sf.write(wav_io, audio_data, sample_rate, format='WAV')
75
+ wav_io.seek(0)
76
+
77
+ # Convert to base64
78
+ wav_b64 = base64.b64encode(wav_io.read()).decode('utf-8')
79
+ return wav_b64
80
+
81
+ # Create Gradio interface
82
+ with gr.Blocks(title="Kokoro Text-to-Audio MCP") as app:
83
+ gr.Markdown("# 🎵 Kokoro Text-to-Audio MCP")
84
+ gr.Markdown("Convert text to speech using the Kokoro-82M model")
85
+
86
+ with gr.Row():
87
+ with gr.Column():
88
+ text_input = gr.Textbox(
89
+ label="Enter your text",
90
+ placeholder="Type something to convert to audio...",
91
+ lines=5
 
 
 
 
 
 
92
  )
93
+ speed_slider = gr.Slider(
94
+ minimum=0.5,
95
+ maximum=2.0,
96
+ value=1.0,
97
+ step=0.1,
98
+ label="Speech Speed"
99
+ )
100
+ submit_btn = gr.Button("Generate Audio")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
+ with gr.Column():
103
+ audio_output = gr.Audio(label="Generated Audio", type="numpy")
 
 
 
104
 
105
+ submit_btn.click(
106
+ fn=text_to_audio,
107
+ inputs=[text_input, speed_slider],
108
+ outputs=[audio_output]
109
+ )
110
+
111
+ gr.Markdown("### Usage Tips")
112
+ gr.Markdown("- Adjust the speed slider to modify the pace of speech")
113
+
114
+ # Add section about MCP support
115
+ with gr.Accordion("MCP Support (for LLMs)", open=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  gr.Markdown("""
117
+ ### MCP Support
118
+
119
+ This app supports the Model Context Protocol (MCP), allowing Large Language Models like Claude Desktop to use it as a tool.
120
+
121
+ To use this app with an MCP client, add the following configuration:
122
+
123
+ ```json
124
+ {
125
+ "mcpServers": {
126
+ "kokoroTTS": {
127
+ "url": "https://fdaudens-kokoro-mcp.hf.space/gradio_api/mcp/sse"
128
+ }
129
+ }
130
+ }
131
+ ```
132
+
133
+ Replace `your-app-url.hf.space` with your actual Hugging Face Space URL.
134
  """)
135
 
136
+ # Launch the app with MCP support
137
+ if __name__ == "__main__":
138
+ # Check for environment variable to enable MCP
139
+ enable_mcp = os.environ.get('GRADIO_MCP_SERVER', 'False').lower() in ('true', '1', 't')
140
+
141
+ app.launch(mcp_server=True)