使用Tensorflow怎么对ckpt文件中的tensor进行读取-创新互联

本篇文章给大家分享的是有关使用Tensorflow怎么对ckpt文件中的tensor进行读取,小编觉得挺实用的,因此分享给大家学习,希望大家阅读完这篇文章后可以有所收获,话不多说,跟着小编一起来看看吧。

网站建设哪家好,找成都创新互联!专注于网页设计、网站建设、微信开发、重庆小程序开发、集团企业网站建设等服务项目。为回馈新老客户创新互联还提供了惠东免费建站欢迎大家使用!

在使用pre-train model时候,我们需要restore variables from checkpoint files.

经常出现在checkpoint 中找不到”Tensor name not found”.

这时候需要查看一下ckpt中到底有哪些变量

import os
from tensorflow.python import pywrap_tensorflow

checkpoint_path = os.path.join(model_dir, "model.ckpt")
# Read data from checkpoint file
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
# Print tensor name and values
for key in var_to_shape_map:
  print("tensor_name: ", key)
  print(reader.get_tensor(key))

可以显示ckpt中的tensor名字和值,当然也可以用pycharm调试。

补充:tensorflow中读取模型中保存的值, tf.train.NewCheckpointReader

使用tf.trian.NewCheckpointReader(model_dir)

一个标准的模型文件有一下文件, model_dir就是MyModel(没有后缀)

checkpoint
Model.meta
Model.data-00000-of-00001
Model.index
import tensorflow as tf
import pprint # 使用pprint 提高打印的可读性
NewCheck =tf.train.NewCheckpointReader("model")

打印模型中的所有变量

print("debug_string:\n")
pprint.pprint(NewCheck.debug_string().decode("utf-8"))

使用Tensorflow怎么对ckpt文件中的tensor进行读取

其中有3个字段, 分别是名字, 数据类型, shape

获取变量中的值

print("get_tensor:\n")
pprint.pprint(NewCheck.get_tensor("D/conv2d/bias"))

使用Tensorflow怎么对ckpt文件中的tensor进行读取

print("get_variable_to_dtype_map\n")
pprint.pprint(NewCheck.get_variable_to_dtype_map())
print("get_variable_to_shape_map\n")
pprint.pprint(NewCheck.get_variable_to_shape_map())

以上就是使用Tensorflow怎么对ckpt文件中的tensor进行读取,小编相信有部分知识点可能是我们日常工作会见到或用到的。希望你能通过这篇文章学到更多知识。更多详情敬请关注创新互联行业资讯频道。


文章标题:使用Tensorflow怎么对ckpt文件中的tensor进行读取-创新互联
标题来源:http://bzwzjz.com/article/dehosj.html

其他资讯

Copyright © 2007-2020 广东宝晨空调科技有限公司 All Rights Reserved 粤ICP备2022107769号
友情链接: 成都商城网站建设 app网站建设 网站制作 成都网站制作 网站建设改版 成都网站建设 网站制作报价 成都做网站建设公司 成都网站建设 教育网站设计方案 定制网站制作 成都品牌网站设计 响应式网站设计方案 成都网站设计公司 成都网站建设 网站建设 成都定制网站建设 盐亭网站设计 泸州网站建设 四川成都网站建设 自适应网站设计 成都网站设计