효율적인 확산 트랜스포머를 위한 조각별 희소 어텐션
초록
PISA는 사전 학습된 Diffusion Transformer의 전체 어텐션을 유지하면서도 블록 단위의 테일러 전개를 이용해 비핵심 키‑밸류 블록을 효율적으로 근사한다. 정확히 계산해야 할 핵심 블록은 그대로 유지하고, 나머지는 0차·1차 테일러 근사와 전역 1차 보정을 결합해 서브-제곱 복잡도로 처리한다. 실험 결과 Wan2.1‑14B, Hunyuan‑Video, FLUX 등에서 기존 희소 어텐션 대비 1.9~2.6배 빠르면서 품질 저하를 최소화한다.
상세 분석
본 논문은 Diffusion Transformer(DiT)의 어텐션 연산이 입력 토큰 수 L에 대해 O(L²) 복잡도를 갖는 점을 출발점으로 삼는다. 기존 블록 희소 어텐션은 중요한 키‑밸류 블록만 선택하고 나머지는 완전히 버리는 “keep‑or‑drop” 방식을 사용한다. 이는 고희소도(예: 80~90% 블록 삭제)에서 중요한 컨텍스트 손실을 초래해 이미지·비디오 품질이 급격히 떨어지는 한계를 가지고 있다. 저자들은 사전 학습된 모델의 어텐션 스코어를 분석한 결과, 비핵심 블록의 사전 소프트맥스 점수(QKᵀ)가 대칭적인 베르누이 형태의 부정적 평균을 보이며, 이들 스코어는 평균을 중심으로 한 테일러 전개에 매우 적합하다는 “분포적 안정성”을 발견한다.
이 통찰을 바탕으로 PISA(Piecewise Sparse Attention)는 “정확‑또는‑근사(exact‑or‑approximate)” 전략을 도입한다. 먼저 전체 토큰을 B×B 크기의 블록으로 나누고, 각 쿼리 블록에 대해 중요도 기반 Top‑K 선택을 수행해 S_i(핵심 블록)와 U_i(비핵심 블록)로 구분한다. 핵심 블록은 기존 블록 희소 어텐션과 동일하게 정확히 계산한다. 비핵심 블록은 두 단계 근사를 적용한다. 1) 0차 테일러 근사는 블록 평균 키 ⎯k_j와 블록 전체 값 ⎯v_j를 사용해 exp(q_t·⎯k_j)·⎯v_j 형태로 계산한다. 2) 1차 테일러 보정은 블록 내부 편차 (k_{j,n}−⎯k_j)·v_{j,n} 를 행렬 H_j로 집계하고, 이를 전역 평균 ⎯H 로 대체한다. 전역 보정 계수 β_t는 각 블록의 스케일 exp(q_t·⎯k_j) 평균을 사용해 추정한다. 이렇게 하면 메모리 접근이 최소화되고, 실제 연산은 GEMM 기반 매트릭스 곱과 한 번의 전역 H 로딩만으로 수행된다.
이론적으로는 정규화 분모 D_t와 분자 N_t 모두에 동일한 근사식을 삽입함으로써 소프트맥스의 가중치 합 규칙을 보존한다. 저자는 정리 3.1을 통해 전역 1차 보정이 도입된 경우 근사 오차가 tail fraction ρ_t와 H_j 편차 M에 비례함을 증명한다. ρ_t는 비핵심 블록이 차지하는 확률 질량으로, 비핵심 블록이 대부분을 차지하더라도 exp 값이 작아 ρ_t가 작게 유지된다. 따라서 전체 오차는 실질적으로 무시 수준이다.
시스템 구현 측면에서 PISA는 세 단계 파이프라인을 제시한다. (1) Prepare Phase: 블록 평균 Q, K와 값 합 ⎯V, 전역 H 를 한 번 스캔해 사전 계산한다. (2) Mask Phase: Top‑K 선택으로 S_i 를 결정하고 마스크 M을 생성한다. (3) Fused Attention Kernel: GPU 커널 내부에서 선택된 블록은 Phase‑1(정확) 경로, 비선택 블록은 Phase‑2(0차)와 Phase‑3(전역 1차) 경로를 동적으로 전환한다. 이 설계는 메모리 대역폭을 최소화하고, 기존 FlashAttention의 온라인 소프트맥스 흐름에 자연스럽게 통합된다.
실험에서는 Wan2.1‑14B 비디오 생성 모델과 Hunyuan‑Video, 그리고 텍스트‑투‑이미지 FLUX 모델에 적용해 각각 2.14×, 2.57×, 1.2×의 속도 향상을 달성했다. 품질 지표(PSNR, LPIPS)에서는 기존 희소 어텐션(SparseAttn) 대비 크게 개선되었으며, 특히 높은 희소도(r=85%)에서도 구조적 손실이 거의 없었다. 또한 사전 학습된 가중치를 그대로 재사용할 수 있어 추가 파인튜닝이 필요 없다는 실용적 장점도 강조한다.
요약하면, PISA는 비핵심 블록의 통계적 특성을 활용해 근사 연산을 설계하고, 이를 정규화 과정에 자연스럽게 녹여내어 “속도 vs 품질” 트레이드오프를 크게 완화한다. 기존 희소 어텐션이 갖는 “핵심 정보 손실” 문제를 근본적으로 해결하면서도, 복잡도는 O(L·B) 수준으로 유지한다는 점에서 Diffusion Transformer 기반 생성 모델의 실시간·고해상도 적용에 중요한 전진을 제공한다.
댓글 및 학술 토론
Loading comments...
의견 남기기