comrender commited on
Commit
d45f4bc
Β·
verified Β·
1 Parent(s): fbe598e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -107
app.py CHANGED
@@ -7,7 +7,6 @@ import numpy as np
7
  import spaces
8
  import torch
9
  from diffusers import FluxImg2ImgPipeline
10
- from transformers import AutoProcessor, AutoModelForCausalLM
11
  from gradio_imageslider import ImageSlider
12
  from PIL import Image
13
  from huggingface_hub import snapshot_download
@@ -40,82 +39,10 @@ device = "cpu"
40
  # Get HuggingFace token
41
  huggingface_token = os.getenv("HF_TOKEN")
42
 
43
- # Download FLUX model
44
- print("πŸ“₯ Downloading FLUX model...")
45
- model_path = snapshot_download(
46
- repo_id="black-forest-labs/FLUX.1-dev",
47
- repo_type="model",
48
- ignore_patterns=["*.md", "*.gitattributes"],
49
- local_dir="FLUX.1-dev",
50
- token=huggingface_token,
51
- )
52
-
53
- # Load Florence-2 model for image captioning on CPU
54
- print("πŸ“₯ Loading Florence-2 model...")
55
- florence_model = AutoModelForCausalLM.from_pretrained(
56
- "microsoft/Florence-2-large",
57
- torch_dtype=torch.float32, # Force CPU dtype
58
- trust_remote_code=True,
59
- attn_implementation="eager"
60
- ).to(device)
61
- florence_processor = AutoProcessor.from_pretrained(
62
- "microsoft/Florence-2-large",
63
- trust_remote_code=True
64
- )
65
-
66
- # Load FLUX Img2Img pipeline on CPU
67
- print("πŸ“₯ Loading FLUX Img2Img...")
68
- pipe = FluxImg2ImgPipeline.from_pretrained(
69
- model_path,
70
- torch_dtype=torch.float32 # Force CPU dtype
71
- )
72
- pipe.enable_vae_tiling()
73
- pipe.enable_vae_slicing()
74
-
75
- print("βœ… All models loaded successfully!")
76
-
77
- # Download ESRGAN model if using
78
- if USE_ESRGAN:
79
- esrgan_path = "4x-UltraSharp.pth"
80
- if not os.path.exists(esrgan_path):
81
- url = "https://huggingface.co/uwg/upscaler/resolve/main/ESRGAN/4x-UltraSharp.pth"
82
- with open(esrgan_path, "wb") as f:
83
- f.write(requests.get(url).content)
84
- esrgan_model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
85
- state_dict = torch.load(esrgan_path)['params_ema']
86
- esrgan_model.load_state_dict(state_dict)
87
- esrgan_model.eval()
88
-
89
  MAX_SEED = 1000000
90
  MAX_PIXEL_BUDGET = 8192 * 8192 # Increased for tiling support
91
 
92
 
93
- def generate_caption(image):
94
- """Generate detailed caption using Florence-2"""
95
- try:
96
- task_prompt = "<MORE_DETAILED_CAPTION>"
97
- prompt = task_prompt
98
-
99
- inputs = florence_processor(text=prompt, images=image, return_tensors="pt").to(device)
100
-
101
- generated_ids = florence_model.generate(
102
- input_ids=inputs["input_ids"],
103
- pixel_values=inputs["pixel_values"],
104
- max_new_tokens=1024,
105
- num_beams=3,
106
- do_sample=True,
107
- )
108
-
109
- generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
110
- parsed_answer = florence_processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.width, image.height))
111
-
112
- caption = parsed_answer[task_prompt]
113
- return caption
114
- except Exception as e:
115
- print(f"Caption generation failed: {e}")
116
- return "a high quality detailed image"
117
-
118
-
119
  def process_input(input_image, upscale_factor):
120
  """Process input image and handle size constraints"""
121
  w, h = input_image.size
@@ -216,21 +143,54 @@ def enhance_image(
216
  num_inference_steps,
217
  upscale_factor,
218
  denoising_strength,
219
- use_generated_caption,
220
  custom_prompt,
221
  progress=gr.Progress(track_tqdm=True),
222
  ):
223
  """Main enhancement function"""
224
- # Move models to GPU with fallback to CPU
225
- try:
226
- device = "cuda"
227
- pipe.to(device)
228
- florence_model.to(device)
229
- if USE_ESRGAN:
230
- esrgan_model.to(device)
231
- except Exception as e:
232
- print(f"GPU error: {e}, falling back to CPU")
233
- device = "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
  # Handle image input
236
  if image_input is not None:
@@ -250,13 +210,7 @@ def enhance_image(
250
  input_image, upscale_factor
251
  )
252
 
253
- # Generate caption if requested
254
- if use_generated_caption:
255
- gr.Info("πŸ” Generating image caption...")
256
- generated_caption = generate_caption(input_image)
257
- prompt = generated_caption
258
- else:
259
- prompt = custom_prompt if custom_prompt.strip() else ""
260
 
261
  generator = torch.Generator(device=device).manual_seed(seed)
262
 
@@ -289,21 +243,21 @@ def enhance_image(
289
  # Resize input image to match output size for slider alignment
290
  resized_input = true_input_image.resize(image.size, resample=Image.LANCZOS)
291
 
292
- # Move back to CPU to release GPU
293
- pipe.to("cpu")
294
- florence_model.to("cpu")
295
- if USE_ESRGAN:
296
- esrgan_model.to("cpu")
297
 
298
  return [resized_input, image]
299
 
300
 
301
  # Create Gradio interface
302
- with gr.Blocks(css=css, title="🎨 AI Image Upscaler - Florence-2 + FLUX") as demo:
303
  gr.HTML("""
304
  <div class="main-header">
305
  <h1>🎨 AI Image Upscaler</h1>
306
- <p>Upload an image or provide a URL to upscale it using Florence-2 captioning and FLUX upscaling</p>
307
  <p>Currently running on <strong>{}</strong></p>
308
  </div>
309
  """.format(power_device))
@@ -327,17 +281,11 @@ with gr.Blocks(css=css, title="🎨 AI Image Upscaler - Florence-2 + FLUX") as d
327
  value="https://upload.wikimedia.org/wikipedia/commons/thumb/a/a7/Example.jpg/800px-Example.jpg"
328
  )
329
 
330
- gr.HTML("<h3>πŸŽ›οΈ Caption Settings</h3>")
331
-
332
- use_generated_caption = gr.Checkbox(
333
- label="Use AI-generated caption (Florence-2)",
334
- value=True,
335
- info="Generate detailed caption automatically"
336
- )
337
 
338
  custom_prompt = gr.Textbox(
339
  label="Custom Prompt (optional)",
340
- placeholder="Enter custom prompt or leave empty for generated caption",
341
  lines=2
342
  )
343
 
@@ -412,7 +360,6 @@ with gr.Blocks(css=css, title="🎨 AI Image Upscaler - Florence-2 + FLUX") as d
412
  num_inference_steps,
413
  upscale_factor,
414
  denoising_strength,
415
- use_generated_caption,
416
  custom_prompt,
417
  ],
418
  outputs=[result_slider]
 
7
  import spaces
8
  import torch
9
  from diffusers import FluxImg2ImgPipeline
 
10
  from gradio_imageslider import ImageSlider
11
  from PIL import Image
12
  from huggingface_hub import snapshot_download
 
39
  # Get HuggingFace token
40
  huggingface_token = os.getenv("HF_TOKEN")
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  MAX_SEED = 1000000
43
  MAX_PIXEL_BUDGET = 8192 * 8192 # Increased for tiling support
44
 
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  def process_input(input_image, upscale_factor):
47
  """Process input image and handle size constraints"""
48
  w, h = input_image.size
 
143
  num_inference_steps,
144
  upscale_factor,
145
  denoising_strength,
 
146
  custom_prompt,
147
  progress=gr.Progress(track_tqdm=True),
148
  ):
149
  """Main enhancement function"""
150
+ # Lazy loading of models
151
+ global pipe, esrgan_model
152
+ if 'pipe' not in globals():
153
+ try:
154
+ device = "cuda" if torch.cuda.is_available() else "cpu"
155
+ dtype = torch.bfloat16 if device == "cuda" else torch.float32
156
+
157
+ print(f"πŸ“₯ Loading FLUX Img2Img on {device}...")
158
+ pipe = FluxImg2ImgPipeline.from_pretrained(
159
+ "black-forest-labs/FLUX.1-dev",
160
+ torch_dtype=dtype,
161
+ low_cpu_mem_usage=True,
162
+ device_map="auto"
163
+ )
164
+ pipe.enable_vae_tiling()
165
+ pipe.enable_vae_slicing()
166
+ pipe.enable_model_cpu_offload() if device == "cuda" else None
167
+
168
+ if USE_ESRGAN:
169
+ esrgan_path = "4x-UltraSharp.pth"
170
+ if not os.path.exists(esrgan_path):
171
+ url = "https://huggingface.co/uwg/upscaler/resolve/main/ESRGAN/4x-UltraSharp.pth"
172
+ with open(esrgan_path, "wb") as f:
173
+ f.write(requests.get(url).content)
174
+ esrgan_model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
175
+ state_dict = torch.load(esrgan_path)['params_ema']
176
+ esrgan_model.load_state_dict(state_dict)
177
+ esrgan_model.eval()
178
+ esrgan_model.to(device)
179
+
180
+ print("βœ… Models loaded successfully!")
181
+ except Exception as e:
182
+ print(f"Model loading error: {e}, falling back to CPU")
183
+ device = "cpu"
184
+ dtype = torch.float32
185
+ # Reload on CPU if needed
186
+ pipe = FluxImg2ImgPipeline.from_pretrained(
187
+ "black-forest-labs/FLUX.1-dev",
188
+ torch_dtype=dtype,
189
+ low_cpu_mem_usage=True,
190
+ device_map="auto"
191
+ )
192
+ pipe.enable_vae_tiling()
193
+ pipe.enable_vae_slicing()
194
 
195
  # Handle image input
196
  if image_input is not None:
 
210
  input_image, upscale_factor
211
  )
212
 
213
+ prompt = custom_prompt if custom_prompt.strip() else ""
 
 
 
 
 
 
214
 
215
  generator = torch.Generator(device=device).manual_seed(seed)
216
 
 
243
  # Resize input image to match output size for slider alignment
244
  resized_input = true_input_image.resize(image.size, resample=Image.LANCZOS)
245
 
246
+ # Move back to CPU to release GPU if possible
247
+ if device == "cuda":
248
+ pipe.to("cpu")
249
+ if USE_ESRGAN:
250
+ esrgan_model.to("cpu")
251
 
252
  return [resized_input, image]
253
 
254
 
255
  # Create Gradio interface
256
+ with gr.Blocks(css=css, title="🎨 AI Image Upscaler - FLUX") as demo:
257
  gr.HTML("""
258
  <div class="main-header">
259
  <h1>🎨 AI Image Upscaler</h1>
260
+ <p>Upload an image or provide a URL to upscale it using FLUX upscaling</p>
261
  <p>Currently running on <strong>{}</strong></p>
262
  </div>
263
  """.format(power_device))
 
281
  value="https://upload.wikimedia.org/wikipedia/commons/thumb/a/a7/Example.jpg/800px-Example.jpg"
282
  )
283
 
284
+ gr.HTML("<h3>πŸŽ›οΈ Prompt Settings</h3>")
 
 
 
 
 
 
285
 
286
  custom_prompt = gr.Textbox(
287
  label="Custom Prompt (optional)",
288
+ placeholder="Enter custom prompt or leave empty",
289
  lines=2
290
  )
291
 
 
360
  num_inference_steps,
361
  upscale_factor,
362
  denoising_strength,
 
363
  custom_prompt,
364
  ],
365
  outputs=[result_slider]