Gemini899 commited on
Commit
1d8f921
·
verified ·
1 Parent(s): 55f9bde

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +346 -185
app.py CHANGED
@@ -1,203 +1,364 @@
1
- # app.py
2
-
3
  import os
4
  import sys
5
-
6
- # --- Install Dependencies ---
7
- print("Installing required packages: diffusers, gradio_imageslider, huggingface-hub…")
8
- os.system("pip install --no-input diffusers gradio_imageslider huggingface-hub")
9
-
10
- # --- Standard Imports ---
11
- import logging
12
- import random
13
- import warnings
14
  import io
15
  import base64
 
 
 
 
 
 
 
 
 
 
 
16
 
 
17
  import gradio as gr
18
- import numpy as np
19
  import spaces
20
- import torch
21
- from diffusers import FluxControlNetModel
22
- from diffusers.pipelines import FluxControlNetPipeline
23
- from gradio_imageslider import ImageSlider
24
- from PIL import Image, ImageOps
25
- from huggingface_hub import snapshot_download
26
-
27
- # --- Logging & Device Setup ---
28
- logging.basicConfig(level=logging.INFO)
29
- warnings.filterwarnings("ignore")
30
-
31
- css = """
32
- #col-container {
33
- margin: 0 auto;
34
- max-width: 512px;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  }
36
- .gradio-container {
37
- max-width: 900px !important;
38
- margin: auto !important;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  }
40
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- if torch.cuda.is_available():
43
- power_device = "GPU"
44
- device = "cuda"
45
- torch_dtype = torch.bfloat16
46
- else:
47
- power_device = "CPU"
48
- device = "cpu"
49
- torch_dtype = torch.float32
50
 
51
- logging.info(f"Running on device={device} with dtype={torch_dtype}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- # --- Model IDs & Download (no token) ---
54
- flux_model_id = "black-forest-labs/FLUX.1-dev"
55
- controlnet_model_id = "jasperai/Flux.1-dev-Controlnet-Upscaler"
56
- local_model_dir = flux_model_id.split("/")[-1]
57
- pipe = None
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  try:
60
- logging.info(f"Downloading base model: {flux_model_id}")
61
- model_path = snapshot_download(
62
- repo_id=flux_model_id,
63
- repo_type="model",
64
- local_dir=local_model_dir,
65
- ignore_patterns=["*.md", "*.gitattributes"],
66
- )
67
- logging.info(f"Downloaded base model to: {model_path}")
68
-
69
- logging.info(f"Loading ControlNet: {controlnet_model_id}")
70
- controlnet = FluxControlNetModel.from_pretrained(
71
- controlnet_model_id,
72
- torch_dtype=torch_dtype
73
- ).to(device)
74
- logging.info("ControlNet loaded.")
75
-
76
- logging.info("Initializing FluxControlNetPipeline…")
77
- pipe = FluxControlNetPipeline.from_pretrained(
78
- model_path,
79
- controlnet=controlnet,
80
- torch_dtype=torch_dtype
81
- ).to(device)
82
- logging.info("Pipeline ready.")
83
 
84
- except Exception as e:
85
- logging.error(f"Error loading models: {e}", exc_info=True)
86
- print(f"FATAL: could not load models: {e}")
87
- sys.exit(1)
 
 
 
 
88
 
89
- # --- Constants & Helpers ---
90
- MAX_SEED = 2**32 - 1
91
- MAX_PIXEL_BUDGET = 1280 * 1280
92
- INTERNAL_PROCESSING_FACTOR = 4
93
-
94
- def process_input(input_image):
95
- if input_image is None:
96
- raise gr.Error("No input image provided!")
97
- img = ImageOps.exif_transpose(input_image)
98
- if img.mode != "RGB":
99
- img = img.convert("RGB")
100
- w, h = img.size
101
-
102
- # enforce intermediate‐scale budget
103
- target_px = (w*INTERNAL_PROCESSING_FACTOR)*(h*INTERNAL_PROCESSING_FACTOR)
104
- if target_px > MAX_PIXEL_BUDGET:
105
- max_in = MAX_PIXEL_BUDGET / (INTERNAL_PROCESSING_FACTOR**2)
106
- scale = (max_in / (w*h))**0.5
107
- w2, h2 = max(8,int(w*scale)), max(8,int(h*scale))
108
- img = img.resize((w2,h2), Image.Resampling.LANCZOS)
109
- was_resized = True
110
- else:
111
- was_resized = False
112
-
113
- # round dimensions to multiples of 8
114
- w2, h2 = img.size
115
- w2 -= w2 % 8; h2 -= h2 % 8
116
- if img.size != (w2,h2):
117
- img = img.resize((w2,h2), Image.Resampling.LANCZOS)
118
-
119
- return img, w, h, was_resized
120
-
121
- @spaces.GPU(duration=75)
122
- def infer(
123
- seed,
124
- randomize_seed,
125
- input_image,
126
- num_inference_steps,
127
- final_upscale_factor,
128
- controlnet_conditioning_scale,
129
- progress=gr.Progress(track_tqdm=True),
130
- ):
131
- global pipe
132
- if pipe is None:
133
- raise gr.Error("Pipeline not loaded.")
134
-
135
- if randomize_seed:
136
- seed = random.randint(0, MAX_SEED)
137
- seed = int(seed)
138
- final_upscale_factor = int(final_upscale_factor)
139
-
140
- processed, w0, h0, resized_flag = process_input(input_image)
141
- w_proc, h_proc = processed.size
142
-
143
- # prepare control image at INTERNAL scale
144
- cw, ch = w_proc*INTERNAL_PROCESSING_FACTOR, h_proc*INTERNAL_PROCESSING_FACTOR
145
- control_img = processed.resize((cw, ch), Image.Resampling.LANCZOS)
146
-
147
- gen = torch.Generator(device=device).manual_seed(seed)
148
- with torch.inference_mode():
149
- result = pipe(
150
- prompt="",
151
- control_image=control_img,
152
- controlnet_conditioning_scale=float(controlnet_conditioning_scale),
153
- num_inference_steps=int(num_inference_steps),
154
- guidance_scale=0.0,
155
- height=ch, width=cw,
156
- generator=gen
157
- ).images[0]
158
-
159
- # final resize to user factor
160
- if resized_flag:
161
- fw, fh = w_proc*final_upscale_factor, h_proc*final_upscale_factor
162
- else:
163
- fw, fh = w0*final_upscale_factor, h0*final_upscale_factor
164
- if (fw, fh) != result.size:
165
- result = result.resize((fw, fh), Image.Resampling.LANCZOS)
166
-
167
- buf = io.BytesIO()
168
- result.save(buf, format="WEBP", quality=90)
169
- b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
170
-
171
- return [[input_image, result], seed, f"data:image/webp;base64,{b64}"]
172
-
173
- # --- Gradio UI ---
174
- with gr.Blocks(css=css, theme=gr.themes.Soft(), title="Flux Upscaler Demo") as demo:
175
- gr.Markdown(f"""
176
- # ⚡ Flux.1‑dev Upscaler
177
- **Device:** {power_device} · **Internal scale:** {INTERNAL_PROCESSING_FACTOR}x · **Budget:** {MAX_PIXEL_BUDGET} px
178
- """)
179
- with gr.Row():
180
- with gr.Column(scale=2):
181
- inp = gr.Image(label="Input Image", type="pil", sources=["upload","clipboard"], height=350)
182
- with gr.Column(scale=1):
183
- upf = gr.Slider("Final Upscale Factor", 1, INTERNAL_PROCESSING_FACTOR, step=1, value=2)
184
- steps = gr.Slider("Inference Steps", 4, 50, step=1, value=15)
185
- cscale= gr.Slider("ControlNet Scale", 0.0, 1.5, step=0.05, value=0.6)
186
- with gr.Row():
187
- sld = gr.Slider("Seed", 0, MAX_SEED, step=1, value=42)
188
- rnd = gr.Checkbox("Randomize", value=True, scale=0, min_width=80)
189
- btn = gr.Button("⚡ Upscale Image", variant="primary")
190
-
191
- slider = ImageSlider("Input / Output", type="pil", interactive=False, show_label=True, position=0.5)
192
- out_seed= gr.Textbox("Seed Used", interactive=False, visible=True)
193
- out_b64 = gr.Textbox("API Base64 Output", interactive=False, visible=False)
194
-
195
- btn.click(
196
- fn=infer,
197
- inputs=[sld, rnd, inp, steps, upf, cscale],
198
- outputs=[slider, out_seed, out_b64],
199
- api_name="upscale"
200
  )
201
 
202
- # Expose JSON API at /run/upscale
203
- demo.queue(max_size=10).launch(share=False, show_api=True)
 
 
 
 
 
 
 
 
1
+ # --- Imports ---
 
2
  import os
3
  import sys
4
+ import cv2
5
+ import torch
6
+ # Delay Gradio import until after installation
7
+ # import gradio as gr
8
+ import numpy as np
9
+ from PIL import Image, ImageOps # <--- IMPORT ImageOps
 
 
 
10
  import io
11
  import base64
12
+ import traceback
13
+ # Delay spaces import until after installation
14
+ # import spaces
15
+
16
+ # --- Dependency Management ---
17
+ # Force install specific compatible versions
18
+ print("Installing specific Gradio and huggingface-hub versions...")
19
+ os.system("pip uninstall -y gradio gradio-client huggingface-hub") # Uninstall all first
20
+ os.system("pip install gradio==4.13.0 gradio-client==0.8.0") # Install specific Gradio
21
+ os.system("pip install huggingface-hub==0.19.4") # Install older huggingface-hub
22
+ print("Dependency installation complete.")
23
 
24
+ # Now import Gradio and spaces
25
  import gradio as gr
 
26
  import spaces
27
+ # from huggingface_hub import HfFolder # Example import from hub if needed
28
+
29
+ # Check installed versions (optional but helpful for debugging)
30
+ try:
31
+ import pkg_resources
32
+ gradio_version = pkg_resources.get_distribution("gradio").version
33
+ hub_version = pkg_resources.get_distribution("huggingface-hub").version
34
+ print(f"Using Gradio version: {gradio_version}")
35
+ print(f"Using huggingface-hub version: {hub_version}")
36
+ except Exception as e:
37
+ print(f"Could not check package versions: {e}")
38
+
39
+
40
+ # Import model-specific libraries AFTER installing dependencies
41
+ try:
42
+ from basicsr.archs.srvgg_arch import SRVGGNetCompact
43
+ from gfpgan.utils import GFPGANer
44
+ from realesrgan.utils import RealESRGANer
45
+ print("Successfully imported model libraries.")
46
+ except ImportError as e:
47
+ print(f"Error importing model libraries: {e}")
48
+ print("Please ensure basicsr, gfpgan, realesrgan are installed or in requirements.txt")
49
+ sys.exit(1)
50
+
51
+ # --- Constants ---
52
+ OUTPUT_DIR = 'output' # This might not be strictly needed if not saving via gr.File
53
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
54
+
55
+ # --- Model Weight Downloads ---
56
+ MODEL_FILES = {
57
+ 'realesr-general-x4v3.pth': 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth',
58
+ 'GFPGANv1.2.pth': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.2.pth',
59
+ 'GFPGANv1.3.pth': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth',
60
+ 'GFPGANv1.4.pth': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth',
61
+ 'RestoreFormer.pth': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth',
62
  }
63
+ print("Downloading model weights...")
64
+ for filename, url in MODEL_FILES.items():
65
+ try:
66
+ if not os.path.exists(filename):
67
+ print(f"Downloading {filename}...")
68
+ os.system(f"wget -q {url} -P .") # Use -q for quiet
69
+ except Exception as e:
70
+ print(f"Error downloading {filename}: {e}")
71
+ if not os.path.exists('realesr-general-x4v3.pth'):
72
+ print("FATAL: RealESRGAN model (realesr-general-x4v3.pth) not found after download attempt. Cannot proceed.")
73
+ sys.exit(1)
74
+ print("Model weight download check complete.")
75
+
76
+ # --- Sample Image Downloads ---
77
+ SAMPLE_IMAGES = {
78
+ 'lincoln.jpg': 'https://upload.wikimedia.org/wikipedia/commons/thumb/a/ab/Abraham_Lincoln_O-77_matte_collodion_print.jpg/1024px-Abraham_Lincoln_O-77_matte_collodion_print.jpg',
79
+ 'AI-generate.jpg': 'https://user-images.githubusercontent.com/17445847/187400315-87a90ac9-d231-45d6-b377-38702bd1838f.jpg',
80
+ 'Blake_Lively.jpg': 'https://user-images.githubusercontent.com/17445847/187400981-8a58f7a4-ef61-42d9-af80-bc6234cef860.jpg',
81
+ '10045.png': 'https://user-images.githubusercontent.com/17445847/187401133-8a3bf269-5b4d-4432-b2f0-6d26ee1d3307.png'
82
  }
83
+ print("Downloading sample images...")
84
+ for filename, url in SAMPLE_IMAGES.items():
85
+ try:
86
+ if not os.path.exists(filename):
87
+ torch.hub.download_url_to_file(url, filename, progress=False)
88
+ except Exception as e:
89
+ print(f"Warning: Error downloading sample image {filename}: {e}")
90
+ print("Sample image download check complete.")
91
+
92
+
93
+ # --- Model Initialization (Background Enhancer) ---
94
+ upsampler = None
95
+ try:
96
+ print("Initializing RealESRGAN upsampler...")
97
+ model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
98
+ model_path = 'realesr-general-x4v3.pth'
99
+ half = torch.cuda.is_available()
100
+ print(f"CUDA available: {torch.cuda.is_available()}, Using half precision: {half}")
101
+ upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half)
102
+ print("Successfully created RealESRGAN upsampler.")
103
+ except Exception as e:
104
+ print(f"Error creating RealESRGAN upsampler: {e}")
105
+ print(traceback.format_exc())
106
+ print("Warning: GFPGAN will run without background enhancement.")
107
+
108
+
109
+ # --- Inference Function (Handles API, PIL, Base64, EXIF) ---
110
+ @spaces.GPU(duration=90) # Keep ZeroGPU decorator
111
+ # --- MODIFIED Function Signature: Accepts filepath instead of PIL Image ---
112
+ def inference(input_image_filepath, version, scale):
113
+ """
114
+ Processes an input image file using GFPGAN, handling EXIF and aspect ratio.
115
 
