|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
import torch.nn.functional as F |
|
import torchvision.transforms as transforms |
|
import os |
|
from PIL import Image |
|
from torch.utils.data import Dataset,dataloader |
|
from torchvision import transforms |
|
from torch.utils.data import DataLoader |
|
|
|
class UPSC(nn.Module): |
|
def __init__(self): |
|
super(UPSC,self).__init__() |
|
self.model = nn.Sequential( |
|
nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5, padding=2), |
|
nn.ReLU(), |
|
nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, padding=1), |
|
nn.ReLU(), |
|
|
|
nn.Conv2d(in_channels=32, out_channels=3 * 3 * 3, kernel_size=3, padding=1), |
|
|
|
nn.PixelShuffle(3) |
|
) |
|
def forward(self, x): |
|
return self.model(x) |
|
|
|
|
|
|
|
class PairedSuperResolutionDataset(Dataset): |
|
def __init__(self, lr_dir, hr_dir, lr_size=(64, 64), hr_size=(256, 256)): |
|
self.lr_dir = lr_dir |
|
self.hr_dir = hr_dir |
|
self.lr_files = sorted(os.listdir(lr_dir)) |
|
self.hr_files = sorted(os.listdir(hr_dir)) |
|
|
|
self.transform_lr = transforms.Compose([ |
|
transforms.Resize(lr_size), |
|
transforms.ToTensor() |
|
]) |
|
|
|
self.transform_hr = transforms.Compose([ |
|
transforms.Resize(hr_size), |
|
transforms.ToTensor() |
|
]) |
|
|
|
def __len__(self): |
|
return len(self.lr_files) |
|
|
|
def __getitem__(self, idx): |
|
lr_path = os.path.join(self.lr_dir, self.lr_files[idx]) |
|
hr_path = os.path.join(self.hr_dir, self.hr_files[idx]) |
|
|
|
lr_img = Image.open(lr_path).convert("RGB") |
|
hr_img = Image.open(hr_path).convert("RGB") |
|
|
|
lr_tensor = self.transform_lr(lr_img) |
|
hr_tensor = self.transform_hr(hr_img) |
|
|
|
return lr_tensor, hr_tensor |
|
|
|
|
|
lr_dir = '/Users/aaronvattay/Documents/DF2K_train_LR_bicubic/X3' |
|
hr_dir = '/Users/aaronvattay/Documents/DF2K_train_HR' |
|
batch_size = 16 |
|
num_epochs = 10 |
|
learning_rate = 1e-4 |
|
|
|
|
|
|
|
dataset = PairedSuperResolutionDataset( |
|
lr_dir=lr_dir, |
|
hr_dir=hr_dir, |
|
lr_size=(256,256), |
|
hr_size=(768,768) |
|
) |
|
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) |
|
|
|
|
|
device = torch.device("mps") |
|
|
|
|
|
model = UPSC().to(device) |
|
criterion = nn.MSELoss() |
|
optimizer = optim.Adam(model.parameters(), lr=learning_rate) |
|
|
|
if os.path.exists("upscaling.pth"): |
|
model.load_state_dict(torch.load("upscaling.pth",map_location=device,weights_only=True)) |
|
|
|
model.train() |
|
if __name__ == "__main__": |
|
for epoch in range(num_epochs): |
|
epoch_loss = 0.0 |
|
for lr_imgs, hr_imgs in dataloader: |
|
|
|
lr_imgs, hr_imgs = lr_imgs.to(device), hr_imgs.to(device) |
|
|
|
|
|
outputs = model(lr_imgs) |
|
loss = criterion(outputs, hr_imgs) |
|
|
|
|
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
|
|
epoch_loss += loss.item() |
|
|
|
avg_loss = epoch_loss / len(dataloader) |
|
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.6f}") |
|
|
|
|
|
torch.save(model.state_dict(), "upscaling.pth") |