为了让训练得到的模型保存下来方便下次直接调用,我们需要将训练得到的神经网络模型持久化。在这篇文章中,针对基于LSTM的预测模型,完整的实现了模型训练、保存、恢复、测试等过程,主要包括:

  • 模型训练
  • 持久化为ckpt文件
  • 基于ckpt文件的测试
  • 从ckpt转为pb文件
  • 基于pb文件的测试

由于本篇博客重点关注内容为上述过程,所以对无关的代码做了省略处理。

1. 基本处理代码

导入一些基础包

import tensorflow as tf
import pandas as pd
import numpy as np
import os
import time
from os.path import join as pjoin
from tensorflow.python.framework import graph_util

定义命令行参数

# parameter settings
flags = tf.app.flags
FLAGS = tf.app.flags.FLAGS
flags.DEFINE_float('learning_rate', 0.001, 'Initial learning rate.')
flags.DEFINE_integer('training_epoch', 2000, 'number of epochs to train.')
flags.DEFINE_integer('num_units', 64, 'hidden units of lstm.')
flags.DEFINE_integer('seq_len',12 , '  time length of inputs.')
flags.DEFINE_integer('pre_len', 3, 'time length of prediction.')
flags.DEFINE_integer('batch_size', 64, 'batch size.')

定义相关路径

ckpt_path = path+'/ckpt/'   # ckpt文件保存路径
pb_path = path+'/pb/model.pb' # pb文件保存路径
saved_model_path = path+'/saved_model/' # 基于saved_model方法得到的pb和variables文件保存路径

定义模型实现函数

这里我们使用的是一个基于LSTM的预测模型,具体实现代码省略。

def lstm_prediction(X, weights, biases):
    # 基于LSTM的预测函数,具体内容不详细写了
    ......
    output = tf.reshape(output, shape=[-1, num_nodes], name='ypred')   # (batch_size*pre_len, num_nodes)
    return output
2. 模型保存为ckpt文件

首先,在模型训练部分遵从Tensorflow所需的模式,不再赘述。

其次,在模型保存持久化部分,TensorFlow提供了一个非常简单的API来保存和还原一个神经网络模型。这个API就是 tf.train.Saver 类。使用 tf.train.saver() 保存模型时会产生多个文件,会把计算图的结构和图上参数取值分成了不同的文件存储。这种方式是在TensorFlow中是最常用的保存方式。
关于这部分,可以参考我之前的一篇博客:TensorFlow 模型保存和加载

def model_train_save_ckpt(train_X, train_Y, FLAGS, ckpt_path):
    """
    训练模型;定期将模型保存为ckpt文件形式
    param:
        train_X: 训练数据X,shape=[train_len, seq_len, num_nodes]
        train_Y: 训练数据Y,shape=[train_len, pre_len, num_nodes]
        FLAGS: 命令行参数
        ckpt_path: ckpt文件的路径
    """
    
    # get parameters from FLAGS
    seq_len = FLAGS.seq_len
    pre_len = FLAGS.pre_len
    batch_size = FLAGS.batch_size
    lr = FLAGS.learning_rate
    training_epoch = FLAGS.training_epoch
    num_units = FLAGS.num_units
    
    # placeholders
    inputX = tf.placeholder(tf.float32, shape=[None, seq_len, num_nodes], name='inputXX')
    labelY = tf.placeholder(tf.float32, shape=[None, pre_len, num_nodes], name='labelYY')
    
    # weight and bias for the output layer(fc) in lstm_prediction function
    weight_out = {'out': tf.Variable(tf.random_normal([num_units, pre_len], mean=1.0), name='weight_out')}
    bias_out  = {'out': tf.Variable(tf.random_normal([pre_len]), name='bias_out')}
    
    # call lstm prediction function and return y_pred
    y_pred = lstm_prediction(inputX, weight_out, bias_out) 
    
    # loss
    label = tf.reshape(labelY, [-1, num_nodes]) # labelY: [  , pre_len, num_nodes]
    loss = tf.reduce_mean(tf.nn.l2_loss(y_pred-label), name='loss')  # 此处定义了name='loss',后面可用get_tensor_by_name获取该Tensor
    # tf.add_to_collection(name='loss', value=loss)  
    # 这块如果用tf.add_to_collection,后面restore阶段,可用 loss=graph.get_collection('loss:0') 获取
    
    # error
    error = tf.sqrt(tf.reduce_mean(tf.square(y_pred-label)), name='error') 
    # tf.add_to_collection(name='error', value=error)
    
    # optimizer
    optimizer = tf.train.AdamOptimizer(lr).minimize(loss) 

    # saver 
    saver = tf.train.Saver(tf.global_variables())  

    # session 
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        
        batch_loss, batch_rmse = [], []
        time_start = time.time()
        
        for epoch in range(training_epoch):   
            for m in range(total_batch):
                mini_batch = train_X[m*batch_size : (m+1)*batch_size]
                mini_label = train_Y[m*batch_size : (m+1)*batch_size]
                
                _, loss1, rmse1, train_output = sess.run([optimizer, loss, error, y_pred], 
                                                         feed_dict={inputX: mini_batch, labelY: mini_label})    
                batch_loss.append(loss1)
                batch_rmse.append(rmse1)
            print('Iter:{}'.format(epoch), 'train_loss:{:.4}'.format(batch_loss[-1]), 'train_rmse:{:.4}'.format(batch_rmse[-1]))
            
            if (epoch % 500 == 0): # 训练时定期保存模型,防止由于断电等情况导致模型丢失
                saver.save(sess, ckpt_path, global_step = epoch)  
            
        time_end = time.time()
        print('Training time:{:.4}'.format(time_end-time_start),'s')    

