pytorch 오토인코더 쉽게 구현하는 방법은?!

2021. 3. 19. 07:05

파이토치는 torchvision을 이용하면, 익히 알고있는 CNN알고리즘을 불러와 사용할 수 있다. 이 방법에 익숙해지다보니 autoencoder도 동일한 방식으로 학습시킬 수 있는 방법을 찾게 되었다. 하지만, 아쉽게도 torchvision에는 미리 탑재된 오토인코더 알고리즘은 없었다. 오늘은 파이토치에서 오토인코더 알고리즘을 쉽게 구현하는 방법에 대해서 알아보도록 하겠다.

 

 

pytorch 오토인코더



필자는 CNN에 기반한 오토인코더 알고리즘을 학습시키고 싶어, 관련 내용을 찾아보았다. 그리고 Resnet을 기반으로 이미 설계된 알고리즘에 학습시킬 수 있는 방법을 알게 됐다.

Restnet기반 AutoEncoder 알고리즘에 학습시키기 위해서는 pytorch-lighting과 pytorch-lighting-bolts 패키지가 필요하다. 파이토치 라이트닝은 파이토치를 래핑한 패키지로, 파이토치로 더 쉽고 편리하게 데이터를 학습시키도록 도와주는 패키지이다.


파이토치 라이트닝 볼츠는 이미 설계된 모델과 사전 훈련된 모델을 쉽게 사용할 수 있게 도와주는 패키지이다. 여기에는 Resnet기반 오토인코더 모델도 포함되어 있다.

 

 

코랩(colab)에 패키지를 설치했는데, 버전이 맞지 않는지 계속 에러가 발생한다. pytorch, pytorch lightning, lightning bolts, torchvision, torchtext 등 관련있는 패키지가 많이 있다. 최종적으로 아래와 같이 설치해서 잘 동작하는 것을 확인했다.

 

!pip install lightning-bolts
!pip install torchtext==0.6.0

 

 

인풋은 코랩(colab)에 샘플로 들어있는 MINST 파일을 사용하기로 하였다. 코랩(colab)에서 MNIST파일을 사용하는 방법은 이전 포스팅에서 다루었으니 참고하기 바란다.
( 참조: 파이토치, colab에서 MNIST파일 dataloader로 불러오기 )

 


파이토치 라이트닝 볼치에는 AutoEncoder와 Variational AutoEncoder 2개의 알고리즘이 정의돼 있다. 아래와 같이 AutoEncoder 알고리즘을 불러올 수 있다.

 

from pl_bolts.models.autoencoders import AE
import torch.nn as nn

 

resnet18과 resnet101중에 원하는 모델을 선택해 학습시킬 수 있다.

 

model = AE(input_height=32, enc_type="resnet18")

 

파이토치 라이트닝 볼트에서 제공하는 autoencoder 모델은 인풋이 RGB칼러로 들어가야 한다. MNIST는 그레이스케일이므로, auto-encoder 모델을 1개의 색깔 채널만 인풋으로 받는 것으로 수정하였다.

 

model.encoder.conv1 = nn.Conv2d(1, 64, kernel_size=(3,3), stride=(1,1), padding=(1,1), bias=False)
model.decoder.conv1 = nn.Conv2d(64, 1, kernel_size=(3,3), stride=(1,1), padding=(1,1), bias=False)

 

이후에는 pytorch를 학습 프로세스에 따라 모형을 학습시킬 수 있다. pytorch lightning의 Trainer를 이용하면, 쉽게 학습도 가능하다. 아래아 같이 fit함수를 호출하여 학습하였다.

 

from pytorch_lightning import Trainer

trainer = Trainer(gpus=1, max_epochs=1)
trainer.fit(model, train_loader)

 

 

Recommendation 포스팅

 

 

다음으로 오토인코더 모형이 어떻게 이미지를 표현하는지 출력해보았다. 먼저 모델을 평가모드로 변경하였다.

 

model.eval()

 

다음 원본 이미지를 출력해 보았다.

 

images, labels = next(iter(test_loader))
grid_img = torchvision.utils.make_grid(images, nrow=10)

import matplotlib.pyplot as plt
plt.imshow(grid_img.permute(1, 2, 0))

 

MNIST 원본 이미지

 

 

다음 모형에 의해 변형된 이미지를 출력해보았다.

 

predicted_images = model(images)
grid_img = torchvision.utils.make_grid(predicted_images, nrow=10)

import matplotlib.pyplot as plt
plt.imshow(grid_img.permute(1, 2, 0))

 

MNIST AUTO-ENCODER 결과

 

 

오늘은 이렇게 CNN기반의 오토인코더 알고리즘을 쉽게 학습하는 방법에 대해서 알아보았다. EPOCH을 1번 밖에 하지 않았지만, 이미지가 단순해서 그런지 예상보다 결과가 잘 나온 것 같다. 위 오토인코더 모형은 가로, 세로 크기가 동일해야 한다는 단점이 있다. 가로, 세로 크기가 다른 이미지를 학습시키는 방법도 추후에 다뤄보도록 하겠다.

댓글()