danhtran2mind's picture
Upload 82 files
5e1b2e8 verified
import argparse
from PIL import Image
import numpy as np
import torch
import os
import sys
import yaml
from huggingface_hub import hf_hub_download
import cv2
# Add the project root directory to the Python path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__))))
from inference.real_esrgan_inference import RealESRGAN
def get_model_checkpoint(model_id, models_config_path):
try:
with open(models_config_path, 'r') as file:
config_list = yaml.safe_load(file)
except yaml.YAMLError as e:
print(f"Error loading YAML: {e}")
exit(1)
# Find the specific model configuration
model_config = next((item for item in config_list if item['model_id'] == model_id), None)
if model_config is None:
print("Error: Model ID 'danhtran2mind/Real-ESRGAN-Anime-finetuning' not found in configuration.")
exit(1)
model_path = os.path.join(model_config["local_dir"], model_config["filename"])
if not os.path.exists(model_path):
hf_hub_download(repo_id=model_config["model_id"],
filename=model_config["filename"],
local_dir=model_config["local_dir"],
local_dir_use_symlinks=False)
print('Weights downloaded to:', model_path)
return model_path
def infer(input_path, model_id, models_config_path, outer_scale, inner_scale=4, output_path=None):
device = "cuda" if torch.cuda.is_available() else "cpu"
model = RealESRGAN(device, scale=inner_scale)
model_path = get_model_checkpoint(model_id, models_config_path)
model.load_weights(model_path)
image = Image.open(input_path).convert('RGB')
output_image = model.predict(image)
if outer_scale != inner_scale:
factor = outer_scale / inner_scale
output_image_np = np.array(output_image)
new_width = int(output_image.width * factor)
new_height = int(output_image.height * factor)
if factor > 1:
interpolation = cv2.INTER_CUBIC
else:
interpolation = cv2.INTER_AREA
output_image_np = cv2.resize(output_image_np, (new_width, new_height), interpolation=interpolation)
output_image = Image.fromarray(output_image_np)
if output_path:
output_image.save(output_path)
# else:
# # If no output path is provided, create a default output path
# output_path = input_path.rsplit('.', 1)[0] + '_out.png'
# output_image.save(output_path)
return output_image
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Super-resolution for anime images using Real-ESRGAN")
parser.add_argument('--input_path', type=str, required=True, help="Path to the input image")
parser.add_argument('--output_path', type=str, default=None, help="Path to save the output image")
parser.add_argument('--model_id', type=str, required=True, help="Model ID for Real-ESRGAN")
parser.add_argument('--models_config_path', type=str, required=True, help="Path to the models configuration YAML file")
parser.add_argument('--batch_size', type=int, default=1, help="Batch size for inference (not used in this implementation)")
parser.add_argument('--outer_scale', type=int, required=True, help="Outer scale for super-resolution")
parser.add_argument('--inner_scale', type=int, default=4, help="Inner scale for the model")
args = parser.parse_args()
# # Read the models_config file
# with open(args.models_config_path, 'r') as file:
# models_config_path = file.read()
# Call infer with the correct arguments
infer(args.input_path, args.model_id, args.models_config_path,
args.outer_scale, args.inner_scale, args.output_path)