Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -3,12 +3,15 @@ import os
|
|
3 |
import sys
|
4 |
import cv2
|
5 |
import torch
|
|
|
6 |
import numpy as np
|
7 |
from PIL import Image, ImageOps
|
8 |
import io
|
9 |
import base64
|
10 |
import traceback
|
11 |
-
import
|
|
|
|
|
12 |
import spaces
|
13 |
|
14 |
# Import model-specific libraries
|
@@ -19,7 +22,7 @@ try:
|
|
19 |
print("Successfully imported model libraries.")
|
20 |
except ImportError as e:
|
21 |
print(f"Error importing model libraries: {e}")
|
22 |
-
print("Please ensure basicsr, gfpgan, realesrgan are installed
|
23 |
sys.exit(1)
|
24 |
|
25 |
# --- Constants ---
|
@@ -69,7 +72,7 @@ except Exception as e:
|
|
69 |
print(traceback.format_exc())
|
70 |
print("Warning: GFPGAN will run without background enhancement.")
|
71 |
|
72 |
-
# ---
|
73 |
@spaces.GPU(duration=90)
|
74 |
def process_image(input_image, version, scale):
|
75 |
"""
|
@@ -265,28 +268,111 @@ def process_image(input_image, version, scale):
|
|
265 |
pass
|
266 |
return error_img, error_msg
|
267 |
|
268 |
-
# ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
269 |
@spaces.GPU(duration=90)
|
270 |
def inference(input_image, version, scale):
|
271 |
"""
|
272 |
API-friendly wrapper that ensures consistent behavior between web and API interfaces.
|
273 |
"""
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
279 |
|
280 |
# --- Gradio Interface Definition ---
|
281 |
title = "GFPGAN: Practical Face Restoration"
|
282 |
description = """Gradio demo for <a href='https://github.com/TencentARC/GFPGAN' target='_blank'><b>GFPGAN: Towards Real-World Blind Face Restoration with Generative Facial Prior</b></a>.
|
283 |
<br>Restore your <b>old photos</b> or improve <b>AI-generated faces</b>. Upload an image to start.
|
284 |
<br>If helpful, please ⭐ the <a href='https://github.com/TencentARC/GFPGAN' target='_blank'>Original Github Repo</a>.
|
285 |
-
<br>API endpoint available at `/predict
|
286 |
"""
|
287 |
article = "Questions? Contact the original creators (see GFPGAN repo)."
|
288 |
|
289 |
-
#
|
290 |
inputs = [
|
291 |
gr.Image(type="pil", label="Input Image", sources=["upload", "clipboard"]),
|
292 |
gr.Radio(
|
@@ -325,6 +411,10 @@ demo = gr.Interface(
|
|
325 |
allow_flagging='never'
|
326 |
)
|
327 |
|
|
|
|
|
|
|
328 |
# Launch the interface
|
329 |
if __name__ == "__main__":
|
330 |
-
|
|
|
|
3 |
import sys
|
4 |
import cv2
|
5 |
import torch
|
6 |
+
import gradio as gr
|
7 |
import numpy as np
|
8 |
from PIL import Image, ImageOps
|
9 |
import io
|
10 |
import base64
|
11 |
import traceback
|
12 |
+
import tempfile
|
13 |
+
from fastapi import FastAPI, File, UploadFile
|
14 |
+
from fastapi.middleware.cors import CORSMiddleware
|
15 |
import spaces
|
16 |
|
17 |
# Import model-specific libraries
|
|
|
22 |
print("Successfully imported model libraries.")
|
23 |
except ImportError as e:
|
24 |
print(f"Error importing model libraries: {e}")
|
25 |
+
print("Please ensure basicsr, gfpgan, realesrgan are installed")
|
26 |
sys.exit(1)
|
27 |
|
28 |
# --- Constants ---
|
|
|
72 |
print(traceback.format_exc())
|
73 |
print("Warning: GFPGAN will run without background enhancement.")
|
74 |
|
75 |
+
# --- Universal processing function ---
|
76 |
@spaces.GPU(duration=90)
|
77 |
def process_image(input_image, version, scale):
|
78 |
"""
|
|
|
268 |
pass
|
269 |
return error_img, error_msg
|
270 |
|
271 |
+
# --- Function to handle file upload for API ---
|
272 |
+
def handle_file_upload(file_data):
|
273 |
+
"""Save uploaded file to temporary directory and return path"""
|
274 |
+
try:
|
275 |
+
print(f"Handling file upload: {type(file_data)}")
|
276 |
+
|
277 |
+
# Create a temporary file
|
278 |
+
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg')
|
279 |
+
temp_path = temp_file.name
|
280 |
+
|
281 |
+
# If it's bytes, write directly
|
282 |
+
if isinstance(file_data, bytes):
|
283 |
+
with open(temp_path, 'wb') as f:
|
284 |
+
f.write(file_data)
|
285 |
+
# If it's a file-like object (from FastAPI/Gradio)
|
286 |
+
elif hasattr(file_data, 'file'):
|
287 |
+
content = file_data.file.read()
|
288 |
+
with open(temp_path, 'wb') as f:
|
289 |
+
f.write(content)
|
290 |
+
# If it's a string path, it's already saved
|
291 |
+
elif isinstance(file_data, str) and os.path.exists(file_data):
|
292 |
+
return file_data
|
293 |
+
else:
|
294 |
+
raise ValueError(f"Unsupported file data type: {type(file_data)}")
|
295 |
+
|
296 |
+
print(f"File saved to temporary path: {temp_path}")
|
297 |
+
return temp_path
|
298 |
+
|
299 |
+
except Exception as e:
|
300 |
+
print(f"Error handling file upload: {e}")
|
301 |
+
print(traceback.format_exc())
|
302 |
+
raise
|
303 |
+
|
304 |
+
# --- API inference function ---
|
305 |
@spaces.GPU(duration=90)
|
306 |
def inference(input_image, version, scale):
|
307 |
"""
|
308 |
API-friendly wrapper that ensures consistent behavior between web and API interfaces.
|
309 |
"""
|
310 |
+
try:
|
311 |
+
# If input is a file upload (from API), save it to a temporary path
|
312 |
+
if not isinstance(input_image, (str, Image.Image, np.ndarray)) and not (hasattr(input_image, 'name') and os.path.exists(input_image.name)):
|
313 |
+
file_path = handle_file_upload(input_image)
|
314 |
+
input_image = file_path
|
315 |
+
|
316 |
+
# Process the image
|
317 |
+
output_pil, base64_or_msg = process_image(input_image, version, scale)
|
318 |
+
|
319 |
+
# Return the processed results
|
320 |
+
return output_pil, base64_or_msg
|
321 |
+
except Exception as e:
|
322 |
+
print(f"Error in inference: {e}")
|
323 |
+
print(traceback.format_exc())
|
324 |
+
# Return a placeholder error image and message
|
325 |
+
error_img = Image.new('RGB', (100, 50), color='red')
|
326 |
+
return error_img, f"Error: {str(e)}"
|
327 |
+
|
328 |
+
# --- Get the FastAPI app from Gradio ---
|
329 |
+
app = FastAPI()
|
330 |
+
|
331 |
+
# Add CORS middleware to allow cross-origin requests
|
332 |
+
app.add_middleware(
|
333 |
+
CORSMiddleware,
|
334 |
+
allow_origins=["*"], # Allows all origins
|
335 |
+
allow_credentials=True,
|
336 |
+
allow_methods=["*"], # Allows all methods
|
337 |
+
allow_headers=["*"], # Allows all headers
|
338 |
+
)
|
339 |
+
|
340 |
+
# --- Direct API endpoint for file upload ---
|
341 |
+
@app.post("/api/direct-process")
|
342 |
+
async def direct_process(file: UploadFile = File(...), version: str = "v1.4", scale: float = 2.0):
|
343 |
+
try:
|
344 |
+
# Save the uploaded file
|
345 |
+
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg')
|
346 |
+
temp_path = temp_file.name
|
347 |
+
with open(temp_path, 'wb') as f:
|
348 |
+
f.write(await file.read())
|
349 |
+
|
350 |
+
# Process the image
|
351 |
+
_, base64_image = process_image(temp_path, version, scale)
|
352 |
+
|
353 |
+
# Clean up
|
354 |
+
os.unlink(temp_path)
|
355 |
+
|
356 |
+
# Return base64 image data
|
357 |
+
if base64_image and base64_image.startswith('data:image'):
|
358 |
+
return {"success": True, "image": base64_image}
|
359 |
+
else:
|
360 |
+
return {"success": False, "error": base64_image or "Unknown error"}
|
361 |
+
except Exception as e:
|
362 |
+
print(f"Error in direct-process API: {e}")
|
363 |
+
print(traceback.format_exc())
|
364 |
+
return {"success": False, "error": str(e)}
|
365 |
|
366 |
# --- Gradio Interface Definition ---
|
367 |
title = "GFPGAN: Practical Face Restoration"
|
368 |
description = """Gradio demo for <a href='https://github.com/TencentARC/GFPGAN' target='_blank'><b>GFPGAN: Towards Real-World Blind Face Restoration with Generative Facial Prior</b></a>.
|
369 |
<br>Restore your <b>old photos</b> or improve <b>AI-generated faces</b>. Upload an image to start.
|
370 |
<br>If helpful, please ⭐ the <a href='https://github.com/TencentARC/GFPGAN' target='_blank'>Original Github Repo</a>.
|
371 |
+
<br>API endpoint available at `/predict` or `/api/direct-process`. Returns processed image and base64 data.
|
372 |
"""
|
373 |
article = "Questions? Contact the original creators (see GFPGAN repo)."
|
374 |
|
375 |
+
# Use upload component for more compatibility
|
376 |
inputs = [
|
377 |
gr.Image(type="pil", label="Input Image", sources=["upload", "clipboard"]),
|
378 |
gr.Radio(
|
|
|
411 |
allow_flagging='never'
|
412 |
)
|
413 |
|
414 |
+
# Mount the Gradio app
|
415 |
+
app = gr.mount_gradio_app(app, demo, path="/")
|
416 |
+
|
417 |
# Launch the interface
|
418 |
if __name__ == "__main__":
|
419 |
+
import uvicorn
|
420 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|