这里主要讨论,如何save或restore通过Estimators构建的model。
Tensorflow提供两种model格式:
checkpoints:依赖创建model的code
SavedModel:与创建model的code无关
保存部分训练的models
Estimator会自动将下面内容写入硬盘:
checkpoints:训练过程中创建的model版本
event files:TensorBoard用于创建可视化的信息
保存model
classifier = tf.estimator.DNNClassifier(
feature_columns=my_feature_columns,
hidden_units=[10, 10],
n_classes=3,
model_dir='models/iris')
当使用Estimator训练模型时
classifier.train(
input_fn=lambda:train_input_fn(train_x, train_y, batch_size=100),
steps=200)
当第一次调用train时,checkpoints和其他文件会被添加到model_dir
目录
查看model_dir
目录内容
$ ls -1 models/iris
checkpoint
events.out.tfevents.timestamp.hostname
graph.pbtxt
model.ckpt-1.data-00000-of-00001
model.ckpt-1.index
model.ckpt-1.meta
model.ckpt-200.data-00000-of-00001
model.ckpt-200.index
model.ckpt-200.meta
目录中显示Estimator在steps 1(开始训练)和200步(训练结束)的checkpoint。
默认的checkpoint目录
如果不指定model_dir
,Estimator会将checkpoint文件写入临时目录
classifier = tf.estimator.DNNClassifier(
feature_columns=my_feature_columns,
hidden_units=[10, 10],
n_classes=3)
print(classifier.model_dir)
Checkpoint频率
默认Estimator保存checkpoints会依据下面的规则:
每10分钟保存一次checkpoint
训练开始和训练结束时保存checkpoint
保留目录中最近的5次checkpoint
可以修改默认的schedule,比如每20分钟保存一次checkpoint,保存最近10次的checkpoint
my_checkpointing_config = tf.estimator.RunConfig(
save_checkpoints_secs = 20*60, # Save checkpoints every 20 minutes.
keep_checkpoint_max = 10, # Retain the 10 most recent checkpoints.
)
classifier = tf.estimator.DNNClassifier(
feature_columns=my_feature_columns,
hidden_units=[10, 10],
n_classes=3,
model_dir='models/iris',
config=my_checkpointing_config)
恢复模型
我们知道,在Estimator在第一次调用train时,tensorflow会保存checkpoint。
而接下来对Estimator的调用,无论是train,eval还是predict,都会让Estimator通过model_fn()
构建模型图,Estimator会使用最近保存的checkpoint中的weights来初始化模型。
避免bad restoration
假设DNNClassifier Estimator包含两层隐藏层,每层有10个节点:
classifier = tf.estimator.DNNClassifier(
feature_columns=feature_columns,
hidden_units=[10, 10],
n_classes=3,
model_dir='models/iris')
classifier.train(
input_fn=lambda:train_input_fn(train_x, train_y, batch_size=100),
steps=200)
在训练之后,此时更改神经网络的结构,将每层隐藏层由10个节点变为20个节点
classifier2 = tf.estimator.DNNClassifier(
feature_columns=my_feature_columns,
hidden_units=[20, 20], # Change the number of neurons in the model.
n_classes=3,
model_dir='models/iris')
classifier.train(
input_fn=lambda:train_input_fn(train_x, train_y, batch_size=100),
steps=200)
此时checkpoint与classifier2中的model并不兼容,因此会报错。