核心代码如下:

# 定义saver
saver = tf.train.Saver(tf.global_variables()) 
# 调用train.Saver的save方法
saver.save(sess, ckpt_path, global_step = epoch) 
    

在上述代码中,通过saver.save() 函数将TensorFlow模型保存到了 ckpt_path路径下。TensorFlow模型一般会保存在后缀为 .ckpt 的文件中,虽然上面的程序只指定了一个文件路径,但是这个文件目录下面会出现三个文件。这是因为TensorFlow会将计算图的结构和图上参数取值分开保存。
在这里插入图片描述
下面解释一下上述几个文件的作用:

  • checkpoint:检查点文件,文件保存了一个目录下所有模型文件列表。
  • model.ckpt.data:二进制文件,保存了TensorFlow程序中每一个变量的取值,包括所有weights,biases,gradient and all the other variables。
  • model.ckpt.index:保存了TensorFlow程序中变量的索引
  • model.ckpt.meta:保存了TensorFlow计算图的结构,包括variables operations,collections等等。(该文件可以被 tf.train.import_meta_graph 加载到当前默认的图来使用
3. 基于ckpt文件的测试

加载模型的代码基本上和保存模型的代码是一样的。在加载模型的程序中先定义了测试图test_graph,再定义计算图上所有运算,并声明了一个 tf.train.Saver类。

不同之处在于,在加载模型的代码中没有运行变量的初始化过程,而是将变量的值通过get_tensor_by_name将已经保存的模型加载出来。

def model_test_ckpt(test_X, test_Y, ckpt_path): 
    """
    基于ckpt文件,对测试数据进行测试(预测)
    param:
        test_X: 测试数据X,shape=[test_len*seq_len, num_nodes]
        test_Y: 测试数据Y,shape=[test_len*pre_len, num_nodes]
        ckpt_path: ckpt文件的路径
    """
    time_start = time.time()
    
    model_path = tf.train.get_checkpoint_state(ckpt_path).model_checkpoint_path   
    test_graph = tf.Graph() # 定义test_graph
    
    with test_graph.as_default():
        saver = tf.train.import_meta_graph(pjoin(f'{model_path}.meta'))
    
    # 基于ckpt的测试
    with tf.Session(graph=test_graph) as test_sess:
        # 从原图中恢复出基本信息
        saver.restore(test_sess, tf.train.latest_checkpoint(ckpt_path)) 
        print('Load model succeed!')
        
        # 通过get_tensor_by_name,定义输入张量
        inputXX = test_sess.graph.get_tensor_by_name('inputXX:0')
        labelYY = test_sess.graph.get_tensor_by_name('labelYY:0')
        
        # 通过get_tensor_by_name,定义输出张量        
        y_pred = test_sess.graph.get_tensor_by_name('ypred:0')
    
        # 通过get_tensor_by_name,定义输出张量 
        loss = test_sess.graph.get_tensor_by_name('loss:0')
        error = test_sess.graph.get_tensor_by_name('error:0')
        #  loss = test_graph.get_collection('loss')[0]
        #  error = test_graph.get_collection('error')[0] 
        
        loss2, rmse2, test_output = test_sess.run([loss, error, y_pred],
                                                  feed_dict = {inputXX: test_X, labelYY: test_Y})
        
        test_label = np.reshape(test_Y,[-1,num_nodes])
        
        # test results evaluation
        test_rmse, test_mae, test_acc = evaluation(test_label1, test_output1)
        print('test_rmse:{:.4}'.format(test_rmse), 'test_mae:{:.4}'.format(test_mae), 'test_acc:{:.4}'.format(test_acc))
        
        time_end = time.time()
        print('Testing time:{:.4}'.format(time_end-time_start),'s')

在上述测试代码中,首先定义test_graph,再通过get_tensor_by_name获取训练图中已经定义好的节点。
如果不希望重复定义图上的运算,也可以直接加载已经持久化的图,以下代码给出一个简单实现样例。

graph = tf.get_default_graph().as_graph_def()  # 获得默认的图
y_pred = test_sess.graph.get_tensor_by_name('ypred:0')) # 获取图中的ypred
4. 模型保存为pb文件

