N-Beats (2019)
๐ Oreshkin, Boris N., et al. โN-BEATS: Neural basis expansion analysis for interpretable time series forecasting.โ arXiv preprint arXiv:1905.10437 (2019).
๋ค์ด๊ฐ๋ฉฐ
2018๋ ์๊ณ์ด ๋ชจ๋ธ ๊ฒฝ์ง๋ํ์ธ M4 Competition์ด ์ด๋ ธ์๋๋ฐ์. ํด๋น ๋ํ์์ ์ฌ๋ฐ๊ฒ ์ฌ๊ธธ๋งํ ์ ์ ๋ฐ๋ก ์์ ML ๋ชจ๋ธ๋ค์ ์ฑ์ ์ ๋๋ค. ์ด 60๊ฐ์ ํ์์ ์ฌ์ฏ ๊ฐ์ ํ๋ง์ด ์์ ML ๋ชจ๋ธ์ ์ ์ถํ์๊ณ , ํด๋น ํ๋ค์ ์์ ์ค ๊ฐ์ฅ ๋์ ์ฑ์ ์ 23์์์ต๋๋ค. ์ฐธ๊ณ ๋ก ํด๋น ๋ํ์์ ์ฐ์น์ ์ฐจ์งํ ๋ชจ๋ธ์ ํต๊ณ์ ์ธ ๋ฐฉ๋ฒ๋ก ๊ณผ ML ๋ฐฉ๋ฒ๋ก ์ ์์ ES-RNN (Exponential Smoothing Recurrent Neural Network) ์ด์์ต๋๋ค.
๋ณธ ๋ ผ๋ฌธ์ ์ด๋ฐ ์ํฉ์์ ํด๋น ๋ํ์ ์ ์ถ๋ ๋ชจ๋ธ๋ณด๋ค ์ฑ๋ฅ์ด ์ข์ ์์ ML ๊ธฐ๋ฐ์ ๋ชจ๋ธ์ธ N-Beats๋ฅผ ์ ์ํฉ๋๋ค. ์ด ๋ ผ๋ฌธ์์ ์ ์ํ๊ณ ์๋ ๋ชจ๋ธ์ ์ฅ์ ์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
- Deep Neural Architecture
- ๊ธฐ์กด ์ฌ๋ฌ ๋ฐ์ดํฐ(M3, M4,
TOURISM
๋ฑ)์ ๋ํด ํต๊ณ์ ์ ๊ทผ๋ฒ๋ณด๋ค ์ข์ ์ฑ๋ฅ์ ๋ณด์ด๋ ์์ DL ๊ธฐ๋ฐ์ ๋ชจ๋ธ์ ๋๋ค.
- ๊ธฐ์กด ์ฌ๋ฌ ๋ฐ์ดํฐ(M3, M4,
- ํด์ ๊ฐ๋ฅํ ์๊ณ์ด ๋ฅ๋ฌ๋ ๋ชจ๋ธ
- ๊ณ์ ์ฑ-์ถ์ธ ์์ค์ ์ ๊ทผ ๋ฐฉ์๊ณผ ๊ฐ์ ์ ํต์ ์ธ ๋ถํด ๊ธฐ๋ฒ๊ณผ ๋น์ทํ ๋ฐฉ์์ผ๋ก ๋ชจ๋ธ์ ํด์ ๊ฐ๋ฅํ ์ํคํ ์ฒ๋ก ์ค๊ณํ๋ ๊ฒ์ด ๊ฐ๋ฅํฉ๋๋ค.
๋ณธ ๋ชจ๋ธ์ ๋ํด์ ์์๋ณด๊ธฐ ์ ์ ๋ช ๊ฐ์ง ํ๊ธฐ๋ฒ(notation)์ ์ง๊ณ ๋์ด๊ฐ๊ฒ ์ต๋๋ค. ์ด์ฐ์ ์๊ฐ์ ๋ํ ๋จ๋ณ๋ ์์ธก ๋ฌธ์ ์ ๋ํด์ ๋ค์์ ์ ์ํฉ๋๋ค.
- H ๊ธธ์ด๋งํผ์ ์์ธก ๋ฒ์
- y=[yT+1,โฏ,yT+H]โRH
- T ๊ธธ์ด๋งํผ์ ๊ณผ๊ฑฐ ์ด๋ ฅ
- [y1,โฏ,yT]โRT
- ๊ธธ์ด tโคT์ lookback window
- x=[yTโt+1,โฏ,yT]โRt
- y๋ฅผ ์์ธกํ ๊ฐ หy
N-Beats
N-Beats์ ์ํคํ ์ฒ๋ฅผ ์ค๊ณํ ๋ ๋ค์์ ํฌ์ธํธ๋ฅผ ์ค์ํ๊ฒ ์ฌ๊ฒผ๋ค๊ณ ํฉ๋๋ค.
- ๊ธฐ๋ณธ ์ํคํ ์ฒ๋ ๋จ์ํ๊ณ ์ผ๋ฐ์ ์ด๋ ๋์ ํํ๋ ฅ์ ๊ฐ๊ณ ์์ด์ผ ํจ
- ์๊ณ์ด์ ํนํ๋ ํผ์ฒ ์์ง๋์ด๋ง์ด๋ ์ ๋ ฅ๊ฐ ์ค์ผ์ผ๋ง ๋ฑ์ ์์กดํ์ง ์๋ ์ํคํ ์ฒ์ฌ์ผ ํจ
- ์ฌ๋์ด ๊ฒฐ๊ณผ๋ฅผ ํด์ํ ์ ์๋๋ก ํ์ฅ ๊ฐ๋ฅํ ์ํคํ ์ฒ์ฌ์ผ ํจ
์ด๋ฐ ํฌ์ธํธ๋ฅผ ํฌํจํ๊ณ ์๋ N-Beats์ ์ํคํ ์ฒ๋ ์๋ ๋ค์ด์ด๊ทธ๋จ๊ณผ ๊ฐ์ต๋๋ค.
Basic Block
The detailed architecture of a basic block.
๊ธฐ๋ณธ ๋ธ๋ก์ ํํ๋ ์ ์ด๋ฏธ์ง์ ๊ฐ์ต๋๋ค. ์ด๋ฐ ๊ธฐ๋ณธ ๋ธ๋ก์ ์ฌ๋ฌ ๊ฐ ์์ ํ๋์ ์คํ์ ๋ง๋๋๋ฐ, ์ผ๋ฐ์ ์ธ ์ค๋ช ์ ์ํด โ ๋ฒ์งธ ๋ธ๋ก์ ๋ํด ๋ค๋ฃจ๊ฒ ์ต๋๋ค.
โ ๋ฒ์งธ ๋ธ๋ก์ ๋ํ์ฌ ํด๋น ๋ธ๋ก์ ์ ๋ ฅ ๋ฒกํฐ์ธ xโ ์ด ์์ต๋๋ค. ๋ง์ฝ โ=1 ์ด๋ผ๋ฉด ๋งจ ์ฒ์ ๋ธ๋ก์ด๋ฏ๋ก xโ ์ ๋ชจ๋ธ์ ์ ๋ ฅ ๋ฒกํฐ์ ๊ฐ์์ง๋๋ค. ๊ทธ๋ฆฌ๊ณ ์ฐ๋ฆฌ๊ฐ ์์ธกํ ๋ฒ์์ ๊ธธ์ด๋ฅผ H ๋ผ๊ณ ํ๋ฉด ์ต์ด ๋ธ๋ก์ ์ ๋ ฅ๊ฐ์ด ๋๋ ๋ฒกํฐ์ ๊ธธ์ด๋ ๋ณดํต 2H ์์ 7H ๋ก ์ค์ ํฉ๋๋ค. ์ฆ ์์ธกํ๋ ํ์์คํฌํ ๊ธธ์ด์ ๋ ๋ฐฐ์์ ์ผ๊ณฑ ๋ฐฐ์ ๋ฐ์ดํฐ๋ฅผ ์ ๋ ฅ๊ฐ์ผ๋ก ์ฌ์ฉํฉ๋๋ค.
ํ์ง๋ง ๋ค๋ฅธ ๊ฒฝ์ฐ์๋ ๋ชจ๋ ์ด์ ๋ธ๋ก์ residual output์ ์ ๋ ฅ์ผ๋ก ๋ฐ์ต๋๋ค. ๊ทธ๋ฆฌ๊ณ ๋ ๊ฐ์ ์์ํ ๋ฒกํฐ หxโ, หyโ ๊ฐ ์์ต๋๋ค. ๊ฐ๊ฐ์ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
- หxโ : Backcast ์์ธก
- ์ ๋ ฅ ๋ฒกํฐ์ ๋ํ ์์ธก์ ๋๋ค.
- หyโ : Forecast ์์ธก
- ์ค์ ๋ก ์์ธกํ ๋ฒ์์ ๋ํ ์์ธก์ ๋๋ค.
์ด๋ฐ ์ค์ ์๋ ๊ธฐ๋ณธ ๋ธ๋ก์ ๋ค ๊ฐ์ FC ๋ ์ด์ด๋ฅผ ๊ฑฐ์ณ์ ๋ ๊ฐ์ ๋ถ๊ธฐ๋ก ๋๋ ์ง๋๋ฐ ๊ฐ ๋ถ๊ธฐ์์ backcast์ forecast์ ๋ํ ์์ธก ๊ณ์ ฮธbโ ์ ฮธfโ๋ฅผ ์ป๊ฒ ๋ฉ๋๋ค. ์ฌ๊ธฐ๊น์ง๋ฅผ ์์์ผ๋ก ๋ํ๋ด๋ฉด ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
- hโ,1=FCโ,1(xโ)
- hโ,2=FCโ,2(hโ,1)
- hโ,3=FCโ,3(hโ,2)
- hโ,4=FCโ,4(hโ,3)
- ฮธbโ=Linearbโ(hโ,4)
- ฮธfโ=Linearfโ(hโ,4)
์ฌ๊ธฐ์ FC ๋ fully connected layer์ ReLU๋ก ๊ตฌ์ฑํฉ๋๋ค.
hโ,1=ReLU(Wโ,1xโ+bโ,1)๋ง์ง๋ง์ผ๋ก ๊ธฐ์ ๋ ์ด์ด(basis layer) gbโ์ gfโ๋ฅผ ๊ฑฐ์ณ ๋ค์์ ๊ณ์ฐํฉ๋๋ค.
^yโ=gfโ(ฮธfโ)=dim(ฮธfโ)โi=1ฮธfโ,ivfi,^xโ=gbโ(ฮธbโ)=dim(ฮธbโ)โi=1ฮธbโ,ivbi์ด๋ ๊ธฐ์ ๋ ์ด์ด๋ ํ์ต ๊ฐ๋ฅํ ํ๋ผ๋ฏธํฐ๋ก ์ค์ ํ ์๋ ์๊ณ ํน์ ํจ์ ํํ๋ก ์ค์ ํ ์๋ ์์ต๋๋ค.
Doubly Residual Stacking
์ผ๋ฐ์ ์ธ residual connection์ ์ ๋ ฅ๊ฐ์ ๋ช ๊ฐ์ ๋ ์ด์ด๋ฅผ ๊ฑด๋ ๋ฐ์ด ๋ํ๋ ๋ฐฉ์์ ์ฌ์ฉํฉ๋๋ค. ์ด๋ ๊ฒ ํ๋ฉด ๋ ๊น์ ๊ตฌ์กฐ๋ฅผ ์ ํ์ตํ๋ ์ด์ ์ด ์์ต๋๋ค. ํ์ง๋ง ํด์ ๊ฐ๋ฅํ ์ํคํ ์ฒ๋ฅผ ๊ตฌ์ฑํ๋ ๊ฒฝ์ฐ์๋ ๋์์ด ๋์ง ๋ชปํฉ๋๋ค. ๊ทธ๋์ ์ ์๋ ๊ธฐ์กด ๊ฐ์ ๋ํ๋ ๋์ ๋นผ๋ ๋ฐฉ์์ ์ฑ์ฉํ์ต๋๋ค. Backcast์์ ๋ธ๋ก์ ์ ๋ ฅ ๋ฒกํฐ์ ํ์ฌ ๋ธ๋ก์ backcast๋ฅผ ๋บ residual์ ๋ค์ ๋ธ๋ก์ผ๋ก ๋๊ฒจ์ฃผ๋ ๋ฐฉ์์ ๋๋ค.
xโ=xโโ1โหxโโ1Forecast๋ residual connection ์์ด ๋งค ๋ธ๋ก์ forecast๋ฅผ ๋ํฉ๋๋ค.
หy=โโหyโ์ด ๊ตฌ์กฐ๋ฅผ ํตํด ์ป๋ ํจ๊ณผ๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
- ์ด์ ๋ธ๋ก์ด ์ ๋ ฅ ๋ฒกํฐ์ ์ผ๋ถ ์๊ทธ๋ หxโโ1 ์ ์ ๊ฑฐํ์ฌ ๋ธ๋ก์ ์์ธก ์์ ์ ์ฝ๊ฒ ๋ง๋ค์ด์ค๋๋ค.
- Backcast์ residual connection ๊ตฌ์กฐ๋ก ์ธํด ๊ทธ๋ผ๋์ธํธ๊ฐ ๋ ์ ํ๋ฌ ์ญ์ ํ๋ฅผ ์ฉ์ดํ๊ฒ ํฉ๋๋ค.
- Forecast์ summation connection ๊ตฌ์กฐ๋ ๊ณ์ธต์ ๋ถํด(hierarchical decompostion) ๋ฅผ ๊ฐ๋ฅํ๊ฒ ํฉ๋๋ค.
- gbโ์ gfโ๋ก ์ธํด ๊ฐ์ ๋๋ ์๋์ ์ธ ๊ตฌ์กฐ๋ forecast์ ๊ณ์ธต์ ๋ถํด๊ฐ ๋ชจ๋ธ์ ํด์์ ๊ฐ๋ฅํ๊ฒ ํ๋ ์ค์ํ ์๋ฏธ๋ฅผ ๊ฐ์ง๋๋ค.
์ ๋ด์ฉ๊น์ง๋ฅผ ํ๋์ ์คํ์ผ๋ก ๊ตฌ์ฑํด์ stack residual์ ๋ค์ ์คํ์ผ๋ก, ๊ฐ ์คํ์ stack forecast๋ ๋ชจ๋ ํฉํด์ global forecast๋ฅผ ์ป๋ ๋ฐฉ์์
๋๋ค. ์ด๋ ํ์ต์ MSE๋ฅผ ์์ค ํจ์๋ก ํด์ ์งํํฉ๋๋ค.
Interpretability
N-Beats๋ gbโ์ gfโ๋ฅผ ์ค์ ํ๋ ๋ฐฉ๋ฒ์ ๋ฐ๋ผ ๋ ๊ฐ์ ์ํคํ ์ฒ๋ก ๋๋ฉ๋๋ค. ์ง๊ธ๊น์ง ๋ค๋ฃฌ ์ผ๋ฐ์ ์ธ ์ํคํ ์ฒ(Generic architecture)๋ ์๊ณ์ด์ ํนํ๋ ์ง์์ ์์กดํ์ง ์์ต๋๋ค. ํ์ง๋ง ์ด์ ๋ถํฐ ์ค๋ช ํ ํด์ ๊ฐ๋ฅํ ์ํคํ ์ฒ(interpretable architecture) ๋ ํด์๋ ฅ์ ์ํด์ ์ ๋ ํธํฅ(inductive bias)๋ฅผ ์ถ๊ฐํ์ต๋๋ค. ์ด๋ ์๊ณ์ด์ ๋ํ ์ ๋ณด๊ฐ ๋ค์ด๊ฐ์ฃ .
์ผ๋ฐ์ ์ธ ์ํคํ ์ฒ๋ gbโ์ gfโ๋ฅผ ์ด์ ๋ ์ด์ด ์์ํ์ linear projection์ผ๋ก ์ค์ ํฉ๋๋ค.
หyโ=Vfโฮธfโ+bfโหxโ=Vbโฮธbโ+bbโ์ด๋ Vfโ ๋ Hรdim(ฮธfโ) ์ ์ฐจ์์ ๊ฐ์ง๋๋ค.
ํด์ ๊ฐ๋ฅํ ์ํคํ ์ฒ๋ gbโ์ gfโ ๋ฅผ ์ด๋ป๊ฒ ์ค์ ํ๋๋์ ๋ฐ๋ผ ์ถ์ธ ๋ชจ๋ธ(trend model) ๊ณผ ๊ณ์ ์ฑ ๋ชจ๋ธ(seasonality model) ๋ก ๋๋ฉ๋๋ค.
Trend model
์ถ์ธ์ ์ผ๋ฐ์ ์ธ ํน์ฑ์ด๋ผ๊ณ ํ๋ฉด ๋จ์กฐ์ฆ๊ฐ ๋๋ ๋จ์กฐ๊ฐ์ํ๋ ํํ๋ฅผ ๊ฐ๊ฑฐ๋ ์ฒ์ฒํ ๋ณํํ๋ ํํ๋ฅผ ๊ฐ๋๋ค๋ ์ ์
๋๋ค. ์ด๋ฐ ํน์ฑ์ ๋ํ๋ด๊ธฐ ์ํด์ ์ ์๋ ์์ ์ฐจ์์ ๋คํญํจ์ ํํ๋ฅผ ์ฐจ์ฉํ์ต๋๋ค.
ํ๋ ฌ์์ผ๋ก ๋ํ๋ด๋ฉด ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
หytrs,โ=Tฮธfs,โwhere T=[1,t,โฏ,tp]์ดํด๋ฅผ ๋๊ธฐ ์ํด ์ ๊ทธ๋ฆผ์ ์ฐธ๊ณ ํ์๋ฉด gb ์ gf ๋ฅผ ํน์ ํ๋ ฌ ํํ๋ก ์ค์ ํฉ๋๋ค. ๊ฐ ํ์ backcast ๋๋ forecast์ time step์ ๋ํ๋ด๊ณ ๊ฐ ์ด์ ๋คํญํจ์์ ์ฐจ์๋งํผ์ผ๋ก ๊ตฌ์ฑ๋์ด ์์ต๋๋ค. p๋ฅผ ์ ๋นํ๊ฒ ์๊ฒ ์ค์ ํ๋ฉด หytrs,โ์ ์ถ์ธ๋ฅผ ๋ฐ๋ผ๊ฐ๊ฒ ๋ฉ๋๋ค.
Seasonality Model
๊ณ์ ์ฑ์ ๊ท์น์ ์ด๊ณ ์ฃผ๊ธฐ์ ์ด๋ฉฐ ๋ฐ๋ณต์ ์ธ ๋ณ๋์ด ์์ต๋๋ค. ์ด๋ฐ ํน์ฑ์ ๋ํ๋ด๊ธฐ ์ํด ์ฃผ๊ธฐ ํจ์๋ฅผ ์ฐจ์ฉํ๋๋ฐ์. ๊ฐ์ฅ ์ ์ ํ ์ ํ์ ์ฌ๋ฌ๋ชจ๋ก ํธ๋ฆฌ์ ๊ธ์์ ๋๋ค.
หys,โ=โH/2โ1โโi=0ฮธfs,โ,icos(2ฯit)+ฮธfs,โ,i+โH/2โsin(2ฯit)ํ๋ ฌ์์ผ๋ก ๋ํ๋ด๋ฉด ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
หyseass,โ=Sfs,โwhere S=[1,cos(2ฯt),โฏ,cos(2ฯโH/2โ1โt),sin(2ฯt),โฏ,sin(2ฯโH/2โ1โt)]๋ง์ง๋ง์ผ๋ก ์ถ์ธ ๋ชจ๋ธ๊ณผ ๊ณ์ ์ฑ ๋ชจ๋ธ์ ์๋ ๊ทธ๋ฆผ์ฒ๋ผ ๋ถ์ฌ์ฃผ๋ฉด ๋ฉ๋๋ค. ๊ฐ ์ถ์ธ ๋ธ๋ก๊ณผ ๊ณ์ ์ฑ ๋ธ๋ก์ผ๋ก ์คํ์ ๊ตฌ์ฑํ๋ฉฐ, ๊ฐ ๋ธ๋ก์ ์ผ๋ฐ์ ์ธ ์ํคํ ์ฒ์ ๋์ผํ๊ฒ residual connection์ ํ์ฉํฉ๋๋ค.
Implementation
N-Beats๋ฅผ ๊ตฌํํ ์ฝ๋๋ ๋ค์ ์ ์ฅ์์์ ํ์ธํ์ค ์ ์์ต๋๋ค. ์ฌ๋ฌ ๋ฐ์ดํฐ์ ์ ๋ํ ์คํ ์ฝ๋๋ ํฌํจํ๊ณ ์์ผ๋ฉฐ, ๋ ผ๋ฌธ๊ณผ ๋์ผํ ๊ฒฐ๊ณผ๋ฅผ ์ป๊ธฐ ์ํ ๊ฐ์ด๋ ์ญ์ ์๋ก๋์ด ์์ต๋๋ค. ๋จ, ์ผ๋ฐ์ ์ธ ์ํคํ ์ฒ๋ ๊ตฌํํ๊ธฐ ์ฝ์ง๋ง ํด์ ๊ฐ๋ฅํ ์ํคํ ์ฒ๋ ๊ตฌํ๋ ๊น๋ค๋กญ๊ณ ํด์๋ ์ฝ์ง ์์ต๋๋ค.