Skip to content

Instantly share code, notes, and snippets.

@tnarihi
Created April 21, 2015 01:35
Show Gist options
  • Save tnarihi/d866adc3446147c8acad to your computer and use it in GitHub Desktop.
Save tnarihi/d866adc3446147c8acad to your computer and use it in GitHub Desktop.
Saving the weights of VGG-16 as Matlab format
Display the source blob
Display the rendered blob
Raw
{"nbformat_minor": 0, "cells": [{"execution_count": 1, "cell_type": "code", "source": "import caffe", "outputs": [], "metadata": {"collapsed": true, "trusted": true}}, {"execution_count": 2, "cell_type": "code", "source": "path_prototxt = '/home/narihira/caffe/models/211839e770f7b538e2d8/VGG_ILSVRC_16_layers_deploy.prototxt'\npath_caffemodel = '/home/narihira/caffe/models/211839e770f7b538e2d8/VGG_ILSVRC_16_layers.caffemodel'\nnet = caffe.Net(path_prototxt, path_caffemodel, caffe.TEST)", "outputs": [], "metadata": {"collapsed": true, "trusted": true}}, {"source": "Parameters are stored in `net.params` as a Python dictionary format.", "cell_type": "markdown", "metadata": {}}, {"execution_count": 4, "cell_type": "code", "source": "print net.params.keys()", "outputs": [{"output_type": "stream", "name": "stdout", "text": "['conv1_1', 'conv1_2', 'conv2_1', 'conv2_2', 'conv3_1', 'conv3_2', 'conv3_3', 'conv4_1', 'conv4_2', 'conv4_3', 'conv5_1', 'conv5_2', 'conv5_3', 'fc6', 'fc7', 'fc8']\n"}], "metadata": {"scrolled": true, "collapsed": false, "trusted": true}}, {"source": "Usually, the 1st element of each value of the `params` dictionary is connection weights. The 2nd is biases. The weights of convolutional kernels are stored in a shape of `(num_out, num_in, h, w)`.", "cell_type": "markdown", "metadata": {}}, {"execution_count": 8, "cell_type": "code", "source": "print net.params['conv1_1'][0].shape # 0: weights\nprint net.params['conv1_1'][1].shape # 1: biases", "outputs": [{"output_type": "stream", "name": "stdout", "text": "(64, 3, 3, 3)\n(64,)\n"}], "metadata": {"collapsed": false, "trusted": true}}, {"source": "Here, you store the ndarray objects into a dictionary `dict_weights`.", "cell_type": "markdown", "metadata": {}}, {"execution_count": 16, "cell_type": "code", "source": "dict_weights = {}\nfor key in net.params.keys():\n weight = net.params[key][0] # This is a caffe.Blob object\n bias = net.params[key][1]\n dict_weights[key + '_w'] = weight.data # ndarray object is accessecible by `Blob.data`.\n dict_weights[key + '_b'] = bias.data", "outputs": [], "metadata": {"collapsed": false, "trusted": true}}, {"source": "You can save the ndarray objects as Matlab format using `scipy.io.savemat`.", "cell_type": "markdown", "metadata": {}}, {"execution_count": 17, "cell_type": "code", "source": "from scipy.io import savemat\nsavemat('weights.mat', dict_weights)", "outputs": [], "metadata": {"collapsed": false, "trusted": true}}, {"source": "You can use the weights from Matlab/Octave as following.\n\n```Matlab\nload weights;\nsize(conv1_1_w)\n```\n\nThis will show `[64 3 3 3]`.", "cell_type": "markdown", "metadata": {}}, {"execution_count": null, "cell_type": "code", "source": "", "outputs": [], "metadata": {"collapsed": true, "trusted": true}}], "nbformat": 4, "metadata": {"kernelspec": {"display_name": "Python 2", "name": "python2", "language": "python"}, "language_info": {"mimetype": "text/x-python", "nbconvert_exporter": "python", "version": "2.7.9", "name": "python", "file_extension": ".py", "pygments_lexer": "ipython2", "codemirror_mode": {"version": 2, "name": "ipython"}}}}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment