Spaces:
Runtime error
Runtime error
| import os | |
| import cv2 | |
| import gradio as gr | |
| import torch | |
| import requests | |
| # ------------------------------------------------------------------------------ | |
| # Dependency Management | |
| # ------------------------------------------------------------------------------ | |
| # Instead of using os.system to manage dependencies in production, | |
| # it's recommended to use a requirements.txt file. | |
| # For this demo, we ensure that numpy and torchvision are of compatible versions. | |
| os.system("pip install --upgrade 'numpy<2'") | |
| os.system("pip install torchvision==0.12.0") # Fixes: ModuleNotFoundError for torchvision.transforms.functional_tensor | |
| # ------------------------------------------------------------------------------ | |
| # Utility Function: Download Weight Files | |
| # ------------------------------------------------------------------------------ | |
| def download_file(filename, url): | |
| """ | |
| ELI5: If the file (like a model weight) isn't on your computer, download it! | |
| """ | |
| if not os.path.exists(filename): | |
| print(f"Downloading {filename} from {url}...") | |
| response = requests.get(url, stream=True) | |
| if response.status_code == 200: | |
| with open(filename, 'wb') as f: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| if chunk: | |
| f.write(chunk) | |
| else: | |
| print(f"Failed to download {filename}") | |
| # ------------------------------------------------------------------------------ | |
| # Download Required Model Weights | |
| # ------------------------------------------------------------------------------ | |
| weights = { | |
| "realesr-general-x4v3.pth": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth", | |
| "GFPGANv1.2.pth": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.2.pth", | |
| "GFPGANv1.3.pth": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth", | |
| "GFPGANv1.4.pth": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth", | |
| "RestoreFormer.pth": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth", | |
| "CodeFormer.pth": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/CodeFormer.pth", | |
| } | |
| for filename, url in weights.items(): | |
| download_file(filename, url) | |
| # ------------------------------------------------------------------------------ | |
| # Import Model-Related Modules After Ensuring Dependencies | |
| # ------------------------------------------------------------------------------ | |
| from basicsr.archs.srvgg_arch import SRVGGNetCompact | |
| from gfpgan.utils import GFPGANer | |
| from realesrgan.utils import RealESRGANer | |
| # ------------------------------------------------------------------------------ | |
| # Initialize ESRGAN Upsampler | |
| # ------------------------------------------------------------------------------ | |
| # ELI5: We build a mini brain (model) to help make images look better. | |
| model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') | |
| model_path = 'realesr-general-x4v3.pth' | |
| half = torch.cuda.is_available() # Use half-precision if you have a GPU. | |
| upsampler = RealESRGANer( | |
| scale=4, | |
| model_path=model_path, | |
| model=model, | |
| tile=0, | |
| tile_pad=10, | |
| pre_pad=0, | |
| half=half | |
| ) | |
| # Create output directory for saving enhanced images. | |
| os.makedirs('output', exist_ok=True) | |
| # ------------------------------------------------------------------------------ | |
| # Image Inference Function | |
| # ------------------------------------------------------------------------------ | |
| def inference(img, version, scale): | |
| """ | |
| ELI5: This function takes your uploaded image, picks a model version, | |
| and a scaling factor. It then: | |
| 1. Reads your image. | |
| 2. Checks if it's in a special format (like with transparency). | |
| 3. Resizes small images for better processing. | |
| 4. Uses a face enhancement model (GFPGAN) and a background upscaler (RealESRGAN) | |
| to make the image look better. | |
| 5. Optionally resizes the final image. | |
| 6. Saves and returns the enhanced image. | |
| """ | |
| try: | |
| # Read the image from the provided file path. | |
| img_path = str(img) | |
| extension = os.path.splitext(os.path.basename(img_path))[1] | |
| img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) | |
| if img is None: | |
| print("Error: Could not read the image. Please check the file.") | |
| return None, None | |
| # Determine the image mode: RGBA (has transparency) or not. | |
| if len(img.shape) == 3 and img.shape[2] == 4: | |
| img_mode = 'RGBA' | |
| elif len(img.shape) == 2: | |
| # If the image is grayscale, convert it to a color image. | |
| img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |
| img_mode = None | |
| else: | |
| img_mode = None | |
| # If the image is too small, double its size. | |
| h, w = img.shape[:2] | |
| if h < 300: | |
| img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4) | |
| # Map the selected model version to its weight file. | |
| model_paths = { | |
| 'v1.2': 'GFPGANv1.2.pth', | |
| 'v1.3': 'GFPGANv1.3.pth', | |
| 'v1.4': 'GFPGANv1.4.pth', | |
| 'RestoreFormer': 'RestoreFormer.pth', | |
| 'CodeFormer': 'CodeFormer.pth', | |
| 'RealESR-General-x4v3': 'realesr-general-x4v3.pth' | |
| } | |
| # Initialize GFPGAN for face enhancement. | |
| face_enhancer = GFPGANer( | |
| model_path=model_paths[version], | |
| upscale=2, # Face region upscale factor. | |
| arch='clean' if version.startswith('v1') else version, | |
| channel_multiplier=2, | |
| bg_upsampler=upsampler # Use the ESRGAN upsampler for background. | |
| ) | |
| # Enhance the image. | |
| _, _, output = face_enhancer.enhance( | |
| img, has_aligned=False, only_center_face=False, paste_back=True | |
| ) | |
| # Optionally, further rescale the enhanced image. | |
| if scale != 2: | |
| interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4 | |
| h, w = output.shape[:2] | |
| output = cv2.resize(output, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation) | |
| # Decide on file extension based on image mode. | |
| extension = 'png' if img_mode == 'RGBA' else 'jpg' | |
| save_path = os.path.join('output', f'out.{extension}') | |
| # Save the enhanced image. | |
| cv2.imwrite(save_path, output) | |
| # Convert BGR to RGB for proper display in Gradio. | |
| output_rgb = cv2.cvtColor(output, cv2.COLOR_BGR2RGB) | |
| return output_rgb, save_path | |
| except Exception as error: | |
| print("Error during inference:", error) | |
| return None, None | |
| # ------------------------------------------------------------------------------ | |
| # Build the Gradio UI | |
| # ------------------------------------------------------------------------------ | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## 📸 Image Upscaling & Restoration") | |
| gr.Markdown("### Enhance your images using GFPGAN & RealESRGAN with a friendly UI!") | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(type="filepath", label="Upload Your Image") | |
| version_selector = gr.Radio( | |
| choices=['v1.2', 'v1.3', 'v1.4', 'RestoreFormer', 'CodeFormer', 'RealESR-General-x4v3'], | |
| label="Select Model Version", | |
| value="v1.4" | |
| ) | |
| scale_factor = gr.Number(value=2, label="Rescaling Factor (e.g., 2 for default)") | |
| enhance_button = gr.Button("Enhance Image 🚀") | |
| with gr.Column(): | |
| output_image = gr.Image(type="numpy", label="Enhanced Image") | |
| download_link = gr.File(label="Download Enhanced Image") | |
| # Link the button click to the inference function. | |
| enhance_button.click( | |
| fn=inference, | |
| inputs=[image_input, version_selector, scale_factor], | |
| outputs=[output_image, download_link] | |
| ) | |
| # ------------------------------------------------------------------------------ | |
| # Launch the Gradio App | |
| # ------------------------------------------------------------------------------ | |
| demo.launch() | |