1inkusFace commited on
Commit
7bfef4d
·
verified ·
1 Parent(s): 5e1c614

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -45
app.py CHANGED
@@ -51,43 +51,32 @@ if GCS_SA_KEY and GCS_BUCKET_NAME:
51
  except Exception as e:
52
  print(f"❌ Failed to initialize GCS client: {e}")
53
 
54
- # --- New GCS Upload Function (runs on CPU) ---
55
  def upload_to_gcs(image_object, filename):
56
- """Uploads a PIL Image object to GCS from memory."""
57
  if not gcs_client:
58
  print("⚠️ GCS client not initialized. Skipping upload.")
59
  return
60
-
61
  try:
62
  print(f"--> Starting GCS upload for {filename}...")
63
  bucket = gcs_client.bucket(GCS_BUCKET_NAME)
64
  blob = bucket.blob(f"stablediff/{filename}")
65
-
66
- # Convert PIL image to bytes stream
67
  img_byte_arr = io.BytesIO()
68
  image_object.save(img_byte_arr, format='PNG', optimize=False, compress_level=0)
69
  img_byte_arr = img_byte_arr.getvalue()
70
-
71
- # Upload from the in-memory string
72
  blob.upload_from_string(img_byte_arr, content_type='image/png')
73
  print(f"✅ Successfully uploaded {filename} to GCS.")
74
-
75
  except Exception as e:
76
  print(f"❌ An error occurred during GCS upload: {e}")
77
 
78
- # --- Model Loading ---
79
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
80
- hftoken = os.getenv("HF_AUTH_TOKEN")
81
 
82
  pipe = StableDiffusion3Pipeline.from_pretrained(
83
  "ford442/stable-diffusion-3.5-large-bf16",
84
  trust_remote_code=True,
85
  transformer=None, # Load transformer separately
86
- use_safetensors=True,
87
- # token=hftoken
88
  )
89
- # Load transformer separately and move to device with specified dtype
90
- ll_transformer=SD3Transformer2DModel.from_pretrained("ford442/stable-diffusion-3.5-large-bf16", subfolder='transformer', token=hftoken).to(device, dtype=torch.bfloat16)
91
  pipe.transformer=ll_transformer
92
  pipe.load_lora_weights("ford442/sdxl-vae-bf16", weight_name="LoRA/UltraReal.safetensors")
93
  pipe.to(device=device, dtype=torch.bfloat16)
