Created
December 27, 2017 10:10
-
-
Save hrsma2i/0c9149d5d344a04d9124a73a8f487fba to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# chainerによるLSTMを用いた時系列予測" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"[**Brains Consulting, Inc.**](https://www.brains-consulting.co.jp/) でインターンをさせていただいている情報系のM1です。\n", | |
"2017年7月から9月にかけて、インターン業務として、LSTM を用いた時系列予測を [Chainer](https://chainer.org/) で実装してきました。\n", | |
"最終的なゴールは、複数商品の需要予測に適用可能な深層学習モデルを構築することですが、その準備として、単一商品の需要予測について検証しました。\n", | |
"業務においては、大手食品メーカ様の需要量実データを用いましたが、この記事では、Web上の公開データセットに置き換えて、その成果を報告したいと思います。\n", | |
"\n", | |
"当記事では、chainer **1.24.0** と古い version を使っていますので、ご注意ください。" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# データセット準備\n", | |
"\n", | |
"[International airline passengers: monthly totals in thousands. Jan 49 – Dec 60 — Dataset — DataMarket](https://datamarket.com/data/set/22u3/international-airline-passengers-monthly-totals-in-thousands-jan-49-dec-60#!ds=22u3&display=line)\n", | |
"\n", | |
"Export タブを押して、カンマ(,)区切りの csv 形式で DL してください。\n", | |
"\n", | |
"<img src='download_data.png' width=300 align='left'>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"csv の中身は、先頭10行を抜き出すと、以下のとおりです。" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"```\n", | |
"\"Month\",\"International airline passengers: monthly totals in thousands. Jan 49 ? Dec 60\"\n", | |
"\"1949-01\",112\n", | |
"\"1949-02\",118\n", | |
"\"1949-03\",132\n", | |
"\"1949-04\",129\n", | |
"\"1949-05\",121\n", | |
"\"1949-06\",135\n", | |
"\"1949-07\",148\n", | |
"\"1949-08\",148\n", | |
"\"1949-09\",136\n", | |
"```" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"ファイル末尾3行は不要なので、エディタなどで削除します。" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"```\n", | |
"\"1960-06\",535\n", | |
"\"1960-07\",622\n", | |
"\"1960-08\",606\n", | |
"\"1960-09\",508\n", | |
"\"1960-10\",461\n", | |
"\"1960-11\",390\n", | |
"\"1960-12\",432\n", | |
"\n", | |
"International airline passengers: monthly totals in thousands. Jan 49 ? Dec 60\n", | |
"```" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"↓(末尾3行削除)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"```\n", | |
"\"1960-06\",535\n", | |
"\"1960-07\",622\n", | |
"\"1960-08\",606\n", | |
"\"1960-09\",508\n", | |
"\"1960-10\",461\n", | |
"\"1960-11\",390\n", | |
"\"1960-12\",432\n", | |
"```" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"ダウンロードしたファイルを読み込み、時系列グラフを表示します。以下のコードは、jupyter notebook利用を前提としています。別の環境で実行する際には、適宜書き換えてください。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import pandas as pd\n", | |
"import matplotlib.pyplot as plt\n", | |
"# jupyter notebook用\n", | |
"%matplotlib inline \n", | |
"\n", | |
"df = pd.read_csv('international-airline-passengers.csv')\n", | |
"\n", | |
"series = df.iloc[:,1].values\n", | |
"\n", | |
"plt.figure(figsize=(15,10))\n", | |
"plt.grid()\n", | |
"\n", | |
"plt.plot(series)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"<img src='raw_series.png' width=1000 align='left'>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# 前処理\n", | |
"\n", | |
"以下の前処理を施すと、良い予測結果が得られました。\n", | |
"\n", | |
"- **階差 (differencing)**\n", | |
"- **正規化 (normalization)**\n", | |
"\n", | |
"参考記事 [LSTMにsin波を覚えてもらう(chainer trainerの速習) - Qiita](https://qiita.com/chachay/items/052406176c55dd5b9a6a) にある\n", | |
"sin 関数などでは、特に前処理は必要ありませんが、今回の時系列データでは、前処理を施さないと、良い予測結果が得られませんでした。" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## 階差\n", | |
"\n", | |
"まず、階差をとります。\n", | |
"階差をとる目的は、時系列の[トレンド](http://www.simafore.com/blog/bid/205420/Time-series-forecasting-understanding-trend-and-seasonality)を除くためです。\n", | |
"本来は、深層学習の枠組みでトレンドを扱えることが望ましいのですが、今回は前処理として階差をとることで、トレンドの少ない時系列データに変換しました。\n", | |
"\n", | |
"階差の定義を記述します。\n", | |
"時系列データの長さを $T$ で表します。 \n", | |
"時系列データの添え字(時刻)を $t$ ($t = 0$,..., $T - 1$)とします。\n", | |
"時系列データを $X(t)$ で表します。\n", | |
"このとき、時系列データ $X(t)$ の階差時系列 $D(t)$ は、\n", | |
"\n", | |
"$D(t) = X(t+1) - X(t)$ ($t = 0$,..., $T - 2$)\n", | |
"\n", | |
"として定義されます。\n", | |
"今回の時系列データでは、長さ\n", | |
"\n", | |
"$T = 144$\n", | |
"\n", | |
"で、各時刻の値は、\n", | |
"\n", | |
"$X(0),...,X(143)$\n", | |
"\n", | |
"で表します。\n", | |
"階差時系列は、長さが1つ少ない $T - 1 = 143$ 個の値\n", | |
"\n", | |
"$D(0),...,D(142)$\n", | |
"\n", | |
"になります。\n", | |
"時系列 $X$ に階差を施してできた時系列 $D$ のグラフを表示します。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"\n", | |
"def difference(series):\n", | |
" diffed = series[1:] - series[:-1]\n", | |
" return diffed\n", | |
"\n", | |
"diffed = difference(series)\n", | |
"\n", | |
"plt.figure(figsize=(15,10))\n", | |
"plt.grid()\n", | |
"\n", | |
"plt.plot(diffed)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"<img src='diffed.png' width=500 align='left'>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## 教師ありデータに変換\n", | |
"\n", | |
"教師あり学習を行うので、**入力**用データと**ラベル**用データを作成します。\n", | |
"\n", | |
"入力\n", | |
"\n", | |
"- $D(0),..., D(141)$ の $T - 2 = 142$ 個\n", | |
"- 最後の時刻 $D(142)$ 以外\n", | |
"\n", | |
"ラベル\n", | |
"\n", | |
"- $D(1),..., D(142)$ の $T - 2 = 142$ 個\n", | |
"- 最初の時刻 $D(0)$ 以外\n", | |
"- 入力に対して、1時刻先の値。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def supervise(series):\n", | |
" X = series[:-1]\n", | |
" y = series[1:]\n", | |
" return X, y\n", | |
" \n", | |
"X, y = supervise(diffed)\n", | |
"\n", | |
"plt.figure(figsize=(15,10))\n", | |
"plt.grid()\n", | |
"\n", | |
"plt.plot(X, label='input')\n", | |
"plt.plot(y, label='label')\n", | |
"plt.legend()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"<img src='supervised.png' width=1000 align='left'>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## train, validation data に分ける\n", | |
"\n", | |
"今回は、train:val=7:3に分けます。\n", | |
"\n", | |
"- train の長さ 99 \n", | |
" - $X: D(0), ..., D(98)$\n", | |
" - $y: D(1), ..., D(99)$\n", | |
"- val の長さ 43\n", | |
" - $X: D(99), ..., D(141)$ \n", | |
" - $y: D(100), ..., D(142)$ \n", | |
"\n", | |
"時系列データは、順序に意味があるので\n", | |
"**シャッフルしない(shuffle=False)**ように設定する必要があります。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from sklearn.model_selection import train_test_split\n", | |
"\n", | |
"X_train, X_val, y_train, y_val = train_test_split(X, y,\n", | |
" test_size=0.3,\n", | |
" shuffle=False)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## 正規化\n", | |
"\n", | |
"LSTM は内部で $tanh$ を使っているため、正規化する必要があります。\n", | |
"実際に、正規化しないとうまく予測できませんでした。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"from sklearn.preprocessing import MinMaxScaler\n", | |
"\n", | |
"def scale(X_train, X_val, y_train, y_val):\n", | |
" # change type\n", | |
" X_train = X_train.astype(np.float32)\n", | |
" X_val = X_val.astype(np.float32)\n", | |
" y_train = y_train.astype(np.float32)\n", | |
" y_val = y_val.astype(np.float32)\n", | |
"\n", | |
" # scale inputs\n", | |
" sclr = MinMaxScaler()\n", | |
" X_train = sclr.fit_transform(X_train)\n", | |
" X_val = sclr.transform(X_val)\n", | |
" \n", | |
" # scale labels\n", | |
" ysclr = MinMaxScaler()\n", | |
" y_train = ysclr.fit_transform(y_train)\n", | |
" y_val = ysclr.transform(y_val)\n", | |
"\n", | |
" return X_train, X_val, y_train, y_val, sclr, ysclr" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"注意として、スケーリングに必要なパラメタは\n", | |
"**train data のみ** から計算します。\n", | |
"`sclr.fit_transform(X_train)` のところです。\n", | |
"\n", | |
"本来、 validation data は学習時に利用できないものと想定します。" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# RNNの定義\n", | |
"\n", | |
"今回は、LSTMを用います。\n", | |
"RNN(系列データを扱える deep learning のモデル)の一種です。\n", | |
"詳しくは、当記事末尾の参考サイトを参照ください。\n", | |
"\n", | |
"model architecture や データの与え方は以下の図のとおりです。\n", | |
"\n", | |
"<img src='LSTM.png' width=500>\n", | |
"\n", | |
"LSTM の loop 部分を展開した図が以下になります。\n", | |
"\n", | |
"<img src='LSTM_deployed.png' width=500>\n", | |
"\n", | |
"各時刻で、次の時刻の値を予測し、\n", | |
"各時刻ごとに MSE(mean squred error) をとり、\n", | |
"時系列の長さ(100)回、これを繰り返し、\n", | |
"全時刻のMSEをSUMでまとめて、時系列長( $\\# D_{train} = 99$ ) で割った値\n", | |
"を loss function としました。\n", | |
"\n", | |
"FC は fully connected = 全結合層です。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"from chainer import Chain\n", | |
"import chainer.links as L\n", | |
"\n", | |
"class RNN(Chain):\n", | |
" def __init__(self, units):\n", | |
" \"\"\"\n", | |
" units (tuple): e.g. (4, 5, 3)\n", | |
" - 1層目のLSTM: 4つのneuron\n", | |
" - 2層目のLSTM: 5つのneuron\n", | |
" - 3層目のLSTM: 3つのneuron\n", | |
" \"\"\"\n", | |
" super(RNN, self).__init__()\n", | |
" \n", | |
" n_in = 1 # features\n", | |
" n_out= 1\n", | |
" \n", | |
" lstms = [('lstm{}'.format(l), L.LSTM(None, n_unit))\n", | |
" for l, n_unit in enumerate(units)]\n", | |
" self.lstms = lstms\n", | |
" for name, lstm in lstms:\n", | |
" self.add_link(name, lstm)\n", | |
" \n", | |
" self.add_link('fc', L.Linear(units[-1], n_out))\n", | |
" \n", | |
" \n", | |
" def __call__(self, x):\n", | |
" \"\"\"\n", | |
" # Param\n", | |
" - x (Variable: (S, F))\n", | |
" S: samples\n", | |
" F: features\n", | |
" \n", | |
" # Return\n", | |
" - (Variable: (S, 1))\n", | |
" \"\"\"\n", | |
" h = x\n", | |
" for name, lstm in self.lstms:\n", | |
" h = lstm(h)\n", | |
" return self.fc(h)\n", | |
" \n", | |
" def reset_state(self):\n", | |
" for name, lstm in self.lstms:\n", | |
" lstm.reset_state()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"コンストラクタ( `__init__` )の外でlayerを追加する場合は、`add_link` 関数で追加する必要があります。\n", | |
"なお、 chainer v3 だと、 `with self.init_scope():` 内でいけるそうです([Chainerにおけるグラフ構造をループで書いてみる。 - のんびりしているエンジニアの日記](http://nonbiri-tereka.hatenablog.com/entry/2016/02/26/001608) 参照)。" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"**reset_state** について\n", | |
"\n", | |
"1つの時系列を読み込み、ネットワークの重みを1回 update したら、\n", | |
"次のepochに移り、もう1度、その時系列を読み直します。\n", | |
"\n", | |
"再び **時系列を始めから読み込むときに、LSTMの前の層から受け取る情報を初期状態に戻す** 必要があります。\n", | |
"それをおこなうのが **reset_state** です。\n", | |
"\n", | |
"このあたりは、\n", | |
"[stateful と stateless な LSTM](https://stackoverflow.com/questions/39681046/keras-stateful-vs-stateless-lstms)\n", | |
"で挙動が違うので気をつけてください。\n", | |
"今回は stateful で、任意のタイミングで reset_state しています。" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# loss function の定義\n", | |
"\n", | |
"先程、説明したlossを実装します。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import chainer.links as L\n", | |
"import chainer.functions as F\n", | |
"\n", | |
"class LossSumMSEOverTime(L.Classifier):\n", | |
" def __init__(self, predictor):\n", | |
" super(LossSumMSEOverTime, self).__init__(predictor, lossfun=F.mean_squared_error)\n", | |
" \n", | |
" def __call__(self, X_STF, y_STF):\n", | |
" \"\"\"\n", | |
" # Param\n", | |
" - X_STF (Variable: (S, T, F))\n", | |
" - y_STF (Variable: (S, T, F))\n", | |
" S: samples\n", | |
" T: time_steps\n", | |
" F: features\n", | |
" \n", | |
" # Return\n", | |
" - loss (Variable: (1, ))\n", | |
" \"\"\"\n", | |
" # 時間 T で loop させるため、Tを先頭の軸にする\n", | |
" X_TSF = X_STF.transpose(1,0,2)\n", | |
" y_TSF = y_STF.transpose(1,0,2)\n", | |
" seq_len = X_TSF.shape[0]\n", | |
" \n", | |
" # 各時刻についてlossをとり、最終的なlossに足していく\n", | |
" loss = 0\n", | |
" for t in range(seq_len):\n", | |
" pred = self.predictor(X_TSF[t])\n", | |
" obs = y_TSF[t]\n", | |
" loss += self.lossfun(pred, obs)\n", | |
" # loss の大きさが時系列長に依存してしまうので、時系列長で割る\n", | |
" loss /= seq_len\n", | |
" \n", | |
" # reporter に loss の値を渡す\n", | |
" reporter.report({'loss': loss}, self)\n", | |
" \n", | |
" return loss" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"`L.Classifier` は loss の report など、\n", | |
"便利な機能が備わってる loss function です。\n", | |
"これを override します。\n", | |
"\n", | |
"Classifier となってはいるものの、\n", | |
"**引数で任意の loss function に変えられる**ので、\n", | |
"MSEを渡してやれば、今回のような **回帰にも使えます** 。" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Updater の定義\n", | |
"\n", | |
"updater もオリジナルのを用意します。\n", | |
"標準的な `StanadardUpdater` を override します。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"from chainer import training\n", | |
"from chainer import Variable, reporter\n", | |
"\n", | |
"class UpdaterRNN(training.StandardUpdater):\n", | |
" def __init__(self, itr_train, optimizer, device=-1):\n", | |
" super(UpdaterRNN, self).__init__(itr_train, optimizer, device=device)\n", | |
" \n", | |
" # overrided\n", | |
" def update_core(self):\n", | |
" itr_train = self.get_iterator('main')\n", | |
" optimizer = self.get_optimizer('main')\n", | |
" \n", | |
" batch = itr_train.__next__()\n", | |
" X_STF, y_STF = chainer.dataset.concat_examples(batch, self.device)\n", | |
" \n", | |
" optimizer.target.zerograds()\n", | |
" optimizer.target.predictor.reset_state()\n", | |
" loss = optimizer.target(Variable(X_STF), Variable(y_STF))\n", | |
" \n", | |
" loss.backward()\n", | |
" optimizer.update()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"`update_core`が学習の1stepにあたり、\n", | |
"この関数が1回呼び出されると、パラメータが1回更新されます。\n", | |
"\n", | |
"`itr_train` は train 用の Iterator で、\n", | |
"各 iteration でデータセットから1つの batch をモデルに渡してくれます。\n", | |
"あとで、updater インスタンス化するときに iterator を渡してあげます。\n", | |
"\n", | |
"ちなみに、 **updater の中で入力X, ラベルyを変形(transpose)するのはオススメしません。**\n", | |
"理由として、 train 中の **evaluation 時は updater を介さず**、\n", | |
"データがモデル(with loss)に渡されるからです。\n", | |
"つまり、 train と evaluation 時で、 \n", | |
"model に渡すデータの形が変わってしまいエラーが起きてしまいます。\n", | |
"\n", | |
"なので、今回は、loss function 内で変形するようにしました。" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# 学習\n", | |
"それでは、各オブジェクトを生成して、学習させていきます。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import chainer\n", | |
"from chainer.optimizers import RMSprop\n", | |
"from chainer.iterators import SerialIterator\n", | |
"from chainer.training import extensions\n", | |
"\n", | |
"# model\n", | |
"units = (5, 4, 3)\n", | |
"model = LossSumMSEOverTime(RNN(units))\n", | |
"\n", | |
"# optimizer\n", | |
"optimizer = RMSprop()\n", | |
"optimizer.setup(model)\n", | |
"\n", | |
"# dataset (Datasetオブジェクトじゃなくて、list(zip())でも可)\n", | |
"df = pd.read_csv('international-airline-passengers.csv')\n", | |
"# 1ではなく1:とするのは、shapeを(144,)ではなく(144,1)とするため\n", | |
"series = df.iloc[:,1:].values \n", | |
"diffed = difference(series)\n", | |
"X, y = supervise(diffed)\n", | |
"X_train, X_val, y_train, y_val = train_test_split(X, y,\n", | |
" test_size=0.3,\n", | |
" shuffle=False)\n", | |
"X_train, X_val, y_train, y_val, sclr, ysclr = scale(X_train, X_val, y_train, y_val)\n", | |
"# change type\n", | |
"X_train = X_train.astype(np.float32)\n", | |
"X_val = X_val.astype(np.float32)\n", | |
"y_train = y_train.astype(np.float32)\n", | |
"y_val = y_val.astype(np.float32)\n", | |
"# change shape\n", | |
"X_train = X_train[np.newaxis, :, :]\n", | |
"X_val = X_val[np.newaxis, :, :]\n", | |
"y_train = y_train[np.newaxis, :, :]\n", | |
"y_val = y_val[np.newaxis, :, :]\n", | |
"ds_train = list(zip(X_train, y_train))\n", | |
"ds_val = list(zip(X_val , y_val ))\n", | |
"\n", | |
"# iterator\n", | |
"itr_train = SerialIterator(ds_train, batch_size=1, shuffle=False)\n", | |
"itr_val = SerialIterator(ds_val , batch_size=1, shuffle=False, repeat=False)\n", | |
"\n", | |
"# updater\n", | |
"updater = UpdaterRNN(itr_train, optimizer)\n", | |
"\n", | |
"# trainer\n", | |
"trainer = training.Trainer(updater, (1000, 'epoch'), out='results')\n", | |
"# evaluation\n", | |
"eval_model = model.copy()\n", | |
"eval_rnn = eval_model.predictor\n", | |
"trainer.extend(extensions.Evaluator(\n", | |
" itr_val, eval_model, device=-1,\n", | |
" eval_hook=lambda _: eval_rnn.reset_state()))\n", | |
"# other extensions\n", | |
"trainer.extend(extensions.LogReport())\n", | |
"trainer.extend(extensions.snapshot_object(model.predictor, \n", | |
" filename='model_epoch-{.updater.epoch}'))\n", | |
"trainer.extend(extensions.PrintReport(\n", | |
" ['epoch','main/loss','validation/main/loss']\n", | |
" ))\n", | |
"\n", | |
"trainer.run()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"extension は以下のサイトがよくまとまっています。\n", | |
"\n", | |
"- [勤労感謝の日なのでChainerの勤労(Training)に感謝してextensionsを全部試した話 - EnsekiTT Blog](http://ensekitt.hatenablog.com/entry/2016/11/24/012539)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# 学習曲線のplot\n", | |
"\n", | |
"学習を実行すると、LogReport extension で、 json 形式の学習 log ファイルが`./results` に保存されます。これを読み込んで可視化します。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"log = pd.read_json('results/log')\n", | |
"log.plot(y=['main/loss', 'validation/main/loss'],\n", | |
" figsize=(15,10),\n", | |
" grid=True)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"<img src='loss.png' width=1000 align='left'>\n", | |
"\n", | |
"train と validation の loss に大きな違いがあると思いますが、\n", | |
"これは、扱ってるデータが時系列で、データセットを**シャッフルしていないため**、\n", | |
"トレンドや変動のスケールが変わるような時系列データだと、\n", | |
"**train, validation によってスケールに偏りが生じる** からです。" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# 予測\n", | |
"\n", | |
"予測の方法には2種類あります。\n", | |
"\n", | |
"- ①観測値 $D(t)$ を用いる方法(e.g. $\\hat{D}(t+1) = RNN(D(t);h_{t-1})$)\n", | |
"- ②予測値 $\\hat{D}(t)$ を用いる方法(e.g. $\\hat{D}(t+1) = RNN(\\hat{D}(t);h_{t-1})$)\n", | |
"\n", | |
"※ $h_{t-1}$ : 前の隠れ層の状態\n", | |
"\n", | |
"それぞれの方法で予測してみます。" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# 学習パラメタの読み込み\n", | |
"\n", | |
"validation loss が最も良かった epoch の重みを採用します。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import os\n", | |
"from chainer import serializers\n", | |
"\n", | |
"best_idx = log['validation/main/loss'].argmin()\n", | |
"best_epoch = int(log['epoch'].ix[best_idx])\n", | |
"\n", | |
"units = (5, 4, 3)\n", | |
"model = RNN(units)\n", | |
"weight_file = os.path.join('results', 'model_epoch-{}'.format(best_epoch))\n", | |
"serializers.load_npz(weight_file, model)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"先ほどまでの `model` は **RNN + loss** でしたが、\n", | |
"上のコードでは **RNNだけ** なので注意です。\n", | |
"\n", | |
"なお、重みの読み込みは以下のページを参考にしました。\n", | |
"\n", | |
"- [Chainerのモデルのセーブとロード - 無限グミ](http://toua20001.hatenablog.com/entry/2016/11/15/203332)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## ①観測値を使って予測" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"model.reset_state()\n", | |
"\n", | |
"n_train = X_train.shape[1]\n", | |
"n_val = X_val.shape[1]\n", | |
"\n", | |
"X = np.concatenate((X_train, X_val), axis=1)[0]\n", | |
"obs = np.concatenate((y_train, y_val), axis=1)[0]\n", | |
"\n", | |
"# prediction\n", | |
"pred = []\n", | |
"for X_t in X:\n", | |
" p_t = model(X_t.reshape(-1,1)).data[0]\n", | |
" pred.append(p_t)\n", | |
"\n", | |
"plt.figure(figsize=(15,10))\n", | |
"\n", | |
"plt.plot(obs, label='obs')\n", | |
"plt.plot(pred, label='pred')\n", | |
"\n", | |
"plt.grid()\n", | |
"plt.legend()\n", | |
"plt.axvline(n_train, color='r')\n", | |
"\n", | |
"plt.savefig('pred1.png')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"<img src='pred1.png' width=1000 align='left'>\n", | |
"\n", | |
"赤い線より左側が train, 右側が validation に対する予測です。\n", | |
"\n", | |
"$$\\hat{D}(1),...,\\hat{D}(99),\\hat{D}(100),...,\\hat{D}(142)$$\n", | |
"\n", | |
"を予測しています。" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## ②予測値を使って予測" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"model.reset_state()\n", | |
"\n", | |
"# train data に関しては先ほどと同じく、観測値を使って予測し、\n", | |
"# 隠れ層の状態を作る。\n", | |
"pred = []\n", | |
"for X_t in X_train[0]:\n", | |
" p_t = model(X_t.reshape(-1,1)).data[0]\n", | |
" pred.append(p_t)\n", | |
" \n", | |
"# valdiation data に対する予測\n", | |
"p_t = X_val[0,0]\n", | |
"n_pred = n_val\n", | |
"for t in range(n_pred):\n", | |
" p_t = model(p_t.reshape(-1,1)).data[0]\n", | |
" pred.append(p_t)\n", | |
"\n", | |
"plt.figure(figsize=(15,10))\n", | |
"\n", | |
"plt.plot(obs, label='obs')\n", | |
"plt.plot(pred, label='pred')\n", | |
"\n", | |
"plt.grid()\n", | |
"plt.legend()\n", | |
"plt.axvline(n_train, color='r')\n", | |
"\n", | |
"plt.savefig('pred2.png')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"<img src='pred2.png' width=1000 align='left'>\n", | |
"\n", | |
"train に関しては、先ほどの①と同じ。\n", | |
"validation に関しては、先程より誤差が若干、大きくなっています。\n", | |
"\n", | |
"①では、validation data の観測値を使って予測していたので、\n", | |
"validation data の個数と同じ時刻分だけしか予測できませんでしたが、\n", | |
"②では 任意個、 `n_pred` 個だけ、未来の時刻を予測できます。\n", | |
"今回は、①と同じく validation data の個数と同じにしました。" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# 後処理\n", | |
"このままだと階差・正規化したままの時系列なので、\n", | |
"これを、もとの時系列と比較できるように逆変換します。\n", | |
"なお、予測値は②を使います。" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## 正規化を戻す" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"obs_unscale = ysclr.inverse_transform(obs)\n", | |
"pred_unscale = ysclr.inverse_transform(pred)\n", | |
"\n", | |
"plt.figure(figsize=(15,10))\n", | |
"\n", | |
"plt.plot(obs_unscale, label='obs')\n", | |
"plt.plot(pred_unscale, label='pred')\n", | |
"\n", | |
"plt.grid()\n", | |
"plt.legend()\n", | |
"plt.axvline(n_train, color='r')\n", | |
"\n", | |
"plt.savefig('unscale.png')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"<img src='unscale.png' width=1000 align='left'>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## 階差を戻す\n", | |
"階差時系列の定義は、\n", | |
"\n", | |
"$$D(t) = X(t+1) - X(t)$$\n", | |
"\n", | |
"ですから、\n", | |
"\n", | |
"$$X(t+1) = D(t) + X(t)$$\n", | |
"\n", | |
"です。\n", | |
"今、\n", | |
"\n", | |
"$$\\hat{D}(1),...,\\hat{D}(99),\\hat{D}(100),...,\\hat{D}(142)$$\n", | |
"\n", | |
"の142個を予測したので、これに、もとの時系列\n", | |
"\n", | |
"$$X(1),...,X(142)$$\n", | |
"\n", | |
"を加算して、\n", | |
"\n", | |
"$$\\hat{X}(2), ..., \\hat{X}(143)$$\n", | |
"\n", | |
"にします。\n", | |
"\n", | |
"ただし、ここで注意があります。\n", | |
"train data に関する予測、\n", | |
"\n", | |
"$$\\hat{X}(2), ..., \\hat{X}(100)$$\n", | |
"\n", | |
"までは、手元にある、\n", | |
"\n", | |
"$$X(1),...,X(99)$$\n", | |
"\n", | |
"を使って出せますが、\n", | |
"validation data は学習時には手に入っていないと想定するので、\n", | |
"validation の予測、\n", | |
"\n", | |
"$$\\hat{X}(101), ..., \\hat{X}(143)$$\n", | |
"\n", | |
"については、\n", | |
"\n", | |
"$$\n", | |
"\\hat{X}(101) = \\hat{D}(100) + \\hat{X}(100) \\\\\n", | |
"\\hat{X}(102) = \\hat{D}(101) + \\hat{X}(101) \\\\\n", | |
"\\vdots \\\\\n", | |
"\\hat{X}(143) = \\hat{D}(142) + \\hat{X}(142) \\\\\n", | |
"$$\n", | |
"\n", | |
"というように、予測値を足し合わせていきます。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"obs_undiff = series[2:]\n", | |
"\n", | |
"pred_train = pred_unscale[:n_train] + series[1:1+n_train]\n", | |
"\n", | |
"pred_val = []\n", | |
"X_t = series[n_train+1]\n", | |
"for D_t in pred_unscale[n_train:]:\n", | |
" X_t = D_t + X_t\n", | |
" pred_val.append(X_t)\n", | |
"pred_undiff = np.concatenate((pred_train,\n", | |
" pred_val), axis=0)\n", | |
"\n", | |
"plt.figure(figsize=(15,10))\n", | |
"\n", | |
"plt.plot(obs_undiff, label='obs')\n", | |
"plt.plot(pred_undiff, label='pred')\n", | |
"\n", | |
"plt.grid()\n", | |
"plt.legend()\n", | |
"plt.axvline(n_train, color='r')\n", | |
"\n", | |
"plt.savefig('undiff.png')" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"<img src='undiff.png' width=1000 align='left'>\n", | |
"\n", | |
"時刻が進むについて大きくなる変動については、うまく学習できていないようです。更なる工夫が必要です。" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# まとめ\n", | |
"\n", | |
"- LSTM 用いた時系列予測を chainer で実装しました。\n", | |
"- 前処理として、階差、正規化を施すと、予測精度が高くなりました。\n", | |
"- 観測値と予測値の2種類の予測方法を試した結果、観測値を使ったほうが予測精度が高くなります。" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# 参考文献・サイト\n", | |
"\n", | |
"## LSTM\n", | |
"\n", | |
"### 理論\n", | |
"- [LSTMネットワークの概要 - Qiita](https://qiita.com/KojiOhki/items/89cd7b69a8a6239d67ca)\n", | |
"- [Fei-Fei Li & Justin Johnson & Serena Yeung Lecture 10: Recurrent Neural Networks](http://cs231n.stanford.edu/slides/2017/cs231n_2017_lecture10.pdf)\n", | |
"- [ニューラルネット勉強会(LSTM編)](http://isw3.naist.jp/~neubig/student/2015/seitaro-s/161025neuralnet_study_LSTM.pdf)\n", | |
"- [わかるLSTM ~ 最近の動向と共に - Qiita](https://qiita.com/t_Signull/items/21b82be280b46f467d1b)\n", | |
"\n", | |
"### 実践(kerasのコードつき)\n", | |
"- [Time Series Forecasting with the Long Short-Term Memory Network in Python - Machine Learning Mastery](https://machinelearningmastery.com/time-series-forecasting-long-short-term-memory-network-python/) \n", | |
"- [Mini-Course on Long Short-Term Memory Recurrent Neural Networks with Keras - Machine Learning Mastery](https://machinelearningmastery.com/long-short-term-memory-recurrent-neural-networks-mini-course/)\n", | |
"\n", | |
"\n", | |
"## chainer\n", | |
"\n", | |
"- [Chainer: ビギナー向けチュートリアル Vol.1 - Qiita](https://qiita.com/mitmul/items/eccf4e0a84cb784ba84a) chainerは**学習を抽象化するクラス**間の関係が初心者にはとっつきにくいです。それらの**関係図**がわかりやすくまとまっています。\n", | |
" - [Chainer v3 ビギナー向けチュートリアル - Qiita](https://qiita.com/mitmul/items/1e35fba085eb07a92560)\n", | |
"- [LSTMにsin波を覚えてもらう(chainer trainerの速習) - Qiita](https://qiita.com/chachay/items/052406176c55dd5b9a6a) 実装する上で最も参考にさせていただきました。\n", | |
"- [Chainerにおけるグラフ構造をループで書いてみる。 - のんびりしているエンジニアの日記](http://nonbiri-tereka.hatenablog.com/entry/2016/02/26/001608) layer追加方法が、参考になりました。\n", | |
"- [Chainerのモデルのセーブとロード - 無限グミ](http://toua20001.hatenablog.com/entry/2016/11/15/203332)\n", | |
"- [勤労感謝の日なのでChainerの勤労(Training)に感謝してextensionsを全部試した話 - EnsekiTT Blog](http://ensekitt.hatenablog.com/entry/2016/11/24/012539)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.0" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment