comrender commited on
Commit
93af3e2
Β·
verified Β·
1 Parent(s): be0c600

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +463 -191
app.py CHANGED
@@ -1,219 +1,491 @@
 
 
1
  import warnings
 
2
  import gradio as gr
3
- import torch
4
- from PIL import Image
5
- from transformers import AutoProcessor, AutoModelForCausalLM
6
- from diffusers import FluxImg2ImgPipeline
7
- import random
8
  import numpy as np
9
- import os
10
  import spaces
11
- import huggingface_hub
12
- import time
13
-
14
- huggingface_hub.constants.HF_HUB_DOWNLOAD_TIMEOUT = 60
 
 
 
15
 
 
16
  try:
17
- import basicsr
18
- # Assume basicsr interpolation setup
19
- interpolation = "basicsr" # Placeholder for actual basicsr usage
20
  except ImportError:
 
21
  warnings.warn("basicsr not installed; falling back to LANCZOS interpolation.")
22
- interpolation = Image.LANCZOS
23
 
24
- # Initialize models
25
- device = "cuda" if torch.cuda.is_available() else "cpu"
26
- dtype = torch.bfloat16
27
- huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
 
 
 
 
 
 
28
 
29
- # Load FLUX img2img pipeline directly to avoid auto_pipeline issues
30
- pipe = FluxImg2ImgPipeline.from_pretrained(
31
- "black-forest-labs/FLUX.1-dev",
32
- torch_dtype=dtype,
33
- token=huggingface_token
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  ).to(device)
35
- pipe.enable_vae_tiling() # To help with memory for large images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
- # Initialize Florence model with float32 to avoid dtype mismatch, with retry
38
- for attempt in range(5):
39
  try:
