博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
TensorFlow学习系列(三):保存/恢复和混合多个模型
阅读量:2439 次
发布时间:2019-05-10

本文共 4382 字,大约阅读时间需要 14 分钟。

这篇教程是翻译写的TensorFlow教程,作者已经授权翻译,这是。


目录



在学习这篇博客之前,我希望你已经掌握了Tensorflow基本的操作。如果没有,你可以阅读这篇。

为什么要学习模型的保存和恢复呢?因为这对于避免数据的混乱无序是至关重要的,特别是在你代码中的不同图。

如何保存和加载模型

saver类

在不同的会话中,当需要将数据在硬盘上面进行保存时,那么我们就可以使用Saver这个类。这个Saver构造类允许你去控制3个目标:

  • 目标(The target):这个参数设置目标。在分布式架构的情况下,我们可以指定要计算哪个TF服务器或者“目标”。
  • 图(The graph):这个参数设置保存的图。保存你希望会话处理的图。对于初学者来说,这里有一件棘手的事情就是在Tensorflow中总是有一个默认的图,并且你所有的操作都是在这个图中首先进行。所有,你总是在“默认图范围”内。
  • 配置(The config):这个参数设置配置。你可以使用 ConfigProto 参数来进行配置Tensorflow。,查看更多信息。

Saver类可以处理你的图中元数据和变量数据的保存和恢复。而我们唯一需要做的是,告诉Saver类我们需要保存哪个图和哪些变量。

在默认情况下,Saver类能处理默认图中包含的所有变量。但是,你也可以去创建很多的Saver类,去保存你想要的任何子图。

import tensorflow as tf# First, you design your mathematical operations# We are the default graph scope# Let's design a variablev1 = tf.Variable(1. , name="v1")v2 = tf.Variable(2. , name="v2")# Let's design an operationa = tf.add(v1, v2)# Let's create a Saver object# By default, the Saver handles every Variables related to the default graphall_saver = tf.train.Saver() # But you can precise which vars you want to save under which namev2_saver = tf.train.Saver({
"v2": v2}) # By default the Session handles the default graph and all its included variableswith tf.Session() as sess: # Init v and v2 sess.run(tf.global_variables_initializer()) # Now v1 holds the value 1.0 and v2 holds the value 2.0 # We can now save all those values all_saver.save(sess, 'data.chkp') # or saves only v2 v2_saver.save(sess, 'data-v2.chkp')

当你运行了上面的程序之后,如果你去看文件夹,那么你会发现文件夹中存在了七个文件(如下)。在接下来的博客中,我会详细解释这些文件的意义。目前你只需要知道,模型的权重是保存在 .chkp 文件中,模型的图是保存在 .chkp.meta 文件中。

├── checkpoint├── data-v2.chkp.data-00000-of-00001├── data-v2.chkp.index├── data-v2.chkp.meta├── data.chkp.data-00000-of-00001├── data.chkp.index├── data.chkp.meta

恢复操作和其它元数据

我想分享的最后一个信息是,Saver将保存与图有关联的任何元数据。这就意味着,当我们恢复一个模型的时候,我们还同时恢复了所有与图相关的变量、操作和集合。

当我们恢复一个元模型(restore a meta checkpoint)时,实际上我们执行的操作是将恢复的图载入到当前的默认图中。所有当你完成模型恢复之后,你可以在默认图中访问载入的任何内容,比如一个张量,一个操作或者集合。

import tensorflow as tf# Let's laod a previous meta graph in the current graph in use: usually the default graph# This actions returns a Saversaver = tf.train.import_meta_graph('results/model.ckpt-1000.meta')# We can now access the default graph where all our metadata has been loadedgraph = tf.get_default_graph()# Finally we can retrieve tensors, operations, etc.global_step_tensor = graph.get_tensor_by_name('loss/global_step:0')train_op = graph.get_operation_by_name('loss/train_op')hyperparameters = tf.get_collection('hyperparameters')

恢复权重

请记住,在实际的环境中,真实的权重只能存在于一个会话中。也就是说,restore 这个操作必须在一个会话中启动,然后将数据权重导入到图中。理解恢复操作的最好方法是将它简单的看做是一种数据初始化操作。

