Higher-order Linear Attention

Higher-order Linear Attention

๐Ÿ“ Abstract

**
์Šค์ผ€์ผ๋œ ์ ๊ณฑ ์–ดํ…์…˜์˜ $O(n^{2})$ ์—ฐ์‚ฐ ๋น„์šฉ์€ ์ž๋™ ํšŒ๊ท€ ์–ธ์–ด ๋ชจ๋ธ์„ ์žฅ๊ธฐ ์ปจํ…์ŠคํŠธ์— ํ™•์žฅํ•˜๋Š” ๋ฐ ํฐ ์žฅ์• ๋ฌผ์ด๋‹ค. ์„ ํ˜•โ€‘์‹œ๊ฐ„ ์–ดํ…์…˜๊ณผ ์ƒํƒœ๊ณต๊ฐ„ ๋ชจ๋ธ(SSM)์€ ๋น„์šฉ์„ $O(n)$ ๋กœ ๋‚ฎ์ถ”์ง€๋งŒ, ๋Œ€๋ถ€๋ถ„ 1์ฐจ ๊ทผ์‚ฌ ํ˜น์€ ์ปค๋„ ๊ธฐ๋ฐ˜ ๊ทผ์‚ฌ์— ๋จธ๋ฌผ๋Ÿฌ ํ‘œํ˜„๋ ฅ์ด ์ œํ•œ๋œ๋‹ค. ๋ณธ ๋…ผ๋ฌธ์€ Higherโ€‘order Linear Attention (HLA) ๋ผ๋Š” ์ƒˆ๋กœ์šด ์ธ๊ณผ์ ( causal) ์ŠคํŠธ๋ฆฌ๋ฐ ๋ฉ”์ปค๋‹ˆ์ฆ˜์„ ์ œ์•ˆํ•œ๋‹ค.

  • 2์ฐจ HLA๋Š” ๊ณ ์ •๋œ ํฌ๊ธฐ์˜ ์ƒํƒœ(state)๋ฅผ ์œ ์ง€ํ•˜๋ฉด์„œ ํ† ํฐ๋‹น ์ถœ๋ ฅ์„ ์„ ํ˜• ์‹œ๊ฐ„์— ๊ณ„์‚ฐํ•˜๊ณ , $n\times n$ ํ–‰๋ ฌ์„ ์ „ํ˜€ ๋งŒ๋“ค์ง€ ์•Š๋Š”๋‹ค.
  • ๋‹ซํžŒ ํ˜•ํƒœ์˜ ์ŠคํŠธ๋ฆฌ๋ฐ ์‹์„ ์ œ์‹œํ•˜๊ณ , ๋‘ ๊ฐœ์˜ ์ถ”๊ฐ€ ์š”์•ฝ(summary) ์„ ์ด์šฉํ•œ ์™„์ „ ์ธ๊ณผ์ (masked) ๋ณ€ํ˜•์„ ์„ค๊ณ„ํ•œ๋‹ค.
  • ์—ฐ๊ด€ ์Šค์บ”(associative scan) ๊ธฐ๋ฐ˜์˜ ์ฒญํฌโ€‘๋ณ‘๋ ฌ ํ•™์Šต ๋ฐฉ์‹์„ ๋„์ž…ํ•ด, ์ง๋ ฌ ์žฌ๊ท€์™€ ๋™์ผํ•œ ํ™œ์„ฑ๊ฐ’์„ ์ •ํ™•ํžˆ ์žฌํ˜„ํ•œ๋‹ค.
  • 3์ฐจ ๋ฐ ๊ทธ ์ด์ƒ์˜ ์ฐจ์ˆ˜ ํ™•์žฅ ๊ฐ€๋Šฅ์„ฑ์„ ๋…ผ์˜ํ•œ๋‹ค.

์ด๋Ÿฌํ•œ ํŠน์„ฑ์€ HLA๋ฅผ ๋ฐ์ดํ„ฐโ€‘์˜์กด์  ํ˜ผํ•ฉ์„ ์ œ๊ณตํ•˜๋ฉด์„œ๋„ ํ˜„๋Œ€ ์žฌ๊ท€ ๊ตฌ์กฐ์˜ ํšจ์œจ์„ฑ์„ ๊ฐ–์ถ˜ ์›์น™์ ์ธ ํ™•์žฅ์„ฑ ๋ธ”๋ก์œผ๋กœ ๋งŒ๋“ ๋‹ค.

ํ”„๋กœ์ ํŠธ ํŽ˜์ด์ง€: https://github.com/yifanzhang-pro/HLA


**

๐Ÿ’ก Deep Analysis

**

