CS336 2강 — PyTorch와 자원 회계: 6ND와 메모리를 냅킨에 계산하기

CS336-LLM-From-Scratch 시리즈의 2단계입니다. 전체 지도는 CS336 커리큘럼에서 볼 수 있습니다. (1강 — 토크나이제이션에서 이어집니다.)

자원 회계 — 모델을 메모리(바이트)와 연산(FLOPs) 두 통화로 저울질하고, 둘 다 달러로 환산한다 자원 회계 — 냅킨 한 장 FLOPs · 바이트 = $ 메모리 B B 바이트 연산 FLOPs 16 B/param 6ND 효율 = 달러 바이트를 줄이면 ↓ FLOPs를 짜내면 ↑
이 글의 척추: 자원 회계는 모델을 두 통화 — 메모리(바이트)와 연산(FLOPs) — 로 저울에 단다. 받침에 새겨진 두 숫자가 이 강의의 북극성, 학습 메모리 ‘16바이트/파라미터’와 학습 연산 ‘6ND’다. 그리고 두 통화는 결국 하나의 단위, 달러로 환산된다.

1강이 “효율”이라는 코스의 주제를 선언했다면, 2강은 그 효율을 측정하는 법을 가르칩니다. 보통 우리는 모델을 짜고, 돌리고, 되는 대로 둡니다. 하지만 학습이 수억 달러로 번지는 순간 그 태도는 통하지 않습니다 — FLOPs와 메모리는 곧 달러이기 때문입니다. 이 강의(Percy Liang)는 텐서에서 시작해 학습 루프까지 PyTorch 프리미티브를 훑되, 매 단계에서 “이건 메모리를 얼마, 연산을 얼마 쓰는가”를 끈질기게 따집니다. 트랜스포머는 다루지 않습니다 — 선형 모델만으로도 핵심 회계는 전부 나옵니다.

한눈에 보기

자원 회계는 두 축으로 갈라집니다 — 메모리(무엇을 얼마나 저장하나)연산(FLOPs를 얼마나 쓰나). 이 둘을 알면 “이 학습이 며칠 걸리고, 이 GPU에 얼마나 큰 모델이 들어가나”라는 실전 질문에 냅킨 한 장으로 답할 수 있습니다.

flowchart TD
    T["텐서<br/>(원자: 값 개수 × 바이트)"]

    T --> MEM["메모리 회계"]
    T --> CMP["연산 회계"]

    MEM --> M1["부동소수점 타입<br/>FP32·BF16·FP8"]
    MEM --> M2["파라미터+그래디언트<br/>+옵티마이저+활성화"]

    CMP --> C1["행렬곱 = 2mnp"]
    CMP --> C2["순전파 2ND · 역전파 4ND"]

    M2 --> Q2["냅킨 ②<br/>이 GPU에 들어가는<br/>최대 모델 크기"]
    C2 --> SixND["학습 ≈ 6ND<br/>(6 × 파라미터 × 토큰)"]
    SixND --> Q1["냅킨 ①<br/>학습에 걸리는 시간"]
    SixND --> MFU["MFU<br/>실측 ÷ 약속 FLOP/s"]

이 강의가 가르치는 두 개의 숫자는 결국 이것입니다 — 학습 연산량 ≈ 6ND, 그리고 AdamW 학습 메모리 ≈ 16바이트/파라미터.

냅킨 계산: 두 개의 질문

강의는 두 질문으로 문을 엽니다. 지금은 답만 보고, 아래에서 직접 유도합니다.

  1. 70B 밀집(dense) 트랜스포머를 15T 토큰으로, H100 1,024장에서 학습하면 며칠 걸리나? → 총 FLOPs = 6 × 70e9 × 15e12. 이를 (H100 FLOP/s × MFU 0.5 × 1024장 × 하루 초)로 나누면 약 144일.
  2. H100 8장에서 AdamW로 (특별한 기교 없이) 학습 가능한 최대 모델은? → 메모리 8 × 80GB = 640GB파라미터당 16바이트로 나누면 약 40B 파라미터(활성화 무시한 거친 추정).

두 답 모두 곱셈·나눗셈 몇 번이 전부입니다. 핵심은 그 안에 들어가는 616이 어디서 오는가입니다.

메모리 회계

텐서와 바이트

딥러닝의 모든 것 — 파라미터·그래디언트·옵티마이저 상태·활성화·데이터 — 은 텐서에 담깁니다. 텐서의 메모리는 단순합니다.

