SOCKET: 긴 문맥을 위한 부드러운 충돌 커널 기반 희소 어텐션
초록
SOCKET은 전통적인 하드 LSH 대신 소프트 LSH를 이용해 토큰 간 유사도를 연속적인 충돌 확률로 변환하고, 이를 토대로 토큰을 순위 매겨 희소 어텐션을 구현한다. 키와 쿼리를 각각 하나의 버킷에 할당하되, 쿼리는 모든 버킷에 확률적으로 분포시켜 점수를 산출한다. 이 점수와 값 벡터의 L2 노름을 곱해 top‑k 토큰을 선택하고, 선택된 토큰에 대해 정확한 어텐션을 수행한다. CUDA 기반 스코어링 커널과 FlashDecode Triton 백엔드를 결합해 기존 FlashAttention 대비 1.5배 높은 처리량을 달성했으며, 다양한 모델·벤치마크에서 기존 희소 어텐션 기법들을 능가한다.
상세 분석
SOCKET은 긴 컨텍스트 추론에서 어텐션 연산이 차지하는 비용을 크게 낮추기 위해, LSH를 “후보 생성”이 아닌 “스코어링” 단계로 재정의한다. 기존 하드 LSH는 임의의 랜덤 프로젝션 후 부호 함수를 적용해 이진 버킷을 만들고, 쿼리와 키가 동일 버킷에 들어가면 충돌(1) 아니면 비충돌(0)으로 점수를 매긴다. 이 방식은 충돌 여부가 이산적이기 때문에 미세한 유사도 차이를 구분하지 못하고, top‑k 토큰을 정확히 랭킹하기에 부적합하다. SOCKET은 이를 보완하기 위해 “소프트 버킷 확률”을 도입한다. 각 해시 테이블마다 쿼리 벡터를 정규화·tanh 변환한 뒤, 미리 정의된 코너 벡터와 내적을 취해 로그잇을 만든다. 이를 softmax(온도 τ)로 정규화하면, 쿼리가 각 버킷에 할당될 확률 분포 p⁽ℓ⁾(r|q)가 얻어진다. 키는 여전히 하나의 버킷에 고정되지만, 쿼리의 확률 질량이 해당 버킷에 집중될수록 해당 키는 높은 s_soft(k_j,q)=∑_ℓ p⁽ℓ⁾(b⁽ℓ⁾_j|q) 점수를 받는다. 이 점수에 값 벡터의 L2 노름 ‖v_j‖²를 가중하면 최종 선택 기준 b_w_j=∑_ℓ p⁽ℓ⁾(b⁽ℓ⁾_j|q)·‖v_j‖²가 된다. 이렇게 하면 전체 키를 읽지 않고도 “soft count”와 ‖v‖²만으로 중요한 토큰을 추정할 수 있어 메모리 트래픽이 크게 감소한다.
이론적으로 SOCKET은 softmax 커널에 근접한 커널을 근사한다는 정리(논문 Theorem 3)를 제공한다. 즉, 소프트 LSH 점수는 exp(b_w_j) 형태의 가중치를 만들고, 이를 정규화하면 원래 어텐션의 확률 분포와 높은 상관관계를 가진다. 따라서 선택된 top‑k 토큰이 실제 어텐션 질량을 크게 차지함을 보장한다.
시스템 구현 측면에서는 키 해시를 사전 단계(pre‑fill)에서 한 번만 수행하고, GPU 메모리에 버킷 ID와 정규화된 ‖v‖²만 저장한다. 디코딩 시에는 쿼리당 L개의 soft bucket 확률을 계산하고, 모든 키에 대해 b_w_j를 한 번에 구한 뒤 top‑k를 추출한다. 이 과정은 CUDA 커널 하나로 구현돼 메모리 접근을 최소화하고, 이후 선택된 토큰에 대해서는 FlashDecode Triton 백엔드가 제공하는 고성능 희소 어텐션 연산을 적용한다. 실험 결과, Llama‑3.1‑8B‑Instruct와 Qwen3‑8B 모델을 32K~128K 토큰 길이에서 테스트했을 때, 기존 하드 LSH 기반 방법, k‑means 기반 PQCache, 그리고 최신 희소 어텐션 기법(Quest, Double‑Sparsity 등)보다 정확도(NDCG, Jaccard)와 처리량 모두 우수했다. 특히 FlashAttention 대비 1.5×의 스루풋 향상을 기록했으며, 메모리 사용량도 KV 캐시 외에 몇 비트 정도만 추가로 요구한다.
요약하면, SOCKET은 (1) 데이터‑의존적 클러스터링 없이 랜덤 프로젝션만으로 빠른 인덱스 구축, (2) 소프트 충돌 확률을 이용한 안정적인 토큰 순위 매김, (3) 이론적 커널 근사 보장, (4) CUDA‑최적화와 Triton‑기반 백엔드 결합을 통한 실시간 추론 가속이라는 네 가지 핵심 장점을 제공한다.
댓글 및 학술 토론
Loading comments...
의견 남기기