Toccata in Nowhere.

神经网络loss不下降的可能原因 —— 输入/输出中含有 NaN 值

2021.04.07

在进行一次对新数据的迁移学习测试时,意外的出现网络无法对新数据特征进行学习,输出预测偏移越来越远的情况。通过替换数据集、对CAE网络互换输入输出的尝试后,甚至出现了Loss 变为NaN的情况,因此锁定问题出自数据集中存在NaN值。

原因

因输入中含有一部分NaN值,而NaN与任何值的运算的结果都是NaN,因此 NaN 在网络学习迭代的过程会因梯度计算会逐渐传播至网络的权重中,而NaN权重的神经元连接无法进一步进行学些更新,导致网络训练陷入僵局。如果输出中含有NaN,则Loss的计算结果也会变为NaN,同样无法进行进一步训练。

解决

可使用 numpy.isnan 对输入/输出矩阵进行判断,并将 NaN 值处归零:

x_train[np.isnan(x_train)] = 0

参考

PyTorch: NaNs in input data breaking gradients