모듈형 그래디언트 수술로 다중 도메인 RL 방해 극복
초록
본 논문은 대규모 추론 모델(LRM)을 수학, 일반 채팅, 명령 수행 등 이질적인 도메인에 동시에 강화학습(RL)으로 튜닝할 때 발생하는 교차 도메인 간섭을 분석한다. 순차 RL과 혼합 RL 모두 행동·그래디언트 수준에서 큰 충돌을 일으키며, 이를 해결하기 위해 트랜스포머 내부 모듈 단위에서 그래디언트를 정제하는 “모듈형 그래디언트 수술(MGS)”을 제안한다. Llama‑3.1‑8B와 Qwen‑2.5‑7B에 적용한 결과, 평균 4.34.5점(1117% 상대) 향상을 달성했으며, 장기 학습에서도 효과가 유지된다.
상세 분석
본 연구는 대규모 언어 모델을 다중 도메인 추론 작업에 적용하기 위한 강화학습(RL) 전략을 체계적으로 검증한다. 먼저 두 가지 전통적 접근법, 즉 **순차 RL(Sequential RL)**과 **혼합 RL(Mixed RL)**을 실험한다. 순차 RL은 한 도메인을 먼저 학습한 뒤 다른 도메인으로 전이하는 방식인데, 실험 결과 ‘망각(Forgetting)’과 ‘경직(Rigidity)’이라는 두 가지 형태의 모드 간섭이 발생한다. 특히 채팅 도메인에서 수학 도메인으로 전이할 경우, 기존에 학습된 채팅 능력이 급격히 감소하고, 반대로 수학 → 채팅 전이는 상대적으로 완만한 성능 저하를 보인다. 이는 각 도메인의 최적화 목표가 서로 경쟁하면서 파라미터 공간을 서로 다른 방향으로 끌어당기기 때문이다.
혼합 RL은 배치 내에 서로 다른 도메인의 샘플을 동시에 섞어 학습함으로써, 한 번에 여러 도메인의 그래디언트를 적용한다. 그러나 이 경우에도 **그래디언트 충돌(Gradient Conflict)**이 빈번히 발생한다. 특히 수학과 채팅을 1:1 비율로 섞었을 때, 두 도메인 모두 단일 도메인 전문가 수준에 도달하지 못한다. 비율을 극단적으로 조정해도 한쪽 도메인의 성능은 향상되지만, 다른 쪽은 크게 손실된다. 이는 전체 파라미터가 모든 도메인에 대해 동시에 최적화되기 어려운 구조적 한계를 드러낸다.
이러한 문제를 해결하기 위해 저자들은 **모듈형 그래디언트 수술(MGS)**을 제안한다. 트랜스포머는 다층의 MLP, Self‑Attention, Feed‑Forward 등으로 구성된 모듈 구조를 가지고 있으며, 각 모듈은 특정 기능(예: 토큰 관계 파악, 논리 흐름 유지 등)에 특화된다. MGS는 각 모듈별로 도메인별 그래디언트를 수집한 뒤, **프로젝션 기반 그래디언트 수술(Gradient Surgery)**을 적용한다. 구체적으로, 두 도메인 간의 그래디언트가 내적이 음수인 경우(즉, 서로 반대 방향으로 작용) 이를 정규화하여 양의 성분만 남기고, 충돌을 최소화한다. 이렇게 하면 각 모듈은 자신에게 가장 유익한 업데이트만을 받아, 도메인 간 간섭을 모듈 수준에서 억제한다.
실험에서는 Llama‑3.1‑8B와 Qwen‑2.5‑7B 두 모델에 MGS를 적용했으며, 수학, 일반 채팅, 명령 수행 세 도메인에서 모두 평균 4.3~4.5점(16.6%·11.1% 상대) 향상을 기록했다. 특히 장기 학습(2배 에포크)에서도 성능 향상이 지속되었으며, 추가적인 도메인(코드, 창의적 글쓰기)에도 일반화가 확인되었다. 계산 비용 측면에서는 기존 FSDP(Full‑State‑Data‑Parallel) 프레임워크와 결합했을 때 오버헤드가 거의 없으며, 기존 글로벌 그래디언트 수술(Global Gradient Surgery) 대비 효율성이 크게 개선되었다.
핵심 인사이트는 다음과 같다. 1) 다중 도메인 RL에서 발생하는 간섭은 모듈 단위에 국한되는 경우가 많으며, 전체 파라미터에 균등하게 적용할 필요가 없다. 2) 그래디언트 충돌을 **양방향(negative inner product)**으로 정의하고 이를 정제하면, 각 도메인의 특화된 학습 신호를 보존하면서도 상호 간섭을 최소화할 수 있다. 3) MGS는 순차 RL의 ‘모드 간섭’과 혼합 RL의 ‘그래디언트 충돌’ 두 문제를 동시에 해결하는 통합적 접근법으로, 향후 다양한 멀티태스크 RL 시나리오에 적용 가능성이 높다.
댓글 및 학술 토론
Loading comments...
의견 남기기