제때 멈추는 토큰 확산 언어 모델 조기 종료
초록
확산 언어 모델에서 토큰별로 예측이 안정화되는 시점을 실시간으로 감지해, 필요 없는 디노이징 단계를 건너뛰는 훈련‑무료 기법 JOT을 제안한다. 토큰 신뢰도(상위 두 확률 비율)와 주변 토큰의 해제 정도를 결합한 동적 임계값을 사용해 각 위치를 독립적으로 얼리 스톱하고, 실험에서 GSM8K·MMLU·HellaSwag·HumanEval 등 4개 벤치마크에 걸쳐 5배~20배 속도 향상을 달성하면서 정확도 손실을 2% 이하로 억제한다.
상세 분석
JOT은 확산 언어 모델(DLM)의 디코딩 과정에서 토큰별 수렴 시점을 정밀하게 추정하는 메커니즘이다. 핵심은 두 단계로 구성된다. 첫 번째는 신뢰도 지표로, 각 마스크된 위치 i에서 모델이 출력한 로짓을 소프트맥스해 얻은 확률 분포 π_i에서 가장 큰 확률 p_i1과 두 번째로 큰 확률 p_i2의 비율 r_i = p_i1/(p_i2+ε)를 계산한다. 이 비율은 상위 후보 간의 확신 정도를 직접 반영하므로, r_i가 클수록 해당 토큰을 조기에 확정해도 오류 위험이 낮다. 두 번째는 공간적 완화이다. 이미 해제된 토큰 주변은 풍부한 양방향 문맥을 제공하므로, 해당 위치의 신뢰도 요구치를 낮출 필요가 있다. 이를 위해 반경 D 내 마스크된 토큰들의 거리 가중합 w_i = Σ_{|i−j|≤D} γ^{|i−j|}를 구하고, 정규화된 가중치 ϕ_i = min(1, w_i / w_max)으로 변환한다. 여기서 γ∈(0,1)은 거리 감쇠율이며, w_max은 이론적 최대 가중치이다.
이후 동적 임계값 τ_i를 τ_max와 τ_min 사이에서 선형 보간한다: τ_i = τ_max − (τ_max−τ_min)·ϕ_i. 즉, ϕ_i가 클수록(주변에 해제된 토큰이 많을수록) τ_i가 낮아져 r_i가 작아도 조기 종료가 허용된다. 디코딩 루프에서 각 단계마다 현재 마스크된 집합 M_n을 확인하고, r_i ≥ τ_i인 토큰을 최종값으로 고정한다. 고정된 토큰은 이후 단계에서 모델 입력에서 마스크 토큰으로 대체되지 않으며, 남은 토큰들은 기존 전통적인 전이 스케줄(예: confidence‑based top‑k unmasking)에 따라 진행된다.
알고리즘 자체는 모델의 순전파 비용 외에 1‑D 컨볼루션을 이용한 가중치 계산만 추가되므로, 연산 오버헤드가 무시할 수준이다. 또한 JOT은 기존의 전이 스케줄을 대체하지 않고 보완하므로, KV‑caching이나 Fast‑dLLM 같은 단계당 지연 감소 기법과도 자연스럽게 결합될 수 있다.
실험에서는 Dream‑7B‑Instruct와 LLaDA‑8B‑Instruct 두 대형 DLM을 대상으로 GSM8K(수학), MMLU(다중 과제), HellaSwag(상식), HumanEval(코드) 네 가지 벤치마크에서 평가하였다. JOT은 Prophet(상위‑2 확신 차이 기반)과 KLASS(KL‑divergence 기반)보다 높은 속도 향상을 보였으며, 특히 HumanEval에서는 19.6배 가속화에도 불구하고 정확도 손실이 0.6%에 불과했다. Ablation 연구에서는 τ_max/τ_min 설정과 γ, D 파라미터가 성능에 미치는 영향을 분석했으며, 공간적 완화 없이 순수 r_i 기반 임계값만 사용할 경우 속도는 비슷하지만 정확도 저하가 더 크게 나타났다.
한계점으로는 현재 1‑차원 거리 기반 가중치만 사용해 문맥 구조를 단순화했으며, 복잡한 문장 구조나 장문에서는 더 정교한 그래프 기반 혹은 트리 구조의 문맥 전파가 필요할 수 있다. 또한 토큰 수준의 조기 종료가 과도하게 일어나면 일관성(예: 코드 블록) 손상이 발생할 위험이 있어, 작업 특성에 맞는 임계값 튜닝이 요구된다. 그럼에도 불구하고 JOT은 훈련‑무료, 하이퍼파라미터가 비교적 직관적이며, 다양한 DLM에 바로 적용 가능한 실용적인 효율화 도구로 평가된다.
댓글 및 학술 토론
Loading comments...
의견 남기기