Created
April 21, 2015 01:35
-
-
Save tnarihi/d866adc3446147c8acad to your computer and use it in GitHub Desktop.
Saving the weights of VGG-16 as Matlab format
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{"nbformat_minor": 0, "cells": [{"execution_count": 1, "cell_type": "code", "source": "import caffe", "outputs": [], "metadata": {"collapsed": true, "trusted": true}}, {"execution_count": 2, "cell_type": "code", "source": "path_prototxt = '/home/narihira/caffe/models/211839e770f7b538e2d8/VGG_ILSVRC_16_layers_deploy.prototxt'\npath_caffemodel = '/home/narihira/caffe/models/211839e770f7b538e2d8/VGG_ILSVRC_16_layers.caffemodel'\nnet = caffe.Net(path_prototxt, path_caffemodel, caffe.TEST)", "outputs": [], "metadata": {"collapsed": true, "trusted": true}}, {"source": "Parameters are stored in `net.params` as a Python dictionary format.", "cell_type": "markdown", "metadata": {}}, {"execution_count": 4, "cell_type": "code", "source": "print net.params.keys()", "outputs": [{"output_type": "stream", "name": "stdout", "text": "['conv1_1', 'conv1_2', 'conv2_1', 'conv2_2', 'conv3_1', 'conv3_2', 'conv3_3', 'conv4_1', 'conv4_2', 'conv4_3', 'conv5_1', 'conv5_2', 'conv5_3', 'fc6', 'fc7', 'fc8']\n"}], "metadata": {"scrolled": true, "collapsed": false, "trusted": true}}, {"source": "Usually, the 1st element of each value of the `params` dictionary is connection weights. The 2nd is biases. The weights of convolutional kernels are stored in a shape of `(num_out, num_in, h, w)`.", "cell_type": "markdown", "metadata": {}}, {"execution_count": 8, "cell_type": "code", "source": "print net.params['conv1_1'][0].shape # 0: weights\nprint net.params['conv1_1'][1].shape # 1: biases", "outputs": [{"output_type": "stream", "name": "stdout", "text": "(64, 3, 3, 3)\n(64,)\n"}], "metadata": {"collapsed": false, "trusted": true}}, {"source": "Here, you store the ndarray objects into a dictionary `dict_weights`.", "cell_type": "markdown", "metadata": {}}, {"execution_count": 16, "cell_type": "code", "source": "dict_weights = {}\nfor key in net.params.keys():\n weight = net.params[key][0] # This is a caffe.Blob object\n bias = net.params[key][1]\n dict_weights[key + '_w'] = weight.data # ndarray object is accessecible by `Blob.data`.\n dict_weights[key + '_b'] = bias.data", "outputs": [], "metadata": {"collapsed": false, "trusted": true}}, {"source": "You can save the ndarray objects as Matlab format using `scipy.io.savemat`.", "cell_type": "markdown", "metadata": {}}, {"execution_count": 17, "cell_type": "code", "source": "from scipy.io import savemat\nsavemat('weights.mat', dict_weights)", "outputs": [], "metadata": {"collapsed": false, "trusted": true}}, {"source": "You can use the weights from Matlab/Octave as following.\n\n```Matlab\nload weights;\nsize(conv1_1_w)\n```\n\nThis will show `[64 3 3 3]`.", "cell_type": "markdown", "metadata": {}}, {"execution_count": null, "cell_type": "code", "source": "", "outputs": [], "metadata": {"collapsed": true, "trusted": true}}], "nbformat": 4, "metadata": {"kernelspec": {"display_name": "Python 2", "name": "python2", "language": "python"}, "language_info": {"mimetype": "text/x-python", "nbconvert_exporter": "python", "version": "2.7.9", "name": "python", "file_extension": ".py", "pygments_lexer": "ipython2", "codemirror_mode": {"version": 2, "name": "ipython"}}}} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment