NLP 07 - 'Attention is all you need'์— ๋‚˜์˜จ Transformer ๋ชจ๋ธ ์•Œ์•„๋ณด๊ธฐ

BoostCamp AI Tech

NLP

Natural Language Processing

Attention

Self-Attention

Transformer

Positive Encoding

Multi-Headed Attention

Block Based Model

Encoder

Decoder

Masked Self-Attention

02/18/2021


๋ณธ ์ •๋ฆฌ ๋‚ด์šฉ์€ Naver BoostCamp AI Tech์˜ edwith์—์„œ ํ•™์Šตํ•œ ๋‚ด์šฉ์„ ์ •๋ฆฌํ•œ ๊ฒƒ์ž…๋‹ˆ๋‹ค.
์‚ฌ์‹ค๊ณผ ๋‹ค๋ฅธ ๋ถ€๋ถ„์ด ์žˆ๊ฑฐ๋‚˜, ์ˆ˜์ •์ด ํ•„์š”ํ•œ ์‚ฌํ•ญ์€ ๋Œ“๊ธ€๋กœ ๋‚จ๊ฒจ์ฃผ์„ธ์š”.


Transformer

๊ธฐ์กด์—๋Š” Attention ๋ชจ๋“ˆ์ด RNN์ด๋‚˜ CNN ๋ชจ๋“ˆ์˜ Add-on ๋ชจ๋“ˆ๋กœ ์‚ฌ์šฉ๋˜์–ด์™”๋‹ค.

๊ทธ๋Ÿฌ๋‚˜ 2017๋…„ ๋ฐœํ‘œ๋œ ๋…ผ๋ฌธ 'Attention is all you need'๋Š” ๊ธฐ์กด์˜ RNN๊ณผ CNN์„ ๋ชจ๋‘ ๊ฑท์–ด๋‚ด๊ณ , ์˜ค๋กœ์ง€ Attention๋งŒ์œผ๋กœ ๊ตฌ์ถ•ํ•˜์—ฌ ์‹œํ€€์Šค ๋ฐ์ดํ„ฐ๋ฅผ ์ž…์ถœ๋ ฅํ•  ์ˆ˜ ์žˆ๋Š” Transformer ๋ชจ๋ธ์„ ๊ตฌ์ถ•ํ•˜์˜€๋‹ค.

๊ธฐ์กด RNN ๋ชจ๋ธ์˜ ํ•œ๊ณ„

RNN์€ ์ •๋ณด๊ฐ€ ์—ฌ๋Ÿฌ time step์„ ๊ฑฐ์น˜๋ฉด์„œ, ๋ฉ€๋ฆฌ์žˆ๋Š” ์ •๋ณด๊ฐ€ ์œ ์‹ค/๋ณ€์งˆ ๋˜๋Š” long-term dependecy ๋ฌธ์ œ๊ฐ€ ์žˆ์—ˆ๋‹ค.

Bi-Directional RNNs

๊ธฐ์กด RNN์€ ๋‹จ๋ฐฉํ–ฅ์œผ๋กœ ์ •๋ณด๋ฅผ ์ „๋‹ฌํ•˜๊ธฐ ๋•Œ๋ฌธ์— ์–ด๋А์ •๋„ ๋จผ ๊ฑฐ๋ฆฌ์˜ ์ •๋ณด๋ฅผ ๋ฐ˜์˜ํ•˜๊ธฐ ์–ด๋ ค์› ๋Š”๋ฐ, ์ด๋ฅผ ํ•ด๊ฒฐํ•˜๊ณ ์ž ์–‘๋ฐฉํ–ฅ์œผ๋กœ RNN์„ ๋ณ‘๋ ฌ ๊ตฌ์„ฑํ•œ Bi-Directional RNN์ด ๋‚˜์˜ค๊ฒŒ ๋˜์—ˆ๋‹ค.(Forward RNN, Backward RNN)

bidirectional-rnn

์ด ๋•Œ ๋™์ผํ•œ time step์—์„œ์˜ hidden state vector htfh^f_t์™€ htbh^b_t๋Š” concat๋˜์–ด ์ธ์ฝ”๋”ฉ ๋ฒกํ„ฐ๋ฅผ ๊ตฌ์„ฑํ•œ๋‹ค.

Transformer

transformer_architecture

์ธ์ฝ”๋” ๊ตฌ์กฐ

transformer-diagram

  1. Input์œผ๋กœ ์ฃผ์–ด์ง„ 3๊ฐœ์˜ ๋ฒกํ„ฐ(I, go, home)๋Š” ๊ฐ๊ฐ ๋ณธ์ธ์˜ ์ฐจ๋ก€์—์„œ ์„ ํ˜•๋ณ€ํ™˜(WQW^Q)๋˜์–ด query ๋ฒกํ„ฐ๋กœ ๊ธฐ๋Šฅํ•˜๊ณ , ์ž๊ธฐ ์ž์‹ ์„ ํฌํ•จํ•œ input ๋ฒกํ„ฐ๋“ค์„ ์„ ํ˜•๋ณ€ํ™˜(WKW^K)ํ•œ key ๋ฒกํ„ฐ๋“ค๊ณผ ๋‚ด์ ํ•˜์—ฌ ์ƒˆ๋กœ์šด ๋ฒกํ„ฐ๋ฅผ ๋งŒ๋“ค์–ด๋‚ธ๋‹ค.
    • ์ด๋ฏธ์ง€์—์„œ [3.8,-0.2,5.9]
    • ์ด๋Š” ๊ฐ ๋ฒกํ„ฐ๊ฐ„์˜ ์œ ์‚ฌ๋„๋ฅผ ๊ฒ€์‚ฌํ•˜์—ฌ, input๋œ ๋‹จ์–ด๋“ค๊ฐ„์˜ ๊ด€๊ณ„๋ฅผ ์ •์˜ํ•œ๋‹ค.
  2. ์ดํ›„, ์ด ๋ฒกํ„ฐ๋Š” softmax๋ฅผ ๊ฑฐ์ณ ํ™•๋ฅ ๋กœ ๋ณ€ํ™˜ํ•˜์—ฌ ๊ฐ€์ค‘์น˜ ๋ฒกํ„ฐ๋ฅผ ๋งŒ๋“ ๋‹ค.
  3. ๋‹ค์‹œ ๊ธฐ์กด์˜ input ๋ฒกํ„ฐ๋“ค์„ ์„ ํ˜•๋ณ€ํ™˜(WVW^V)ํ•œ value ๋ฒกํ„ฐ์™€ ๊ฐ€์ค‘ํ‰๊ท ๋‚ด์–ด ์ดํ•ฉ์ด 1์ด ๋˜๋Š” ๋ฒกํ„ฐ(attention output vector)๋ฅผ ๊ตฌ์„ฑํ•˜๊ณ , ์ด๊ฒƒ์ด ๊ณง ํ•ด๋‹น input ๋ฒกํ„ฐ์˜ ์ธ์ฝ”๋”ฉ ๋ฒกํ„ฐ๊ฐ€ ๋œ๋‹ค.
    • ๊ตฌํ•ด์ง„ ์œ ์‚ฌ๋„๋ฅผ ๋ฐ”ํƒ•์œผ๋กœ ๊ฐ€์ค‘ํ‰๊ท ์„ ๋‚ด๋Š” ๊ณผ์ •์ด๋‹ค.

์ด ๊ณผ์ •์„ ํ†ตํ•ด ๋‚˜์˜จ ์ธ์ฝ”๋”ฉ ๋ฒกํ„ฐ ๊ฐ’์œผ๋กœ ์—ฌ๋Ÿฌ input ์ค‘ ์–ด๋–ค input์— ์–ด๋А ์ •๋„ ๋น„์œจ๋กœ ์ง‘์ค‘(attention)ํ•ด์•ผ ํ•  ์ง€ ์•Œ ์ˆ˜ ์žˆ๊ฒŒ ๋œ๋‹ค.

์ด ๋•Œ, ์„ ํ˜•๋ณ€ํ™˜(W)์„ ๊ฑฐ์น˜์ง€ ์•Š๊ณ  input ๋ฒกํ„ฐ๋ฅผ ๊ทธ๋Œ€๋กœ ์‚ฌ์šฉํ•˜๋ฉด, ๋‹น์—ฐํ•˜๊ฒŒ๋„ ์ž๊ธฐ ์ž์‹ ๊ณผ์˜ ๋‚ด์ ์ด ๊ฐ€์žฅ ์ปค์ง€๊ธฐ ๋•Œ๋ฌธ์—, ๊ฒฐ๊ณผ๋กœ ๋‚˜์˜จ h์—์„œ๋Š” ์ž๊ธฐ์ž์‹ ์— ๋Œ€ํ•œ attention์ด ๊ฐ€์žฅ ์ปค์ ธ ์ •๋ณด์˜ ์œ ํšจ์„ฑ์ด ๋ถ€์กฑํ•ด์ง€๋Š” ๋ฌธ์ œ๊ฐ€ ์ƒ๊ธธ ์ˆ˜ ์žˆ๋‹ค.

์ด๋Ÿฌํ•œ transformer์˜ ๊ตฌ์กฐ๋Š” ์ •๋ณด๊ฐ€ ์œ„์น˜ํ•œ ๊ฑฐ๋ฆฌ์™€ ๊ด€๊ณ„์—†์ด ์œ ์‚ฌ๋„๋ฅผ ์ธก์ •ํ•˜์—ฌ attention์„ ๋ถ„๋ฐฐํ•จ์œผ๋กœ์จ ๊ธฐ์กด์˜ RNN๊ตฌ์กฐ๊ฐ€ ๊ฐ€์ง€๋Š” Long-term Dependency ๋ฌธ์ œ๋ฅผ ๊ทผ๋ณธ์ ์œผ๋กœ ํ•ด๊ฒฐํ–ˆ๋‹ค๊ณ  ํ•  ์ˆ˜ ์žˆ๋‹ค.

๋œฏ์–ด๋ณด๊ธฐ

Transformer ๋ชจ๋ธ์—์„œ ํ•ต์‹ฌ์  ์—ญํ• ์„ ํ•˜๋Š” ์„ธ ๋ฒกํ„ฐ, Query(์ดํ•˜ Q)์™€ Key(์ดํ•˜ K), Value(์ดํ•˜ V)์— ๋Œ€ํ•ด์„œ ์‚ดํŽด๋ณด์ž.

Output์€ V๋“ค์˜ ๊ฐ€์ค‘ํ‰๊ท ์ธ๋ฐ, ์ด ๊ฐ€์ค‘ํ‰๊ท ์€ ๊ฒฐ๊ตญ Q์™€ K์˜ ๋‚ด์ ์œผ๋กœ ๊ตฌ์„ฑ๋œ๋‹ค.

  • ์ด ๋•Œ, Q์™€ K๋Š” ๋‚ด์  ๊ฐ€๋Šฅํ•ด์•ผํ•˜๋ฏ€๋กœ ๋ฐ˜๋“œ์‹œ ๊ฐ™์€ ์ฐจ์›์ด์–ด์•ผ ํ•œ๋‹ค. (dkd_k)
  • Q,K์˜ ์ฐจ์›๊ณผ V์˜ ์ฐจ์›(dvd_v)์€ ๊ฐ™์„ ํ•„์š”๊ฐ€ ์—†๋‹ค. V๋Š” ๊ฒฐ๊ตญ ์ƒ์ˆ˜๋ฐฐํ•ด์„œ ๊ฐ€์ค‘ํ‰๊ท  ๋‚ผ ๊ฒƒ์ด๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค.

A(q,K,V)=โˆ‘iexpโก(qโ‹…ki)โˆ‘jexpโก(qโ‹…kj)viA(q,K,V) = \sum_i\frac{\exp(q\cdot k_i)}{\sum_j\exp(q\cdot k_j)}v_i

  • Attention ๋ชจ๋“ˆ์˜ input์€ ํ•˜๋‚˜์˜ query ๋ฒกํ„ฐ, ๋ชจ๋“  Key(๋ฅผ concatํ•œ) ๋ฒกํ„ฐ์™€ ๋ชจ๋“  Value(๋ฅผ concatํ•œ) ๋ฒกํ„ฐ๊ฐ€ ๋œ๋‹ค.
  • ๋ถ„์ž๋Š” i๋ฒˆ์งธ key์™€ query ์‚ฌ์ด์˜ ์œ ์‚ฌ๋„, ์ฆ‰ ๋‘ ๋‹จ์–ด๊ฐ„์˜ ์œ ์‚ฌ๋„๊ฐ€ ๋œ๋‹ค. ์—ฌ๊ธฐ์— ํ•ด๋‹น ๋‹จ์–ด์˜ value ๋ฒกํ„ฐ๋ฅผ ๊ณฑํ•œ๋‹ค.
  • ๋ถ„๋ชจ๋Š” ๋ชจ๋“  key์— ๋Œ€ํ•œ ์œ ์‚ฌ๋„์˜ ์ดํ•ฉ์ด ๋œ๋‹ค.
  • ๋”ฐ๋ผ์„œ, (ํ•ด๋‹น ๋‹จ์–ด์˜ ์œ ์‚ฌ๋„ / ์ „์ฒด ๋‹จ์–ด์˜ ์œ ์‚ฌ๋„ ์ดํ•ฉ) ํ˜•ํƒœ๊ฐ€ ๋˜๋ฏ€๋กœ ๊ฐ€์ค‘ํ‰๊ท ์„ ๊ตฌ์„ฑํ•˜๊ฒŒ ๋œ๋‹ค. ์ด ๋•Œ ์ถœ๋ ฅ ๋ฒกํ„ฐ๋Š” value ๋ฒกํ„ฐ์˜ ํฌ๊ธฐ๊ฐ€ ๋  ๊ฒƒ์ด๋‹ค.