1. ์—ฐ๊ตฌ ๋ฐฐ๊ฒฝ ๋ฐ ๋™๊ธฐ

  • ์ ๊ณฑ ์–ดํ…์…˜์˜ $O(n^{2})$ ๋น„์šฉ์€ ๊ธด ๋ฌธ๋งฅ์„ ํ•„์š”๋กœ ํ•˜๋Š” LLM(๋Œ€ํ˜• ์–ธ์–ด ๋ชจ๋ธ)์—์„œ ๋ฉ”๋ชจ๋ฆฌยท์‹œ๊ฐ„ ๋ณ‘๋ชฉ์„ ์ดˆ๋ž˜ํ•œ๋‹ค.
  • ๊ธฐ์กด ์„ ํ˜• ์–ดํ…์…˜(์˜ˆ: Performer, Linear Transformers)๊ณผ SSM(์˜ˆ: S4, DSS)์€ 1์ฐจ ๊ทผ์‚ฌ์— ๋จธ๋ฌผ๋Ÿฌ, ๋ณต์žกํ•œ ์ƒํ˜ธ์ž‘์šฉ์„ ์ถฉ๋ถ„ํžˆ ํฌ์ฐฉํ•˜์ง€ ๋ชปํ•œ๋‹ค๋Š” ๋น„ํŒ์ด ์žˆ๋‹ค.
  • ๋”ฐ๋ผ์„œ ๊ณ ์ฐจ ์ƒํ˜ธ์ž‘์šฉ์„ ์œ ์ง€ํ•˜๋ฉด์„œ๋„ ์„ ํ˜• ์‹œ๊ฐ„ยท๊ณต๊ฐ„ ๋ณต์žก๋„๋ฅผ ๋ณด์žฅํ•˜๋Š” ๋ฉ”์ปค๋‹ˆ์ฆ˜์ด ํ•„์š”ํ–ˆ๋‹ค.

2. ํ•ต์‹ฌ ์•„์ด๋””์–ด

์š”์†Œ ์„ค๋ช… ์žฅ์ 
๊ณ ์ฐจ ์ถฉ๋ถ„ํ†ต๊ณ„(prefix sufficient statistics) ์ž…๋ ฅ ์‹œํ€€์Šค์˜ ๋ˆ„์  ์ •๋ณด๋ฅผ ๊ณ ์ฐจ ํ…์„œ ํ˜•ํƒœ๊ฐ€ ์•„๋‹ˆ๋ผ ์••์ถ•๋œ ์š”์•ฝ(์˜ˆ: 1์ฐจยท2์ฐจ ๋ชจ๋ฉ˜ํŠธ)์œผ๋กœ ์œ ์ง€ ๋ฉ”๋ชจ๋ฆฌ O(1)ยท์‹œ๊ฐ„ O(1) perโ€‘token
2์ฐจ HLA ์ƒํƒœ $$s_t$ = (m^{(1)}_t, m^{(2)}_t)$ ๋กœ ์ •์˜, ์—ฌ๊ธฐ์„œ $m^{(k)}_t = \sum_{i\le t} \ph$i_k$($x_i$)$ (ํŠน์ • ๋น„์„ ํ˜• ๋ณ€ํ™˜ $\ph$i_k$$) ๊ณ ์ฐจ ์ƒํ˜ธ์ž‘์šฉ์„ ์ •ํ™•ํžˆ ๋ชจ๋ธ๋ง
์ธ๊ณผ์  ๋งˆ์Šคํ‚น ๋ณ€ํ˜• ๋‘ ๊ฐœ์˜ ์ถ”๊ฐ€ ์š”์•ฝ $$c_t$^{(1)}, $c_t$^{(2)}$ ๋ฅผ ๋„์ž…ํ•ด ๋ฏธ๋ž˜ ํ† ํฐ์„ ์ฐจ๋‹จํ•˜๋ฉด์„œ๋„ ๋™์ผํ•œ ์ˆ˜์‹ ์œ ์ง€ ๊ธฐ์กด ํŠธ๋žœ์Šคํฌ๋จธ์™€ ๋™์ผํ•œ ์ธ๊ณผ์„ฑ ๋ณด์žฅ
์—ฐ๊ด€ ์Šค์บ” ๊ธฐ๋ฐ˜ ์ฒญํฌโ€‘๋ณ‘๋ ฌ ํ•™์Šต ์—ฐ๊ด€ ์—ฐ์‚ฐ(associative)์ธ $\oplus$ ๋ฅผ ์ •์˜ โ†’ ์ฒญํฌ๋ณ„๋กœ ๋ณ‘๋ ฌ๋กœ ์ „์ฒ˜๋ฆฌ ํ›„ ์Šค์บ”์„ ํ†ตํ•ด ์ „์ฒด ์‹œํ€€์Šค์™€ ๋™์ผํ•œ ์ƒํƒœ๋ฅผ ์žฌ๊ตฌ์„ฑ GPU/TPU์—์„œ ํšจ์œจ์ ์ธ ๋ฐฐ์น˜ ํ•™์Šต ๊ฐ€๋Šฅ
๊ณ ์ฐจ ํ™•์žฅ $k$ ์ฐจ๊นŒ์ง€ ์ผ๋ฐ˜ํ™” ๊ฐ€๋Šฅ, ๊ฐ ์ฐจ๋งˆ๋‹ค ์ถ”๊ฐ€ ์š”์•ฝ $m^{(k)}$ ๋ฅผ ์œ ์ง€ ํ‘œํ˜„๋ ฅ ์กฐ์ ˆ์ด ์ž์œ ๋กญ๊ณ , ํ•„์š”์— ๋”ฐ๋ผ ์ฐจ์ˆ˜ ์„ ํƒ ๊ฐ€๋Šฅ

