NLP 07 - 'Attention is all you need'์ ๋์จ Transformer ๋ชจ๋ธ ์์๋ณด๊ธฐ
BoostCamp AI Tech
NLP
Natural Language Processing
Attention
Self-Attention
Transformer
Positive Encoding
Multi-Headed Attention
Block Based Model
Encoder
Decoder
Masked Self-Attention
02/18/2021
๋ณธ ์ ๋ฆฌ ๋ด์ฉ์ Naver BoostCamp AI Tech์ edwith์์ ํ์ตํ ๋ด์ฉ์ ์ ๋ฆฌํ ๊ฒ์
๋๋ค.
์ฌ์ค๊ณผ ๋ค๋ฅธ ๋ถ๋ถ์ด ์๊ฑฐ๋, ์์ ์ด ํ์ํ ์ฌํญ์ ๋๊ธ๋ก ๋จ๊ฒจ์ฃผ์ธ์.
Transformer
๊ธฐ์กด์๋ Attention ๋ชจ๋์ด RNN์ด๋ CNN ๋ชจ๋์ Add-on ๋ชจ๋๋ก ์ฌ์ฉ๋์ด์๋ค.
๊ทธ๋ฌ๋ 2017๋ ๋ฐํ๋ ๋ ผ๋ฌธ 'Attention is all you need'๋ ๊ธฐ์กด์ RNN๊ณผ CNN์ ๋ชจ๋ ๊ฑท์ด๋ด๊ณ , ์ค๋ก์ง Attention๋ง์ผ๋ก ๊ตฌ์ถํ์ฌ ์ํ์ค ๋ฐ์ดํฐ๋ฅผ ์ ์ถ๋ ฅํ ์ ์๋ Transformer ๋ชจ๋ธ์ ๊ตฌ์ถํ์๋ค.
๊ธฐ์กด RNN ๋ชจ๋ธ์ ํ๊ณ
RNN์ ์ ๋ณด๊ฐ ์ฌ๋ฌ time step์ ๊ฑฐ์น๋ฉด์, ๋ฉ๋ฆฌ์๋ ์ ๋ณด๊ฐ ์ ์ค/๋ณ์ง ๋๋ long-term dependecy ๋ฌธ์ ๊ฐ ์์๋ค.
Bi-Directional RNNs
๊ธฐ์กด RNN์ ๋จ๋ฐฉํฅ์ผ๋ก ์ ๋ณด๋ฅผ ์ ๋ฌํ๊ธฐ ๋๋ฌธ์ ์ด๋์ ๋ ๋จผ ๊ฑฐ๋ฆฌ์ ์ ๋ณด๋ฅผ ๋ฐ์ํ๊ธฐ ์ด๋ ค์ ๋๋ฐ, ์ด๋ฅผ ํด๊ฒฐํ๊ณ ์ ์๋ฐฉํฅ์ผ๋ก RNN์ ๋ณ๋ ฌ ๊ตฌ์ฑํ Bi-Directional RNN
์ด ๋์ค๊ฒ ๋์๋ค.(Forward RNN, Backward RNN)
์ด ๋ ๋์ผํ time step์์์ hidden state vector ์ ๋ concat๋์ด ์ธ์ฝ๋ฉ ๋ฒกํฐ๋ฅผ ๊ตฌ์ฑํ๋ค.
Transformer
์ธ์ฝ๋ ๊ตฌ์กฐ
- Input์ผ๋ก ์ฃผ์ด์ง 3๊ฐ์ ๋ฒกํฐ(I, go, home)๋ ๊ฐ๊ฐ ๋ณธ์ธ์ ์ฐจ๋ก์์ ์ ํ๋ณํ()๋์ด query ๋ฒกํฐ๋ก ๊ธฐ๋ฅํ๊ณ , ์๊ธฐ ์์ ์ ํฌํจํ input ๋ฒกํฐ๋ค์ ์ ํ๋ณํ()ํ key ๋ฒกํฐ๋ค๊ณผ ๋ด์ ํ์ฌ ์๋ก์ด ๋ฒกํฐ๋ฅผ ๋ง๋ค์ด๋ธ๋ค.
- ์ด๋ฏธ์ง์์ [3.8,-0.2,5.9]
- ์ด๋ ๊ฐ ๋ฒกํฐ๊ฐ์ ์ ์ฌ๋๋ฅผ ๊ฒ์ฌํ์ฌ, input๋ ๋จ์ด๋ค๊ฐ์ ๊ด๊ณ๋ฅผ ์ ์ํ๋ค.
- ์ดํ, ์ด ๋ฒกํฐ๋ softmax๋ฅผ ๊ฑฐ์ณ ํ๋ฅ ๋ก ๋ณํํ์ฌ ๊ฐ์ค์น ๋ฒกํฐ๋ฅผ ๋ง๋ ๋ค.
- ๋ค์ ๊ธฐ์กด์ input ๋ฒกํฐ๋ค์ ์ ํ๋ณํ()ํ value ๋ฒกํฐ์ ๊ฐ์คํ๊ท ๋ด์ด ์ดํฉ์ด 1์ด ๋๋ ๋ฒกํฐ(attention output vector)๋ฅผ ๊ตฌ์ฑํ๊ณ , ์ด๊ฒ์ด ๊ณง ํด๋น input ๋ฒกํฐ์ ์ธ์ฝ๋ฉ ๋ฒกํฐ๊ฐ ๋๋ค.
- ๊ตฌํด์ง ์ ์ฌ๋๋ฅผ ๋ฐํ์ผ๋ก ๊ฐ์คํ๊ท ์ ๋ด๋ ๊ณผ์ ์ด๋ค.
์ด ๊ณผ์ ์ ํตํด ๋์จ ์ธ์ฝ๋ฉ ๋ฒกํฐ ๊ฐ์ผ๋ก ์ฌ๋ฌ input ์ค ์ด๋ค input์ ์ด๋ ์ ๋ ๋น์จ๋ก ์ง์ค(attention)ํด์ผ ํ ์ง ์ ์ ์๊ฒ ๋๋ค.
์ด ๋, ์ ํ๋ณํ(W)์ ๊ฑฐ์น์ง ์๊ณ input ๋ฒกํฐ๋ฅผ ๊ทธ๋๋ก ์ฌ์ฉํ๋ฉด, ๋น์ฐํ๊ฒ๋ ์๊ธฐ ์์ ๊ณผ์ ๋ด์ ์ด ๊ฐ์ฅ ์ปค์ง๊ธฐ ๋๋ฌธ์, ๊ฒฐ๊ณผ๋ก ๋์จ h์์๋ ์๊ธฐ์์ ์ ๋ํ attention์ด ๊ฐ์ฅ ์ปค์ ธ ์ ๋ณด์ ์ ํจ์ฑ์ด ๋ถ์กฑํด์ง๋ ๋ฌธ์ ๊ฐ ์๊ธธ ์ ์๋ค.
์ด๋ฌํ transformer์ ๊ตฌ์กฐ๋ ์ ๋ณด๊ฐ ์์นํ ๊ฑฐ๋ฆฌ์ ๊ด๊ณ์์ด ์ ์ฌ๋๋ฅผ ์ธก์ ํ์ฌ attention์ ๋ถ๋ฐฐํจ์ผ๋ก์จ ๊ธฐ์กด์ RNN๊ตฌ์กฐ๊ฐ ๊ฐ์ง๋ Long-term Dependency ๋ฌธ์ ๋ฅผ ๊ทผ๋ณธ์ ์ผ๋ก ํด๊ฒฐํ๋ค๊ณ ํ ์ ์๋ค.
๋ฏ์ด๋ณด๊ธฐ
Transformer ๋ชจ๋ธ์์ ํต์ฌ์ ์ญํ ์ ํ๋ ์ธ ๋ฒกํฐ, Query
(์ดํ Q)์ Key
(์ดํ K), Value
(์ดํ V)์ ๋ํด์ ์ดํด๋ณด์.
Output์ V๋ค์ ๊ฐ์คํ๊ท ์ธ๋ฐ, ์ด ๊ฐ์คํ๊ท ์ ๊ฒฐ๊ตญ Q์ K์ ๋ด์ ์ผ๋ก ๊ตฌ์ฑ๋๋ค.
- ์ด ๋, Q์ K๋ ๋ด์ ๊ฐ๋ฅํด์ผํ๋ฏ๋ก ๋ฐ๋์ ๊ฐ์ ์ฐจ์์ด์ด์ผ ํ๋ค. ()
- Q,K์ ์ฐจ์๊ณผ V์ ์ฐจ์()์ ๊ฐ์ ํ์๊ฐ ์๋ค. V๋ ๊ฒฐ๊ตญ ์์๋ฐฐํด์ ๊ฐ์คํ๊ท ๋ผ ๊ฒ์ด๊ธฐ ๋๋ฌธ์ด๋ค.
- Attention ๋ชจ๋์ input์ ํ๋์ query ๋ฒกํฐ, ๋ชจ๋ Key(๋ฅผ concatํ) ๋ฒกํฐ์ ๋ชจ๋ Value(๋ฅผ concatํ) ๋ฒกํฐ๊ฐ ๋๋ค.
- ๋ถ์๋ i๋ฒ์งธ key์ query ์ฌ์ด์ ์ ์ฌ๋, ์ฆ ๋ ๋จ์ด๊ฐ์ ์ ์ฌ๋๊ฐ ๋๋ค. ์ฌ๊ธฐ์ ํด๋น ๋จ์ด์ value ๋ฒกํฐ๋ฅผ ๊ณฑํ๋ค.
- ๋ถ๋ชจ๋ ๋ชจ๋ key์ ๋ํ ์ ์ฌ๋์ ์ดํฉ์ด ๋๋ค.
- ๋ฐ๋ผ์, (ํด๋น ๋จ์ด์ ์ ์ฌ๋ / ์ ์ฒด ๋จ์ด์ ์ ์ฌ๋ ์ดํฉ) ํํ๊ฐ ๋๋ฏ๋ก ๊ฐ์คํ๊ท ์ ๊ตฌ์ฑํ๊ฒ ๋๋ค. ์ด ๋ ์ถ๋ ฅ ๋ฒกํฐ๋ value ๋ฒกํฐ์ ํฌ๊ธฐ๊ฐ ๋ ๊ฒ์ด๋ค.
๊ทธ๋ ๋ค๋ฉด query๋ฅผ ์ฌ๋ฌ๊ฐ ์์ ํ๋ ฌ Q๋ฅผ ๋ง๋ ๋ค ํ๋ฒ์ ํํํด๋ณด๋๋ก ํ์.
์ด๋ฅผ ํ๋ ฌ์ ํํ๋ก ํํํ๋ฉด ๋ค์ ์ด๋ฏธ์ง์ ๊ฐ๋ค.
- : ์ฟผ๋ฆฌ์ ๊ฐ์
- : ์ฟผ๋ฆฌ์ ์ฐจ์
- Q์ K์ ๋ด์ ์ ์ํด์๋ K๊ฐ transpose๋์ด์ผ ํ๋ค. ๋ด์ ์ ๊ฒฐ๊ณผ๋ก ๋์ค๋ ํ๋ ฌ์์ i๋ฒ์งธ row๋ i๋ฒ์งธ query์ ๋ํ input ๋ฒกํฐ๋ค์ ์ ์ฌ๋๋ฅผ ๋ํ๋ด๋ row๊ฐ ๋๋ค. ์ด ์ฐ์ฐ์ด ๋๋๋ฉด softmax๋ฅผ ์ด์ฉํ์ฌ ๊ฐ์ค์น ๋ฒกํฐ๋ก ๋ณํ๋๋ค.
- ๊ฐ์ค์น ๋ฒกํฐ์ V ๋ฒกํฐ์์ ๋ด์ ์ ํตํด ๋์จ ์ถ๋ ฅ ํ๋ ฌ์ ๋ค์ ๊ธฐ์กด์ Q์ ๋์ผํ ํํ๋ฅผ ์ด๋ฃจ๋ฉฐ, ์ถ๋ ฅ ํ๋ ฌ์ i๋ฒ์งธ row๋ input Q์ i๋ฒ์งธ row(query)์ ๋ํ attention์ output์ด ๋๋ค.
Scaled Dot-Product Attention
์์ ๊ฒฝ์ฐ๋ 2์ฐจ์ ํ๋ ฌ์ ๊ฐ์ ํ๊ณ ์ํํ์ง๋ง, ์ค์ ๋ก ์ฐ์ฐ์ ์ํํ ๋์๋ q์ k๊ฐ n์ฐจ์์ผ ์ ์๋ค.
์ด๋, dimension์ด ์ปค์ง๋ฉด ์ปค์ง์๋ก q์ k์ ๋ด์ ๊ฐ์ ์ ๋ฆผํ์์ด ์ฌํด์ ธ์, attention์ด ํจ์จ์ ์ผ๋ก ํ์ต๋์ง ๋ชปํ๋ค.
- ์ด๋ ์๋ก ๋ ๋ฆฝ์ธ random variable๋ผ๋ฆฌ ๊ณฑํ์ ๋ ๋ถ์ฐ์ด 1์ด๋๊ณ , ์ด ๊ณฑ๋ค์ ์๋ก ๋ํ์ ๋ ๋ถ์ฐ์ด ๊ณ ์ค๋ํ ๋ํด์ง๊ธฐ ๋๋ฌธ์ด๋ค. ๋ด์ ์ํ ๊ณผ์ ์์๋ ์ฐจ์์ด ํด์๋ก random variable ๊ฐ๋ผ๋ฆฌ ๋ ๋ง์ด ๊ณฑํ๊ณ ๋ํด์ง๊ฒ ๋๋๋ฐ, ์ด ๊ณผ์ ์์ ๋ถ์ฐ์ด ๋๋ฌด ์ปค์ง๋ค. ์ด๋ ๊ณง ๋ด์ ๊ฐ์ ์ฐจ์ด๊ฐ ํฌ๊ฒ ๋๋ ๊ฒ์ ์๋ฏธํ๋ค.
- ์ด๋ฅผ ์ง๊ด์ ์ผ๋ก ์๊ฐํด๋ณด์. 2์ฐจ์ ํ๋ฉด์์ ๋จ์ด์ ธ์๋ ์ฌ๋ฌ ์ ๋ค์ด, 3์ฐจ์ ๊ณต๊ฐ์ผ๋ก ์ด๋ํ๋ฉด, ์ ๋ค๊ฐ์ ๊ฑฐ๋ฆฌ๋ ์ด๋ป๊ฒ ๋ ๊น? ๋์ด(z)๋ผ๋ ์๋ก์ด ์ถ์ด ์๊ฒผ์ผ๋ฏ๋ก, ๋น์ฐํ ์ ๋ค๊ฐ์ ๊ฑฐ๋ฆฌ๋ ํ๋ฉด์๋ณด๋ค ๋์ฒด๋ก ๋ฉ์ด์ง๊ณ , ์ต์ ์ ๊ฒฝ์ฐ(์ธ ์ ๋ค์ด ๊ฐ์ ํ๋ฉด์ ์์ ๋)์์ผ ๋์ผํ ๊ฑฐ๋ฆฌ๋ฅผ ๊ฐ์ง๊ฒ ๋๋ค.
- ๋ฐ๋ผ์, ์ฐจ์์ด ๋์ด๋๋ค๋ ๊ฒ์ ๋์ฒด๋ก ๊ณต๊ฐ์์ ์ ์ขํ๋ก ํํ๋๋ vector๊ฐ์ ๊ฑฐ๋ฆฌ๊ฐ sparseํด ์ง๋ค๋ ๊ฒ, ์ฆ ๋ถ์ฐ์ด ๋์ด๋๋ค๋ ๊ฒ์ ์ง๊ด์ ์ผ๋ก ์ดํดํ ์ ์๋ค.
์ด๋ฐ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด, ๋ถ์ฐ์ ์ผ์ ํ๊ฒ ์ ์งํ๊ธฐ ์ํ ๋ฐฉ๋ฒ์ผ๋ก Q์ K์ ๋ด์ ์ ์ฐจ์()์ ์ ๊ณฑ๊ทผ์ผ๋ก ๋๋์ด์ฃผ๋ ํ ํฌ๋์ด ์๋ค. ์ด ๊ฒฝ์ฐ ๋ถ์ฐ์ด ๋ถ์ฐ์ ์ ๊ณฑ๋ฐฐ๋งํผ ์ถ์๋๊ธฐ ๋๋ฌธ์, ์ฐจ์์ด ๋ช์ฐจ์์ด๋ ๋ถ์ฐ์ ํญ์ 1๋ก ์ผ์ ํ๊ฒ Scalingํ ์ ์๋ค.
Multi-Head Attention
Multi-Head Attention
์ ๊ธฐ์กด์ Attention ๋ชจ๋์ ์ข ๋ ์ ์ฉํ๊ฒ ํ์ฅํ ๋ชจ๋์ด๋ค.
Multi-head attention์ ์ฌ๋ฌ๊ฐ์ attention ๋ชจ๋์ ๋์์ ์ฌ์ฉํ๋ค. ์ด ๋ ๊ฐ attention ๋ชจ๋์ ์ ํ๋ณํ ํ๋ผ๋ฏธํฐ (head)๋ ๋ชจ๋๋ง๋ค ๊ฐ๊ฐ ๋ค๋ฅด๋ค. ์ด ๊ฐ๊ธฐ ๋ค๋ฅธ version์ ๋ชจ๋๋ค์ ์ด์ฉํ์ฌ ๋ธ output๋ค์ concat(ํ ๋ค, ๋ก ์ ํ๋ณํํ์ฌ ํ๋์ output์ ๋ง๋๋ ํํ์ด๋ค.
์ด๋ฐ Multi-head attention์ ์ฌ์ฉํ๋ ์ด์ ๋, ๋์ผํ ์ ๋ ฅ๋ฌธ ๊ธฐ์ค์ผ๋ก๋ ํ์์ ๋ฐ๋ผ ์ค์ ์ ๋์ด์ผ ํ ๋จ์ด๋ค์ด ๋ค๋ฅผ ๊ฒฝ์ฐ๊ฐ ์๊ธฐ ๋๋ฌธ์ด๋ค.
- ๊ฐ๋ น, I am going to eat dinner๋ผ๋ ๋ฌธ์ฅ์ด ์๋ค๊ณ ํ์.
- ์ด๋ค ๋๋ '๋ด๊ฐ' ๋จน์๋ค๋ ์ฌ์ค์ ์ฃผ๋ชฉํด์ผ ํด์ 'I'์ ์ง์คํด์ผ ํ ์๋ ์๊ณ , ์ด๋ค ๋๋ '์ ๋ '์ ๋จน์๋ค๋ ์ฌ์ค์ ์ฃผ๋ชฉํ๊ธฐ ์ํด 'dinner'์ ์ง์คํด์ผ ํ ์๋ ์๋ค.
attention์ ์ฐ์ฐ๋
- : sequnce ๊ธธ์ด
- : query / key ์ฐจ์
- : ํฉ์ฑ๊ณฑ ์ฐ์ฐ์ ์ปค๋ ์ฌ์ด์ฆ
- : restricted self-attention์ ์ด์ ์ฌ์ด์ฆ
์์ ๊ฐ์ ๋, ๊ธฐ์กด layer๋ค์ ์ฐ์ฐ๋์ ๋ค์๊ณผ ๊ฐ๋ค.
Complexity per Layer
- Total Computational Complexity per Layer๋ฅผ ์๋ฏธํ๋ ๊ฒ์ผ๋ก, ์ด ์ฐ์ฐ์์ ๋ณต์ก๋๋ฅผ ์๋ฏธํ๋ค. ๋ ผ๋ฌธ ์๋ฌธ์์๋ (์ฐ์ฐ์์ ๋ฐ๋ฅธ) ์๊ฐ๋ณต์ก๋๋ฅผ ์๋ฏธํ ๋ฏ ํ๋ฐ, ๊ณต๊ฐ๋ณต์ก๋๋ก ํด์ํด๋ ํฐ ๋ฌด๋ฆฌ๋ ์๋ ๋ฏํ๋ค. ์ฐ์ฐ์ ์ฒ๋ฆฌํ๋ ๋๋ฐ์ด์ค๊ฐ 1๊ฐ์ธ ์ํฉ์ ๊ฐ์ ํ๋ค๊ณ ์๊ฐํ๋ฉด ์ผ์ถ ๋ค์ด๋ง๋๋ค.
- Self-attention
- Q์ K๋ฅผ ๋ด์ ํ๋ฏ๋ก ๊ฐ ๊ธธ์ด ์ ์ ๊ณฑํ๊ณ , ์ด๋ฅผ ๋ชจ๋ dimension๋ง๋ค ๊ณ์ฐํด์ผ ํ๋ฏ๋ก ๋ฅผ ๊ณฑํด์ผํ๋ค. -
- Recurrent
- time step์ ๊ฐ์๊ฐ ์ด๊ณ , ๋งค time step๋ง๋ค ํฌ๊ธฐ์ ๋ฅผ ๊ณฑํ๋ค. - ์ด ๋ ์ dimension ๋ hidden state vector์ ํฌ๊ธฐ๋ก, ํ์ดํผํ๋ผ๋ฏธํฐ์ด๋ฏ๋ก ์ง์ ์ ํด์ค ์ ์๋ค.
Sequential Operations
- ํด๋น ์ฐ์ฐ์ ์ผ๋ง์ ์๊ฐ๋ด์ ๋๋ผ ์ ์๋๊ฐ๋ฅผ ๋ํ๋ธ ๊ฒ์ผ๋ก, ์ฐ์ฐ์ ์์ฒด๋ ๋ฌดํํ ๋ง์ GPU์ ๋ณ๋ ฌ์ฐ์ฐ์ผ๋ก ํ๋ฒ์ ์ฒ๋ฆฌํ ์ ์์์ ๊ฐ์ ํ๋ค.
- Self-attention
- ์ํ์ค์ ๊ธธ์ด ์ด ๊ธธ์ด์ง์๋ก ์ง์๋ฐฐ๋ก ์ฐ์ฐ๋ณต์ก๋๊ฐ ๋์ด๋๋ค.
- ์ด๋ ๋ชจ๋ Q์ K์ ๋ด์ ๊ฐ์ ๋ชจ๋ ์ ์ฅํ๊ณ ์์ด์ผํ๊ธฐ ๋๋ฌธ์ด๋ค.๋ฐ๋ผ์ ์ผ๋ฐ์ ์ธ Recurrent๋ณด๋ค ํจ์ฌ ๋ง์ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ํ์๋ก ํ๊ฒ ๋๋ค.
- ๊ทธ๋ฌ๋ GPU๋ ์ด๋ฐ ํํ์ ํ๋ ฌ ์ฐ์ฐ ๋ณ๋ ฌํ์ ํนํ๋์ด ์๊ณ , ๋ฐ๋ผ์ ๋ฌดํํ ๋ง์ GPU๋ฅผ ๊ฐ์ง๊ณ ๋ง ์๊ธฐ๋ง ํ๋ค๋ฉด ์ด๋ฅผ ๋ณ๋ ฌํํ์ฌ ๊ณ์ฐํ ์ ์์ผ๋ฏ๋ก, ์๊ฐ ๋ณต์ก๋๋ ์ด ๋๋ค.
- Recurrent
- ์ด์ time step์ ์ด ์ ๊ณต๋์ด์ผ ๊ทธ๊ฒ์ input์ผ๋ก ๋ค์ ๋ฅผ ๊ณ์ฐํ ์ ์๊ธฐ ๋๋ฌธ์, ๋ถ๊ฐํผํ๊ฒ ์๊ฐ ๋ณต์ก๋๋ ์ด ๋๋ค.
Maximum Path Length
- ๋ ๋จ์ด ๊ฐ์ ๊ฒฝ๋ก ๊ฑฐ๋ฆฌ- Self-attention
- ๋ ๋จ์ด๊ฐ์ ์ ์ฌ๋๋ฅผ ๊ตฌํ ๋, ํ๋ ฌ ์ฐ์ฐ์ผ๋ก ๋ฐ๋ก ๊ณฑํ ์ ์์ผ๋ฏ๋ก ์ด๋ค.
- Recurrent
- ์ด๋ค ๋จ์ด a๊ฐ ์ด๋์ ๋ ๋จ์ด์ง ๋จ์ด b์ ๋๋ฌํ๊ธฐ๊น์ง recurrent cell์ ํ๋์ฉ ํต๊ณผํด์ผํ๊ธฐ ๋๋ฌธ์ ์ด ๋๋ค.
- Self-attention
Block-Based Model
์ฒ์์ ์์ํด Multi-Head Attention์ผ๋ก ๊ฐ๋ ์ธ๊ฐ์ ํ์ดํ๋ ๊ฐ๊ฐ Q,K,V๋ฅผ ์๋ฏธํ๋ค. ๊ฐ head๋ง๋ค ๋ค์ด๊ฐ๊ฒ ๋๋ค.
๊ทธ ์ฐ์ฐ ์ดํ์ ์งํ๋๋ Add&Norm ๊ตฌ๊ฐ์ ๋ฌด์์ผ๊น?
- Add -
Residual Connection
- CV ์ชฝ์์ ๊น์ ๋ ์ด์ด๋ฅผ ๋ง๋ค ๋ graident vanishing์ ํด๊ฒฐํ๋ฉด์ ๋ ๊น์ ์ธต์ ์๋๋ก ํ๋ ํจ๊ณผ์ ์ธ ๋ชจ๋ธ์ด๋ค.
- ์ฃผ์ด์ง input vector๋ฅผ Multi-Head Attention์ encoding output์ ๊ทธ๋๋ก ๋ํ์ฌ ์๋ก์ด output์ ๋ง๋ค์ด์, ํ์ต์์ Multi-Head Attention์ด ์ ๋ ฅ ๋ฒกํฐ ๋๋น ์ ๋ต ๋ฒกํฐ์์ '์ฐจ์ด๋๋ ์ ๋ณด'๋ง ํ์ตํ๋๋ก ํ ์ ์๋ค.
- ์ด ๋, Multi-Head Attention output๊ณผ input ๋ฒกํฐ์ ํฌ๊ธฐ๊ฐ ์์ ํ ๋์ผํ๋๋ก ์ ์งํด์ผ ๋ํ ์ ์๋ค.
- Norm -
Normalization
- ์ผ๋ฐ์ ์ผ๋ก ์ ๊ฒฝ๋ง์์ ์ฌ์ฉ๋๋ normalization์, (ํ๊ท ,๋ถ์ฐ)์ (0,1)๋ก ๋ง๋ ๋ค, ์ํ๋ ํ๊ท ๊ณผ ๋ถ์ฐ์ ์ฃผ์
ํ ์ ์๋๋ก ํ๋
์ ํ๋ณํ(affine transformation)
์ผ๋ก ์ด๋ฃจ์ด์ง๋ค. Batch Norm
- ๊ฐ ์์์ ํ๊ท ์ ๋นผ๊ณ , ํ์คํธ์ฐจ๋ก ๋๋๋ค. โ (ํ๊ท ,๋ถ์ฐ)==(0,1)
- affine transforamtionํ์ฌ ์ํ๋ ํ๊ท ๊ณผ ๋ถ์ฐ์ผ๋ก ๋ง๋ ๋ค.
- ex )
- ์ด ๋ 2์ 3์ Optimization ๊ณผ์ ์์ ์ต์ ํ ๋์์ธ ํ๋ผ๋ฏธํฐ๊ฐ ๋๋ค.
Layer Norm
- Batch Norm๊ณผ ๋ง์ฐฌ๊ฐ์ง ๋ฐฉ๋ฒ์ผ๋ก ์ํํ๋, ์ฌ๋ฌ layer๊ฐ ๋ถ์ด์๋ ํ๋ ฌ์ ๋์์ผ๋ก, ํ layer๋ง๋ค ์ํํ๋ค.
- affine transformation์ ๊ฐ layer์ ๋์ผํ node ๊ธฐ์ค์ผ๋ก ์ํํ๋ค.(normalization์ด column ๋จ์์๋ค๋ฉด affine transformation์ row ๋ณ)
- Batch Norm๊ณผ๋ ์ผ๋ถ ์ฐจ์ด์ ์ด ์์ง๋ง, ํฐ ํ์์ ํ์ต์ ์์ ํํ๋ค๋ ์ ์ ๋์ผํ๋ค.
- ์ผ๋ฐ์ ์ผ๋ก ์ ๊ฒฝ๋ง์์ ์ฌ์ฉ๋๋ normalization์, (ํ๊ท ,๋ถ์ฐ)์ (0,1)๋ก ๋ง๋ ๋ค, ์ํ๋ ํ๊ท ๊ณผ ๋ถ์ฐ์ ์ฃผ์
ํ ์ ์๋๋ก ํ๋
Add&Norm ๊ตฌ๊ฐ์ ๊ฑฐ์น๊ณ ๋์จ output์ ๋ค์ fully connected layer(Feed Forward)์ ํต๊ณผ์์ผ Word์ ์ธ์ฝ๋ฉ ๋ฒกํฐ๋ฅผ ๋ณํํ๋ค. ์ดํ ๋ค์ Add&Norm์ ํ๋ฒ ๋ ์ํํ๋ ๊ฒ๊น์ง๋ฅผ ๋์ผ๋ก Transformer์ (self-attention ๋ชจ๋์ ํฌํจํ) Block Based Model
์ด ์์ฑ๋๋ค.
Positional Encoding
RNN๊ณผ ๋ฌ๋ฆฌ self-attention ๋ชจ๋ ๊ธฐ๋ฐ์ Block Based Model๋ก ์ธ์ฝ๋ฉํ๋ ๊ฒฝ์ฐ, ์์๋ฅผ ๊ณ ๋ คํ์ง ์๊ธฐ ๋๋ฌธ์ input ๋จ์ด๋ค์ ์์๊ฐ ๋ฐ๋์ด๋ output์ ๋์ผํ๋ค. ์ด๋ K์ Q๊ฐ์ ์ ์ฌ๋๋ฅผ ๊ตฌํ๊ณ V๋ก ๊ฐ์ค์น๋ฅผ ๊ตฌํด ๊ฐ์คํฉ(์ด๋ softmax๋ฅผ ํต๊ณผํ ๊ฐ์ด๋ฏ๋ก ๊ฐ์คํฉ ์์ฒด๊ฐ ๊ฐ์คํ๊ท ์ด๋ค)์ ๋์ถํ๋ ๊ณผ์ ์์, sequence๋ฅผ ๊ณ ๋ คํ์ง ์๊ธฐ ๋๋ฌธ์ด๋ค.
Positional Encoding
์ ์ด๋ฅผ ํด๊ฒฐํ๊ธฐ์ํด ๋ฒกํฐ ๋ด์ ํน์ ์์์ ํด๋น word์ ์์๋ฅผ ์์๋ณผ ์ ์๋ , ๋ง์น ์ง๋ฌธ๊ณผ๋ ๊ฐ์ uniqueํ ๊ฐ์ ์ถ๊ฐํ์ฌ sequence๋ฅผ ๊ณ ๋ คํ๋ ๊ฒ์ ๋งํ๋ค.
- ์ด ๋, uniqueํ ๊ฐ์ ์ฌ๋ฌ ์ฃผ๊ธฐํจ์์ ์ถ๋ ฅ ํจ์๊ฐ์ ํฉ์ณ ์ฌ์ฉํ๋ค. ์ฃผ๊ธฐํจ์๋ ์ ๋ ฅ๊ฐ x์ ์์น์ ๋ฐ๋ผ ์ถ๋ ฅ๊ฐ์ด ๋ณํ๊ธฐ ๋๋ฌธ์ด๋ค.
- ๋จ, ํ๋์ ์ฃผ๊ธฐํจ์๋ง ์ฌ์ฉํ๋ฉด ๋์ผํ ํจ์๊ฐ์ ๊ฐ์ง๋ ๊ตฌ๊ฐ์ด ์๊ธฐ๋ฏ๋ก, ์๋ก ๋ค๋ฅธ ์ฌ๋ฌ ์ฃผ๊ธฐํจ์์ ์ถ๋ ฅ๊ฐ๋ค์ ๋ชจ๋ ํฉ์ณ์ ์ฌ์ฉํ๋ค.
์ด๋ ๊ฒ ํน์ํ ๊ฐ์ ์ถ๊ฐํ์ฌ ์ธ์ฝ๋ฉํ๊ฒ ๋๋ฉด, input ๋จ์ด์ ์์๊ฐ ๋ฐ๋์์ ๋ output ๊ฐ๋ ๋ฌ๋ผ์ง๊ฒ ๋์ด ์์๋ฅผ ๊ตฌ๋ณํ ์ ์๋ ๋ชจ๋ธ์ด ๋๋ค.
Learning Rate Scheduler
๊ธฐ์กด์ ๋ชจ๋ธ์์ ํ์ต๋ฅ (learning rate)๋ ํ์ดํผํ๋ผ๋ฏธํฐ๋ก, ํ์ต ๋ด๋ด ๊ณ ์ ๋์ด์๋ ๊ฐ์ด์๋ค. ๊ทธ๋ฌ๋ ํ์ต์ ๊ณผ์ ๋์ ํจ์จ์ ์ธ ํ์ต๋ฅ ์ ๊ณ์ ๋ฐ๋๊ธฐ ๋ง๋ จ์ด๋ฏ๋ก, ์ด๋ฅผ ํ์ต ๊ณผ์ ๋ด์์ ํจ๊ณผ์ ์ผ๋ก ๋ฐ๊พธ์ด ์ค ์ ์๋ ๋ฐฉ์์ผ๋ก Learning Rate Scheduler
๊ฐ ๋์ค๊ฒ ๋์๋ค.
๋์ฝ๋ ๊ตฌ์กฐ
Outputs๊ฐ ๋์ฝ๋์ ์
๋ ฅ์ผ๋ก ๋ค์ด์ฌ ๋, ๊ธฐ์กด์ ground truth ๋ฌธ์ฅ์์ ์์ชฝ์๋ <SOS>
ํ ํฐ์ ๋ถ์ฌ ๋ค์ด์ค๋ฏ๋ก, ํ์นธ ๋ฐ๋ฆฐ(shfited right) ํํ๋ก ๋ค์ด์ค๊ฒ ๋๋ค.
๋์ฝ๋์์ Attention ๋ชจ๋์ ํฌํจํ ํ์ฐจ๋ก์ ๊ณผ์ ์ ๊ฑฐ์น ํ ๋ค์ Multi-Head Attention์ผ๋ก ๊ฐ ๋, ๋์ฝ๋์ hidden state vector๋ฅผ ์ ๋ ฅ Q๋ก ๋๊ฒจ์ค๋ค. ๊ทธ๋ฐ๋ฐ, ๋๋จธ์ง K์ V ์ ๋ ฅ์ ์ธ๋ถ, ์ฆ ์ธ์ฝ๋์ ์ต์ข ์ถ๋ ฅ์ผ๋ก๋ถํฐ ์จ๋ค. ์ฆ, ์ด ๋ถ๋ถ์ ๋์ฝ๋์ hidden state vector๋ฅผ ๊ธฐ์ค, ์ฆ Q๋ก ํด์ ์ธ์ฝ๋์ hidden state vector K, V ๋ฅผ ๊ฐ์คํ์ฌ ๊ฐ์ ธ์ค๋, ์ธ์ฝ๋์ ๋์ฝ๋๊ฐ์ Attention ๋ชจ๋์ด ๋๋ค.
์ด ํ ์ด๋ฏธ์ง์ ๋์จ ๋๋ก์ ์ฐ์ฐ์ ๊ฑฐ์น๋ค๊ฐ, ๋์ฝ๋์ ์ต์ข output ๊ฐ์ด Linear Layer์ Softmax๋ฅผ ๊ฑฐ์ณ ํ๋ฅ ๋ถํฌ์ ํํ๋ก ์ถ๋ ฅ๋๋ค. ์ด ๊ฐ์ Softmax-with-loss ์์คํจ์๋ฅผ ํตํด ํ์ต๋๋ค.
Masked Self-Attention
Self-Attention ๋ชจ๋ธ์์, ์์์ ๋จ์ด a๋ Q์ K์ ๋ด์ ์ ํตํด ์์ ๊ณผ ๋ชจ๋ ๋จ์ด๋ค์ ๊ด๊ณ๋ฅผ ๋ค ์์ ์๋ค. ์ด ๋, ํ์ต ๋น์์๋ ๋ฐฐ์น ํ๋ก์ธ์ฑ์ ์ํด a ๋ค์ ๋จ์ด๋ค๊น์ง ๋ชจ๋ ๊ณ ๋ คํ๋๋ก ํ์ต์ด ์งํ๋๋, ์ฌ์ค ์ค์ ๋์ฝ๋ฉ ์ํฉ์ ๊ณ ๋ คํ๋ค๋ฉด a ๋ค์ ๋จ์ด๋ฅผ ์์์๋ ์๋๋ค. ์ด๋ ๋ค์ ๋จ์ด๋ฅผ ์ถ๋ก ํด์ผ ํ๋ ์ํฉ์์ ๋ค์ ์ด๋ค ๋จ์ด๊ฐ ์๋์ง ๋ฏธ๋ฆฌ ์๊ณ ์๋, ์ผ์ข ์ cheating ์ํฉ์ด๊ธฐ ๋๋ฌธ์ด๋ค. ์ด๋ฌ๋ฉด ๋น์ฐํ ํ์ต์ ์ ๋๋ก ๋์ง ์๊ฒ ๋๋ค.
๋์ฝ๋ ๊ณผ์ ์ ์ด๋ฏธ์ง ์ค Masked Self-attention
์ ์ด๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํ ๋ฐฉ๋ฒ์ผ๋ก, ๊ธฐ์กด์ attention ๋ชจ๋์์ Q, K ๋ด์ ๊ณผ softmax๋ฅผ ํต๊ณผํ ๊ฐ์์ ํ์ฌ ๋จ์ด a์ ๋ค์ ์๋ ๋จ์ด๋ค์ key ๊ฐ์ผ๋ก ๊ณ์ฐ๋ ์
๋ค์ ๋ชจ๋ ์ญ์ ํ๋ค. Mask
๋ผ๋ ๋จ์ด๋ ์ด์ฒ๋ผ ๋ค์ชฝ์ ์ ๋ณด๋ฅผ ๊ฐ๋ฆฐ๋ค(mask)๋ ์๋ฏธ๋ค.
์์ ์ด๋ฏธ์ง๋ [I go home โ ๋๋ ์ง์ ๊ฐ๋ค] ๋ผ๋ ๋ฒ์ญ์ ์ํํ๋ ์ฌ๋ก์ธ๋ฐ, Q,K์ ๋ด์ ์ ํตํด ์ป์ ์ ์ฌ๊ฐ ํ๋ ฌ์ ํํํ๊ณ ์๋ค. ์ด ๋ ์ฃผ๋๊ฐ์ ์์ชฝ์ ๊ฐ๋ค์ query๋ณด๋ค key๊ฐ ๋ค์ชฝ์ ๋จ์ด๋ค์ธ ๊ฒฝ์ฐ๋ก, ์ด ์ ๋ค์ ์ ๋ณด๋ฅผ ๊ทธ๋๋ก ๋ ์ฑ๋ก ํ์ต์ํค์ง ๋ชปํ๋๋ก ํด๋น ๊ฐ๋ค์ 0์ผ๋ก ๋์ฒดํ๋ค. ๊ทธ ์ดํ, ๋จ์ ์ฃผ๋๊ฐ์ ์ดํ์ ์ ๋ค๋ง ๊ฐ์ง๊ณ , row๋จ์๋ก ์ดํฉ์ด 1์ด ๋๋๋ก normalize ํ ์ ๋ณด๋ฅผ ์ต์ข output์ผ๋ก ๋ด๋ณด๋ธ๋ค.
Reference
The Illustrated Transformer - ํ๊ตญ๋ฒ์ญ
Transformer: All you need is Attention (์ค๋ช /์์ฝ/์ ๋ฆฌ)