40
- florence_model = AutoModelForCausalLM.from_pretrained(
41
- 'microsoft/Florence-2-large',
42
- trust_remote_code=True,
43
- torch_dtype=torch.float32
44
- ).to(device).eval()
45
- florence_processor = AutoProcessor.from_pretrained(
46
- 'microsoft/Florence-2-large',
47
- trust_remote_code=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  )
49
- break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  except Exception as e:
51
- print(f"Attempt {attempt+1} to load Florence-2 failed: {e}")
52
- time.sleep(10)
53
- else:
54
- raise RuntimeError("Failed to load Florence-2 after multiple attempts")
55
-
56
- MAX_SEED = np.iinfo(np.int32).max
57
- MAX_IMAGE_SIZE = 2048
58
-
59
- # Florence caption function
60
- @spaces.GPU
61
- def florence_caption(image):
62
- if not isinstance(image, Image.Image):
63
- image = Image.fromarray(image)
64
- inputs = florence_processor(text="<DETAILED_CAPTION>", images=image, return_tensors="pt").to(device)
65
- generated_ids = florence_model.generate(
66
- input_ids=inputs["input_ids"],
67
- pixel_values=inputs["pixel_values"],
68
- max_new_tokens=1024,
69
- early_stopping=False,
70
- do_sample=False,
71
- num_beams=3,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  )
73
- generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
74
- parsed_answer = florence_processor.post_process_generation(
75
- generated_text,
76
- task="<DETAILED_CAPTION>",
77
- image_size=(image.width, image.height)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  )
79
- return parsed_answer["<DETAILED_CAPTION>"]
80
-
81
- # Tiled FLUX img2img function with fix for small dimensions and overlap
82
- def tiled_flux_img2img(image, prompt, strength, num_inference_steps, guidance_scale, tile_size=512, overlap=64):
83
- width, height = image.size
84
- # Resize to multiple of 16 to avoid dimension warnings
85
- width = (width // 16) * 16 if width >= 16 else 16
86
- height = (height // 16) * 16 if height >= 16 else 16
87
- if width != image.size[0] or height != image.size[1]:
88
- image = image.resize((width, height), resample=interpolation)
89
 
90
- result = Image.new('RGB', (width, height))
91
- stride = tile_size - overlap
 
92
 
93
- # Tile in both directions, handling small sizes
94
- for y in range(0, height, stride):
95
- for x in range(0, width, stride):
96
- tile_left = x
97
- tile_top = y
98
- tile_right = min(x + tile_size, width)
99
- tile_bottom = min(y + tile_size, height)
100
- tile = image.crop((tile_left, tile_top, tile_right, tile_bottom))
 
 
 
 
 
 
 
 
101
 
102
- # Skip if tile is too small
103
- if tile.width < 16 or tile.height < 16:
104
- continue
 
 
 
 
 
 
 
 
 
 
 
105
 
106
- # Generate with img2img
107
- generated_tile = pipe(
108
- prompt,
109
- image=tile,
110
- strength=strength,
111
- guidance_scale=guidance_scale,
112
- num_inference_steps=num_inference_steps
113
- ).images[0]
114
- generated_tile = generated_tile.resize(tile.size) # Ensure size match
115
 
116
- # Paste without blend if first tile
117
- if x == 0 and y == 0:
118
- result.paste(generated_tile, (tile_left, tile_top))
119
- continue
 
120
 
121
- # Vertical blend
122
- if y > 0:
123
- effective_overlap = min(overlap, tile_bottom - tile_top, height - tile_top)
124
- if effective_overlap > 0:
125
- mask = Image.new('L', (tile_right - tile_left, effective_overlap))
126
- for i in range(mask.width):
127
- for j in range(mask.height):
128
- divisor = effective_overlap - 1 if effective_overlap > 1 else 1
129
- mask.putpixel((i, j), int(255 * (j / divisor)))
130
- blend_region = Image.composite(
131
- generated_tile.crop((0, 0, mask.width, mask.height)),
132
- result.crop((tile_left, tile_top, tile_right, tile_top + mask.height)),
133
- mask
134
- )
135
- result.paste(blend_region, (tile_left, tile_top))
136
- result.paste(generated_tile.crop((0, effective_overlap, generated_tile.width, generated_tile.height)), (tile_left, tile_top + effective_overlap))
137
- else:
138
- result.paste(generated_tile, (tile_left, tile_top))
139
 
140
- # Horizontal blend
141
- if x > 0:
142
- effective_overlap_h = min(overlap, tile_right - tile_left, width - tile_left)
143
- if effective_overlap_h > 0:
144
- mask_h = Image.new('L', (effective_overlap_h, tile_bottom - tile_top))
145
- for i in range(mask_h.width):
146
- for j in range(mask_h.height):
147
- divisor_h = effective_overlap_h - 1 if effective_overlap_h > 1 else 1
148
- mask_h.putpixel((i, j), int(255 * (i / divisor_h)))
149
- blend_region_h = Image.composite(
150
- generated_tile.crop((0, 0, mask_h.width, mask_h.height)),
151
- result.crop((tile_left, tile_top, tile_left + mask_h.width, tile_bottom)),
152
- mask_h
153
- )
154
- result.paste(blend_region_h, (tile_left, tile_top))
155
- result.paste(generated_tile.crop((effective_overlap_h, 0, generated_tile.width, generated_tile.height)), (tile_left + effective_overlap_h, tile_top))
156
- else:
157
- result.paste(generated_tile, (tile_left, tile_top))
158
 
159
- return result
160
-
161
- # Main enhance function
162
- @spaces.GPU(duration=190)
163
- def enhance_image(image, text_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, strength, progress=gr.Progress(track_tqdm=True)):
164
- prompt = text_prompt
165
- if image is not None:
166
- prompt = florence_caption(image)
167
- if randomize_seed:
168
- seed = random.randint(0, MAX_SEED)
169
- generator = torch.Generator(device=device).manual_seed(seed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
- # Use tiled if large, else direct
172
- if image and (image.size[0] > MAX_IMAGE_SIZE or image.size[1] > MAX_IMAGE_SIZE):
173
- output_image = tiled_flux_img2img(image, prompt, strength, num_inference_steps, guidance_scale)
174
- else:
175
- kw = {}
176
- if image is not None:
177
- kw['image'] = image
178
- kw['strength'] = strength
179
- else:
180
- kw['width'] = width
181
- kw['height'] = height
182
- output_image = pipe(
183
- prompt,
184
- generator=generator,
185
- num_inference_steps=num_inference_steps,
186
- guidance_scale=guidance_scale,
187
- **kw
188
- ).images[0]
189
- return output_image, prompt, seed
190
-
191
- # Gradio interface
192
- title = "<h1 align='center'>FLUX Image Enhancer with Florence-2 Captioner</h1>"
193
- with gr.Blocks() as demo:
194
- gr.HTML(title)
195
- with gr.Row():
196
- with gr.Column():
197
- input_image = gr.Image(label="Upload Image")
198
- text_prompt = gr.Textbox(label="Text Prompt (if no image)")
199
- strength = gr.Slider(label="Strength", minimum=0.1, maximum=1.0, value=0.8)
200
- guidance_scale = gr.Slider(label="Guidance Scale", minimum=1, maximum=10, value=5.0)
201
- num_inference_steps = gr.Slider(label="Steps", minimum=10, maximum=50, value=20)
202
- seed = gr.Number(value=42, label="Seed")
203
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
204
- width = gr.Slider(minimum=256, maximum=1024, step=16, value=512, label="Width")
205
- height = gr.Slider(minimum=256, maximum=1024, step=16, value=512, label="Height")
206
- submit = gr.Button("Enhance")
207
- with gr.Column():
208
- output_image = gr.Image(label="Enhanced Image")
209
- output_prompt = gr.Textbox(label="Generated Prompt")
210
- output_seed = gr.Number(label="Used Seed")
211
 
212
- submit.click(
213
- enhance_image,
214
- inputs=[input_image, text_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, strength],
215
- outputs=[output_image, output_prompt, output_seed]
216
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
- print("βœ… All models loaded successfully!")
219
- demo.launch(server_port=7860, server_name="0.0.0.0")
 
1
+ import logging
2
+ import random
3
  import warnings
4
+ import os
5
  import gradio as gr
 
 
 
 
 
6
  import numpy as np
 
7
  import spaces
8
+ import torch
9
+ from diffusers import FluxImg2ImgPipeline
10
+ from transformers import AutoProcessor, AutoModelForCausalLM
11
+ from gradio_imageslider import ImageSlider
12
+ from PIL import Image
13
+ from huggingface_hub import snapshot_download
14
+ import requests
15
 
16
+ # For ESRGAN (requires pip install basicsr gfpgan)
17
  try:
18
+ from basicsr.archs.rrdbnet_arch import RRDBNet
19
+ from basicsr.utils import img2tensor, tensor2img
20
+ USE_ESRGAN = True
21
  except ImportError:
22
+ USE_ESRGAN = False
23
  warnings.warn("basicsr not installed; falling back to LANCZOS interpolation.")
 
24
 
25
+ css = """
26
+ #col-container {
27
+ margin: 0 auto;
28
+ max-width: 800px;
29
+ }
30
+ .main-header {
31
+ text-align: center;
32
+ margin-bottom: 2rem;
33
+ }
34
+ """
35
 
36
+ # Device setup - Force CPU for startup in ZeroGPU
37
+ power_device = "ZeroGPU"
38
+ device = "cpu"
39
+
40
+ # Get HuggingFace token
41
+ huggingface_token = os.getenv("HF_TOKEN")
42
+
43
+ # Download FLUX model
44
+ print("πŸ“₯ Downloading FLUX model...")
45
+ model_path = snapshot_download(
46
+ repo_id="black-forest-labs/FLUX.1-dev",
47
+ repo_type="model",
48
+ ignore_patterns=["*.md", "*.gitattributes"],
49
+ local_dir="FLUX.1-dev",
50
+ token=huggingface_token,
51
+ )
52
+
53
+ # Load Florence-2 model for image captioning on CPU
54
+ print("πŸ“₯ Loading Florence-2 model...")
55
+ florence_model = AutoModelForCausalLM.from_pretrained(
56
+ "microsoft/Florence-2-large",
57
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
58
+ trust_remote_code=True,
59
+ attn_implementation="eager"
60
  ).to(device)
61
+ florence_processor = AutoProcessor.from_pretrained(
62
+ "microsoft/Florence-2-large",
63
+ trust_remote_code=True
64
+ )
65
+
66
+ # Load FLUX Img2Img pipeline on CPU
67
+ print("πŸ“₯ Loading FLUX Img2Img...")
68
+ pipe = FluxImg2ImgPipeline.from_pretrained(
69
+ model_path,
70
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
71
+ )
72
+ pipe.enable_vae_tiling()
73
+ pipe.enable_vae_slicing()
74
+
75
+ print("βœ… All models loaded successfully!")
76
+
77
+ # Download ESRGAN model if using
78
+ if USE_ESRGAN:
79
+ esrgan_path = "4x-UltraSharp.pth"
80
+ if not os.path.exists(esrgan_path):
81
+ url = "https://huggingface.co/uwg/upscaler/resolve/main/ESRGAN/4x-UltraSharp.pth"
82
+ with open(esrgan_path, "wb") as f:
83
+ f.write(requests.get(url).content)
84
+ esrgan_model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
85
+ state_dict = torch.load(esrgan_path)['params_ema']
86
+ esrgan_model.load_state_dict(state_dict)
87
+ esrgan_model.eval()
88
+
89
+ MAX_SEED = 1000000
90
+ MAX_PIXEL_BUDGET = 8192 * 8192 # Increased for tiling support
91
+
92
 
93
+ def generate_caption(image):
94
+ """Generate detailed caption using Florence-2"""
95
  try:
96
+ task_prompt = "<MORE_DETAILED_CAPTION>"
97
+ prompt = task_prompt
98
+
99
+ inputs = florence_processor(text=prompt, images=image, return_tensors="pt").to(florence_model.device) # Fixed: Use model's current device instead of static 'device'
100
+
101
+ generated_ids = florence_model.generate(
102
+ input_ids=inputs["input_ids"],
103
+ pixel_values=inputs["pixel_values"],
104
+ max_new_tokens=1024,
105
+ num_beams=3,
106
+ do_sample=True,
107
+ )
108
+
109
+ generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
110
+ parsed_answer = florence_processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.width, image.height))
111
+
112
+ caption = parsed_answer[task_prompt]
113
+ return caption
114
+ except Exception as e:
115
+ print(f"Caption generation failed: {e}")
116
+ return "a high quality detailed image"
117
+
118
+
119
+ def process_input(input_image, upscale_factor):
120
+ """Process input image and handle size constraints"""
121
+ w, h = input_image.size
122
+ w_original, h_original = w, h
123
+ aspect_ratio = w / h
124
+
125
+ was_resized = False
126
+
127
+ if w * h * upscale_factor**2 > MAX_PIXEL_BUDGET:
128
+ warnings.warn(
129
+ f"Requested output image is too large ({w * upscale_factor}x{h * upscale_factor}). Resizing to fit budget."
130
+ )
131
+ gr.Info(
132
+ f"Requested output image is too large. Resizing input to fit within pixel budget."
133
  )
134
+ target_input_pixels = MAX_PIXEL_BUDGET / (upscale_factor ** 2)
135
+ scale = (target_input_pixels / (w * h)) ** 0.5
136
+ new_w = int(w * scale) - int(w * scale) % 16 # Fixed: Use % 16 for FLUX alignment (was % 8)
137
+ new_h = int(h * scale) - int(h * scale) % 16 # Fixed: Use % 16 for FLUX alignment (was % 8)
138
+ input_image = input_image.resize((new_w, new_h), resample=Image.LANCZOS)
139
+ was_resized = True
140
+
141
+ return input_image, w_original, h_original, was_resized
142
+
143
+
144
+ def load_image_from_url(url):
145
+ """Load image from URL"""
146
+ try:
147
+ response = requests.get(url, stream=True)
148
+ response.raise_for_status()
149
+ return Image.open(response.raw)
150
  except Exception as e:
151
+ raise gr.Error(f"Failed to load image from URL: {e}")
152
+
153
+
154
+ def esrgan_upscale(image, scale=4):
155
+ if not USE_ESRGAN:
156
+ return image.resize((image.width * scale, image.height * scale), resample=Image.LANCZOS)
157
+ img = img2tensor(np.array(image) / 255., bgr2rgb=False, float32=True)
158
+ with torch.no_grad():
159
+ output = esrgan_model(img.unsqueeze(0)).squeeze()
160
+ output_img = tensor2img(output, rgb2bgr=False, min_max=(0, 1))
161
+ return Image.fromarray(output_img)
162
+
163
+
164
+ def tiled_flux_img2img(pipe, prompt, image, strength, steps, guidance, generator, tile_size=1024, overlap=32):
165
+ """Tiled Img2Img to mimic Ultimate SD Upscaler tiling"""
166
+ w, h = image.size
167
+ output = image.copy() # Start with the control image
168
+
169
+ for x in range(0, w, tile_size - overlap):
170
+ for y in range(0, h, tile_size - overlap):
171
+ tile_w = min(tile_size, w - x)
172
+ tile_h = min(tile_size, h - y)
173
+ tile = image.crop((x, y, x + tile_w, y + tile_h))
174
+
175
+ # Run Flux on tile
176
+ gen_tile = pipe(
177
+ prompt=prompt,
178
+ image=tile,
179
+ strength=strength,
180
+ num_inference_steps=steps,
181
+ guidance_scale=guidance,
182
+ height=tile_h,
183
+ width=tile_w,
184
+ generator=generator,
185
+ ).images[0]
186
+
187
+ # Fixed: Resize generated tile back to exact tile dimensions if pipeline auto-resized for multiple-of-16 requirement
188
+ gen_tile = gen_tile.resize((tile_w, tile_h), resample=Image.LANCZOS)
189
+
190
+ # Paste with blending if overlap
191
+ if overlap > 0:
192
+ paste_box = (x, y, x + tile_w, y + tile_h)
193
+ if x > 0 or y > 0:
194
+ # Simple linear blend on overlaps
195
+ mask = Image.new('L', (tile_w, tile_h), 255)
196
+ if x > 0:
197
+ for i in range(overlap):
198
+ for j in range(tile_h):
199
+ mask.putpixel((i, j), int(255 * (i / overlap)))
200
+ if y > 0:
201
+ for i in range(tile_w):
202
+ for j in range(overlap):
203
+ mask.putpixel((i, j), int(255 * (j / overlap)))
204
+ output.paste(gen_tile, paste_box, mask)
205
+ else:
206
+ output.paste(gen_tile, paste_box)
207
+ else:
208
+ output.paste(gen_tile, (x, y))
209
+
210
+ return output
211
+
212
+
213
+ @spaces.GPU(duration=120)
214
+ def enhance_image(
215
+ image_input,
216
+ image_url,
217
+ seed,
218
+ randomize_seed,
219
+ num_inference_steps,
220
+ upscale_factor,
221
+ denoising_strength,
222
+ use_generated_caption,
223
+ custom_prompt,
224
+ progress=gr.Progress(track_tqdm=True),
225
+ ):
226
+ """Main enhancement function"""
227
+ # Move models to GPU inside the function
228
+ pipe.to("cuda")
229
+ florence_model.to("cuda")
230
+
231
+ # Handle image input
232
+ if image_input is not None:
233
+ input_image = image_input
234
+ elif image_url:
235
+ input_image = load_image_from_url(image_url)
236
+ else:
237
+ raise gr.Error("Please provide an image (upload or URL)")
238
+
239
+ if randomize_seed:
240
+ seed = random.randint(0, MAX_SEED)
241
+
242
+ true_input_image = input_image
243
+
244
+ # Process input image
245
+ input_image, w_original, h_original, was_resized = process_input(
246
+ input_image, upscale_factor
247
  )
248
+
249
+ # Generate caption if requested
250
+ if use_generated_caption:
251
+ gr.Info("πŸ” Generating image caption...")
252
+ generated_caption = generate_caption(input_image)
253
+ prompt = generated_caption
254
+ else:
255
+ prompt = custom_prompt if custom_prompt.strip() else ""
256
+
257
+ generator = torch.Generator(device="cuda").manual_seed(seed)
258
+
259
+ gr.Info("πŸš€ Upscaling image...")
260
+
261
+ # Initial upscale
262
+ if USE_ESRGAN and upscale_factor == 4:
263
+ esrgan_model.to("cuda")
264
+ control_image = esrgan_upscale(input_image, upscale_factor)
265
+ esrgan_model.to("cpu")
266
+ else:
267
+ w, h = input_image.size
268
+ control_image = input_image.resize((w * upscale_factor, h * upscale_factor), resample=Image.LANCZOS)
269
+
270
+ # Tiled Flux Img2Img for refinement
271
+ image = tiled_flux_img2img(
272
+ pipe,
273
+ prompt,
274
+ control_image,
275
+ denoising_strength,
276
+ num_inference_steps,
277
+ 1.0, # Hardcoded guidance_scale to 1
278
+ generator,
279
+ tile_size=1024,
280
+ overlap=32
281
  )
282
+
283
+ if was_resized:
284
+ gr.Info(f"πŸ“ Resizing output to target size: {w_original * upscale_factor}x{h_original * upscale_factor}")
285
+ image = image.resize((w_original * upscale_factor, h_original * upscale_factor), resample=Image.LANCZOS)
286
+
287
+ # Resize input image to match output size for slider alignment
288
+ resized_input = true_input_image.resize(image.size, resample=Image.LANCZOS)
 
 
 
289
 
290
+ # Move back to CPU to release GPU
291
+ pipe.to("cpu")
292
+ florence_model.to("cpu")
293
 
294
+ return [resized_input, image]
295
+
296
+
297
+ # Create Gradio interface
298
+ with gr.Blocks(css=css, title="🎨 AI Image Upscaler - Florence-2 + FLUX") as demo:
299
+ gr.HTML("""
300
+ <div class="main-header">
301
+ <h1>🎨 AI Image Upscaler</h1>
302
+ <p>Upload an image or provide a URL to upscale it using Florence-2 captioning and FLUX upscaling</p>
303
+ <p>Currently running on <strong>{}</strong></p>
304
+ </div>
305
+ """.format(power_device))
306
+
307
+ with gr.Row():
308
+ with gr.Column(scale=1):
309
+ gr.HTML("<h3>πŸ“€ Input</h3>")
310
 
311
+ with gr.Tabs():
312
+ with gr.TabItem("πŸ“ Upload Image"):
313
+ input_image = gr.Image(
314
+ label="Upload Image",
315
+ type="pil",
316
+ height=200 # Made smaller
317
+ )
318
+
319
+ with gr.TabItem("πŸ”— Image URL"):
320
+ image_url = gr.Textbox(
321
+ label="Image URL",
322
+ placeholder="https://example.com/image.jpg",
323
+ value="https://upload.wikimedia.org/wikipedia/commons/thumb/a/a7/Example.jpg/800px-Example.jpg"
324
+ )
325
 
326
+ gr.HTML("<h3>πŸŽ›οΈ Caption Settings</h3>")
 
 
 
 
 
 
 
 
327
 
328
+ use_generated_caption = gr.Checkbox(
329
+ label="Use AI-generated caption (Florence-2)",
330
+ value=True,
331
+ info="Generate detailed caption automatically"
332
+ )
333
 
334
+ custom_prompt = gr.Textbox(
335
+ label="Custom Prompt (optional)",
336
+ placeholder="Enter custom prompt or leave empty for generated caption",
337
+ lines=2
338
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
339
 
340
+ gr.HTML("<h3>βš™οΈ Upscaling Settings</h3>")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341
 
342
+ upscale_factor = gr.Slider(
343
+ label="Upscale Factor",
344
+ minimum=1,
345
+ maximum=4,
346
+ step=1,
347
+ value=2,
348
+ info="How much to upscale the image"
349
+ )
350
+
351
+ num_inference_steps = gr.Slider(
352
+ label="Number of Inference Steps",
353
+ minimum=8,
354
+ maximum=50,
355
+ step=1,
356
+ value=25,
357
+ info="More steps = better quality but slower"
358
+ )
359
+
360
+ denoising_strength = gr.Slider(
361
+ label="Denoising Strength",
362
+ minimum=0.0,
363
+ maximum=1.0,
364
+ step=0.05,
365
+ value=0.3,
366
+ info="Controls how much the image is transformed"
367
+ )
368
+
369
+ with gr.Row():
370
+ randomize_seed = gr.Checkbox(
371
+ label="Randomize seed",
372
+ value=True
373
+ )
374
+ seed = gr.Slider(
375
+ label="Seed",
376
+ minimum=0,
377
+ maximum=MAX_SEED,
378
+ step=1,
379
+ value=42,
380
+ interactive=True
381
+ )
382
+
383
+ enhance_btn = gr.Button(
384
+ "πŸš€ Upscale Image",
385
+ variant="primary",
386
+ size="lg"
387
+ )
388
+
389
+ with gr.Column(scale=2): # Larger scale for results
390
+ gr.HTML("<h3>πŸ“Š Results</h3>")
391
+
392
+ result_slider = ImageSlider(
393
+ type="pil",
394
+ interactive=False, # Disable interactivity to prevent uploads
395
+ height=600, # Made larger
396
+ elem_id="result_slider",
397
+ label=None # Remove default label
398
+ )
399
+
400
+ # Event handler
401
+ enhance_btn.click(
402
+ fn=enhance_image,
403
+ inputs=[
404
+ input_image,
405
+ image_url,
406
+ seed,
407
+ randomize_seed,
408
+ num_inference_steps,
409
+ upscale_factor,
410
+ denoising_strength,
411
+ use_generated_caption,
412
+ custom_prompt,
413
+ ],
414
+ outputs=[result_slider]
415
+ )
416
 
417
+ gr.HTML("""
418
+ <div style="margin-top: 2rem; padding: 1rem; background: #f0f0f0; border-radius: 8px;">
419
+ <p><strong>Note:</strong> This upscaler uses the Flux dev model. Users are responsible for obtaining commercial rights if used commercially under their license.</p>
420
+ </div>
421
+ """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
422
 
423
+ # Custom CSS for slider
424
+ gr.HTML("""
425
+ <style>
426
+ #result_slider .slider {
427
+ width: 100% !important;
428
+ max-width: inherit !important;
429
+ }
430
+ #result_slider img {
431
+ object-fit: contain !important;
432
+ width: 100% !important;
433
+ height: auto !important;
434
+ }
435
+ #result_slider .gr-button-tool {
436
+ display: none !important;
437
+ }
438
+ #result_slider .gr-button-undo {
439
+ display: none !important;
440
+ }
441
+ #result_slider .gr-button-clear {
442
+ display: none !important;
443
+ }
444
+ #result_slider .badge-container .badge {
445
+ display: none !important;
446
+ }
447
+ #result_slider .badge-container::before {
448
+ content: "Before";
449
+ position: absolute;
450
+ top: 10px;
451
+ left: 10px;
452
+ background: rgba(0,0,0,0.5);
453
+ color: white;
454
+ padding: 5px;
455
+ border-radius: 5px;
456
+ z-index: 10;
457
+ }
458
+ #result_slider .badge-container::after {
459
+ content: "After";
460
+ position: absolute;
461
+ top: 10px;
462
+ right: 10px;
463
+ background: rgba(0,0,0,0.5);
464
+ color: white;
465
+ padding: 5px;
466
+ border-radius: 5px;
467
+ z-index: 10;
468
+ }
469
+ #result_slider .fullscreen img {
470
+ object-fit: contain !important;
471
+ width: 100vw !important;
472
+ height: 100vh !important;
473
+ }
474
+ </style>
475
+ """)
476
+
477
+ # JS to set slider default position to middle
478
+ gr.HTML("""
479
+ <script>
480
+ document.addEventListener('DOMContentLoaded', function() {
481
+ const sliderInput = document.querySelector('#result_slider input[type="range"]');
482
+ if (sliderInput) {
483
+ sliderInput.value = 50;
484
+ sliderInput.dispatchEvent(new Event('input'));
485
+ }
486
+ });
487
+ </script>
488
+ """)
489
 
490
+ if __name__ == "__main__":
491
+ demo.queue().launch(share=True, server_name="0.0.0.0", server_port=7860)