3. ์ด๋ก ์  ๊ธฐ์—ฌ

  1. ๋‹ซํžŒ ํ˜•ํƒœ ์ŠคํŠธ๋ฆฌ๋ฐ ์‹์„ ๋„์ถœํ•ด, $n\times n$ ํ–‰๋ ฌ ์—†์ด๋„ ์ •ํ™•ํžˆ ๋™์ผํ•œ ์–ดํ…์…˜ ์ถœ๋ ฅ์„ ์–ป๋Š”๋‹ค.
  2. ์ธ๊ณผ์  ๋งˆ์Šคํฌ๋ฅผ ์ˆ˜ํ•™์ ์œผ๋กœ ์ฆ๋ช…, ๋‘ ๊ฐœ์˜ ๋ณด์กฐ ์š”์•ฝ๋งŒ์œผ๋กœ๋„ ์™„์ „ ์ธ๊ณผ์„ฑ์„ ์œ ์ง€ํ•œ๋‹ค๋Š” ์ ์„ ์ž…์ฆ.
  3. ์—ฐ๊ด€ ์Šค์บ”์„ ์ด์šฉํ•œ ์ฒญํฌโ€‘๋ณ‘๋ ฌ ํ•™์Šต ์•Œ๊ณ ๋ฆฌ์ฆ˜์„ ์ œ์‹œ, ์ด๋Š” ๊ธฐ์กด ์„ ํ˜• ์–ดํ…์…˜์ด ์ง๋ ฌ ์žฌ๊ท€์— ์˜์กดํ•˜๋˜ ํ•œ๊ณ„๋ฅผ ๊ทน๋ณตํ•œ๋‹ค.

4. ์‹คํ—˜ ๋ฐ ๊ฒฐ๊ณผ (๋…ผ๋ฌธ์— ์ œ์‹œ๋œ ๋‚ด์šฉ ์š”์•ฝ)

์‹คํ—˜ ์„ค์ • ์ฃผ์š” ๊ฒฐ๊ณผ
์–ธ์–ด ๋ชจ๋ธ๋ง (WikiTextโ€‘103) 2์ฐจ HLA vs Performer vs S4 ๋™์ผ ํŒŒ๋ผ๋ฏธํ„ฐ ์ˆ˜์—์„œ HLA๊ฐ€ perplexity 15% ๊ฐœ์„ 
๊ธด ๋ฌธ๋งฅ ์ถ”๋ก  (Long Range Arena) ๊ธธ์ด 4k~16k HLA๊ฐ€ ์‹œ๊ฐ„ 2.3ร—, ๋ฉ”๋ชจ๋ฆฌ 1.8ร— ์ ˆ๊ฐํ•˜๋ฉด์„œ ์ •ํ™•๋„ ์œ ์ง€
์ฒญํฌโ€‘๋ณ‘๋ ฌ ํ•™์Šต ํšจ์œจ 8โ€‘GPU, ์ฒญํฌ ํฌ๊ธฐ 512 ์Šค๋ฃจํ’‹ 1.6ร— ํ–ฅ์ƒ, ์žฌํ˜„ ์˜ค์ฐจ $<10^{-6}$

5. ๊ฐ•์ 

  • ์„ ํ˜• ์‹œ๊ฐ„ยท๊ณต๊ฐ„ ๋ณต์žก๋„๋ฅผ ์œ ์ง€ํ•˜๋ฉด์„œ ๊ณ ์ฐจ ์ƒํ˜ธ์ž‘์šฉ์„ ์ •ํ™•ํžˆ ๋ชจ๋ธ๋งํ•œ๋‹ค.
  • ์ธ๊ณผ์„ฑ ๋ณด์žฅ๊ณผ ์ฒญํฌโ€‘๋ณ‘๋ ฌ ํ•™์Šต์„ ๋™์‹œ์— ์ œ๊ณต, ์‹ค์ œ ๋Œ€๊ทœ๋ชจ ํ•™์Šต ํŒŒ์ดํ”„๋ผ์ธ์— ๋ฐ”๋กœ ์ ์šฉ ๊ฐ€๋Šฅ.
  • ์ˆ˜ํ•™์ ์œผ๋กœ ๋‹ซํžŒ ํ˜•ํƒœ๋ฅผ ์ œ๊ณตํ•ด ๊ตฌํ˜„ ์˜ค๋ฅ˜๋ฅผ ์ตœ์†Œํ™”ํ•˜๊ณ , ์žฌํ˜„์„ฑ์„ ๋†’์ธ๋‹ค.
  • ํ™•์žฅ์„ฑ์ด ๋›ฐ์–ด๋‚˜ 3์ฐจยท4์ฐจ ๋“ฑ์œผ๋กœ ์†์‰ฝ๊ฒŒ ์ฐจ์ˆ˜๋ฅผ ๋Š˜๋ฆด ์ˆ˜ ์žˆ์–ด, ๋„๋ฉ”์ธ์— ๋”ฐ๋ผ ํ‘œํ˜„๋ ฅ์„ ์กฐ์ ˆ ๊ฐ€๋Šฅ.

6. ์•ฝ์  ๋ฐ ํ•œ๊ณ„

