안녕하세요.
이번 글에서는 pytorch를 이용해서 대표적인 CNN 모델인 ResNet을 implementation 하는데 필요한 코드를 line by line으로 설명해보려고 합니다.
ResNet을 구현할 줄 아시면 전통적인 CNN 모델들은 자유롭게 구현하는데 어려움이 없을거라 생각됩니다.
우선 pytorch에서 resnet 모델을 불러오는 코드는 아래 한 줄로 가능합니다.
model = resnet50().to(device)
그렇다면 resnet50() 이라는 함수가 어떤 과정을 통해 실행되는지 살펴봐야겠죠?
지금부터 이 과정을 순차대로 살펴보도록 하겠습니다.
※ 최종코드는 제일 아래에 있으니 참고해주세요!
※ 대부분 PPT 슬라이드에 설명한 내용을 이미지로 만들어 업로드했기 때문에 글씨가 잘 안보일 수도 있습니다. 그래서 PPT파일을 따로 첨부 하도록 하겠습니다.
0. ResNet() 함수 호출
- 먼저 resnet50()을 호출하면 ResNet(BottleNeck, [3,4,6,3]) 함수를 호출하게됩니다.
- ResNet 함수 내부를 대략적으로 살펴보면 ResNet50 구조를 파악할 수 있습니다.
1. (BottleNeck 적용 전) 첫 번째 conv layer
- ResNet 함수에서 첫 번째 conv layer 부터 살펴보도록 하겠습니다.
2. 두 번째 Conv layer
- 두 번째 Conv layer 부터 bottleneck이 적용됩니다. 앞서 노란색 영역인 첫 번째 conv layer를 지나면, 아래 빨간색 영역의 첫 번째 bottleneck 연산이 진행됩니다.
- 우선 첫 번째 bottleneck을 간단히 도식화하면 아래와 같이 나타낼 수 있습니다.
- Bottleneck이 포함된 conv layer를 생성하기 위해 make_layer 함수가 실행되야 하는데, make_layer 함수에 작성된 python 기본 문법들을 먼저 설명하겠습니다.
- 연산자를 이용한 리스트 생성
- for in 반복문 (with 리스트)
- 리스트 인자 함수
- Sequential 함수
- sequential 함수 설명
2-1. 두 번째 Conv layer에서 첫 번째 BottleNeck 적용 (make_layer(), BottleNeck()=block() 함수 호출)
2-2. 두 번째 Conv layer에서 두 번째 BottleNeck 적용 (make_layer(), BottleNeck()=block() 함수 호출)
2-3. 두 번째 Conv layer에서 세 번째 BottleNeck 적용 (make_layer(), BottleNeck()=block() 함수 호출)
3. 세 번째 Conv layer
3-1. 세 번째 Conv layer에서 첫 번째 BottleNeck 적용 (make_layer(), BottleNeck()=block() 함수 호출 + Down_sampling)
- 여기서 부터는 첫 번째 bottleNeck에 shortcut (for skip connection) 적용을 위해 down_sampling이 된다는 점을 알아두시면 좋을 것 같습니다.
- Down_sampling은 conv filter의 stride를 2로 설정함으로써 진행이 됩니다.
3-2. 세 번째 Conv layer에서 두 번째 BottleNeck 적용 (make_layer(), BottleNeck()=block() 함수 호출)
3-3. 세 번째 Conv layer에서 세 번째 BottleNeck 적용 (make_layer(), BottleNeck()=block() 함수 호출)
- block 함수 부분은 이전과 설명이 동일 하므로 이제부터는 생략하겠습니다.
3-4. 세 번째 Conv layer에서 세 번째 BottleNeck 적용 (make_layer(), BottleNeck()=block() 함수 호출)
4, 5. 네 번째 Conv layer, 다섯 번째 Conv layer
- 여기서부터는 위에서 설명한 내용의 반복이라 make_layer, block 함수 실행과정은 생략하도록 하겠습니다.
6. Average pooling, FC layer, Softmax
7. Weight initialization
(↓↓↓ 가중치 초기화 관련 API ↓↓↓)
https://pytorch.org/docs/stable/nn.init.html
8. Model Show
- 앞서 작성한 코드가 올바로 작성됐는지 해당 모델 구조를 들여다보는 세 가지 방법에 대해서 알아보겠습니다.
8-1. model.modules()
8-2. model.named_parameters()
8-3. summary()
9. 최종 코드
# model
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
class ResNet(nn.Module):
def __init__(self, block, num_block, num_classes=10, init_weights=True):
super().__init__()
self.in_channels=64
self.conv1 = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)
self.conv2_x = self._make_layer(block, 64, num_block[0], 1)
self.conv3_x = self._make_layer(block, 128, num_block[1], 2)
self.conv4_x = self._make_layer(block, 256, num_block[2], 2)
self.conv5_x = self._make_layer(block, 512, num_block[3], 2)
self.avg_pool = nn.AdaptiveAvgPool2d((1,1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
# weights inittialization
if init_weights:
self._initialize_weights()
def _make_layer(self, block, out_channels, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
ith_block = 1
for stride in strides:
layers.append(block(self.in_channels, out_channels, stride, ith_block))
self.in_channels = out_channels * block.expansion
ith_block = ith_block+1
return nn.Sequential(*layers)
def forward(self,x):
output = self.conv1(x)
output = self.conv2_x(output)
x = self.conv3_x(output)
x = self.conv4_x(x)
x = self.conv5_x(x)
x = self.avg_pool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
# define weight initialization function
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def resnet50():
return ResNet(BottleNeck, [3,4,6,3])
class BottleNeck(nn.Module):
expansion = 4
def __init__(self, in_channels, out_channels, stride=1, ith_block=1):
super().__init__()
self.residual_function = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, stride=1, bias=False),
nn.BatchNorm2d(out_channels * BottleNeck.expansion),
)
self.shortcut = nn.Sequential()
if stride == 1 and ith_block == 1: #첫 번째 block에서의 shortcut (or identity) 을 적용해주기 위해서는 channel 조정필요
self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, kernel_size=1, stride=1),
nn.BatchNorm2d(out_channels*BottleNeck.expansion))
if stride != 1 or in_channels != out_channels * BottleNeck.expansion: #feature size_downsampling
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels*BottleNeck.expansion, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_channels*BottleNeck.expansion)
)
self.relu = nn.ReLU()
def forward(self, x):
x = self.residual_function(x) + self.shortcut(x)
x = self.relu(x)
return x
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = resnet50().to(device)
x = torch.randn(3, 3, 224, 224).to(device)
output = model(x)
print(output.size())
for name, param in model.named_parameters():
print(name, param.size())
for m in model.modules():
print(m)
#ResNet50 모델 summary
summary(model, (3, 224, 224), device=device.type)
지금까지 ResNet50을 pytorch로 구현한 code에 대해서 설명해봤습니다.
다음 글에서는 Pretrained model를 불러드려와 transfer learning을 적용시키는 코드에 대해 설명하도록 하겠습니다.
'Pytorch > 2.CNN' 카테고리의 다른 글
5.Loss function, Optimizer, Learning rate policy (0) | 2021.07.27 |
---|---|
4. Transfer Learning (Feat. pre-trained model) (0) | 2021.07.27 |
2. Data preprocessing (Feat. Augmentation) (0) | 2021.07.27 |
1. Data Load (Feat. CUDA) (0) | 2021.07.27 |
코드 참고 사이트 (0) | 2021.07.02 |