๊ทธ๋ ‡๋‹ค๋ฉด query๋ฅผ ์—ฌ๋Ÿฌ๊ฐœ ์Œ“์•„ ํ–‰๋ ฌ Q๋ฅผ ๋งŒ๋“  ๋’ค ํ•œ๋ฒˆ์— ํ‘œํ˜„ํ•ด๋ณด๋„๋ก ํ•˜์ž.

A(Q,K,V)=softmax(QKT)VA(Q,K,V) = \text{softmax}(QK^T)V

์ด๋ฅผ ํ–‰๋ ฌ์˜ ํ˜•ํƒœ๋กœ ํ‘œํ˜„ํ•˜๋ฉด ๋‹ค์Œ ์ด๋ฏธ์ง€์™€ ๊ฐ™๋‹ค.

qkv

  • โˆฃQโˆฃ\vert Q\vert : ์ฟผ๋ฆฌ์˜ ๊ฐœ์ˆ˜
  • dkd_k : ์ฟผ๋ฆฌ์˜ ์ฐจ์›
  • Q์™€ K์˜ ๋‚ด์ ์„ ์œ„ํ•ด์„œ๋Š” K๊ฐ€ transpose๋˜์–ด์•ผ ํ•œ๋‹ค. ๋‚ด์ ์˜ ๊ฒฐ๊ณผ๋กœ ๋‚˜์˜ค๋Š” โˆฃQโˆฃร—โˆฃKโˆฃ|Q|\times|K| ํ–‰๋ ฌ์—์„œ i๋ฒˆ์งธ row๋Š” i๋ฒˆ์งธ query์— ๋Œ€ํ•œ input ๋ฒกํ„ฐ๋“ค์˜ ์œ ์‚ฌ๋„๋ฅผ ๋‚˜ํƒ€๋‚ด๋Š” row๊ฐ€ ๋œ๋‹ค. ์ด ์—ฐ์‚ฐ์ด ๋๋‚˜๋ฉด softmax๋ฅผ ์ด์šฉํ•˜์—ฌ ๊ฐ€์ค‘์น˜ ๋ฒกํ„ฐ๋กœ ๋ณ€ํ™˜๋œ๋‹ค.
  • ๊ฐ€์ค‘์น˜ ๋ฒกํ„ฐ์™€ V ๋ฒกํ„ฐ์™€์˜ ๋‚ด์ ์„ ํ†ตํ•ด ๋‚˜์˜จ โˆฃQโˆฃร—dK|Q|\times d_K ์ถœ๋ ฅ ํ–‰๋ ฌ์€ ๋‹ค์‹œ ๊ธฐ์กด์˜ Q์™€ ๋™์ผํ•œ ํ˜•ํƒœ๋ฅผ ์ด๋ฃจ๋ฉฐ, ์ถœ๋ ฅ ํ–‰๋ ฌ์˜ i๋ฒˆ์งธ row๋Š” input Q์˜ i๋ฒˆ์งธ row(query)์— ๋Œ€ํ•œ attention์˜ output์ด ๋œ๋‹ค.

Scaled Dot-Product Attention

์œ„์˜ ๊ฒฝ์šฐ๋Š” 2์ฐจ์› ํ–‰๋ ฌ์„ ๊ฐ€์ •ํ•˜๊ณ  ์ˆ˜ํ–‰ํ–ˆ์ง€๋งŒ, ์‹ค์ œ๋กœ ์—ฐ์‚ฐ์„ ์ˆ˜ํ–‰ํ•  ๋•Œ์—๋Š” q์™€ k๊ฐ€ n์ฐจ์›์ผ ์ˆ˜ ์žˆ๋‹ค.

์ด๋•Œ, dimension์ด ์ปค์ง€๋ฉด ์ปค์งˆ์ˆ˜๋ก q์™€ k์˜ ๋‚ด์ ๊ฐ’์˜ ์ ๋ฆผํ˜„์ƒ์ด ์‹ฌํ•ด์ ธ์„œ, attention์ด ํšจ์œจ์ ์œผ๋กœ ํ•™์Šต๋˜์ง€ ๋ชปํ•œ๋‹ค.

  • ์ด๋Š” ์„œ๋กœ ๋…๋ฆฝ์ธ random variable๋ผ๋ฆฌ ๊ณฑํ–ˆ์„ ๋•Œ ๋ถ„์‚ฐ์ด 1์ด๋˜๊ณ , ์ด ๊ณฑ๋“ค์„ ์„œ๋กœ ๋”ํ–ˆ์„ ๋•Œ ๋ถ„์‚ฐ์ด ๊ณ ์Šค๋ž€ํžˆ ๋”ํ•ด์ง€๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค. ๋‚ด์  ์ˆ˜ํ–‰ ๊ณผ์ •์—์„œ๋Š” ์ฐจ์›์ด ํด์ˆ˜๋ก random variable ๊ฐ’๋ผ๋ฆฌ ๋” ๋งŽ์ด ๊ณฑํ•˜๊ณ  ๋”ํ•ด์ง€๊ฒŒ ๋˜๋Š”๋ฐ, ์ด ๊ณผ์ •์—์„œ ๋ถ„์‚ฐ์ด ๋„ˆ๋ฌด ์ปค์ง„๋‹ค. ์ด๋Š” ๊ณง ๋‚ด์ ๊ฐ’์˜ ์ฐจ์ด๊ฐ€ ํฌ๊ฒŒ ๋‚˜๋Š” ๊ฒƒ์„ ์˜๋ฏธํ•œ๋‹ค.
  • ์ด๋ฅผ ์ง๊ด€์ ์œผ๋กœ ์ƒ๊ฐํ•ด๋ณด์ž. 2์ฐจ์› ํ‰๋ฉด์ƒ์— ๋–จ์–ด์ ธ์žˆ๋˜ ์—ฌ๋Ÿฌ ์ ๋“ค์ด, 3์ฐจ์› ๊ณต๊ฐ„์œผ๋กœ ์ด๋™ํ•˜๋ฉด, ์ ๋“ค๊ฐ„์˜ ๊ฑฐ๋ฆฌ๋Š” ์–ด๋–ป๊ฒŒ ๋ ๊นŒ? ๋†’์ด(z)๋ผ๋Š” ์ƒˆ๋กœ์šด ์ถ•์ด ์ƒ๊ฒผ์œผ๋ฏ€๋กœ, ๋‹น์—ฐํžˆ ์ ๋“ค๊ฐ„์˜ ๊ฑฐ๋ฆฌ๋Š” ํ‰๋ฉด์ƒ๋ณด๋‹ค ๋Œ€์ฒด๋กœ ๋ฉ€์–ด์ง€๊ณ , ์ตœ์„ ์˜ ๊ฒฝ์šฐ(์„ธ ์ ๋“ค์ด ๊ฐ™์€ ํ‰๋ฉด์— ์žˆ์„ ๋•Œ)์—์•ผ ๋™์ผํ•œ ๊ฑฐ๋ฆฌ๋ฅผ ๊ฐ€์ง€๊ฒŒ ๋œ๋‹ค.
  • ๋”ฐ๋ผ์„œ, ์ฐจ์›์ด ๋Š˜์–ด๋‚œ๋‹ค๋Š” ๊ฒƒ์€ ๋Œ€์ฒด๋กœ ๊ณต๊ฐ„์ƒ์˜ ์ ์ขŒํ‘œ๋กœ ํ‘œํ˜„๋˜๋Š” vector๊ฐ„์˜ ๊ฑฐ๋ฆฌ๊ฐ€ sparseํ•ด ์ง„๋‹ค๋Š” ๊ฒƒ, ์ฆ‰ ๋ถ„์‚ฐ์ด ๋Š˜์–ด๋‚œ๋‹ค๋Š” ๊ฒƒ์„ ์ง๊ด€์ ์œผ๋กœ ์ดํ•ดํ•  ์ˆ˜ ์žˆ๋‹ค.
