FP8 기반 오자키‑II 스킴으로 구현하는 고정밀 DGEMM 가속화
초록
본 논문은 최신 GPU에서 감소된 INT8 성능을 보완하기 위해 FP8 연산 유닛을 활용한 Ozaki‑II 스킴 기반 FP64 행렬 곱셈 에뮬레이션 방법을 제안한다. Karatsuba 확장과 모듈러 축소 기법을 결합한 하이브리드 방식을 도입해 필요한 FP8 매트릭스 곱셈 횟수를 크게 줄이고, 13~14개의 모듈러만으로도 FP64 수준의 정확도를 달성한다. 또한 성능 모델, 메모리 사용량 분석 및 실험 결과를 제공하고, NVIDIA·AMD GPU용 오픈소스 라이브러리를 공개한다.
상세 분석
Ozaki‑II 스킴은 정수 행렬을 여러 개의 서로소 모듈러(p₁…p_N)로 나누어 CRT(중국 나머지 정리)를 이용해 정확한 결과를 복원하는 방식이다. 기존 INT8 기반 구현은 pₗ≤256 범위에서 정수 행렬을 그대로 INT8 형식으로 표현하고, INT8‑MMA가 INT32 누산을 수행함으로써 연산 오차 없이 진행할 수 있었다. 그러나 FP8(E4M3) 포맷은 표현 가능한 정수 범위가 –16~16에 제한되므로 pₗ≤32만 사용 가능하고, 이 경우 P/2<2⁴⁷에 머물러 FP64를 에뮬레이션하기엔 동적 범위가 부족했다.
이를 극복하기 위해 저자들은 Karatsuba 알고리즘을 차용해 각 모듈러 행렬을 두 개의 FP8 서브행렬(A′(1)ₗ, A′(2)ₗ 등)으로 분해하고, s=16이라는 스케일링 상수를 도입했다. 이렇게 하면 서브행렬의 원소 절댓값이 2⁴ 이하가 되어 FP8에 정확히 표현될 수 있다. Karatsuba 재구성을 통해 A′ₗ·B′ₗ을 세 번의 FP8 매트릭스 곱셈(C′(1)ₗ, C′(2)ₗ, C′(3)ₗ)과 몇 개의 선형 조합으로 복원한다. 이 과정에서 k≤2¹⁶이면 곱셈 결과가 24비트 이하이므로 FP32 누산에서도 반올림 오차가 발생하지 않는다.
Karatsuba만 사용하면 모듈러 개수가 N≥13이어야 P/2>2¹¹⁵≈2⁵³+⁵³을 만족해 FP64 수준의 정확도를 얻을 수 있다. 저자들은 추가로 “모듈러 축소 없이 Karatsuba” 방식을 도입한다. s²=pₗ인 제곱 모듈러에 대해 s²·A′(1)ₗ·B′(1)ₗ 항이 0(mod pₗ)임을 이용해, 세 번의 곱셈(A′(1)ₗ·B′(2)ₗ, A′(2)ₗ·B′(1)ₗ, A′(2)ₗ·B′(2)ₗ)만으로 C′ₗ을 계산한다. 제곱 모듈러는 1089, 1024, 961 등으로 선택하고, 나머지 비제곱 모듈러는 기존 Karatsuba 방식을 적용한다. 이 하이브리드 전략은 전체 FP8 매트릭스 곱셈 횟수를 N≥12에서 충분히 커버하도록 감소시켜, INT8 기반 Ozaki‑II(14 모듈러) 대비 약 15%~20% 연산량 절감 효과를 얻는다.
스케일링 벡터 µ, ν의 선택은 기존 방식과 동일하게 “fast mode”와 “accurate mode”를 제공한다. 정확도 모드에서는 FP8 매트릭스 곱셈을 이용해 상한을 직접 계산하고, µ′, ν′를 2⁷·ufp(max|a|) 형태로 정의해 2·µ_i·w̄_ij·ν_j < P 조건을 만족시킨다.
성능 모델링에서는 FP8‑MMA의 TFLOP/s와 메모리 대역폭을 고려해 연산량 대비 기대 속도를 예측하고, INT8‑MMA와 비교해 FP8‑MMA가 제공하는 높은 연산 집약성을 정량화한다. 메모리 사용량 분석에서는 각 모듈러별 서브행렬이 추가되는 오버헤드를 상세히 제시하고, 하이브리드 방식이 메모리 풋프린트를 10%~15% 절감함을 보인다.
마지막으로 저자들은 CUDA와 ROCm 양쪽을 지원하는 오픈소스 라이브러리를 공개했으며, 동일한 툴체인 하에서 비트 단위 재현성을 보장한다. 실험에서는 DGEMM 벤치마크와 실제 과학 응용(예: 전자 구조 계산)에서 FP8‑Ozaki‑II가 FP64 직접 연산 대비 2.5×~3.2× 속도 향상을 달성하면서도 평균 상대 오차 1e‑13 이하의 정확성을 유지함을 입증한다.
댓글 및 학술 토론
Loading comments...
의견 남기기