안녕하세요.

이번 글에서는 pytorch를 이용해 UNet 모델을 구현한 code를 설명할 예정입니다.

 

다양한 딥러닝 기반 segmentation 모델이 있지만, UNet 모델이 가장 기본이 되기 때문에 다루었습니다.

 

소개해 드릴 UNet pytorch 코드는 아래 영상을 기반으로 리뷰했으니 아래 영상도 참고해주세요!

https://www.youtube.com/watch?v=sSxdQq9CCx0 

 

 

최종코드는 제일 아래에 있으니 참고해주세요!

※대부분 PPT 슬라이드에 설명한 내용을 이미지로 만들어 업로드했기 때문에 글씨가 잘 안보일 수 도 있습니다. 그래서 PPT파일을 따로 첨부하도록 하겠습니다.

 

 

Unet pytorch implementation.pptx
2.52MB

 

 

 

0. UNet() 함수 호출

 

Pytorch에서 UNet 모델을 불러오는 코드는 아래 한 줄로 가능합니다.

 

model = UNet().to(device)

 

위의 코드를 실행시키면 구현해 놓은 UNet class가 로드 됩니다.

 

그림1

 

그럼 구현해 놓은 UNet class를 살펴보도록 하겠습니다.

 

 

 

 

 

1. Contracting Path 구현하기

그림2

 

 

그림3

 

 

 

 

 

2. Expansive Path 구현하기

그림4

 

 

 

그림5

 

 

 

 

 

3. Concatenation 구현하기

그림6

 

 

 

그림7

 

 

 

 

 

 

 

4. 최종코드

class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        def CBR2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True):
            layers = []
            layers += [nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                                 kernel_size=kernel_size, stride=stride, padding=padding,
                                 bias=bias)]
            layers += [nn.BatchNorm2d(num_features=out_channels)]
            layers += [nn.ReLU()]

            cbr = nn.Sequential(*layers)

            return cbr

        # Contracting path
        self.enc1_1 = CBR2d(in_channels=1, out_channels=64)
        self.enc1_2 = CBR2d(in_channels=64, out_channels=64)

        self.pool1 = nn.MaxPool2d(kernel_size=2)

        self.enc2_1 = CBR2d(in_channels=64, out_channels=128)
        self.enc2_2 = CBR2d(in_channels=128, out_channels=128)

        self.pool2 = nn.MaxPool2d(kernel_size=2)

        self.enc3_1 = CBR2d(in_channels=128, out_channels=256)
        self.enc3_2 = CBR2d(in_channels=256, out_channels=256)

        self.pool3 = nn.MaxPool2d(kernel_size=2)

        self.enc4_1 = CBR2d(in_channels=256, out_channels=512)
        self.enc4_2 = CBR2d(in_channels=512, out_channels=512)

        self.pool4 = nn.MaxPool2d(kernel_size=2)

        self.enc5_1 = CBR2d(in_channels=512, out_channels=1024)

        # Expansive path
        self.dec5_1 = CBR2d(in_channels=1024, out_channels=512)

        self.unpool4 = nn.ConvTranspose2d(in_channels=512, out_channels=512,
                                          kernel_size=2, stride=2, padding=0, bias=True)

        self.dec4_2 = CBR2d(in_channels=2 * 512, out_channels=512)
        self.dec4_1 = CBR2d(in_channels=512, out_channels=256)

        self.unpool3 = nn.ConvTranspose2d(in_channels=256, out_channels=256,
                                          kernel_size=2, stride=2, padding=0, bias=True)

        self.dec3_2 = CBR2d(in_channels=2 * 256, out_channels=256)
        self.dec3_1 = CBR2d(in_channels=256, out_channels=128)

        self.unpool2 = nn.ConvTranspose2d(in_channels=128, out_channels=128,
                                          kernel_size=2, stride=2, padding=0, bias=True)

        self.dec2_2 = CBR2d(in_channels=2 * 128, out_channels=128)
        self.dec2_1 = CBR2d(in_channels=128, out_channels=64)

        self.unpool1 = nn.ConvTranspose2d(in_channels=64, out_channels=64,
                                          kernel_size=2, stride=2, padding=0, bias=True)

        self.dec1_2 = CBR2d(in_channels=2 * 64, out_channels=64)
        self.dec1_1 = CBR2d(in_channels=64, out_channels=64)

        self.fc = nn.Conv2d(in_channels=64, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True)

    def forward(self, x):
        enc1_1 = self.enc1_1(x)
        enc1_2 = self.enc1_2(enc1_1)
        pool1 = self.pool1(enc1_2)

        enc2_1 = self.enc2_1(pool1)
        enc2_2 = self.enc2_2(enc2_1)
        pool2 = self.pool2(enc2_2)

        enc3_1 = self.enc3_1(pool2)
        enc3_2 = self.enc3_2(enc3_1)
        pool3 = self.pool3(enc3_2)
        # print(pool3.size())
        enc4_1 = self.enc4_1(pool3)
        enc4_2 = self.enc4_2(enc4_1)
        pool4 = self.pool4(enc4_2)

        enc5_1 = self.enc5_1(pool4)

        dec5_1 = self.dec5_1(enc5_1)

        unpool4 = self.unpool4(dec5_1)
        cat4 = torch.cat((unpool4, enc4_2), dim=1)
        dec4_2 = self.dec4_2(cat4)
        dec4_1 = self.dec4_1(dec4_2)

        unpool3 = self.unpool3(dec4_1)
        cat3 = torch.cat((unpool3, enc3_2), dim=1)
        dec3_2 = self.dec3_2(cat3)
        dec3_1 = self.dec3_1(dec3_2)

        unpool2 = self.unpool2(dec3_1)
        cat2 = torch.cat((unpool2, enc2_2), dim=1)
        dec2_2 = self.dec2_2(cat2)
        dec2_1 = self.dec2_1(dec2_2)

        unpool1 = self.unpool1(dec2_1)
        cat1 = torch.cat((unpool1, enc1_2), dim=1)
        dec1_2 = self.dec1_2(cat1)
        dec1_1 = self.dec1_1(dec1_2)

        x = self.fc(dec1_1)

        return x

 

 

 

지금까지 UNet을 Pytorch로 구현한 code에 대해서 설명해봤습니다.

다음 글에서는 Pretrained model을 불러와 transfer learning을 적용시키는 코드에 대해 설명하도록 하겠습니다.

 

 

 

[Reference Site]

https://toitoitoi79.tistory.com/97

 

U-net: Convolutional Networks for Biomedical Image Segmentation Pytorch 구현

U-net은 바이오 기술에 사용되는 segmentation 논문입니다. sliding window 방식을 사용하는 CNN 구조와 달리 검증된 patch는 넘기기 때문에 보다 빠른 처리가 가능한 구조 입니다. 해당 포스팅은 구현에 포

toitoitoi79.tistory.com

 

+ Recent posts