대규모 언어 모델 훈련을 위한 확장 가능한 곡률 측정법

대규모 언어 모델 훈련을 위한 확장 가능한 곡률 측정법
안내: 본 포스트의 한글 요약 및 분석 리포트는 AI 기술을 통해 자동 생성되었습니다. 정보의 정확성을 위해 하단의 [원본 논문 뷰어] 또는 ArXiv 원문을 반드시 참조하시기 바랍니다.

초록

본 논문은 LLM 훈련 중 손실 곡률을 효율적으로 추정할 수 있는 “critical sharpness(λ_c)”를 제안한다. 10회 이하의 순전파만으로 η_c(손실이 증가하기 시작하는 최소 학습률)를 찾고, λ_c=2/η_c 로 정의한다. 이 측정값은 Hessian sharpness와 이론적으로 연관되며, progressive sharpening 및 Edge of Stability 현상을 대규모(최대 7 B 파라미터) 모델에서도 정확히 포착한다. 또한, pre‑training과 fine‑tuning 사이의 곡률 차이를 나타내는 “relative critical sharpness(λ_c^{1→2})”를 도입해 데이터 믹싱 전략을 제시한다.

상세 분석

논문은 먼저 손실 함수 L(θ)의 2차 근사를 이용해 “critical learning rate η_c”를 정의한다. η_c는 현재 업데이트 방향 Δθ에 대해 손실이 처음으로 증가하는 최소 학습률이며, 이를 찾기 위해 지수 탐색과 이진 탐색을 결합한 두 단계 라인서치를 사용한다. 이 과정은 순전파만 필요하므로 대규모 분산 학습 파이프라인에 그대로 적용 가능하고, 평균 56번의 forward pass로 η_c를 추정한다. 정의된 λ_c=2/η_c는 기존 Hessian sharpness λ_H^max와의 관계를 이론적으로 분석한다. 2차 근사 하에서 손실 증가 조건은 η>2/λ_dir이며, 여기서 λ_dir=ΔθᵀHΔθ / Δθᵀg는 “directional sharpness”이다. 논문은 λ_dir가 Hessian 고유값들의 가중합이며, 가중치는 gradient가 각 고유벡터와 얼마나 정렬되는가에 따라 결정된다고 증명한다(결과 2.1). 따라서 gradient가 가장 큰 고유벡터와 완전히 정렬될 경우 λ_dir=λ_H^max가 되고, λ_c와 λ_H^max가 일치한다. Adaptive optimizer(예: Adam)의 경우에도 pre‑conditioned Hessian와 pre‑conditioned gradient를 이용해 동일한 형태의 관계(결과 2.2)를 얻는다. 실험에서는 CIFAR‑10 MLP와 다양한 배치 크기에서 λ_c와 λ_dir가 거의 일치함을 보여주며, Hessian sharpness와는 초기 구간에서 차이를 보이지만 전체적인 “progressive sharpening”과 “Edge of Stability(EoS)” 현상을 동일하게 포착한다. 대규모 LLM(OLMo‑2, 0.3B7B) 실험에서는 λ_c가 학습 전반에 걸쳐 지속적으로 증가하고, 일정 시점에서 2/η (학습률)와 교차하면서 EoS에 도달함을 확인한다. 특히, 학습률 스케줄이 변할 때 λ_c는 학습률 변화를 따라 움직이며, 이는 기존 Hessian 기반 분석과 일치한다. 논문의 또 다른 핵심 기여는 “relative critical sharpness λ_c^{1→2}”이다. 이는 사전 학습 손실 L₁과 미세 조정 손실 L₂ 사이에서 현재 업데이트 방향이 L₁에 대해 얼마나 날카로운지를 측정한다. 실험에서는 사전 학습 데이터 비중을 높일수록 λ_c^{1→2}가 감소해 모델이 사전 학습 베이스에 머무르게 되고, 이는 수학 문제(GSM8K) 성능 향상에 기여한다. 반대로, 베이스를 유지하면 일반 추론(MMLU) 성능이 향상된다. 따라서 λ_c^{1→2}는 데이터 믹싱 비율을 조절해 특정 downstream task에 최적화된 훈련 전략을 설계하는 실용적인 지표가 된다. 전체적으로 논문은 Hessian 계산의 고비용을 피하면서도 곡률 정보를 충분히 제공하는 측정법을 제시하고, 이를 통해 대규모 LLM 훈련의 안정성 진단, 학습 스케줄 설계, 데이터 구성 최적화 등에 직접 활용할 수 있음을 입증한다.


댓글 및 학술 토론

Loading comments...

의견 남기기