A(Q,K,V)=softmax(QKTdk)A(Q,K,V) = \text{softmax}\Bigg(\frac{QK^T}{\sqrt{d_k}}\Bigg)

์ด๋Ÿฐ ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๊ธฐ ์œ„ํ•ด, ๋ถ„์‚ฐ์„ ์ผ์ •ํ•˜๊ฒŒ ์œ ์ง€ํ•˜๊ธฐ ์œ„ํ•œ ๋ฐฉ๋ฒ•์œผ๋กœ Q์™€ K์˜ ๋‚ด์ ์„ ์ฐจ์›(dkd_k)์˜ ์ œ๊ณฑ๊ทผ์œผ๋กœ ๋‚˜๋ˆ„์–ด์ฃผ๋Š” ํ…Œํฌ๋‹‰์ด ์žˆ๋‹ค. ์ด ๊ฒฝ์šฐ ๋ถ„์‚ฐ์ด ๋ถ„์‚ฐ์˜ ์ œ๊ณฑ๋ฐฐ๋งŒํผ ์ถ•์†Œ๋˜๊ธฐ ๋•Œ๋ฌธ์—, ์ฐจ์›์ด ๋ช‡์ฐจ์›์ด๋“  ๋ถ„์‚ฐ์„ ํ•ญ์ƒ 1๋กœ ์ผ์ •ํ•˜๊ฒŒ Scalingํ•  ์ˆ˜ ์žˆ๋‹ค.

Multi-Head Attention

Multi-Head Attention์€ ๊ธฐ์กด์˜ Attention ๋ชจ๋“ˆ์„ ์ข€ ๋” ์œ ์šฉํ•˜๊ฒŒ ํ™•์žฅํ•œ ๋ชจ๋“ˆ์ด๋‹ค.

mha

MultiHead(Q,K,V)=Concat(head1,โ€ฆ,headn)WOWhere headi=Attention(QWiQ,KWiK,VWiV)\text{MultiHead}(Q,K,V) = \mathrm{Concat(head_1,\dots,head_n)}W^O\\ \mathrm{Where\ head_i = Attention}(QW_i^Q,KW_i^K,VW_i^V)

Multi-head attention์€ ์—ฌ๋Ÿฌ๊ฐœ์˜ attention ๋ชจ๋“ˆ์„ ๋™์‹œ์— ์‚ฌ์šฉํ•œ๋‹ค. ์ด ๋•Œ ๊ฐ attention ๋ชจ๋“ˆ์˜ ์„ ํ˜•๋ณ€ํ™˜ ํŒŒ๋ผ๋ฏธํ„ฐ WiW_i(head)๋Š” ๋ชจ๋“ˆ๋งˆ๋‹ค ๊ฐ๊ฐ ๋‹ค๋ฅด๋‹ค. ์ด ๊ฐ๊ธฐ ๋‹ค๋ฅธ version์˜ ๋ชจ๋“ˆ๋“ค์„ ์ด์šฉํ•˜์—ฌ ๋‚ธ output๋“ค์„ concat(ร—i)\times i)ํ•œ ๋’ค, WOW^O๋กœ ์„ ํ˜•๋ณ€ํ™˜ํ•˜์—ฌ ํ•˜๋‚˜์˜ output์„ ๋งŒ๋“œ๋Š” ํ˜•ํƒœ์ด๋‹ค.

์ด๋Ÿฐ Multi-head attention์„ ์‚ฌ์šฉํ•˜๋Š” ์ด์œ ๋Š”, ๋™์ผํ•œ ์ž…๋ ฅ๋ฌธ ๊ธฐ์ค€์œผ๋กœ๋„ ํ•„์š”์— ๋”ฐ๋ผ ์ค‘์ ์„ ๋‘์–ด์•ผ ํ•  ๋‹จ์–ด๋“ค์ด ๋‹ค๋ฅผ ๊ฒฝ์šฐ๊ฐ€ ์žˆ๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค.

  • ๊ฐ€๋ น, I am going to eat dinner๋ผ๋Š” ๋ฌธ์žฅ์ด ์žˆ๋‹ค๊ณ  ํ•˜์ž.
  • ์–ด๋–ค ๋•Œ๋Š” '๋‚ด๊ฐ€' ๋จน์—ˆ๋‹ค๋Š” ์‚ฌ์‹ค์— ์ฃผ๋ชฉํ•ด์•ผ ํ•ด์„œ 'I'์— ์ง‘์ค‘ํ•ด์•ผ ํ•  ์ˆ˜๋„ ์žˆ๊ณ , ์–ด๋–ค ๋•Œ๋Š” '์ €๋…'์„ ๋จน์—ˆ๋‹ค๋Š” ์‚ฌ์‹ค์— ์ฃผ๋ชฉํ•˜๊ธฐ ์œ„ํ•ด 'dinner'์— ์ง‘์ค‘ํ•ด์•ผ ํ•  ์ˆ˜๋„ ์žˆ๋‹ค.

attention์˜ ์—ฐ์‚ฐ๋Ÿ‰

  • nn : sequnce ๊ธธ์ด
  • dd : query / key ์ฐจ์›
  • kk : ํ•ฉ์„ฑ๊ณฑ ์—ฐ์‚ฐ์˜ ์ปค๋„ ์‚ฌ์ด์ฆˆ
  • rr : restricted self-attention์˜ ์ด์›ƒ ์‚ฌ์ด์ฆˆ

์œ„์™€ ๊ฐ™์„ ๋•Œ, ๊ธฐ์กด layer๋“ค์˜ ์—ฐ์‚ฐ๋Ÿ‰์€ ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค.

