File size: 3,673 Bytes
a6a4c31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
import os
from PIL import Image
from torch.utils.data import Dataset,dataloader
from torchvision import transforms
from torch.utils.data import DataLoader

class UPSC(nn.Module):
    def __init__(self):
        super(UPSC,self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, padding=1),
            nn.ReLU(),
            # This convolution outputs channels that are scale_factor^2 * number_of_channels.
            nn.Conv2d(in_channels=32, out_channels=3 * 3 * 3, kernel_size=3, padding=1),
            # PixelShuffle rearranges channels into spatial dimensions.
            nn.PixelShuffle(3)
        )
    def forward(self, x):
        return self.model(x)
    


class PairedSuperResolutionDataset(Dataset):
    def __init__(self, lr_dir, hr_dir, lr_size=(64, 64), hr_size=(256, 256)):
        self.lr_dir = lr_dir
        self.hr_dir = hr_dir
        self.lr_files = sorted(os.listdir(lr_dir))
        self.hr_files = sorted(os.listdir(hr_dir))

        self.transform_lr = transforms.Compose([
            transforms.Resize(lr_size),
            transforms.ToTensor()
        ])

        self.transform_hr = transforms.Compose([
            transforms.Resize(hr_size),
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.lr_files)

    def __getitem__(self, idx):
        lr_path = os.path.join(self.lr_dir, self.lr_files[idx])
        hr_path = os.path.join(self.hr_dir, self.hr_files[idx])

        lr_img = Image.open(lr_path).convert("RGB")
        hr_img = Image.open(hr_path).convert("RGB")

        lr_tensor = self.transform_lr(lr_img)
        hr_tensor = self.transform_hr(hr_img)

        return lr_tensor, hr_tensor


lr_dir = '/Users/aaronvattay/Documents/DF2K_train_LR_bicubic/X3'   
hr_dir =  '/Users/aaronvattay/Documents/DF2K_train_HR'
batch_size = 16
num_epochs = 10
learning_rate = 1e-4


# Create dataset and dataloader
dataset = PairedSuperResolutionDataset(
    lr_dir=lr_dir,
    hr_dir=hr_dir,
    lr_size=(256,256),
    hr_size=(768,768)
)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Device configuration
device = torch.device("mps")

# Initialize model, loss, and optimizer
model = UPSC().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# Load the model state if available
if os.path.exists("upscaling.pth"):
    model.load_state_dict(torch.load("upscaling.pth",map_location=device,weights_only=True))
# Set the model to training mode
model.train()
if __name__ == "__main__":
 for epoch in range(num_epochs):
    epoch_loss = 0.0
    for lr_imgs, hr_imgs in dataloader:
        # Move images to device
        lr_imgs, hr_imgs = lr_imgs.to(device), hr_imgs.to(device)
        
        # Forward pass: Model produces the upscaled images
        outputs = model(lr_imgs)
        loss = criterion(outputs, hr_imgs)
        
        # Backpropagation and optimization
        optimizer.zero_grad()    # Clear gradients for this iteration
        loss.backward()          # Backpropagate the loss
        optimizer.step()         # Update weights
        
        epoch_loss += loss.item()
    
    avg_loss = epoch_loss / len(dataloader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.6f}")

 # Optionally, save your trained model for later inference
 torch.save(model.state_dict(), "upscaling.pth")