메모리 = (원소 개수) × (원소당 바이트 수)

import torch
x = torch.zeros(4, 8)        # 기본 dtype: float32
x.numel()                    # 32 (원소 개수)
x.element_size()             # 4 (바이트) — float32 = 32비트 = 4바이트
x.numel() * x.element_size() # 128 바이트

행렬 하나가 작아 보여도, GPT-3의 한 FFN 행렬은 2.3 GB에 이릅니다. 그래서 원소당 바이트를 줄이는 것 — 부동소수점 타입 선택 — 이 첫 번째 효율 레버입니다.

부동소수점 타입

타입 비트 부호/지수/가수 바이트 특징
FP32 32 1 / 8 / 23 4 기본값, “full precision”. 안전하지만 무겁다
FP16 16 1 / 5 / 10 2 half precision. dynamic range가 좁아 작은 값이 underflow(0으로)
BF16 16 1 / 8 / 7 2 brain float(2018). FP32의 dynamic range + 낮은 해상도 — DL 연산의 사실상 표준
FP8 8 1/4/3 또는 1/5/2 1 2022, H100~. 매우 거칠다. 안정성 주의
torch.tensor([1e-8], dtype=torch.float16)   # → tensor([0.])  ← underflow!
torch.tensor([1e-8], dtype=torch.bfloat16)  # → 0이 아님       ← dynamic range 보존
FP32 / FP16 / BF16 / FP8 비트 배치 — 부호 1 · 지수(dynamic range) · 가수(해상도) 부호 1비트 지수 — dynamic range 가수 — 해상도 FP32 4바이트 지수 8 가수 23 FP16 2바이트 지수 5 가수 10 지수 좁음 → underflow BF16 2바이트 지수 8 가수 7 지수는 FP32와 동일 → range 보존 FP8 1바이트 지수 4 가수 3 E4M3 예 (매우 거칠다) 지수 비트 ↑ = 표현 가능한 범위(dynamic range) ↑ — 작은 값이 0으로 underflow하지 않는다 가수 비트 ↑ = 두 수 사이의 해상도(정밀도) ↑ — 같은 범위 안을 더 촘촘히 쪼갠다
같은 척도 위의 네 타입: 막대 길이가 곧 비트 수다. 핵심은 BF16이 FP32의 지수 8비트를 그대로 물려받는다는 점(점선) — FP16보다 가수(해상도)는 떨어지지만 dynamic range는 FP32만큼 넓어, 작은 값이 0으로 underflow하지 않는다. 딥러닝이 해상도보다 range를 택하는 이유가 한눈에 보인다.

BF16의 통찰은 “딥러닝은 해상도보다 dynamic range가 중요하다”입니다. 지수 비트를 FP32만큼 유지하고 가수를 깎아, 같은 2바이트로 안정성을 챙깁니다. 다만 파라미터와 옵티마이저 상태는 FP32로 둬야 학습이 무너지지 않습니다 — BF16은 순전파처럼 잠깐 쓰고 버리는 전이용(transitory) 표현으로 봅니다(→ 혼합 정밀도, mixed precision).

파라미터당 16바이트 (AdamW)

학습 중 한 파라미터가 끌고 다니는 메모리를 세어 보면, 냅킨 ②의 16이 나옵니다(전부 FP32 기준).

항목 바이트/파라미터
파라미터(weights) 4
그래디언트(gradients) 4
옵티마이저 상태 m (1차 모멘트) 4
옵티마이저 상태 v (2차 모멘트) 4
합계 16

여기에 활성화(activations)가 더해지는데, 이건 배치 크기·시퀀스 길이에 따라 달라져 위 16에는 빠져 있습니다(그래서 냅킨 ②가 “거친 추정”인 이유).

파라미터당 16바이트 = 파라미터 4 + 그래디언트 4 + m 4 + v 4 (FP32). 활성화는 별도 가변. 0 4 8 12 16 파라미터 4 그래디언트 4 옵티마이저 m 4 옵티마이저 v 4 = 16 바이트 / 파라미터 (모두 FP32) 활성화 가변 배치 · 시퀀스 길이 의존 → 16에 포함되지 않음
학습 메모리의 두 번째 북극성: 파라미터 하나가 끌고 다니는 4가지 — 파라미터·그래디언트·옵티마이저 m·옵티마이저 v — 가 각 4바이트씩 쌓여 16바이트/파라미터가 된다(모두 FP32). 활성화는 배치·시퀀스 길이에 따라 늘었다 줄었다 하는 별도 가변 블록이라 이 16에는 빠져 있다(그래서 냅킨 ②가 ‘거친 추정’인 이유).