calc

  • Complexity per Layer
    • Total Computational Complexity per Layer๋ฅผ ์˜๋ฏธํ•˜๋Š” ๊ฒƒ์œผ๋กœ, ์ด ์—ฐ์‚ฐ์–‘์˜ ๋ณต์žก๋„๋ฅผ ์˜๋ฏธํ•œ๋‹ค. ๋…ผ๋ฌธ ์›๋ฌธ์—์„œ๋Š” (์—ฐ์‚ฐ์–‘์— ๋”ฐ๋ฅธ) ์‹œ๊ฐ„๋ณต์žก๋„๋ฅผ ์˜๋ฏธํ•œ ๋“ฏ ํ•œ๋ฐ, ๊ณต๊ฐ„๋ณต์žก๋„๋กœ ํ•ด์„ํ•ด๋„ ํฐ ๋ฌด๋ฆฌ๋Š” ์—†๋Š” ๋“ฏํ•˜๋‹ค. ์—ฐ์‚ฐ์„ ์ฒ˜๋ฆฌํ•˜๋Š” ๋””๋ฐ”์ด์Šค๊ฐ€ 1๊ฐœ์ธ ์ƒํ™ฉ์„ ๊ฐ€์ •ํ•œ๋‹ค๊ณ  ์ƒ๊ฐํ•˜๋ฉด ์–ผ์ถ” ๋“ค์–ด๋งž๋Š”๋‹ค.
    • Self-attention
      • Qร—KT=(nร—d)ร—(dร—n)Q \times K^T = (n \times d) \times (d \times n)
      • Q์™€ K๋ฅผ ๋‚ด์ ํ•˜๋ฏ€๋กœ ๊ฐ ๊ธธ์ด nn์„ ์ œ๊ณฑํ•˜๊ณ , ์ด๋ฅผ ๋ชจ๋“  dimension๋งˆ๋‹ค ๊ณ„์‚ฐํ•ด์•ผ ํ•˜๋ฏ€๋กœ dd๋ฅผ ๊ณฑํ•ด์•ผํ•œ๋‹ค. - O(n2โ‹…d)O(n^2\cdot d)
    • Recurrent
      • time step์˜ ๊ฐœ์ˆ˜๊ฐ€ nn์ด๊ณ , ๋งค time step๋งˆ๋‹ค (dร—d)(d\times d) ํฌ๊ธฐ์˜ WhhW_{hh}๋ฅผ ๊ณฑํ•œ๋‹ค. - O(nโ‹…d2)O(n\cdot d^2)์ด ๋•Œ WhhW_{hh}์˜ dimension dd๋Š” hidden state vector์˜ ํฌ๊ธฐ๋กœ, ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ์ด๋ฏ€๋กœ ์ง์ ‘ ์ •ํ•ด์ค„ ์ˆ˜ ์žˆ๋‹ค.
  • Sequential Operations
    • ํ•ด๋‹น ์—ฐ์‚ฐ์„ ์–ผ๋งˆ์˜ ์‹œ๊ฐ„๋‚ด์— ๋๋‚ผ ์ˆ˜ ์žˆ๋Š”๊ฐ€๋ฅผ ๋‚˜ํƒ€๋‚ธ ๊ฒƒ์œผ๋กœ, ์—ฐ์‚ฐ์–‘ ์ž์ฒด๋Š” ๋ฌดํ•œํžˆ ๋งŽ์€ GPU์˜ ๋ณ‘๋ ฌ์—ฐ์‚ฐ์œผ๋กœ ํ•œ๋ฒˆ์— ์ฒ˜๋ฆฌํ•  ์ˆ˜ ์žˆ์Œ์„ ๊ฐ€์ •ํ•œ๋‹ค.
    • Self-attention
      • ์‹œํ€€์Šค์˜ ๊ธธ์ด nn์ด ๊ธธ์–ด์งˆ์ˆ˜๋ก ์ง€์ˆ˜๋ฐฐ๋กœ ์—ฐ์‚ฐ๋ณต์žก๋„๊ฐ€ ๋Š˜์–ด๋‚œ๋‹ค.
      • ์ด๋Š” ๋ชจ๋“  Q์™€ K์˜ ๋‚ด์ ๊ฐ’์„ ๋ชจ๋‘ ์ €์žฅํ•˜๊ณ  ์žˆ์–ด์•ผํ•˜๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค.๋”ฐ๋ผ์„œ ์ผ๋ฐ˜์ ์ธ Recurrent๋ณด๋‹ค ํ›จ์”ฌ ๋งŽ์€ ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ํ•„์š”๋กœ ํ•˜๊ฒŒ ๋œ๋‹ค.
      • ๊ทธ๋Ÿฌ๋‚˜ GPU๋Š” ์ด๋Ÿฐ ํ˜•ํƒœ์˜ ํ–‰๋ ฌ ์—ฐ์‚ฐ ๋ณ‘๋ ฌํ™”์— ํŠนํ™”๋˜์–ด ์žˆ๊ณ , ๋”ฐ๋ผ์„œ ๋ฌดํ•œํžˆ ๋งŽ์€ GPU๋ฅผ ๊ฐ€์ง€๊ณ ๋งŒ ์žˆ๊ธฐ๋งŒ ํ•˜๋‹ค๋ฉด ์ด๋ฅผ ๋ณ‘๋ ฌํ™”ํ•˜์—ฌ ๊ณ„์‚ฐํ•  ์ˆ˜ ์žˆ์œผ๋ฏ€๋กœ, ์‹œ๊ฐ„ ๋ณต์žก๋„๋Š” O(1)O(1)์ด ๋œ๋‹ค.
    • Recurrent
      • ์ด์ „ time step์˜ htโˆ’1h_{t-1}์ด ์ œ๊ณต๋˜์–ด์•ผ ๊ทธ๊ฒƒ์„ input์œผ๋กœ ๋‹ค์Œ hth_t๋ฅผ ๊ณ„์‚ฐํ•  ์ˆ˜ ์žˆ๊ธฐ ๋•Œ๋ฌธ์—, ๋ถˆ๊ฐ€ํ”ผํ•˜๊ฒŒ ์‹œ๊ฐ„ ๋ณต์žก๋„๋Š” O(n)O(n)์ด ๋œ๋‹ค.
  • Maximum Path Length - ๋‘ ๋‹จ์–ด ๊ฐ„์˜ ๊ฒฝ๋กœ ๊ฑฐ๋ฆฌ
    • Self-attention
      • ๋‘ ๋‹จ์–ด๊ฐ„์˜ ์œ ์‚ฌ๋„๋ฅผ ๊ตฌํ•  ๋•Œ, ํ–‰๋ ฌ ์—ฐ์‚ฐ์œผ๋กœ ๋ฐ”๋กœ ๊ณฑํ•  ์ˆ˜ ์žˆ์œผ๋ฏ€๋กœ O(1)O(1)์ด๋‹ค.
    • Recurrent
      • ์–ด๋–ค ๋‹จ์–ด a๊ฐ€ ์–ด๋А์ •๋„ ๋–จ์–ด์ง„ ๋‹จ์–ด b์— ๋„๋‹ฌํ•˜๊ธฐ๊นŒ์ง€ recurrent cell์„ ํ•˜๋‚˜์”ฉ ํ†ต๊ณผํ•ด์•ผํ•˜๊ธฐ ๋•Œ๋ฌธ์— O(n)O(n)์ด ๋œ๋‹ค.

Block-Based Model

transformer-encoder

