티스토리 뷰
Dataset이란?
Dataset은 데이터 항목을 나타내는 추상 class이다. Pytorch에서 DataLoader를 사용하려면, Dataset class를 상속받아 최소 두개의 method인 __len__(), __getitem__()를 override해야한다.
* 데이터를 load하고 batch로 구성할 때 DataLoaser, Dataset을 사용하면, 전처리 및 로딩 과정이 효율적이고 간편해진다.
Dataset object를 구성할 때 필요한 method
# __init__() : 데이터셋의 초기화 매서드 - 데이터 경로, train, test set 정의 등
# __len__() : 데이터셋 크기를 반환함 - DataLoader에서 배치 처리를 위해 길이를 반환하는 것이 필요함
# __getitem__() : 주어진 인덱스에 해당하는 데이터, input data, label을 반환함
예시
- 일반적으로 사용하는 형태
from torch.utils.data import Dataset
class CustomDataset(Dataset):
def __init__(self, data, labels):
self.data = data
self.labels = labels
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index], self.labels[index]
- 실제 사용하는 다양한 형태
from torch.utils.data.dataset import Dataset
class Denoising(Dataset):
def __init__(self):
# 원본이미지
self.mnist = MNIST('./data/', train=True, download=True, transform=ToTensor())
# 노이즈 추가한 데이터 담을 리스트
self.data = []
# 노이즈 입히기
for i in range(len(self.mnist)):
noisy_input = gaussian_noise(self.mnist.data[i])
input_tensor = torch.tensor(noisy_input)
self.data.append(torch.unsqueeze(input_tensor, dim=0))
# 데이터 개수 반환하는 함수
def __len__(self):
return len(self.data)
# 데이터 불러오는 함수(+전처리)
def __getitem__(self, i):
data = self.data[i]
# 원본 이미지도 0~1로 값을 맞춰줌 (min-max정규화, mnist 이미지 데이터에서 min값이 0임)
label = self.mnist.data[i]/255
return data, label
from torch.utils.data.dataset import Dataset
class Pets(Dataset):
def __init__(self,
path_to_img,
path_to_anno,
train=True,
transforms=None,
input_size=(128,128)):
# 정답과 입력 이미지를 이름순으로 정렬
self.images = sorted(glob.glob(path_to_img+'/*.jpg'))
self.annotations = sorted(glob.glob(path_to_anno+'/*png'))
# 데이터셋을 학습, 평가로 나눔 (8:2)
self.X_train = self.images[:int(0.8*len(self.images))]
self.X_test = self.images[int(0.8*len(self.images)):]
self.Y_train = self.annotations[:int(0.8*len(self.annotations))]
self.Y_test = self.annotations[int(0.8*len(self.annotations)):]
self.train = train # 학습용 or 평가용
self.transforms = transforms # 데이터 증강
self.input_size = input_size # 입력 이미지 크기
# 데이터 개수 반환 함수
def __len__(self):
if self.train:
return len(self.X_train)
else:
return len(self.X_test)
# 정답 label 변환 함수
def preprocess_mask(self, mask):
mask = mask.resize(self.input_size)
mask = np.array(mask).astype(np.float32)
mask[mask != 2.0] = 1.0
mask[mask == 2.0] = 0.0
mask = torch.tensor(mask)
return mask
# 데이터를 불러오는 함수(+이미지 전처리)
def __getitem__(self, i): # i번째 데이터와 정답을 반환
if self.train:
X_train = Image.open(self.X_train[i])
X_train = self.transforms(X_train)
Y_train = Image.open(self.Y_train[i])
Y_train = self.preprocess_mask(Y_train)
return X_train, Y_train
else:
X_test = Image.open(self.X_test[i])
X_test = self.transforms(X_test)
Y_test = Image.open(self.Y_test[i])
Y_test = self.preprocess_mask(Y_test)
return X_test, Y_test
DataLoader 사용하기
DataLoader를 이용해 Dataset을 불러와서 아래와 같은 방식으로 train 한다.
trainset = Denoising()
train_loader = DataLoader(trainset, batch_size=32) # 한번에 이미지 32장 사용
for epoch in range(20):
iterator = tqdm.tqdm(train_loader)
for data, label in iterator:
optim.zero_grad()
pred = model(data.to(device))
loss = nn.MSELoss()(torch.squeeze(pred), label.to(device))
loss.backward()
optim.step()
iterator.set_description(f'epoch{epoch+1} loss{loss.item()}')
'개발' 카테고리의 다른 글
[PyTorch] 이미지 Classification - CNN (1) CIFAR10 dataset (0) | 2023.02.20 |
---|---|
[Snowflake] data insert 하는 방법 (1) snowlight worksheet SQL (0) | 2023.01.06 |
[Pytorch] torchtext로 텍스트 classification (0) | 2022.07.13 |
[AutoML] auto-sklearn classification example (0) | 2022.06.26 |
댓글