생각거품: 잠재공간에서 병렬 사고를 구현하는 비지도 방법
초록
Thoughtbubbles는 토큰별로 잔차 스트림을 복제·삭제하는 “포크” 메커니즘을 도입해, 어려운 토큰에 대해 잠재공간에서 병렬 연산을 자동으로 생성한다. 언어 모델링 손실만으로 사전학습이 가능하며, 150M‑1.9B 규모에서 기존 디코더 LM보다 낮은 토큰 예산으로 퍼플렉시티와 제로샷 성능을 향상시킨다.
상세 분석
본 논문은 트랜스포머의 고정된 연산 예산을 넘어서는 적응형 병렬 연산을 잠재공간에서 구현하는 새로운 아키텍처인 Thoughtbubbles를 제안한다. 핵심 아이디어는 “포크(fork)” 연산으로, 각 레이어 사이에 삽입된 포크 레이어가 잔차 스트림마다 두 개의 스코어(keep와 fork)를 계산하고, 누적 스코어(p_cum)를 통해 상위 k개의 스트림만을 유지·복제한다. 이때 스코어는 시그모이드와 로그-스페이스 연산으로 안정성을 확보하며, 원본 토큰은 항상 keep 스코어가 1로 강제되어 최소 하나의 스트림은 보존된다. 포크된 스트림은 레이어별 학습된 포크 임베딩(v′)을 더해 원본과 구분되며, 포크 횟수에 비례해 RoPE 위치 임베딩을 부분적으로 회전시켜 위치 정보를 유지한다.
포크 판단 후에는 누적 스코어 벡터 P(k)를 이용해 어텐션과 잔차 업데이트를 스코어에 따라 감쇠(attenuate)한다. 구체적으로 어텐션 스코어에 log P(k)ᵀ를 더해 낮은 스코어를 가진 스트림이 어텐션에 기여하지 못하도록 하고, V값에도 P(k)를 원소별 곱한다. 이후 MLP와 레이어 정규화에서도 동일하게 스코어를 곱해, 모델이 “삭제될” 스트림에 대한 업데이트를 스스로 억제하도록 학습한다.
출력 단계에서는 각 스트림을 독립적으로 디코딩하고, 누적 스코어를 가중치로 사용해 확률 분포를 가중 평균한다. 원칙적인 구현은 로그-섬-익스프(log‑sum‑exp) 트릭을 사용해 수치적 안정성을 확보하지만, 대규모 모델에서는 연산 비용을 줄이기 위해 스트림 자체를 가중 평균한 후 소프트맥스를 적용하는 근사 방식을 채택한다.
학습은 기존 디코더‑전용 LM과 동일하게 크로스 엔트로피 손실만을 사용하며, 포크 레이어는 초기 몇 개 블록 뒤에 삽입해 충분한 컨텍스트를 확보한다. 실험에서는 1.9B 모델을 40B 토큰(절반 예산)으로 학습했을 때, 퍼플렉시티와 GSM8K, MMLU 등 제로샷 벤치마크에서 기존 비포크 모델보다 일관된 개선을 보였다. 또한 150M‑772M 규모에서도 스코어가 높은 토큰(불확실도 높은 영역)에서 포크가 집중되는 현상을 관찰, 모델이 자체적으로 연산 자원을 할당한다는 해석이 가능하다.
이러한 설계는 사전학습 단계부터 적응형 연산을 학습하게 함으로써, 추론 시에도 별도의 체인‑오브‑쓰앗 프롬프트 없이 내부적으로 병렬 “생각”을 수행한다는 점에서 기존 CoT 기반 방법과 근본적으로 차별화된다.
댓글 및 학술 토론
Loading comments...
의견 남기기