์ฒ˜์Œ์— ์‹œ์ž‘ํ•ด Multi-Head Attention์œผ๋กœ ๊ฐ€๋Š” ์„ธ๊ฐœ์˜ ํ™”์‚ดํ‘œ๋Š” ๊ฐ๊ฐ Q,K,V๋ฅผ ์˜๋ฏธํ•œ๋‹ค. ๊ฐ head๋งˆ๋‹ค ๋“ค์–ด๊ฐ€๊ฒŒ ๋œ๋‹ค.

๊ทธ ์—ฐ์‚ฐ ์ดํ›„์— ์ง„ํ–‰๋˜๋Š” Add&Norm ๊ตฌ๊ฐ„์€ ๋ฌด์—‡์ผ๊นŒ?

  • Add - Residual Connection
    • CV ์ชฝ์—์„œ ๊นŠ์€ ๋ ˆ์ด์–ด๋ฅผ ๋งŒ๋“ค ๋•Œ graident vanishing์„ ํ•ด๊ฒฐํ•˜๋ฉด์„œ ๋” ๊นŠ์€ ์ธต์„ ์Œ“๋„๋ก ํ•˜๋Š” ํšจ๊ณผ์ ์ธ ๋ชจ๋ธ์ด๋‹ค.
    • ์ฃผ์–ด์ง„ input vector๋ฅผ Multi-Head Attention์˜ encoding output์— ๊ทธ๋Œ€๋กœ ๋”ํ•˜์—ฌ ์ƒˆ๋กœ์šด output์„ ๋งŒ๋“ค์–ด์„œ, ํ•™์Šต์‹œ์— Multi-Head Attention์ด ์ž…๋ ฅ ๋ฒกํ„ฐ ๋Œ€๋น„ ์ •๋‹ต ๋ฒกํ„ฐ์™€์˜ '์ฐจ์ด๋‚˜๋Š” ์ •๋ณด'๋งŒ ํ•™์Šตํ•˜๋„๋ก ํ•  ์ˆ˜ ์žˆ๋‹ค.
    • ์ด ๋•Œ, Multi-Head Attention output๊ณผ input ๋ฒกํ„ฐ์˜ ํฌ๊ธฐ๊ฐ€ ์™„์ „ํžˆ ๋™์ผํ•˜๋„๋ก ์œ ์ง€ํ•ด์•ผ ๋”ํ•  ์ˆ˜ ์žˆ๋‹ค.
  • Norm - Normalization
    • ์ผ๋ฐ˜์ ์œผ๋กœ ์‹ ๊ฒฝ๋ง์—์„œ ์‚ฌ์šฉ๋˜๋Š” normalization์€, (ํ‰๊ท ,๋ถ„์‚ฐ)์„ (0,1)๋กœ ๋งŒ๋“  ๋’ค, ์›ํ•˜๋Š” ํ‰๊ท ๊ณผ ๋ถ„์‚ฐ์„ ์ฃผ์ž…ํ•  ์ˆ˜ ์žˆ๋„๋ก ํ•˜๋Š” ์„ ํ˜•๋ณ€ํ™˜(affine transformation)์œผ๋กœ ์ด๋ฃจ์–ด์ง„๋‹ค.
    • Batch Norm
      • ๊ฐ ์›์†Œ์— ํ‰๊ท ์„ ๋นผ๊ณ , ํ‘œ์ค€ํŽธ์ฐจ๋กœ ๋‚˜๋ˆˆ๋‹ค. โ†’ (ํ‰๊ท ,๋ถ„์‚ฐ)==(0,1)
      • affine transforamtionํ•˜์—ฌ ์›ํ•˜๋Š” ํ‰๊ท ๊ณผ ๋ถ„์‚ฐ์œผ๋กœ ๋งŒ๋“ ๋‹ค.
        • ex ) y=2x+3โ†’(ํ‰๊ท ,๋ถ„์‚ฐ)=(3,2)y= \textcolor{red}2x +\textcolor{Green}3 \rarr (ํ‰๊ท ,๋ถ„์‚ฐ) = (\textcolor{Green}3,\textcolor{red}2)
        • ์ด ๋•Œ 2์™€ 3์€ Optimization ๊ณผ์ •์—์„œ ์ตœ์ ํ™” ๋Œ€์ƒ์ธ ํŒŒ๋ผ๋ฏธํ„ฐ๊ฐ€ ๋œ๋‹ค.
    • Layer Norm
      • Batch Norm๊ณผ ๋งˆ์ฐฌ๊ฐ€์ง€ ๋ฐฉ๋ฒ•์œผ๋กœ ์ˆ˜ํ–‰ํ•˜๋˜, ์—ฌ๋Ÿฌ layer๊ฐ€ ๋ถ™์–ด์žˆ๋Š” ํ–‰๋ ฌ์„ ๋Œ€์ƒ์œผ๋กœ, ํ•œ layer๋งˆ๋‹ค ์ˆ˜ํ–‰ํ•œ๋‹ค.
      • affine transformation์€ ๊ฐ layer์˜ ๋™์ผํ•œ node ๊ธฐ์ค€์œผ๋กœ ์ˆ˜ํ–‰ํ•œ๋‹ค.(normalization์ด column ๋‹จ์œ„์˜€๋‹ค๋ฉด affine transformation์€ row ๋ณ„)
      • Batch Norm๊ณผ๋Š” ์ผ๋ถ€ ์ฐจ์ด์ ์ด ์žˆ์ง€๋งŒ, ํฐ ํ‹€์—์„œ ํ•™์Šต์„ ์•ˆ์ •ํ™”ํ•œ๋‹ค๋Š” ์ ์€ ๋™์ผํ•˜๋‹ค.

Add&Norm ๊ตฌ๊ฐ„์„ ๊ฑฐ์น˜๊ณ  ๋‚˜์˜จ output์€ ๋‹ค์‹œ fully connected layer(Feed Forward)์— ํ†ต๊ณผ์‹œ์ผœ Word์˜ ์ธ์ฝ”๋”ฉ ๋ฒกํ„ฐ๋ฅผ ๋ณ€ํ™˜ํ•œ๋‹ค. ์ดํ›„ ๋‹ค์‹œ Add&Norm์„ ํ•œ๋ฒˆ ๋” ์ˆ˜ํ–‰ํ•˜๋Š” ๊ฒƒ๊นŒ์ง€๋ฅผ ๋์œผ๋กœ Transformer์˜ (self-attention ๋ชจ๋“ˆ์„ ํฌํ•จํ•œ) Block Based Model์ด ์™„์„ฑ๋œ๋‹ค.

Positional Encoding

