Tensorflow 구현 Pattern
Tensorflow는 현 시점에서 가장 popular한 deep learning framework이고, google 에서 지원하는 만큼 앞으로도 이러한 상황이 지속될 것 같다.
pytorch가 tensorflow에 비해서 매우 simple한 구조라고 하는데, 아직까지는 tensorflow에 비해 community가 크지 않아서 개인적으로는 tensorflow를 계속사용하려고 한다. 라이브러리 또 배우는게 매우 귀찮다
내가 Tensorflow를 사용할 때는 training loop 이나 evaluation operation 같은 경우 다른 task에서 사용한 code를 copy & paste 하는 일이 매우 빈번했다. 그래서 이렇게 반복되는 code와 그렇지 않은 code를 분리해서 tensorflow base code를 구현해 보았다.
앞으로 tensorflow를 사용하는 project에서는 tensorflow base code를 sub-tree로 fetch 받아서 같은 작업의 반복을 줄여보려고 한다. 제발 그러고 싶다.
Table of contents
1. Tensorflow에서 Project마다 중복되는 code
먼저 Project마다 중복되는 (그래서 copy & paste 해야하는) code가 무엇인지를 파악해보았다.
웹에서 자료도 찾아 보았는데 https://wookayin.github.io/TensorFlowKR-2017-talk-bestpractice/ko/#1 가 매우 큰 도움이 되었다.
기본적으로 Project마다 ~항상 copy & paste 하고 있는~ 중복되는 code는 train함수와 evaluation함수이다.
그렇다면 중복되지 않는 code는 어느 부분일까? 당연히 graph내에서 operation(최종적으로 inference_op를 만드는 부분)을 구현하는 code 이다.
좀 더 formal 하게 정리해보면 다음과 같이 적어 볼수 있다.
- graph를 생성하는 부분 (여기는 project 마다 다른 부분이다.)
- session에서 graph를 실행하는 부분 (여기는 project 마다 같은 부분이다.)
- inference() : inference_op를 실행
- train() : train_op를 실행
- evaluation() : evaluation_op를 실행
2. 구조
중복되는 code를 파악했으니 이제는 구조를 잡아볼 차례이다.
code 전체 : https://github.com/penny4860/tensorflow-train-images/blob/master/src/net/base.py
2.1. graph를 생성하는 부분 : _Model class
여기는 Project마다 매번 새로 구현해야하는 code이다.
그래서 base class를 구현해놓고 project마다 base class를 상속받아서 사용하는 구조를 잡았다.
class _Model(object):
__metaclass__ = ABCMeta
def __init__(self):
# Placeholders
self.X = self._create_input_placeholder()
self.Y = self._create_output_placeholder()
self.is_training = self._create_is_train_placeholder()
# basic operations
self.inference_op = self._create_inference_op()
self.loss_op = self._create_loss_op()
self.accuracy_op = self._create_accuracy_op()
# summary operations
self.train_summary_op = self._create_train_summary_op()
## 이하생략
새로운 Project에서는 위 class를 상속받고, _create_inference_op()
등의 abstract method를 정의한 다음 사용하면 될 것이다.
2.2. session 에서 graph를 실행하는 부분 : inference(), train(), evaluation()
여기에 해당하는 code는 function으로 만들었다. 함수들의 interface만 간단하게 살펴보자.
2.2.1. train()
def train(model, X_train, y_train, X_val, y_val, batch_size=100, n_epoches=5, ckpt=None, random_seed=None):
- _Model class 를 상속받아서 구현한 class instance를 parameter로 전달받아 session내부에서 실행하는 함수이다.
- 함수내부에서 session instance를 생성해서 training operation을 실행한다.
- training 중에 loss, accuracy에 대한 summary도 한다.
2.2.2. evaluate()
def evaluate(model, images, labels, session=None, ckpt=None, batch_size=100):
- session을 parameter로 전달받을 수 있다.
- training loop에서 evaluation이 필요할 때는 session을 전달하고,
- 그렇지 않은 경우에는 함수내부에서 session을 생성한다.
3. Client code 구조
#1. MnistCnn class 구현
class MnistCnn(_Model):
## 이하생략
#2. instance 생성
model = MnistCnn()
#3. training
train(model, train_images, train_labels, valid_images, valid_labels, ckpt='ckpts/cnn')
#4. 학습된 model을 평가
evaluate(model, test_images, test_labels, ckpt='ckpts')
4. 정리
session을 실행하는 부분과 graph를 생성하는 부분을 분리시켜서 구조를 잡았다.
tensorboard로 training 과정을 logging하는 부분도 고려해서 구현해놨는데, 이걸로 노가다가 좀 줄었으면 좋겠다.
transfer learning이나 dataset을 thread & queue로 입력받는 부분은 고려하지 않았는데 지금 마음가짐으로는 하나씩 추가해볼 생각이다.