자기 증류를 통한 다중 토큰 예측 가속화
초록
본 논문은 사전 학습된 단일 토큰 예측 언어 모델을 별도 구조 변경 없이 온라인 자기 증류 손실을 이용해 다중 토큰을 한 번에 예측하도록 변환한다. 교사 모델의 확률을 KL‑다이버전스로 활용해 학생 모델이 일관된 토큰 시퀀스를 생성하도록 학습시키며, 추론 시에는 신뢰도 기반 적응 전략을 적용한다. GSM8K 테스트에서 평균 3배 이상의 디코딩 속도를 달성하면서 정확도는 5% 이하로 감소한다.
상세 분석
이 연구는 기존의 스페큘레이티브 디코딩과 달리 별도의 스펙큘레이터 모델이나 복잡한 파이프라인을 요구하지 않는다. 핵심 아이디어는 “학생 강제(student‑forced)” 방식으로, 학생 모델이 한 번의 포워드 패스에서 k개의 토큰을 argmax 로 결정하고, 이를 교사인 기존 NTP(Next‑Token Prediction) 모델에 입력해 전체 시퀀스의 로그우도를 계산한다. 그 로그우도를 학생 모델이 출력하는 다중 토큰 분포와 KL‑다이버전스로 최소화함으로써, 학생은 교사가 부여한 높은 확률을 받는 토큰 조합을 스스로 학습한다.
이 접근법은 두 가지 중요한 장점을 가진다. 첫째, 온‑정책(on‑policy) 학습이므로 학생이 실제 생성하는 토큰 시퀀스에 직접 피드백을 제공한다. 이는 전통적인 오프라인 교차 엔트로피 손실이 각 위치의 주변 분포만을 학습하고 토큰 간 상관관계를 무시하는 문제를 극복한다. 둘째, 교사의 출력은 deterministic argmax 로 고정되므로 학습 과정에서 교사와 학생 사이의 불확실성이 최소화되어 안정적인 신호를 제공한다.
또한 논문은 “hard teacher” 변형을 제안한다. 교사의 확률을 delta‑분포(즉, argmax 결과만)로 제한하면 학생의 엔트로피가 자연스럽게 감소하고, 최종적으로는 교사와 동일한 결정 경로를 따르게 된다. 실험에서는 초기 학생‑교사 파라미터를 동일한 사전 학습 체크포인트로 설정해 학습 초기에 손실이 0에 가깝게 시작하도록 설계하였다.
추론 단계에서는 학생 모델의 소프트맥스 샘플링 대신 argmax 를 기본으로 사용한다. 그러나 모든 토큰이 완전히 확정되지 않을 경우를 대비해, 각 토큰의 confidence score 를 계산하고 사전 정의된 임계값(예: 90%) 이상인 경우에만 다중 토큰을 그대로 받아들인다. 이렇게 하면 “쉬운” 토큰은 한 번에 묶어 빠르게 처리하고, 불확실한 토큰은 기존 단일 토큰 디코딩으로 전환해 정확도를 유지한다.
실험 결과는 GSM8K 수학 문제 풀이에서 평균 3.04개의 토큰을 한 번에 생성하며, 전체 디코딩 속도가 2×~5×까지 향상됨을 보여준다. 정확도는 단일 토큰 디코딩 대비 5% 미만 감소했으며, 이는 스페큘레이티브 디코딩이 요구하는 별도 검증 모델 없이도 실용적인 속도‑정확도 균형을 달성한다는 점에서 의미가 크다.
댓글 및 학술 토론
Loading comments...
의견 남기기