RNN๊ณผ ๋‹ฌ๋ฆฌ self-attention ๋ชจ๋“ˆ ๊ธฐ๋ฐ˜์˜ Block Based Model๋กœ ์ธ์ฝ”๋”ฉํ•˜๋Š” ๊ฒฝ์šฐ, ์ˆœ์„œ๋ฅผ ๊ณ ๋ คํ•˜์ง€ ์•Š๊ธฐ ๋•Œ๋ฌธ์— input ๋‹จ์–ด๋“ค์˜ ์ˆœ์„œ๊ฐ€ ๋ฐ”๋€Œ์–ด๋„ output์€ ๋™์ผํ•˜๋‹ค. ์ด๋Š” K์™€ Q๊ฐ„์˜ ์œ ์‚ฌ๋„๋ฅผ ๊ตฌํ•˜๊ณ  V๋กœ ๊ฐ€์ค‘์น˜๋ฅผ ๊ตฌํ•ด ๊ฐ€์ค‘ํ•ฉ(์ด๋•Œ softmax๋ฅผ ํ†ต๊ณผํ•œ ๊ฐ’์ด๋ฏ€๋กœ ๊ฐ€์ค‘ํ•ฉ ์ž์ฒด๊ฐ€ ๊ฐ€์ค‘ํ‰๊ท ์ด๋‹ค)์„ ๋„์ถœํ•˜๋Š” ๊ณผ์ •์—์„œ, sequence๋ฅผ ๊ณ ๋ คํ•˜์ง€ ์•Š๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค.

Positional Encoding์€ ์ด๋ฅผ ํ•ด๊ฒฐํ•˜๊ธฐ์œ„ํ•ด ๋ฒกํ„ฐ ๋‚ด์˜ ํŠน์ • ์›์†Œ์— ํ•ด๋‹น word์˜ ์ˆœ์„œ๋ฅผ ์•Œ์•„๋ณผ ์ˆ˜ ์žˆ๋Š” , ๋งˆ์น˜ ์ง€๋ฌธ๊ณผ๋„ ๊ฐ™์€ uniqueํ•œ ๊ฐ’์„ ์ถ”๊ฐ€ํ•˜์—ฌ sequence๋ฅผ ๊ณ ๋ คํ•˜๋Š” ๊ฒƒ์„ ๋งํ•œ๋‹ค.

  • ์ด ๋•Œ, uniqueํ•œ ๊ฐ’์€ ์—ฌ๋Ÿฌ ์ฃผ๊ธฐํ•จ์ˆ˜์˜ ์ถœ๋ ฅ ํ•จ์ˆ˜๊ฐ’์„ ํ•ฉ์ณ ์‚ฌ์šฉํ•œ๋‹ค. ์ฃผ๊ธฐํ•จ์ˆ˜๋Š” ์ž…๋ ฅ๊ฐ’ x์˜ ์œ„์น˜์— ๋”ฐ๋ผ ์ถœ๋ ฅ๊ฐ’์ด ๋ณ€ํ•˜๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค.
  • ๋‹จ, ํ•˜๋‚˜์˜ ์ฃผ๊ธฐํ•จ์ˆ˜๋งŒ ์‚ฌ์šฉํ•˜๋ฉด ๋™์ผํ•œ ํ•จ์ˆ˜๊ฐ’์„ ๊ฐ€์ง€๋Š” ๊ตฌ๊ฐ„์ด ์ƒ๊ธฐ๋ฏ€๋กœ, ์„œ๋กœ ๋‹ค๋ฅธ ์—ฌ๋Ÿฌ ์ฃผ๊ธฐํ•จ์ˆ˜์˜ ์ถœ๋ ฅ๊ฐ’๋“ค์„ ๋ชจ๋‘ ํ•ฉ์ณ์„œ ์‚ฌ์šฉํ•œ๋‹ค.

์ด๋ ‡๊ฒŒ ํŠน์ˆ˜ํ•œ ๊ฐ’์„ ์ถ”๊ฐ€ํ•˜์—ฌ ์ธ์ฝ”๋”ฉํ•˜๊ฒŒ ๋˜๋ฉด, input ๋‹จ์–ด์˜ ์ˆœ์„œ๊ฐ€ ๋ฐ”๋€Œ์—ˆ์„ ๋•Œ output ๊ฐ’๋„ ๋‹ฌ๋ผ์ง€๊ฒŒ ๋˜์–ด ์ˆœ์„œ๋ฅผ ๊ตฌ๋ณ„ํ•  ์ˆ˜ ์žˆ๋Š” ๋ชจ๋ธ์ด ๋œ๋‹ค.

Learning Rate Scheduler

๊ธฐ์กด์˜ ๋ชจ๋ธ์—์„œ ํ•™์Šต๋ฅ (learning rate)๋Š” ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ๋กœ, ํ•™์Šต ๋‚ด๋‚ด ๊ณ ์ •๋˜์–ด์žˆ๋Š” ๊ฐ’์ด์—ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ ํ•™์Šต์˜ ๊ณผ์ •๋™์•ˆ ํšจ์œจ์ ์ธ ํ•™์Šต๋ฅ ์€ ๊ณ„์† ๋ฐ”๋€Œ๊ธฐ ๋งˆ๋ จ์ด๋ฏ€๋กœ, ์ด๋ฅผ ํ•™์Šต ๊ณผ์ • ๋‚ด์—์„œ ํšจ๊ณผ์ ์œผ๋กœ ๋ฐ”๊พธ์–ด ์ค„ ์ˆ˜ ์žˆ๋Š” ๋ฐฉ์‹์œผ๋กœ Learning Rate Scheduler๊ฐ€ ๋‚˜์˜ค๊ฒŒ ๋˜์—ˆ๋‹ค.

๋””์ฝ”๋” ๊ตฌ์กฐ

transformer-decoder

Outputs๊ฐ€ ๋””์ฝ”๋”์˜ ์ž…๋ ฅ์œผ๋กœ ๋“ค์–ด์˜ฌ ๋•Œ, ๊ธฐ์กด์˜ ground truth ๋ฌธ์žฅ์—์„œ ์•ž์ชฝ์—๋Š” <SOS> ํ† ํฐ์„ ๋ถ™์—ฌ ๋“ค์–ด์˜ค๋ฏ€๋กœ, ํ•œ์นธ ๋ฐ€๋ฆฐ(shfited right) ํ˜•ํƒœ๋กœ ๋“ค์–ด์˜ค๊ฒŒ ๋œ๋‹ค.

๋””์ฝ”๋”์—์„œ Attention ๋ชจ๋“ˆ์„ ํฌํ•จํ•œ ํ•œ์ฐจ๋ก€์˜ ๊ณผ์ •์„ ๊ฑฐ์นœ ํ›„ ๋‹ค์Œ Multi-Head Attention์œผ๋กœ ๊ฐˆ ๋•Œ, ๋””์ฝ”๋”์˜ hidden state vector๋ฅผ ์ž…๋ ฅ Q๋กœ ๋„˜๊ฒจ์ค€๋‹ค. ๊ทธ๋Ÿฐ๋ฐ, ๋‚˜๋จธ์ง€ K์™€ V ์ž…๋ ฅ์€ ์™ธ๋ถ€, ์ฆ‰ ์ธ์ฝ”๋”์˜ ์ตœ์ข… ์ถœ๋ ฅ์œผ๋กœ๋ถ€ํ„ฐ ์˜จ๋‹ค. ์ฆ‰, ์ด ๋ถ€๋ถ„์€ ๋””์ฝ”๋”์˜ hidden state vector๋ฅผ ๊ธฐ์ค€, ์ฆ‰ Q๋กœ ํ•ด์„œ ์ธ์ฝ”๋”์˜ hidden state vector K, V ๋ฅผ ๊ฐ€์ค‘ํ•˜์—ฌ ๊ฐ€์ ธ์˜ค๋Š”, ์ธ์ฝ”๋”์™€ ๋””์ฝ”๋”๊ฐ„์˜ Attention ๋ชจ๋“ˆ์ด ๋œ๋‹ค.

