하나의 사이클로 구조적 프루닝을 구현하는 안정성 기반 서브네트워크 탐색
초록
본 논문은 사전 학습‑프루닝‑미세조정의 다단계 과정을 하나의 학습 사이클로 통합한 OCSPruner 프레임워크를 제안한다. 그룹 ‑L2 정규화 기반의 그룹 살리언시와 점진적 구조적 스파시티 정규화를 이용해 초기 학습 단계에서 최적 서브네트워크를 탐색하고, 연속 에포크 간 서브네트워크 유사도를 Jaccard 지표로 측정해 안정적인 프루닝 시점을 자동으로 결정한다. CIFAR‑10/100, ImageNet에서 VGG, ResNet, MobileNet에 적용해 기존 방법 대비 동일 혹은 향상된 정확도와 1.3‑1.5배 빠른 학습 속도를 달성하였다.
상세 분석
OCSPruner는 구조적 프루닝을 한 번의 학습 사이클에 녹여내는 효율적인 파이프라인을 설계했다. 핵심 아이디어는 “프루닝 안정성 지표”를 도입해, 매 에포크마다 현재 파라미터 그룹을 L2‑norm 기반으로 정규화하고, 이를 전역적으로 정렬해 임시 프루닝 마스크를 만든 뒤, 연속된 두 에포크에서 얻은 마스크 간 Jaccard 유사도를 계산한다는 점이다. 이 유사도가 사전에 정의한 임계값(1‑ε) 이상으로 수렴하면 해당 에포크를 “안정적인 프루닝 시점(t*)”으로 판단하고, 그 시점에 최종 마스크를 고정한다.
프루닝 전 과정에 구조적 스파시티 정규화(Structured Sparsity Regularization)를 적용한다. 그룹 살리언시 S(g)는 각 파라미터 w∈g의 L2‑norm을 그룹 내 원소 수와 전체 그룹 수로 정규화한 값이며, 이는 서로 다른 크기의 그룹 간 공정한 비교를 가능하게 한다. 정규화 강도 λ_t는 에포크가 진행될수록 점진적으로 증가하도록 설계돼, 중요도가 낮은 그룹을 점차 0에 수렴시키면서 자연스럽게 스파시티를 학습한다.
알고리즘 흐름은 다음과 같다. (1) 무작위 초기화된 모델을 몇 에포크 학습한다. (2) 매 에포크마다 파라미터를 구조적 그룹으로 분할하고, Eq.(2)의 살리언시를 계산한다. (3) 목표 프루닝 비율 α에 맞춰 전역적으로 마스크를 생성하고, 임시 프루닝된 서브네트워크 M_t를 만든다. (4) Jaccard 기반 안정성 점수 J_t^avg를 구해 연속 r 에포크 평균을 업데이트한다. (5) J_t^avg 변화가 τ 이하이면 스파시티 학습 단계에 진입하고, 이후 J_t^avg가 1‑ε에 도달하면 t를 확정한다. (6) t에서 최종 마스크를 고정하고, 남은 에포크를 일반 학습으로 진행한다.
이 설계는 기존의 “프루닝‑사전학습‑미세조정” 파이프라인이 요구하던 다중 재학습 비용을 크게 절감한다. 특히, 프루닝 초기화를 무작위가 아닌 “학습 초기에 형성되는 안정적인 서브네트워크”에 기반함으로써 초기 프루닝이 초래하는 성능 저하를 방지한다. 또한, 그룹 기반 살리언시와 정규화가 결합돼 채널, 필터, 레이어 수준의 구조적 제약을 동시에 만족시키며, FLOPs, 메모리, 레이턴시 등 다양한 하드웨어 제약에 쉽게 확장 가능하다.
실험 결과는 VGG‑16, ResNet‑50, MobileNet‑V2 등 다양한 아키텍처에 대해 CIFAR‑10/100, ImageNet 데이터셋에서 검증되었다. ResNet‑50 기준 ImageNet에서 57 % FLOPs 감소와 1.38× 학습 가속을 달성하면서 Top‑1 75.49 %, Top‑5 92.63 %라는 SOTA 수준의 정확도를 기록했다. 동일한 프루닝 비율에서 기존 PaT, LAASP, OTO 등과 비교했을 때, 정확도 손실이 거의 없으며 학습 시간은 평균 30 % 이상 단축되었다.
한계점으로는 안정성 임계값 τ와 ε를 데이터·모델에 따라 튜닝해야 하는 점, 그리고 그룹 정의가 네트워크 구조에 따라 달라질 수 있어 자동화된 그룹화 방법이 추가 연구가 필요하다는 점을 들 수 있다. 그러나 전체적인 설계는 단일 사이클 내에서 프루닝과 학습을 동시에 수행함으로써 실시간·임베디드 환경에 적합한 경량화 전략을 제공한다는 점에서 큰 의의를 가진다.
댓글 및 학술 토론
Loading comments...
의견 남기기