Toccata in Nowhere.

TensorFlow 保存 / 读取训练过的神经网络文件

2020.07.07

在训练完成神经网络后或训练过程中,往往需要保存模型,以备应用。本文介绍保存神经网络 h5 文件与读取神经网络文件的方法。

依赖

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import models

写入文件

对于 keras.Model()对象,调用

model.save('fileName.h5') 

即可保存模型权值与权值。即保存整个模型结构。可以在下文里直接调用。

读取文件

对于已经保存的 nerual_network.h5 模型文件,可以如下调用:

model_file = 'nerual_network.h5'
model = keras.models.load_model(model_file)

即可恢复成为一个keras.Model()对象。

之后可以通过传入 tf.tensor 对神经网络模型进行测试,得到 Predicted 的结果:

test_data = tf.convert_to_tensor(<np.array>)        # convert to tensor
test_data = tf.expand_dims(test_data, -1)           # expand dimension
test_data = tf.image.resize(test_data, (512, 512))  # resample data
test_data = tf.expand_dims(test_data, 0)            # expand dimension 
predicted_data = model.predict(test_data)           # predict by NN
predicted_data = np.squeeze(predicted_data)         # squeeze dimentsion