Yaron Koresh commited on
Commit
9f1f2bf
·
verified ·
1 Parent(s): 19adb45

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -34
app.py CHANGED
@@ -8,6 +8,8 @@ import requests
8
  import gradio as gr
9
  import numpy as np
10
  from lxml.html import fromstring
 
 
11
  #from pathos.multiprocessing import ProcessPool as Pool
12
  #from pathos.threading import ThreadPool as Pool
13
  #from diffusers.pipelines.flux import FluxPipeline
@@ -20,32 +22,37 @@ from diffusers.utils import load_image
20
  #import jax.numpy as jnp
21
  import torch._dynamo
22
 
 
23
  torch._dynamo.config.suppress_errors = True
24
 
25
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
26
 
27
- #pipe = FlaxStableDiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", dtype=jnp.bfloat16, token=os.getenv("hf_token")).to(device)
28
- pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, revision="refs/pr/1", token=os.getenv("hf_token")).to(device)
29
 
30
- pipe2 = StableDiffusionXLImg2ImgPipeline.from_pretrained(
31
- "stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
32
- ).to(device)
33
- pipe2.unet = torch.compile(pipe2.unet, mode="reduce-overhead", fullgraph=True)
34
-
35
- def translate(text,lang):
 
 
 
 
 
 
36
 
 
37
  if text == None or lang == None:
38
- return ""
39
-
40
  text = re.sub(f'[{string.punctuation}]', '', re.sub('[\s+]', ' ', text)).lower().strip()
41
- lang = re.sub(f'[{string.punctuation}]', '', re.sub('[\s+]', ' ', lang)).lower().strip()
42
-
43
  if text == "" or lang == "":
44
  return ""
45
-
46
  if len(text) > 38:
47
  raise Exception("Translation Error: Too long text!")
