머신러닝과 기술적 분석

Tensorflow 에서 model 을 저장, 로드하는 방법 본문

Tensorflow

Tensorflow 에서 model 을 저장, 로드하는 방법

BetterToday 2017. 8. 16. 23:31
728x90

Tensorflow 에서 model 을 저장, 로드하는 샘플코드다.

Variable 객체를 생성할 때 naming 을 지정하도록 하자.

import tensorflow as tf

tf.reset_default_graph()

save_file = './model.ckpt'

# Two Tensor Variables: weights and bias
weights = tf.Variable(tf.truncated_normal([2, 3]), name='weights_0')
bias = tf.Variable(tf.truncated_normal([3]), name='bias_0')

saver = tf.train.Saver()

# Print the name of Weights and Bias
print('Save Weights: {}'.format(weights.name))
print('Save Bias: {}'.format(bias.name))

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.save(sess, save_file)

# Remove the previous weights and bias
tf.reset_default_graph()

# Two Variables: weights and bias
bias = tf.Variable(tf.truncated_normal([3]), name='bias_0')
weights = tf.Variable(tf.truncated_normal([2, 3]) ,name='weights_0')

saver = tf.train.Saver()

# Print the name of Weights and Bias
print('Load Weights: {}'.format(weights.name))
print('Load Bias: {}'.format(bias.name))

with tf.Session() as sess:
    # Load the weights and bias - No Error
    saver.restore(sess, save_file)

print('Loaded Weights and Bias successfully.')
728x90
반응형
Comments