머신러닝과 기술적 분석

Tensorflow에서 scope/name 조합으로 variable 가져오기 본문

Tensorflow

Tensorflow에서 scope/name 조합으로 variable 가져오기

BetterToday 2017. 8. 16. 23:36
728x90

처음 Tensorflow을 접했을 때 디버깅이 힘들었던 이유가 사실은 지금도 힘들다 graph/session이 분리되어있어서 어떤 변수(Tensor)의 value를 찍어보기가 힘들었던 점이다.

실행 중에 value를 검사하는 것(print로 찍든, 디버거로 보든)이야말로 모든 디버깅의 기본이라고 할 수 있다.
본 post에서는 관심있는 Variable 객체들을 scope/name/collection의 조합으로 모으는 방법을 알아보려고 한다.

1. tf.contrib.framework.get_variables()

Variable 객체들을 모으는 Tensorflow에서 제공하는 몇 가지 Api가 있다. 그 중에서 tf.contrib.framework.get_variables()의 사용법이 가장 편한 것 같다.

scope/suffix/collection 의 조합으로 variable들을 가져올 수 있다.

# 이렇게 하면 graph에 있는 모든 variable을 가져올 수 있다.
variables = tf.contrib.framework.get_variables(scope=None, suffix=None, collection=tf.GraphKeys.GLOBAL_VARIABLES)
for variable in variables:
    print(variable.name)

# 이렇게 하면 graph의 모든 scope에서 name을 "weights"로 지정한 모든 variable들을 가져올 수 있다.
variables = tf.contrib.framework.get_variables(scope=None, suffix="weights", collection=tf.GraphKeys.GLOBAL_VARIABLES)
for variable in variables:
    print(variable.name)

2. graph code 구현할 때 scope/name을 잘 지정하자.

결국 Layer 마다 scope를 잘 지정하고 Layer 안에서 생성하는 Variable 객제에 name property를 잘 지정하자.

이러면 scope/suffix(name property로 지정된 string) 조합으로 쉽게 Variables를 가져올 수 있다.
“`

728x90
반응형
Comments