텐서의 내부: 스토리지와 뷰

회계를 정확히 하려면 “언제 복사가 일어나는가”를 알아야 합니다. PyTorch 텐서는 수학적 객체가 아니라 할당된 메모리(storage)를 가리키는 포인터 + 메타데이터(stride)입니다.

x = torch.arange(6).reshape(2, 3)   # [[0,1,2],[3,4,5]]
x.stride()        # (3, 1) — 다음 행으로 가려면 3칸, 다음 열로 가려면 1칸
# 원소 [1,2](값 5)의 storage 오프셋 = 1*3 + 2*1 = 5

y = x[0]          # 첫 행 — 복사 아님! 같은 storage를 보는 뷰(view)
x[0, 0] = 99      # x를 바꾸면 y도 바뀐다 (포인터 공유)

x.t().is_contiguous()   # False — 전치는 storage를 건드리지 않고 stride만 바꾼다
x.t().contiguous()      # 여기서 비로소 복사가 일어난다

슬라이싱·view·transpose라서 메모리를 새로 쓰지 않습니다 — 공짜이니 마음껏 변수로 쪼개 가독성을 높이세요. 다만 contiguous·reshape복사를 일으킬 수 있으니 회계할 때 주의합니다.

실전 팁: 차원을 -1·-2 인덱스로 다루면 버그가 잦습니다. 강의는 einops(einsum/rearrange/reduce)로 차원에 이름을 붙이길 권합니다 — “행렬곱에 좋은 부기(bookkeeping)를 더한 것”입니다.

연산 회계: 행렬곱은 2mnp

FLOPs vs FLOP/s

먼저 용어를 분리합니다. FLOPs(소문자 s)는 수행한 부동소수점 연산의 개수(연산량)이고, FLOP/s초당 연산 수(하드웨어 속도)입니다. 이 글은 속도를 항상 /s로 적습니다.

딥러닝 연산량은 거의 전부 행렬곱이 지배합니다. (m×n)·(n×p) 행렬곱의 FLOPs는:

행렬곱 FLOPs = 2 × m × n × p (곱셈 1 + 덧셈 1, 그래서 ×2)

# 선형 모델: 데이터 X(B×D) · 가중치 W(D×K) → (B×K)
def matmul_flops(B, D, K):
    return 2 * B * D * K        # 2 × (세 차원의 곱)

이 식을 머신러닝 언어로 옮기면 강의의 첫 통찰이 나옵니다. B는 데이터 포인트(토큰) 수, D×K는 파라미터 수이므로:

순전파 FLOPs ≈ 2 × (토큰 수 N) × (파라미터 수 D) = 2ND

(시퀀스 길이가 너무 길지 않다면 트랜스포머에도 대략 성립합니다.) 다른 연산들은 텐서 크기에 선형이라, 충분히 큰 모델에선 행렬곱 외에는 무시할 만합니다 — 냅킨 계산이 단순해지는 이유입니다.

MFU: 하드웨어를 얼마나 짜냈나

연산량(FLOPs)을 시간으로 나누면 실측 속도(FLOP/s)가 나오고, 이를 카탈로그상 약속 속도와 비교한 것이 MFU(Model FLOPs Utilization)입니다.

MFU = (모델에 유용한 실측 FLOP/s) ÷ (하드웨어 약속 FLOP/s)

def mfu(flops, seconds, promised_flops_per_s):
    return (flops / seconds) / promised_flops_per_s

# 예: H100 dense BF16 ≈ 9.9e14 FLOP/s
mfu(flops=2*1024*51200*51200, seconds=0.03, promised_flops_per_s=9.9e14)  # ≈ 0.18

