ginipick commited on
Commit
9840d42
·
verified ·
1 Parent(s): 9469372

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -8
app.py CHANGED
@@ -24,8 +24,6 @@ from transformers import T5EncoderModel, T5Tokenizer
24
  # from optimum.quanto import freeze, qfloat8, quantize
25
  from transformers import pipeline
26
 
27
- ko_translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
28
- ja_translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ja-en")
29
 
30
  class HFEmbedder(nn.Module):
31
  def __init__(self, version: str, max_length: int, **hf_kwargs):
@@ -749,9 +747,6 @@ model = Flux().to(dtype=torch.bfloat16, device="cuda")
749
  result = model.load_state_dict(sd)
750
  model_zero_init = False
751
 
752
- # model = Flux().to(dtype=torch.bfloat16, device="cuda")
753
- # result = model.load_state_dict(load_file("/storage/dev/nyanko/flux-dev/flux1-dev.sft"))
754
-
755
 
756
  @spaces.GPU
757
  @torch.no_grad()
@@ -762,14 +757,28 @@ def generate_image(
762
  ):
763
  translated_prompt = prompt
764
 
765
- # 한글 또는 일본어 문자 감지
766
  def contains_korean(text):
767
  return any('\u3131' <= c <= '\u318E' or '\uAC00' <= c <= '\uD7A3' for c in text)
768
 
769
  def contains_japanese(text):
770
  return any('\u3040' <= c <= '\u309F' or '\u30A0' <= c <= '\u30FF' or '\u4E00' <= c <= '\u9FFF' for c in text)
771
 
772
- # 한글이나 일본어가 있으면 번역
 
 
 
 
 
 
 
 
 
 
 
 
 
 
773
  if contains_korean(prompt):
774
  translated_prompt = ko_translator(prompt, max_length=512)[0]['translation_text']
775
  print(f"Translated Korean prompt: {translated_prompt}")
@@ -778,6 +787,14 @@ def generate_image(
778
  translated_prompt = ja_translator(prompt, max_length=512)[0]['translation_text']
779
  print(f"Translated Japanese prompt: {translated_prompt}")
780
  prompt = translated_prompt
 
 
 
 
 
 
 
 
781
 
782
  if seed == 0:
783
  seed = int(random.random() * 1000000)
@@ -785,7 +802,6 @@ def generate_image(
785
  device = "cuda" if torch.cuda.is_available() else "cpu"
786
  torch_device = torch.device(device)
787
 
788
-
789
 
790
  global model, model_zero_init
791
  if not model_zero_init:
 
24
  # from optimum.quanto import freeze, qfloat8, quantize
25
  from transformers import pipeline
26
 
 
 
27
 
28
  class HFEmbedder(nn.Module):
29
  def __init__(self, version: str, max_length: int, **hf_kwargs):
 
747
  result = model.load_state_dict(sd)
748
  model_zero_init = False
749
 
 
 
 
750
 
751
  @spaces.GPU
752
  @torch.no_grad()
 
757
  ):
758
  translated_prompt = prompt
759
 
760
+ # 한글, 일본어, 중국어, 스페인어 문자 감지
761
  def contains_korean(text):
762
  return any('\u3131' <= c <= '\u318E' or '\uAC00' <= c <= '\uD7A3' for c in text)
763
 
764
  def contains_japanese(text):
765
  return any('\u3040' <= c <= '\u309F' or '\u30A0' <= c <= '\u30FF' or '\u4E00' <= c <= '\u9FFF' for c in text)
766
 
767
+ def contains_chinese(text):
768
+ return any('\u4e00' <= c <= '\u9fff' for c in text)
769
+
770
+ def contains_spanish(text):
771
+ # 스페인어 특수 문자 포함 확인
772
+ spanish_chars = set('áéíóúüñ¿¡')
773
+ return any(c in spanish_chars for c in text.lower())
774
+
775
+ # 번역기 추가
776
+ ko_translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
777
+ ja_translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ja-en")
778
+ zh_translator = pipeline("translation", model="Helsinki-NLP/opus-mt-zh-en")
779
+ es_translator = pipeline("translation", model="Helsinki-NLP/opus-mt-es-en")
780
+
781
+ # 각 언어 감지 후 번역
782
  if contains_korean(prompt):
783
  translated_prompt = ko_translator(prompt, max_length=512)[0]['translation_text']
784
  print(f"Translated Korean prompt: {translated_prompt}")
 
787
  translated_prompt = ja_translator(prompt, max_length=512)[0]['translation_text']
788
  print(f"Translated Japanese prompt: {translated_prompt}")
789
  prompt = translated_prompt
790
+ elif contains_chinese(prompt):
791
+ translated_prompt = zh_translator(prompt, max_length=512)[0]['translation_text']
792
+ print(f"Translated Chinese prompt: {translated_prompt}")
793
+ prompt = translated_prompt
794
+ elif contains_spanish(prompt):
795
+ translated_prompt = es_translator(prompt, max_length=512)[0]['translation_text']
796
+ print(f"Translated Spanish prompt: {translated_prompt}")
797
+ prompt = translated_prompt
798
 
799
  if seed == 0:
800
  seed = int(random.random() * 1000000)
 
802
  device = "cuda" if torch.cuda.is_available() else "cpu"
803
  torch_device = torch.device(device)
804
 
 
805
 
806
  global model, model_zero_init
807
  if not model_zero_init: