머신러닝과 기술적 분석

Tensorflow 구현 Pattern 본문

Tensorflow

Tensorflow 구현 Pattern

BetterToday 2017. 8. 16. 23:33
728x90

Tensorflow는 현 시점에서 가장 popular한 deep learning framework이고, google 에서 지원하는 만큼 앞으로도 이러한 상황이 지속될 것 같다.

pytorch가 tensorflow에 비해서 매우 simple한 구조라고 하는데, 아직까지는 tensorflow에 비해 community가 크지 않아서 개인적으로는 tensorflow를 계속사용하려고 한다. 라이브러리 또 배우는게 매우 귀찮다

내가 Tensorflow를 사용할 때는 training loop 이나 evaluation operation 같은 경우 다른 task에서 사용한 code를 copy & paste 하는 일이 매우 빈번했다. 그래서 이렇게 반복되는 code와 그렇지 않은 code를 분리해서 tensorflow base code를 구현해 보았다.

앞으로 tensorflow를 사용하는 project에서는 tensorflow base code를 sub-tree로 fetch 받아서 같은 작업의 반복을 줄여보려고 한다. 제발 그러고 싶다.

Table of contents

  1. Tensorflow에서 Project마다 중복되는 code
  2. 구조
  3. Client Code
  4. 정리

1. Tensorflow에서 Project마다 중복되는 code

먼저 Project마다 중복되는 (그래서 copy & paste 해야하는) code가 무엇인지를 파악해보았다.

웹에서 자료도 찾아 보았는데 https://wookayin.github.io/TensorFlowKR-2017-talk-bestpractice/ko/#1 가 매우 큰 도움이 되었다.

기본적으로 Project마다 ~항상 copy & paste 하고 있는~ 중복되는 code는 train함수와 evaluation함수이다.

그렇다면 중복되지 않는 code는 어느 부분일까? 당연히 graph내에서 operation(최종적으로 inference_op를 만드는 부분)을 구현하는 code 이다.

좀 더 formal 하게 정리해보면 다음과 같이 적어 볼수 있다.

  • graph를 생성하는 부분 (여기는 project 마다 다른 부분이다.)
  • session에서 graph를 실행하는 부분 (여기는 project 마다 같은 부분이다.)
    • inference() : inference_op를 실행
    • train() : train_op를 실행
    • evaluation() : evaluation_op를 실행

2. 구조

중복되는 code를 파악했으니 이제는 구조를 잡아볼 차례이다.

code 전체 : https://github.com/penny4860/tensorflow-train-images/blob/master/src/net/base.py

2.1. graph를 생성하는 부분 : _Model class

여기는 Project마다 매번 새로 구현해야하는 code이다.

그래서 base class를 구현해놓고 project마다 base class를 상속받아서 사용하는 구조를 잡았다.

class _Model(object):

    __metaclass__ = ABCMeta

    def __init__(self):
        # Placeholders
        self.X = self._create_input_placeholder()
        self.Y = self._create_output_placeholder()
        self.is_training = self._create_is_train_placeholder()

        # basic operations
        self.inference_op = self._create_inference_op()
        self.loss_op = self._create_loss_op()
        self.accuracy_op = self._create_accuracy_op()

        # summary operations        
        self.train_summary_op = self._create_train_summary_op()

    ## 이하생략        

새로운 Project에서는 위 class를 상속받고, _create_inference_op() 등의 abstract method를 정의한 다음 사용하면 될 것이다.

2.2. session 에서 graph를 실행하는 부분 : inference(), train(), evaluation()

여기에 해당하는 code는 function으로 만들었다. 함수들의 interface만 간단하게 살펴보자.

2.2.1. train()

def train(model, X_train, y_train, X_val, y_val, batch_size=100, n_epoches=5, ckpt=None, random_seed=None):

  • _Model class 를 상속받아서 구현한 class instance를 parameter로 전달받아 session내부에서 실행하는 함수이다.
  • 함수내부에서 session instance를 생성해서 training operation을 실행한다.
  • training 중에 loss, accuracy에 대한 summary도 한다.

2.2.2. evaluate()

def evaluate(model, images, labels, session=None, ckpt=None, batch_size=100):

  • session을 parameter로 전달받을 수 있다.
    • training loop에서 evaluation이 필요할 때는 session을 전달하고,
    • 그렇지 않은 경우에는 함수내부에서 session을 생성한다.

3. Client code 구조

#1. MnistCnn class 구현
class MnistCnn(_Model):
    ## 이하생략

#2. instance 생성
model = MnistCnn()

#3. training
train(model, train_images, train_labels, valid_images, valid_labels, ckpt='ckpts/cnn')

#4. 학습된 model을 평가
evaluate(model, test_images, test_labels, ckpt='ckpts')

4. 정리

session을 실행하는 부분과 graph를 생성하는 부분을 분리시켜서 구조를 잡았다.

tensorboard로 training 과정을 logging하는 부분도 고려해서 구현해놨는데, 이걸로 노가다가 좀 줄었으면 좋겠다.

transfer learning이나 dataset을 thread & queue로 입력받는 부분은 고려하지 않았는데 지금 마음가짐으로는 하나씩 추가해볼 생각이다.

728x90
반응형
Comments