์ด ํ›„ ์ด๋ฏธ์ง€์— ๋‚˜์˜จ ๋Œ€๋กœ์˜ ์—ฐ์‚ฐ์„ ๊ฑฐ์น˜๋‹ค๊ฐ€, ๋””์ฝ”๋”์˜ ์ตœ์ข… output ๊ฐ’์ด Linear Layer์™€ Softmax๋ฅผ ๊ฑฐ์ณ ํ™•๋ฅ ๋ถ„ํฌ์˜ ํ˜•ํƒœ๋กœ ์ถœ๋ ฅ๋œ๋‹ค. ์ด ๊ฐ’์€ Softmax-with-loss ์†์‹คํ•จ์ˆ˜๋ฅผ ํ†ตํ•ด ํ•™์Šต๋œ๋‹ค.

Masked Self-Attention

Self-Attention ๋ชจ๋ธ์—์„œ, ์ž„์˜์˜ ๋‹จ์–ด a๋Š” Q์™€ K์˜ ๋‚ด์ ์„ ํ†ตํ•ด ์ž์‹ ๊ณผ ๋ชจ๋“  ๋‹จ์–ด๋“ค์˜ ๊ด€๊ณ„๋ฅผ ๋‹ค ์•Œ์ˆ˜ ์žˆ๋‹ค. ์ด ๋•Œ, ํ•™์Šต ๋‹น์‹œ์—๋Š” ๋ฐฐ์น˜ ํ”„๋กœ์„ธ์‹ฑ์„ ์œ„ํ•ด a ๋’ค์˜ ๋‹จ์–ด๋“ค๊นŒ์ง€ ๋ชจ๋‘ ๊ณ ๋ คํ•˜๋„๋ก ํ•™์Šต์ด ์ง„ํ–‰๋˜๋‚˜, ์‚ฌ์‹ค ์‹ค์ œ ๋””์ฝ”๋”ฉ ์ƒํ™ฉ์„ ๊ณ ๋ คํ•œ๋‹ค๋ฉด a ๋’ค์˜ ๋‹จ์–ด๋ฅผ ์•Œ์•„์„œ๋Š” ์•ˆ๋œ๋‹ค. ์ด๋Š” ๋’ค์˜ ๋‹จ์–ด๋ฅผ ์ถ”๋ก ํ•ด์•ผ ํ•˜๋Š” ์ƒํ™ฉ์—์„œ ๋’ค์— ์–ด๋–ค ๋‹จ์–ด๊ฐ€ ์žˆ๋Š”์ง€ ๋ฏธ๋ฆฌ ์•Œ๊ณ ์žˆ๋Š”, ์ผ์ข…์˜ cheating ์ƒํ™ฉ์ด๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค. ์ด๋Ÿฌ๋ฉด ๋‹น์—ฐํžˆ ํ•™์Šต์€ ์ œ๋Œ€๋กœ ๋˜์ง€ ์•Š๊ฒŒ ๋œ๋‹ค.

๋””์ฝ”๋” ๊ณผ์ •์˜ ์ด๋ฏธ์ง€ ์ค‘ Masked Self-attention์€ ์ด๋ฅผ ํ•ด๊ฒฐํ•˜๊ธฐ ์œ„ํ•œ ๋ฐฉ๋ฒ•์œผ๋กœ, ๊ธฐ์กด์˜ attention ๋ชจ๋“ˆ์—์„œ Q, K ๋‚ด์ ๊ณผ softmax๋ฅผ ํ†ต๊ณผํ•œ ๊ฐ’์—์„œ ํ˜„์žฌ ๋‹จ์–ด a์˜ ๋’ค์— ์žˆ๋Š” ๋‹จ์–ด๋“ค์„ key ๊ฐ’์œผ๋กœ ๊ณ„์‚ฐ๋œ ์…€๋“ค์„ ๋ชจ๋‘ ์‚ญ์ œํ•œ๋‹ค. Mask ๋ผ๋Š” ๋‹จ์–ด๋Š” ์ด์ฒ˜๋Ÿผ ๋’ค์ชฝ์˜ ์ •๋ณด๋ฅผ ๊ฐ€๋ฆฐ๋‹ค(mask)๋Š” ์˜๋ฏธ๋‹ค.

masked-self-attention

์œ„์˜ ์ด๋ฏธ์ง€๋Š” [I go home โ†’ ๋‚˜๋Š” ์ง‘์— ๊ฐ„๋‹ค] ๋ผ๋Š” ๋ฒˆ์—ญ์„ ์ˆ˜ํ–‰ํ•˜๋Š” ์‚ฌ๋ก€์ธ๋ฐ, Q,K์˜ ๋‚ด์ ์„ ํ†ตํ•ด ์–ป์€ ์ •์‚ฌ๊ฐ ํ–‰๋ ฌ์„ ํ‘œํ˜„ํ•˜๊ณ  ์žˆ๋‹ค. ์ด ๋•Œ ์ฃผ๋Œ€๊ฐ์„  ์œ„์ชฝ์˜ ๊ฐ’๋“ค์€ query๋ณด๋‹ค key๊ฐ€ ๋’ค์ชฝ์˜ ๋‹จ์–ด๋“ค์ธ ๊ฒฝ์šฐ๋กœ, ์ด ์…€๋“ค์˜ ์ •๋ณด๋ฅผ ๊ทธ๋Œ€๋กœ ๋‘” ์ฑ„๋กœ ํ•™์Šต์‹œํ‚ค์ง€ ๋ชปํ•˜๋„๋ก ํ•ด๋‹น ๊ฐ’๋“ค์„ 0์œผ๋กœ ๋Œ€์ฒดํ•œ๋‹ค. ๊ทธ ์ดํ›„, ๋‚จ์€ ์ฃผ๋Œ€๊ฐ์„  ์ดํ•˜์˜ ์…€๋“ค๋งŒ ๊ฐ€์ง€๊ณ , row๋‹จ์œ„๋กœ ์ดํ•ฉ์ด 1์ด ๋˜๋„๋ก normalize ํ•œ ์ •๋ณด๋ฅผ ์ตœ์ข… output์œผ๋กœ ๋‚ด๋ณด๋‚ธ๋‹ค.


Reference

The Illustrated Transformer - ํ•œ๊ตญ๋ฒˆ์—ญ

Group Normalization

Transformer: All you need is Attention (์„ค๋ช…/์š”์•ฝ/์ •๋ฆฌ)

๐Ÿ“ƒ Review of "Attention Is All You Need"


WRITTEN BY

์•ŒํŒŒ์นด์˜ Always Awake Devlog

Seoul