출력 임베딩 중심화로 안정적인 대형 언어 모델 사전학습
초록
본 논문은 대형 언어 모델(LLM) 사전학습 시 학습률이 높아지는 말단 단계에서 발생하는 출력 로짓 발산 문제를 출력 임베딩의 기하학적 특성에서 분석한다. 기존의 z‑loss는 증상만 완화시키는 반면, 저자들은 출력 임베딩을 평균 0으로 중심화하는 µ‑centering과 이를 정규화 형태로 구현한 µ‑loss를 제안한다. 이 방법은 로짓 발산을 이론적으로 억제하고, 실험적으로 z‑loss 대비 학습 안정성, 학습률 민감도, 하이퍼파라미터 튜닝 부담을 크게 개선함을 보인다.
상세 분석
논문은 먼저 대형 언어 모델의 출력 단계가 “입력 토큰 → 출력 임베딩 → 선형 변환 → 로짓”이라는 구조임을 재조명한다. 여기서 출력 임베딩 행렬 (E\in\mathbb{R}^{V\times d}) (V는 어휘 크기, d는 차원)의 평균 벡터 (\mu = \frac{1}{V}\sum_{i=1}^{V}E_i)가 0이 아니면, 선형 변환 후 로짓에 (\mathbf{W}\mu)라는 상수 편향이 추가된다. 학습이 진행될수록 (\mu)는 점차 커지거나 불균형하게 변동하며, 특히 학습률이 크게 설정된 말단 단계에서 (\mathbf{W}\mu)가 급격히 증폭돼 로짓 값이 전체 어휘에 대해 동일하게 상승하거나 하강한다. 이는 softmax의 분포를 왜곡시켜 “출력 로짓 발산” 현상을 초래한다. 기존의 z‑loss는 로짓의 L2 노름을 직접 억제해 증상을 완화하지만, (\mu) 자체가 비대칭인 근본 원인을 해결하지 못한다.
저자들은 (\mu)를 0으로 강제하는 두 가지 접근법을 제시한다. 첫 번째는 매 미니배치 혹은 전체 파라미터 업데이트 직후에 (E \leftarrow E - \mu)를 수행하는 deterministic µ‑centering이다. 이는 임베딩 공간을 원점에 맞추어 (\mathbf{W}\mu)를 완전히 제거한다. 두 번째는 µ‑loss라는 정규화 항 (\lambda|\mu|_2^2)를 손실에 추가하는 방법이다. µ‑loss는 학습 과정에서 (\mu)가 0에 가까워지도록 유도하면서도 기존 손실과 동시에 최적화되므로, 구현이 간단하고 기존 파이프라인에 바로 삽입 가능하다.
이론적 증명에서는 (\mu=0)일 때 로짓은 (\mathbf{W}E)만으로 결정되며, (\mathbf{W}\mu)에 의한 편향이 사라져 로짓 발산이 수학적으로 억제됨을 보인다. 또한, µ‑loss는 (\lambda) 하이퍼파라미터에 대해 완만한 민감도를 보이는데, 이는 (|\mu|_2) 자체가 이미 작은 값으로 수렴하기 때문에 과도한 정규화가 필요 없음을 의미한다. 실험에서는 GPT‑style 모델(1.3B 파라미터)과 다양한 데이터셋에서 학습률을 2배~4배 확대했을 때, µ‑centering과 µ‑loss가 모두 수렴을 유지하고, z‑loss는 발산하거나 급격한 손실 진동을 보였다. 특히 µ‑loss는 학습 초반에 별도 튜닝 없이 (\lambda=1e-4) 정도만으로도 안정적인 학습을 달성했다.
결과적으로, 출력 임베딩 중심화는 로짓 발산의 근본 원인을 제거함으로써 대형 모델의 고학습률 사전학습을 가능하게 하고, 기존 z‑loss 대비 하이퍼파라미터 관리 부담을 크게 낮춘다.
댓글 및 학술 토론
Loading comments...
의견 남기기