import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F import torchvision.transforms as transforms import os from PIL import Image from torch.utils.data import Dataset,dataloader from torchvision import transforms from torch.utils.data import DataLoader class UPSC(nn.Module): def __init__(self): super(UPSC,self).__init__() self.model = nn.Sequential( nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5, padding=2), nn.ReLU(), nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, padding=1), nn.ReLU(), # This convolution outputs channels that are scale_factor^2 * number_of_channels. nn.Conv2d(in_channels=32, out_channels=3 * 3 * 3, kernel_size=3, padding=1), # PixelShuffle rearranges channels into spatial dimensions. nn.PixelShuffle(3) ) def forward(self, x): return self.model(x) class PairedSuperResolutionDataset(Dataset): def __init__(self, lr_dir, hr_dir, lr_size=(64, 64), hr_size=(256, 256)): self.lr_dir = lr_dir self.hr_dir = hr_dir self.lr_files = sorted(os.listdir(lr_dir)) self.hr_files = sorted(os.listdir(hr_dir)) self.transform_lr = transforms.Compose([ transforms.Resize(lr_size), transforms.ToTensor() ]) self.transform_hr = transforms.Compose([ transforms.Resize(hr_size), transforms.ToTensor() ]) def __len__(self): return len(self.lr_files) def __getitem__(self, idx): lr_path = os.path.join(self.lr_dir, self.lr_files[idx]) hr_path = os.path.join(self.hr_dir, self.hr_files[idx]) lr_img = Image.open(lr_path).convert("RGB") hr_img = Image.open(hr_path).convert("RGB") lr_tensor = self.transform_lr(lr_img) hr_tensor = self.transform_hr(hr_img) return lr_tensor, hr_tensor lr_dir = '/Users/aaronvattay/Documents/DF2K_train_LR_bicubic/X3' hr_dir = '/Users/aaronvattay/Documents/DF2K_train_HR' batch_size = 16 num_epochs = 10 learning_rate = 1e-4 # Create dataset and dataloader dataset = PairedSuperResolutionDataset( lr_dir=lr_dir, hr_dir=hr_dir, lr_size=(256,256), hr_size=(768,768) ) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) # Device configuration device = torch.device("mps") # Initialize model, loss, and optimizer model = UPSC().to(device) criterion = nn.MSELoss() optimizer = optim.Adam(model.parameters(), lr=learning_rate) # Load the model state if available if os.path.exists("upscaling.pth"): model.load_state_dict(torch.load("upscaling.pth",map_location=device,weights_only=True)) # Set the model to training mode model.train() if __name__ == "__main__": for epoch in range(num_epochs): epoch_loss = 0.0 for lr_imgs, hr_imgs in dataloader: # Move images to device lr_imgs, hr_imgs = lr_imgs.to(device), hr_imgs.to(device) # Forward pass: Model produces the upscaled images outputs = model(lr_imgs) loss = criterion(outputs, hr_imgs) # Backpropagation and optimization optimizer.zero_grad() # Clear gradients for this iteration loss.backward() # Backpropagate the loss optimizer.step() # Update weights epoch_loss += loss.item() avg_loss = epoch_loss / len(dataloader) print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.6f}") # Optionally, save your trained model for later inference torch.save(model.state_dict(), "upscaling.pth")