기본적으로 Dataset을 구성할 때 파이토치의 torch.utils.data의 Dataset 클래스를 상속해서 만든다.
map-style dataset은 아래와 같이 3가지 메서드로 구성된다.
(map-style datasets은 getitem()과 len()을 구현하는 데이터셋으로 index(key)를 통해 데이터에 접근할 수 있는 형태이다.)
from torch.utils.data import Dataset
class CustomDataset(Dataset):
def __init__(self,):
pass
def __len__(self):
pass
def __getitem__(self, idx):
pass
__init__ 필요한 변수들을 선언한다. x_data와 y_data를 load하거나 파일목록을 load한다. 이미지를 처리할 transforms들을 Compose 해서 정의해둔다.
__len__ Dataset의 갯수를 반환한다.
__getitem__ idx 번째 데이터를 반환한다.
Iterable-style datasets은 __iter__() 기능을 구현하는 데이터셋이다.
데이터셋의 랜덤 읽기가 어렵거나 불가능할 경우에 적합하다. (stream data, real-time log 등)
Map-style dataset의 경우 인덱스를 통해 데이터에 접근할 수 있지만
Iterable-style은 next를 통해서 접근 하기 때문에 sampler를 사용하기 어려워 random shuffle을 원할 경우 임의로 미리 shuffle을 진행한 후 사용해야 한다.
from torch.utils.data import IterableDataset
class CustomDataset(IterableDataset):
def __init__(self, data_path):
self.data_path = data_path
def __iter__(self):
iter_csv = pd.read_csv(self.data_path, sep='\t', iterator=True, chunksize=1)
for line in iter_csv:
line = line['text'].item()
yield line
DataLoader
Dataloader는 모델 학습을 위해서 데이터를 Mini batch 단위로 제공해주는 역할이다.
DataLoader(dataset, # Dataset 인스턴스가 들어감
batch_size=1, # 배치 사이즈를 설정
shuffle=False, # 데이터를 섞어서 사용하겠는지를 설정
sampler=None, # sampler는 index를 컨트롤
batch_sampler=None, # 위와 비슷하므로 생략
num_workers=0, # 데이터를 불러올때 사용하는 서브 프로세스 개수
collate_fn=None, # map-style 데이터셋에서 sample list를 batch 단위로 바꾸기 위해 필요한 기능
pin_memory=False, # Tensor를 CUDA 고정 메모리에 할당
drop_last=False, # 마지막 batch를 사용 여부
timeout=0, # data를 불러오는데 제한시간
worker_init_fn=None # 어떤 worker를 불러올 것인가를 리스트로 전달
)
dataset 생성한 Dataset 인스턴스를 입력한다.
batch_size 배치 사이즈
shuffle 데이터를 섞어서 사용할지 여부 (default: False)
sampler 데이터의 index를 원하는 방식대로 조정하는 방법 index를 컨트롤하기 때문에 shuffle 파라미터는 False여야 한다. 불균형 데이터셋의 경우, 클래스의 비율에 맞게끔 데이터를 배치마다 제공해야 할 필요가 있는데 이럴 때 사용한다.
SequentialSampler : 항상 같은 순서
RandomSampler : 랜덤, replacemetn 여부 선택 가능, 개수 선택 가능
SubsetRandomSampler : 랜덤 리스트, 위와 두 조건 불가능
WeigthRandomSampler : 가중치에 따른 확률
BatchSampler : batch단위로 sampling 가능
DistributedSampler : 분산처리 (torch.nn.parallel.DistributedDataParallel과 함께 사용)
num_workers 데이터 로딩을 하기위해 몇 개의 CPU 프로세스를 사용할 것인지를 의미
collate_fn dataset이 variable length이면 collate_fn을 꼭 사용해주어야 한다.
dataloader_example = torch.utils.data.DataLoader(dataset_example,
batch_size=2,
collate_fn=my_collate_fn)
for d in dataloader_example:
print(d['X'], d['y'])
drop_last batch 단위로 불러오는 경우, batch_size에 따라 마지막 batch의 길이가 달라질 수 있다. drop_last를 True로 설정할 경우 마지막 batch를 사용하지 않는다.