안녕하세요.
이번에는 tensorflow 2 기반으로 CNN 모델을 구현하는 내용을 다루도록 하겠습니다.
Tensorflow 2에서 DNN or CNN 모델을 구축하는 방식은 크게 2가지로 나눌 수 있습니다.
- Sequential API
- Functional API
먼저, sequential API에 대해 설명한 후, functional API를 설명하도록 하겠습니다.
1. Sequential API
- Sequential API는 tensorflow 2에서 뉴럴 네트워크를 가장 쉽게 구성할 수 있는 방식입니다.
- Sequential이라는 이름에 맞게 add 함수를 이용하면 layer가 순차대로 연결이 됩니다.
- A sequential model is appropriate for a plain stack of layers where each layer has exactly one input tensor and one output tensor.
- 즉, add 함수를 통해 각 layer들은 정확히 하나의 input값만을 받을 수 있으며, output또한 하나의 tensor 형태로만 출력이 가능합니다.
- 이러한 특징이 갖고 있는 단점 중 하나는 복잡한 CNN 모델을 구성하기 힘들다는 점입니다.
- 예를 들어, ResNet 같은 경우는 Residual block 을 구성하기 위해서는 두개의 input 값 (ex: F(x), x)을 받아야 하는데, Sequential API로 구성하는 경우 x값을 받을 수 없으니 ResNet 모델을 구현할 수 없게 됩니다.
- 위와 같은 이유로 DenseNet 또한 구현이 불가능 합니다.
from tensorflow.keras import layers
from tensorflow.keras import models
#Conv2D(채널 수, (Conv filter 크기), activation function, 입력 데이터 크기)
model = models.Sequential()
model.add(layers.Conv2D(32, (3, 3), activation='relu',
input_shape=(150, 150, 3)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(128, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(128, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Flatten())
model.add(layers.Dropout(0.5))
model.add(layers.Dense(512, activation='relu'))
model.add(layers.Dense(1, activation='sigmoid'))
2. Functional API
- Functional API을 이용하면 다양한 input 값을 받을 수 있습니다.
- 즉, layer에 input 값을 따로 기재해줄 수 있다는 뜻이죠.
from tensorflow.keras import layers
from tensorflow import keras
input_shape = (150,150,3)
img_input = layers.Input(shape=input_shape)
output1 = layers.Conv2D(kernel_size=(3,3), filters=32, activation='relu')(img_input)
output2 = layers.MaxPooling2D((2,2))(output1)
output3 = layers.Conv2D(kernel_size=(3,3), filters=64, activation='relu')(output2)
output4 = layers.MaxPooling2D((2,2))(output3)
output5 = layers.Conv2D(kernel_size=(3,3), filters=128, activation='relu')(output4)
output6 = layers.MaxPooling2D((2,2))(output5)
output7 = layers.Conv2D(kernel_size=(3,3), filters=128, activation='relu')(output4)
output8 = layers.MaxPooling2D((2,2))(output7)
output9 = layers.Flatten()(output8)
output10 = layers.Dropout(0.5)(output9)
output11 = layers.Dense(512, activation='relu')(output10)
predictions = layers.Dense(2, activation='softmax')(output11)
model = keras.Model(inputs=img_input, outputs=predictions)
3. Functional API를 이용해 Residual block 구성하기
- BatchNormalization layer 추가
- skip connection 적용
input = X
#첫 번째 conv layer에 있는 residual block
block_1_output1 = layers.Conv2D(kernel_size=(3,3), filters=channel_num, padding='same', name=name + '0_conv')(input)
block_1_output2 = BatchNormalization(name=name + '0_bn')(block_1_output1)
block_1_output3 = Activation('relu', name=name + '0_relu')(block_1_output2)
block_1_output4 = Conv2D(kernel_size=(3, ), filters=channel_num, padding='same', name=name + '1_conv')(block_1_output3)
# Zero gamma - Last BN for each ResNet block, easier to train at the initial stage.
#block_1_output4 = F(X)
block_1_output4 = BatchNormalization(gamma_initializer='zeros', name=name + '1_bn')(block_1_output4)
#merge_data = X+F(X)
merge_data = add([block_1_output4, input], name=name + '1_add')
out = Activation('relu', name=name + '2_conv')(merge_data)
위의 코드는 아래 이미지의 original 버전이라고 생각하시면 됩니다. 위의 코드를 기반으로 나머지 구조들((b), (c), (d), (e))도 구현하실 수 있겠죠?
3. Model summary
앞서 CNN 모델을 작성했다면, 해당 모델에 대한 간단한 구조를 summary 함수를 통해 알아볼 수 있습니다.
model.summary()
위의 출력결과에서는 "Non-trainable params:0"으로 표현되어 있는데, 나중에 transfer learning or fine-tuning을 적용시킬 때는 특정 layer까지 freezing 시키는 경우도 빈번하므로 Non-trainable params가 0이 아닐 때도 생깁니다. 이 부분은 trasnfer learning을 다룰 때 설명하도록 하겠습니다.
'Tensorflow > 2.CNN' 카테고리의 다른 글
5.Pre-trained model 불러오기 (feat. Transfer Learning and .h5 파일) (0) | 2021.06.30 |
---|---|
4. 평가지표(Metrics ) visualization (0) | 2021.06.29 |
3.CNN 모델의 loss function 및 optimizer 설정 (0) | 2021.06.29 |
1.Data Load 및 Preprocessing (전처리) (0) | 2021.06.28 |
코드참고 사이트 (0) | 2021.06.28 |