Toccata in Nowhere.

Keras 定义每次训练/学习迭代 Epoch 后的操作

2020.07.03

在对神经网络进行训练时,经常需要获取每次训练后的的神经网络的初步测试结果的对比,以下介绍每次训练结束后调用函数进行操作的方法。

依赖

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import callbacks

基本

首先继承 keras.callbacks.EarlyStopping类,重构on_epoch_end函数:

class earlyStopAndDraw(keras.callbacks.EarlyStopping):
		def on_epoch_end(self, epoch, logs=None):
		
		# operation here

之后生成类对象 early_stop:

early_stop = earlyStopAndDraw(patience=4, monitor='loss')

在model.fit里使用这个类对象:

train_history = model.fit(train_data, label_data, callbacks=[early_stop], batch_size=4, epochs=50, validation_split=0.2, validation_freq=1, shuffle=True)

能做什么?

# operation here处可进行需要的操作。例如保存模型 model.save('model_file.h5');或使用数据测试每次训练结果,对于测试数据 testRawData例如:


plt.figure()
plt.subplot(1,2,1)
plt.imshow(tf.reshape(testRawData,(input_width, input_width)))
plt.subplot(1,2,2)
test_data = tf.expand_dims(tf.reshape(testRawData,(input_width, input_width)), -1)
test_data = tf.expand_dims(test_data, 0)
predict_data = auto_encoder.predict(test_data)
plt.imshow(np.squeeze(predict_data))
plt.show()
plt.savefig(file_name)
plt.close()