주의 인덱스 모델의 베이즈 최적 학습
초록
본 논문은 토큰 임베딩 간의 쌍별 bilinear 상호작용을 고차원에서 모델링하는 ‘Attention‑Indexed Model (AIM)’을 제안하고, 키·쿼리 행렬의 폭이 임베딩 차원과 동등한 실용적인 설정에서도 베이즈 최적 일반화 오차를 정확히 계산한다. 통계역학·랜덤 행렬 이론을 이용해 샘플 복잡도, 시퀀스 길이, 폭 비율에 따른 급격한 학습 전이 현상을 밝히며, 이를 달성하는 Approximate Message Passing (AMP) 알고리즘과 실제 Gradient Descent가 최적 성능에 근접함을 실험적으로 입증한다.
상세 분석
AIM은 기존 multi‑index 모델을 확장해, 각 토큰 a의 d‑차원 임베딩 x_a와 학습 가능한 대칭 행렬 S^{ℓ} (ℓ=1…L) 사이의 bilinear 형태 h^{ℓ}{ab}=x_a^{⊤}S^{ℓ}x_b−δ{ab}Tr S^{ℓ}/√d 로 정의된 L개의 ‘attention index’를 도입한다. 이때 S^{ℓ}= (1/√{r_ℓ d}) W_ℓW_ℓ^{⊤} 로 표현하면, W_ℓ∈ℝ^{d×r_ℓ}는 실제 트랜스포머의 key·query 행렬에 해당한다. 논문은 r_ℓ가 d와 같은 차수(ρ_ℓ=r_ℓ/d=Θ(1))로 스케일링되는 ‘extensive‑width’ 상황을 다루어, 기존 연구가 제한했던 좁은 폭 가정에서 벗어난다.
고차원 극한(d→∞, n/d²=α=Θ(1))에서 입력은 i.i.d. Gaussian이며, S^{ℓ}는 회전 불변 분포 P_S 를 따른다. 베이즈 최적 추정은 사후 평균으로 정의되며, 핵심은 두 종류의 겹침 행렬 q_{ℓk}= (1/d)E_{post}
댓글 및 학술 토론
Loading comments...
의견 남기기