0.5 이상이면 좋고, 0.05면 형편없는 수준입니다. 통신·오버헤드를 빼고 순수 연산만 보기 때문에 100%엔 닿지 못합니다. 주의할 함정 둘:

  • 데이터 타입에 따라 약속 FLOP/s가 다르다. H100에서 FP32는 BF16/FP8보다 몇 배 느립니다. 빠른 행렬곱은 텐서 코어(Tensor Core) 위에서 돕니다(PyTorch·torch.compile이 알아서 씀).
  • 희소성(sparsity) 별표. 카탈로그의 큰 숫자(예: H100 1979 TFLOP/s)는 구조적 2:4 희소성 가정입니다. 밀집 행렬에선 정확히 절반(~990 TFLOP/s)만 나옵니다 — 강의의 표현으로는 “마케팅 부서가 쓰는 숫자”.

역전파의 비용과 6ND

순전파만 셌습니다. 학습은 역전파(backward)가 더해집니다. 2층 선형망에서 한 가중치 W에 대해 역전파가 하는 일은 두 가지입니다.

  1. 가중치 그래디언트 dL/dW 계산 → 행렬곱 한 번(2mnp)
  2. 입력(활성화) 그래디언트 dL/dh 계산 → 역전파를 더 흘려보내기 위해 또 한 번(2mnp)
flowchart LR
    subgraph FWD["순전파 (× 2ND)"]
        X["X"] -->|"·W₁"| H["h"] -->|"·W₂"| Y["ŷ"] --> L["loss"]
    end
    subgraph BWD["역전파 (× 4ND)"]
        L2["dL"] -->|"dL/dW₂ : 2mnp"| W2g["W₂ grad"]
        L2 -->|"dL/dh : 2mnp"| Hg["h grad"]
        Hg -->|"계속 전파"| W1g["W₁ grad"]
    end
    L -.-> L2

가중치마다 2번의 행렬곱이 필요하므로:

역전파 FLOPs ≈ 4 × N × D = 4ND (순전파의 2배)

둘을 합치면 학습 1스텝의 총 연산량, 냅킨 ①의 6이 나옵니다.

학습 FLOPs ≈ 순전파(2ND) + 역전파(4ND) = 6 × 파라미터 × 토큰 = 6ND

이것이 강의 맨 앞 “총 FLOPs = 6 × 파라미터 × 토큰”의 정체입니다. (참고: GPT-3 ≈ 3.1e23 FLOPs, GPT-4 ≈ 2e25 FLOPs(추정).)

전체 메모리를 한 장에

이제 메모리도 전부 모읍니다. num_layers개의 D×D 선형층 + 헤드를 가진 모델이라면:

def total_memory_bytes(D, num_layers, B, bytes_per_value=4):
    n_params     = num_layers * D * D + D          # 가중치
    n_gradients  = n_params                         # 그래디언트(파라미터와 동수)
    n_optim      = 2 * n_params                     # Adam: m, v 두 벌
    n_activations = B * D * num_layers              # 활성화(배치·시퀀스 의존)
    total = n_params + n_gradients + n_optim + n_activations
    return total * bytes_per_value

파라미터 + 그래디언트 + 옵티마이저 상태 + 활성화 — 이 네 가지가 학습 메모리의 전부입니다. 트랜스포머는 항이 더 많아질 뿐 형태는 같습니다.

활성화를 저장할까요? i층의 그래디언트가 그 층의 활성화에 의존하기 때문입니다. 메모리가 빠듯하면 활성화 체크포인팅(activation checkpointing) — 저장 대신 역전파 때 재계산 — 으로 메모리를 연산과 맞바꿉니다(뒤 병렬화 강의에서 다룹니다).

냅킨 계산 다시 풀기

이제 두 질문을 직접 풉니다.

# ① 학습 시간 — 70B 모델, 15T 토큰, H100 1024장
total_flops   = 6 * 70e9 * 15e12                 # 6ND = 6.3e24
h100_bf16     = 9.9e14                            # dense BF16 FLOP/s (희소성 별표 제거)
mfu           = 0.5
flops_per_day = h100_bf16 * mfu * 1024 * 86400    # ≈ 4.38e22 FLOP/day
days = total_flops / flops_per_day                # ≈ 144일

# ② 최대 모델 크기 — H100 8장, AdamW
hbm_bytes        = 8 * 80e9                        # 640 GB
bytes_per_param  = 16                              # 4+4+4+4 (FP32, 활성화 제외)
max_params = hbm_bytes / bytes_per_param           # ≈ 40e9 = 40B

두 답 모두 앞에서 유도한 6ND16바이트/파라미터가 핵심입니다. 이것이 강의가 심으려는 마인드셋입니다 — 모델을 돌리기 전에 비용을 숫자로 먼저 안다.

