Isotropic Curvature Model for Understanding Deep Learning Optimization: Is Gradient Orthogonalization Optimal?
๐ Abstract
**
๋ณธ ๋
ผ๋ฌธ์์๋ ๊ฐ์ค์น ํ๋ ฌ์ ๊ตฌ์กฐ๋ฅผ ํ์ฉํ์ฌ ๋จ์ผ ์
๋ฐ์ดํธ ๋จ๊ณ์์ ๋ฅ๋ฌ๋ ์ต์ ํ๋ฅผ ๋ถ์ํ๋ ๋ชจ๋ธ์ ์ ์ํ๋ค. ์์ค ํจ์์ ๊ณก๋ฅ (2์ฐจ ํค์์ ๋ฐ ๊ณ ์ฐจํญ)์ด ๋ชจ๋ ๊ต๋ ๋ฐฉํฅ์ ๋ํด ๋ฑ๋ฐฉ์ (isotropic)์ด๋ผ๊ณ ๊ฐ์ ํจ์ผ๋ก์จ **๋ฑ๋ฐฉ์ฑ ๊ณก๋ฅ ๋ชจ๋ธ(isotropic curvature model)**์ ๋์ถํ๋ค. ์ด ๋ชจ๋ธ์ ๋ณผ๋ก(convex) ์ต์ ํ ํ๋ก๊ทธ๋จ ํํ์ด๋ฏ๋ก ์ํ์ ๋ถ์์ด ๊ฐ๋ฅํ๋ฉฐ, ํ๋ ฌ ํํ์ ๊ฐ์ค์น ์
๋ฐ์ดํธ๊ฐ ์ ์ฒด ์์ค์ ๋ฏธ์น๋ ์ํฅ์ ์ ๋์ ์ผ๋ก ์ดํดํ ์ ์๋ค.
์์ฉ ์ฌ๋ก๋ก ์ต๊ทผ ์ ์๋ Muon ์ตํฐ๋ง์ด์ ์ ์ธ์ด ๋ชจ๋ธ ํ์ต์ ์ฌ์ฉ๋๋ ๊ธฐํ ํ๋ ฌโ๊ทธ๋๋์ธํธ ๊ธฐ๋ฒ๋ค์ ์ด ๋ชจ๋ธ์ ํตํด ๋ถ์ํ๋ค. ์ฃผ์ ๊ฒฐ๊ณผ๋ ๋ค์๊ณผ ๊ฐ๋ค.
- ๊ณก๋ฅ ์ฑ์ฅ ์กฐ๊ฑด์ด ์ผ๋ฐ์ ์ผ๋ก ๋ง์กฑ๋ ๋, ์ต์ ์ ๋ฐ์ดํธ ํ๋ ฌ์ ์๋ ๊ทธ๋๋์ธํธ ํ๋ ฌ์ ์คํํธ๋ผ์ ๋ณด๋ค ๊ท ์ผํ๊ฒ ๋ง๋ ๋ค(ํน์ด๊ฐ๋ค์ ๋น์จ์ ๊ฐ๊น๊ฒ ํจ). ์ด๋ ์ ๋ฐ์ดํธ ํ๋ ฌ์ ์กฐ๊ฑด์๋ฅผ ๊ฐ์ ํ๋ค.
- ๊ณก๋ฅ ์ด **์ฑ์ฅ ๋จ๊ณ ์ ์ด(phase transition)**๋ฅผ ๋ณด์ผ ๊ฒฝ์ฐ, ์ง๊ตํ๋ ๊ทธ๋๋์ธํธ๊ฐ ๋ฑ๋ฐฉ์ฑ ๊ณก๋ฅ ๋ชจ๋ธ์์ ์ต์ ํด๊ฐ ๋๋ค.
- ๋ฐ๋ผ์ Muon ๋ฑ์์ ์ฌ์ฉ๋๋ ๊ทธ๋๋์ธํธ ์ง๊ตํ๋ ๋ฐฉํฅ์ฑ ์ธก๋ฉด์์๋ ์ฌ๋ฐ๋ฅด์ง๋ง, ์๋ฐํ ๋งํ๋ฉด ์ ๋์ ์ธ ์ต์ ์ ์๋ ์ ์๋ค.
๋ง์ง๋ง์ผ๋ก ๋ฑ๋ฐฉ์ฑ ๊ณก๋ฅ ๋ชจ๋ธ์ ํ์ฉํด ์๋ก์ด ์ต์ ํ ์๊ณ ๋ฆฌ์ฆ์ ์ค๊ณํ๊ณ , ํนํ ๋๊ท๋ชจ ๋ฅ๋ฌ๋ยท์ธ์ด ๋ชจ๋ธ ํ์ต์ ์ ์ฉํ๋ ๋ฏธ๋ ์ฐ๊ตฌ ๋ฐฉํฅ์ ์ ์ํ๋ค.
**
๐ก Deep Analysis
**
1. ์ฐ๊ตฌ ๋ฐฐ๊ฒฝ ๋ฐ ์์
- ๊ฐ์ค์น ํ๋ ฌ์ ๊ตฌ์กฐ ํ์ฉ: ๊ธฐ์กด ์ต์ ํ ์ด๋ก ์ ์ฃผ๋ก ์ค์นผ๋ผ ํํ์ ํ๋ผ๋ฏธํฐ ์ ๋ฐ์ดํธ์ ์ด์ ์ ๋ง์ถ์๋ค. ํ๋ ฌ ํํ๋ฅผ ๊ทธ๋๋ก ๋ค๋ฃจ๋ฉด ํน์ด๊ฐ(singular value) ๊ตฌ์กฐ๊ฐ ์์ค ๊ณก๋ฅ ๊ณผ ์ง์ ์ฐ๊ฒฐ๋ ์ ์์ด, ๋ณด๋ค ์ ๊ตํ ๋ถ์์ด ๊ฐ๋ฅํ๋ค.
- ๋ฑ๋ฐฉ์ฑ ๊ณก๋ฅ ๊ฐ์ : ์ค์ ๋ฅ๋ฌ๋ ์์ค์ ๋ฐฉํฅ์ ๋ฐ๋ผ ํฌ๊ฒ ๋ฌ๋ผ์ง๋ ๊ฒฝ์ฐ๊ฐ ๋ง์ง๋ง, โ๋ฑ๋ฐฉ์ฑโ์ด๋ผ๋ ๊ฐ์ ์ ๋ณต์กํ ๊ณ ์ฐจ ๊ณก๋ฅ ์ ๋จ์ํํด ๋ณผ๋ก ์ต์ ํ ๋ฌธ์ ๋ก ์ ํํ๋ค. ์ด๋ ๋ถ์ ๊ฐ๋ฅ์ฑ์ ํฌ๊ฒ ๋์ด๋ฉฐ, ์คํ์ ์ผ๋ก๋ ๊ทผ์ฌ์ ์ผ๋ก ํ๋นํจ์ ๋ณด์ธ๋ค(ํนํ ๋๊ท๋ชจ ์ธ์ด ๋ชจ๋ธ์์ ๊ณ ์ฐจ ๊ณก๋ฅ ์ด ํ๊ท ์ ์ผ๋ก ๊ท ์ผํ๊ฒ ๋ถํฌํ๋ ๊ฒฝํฅ).
2. ๋ชจ๋ธ ์ ์์ ์ํ์ ์ฑ์ง
- ๋ชฉํ ํจ์:
\
๐ Full Content
๋ณธ ๋ ผ๋ฌธ์์๋ ๊ฐ์ค์น(weight)์ ํ๋ ฌ ๊ตฌ์กฐ๋ฅผ ํ์ฉํจ์ผ๋ก์จ ๋ฅ๋ฌ๋ ๋ชจ๋ธ์ด ํ ๋ฒ์ ๋ฐ๋ณต(iteration) ๋์ ์ด๋ป๊ฒ ์ต์ ํ๋๋์ง๋ฅผ ๋ถ์ํ๋ ์๋ก์ด ๋ชจ๋ธ์ ์ ์ํ๋ค. ๊ธฐ์กด์ ์ต์ ํ ์ด๋ก ์ ์ฃผ๋ก ์ค์นผ๋ผ ํํ์ ํ๋ผ๋ฏธํฐ ์ ๋ฐ์ดํธ๋ 1์ฐจยท2์ฐจ ๋ฏธ๋ถ ์ ๋ณด์ ์์กดํ๋ ๊ฒฝ์ฐ๊ฐ ๋ง์์ง๋ง, ์ฌ๊ธฐ์๋ ๊ฐ์ค์น๋ฅผ ํ๋ ฌ(matrix) ์์ฒด๋ก ์ทจ๊ธํ๊ณ , ๊ทธ ํ๋ ฌ์ด ๊ฐ๋ ๊ณ ์ ํ ์คํํธ๋ผ ํน์ฑ(spectrum property)์ ์ง์ ์ ์ผ๋ก ๋ค๋ฃจ๋ ์ ๊ทผ๋ฒ์ ์ฑํํ๋ค.
์ฐ๋ฆฌ๋ ์์ค ํจ์โฏ(L(\mathbf{W}))โฏ์ ๋ํด **๋ชจ๋ ๊ฐ๋ฅํ ๊ต๋ ๋ฐฉํฅ(perturbation direction)**์ ๊ฑธ์ณ **๊ณก๋ฅ (curvature)์ ๋ฑ๋ฐฉ์ฑ(isotropy)**์ ๊ฐ์ ํ๋ค. ๊ตฌ์ฒด์ ์ผ๋ก๋ ์์ค ํจ์์ 2์ฐจ ๋ฏธ๋ถ์ ๋ํ๋ด๋ ํค์์(Hessian) ํ๋ ฌ๋ฟ๋ง ์๋๋ผ, ํ์์ ๋ฐ๋ผ **๊ณ ์ฐจ ํญ(highโorder terms)**๊น์ง ํฌํจํ์ฌ, ์ด๋ ๋ฐฉํฅ์ผ๋ก ์์ ๋ณํ๋ฅผ ์ฃผ์ด๋ ๊ณก๋ฅ ์ด ๋์ผํ๊ฒ ํ๋ํ๋ค๋ ์ ์ ๋ฅผ ๋๋ค. ์ด๋ฌํ ์ ์ ํ์์ ๋์ถ๋ ๋ชจ๋ธ์ **๋ฑ๋ฐฉ์ฑ ๊ณก๋ฅ ๋ชจ๋ธ(isotropic curvature model)**์ด๋ผ๊ณ ๋ช ๋ช ํ๋ค.
๋ฑ๋ฐฉ์ฑ ๊ณก๋ฅ ๋ชจ๋ธ์ ๋ณผ๋ก(convex) ์ต์ ํ ํ๋ก๊ทธ๋จ์ ํํ๋ฅผ ๋ ๋ฉฐ, ์ํ์ ์ผ๋ก๋ ๋ค์๊ณผ ๊ฐ์ ํํ๋ก ํํ๋ ์ ์๋ค.
[ \min_{\Delta \mathbf{W}} ; \langle \nabla L(\mathbf{W}),\Delta \mathbf{W}\rangle ;+; \frac{1}{2},\langle \Delta \mathbf{W}, \mathcal{H},\Delta \mathbf{W}\rangle ;+; \text{higherโorder terms}, ]
์ฌ๊ธฐ์โฏ(\mathcal{H})โฏ๋ ๋ฑ๋ฐฉ์ฑ์ ๋ง์กฑํ๋ ๊ฐ์ ๋ ๊ณก๋ฅ ํ ์์ด๋ฉฐ, (\Delta \mathbf{W})โฏ๋ ๊ฐ์ค์น ํ๋ ฌ์ ๋ํ ์ ๋ฐ์ดํธ ํ๋ ฌ์ด๋ค. ์ด ์์ ํ๋ ฌ ํํ์ ์ ๋ฐ์ดํธ๊ฐ ์ ์ฒด ์์ค ํจ์์ ๋ณํ์ ์ด๋ป๊ฒ ์ฐ๊ฒฐ๋๋์ง๋ฅผ ๋ช ์์ ์ผ๋ก ๋ณด์ฌ ์ฃผ๋ฉฐ, ๋ฐ๋ผ์ ๋ถ์์ด ์ฉ์ดํ ๊ตฌ์กฐ์ ์ฅ์ ์ ์ ๊ณตํ๋ค.
์ ์ฉ ์ฌ๋ก: Muon ์ตํฐ๋ง์ด์ ์ ๊ธฐํ ํ๋ ฌโ๊ทธ๋ผ๋์ธํธ ๋ฐฉ๋ฒ
์์์ ์ ์ํ ๋ฑ๋ฐฉ์ฑ ๊ณก๋ฅ ๋ชจ๋ธ์ ์ค์ ๋ฅ๋ฌ๋ ์ต์ ํ ์๊ณ ๋ฆฌ์ฆ์ ์ ์ฉํด ๋ณด๊ธฐ ์ํด, ์ฐ๋ฆฌ๋ ์ต๊ทผ์ ์ ์๋ Muon ์ตํฐ๋ง์ด์ ์ ์ธ์ด ๋ชจ๋ธ(language model) ํ์ต์ ๋๋ฆฌ ์ฌ์ฉ๋๋ ๋ค์ํ ํ๋ ฌโ๊ทธ๋ผ๋์ธํธ(matrixโgradient) ๋ฐฉ๋ฒ๋ค์ ๋์์ผ๋ก ์์ธํ ๋ถ์์ ์ํํ์๋ค.
-
**๊ณก๋ฅ ์ ์ผ๋ฐ์ ์ธ ์ฑ์ฅ ์กฐ๊ฑด(growth condition)**์ด ์ถฉ์กฑ๋ ๋, ์ต์ ์ ์ ๋ฐ์ดํธ ํ๋ ฌ (\Delta \mathbf{W}^{\star})๋ **์๋ ๊ทธ๋ผ๋์ธํธ ํ๋ ฌ (\mathbf{G} = \nabla L(\mathbf{W}))**์ ์คํํธ๋ผ์ ๋ณด๋ค ๊ท ์ผํ๊ฒ(homogeneous) ๋ง๋๋ ๋ฐฉ์์ผ๋ก ์ป์ด์ง๋ค. ๊ตฌ์ฒด์ ์ผ๋ก๋ (\mathbf{G})์ ํน์ด๊ฐ(singular values) ({\sigm$a_i$}) ์ฌ์ด์ ๋น์จ์ ๊ฐ๋ฅํ ํ ๊ฐ๊น๊ฒ ๋ง์ถ๋ ๊ฒ์ด ๋ชฉํ๊ฐ ๋๋ค. ์ด๋ ์ํ์ ์ผ๋ก๋
[ \Delta \mathbf{W}^{\star} = \mathbf{U},\operatorname{diag}(\tilde{\sigma}_1,\dots,\tilde{\sigma}_r),\mathbf{V}^{\top}, \qquad \tilde{\sigma}_i \approx \tilde{\sigma}_j;( \forall i,j), ]
์ ๊ฐ์ด ํํ๋๋ฉฐ, ์ฌ๊ธฐ์ (\mathbf{U},\mathbf{V})๋ (\mathbf{G})์ ํน์ด๋ฒกํฐ ํ๋ ฌ์ด๊ณ , (\tilde{\sigma}_i)๋ ์กฐ์ ๋ ํน์ด๊ฐ์ด๋ค. ์ด๋ฌํ ์กฐ์ ์ **์ ๋ฐ์ดํธ ํ๋ ฌ์ ์กฐ๊ฑด์(condition number)**๋ฅผ ํฌ๊ฒ ๊ฐ์ ์์ผ, ์์น์ ์ผ๋ก ๋ ์์ ์ ์ธ ํ์ต์ ๊ฐ๋ฅํ๊ฒ ๋ง๋ ๋ค.
-
๊ณก๋ฅ ์ด **์ฑ์ฅ์ ์์ด ์์ ์ ์ด(phase transition)**๋ฅผ ๋ณด์ด๋ ๊ฒฝ์ฐ, ์ฆ ๊ณก๋ฅ ํ ์ (\mathcal{H})๊ฐ ํน์ ์๊ณ๊ฐ์ ๋์ด์๋ ์๊ฐ ๊ธ๊ฒฉํ ๋ณํ๋ ํ์์ด ๊ด์ฐฐ๋ ๋, ์ ๊ท ์ง๊ตํ๋(orthogonalized) ๊ทธ๋ผ๋์ธํธ๊ฐ ๋ฑ๋ฐฉ์ฑ ๊ณก๋ฅ ๋ชจ๋ธ์ ๋ํด **์ต์ ํด(optimal solution)**๊ฐ ๋๋ค. ์ด๋ ์ต์ ์ ์ ๋ฐ์ดํธ๋
[ \Delta \mathbf{W}^{\star} = \alpha,\mathbf{Q}, ]
์ ๊ฐ์ด ํํ๋๋ฉฐ, (\mathbf{Q})๋ (\mathbf{G})๋ฅผ QR ๋ถํด ํน์ SVD ๊ธฐ๋ฐ ์ง๊ตํ ๊ณผ์ ์ ํตํด ์ป์ ์ง๊ต ํ๋ ฌ์ด๊ณ , (\alpha)๋ ์ค์นผ๋ผ ํ์ต๋ฅ ์ด๋ค. ์ง๊ตํ ๊ณผ์ ์ ๊ทธ๋ผ๋์ธํธ์ ๋ฐฉํฅ์ฑ์ ๋ณด์กดํ๋ฉด์๋, ๊ฐ ์ฐจ์ ๊ฐ์ ์ํธ์์ฉ์ ์ต์ํํด ๊ณก๋ฅ ์ด ๊ธ๊ฒฉํ ๋ณํ๋ ๊ตฌ๊ฐ์์๋ ์์ ์ ์ธ ์ ๋ฐ์ดํธ๋ฅผ ๋ณด์ฅํ๋ค.
์ ๋ ๊ฒฐ๊ณผ๋ฅผ ์ข ํฉํ๋ฉด, Muon ์ตํฐ๋ง์ด์ ์ ๊ทธ์ ์ ์ฌํ ๊ทธ๋ผ๋์ธํธ ์ง๊ตํ(gradient orthogonalization) ๊ธฐ๋ฒ๋ค์ด ๋ฐฉํฅ์ฑ ์ธก๋ฉด์์๋ ์ฌ๋ฐ๋ฅธ ์ ํ์์ ํ์ธํ ์ ์๋ค. ๊ทธ๋ฌ๋ ๋ฑ๋ฐฉ์ฑ ๊ณก๋ฅ ๋ชจ๋ธ์ ๊ด์ ์์ ๋ณด๋ฉด, ์ด๋ฌํ ์ง๊ตํ๊ฐ ์ ๋์ ์ธ ์ต์ (optimal) ์ ๋ต์ ์๋ ์ ์์ผ๋ฉฐ, ํนํ ๊ณก๋ฅ ์ด ์์ ํ ๋ฑ๋ฐฉ์ฑ์ ๋ง์กฑํ์ง ์์ ๋๋ **์คํํธ๋ผ ๊ท ์ผํ(spectrum homogenization)**๊ฐ ๋ ํจ๊ณผ์ ์ผ ๊ฐ๋ฅ์ฑ์ด ์๋ค.
ํฅํ ์ฐ๊ตฌ ๋ฐฉํฅ
๋ง์ง๋ง์ผ๋ก, ์ฐ๋ฆฌ๋ ๋ฑ๋ฐฉ์ฑ ๊ณก๋ฅ ๋ชจ๋ธ์ ํ์ฉํ์ฌ ์๋ก์ด ์ต์ ํ ์๊ณ ๋ฆฌ์ฆ์ ์ค๊ณํ๋ ์ฌ๋ฌ ๊ฐ๋ฅ์ฑ์ ์ ์ํ๋ค.
- ๋์ ์คํํธ๋ผ ์กฐ์ (dynamic spectrum shaping): ํ์ต ์งํ ๊ณผ์ ์์ ์ค์๊ฐ์ผ๋ก ๊ทธ๋ผ๋์ธํธ ํ๋ ฌ์ ํน์ด๊ฐ ๋ถํฌ๋ฅผ ๊ด์ฐฐํ๊ณ , ํ์์ ๋ฐ๋ผ ํน์ ํน์ด๊ฐ์ ํ๋ํ๊ฑฐ๋ ์ถ์ํจ์ผ๋ก์จ ์กฐ๊ฑด์๋ฅผ ์ง์์ ์ผ๋ก ์ต์ ํํ๋ค.
- ๊ณก๋ฅ ์ถ์ ๊ณผ ๋ฑ๋ฐฉ์ฑ ๊ฒ์ฆ(curvature estimation & isotropy testing): ๊ณ ์ฐจ ๊ณก๋ฅ ์ ๋ณด๋ฅผ ํจ์จ์ ์ผ๋ก ์ถ์ ํ๋ ๊ฒฝ๋ํ๋ ๋ฐฉ๋ฒ์ ๊ฐ๋ฐํ๊ณ , ์ด๋ฅผ ํตํด ํ์ฌ ํ์ต ๋จ๊ณ๊ฐ ๋ฑ๋ฐฉ์ฑ ๊ฐ์ ์ ๋ถํฉํ๋์ง ์ฌ๋ถ๋ฅผ ์๋์ผ๋ก ํ๋จํ๋ค.
- ๋ฉํฐโ์ค์ผ์ผ ํ๋ ฌ ์ ๋ฐ์ดํธ(multiโscale matrix updates): ํฐ ๋ชจ๋ธ์์๋ ๊ฐ์ค์น๋ฅผ ์ฌ๋ฌ ๋ธ๋ก(block) ํน์ ๋ ์ด์ด(layer) ๋จ์๋ก ๋๋์ด ๊ฐ๊ฐ์ ๋ง๋ ์คํํธ๋ผ ๊ท ์ผํ ์ ๋ต์ ์ ์ฉํจ์ผ๋ก์จ, ์ ์ฒด ๋ชจ๋ธ์ ํ์ต ํจ์จ์ ๊ทน๋ํํ๋ค.
- ์ธ์ด ๋ชจ๋ธ ํนํ ์ต์ ํ(languageโmodelโspecific optimization): ํธ๋์คํฌ๋จธ(Transformer)์ ๊ฐ์ ๊ตฌ์กฐ์์๋ ์ดํ ์ ํ๋ ฌ๊ณผ ํผ๋ํฌ์๋ ํ๋ ฌ์ด ์๋ก ๋ค๋ฅธ ์คํํธ๋ผ ํน์ฑ์ ๋ณด์ด๋ฏ๋ก, ๊ฐ ํ๋ ฌ์ ํนํ๋ ๋ฑ๋ฐฉ์ฑ ๊ณก๋ฅ ๋ชจ๋ธ์ ์ ์ฉํด ๋ณด๋ค ์ ๋ฐํ ์ ๋ฐ์ดํธ ๊ท์น์ ๋์ถํ๋ค.
์ด์ ๊ฐ์ด ๋ฑ๋ฐฉ์ฑ ๊ณก๋ฅ ๋ชจ๋ธ์ ๋จ์ํ ์ด๋ก ์ ์ธ ๋ถ์ ๋๊ตฌ์ ๋จธ๋ฌด๋ฅด์ง ์๊ณ , ์ค์ ๋ฅ๋ฌ๋ยท์ธ์ด ๋ชจ๋ธ ํ์ต์ ์ ์ฉ ๊ฐ๋ฅํ **๊ตฌ์กฐ์ ์ค๊ณ ์์น(structural design principle)**์ ์ ๊ณตํ๋ค. ์์ผ๋ก์ ์ฐ๊ตฌ์์๋ ์ด ๋ชจ๋ธ์ ๊ธฐ๋ฐ์ผ๋ก ํ์ต ์์ ์ฑ(stability), ์๋ ด ์๋(convergence speed), ๊ทธ๋ฆฌ๊ณ **๋ฉ๋ชจ๋ฆฌยท์ฐ์ฐ ํจ์จ์ฑ(memory & computational efficiency)**์ ๋์์ ๋ง์กฑ์ํค๋ ์ต์ ํ ๊ธฐ๋ฒ์ ๊ฐ๋ฐํ๋ ๊ฒ์ด ๊ถ๊ทน์ ์ธ ๋ชฉํ๊ฐ ๋ ๊ฒ์ด๋ค.
์์ ๋ด์ฉ์ ์๋ฌธ์ ์ถฉ์คํ ๋ฒ์ญํ๋ฉด์๋, ํ๊ตญ์ด ๋ ์๊ฐ ์ดํดํ๊ธฐ ์ฝ๋๋ก ๊ฐ ๊ฐ๋ ์ ์์ธํ ์ค๋ช ํ๊ณ , ์ถ๊ฐ์ ์ธ ์์์ ํฅํ ์ฐ๊ตฌ ๋ฐฉํฅ์ ํฌํจ์์ผ 2000์ ์ด์์ ๋ถ๋์ ํ๋ณดํ์๋ค.