Yaron Koresh commited on
Commit
bf9773d
·
verified ·
1 Parent(s): aa7cc5b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -12
app.py CHANGED
@@ -1,6 +1,5 @@
1
 
2
- # built-in
3
-
4
  from collections import namedtuple
5
  from inspect import signature
6
  import os
@@ -555,7 +554,7 @@ def get_tensor_length(tensor):
555
  ret = ret * num
556
  return ret
557
 
558
- def summarize_text(
559
  text, max_len=20, min_len=10
560
  ):
561
  log(f'CALL summarize_text')
@@ -564,12 +563,12 @@ def summarize_text(
564
  while get_tensor_length(inputs) > max_len:
565
  print(f'DBG summarize_text 1 {i}')
566
  outputs = model.generate(
567
- inputs[:512],
568
- max_length=max_len,
569
- min_length=min_len,
570
  length_penalty=2.0,
571
- num_beams=4,
572
- early_stopping=True
 
 
573
  )
574
  inputs = torch.tensor([[*list(outputs[0]), *list(inputs[0][512:])]])
575
  i = i + 1
@@ -631,6 +630,25 @@ def all_pipes(pos,neg,artist,song):
631
 
632
  return imgs
633
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
634
  @spaces.GPU(duration=300)
635
  def handle_generation(artist,song,genre,lyrics):
636
 
@@ -644,11 +662,11 @@ def handle_generation(artist,song,genre,lyrics):
644
  pos_genre = ' '.join(word[0].upper() + word[1:] for word in pos_genre.split())
645
 
646
  pos_lyrics = re.sub(f'[{punctuation}]', '', re.sub("([ \t\n]){1,}", " ", lyrics)).lower().strip()
647
- pos_lyrics_sum = pos_lyrics if pos_lyrics == "" else summarize_text(pos_lyrics)
648
 
649
  neg = f"Sexuality, Humanity, Textual, Labeled, Distorted, Discontinuous, Blurry, Doll-Like, Overly Plastic, Low-Quality, Painted, Smoothed, Artificial, Phony, Gaudy, Digital Effects."
650
  q = "\""
651
- pos = f'HQ Hyper-realistic { pos_genre } song "{ pos_song }"{ pos_lyrics_sum if pos_lyrics_sum == "" else ": " + pos_lyrics_sum }.'
652
 
653
  print(f"""
654
  Positive: {pos}
@@ -693,13 +711,13 @@ if __name__ == "__main__":
693
  max_lines=1
694
  )
695
  genre = gr.Textbox(
696
- placeholder="Genre (English)",
697
  value="",
698
  container=False,
699
  max_lines=1
700
  )
701
  lyrics = gr.Textbox(
702
- placeholder="Lyrics (English)",
703
  value="",
704
  container=False,
705
  max_lines=1
 
1
 
2
+ from langdetect import detect as get_language
 
3
  from collections import namedtuple
4
  from inspect import signature
5
  import os
 
554
  ret = ret * num
555
  return ret
556
 
557
+ def summarize(
558
  text, max_len=20, min_len=10
559
  ):
560
  log(f'CALL summarize_text')
 
563
  while get_tensor_length(inputs) > max_len:
564
  print(f'DBG summarize_text 1 {i}')
565
  outputs = model.generate(
566
+ inputs[0][:512],
 
 
567
  length_penalty=2.0,
568
+ num_beams=max(8,get_tensor_length(inputs)),
569
+ early_stopping=True,
570
+ max_length=max( get_tensor_length(inputs) // 4 , max_len ),
571
+ min_length=min_len
572
  )
573
  inputs = torch.tensor([[*list(outputs[0]), *list(inputs[0][512:])]])
574
  i = i + 1
 
630
 
631
  return imgs
632
 
633
+ def translate(txt,to_lang="en",from_lang=False):
634
+ log(f'CALL translate')
635
+ if not from_lang:
636
+ from_lang = get_language(txt)
637
+ if(from_lang == to_lang):
638
+ log(f'RET translate with txt as {txt}')
639
+ return txt
640
+ inputs = tokenizer.encode(f"translate {from_lang} to {to_lang}: " + text, return_tensors="pt", max_length=float('inf'), truncation=False)
641
+ chunks_length = math.ceil(get_tensor_length(inputs) / 512):
642
+ ret = ""
643
+ for index in range(chunks_length):
644
+ ret = ret + ("" if ret == "" else " ") + tokenizer.decode(
645
+ model.generate(
646
+ inputs[0][ index*512:index*512+512 ]
647
+ )[0]
648
+ )
649
+ log(f'RET translate with ret as {ret}')
650
+ return ret
651
+
652
  @spaces.GPU(duration=300)
653
  def handle_generation(artist,song,genre,lyrics):
654
 
 
662
  pos_genre = ' '.join(word[0].upper() + word[1:] for word in pos_genre.split())
663
 
664
  pos_lyrics = re.sub(f'[{punctuation}]', '', re.sub("([ \t\n]){1,}", " ", lyrics)).lower().strip()
665
+ pos_lyrics_sum = pos_lyrics if pos_lyrics == "" else summarize(pos_lyrics)
666
 
667
  neg = f"Sexuality, Humanity, Textual, Labeled, Distorted, Discontinuous, Blurry, Doll-Like, Overly Plastic, Low-Quality, Painted, Smoothed, Artificial, Phony, Gaudy, Digital Effects."
668
  q = "\""
669
+ pos = f'HQ Hyper-realistic { translate(pos_genre) } song "{ translate(pos_song) }"{ pos_lyrics_sum if pos_lyrics_sum == "" else ": " + translate(pos_lyrics_sum) }.'
670
 
671
  print(f"""
672
  Positive: {pos}
 
711
  max_lines=1
712
  )
713
  genre = gr.Textbox(
714
+ placeholder="Genre",
715
  value="",
716
  container=False,
717
  max_lines=1
718
  )
719
  lyrics = gr.Textbox(
720
+ placeholder="Lyrics",
721
  value="",
722
  container=False,
723
  max_lines=1