모델·옵티마이저·학습 루프

회계 외에 PyTorch 프리미티브도 짚고 갑니다(과제 1의 토대).

  • 초기화. randn을 그대로 쓰면 출력이 √(hidden_dim)에 비례해 커져 학습이 불안정해집니다. 1/√(입력 차원)로 스케일하면 출력이 N(0,1) 근처로 안정됩니다 — 상수배까지 Xavier 초기화이며, 보통 ±3으로 truncate해 꼬리값을 막습니다.
  • 재현성. 초기화·드롭아웃·데이터 순서마다 랜덤 시드를 고정하세요. 소스별로 다른 시드를 주면 “초기화만 고정, 데이터는 변화” 같은 디버깅이 가능합니다.
  • 데이터 로딩. 토큰은 정수 열이라 numpy 배열로 직렬화합니다. Llama 데이터(2.8 TB)를 통째로 못 올리니 np.memmap으로 필요할 때만 읽어 배치를 샘플링합니다.
  • 옵티마이저. SGD → 모멘텀 → AdaGrad → RMSProp → Adam(모멘텀 + RMSProp, 2014)으로 이어집니다. Optimizer를 상속해 step()에서 파라미터별 state(예: m, v)를 갱신합니다.
  • 체크포인팅. 학습은 길고 언젠가 크래시합니다. 모델 + 옵티마이저 상태 + 현재 스텝을 주기적으로 저장하세요.
# 전형적인 학습 루프 — 회계의 대상이 된 그 루프
for step in range(num_steps):
    x, y = get_batch(data, batch_size)
    loss = model(x).mse_loss(y)   # 순전파 (2ND)
    loss.backward()                # 역전파 (4ND)
    optimizer.step()               # 파라미터 갱신 (옵티마이저 상태 사용)
    optimizer.zero_grad()

성능·복잡도 노트

  • 행렬곱이 연산을 지배한다. 2mnp만 세면 모델 연산의 대부분이 잡힙니다. 행렬이 너무 작으면 다른 연산이 비중을 키우지만, 그건 하드웨어를 못 쓰는 나쁜 영역입니다.
  • 6ND와 16바이트가 두 개의 북극성. 학습 연산은 6 × 파라미터 × 토큰, AdamW 학습 메모리는 16바이트 × 파라미터(+활성화). 이 둘로 시간·비용·하드웨어 한계를 즉석에서 추정합니다.
  • MFU로 낭비를 잡는다. 0.5 미만이면 어딘가 새고 있는 것. 단, 약속 FLOP/s는 데이터 타입·희소성 별표에 휘둘리니 항상 직접 벤치마크하세요.
  • 정밀도는 시스템과 모델 설계가 만나는 지점. FP32는 안전하되 무겁고, BF16/FP8은 빠르되 불안정합니다. 추론(inference)에서는 학습보다 훨씬 공격적인 양자화가 통합니다.

요약

  • 2강의 주제는 자원 회계 — 모델을 느낌이 아니라 메모리와 FLOPs라는 숫자로 다루는 마인드셋입니다.
  • 메모리: 텐서 = 값 개수 × 바이트. 타입(FP32/BF16/FP8)이 첫 레버이고, 학습 메모리는 파라미터 + 그래디언트 + 옵티마이저 + 활성화, AdamW 기준 16바이트/파라미터(+활성화).
  • 연산: 행렬곱 = 2mnp. 순전파 2ND + 역전파 4ND = 학습 6ND. 효율은 MFU(실측 ÷ 약속 FLOP/s, 0.5↑ 양호)로 측정합니다.
  • 616 두 숫자로, “며칠 걸리나·얼마나 큰 모델이 들어가나”를 냅킨 한 장에 답합니다.
  • 텐서 뷰(공짜) vs contiguous(복사), 1/√d 초기화, memmap 데이터 로딩, 모델+옵티마이저+스텝 체크포인팅이 과제 1의 실전 토대입니다.

다음 학습 (Next Learning)

  • 3단계: 아키텍처와 하이퍼파라미터 — Pre-norm·RMSNorm·SwiGLU·RoPE 등 현대 트랜스포머의 표준 설계 (상세 포스트 작성 예정)
  • CS336 1강 — 개요와 토크나이제이션 — 코스의 “효율” 주제가 시작된 곳
  • CS336 커리큘럼 — 전체 17단계 지도와 진행 현황