Yaron Koresh commited on
Commit
9642724
·
verified ·
1 Parent(s): 07d7428

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -6
app.py CHANGED
@@ -13,7 +13,8 @@ from pathos.threading import ThreadPool as Pool
13
  #from diffusers.utils import export_to_gif
14
  #from huggingface_hub import hf_hub_download
15
  #from safetensors.torch import load_file
16
- from diffusers import FlaxStableDiffusionPipeline, DiffusionPipeline
 
17
  #import jax
18
  #import jax.numpy as jnp
19
 
@@ -22,6 +23,11 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
22
  #pipe = FlaxStableDiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", dtype=jnp.bfloat16, token=os.getenv("hf_token")).to(device)
23
  pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, revision="refs/pr/1", token=os.getenv("hf_token")).to(device)
24
 
 
 
 
 
 
25
  def translate(text,lang):
26
 
27
  if text == None or lang == None:
@@ -81,7 +87,7 @@ def generate_random_string(length):
81
  return ''.join(random.choice(characters) for _ in range(length))
82
 
83
  @spaces.GPU(duration=35)
84
- def Piper(_do,neg):
85
  try:
86
  retu = pipe(
87
  _do,
@@ -97,6 +103,18 @@ def Piper(_do,neg):
97
  print(e)
98
  return None
99
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  @spaces.GPU(duration=25)
101
  def tok(txt):
102
  toks = pipe.tokenizer(txt)['input_ids']
@@ -110,15 +128,24 @@ def infer(p1,p2):
110
  _do.append(f'{p1}')
111
  if p2 != "":
112
  _dont = f'{p2} where in {p1}'
113
- neg = tok(_dont)
114
  else:
115
  neg = None
116
- output = Piper('A '+" ".join(_do),neg)
117
  if output == None:
118
- return output
119
  else:
120
  output.images[0].save(name)
121
- return name
 
 
 
 
 
 
 
 
 
122
 
123
  css="""
124
  input, input::placeholder {
 
13
  #from diffusers.utils import export_to_gif
14
  #from huggingface_hub import hf_hub_download
15
  #from safetensors.torch import load_file
16
+ from diffusers import DiffusionPipeline, StableDiffusionXLImg2ImgPipeline
17
+ from diffusers.utils import load_image
18
  #import jax
19
  #import jax.numpy as jnp
20
 
 
23
  #pipe = FlaxStableDiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", dtype=jnp.bfloat16, token=os.getenv("hf_token")).to(device)
24
  pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16, revision="refs/pr/1", token=os.getenv("hf_token")).to(device)
25
 
26
+ pipe2 = StableDiffusionXLImg2ImgPipeline.from_pretrained(
27
+ "stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
28
+ ).to(device)
29
+ pipe2.unet = torch.compile(pipe2.unet, mode="reduce-overhead", fullgraph=True)
30
+
31
  def translate(text,lang):
32
 
33
  if text == None or lang == None:
 
87
  return ''.join(random.choice(characters) for _ in range(length))
88
 
89
  @spaces.GPU(duration=35)
90
+ def Piper(_do):
91
  try:
92
  retu = pipe(
93
  _do,
 
103
  print(e)
104
  return None
105
 
106
+ @spaces.GPU(duration=75)
107
+ def Piper2(name,neg):
108
+ try:
109
+ retu = pipe2(
110
+ negative_prompt=neg,
111
+ image=name
112
+ )
113
+ return retu
114
+ except Exception as e:
115
+ print(e)
116
+ return None
117
+
118
  @spaces.GPU(duration=25)
119
  def tok(txt):
120
  toks = pipe.tokenizer(txt)['input_ids']
 
128
  _do.append(f'{p1}')
129
  if p2 != "":
130
  _dont = f'{p2} where in {p1}'
131
+ neg = _dont
132
  else:
133
  neg = None
134
+ output = Piper('A '+" ".join(_do))
135
  if output == None:
136
+ return None
137
  else:
138
  output.images[0].save(name)
139
+ if neg == None:
140
+ return name
141
+
142
+ init_image = load_image(name).convert("RGB")
143
+ output2 = Piper2(init_image,neg)
144
+ if output2 == None:
145
+ return None
146
+ else:
147
+ output2.images[0].save("_"+name)
148
+ return "_"+name
149
 
150
  css="""
151
  input, input::placeholder {