在训练完成神经网络后或训练过程中,往往需要保存模型,以备应用。本文介绍保存神经网络 h5
文件与读取神经网络文件的方法。
依赖
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import models
写入文件
对于 keras.Model()
对象,调用
model.save('fileName.h5')
即可保存模型权值与权值。即保存整个模型结构。可以在下文里直接调用。
读取文件
对于已经保存的 nerual_network.h5
模型文件,可以如下调用:
model_file = 'nerual_network.h5'
model = keras.models.load_model(model_file)
即可恢复成为一个keras.Model()
对象。
之后可以通过传入 tf.tensor
对神经网络模型进行测试,得到 Predicted 的结果:
test_data = tf.convert_to_tensor(<np.array>) # convert to tensor
test_data = tf.expand_dims(test_data, -1) # expand dimension
test_data = tf.image.resize(test_data, (512, 512)) # resample data
test_data = tf.expand_dims(test_data, 0) # expand dimension
predicted_data = model.predict(test_data) # predict by NN
predicted_data = np.squeeze(predicted_data) # squeeze dimentsion