svjack's picture
Upload SPIGA with huggingface_hub
8f09a1b
raw
history blame contribute delete
825 Bytes
from torchvision import transforms
import numpy as np
from PIL import Image
import cv2
from spiga.data.loaders.transforms import TargetCrop, ToOpencv, AddModel3D
def get_transformers(data_config):
transformer_seq = [
Opencv2Pil(),
TargetCrop(data_config.image_size, data_config.target_dist),
ToOpencv(),
NormalizeAndPermute()]
return transforms.Compose(transformer_seq)
class NormalizeAndPermute:
def __call__(self, sample):
image = np.array(sample['image'], dtype=float)
image = np.transpose(image, (2, 0, 1))
sample['image'] = image / 255
return sample
class Opencv2Pil:
def __call__(self, sample):
image = cv2.cvtColor(sample['image'], cv2.COLOR_BGR2RGB)
sample['image'] = Image.fromarray(image)
return sample