티스토리 뷰

Dataset이란?

Dataset은 데이터 항목을 나타내는 추상 class이다. Pytorch에서 DataLoader를 사용하려면, Dataset class를 상속받아 최소 두개의 method인 __len__(), __getitem__()를 override해야한다. 

 

* 데이터를 load하고 batch로 구성할 때 DataLoaser, Dataset을 사용하면, 전처리 및 로딩 과정이 효율적이고 간편해진다. 

 

 

Dataset object를 구성할 때 필요한 method

# __init__() : 데이터셋의 초기화 매서드 - 데이터 경로, train, test set 정의 등

# __len__() : 데이터셋 크기를 반환함 - DataLoader에서 배치 처리를 위해 길이를 반환하는 것이 필요함

# __getitem__() : 주어진 인덱스에 해당하는 데이터, input data, label을 반환함 

 

예시

- 일반적으로 사용하는 형태

from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index], self.labels[index]

 

- 실제 사용하는 다양한 형태

from torch.utils.data.dataset import Dataset

class Denoising(Dataset):
    def __init__(self):
        # 원본이미지
        self.mnist = MNIST('./data/', train=True, download=True, transform=ToTensor())
        # 노이즈 추가한 데이터 담을 리스트
        self.data = [] 
        
        # 노이즈 입히기
        for i in range(len(self.mnist)):
            noisy_input = gaussian_noise(self.mnist.data[i])
            input_tensor = torch.tensor(noisy_input) 
            self.data.append(torch.unsqueeze(input_tensor, dim=0)) 
        
    # 데이터 개수 반환하는 함수
    def __len__(self):
        return len(self.data)

    # 데이터 불러오는 함수(+전처리)
    def __getitem__(self, i):
        data = self.data[i]

        # 원본 이미지도 0~1로 값을 맞춰줌 (min-max정규화, mnist 이미지 데이터에서 min값이 0임)
        label = self.mnist.data[i]/255

        return data, label
from torch.utils.data.dataset import Dataset

class Pets(Dataset):
    def __init__(self,
                 path_to_img,
                 path_to_anno,
                 train=True,
                 transforms=None,
                 input_size=(128,128)):
    
        # 정답과 입력 이미지를 이름순으로 정렬
        self.images = sorted(glob.glob(path_to_img+'/*.jpg'))
        self.annotations = sorted(glob.glob(path_to_anno+'/*png'))
        
        # 데이터셋을 학습, 평가로 나눔 (8:2)
        self.X_train = self.images[:int(0.8*len(self.images))]
        self.X_test = self.images[int(0.8*len(self.images)):]
        self.Y_train = self.annotations[:int(0.8*len(self.annotations))]
        self.Y_test = self.annotations[int(0.8*len(self.annotations)):]

        self.train = train # 학습용 or 평가용
        self.transforms = transforms # 데이터 증강
        self.input_size = input_size # 입력 이미지 크기
        
    # 데이터 개수 반환 함수
    def __len__(self):
        if self.train:
            return len(self.X_train)
        else:
            return len(self.X_test)
    
    # 정답 label 변환 함수
    def preprocess_mask(self, mask):
        mask = mask.resize(self.input_size)
        mask = np.array(mask).astype(np.float32) 
        mask[mask != 2.0] = 1.0
        mask[mask == 2.0] = 0.0
        mask = torch.tensor(mask)
        return mask
    
    # 데이터를 불러오는 함수(+이미지 전처리)
    def __getitem__(self, i): # i번째 데이터와 정답을 반환
        if self.train:
            X_train = Image.open(self.X_train[i])
            X_train = self.transforms(X_train)
            Y_train = Image.open(self.Y_train[i])
            Y_train = self.preprocess_mask(Y_train)
            return X_train, Y_train
        else:
            X_test = Image.open(self.X_test[i])
            X_test = self.transforms(X_test)
            Y_test = Image.open(self.Y_test[i])
            Y_test = self.preprocess_mask(Y_test)
            return X_test, Y_test

 

DataLoader 사용하기

DataLoader를 이용해 Dataset을 불러와서 아래와 같은 방식으로 train 한다. 

trainset = Denoising()
train_loader = DataLoader(trainset, batch_size=32) # 한번에 이미지 32장 사용

for epoch in range(20):
    iterator = tqdm.tqdm(train_loader)
    
    for data, label in iterator:
        optim.zero_grad()
        pred = model(data.to(device))
        
        loss = nn.MSELoss()(torch.squeeze(pred), label.to(device))
        loss.backward()
        optim.step()
        iterator.set_description(f'epoch{epoch+1} loss{loss.item()}')
댓글
공지사항
최근에 올라온 글
최근에 달린 댓글
Total
Today
Yesterday
링크
«   2025/05   »
1 2 3
4 5 6 7 8 9 10
11 12 13 14 15 16 17
18 19 20 21 22 23 24
25 26 27 28 29 30 31
글 보관함