|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
import coremltools as ct |
|
|
|
from torch.utils.data import Dataset, DataLoader |
|
from torchvision import transforms |
|
import os |
|
from PIL import Image |
|
import torchvision.transforms.functional as TF |
|
|
|
device = torch.device("mps") |
|
class UPSC(nn.Module): |
|
def __init__(self): |
|
super(UPSC,self).__init__() |
|
self.model = nn.Sequential( |
|
nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5, padding=2), |
|
nn.ReLU(), |
|
nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, padding=1), |
|
nn.ReLU(), |
|
|
|
nn.Conv2d(in_channels=32, out_channels=3 * 3 * 3, kernel_size=3, padding=1), |
|
|
|
nn.PixelShuffle(3) |
|
) |
|
def forward(self, x): |
|
return self.model(x) |
|
|
|
model = UPSC().to(device) |
|
model.load_state_dict(torch.load("upscaling.pth", weights_only=True)) |
|
model.eval() |
|
|
|
img = Image.open("test.png").convert("RGB") |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((256, 256)), |
|
transforms.ToTensor() |
|
]) |
|
|
|
lr_tensor = transform(img).unsqueeze(0).to(device) |
|
|
|
with torch.no_grad(): |
|
sr_tensor = model(lr_tensor) |
|
traced_model = torch.jit.trace(model, lr_tensor) |
|
|
|
|
|
|
|
sr_image = TF.to_pil_image(sr_tensor.squeeze(0).clamp(0, 1)) |
|
sr_image.save("upscaled_output_5.jpg") |
|
|
|
mlmodel = ct.convert( |
|
traced_model, |
|
inputs=[ct.ImageType(name="input", shape=lr_tensor.shape)], |
|
compute_units=ct.ComputeUnit.ALL |
|
) |
|
|
|
mlmodel.save("upscaling.mlmodel") |
|
|