플래시어텐션‑4: 비대칭 GPU 스케일링을 위한 알고리즘·커널 파이프라인 공동 설계

플래시어텐션‑4: 비대칭 GPU 스케일링을 위한 알고리즘·커널 파이프라인 공동 설계
안내: 본 포스트의 한글 요약 및 분석 리포트는 AI 기술을 통해 자동 생성되었습니다. 정보의 정확성을 위해 하단의 [원본 논문 뷰어] 또는 ArXiv 원문을 반드시 참조하시기 바랍니다.

초록

플래시어텐션‑4는 NVIDIA Blackwell(B200/GB200) GPU의 텐서코어는 두 배로 빨라지고 메모리·지수 연산은 그대로인 비대칭 하드웨어 특성을 고려해, 비동기 MMA와 대형 타일, 소프트웨어 기반 지수 근사, 텐서 메모리 활용, 2‑CTA MMA 모드 등을 결합한 새로운 파이프라인을 제시한다. BF16 기준 1.3×(cuDNN)·2.7×(Triton) 가속, 71% 활용도에 1613 TFLOPs/s를 달성했으며, CuTe‑DSL 기반 구현으로 컴파일 시간을 20‑30배 단축한다.

상세 분석

본 논문은 최신 Blackwell 아키텍처가 보여주는 “비대칭 하드웨어 스케일링” 현상을 정량적으로 분석하고, 그 결과를 바탕으로 어텐션 연산의 병목이 행렬곱(MMA)에서 공유 메모리와 지수 연산으로 이동했음을 밝힌다. Blackwell SM당 256 KB 텐서 메모리(TMEM)와 128×128(또는 256) 크기의 MMA 타일은 레지스터 압박을 크게 완화하고, MMA 결과를 직접 TMEM에 비동기적으로 기록하게 함으로써 연산‑메모리 겹침을 극대화한다. 저자들은 이를 활용해 두 가지 핵심 설계를 제시한다.

  1. 완전 비동기 파이프라인 및 대형 타일: 기존 FlashAttention‑3가 64×128 타일에 의존했지만, Blackwell에서는 128×128(또는 256) 타일을 사용해 한 번에 더 많은 Q·K·V 데이터를 처리한다. 두 개의 워프그룹이 각각 “생산자”와 “소비자” 역할을 수행해, 하나는 MMA를 수행하고 다른 하나는 소프트맥스와 지수 연산을 동시에 진행한다. 이렇게 하면 MMA가 TMEM에 기록되는 동안 공유 메모리에서 행렬‑벡터 연산과 정규화가 진행돼, 연산 유휴 시간을 최소화한다.

  2. 소프트웨어 기반 지수 근사와 조건부 리스케일링: Blackwell의 MUFU는 초당 16개의 지수 연산만 지원하므로, 저자들은 FMA 유닛을 이용한 다항식 근사를 도입해 지수 함수를 가속한다. 또한, 행별 최대값을 뺀 후 소프트맥스를 계산하는 과정에서 필요 없는 리스케일링을 조건부로 건너뛰는 로직을 삽입해 지수 연산량을 평균 30% 이상 감소시켰다.

  3. TMEM 활용 및 2‑CTA MMA 모드: 역전파 단계에서 dQ·dK·dV를 계산할 때 기존 구현은 공유 메모리와 원자적 add에 크게 의존했지만, FlashAttention‑4는 중간 결과를 TMEM에 저장해 공유 메모리 트래픽을 절반 수준으로 줄인다. 2‑CTA 모드에서는 두 CTA가 협력해 하나의 256×128 MMA를 수행하면서 B 타일을 절반씩 로드하므로, 공유 메모리 사용량과 대역폭 요구가 크게 감소한다. 이와 동시에 원자적 감소 연산을 절반으로 줄여 역전파 성능을 크게 끌어올렸다.

  4. 스케줄링·레지스터 최적화: Blackwell의 레지스터 파일은 256개 제한이지만, TMEM에 누적값을 저장함으로써 레지스터 압박을 완화하고, 각 CTA가 필요로 하는 레지스터 수를 30% 이하로 감소시켰다. 또한, 워프‑스페셜라이즈된 스케줄러를 설계해 MMA, 지수, 공유 메모리 복사를 동시 진행하도록 하여 전체 파이프라인 효율을 71%까지 끌어올렸다.

  5. CuTe‑DSL 기반 구현: 기존 C++ 템플릿 기반 커널은 컴파일 시간이 수십 분에 달했지만, 저자들은 Python에 내장된 CuTe‑DSL을 사용해 커널을 선언형으로 기술했다. DSL은 텐서 코어 명령어와 TMEM 할당을 추상화하면서도 저수준 최적화를 유지한다. 결과적으로 전체 프로젝트의 컴파일 시간이 20‑30배 빨라졌으며, 연구자들이 새로운 어텐션 변형을 빠르게 프로토타이핑할 수 있게 되었다.

실험 결과는 B200 GPU에서 BF16 기준 1.3×(cuDNN 9.13)·2.7×(Triton) 가속을 보이며, 1613 TFLOPs/s(이론 최대 2270 TFLOPs)의 71% 활용도를 달성했다. 특히 시퀀스 길이가 8K 이상인 경우, 기존 FlashAttention‑3 대비 1.8‑2.2×의 속도 향상을 기록했다. 전체 코드는 MIT‑style 라이선스로 공개돼, 향후 다른 라이브러리와의 통합이 용이하도록 설계되었다.


댓글 및 학술 토론

Loading comments...

의견 남기기