Gemini899 commited on
Commit
3162525
·
verified ·
1 Parent(s): a237eb7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -12
app.py CHANGED
@@ -3,12 +3,15 @@ import os
3
  import sys
4
  import cv2
5
  import torch
 
6
  import numpy as np
7
  from PIL import Image, ImageOps
8
  import io
9
  import base64
10
  import traceback
11
- import gradio as gr
 
 
12
  import spaces
13
 
14
  # Import model-specific libraries
@@ -19,7 +22,7 @@ try:
19
  print("Successfully imported model libraries.")
20
  except ImportError as e:
21
  print(f"Error importing model libraries: {e}")
22
- print("Please ensure basicsr, gfpgan, realesrgan are installed or in requirements.txt")
23
  sys.exit(1)
24
 
25
  # --- Constants ---
@@ -69,7 +72,7 @@ except Exception as e:
69
  print(traceback.format_exc())
70
  print("Warning: GFPGAN will run without background enhancement.")
71
 
72
- # --- IMPROVED: Enhanced processing function that handles both file paths and PIL images ---
73
  @spaces.GPU(duration=90)
74
  def process_image(input_image, version, scale):
75
  """
@@ -265,28 +268,111 @@ def process_image(input_image, version, scale):
265
  pass
266
  return error_img, error_msg
267
 
268
- # --- Improved API wrapper function to handle both web UI and API calls ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
  @spaces.GPU(duration=90)
270
  def inference(input_image, version, scale):
271
  """
272
  API-friendly wrapper that ensures consistent behavior between web and API interfaces.
273
  """
274
- # Process the image using our universal processor
275
- output_pil, base64_or_msg = process_image(input_image, version, scale)
276
-
277
- # Return the processed results in a format suitable for both UI and API
278
- return output_pil, base64_or_msg
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
 
280
  # --- Gradio Interface Definition ---
281
  title = "GFPGAN: Practical Face Restoration"
282
  description = """Gradio demo for <a href='https://github.com/TencentARC/GFPGAN' target='_blank'><b>GFPGAN: Towards Real-World Blind Face Restoration with Generative Facial Prior</b></a>.
283
  <br>Restore your <b>old photos</b> or improve <b>AI-generated faces</b>. Upload an image to start.
284
  <br>If helpful, please ⭐ the <a href='https://github.com/TencentARC/GFPGAN' target='_blank'>Original Github Repo</a>.
285
- <br>API endpoint available at `/predict`. Expects image file, version string, scale number. Returns Image and Base64 data.
286
  """
287
  article = "Questions? Contact the original creators (see GFPGAN repo)."
288
 
289
- # --- IMPROVED: Dual input type handling (both "filepath" for UI and "pil" for API) ---
290
  inputs = [
291
  gr.Image(type="pil", label="Input Image", sources=["upload", "clipboard"]),
292
  gr.Radio(
@@ -325,6 +411,10 @@ demo = gr.Interface(
325
  allow_flagging='never'
326
  )
327
 
 
 
 
328
  # Launch the interface
329
  if __name__ == "__main__":
330
- demo.queue().launch(server_name="0.0.0.0", share=False)
 
 
3
  import sys
4
  import cv2
5
  import torch
6
+ import gradio as gr
7
  import numpy as np
8
  from PIL import Image, ImageOps
9
  import io
10
  import base64
11
  import traceback
12
+ import tempfile
13
+ from fastapi import FastAPI, File, UploadFile
14
+ from fastapi.middleware.cors import CORSMiddleware
15
  import spaces
16
 
17
  # Import model-specific libraries
 
22
  print("Successfully imported model libraries.")
23
  except ImportError as e:
24
  print(f"Error importing model libraries: {e}")
25
+ print("Please ensure basicsr, gfpgan, realesrgan are installed")
26
  sys.exit(1)
27
 
28
  # --- Constants ---
 
72
  print(traceback.format_exc())
73
  print("Warning: GFPGAN will run without background enhancement.")
74
 
75
+ # --- Universal processing function ---
76
  @spaces.GPU(duration=90)
77
  def process_image(input_image, version, scale):
78
  """
 
268
  pass
269
  return error_img, error_msg
270
 
