Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,203 +1,364 @@
|
|
1 |
-
#
|
2 |
-
|
3 |
import os
|
4 |
import sys
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
#
|
11 |
-
import logging
|
12 |
-
import random
|
13 |
-
import warnings
|
14 |
import io
|
15 |
import base64
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
|
|
17 |
import gradio as gr
|
18 |
-
import numpy as np
|
19 |
import spaces
|
20 |
-
import
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
}
|
36 |
-
|
37 |
-
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
}
|
40 |
-
""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
else:
|
47 |
-
power_device = "CPU"
|
48 |
-
device = "cpu"
|
49 |
-
torch_dtype = torch.float32
|
50 |
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
-
#
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
pipe = None
|
58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
try:
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
logging.info("ControlNet loaded.")
|
75 |
-
|
76 |
-
logging.info("Initializing FluxControlNetPipeline…")
|
77 |
-
pipe = FluxControlNetPipeline.from_pretrained(
|
78 |
-
model_path,
|
79 |
-
controlnet=controlnet,
|
80 |
-
torch_dtype=torch_dtype
|
81 |
-
).to(device)
|
82 |
-
logging.info("Pipeline ready.")
|
83 |
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
|
|
|
|
|
|
|
|
88 |
|
89 |
-
#
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
else:
|
111 |
-
was_resized = False
|
112 |
-
|
113 |
-
# round dimensions to multiples of 8
|
114 |
-
w2, h2 = img.size
|
115 |
-
w2 -= w2 % 8; h2 -= h2 % 8
|
116 |
-
if img.size != (w2,h2):
|
117 |
-
img = img.resize((w2,h2), Image.Resampling.LANCZOS)
|
118 |
-
|
119 |
-
return img, w, h, was_resized
|
120 |
-
|
121 |
-
@spaces.GPU(duration=75)
|
122 |
-
def infer(
|
123 |
-
seed,
|
124 |
-
randomize_seed,
|
125 |
-
input_image,
|
126 |
-
num_inference_steps,
|
127 |
-
final_upscale_factor,
|
128 |
-
controlnet_conditioning_scale,
|
129 |
-
progress=gr.Progress(track_tqdm=True),
|
130 |
-
):
|
131 |
-
global pipe
|
132 |
-
if pipe is None:
|
133 |
-
raise gr.Error("Pipeline not loaded.")
|
134 |
-
|
135 |
-
if randomize_seed:
|
136 |
-
seed = random.randint(0, MAX_SEED)
|
137 |
-
seed = int(seed)
|
138 |
-
final_upscale_factor = int(final_upscale_factor)
|
139 |
-
|
140 |
-
processed, w0, h0, resized_flag = process_input(input_image)
|
141 |
-
w_proc, h_proc = processed.size
|
142 |
-
|
143 |
-
# prepare control image at INTERNAL scale
|
144 |
-
cw, ch = w_proc*INTERNAL_PROCESSING_FACTOR, h_proc*INTERNAL_PROCESSING_FACTOR
|
145 |
-
control_img = processed.resize((cw, ch), Image.Resampling.LANCZOS)
|
146 |
-
|
147 |
-
gen = torch.Generator(device=device).manual_seed(seed)
|
148 |
-
with torch.inference_mode():
|
149 |
-
result = pipe(
|
150 |
-
prompt="",
|
151 |
-
control_image=control_img,
|
152 |
-
controlnet_conditioning_scale=float(controlnet_conditioning_scale),
|
153 |
-
num_inference_steps=int(num_inference_steps),
|
154 |
-
guidance_scale=0.0,
|
155 |
-
height=ch, width=cw,
|
156 |
-
generator=gen
|
157 |
-
).images[0]
|
158 |
-
|
159 |
-
# final resize to user factor
|
160 |
-
if resized_flag:
|
161 |
-
fw, fh = w_proc*final_upscale_factor, h_proc*final_upscale_factor
|
162 |
-
else:
|
163 |
-
fw, fh = w0*final_upscale_factor, h0*final_upscale_factor
|
164 |
-
if (fw, fh) != result.size:
|
165 |
-
result = result.resize((fw, fh), Image.Resampling.LANCZOS)
|
166 |
-
|
167 |
-
buf = io.BytesIO()
|
168 |
-
result.save(buf, format="WEBP", quality=90)
|
169 |
-
b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
|
170 |
-
|
171 |
-
return [[input_image, result], seed, f"data:image/webp;base64,{b64}"]
|
172 |
-
|
173 |
-
# --- Gradio UI ---
|
174 |
-
with gr.Blocks(css=css, theme=gr.themes.Soft(), title="Flux Upscaler Demo") as demo:
|
175 |
-
gr.Markdown(f"""
|
176 |
-
# ⚡ Flux.1‑dev Upscaler
|
177 |
-
**Device:** {power_device} · **Internal scale:** {INTERNAL_PROCESSING_FACTOR}x · **Budget:** {MAX_PIXEL_BUDGET} px
|
178 |
-
""")
|
179 |
-
with gr.Row():
|
180 |
-
with gr.Column(scale=2):
|
181 |
-
inp = gr.Image(label="Input Image", type="pil", sources=["upload","clipboard"], height=350)
|
182 |
-
with gr.Column(scale=1):
|
183 |
-
upf = gr.Slider("Final Upscale Factor", 1, INTERNAL_PROCESSING_FACTOR, step=1, value=2)
|
184 |
-
steps = gr.Slider("Inference Steps", 4, 50, step=1, value=15)
|
185 |
-
cscale= gr.Slider("ControlNet Scale", 0.0, 1.5, step=0.05, value=0.6)
|
186 |
-
with gr.Row():
|
187 |
-
sld = gr.Slider("Seed", 0, MAX_SEED, step=1, value=42)
|
188 |
-
rnd = gr.Checkbox("Randomize", value=True, scale=0, min_width=80)
|
189 |
-
btn = gr.Button("⚡ Upscale Image", variant="primary")
|
190 |
-
|
191 |
-
slider = ImageSlider("Input / Output", type="pil", interactive=False, show_label=True, position=0.5)
|
192 |
-
out_seed= gr.Textbox("Seed Used", interactive=False, visible=True)
|
193 |
-
out_b64 = gr.Textbox("API Base64 Output", interactive=False, visible=False)
|
194 |
-
|
195 |
-
btn.click(
|
196 |
-
fn=infer,
|
197 |
-
inputs=[sld, rnd, inp, steps, upf, cscale],
|
198 |
-
outputs=[slider, out_seed, out_b64],
|
199 |
-
api_name="upscale"
|
200 |
)
|
201 |
|
202 |
-
#
|
203 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --- Imports ---
|
|
|
2 |
import os
|
3 |
import sys
|
4 |
+
import cv2
|
5 |
+
import torch
|
6 |
+
# Delay Gradio import until after installation
|
7 |
+
# import gradio as gr
|
8 |
+
import numpy as np
|
9 |
+
from PIL import Image, ImageOps # <--- IMPORT ImageOps
|
|
|
|
|
|
|
10 |
import io
|
11 |
import base64
|
12 |
+
import traceback
|
13 |
+
# Delay spaces import until after installation
|
14 |
+
# import spaces
|
15 |
+
|
16 |
+
# --- Dependency Management ---
|
17 |
+
# Force install specific compatible versions
|
18 |
+
print("Installing specific Gradio and huggingface-hub versions...")
|
19 |
+
os.system("pip uninstall -y gradio gradio-client huggingface-hub") # Uninstall all first
|
20 |
+
os.system("pip install gradio==4.13.0 gradio-client==0.8.0") # Install specific Gradio
|
21 |
+
os.system("pip install huggingface-hub==0.19.4") # Install older huggingface-hub
|
22 |
+
print("Dependency installation complete.")
|
23 |
|
24 |
+
# Now import Gradio and spaces
|
25 |
import gradio as gr
|
|
|
26 |
import spaces
|
27 |
+
# from huggingface_hub import HfFolder # Example import from hub if needed
|
28 |
+
|
29 |
+
# Check installed versions (optional but helpful for debugging)
|
30 |
+
try:
|
31 |
+
import pkg_resources
|
32 |
+
gradio_version = pkg_resources.get_distribution("gradio").version
|
33 |
+
hub_version = pkg_resources.get_distribution("huggingface-hub").version
|
34 |
+
print(f"Using Gradio version: {gradio_version}")
|
35 |
+
print(f"Using huggingface-hub version: {hub_version}")
|
36 |
+
except Exception as e:
|
37 |
+
print(f"Could not check package versions: {e}")
|
38 |
+
|
39 |
+
|
40 |
+
# Import model-specific libraries AFTER installing dependencies
|
41 |
+
try:
|
42 |
+
from basicsr.archs.srvgg_arch import SRVGGNetCompact
|
43 |
+
from gfpgan.utils import GFPGANer
|
44 |
+
from realesrgan.utils import RealESRGANer
|
45 |
+
print("Successfully imported model libraries.")
|
46 |
+
except ImportError as e:
|
47 |
+
print(f"Error importing model libraries: {e}")
|
48 |
+
print("Please ensure basicsr, gfpgan, realesrgan are installed or in requirements.txt")
|
49 |
+
sys.exit(1)
|
50 |
+
|
51 |
+
# --- Constants ---
|
52 |
+
OUTPUT_DIR = 'output' # This might not be strictly needed if not saving via gr.File
|
53 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
54 |
+
|
55 |
+
# --- Model Weight Downloads ---
|
56 |
+
MODEL_FILES = {
|
57 |
+
'realesr-general-x4v3.pth': 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth',
|
58 |
+
'GFPGANv1.2.pth': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.2.pth',
|
59 |
+
'GFPGANv1.3.pth': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
|
60 |
+
'GFPGANv1.4.pth': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth',
|
61 |
+
'RestoreFormer.pth': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth',
|
62 |
}
|
63 |
+
print("Downloading model weights...")
|
64 |
+
for filename, url in MODEL_FILES.items():
|
65 |
+
try:
|
66 |
+
if not os.path.exists(filename):
|
67 |
+
print(f"Downloading {filename}...")
|
68 |
+
os.system(f"wget -q {url} -P .") # Use -q for quiet
|
69 |
+
except Exception as e:
|
70 |
+
print(f"Error downloading {filename}: {e}")
|
71 |
+
if not os.path.exists('realesr-general-x4v3.pth'):
|
72 |
+
print("FATAL: RealESRGAN model (realesr-general-x4v3.pth) not found after download attempt. Cannot proceed.")
|
73 |
+
sys.exit(1)
|
74 |
+
print("Model weight download check complete.")
|
75 |
+
|
76 |
+
# --- Sample Image Downloads ---
|
77 |
+
SAMPLE_IMAGES = {
|
78 |
+
'lincoln.jpg': 'https://upload.wikimedia.org/wikipedia/commons/thumb/a/ab/Abraham_Lincoln_O-77_matte_collodion_print.jpg/1024px-Abraham_Lincoln_O-77_matte_collodion_print.jpg',
|
79 |
+
'AI-generate.jpg': 'https://user-images.githubusercontent.com/17445847/187400315-87a90ac9-d231-45d6-b377-38702bd1838f.jpg',
|
80 |
+
'Blake_Lively.jpg': 'https://user-images.githubusercontent.com/17445847/187400981-8a58f7a4-ef61-42d9-af80-bc6234cef860.jpg',
|
81 |
+
'10045.png': 'https://user-images.githubusercontent.com/17445847/187401133-8a3bf269-5b4d-4432-b2f0-6d26ee1d3307.png'
|
82 |
}
|
83 |
+
print("Downloading sample images...")
|
84 |
+
for filename, url in SAMPLE_IMAGES.items():
|
85 |
+
try:
|
86 |
+
if not os.path.exists(filename):
|
87 |
+
torch.hub.download_url_to_file(url, filename, progress=False)
|
88 |
+
except Exception as e:
|
89 |
+
print(f"Warning: Error downloading sample image {filename}: {e}")
|
90 |
+
print("Sample image download check complete.")
|
91 |
+
|
92 |
+
|
93 |
+
# --- Model Initialization (Background Enhancer) ---
|
94 |
+
upsampler = None
|
95 |
+
try:
|
96 |
+
print("Initializing RealESRGAN upsampler...")
|
97 |
+
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
|
98 |
+
model_path = 'realesr-general-x4v3.pth'
|
99 |
+
half = torch.cuda.is_available()
|
100 |
+
print(f"CUDA available: {torch.cuda.is_available()}, Using half precision: {half}")
|
101 |
+
upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
|
102 |
+
print("Successfully created RealESRGAN upsampler.")
|
103 |
+
except Exception as e:
|
104 |
+
print(f"Error creating RealESRGAN upsampler: {e}")
|
105 |
+
print(traceback.format_exc())
|
106 |
+
print("Warning: GFPGAN will run without background enhancement.")
|
107 |
+
|
108 |
+
|
109 |
+
# --- Inference Function (Handles API, PIL, Base64, EXIF) ---
|
110 |
+
@spaces.GPU(duration=90) # Keep ZeroGPU decorator
|
111 |
+
# --- MODIFIED Function Signature: Accepts filepath instead of PIL Image ---
|
112 |
+
def inference(input_image_filepath, version, scale):
|
113 |
+
"""
|
114 |
+
Processes an input image file using GFPGAN, handling EXIF and aspect ratio.
|
115 |
|
116 |
+
Args:
|
117 |
+
input_image_filepath (str): Path to the temporary input image file provided by Gradio.
|
118 |
+
version (str): GFPGAN model version ('v1.2', 'v1.3', 'v1.4', 'RestoreFormer').
|
119 |
+
scale (float): Rescaling factor for the final output relative to original.
|
|
|
|
|
|
|
|
|
120 |
|
121 |
+
Returns:
|
122 |
+
tuple: (PIL.Image.Image | None, str | None)
|
123 |
+
- Output PIL image (RGB) or None on error.
|
124 |
+
- Base64 encoded output image string (data URI) or an error message string.
|
125 |
+
"""
|
126 |
+
# --- ADDED: Load image from filepath ---
|
127 |
+
input_pil_image = None
|
128 |
+
try:
|
129 |
+
print(f"Loading image from filepath: {input_image_filepath}")
|
130 |
+
if not input_image_filepath or not isinstance(input_image_filepath, str) or not os.path.exists(input_image_filepath):
|
131 |
+
error_msg = f"Error: Input image filepath is invalid or file does not exist: '{input_image_filepath}'"
|
132 |
+
print(error_msg)
|
133 |
+
# Return None for Image output, error message for Textbox output
|
134 |
+
return None, error_msg
|
135 |
+
input_pil_image = Image.open(input_image_filepath)
|
136 |
+
print(f"Successfully loaded image. Initial mode: {input_pil_image.mode}, size: {input_pil_image.size}")
|
137 |
+
except Exception as load_err:
|
138 |
+
error_msg = f"Error loading image from filepath {input_image_filepath}: {load_err}"
|
139 |
+
print(error_msg)
|
140 |
+
print(traceback.format_exc())
|
141 |
+
# Return None for Image output, error message for Textbox output
|
142 |
+
return None, error_msg
|
143 |
+
# --- End Added Image Loading ---
|
144 |
|
145 |
+
# Check if loading failed (redundant due to try/except, but safe)
|
146 |
+
if input_pil_image is None:
|
147 |
+
print("Error: No input image could be loaded.")
|
148 |
+
return None, "Error: Failed to load input image."
|
|
|
149 |
|
150 |
+
print(f"Processing image with GFPGAN version: {version}, scale: {scale}")
|
151 |
+
|
152 |
+
# --- Handle EXIF Orientation ---
|
153 |
+
original_size_before_exif = input_pil_image.size
|
154 |
+
try:
|
155 |
+
input_pil_image = ImageOps.exif_transpose(input_pil_image)
|
156 |
+
if input_pil_image.size != original_size_before_exif:
|
157 |
+
print(f"Image size changed by EXIF transpose: {original_size_before_exif} -> {input_pil_image.size}")
|
158 |
+
except Exception as exif_err:
|
159 |
+
print(f"Warning: Could not apply EXIF transpose: {exif_err}")
|
160 |
+
# -----------------------------
|
161 |
+
|
162 |
+
w_orig, h_orig = input_pil_image.size
|
163 |
+
print(f"Input size for processing (WxH): {w_orig}x{h_orig}")
|
164 |
+
|
165 |
+
# Convert PIL Image to OpenCV format (BGR numpy array)
|
166 |
+
try:
|
167 |
+
img_mode = input_pil_image.mode
|
168 |
+
if img_mode == 'RGBA':
|
169 |
+
input_pil_image = input_pil_image.convert('RGB')
|
170 |
+
elif img_mode != 'RGB':
|
171 |
+
print(f"Converting input image from {img_mode} to RGB")
|
172 |
+
input_pil_image = input_pil_image.convert('RGB')
|
173 |
+
img_bgr = np.array(input_pil_image)[:, :, ::-1].copy()
|
174 |
+
except Exception as conversion_err:
|
175 |
+
error_msg = f"Error converting PIL image to OpenCV format: {conversion_err}"
|
176 |
+
print(error_msg)
|
177 |
+
print(traceback.format_exc())
|
178 |
+
return None, error_msg # Return None for Image, error string for Textbox
|
179 |
+
|
180 |
+
# --- Start GFPGAN Processing ---
|
181 |
+
try:
|
182 |
+
h, w = img_bgr.shape[0:2]
|
183 |
+
if h > 4000 or w > 4000:
|
184 |
+
print(f'Warning: Image size ({w}x{h}) is very large, processing might be slow or fail.')
|
185 |
+
|
186 |
+
model_map = {
|
187 |
+
'v1.2': 'GFPGANv1.2.pth', 'v1.3': 'GFPGANv1.3.pth',
|
188 |
+
'v1.4': 'GFPGANv1.4.pth', 'RestoreFormer': 'RestoreFormer.pth'
|
189 |
+
}
|
190 |
+
arch_map = {
|
191 |
+
'v1.2': 'clean', 'v1.3': 'clean', 'v1.4': 'clean',
|
192 |
+
'RestoreFormer': 'RestoreFormer'
|
193 |
+
}
|
194 |
+
|
195 |
+
if version not in model_map:
|
196 |
+
error_msg = f"Error: Unknown version selected: {version}"
|
197 |
+
print(error_msg)
|
198 |
+
return None, error_msg
|
199 |
+
model_path = model_map[version]
|
200 |
+
arch = arch_map[version]
|
201 |
+
if not os.path.exists(model_path):
|
202 |
+
error_msg = f"Error: Model file not found for version {version}: {model_path}"
|
203 |
+
print(error_msg)
|
204 |
+
return None, error_msg
|
205 |
+
|
206 |
+
current_bg_upsampler = upsampler
|
207 |
+
if not current_bg_upsampler:
|
208 |
+
print("Warning: RealESRGAN upsampler not available. Background enhancement disabled.")
|
209 |
+
|
210 |
+
face_enhancer = GFPGANer(
|
211 |
+
model_path=model_path, upscale=2, arch=arch,
|
212 |
+
channel_multiplier=2, bg_upsampler=current_bg_upsampler
|
213 |
+
)
|
214 |
+
|
215 |
+
print(f"Running GFPGAN enhancement with {version}...")
|
216 |
+
_, _, output_bgr = face_enhancer.enhance(
|
217 |
+
img_bgr, has_aligned=False, only_center_face=False, paste_back=True
|
218 |
+
)
|
219 |
+
if output_bgr is None:
|
220 |
+
error_msg = "Error: GFPGAN enhancement returned None."
|
221 |
+
print(error_msg)
|
222 |
+
return None, error_msg
|
223 |
+
print(f"Enhancement complete. Intermediate output shape (HxWxC BGR): {output_bgr.shape}")
|
224 |
+
|
225 |
+
# --- Post-processing (Resizing) ---
|
226 |
+
target_scale_factor = float(scale)
|
227 |
+
h_gfpgan, w_gfpgan = output_bgr.shape[0:2]
|
228 |
+
target_w = int(w_orig * target_scale_factor)
|
229 |
+
target_h = int(h_orig * target_scale_factor)
|
230 |
+
|
231 |
+
if target_w <= 0 or target_h <= 0:
|
232 |
+
print(f"Warning: Invalid target size ({target_w}x{target_h}) calculated from scale {scale}. Using GFPGAN output size {w_gfpgan}x{h_gfpgan}.")
|
233 |
+
target_w, target_h = w_gfpgan, h_gfpgan
|
234 |
+
|
235 |
+
if abs(target_w - w_gfpgan) > 2 or abs(target_h - h_gfpgan) > 2:
|
236 |
+
print(f"Resizing GFPGAN output ({w_gfpgan}x{h_gfpgan}) to target ({target_w}x{target_h}) based on scale {target_scale_factor}...")
|
237 |
+
interpolation = cv2.INTER_LANCZOS4 if (target_w * target_h) > (w_gfpgan * h_gfpgan) else cv2.INTER_AREA
|
238 |
+
try:
|
239 |
+
output_bgr = cv2.resize(output_bgr, (target_w, target_h), interpolation=interpolation)
|
240 |
+
except cv2.error as resize_err:
|
241 |
+
# --- MODIFIED Resize Error Handling ---
|
242 |
+
error_msg = f"Error during OpenCV resize: {resize_err}. Returning image before final resize attempt."
|
243 |
+
print(error_msg)
|
244 |
+
output_pil = Image.fromarray(cv2.cvtColor(output_bgr, cv2.COLOR_BGR2RGB))
|
245 |
+
# Still try to encode this fallback image
|
246 |
+
base64_output = None
|
247 |
+
try:
|
248 |
+
buffered = io.BytesIO()
|
249 |
+
output_pil.save(buffered, format="WEBP", quality=85)
|
250 |
+
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
251 |
+
base64_output = f"data:image/webp;base64,{img_str}"
|
252 |
+
except Exception as enc_err:
|
253 |
+
print(f"Error encoding fallback image: {enc_err}")
|
254 |
+
error_msg += f" | Encoding Error: {enc_err}" # Append encoding error
|
255 |
+
# Return fallback image and combined error message
|
256 |
+
return output_pil, base64_output if base64_output else error_msg
|
257 |
+
# --- End Modified Resize Error Handling ---
|
258 |
+
|
259 |
+
# --- Convert final result back to PIL (RGB) ---
|
260 |
+
output_pil = Image.fromarray(cv2.cvtColor(output_bgr, cv2.COLOR_BGR2RGB))
|
261 |
+
print(f"Final output image size (WxH PIL): {output_pil.size}")
|
262 |
+
|
263 |
+
# --- Encode final PIL image to Base64 for API ---
|
264 |
+
# This replaces the need for gr.File saving/output
|
265 |
+
base64_output = None
|
266 |
+
try:
|
267 |
+
buffered = io.BytesIO()
|
268 |
+
output_pil.save(buffered, format="WEBP", quality=90)
|
269 |
+
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
270 |
+
base64_output = f"data:image/webp;base64,{img_str}"
|
271 |
+
except Exception as enc_err:
|
272 |
+
error_msg = f"Error encoding final image to base64: {enc_err}"
|
273 |
+
print(error_msg)
|
274 |
+
print(traceback.format_exc())
|
275 |
+
# Return the PIL image anyway, but with an error message
|
276 |
+
return output_pil, error_msg
|
277 |
+
|
278 |
+
# --- MODIFIED RETURN: Return PIL image and base64 string ---
|
279 |
+
# Remove save_path which is no longer needed for outputs
|
280 |
+
# Return the base64 string for the Textbox output (or success message)
|
281 |
+
success_msg = f"Success! Output size: {output_pil.size[0]}x{output_pil.size[1]}"
|
282 |
+
return output_pil, base64_output if base64_output else success_msg
|
283 |
+
|
284 |
+
except Exception as error:
|
285 |
+
# --- MODIFIED Main Exception Handling ---
|
286 |
+
error_msg = f"Error during GFPGAN processing: {error}"
|
287 |
+
print(error_msg)
|
288 |
+
print(traceback.format_exc())
|
289 |
+
# Return placeholder image (or None) and error message string
|
290 |
+
error_img = None
|
291 |
+
try:
|
292 |
+
error_img = Image.new('RGB', (100, 50), color = 'red')
|
293 |
+
except Exception: pass # Ignore if placeholder fails
|
294 |
+
return error_img, error_msg # Return placeholder/None and error string
|
295 |
+
# --- End Modified Main Exception Handling ---
|
296 |
+
|
297 |
+
|
298 |
+
# --- Gradio Interface Definition ---
|
299 |
+
title = "GFPGAN: Practical Face Restoration"
|
300 |
+
description = """Gradio demo for <a href='https://github.com/TencentARC/GFPGAN' target='_blank'><b>GFPGAN: Towards Real-World Blind Face Restoration with Generative Facial Prior</b></a>.
|
301 |
+
<br>Restore your <b>old photos</b> or improve <b>AI-generated faces</b>. Upload an image to start.
|
302 |
+
<br>If helpful, please ⭐ the <a href='https://github.com/TencentARC/GFPGAN' target='_blank'>Original Github Repo</a>.
|
303 |
+
<br>API endpoint available at `/predict`. Expects image filepath, version string, scale number. Returns Image and Textbox data.
|
304 |
+
"""
|
305 |
+
article = "Questions? Contact the original creators (see GFPGAN repo)."
|
306 |
+
|
307 |
+
print("Creating Gradio interface...")
|
308 |
try:
|
309 |
+
# --- MODIFIED INPUTS ---
|
310 |
+
# Changed gr.Image type to 'filepath'
|
311 |
+
inputs = [
|
312 |
+
gr.Image(type="filepath", label="Input Image", sources=["upload", "clipboard"]), # <-- TYPE CHANGED
|
313 |
+
gr.Radio(
|
314 |
+
['v1.2', 'v1.3', 'v1.4', 'RestoreFormer'],
|
315 |
+
type="value", value='v1.4', label='GFPGAN Version',
|
316 |
+
info="v1.4 recommended. RestoreFormer for diverse poses."
|
317 |
+
),
|
318 |
+
gr.Number(
|
319 |
+
label="Rescaling Factor", value=2,
|
320 |
+
info="Final output size multiplier relative to original input size (e.g., 2 = 2x original WxH)."
|
321 |
+
),
|
322 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
323 |
|
324 |
+
# --- MODIFIED OUTPUTS ---
|
325 |
+
# Removed gr.File, kept gr.Image and gr.Textbox
|
326 |
+
# Made Textbox visible for easier debugging
|
327 |
+
outputs = [
|
328 |
+
gr.Image(type="pil", label="Output Image"),
|
329 |
+
# gr.File(label="Download Output Image (Server Path)"), # <-- REMOVED
|
330 |
+
gr.Textbox(label="Output Info / Base64 Data", interactive=False, visible=True) # <-- KEPT (made visible)
|
331 |
+
]
|
332 |
|
333 |
+
# Define examples using file paths (Gradio handles loading them)
|
334 |
+
# These should still work even with type="filepath" for the input component
|
335 |
+
examples = [
|
336 |
+
['AI-generate.jpg', 'v1.4', 2],
|
337 |
+
['lincoln.jpg', 'v1.4', 2],
|
338 |
+
['Blake_Lively.jpg', 'v1.4', 2],
|
339 |
+
['10045.png', 'v1.4', 2]
|
340 |
+
]
|
341 |
+
|
342 |
+
# --- Gradio Interface Instantiation ---
|
343 |
+
# Ensure fn=inference points to the modified function
|
344 |
+
demo = gr.Interface(
|
345 |
+
fn=inference, # Should now accept filepath and return 2 items
|
346 |
+
inputs=inputs, # Use modified inputs list
|
347 |
+
outputs=outputs, # Use modified outputs list
|
348 |
+
title=title,
|
349 |
+
description=description,
|
350 |
+
article=article,
|
351 |
+
examples=examples,
|
352 |
+
cache_examples=False, # Caching might be complex with filepaths vs PIL
|
353 |
+
allow_flagging='never'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
354 |
)
|
355 |
|
356 |
+
# Launch the interface
|
357 |
+
print("Launching Gradio interface...")
|
358 |
+
demo.queue().launch(server_name="0.0.0.0", share=False)
|
359 |
+
print("Gradio app launched successfully and should be accessible.")
|
360 |
+
|
361 |
+
except Exception as e:
|
362 |
+
print(f"Error setting up or launching Gradio interface: {e}")
|
363 |
+
print(traceback.format_exc())
|
364 |
+
sys.exit(1)
|