AIupscaling / AIupscale_train.py
Aaron Vattay
Model relase
a6a4c31
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
import os
from PIL import Image
from torch.utils.data import Dataset,dataloader
from torchvision import transforms
from torch.utils.data import DataLoader
class UPSC(nn.Module):
def __init__(self):
super(UPSC,self).__init__()
self.model = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5, padding=2),
nn.ReLU(),
nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, padding=1),
nn.ReLU(),
# This convolution outputs channels that are scale_factor^2 * number_of_channels.
nn.Conv2d(in_channels=32, out_channels=3 * 3 * 3, kernel_size=3, padding=1),
# PixelShuffle rearranges channels into spatial dimensions.
nn.PixelShuffle(3)
)
def forward(self, x):
return self.model(x)
class PairedSuperResolutionDataset(Dataset):
def __init__(self, lr_dir, hr_dir, lr_size=(64, 64), hr_size=(256, 256)):
self.lr_dir = lr_dir
self.hr_dir = hr_dir
self.lr_files = sorted(os.listdir(lr_dir))
self.hr_files = sorted(os.listdir(hr_dir))
self.transform_lr = transforms.Compose([
transforms.Resize(lr_size),
transforms.ToTensor()
])
self.transform_hr = transforms.Compose([
transforms.Resize(hr_size),
transforms.ToTensor()
])
def __len__(self):
return len(self.lr_files)
def __getitem__(self, idx):
lr_path = os.path.join(self.lr_dir, self.lr_files[idx])
hr_path = os.path.join(self.hr_dir, self.hr_files[idx])
lr_img = Image.open(lr_path).convert("RGB")
hr_img = Image.open(hr_path).convert("RGB")
lr_tensor = self.transform_lr(lr_img)
hr_tensor = self.transform_hr(hr_img)
return lr_tensor, hr_tensor
lr_dir = '/Users/aaronvattay/Documents/DF2K_train_LR_bicubic/X3'
hr_dir = '/Users/aaronvattay/Documents/DF2K_train_HR'
batch_size = 16
num_epochs = 10
learning_rate = 1e-4
# Create dataset and dataloader
dataset = PairedSuperResolutionDataset(
lr_dir=lr_dir,
hr_dir=hr_dir,
lr_size=(256,256),
hr_size=(768,768)
)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# Device configuration
device = torch.device("mps")
# Initialize model, loss, and optimizer
model = UPSC().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# Load the model state if available
if os.path.exists("upscaling.pth"):
model.load_state_dict(torch.load("upscaling.pth",map_location=device,weights_only=True))
# Set the model to training mode
model.train()
if __name__ == "__main__":
for epoch in range(num_epochs):
epoch_loss = 0.0
for lr_imgs, hr_imgs in dataloader:
# Move images to device
lr_imgs, hr_imgs = lr_imgs.to(device), hr_imgs.to(device)
# Forward pass: Model produces the upscaled images
outputs = model(lr_imgs)
loss = criterion(outputs, hr_imgs)
# Backpropagation and optimization
optimizer.zero_grad() # Clear gradients for this iteration
loss.backward() # Backpropagate the loss
optimizer.step() # Update weights
epoch_loss += loss.item()
avg_loss = epoch_loss / len(dataloader)
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.6f}")
# Optionally, save your trained model for later inference
torch.save(model.state_dict(), "upscaling.pth")