Spaces:
Runtime error
Runtime error
import math | |
import torch.nn as nn | |
from torch.nn import functional as F | |
def is_square_of_two(num): | |
if num <= 0: | |
return False | |
return num & (num - 1) == 0 | |
class CnnEncoder(nn.Module): | |
""" | |
Simple cnn encoder that encodes a 64x64 image to embeddings | |
""" | |
def __init__(self, embedding_size, activation_function='relu'): | |
super().__init__() | |
self.act_fn = getattr(F, activation_function) | |
self.embedding_size = embedding_size | |
self.fc = nn.Linear(1024, self.embedding_size) | |
self.conv1 = nn.Conv2d(3, 32, 4, stride=2) | |
self.conv2 = nn.Conv2d(32, 64, 4, stride=2) | |
self.conv3 = nn.Conv2d(64, 128, 4, stride=2) | |
self.conv4 = nn.Conv2d(128, 256, 4, stride=2) | |
self.modules = [self.conv1, self.conv2, self.conv3, self.conv4] | |
def forward(self, observation): | |
batch_size = observation.shape[0] | |
hidden = self.act_fn(self.conv1(observation)) | |
hidden = self.act_fn(self.conv2(hidden)) | |
hidden = self.act_fn(self.conv3(hidden)) | |
hidden = self.act_fn(self.conv4(hidden)) | |
hidden = self.fc(hidden.view(batch_size, 1024)) | |
return hidden | |
class CnnDecoder(nn.Module): | |
""" | |
Simple Cnn decoder that decodes an embedding to 64x64 images | |
""" | |
def __init__(self, embedding_size, activation_function='relu'): | |
super().__init__() | |
self.act_fn = getattr(F, activation_function) | |
self.embedding_size = embedding_size | |
self.fc = nn.Linear(embedding_size, 128) | |
self.conv1 = nn.ConvTranspose2d(128, 128, 5, stride=2) | |
self.conv2 = nn.ConvTranspose2d(128, 64, 5, stride=2) | |
self.conv3 = nn.ConvTranspose2d(64, 32, 6, stride=2) | |
self.conv4 = nn.ConvTranspose2d(32, 3, 6, stride=2) | |
self.modules = [self.conv1, self.conv2, self.conv3, self.conv4] | |
def forward(self, embedding): | |
batch_size = embedding.shape[0] | |
hidden = self.fc(embedding) | |
hidden = hidden.view(batch_size, 128, 1, 1) | |
hidden = self.act_fn(self.conv1(hidden)) | |
hidden = self.act_fn(self.conv2(hidden)) | |
hidden = self.act_fn(self.conv3(hidden)) | |
observation = self.conv4(hidden) | |
return observation | |
class FullyConvEncoder(nn.Module): | |
""" | |
Simple fully convolutional encoder, with 2D input and 2D output | |
""" | |
def __init__(self, | |
input_shape=(3, 64, 64), | |
embedding_shape=(8, 16, 16), | |
activation_function='relu', | |
init_channels=16, | |
): | |
super().__init__() | |
assert len(input_shape) == 3, "input_shape must be a tuple of length 3" | |
assert len(embedding_shape) == 3, "embedding_shape must be a tuple of length 3" | |
assert input_shape[1] == input_shape[2] and is_square_of_two(input_shape[1]), "input_shape must be square" | |
assert embedding_shape[1] == embedding_shape[2], "embedding_shape must be square" | |
assert input_shape[1] % embedding_shape[1] == 0, "input_shape must be divisible by embedding_shape" | |
assert is_square_of_two(init_channels), "init_channels must be a square of 2" | |
depth = int(math.sqrt(input_shape[1] / embedding_shape[1])) + 1 | |
channels_per_layer = [init_channels * (2 ** i) for i in range(depth)] | |
self.act_fn = getattr(F, activation_function) | |
self.downs = nn.ModuleList([]) | |
self.downs.append(nn.Conv2d(input_shape[0], channels_per_layer[0], kernel_size=3, stride=1, padding=1)) | |
for i in range(1, depth): | |
self.downs.append(nn.Conv2d(channels_per_layer[i-1], channels_per_layer[i], | |
kernel_size=3, stride=2, padding=1)) | |
# Bottleneck layer | |
self.downs.append(nn.Conv2d(channels_per_layer[-1], embedding_shape[0], kernel_size=1, stride=1, padding=0)) | |
def forward(self, observation): | |
hidden = observation | |
for layer in self.downs: | |
hidden = self.act_fn(layer(hidden)) | |
return hidden | |
class FullyConvDecoder(nn.Module): | |
""" | |
Simple fully convolutional decoder, with 2D input and 2D output | |
""" | |
def __init__(self, | |
embedding_shape=(8, 16, 16), | |
output_shape=(3, 64, 64), | |
activation_function='relu', | |
init_channels=16, | |
): | |
super().__init__() | |
assert len(embedding_shape) == 3, "embedding_shape must be a tuple of length 3" | |
assert len(output_shape) == 3, "output_shape must be a tuple of length 3" | |
assert output_shape[1] == output_shape[2] and is_square_of_two(output_shape[1]), "output_shape must be square" | |
assert embedding_shape[1] == embedding_shape[2], "input_shape must be square" | |
assert output_shape[1] % embedding_shape[1] == 0, "output_shape must be divisible by input_shape" | |
assert is_square_of_two(init_channels), "init_channels must be a square of 2" | |
depth = int(math.sqrt(output_shape[1] / embedding_shape[1])) + 1 | |
channels_per_layer = [init_channels * (2 ** i) for i in range(depth)] | |
self.act_fn = getattr(F, activation_function) | |
self.ups = nn.ModuleList([]) | |
self.ups.append(nn.ConvTranspose2d(embedding_shape[0], channels_per_layer[-1], | |
kernel_size=1, stride=1, padding=0)) | |
for i in range(1, depth): | |
self.ups.append(nn.ConvTranspose2d(channels_per_layer[-i], channels_per_layer[-i-1], | |
kernel_size=3, stride=2, padding=1, output_padding=1)) | |
self.output_layer = nn.ConvTranspose2d(channels_per_layer[0], output_shape[0], | |
kernel_size=3, stride=1, padding=1) | |
def forward(self, embedding): | |
hidden = embedding | |
for layer in self.ups: | |
hidden = self.act_fn(layer(hidden)) | |
return self.output_layer(hidden) | |