Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -7,7 +7,6 @@ import numpy as np
|
|
7 |
import spaces
|
8 |
import torch
|
9 |
from diffusers import FluxImg2ImgPipeline
|
10 |
-
from transformers import AutoProcessor, AutoModelForCausalLM
|
11 |
from gradio_imageslider import ImageSlider
|
12 |
from PIL import Image
|
13 |
from huggingface_hub import snapshot_download
|
@@ -40,82 +39,10 @@ device = "cpu"
|
|
40 |
# Get HuggingFace token
|
41 |
huggingface_token = os.getenv("HF_TOKEN")
|
42 |
|
43 |
-
# Download FLUX model
|
44 |
-
print("π₯ Downloading FLUX model...")
|
45 |
-
model_path = snapshot_download(
|
46 |
-
repo_id="black-forest-labs/FLUX.1-dev",
|
47 |
-
repo_type="model",
|
48 |
-
ignore_patterns=["*.md", "*.gitattributes"],
|
49 |
-
local_dir="FLUX.1-dev",
|
50 |
-
token=huggingface_token,
|
51 |
-
)
|
52 |
-
|
53 |
-
# Load Florence-2 model for image captioning on CPU
|
54 |
-
print("π₯ Loading Florence-2 model...")
|
55 |
-
florence_model = AutoModelForCausalLM.from_pretrained(
|
56 |
-
"microsoft/Florence-2-large",
|
57 |
-
torch_dtype=torch.float32, # Force CPU dtype
|
58 |
-
trust_remote_code=True,
|
59 |
-
attn_implementation="eager"
|
60 |
-
).to(device)
|
61 |
-
florence_processor = AutoProcessor.from_pretrained(
|
62 |
-
"microsoft/Florence-2-large",
|
63 |
-
trust_remote_code=True
|
64 |
-
)
|
65 |
-
|
66 |
-
# Load FLUX Img2Img pipeline on CPU
|
67 |
-
print("π₯ Loading FLUX Img2Img...")
|
68 |
-
pipe = FluxImg2ImgPipeline.from_pretrained(
|
69 |
-
model_path,
|
70 |
-
torch_dtype=torch.float32 # Force CPU dtype
|
71 |
-
)
|
72 |
-
pipe.enable_vae_tiling()
|
73 |
-
pipe.enable_vae_slicing()
|
74 |
-
|
75 |
-
print("β
All models loaded successfully!")
|
76 |
-
|
77 |
-
# Download ESRGAN model if using
|
78 |
-
if USE_ESRGAN:
|
79 |
-
esrgan_path = "4x-UltraSharp.pth"
|
80 |
-
if not os.path.exists(esrgan_path):
|
81 |
-
url = "https://huggingface.co/uwg/upscaler/resolve/main/ESRGAN/4x-UltraSharp.pth"
|
82 |
-
with open(esrgan_path, "wb") as f:
|
83 |
-
f.write(requests.get(url).content)
|
84 |
-
esrgan_model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
|
85 |
-
state_dict = torch.load(esrgan_path)['params_ema']
|
86 |
-
esrgan_model.load_state_dict(state_dict)
|
87 |
-
esrgan_model.eval()
|
88 |
-
|
89 |
MAX_SEED = 1000000
|
90 |
MAX_PIXEL_BUDGET = 8192 * 8192 # Increased for tiling support
|
91 |
|
92 |
|
93 |
-
def generate_caption(image):
|
94 |
-
"""Generate detailed caption using Florence-2"""
|
95 |
-
try:
|
96 |
-
task_prompt = "<MORE_DETAILED_CAPTION>"
|
97 |
-
prompt = task_prompt
|
98 |
-
|
99 |
-
inputs = florence_processor(text=prompt, images=image, return_tensors="pt").to(device)
|
100 |
-
|
101 |
-
generated_ids = florence_model.generate(
|
102 |
-
input_ids=inputs["input_ids"],
|
103 |
-
pixel_values=inputs["pixel_values"],
|
104 |
-
max_new_tokens=1024,
|
105 |
-
num_beams=3,
|
106 |
-
do_sample=True,
|
107 |
-
)
|
108 |
-
|
109 |
-
generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
|
110 |
-
parsed_answer = florence_processor.post_process_generation(generated_text, task=task_prompt, image_size=(image.width, image.height))
|
111 |
-
|
112 |
-
caption = parsed_answer[task_prompt]
|
113 |
-
return caption
|
114 |
-
except Exception as e:
|
115 |
-
print(f"Caption generation failed: {e}")
|
116 |
-
return "a high quality detailed image"
|
117 |
-
|
118 |
-
|
119 |
def process_input(input_image, upscale_factor):
|
120 |
"""Process input image and handle size constraints"""
|
121 |
w, h = input_image.size
|
@@ -216,21 +143,54 @@ def enhance_image(
|
|
216 |
num_inference_steps,
|
217 |
upscale_factor,
|
218 |
denoising_strength,
|
219 |
-
use_generated_caption,
|
220 |
custom_prompt,
|
221 |
progress=gr.Progress(track_tqdm=True),
|
222 |
):
|
223 |
"""Main enhancement function"""
|
224 |
-
#
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
234 |
|
235 |
# Handle image input
|
236 |
if image_input is not None:
|
@@ -250,13 +210,7 @@ def enhance_image(
|
|
250 |
input_image, upscale_factor
|
251 |
)
|
252 |
|
253 |
-
|
254 |
-
if use_generated_caption:
|
255 |
-
gr.Info("π Generating image caption...")
|
256 |
-
generated_caption = generate_caption(input_image)
|
257 |
-
prompt = generated_caption
|
258 |
-
else:
|
259 |
-
prompt = custom_prompt if custom_prompt.strip() else ""
|
260 |
|
261 |
generator = torch.Generator(device=device).manual_seed(seed)
|
262 |
|
@@ -289,21 +243,21 @@ def enhance_image(
|
|
289 |
# Resize input image to match output size for slider alignment
|
290 |
resized_input = true_input_image.resize(image.size, resample=Image.LANCZOS)
|
291 |
|
292 |
-
# Move back to CPU to release GPU
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
|
298 |
return [resized_input, image]
|
299 |
|
300 |
|
301 |
# Create Gradio interface
|
302 |
-
with gr.Blocks(css=css, title="π¨ AI Image Upscaler -
|
303 |
gr.HTML("""
|
304 |
<div class="main-header">
|
305 |
<h1>π¨ AI Image Upscaler</h1>
|
306 |
-
<p>Upload an image or provide a URL to upscale it using
|
307 |
<p>Currently running on <strong>{}</strong></p>
|
308 |
</div>
|
309 |
""".format(power_device))
|
@@ -327,17 +281,11 @@ with gr.Blocks(css=css, title="π¨ AI Image Upscaler - Florence-2 + FLUX") as d
|
|
327 |
value="https://upload.wikimedia.org/wikipedia/commons/thumb/a/a7/Example.jpg/800px-Example.jpg"
|
328 |
)
|
329 |
|
330 |
-
gr.HTML("<h3>ποΈ
|
331 |
-
|
332 |
-
use_generated_caption = gr.Checkbox(
|
333 |
-
label="Use AI-generated caption (Florence-2)",
|
334 |
-
value=True,
|
335 |
-
info="Generate detailed caption automatically"
|
336 |
-
)
|
337 |
|
338 |
custom_prompt = gr.Textbox(
|
339 |
label="Custom Prompt (optional)",
|
340 |
-
placeholder="Enter custom prompt or leave empty
|
341 |
lines=2
|
342 |
)
|
343 |
|
@@ -412,7 +360,6 @@ with gr.Blocks(css=css, title="π¨ AI Image Upscaler - Florence-2 + FLUX") as d
|
|
412 |
num_inference_steps,
|
413 |
upscale_factor,
|
414 |
denoising_strength,
|
415 |
-
use_generated_caption,
|
416 |
custom_prompt,
|
417 |
],
|
418 |
outputs=[result_slider]
|
|
|
7 |
import spaces
|
8 |
import torch
|
9 |
from diffusers import FluxImg2ImgPipeline
|
|
|
10 |
from gradio_imageslider import ImageSlider
|
11 |
from PIL import Image
|
12 |
from huggingface_hub import snapshot_download
|
|
|
39 |
# Get HuggingFace token
|
40 |
huggingface_token = os.getenv("HF_TOKEN")
|
41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
MAX_SEED = 1000000
|
43 |
MAX_PIXEL_BUDGET = 8192 * 8192 # Increased for tiling support
|
44 |
|
45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
def process_input(input_image, upscale_factor):
|
47 |
"""Process input image and handle size constraints"""
|
48 |
w, h = input_image.size
|
|
|
143 |
num_inference_steps,
|
144 |
upscale_factor,
|
145 |
denoising_strength,
|
|
|
146 |
custom_prompt,
|
147 |
progress=gr.Progress(track_tqdm=True),
|
148 |
):
|
149 |
"""Main enhancement function"""
|
150 |
+
# Lazy loading of models
|
151 |
+
global pipe, esrgan_model
|
152 |
+
if 'pipe' not in globals():
|
153 |
+
try:
|
154 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
155 |
+
dtype = torch.bfloat16 if device == "cuda" else torch.float32
|
156 |
+
|
157 |
+
print(f"π₯ Loading FLUX Img2Img on {device}...")
|
158 |
+
pipe = FluxImg2ImgPipeline.from_pretrained(
|
159 |
+
"black-forest-labs/FLUX.1-dev",
|
160 |
+
torch_dtype=dtype,
|
161 |
+
low_cpu_mem_usage=True,
|
162 |
+
device_map="auto"
|
163 |
+
)
|
164 |
+
pipe.enable_vae_tiling()
|
165 |
+
pipe.enable_vae_slicing()
|
166 |
+
pipe.enable_model_cpu_offload() if device == "cuda" else None
|
167 |
+
|
168 |
+
if USE_ESRGAN:
|
169 |
+
esrgan_path = "4x-UltraSharp.pth"
|
170 |
+
if not os.path.exists(esrgan_path):
|
171 |
+
url = "https://huggingface.co/uwg/upscaler/resolve/main/ESRGAN/4x-UltraSharp.pth"
|
172 |
+
with open(esrgan_path, "wb") as f:
|
173 |
+
f.write(requests.get(url).content)
|
174 |
+
esrgan_model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
|
175 |
+
state_dict = torch.load(esrgan_path)['params_ema']
|
176 |
+
esrgan_model.load_state_dict(state_dict)
|
177 |
+
esrgan_model.eval()
|
178 |
+
esrgan_model.to(device)
|
179 |
+
|
180 |
+
print("β
Models loaded successfully!")
|
181 |
+
except Exception as e:
|
182 |
+
print(f"Model loading error: {e}, falling back to CPU")
|
183 |
+
device = "cpu"
|
184 |
+
dtype = torch.float32
|
185 |
+
# Reload on CPU if needed
|
186 |
+
pipe = FluxImg2ImgPipeline.from_pretrained(
|
187 |
+
"black-forest-labs/FLUX.1-dev",
|
188 |
+
torch_dtype=dtype,
|
189 |
+
low_cpu_mem_usage=True,
|
190 |
+
device_map="auto"
|
191 |
+
)
|
192 |
+
pipe.enable_vae_tiling()
|
193 |
+
pipe.enable_vae_slicing()
|
194 |
|
195 |
# Handle image input
|
196 |
if image_input is not None:
|
|
|
210 |
input_image, upscale_factor
|
211 |
)
|
212 |
|
213 |
+
prompt = custom_prompt if custom_prompt.strip() else ""
|
|
|
|
|
|
|
|
|
|
|
|
|
214 |
|
215 |
generator = torch.Generator(device=device).manual_seed(seed)
|
216 |
|
|
|
243 |
# Resize input image to match output size for slider alignment
|
244 |
resized_input = true_input_image.resize(image.size, resample=Image.LANCZOS)
|
245 |
|
246 |
+
# Move back to CPU to release GPU if possible
|
247 |
+
if device == "cuda":
|
248 |
+
pipe.to("cpu")
|
249 |
+
if USE_ESRGAN:
|
250 |
+
esrgan_model.to("cpu")
|
251 |
|
252 |
return [resized_input, image]
|
253 |
|
254 |
|
255 |
# Create Gradio interface
|
256 |
+
with gr.Blocks(css=css, title="π¨ AI Image Upscaler - FLUX") as demo:
|
257 |
gr.HTML("""
|
258 |
<div class="main-header">
|
259 |
<h1>π¨ AI Image Upscaler</h1>
|
260 |
+
<p>Upload an image or provide a URL to upscale it using FLUX upscaling</p>
|
261 |
<p>Currently running on <strong>{}</strong></p>
|
262 |
</div>
|
263 |
""".format(power_device))
|
|
|
281 |
value="https://upload.wikimedia.org/wikipedia/commons/thumb/a/a7/Example.jpg/800px-Example.jpg"
|
282 |
)
|
283 |
|
284 |
+
gr.HTML("<h3>ποΈ Prompt Settings</h3>")
|
|
|
|
|
|
|
|
|
|
|
|
|
285 |
|
286 |
custom_prompt = gr.Textbox(
|
287 |
label="Custom Prompt (optional)",
|
288 |
+
placeholder="Enter custom prompt or leave empty",
|
289 |
lines=2
|
290 |
)
|
291 |
|
|
|
360 |
num_inference_steps,
|
361 |
upscale_factor,
|
362 |
denoising_strength,
|
|
|
363 |
custom_prompt,
|
364 |
],
|
365 |
outputs=[result_slider]
|