在上述过程中,我们将TensorFlow模型保存为 ckpt 格式的模型文件,但是这种保存方式有几个缺点:

  • 这种模型文件是依赖 TensorFlow 的,只能在其框架下使用。
  • 在恢复模型之前还需要再定义一遍网络结构,然后才能把变量的值恢复到网络中。
  • 保存模型文件的时候会产生多个文件,它将变量的取值和计算图结构分成了不同的文件存储。
  • 使用 tf.train.Saver 默认保存和加载了TensorFlow计算图上定义的所有变量,但是有时可能只需要保存或者加载部分变量。
    (比如: (1)在测试或者离线预测时,只需要知道如何从神经网络的输出层经过前向传播计算得到输出层即可,而不需要类似于变量初始化,模型保存等辅助接点的信息。(2) 再比如,可能有一个之前训练好的五层神经网络模型,现在想尝试一个六层神经网络,那么可以将前面五层神经网络中的参数直接加载到新的模型,而仅仅将最后一层神经网络重新训练。)

谷歌推荐的保存模型的方式是保存模型为 PB 文件,它具有语言独立性,可独立运行,封闭的序列化格式,任何语言都可以解析它,它允许其他语言和深度学习框架读取、继续训练和迁移 TensorFlow 的模型。另外的好处是保存为 PB 文件时候,模型的变量都会变成固定的,导致模型的大小会大大减小,适合在手机端运行。

那么模型保存为pb文件的形式也可以分为两种:直接保存为pb文件、从ckpt转pb文件。

4.1 直接保存为pb文件
4.1.1 tf.graph_util.convert_variables_to_constants

代码主体部分和直接保存为ckpt文件类似,主要区别部分在于:

  • convert_variables_to_constants 将图中的变量转化为常量,固化模型结构
  • 写入序列化的pb文件