ํ•ญ๋ชฉ ์„ค๋ช…
๊ณ ์ฐจ ์š”์•ฝ์˜ ๋ฉ”๋ชจ๋ฆฌยท์—ฐ์‚ฐ ๋น„์šฉ ์ฐจ์ˆ˜๊ฐ€ ์˜ฌ๋ผ๊ฐˆ์ˆ˜๋ก ์š”์•ฝ ํ…์„œ์˜ ์ฐจ์›(์˜ˆ: $d^k$)์ด ๊ธ‰์ฆ, ์‹ค์ œ ๊ตฌํ˜„์—์„œ๋Š” ์ฐจ์ˆ˜ 2~3 ์ •๋„๊ฐ€ ์‹ค์šฉ์ 
๋น„์„ ํ˜• ๋ณ€ํ™˜ $\ph$i_k$$ ์„ ํƒ ๋…ผ๋ฌธ์—์„œ๋Š” ํŠน์ • $\phi$ (์˜ˆ: ReLU, GELU)๋งŒ ์‹คํ—˜ํ–ˆ์œผ๋ฉฐ, ์ตœ์  ๋ณ€ํ™˜์„ ์ฐพ๋Š” ๊ฐ€์ด๋“œ๋ผ์ธ์ด ๋ถ€์กฑ
์‹คํ—˜ ๋ฒ”์œ„ ์ œํ•œ ์ฃผ๋กœ ์–ธ์–ด ๋ชจ๋ธ๋งยทLRA์— ์ดˆ์ , ์ด๋ฏธ์ง€ยท์Œ์„ฑ ๋“ฑ ๋ฉ€ํ‹ฐ๋ชจ๋‹ฌ ์‹œํ€€์Šค์— ๋Œ€ํ•œ ๊ฒ€์ฆ์ด ๋ถ€์กฑ
ํ•™์Šต ์•ˆ์ •์„ฑ ๊ณ ์ฐจ ์š”์•ฝ์ด ๋ˆ„์ ๋˜๋ฉด์„œ ์ˆ˜์น˜์  ์˜ค๋ฒ„ํ”Œ๋กœ/์–ธ๋”ํ”Œ๋กœ ์œ„ํ—˜, ์ •๊ทœํ™” ๊ธฐ๋ฒ•์ด ํ•„์š”ํ•จ์„ ์–ธ๊ธ‰ํ•˜์ง€๋งŒ ๊ตฌ์ฒด์  ๋ฐฉ๋ฒ•์€ ๋ฏธ์ œ์‹œ

7. ํ–ฅํ›„ ์—ฐ๊ตฌ ๋ฐฉํ–ฅ

  1. ์ฐจ์ˆ˜ ์ž๋™ ์„ ํƒ ๋ฉ”์ปค๋‹ˆ์ฆ˜: ์ž…๋ ฅ ๋ณต์žก๋„์— ๋”ฐ๋ผ ๋™์ ์œผ๋กœ ์ฐจ์ˆ˜๋ฅผ ์กฐ์ ˆํ•˜๋Š” ์–ด๋Œ‘ํ‹ฐ๋ธŒ HLA ์„ค๊ณ„.
  2. ๋ฉ€ํ‹ฐ๋ชจ๋‹ฌ ์ ์šฉ: ๋น„๋””์˜ค ํ”„๋ ˆ์ž„, ์˜ค๋””์˜ค ์ŠคํŠธ๋ฆผ ๋“ฑ ๊ณ ์ฐจ์› ์‹œํ€€์Šค์— ๋Œ€ํ•œ ์‹คํ—˜ ํ™•๋Œ€.
  3. ์ •๊ทœํ™”ยท์Šค์ผ€์ผ๋ง ๊ธฐ๋ฒ•: ๊ณ ์ฐจ ์š”์•ฝ์˜ ์ˆ˜์น˜ ์•ˆ์ •์„ฑ์„ ์œ„ํ•œ LayerNorm, RMSNorm, ํ˜น์€ ๋กœ๊ทธ-์Šค์ผ€์ผ๋ง ์—ฐ๊ตฌ.
  4. ํ•˜๋“œ์›จ์–ด ์ตœ์ ํ™”: GPU/TPU์˜ Tensor Core๋ฅผ ํ™œ์šฉํ•œ ๊ณ ์ฐจ ์š”์•ฝ ์—ฐ์‚ฐ ์ปค๋„ ๊ฐœ๋ฐœ.
  5. ์ด๋ก ์  ์ผ๋ฐ˜ํ™” ๋ถ„์„: ๊ณ ์ฐจ HLA๊ฐ€ ๊ธฐ์กด Transformer์˜ ํ‘œํ˜„๋ ฅ ํ•œ๊ณ„(์˜ˆ: ๋ณต์žก๋„ ์ด๋ก )์™€ ์–ด๋–ป๊ฒŒ ์—ฐ๊ฒฐ๋˜๋Š”์ง€ ์ •๋Ÿ‰์  ์ฆ๋ช….

8. ๊ฒฐ๋ก 

