Automatic Mixed prcision이란, torch에서 몇가지 연산은 torch.float32(FP32, single-precision) 타입을 쓰고 어떤 연산은 torch.float16(FP16, half, half-precision)를 쓴다. 당연하게도 float16이 연산속도가 빠르지만 비정확하다. 그리고 타입이 다른 행렬끼리의 연산은 시간이 오래걸린다.
그래서 Mixed precision learning은 필요에 따라 FP16 연산 혹은 FP32 연산을 혼합하여 모델 학습을 진행하는 것으로 단순히 FP32 연산만을 사용하여 모델 학습하는 것에 비해 메모리 사용 감소, 에너지 사용 감소, 계산 속도 향상의 장점이 있다.
torch.cuda.amp.GradScaler( )
FP16 데이터 타입으로 gradient를 저장할 시, 작은 크기를 가지는 gradients는 zero가 될 가능성이 있다(underflow 현상).
이 경우 weight update가 제대로 되지 않아 네트워크 학습이 수렴하지 않을 수 있기 때문에 loss를 어떤 수(scale factor)만큼 곱해 크게 만드는 loss scaling 기법이다.
이를 통해 backward pass에서 gradient를 계산할 때, 그 값을 scaling 하여, 작은 gradient를 큰 수로 만들어 underflow를 방지할 수 있다. 다만 gradient가 실제 값보다 scale factor만큼 곱해진 값을 가지므로 weight update 시에는 scale factor만큼 나눠주는 unscale이 필수이다.
GradScalar 내의 scale factor는 미리 정해진 수이다. Scale factor는 gradient 계산 시에만 이용되고 weight update시에는 unscale되기에 적절한 알려진 수이기만 하면 된다. 다만 scale factor가 너무 클 경우 gradient를 FP16으로 표현할 수 있는 수보다 크게 scale할 수 있다(이를 overflow라고 함). overflow가 발생하면 gradient가 매우 이상한 값(inf, NaN 등)을 가지게 되므로 학습을 한번에 diverge하게 할 수 있다. 따라서 overflow가 발생하면 다음 두 가지 방법으로 학습을 안정화한다.
1. Scaled gradients가 inf, NaN이 되면, step() 함수는 skip되어 해당 gradients는 weight update에 사용되지 않고 버린다.
2. Scale factor가 크다고 판단되어 update() 함수를 통해 더 작은 수로 교체된다. 미리 정해진 backoff_factor 만큼 곱해져 scale factor의 크기는 감소된다.
다만 scale factor가 작아지기만 하면 작은 gradient가 underflow되는 현상이 발생할 수 있다. 따라서 growth_interval 만큼의 iteration동안 overflow가 발생하지 않는다면 를 growth_factor 곱하여 scale factor의 크기를 키운다.
torch.cuda.amp.autocast( )
autocast는 context manager로서, autocast가 선언된 코드 영역에서는 mixed precision 연산이 진행된다. 이 영역 내에서 연산들은 FP16(BF16) or FP32 중 autocast가 선택한 data type으로 연산이 되는데, 따로 타입 변환을 위한 함수를 호출할 필요 없이 영역에서 실행되는 것만으로 기준에 따라 데이터 타입이 변환된다.
autocast는 딥러닝 네트워크 학습 시 forward pass(loss를 계산하는 것까지)에서만 선언되어야 한다. Backward pass는 forward pass에서 선택된 data type으로 맞춰져서 실행된다.
autocast는 thread local이기에, 여러 thread에서 학습 실행 시, 모든 thread에서 각각 autocast를 선언해줘야 한다. Multi GPUs를 사용하거나 multiple nodes를 사용할 때 주의해야 한다.
[FP16으로 변환되어 연산 목록]
__matmul__, addbmm, addmm, addmv, addr, baddbmm, bmm, chain_matmul, multi_dot, conv1d, conv2d, conv3d, conv_transpose1d, conv_transpose2d, conv_transpose3d, GRUCell, linear, LSTMCell, matmul, mm, mv, prelu, RNNCell
# AMP : loss scale을 위한 GradScaler 생성
scaler = torch.cuda.amp.GradScaler()
### 학습 시작
for epoch in epochs:
for input, target in data:
optimizer.zero_grad()
# AMP : Forward pass 진행
# AMP : autocast를 통한 자동 FP32 -> FP16 변환 (가능한 연산에 한하여)
with autocast():
output = model(input)
loss = loss_fn(output, target)
# AMP : scaled loss를 이용해 backward 진행 (gradient 모두 같은 scale factor로 scale됨)
# AMP : backward pass는 autocast 영역 내에 진행될 필요 없음
# AMP : forward pass에서 사용된 같은 data type으로 backward pass는 실행됨
scaler.scale(loss).backward()
# AMP : scaler.step은 가장 먼저 unscale(grad를 scale factor만큼 나눠기)
# AMP : weight update 실시, 단 만약 grad 중 infs or NaNs이 있으면 step 스킵됨
scaler.step(optimizer)
# AMP : scale factor 업데이트
scaler.update()
reference: https://computing-jhson.tistory.com/m/37