ginipick commited on
Commit
7d4bf39
ยท
verified ยท
1 Parent(s): 1b3b25e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +236 -269
app.py CHANGED
@@ -82,6 +82,8 @@ def inference(
82
  if randomize_seed:
83
  seed = random.randint(0, MAX_SEED)
84
  generator = torch.Generator(device=device).manual_seed(seed)
 
 
85
 
86
  try:
87
  image = pipeline(
@@ -102,70 +104,123 @@ def inference(
102
 
103
  # ----------------------------- Florence-2 Captioner ---------------------------
104
  import subprocess
105
- subprocess.run(
106
- 'pip install flash-attn --no-build-isolation',
107
- env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
108
- shell=True
109
- )
 
 
 
110
 
111
  from transformers import AutoProcessor, AutoModelForCausalLM
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  # Pre-load models and processors
114
- models = {
115
- 'gokaygokay/Florence-2-Flux-Large': AutoModelForCausalLM.from_pretrained(
116
- 'gokaygokay/Florence-2-Flux-Large', trust_remote_code=True
117
- ).eval(),
118
- 'gokaygokay/Florence-2-Flux': AutoModelForCausalLM.from_pretrained(
119
- 'gokaygokay/Florence-2-Flux', trust_remote_code=True
120
- ).eval(),
121
- }
122
 
123
- processors = {
124
- 'gokaygokay/Florence-2-Flux-Large': AutoProcessor.from_pretrained(
125
- 'gokaygokay/Florence-2-Flux-Large', trust_remote_code=True
126
- ),
127
- 'gokaygokay/Florence-2-Flux': AutoProcessor.from_pretrained(
128
- 'gokaygokay/Florence-2-Flux', trust_remote_code=True
129
- ),
130
- }
 
 
 
 
 
 
 
 
 
131
 
132
  @spaces.GPU
133
- def caption_image(image, model_name='gokaygokay/Florence-2-Flux-Large'):
134
  """
135
  Runs the selected Florence-2 model to generate a detailed caption.
136
  """
137
  from PIL import Image as PILImage
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
- task_prompt = "<DESCRIPTION>"
140
- user_prompt = task_prompt + "Describe this image in great detail."
141
-
142
  # Convert input to RGB if needed
143
- image = PILImage.fromarray(image)
144
- if image.mode != "RGB":
145
- image = image.convert("RGB")
146
-
 
 
 
 
 
 
 
 
147
  model = models[model_name]
148
  processor = processors[model_name]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
- inputs = processor(text=user_prompt, images=image, return_tensors="pt")
151
- generated_ids = model.generate(
152
- input_ids=inputs["input_ids"],
153
- pixel_values=inputs["pixel_values"],
154
- max_new_tokens=1024,
155
- num_beams=3,
156
- repetition_penalty=1.10,
157
- )
158
- generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
159
- parsed_answer = processor.post_process_generation(
160
- generated_text, task=task_prompt, image_size=(image.width, image.height)
161
- )
162
- return parsed_answer["<DESCRIPTION>"]
163
-
164
- # --------- NEW FUNCTION: Process uploaded image and generate Ghibli style image ---------
165
  @spaces.GPU(duration=120)
166
  def process_uploaded_image(
167
  image,
168
- model_name,
169
  seed,
170
  randomize_seed,
171
  width,
@@ -174,245 +229,157 @@ def process_uploaded_image(
174
  num_inference_steps,
175
  lora_scale
176
  ):
 
 
 
 
 
 
177
  # Step 1: Generate caption from the uploaded image
178
- caption = caption_image(image, model_name)
 
 
 
 
 
 
 
 
179
 
180
  # Step 2: Append "ghibli style" to the caption
181
  ghibli_prompt = f"{caption}, ghibli style"
 
182
 
183
  # Step 3: Generate Ghibli-style image based on the caption
184
- generated_image, used_seed = inference(
185
- prompt=ghibli_prompt,
186
- seed=seed,
187
- randomize_seed=randomize_seed,
188
- width=width,
189
- height=height,
190
- guidance_scale=guidance_scale,
191
- num_inference_steps=num_inference_steps,
192
- lora_scale=lora_scale
193
- )
194
-
195
- return generated_image, used_seed, caption, ghibli_prompt
 
 
 
 
 
 
196
 
197
  # ----------------------------- Gradio UI --------------------------------------
198
  with gr.Blocks(analytics_enabled=False) as demo:
199
- with gr.Tabs():
200
- # ------------------ TAB 1: Image Generation ----------------------------
201
- with gr.TabItem("FLUX Ghibli LoRA Generator"):
202
- gr.Markdown("## Generate an image with the FLUX Ghibli LoRA")
203
 
 
 
 
 
204
  with gr.Row():
205
- with gr.Column():
206
- prompt = gr.Textbox(
207
- label="Prompt",
208
- placeholder="Describe your Ghibli-style image...",
209
- lines=3
210
- )
211
- with gr.Row():
212
- seed = gr.Slider(
213
- label="Seed",
214
- minimum=0,
215
- maximum=MAX_SEED,
216
- step=1,
217
- value=42
218
- )
219
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
220
- with gr.Row():
221
- width = gr.Slider(
222
- label="Width",
223
- minimum=256,
224
- maximum=MAX_IMAGE_SIZE,
225
- step=32,
226
- value=512
227
- )
228
- height = gr.Slider(
229
- label="Height",
230
- minimum=256,
231
- maximum=MAX_IMAGE_SIZE,
232
- step=32,
233
- value=512
234
- )
235
- with gr.Row():
236
- guidance_scale = gr.Slider(
237
- label="Guidance scale",
238
- minimum=0.0,
239
- maximum=10.0,
240
- step=0.1,
241
- value=3.5
242
- )
243
- num_inference_steps = gr.Slider(
244
- label="Steps",
245
- minimum=1,
246
- maximum=50,
247
- step=1,
248
- value=30
249
- )
250
- lora_scale = gr.Slider(
251
- label="LoRA scale",
252
- minimum=0.0,
253
- maximum=1.0,
254
- step=0.1,
255
- value=1.0
256
- )
257
- generate_button = gr.Button("Generate Image")
258
-
259
- with gr.Column():
260
- output_image = gr.Image(label="Generated Image")
261
- output_seed = gr.Number(label="Seed Used")
262
-
263
- # Link the button to the inference function
264
- generate_button.click(
265
- inference,
266
- inputs=[
267
- prompt,
268
- seed,
269
- randomize_seed,
270
- width,
271
- height,
272
- guidance_scale,
273
- num_inference_steps,
274
- lora_scale,
275
- ],
276
- outputs=[output_image, output_seed]
277
- )
278
-
279
- # ------------------ TAB 2: Image Captioning ---------------------------
280
- with gr.TabItem("Florence-2 Captioner"):
281
- gr.Markdown("## Generate a caption for an uploaded image using Florence-2")
282
-
283
  with gr.Row():
284
- with gr.Column():
285
- input_img = gr.Image(label="Upload an Image")
286
- model_selector = gr.Dropdown(
287
- choices=list(models.keys()),
288
- value='gokaygokay/Florence-2-Flux-Large',
289
- label="Select Model"
290
- )
291
- caption_button = gr.Button("Generate Caption")
292
- with gr.Column():
293
- caption_output = gr.Textbox(label="Caption")
294
-
295
- caption_button.click(
296
- caption_image,
297
- inputs=[input_img, model_selector],
298
- outputs=[caption_output]
299
- )
300
 
301
- # ------------------ NEW TAB 3: Image to Ghibli Style ---------------------------
302
- with gr.TabItem("์ด๋ฏธ์ง€ to ์ง€๋ธŒ๋ฆฌ ์Šคํƒ€์ผ"):
303
- gr.Markdown("## Upload an image and transform it to Ghibli style")
304
-
305
  with gr.Row():
306
- with gr.Column():
307
- upload_img = gr.Image(label="Upload an Image")
308
- caption_model_selector = gr.Dropdown(
309
- choices=list(models.keys()),
310
- value='gokaygokay/Florence-2-Flux-Large',
311
- label="Caption Model",
312
- visible=False # Hidden as requested
313
- )
314
- with gr.Row():
315
- img2img_seed = gr.Slider(
316
- label="Seed",
317
- minimum=0,
318
- maximum=MAX_SEED,
319
- step=1,
320
- value=42
321
- )
322
- img2img_randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
323
- with gr.Row():
324
- img2img_width = gr.Slider(
325
- label="Width",
326
- minimum=256,
327
- maximum=MAX_IMAGE_SIZE,
328
- step=32,
329
- value=512
330
- )
331
- img2img_height = gr.Slider(
332
- label="Height",
333
- minimum=256,
334
- maximum=MAX_IMAGE_SIZE,
335
- step=32,
336
- value=512
337
- )
338
- with gr.Row():
339
- img2img_guidance_scale = gr.Slider(
340
- label="Guidance scale",
341
- minimum=0.0,
342
- maximum=10.0,
343
- step=0.1,
344
- value=3.5
345
- )
346
- img2img_steps = gr.Slider(
347
- label="Steps",
348
- minimum=1,
349
- maximum=50,
350
- step=1,
351
- value=30
352
- )
353
- img2img_lora_scale = gr.Slider(
354
- label="LoRA scale",
355
- minimum=0.0,
356
- maximum=1.0,
357
- step=0.1,
358
- value=1.0
359
- )
360
- transform_button = gr.Button("Transform to Ghibli Style")
361
-
362
- with gr.Column():
363
- ghibli_output_image = gr.Image(label="Generated Ghibli Image")
364
- ghibli_output_seed = gr.Number(label="Seed Used")
365
- extracted_caption = gr.Textbox(
366
- label="Extracted Description",
367
- visible=False # Hidden as requested
368
- )
369
- ghibli_prompt = gr.Textbox(
370
- label="Generated Prompt",
371
- visible=False # Hidden as requested
372
- )
373
-
374
- # Auto-process when image is uploaded
375
- upload_img.upload(
376
- process_uploaded_image,
377
- inputs=[
378
- upload_img,
379
- caption_model_selector,
380
- img2img_seed,
381
- img2img_randomize_seed,
382
- img2img_width,
383
- img2img_height,
384
- img2img_guidance_scale,
385
- img2img_steps,
386
- img2img_lora_scale,
387
- ],
388
- outputs=[
389
- ghibli_output_image,
390
- ghibli_output_seed,
391
- extracted_caption,
392
- ghibli_prompt,
393
- ]
394
- )
395
 
396
- # Manual process button
397
- transform_button.click(
398
- process_uploaded_image,
399
- inputs=[
400
- upload_img,
401
- caption_model_selector,
402
- img2img_seed,
403
- img2img_randomize_seed,
404
- img2img_width,
405
- img2img_height,
406
- img2img_guidance_scale,
407
- img2img_steps,
408
- img2img_lora_scale,
409
- ],
410
- outputs=[
411
- ghibli_output_image,
412
- ghibli_output_seed,
413
- extracted_caption,
414
- ghibli_prompt,
415
- ]
416
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
417
 
418
  demo.launch(debug=True)
 
82
  if randomize_seed:
83
  seed = random.randint(0, MAX_SEED)
84
  generator = torch.Generator(device=device).manual_seed(seed)
85
+
86
+ print(f"Running inference with prompt: {prompt}")
87
 
88
  try:
89
  image = pipeline(
 
104
 
105
  # ----------------------------- Florence-2 Captioner ---------------------------
106
  import subprocess
107
+ try:
108
+ subprocess.run(
109
+ 'pip install flash-attn --no-build-isolation',
110
+ env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
111
+ shell=True
112
+ )
113
+ except Exception as e:
114
+ print(f"Warning: Could not install flash-attn: {e}")
115
 
116
  from transformers import AutoProcessor, AutoModelForCausalLM
117
 
118
+ # Function to safely load models
119
+ def load_caption_model(model_name):
120
+ try:
121
+ model = AutoModelForCausalLM.from_pretrained(
122
+ model_name, trust_remote_code=True
123
+ ).eval()
124
+ processor = AutoProcessor.from_pretrained(
125
+ model_name, trust_remote_code=True
126
+ )
127
+ return model, processor
128
+ except Exception as e:
129
+ print(f"Error loading caption model {model_name}: {e}")
130
+ return None, None
131
+
132
  # Pre-load models and processors
133
+ print("Loading captioning models...")
134
+ default_caption_model = 'gokaygokay/Florence-2-Flux-Large'
135
+ models = {}
136
+ processors = {}
 
 
 
 
137
 
138
+ # Try to load the default model
139
+ default_model, default_processor = load_caption_model(default_caption_model)
140
+ if default_model is not None and default_processor is not None:
141
+ models[default_caption_model] = default_model
142
+ processors[default_caption_model] = default_processor
143
+ print(f"Successfully loaded default caption model: {default_caption_model}")
144
+ else:
145
+ # Fallback to simpler model
146
+ fallback_model = 'gokaygokay/Florence-2-Flux'
147
+ fallback_model_obj, fallback_processor = load_caption_model(fallback_model)
148
+ if fallback_model_obj is not None and fallback_processor is not None:
149
+ models[fallback_model] = fallback_model_obj
150
+ processors[fallback_model] = fallback_processor
151
+ default_caption_model = fallback_model
152
+ print(f"Loaded fallback caption model: {fallback_model}")
153
+ else:
154
+ print("WARNING: Failed to load any caption model!")
155
 
156
  @spaces.GPU
157
+ def caption_image(image, model_name=default_caption_model):
158
  """
159
  Runs the selected Florence-2 model to generate a detailed caption.
160
  """
161
  from PIL import Image as PILImage
162
+ import numpy as np
163
+
164
+ print(f"Starting caption generation with model: {model_name}")
165
+
166
+ # Handle case where image is already a PIL image
167
+ if isinstance(image, PILImage.Image):
168
+ pil_image = image
169
+ else:
170
+ # Convert numpy array to PIL
171
+ if isinstance(image, np.ndarray):
172
+ pil_image = PILImage.fromarray(image)
173
+ else:
174
+ print(f"Unexpected image type: {type(image)}")
175
+ return "Error: Unsupported image type"
176
 
 
 
 
177
  # Convert input to RGB if needed
178
+ if pil_image.mode != "RGB":
179
+ pil_image = pil_image.convert("RGB")
180
+
181
+ # Check if model is available
182
+ if model_name not in models or model_name not in processors:
183
+ available_models = list(models.keys())
184
+ if available_models:
185
+ model_name = available_models[0]
186
+ print(f"Requested model not available, using: {model_name}")
187
+ else:
188
+ return "Error: No caption models available"
189
+
190
  model = models[model_name]
191
  processor = processors[model_name]
192
+
193
+ task_prompt = "<DESCRIPTION>"
194
+ user_prompt = task_prompt + "Describe this image in great detail."
195
+
196
+ try:
197
+ inputs = processor(text=user_prompt, images=pil_image, return_tensors="pt")
198
+
199
+ generated_ids = model.generate(
200
+ input_ids=inputs["input_ids"],
201
+ pixel_values=inputs["pixel_values"],
202
+ max_new_tokens=1024,
203
+ num_beams=3,
204
+ repetition_penalty=1.10,
205
+ )
206
+
207
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
208
+ parsed_answer = processor.post_process_generation(
209
+ generated_text, task=task_prompt, image_size=(pil_image.width, pil_image.height)
210
+ )
211
+
212
+ # Extract the caption
213
+ caption = parsed_answer.get("<DESCRIPTION>", "")
214
+ print(f"Generated caption: {caption}")
215
+ return caption
216
+ except Exception as e:
217
+ print(f"Error during captioning: {e}")
218
+ return f"Error generating caption: {str(e)}"
219
 
220
+ # --------- Process uploaded image and generate Ghibli style image ---------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  @spaces.GPU(duration=120)
222
  def process_uploaded_image(
223
  image,
 
224
  seed,
225
  randomize_seed,
226
  width,
 
229
  num_inference_steps,
230
  lora_scale
231
  ):
232
+ if image is None:
233
+ print("No image provided")
234
+ return None, None, "No image provided", "No image provided"
235
+
236
+ print("Starting image processing workflow")
237
+
238
  # Step 1: Generate caption from the uploaded image
239
+ try:
240
+ caption = caption_image(image)
241
+ if caption.startswith("Error:"):
242
+ print(f"Captioning failed: {caption}")
243
+ # Use a default caption as fallback
244
+ caption = "A beautiful scene"
245
+ except Exception as e:
246
+ print(f"Exception during captioning: {e}")
247
+ caption = "A beautiful scene"
248
 
249
  # Step 2: Append "ghibli style" to the caption
250
  ghibli_prompt = f"{caption}, ghibli style"
251
+ print(f"Final prompt for Ghibli generation: {ghibli_prompt}")
252
 
253
  # Step 3: Generate Ghibli-style image based on the caption
254
+ try:
255
+ generated_image, used_seed = inference(
256
+ prompt=ghibli_prompt,
257
+ seed=seed,
258
+ randomize_seed=randomize_seed,
259
+ width=width,
260
+ height=height,
261
+ guidance_scale=guidance_scale,
262
+ num_inference_steps=num_inference_steps,
263
+ lora_scale=lora_scale
264
+ )
265
+
266
+ print(f"Image generation complete with seed: {used_seed}")
267
+ return generated_image, used_seed, caption, ghibli_prompt
268
+ except Exception as e:
269
+ print(f"Error generating image: {e}")
270
+ error_img = Image.new('RGB', (width, height), color='red')
271
+ return error_img, seed, caption, ghibli_prompt
272
 
273
  # ----------------------------- Gradio UI --------------------------------------
274
  with gr.Blocks(analytics_enabled=False) as demo:
275
+ gr.Markdown("# ์ด๋ฏธ์ง€ to ์ง€๋ธŒ๋ฆฌ ์Šคํƒ€์ผ ๋ณ€ํ™˜")
276
+ gr.Markdown("์ด๋ฏธ์ง€๋ฅผ ์—…๋กœ๋“œํ•˜๋ฉด ์ž๋™์œผ๋กœ ์ด๋ฏธ์ง€ ์„ค๋ช…์ด ์ถ”์ถœ๋˜๊ณ  ์ง€๋ธŒ๋ฆฌ ์Šคํƒ€์ผ๋กœ ๋ณ€ํ™˜๋ฉ๋‹ˆ๋‹ค.")
 
 
277
 
278
+ with gr.Row():
279
+ with gr.Column():
280
+ upload_img = gr.Image(label="์ด๋ฏธ์ง€ ์—…๋กœ๋“œ", type="pil")
281
+
282
  with gr.Row():
283
+ img2img_seed = gr.Slider(
284
+ label="Seed",
285
+ minimum=0,
286
+ maximum=MAX_SEED,
287
+ step=1,
288
+ value=42
289
+ )
290
+ img2img_randomize_seed = gr.Checkbox(label="๋žœ๋ค ์‹œ๋“œ", value=True)
291
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
  with gr.Row():
293
+ img2img_width = gr.Slider(
294
+ label="๋„ˆ๋น„",
295
+ minimum=256,
296
+ maximum=MAX_IMAGE_SIZE,
297
+ step=32,
298
+ value=512
299
+ )
300
+ img2img_height = gr.Slider(
301
+ label="๋†’์ด",
302
+ minimum=256,
303
+ maximum=MAX_IMAGE_SIZE,
304
+ step=32,
305
+ value=512
306
+ )
 
 
307
 
 
 
 
 
308
  with gr.Row():
309
+ img2img_guidance_scale = gr.Slider(
310
+ label="๊ฐ€์ด๋˜์Šค ์Šค์ผ€์ผ",
311
+ minimum=0.0,
312
+ maximum=10.0,
313
+ step=0.1,
314
+ value=3.5
315
+ )
316
+ img2img_steps = gr.Slider(
317
+ label="์Šคํ…",
318
+ minimum=1,
319
+ maximum=50,
320
+ step=1,
321
+ value=30
322
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
 
324
+ img2img_lora_scale = gr.Slider(
325
+ label="LoRA ์Šค์ผ€์ผ",
326
+ minimum=0.0,
327
+ maximum=1.0,
328
+ step=0.1,
329
+ value=1.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
  )
331
+
332
+ transform_button = gr.Button("์ง€๋ธŒ๋ฆฌ ์Šคํƒ€์ผ๋กœ ๋ณ€ํ™˜")
333
+
334
+ with gr.Column():
335
+ ghibli_output_image = gr.Image(label="์ƒ์„ฑ๋œ ์ง€๋ธŒ๋ฆฌ ์Šคํƒ€์ผ ์ด๋ฏธ์ง€")
336
+ ghibli_output_seed = gr.Number(label="์‚ฌ์šฉ๋œ ์‹œ๋“œ")
337
+
338
+ # Debug elements (hidden by default)
339
+ with gr.Accordion("๋””๋ฒ„๊ทธ ์ •๋ณด", open=False):
340
+ extracted_caption = gr.Textbox(label="์ถ”์ถœ๋œ ์ด๋ฏธ์ง€ ์„ค๋ช…")
341
+ ghibli_prompt = gr.Textbox(label="์ƒ์„ฑ์— ์‚ฌ์šฉ๋œ ํ”„๋กฌํ”„ํŠธ")
342
+
343
+ # Auto-process when image is uploaded
344
+ upload_img.upload(
345
+ process_uploaded_image,
346
+ inputs=[
347
+ upload_img,
348
+ img2img_seed,
349
+ img2img_randomize_seed,
350
+ img2img_width,
351
+ img2img_height,
352
+ img2img_guidance_scale,
353
+ img2img_steps,
354
+ img2img_lora_scale,
355
+ ],
356
+ outputs=[
357
+ ghibli_output_image,
358
+ ghibli_output_seed,
359
+ extracted_caption,
360
+ ghibli_prompt,
361
+ ]
362
+ )
363
+
364
+ # Manual process button
365
+ transform_button.click(
366
+ process_uploaded_image,
367
+ inputs=[
368
+ upload_img,
369
+ img2img_seed,
370
+ img2img_randomize_seed,
371
+ img2img_width,
372
+ img2img_height,
373
+ img2img_guidance_scale,
374
+ img2img_steps,
375
+ img2img_lora_scale,
376
+ ],
377
+ outputs=[
378
+ ghibli_output_image,
379
+ ghibli_output_seed,
380
+ extracted_caption,
381
+ ghibli_prompt,
382
+ ]
383
+ )
384
 
385
  demo.launch(debug=True)