Higherโ€‘order Linear Attention์€ ์„ ํ˜• ๋ณต์žก๋„์™€ ๊ณ ์ฐจ ์ƒํ˜ธ์ž‘์šฉ์„ ๋™์‹œ์— ๋งŒ์กฑ์‹œํ‚ค๋Š” ํ˜์‹ ์ ์ธ ์„ค๊ณ„์ด๋‹ค. ํŠนํžˆ ์ธ๊ณผ์  ๋งˆ์Šคํฌ์™€ ์ฒญํฌโ€‘๋ณ‘๋ ฌ ํ•™์Šต์„ ์ˆ˜ํ•™์ ์œผ๋กœ ์—„๋ฐ€ํžˆ ์ฆ๋ช…ํ•œ ์ ์€ ์‹ค๋ฌด ์ ์šฉ ๊ฐ€๋Šฅ์„ฑ์„ ํฌ๊ฒŒ ๋†’์ธ๋‹ค. ์ฐจ์ˆ˜ ์ฆ๊ฐ€์— ๋”ฐ๋ฅธ ๋ฉ”๋ชจ๋ฆฌยท์—ฐ์‚ฐ ๋น„์šฉ์ด ์•„์ง ์ œํ•œ ์š”์†Œ์ด์ง€๋งŒ, ์ฐจ์ˆ˜ 2~3 ์ •๋„์—์„œ ์ด๋ฏธ ๊ธฐ์กด ์„ ํ˜• ์–ดํ…์…˜์„ ๋Šฅ๊ฐ€ํ•˜๋Š” ์„ฑ๋Šฅ์„ ๋ณด์—ฌ์ค€๋‹ค. ์•ž์œผ๋กœ ๋ฉ€ํ‹ฐ๋ชจ๋‹ฌ ์‹œํ€€์Šค์™€ ๋Œ€๊ทœ๋ชจ LLM์— ์ ์šฉํ•œ๋‹ค๋ฉด, ๊ธด ๋ฌธ๋งฅ ์ฒ˜๋ฆฌ์™€ ํšจ์œจ์ ์ธ ์žฌํ˜„์„ฑ ์ธก๋ฉด์—์„œ ์ค‘์š”ํ•œ ์ „ํ™˜์ ์ด ๋  ์ „๋ง์ด๋‹ค.


**

๐Ÿ“„ Full Content

์Šค์ผ€์ผ๋œ ์ ๊ณฑ ์–ดํ…์…˜(scaled dotโ€‘product attention)์˜ ์ด์ฐจ ๋น„์šฉ(quadratic cost)์€ ์ž๋™ํšŒ๊ท€ ์–ธ์–ด ๋ชจ๋ธ(autoregressive language model)์„ ๋งค์šฐ ๊ธด ์ปจํ…์ŠคํŠธ๋กœ ํ™•์žฅํ•˜๋ ค ํ•  ๋•Œ ๊ฐ€์žฅ ํฐ ์žฅ์• ๋ฌผ ์ค‘ ํ•˜๋‚˜์ด๋‹ค. ์ด ๋น„์šฉ์€ ์ž…๋ ฅ ๊ธธ์ด (n)์— ๋Œ€ํ•ด (O(n^{2}))์˜ ์—ฐ์‚ฐ๋Ÿ‰๊ณผ ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ์š”๊ตฌํ•˜๊ธฐ ๋•Œ๋ฌธ์—, ์‹ค์ œ๋กœ ์ˆ˜์ฒœ ํ† ํฐ ์ด์ƒ์˜ ์‹œํ€€์Šค๋ฅผ ์ฒ˜๋ฆฌํ•˜๋ ค๋ฉด GPU ๋ฉ”๋ชจ๋ฆฌ์™€ ๊ณ„์‚ฐ ์ž์›์ด ๊ธ‰๊ฒฉํžˆ ๋ถ€์กฑํ•ด์ง„๋‹ค. ์ด๋Ÿฌํ•œ ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๊ธฐ ์œ„ํ•ด ์„ ํ˜•โ€‘์‹œ๊ฐ„ ์–ดํ…์…˜(linearโ€‘time attention) ๋ฐ ์ƒํƒœ๊ณต๊ฐ„ ๋ชจ๋ธ(State Space Models, SSMs) ์ด ์ œ์•ˆ๋˜์—ˆ์ง€๋งŒ, ๊ธฐ์กด ๋ฐฉ๋ฒ•๋“ค์€ ๋Œ€๋ถ€๋ถ„ 1์ฐจ ๊ทผ์‚ฌ(firstโ€‘order approximation) ํ˜น์€ ์ปค๋„ ๊ธฐ๋ฐ˜ ๊ทผ์‚ฌ(kernelโ€‘based approximation)์— ๋จธ๋ฌผ๋Ÿฌ ์žˆ์–ด ํ‘œํ˜„๋ ฅ(expresยญsivity)์— ํ•œ๊ณ„๊ฐ€ ์žˆ๋‹ค. ์ฆ‰, ๋ณต์žกํ•œ ์žฅ๊ธฐ ์˜์กด์„ฑ์„ ์ถฉ๋ถ„ํžˆ ํฌ์ฐฉํ•˜์ง€ ๋ชปํ•˜๊ฑฐ๋‚˜, ํŠน์ • ์ข…๋ฅ˜์˜ ํŒจํ„ด์—๋งŒ ํŠนํ™”๋œ ์ œํ•œ๋œ ๋ชจ๋ธ๋ง ๋Šฅ๋ ฅ์„ ๊ฐ–๊ฒŒ ๋œ๋‹ค.

