Yaron Koresh commited on
Commit
8ecb267
·
verified ·
1 Parent(s): f44b741

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -21
app.py CHANGED
@@ -9,7 +9,7 @@ import gradio as gr
9
  import numpy as np
10
  from lxml.html import fromstring
11
  from transformers import pipeline
12
- from torch.multiprocessing import Pool, Process, set_start_method
13
  #from pathos.multiprocessing import ProcessPool as Pool
14
  #from pathos.threading import ThreadPool as Pool
15
  #from diffusers.pipelines.flux import FluxPipeline
@@ -20,10 +20,15 @@ from diffusers import DiffusionPipeline, StableDiffusionXLImg2ImgPipeline
20
  from diffusers.utils import load_image
21
  #import jax
22
  #import jax.numpy as jnp
23
- import torch._dynamo
24
 
25
- set_start_method("spawn", force=True)
26
- torch._dynamo.config.suppress_errors = True
 
 
 
 
 
 
27
 
28
  #pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, revision="refs/pr/1", token=os.getenv("hf_token")).to(device)
29
  #pipe2 = StableDiffusionXLImg2ImgPipeline.from_pretrained("stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True).to(device)
@@ -126,7 +131,9 @@ def tok(txt):
126
  print(toks)
127
  return toks
128
 
129
- def infer(p1,p2):
 
 
130
  name = generate_random_string(12)+".png"
131
  _do = ['beautiful', 'playful', 'photographed', 'realistic', 'dynamic poze', 'deep field', 'reasonable coloring', 'rough texture', 'best quality', 'focused']
132
  if p1 != "":
@@ -218,29 +225,24 @@ with gr.Blocks(theme=gr.themes.Soft(),css=css,js=js) as demo:
218
  result.append(gr.Image(interactive=False,elem_classes="image-container", label="Result", show_label=False, type='filepath', show_share_button=False))
219
  result.append(gr.Image(interactive=False,elem_classes="image-container", label="Result", show_label=False, type='filepath', show_share_button=False))
220
 
221
- def _ret(p):
222
-
223
- print(f'Starting!')
224
- v = infer(p["a"],p["b"])
225
- print(f'Finished!')
226
- return v
227
-
228
- def _rets(p1,p2):
229
 
230
  p1_en = translate(p1,"english")
231
  p2_en = translate(p2,"english")
232
 
233
  p = {"a":p1_en,"b":p2_en}
234
-
235
  ln = len(result)
236
  rng = range(ln)
237
- p_arr = [p for _ in rng]
238
- pool = Pool(processes=ln)
239
- lst = list( pool.imap( _ret, p_arr ) )
240
- pool.clear()
241
- return lst
 
 
 
242
 
243
  #return list( _ret(p1_en,p2_en) )
244
 
245
- run_button.click(fn=_rets,inputs=[prompt,prompt2],outputs=result)
246
- demo.queue().launch(server_port=7861)
 
9
  import numpy as np
10
  from lxml.html import fromstring
11
  from transformers import pipeline
12
+ from torch import multiprocessing as mp, _dynamo
13
  #from pathos.multiprocessing import ProcessPool as Pool
14
  #from pathos.threading import ThreadPool as Pool
15
  #from diffusers.pipelines.flux import FluxPipeline
 
20
  from diffusers.utils import load_image
21
  #import jax
22
  #import jax.numpy as jnp
 
23
 
24
+ def init_pool(q):
25
+ global b
26
+ b = q
27
+
28
+ _dynamo.config.suppress_errors = True
29
+
30
+ mp.set_start_method("spawn", force=True)
31
+ b=mp.Queue()
32
 
33
  #pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, revision="refs/pr/1", token=os.getenv("hf_token")).to(device)
34
  #pipe2 = StableDiffusionXLImg2ImgPipeline.from_pretrained("stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True).to(device)
 
131
  print(toks)
132
  return toks
133
 
134
+ def infer():
135
+ p1 = b["a"]
136
+ p2 = b["b"]
137
  name = generate_random_string(12)+".png"
138
  _do = ['beautiful', 'playful', 'photographed', 'realistic', 'dynamic poze', 'deep field', 'reasonable coloring', 'rough texture', 'best quality', 'focused']
139
  if p1 != "":
 
225
  result.append(gr.Image(interactive=False,elem_classes="image-container", label="Result", show_label=False, type='filepath', show_share_button=False))
226
  result.append(gr.Image(interactive=False,elem_classes="image-container", label="Result", show_label=False, type='filepath', show_share_button=False))
227
 
228
+ def main(p1,p2):
 
 
 
 
 
 
 
229
 
230
  p1_en = translate(p1,"english")
231
  p2_en = translate(p2,"english")
232
 
233
  p = {"a":p1_en,"b":p2_en}
 
234
  ln = len(result)
235
  rng = range(ln)
236
+
237
+ for _ in rng:
238
+ b.put(p)
239
+
240
+ with mp.Pool(ln, initializer=init_pool, initargs=(b,)) as pool:
241
+ out = pool.map(infer)
242
+ pool.clear()
243
+ return list(out.get())
244
 
245
  #return list( _ret(p1_en,p2_en) )
246
 
247
+ run_button.click(fn=main,inputs=[prompt,prompt2],outputs=result)
248
+ demo.queue().launch()