Model Parallel
- 모델의 일부를 각각 다른 device에 할당하는 방식
- 이 때 한 device에서 연산하는 동안 다른 device가 유휴상태가 되는 문제가 발생한다.
class ModelParallelResNet(ResNet):
def __init__(self, *args, **kwargs):
super(ModelParallelResNtet, self).__init_(Bottleneck, [3,4,5,6], num_classes=num_classes)
self.seq1 = nn.Sequential(
...
).to('cuda:0') # 첫번째 모델을 cuda0에 할당
self.seq2 = nn.Sequential(
...
).to('cuda:1') # 두번째 모델을 cuda1에 할당
self.fc.to('cuda:1')
def forward(self, x):
x = self.seq2(self.seq1(x).to('cuda:1')) # 두 모델을 연결
return self.fc(x.view(x.size(0), -1))
Data Parallel
- 데이터를 나눠 GPU에 할당 후 결과의 평균을 취하는 방법
- Mini batch 연산을 여러 GPU에서 동시에 수행한다고 볼 수 있다.
PyTorch에서는 2가지 방식을 제공
- DataParallel
단순히 데이터를 분배 후 평균을 취함
→ GPU 사용 불균형 문제가 발생한다.
parallel_model = torch.nn.DataParallel(model)
...
preds = parallel_model(inputs)
loss = loss_function(preds, labels)
loss.mean().backward() # GPU 개수 만큼 나눠 loss의 평균을 구한 뒤 gradient 계산
optimizer.step()
...
- DistributedDataParallel
각 CPU마다 process를 생성하여 개별 GPU에 할당
→ 기본적으로 DataParallel와 같은 개념이지만 개별적으로 연산을 해서 평균을 낸다.
# Distributed Data Parallel을 위한 Train Loader 세팅
train_sampler = torch.utils.data.distributed.DistributedSampler(train_data)
suffle = False
pin_memory = True
num_workers = 3 # Tip: 보통 GPU 개수의 4배
trainloader = torch.utils.data.DataLoader(train_data, batch_size=20, shuffle=False, pin_memory=pin_memory, num_workers=num_workers, sampler=train_sampler)
def main():
n_gpus = torch.cuda.device_count()
torch.multiprocessing.spawn(main_worker, nprocs=n_gpus, args=(n_gpus, ))
def main_worker(gpu, n_gpus):
image_size = 224
batch_size = 512
num_worker = 8
epoch = ...
batch_size = int(batch_size / n_gpus)
num_worker = int(num_worker / n_gpus)
# 멀티프로세싱 통신 규약 정의
torch.distributed.init_process_group(backend='nccl', init_method='tcp://127.0.0.1:2568', world_size=n_gpus, rank=gpu)
model = Model
torch.cuda.set_device(gpu)
model = model.cuda(gpu)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu])
# cf. Python의 멀티프로세싱 코드
from multiprocessing import Pool
def f(x):
return x ** x
if __name__ == '__main__':
with Pool(5) as p:
print(p.map(f, [1,2,3]))
부스트캠프 AI Tech 교육 자료를 참고하였습니다.