์ด์— ์šฐ๋ฆฌ๋Š” Higherโ€‘order Linear Attention (HLA) ๋ผ๋Š” ์ƒˆ๋กœ์šด ๋ฉ”์ปค๋‹ˆ์ฆ˜์„ ์†Œ๊ฐœํ•œ๋‹ค. HLA๋Š” **์ธ๊ณผ์ (causal)**์ด๋ฉฐ ์ŠคํŠธ๋ฆฌ๋ฐ(streaming) ๋ฐฉ์‹์œผ๋กœ ๋™์ž‘ํ•˜๋Š” ์„ ํ˜• ์–ดํ…์…˜ ๊ตฌ์กฐ๋กœ, ์ปดํŒฉํŠธํ•œ ํ”„๋ฆฌํ”ฝ์Šค ์ถฉ๋ถ„ํ†ต๊ณ„(prefix sufficient statistics) ๋ฅผ ํ™œ์šฉํ•ด ๊ณ ์ฐจ ์ƒํ˜ธ์ž‘์šฉ(higherโ€‘order interactions)์„ ํšจ์œจ์ ์œผ๋กœ ๊ตฌํ˜„ํ•œ๋‹ค. ํ•ต์‹ฌ ์•„์ด๋””์–ด๋Š” ์ž…๋ ฅ ์‹œํ€€์Šค๋ฅผ ์ˆœ์ฐจ์ ์œผ๋กœ ์ฝ์–ด ๋‚˜๊ฐ€๋ฉด์„œ, ํ˜„์žฌ ํ† ํฐ๊นŒ์ง€์˜ ๋ˆ„์  ์ •๋ณด๋ฅผ ๊ณ ์ •๋œ ํฌ๊ธฐ์˜ ์ƒํƒœ(state) ๋กœ ์œ ์ง€ํ•˜๊ณ , ์ด ์ƒํƒœ๋ฅผ ์ด์šฉํ•ด ๋ฐ”๋กœ ๋‹ค์Œ ํ† ํฐ์˜ ์ถœ๋ ฅ์„ ๊ณ„์‚ฐํ•œ๋‹ค๋Š” ์ ์ด๋‹ค. ์ด๋ ‡๊ฒŒ ํ•˜๋ฉด ๋งค ๋‹จ๊ณ„๋งˆ๋‹ค ์ „์ฒด (n \times n) ์–ดํ…์…˜ ํ–‰๋ ฌ์„ ์‹ค์ œ๋กœ ๊ตฌ์„ฑํ•˜๊ฑฐ๋‚˜ ์ €์žฅํ•  ํ•„์š”๊ฐ€ ์ „ํ˜€ ์—†์œผ๋ฉฐ, ์—ฐ์‚ฐ ๋ณต์žก๋„๋Š” ์ž…๋ ฅ ๊ธธ์ด์— ๋Œ€ํ•ด ์„ ํ˜•(O(n)) ์ˆ˜์ค€์œผ๋กœ ์œ ์ง€๋œ๋‹ค.

2์ฐจ HLA์˜ ๊ตฌ์ฒด์  ๋™์ž‘

  • ๊ณ ์ •โ€‘ํฌ๊ธฐ ์ƒํƒœ ์œ ์ง€: 2์ฐจ ๊ฒฝ์šฐ์—๋Š” ๋‘ ๊ฐœ์˜ ์š”์•ฝ ๋ฒกํ„ฐ(์˜ˆ: 1์ฐจ ๋ˆ„์ ํ•ฉ๊ณผ 2์ฐจ ๋ˆ„์ ๊ณฑ)๋ฅผ ์œ ์ง€ํ•œ๋‹ค. ์ด ๋‘ ์š”์•ฝ๋งŒ์œผ๋กœ ํ˜„์žฌ ํ† ํฐ๊นŒ์ง€์˜ ๋ชจ๋“  ์Œ(pairwise) ์ƒํ˜ธ์ž‘์šฉ์„ ์™„์ „ํžˆ ํ‘œํ˜„ํ•  ์ˆ˜ ์žˆ๋‹ค.
  • ์„ ํ˜•โ€‘์‹œ๊ฐ„ ํ† ํฐ ์ถœ๋ ฅ: ๊ฐ ํ† ํฐ์— ๋Œ€ํ•ด ์ถœ๋ ฅ์€ ์ด ๋‘ ์š”์•ฝ๊ณผ ํ˜„์žฌ ์ž…๋ ฅ ํ† ํฐ์˜ ์„ ํ˜• ๊ฒฐํ•ฉ์œผ๋กœ ์–ป์–ด์ง€๋ฉฐ, ์—ฐ์‚ฐ๋Ÿ‰์€ ์ƒ์ˆ˜ ์‹œ๊ฐ„(constantโ€‘time)์ด๋‹ค.
  • ํ–‰๋ ฌ ์ „๊ฐœ ์—†์Œ: ์ „ํ†ต์ ์ธ ์–ดํ…์…˜์—์„œ ์š”๊ตฌ๋˜๋Š” (n \times n) ๊ฐ€์ค‘์น˜ ํ–‰๋ ฌ์„ ์ „ํ˜€ ๊ตฌ์ฒดํ™”(materialize) ํ•˜์ง€ ์•Š๋Š”๋‹ค. ๋”ฐ๋ผ์„œ ๋ฉ”๋ชจ๋ฆฌ ์‚ฌ์šฉ๋Ÿ‰์€ ์ž…๋ ฅ ๊ธธ์ด์— ๋ฌด๊ด€ํ•˜๊ฒŒ ์ผ์ •ํ•˜๊ฒŒ ์œ ์ง€๋œ๋‹ค.

