File size: 1,812 Bytes
a6a4c31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch
import torch.nn as nn
import torch.optim as optim
import coremltools as ct

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import os
from PIL import Image
import torchvision.transforms.functional as TF

device = torch.device("mps")
class UPSC(nn.Module):
    def __init__(self):
        super(UPSC,self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, padding=1),
            nn.ReLU(),
            # This convolution outputs channels that are scale_factor^2 * number_of_channels.
            nn.Conv2d(in_channels=32, out_channels=3 * 3 * 3, kernel_size=3, padding=1),
            # PixelShuffle rearranges channels into spatial dimensions.
            nn.PixelShuffle(3)
        )
    def forward(self, x):
        return self.model(x)

model = UPSC().to(device)
model.load_state_dict(torch.load("upscaling.pth", weights_only=True))
model.eval()

img = Image.open("test.png").convert("RGB")

# Resize it to match what the model expects (e.g. 256x256)
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # match training input size
    transforms.ToTensor()
])

lr_tensor = transform(img).unsqueeze(0).to(device)

with torch.no_grad():
    sr_tensor = model(lr_tensor)
    traced_model = torch.jit.trace(model, lr_tensor)


# Remove batch dimension and convert to PIL
sr_image = TF.to_pil_image(sr_tensor.squeeze(0).clamp(0, 1))
sr_image.save("upscaled_output_5.jpg")

mlmodel = ct.convert(
    traced_model,
    inputs=[ct.ImageType(name="input", shape=lr_tensor.shape)],
    compute_units=ct.ComputeUnit.ALL  # Use ANE if available
)

mlmodel.save("upscaling.mlmodel")