@@ -97,13 +86,10 @@ upscaler_2 = UpscaleWithModel.from_pretrained("Kim2091/ClearRealityV1").to(devic
97
  MAX_SEED = np.iinfo(np.int32).max
98
  MAX_IMAGE_SIZE = 4096
99
 
100
- # --- Refactored GPU Inference Function ---
101
- @spaces.GPU(duration=120)
102
- def generate_images(prompt, neg_prompt_1, neg_prompt_2, neg_prompt_3, width, height, guidance, steps, progress=gr.Progress(track_tqdm=True)):
103
- """Generates the main image and its upscaled version on the GPU."""
104
  seed = random.randint(0, MAX_SEED)
105
  generator = torch.Generator(device=device).manual_seed(seed)
106
-
107
  print('-- generating image --')
108
  sd_image = pipe(
109
  prompt=prompt, prompt_2=prompt, prompt_3=prompt,
@@ -113,48 +99,102 @@ def generate_images(prompt, neg_prompt_1, neg_prompt_2, neg_prompt_3, width, hei
113
  max_sequence_length=512
114
  ).images[0]
115
  print('-- got image --')
116
-
117
  with torch.no_grad():
118
  upscale = upscaler_2(sd_image, tiling=True, tile_width=256, tile_height=256)
119
  upscale2 = upscaler_2(upscale, tiling=True, tile_width=256, tile_height=256)
120
  print('-- got upscaled image --')
121
  downscaled_upscale = upscale2.resize((upscale2.width // 4, upscale2.height // 4), Image.LANCZOS)
122
-
123
  return sd_image, downscaled_upscale, prompt
124
 
125
- # --- Main Gradio Handler (runs on CPU) ---
126
- def run_inference_and_upload(prompt, neg_prompt_1, neg_prompt_2, neg_prompt_3, width, height, guidance, steps, save_consent, progress=gr.Progress(track_tqdm=True)):
127
- """
128
- Orchestrates the process: calls the GPU function, then handles the upload if consented.
129
- """
130
- # 1. Call the GPU-bound function to get the images
131
- sd_image, upscaled_image, expanded_prompt = generate_images(prompt, neg_prompt_1, neg_prompt_2, neg_prompt_3, width, height, guidance, steps, progress)
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
- # 2. If user consented, start uploads in background threads
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  if save_consent:
135
  print("✅ User consented to save. Preparing uploads...")
136
  timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
137
  sd_filename = f"sd35ll_{timestamp}.png"
138
  upscale_filename = f"sd35ll_upscale_{timestamp}.png"
139
-
140
- # Create and start threads for each upload
141
  sd_thread = threading.Thread(target=upload_to_gcs, args=(sd_image, sd_filename))
142
  upscale_thread = threading.Thread(target=upload_to_gcs, args=(upscaled_image, upscale_filename))
143
-
144
  sd_thread.start()
145
  upscale_thread.start()
146
  else:
147
  print("ℹ️ User did not consent to save. Skipping upload.")
 
148
 
149
- # 3. Return the primary image to the UI immediately
 
 
 
 
 
 
 
 
 
 
 
 
150
  return sd_image, expanded_prompt
151
 
152
- # --- Gradio UI Definition ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  css = """
154
  #col-container {margin: 0 auto;max-width: 640px;}
155
  body{background-color: blue;}
156
  """
157
-
158
  with gr.Blocks(theme=gr.themes.Origin(), css=css) as demo:
159
  with gr.Column(elem_id="col-container"):
160
  gr.Markdown(" # StableDiffusion 3.5 Large with UltraReal lora test")
@@ -164,18 +204,15 @@ with gr.Blocks(theme=gr.themes.Origin(), css=css) as demo:
164
  label="Prompt", show_label=False, max_lines=1,
165
  placeholder="Enter your prompt", container=False,
166
  )
167
- # Use a single run button for simplicity or keep multiple if durations are critical
168
- run_button = gr.Button("Run", scale=0, variant="primary")
169
-
170
  result = gr.Image(label="Result", show_label=False, type="pil")
171
-
172
- # --- New Consent Checkbox ---
173
  save_consent_checkbox = gr.Checkbox(
174
  label="✅ Anonymously upload result to a public gallery",
175
- value=False, # Default to not uploading
176
  info="Check this box to help us by contributing your image."
177
  )
178
-
179
  with gr.Accordion("Advanced Settings", open=True):
180
  negative_prompt_1 = gr.Text(label="Negative prompt 1", max_lines=1, placeholder="Enter a negative prompt", value="bad anatomy, poorly drawn hands, distorted face, blurry, out of frame, low resolution, grainy, pixelated, disfigured, mutated, extra limbs, bad composition")
181
  negative_prompt_2 = gr.Text(label="Negative prompt 2", max_lines=1, placeholder="Enter a second negative prompt", value="unrealistic, cartoon, anime, sketch, painting, drawing, illustration, graphic, digital art, render, 3d, blurry, deformed, disfigured, poorly drawn, bad anatomy, mutated, extra limbs, ugly, out of frame, bad composition, low resolution, grainy, pixelated, noisy, oversaturated, undersaturated, (worst quality, low quality:1.3), (bad hands, missing fingers:1.2)")
@@ -187,9 +224,8 @@ with gr.Blocks(theme=gr.themes.Origin(), css=css) as demo:
187
  guidance_scale = gr.Slider(label="Guidance scale", minimum=0.0, maximum=30.0, step=0.1, value=4.2)
188
  num_inference_steps = gr.Slider(label="Inference steps", minimum=1, maximum=150, step=1, value=60)
189
 
190
- # Connect the button to the main handler function
191
- run_button.click(
192
- fn=run_inference_and_upload,
193
  inputs=[
194
  prompt,
195
  negative_prompt_1,
@@ -204,5 +240,38 @@ with gr.Blocks(theme=gr.themes.Origin(), css=css) as demo:
204
  outputs=[result, expanded_prompt_output],
205
  )
206
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  if __name__ == "__main__":
208
  demo.launch()
 
51
  except Exception as e:
52
  print(f"❌ Failed to initialize GCS client: {e}")
53
 
 
54
  def upload_to_gcs(image_object, filename):
 
55
  if not gcs_client:
56
  print("⚠️ GCS client not initialized. Skipping upload.")
57
  return
 
58
  try:
59
  print(f"--> Starting GCS upload for {filename}...")
60
  bucket = gcs_client.bucket(GCS_BUCKET_NAME)
61
  blob = bucket.blob(f"stablediff/{filename}")
 
 
62
  img_byte_arr = io.BytesIO()
63
  image_object.save(img_byte_arr, format='PNG', optimize=False, compress_level=0)
64
  img_byte_arr = img_byte_arr.getvalue()
 
 
65
  blob.upload_from_string(img_byte_arr, content_type='image/png')
66
  print(f"✅ Successfully uploaded {filename} to GCS.")
 
67
  except Exception as e:
68
  print(f"❌ An error occurred during GCS upload: {e}")
69
 
 
70
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
71
 
72
  pipe = StableDiffusion3Pipeline.from_pretrained(
73
  "ford442/stable-diffusion-3.5-large-bf16",
74
  trust_remote_code=True,
75
  transformer=None, # Load transformer separately
76
+ use_safetensors=True
 
77
  )
78
+
79
+ ll_transformer=SD3Transformer2DModel.from_pretrained("ford442/stable-diffusion-3.5-large-bf16", subfolder='transformer').to(device, dtype=torch.bfloat16)
80
  pipe.transformer=ll_transformer
81
  pipe.load_lora_weights("ford442/sdxl-vae-bf16", weight_name="LoRA/UltraReal.safetensors")
82
  pipe.to(device=device, dtype=torch.bfloat16)
 
86
  MAX_SEED = np.iinfo(np.int32).max
87
  MAX_IMAGE_SIZE = 4096
88
 
89
+ @spaces.GPU(duration=45)
90
+ def generate_images_30(prompt, neg_prompt_1, neg_prompt_2, neg_prompt_3, width, height, guidance, steps, progress=gr.Progress(track_tqdm=True)):
 
 
91
  seed = random.randint(0, MAX_SEED)
92
  generator = torch.Generator(device=device).manual_seed(seed)
 
93
  print('-- generating image --')
94
  sd_image = pipe(
95
  prompt=prompt, prompt_2=prompt, prompt_3=prompt,
 
99
  max_sequence_length=512
100
  ).images[0]
101
  print('-- got image --')
 
102
  with torch.no_grad():
103
  upscale = upscaler_2(sd_image, tiling=True, tile_width=256, tile_height=256)
104
  upscale2 = upscaler_2(upscale, tiling=True, tile_width=256, tile_height=256)
105
  print('-- got upscaled image --')
106
  downscaled_upscale = upscale2.resize((upscale2.width // 4, upscale2.height // 4), Image.LANCZOS)
 
107
  return sd_image, downscaled_upscale, prompt
108
 
109
+ @spaces.GPU(duration=70)
110
+ def generate_images_60(prompt, neg_prompt_1, neg_prompt_2, neg_prompt_3, width, height, guidance, steps, progress=gr.Progress(track_tqdm=True)):
111
+ seed = random.randint(0, MAX_SEED)
112
+ generator = torch.Generator(device=device).manual_seed(seed)
113
+ print('-- generating image --')
114
+ sd_image = pipe(
115
+ prompt=prompt, prompt_2=prompt, prompt_3=prompt,
116
+ negative_prompt=neg_prompt_1, negative_prompt_2=neg_prompt_2, negative_prompt_3=neg_prompt_3,
117
+ guidance_scale=guidance, num_inference_steps=steps,
118
+ width=width, height=height, generator=generator,
119
+ max_sequence_length=512
120
+ ).images[0]
121
+ print('-- got image --')
122
+ with torch.no_grad():
123
+ upscale = upscaler_2(sd_image, tiling=True, tile_width=256, tile_height=256)
124
+ upscale2 = upscaler_2(upscale, tiling=True, tile_width=256, tile_height=256)
125
+ print('-- got upscaled image --')
126
+ downscaled_upscale = upscale2.resize((upscale2.width // 4, upscale2.height // 4), Image.LANCZOS)
127
+ return sd_image, downscaled_upscale, prompt
128
 
129
+ @spaces.GPU(duration=110)
130
+ def generate_images_100(prompt, neg_prompt_1, neg_prompt_2, neg_prompt_3, width, height, guidance, steps, progress=gr.Progress(track_tqdm=True)):
131
+ seed = random.randint(0, MAX_SEED)
132
+ generator = torch.Generator(device=device).manual_seed(seed)
133
+ print('-- generating image --')
134
+ sd_image = pipe(
135
+ prompt=prompt, prompt_2=prompt, prompt_3=prompt,
136
+ negative_prompt=neg_prompt_1, negative_prompt_2=neg_prompt_2, negative_prompt_3=neg_prompt_3,
137
+ guidance_scale=guidance, num_inference_steps=steps,
138
+ width=width, height=height, generator=generator,
139
+ max_sequence_length=512
140
+ ).images[0]
141
+ print('-- got image --')
142
+ with torch.no_grad():
143
+ upscale = upscaler_2(sd_image, tiling=True, tile_width=256, tile_height=256)
144
+ upscale2 = upscaler_2(upscale, tiling=True, tile_width=256, tile_height=256)
145
+ print('-- got upscaled image --')
146
+ downscaled_upscale = upscale2.resize((upscale2.width // 4, upscale2.height // 4), Image.LANCZOS)
147
+ return sd_image, downscaled_upscale, prompt
148
+
149
+ def run_inference_and_upload_30(prompt, neg_prompt_1, neg_prompt_2, neg_prompt_3, width, height, guidance, steps, save_consent, progress=gr.Progress(track_tqdm=True)):
150
+ sd_image, upscaled_image, expanded_prompt = generate_images_30(prompt, neg_prompt_1, neg_prompt_2, neg_prompt_3, width, height, guidance, steps, progress)
151
  if save_consent:
152
  print("✅ User consented to save. Preparing uploads...")
153
  timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
154
  sd_filename = f"sd35ll_{timestamp}.png"
155
  upscale_filename = f"sd35ll_upscale_{timestamp}.png"
 
 
156
  sd_thread = threading.Thread(target=upload_to_gcs, args=(sd_image, sd_filename))
157
  upscale_thread = threading.Thread(target=upload_to_gcs, args=(upscaled_image, upscale_filename))
 
158
  sd_thread.start()
159
  upscale_thread.start()
160
  else:
161
  print("ℹ️ User did not consent to save. Skipping upload.")
162
+ return sd_image, expanded_prompt
163
 
164
+ def run_inference_and_upload_60(prompt, neg_prompt_1, neg_prompt_2, neg_prompt_3, width, height, guidance, steps, save_consent, progress=gr.Progress(track_tqdm=True)):
165
+ sd_image, upscaled_image, expanded_prompt = generate_images_60(prompt, neg_prompt_1, neg_prompt_2, neg_prompt_3, width, height, guidance, steps, progress)
166
+ if save_consent:
167
+ print("✅ User consented to save. Preparing uploads...")
168
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
169
+ sd_filename = f"sd35ll_{timestamp}.png"
170
+ upscale_filename = f"sd35ll_upscale_{timestamp}.png"
171
+ sd_thread = threading.Thread(target=upload_to_gcs, args=(sd_image, sd_filename))
172
+ upscale_thread = threading.Thread(target=upload_to_gcs, args=(upscaled_image, upscale_filename))
173
+ sd_thread.start()
174
+ upscale_thread.start()
175
+ else:
176
+ print("ℹ️ User did not consent to save. Skipping upload.")
177
  return sd_image, expanded_prompt
178
 
179
+ def run_inference_and_upload_100(prompt, neg_prompt_1, neg_prompt_2, neg_prompt_3, width, height, guidance, steps, save_consent, progress=gr.Progress(track_tqdm=True)):
180
+ sd_image, upscaled_image, expanded_prompt = generate_images_100(prompt, neg_prompt_1, neg_prompt_2, neg_prompt_3, width, height, guidance, steps, progress)
181
+ if save_consent:
182
+ print("✅ User consented to save. Preparing uploads...")
183
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
184
+ sd_filename = f"sd35ll_{timestamp}.png"
185
+ upscale_filename = f"sd35ll_upscale_{timestamp}.png"
186
+ sd_thread = threading.Thread(target=upload_to_gcs, args=(sd_image, sd_filename))
187
+ upscale_thread = threading.Thread(target=upload_to_gcs, args=(upscaled_image, upscale_filename))
188
+ sd_thread.start()
189
+ upscale_thread.start()
190
+ else:
191
+ print("ℹ️ User did not consent to save. Skipping upload.")
192
+ return sd_image, expanded_prompt
193
+
194
  css = """
195
  #col-container {margin: 0 auto;max-width: 640px;}
196
  body{background-color: blue;}
197
  """
 
198
  with gr.Blocks(theme=gr.themes.Origin(), css=css) as demo:
199
  with gr.Column(elem_id="col-container"):
200
  gr.Markdown(" # StableDiffusion 3.5 Large with UltraReal lora test")
 
204
  label="Prompt", show_label=False, max_lines=1,
205
  placeholder="Enter your prompt", container=False,
206
  )
207
+ run_button_30 = gr.Button("Run30", scale=0, variant="primary")
208
+ run_button_60 = gr.Button("Run60", scale=0, variant="primary")
209
+ run_button_100 = gr.Button("Run100", scale=0, variant="primary")
210
  result = gr.Image(label="Result", show_label=False, type="pil")
 
 
211
  save_consent_checkbox = gr.Checkbox(
212
  label="✅ Anonymously upload result to a public gallery",
213
+ value=True, # Default to not uploading
214
  info="Check this box to help us by contributing your image."
215
  )
 
216
  with gr.Accordion("Advanced Settings", open=True):
217
  negative_prompt_1 = gr.Text(label="Negative prompt 1", max_lines=1, placeholder="Enter a negative prompt", value="bad anatomy, poorly drawn hands, distorted face, blurry, out of frame, low resolution, grainy, pixelated, disfigured, mutated, extra limbs, bad composition")
218
  negative_prompt_2 = gr.Text(label="Negative prompt 2", max_lines=1, placeholder="Enter a second negative prompt", value="unrealistic, cartoon, anime, sketch, painting, drawing, illustration, graphic, digital art, render, 3d, blurry, deformed, disfigured, poorly drawn, bad anatomy, mutated, extra limbs, ugly, out of frame, bad composition, low resolution, grainy, pixelated, noisy, oversaturated, undersaturated, (worst quality, low quality:1.3), (bad hands, missing fingers:1.2)")
 
224
  guidance_scale = gr.Slider(label="Guidance scale", minimum=0.0, maximum=30.0, step=0.1, value=4.2)
225
  num_inference_steps = gr.Slider(label="Inference steps", minimum=1, maximum=150, step=1, value=60)
226
 
227
+ run_button_30.click(
228
+ fn=run_inference_and_upload_30,
 
229
  inputs=[
230
  prompt,
231
  negative_prompt_1,
 
240
  outputs=[result, expanded_prompt_output],
241
  )
242
 
243
+ run_button_60.click(
244
+ fn=run_inference_and_upload_60,
245
+ inputs=[
246
+ prompt,
247
+ negative_prompt_1,
248
+ negative_prompt_2,
249
+ negative_prompt_3,
250
+ width,
251
+ height,
252
+ guidance_scale,
253
+ num_inference_steps,
254
+ save_consent_checkbox # Pass the checkbox value
255
+ ],
256
+ outputs=[result, expanded_prompt_output],
257
+ )
258
+
259
+ run_button_100.click(
260
+ fn=run_inference_and_upload_100,
261
+ inputs=[
262
+ prompt,
263
+ negative_prompt_1,
264
+ negative_prompt_2,
265
+ negative_prompt_3,
266
+ width,
267
+ height,
268
+ guidance_scale,
269
+ num_inference_steps,
270
+ save_consent_checkbox # Pass the checkbox value
271
+ ],
272
+ outputs=[result, expanded_prompt_output],
273
+ )
274
+
275
+
276
  if __name__ == "__main__":
277
  demo.launch()