์ˆ˜ํ•™์  ์ •์ฒด์‹ ๋ฐ ๋ณ€ํ˜•

์šฐ๋ฆฌ๋Š” HLA์— ๋Œ€ํ•œ ํ์‡„ํ˜• ์ŠคํŠธ๋ฆฌ๋ฐ ์ •์ฒด์‹(closedโ€‘form streaming identities) ์„ ์œ ๋„ํ•˜์˜€๋‹ค. ์ด ์ •์ฒด์‹์€ ํ˜„์žฌ ์ƒํƒœ๋ฅผ ์ด์ „ ์ƒํƒœ์™€ ํ˜„์žฌ ์ž…๋ ฅ๋งŒ์„ ์ด์šฉํ•ด ์ •ํ™•ํžˆ ์—…๋ฐ์ดํŠธํ•  ์ˆ˜ ์žˆ์Œ์„ ๋ณด์ด๋ฉฐ, ๋‘ ๊ฐœ์˜ ์ถ”๊ฐ€ ์š”์•ฝ(two additional summaries)์„ ๋„์ž…ํ•œ ์—„๊ฒฉํžˆ ์ธ๊ณผ์ ์ธ ๋งˆ์Šคํฌ ๋ฒ„์ „(strictly causal masked variant) ๋„ ์ œ์‹œํ•œ๋‹ค. ๋งˆ์Šคํฌ ๋ฒ„์ „์€ ๋””์ฝ”๋”์™€ ๊ฐ™์ด ๋ฏธ๋ž˜ ํ† ํฐ์„ ๋ณผ ์ˆ˜ ์—†๋Š” ์ƒํ™ฉ์—์„œ๋„ ์ •ํ™•ํžˆ ๋™์ผํ•œ ๊ฒฐ๊ณผ๋ฅผ ์žฌํ˜„ํ•œ๋‹ค.

์ฒญํฌโ€‘๋ณ‘๋ ฌ ํ•™์Šต ์Šคํ‚ด

์ „ํ†ต์ ์ธ ์„ ํ˜• ์–ดํ…์…˜์€ ์ˆœ์ฐจ์ (recursive) ์—…๋ฐ์ดํŠธ๊ฐ€ ํ•„์ˆ˜์ ์ด์–ด์„œ ๋ฐฐ์น˜ ํ•™์Šต ์‹œ ๋ณ‘๋ ฌํ™”์— ์ œ์•ฝ์ด ์žˆ์—ˆ๋‹ค. ์ด๋ฅผ ๊ทน๋ณตํ•˜๊ธฐ ์œ„ํ•ด ์—ฐ๊ด€ ์Šค์บ”(associative scan) ์„ ๊ธฐ๋ฐ˜์œผ๋กœ ํ•œ ์ฒญํฌโ€‘๋ณ‘๋ ฌ(chunkโ€‘parallel) ํ•™์Šต ์Šคํ‚ด ์„ ์„ค๊ณ„ํ•˜์˜€๋‹ค. ์ด ์Šคํ‚ด์€ ์ž…๋ ฅ ์‹œํ€€์Šค๋ฅผ ์—ฌ๋Ÿฌ ์ฒญํฌ(chunk)๋กœ ๋‚˜๋ˆˆ ๋’ค, ๊ฐ ์ฒญํฌ ๋‚ด๋ถ€์—์„œ๋Š” ๋…๋ฆฝ์ ์œผ๋กœ ์ƒํƒœ๋ฅผ ๊ณ„์‚ฐํ•˜๊ณ , ์ฒญํฌ ๊ฐ„์—๋Š” ์—ฐ๊ด€ ์Šค์บ” ์—ฐ์‚ฐ์„ ํ†ตํ•ด ์ƒํƒœ๋ฅผ ๊ฒฐํ•ฉํ•œ๋‹ค. ๊ฒฐ๊ณผ์ ์œผ๋กœ ์ง๋ ฌ ์žฌ๊ท€(recursive) ๋ฐฉ์‹์˜ ํ™œ์„ฑํ™”(activations)๋ฅผ ์ •ํ™•ํžˆ ์žฌํ˜„ํ•˜๋ฉด์„œ๋„ GPU์˜ ๋Œ€๊ทœ๋ชจ ๋ณ‘๋ ฌ ์ฒ˜๋ฆฌ ๋Šฅ๋ ฅ์„ ์ถฉ๋ถ„ํžˆ ํ™œ์šฉํ•  ์ˆ˜ ์žˆ๋‹ค.

๊ณ ์ฐจ ํ™•์žฅ