def model_train_save_pb(train_X, train_Y, FLAGS, pb_path):
    """
    训练模型;定期将模型保存为pb文件形式
    param:
        train_X: 训练数据X,shape=[train_len, seq_len, num_nodes]
        train_Y: 训练数据Y,shape=[train_len, pre_len, num_nodes]
        FLAGS: 命令行参数
        pb_path: pb文件的路径
    """
    
    # get parameters from FLAGS
    seq_len = FLAGS.seq_len
    pre_len = FLAGS.pre_len
    batch_size = FLAGS.batch_size
    lr = FLAGS.learning_rate
    training_epoch = FLAGS.training_epoch
    num_units = FLAGS.num_units
    
    # placeholders
    inputX = tf.placeholder(tf.float32, shape=[None, seq_len, num_nodes], name='inputXX')
    labelY = tf.placeholder(tf.float32, shape=[None, pre_len, num_nodes], name='labelYY')
    
    # weight and bias for the output layer(fc) in lstm_prediction function
    weight_out = {'out': tf.Variable(tf.random_normal([num_units, pre_len], mean=1.0), name='weight_out')}
    bias_out  = {'out': tf.Variable(tf.random_normal([pre_len]), name='bias_out')}
    
    # call lstm prediction function and return y_pred
    y_pred = lstm_prediction(inputX, weight_out, bias_out) 
    
    # loss
    label = tf.reshape(labelY, [-1, num_nodes]) # labelY: [  , pre_len, num_nodes]
    loss = tf.reduce_mean(tf.nn.l2_loss(y_pred-label), name='loss')  # 此处定义了name='loss',后面可用get_tensor_by_name获取该Tensor
    # tf.add_to_collection(name='loss', value=loss)  
    # 这块如果用tf.add_to_collection,后面restore阶段,可用 loss=graph.get_collection('loss:0') 获取
    
    # error
    error = tf.sqrt(tf.reduce_mean(tf.square(y_pred-label)), name='error') 
    # tf.add_to_collection(name='error', value=error)
    
    # optimizer
    optimizer = tf.train.AdamOptimizer(lr).minimize(loss) 

    # saver 
    saver = tf.train.Saver(tf.global_variables())  

    # session 
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        
        batch_loss, batch_rmse = [], []
        time_start = time.time()
        
        for epoch in range(training_epoch):   
            for m in range(total_batch):
                mini_batch = train_X[m*batch_size : (m+1)*batch_size]
                mini_label = train_Y[m*batch_size : (m+1)*batch_size]
                
                _, loss1, rmse1, train_output = sess.run([optimizer, loss, error, y_pred], 
                                                         feed_dict={inputX: mini_batch, labelY: mini_label})    
                batch_loss.append(loss1)
                batch_rmse.append(rmse1)
            print('Iter:{}'.format(epoch), 'train_loss:{:.4}'.format(batch_loss[-1]), 'train_rmse:{:.4}'.format(batch_rmse[-1]))
            
        # convert_variables_to_constants 将图中的变量转化为常量,固化模型结构    
        output_graph = graph_util.convert_variables_to_constants(
            sess=sess, 
            input_graph_def=sess.graph_def, 
            output_node_names=['ypred']
        )
        # 写入序列化的pb文件
        with tf.gfile.GFile(pb_path, 'wb') as f:
            f.write(output_graph.SerializeToString())  
            
        time_end = time.time()
        print('Training time:{:.4}'.format(time_end-time_start),'s')    

核心代码如下:

# convert_variables_to_constants 将图中的变量转化为常量,固化模型结构    
output_graph = graph_util.convert_variables_to_constants(
    sess=sess, 
    input_graph_def=sess.graph_def, 
    output_node_names=['ypred']
)
# 写入序列化的pb文件
with tf.gfile.GFile(pb_path, 'wb') as f:
    f.write(output_graph.SerializeToString())  
4.1.2 tf.saved_model.builder.SavedModelBuilder

saved_model模块主要用于TensorFlow Serving,详细内容可以参考我的另外一篇博客:https://blog.csdn.net/u012856866/article/details/107915516

  • savedmodel文件保存
# 构造SavedModelBuilder对象
builder=tf.saved_model.builder.SavedModelBuilder(saved_model_path)
# 导入graph的信息以及变量
builder.add_meta_graph_and_variables(
    sess=sess,
    tags=['lstm_saved_model']
    )
# 将模型序列化到指定目录底下
builder.save()

保存好以后到saved_model_dir目录下,会有一个saved_model.pb文件以及variables文件夹。顾名思义,variables保存所有变量,saved_model.pb用于保存模型结构等信息。

  • savemodel加载
meta_graph_def = tf.saved_model.loader.load(
    sess, 
    ['lstm_saved_model'], 
    saved_model_path
)

第一个参数就是当前的session,第二个参数是在保存的时候定义的meta graph的标签,标签一致才能找到对应的meta graph。第三个参数就是模型保存的目录。load完以后,也是从sess对应的graph中获取需要的tensor来inference。

4.2 从ckpt转为pb文件

将CKPT转换成 PB格式的文件可以通过 convert_variables_to_constants 函数实现,主要过程如下:

  1. 通过传入 CKPT模型的路径得到模型的图和变量数据
  2. 通过 import_meta_graph 导入模型中的图
  3. 通过saver.restore 从模型中恢复图中各个变量的数据
  4. 通过 graph_util.convert_variables_to_constants 将模型持久化
