CS336 7강 — 병렬화 1: 데이터 병렬과 ZeRO/FSDP
CS336-LLM-From-Scratch 시리즈의 7단계입니다. 전체 지도는 CS336 커리큘럼에서 볼 수 있습니다. (6강 — 커널과 Triton에서 이어집니다.)
지금까지는 GPU 한 장을 짜내는 이야기였습니다. 7강(Tatsunori Hashimoto)부터는 여러 대로 넘어갑니다. 큰 모델은 한 GPU에 안 들어가고(메모리), 빨리 학습하려면 여러 서버를 동시에 써야 하니까(연산) — 이제 연산 단위는 GPU가 아니라 데이터센터입니다. 이 강의는 병렬화 전체(데이터·모델·활성화)를 다루지만, 이 글은 데이터 병렬 계열(집합 통신 → 데이터 병렬 → ZeRO/FSDP)에 집중합니다. 텐서·파이프라인 병렬은 8강에서 이어집니다.
한눈에 보기
데이터 병렬의 발목을 잡는 건 메모리입니다. 모든 GPU가 파라미터·그래디언트·옵티마이저 상태를 통째로 복제하니 — 파라미터당 16바이트가 GPU마다 그대로 반복됩니다. ZeRO는 이 중복을 단계적으로 걷어내, 같은 통신량으로 메모리를 GPU 수만큼 나눕니다.
flowchart TD
DDP["DDP / ZeRO-0<br/>모든 GPU가 전부 복제<br/>16 B/param"] -->|"옵티마이저 상태 분할"| Z1["ZeRO-1<br/>(통신 그대로 · 사실상 공짜)"]
Z1 -->|"+ 그래디언트 분할"| Z2["ZeRO-2"]
Z2 -->|"+ 파라미터 분할"| Z3["ZeRO-3 = FSDP<br/>모든 것을 GPU 수로 나눔"]
DDP -.->|"7.5B·64 GPU 예시"| M0["120 GB"]
Z1 -.-> M1["31.4 GB"]
Z2 -.-> M2["16.6 GB"]
Z3 -.-> M3["1.9 GB"]
목표는 둘 — 선형 메모리 스케일링(GPU를 늘리면 더 큰 모델)과 선형 연산 스케일링(GPU를 늘리면 더 빠르게). 그리고 이 알고리즘들의 성능은 결국 집합 통신을 몇 번 부르나로 따집니다.
왜 여러 대인가
두 자원이 한 GPU의 한계를 만듭니다.
- 연산. 가장 빠른 슈퍼컴퓨터는 엑사플롭스 단위입니다. 지금 당장 강력한 모델을 학습하려면 GPU 곡선이 몇 년 더 오르길 기다릴 수 없으니, 다기계 병렬에 기댑니다.
- 메모리. 수십억 파라미터는 한 GPU에 안 들어갑니다. GPU 메모리도 늘지만 모델만큼 빠르진 않습니다.
네트워킹 계층
병렬화 전략은 하드웨어 통신 계층에서 출발합니다. 한 노드(서버) 안의 8개 GPU는 NVLink/NVSwitch로 아주 빠르게 연결됩니다. 하지만 다른 노드의 GPU와 통신하려면 InfiniBand 스위치를 거치는데, 레인당 약 8배 느립니다. 또 약 256 GPU까지는 all-to-all로 빠르게 통신되지만, 그 너머는 leaf·spine 스위치로 더 느려집니다. 이 위계가 “무엇을 어디에 병렬화할지”를 결정합니다.
집합 통신 (collective communication)
병렬 알고리즘은 몇 가지 집합 통신 프리미티브로 조립됩니다.
| 연산 | 하는 일 | 대략 비용 |
|---|---|---|
| all-reduce | 모두의 값을 합쳐 모두에게 복사 | ~2× 데이터 |
| broadcast | 한 랭크 → 모두 | ~1× |
| reduce | 모두 → 합쳐서 한 랭크 | ~1× |
| all-gather | 각 랭크의 조각을 모두에게 | ~1× |
| reduce-scatter | 합치되 각 조각을 한 랭크씩에게(부분 all-reduce) | ~1× |
이 강의 전체를 관통하는 핵심 항등식 하나:
all-reduce = reduce-scatter + all-gather (대역폭 비용이 같다)
데이터 병렬에서 그래디언트를 합치는 자연스러운 연산은 all-reduce지만, 이를 reduce-scatter와 all-gather 두 단계로 쪼개도 비용이 같습니다. 그 사이에 약간의 계산을 끼워 넣을 수 있다는 점이 — 곧 보겠지만 — ZeRO의 마법입니다.
데이터 병렬
가장 단순한 출발점. 파라미터는 GPU마다 복제하고, 배치를 쪼개 각 GPU가 다른 부분을 맡습니다. 각 GPU가 자기 몫(B/M개)의 그래디언트를 계산한 뒤, all-reduce로 그래디언트를 합치고, 파라미터를 갱신합니다.
각 GPU: 배치의 1/M 받기 → 순전파·역전파 → 로컬 그래디언트
↓ all-reduce (그래디언트 합산, ~2× 파라미터 통신)
모든 GPU: 동일한 평균 그래디언트 → 각자 파라미터 갱신(동일 결과)
- 연산 스케일링: 좋음. 각 GPU가 마이크로배치를 받아 연산을 채웁니다.
- 통신: 배치당 ~2× 파라미터. 배치가 크면 이 동기화 비용을 가립니다.
- 메모리 스케일링: 끔찍함. 모든 GPU가 파라미터·옵티마이저 상태를 통째로 복제합니다. 손도 못 댄 셈.
파라미터당 16바이트
메모리가 왜 문제인지 숫자로 봅시다. 혼합 정밀도 학습에서 파라미터 하나가 끌고 다니는 메모리는 16바이트입니다.
| 항목 | 바이트/파라미터 |
|---|---|
| 파라미터 (BF16) | 2 |
| 그래디언트 (BF16) | 2 |
| 마스터 가중치 (FP32, 누적용) | 4 |
Adam 1차 모멘트 m (FP32) |
4 |
Adam 2차 모멘트 v (FP32) |
4 |
| 합계 | 16 |
순수 파라미터는 2바이트면 되는데 8배가 붙고, 그 대부분이 옵티마이저 상태(마스터+m+v = 12바이트)입니다. (2강의 “16바이트/파라미터”가 전부 FP32 기준이었다면, 여기선 마스터 가중치를 둔 혼합 정밀도 분해 — 합계는 같은 16.) 그리고 이 메모리가 모든 GPU에 그대로 복제되니, 7.5B 모델을 64 GPU에 올리면 GPU당 ~120GB가 됩니다. 끔찍합니다.
핵심 통찰: 데이터 병렬을 하려면 파라미터·그래디언트는 복제해야 할 것 같지만, 옵티마이저 상태까지 모든 GPU에 둘 필요는 없습니다.
ZeRO: 중복을 걷어내다
ZeRO(Zero Redundancy Optimizer)는 복제된 것들을 단계적으로 분할(shard)합니다.
Stage 1 — 옵티마이저 상태 분할
모든 GPU가 파라미터·그래디언트는 전부 갖되, 옵티마이저 상태(m·v)는 자기 몫만 갖고 자기 몫의 파라미터만 갱신합니다.
① 각 GPU: 자기 데이터의 full 그래디언트 계산
② reduce-scatter: 그래디언트를 합쳐 — 각 GPU가 '자기 담당 파라미터'의 합산 그래디언트를 받음
③ 각 GPU: 자기 담당 파라미터만 Adam 갱신 (옵티마이저 상태가 거기 있으니까)
④ all-gather: 갱신된 파라미터를 모두에게 복사
여기서 항등식이 빛납니다 — ②reduce-scatter + ④all-gather = all-reduce이므로, 통신 비용은 나이브 데이터 병렬과 똑같습니다(2× 파라미터). 그런데 옵티마이저 상태는 GPU 수만큼 나뉩니다. 사실상 공짜로 메모리를 아낍니다(120 → 31.4GB). 안 할 이유가 없습니다.
Stage 2 — + 그래디언트 분할
그래디언트도 분할합니다. 복잡함이 하나 늘어요 — full 그래디언트 벡터를 통째로 만들면 OOM이 나니, 역전파가 한 층의 그래디언트를 계산하는 즉시 담당 GPU로 보내고 나머지는 버립니다. 통신 총량은 그대로(2× 파라미터)지만 층별 동기화 오버헤드가 약간 붙습니다(→ 16.6GB).
Stage 3 = FSDP — + 파라미터 분할
모든 것(파라미터·그래디언트·옵티마이저 상태)을 분할합니다. FSDP(Fully Sharded Data Parallel)가 바로 ZeRO Stage 3입니다. 어떤 GPU도 전체 파라미터를 갖지 않으므로, 계산 그래프를 따라가며 필요할 때 파라미터를 요청(all-gather)합니다.
각 층마다: all-gather로 그 층 파라미터 모으기 → 순전파 → 즉시 파라미터 버리기
역전파: all-gather로 모으기 → 그래디언트 계산 → reduce-scatter로 갱신 → 버리기
통신 비용은 3× 파라미터로 늘지만(2× → 3×), 놀랍게도 오버헤드가 작습니다. 비결은 통신과 연산의 겹치기(overlap) — 지금 층을 계산하는 동안 다음 층 파라미터를 미리 가져옵니다(prefetch). 6강에서 본 “CPU가 앞서 달리며 큐에 넣기”가 여기서 재현됩니다. 그래서 8×A100(80GB) 한 노드에서 baseline ~6B → ZeRO-3로 ~50B 모델까지 올릴 수 있습니다.
FSDP가 이렇게 인기인 또 다른 이유: 아키텍처를 몰라도 됩니다. 데이터 병렬은 모델 내부를 들여다볼 필요 없이 임의의 신경망을 감싸 병렬화할 수 있습니다.
배치 크기라는 자원
데이터 병렬엔 결정적 한계가 있습니다 — 배치 크기보다 더 병렬화할 수 없습니다(GPU당 최소 1개 예시). 게다가 배치를 무작정 키우면 임계 배치 크기(critical batch size) 너머로는 최적화 효율이 빠르게 떨어집니다(작을 땐 그래디언트 노이즈 감소가 값지지만, 어느 순간 변수는 노이즈가 아니라 스텝 수가 됩니다). 그래서 배치 크기는 유한한 자원이고, 데이터 병렬·파이프라인 병렬 등에 나눠 쓰는 대상입니다 — 이 관점은 8강에서 중요해집니다.
데이터 병렬의 한계 → 모델 병렬
ZeRO Stage 1·2는 메모리를 못 줄이고, Stage 3는 좋지만 느려질 수 있으며, 결정적으로 활성화(activation) 메모리를 줄이지 못합니다. 모델을 통째로 쪼개 GPU마다 다른 부분이 살게 하면 활성화 메모리도 함께 줄어듭니다 — 그게 모델 병렬(텐서·파이프라인)이고, 8강에서 다룹니다. 모델 병렬은 파라미터를 옮기지 않고 활성화를 주고받는다는 점이 데이터 병렬과 결정적으로 다릅니다.
성능·복잡도 노트
- 새 단위는 데이터센터. 목표는 선형 메모리·연산 스케일링. 성능 분석은 집합 통신 횟수로 환원됩니다.
- 항등식이 ZeRO를 만든다. all-reduce = reduce-scatter + all-gather. 같은 대역폭으로 그 사이에 분할 갱신을 끼워 메모리를 GPU 수로 나눕니다.
- 옵티마이저 상태가 메모리를 지배. 16바이트 중 12바이트가 마스터+m+v. ZeRO-1은 이를 거의 공짜로 분할 — 무조건 켜세요.
- 겹치기가 FSDP를 살린다. 파라미터를 계속 주고받는데도 빠른 건, 통신을 연산 뒤에 숨기는 prefetch 덕분입니다.
요약
- 큰 모델은 한 GPU를 넘는다 — 연산·메모리 모두. 새 단위는 데이터센터, 통신 계층은 NVLink(노드 내) ≫ InfiniBand(노드 간), 임계점 ~256 GPU.
- 집합 통신으로 조립한다. 핵심 항등식 all-reduce = reduce-scatter + all-gather.
- 데이터 병렬: 배치 쪼개기 + 그래디언트 all-reduce. 연산✓ 통신 가림 가능, 그러나 메모리✗(파라미터당 16바이트를 전 GPU가 복제, 대부분 옵티마이저 상태).
- ZeRO: 1(옵티마이저 상태, 거의 공짜) → 2(+그래디언트) → 3=FSDP(+파라미터, 통신 3× 그러나 겹치기로 가림). 메모리를 GPU 수로 나눈다.
- 배치 크기는 유한한 자원이고, 데이터 병렬은 활성화 메모리를 못 줄인다 → 모델 병렬(8강)로.
다음 학습 (Next Learning)
- CS336 8강 — 병렬화 2: 텐서·파이프라인 병렬과 3D 병렬화 — 모델을 쪼개 활성화만 주고받기, 그리고 3D 병렬화
- CS336 6강 — 커널과 Triton — FSDP의 통신·연산 겹치기를 이해하는 토대
- CS336 커리큘럼 — 전체 17단계 지도와 진행 현황