시각 프롬프트를 활용한 재귀 하이퍼네트워크 기반 구조적 프루닝 PASS

시각 프롬프트를 활용한 재귀 하이퍼네트워크 기반 구조적 프루닝 PASS
안내: 본 포스트의 한글 요약 및 분석 리포트는 AI 기술을 통해 자동 생성되었습니다. 정보의 정확성을 위해 하단의 [원본 논문 뷰어] 또는 ArXiv 원문을 반드시 참조하시기 바랍니다.

초록

PASS는 시각 프롬프트와 레이어 가중치 통계를 입력으로 받아, LSTM 기반 재귀 하이퍼네트워크가 각 레이어의 채널 마스크를 순차적으로 생성하도록 설계된 구조적 프루닝 프레임워크이다. 채널 의존성을 고려하면서 시각 프롬프트가 제공하는 데이터‑중심 정보를 활용해 높은 정확도와 FLOPs 감소를 동시에 달성한다. CIFAR‑10/100, Tiny‑ImageNet, Food101 등 6개 데이터셋과 ResNet·VGG·ViT·Swin 등 7개 모델에서 기존 프루닝 기법 대비 1‑3% 정확도 향상 또는 0.35× 이상의 속도 향상을 기록하였다.

상세 분석

PASS는 기존 구조적 프루닝이 주로 모델‑중심적인 가중치 통계(예: L1/L2 노름, BN 스케일)만을 활용해 채널 중요도를 추정하고, 레이어 간 의존성을 무시하는 한계를 극복하고자 한다. 핵심 아이디어는 ‘시각 프롬프트’라는 외부 입력을 통해 데이터‑중심적인 신호를 모델에 주입함으로써, 채널 선택 과정에 추가적인 힌트를 제공하는 것이다. 이를 위해 저자들은 두 단계의 모듈을 설계한다. 첫 번째는 3‑layer CNN으로 구성된 프롬프트 인코더(g ω)이며, 입력 이미지에 더해지는 프롬프트 V를 임베딩 벡터로 변환해 LSTM의 초기 hidden state로 사용한다. 두 번째는 LSTM 기반 하이퍼네트워크(θ)로, 이전 레이어의 마스크 M(i‑1)와 현재 레이어의 가중치 통계 eW(i) 를 결합해 현재 레이어의 마스크 M(i)를 예측한다. 여기서 eW(i)=M(i‑1)⊗W(i) 로, 이미 프루닝된 입력 채널 정보를 반영한다. LSTM은 순차적 의존성을 자연스럽게 모델링하므로, 한 레이어에서 제거된 채널이 다음 레이어에 미치는 영향을 학습한다. 마스크 생성 과정은 두 단계로 이루어진다. (1) LSTM 출력 임베딩을 레이어별 선형 변환기로 매핑해 채널별 중요 점수를 얻고, (2) Straight‑Through Estimator를 이용해 이 점수를 이진 마스크로 양자화한다. 전체 네트워크는 손실 L(Φ_{bW}(x+V), y) 를 최소화하도록 공동 최적화되며, 여기서 Φ_{bW}는 프루닝된 가중치 bW를 가진 원본 CNN이다. 프루닝 단계와 이후 파인튜닝 단계에서 모두 시각 프롬프트 V를 학습 파라미터로 포함시켜, 프루닝 후에도 프롬프트가 모델 성능을 보정하도록 설계했다. 실험에서는 ResNet‑18/34/50, VGG‑16, ResNeXt‑50, ViT‑B/16, Swin‑T 등 다양한 아키텍처와 CIFAR‑10/100, Tiny‑ImageNet, Food101, DTD, StanfordCars 등 6개 데이터셋을 대상으로 비교하였다. PASS는 동일 FLOPs 수준에서 기존 방법 대비 1‑3% 정확도 향상을 보였으며, 80% 정확도 목표에서는 0.35배 이상의 속도 향상을 달성했다. 특히, 학습된 마스크와 하이퍼네트워크가 다른 데이터셋이나 모델에 전이될 때도 성능 저하가 적어, 마스크와 프롬프트가 일반화 가능한 구조적 정보를 담고 있음을 확인했다. 전체적으로 PASS는 데이터‑중심 프롬프트와 모델‑중심 가중치 통계를 효과적으로 결합하고, 레이어 간 의존성을 재귀적으로 학습함으로써 기존 프루닝 기법보다 더 정교하고 효율적인 채널 선택을 가능하게 한다.


댓글 및 학술 토론

Loading comments...

의견 남기기