File size: 1,372 Bytes
8eb4303
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
import torch.nn as nn
import torch.nn.functional as F


class HighResoEncoder(nn.Module):
    def __init__(self, 
                 in_dim=5, # 3 for rgb and 2 for coordinate
                 out_dim=96, 
                 ):
        super().__init__()
        self.first = nn.Conv2d(in_channels=in_dim, out_channels=64, kernel_size=7, stride=2, padding=3)
        self.activation = nn.LeakyReLU(negative_slope=0.01)

        self.conv_layers = nn.Sequential(*[
            nn.Conv2d(in_channels=64, out_channels=96, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(negative_slope=0.01),
            nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(negative_slope=0.01),
            nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(negative_slope=0.01),
            nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(negative_slope=0.01),
        ])

        self.final = nn.Conv2d(in_channels=96, out_channels=out_dim, kernel_size=3, stride=1, padding=1)
    
    def forward(self, x):
        """
        x: [B, C=5, 256, 256]
        return: [B, C=96, 256, 256]
        """
        h = self.first(x)
        h = self.conv_layers(h)
        h = self.final(h)
        return h