본문 바로가기
Python/PyTorch

[PyTorch] Torchvision과 utils.data

by _sweep 2022. 2. 23.

이수안컴퓨터연구소파이토치(PyTorch) 기초 영상을 보고 정리한 내용입니다.

 

 

Torchvision

Torchvision은 대표적으로 전처리할 때 사용하는 메서드인 transforms를 제공하는 패키지이다.
이외에도 일반적으로 클래스를 따로 만들어 전처리 단계를 진행하며 다양한 전처리 기술을 제공한다.

 

import torch
import torchvision
import torchvision.transforms as transforms

transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize(mean=(0.5,), std=(0.5,))])

 

예를 들어 DataLoader의 인자로 들어갈 transform을 미리 정의할 수 있고 transforms.Compose()를 통해 리스트 안의 순서대로 전처리를 진행할 수 있다.

이때 ToTensor()를 하는 이유는 torchvision이 PIL Image 형태로만 입력을 받기 때문에 데이터 처리를 위해서 텐서의 형태로 변환해 주기 위함이다.

 

 

utils.data

Dataset에는 MNIST, CIFAR10 등 다양한 데이터셋이 존재한다.
DataLoader, Dataset을 통해 batch_size, train 여부, transform 등을 인자로 넣어 데이터를 어떻게 로드할 것인지 정해줄 수 있다.

 

import torch
from torch.utils.data import Dataset, DataLoader

import torchvision
import torchvision.transforms as transforms

trainset = torchvision.datasets.MNIST(root='/content/',
                                      train=True,
                                      download=True,
                                      transform=transform
                                      )

testset = torchvision.datasets.MNIST(root='/content/',
                                      train=False,
                                      download=True,
                                      transform=transform
                                      )

 

torch의 utils.data에서 MNIST data set을 가져오는 작업을 수행하고 있다.

학습 여부를 기준으로 trainset과 testset을 구분했다.

 

train_loader = DataLoader(trainset, batch_size=8, shuffle=True, num_workers=2)
test_loader = DataLoader(testset, batch_size=8, shuffle=False, num_workers=2)

dataiter = iter(train_loader)
images, labels = dataiter.next()
print(images.shape)
print(labels.shape)

# output
# torch.Size([8, 1, 28, 28])
# torch.Size([8])

 

DataLoader로 trainset과 testset의 값을 각각 가져온 뒤 데이터를 하나 확인해보면 위와 같은 값을 얻는다.

 

import matplotlib.pyplot as plt
plt.style.use('seaborn-white')

torch_image = torch.squeeze(images[0])
image = torch_image.numpy()
label = labels[0].numpy()

plt.title(label)
plt.imshow(image, 'gray')
plt.show()

 

이를 확인해 보기 위해 matplotlib을 적용하면 다음의 결과를 얻을 수 있었다.

 

 

 



 

'Python > PyTorch' 카테고리의 다른 글

[PyTorch] nn과 nn.functional  (0) 2022.02.23
[PyTorch] CUDA와 자동미분  (0) 2022.02.23
[PyTorch] 텐서의 연산과 조작  (0) 2022.02.22
[PyTorch] 파이토치와 텐서  (0) 2022.02.22

댓글