일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | |||
5 | 6 | 7 | 8 | 9 | 10 | 11 |
12 | 13 | 14 | 15 | 16 | 17 | 18 |
19 | 20 | 21 | 22 | 23 | 24 | 25 |
26 | 27 | 28 | 29 | 30 | 31 |
- 신의 시간술
- 김프
- 제시 리버모어
- 파이어족 자산증식
- 에드워드 소프
- 데이비드 라이언
- 파이어족 자산
- 데이빗 라이언
- GIT
- eclipse
- AWS
- 파이어족
- 자산배분
- 퀀터스 하지 마세요
- 아웃풋 트레이닝
- mark minervini
- 2%룰
- 파이어족 포트폴리오
- 마크미너비니
- 퀀트 트레이딩
- H는 통계를 모른다.
- 통계적 유의성
- 마크 미너비니
- 파이어족 저축
- 니콜라스 다바스
- 이클립스
- python
- tensorflow
- 추세추종 2%룰
- 연금저축계좌
- Today
- Total
머신러닝과 기술적 분석
Tensorflow 구현 Pattern 본문
Tensorflow는 현 시점에서 가장 popular한 deep learning framework이고, google 에서 지원하는 만큼 앞으로도 이러한 상황이 지속될 것 같다.
pytorch가 tensorflow에 비해서 매우 simple한 구조라고 하는데, 아직까지는 tensorflow에 비해 community가 크지 않아서 개인적으로는 tensorflow를 계속사용하려고 한다. 라이브러리 또 배우는게 매우 귀찮다
내가 Tensorflow를 사용할 때는 training loop 이나 evaluation operation 같은 경우 다른 task에서 사용한 code를 copy & paste 하는 일이 매우 빈번했다. 그래서 이렇게 반복되는 code와 그렇지 않은 code를 분리해서 tensorflow base code를 구현해 보았다.
앞으로 tensorflow를 사용하는 project에서는 tensorflow base code를 sub-tree로 fetch 받아서 같은 작업의 반복을 줄여보려고 한다. 제발 그러고 싶다.
Table of contents
1. Tensorflow에서 Project마다 중복되는 code
먼저 Project마다 중복되는 (그래서 copy & paste 해야하는) code가 무엇인지를 파악해보았다.
웹에서 자료도 찾아 보았는데 https://wookayin.github.io/TensorFlowKR-2017-talk-bestpractice/ko/#1 가 매우 큰 도움이 되었다.
기본적으로 Project마다 ~항상 copy & paste 하고 있는~ 중복되는 code는 train함수와 evaluation함수이다.
그렇다면 중복되지 않는 code는 어느 부분일까? 당연히 graph내에서 operation(최종적으로 inference_op를 만드는 부분)을 구현하는 code 이다.
좀 더 formal 하게 정리해보면 다음과 같이 적어 볼수 있다.
- graph를 생성하는 부분 (여기는 project 마다 다른 부분이다.)
- session에서 graph를 실행하는 부분 (여기는 project 마다 같은 부분이다.)
- inference() : inference_op를 실행
- train() : train_op를 실행
- evaluation() : evaluation_op를 실행
2. 구조
중복되는 code를 파악했으니 이제는 구조를 잡아볼 차례이다.
code 전체 : https://github.com/penny4860/tensorflow-train-images/blob/master/src/net/base.py
2.1. graph를 생성하는 부분 : _Model class
여기는 Project마다 매번 새로 구현해야하는 code이다.
그래서 base class를 구현해놓고 project마다 base class를 상속받아서 사용하는 구조를 잡았다.
class _Model(object):
__metaclass__ = ABCMeta
def __init__(self):
# Placeholders
self.X = self._create_input_placeholder()
self.Y = self._create_output_placeholder()
self.is_training = self._create_is_train_placeholder()
# basic operations
self.inference_op = self._create_inference_op()
self.loss_op = self._create_loss_op()
self.accuracy_op = self._create_accuracy_op()
# summary operations
self.train_summary_op = self._create_train_summary_op()
## 이하생략
새로운 Project에서는 위 class를 상속받고, _create_inference_op()
등의 abstract method를 정의한 다음 사용하면 될 것이다.
2.2. session 에서 graph를 실행하는 부분 : inference(), train(), evaluation()
여기에 해당하는 code는 function으로 만들었다. 함수들의 interface만 간단하게 살펴보자.
2.2.1. train()
def train(model, X_train, y_train, X_val, y_val, batch_size=100, n_epoches=5, ckpt=None, random_seed=None):
- _Model class 를 상속받아서 구현한 class instance를 parameter로 전달받아 session내부에서 실행하는 함수이다.
- 함수내부에서 session instance를 생성해서 training operation을 실행한다.
- training 중에 loss, accuracy에 대한 summary도 한다.
2.2.2. evaluate()
def evaluate(model, images, labels, session=None, ckpt=None, batch_size=100):
- session을 parameter로 전달받을 수 있다.
- training loop에서 evaluation이 필요할 때는 session을 전달하고,
- 그렇지 않은 경우에는 함수내부에서 session을 생성한다.
3. Client code 구조
#1. MnistCnn class 구현
class MnistCnn(_Model):
## 이하생략
#2. instance 생성
model = MnistCnn()
#3. training
train(model, train_images, train_labels, valid_images, valid_labels, ckpt='ckpts/cnn')
#4. 학습된 model을 평가
evaluate(model, test_images, test_labels, ckpt='ckpts')
4. 정리
session을 실행하는 부분과 graph를 생성하는 부분을 분리시켜서 구조를 잡았다.
tensorboard로 training 과정을 logging하는 부분도 고려해서 구현해놨는데, 이걸로 노가다가 좀 줄었으면 좋겠다.
transfer learning이나 dataset을 thread & queue로 입력받는 부분은 고려하지 않았는데 지금 마음가짐으로는 하나씩 추가해볼 생각이다.
'Tensorflow' 카테고리의 다른 글
[Tensorflow] Checkpoint file에 저장되어있는 Variable Tensor를 알아보는 방법 (0) | 2017.08.19 |
---|---|
Tensorflow 에서 VGG16을 사용하는 방법 (0) | 2017.08.17 |
Tensorflow 에서 random seed 사용방법 (0) | 2017.08.16 |
Tensorflow에서 scope/name 조합으로 variable 가져오기 (0) | 2017.08.16 |
Tensorflow 에서 model 을 저장, 로드하는 방법 (0) | 2017.08.16 |