48
-
49
  user_agents = [
50
  'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/126.0.0.0 Safari/537.36',
51
  'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/126.0.0.0 Safari/537.36',
@@ -56,32 +63,24 @@ def translate(text,lang):
56
  padded_chars = re.sub("[(^\-)(\-$)]","",text.replace("","-").replace("- -"," ")).strip()
57
  query_text = f'Please translate {padded_chars}, into {lang}'
58
  url = f'https://www.google.com/search?q={query_text}'
59
-
60
- print(url)
61
-
62
  resp = requests.get(
63
  url = url,
64
  headers = {
65
  'User-Agent': random.choice(user_agents)
66
  }
67
  )
68
-
69
  content = resp.content
70
  html = fromstring(content)
71
-
72
  translated = text
73
-
74
  try:
75
  src_lang = html.xpath('//*[@class="source-language"]')[0].text_content().lower().strip()
76
  trgt_lang = html.xpath('//*[@class="target-language"]')[0].text_content().lower().strip()
77
  src_text = html.xpath('//*[@id="tw-source-text"]/*')[0].text_content().lower().strip()
78
  trgt_text = html.xpath('//*[@id="tw-target-text"]/*')[0].text_content().lower().strip()
79
-
80
  if trgt_lang == lang:
81
  translated = trgt_text
82
  except:
83
  print(f'Translation Warning: Failed To Translate!')
84
-
85
  ret = re.sub(f'[{string.punctuation}]', '', re.sub('[\s+]', ' ', translated)).lower().strip()
86
  print(ret)
87
  return ret
@@ -92,6 +91,7 @@ def generate_random_string(length):
92
 
93
  @spaces.GPU(duration=35)
94
  def Piper(_do):
 
95
  try:
96
  retu = pipe(
97
  _do,
@@ -108,6 +108,7 @@ def Piper(_do):
108
 
109
  @spaces.GPU(duration=35)
110
  def Piper2(img,posi,neg):
 
111
  try:
112
  retu = pipe2(
113
  prompt=posi,
@@ -119,7 +120,7 @@ def Piper2(img,posi,neg):
119
  print(e)
120
  return None
121
 
122
- @spaces.GPU(duration=25)
123
  def tok(txt):
124
  toks = pipe.tokenizer(txt)['input_ids']
125
  print(toks)
@@ -143,8 +144,8 @@ def infer(p1,p2):
143
  if neg == None:
144
  return name
145
 
146
- init_image = load_image(name).convert("RGB")
147
- output2 = Piper2(init_image,p1,neg)
148
  if output2 == None:
149
  return None
150
  else:
@@ -214,6 +215,8 @@ with gr.Blocks(theme=gr.themes.Soft(),css=css,js=js) as demo:
214
  run_button = gr.Button("START",elem_classes="btn",scale=0)
215
  with gr.Row():
216
  result.append(gr.Image(interactive=False,elem_classes="image-container", label="Result", show_label=False, type='filepath', show_share_button=False))
 
 
217
 
218
  def _ret(p1,p2):
219
 
@@ -227,14 +230,16 @@ with gr.Blocks(theme=gr.themes.Soft(),css=css,js=js) as demo:
227
  p1_en = translate(p1,"english")
228
  p2_en = translate(p2,"english")
229
 
230
- #ln = len(result)
231
- #idxs = list(range(ln))
232
- #p1s = [p1_en for _ in idxs]
233
- #p2s = [p2_en for _ in idxs]
234
- #pool = Pool(ln)
235
- #return list( pool.imap( _ret, p1s, p2s ) )
236
-
237
- return list( _ret(p1_en,p2_en) )
 
 
238
 
239
  run_button.click(fn=_rets,inputs=[prompt,prompt2],outputs=result)
240
 
 
8
  import gradio as gr
9
  import numpy as np
10
  from lxml.html import fromstring
11
+ from transformers import pipeline
12
+ from torch.multiprocessing import Pool, Process, set_start_method
13
  #from pathos.multiprocessing import ProcessPool as Pool
14
  #from pathos.threading import ThreadPool as Pool
15
  #from diffusers.pipelines.flux import FluxPipeline
 
22
  #import jax.numpy as jnp
23
  import torch._dynamo
24
 
25
+ set_start_method("spawn", force=True)
26
  torch._dynamo.config.suppress_errors = True
27
 
28
+ #pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, revision="refs/pr/1", token=os.getenv("hf_token")).to(device)
29
+ #pipe2 = StableDiffusionXLImg2ImgPipeline.from_pretrained("stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True).to(device)
30
+ #pipe2.unet = torch.compile(pipe2.unet, mode="reduce-overhead", fullgraph=True)
31
 
32
+ PIPE = None
 
33
 
34
+ def pipe_t2i():
35
+ global PIPE
36
+ if PIPE is None:
37
+ PIPE = pipeline("text-to-image", model="black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, revision="refs/pr/1", tokenizer="black-forest-labs/FLUX.1-schnell", device=-1, token=os.getenv("hf_token"))
38
+ return PIPE
39
+
40
+ def pipe_i2i():
41
+ global PIPE
42
+ if PIPE is None:
43
+ PIPE = pipeline("image-to-image", model="stabilityai/stable-diffusion-xl-refiner-1.0", tokenizer="stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, device=-1, variant="fp16", use_safetensors=True)
44
+ PIPE.unet = torch.compile(PIPE.unet, mode="reduce-overhead", fullgraph=True)
45
+ return PIPE
46
 
47
+ def translate(text,lang):
48
  if text == None or lang == None:
49
+ return ""
 
50
  text = re.sub(f'[{string.punctuation}]', '', re.sub('[\s+]', ' ', text)).lower().strip()
51
+ lang = re.sub(f'[{string.punctuation}]', '', re.sub('[\s+]', ' ', lang)).lower().strip()
 
52
  if text == "" or lang == "":
53
  return ""
 
54
  if len(text) > 38:
55
  raise Exception("Translation Error: Too long text!")
 
56
  user_agents = [
57
  'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/126.0.0.0 Safari/537.36',
58
  'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/126.0.0.0 Safari/537.36',
 
63
  padded_chars = re.sub("[(^\-)(\-$)]","",text.replace("","-").replace("- -"," ")).strip()
64
  query_text = f'Please translate {padded_chars}, into {lang}'
65
  url = f'https://www.google.com/search?q={query_text}'
 
 
 
66
  resp = requests.get(
67
  url = url,
68
  headers = {
69
  'User-Agent': random.choice(user_agents)
70
  }
71
  )
 
72
  content = resp.content
73
  html = fromstring(content)
 
74
  translated = text
 
75
  try:
76
  src_lang = html.xpath('//*[@class="source-language"]')[0].text_content().lower().strip()
77
  trgt_lang = html.xpath('//*[@class="target-language"]')[0].text_content().lower().strip()
78
  src_text = html.xpath('//*[@id="tw-source-text"]/*')[0].text_content().lower().strip()
79
  trgt_text = html.xpath('//*[@id="tw-target-text"]/*')[0].text_content().lower().strip()
 
80
  if trgt_lang == lang:
81
  translated = trgt_text
82
  except:
83
  print(f'Translation Warning: Failed To Translate!')
 
84
  ret = re.sub(f'[{string.punctuation}]', '', re.sub('[\s+]', ' ', translated)).lower().strip()
85
  print(ret)
86
  return ret
 
91
 
92
  @spaces.GPU(duration=35)
93
  def Piper(_do):
94
+ pipe = pipe_t2i()
95
  try:
96
  retu = pipe(
97
  _do,
 
108
 
109
  @spaces.GPU(duration=35)
110
  def Piper2(img,posi,neg):
111
+ pipe = pipe_i2i()
112
  try:
113
  retu = pipe2(
114
  prompt=posi,
 
120
  print(e)
121
  return None
122
 
123
+ @spaces.GPU(duration=35)
124
  def tok(txt):
125
  toks = pipe.tokenizer(txt)['input_ids']
126
  print(toks)
 
144
  if neg == None:
145
  return name
146
 
147
+ img = load_image(name).convert("RGB")
148
+ output2 = Piper2(img,p1,neg)
149
  if output2 == None:
150
  return None
151
  else:
 
215
  run_button = gr.Button("START",elem_classes="btn",scale=0)
216
  with gr.Row():
217
  result.append(gr.Image(interactive=False,elem_classes="image-container", label="Result", show_label=False, type='filepath', show_share_button=False))
218
+ result.append(gr.Image(interactive=False,elem_classes="image-container", label="Result", show_label=False, type='filepath', show_share_button=False))
219
+ result.append(gr.Image(interactive=False,elem_classes="image-container", label="Result", show_label=False, type='filepath', show_share_button=False))
220
 
221
  def _ret(p1,p2):
222
 
 
230
  p1_en = translate(p1,"english")
231
  p2_en = translate(p2,"english")
232
 
233
+ ln = len(result)
234
+ idxs = list(range(ln))
235
+ p1s = [p1_en for _ in idxs]
236
+ p2s = [p2_en for _ in idxs]
237
+ pool = Pool(ln)
238
+ lst = list( pool.imap( _ret, p1s, p2s ) )
239
+ pool.clear()
240
+ return lst
241
+
242
+ #return list( _ret(p1_en,p2_en) )
243
 
244
  run_button.click(fn=_rets,inputs=[prompt,prompt2],outputs=result)
245