Yaron Koresh commited on
Commit
8aa0947
·
verified ·
1 Parent(s): 8172c90

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +189 -33
app.py CHANGED
@@ -1,6 +1,7 @@
1
  """
2
  Modified parts included from these sources:
3
  - https://github.com/nidhaloff/deep-translator
 
4
  """
5
 
6
  import urllib
@@ -9,7 +10,7 @@ from bs4 import BeautifulSoup
9
  from abc import ABC, abstractmethod
10
  from pathlib import Path
11
  from langdetect import detect as get_language
12
- from typing import List, Optional, Union
13
  from collections import namedtuple
14
  from inspect import signature
15
  import os
@@ -38,9 +39,9 @@ import gradio as gr
38
  from lxml.html import fromstring
39
  from huggingface_hub import hf_hub_download
40
  from safetensors.torch import load_file, save_file
41
- from diffusers import DiffusionPipeline, AutoencoderTiny
42
  from PIL import Image, ImageDraw, ImageFont
43
- from transformers import pipeline, T5ForConditionalGeneration, T5Tokenizer
44
  from refiners.fluxion.utils import manual_seed
45
  from refiners.foundationals.latent_diffusion import Solver, solvers
46
  from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_upscaler import (
@@ -51,8 +52,169 @@ from datetime import datetime
51
 
52
  working = False
53
 
54
- model = T5ForConditionalGeneration.from_pretrained("t5-large")
55
- tokenizer = T5Tokenizer.from_pretrained("t5-large")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  def log(msg):
58
  print(f'{datetime.now().time()} {msg}')
@@ -446,8 +608,8 @@ MAX_SEED = np.iinfo(np.int32).max
446
  # precision data
447
 
448
  seq=512
449
- image_steps=50
450
- img_accu=7.0
451
 
452
  # ui data
453
 
@@ -508,10 +670,13 @@ function custom(){
508
  # torch pipes
509
 
510
  taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
 
511
  image_pipe = DiffusionPipeline.from_pretrained("ostris/Flex.1-alpha", torch_dtype=dtype, vae=taef1).to(device)
512
  image_pipe.enable_model_cpu_offload()
513
- image_pipe.enable_vae_slicing()
514
- image_pipe.enable_vae_tiling()
 
 
515
 
516
  # functionality
517
 
@@ -519,7 +684,7 @@ def upscaler(
519
  input_image: Image.Image,
520
  prompt: str = "Hyper realistic photography, Natural visual content.",
521
  negative_prompt: str = "Distorted, Discontinuous, Blurry, Doll-Like, Overly-Plastic, Low-Quality, Painted, Smoothed, Artificial, Phony, Gaudy, Digital Effects.",
522
- seed: int = int(str(random.random()).split(".")[1]),
523
  upscale_factor: int = 2,
524
  controlnet_scale: float = 0.6,
525
  controlnet_decay: float = 1.0,
@@ -527,7 +692,7 @@ def upscaler(
527
  tile_width: int = 112,
528
  tile_height: int = 144,
529
  denoise_strength: float = 0.35,
530
- num_inference_steps: int = 30,
531
  solver: str = "DDIM",
532
  ) -> Image.Image:
533
 
@@ -571,8 +736,8 @@ def _summarize(text):
571
  toks = tokenizer.encode( prefix + text, return_tensors="pt", truncation=False)
572
  gen = model.generate(
573
  toks,
574
- length_penalty=0.01,
575
- num_beams=8,
576
  early_stopping=True,
577
  max_length=512
578
  )
@@ -580,20 +745,11 @@ def _summarize(text):
580
  log(f'RET _summarize with ret as {ret}')
581
  return ret
582
 
583
- def summarize(text, max_words=10):
584
  log(f'CALL summarize')
585
 
586
  words = text.split()
587
-
588
- if len(words) < 5:
589
- print("Summarization Error: Text is too short, 5 words minimum.")
590
- return text
591
-
592
- if max_words < 5 or max_words > 500:
593
- print("Summarization Error: max_words value must be between 5 and 500 words.")
594
- return text
595
-
596
- words_length = len(text.split())
597
 
598
  if words_length >= 510:
599
  while words_length >= 510:
@@ -606,12 +762,11 @@ def summarize(text, max_words=10):
606
  text = summ
607
  words_length = len(text.split())
608
 
609
- while words_length > max_words:
610
  summ = _summarize(text)
611
  if summ == text:
612
  return text
613
  text = summ
614
- words_length = len(text.split())
615
 
616
  log(f'RET summarize with text as {text}')
617
  return text
@@ -621,8 +776,7 @@ def generate_random_string(length):
621
  return ''.join(random.choice(characters) for _ in range(length))
622
 
623
  def pipe_generate_image(p1,p2,h,w):
624
- log(f'CALL pipe_generate')
625
- imgs = image_pipe(
626
  prompt=p1,
627
  negative_prompt=p2,
628
  height=h,
@@ -632,9 +786,8 @@ def pipe_generate_image(p1,p2,h,w):
632
  num_inference_steps=image_steps,
633
  max_sequence_length=seq,
634
  generator=torch.Generator(device).manual_seed(random.randint(0, MAX_SEED))
635
- ).images
636
- log(f'RET pipe_generate')
637
- return imgs
638
 
639
  def add_song_cover_text(img,artist,song,h,w):
640
 
@@ -1273,6 +1426,9 @@ class GoogleTranslator(BaseTranslator):
1273
 
1274
  def translate(txt,to_lang="en",from_lang="auto"):
1275
  log(f'CALL translate')
 
 
 
1276
  translator = GoogleTranslator(from_lang=from_lang,to_lang=to_lang)
1277
  translation = ""
1278
  if len(txt) > 1000:
@@ -1323,7 +1479,7 @@ def handle_generation(artist,song,lyrics,h,w):
1323
  pos_lyrics = pos_lyrics if pos_lyrics == "" else summarize(translate(pos_lyrics))
1324
  pos_lyrics = re.sub(r"([ \t]){1,}", " ", pos_lyrics).lower().strip()
1325
 
1326
- neg = f"Sexuality, Nudity, Human body, Human, Textual, Text, Distorted, Fake, Discontinuous, Blurry, Doll-Like, Overly Plastic, Low Quality, Paint, Smoothed, Artificial, Phony, Gaudy, Digital Effects."
1327
  q = "\""
1328
  pos = f'HQ Hyper-realistic professional photograph{ pos_lyrics if pos_lyrics == "" else ": " + q + pos_lyrics + q }.'
1329
 
@@ -1336,7 +1492,7 @@ def handle_generation(artist,song,lyrics,h,w):
1336
  img = all_pipes(pos,neg,h,w)
1337
 
1338
  labeled_img = add_song_cover_text(img,pos_artist,pos_song,h,w)
1339
- name = f'{generate_random_string(8)}.png'
1340
  labeled_img.save(name)
1341
 
1342
  working = False
 
1
  """
2
  Modified parts included from these sources:
3
  - https://github.com/nidhaloff/deep-translator
4
+ - https://huggingface.co/spaces/ostris/Flex.1-alpha
5
  """
6
 
7
  import urllib
 
10
  from abc import ABC, abstractmethod
11
  from pathlib import Path
12
  from langdetect import detect as get_language
13
+ from typing import Any, Dict, List, Optional, Union
14
  from collections import namedtuple
15
  from inspect import signature
16
  import os
 
39
  from lxml.html import fromstring
40
  from huggingface_hub import hf_hub_download
41
  from safetensors.torch import load_file, save_file
42
+ from diffusers import DiffusionPipeline, AutoencoderTiny, FluxPipeline, FlowMatchEulerDiscreteScheduler
43
  from PIL import Image, ImageDraw, ImageFont
44
+ from transformers import pipeline, T5ForConditionalGeneration, T5Tokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
45
  from refiners.fluxion.utils import manual_seed
46
  from refiners.foundationals.latent_diffusion import Solver, solvers
47
  from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_upscaler import (
 
52
 
53
  working = False
54
 
55
+ model = T5ForConditionalGeneration.from_pretrained("t5-base")
56
+ tokenizer = T5Tokenizer.from_pretrained("t5-base")
57
+
58
+ def calculate_shift(
59
+ image_seq_len,
60
+ base_seq_len: int = 256,
61
+ max_seq_len: int = 4096,
62
+ base_shift: float = 0.5,
63
+ max_shift: float = 1.16,
64
+ ):
65
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
66
+ b = base_shift - m * base_seq_len
67
+ mu = image_seq_len * m + b
68
+ return mu
69
+
70
+ def retrieve_timesteps(
71
+ scheduler,
72
+ num_inference_steps: Optional[int] = None,
73
+ device: Optional[Union[str, torch.device]] = None,
74
+ timesteps: Optional[List[int]] = None,
75
+ sigmas: Optional[List[float]] = None,
76
+ **kwargs,
77
+ ):
78
+ if timesteps is not None and sigmas is not None:
79
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
80
+ if timesteps is not None:
81
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
82
+ timesteps = scheduler.timesteps
83
+ num_inference_steps = len(timesteps)
84
+ elif sigmas is not None:
85
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
86
+ timesteps = scheduler.timesteps
87
+ num_inference_steps = len(timesteps)
88
+ else:
89
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
90
+ timesteps = scheduler.timesteps
91
+ return timesteps, num_inference_steps
92
+
93
+ # FLUX pipeline function
94
+ @torch.inference_mode()
95
+ def flux_pipe_call_that_returns_an_iterable_of_images(
96
+ self,
97
+ prompt: Union[str, List[str]] = None,
98
+ prompt_2: Optional[Union[str, List[str]]] = None,
99
+ height: Optional[int] = None,
100
+ width: Optional[int] = None,
101
+ num_inference_steps: int = 28,
102
+ timesteps: List[int] = None,
103
+ guidance_scale: float = 3.5,
104
+ num_images_per_prompt: Optional[int] = 1,
105
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
106
+ latents: Optional[torch.FloatTensor] = None,
107
+ prompt_embeds: Optional[torch.FloatTensor] = None,
108
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
109
+ output_type: Optional[str] = "pil",
110
+ return_dict: bool = True,
111
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
112
+ max_sequence_length: int = 512,
113
+ good_vae: Optional[Any] = None,
114
+ ):
115
+ height = height or self.default_sample_size * self.vae_scale_factor
116
+ width = width or self.default_sample_size * self.vae_scale_factor
117
+
118
+ # 1. Check inputs
119
+ self.check_inputs(
120
+ prompt,
121
+ prompt_2,
122
+ height,
123
+ width,
124
+ prompt_embeds=prompt_embeds,
125
+ pooled_prompt_embeds=pooled_prompt_embeds,
126
+ max_sequence_length=max_sequence_length,
127
+ )
128
+
129
+ self._guidance_scale = guidance_scale
130
+ self._joint_attention_kwargs = joint_attention_kwargs
131
+ self._interrupt = False
132
+
133
+ # 2. Define call parameters
134
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
135
+ device = self._execution_device
136
+
137
+ # 3. Encode prompt
138
+ lora_scale = joint_attention_kwargs.get("scale", None) if joint_attention_kwargs is not None else None
139
+ prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(
140
+ prompt=prompt,
141
+ prompt_2=prompt_2,
142
+ prompt_embeds=prompt_embeds,
143
+ pooled_prompt_embeds=pooled_prompt_embeds,
144
+ device=device,
145
+ num_images_per_prompt=num_images_per_prompt,
146
+ max_sequence_length=max_sequence_length,
147
+ lora_scale=lora_scale,
148
+ )
149
+ # 4. Prepare latent variables
150
+ num_channels_latents = self.transformer.config.in_channels // 4
151
+ latents, latent_image_ids = self.prepare_latents(
152
+ batch_size * num_images_per_prompt,
153
+ num_channels_latents,
154
+ height,
155
+ width,
156
+ prompt_embeds.dtype,
157
+ device,
158
+ generator,
159
+ latents,
160
+ )
161
+ # 5. Prepare timesteps
162
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
163
+ image_seq_len = latents.shape[1]
164
+ mu = calculate_shift(
165
+ image_seq_len,
166
+ self.scheduler.config.base_image_seq_len,
167
+ self.scheduler.config.max_image_seq_len,
168
+ self.scheduler.config.base_shift,
169
+ self.scheduler.config.max_shift,
170
+ )
171
+ timesteps, num_inference_steps = retrieve_timesteps(
172
+ self.scheduler,
173
+ num_inference_steps,
174
+ device,
175
+ timesteps,
176
+ sigmas,
177
+ mu=mu,
178
+ )
179
+ self._num_timesteps = len(timesteps)
180
+
181
+ # Handle guidance
182
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32).expand(latents.shape[0]) if self.transformer.config.guidance_embeds else None
183
+
184
+ # 6. Denoising loop
185
+ for i, t in enumerate(timesteps):
186
+ if self.interrupt:
187
+ continue
188
+
189
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
190
+
191
+ noise_pred = self.transformer(
192
+ hidden_states=latents,
193
+ timestep=timestep / 1000,
194
+ guidance=guidance,
195
+ pooled_projections=pooled_prompt_embeds,
196
+ encoder_hidden_states=prompt_embeds,
197
+ txt_ids=text_ids,
198
+ img_ids=latent_image_ids,
199
+ joint_attention_kwargs=self.joint_attention_kwargs,
200
+ return_dict=False,
201
+ )[0]
202
+ # Yield intermediate result
203
+ latents_for_image = self._unpack_latents(latents, height, width, self.vae_scale_factor)
204
+ latents_for_image = (latents_for_image / self.vae.config.scaling_factor) + self.vae.config.shift_factor
205
+ image = self.vae.decode(latents_for_image, return_dict=False)[0]
206
+ yield self.image_processor.postprocess(image, output_type=output_type)[0]
207
+
208
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
209
+ torch.cuda.empty_cache()
210
+
211
+ # Final image using good_vae
212
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
213
+ latents = (latents / good_vae.config.scaling_factor) + good_vae.config.shift_factor
214
+ image = good_vae.decode(latents, return_dict=False)[0]
215
+ self.maybe_free_model_hooks()
216
+ torch.cuda.empty_cache()
217
+ yield self.image_processor.postprocess(image, output_type=output_type)[0]
218
 
219
  def log(msg):
220
  print(f'{datetime.now().time()} {msg}')
 
608
  # precision data
609
 
610
  seq=512
611
+ image_steps=25
612
+ img_accu=3.5
613
 
614
  # ui data
615
 
 
670
  # torch pipes
671
 
672
  taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
673
+ good_vae = AutoencoderKL.from_pretrained("ostris/Flex.1-alpha", subfolder="vae", torch_dtype=dtype).to(device)
674
  image_pipe = DiffusionPipeline.from_pretrained("ostris/Flex.1-alpha", torch_dtype=dtype, vae=taef1).to(device)
675
  image_pipe.enable_model_cpu_offload()
676
+
677
+ torch.cuda.empty_cache()
678
+
679
+ image_pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(image_pipe)
680
 
681
  # functionality
682
 
 
684
  input_image: Image.Image,
685
  prompt: str = "Hyper realistic photography, Natural visual content.",
686
  negative_prompt: str = "Distorted, Discontinuous, Blurry, Doll-Like, Overly-Plastic, Low-Quality, Painted, Smoothed, Artificial, Phony, Gaudy, Digital Effects.",
687
+ seed: int = random.randint(0, MAX_SEED),
688
  upscale_factor: int = 2,
689
  controlnet_scale: float = 0.6,
690
  controlnet_decay: float = 1.0,
 
692
  tile_width: int = 112,
693
  tile_height: int = 144,
694
  denoise_strength: float = 0.35,
695
+ num_inference_steps: int = 15,
696
  solver: str = "DDIM",
697
  ) -> Image.Image:
698
 
 
736
  toks = tokenizer.encode( prefix + text, return_tensors="pt", truncation=False)
737
  gen = model.generate(
738
  toks,
739
+ length_penalty=0.5,
740
+ num_beams=4,
741
  early_stopping=True,
742
  max_length=512
743
  )
 
745
  log(f'RET _summarize with ret as {ret}')
746
  return ret
747
 
748
+ def summarize(text, max_len=500):
749
  log(f'CALL summarize')
750
 
751
  words = text.split()
752
+ words_length = len(words)
 
 
 
 
 
 
 
 
 
753
 
754
  if words_length >= 510:
755
  while words_length >= 510:
 
762
  text = summ
763
  words_length = len(text.split())
764
 
765
+ while len(text) > max_len:
766
  summ = _summarize(text)
767
  if summ == text:
768
  return text
769
  text = summ
 
770
 
771
  log(f'RET summarize with text as {text}')
772
  return text
 
776
  return ''.join(random.choice(characters) for _ in range(length))
777
 
778
  def pipe_generate_image(p1,p2,h,w):
779
+ for img in pipe.flux_pipe_call_that_returns_an_iterable_of_images(
 
780
  prompt=p1,
781
  negative_prompt=p2,
782
  height=h,
 
786
  num_inference_steps=image_steps,
787
  max_sequence_length=seq,
788
  generator=torch.Generator(device).manual_seed(random.randint(0, MAX_SEED))
789
+ ):
790
+ yield img
 
791
 
792
  def add_song_cover_text(img,artist,song,h,w):
793
 
 
1426
 
1427
  def translate(txt,to_lang="en",from_lang="auto"):
1428
  log(f'CALL translate')
1429
+ if from_lang == to_lang or get_language(txt) == to_lang:
1430
+ print("Translation failed!")
1431
+ return txt.strip().lower()
1432
  translator = GoogleTranslator(from_lang=from_lang,to_lang=to_lang)
1433
  translation = ""
1434
  if len(txt) > 1000:
 
1479
  pos_lyrics = pos_lyrics if pos_lyrics == "" else summarize(translate(pos_lyrics))
1480
  pos_lyrics = re.sub(r"([ \t]){1,}", " ", pos_lyrics).lower().strip()
1481
 
1482
+ neg = f"Textual, Text, Distorted, Fake, Discontinuous, Blurry, Doll-Like, Overly Plastic, Low Quality, Paint, Smoothed, Artificial, Phony, Gaudy, Digital Effects."
1483
  q = "\""
1484
  pos = f'HQ Hyper-realistic professional photograph{ pos_lyrics if pos_lyrics == "" else ": " + q + pos_lyrics + q }.'
1485
 
 
1492
  img = all_pipes(pos,neg,h,w)
1493
 
1494
  labeled_img = add_song_cover_text(img,pos_artist,pos_song,h,w)
1495
+ name = f'{generate_random_string(16)}.png'
1496
  labeled_img.save(name)
1497
 
1498
  working = False