HemaAM commited on
Commit
fa1bc1f
·
1 Parent(s): 6868d32

Deleting this file as an updated file is uploaded

Browse files
Files changed (1) hide show
  1. yolo3.py +0 -181
yolo3.py DELETED
@@ -1,181 +0,0 @@
1
- """Implementation of YOLOv3 architecture."""
2
- from typing import Any, List
3
-
4
- import torch
5
- import torch.nn as nn
6
-
7
- """
8
- Information about architecture config:
9
- Tuple is structured by (filters, kernel_size, stride)
10
- Every conv is a same convolution.
11
- List is structured by "B" indicating a residual block followed by the number of repeats
12
- "S" is for scale prediction block and computing the yolo loss
13
- "U" is for upsampling the feature map and concatenating with a previous layer
14
- """
15
- config = [
16
- (32, 3, 1),
17
- (64, 3, 2),
18
- ["B", 1],
19
- (128, 3, 2),
20
- ["B", 2],
21
- (256, 3, 2),
22
- ["B", 8],
23
- (512, 3, 2),
24
- ["B", 8],
25
- (1024, 3, 2),
26
- ["B", 4], # To this point is Darknet-53
27
- (512, 1, 1),
28
- (1024, 3, 1),
29
- "S",
30
- (256, 1, 1),
31
- "U",
32
- (256, 1, 1),
33
- (512, 3, 1),
34
- "S",
35
- (128, 1, 1),
36
- "U",
37
- (128, 1, 1),
38
- (256, 3, 1),
39
- "S",
40
- ]
41
-
42
-
43
- class CNNBlock(nn.Module):
44
- def __init__(self, in_channels, out_channels, bn_act=True, **kwargs):
45
- super().__init__()
46
- self.conv = nn.Conv2d(in_channels, out_channels, bias=not bn_act, **kwargs)
47
- self.bn = nn.BatchNorm2d(out_channels)
48
- self.leaky = nn.LeakyReLU(0.1)
49
- self.use_bn_act = bn_act
50
-
51
- def forward(self, x):
52
- if self.use_bn_act:
53
- return self.leaky(self.bn(self.conv(x)))
54
- else:
55
- return self.conv(x)
56
-
57
-
58
- class ResidualBlock(nn.Module):
59
- def __init__(self, channels, use_residual=True, num_repeats=1):
60
- super().__init__()
61
- self.layers = nn.ModuleList()
62
- for repeat in range(num_repeats):
63
- self.layers += [
64
- nn.Sequential(
65
- CNNBlock(channels, channels // 2, kernel_size=1),
66
- CNNBlock(channels // 2, channels, kernel_size=3, padding=1),
67
- )
68
- ]
69
-
70
- self.use_residual = use_residual
71
- self.num_repeats = num_repeats
72
-
73
- def forward(self, x):
74
- for layer in self.layers:
75
- if self.use_residual:
76
- x = x + layer(x)
77
- else:
78
- x = layer(x)
79
-
80
- return x
81
-
82
-
83
- class ScalePrediction(nn.Module):
84
- def __init__(self, in_channels, num_classes):
85
- super().__init__()
86
- self.pred = nn.Sequential(
87
- CNNBlock(in_channels, 2 * in_channels, kernel_size=3, padding=1),
88
- CNNBlock(2 * in_channels, (num_classes + 5) * 3, bn_act=False, kernel_size=1),
89
- )
90
- self.num_classes = num_classes
91
-
92
- def forward(self, x):
93
- return (
94
- self.pred(x)
95
- .reshape(x.shape[0], 3, self.num_classes + 5, x.shape[2], x.shape[3])
96
- .permute(0, 1, 3, 4, 2)
97
- )
98
-
99
-
100
- class YOLOv3(nn.Module):
101
- def __init__(self, load_config: List[Any] = config, in_channels=3, num_classes=80):
102
- super().__init__()
103
- self.load_config = load_config
104
- self.num_classes = num_classes
105
- self.in_channels = in_channels
106
- self.layers = self._create_conv_layers()
107
-
108
- def forward(self, x):
109
- outputs = [] # for each scale
110
- route_connections = []
111
- for layer in self.layers:
112
- if isinstance(layer, ScalePrediction):
113
- outputs.append(layer(x))
114
- continue
115
-
116
- x = layer(x)
117
-
118
- if isinstance(layer, ResidualBlock) and layer.num_repeats == 8:
119
- route_connections.append(x)
120
-
121
- elif isinstance(layer, nn.Upsample):
122
- x = torch.cat([x, route_connections[-1]], dim=1)
123
- route_connections.pop()
124
-
125
- return outputs
126
-
127
- def _create_conv_layers(self):
128
- layers = nn.ModuleList()
129
- in_channels = self.in_channels
130
-
131
- for module in self.load_config:
132
- if isinstance(module, tuple):
133
- out_channels, kernel_size, stride = module
134
- layers.append(
135
- CNNBlock(
136
- in_channels,
137
- out_channels,
138
- kernel_size=kernel_size,
139
- stride=stride,
140
- padding=1 if kernel_size == 3 else 0,
141
- )
142
- )
143
- in_channels = out_channels
144
-
145
- elif isinstance(module, list):
146
- num_repeats = module[1]
147
- layers.append(
148
- ResidualBlock(
149
- in_channels,
150
- num_repeats=num_repeats,
151
- )
152
- )
153
-
154
- elif isinstance(module, str):
155
- if module == "S":
156
- layers += [
157
- ResidualBlock(in_channels, use_residual=False, num_repeats=1),
158
- CNNBlock(in_channels, in_channels // 2, kernel_size=1),
159
- ScalePrediction(in_channels // 2, num_classes=self.num_classes),
160
- ]
161
- in_channels = in_channels // 2
162
-
163
- elif module == "U":
164
- layers.append(
165
- nn.Upsample(scale_factor=2),
166
- )
167
- in_channels = in_channels * 3
168
-
169
- return layers
170
-
171
-
172
- if __name__ == "__main__":
173
- num_classes = 20
174
- IMAGE_SIZE = 416
175
- model = YOLOv3(load_config=config, num_classes=num_classes)
176
- x = torch.randn((2, 3, IMAGE_SIZE, IMAGE_SIZE))
177
- out = model(x)
178
- assert out[0].shape == (2, 3, IMAGE_SIZE // 32, IMAGE_SIZE // 32, num_classes + 5)
179
- assert out[1].shape == (2, 3, IMAGE_SIZE // 16, IMAGE_SIZE // 16, num_classes + 5)
180
- assert out[2].shape == (2, 3, IMAGE_SIZE // 8, IMAGE_SIZE // 8, num_classes + 5)
181
- print("Success!")