import argparse from PIL import Image import numpy as np import torch import os import sys import yaml from huggingface_hub import hf_hub_download import cv2 # Add the project root directory to the Python path sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__)))) from inference.real_esrgan_inference import RealESRGAN def get_model_checkpoint(model_id, models_config_path): try: with open(models_config_path, 'r') as file: config_list = yaml.safe_load(file) except yaml.YAMLError as e: print(f"Error loading YAML: {e}") exit(1) # Find the specific model configuration model_config = next((item for item in config_list if item['model_id'] == model_id), None) if model_config is None: print("Error: Model ID 'danhtran2mind/Real-ESRGAN-Anime-finetuning' not found in configuration.") exit(1) model_path = os.path.join(model_config["local_dir"], model_config["filename"]) if not os.path.exists(model_path): hf_hub_download(repo_id=model_config["model_id"], filename=model_config["filename"], local_dir=model_config["local_dir"], local_dir_use_symlinks=False) print('Weights downloaded to:', model_path) return model_path def infer(input_path, model_id, models_config_path, outer_scale, inner_scale=4, output_path=None): device = "cuda" if torch.cuda.is_available() else "cpu" model = RealESRGAN(device, scale=inner_scale) model_path = get_model_checkpoint(model_id, models_config_path) model.load_weights(model_path) image = Image.open(input_path).convert('RGB') output_image = model.predict(image) if outer_scale != inner_scale: factor = outer_scale / inner_scale output_image_np = np.array(output_image) new_width = int(output_image.width * factor) new_height = int(output_image.height * factor) if factor > 1: interpolation = cv2.INTER_CUBIC else: interpolation = cv2.INTER_AREA output_image_np = cv2.resize(output_image_np, (new_width, new_height), interpolation=interpolation) output_image = Image.fromarray(output_image_np) if output_path: output_image.save(output_path) # else: # # If no output path is provided, create a default output path # output_path = input_path.rsplit('.', 1)[0] + '_out.png' # output_image.save(output_path) return output_image if __name__ == "__main__": parser = argparse.ArgumentParser(description="Super-resolution for anime images using Real-ESRGAN") parser.add_argument('--input_path', type=str, required=True, help="Path to the input image") parser.add_argument('--output_path', type=str, default=None, help="Path to save the output image") parser.add_argument('--model_id', type=str, required=True, help="Model ID for Real-ESRGAN") parser.add_argument('--models_config_path', type=str, required=True, help="Path to the models configuration YAML file") parser.add_argument('--batch_size', type=int, default=1, help="Batch size for inference (not used in this implementation)") parser.add_argument('--outer_scale', type=int, required=True, help="Outer scale for super-resolution") parser.add_argument('--inner_scale', type=int, default=4, help="Inner scale for the model") args = parser.parse_args() # # Read the models_config file # with open(args.models_config_path, 'r') as file: # models_config_path = file.read() # Call infer with the correct arguments infer(args.input_path, args.model_id, args.models_config_path, args.outer_scale, args.inner_scale, args.output_path)