Gosula commited on
Commit
81eba8e
·
1 Parent(s): 3a88fe8

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +141 -0
model.py CHANGED
@@ -163,6 +163,147 @@ class YOLOv3(LightningModule):
163
 
164
  return layers
165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
  if __name__ == "__main__":
168
  num_classes = 20
 
163
 
164
  return layers
165
 
166
+ class YoloVersion3(LightningModule):
167
+ def __init__(self):
168
+ super(YoloVersion3, self).__init__( )
169
+ self.save_hyperparameters()
170
+ # Set our init args as class attributes
171
+ self.learning_rate=config.LEARNING_RATE
172
+ #self.config=config
173
+
174
+ self.num_classes=config.NUM_CLASSES
175
+ self.train_csv=config.DATASET + "/train.csv"
176
+ self.test_csv=config.DATASET + "/test.csv"
177
+
178
+ self.loss_fn= YoloLoss()
179
+ self.scaler = amp.GradScaler()
180
+ #self.train_transform_function= config.train_transforms
181
+ #self.in_channels = 3
182
+ self.model= YOLOv3(num_classes=config.NUM_CLASSES).to(config.DEVICE)
183
+ self.scaled_anchors = (
184
+ torch.tensor(config.ANCHORS) * torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2)).to(config.DEVICE)
185
+ #self.register_buffer("scaled_anchors", self.scaled_anchors)
186
+ self.training_step_outputs = []
187
+
188
+ def forward(self, x):
189
+ return self.model(x)
190
+
191
+ def training_step(self, batch, batch_idx):
192
+ x, y = batch
193
+ y0, y1, y2 = (
194
+ y[0],
195
+ y[1],
196
+ y[2],
197
+ )
198
+ out = self(x)
199
+ loss = (
200
+ self.loss_fn(out[0], y0, self.scaled_anchors[0])
201
+ + self.loss_fn(out[1], y1, self.scaled_anchors[1])
202
+ + self.loss_fn(out[2], y2, self.scaled_anchors[2])
203
+ )
204
+ self.log("train_loss", loss, on_epoch=True, prog_bar=True, logger=True) # Logging the training loss for visualization
205
+ self.training_step_outputs.append(loss)
206
+ return loss
207
+
208
+ def on_train_epoch_end(self):
209
+
210
+ print(f"\nCurrently epoch {self.current_epoch}")
211
+ train_epoch_average = torch.stack(self.training_step_outputs).mean()
212
+ self.training_step_outputs.clear()
213
+ print(f"Train loss {train_epoch_average}")
214
+ print("On Train Eval loader:")
215
+ print("On Train loader:")
216
+ class_accuracy, no_obj_accuracy, obj_accuracy = check_class_accuracy(self.model, self.train_loader, threshold=config.CONF_THRESHOLD)
217
+ self.log("class_accuracy", class_accuracy, on_epoch=True, prog_bar=True, logger=True)
218
+ self.log("no_obj_accuracy", no_obj_accuracy, on_epoch=True, prog_bar=True, logger=True)
219
+ self.log("obj_accuracy", obj_accuracy, on_epoch=True, prog_bar=True, logger=True)
220
+
221
+ if (self.current_epoch>0) and ((self.current_epoch+1) % 6 == 0): # for every 10 epochs we are plotting
222
+ plot_couple_examples(self.model, self.test_loader, 0.6, 0.5, self.scaled_anchors)
223
+
224
+ if (self.current_epoch>0) and (self.current_epoch+1 == self.trainer.max_epochs ): #map calculation across last epoch
225
+ check_class_accuracy(self.model, self.test_loader, threshold=config.CONF_THRESHOLD)
226
+ pred_boxes, true_boxes = get_evaluation_bboxes(
227
+ self.test_loader,
228
+ self.model,
229
+ iou_threshold=config.NMS_IOU_THRESH,
230
+ anchors=config.ANCHORS,
231
+ threshold=config.CONF_THRESHOLD,
232
+ )
233
+ mapval = mean_average_precision(
234
+ pred_boxes,
235
+ true_boxes,
236
+ iou_threshold=config.MAP_IOU_THRESH,
237
+ box_format="midpoint",
238
+ num_classes=config.NUM_CLASSES,
239
+ )
240
+ print(f"MAP: {mapval.item()}")
241
+
242
+ self.log("MAP", mapval.item(), on_epoch=True, prog_bar=True, logger=True)
243
+
244
+
245
+
246
+ def configure_optimizers(self):
247
+ optimizer = optim.Adam(
248
+ self.parameters(),
249
+ lr=config.LEARNING_RATE,
250
+ weight_decay=config.WEIGHT_DECAY,
251
+ )
252
+
253
+ self.trainer.fit_loop.setup_data()
254
+ dataloader = self.trainer.train_dataloader
255
+
256
+ EPOCHS = config.NUM_EPOCHS # 40 % of number of epochs
257
+ lr_scheduler = OneCycleLR(
258
+ optimizer,
259
+ max_lr=1E-3,
260
+ steps_per_epoch=len(dataloader),
261
+ epochs=EPOCHS,
262
+ pct_start=5/EPOCHS,
263
+ div_factor=100,
264
+ three_phase=False,
265
+ final_div_factor=100,
266
+ anneal_strategy='linear'
267
+ )
268
+
269
+ scheduler = {"scheduler": lr_scheduler, "interval" : "step"}
270
+
271
+ return [optimizer]
272
+
273
+ def setup(self, stage=None):
274
+ self.train_loader, self.test_loader, self.train_eval_loader = get_loaders(
275
+ train_csv_path=self.train_csv,
276
+ test_csv_path=self.test_csv,
277
+ )
278
+
279
+ def train_dataloader(self):
280
+ return self.train_loader
281
+
282
+ def val_dataloader(self):
283
+ return self.train_eval_loader
284
+
285
+ def test_dataloader(self):
286
+ return self.test_loader
287
+ # if __name__ == "__main__":
288
+
289
+ # model = YoloVersion3()
290
+
291
+ # checkpoint = ModelCheckpoint(filename='last_epoch', save_last=True)
292
+ # lr_rate_monitor = LearningRateMonitor(logging_interval="epoch")
293
+ # trainer = pl.Trainer(
294
+ # max_epochs=config.NUM_EPOCHS,
295
+ # deterministic=True,
296
+ # logger=True,
297
+ # default_root_dir="/content/drive/MyDrive/sunandini/Checkpoint/",
298
+ # callbacks=[lr_rate_monitor],
299
+ # enable_model_summary=False,
300
+ # log_every_n_steps=1,
301
+ # precision="16-mixed"
302
+ # )
303
+ # print("---- Training Started ---- Sunandini ----")
304
+ # trainer.fit(model)
305
+ # torch.save(model.state_dict(), 'YOLOv3.pth')
306
+
307
 
308
  if __name__ == "__main__":
309
  num_classes = 20