File size: 3,363 Bytes
41e894a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
inputSize = [224 224 3];

preprossedTrainingData = transform(trainingData, @(data)preprocessData(data,inputSize));
numAnchors = 3;
anchorBoxes = estimateAnchorBoxes(preprocessedTrainingData,numAnchors)

featuresExtractionNetwork = resnet50;

featureLayer - "activation_40_relu";

numClasses = width(vehicleDataset)-1;

lgraph = fasterRCNNLayers(inputSize,numClasses,anchorBoxes,featureExtractionNetwork,featureLayer);

augmentedTrainingData = transform(trainingData,@aumentData);

augmentedData = cell(4,1);
for k = 1:4
    data = read(augmentedTrainingData);
    augmentedData{k} = insertShape)data{1},"rectangle",data{2});
    reset(augmentedTrainingData);
end
figure
montage(augmentedData,BorderSize=10)

trainingData = transform(augmentedTrainingData,@(data)preprocessData(data,inputSize));
validationData = transform(validationData,@(data)preprocessData(data,inputSize));

data = read(trainingData);

I = data{1};
bbox = data{2};
annotatedImage = insertShape(I,"rectangle",bbox);
annotatedImage = imresize(annotatedImage,2);
figure
imshow(annotatedImage)

// Train Faster R-CNN

options = trainingOptions("sgdm",...
    MaxEpochs=10,...
    MiniBatchSize=2,...
    InitialLearnRate=1e-3,...
    CheckpointPatin=tempdir,...
    ValidationData=validationData);

if doTraining
      % Train the Faster R-CNN detector.
      % * Adjust NegativeOveralpRange and PositiveOverlapRange to ensure
      % that training samples tightly overlap with ground truth.
      [detector, info] = trainFasterRCNNObjectDetector(training
      NegativeOverlapRange=[0 0.3], ...
      PositiveOverlapRange=[0.6 1]);      
else 
    % Load pretrained detector for the example.
    pretrained = load("fasterRCNNResNet50EndToEndVehicleExample.mat");
    detector = pretrained.detetor;
end

I = imread(testDataTbl.imageFilename{3});
I = imresize(I,inputSize(1:2));
[bboxes,scores] = detect(detector,I);

I = insertObjectAnnotation(I,"rectangle",bboxes,scores);
figure
imshow(I)

testData = transform(testData,@(data)preprocessData(data,inputSize));

detectionResults = detect(detector,testData,...
    Threshold=0.2,...
    MiniBatchSize=4);

classID = 1;
metrics = evaluateObjectDetection(detectionResults,testData);
precision = metrics.ClassMetrics.Precision{classID};
recall = metrics.ClassMetrics.Recall{classID};

figure
plot(recall,precision)
xlabel("Recall")
ylable("Precision")
grid on
title(sprintf("Average Precision = %.2f", metrics.ClassMetrics.mAP(classID)))

function data = augmentData(data)
% Randomly flip images and bounding boxes horizontally.
tform = randomAffine2d("XReflection",true);
sz = size(data{1});
rout = affineOutputView(sz,tform);
data{1} = imwarp(data{1},tform,"OutputView",rout);

% Sanitize boxes, if needed. This helper function is attached as a 
% supporting file. Open the example in MATLAB to open this function.
data{2} = helperSanitizeBoxes(data{2});

% Warp boxes.
data{2} = bboxwwarp(data{2},tform,rout);
end

function data = preprocessData(data,targetSize)
% Resize image and bounding boxes to targetSize.
sz = size(data{1},[1 2]);
scale = targetSize(1:2)./sz;
data{1} = imresize(data{1},targetSize(1:2));

% Sanitize boxes, if needed. This helper function is attached as a 
% supporting file. Open the example in MATLAB to open this function.
data{2} = helperSanitizeBoxes(data{2});

% Resize boxes.
data{2} = bboxresize(data{2},scale);
end