MATLAB 的深度学习工具箱Deep Learning Toolbox的学习与使用对熟悉MATLAB的用户较为容易,尤其可以使用图形界面直观地看到训练的过程,但相应的GUI中的数值并不能直接导出,本文介绍如何在训练时保存训练误差相关信息。
例子
以一个带全连接层的神经网络线性回归拟合函数 $ y = x $ 为例:
clc, clear
%% 数据集创建
dataLen = 100;
dataIn = rand([1,dataLen])*100;
dataOut = dataIn + 2*(rand([1,dataLen])-0.5) * 5;
%% 网络结构声明
layers = [
sequenceInputLayer([1])
fullyConnectedLayer(5)
reluLayer
fullyConnectedLayer(1)
regressionLayer
];
%% 训练参数
options = trainingOptions('sgdm', ...
'InitialLearnRate',1e-4, ...
'MaxEpochs',80, ...
'Shuffle','every-epoch', ...
'Verbose',false, ...
'Plots','training-progress','ExecutionEnvironment','cpu');
%% 使用 trainInfo 保存训练信息
[net, trainInfo] = trainNetwork(dataIn,dataOut,layers,options);
plot(trainInfo.TrainingLoss)
以上代码训练结束后会绘制训练的 Loss 下降曲线。
对于使用回归层regressionLayer
的神经网络(对应MSE目标函数),返回的trainInfo
为 包含 TrainingLoss
、TrainingRMES
、BaseLearningRate
的 MATLAB Struct
;长度为Epochs数的double数组,可以使用下标访问,例如trainInfo.TrainingLoss
。