271
+ # --- Function to handle file upload for API ---
272
+ def handle_file_upload(file_data):
273
+ """Save uploaded file to temporary directory and return path"""
274
+ try:
275
+ print(f"Handling file upload: {type(file_data)}")
276
+
277
+ # Create a temporary file
278
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg')
279
+ temp_path = temp_file.name
280
+
281
+ # If it's bytes, write directly
282
+ if isinstance(file_data, bytes):
283
+ with open(temp_path, 'wb') as f:
284
+ f.write(file_data)
285
+ # If it's a file-like object (from FastAPI/Gradio)
286
+ elif hasattr(file_data, 'file'):
287
+ content = file_data.file.read()
288
+ with open(temp_path, 'wb') as f:
289
+ f.write(content)
290
+ # If it's a string path, it's already saved
291
+ elif isinstance(file_data, str) and os.path.exists(file_data):
292
+ return file_data
293
+ else:
294
+ raise ValueError(f"Unsupported file data type: {type(file_data)}")
295
+
296
+ print(f"File saved to temporary path: {temp_path}")
297
+ return temp_path
298
+
299
+ except Exception as e:
300
+ print(f"Error handling file upload: {e}")
301
+ print(traceback.format_exc())
302
+ raise
303
+
304
+ # --- API inference function ---
305
  @spaces.GPU(duration=90)
306
  def inference(input_image, version, scale):
307
  """
308
  API-friendly wrapper that ensures consistent behavior between web and API interfaces.
309
  """
310
+ try:
311
+ # If input is a file upload (from API), save it to a temporary path
312
+ if not isinstance(input_image, (str, Image.Image, np.ndarray)) and not (hasattr(input_image, 'name') and os.path.exists(input_image.name)):
313
+ file_path = handle_file_upload(input_image)
314
+ input_image = file_path
315
+
316
+ # Process the image
317
+ output_pil, base64_or_msg = process_image(input_image, version, scale)
318
+
319
+ # Return the processed results
320
+ return output_pil, base64_or_msg
321
+ except Exception as e:
322
+ print(f"Error in inference: {e}")
323
+ print(traceback.format_exc())
324
+ # Return a placeholder error image and message
325
+ error_img = Image.new('RGB', (100, 50), color='red')
326
+ return error_img, f"Error: {str(e)}"
327
+
328
+ # --- Get the FastAPI app from Gradio ---
329
+ app = FastAPI()
330
+
331
+ # Add CORS middleware to allow cross-origin requests
332
+ app.add_middleware(
333
+ CORSMiddleware,
334
+ allow_origins=["*"], # Allows all origins
335
+ allow_credentials=True,
336
+ allow_methods=["*"], # Allows all methods
337
+ allow_headers=["*"], # Allows all headers
338
+ )
339
+
340
+ # --- Direct API endpoint for file upload ---
341
+ @app.post("/api/direct-process")
342
+ async def direct_process(file: UploadFile = File(...), version: str = "v1.4", scale: float = 2.0):
343
+ try:
344
+ # Save the uploaded file
345
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg')
346
+ temp_path = temp_file.name
347
+ with open(temp_path, 'wb') as f:
348
+ f.write(await file.read())
349
+
350
+ # Process the image
351
+ _, base64_image = process_image(temp_path, version, scale)
352
+
353
+ # Clean up
354
+ os.unlink(temp_path)
355
+
356
+ # Return base64 image data
357
+ if base64_image and base64_image.startswith('data:image'):
358
+ return {"success": True, "image": base64_image}
359
+ else:
360
+ return {"success": False, "error": base64_image or "Unknown error"}
361
+ except Exception as e:
362
+ print(f"Error in direct-process API: {e}")
363
+ print(traceback.format_exc())
364
+ return {"success": False, "error": str(e)}
365
 
366
  # --- Gradio Interface Definition ---
367
  title = "GFPGAN: Practical Face Restoration"
368
  description = """Gradio demo for <a href='https://github.com/TencentARC/GFPGAN' target='_blank'><b>GFPGAN: Towards Real-World Blind Face Restoration with Generative Facial Prior</b></a>.
369
  <br>Restore your <b>old photos</b> or improve <b>AI-generated faces</b>. Upload an image to start.
370
  <br>If helpful, please ⭐ the <a href='https://github.com/TencentARC/GFPGAN' target='_blank'>Original Github Repo</a>.
371
+ <br>API endpoint available at `/predict` or `/api/direct-process`. Returns processed image and base64 data.
372
  """
373
  article = "Questions? Contact the original creators (see GFPGAN repo)."
374
 
375
+ # Use upload component for more compatibility
376
  inputs = [
377
  gr.Image(type="pil", label="Input Image", sources=["upload", "clipboard"]),
378
  gr.Radio(
 
411
  allow_flagging='never'
412
  )
413
 
414
+ # Mount the Gradio app
415
+ app = gr.mount_gradio_app(app, demo, path="/")
416
+
417
  # Launch the interface
418
  if __name__ == "__main__":
419
+ import uvicorn
420
+ uvicorn.run(app, host="0.0.0.0", port=7860)