comrender commited on
Commit
c1ad781
Β·
verified Β·
1 Parent(s): 9ebfb7f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +394 -0
app.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import sys
4
+ from typing import Sequence, Mapping, Any, Union
5
+ import torch
6
+ import gradio as gr
7
+ from huggingface_hub import hf_hub_download
8
+ import spaces
9
+
10
+ # Download required models from Hugging Face
11
+ hf_hub_download(repo_id="black-forest-labs/FLUX.1-dev", filename="ae.safetensors", local_dir="models/vae")
12
+ hf_hub_download(repo_id="black-forest-labs/FLUX.1-dev", filename="flux1-dev.safetensors", local_dir="models/diffusion_models")
13
+ hf_hub_download(repo_id="comfyanonymous/flux_text_encoders", filename="clip_l.safetensors", local_dir="models/text_encoders")
14
+ hf_hub_download(repo_id="comfyanonymous/flux_text_encoders", filename="t5xxl_fp16.safetensors", local_dir="models/text_encoders")
15
+ hf_hub_download(repo_id="kim2091/UltraSharp", filename="4x-UltraSharp.pth", local_dir="models/upscale_models")
16
+
17
+ def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
18
+ """Returns the value at the given index of a sequence or mapping."""
19
+ try:
20
+ return obj[index]
21
+ except KeyError:
22
+ return obj["result"][index]
23
+
24
+ def find_path(name: str, path: str = None) -> str:
25
+ """Recursively looks at parent folders starting from the given path until it finds the given name."""
26
+ if path is None:
27
+ path = os.getcwd()
28
+
29
+ if name in os.listdir(path):
30
+ path_name = os.path.join(path, name)
31
+ print(f"{name} found: {path_name}")
32
+ return path_name
33
+
34
+ parent_directory = os.path.dirname(path)
35
+ if parent_directory == path:
36
+ return None
37
+
38
+ return find_path(name, parent_directory)
39
+
40
+ def add_comfyui_directory_to_sys_path() -> None:
41
+ """Add 'ComfyUI' to the sys.path"""
42
+ comfyui_path = find_path("ComfyUI")
43
+ if comfyui_path is not None and os.path.isdir(comfyui_path):
44
+ sys.path.append(comfyui_path)
45
+ print(f"'{comfyui_path}' added to sys.path")
46
+
47
+ def add_extra_model_paths() -> None:
48
+ """Parse the optional extra_model_paths.yaml file and add the parsed paths to the sys.path."""
49
+ try:
50
+ from main import load_extra_path_config
51
+ except ImportError:
52
+ print("Could not import load_extra_path_config from main.py. Looking in utils.extra_config instead.")
53
+ from utils.extra_config import load_extra_path_config
54
+
55
+ extra_model_paths = find_path("extra_model_paths.yaml")
56
+ if extra_model_paths is not None:
57
+ load_extra_path_config(extra_model_paths)
58
+ else:
59
+ print("Could not find the extra_model_paths config file.")
60
+
61
+ add_comfyui_directory_to_sys_path()
62
+ add_extra_model_paths()
63
+
64
+ def import_custom_nodes() -> None:
65
+ """Find all custom nodes in the custom_nodes folder and add those node objects to NODE_CLASS_MAPPINGS"""
66
+ import asyncio
67
+ import execution
68
+ from nodes import init_extra_nodes
69
+ import server
70
+
71
+ loop = asyncio.new_event_loop()
72
+ asyncio.set_event_loop(loop)
73
+
74
+ server_instance = server.PromptServer(loop)
75
+ execution.PromptQueue(server_instance)
76
+ init_extra_nodes()
77
+
78
+ from nodes import NODE_CLASS_MAPPINGS
79
+
80
+ # Pre-load models outside the decorated function for ZeroGPU efficiency
81
+ import_custom_nodes()
82
+
83
+ # Initialize model loaders
84
+ dualcliploader = NODE_CLASS_MAPPINGS["DualCLIPLoader"]()
85
+ dualcliploader_54 = dualcliploader.load_clip(
86
+ clip_name1="clip_l.safetensors",
87
+ clip_name2="t5xxl_fp16.safetensors",
88
+ type="flux",
89
+ device="default",
90
+ )
91
+
92
+ upscalemodelloader = NODE_CLASS_MAPPINGS["UpscaleModelLoader"]()
93
+ upscalemodelloader_44 = upscalemodelloader.load_model(model_name="4x-UltraSharp.pth")
94
+
95
+ vaeloader = NODE_CLASS_MAPPINGS["VAELoader"]()
96
+ vaeloader_55 = vaeloader.load_vae(vae_name="ae.safetensors")
97
+
98
+ unetloader = NODE_CLASS_MAPPINGS["UNETLoader"]()
99
+ unetloader_58 = unetloader.load_unet(
100
+ unet_name="flux1-dev.safetensors", weight_dtype="default"
101
+ )
102
+
103
+ downloadandloadflorence2model = NODE_CLASS_MAPPINGS["DownloadAndLoadFlorence2Model"]()
104
+ downloadandloadflorence2model_52 = downloadandloadflorence2model.loadmodel(
105
+ model="microsoft/Florence-2-large", precision="fp16", attention="sdpa"
106
+ )
107
+
108
+ # Pre-load models to GPU for efficiency
109
+ from comfy import model_management
110
+ model_loaders = [dualcliploader_54, vaeloader_55, unetloader_58, downloadandloadflorence2model_52]
111
+ valid_models = [
112
+ getattr(loader[0], 'patcher', loader[0])
113
+ for loader in model_loaders
114
+ if not isinstance(loader[0], dict) and not isinstance(getattr(loader[0], 'patcher', None), dict)
115
+ ]
116
+ model_management.load_models_gpu(valid_models)
117
+
118
+ @spaces.GPU(duration=120) # Adjust duration based on your workflow speed
119
+ def enhance_image(image_input, upscale_factor, steps, cfg_scale, denoise_strength, guidance_scale):
120
+ """
121
+ Main function to enhance and upscale images using Florence-2 captioning and FLUX upscaling
122
+ """
123
+ try:
124
+ with torch.inference_mode():
125
+ # Handle different input types (file upload vs URL)
126
+ if isinstance(image_input, str) and image_input.startswith(('http://', 'https://')):
127
+ # Load from URL
128
+ load_image_from_url_mtb = NODE_CLASS_MAPPINGS["Load Image From Url (mtb)"]()
129
+ load_image_result = load_image_from_url_mtb.load(url=image_input)
130
+ else:
131
+ # Load from uploaded file
132
+ loadimage = NODE_CLASS_MAPPINGS["LoadImage"]()
133
+ load_image_result = loadimage.load_image(image=image_input)
134
+
135
+ # Generate detailed caption using Florence-2
136
+ florence2run = NODE_CLASS_MAPPINGS["Florence2Run"]()
137
+ florence2run_51 = florence2run.encode(
138
+ text_input="",
139
+ task="more_detailed_caption",
140
+ fill_mask=True,
141
+ keep_model_loaded=False,
142
+ max_new_tokens=1024,
143
+ num_beams=3,
144
+ do_sample=True,
145
+ output_mask_select="",
146
+ seed=random.randint(1, 2**64),
147
+ image=get_value_at_index(load_image_result, 0),
148
+ florence2_model=get_value_at_index(downloadandloadflorence2model_52, 0),
149
+ )
150
+
151
+ # Encode the generated caption
152
+ cliptextencode = NODE_CLASS_MAPPINGS["CLIPTextEncode"]()
153
+ cliptextencode_6 = cliptextencode.encode(
154
+ text=get_value_at_index(florence2run_51, 2),
155
+ clip=get_value_at_index(dualcliploader_54, 0),
156
+ )
157
+
158
+ # Encode empty negative prompt
159
+ cliptextencode_42 = cliptextencode.encode(
160
+ text="", clip=get_value_at_index(dualcliploader_54, 0)
161
+ )
162
+
163
+ # Set up upscale factor
164
+ primitivefloat = NODE_CLASS_MAPPINGS["PrimitiveFloat"]()
165
+ primitivefloat_60 = primitivefloat.execute(value=upscale_factor)
166
+
167
+ # Apply FLUX guidance
168
+ fluxguidance = NODE_CLASS_MAPPINGS["FluxGuidance"]()
169
+ fluxguidance_26 = fluxguidance.append(
170
+ guidance=guidance_scale,
171
+ conditioning=get_value_at_index(cliptextencode_6, 0)
172
+ )
173
+
174
+ # Perform ultimate upscaling
175
+ ultimatesdupscale = NODE_CLASS_MAPPINGS["UltimateSDUpscale"]()
176
+ ultimatesdupscale_50 = ultimatesdupscale.upscale(
177
+ upscale_by=get_value_at_index(primitivefloat_60, 0),
178
+ seed=random.randint(1, 2**64),
179
+ steps=steps,
180
+ cfg=cfg_scale,
181
+ sampler_name="euler",
182
+ scheduler="normal",
183
+ denoise=denoise_strength,
184
+ mode_type="Linear",
185
+ tile_width=1024,
186
+ tile_height=1024,
187
+ mask_blur=8,
188
+ tile_padding=32,
189
+ seam_fix_mode="None",
190
+ seam_fix_denoise=1,
191
+ seam_fix_width=64,
192
+ seam_fix_mask_blur=8,
193
+ seam_fix_padding=16,
194
+ force_uniform_tiles=True,
195
+ tiled_decode=False,
196
+ image=get_value_at_index(load_image_result, 0),
197
+ model=get_value_at_index(unetloader_58, 0),
198
+ positive=get_value_at_index(fluxguidance_26, 0),
199
+ negative=get_value_at_index(cliptextencode_42, 0),
200
+ vae=get_value_at_index(vaeloader_55, 0),
201
+ upscale_model=get_value_at_index(upscalemodelloader_44, 0),
202
+ )
203
+
204
+ # Save the result
205
+ saveimage = NODE_CLASS_MAPPINGS["SaveImage"]()
206
+ saveimage_43 = saveimage.save_images(
207
+ filename_prefix="enhanced_image",
208
+ images=get_value_at_index(ultimatesdupscale_50, 0),
209
+ )
210
+
211
+ # Return the path to the saved image
212
+ saved_path = f"output/{saveimage_43['ui']['images'][0]['filename']}"
213
+
214
+ # Also return the generated caption for user feedback
215
+ generated_caption = get_value_at_index(florence2run_51, 2)
216
+
217
+ return saved_path, generated_caption
218
+
219
+ except Exception as e:
220
+ print(f"Error in enhance_image: {str(e)}")
221
+ raise gr.Error(f"Enhancement failed: {str(e)}")
222
+
223
+ # Create the Gradio interface
224
+ def create_interface():
225
+ with gr.Blocks(
226
+ title="πŸš€ AI Image Enhancer - Florence-2 + FLUX",
227
+ theme=gr.themes.Soft(),
228
+ css="""
229
+ .gradio-container {
230
+ max-width: 1200px !important;
231
+ }
232
+ .main-header {
233
+ text-align: center;
234
+ margin-bottom: 2rem;
235
+ }
236
+ .result-gallery {
237
+ min-height: 400px;
238
+ }
239
+ """
240
+ ) as app:
241
+
242
+ gr.HTML("""
243
+ <div class="main-header">
244
+ <h1>🎨 AI Image Enhancer</h1>
245
+ <p>Upload an image or provide a URL to enhance it using Florence-2 captioning and FLUX upscaling</p>
246
+ </div>
247
+ """)
248
+
249
+ with gr.Row():
250
+ with gr.Column(scale=1):
251
+ gr.HTML("<h3>πŸ“€ Input Settings</h3>")
252
+
253
+ with gr.Tabs():
254
+ with gr.TabItem("πŸ“ Upload Image"):
255
+ image_upload = gr.Image(
256
+ label="Upload Image",
257
+ type="filepath",
258
+ height=300
259
+ )
260
+
261
+ with gr.TabItem("πŸ”— Image URL"):
262
+ image_url = gr.Textbox(
263
+ label="Image URL",
264
+ placeholder="https://example.com/image.jpg",
265
+ value="https://upload.wikimedia.org/wikipedia/commons/thumb/a/a7/Example.jpg/800px-Example.jpg"
266
+ )
267
+
268
+ gr.HTML("<h3>βš™οΈ Enhancement Settings</h3>")
269
+
270
+ upscale_factor = gr.Slider(
271
+ minimum=1.0,
272
+ maximum=4.0,
273
+ value=2.0,
274
+ step=0.5,
275
+ label="Upscale Factor",
276
+ info="How much to upscale the image"
277
+ )
278
+
279
+ steps = gr.Slider(
280
+ minimum=10,
281
+ maximum=50,
282
+ value=25,
283
+ step=5,
284
+ label="Steps",
285
+ info="Number of denoising steps"
286
+ )
287
+
288
+ cfg_scale = gr.Slider(
289
+ minimum=0.5,
290
+ maximum=10.0,
291
+ value=1.0,
292
+ step=0.5,
293
+ label="CFG Scale",
294
+ info="Classifier-free guidance scale"
295
+ )
296
+
297
+ denoise_strength = gr.Slider(
298
+ minimum=0.1,
299
+ maximum=1.0,
300
+ value=0.3,
301
+ step=0.1,
302
+ label="Denoise Strength",
303
+ info="How much to denoise the image"
304
+ )
305
+
306
+ guidance_scale = gr.Slider(
307
+ minimum=1.0,
308
+ maximum=10.0,
309
+ value=3.5,
310
+ step=0.5,
311
+ label="Guidance Scale",
312
+ info="FLUX guidance strength"
313
+ )
314
+
315
+ enhance_btn = gr.Button(
316
+ "πŸš€ Enhance Image",
317
+ variant="primary",
318
+ size="lg"
319
+ )
320
+
321
+ with gr.Column(scale=1):
322
+ gr.HTML("<h3>πŸ“Š Results</h3>")
323
+
324
+ output_image = gr.Image(
325
+ label="Enhanced Image",
326
+ type="filepath",
327
+ height=400,
328
+ interactive=False
329
+ )
330
+
331
+ generated_caption = gr.Textbox(
332
+ label="Generated Caption",
333
+ placeholder="The AI-generated caption will appear here...",
334
+ lines=3,
335
+ interactive=False
336
+ )
337
+
338
+ gr.HTML("""
339
+ <div style="margin-top: 1rem; padding: 1rem; background: #f0f0f0; border-radius: 8px;">
340
+ <h4>πŸ’‘ How it works:</h4>
341
+ <ol>
342
+ <li>Florence-2 analyzes your image and generates a detailed caption</li>
343
+ <li>FLUX uses this caption to guide the upscaling process</li>
344
+ <li>The result is an enhanced, higher-resolution image</li>
345
+ </ol>
346
+ </div>
347
+ """)
348
+
349
+ # Event handlers
350
+ def process_image(img_upload, img_url, upscale_f, steps_val, cfg_val, denoise_val, guidance_val):
351
+ # Determine input source
352
+ image_input = img_upload if img_upload is not None else img_url
353
+
354
+ if not image_input:
355
+ raise gr.Error("Please provide an image (upload or URL)")
356
+
357
+ return enhance_image(image_input, upscale_f, steps_val, cfg_val, denoise_val, guidance_val)
358
+
359
+ enhance_btn.click(
360
+ fn=process_image,
361
+ inputs=[
362
+ image_upload,
363
+ image_url,
364
+ upscale_factor,
365
+ steps,
366
+ cfg_scale,
367
+ denoise_strength,
368
+ guidance_scale
369
+ ],
370
+ outputs=[output_image, generated_caption]
371
+ )
372
+
373
+ # Example inputs
374
+ gr.Examples(
375
+ examples=[
376
+ [None, "https://upload.wikimedia.org/wikipedia/commons/thumb/a/a7/Example.jpg/800px-Example.jpg", 2.0, 25, 1.0, 0.3, 3.5],
377
+ [None, "https://picsum.photos/512/512", 2.0, 20, 1.5, 0.4, 4.0],
378
+ ],
379
+ inputs=[
380
+ image_upload,
381
+ image_url,
382
+ upscale_factor,
383
+ steps,
384
+ cfg_scale,
385
+ denoise_strength,
386
+ guidance_scale
387
+ ]
388
+ )
389
+
390
+ return app
391
+
392
+ if __name__ == "__main__":
393
+ app = create_interface()
394
+ app.launch(share=True, server_name="0.0.0.0", server_port=7860)