Spaces:
Build error
Build error
| import numpy as np | |
| import torch | |
| import torch.backends.cudnn as cudnn | |
| from PIL import Image | |
| import cv2 | |
| from nets.esrgan import Generator | |
| from utils.utils import cvtColor, preprocess_input | |
| class ESRGAN(object): | |
| #-----------------------------------------# | |
| # 注意修改model_path | |
| #-----------------------------------------# | |
| _defaults = { | |
| #-----------------------------------------------# | |
| # model_path指向logs文件夹下的权值文件 | |
| #-----------------------------------------------# | |
| "model_path" : 'model_data/Generator_ESRGAN.pth', | |
| #-----------------------------------------------# | |
| # 上采样的倍数,和训练时一样 | |
| #-----------------------------------------------# | |
| "scale_factor" : 8, | |
| #-------------------------------# | |
| # 是否使用Cuda | |
| # 没有GPU可以设置成False | |
| #-------------------------------# | |
| "cuda" : False, | |
| } | |
| #---------------------------------------------------# | |
| # 初始化SRGAN | |
| #---------------------------------------------------# | |
| def __init__(self, **kwargs): | |
| self.__dict__.update(self._defaults) | |
| for name, value in kwargs.items(): | |
| setattr(self, name, value) | |
| self.generate() | |
| def generate(self): | |
| self.net = Generator(self.scale_factor) | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| self.net.load_state_dict(torch.load(self.model_path, map_location=device)) | |
| self.net = self.net.eval() | |
| print('{} model, and classes loaded.'.format(self.model_path)) | |
| if self.cuda: | |
| self.net = torch.nn.DataParallel(self.net) | |
| cudnn.benchmark = True | |
| self.net = self.net.cuda() | |
| def generate_1x1_image(self, image): | |
| #---------------------------------------------------------# | |
| # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。 | |
| # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB | |
| #---------------------------------------------------------# | |
| image = cvtColor(image) | |
| #---------------------------------------------------------# | |
| # 添加上batch_size维度,并进行归一化 | |
| #---------------------------------------------------------# | |
| image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image, dtype=np.float32), [0.5,0.5,0.5], [0.5,0.5,0.5]), [2,0,1]), 0) | |
| with torch.no_grad(): | |
| image_data = torch.from_numpy(image_data).type(torch.FloatTensor) | |
| if self.cuda: | |
| image_data = image_data.cuda() | |
| #---------------------------------------------------------# | |
| # 将图像输入网络当中进行预测! | |
| #---------------------------------------------------------# | |
| hr_image = self.net(image_data)[0] | |
| #---------------------------------------------------------# | |
| # 将归一化的结果再转成rgb格式 | |
| #---------------------------------------------------------# | |
| hr_image = (hr_image.cpu().data.numpy().transpose(1, 2, 0) * 0.5 + 0.5) | |
| hr_image = (hr_image-np.min(hr_image))/(np.max(hr_image)-np.min(hr_image)) * 255 | |
| hr_image = Image.fromarray(np.uint8(hr_image)) | |
| return hr_image | |