def freeze_graph_ckpt_to_pb(ckpt_path, pb_path):
    """
    将前面保存的ckpt文件转化为pb文件
    param:
        ckpt_path: ckpt文件的路径
        pb_path: pb文件的路径
    """
    model_path = tf.train.get_checkpoint_state(ckpt_path).model_checkpoint_path   
    saver = tf.train.import_meta_graph(pjoin(f'{model_path}.meta')) # 恢复图并得到数据
    
    # input_graph = tf.get_default_graph().as_graph_def() 
    graph = tf.get_default_graph()  # 获得默认的图(由saver=tf.train.import_meta_graph恢复出的图)
    input_graph = graph.as_graph_def()  # 返回一个序列化的图代表当前的图
    
    with tf.Session() as sess:
        saver.restore(sess, tf.train.latest_checkpoint(ckpt_path))
        print('Load model succeed!')
         
        # convert_variables_to_constants 将图中的变量及其取值转化为常量,固化模型结构
        output_graph = graph_util.convert_variables_to_constants(
            sess=sess,
            input_graph_def=input_graph,
            output_node_names=['ypred']  # 需要保存的节点名称,该节点名称必须是元模型中存在的节点
        )
        
        # 将导出的模型存入文件
        with tf.gfile.GFile(pb_path, 'wb') as f:
            f.write(output_graph.SerializeToString())
        
        print("%d ops in the final graph." % len(output_graph.node)) 

补充说明:

  1. 函数 freeze_graph中,最重要的就是要确定“指定输出的节点名称”,这个节点名称必须是原模型中存在的节点,对于 freeze 操作,我们需要定义输出节点的名字。因为网络其实是比较复杂的,定义了输出节点的名字,那么freeze操作的时候就只把输出该节点所需要的子图都固化下来,其他无关的就舍弃掉。因为我们 freeze 模型的目的是接下来做预测,所以 output_node_names 一般是网络模型最后一层输出的节点名称,或者说我们预测的目标。
  2. 在保存的时候,通过 convert_variables_to_constants 函数来指定需要固化的节点名称output_node_names。注意节点名称与张量名称的区别。比如:“input:0 是张量的名称”,而“input” 表示的是节点的名称。
  3. 源码中通过 graph=tf.get_default_graph() 获得默认的图,这个图就是由 saver=tf.train.import_meta_graph(input_checkpoint + ‘.meta’, clear_devices=True) 恢复的图,因此必须先执行 tf.train.import_meta_graph,再执行 tf.get_default_graph()。

注:如果不知道网络节点名称,或者说不想去模型中找节点名称,那么我们可以在加载完模型的图数据之后,可以输出图中的节点信息查看一下模型的输入输出节点:

for op in tf.get_default_graph().get_operations():
    print(op.name, op.values())
5. 基于pb文件的测试

针对上面 4.1.14.2 节中,利用 convert_variables_to_constants 固化得到的pb模型文件,可直接使用下面代码进行相应的测试:

def freeze_graph_test_pb(pb_path):
    """
    基于固化的pb文件,对测试数据进行测试(预测)
    param:
        pb_path: pb文件的路径
    """
    with tf.Graph().as_default():
        output_graph = tf.GraphDef()
        with open(pb_path, 'rb') as f:
            output_graph.ParseFromString(f.read())
            tf.import_graph_def(output_graph, name='')
            print('Load model succeed!')
        
        with tf.Session() as test_sess:
            test_sess.run(tf.global_variables_initializer())
            
            # 通过get_tensor_by_name,定义输入张量
            inputXX = test_sess.graph.get_tensor_by_name('inputXX:0')

            # 通过get_tensor_by_name,定义输出张量       
            y_pred = test_sess.graph.get_tensor_by_name('ypred:0')
            
            # 加载测试数据
            test_X = pd.read_csv('./test_X_file.csv', header=None)          
            test_X = np.array(test_X)
            test_X = np.reshape(test_X, [-1, seq_len, num_nodes])
            
            # feed输入数据到模型中,得到输出(即预测结果)
            test_output = test_sess.run([y_pred], feed_dict = {inputXX: test_X})
            
            # 将预测结果存储成CSV文件
            test_output = np.array(test_output[0])  # 注:取test_output[0] 
            test_output_file = pd.DataFrame(test_output)
            test_output_file.to_csv('./test_output.csv', index=None, header=None)
            
            print('test with pb finished!')

