CS336 6강 — 커널과 Triton: 측정하고, 퓨즈하라
CS336-LLM-From-Scratch 시리즈의 6단계입니다. 전체 지도는 CS336 커리큘럼에서 볼 수 있습니다. (5강 — GPU에서 이어집니다.)
5강이 “GPU가 어떻게 동작하나”였다면, 6강(Tatsunori Hashimoto)은 “그 위에서 빠른 코드를 어떻게 짜나“입니다. 실전 강의입니다 — 벤치마킹·프로파일링으로 병목을 찾고, GELU 커널을 다섯 가지 방식으로 써 보며 커널 퓨전(kernel fusion)의 효과를 직접 잽니다. 한 문장으로 요약하면 — 추측하지 말고 측정하라, 그리고 퓨즈하라.
한눈에 보기
성능 작업의 흐름은 단순합니다. 거친 도구(벤치마크)로 “느리다”를 확인하고, 정밀한 도구(프로파일러)로 “어디가 느린지”를 찾은 뒤, 그 지점을 퓨즈합니다. 퓨전을 구현하는 길은 넷 — 그 트레이드오프가 이 글의 핵심입니다.
flowchart TD
Bench["① 벤치마크<br/>(거친 측정: 느린가?)"] --> Prof["② 프로파일<br/>(정밀 측정: 어디가?)"]
Prof --> Fuse{"③ 퓨즈할 방법"}
Fuse --> TC["torch.compile<br/>(자동 · 기본값)"]
Fuse --> TR["Triton<br/>(Python · 블록 관점)"]
Fuse --> CU["CUDA C++<br/>(저수준 · 스레드 관점)"]
Fuse --> PT["PyTorch 내장 커널<br/>(이미 퓨즈됨)"]
핵심 교훈은 처음부터 하나입니다 — “여기가 병목일 것 같아”라며 세 시간 최적화하지 말 것. 프로파일러를 켜면 실제 병목이 보이고, 거기에 노력을 쏟으면 됩니다.
먼저 측정하라: 벤치마킹
벤치마킹은 함수의 벽시계 시간(wall-clock time)을 재는 일입니다. 그런데 GPU에선 두 가지를 빠뜨리면 엉터리 숫자(예: 거대한 행렬곱이 “즉시” 끝남)가 나옵니다.
def benchmark(run, warmup=3, trials=10):
for _ in range(warmup): # ① 워밍업: 첫 실행의 컴파일·초기화를 측정에서 제외
run()
torch.cuda.synchronize() # ② CPU/GPU 상태 맞추기 (둘은 비동기로 따로 돈다)
times = []
for _ in range(trials):
t0 = time.time()
run()
torch.cuda.synchronize() # GPU가 끝날 때까지 기다린 뒤에 시간 기록
times.append(time.time() - t0)
return sum(times) / len(times) # 여러 번 재서 평균(발열 등 변동 흡수)
- 워밍업(warm-up). PyTorch 코드를 처음 실행하면 그 자리에서 머신 코드를 컴파일하고 초기화합니다. 그 시작 비용이 아니라 정상 상태(steady state) 속도를 재야 합니다.
torch.cuda.synchronize(). CPU와 GPU는 독립적인 연산 장치라 따로 돕니다. CPU가 CUDA 커널을 GPU에 던져 놓고는 기다리지 않고 앞서 달려갑니다. 동기화하지 않으면 GPU 실행 시간을 재지 못합니다.
행렬곱은 크기에 슈퍼리니어(super-linear)로 느려지지만, 작은 크기에선 커널 런치·CPU→GPU 전송 같은 상수 오버헤드 때문에 시간이 잘 안 줄어듭니다.
어디가 느린가: 프로파일링
벤치마크는 “느리다”만 알려줄 뿐, 어디서 시간을 쓰는지는 모릅니다. 그건 프로파일러의 몫입니다. PyTorch에서 a + b 한 줄을 호출해도, 빙산 아래에는 이런 일들이 벌어집니다.
a + b (Python)
└─ ATen 래퍼 (C++ 인터페이스)
└─ vectorized_elementwise_kernel ← 실제 덧셈 (GPU)
└─ cudaLaunchKernel ← CPU가 명령을 GPU로 전송
└─ cudaDeviceSynchronize ← GPU 완료를 기다림
행렬곱은 CUTLASS/cuBLAS의 특정 타일 크기 커널로 디스패치되고, 크기마다 다른 커널이 불립니다(그래서 torch.compile이 마이크로벤치마크로 최적 커널을 골라 ~10%를 공짜로 줍니다). cdist처럼 복합 연산은 mm(78%)·pow·sum으로 분해돼, 어디를 최적화할지 한눈에 보입니다.
CPU와 GPU는 따로 논다
NVIDIA Nsight Systems 같은 본격 프로파일러로 보면, CPU와 GPU가 별개의 타임라인으로 돕니다. CPU는 커널을 GPU 큐에 던져 놓고 앞서 달려갑니다 — GPU가 1층을 돌 때 CPU는 이미 9층을 큐에 넣고 있습니다(큐 깊이까지).
flowchart LR
subgraph CPU["CPU (앞서 달림)"]
C0["layer 0"] --> C1["layer 1"] --> C9["… layer 9"]
end
subgraph GPU["GPU (큐를 소비)"]
G0["layer 0"] --> G1["layer 1 (실행 중)"]
end
C0 -.커널 큐잉.-> G0
C1 -.큐잉.-> G1
이 비동기성 덕분에 Python의 느림이 병목이 안 됩니다 — CPU는 그저 커널을 큐에 넣으면 되니까요. 단, print(loss)처럼 GPU 결과를 CPU가 기다려야 하는 코드를 끼우면 동기화가 강제돼, 심하면 CPU 병목이 생깁니다. 프로파일러 없이는 보이지 않는 함정입니다.
커널 퓨전: GELU를 다섯 가지로
이제 핵심입니다. 5강의 퓨전(데이터를 글로벌 메모리로 왕복시키지 않고 연산 유닛에 둔 채 처리)을 GELU로 실험합니다. 같은 GELU를 다섯 가지로 구현해 큰 입력에서 시간을 잽니다.
| 구현 | 시간 | GPU 커널 |
|---|---|---|
수동 PyTorch (0.5*x*(1+tanh(...))) |
8.1 ms | 여러 개 (곱 ×3·tanh·add…) |
PyTorch 내장 (F.gelu) |
1.1 ms | 단일 퓨즈드 |
| CUDA C++ (직접 작성) | 1.8 ms | 단일 |
| Triton (직접 작성) | 1.85 ms | 단일 |
| torch.compile | 1.47 ms | 단일 (Triton 자동 생성) |
수동 PyTorch가 8배 느린 이유는 명확합니다 — x³, tanh, 상수곱 하나하나가 별도 CUDA 커널이라, 매번 글로벌 메모리를 왕복합니다. 나머지는 모두 단일 퓨즈드 커널이라 빠릅니다.
CUDA C++: 스레드 관점
가장 저수준. 각 스레드가 원소 하나를 맡고, 자기 전역 좌표를 직접 계산합니다.
// 커널(GPU): 각 스레드가 원소 하나를 처리
__global__ void gelu_kernel(const float* in, float* out, int n) {
int i = blockIdx.x * blockDim.x + threadIdx.x; // 블록 시작 + 블록 내 오프셋 = 전역 좌표
if (i < n) { // 경계 검사 (마지막 블록의 초과 스레드 보호)
float x = in[i];
out[i] = 0.5f * x * (1.0f + tanhf(0.7978845608f * (x + 0.044715f * x*x*x)));
}
}
// 래퍼(CPU): x가 CUDA 텐서·contiguous인지 assert → empty_like로 y 할당
// → grid = ceil(n / block_size) → gelu_kernel<<<grid, block_size>>>(x, y, n)
__global__이 CUDA 커널임을 표시합니다. 래퍼는 입력이 연속(contiguous) 메모리인지 확인하고(인덱싱 산술이 그걸 전제), 출력은 zeros가 아니라 empty_like로 잡아(어차피 덮어쓰니) 한 번의 초기화를 아낍니다. 결과는 8.1 → 1.8 ms — C 코드가 그리 어렵지 않은데도 큰 이득입니다.
Triton: 블록 관점
매번 C++로 내려가는 건 번거롭습니다. Triton(OpenAI, 2021)은 GPU 프로그래밍을 Python으로 끌어올립니다. 핵심 차이 — 스레드가 아니라 블록 관점으로 짜고, 코얼레싱·공유 메모리·스레드 관리를 Triton이 자동 처리합니다(SM 간 스케줄링만 수동).
import triton
import triton.language as tl
@triton.jit
def gelu_kernel(x_ptr, y_ptr, n, BLOCK: tl.constexpr):
pid = tl.program_id(0) # 스레드가 아니라 '블록' id
offsets = pid * BLOCK + tl.arange(0, BLOCK) # 좌표가 단일 값이 아니라 '벡터'다
mask = offsets < n # 경계 밖 원소 마스킹
x = tl.load(x_ptr + offsets, mask=mask) # 코얼레싱은 Triton이 알아서
y = 0.5 * x * (1 + tl.math.tanh(0.7978845608 * (x + 0.044715 * x*x*x)))
tl.store(y_ptr + offsets, y, mask=mask)
스레드별 오프셋이 아니라 오프셋 벡터(block_start + arange(BLOCK))로 블록 전체를 한 번에 다룹니다. CUDA와 거의 같은 속도(1.85 ms)지만 Python으로 짜고 디버깅하기 훨씬 쉽고, 컴파일된 PTX(거의 기계어)를 들여다보면 각 스레드가 ld.global로 4개 값을 한 번에 읽는(코얼레싱) 모습이 보입니다.
torch.compile: 그냥 맡기기
대부분의 경우 직접 커널을 쓸 필요조차 없습니다. torch.compile은 평범한 PyTorch 코드를 받아 자동으로 퓨전합니다 — 내부적으로 Triton을 생성하며, 우리가 손으로 짠 것보다 살짝 더 최적화돼 1.47 ms가 나옵니다.
리덕션: Triton softmax
지금까지는 원소별(elementwise) 연산이라 쉬웠습니다. softmax는 행 전체를 더하는 리덕션(reduction)이 들어갑니다. 영리한 블록 설계 — 블록 하나 = 행 하나(grid = 행 수, block_size = 열 수의 다음 2의 거듭제곱). 한 행이 SM에 통째로 들어가면, SM 안에서 행을 더하고 나누면 끝입니다.
@triton.jit
def softmax_kernel(x_ptr, y_ptr, n_cols, stride, BLOCK: tl.constexpr):
row = tl.program_id(0) # 블록 하나 = 행 하나
cols = tl.arange(0, BLOCK)
mask = cols < n_cols
x = tl.load(x_ptr + row * stride + cols, mask=mask, other=-float("inf"))
x = x - tl.max(x, axis=0) # 수치 안정
e = tl.exp(x)
y = e / tl.sum(e, axis=0) # 행 정규화
tl.store(y_ptr + row * stride + cols, y, mask=mask)
연산이 SM에 깔끔히 들어가면, Triton 코드는 보통의 Python처럼 보입니다 — 약간의 load/store와 블록 좌표 계산만 더해질 뿐입니다.
언제 무엇을 쓰나
- 기본값은
torch.compile. 현대 JIT 컴파일러는 연산자 퓨전과 행렬곱 최적화(모양을 알면 최적 커널 선택)를 아주 잘합니다. 웬만하면 이걸 넘기 어렵습니다. - Triton은 비자명한 부분에. 컴파일러가 못 찾는 최적화 — 예컨대 Flash Attention(2·3) 같은 — 이 필요할 때 꺼냅니다. 모델의 모든 부분에 CUDA 커널을 짜는 건 시간 낭비입니다.
- CUDA C++는 최후의 수단. Triton이 못 주는 하드웨어 수준 제어(예: H100 전용 Flash Attention 3 최적화)가 필요할 때.
성능·복잡도 노트
- 측정이 1순위. “여기가 병목 같다”는 직관은 자주 틀립니다. 벤치마크(거친)와 프로파일러(정밀)로 실제 병목을 찾고 거기에만 노력을 쏟으세요.
- 퓨전이 핵심 무기. 여러 작은 연산을 한 커널로 합쳐 글로벌 메모리 왕복을 없애는 것 — GELU에서 8.1 ms → ~1.1–1.85 ms.
- 벤치마킹 함정 둘. 워밍업과
cuda.synchronize를 빠뜨리면 숫자가 거짓말을 합니다. - CPU/GPU 비동기. CPU가 커널을 큐에 넣고 앞서 달리는 덕에 Python이 병목이 안 됩니다. 단
print(loss)류의 동기화는 그 이점을 깹니다.
요약
- 6강은 GPU 코드를 빠르게 만드는 실전 — 측정하고, 퓨즈하라.
- 벤치마킹: 벽시계 시간. 워밍업 +
torch.cuda.synchronize()필수(CPU/GPU는 비동기). - 프로파일링: 빙산 아래(ATen → 커널 → 런치 → 싱크)를 본다. Nsight로 CPU가 GPU보다 앞서 달리는 모습까지.
print(loss)는 동기화를 강제. - GELU 다섯 가지: 수동 PyTorch(8.1 ms) ≫ 내장·CUDA·Triton·torch.compile(~1.1–1.85 ms). 차이는 커널 퓨전.
- Triton: Python으로, 블록 관점, 코얼레싱·공유 메모리 자동. 리덕션(softmax)은 블록=행으로.
- 선택: 기본
torch.compile→ 비자명하면 Triton(Flash Attention) → 최후에 CUDA C++.
다음 학습 (Next Learning)
- 7단계: 병렬화 1 — 데이터 병렬 — 한 GPU를 넘어 여러 장치로, 집합 통신과 ZeRO/FSDP (상세 포스트 작성 예정)
- CS336 5강 — GPU: 병목은 연산이 아니라 메모리다 — 퓨전·타일링·Flash Attention의 토대
- CS336 커리큘럼 — 전체 17단계 지도와 진행 현황