프리즘: 스펙트럼‑적응 다항식과 랜덤 스케치로 가속하는 행렬 함수 계산
초록
**
프리즘(PRISM)은 행렬 함수(제곱근, 역제곱근, 정규직교화 등)를 GPU 친화적인 반복 알고리즘에 적용하기 위해, 현재 반복 단계의 스펙트럼을 랜덤 스케치 기반 최소제곱 문제로 빠르게 추정하고 그에 맞는 다항식 근사를 자동으로 업데이트한다. 사전 스펙트럼 정보가 필요 없으며, 뉴턴‑슐츠 계열의 반복을 크게 가속한다. 실험에서는 Shampoo와 Muon 최적화기에 통합했을 때 학습 속도가 현저히 향상되었다.
**
상세 분석
**
프리즘은 기존의 행렬 함수 계산 방법이 갖는 두 가지 근본적인 한계를 해결한다. 첫째, 전통적인 다항식 가속 기법은 사전에 최소·최대 특이값 ℓ, u 를 알아야 최적의 다항식 계수를 설계할 수 있었지만, 실제 딥러닝 파이프라인에서는 이러한 정보를 얻는 비용이 원래 문제와 맞먹는다. 프리즘은 “분포‑프리(distribution‑free)” 접근을 채택해, 매 반복마다 현재 추정 행렬 Xₖ 의 잔차 Rₖ = ξ(Xₖ, A) 에 대해 랜덤 스케치 S 를 적용해 저차원 공간에서 α 를 최소화한다. 이 과정은 O(n² log n) 의 연산량으로, GEMM이 차지하는 O(n³) 비용에 비해 무시할 수준이다.
둘째, 기존 가속 기법은 고정된 다항식 p_d (ξ) 을 사용해 초기 수렴 속도는 개선되지만, 스펙트럼이 변하면 급격히 성능이 저하된다. 프리즘은 다항식 g_d(ξ; α) = f_d‑1(ξ) + α ξ^d 형태를 도입해, 매 단계마다 최적 αₖ 를 찾음으로써 현재 스펙트럼에 가장 근접한 근사함수를 만든다. 이 적응 과정은 잔차 ‖Rₖ₊₁‖_F 를 직접 최소화하므로, 수렴률이 이론적으로 기존 뉴턴‑슐츠와 동등하거나 더 빠르다.
프리즘 메타알고리즘은 크게 두 파트로 구성된다. Part I에서는 목표 행렬 함수 T(A) 를 x·f(ξ) 형태로 재구성하고, 스칼라 테일러 전개 f_d 를 이용해 기본 반복 Xₖ₊₁ = Xₖ f_d(Rₖ) 를 만든다. Part II에서는 (4) 다항식 피팅 단계와 (5) 스케치 단계로 확장한다. 여기서 스케치 행렬 S 는 서브샘플링, 랜덤 프로젝션, 혹은 레버리지 스코어 기반 선택 등 다양한 RandNLA 기법을 사용할 수 있다. 논문은 특히 행렬 부호, 제곱근, 정규직교화에 대해 구체적인 파라미터 선택과 수렴 분석을 제공한다.
실험 결과는 두 가지 차원에서 의미 있다. 첫째, 스펙트럼이 넓게 퍼진 경우(예: Marchenko‑Pastur 분포)와 극단적인 작은 특이값을 가진 경우(예: 사전 학습된 모델) 모두에서 프리즘은 기존 PolarExpress·CANS 등 최적화된 다항식 방법보다 일정 수준 이상의 GPU 시간 가속을 달성한다. 둘째, 실제 딥러닝 훈련에 Shampoo(제곱근 기반 프리컨디셔너)와 Muon(정규직교화 기반 모멘텀)에 적용했을 때, 전체 에포크당 학습 시간이 평균 12‑18% 감소했으며, 최종 모델 정확도에는 전혀 영향을 주지 않았다.
프리즘의 한계도 언급한다. 스케치 차원 s 를 너무 작게 잡으면 αₖ 추정이 부정확해져 수렴이 느려질 수 있다. 또한, 현재 구현은 대칭 잔차 Rₖ 가 필요하므로, 비대칭 행렬에 대한 직접 적용은 추가 변형이 필요하다. 향후 연구에서는 비대칭 케이스, 다중 GPU 분산 환경, 그리고 다항식 대신 라플라스 변환 기반의 랜덤 근사 등으로 확장할 여지가 있다.
**
댓글 및 학술 토론
Loading comments...
의견 남기기