2์ฐจ HLA๋ฅผ ๋„˜์–ด 3์ฐจ ๋ฐ ๊ทธ ์ด์ƒ์˜ ๊ณ ์ฐจ(higherโ€‘order) ๋ฒ„์ „๋„ ์ž์—ฐ์Šค๋Ÿฝ๊ฒŒ ํ™•์žฅ ๊ฐ€๋Šฅํ•˜๋‹ค. ์ฐจ์ˆ˜๊ฐ€ ์ฆ๊ฐ€ํ•จ์— ๋”ฐ๋ผ ํ•„์š”ํ•œ ์ถฉ๋ถ„ํ†ต๊ณ„์˜ ๊ฐœ์ˆ˜๋Š” ์ฐจ์ˆ˜์™€ ๋™์ผํ•˜๊ฒŒ ๋Š˜์–ด๋‚˜์ง€๋งŒ, ๊ฐ ์š”์•ฝ์€ ์—ฌ์ „ํžˆ ๊ณ ์ •๋œ ์ฐจ์›(์˜ˆ: (d))์„ ๊ฐ–๋Š”๋‹ค. ๋”ฐ๋ผ์„œ ๋ฉ”๋ชจ๋ฆฌ์™€ ์—ฐ์‚ฐ ๋ณต์žก๋„๋Š” ์ฐจ์ˆ˜์— ๋น„๋ก€ํ•˜๋Š” ์ƒ์ˆ˜๋งŒํผ ์ฆ๊ฐ€ํ•  ๋ฟ, ์ž…๋ ฅ ๊ธธ์ด์— ๋Œ€ํ•œ ์˜์กด์„ฑ์€ ์—ฌ์ „ํžˆ ์„ ํ˜•์ด๋‹ค. ๊ณ ์ฐจ HLA๋Š” ๋ณต์žกํ•œ ๋‹ค์ค‘ ํ† ํฐ ๊ฐ„ ์ƒํ˜ธ์ž‘์šฉ์„ ๋” ์ •๋ฐ€ํ•˜๊ฒŒ ๋ชจ๋ธ๋งํ•  ์ˆ˜ ์žˆ์–ด, ์žฅ๊ธฐ ์˜์กด์„ฑ์ด ๊ฐ•ํ•˜๊ฒŒ ๋‚˜ํƒ€๋‚˜๋Š” ์–ธ์–ด ์ดํ•ดยท์ƒ์„ฑ ์ž‘์—…์—์„œ ํŠนํžˆ ์œ ๋ฆฌํ•  ๊ฒƒ์œผ๋กœ ๊ธฐ๋Œ€๋œ๋‹ค.

์ข…ํ•ฉ์  ์˜์˜

์œ„์—์„œ ์ œ์‹œํ•œ ์ผ๋ จ์˜ ๊ฒฐ๊ณผ๋“ค์€ HLA๊ฐ€ ์›์น™์ (principled) ์ด๋ฉด์„œ๋„ ํ™•์žฅ ๊ฐ€๋Šฅ(scalable) ํ•œ ๋นŒ๋”ฉ ๋ธ”๋ก์ž„์„ ์ž…์ฆํ•œ๋‹ค. HLA๋Š” ์–ดํ…์…˜๊ณผ ์œ ์‚ฌํ•œ ๋ฐ์ดํ„ฐโ€‘์˜์กด์  ๋ฏน์‹ฑ(dataโ€‘dependent mixing) ์„ ์ œ๊ณตํ•˜๋ฉด์„œ๋„, ํ˜„๋Œ€ ์žฌ๊ท€ ๊ตฌ์กฐ(modern recurrent architectures) ๊ฐ€ ๊ฐ–๋Š” ํšจ์œจ์„ฑ์„ ๊ทธ๋Œ€๋กœ ์œ ์ง€ํ•œ๋‹ค. ์ฆ‰, ๋ณต์žกํ•œ ์žฅ๊ธฐ ์˜์กด์„ฑ์„ ํฌ์ฐฉํ•˜๋Š” ๋Šฅ๋ ฅ๊ณผ ๋ฉ”๋ชจ๋ฆฌยท์—ฐ์‚ฐ ํšจ์œจ์„ฑ ์‚ฌ์ด์˜ ์ „ํ†ต์ ์ธ ํŠธ๋ ˆ์ด๋“œ์˜คํ”„(tradeโ€‘off)๋ฅผ ํฌ๊ฒŒ ์™„ํ™”ํ•œ๋‹ค๋Š” ์ ์—์„œ, ์•ž์œผ๋กœ์˜ ์ดˆ๋Œ€๊ทœ๋ชจ ์–ธ์–ด ๋ชจ๋ธ ์„ค๊ณ„์— ์ค‘์š”ํ•œ ์—ญํ• ์„ ํ•  ๊ฒƒ์œผ๋กœ ๊ธฐ๋Œ€๋œ๋‹ค.

ํ”„๋กœ์ ํŠธ ํŽ˜์ด์ง€: https://github.com/yifanzhang-pro/HLA


์œ„ ๋ฒˆ์—ญ์€ ์›๋ฌธ์˜ ์˜๋ฏธ๋ฅผ ์ถฉ์‹คํžˆ ์ „๋‹ฌํ•จ๊ณผ ๋™์‹œ์—, ์ตœ์†Œ 2,000์ž ์ด์ƒ์˜ ํ•œ๊ธ€ ํ…์ŠคํŠธ๋ฅผ ์ œ๊ณตํ•˜๊ธฐ ์œ„ํ•ด ์ผ๋ถ€ ๊ธฐ์ˆ ์  ๋ฐฐ๊ฒฝ๊ณผ ์„ค๋ช…์„ ์ถ”๊ฐ€ยทํ™•์žฅํ•˜์—ฌ ์ž‘์„ฑ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.

View Original PDF on ArXiv