补充说明:

  1. 与ckpt预测不同的是,pb文件已经固化了网络模型结构,因此,即使不知道原训练模型(train)的源码,我们也可以恢复网络图,并进行预测。恢复模型非常简单,只需要从读取的序列化数据中导入网络结构即可:
tf.import_graph_def(output_graph, name="")
  1. 但是必须知道原网络模型的输入和输出的节点名称(当然了,传递数据时,是通过输入输出的张量来完成的)。由于LSTM模型的输入有1个节点,因此这里需要定义输入的张量名称,它对应的网络结构的输入张量:
inputXX = test_sess.graph.get_tensor_by_name('inputXX:0')

定义输出张量:

y_pred = test_sess.graph.get_tensor_by_name('ypred:0')
  1. 预测时,需要 feed输入数据
test_output = test_sess.run([y_pred], feed_dict = {inputXX: test_X})
附1:模型持久化完整例子

参考博客:tensorflow三种加载模型的方法和三种模型保存文件(.ckpt,.pb, SavedModel)

下面代码是实现上面三种保存模式的小例子,可以粘贴复制把相关的代码注释掉,运行一下看看结果,能加深理解:

import tensorflow as tf
with tf.Session() as sess:
  #搭建网络
  x=tf.placeholder(tf.float32,name='x')
  y=tf.placeholder(tf.float32,name='y')
  b=tf.Variable(1.,name='b')
  xy=tf.multiply(x,y)
  op=tf.add(xy,b,name='op')
  sess.run(tf.global_variables_initializer())
  print(sess.run(op,feed_dict={x:2,y:3}))

  #ckpt保存
  saver=tf.train.Saver()
  saver.save(sess,'D:/pycharm files/111/ckpt/model_ck')

  #pb保存
  constant_graph=tf.graph_util.convert_variables_to_constants(sess,sess.graph_def,['op'])
  with tf.gfile.FastGFile('D:/pycharm files/111/pb/model.pb','wb') as f:
  f.write(constant_graph.SerializeToString())

  #savedmodel文件保存
  builder=tf.saved_model.builder.SavedModelBuilder('D:/pycharm files/111/savemodel')
  builder.add_meta_graph_and_variables(sess,['cpu_server_1'])
  builder.save()

  print('over')


  #ckpt加载
  saver=tf.train.import_meta_graph('D:/pycharm files/111/ckpt/model_ck.meta')
  saver.restore(sess,tf.train.latest_checkpoint('D:/pycharm files/111/ckpt'))

  #pb加载
  with tf.gfile.FastGFile('D:/pycharm files/111/pb/model.pb','rb') as f:
    graph_def=tf.GraphDef()
    graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def,name='')

  #savemodel加载
  tf.saved_model.loader.load(sess, ['cpu_server_1'], 'D:/pycharm files/111/savemodel')

  #测试模型加载是否成功
  input_x = sess.graph.get_tensor_by_name('x:0')
  input_y = sess.graph.get_tensor_by_name('y:0')
  op = sess.graph.get_tensor_by_name('op:0')
  ret = sess.run(op, feed_dict={input_x: 5, input_y: 5})
  print(ret)

在这里插入图片描述在这里插入图片描述在这里插入图片描述

附2:tensorflow打印模型文件节点

tensorflow打印pb模型的所有节点:

from tensorflow.python.framework import tensor_util
from google.protobuf import text_format 
import tensorflow as tf 
from tensorflow.python.platform import gfile 
from tensorflow.python.framework import tensor_util
 
pb_path = './model.pb'
 
with tf.Session() as sess:
    with gfile.FastGFile(pb_path,'rb') as f:
        graph_def = tf.GraphDef()
 
        graph_def.ParseFromString(f.read())
        tf.import_graph_def(graph_def,name='')
        for i,n in enumerate(graph_def.node):
            print("Name of the node -%s"%n.name)

tensorflow打印ckpt模型的所有节点:

from tensorflow.python import pywrap_tensorflow
checkpoint_path = './_checkpoint/hed.ckpt-130'
 
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
    print("tensor_name:",key)

【参考博客】:

Logo

学大模型,用大模型上飞桨星河社区!每天8点V100G算力免费领!免费领取ERNIE 4.0 100w Token >>>

更多推荐