OD-RCNN / Faster R-CNN DN
antitheft159's picture
Create Faster R-CNN DN
41e894a verified
inputSize = [224 224 3];
preprossedTrainingData = transform(trainingData, @(data)preprocessData(data,inputSize));
numAnchors = 3;
anchorBoxes = estimateAnchorBoxes(preprocessedTrainingData,numAnchors)
featuresExtractionNetwork = resnet50;
featureLayer - "activation_40_relu";
numClasses = width(vehicleDataset)-1;
lgraph = fasterRCNNLayers(inputSize,numClasses,anchorBoxes,featureExtractionNetwork,featureLayer);
augmentedTrainingData = transform(trainingData,@aumentData);
augmentedData = cell(4,1);
for k = 1:4
data = read(augmentedTrainingData);
augmentedData{k} = insertShape)data{1},"rectangle",data{2});
reset(augmentedTrainingData);
end
figure
montage(augmentedData,BorderSize=10)
trainingData = transform(augmentedTrainingData,@(data)preprocessData(data,inputSize));
validationData = transform(validationData,@(data)preprocessData(data,inputSize));
data = read(trainingData);
I = data{1};
bbox = data{2};
annotatedImage = insertShape(I,"rectangle",bbox);
annotatedImage = imresize(annotatedImage,2);
figure
imshow(annotatedImage)
// Train Faster R-CNN
options = trainingOptions("sgdm",...
MaxEpochs=10,...
MiniBatchSize=2,...
InitialLearnRate=1e-3,...
CheckpointPatin=tempdir,...
ValidationData=validationData);
if doTraining
% Train the Faster R-CNN detector.
% * Adjust NegativeOveralpRange and PositiveOverlapRange to ensure
% that training samples tightly overlap with ground truth.
[detector, info] = trainFasterRCNNObjectDetector(training
NegativeOverlapRange=[0 0.3], ...
PositiveOverlapRange=[0.6 1]);
else
% Load pretrained detector for the example.
pretrained = load("fasterRCNNResNet50EndToEndVehicleExample.mat");
detector = pretrained.detetor;
end
I = imread(testDataTbl.imageFilename{3});
I = imresize(I,inputSize(1:2));
[bboxes,scores] = detect(detector,I);
I = insertObjectAnnotation(I,"rectangle",bboxes,scores);
figure
imshow(I)
testData = transform(testData,@(data)preprocessData(data,inputSize));
detectionResults = detect(detector,testData,...
Threshold=0.2,...
MiniBatchSize=4);
classID = 1;
metrics = evaluateObjectDetection(detectionResults,testData);
precision = metrics.ClassMetrics.Precision{classID};
recall = metrics.ClassMetrics.Recall{classID};
figure
plot(recall,precision)
xlabel("Recall")
ylable("Precision")
grid on
title(sprintf("Average Precision = %.2f", metrics.ClassMetrics.mAP(classID)))
function data = augmentData(data)
% Randomly flip images and bounding boxes horizontally.
tform = randomAffine2d("XReflection",true);
sz = size(data{1});
rout = affineOutputView(sz,tform);
data{1} = imwarp(data{1},tform,"OutputView",rout);
% Sanitize boxes, if needed. This helper function is attached as a
% supporting file. Open the example in MATLAB to open this function.
data{2} = helperSanitizeBoxes(data{2});
% Warp boxes.
data{2} = bboxwwarp(data{2},tform,rout);
end
function data = preprocessData(data,targetSize)
% Resize image and bounding boxes to targetSize.
sz = size(data{1},[1 2]);
scale = targetSize(1:2)./sz;
data{1} = imresize(data{1},targetSize(1:2));
% Sanitize boxes, if needed. This helper function is attached as a
% supporting file. Open the example in MATLAB to open this function.
data{2} = helperSanitizeBoxes(data{2});
% Resize boxes.
data{2} = bboxresize(data{2},scale);
end