안녕하세요. 

이번 글에서는 transfer learning과 fine-tuning을 하는 tensorflow code를 살펴보도록 하겠습니다.

 

 

 

1. Transfer Learning

아래 코드를 실행시키고 자세히 보면 prediction part(→Flatten, Dense layer 부분이 빠져있는걸 확인할 수 있습니다)

from tensorflow.keras.applications import VGG16

vgg_base = VGG16(weights='imagenet',
                  include_top=False,
                  input_shape=(150, 150, 3))
vgg_base.summary()

 

Prediction layers 분은 굳이 Functional API로 구현할 필요는 없기 때문에 sqeuntial로 prediction layer 부분을 추가해 줍니다.

from tensorflow.keras import models
from tensorflow.keras import layers

model = models.Sequential()
model.add(vgg_base)
model.add(layers.Flatten())
model.add(layers.Dense(256, activation='relu'))
model.add(layers.Dense(1, activation='sigmoid'))
model.summary()

 

위와 같이 model을 구성했다면 이전에 배운대로 아래 코드를 수행시켜 학습을 실행시킵니다.

model.compile(optimizer=optimizers.RMSprop(lr=2e-5),
              loss='binary_crossentropy',
              metrics=['acc'])

history = model.fit(train_features, train_labels,
                    epochs=30,
                    batch_size=20,
                    validation_data=(validation_features, validation_labels))

 

 

 

2. Fine tuning

위와 같이 transfer learning을 적용시킨 후 특정 layer들만 학습시키려면 우선 아래와 같이 Convolution base 부분을 아래 코드와 같이 freezing 시킵니다.

for layer in vgg_base.layers:
	layer.trainable = False

 

 

보통 fine tuning을 할 때, 끝단에 위치한 bottom conv layer (=deeper layer) 부분만 미세하게 학습시켜주는 것이 일반적입니다. 왜냐하면, 어떠한 이미지분류 문제이든 해당 객체의 edge(←top layer의 conv filter가 뽑는 feature)는 거의 비슷할 테니까요. 중간 layer에 있는 conv filter가 보통 texture feature를 뽑아내는데, 만약 ImageNet에서 학습한 이미지의 texture와 내가 분류하려는 새로운 이미지들의 texture가 많이 다르다는 판단을 하게 된다면 이 부분도 미세하게 학습시켜주어야 합니다. 이러한 방법은 아래와 같은 코드로 진행이 됩니다.

for layer in vgg_base.layers[15:]:
	layer.trainable = True

 

 

[정리]

지금까지 배운 pre-trained model을 다운받아 transfer learning 및 fine-tuning을 적용시키는 코드를 작성해보도록 하겠습니다.

from tensorflow.keras import models
from tensorflow.keras import layers
from tensorflow.keras.applications import VGG16
vgg_base = VGG16(weights='imagenet',
                  include_top=False,
                  input_shape=(150, 150, 3))
vgg_base.summary()
for layer in vgg_base.layers:
	layer.trainable = False
for layer in vgg_base.layers[15:]:
	layer.trainable = True
model = models.Sequential()
model.add(vgg_base)
model.add(layers.Flatten())
model.add(layers.Dense(256, activation='relu'))
model.add(layers.Dense(2, activation='softmax'))
model.compile(optimizer=optimizers.RMSprop(lr=2e-5),
              loss='categorical_crossentropy',
              metrics=['acc'])

history = model.fit(train_features, train_labels,
                    epochs=30,
                    batch_size=20,
                    validation_data=(validation_features, validation_labels))

 

 

3. Model 저장 및 불러오기

앞선 작업을 통해 학습시킨 CNN 모델을 따로 저장할 수 도 있고, 나중에 데이터셋이 더 마련되면 해당 CNN 모델을 불러와 다시 re-training 시킬 수 있습니다.

 

tensorflow에서 CNN 모델을 저장시키거나 불러올 때 사용되는 CNN 모델 파일 format은 TensorFlow SavedModel 형식 or Keras H5 형식 or checkpoint 형식으로 크게 세 가지가 있습니다.

 

 

3-1. h5 파일 형태로 저장 또는 불러오기

tensorflow에서는 학습시킨 모델을 h5 형태의 파일로 저장시킬 수 있습니다. 본래 h5 형태의 파일은 keras에서 지원하던 파일형식인데 keras가 tensorflow backend를 사용할 때 tf.keras 로 h5 형태의 파일을 저장 및 불러오는 것을 가능하게 했습니다.

# Calling `save('my_model.h5')` creates a h5 file `my_model.h5`.
model.save("my_h5_model.h5")

# It can be used to reconstruct the model identically.
reconstructed_model = keras.models.load_model("my_h5_model.h5")

 

H5 파일은 Hierarchical Data Format (HDF) 5의 축약표현입니다. 보통 과학 데이터의 다차원 배열이 포함되기 때문에 다양한 과학분야(ex: 물리학, 공학, 딥러닝, ... 등)에서 주로 사용되는 파일형태입니다.

 

 

3-2. SavedModel 형태로 저장 또는 불러오기

# Calling `save('my_model')` creates a SavedModel folder `my_model`.
model.save("my_model")

# It can be used to reconstruct the model identically.
reconstructed_model = keras.models.load_model("my_model")

 

아래링크에서 SavedModel과 h5파일을 사용할 때 어떠한 차이가 있는지 기술해놓았으니 참고하시면 좋을 것 같습니다.

https://www.tensorflow.org/guide/keras/save_and_serialize?hl=ko 

 

Keras 모델 저장 및 로드  |  TensorFlow Core

소개 Keras 모델은 다중 구성 요소로 이루어집니다. 모델에 포함된 레이어 및 레이어의 연결 방법을 지정하는 아키텍처 또는 구성 가중치 값의 집합("모델의 상태") 옵티마이저(모델을 컴파일하여

www.tensorflow.org

 

 

 

3-3. Checkpoint

기존 tensorflow에서는 모델을 저장시키거나 불러올 때, SavedModel 방식 또는 Checkpoint 방식을 사용했습니다. Checkpoint에 대한 설명은 추후에 더 자세히 다루도록 하겠습니다. 아래링크에서 checkpoint에 관한 내용을 다루고 있으니 참고하시면 좋을 것 같습니다.

 

https://www.tensorflow.org/guide/checkpoint?hl=ko 

 

체크포인트 훈련하기  |  TensorFlow Core

Note: 이 문서는 텐서플로 커뮤니티에서 번역했습니다. 커뮤니티 번역 활동의 특성상 정확한 번역과 최신 내용을 반영하기 위해 노력함에도 불구하고 공식 영문 문서의 내용과 일치하지 않을 수

www.tensorflow.org

 

 

 

+ Recent posts