Pytorch TensorDataset과 DataLoader 사용하는 방법은?!

2021. 3. 5. 06:08

이전 포스팅에서 Pytorch를 이용하여 블로그 텍스트 생성 모델을 만들어 보았다. 하지만, Input 데이터를 거의 다 직접 만들어주다 보니 사용하기도 힘들고 어려웠다. Pytorch에는 입력 데이터를 쉽게 처리하고, 배치 단위로 잘러서 학습할 수 있게 도와주는 모듈이 있다. 오늘은 TensorDataset과 DataLoader에 대해서 알아보도록 하겠다.

 

 

Pytorch TensorDataset & DataLoader



이미지는 폴더로 저장하고, 이미 관련 모듈을 이용해 쉽게 사용할 수 있다. text도 torchtext라는 모듈이 있지만, 한글은 왠지 적용하기가 어려울 것 같다. 그보다는 X,Y를 다 만든 후에 아래와 같이 TensorDataSet과 DataLoader를 이용하는 것이 쉬울 듯 하다.

먼저 학습이 필요한 텐서를 모아 TensorDataset을 만들어준다. 보통 X, Y를 넣지만, 그 외에 필요한 Tensor가 있으면 추가로 입력해도 된다. numpy 데이터 타입을 Float 텐서로 변경하고, TensorDataset으로 정의해 보았다.

 

x_train = np_data[:, 0:19, :]
y_train = np_data[:, 19, :]

x_train = torch.FloatTensor(x_train)
y_train = torch.FloatTensor(y_train)

ds = TensorDataset(x_train, y_train)

 

 

다음 DataLoader함수를 이용해, 배치크기나 데이터를 섞을지 여부 등을 결정해 준다.

 

dl = DataLoader(ds, batch_size = 1024, shuffle=True)

 

 

다음으로 for문을 이용해서 아래와 같이 학습시키면 쉽게 배치크기 단위로 학습이 가능하다.

 

for idx, (x, y) in enumerate(dl):
    # 파이토치 학습코드 작성

 

DataLoader에 대한 더 자세한 내용은 pytorch 공식 페이지에 잘 나와 있으니 참고하기 바란다.
( 참조: pytorch.org/docs/stable/data.html )

 

 

오늘은 이렇게 Pytorch TensorDataset과 DataLoader에 대해서 알아보았다. 파이토치에는 딥러닝을 쉽게 트레이닝할 수 있게 도와주는 모듈과 패키지들이 많이 있다. 다음 포스팅에서도 이러한 모듈이나 패키지들에 대해서 다뤄보도록 하겠다.

 

 

댓글()