CS336 2강 — PyTorch와 자원 회계: 6ND와 메모리를 냅킨에 계산하기
CS336-LLM-From-Scratch 시리즈의 2단계입니다. 전체 지도는 CS336 커리큘럼에서 볼 수 있습니다. (1강 — 토크나이제이션에서 이어집니다.)
1강이 “효율”이라는 코스의 주제를 선언했다면, 2강은 그 효율을 측정하는 법을 가르칩니다. 보통 우리는 모델을 짜고, 돌리고, 되는 대로 둡니다. 하지만 학습이 수억 달러로 번지는 순간 그 태도는 통하지 않습니다 — FLOPs와 메모리는 곧 달러이기 때문입니다. 이 강의(Percy Liang)는 텐서에서 시작해 학습 루프까지 PyTorch 프리미티브를 훑되, 매 단계에서 “이건 메모리를 얼마, 연산을 얼마 쓰는가”를 끈질기게 따집니다. 트랜스포머는 다루지 않습니다 — 선형 모델만으로도 핵심 회계는 전부 나옵니다.
한눈에 보기
자원 회계는 두 축으로 갈라집니다 — 메모리(무엇을 얼마나 저장하나)와 연산(FLOPs를 얼마나 쓰나). 이 둘을 알면 “이 학습이 며칠 걸리고, 이 GPU에 얼마나 큰 모델이 들어가나”라는 실전 질문에 냅킨 한 장으로 답할 수 있습니다.
flowchart TD
T["텐서<br/>(원자: 값 개수 × 바이트)"]
T --> MEM["메모리 회계"]
T --> CMP["연산 회계"]
MEM --> M1["부동소수점 타입<br/>FP32·BF16·FP8"]
MEM --> M2["파라미터+그래디언트<br/>+옵티마이저+활성화"]
CMP --> C1["행렬곱 = 2mnp"]
CMP --> C2["순전파 2ND · 역전파 4ND"]
M2 --> Q2["냅킨 ②<br/>이 GPU에 들어가는<br/>최대 모델 크기"]
C2 --> SixND["학습 ≈ 6ND<br/>(6 × 파라미터 × 토큰)"]
SixND --> Q1["냅킨 ①<br/>학습에 걸리는 시간"]
SixND --> MFU["MFU<br/>실측 ÷ 약속 FLOP/s"]
이 강의가 가르치는 두 개의 숫자는 결국 이것입니다 — 학습 연산량 ≈ 6ND, 그리고 AdamW 학습 메모리 ≈ 16바이트/파라미터.
냅킨 계산: 두 개의 질문
강의는 두 질문으로 문을 엽니다. 지금은 답만 보고, 아래에서 직접 유도합니다.
- 70B 밀집(dense) 트랜스포머를 15T 토큰으로, H100 1,024장에서 학습하면 며칠 걸리나?
→ 총 FLOPs =
6 × 70e9 × 15e12. 이를 (H100 FLOP/s × MFU 0.5 × 1024장 × 하루 초)로 나누면 약 144일. - H100 8장에서 AdamW로 (특별한 기교 없이) 학습 가능한 최대 모델은?
→ 메모리
8 × 80GB = 640GB를 파라미터당 16바이트로 나누면 약 40B 파라미터(활성화 무시한 거친 추정).
두 답 모두 곱셈·나눗셈 몇 번이 전부입니다. 핵심은 그 안에 들어가는 6과 16이 어디서 오는가입니다.
메모리 회계
텐서와 바이트
딥러닝의 모든 것 — 파라미터·그래디언트·옵티마이저 상태·활성화·데이터 — 은 텐서에 담깁니다. 텐서의 메모리는 단순합니다.
메모리 = (원소 개수) × (원소당 바이트 수)
import torch
x = torch.zeros(4, 8) # 기본 dtype: float32
x.numel() # 32 (원소 개수)
x.element_size() # 4 (바이트) — float32 = 32비트 = 4바이트
x.numel() * x.element_size() # 128 바이트
행렬 하나가 작아 보여도, GPT-3의 한 FFN 행렬은 2.3 GB에 이릅니다. 그래서 원소당 바이트를 줄이는 것 — 부동소수점 타입 선택 — 이 첫 번째 효율 레버입니다.
부동소수점 타입
| 타입 | 비트 | 부호/지수/가수 | 바이트 | 특징 |
|---|---|---|---|---|
| FP32 | 32 | 1 / 8 / 23 | 4 | 기본값, “full precision”. 안전하지만 무겁다 |
| FP16 | 16 | 1 / 5 / 10 | 2 | half precision. dynamic range가 좁아 작은 값이 underflow(0으로) |
| BF16 | 16 | 1 / 8 / 7 | 2 | brain float(2018). FP32의 dynamic range + 낮은 해상도 — DL 연산의 사실상 표준 |
| FP8 | 8 | 1/4/3 또는 1/5/2 | 1 | 2022, H100~. 매우 거칠다. 안정성 주의 |
torch.tensor([1e-8], dtype=torch.float16) # → tensor([0.]) ← underflow!
torch.tensor([1e-8], dtype=torch.bfloat16) # → 0이 아님 ← dynamic range 보존
BF16의 통찰은 “딥러닝은 해상도보다 dynamic range가 중요하다”입니다. 지수 비트를 FP32만큼 유지하고 가수를 깎아, 같은 2바이트로 안정성을 챙깁니다. 다만 파라미터와 옵티마이저 상태는 FP32로 둬야 학습이 무너지지 않습니다 — BF16은 순전파처럼 잠깐 쓰고 버리는 전이용(transitory) 표현으로 봅니다(→ 혼합 정밀도, mixed precision).
파라미터당 16바이트 (AdamW)
학습 중 한 파라미터가 끌고 다니는 메모리를 세어 보면, 냅킨 ②의 16이 나옵니다(전부 FP32 기준).
| 항목 | 바이트/파라미터 |
|---|---|
| 파라미터(weights) | 4 |
| 그래디언트(gradients) | 4 |
| 옵티마이저 상태 m (1차 모멘트) | 4 |
| 옵티마이저 상태 v (2차 모멘트) | 4 |
| 합계 | 16 |
여기에 활성화(activations)가 더해지는데, 이건 배치 크기·시퀀스 길이에 따라 달라져 위 16에는 빠져 있습니다(그래서 냅킨 ②가 “거친 추정”인 이유).
텐서의 내부: 스토리지와 뷰
회계를 정확히 하려면 “언제 복사가 일어나는가”를 알아야 합니다. PyTorch 텐서는 수학적 객체가 아니라 할당된 메모리(storage)를 가리키는 포인터 + 메타데이터(stride)입니다.
x = torch.arange(6).reshape(2, 3) # [[0,1,2],[3,4,5]]
x.stride() # (3, 1) — 다음 행으로 가려면 3칸, 다음 열로 가려면 1칸
# 원소 [1,2](값 5)의 storage 오프셋 = 1*3 + 2*1 = 5
y = x[0] # 첫 행 — 복사 아님! 같은 storage를 보는 뷰(view)
x[0, 0] = 99 # x를 바꾸면 y도 바뀐다 (포인터 공유)
x.t().is_contiguous() # False — 전치는 storage를 건드리지 않고 stride만 바꾼다
x.t().contiguous() # 여기서 비로소 복사가 일어난다
슬라이싱·view·transpose는 뷰라서 메모리를 새로 쓰지 않습니다 — 공짜이니 마음껏 변수로 쪼개 가독성을 높이세요. 다만 contiguous·reshape는 복사를 일으킬 수 있으니 회계할 때 주의합니다.
실전 팁: 차원을
-1·-2인덱스로 다루면 버그가 잦습니다. 강의는 einops(einsum/rearrange/reduce)로 차원에 이름을 붙이길 권합니다 — “행렬곱에 좋은 부기(bookkeeping)를 더한 것”입니다.
연산 회계: 행렬곱은 2mnp
FLOPs vs FLOP/s
먼저 용어를 분리합니다. FLOPs(소문자 s)는 수행한 부동소수점 연산의 개수(연산량)이고, FLOP/s는 초당 연산 수(하드웨어 속도)입니다. 이 글은 속도를 항상 /s로 적습니다.
딥러닝 연산량은 거의 전부 행렬곱이 지배합니다. (m×n)·(n×p) 행렬곱의 FLOPs는:
행렬곱 FLOPs = 2 × m × n × p (곱셈 1 + 덧셈 1, 그래서 ×2)
# 선형 모델: 데이터 X(B×D) · 가중치 W(D×K) → (B×K)
def matmul_flops(B, D, K):
return 2 * B * D * K # 2 × (세 차원의 곱)
이 식을 머신러닝 언어로 옮기면 강의의 첫 통찰이 나옵니다. B는 데이터 포인트(토큰) 수, D×K는 파라미터 수이므로:
순전파 FLOPs ≈ 2 × (토큰 수 N) × (파라미터 수 D) = 2ND
(시퀀스 길이가 너무 길지 않다면 트랜스포머에도 대략 성립합니다.) 다른 연산들은 텐서 크기에 선형이라, 충분히 큰 모델에선 행렬곱 외에는 무시할 만합니다 — 냅킨 계산이 단순해지는 이유입니다.
MFU: 하드웨어를 얼마나 짜냈나
연산량(FLOPs)을 시간으로 나누면 실측 속도(FLOP/s)가 나오고, 이를 카탈로그상 약속 속도와 비교한 것이 MFU(Model FLOPs Utilization)입니다.
MFU = (모델에 유용한 실측 FLOP/s) ÷ (하드웨어 약속 FLOP/s)
def mfu(flops, seconds, promised_flops_per_s):
return (flops / seconds) / promised_flops_per_s
# 예: H100 dense BF16 ≈ 9.9e14 FLOP/s
mfu(flops=2*1024*51200*51200, seconds=0.03, promised_flops_per_s=9.9e14) # ≈ 0.18
0.5 이상이면 좋고, 0.05면 형편없는 수준입니다. 통신·오버헤드를 빼고 순수 연산만 보기 때문에 100%엔 닿지 못합니다. 주의할 함정 둘:
- 데이터 타입에 따라 약속 FLOP/s가 다르다. H100에서 FP32는 BF16/FP8보다 몇 배 느립니다. 빠른 행렬곱은 텐서 코어(Tensor Core) 위에서 돕니다(PyTorch·
torch.compile이 알아서 씀). - 희소성(sparsity) 별표. 카탈로그의 큰 숫자(예: H100 1979 TFLOP/s)는 구조적 2:4 희소성 가정입니다. 밀집 행렬에선 정확히 절반(~990 TFLOP/s)만 나옵니다 — 강의의 표현으로는 “마케팅 부서가 쓰는 숫자”.
역전파의 비용과 6ND
순전파만 셌습니다. 학습은 역전파(backward)가 더해집니다. 2층 선형망에서 한 가중치 W에 대해 역전파가 하는 일은 두 가지입니다.
- 가중치 그래디언트
dL/dW계산 → 행렬곱 한 번(2mnp) - 입력(활성화) 그래디언트
dL/dh계산 → 역전파를 더 흘려보내기 위해 또 한 번(2mnp)
flowchart LR
subgraph FWD["순전파 (× 2ND)"]
X["X"] -->|"·W₁"| H["h"] -->|"·W₂"| Y["ŷ"] --> L["loss"]
end
subgraph BWD["역전파 (× 4ND)"]
L2["dL"] -->|"dL/dW₂ : 2mnp"| W2g["W₂ grad"]
L2 -->|"dL/dh : 2mnp"| Hg["h grad"]
Hg -->|"계속 전파"| W1g["W₁ grad"]
end
L -.-> L2
가중치마다 2번의 행렬곱이 필요하므로:
역전파 FLOPs ≈ 4 × N × D = 4ND (순전파의 2배)
둘을 합치면 학습 1스텝의 총 연산량, 냅킨 ①의 6이 나옵니다.
학습 FLOPs ≈ 순전파(2ND) + 역전파(4ND) = 6 × 파라미터 × 토큰 = 6ND
이것이 강의 맨 앞 “총 FLOPs = 6 × 파라미터 × 토큰”의 정체입니다. (참고: GPT-3 ≈ 3.1e23 FLOPs, GPT-4 ≈ 2e25 FLOPs(추정).)
전체 메모리를 한 장에
이제 메모리도 전부 모읍니다. num_layers개의 D×D 선형층 + 헤드를 가진 모델이라면:
def total_memory_bytes(D, num_layers, B, bytes_per_value=4):
n_params = num_layers * D * D + D # 가중치
n_gradients = n_params # 그래디언트(파라미터와 동수)
n_optim = 2 * n_params # Adam: m, v 두 벌
n_activations = B * D * num_layers # 활성화(배치·시퀀스 의존)
total = n_params + n_gradients + n_optim + n_activations
return total * bytes_per_value
파라미터 + 그래디언트 + 옵티마이저 상태 + 활성화 — 이 네 가지가 학습 메모리의 전부입니다. 트랜스포머는 항이 더 많아질 뿐 형태는 같습니다.
활성화를 왜 저장할까요? i층의 그래디언트가 그 층의 활성화에 의존하기 때문입니다. 메모리가 빠듯하면 활성화 체크포인팅(activation checkpointing) — 저장 대신 역전파 때 재계산 — 으로 메모리를 연산과 맞바꿉니다(뒤 병렬화 강의에서 다룹니다).
냅킨 계산 다시 풀기
이제 두 질문을 직접 풉니다.
# ① 학습 시간 — 70B 모델, 15T 토큰, H100 1024장
total_flops = 6 * 70e9 * 15e12 # 6ND = 6.3e24
h100_bf16 = 9.9e14 # dense BF16 FLOP/s (희소성 별표 제거)
mfu = 0.5
flops_per_day = h100_bf16 * mfu * 1024 * 86400 # ≈ 4.38e22 FLOP/day
days = total_flops / flops_per_day # ≈ 144일
# ② 최대 모델 크기 — H100 8장, AdamW
hbm_bytes = 8 * 80e9 # 640 GB
bytes_per_param = 16 # 4+4+4+4 (FP32, 활성화 제외)
max_params = hbm_bytes / bytes_per_param # ≈ 40e9 = 40B
두 답 모두 앞에서 유도한 6ND와 16바이트/파라미터가 핵심입니다. 이것이 강의가 심으려는 마인드셋입니다 — 모델을 돌리기 전에 비용을 숫자로 먼저 안다.
모델·옵티마이저·학습 루프
회계 외에 PyTorch 프리미티브도 짚고 갑니다(과제 1의 토대).
- 초기화.
randn을 그대로 쓰면 출력이√(hidden_dim)에 비례해 커져 학습이 불안정해집니다.1/√(입력 차원)로 스케일하면 출력이N(0,1)근처로 안정됩니다 — 상수배까지 Xavier 초기화이며, 보통±3으로 truncate해 꼬리값을 막습니다. - 재현성. 초기화·드롭아웃·데이터 순서마다 랜덤 시드를 고정하세요. 소스별로 다른 시드를 주면 “초기화만 고정, 데이터는 변화” 같은 디버깅이 가능합니다.
- 데이터 로딩. 토큰은 정수 열이라 numpy 배열로 직렬화합니다. Llama 데이터(2.8 TB)를 통째로 못 올리니
np.memmap으로 필요할 때만 읽어 배치를 샘플링합니다. - 옵티마이저. SGD → 모멘텀 → AdaGrad → RMSProp → Adam(모멘텀 + RMSProp, 2014)으로 이어집니다.
Optimizer를 상속해step()에서 파라미터별state(예:m,v)를 갱신합니다. - 체크포인팅. 학습은 길고 언젠가 크래시합니다. 모델 + 옵티마이저 상태 + 현재 스텝을 주기적으로 저장하세요.
# 전형적인 학습 루프 — 회계의 대상이 된 그 루프
for step in range(num_steps):
x, y = get_batch(data, batch_size)
loss = model(x).mse_loss(y) # 순전파 (2ND)
loss.backward() # 역전파 (4ND)
optimizer.step() # 파라미터 갱신 (옵티마이저 상태 사용)
optimizer.zero_grad()
성능·복잡도 노트
- 행렬곱이 연산을 지배한다.
2mnp만 세면 모델 연산의 대부분이 잡힙니다. 행렬이 너무 작으면 다른 연산이 비중을 키우지만, 그건 하드웨어를 못 쓰는 나쁜 영역입니다. - 6ND와 16바이트가 두 개의 북극성. 학습 연산은
6 × 파라미터 × 토큰, AdamW 학습 메모리는16바이트 × 파라미터(+활성화). 이 둘로 시간·비용·하드웨어 한계를 즉석에서 추정합니다. - MFU로 낭비를 잡는다. 0.5 미만이면 어딘가 새고 있는 것. 단, 약속 FLOP/s는 데이터 타입·희소성 별표에 휘둘리니 항상 직접 벤치마크하세요.
- 정밀도는 시스템과 모델 설계가 만나는 지점. FP32는 안전하되 무겁고, BF16/FP8은 빠르되 불안정합니다. 추론(inference)에서는 학습보다 훨씬 공격적인 양자화가 통합니다.
요약
- 2강의 주제는 자원 회계 — 모델을 느낌이 아니라 메모리와 FLOPs라는 숫자로 다루는 마인드셋입니다.
- 메모리: 텐서 = 값 개수 × 바이트. 타입(FP32/BF16/FP8)이 첫 레버이고, 학습 메모리는 파라미터 + 그래디언트 + 옵티마이저 + 활성화, AdamW 기준 16바이트/파라미터(+활성화).
- 연산: 행렬곱 =
2mnp. 순전파2ND+ 역전파4ND= 학습 6ND. 효율은 MFU(실측 ÷ 약속 FLOP/s, 0.5↑ 양호)로 측정합니다. - 이 6과 16 두 숫자로, “며칠 걸리나·얼마나 큰 모델이 들어가나”를 냅킨 한 장에 답합니다.
- 텐서 뷰(공짜) vs
contiguous(복사),1/√d초기화,memmap데이터 로딩, 모델+옵티마이저+스텝 체크포인팅이 과제 1의 실전 토대입니다.
다음 학습 (Next Learning)
- 3단계: 아키텍처와 하이퍼파라미터 — Pre-norm·RMSNorm·SwiGLU·RoPE 등 현대 트랜스포머의 표준 설계 (상세 포스트 작성 예정)
- CS336 1강 — 개요와 토크나이제이션 — 코스의 “효율” 주제가 시작된 곳
- CS336 커리큘럼 — 전체 17단계 지도와 진행 현황