116
+ Args:
117
+ input_image_filepath (str): Path to the temporary input image file provided by Gradio.
118
+ version (str): GFPGAN model version ('v1.2', 'v1.3', 'v1.4', 'RestoreFormer').
119
+ scale (float): Rescaling factor for the final output relative to original.
 
 
 
 
120
 
121
+ Returns:
122
+ tuple: (PIL.Image.Image | None, str | None)
123
+ - Output PIL image (RGB) or None on error.
124
+ - Base64 encoded output image string (data URI) or an error message string.
125
+ """
126
+ # --- ADDED: Load image from filepath ---
127
+ input_pil_image = None
128
+ try:
129
+ print(f"Loading image from filepath: {input_image_filepath}")
130
+ if not input_image_filepath or not isinstance(input_image_filepath, str) or not os.path.exists(input_image_filepath):
131
+ error_msg = f"Error: Input image filepath is invalid or file does not exist: '{input_image_filepath}'"
132
+ print(error_msg)
133
+ # Return None for Image output, error message for Textbox output
134
+ return None, error_msg
135
+ input_pil_image = Image.open(input_image_filepath)
136
+ print(f"Successfully loaded image. Initial mode: {input_pil_image.mode}, size: {input_pil_image.size}")
137
+ except Exception as load_err:
138
+ error_msg = f"Error loading image from filepath {input_image_filepath}: {load_err}"
139
+ print(error_msg)
140
+ print(traceback.format_exc())
141
+ # Return None for Image output, error message for Textbox output
142
+ return None, error_msg
143
+ # --- End Added Image Loading ---
144
 
145
+ # Check if loading failed (redundant due to try/except, but safe)
146
+ if input_pil_image is None:
147
+ print("Error: No input image could be loaded.")
148
+ return None, "Error: Failed to load input image."
 
149
 
150
+ print(f"Processing image with GFPGAN version: {version}, scale: {scale}")
151
+
152
+ # --- Handle EXIF Orientation ---
153
+ original_size_before_exif = input_pil_image.size
154
+ try:
155
+ input_pil_image = ImageOps.exif_transpose(input_pil_image)
156
+ if input_pil_image.size != original_size_before_exif:
157
+ print(f"Image size changed by EXIF transpose: {original_size_before_exif} -> {input_pil_image.size}")
158
+ except Exception as exif_err:
159
+ print(f"Warning: Could not apply EXIF transpose: {exif_err}")
160
+ # -----------------------------
161
+
162
+ w_orig, h_orig = input_pil_image.size
163
+ print(f"Input size for processing (WxH): {w_orig}x{h_orig}")
164
+
165
+ # Convert PIL Image to OpenCV format (BGR numpy array)
166
+ try:
167
+ img_mode = input_pil_image.mode
168
+ if img_mode == 'RGBA':
169
+ input_pil_image = input_pil_image.convert('RGB')
170
+ elif img_mode != 'RGB':
171
+ print(f"Converting input image from {img_mode} to RGB")
172
+ input_pil_image = input_pil_image.convert('RGB')
173
+ img_bgr = np.array(input_pil_image)[:, :, ::-1].copy()
174
+ except Exception as conversion_err:
175
+ error_msg = f"Error converting PIL image to OpenCV format: {conversion_err}"
176
+ print(error_msg)
177
+ print(traceback.format_exc())
178
+ return None, error_msg # Return None for Image, error string for Textbox
179
+
180
+ # --- Start GFPGAN Processing ---
181
+ try:
182
+ h, w = img_bgr.shape[0:2]
183
+ if h > 4000 or w > 4000:
184
+ print(f'Warning: Image size ({w}x{h}) is very large, processing might be slow or fail.')
185
+
186
+ model_map = {
187
+ 'v1.2': 'GFPGANv1.2.pth', 'v1.3': 'GFPGANv1.3.pth',
188
+ 'v1.4': 'GFPGANv1.4.pth', 'RestoreFormer': 'RestoreFormer.pth'
189
+ }
190
+ arch_map = {
191
+ 'v1.2': 'clean', 'v1.3': 'clean', 'v1.4': 'clean',
192
+ 'RestoreFormer': 'RestoreFormer'
193
+ }
194
+
195
+ if version not in model_map:
196
+ error_msg = f"Error: Unknown version selected: {version}"
197
+ print(error_msg)
198
+ return None, error_msg
199
+ model_path = model_map[version]
200
+ arch = arch_map[version]
201
+ if not os.path.exists(model_path):
202
+ error_msg = f"Error: Model file not found for version {version}: {model_path}"
203
+ print(error_msg)
204
+ return None, error_msg
205
+
206
+ current_bg_upsampler = upsampler
207
+ if not current_bg_upsampler:
208
+ print("Warning: RealESRGAN upsampler not available. Background enhancement disabled.")
209
+
210
+ face_enhancer = GFPGANer(
211
+ model_path=model_path, upscale=2, arch=arch,
212
+ channel_multiplier=2, bg_upsampler=current_bg_upsampler
213
+ )
214
+
215
+ print(f"Running GFPGAN enhancement with {version}...")
216
+ _, _, output_bgr = face_enhancer.enhance(
217
+ img_bgr, has_aligned=False, only_center_face=False, paste_back=True
218
+ )
219
+ if output_bgr is None:
220
+ error_msg = "Error: GFPGAN enhancement returned None."
221
+ print(error_msg)
222
+ return None, error_msg
223
+ print(f"Enhancement complete. Intermediate output shape (HxWxC BGR): {output_bgr.shape}")
224
+
225
+ # --- Post-processing (Resizing) ---
226
+ target_scale_factor = float(scale)
227
+ h_gfpgan, w_gfpgan = output_bgr.shape[0:2]
228
+ target_w = int(w_orig * target_scale_factor)
229
+ target_h = int(h_orig * target_scale_factor)
230
+
231
+ if target_w <= 0 or target_h <= 0:
232
+ print(f"Warning: Invalid target size ({target_w}x{target_h}) calculated from scale {scale}. Using GFPGAN output size {w_gfpgan}x{h_gfpgan}.")
233
+ target_w, target_h = w_gfpgan, h_gfpgan
234
+
235
+ if abs(target_w - w_gfpgan) > 2 or abs(target_h - h_gfpgan) > 2:
236
+ print(f"Resizing GFPGAN output ({w_gfpgan}x{h_gfpgan}) to target ({target_w}x{target_h}) based on scale {target_scale_factor}...")
237
+ interpolation = cv2.INTER_LANCZOS4 if (target_w * target_h) > (w_gfpgan * h_gfpgan) else cv2.INTER_AREA
238
+ try:
239
+ output_bgr = cv2.resize(output_bgr, (target_w, target_h), interpolation=interpolation)
240
+ except cv2.error as resize_err:
241
+ # --- MODIFIED Resize Error Handling ---
242
+ error_msg = f"Error during OpenCV resize: {resize_err}. Returning image before final resize attempt."
243
+ print(error_msg)
244
+ output_pil = Image.fromarray(cv2.cvtColor(output_bgr, cv2.COLOR_BGR2RGB))
245
+ # Still try to encode this fallback image
246
+ base64_output = None
247
+ try:
248
+ buffered = io.BytesIO()
249
+ output_pil.save(buffered, format="WEBP", quality=85)
250
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
251
+ base64_output = f"data:image/webp;base64,{img_str}"
252
+ except Exception as enc_err:
253
+ print(f"Error encoding fallback image: {enc_err}")
254
+ error_msg += f" | Encoding Error: {enc_err}" # Append encoding error
255
+ # Return fallback image and combined error message
256
+ return output_pil, base64_output if base64_output else error_msg
257
+ # --- End Modified Resize Error Handling ---
258
+
259
+ # --- Convert final result back to PIL (RGB) ---
260
+ output_pil = Image.fromarray(cv2.cvtColor(output_bgr, cv2.COLOR_BGR2RGB))
261
+ print(f"Final output image size (WxH PIL): {output_pil.size}")
262
+
263
+ # --- Encode final PIL image to Base64 for API ---
264
+ # This replaces the need for gr.File saving/output
265
+ base64_output = None
266
+ try:
267
+ buffered = io.BytesIO()
268
+ output_pil.save(buffered, format="WEBP", quality=90)
269
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
270
+ base64_output = f"data:image/webp;base64,{img_str}"
271
+ except Exception as enc_err:
272
+ error_msg = f"Error encoding final image to base64: {enc_err}"
273
+ print(error_msg)
274
+ print(traceback.format_exc())
275
+ # Return the PIL image anyway, but with an error message
276
+ return output_pil, error_msg
277
+
278
+ # --- MODIFIED RETURN: Return PIL image and base64 string ---
279
+ # Remove save_path which is no longer needed for outputs
280
+ # Return the base64 string for the Textbox output (or success message)
281
+ success_msg = f"Success! Output size: {output_pil.size[0]}x{output_pil.size[1]}"
282
+ return output_pil, base64_output if base64_output else success_msg
283
+
284
+ except Exception as error:
285
+ # --- MODIFIED Main Exception Handling ---
286
+ error_msg = f"Error during GFPGAN processing: {error}"
287
+ print(error_msg)
288
+ print(traceback.format_exc())
289
+ # Return placeholder image (or None) and error message string
290
+ error_img = None
291
+ try:
292
+ error_img = Image.new('RGB', (100, 50), color = 'red')
293
+ except Exception: pass # Ignore if placeholder fails
294
+ return error_img, error_msg # Return placeholder/None and error string
295
+ # --- End Modified Main Exception Handling ---
296
+
297
+
298
+ # --- Gradio Interface Definition ---
299
+ title = "GFPGAN: Practical Face Restoration"
300
+ description = """Gradio demo for <a href='https://github.com/TencentARC/GFPGAN' target='_blank'><b>GFPGAN: Towards Real-World Blind Face Restoration with Generative Facial Prior</b></a>.
301
+ <br>Restore your <b>old photos</b> or improve <b>AI-generated faces</b>. Upload an image to start.
302
+ <br>If helpful, please ⭐ the <a href='https://github.com/TencentARC/GFPGAN' target='_blank'>Original Github Repo</a>.
303
+ <br>API endpoint available at `/predict`. Expects image filepath, version string, scale number. Returns Image and Textbox data.
304
+ """
305
+ article = "Questions? Contact the original creators (see GFPGAN repo)."
306
+
307
+ print("Creating Gradio interface...")
308
  try:
309
+ # --- MODIFIED INPUTS ---
310
+ # Changed gr.Image type to 'filepath'
311
+ inputs = [
312
+ gr.Image(type="filepath", label="Input Image", sources=["upload", "clipboard"]), # <-- TYPE CHANGED
313
+ gr.Radio(
314
+ ['v1.2', 'v1.3', 'v1.4', 'RestoreFormer'],
315
+ type="value", value='v1.4', label='GFPGAN Version',
316
+ info="v1.4 recommended. RestoreFormer for diverse poses."
317
+ ),
318
+ gr.Number(
319
+ label="Rescaling Factor", value=2,
320
+ info="Final output size multiplier relative to original input size (e.g., 2 = 2x original WxH)."
321
+ ),
322
+ ]
 
 
 
 
 
 
 
 
 
323
 
324
+ # --- MODIFIED OUTPUTS ---
325
+ # Removed gr.File, kept gr.Image and gr.Textbox
326
+ # Made Textbox visible for easier debugging
327
+ outputs = [
328
+ gr.Image(type="pil", label="Output Image"),
329
+ # gr.File(label="Download Output Image (Server Path)"), # <-- REMOVED
330
+ gr.Textbox(label="Output Info / Base64 Data", interactive=False, visible=True) # <-- KEPT (made visible)
331
+ ]
332
 
333
+ # Define examples using file paths (Gradio handles loading them)
334
+ # These should still work even with type="filepath" for the input component
335
+ examples = [
336
+ ['AI-generate.jpg', 'v1.4', 2],
337
+ ['lincoln.jpg', 'v1.4', 2],
338
+ ['Blake_Lively.jpg', 'v1.4', 2],
339
+ ['10045.png', 'v1.4', 2]
340
+ ]
341
+
342
+ # --- Gradio Interface Instantiation ---
343
+ # Ensure fn=inference points to the modified function
344
+ demo = gr.Interface(
345
+ fn=inference, # Should now accept filepath and return 2 items
346
+ inputs=inputs, # Use modified inputs list
347
+ outputs=outputs, # Use modified outputs list
348
+ title=title,
349
+ description=description,
350
+ article=article,
351
+ examples=examples,
352
+ cache_examples=False, # Caching might be complex with filepaths vs PIL
353
+ allow_flagging='never'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
354
  )
355
 
356
+ # Launch the interface
357
+ print("Launching Gradio interface...")
358
+ demo.queue().launch(server_name="0.0.0.0", share=False)
359
+ print("Gradio app launched successfully and should be accessible.")
360
+
361
+ except Exception as e:
362
+ print(f"Error setting up or launching Gradio interface: {e}")
363
+ print(traceback.format_exc())
364
+ sys.exit(1)