with tf.Session() as sess:    # To initialize values with saved data    saver.restore(sess, 'results/model.ckpt-1000-00000-of-00001')    print(sess.run(global_step_tensor)) # returns 1000

在新图中导入预训练模型

至此,你应该已经明白了如何去保存和恢复一个模型。然而,我们还可以使用一些技巧去帮助你更快的保存和恢复一个模型。比如:

  • 一个图的输出能成为另一个图的输入吗?

答案是确定的。但是目前我的做法是先将第一个图进行保存,然后在另一个图中进行恢复。但是这种方案感觉很笨重,我不知道是否有更好的方法。

但是这种方法确实能工作,除非你想要去重新训练第一个图。在这种情况下,你需要将输入的梯度重新输入到第一张图中的特定的训练步骤中。我想你已经被这种复杂的方案给逼疯了把。:-)

  • 我可以在一个图中混合不同的图吗?

答案当然是肯定的,但是你必须非常小心命名空间。这种方法有一点好处是,简化了一切。比如,你可以预加载一个VGG-19模型。然后访问图中的任何节点,并执行你自己的后续操作,从而训练一整个完整的模型。

如果你只想微调你自己的节点,那么你可以在你想要的地方中断梯度。

import tensorflow as tf# Load the VGG-16 model in the default graphvgg_saver = tf.train.import_meta_graph(dir + '/vgg/results/vgg-16.meta')# Access the graphvgg_graph = tf.get_default_graph()# Retrieve VGG inputsself.x_plh = vgg_graph.get_tensor_by_name('input:0')# Choose which node you want to connect your own graphoutput_conv =vgg_graph.get_tensor_by_name('conv1_2:0')# output_conv =vgg_graph.get_tensor_by_name('conv2_2:0')# output_conv =vgg_graph.get_tensor_by_name('conv3_3:0')# output_conv =vgg_graph.get_tensor_by_name('conv4_3:0')# output_conv =vgg_graph.get_tensor_by_name('conv5_3:0')# Stop the gradient for fine-tuningoutput_conv_sg = tf.stop_gradient(output_conv) # It's an identity function# Build further operationsoutput_conv_shape = output_conv_sg.get_shape().as_list()W1 = tf.get_variable('W1', shape=[1, 1, output_conv_shape[3], 32], initializer=tf.random_normal_initializer(stddev=1e-1))b1 = tf.get_variable('b1', shape=[32], initializer=tf.constant_initializer(0.1))z1 = tf.nn.conv2d(output_conv_sg, W1, strides=[1, 1, 1, 1], padding='SAME') + b1a = tf.nn.relu(z1)

References:


如果觉得内容有用,帮助多多分享哦 :)

长按或者扫描如下二维码,关注 “CoderPai” 微信号(coderpai)。添加底部的 coderpai 小助手,添加小助手时,请备注 “算法” 二字,小助手会拉你进算法群。如果你想进入 AI 实战群,那么请备注 “AI”,小助手会拉你进AI实战群。

你可能感兴趣的文章
Solaris硬盘分区简介(转)
查看>>
gcc编译器小知识FAQ(转)
查看>>
Linux下多线程编程与信号处理易疏忽的一个例子(转)
查看>>
流氓和木马结合 强行关闭你的防火墙(转)
查看>>
SUSE一纸诉状控告SCO 捍卫知识产权(转)
查看>>
debian下编译2.6.13.2内核的步骤及感受(转)
查看>>
预装正版的市场意义(转)
查看>>
创建小于16M XFree86迷你Linux系统(转)
查看>>
shell中常用的工具(转)
查看>>
使用MySQL内建复制功能来最佳化可用性(转)
查看>>
一个比较vista的vista主题for rf5.0fb(转)
查看>>
推荐一款 Linux 上比较漂亮的字体(转)
查看>>
在Linux中添加新的系统调用(转)
查看>>
Fedora Core 5.0 安装教程{下载}(转)
查看>>
把ACCESS的数据导入到Mysql中(转)
查看>>
shell里边子函数与主函数的实例(转)
查看>>
Linux中MAXIMA符号运算软件的简介(转)
查看>>
银行选择Linux 则无法回避高成本(转)
查看>>
上网聊天需要防范的几大威胁(转)
查看>>
[分享]后门清除完全篇(转)
查看>>