공유 프리픽스 재사용을 통한 에이전트 LLM 효율적 트리 학습
초록
에이전트형 대형 언어 모델의 멀티턴 상호작용은 토큰 흐름이 트리 구조를 이루며 많은 공통 프리픽스를 포함한다. 기존 학습 파이프라인은 이러한 트리를 선형화해 각 브랜치를 독립적으로 처리하므로 전·후방 연산이 중복된다. 논문은 “Tree Training”이라는 프레임워크를 제안한다. 핵심은 Gradient Restoration 기법으로 공유 프리픽스에 대한 역전파 그래디언트를 올바르게 합산해 한 번만 계산하도록 하며, Tree Packing을 통해 메모리 제한 하에서도 높은 프리픽스 재사용률을 유지한다. 실험 결과, SFT와 RL 단계 모두에서 6.2배 가량 학습 속도가 향상되었으며 모델 품질은 유지된다.
상세 분석
본 논문은 에이전트형 LLM이 수행하는 복잡한 멀티턴 작업이 자연스럽게 트리 형태의 토큰 궤적을 만든다는 관찰에서 출발한다. 각 턴에서 도구 호출, 서브‑에이전트 호출, 생각 모드 전환 등으로 인해 동일한 초기 컨텍스트가 여러 갈래로 분기되며, 이러한 공통 프리픽스는 전방 연산에서는 캐시를 통해 재사용이 가능하지만 역방향 연산에서는 각 브랜치의 손실이 프리픽스에 역전파되기 때문에 기존 캐시 방식은 메모리와 연산 효율성을 크게 저해한다.
Gradient Restoration은 이 문제를 해결하기 위해 토큰 수준의 보정 항을 도입한다. 구체적으로, 공유 프리픽스 P에 대해 각 브랜치 i가 생성한 손실에 대한 그래디언트 dY₍i₎를 모두 합산한 뒤, P에 대한 최종 그래디언트를 Pᵀ·(∑₍i₎ dY₍i₎) 형태로 재구성한다. 이는 선형 변환 Y = X·W에 대한 역전파 식을 그대로 적용한 것으로, 프리픽스를 한 번만 전방 계산하고도 동일한 파라미터 업데이트를 보장한다.
또한, 트리 구조 데이터를 그대로 받아들일 수 있도록 학습 엔진을 재설계하고, 대규모 트리를 GPU 메모리 제한 내에 적재하기 위해 Tree Packing이라는 DFS 기반 분할 전략을 제시한다. 이 전략은 전체 트리를 가능한 한 큰 서브트리로 묶어 프리픽스 재사용 비율을 최적화하면서도 각 서브트리의 토큰 수가 메모리 한도를 초과하지 않도록 한다. 실험에서는 전체 토큰 수 83k인 트리를 60k 토큰 메모리 한도에 맞게 102k 토큰으로 압축해, 기존 평탄화 방식(164k 토큰) 대비 38% 정도의 메모리 절감과 6.2배 학습 속도 향상을 달성했다.
핵심 기여는 (1) 에이전트 LLM 학습에서 트리형 프리픽스 재사용 가능성을 최초로 정량화, (2) Gradient Restoration과 Tree Packing을 결합한 완전한 트리 학습 파이프라인 구축, (3) 다양한 밀집 모델 및 MoE 모델에 적용해 품질 저하 없이 대규모 속도 개선을 실증한 점이다. 특히 RL 단계에서도 동일한 프레임워크를 적용해 정책 업데이트와 가치 모델 학습 모두에서 효율성을 확보했으며, 이는 향후 에이전트형 LLM의 대규모 파인튜닝 및 지속 학습에 중요한 기반이 될 것으로 기대된다.
댓글 및 학술 토론
Loading comments...
의견 남기기