Created
July 4, 2016 10:01
-
-
Save yusuke0519/788bc795e0170e7d09ec0723a11e2752 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## Implementing KL aneeling in Keras\n", | |
"- [Generating Sentences from a Continuous Space](https://arxiv.org/pdf/1511.06349.pdf)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": false, | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Using Theano backend.\n", | |
"Using gpu device 0: GeForce GTX TITAN X (CNMeM is disabled, cuDNN 5005)\n" | |
] | |
} | |
], | |
"source": [ | |
"from keras import backend as K" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"hp_lambda = K.variable(0) # default values\n", | |
"\n", | |
"def original_loss(y_true, y_pred):\n", | |
" loss = K.categorical_crossentropy(y_true, y_pred)\n", | |
" KL = K.variable([100])\n", | |
" return loss + (hp_lambda) * KL" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"y_pred = K.placeholder(ndim=1)\n", | |
"y_true = K.placeholder(ndim=1)\n", | |
"f = K.function([y_true, y_pred], original_loss(y_true, y_pred))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"When lambda=0.0: [ 1.19209304e-07]\n" | |
] | |
} | |
], | |
"source": [ | |
"K.set_value(hp_lambda, 0)\n", | |
"print(\"When lambda={}: {}\".format(K.get_value(hp_lambda), f([[0, 1, 0], [0, 1, 0]])))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"When lambda=1.0: [ 100.]\n" | |
] | |
} | |
], | |
"source": [ | |
"K.set_value(hp_lambda, 1)\n", | |
"print(\"When lambda={}: {}\".format(K.get_value(hp_lambda), f([[0, 1, 0], [0, 1, 0]])))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"- K.set_valueをepochごとに呼び出せば良い\n", | |
"- やり方としては\n", | |
" 1. エポックごとにset_valueを実行\n", | |
" 2. hp_lambdaをアップデートするコールバックを書く\n", | |
"- の2通りありそう\n", | |
"- 後者に関しては[url](https://github.com/fchollet/keras/blob/master/keras/callbacks.py)のLerningRateSchedulerが参考になりそう\n" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Callbacks" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"from keras import callbacks\n", | |
"\n", | |
"\n", | |
"class AneelingCallback(callbacks.Callback):\n", | |
" '''Aneeling theano shared variable.\n", | |
" # Arguments\n", | |
" schedule: a function that takes an epoch index as input\n", | |
" (integer, indexed from 0) and returns a new\n", | |
" learning rate as output (float).\n", | |
" '''\n", | |
" def __init__(self, schedule, variable):\n", | |
" super(AneelingCallback, self).__init__()\n", | |
" self.schedule = schedule\n", | |
" self.variable = variable\n", | |
" \n", | |
" def on_epoch_begin(self, epoch, logs={}):\n", | |
" assert hasattr(self.model.optimizer, 'lr'), \\\n", | |
" 'Optimizer must have a \"lr\" attribute.'\n", | |
" value = self.schedule(epoch)\n", | |
" assert type(value) == float, 'The output of the \"schedule\" function should be float.'\n", | |
" K.set_value(self.variable, value)\n", | |
" \n", | |
" \n", | |
"def schedule(epoch):\n", | |
" if epoch > 5:\n", | |
" return 10.0\n", | |
" return 0.0\n", | |
"\n", | |
"aneeling_callback = AneelingCallback(schedule, hp_lambda)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"- テストはしてない" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 2", | |
"language": "python", | |
"name": "python2" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 2 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython2", | |
"version": "2.7.11" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment