Created
October 11, 2019 06:24
-
-
Save Niranjankumar-c/1beb9f8260f1b209719c99d0258a17d7 to your computer and use it in GitHub Desktop.
plotting weights
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def plot_weights(model, layer_num, single_channel = True, collated = False): | |
#extracting the model features at the particular layer number | |
layer = model.features[layer_num] | |
#checking whether the layer is convolution layer or not | |
if isinstance(layer, nn.Conv2d): | |
#getting the weight tensor data | |
weight_tensor = model.features[layer_num].weight.data | |
if single_channel: | |
if collated: | |
plot_filters_single_channel_big(weight_tensor) | |
else: | |
plot_filters_single_channel(weight_tensor) | |
else: | |
if weight_tensor.shape[1] == 3: | |
plot_filters_multi_channel(weight_tensor) | |
else: | |
print("Can only plot weights with three channels with single channel = False") | |
else: | |
print("Can only visualize layers which are convolutional") | |
#visualize weights for alexnet - first conv layer | |
plot_weights(alexnet, 0, single_channel = False) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment