Created
July 2, 2020 07:05
-
-
Save douglasrizzo/81b9b2a190cf5ad3125e929df919e98d to your computer and use it in GitHub Desktop.
wandb_dict_forked_net.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "wandb_dict_forked_net.ipynb", | |
"provenance": [], | |
"authorship_tag": "ABX9TyNuid6KzwgJ8c3G6u45upsB", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/douglasrizzo/81b9b2a190cf5ad3125e929df919e98d/wandb_dict_forked_net.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "fA_qOdGNTHl1", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 1000 | |
}, | |
"outputId": "f8786ab7-6b8a-4e7d-bc31-f3dff4c564b3" | |
}, | |
"source": [ | |
"!pip install wandb" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Collecting wandb\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/00/8e/d43984196a0fa8ef961ae3dce91ada52ae7747fbf39d41f5743c27152d97/wandb-0.9.2-py2.py3-none-any.whl (1.4MB)\n", | |
"\u001b[K |████████████████████████████████| 1.4MB 2.8MB/s \n", | |
"\u001b[?25hRequirement already satisfied: six>=1.10.0 in /usr/local/lib/python3.6/dist-packages (from wandb) (1.12.0)\n", | |
"Collecting watchdog>=0.8.3\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/0e/06/121302598a4fc01aca942d937f4a2c33430b7181137b35758913a8db10ad/watchdog-0.10.3.tar.gz (94kB)\n", | |
"\u001b[K |████████████████████████████████| 102kB 9.2MB/s \n", | |
"\u001b[?25hRequirement already satisfied: Click>=7.0 in /usr/local/lib/python3.6/dist-packages (from wandb) (7.1.2)\n", | |
"Collecting sentry-sdk>=0.4.0\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/2f/6b/939519d77c95a9b2c85b771e9dccbf9e69cb90016c7cd63887c26400dd7a/sentry_sdk-0.15.1-py2.py3-none-any.whl (105kB)\n", | |
"\u001b[K |████████████████████████████████| 112kB 15.4MB/s \n", | |
"\u001b[?25hCollecting gql==0.2.0\n", | |
" Downloading https://files.pythonhosted.org/packages/c4/6f/cf9a3056045518f06184e804bae89390eb706168349daa9dff8ac609962a/gql-0.2.0.tar.gz\n", | |
"Requirement already satisfied: python-dateutil>=2.6.1 in /usr/local/lib/python3.6/dist-packages (from wandb) (2.8.1)\n", | |
"Requirement already satisfied: PyYAML>=3.10 in /usr/local/lib/python3.6/dist-packages (from wandb) (3.13)\n", | |
"Requirement already satisfied: nvidia-ml-py3>=7.352.0 in /usr/local/lib/python3.6/dist-packages (from wandb) (7.352.0)\n", | |
"Requirement already satisfied: psutil>=5.0.0 in /usr/local/lib/python3.6/dist-packages (from wandb) (5.4.8)\n", | |
"Collecting subprocess32>=3.5.3\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/32/c8/564be4d12629b912ea431f1a50eb8b3b9d00f1a0b1ceff17f266be190007/subprocess32-3.5.4.tar.gz (97kB)\n", | |
"\u001b[K |████████████████████████████████| 102kB 7.4MB/s \n", | |
"\u001b[?25hCollecting GitPython>=1.0.0\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/8c/f9/c315aa88e51fabdc08e91b333cfefb255aff04a2ee96d632c32cb19180c9/GitPython-3.1.3-py3-none-any.whl (451kB)\n", | |
"\u001b[K |████████████████████████████████| 460kB 16.4MB/s \n", | |
"\u001b[?25hCollecting docker-pycreds>=0.4.0\n", | |
" Downloading https://files.pythonhosted.org/packages/f5/e8/f6bd1eee09314e7e6dee49cbe2c5e22314ccdb38db16c9fc72d2fa80d054/docker_pycreds-0.4.0-py2.py3-none-any.whl\n", | |
"Collecting configparser>=3.8.1\n", | |
" Downloading https://files.pythonhosted.org/packages/4b/6b/01baa293090240cf0562cc5eccb69c6f5006282127f2b846fad011305c79/configparser-5.0.0-py3-none-any.whl\n", | |
"Collecting shortuuid>=0.5.0\n", | |
" Downloading https://files.pythonhosted.org/packages/25/a6/2ecc1daa6a304e7f1b216f0896b26156b78e7c38e1211e9b798b4716c53d/shortuuid-1.0.1-py3-none-any.whl\n", | |
"Requirement already satisfied: requests>=2.0.0 in /usr/local/lib/python3.6/dist-packages (from wandb) (2.23.0)\n", | |
"Collecting pathtools>=0.1.1\n", | |
" Downloading https://files.pythonhosted.org/packages/e7/7f/470d6fcdf23f9f3518f6b0b76be9df16dcc8630ad409947f8be2eb0ed13a/pathtools-0.1.2.tar.gz\n", | |
"Requirement already satisfied: certifi in /usr/local/lib/python3.6/dist-packages (from sentry-sdk>=0.4.0->wandb) (2020.6.20)\n", | |
"Requirement already satisfied: urllib3>=1.10.0 in /usr/local/lib/python3.6/dist-packages (from sentry-sdk>=0.4.0->wandb) (1.24.3)\n", | |
"Collecting graphql-core<2,>=0.5.0\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/b0/89/00ad5e07524d8c523b14d70c685e0299a8b0de6d0727e368c41b89b7ed0b/graphql-core-1.1.tar.gz (70kB)\n", | |
"\u001b[K |████████████████████████████████| 71kB 8.0MB/s \n", | |
"\u001b[?25hRequirement already satisfied: promise<3,>=2.0 in /usr/local/lib/python3.6/dist-packages (from gql==0.2.0->wandb) (2.3)\n", | |
"Collecting gitdb<5,>=4.0.1\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/48/11/d1800bca0a3bae820b84b7d813ad1eff15a48a64caea9c823fc8c1b119e8/gitdb-4.0.5-py3-none-any.whl (63kB)\n", | |
"\u001b[K |████████████████████████████████| 71kB 8.2MB/s \n", | |
"\u001b[?25hRequirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests>=2.0.0->wandb) (3.0.4)\n", | |
"Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests>=2.0.0->wandb) (2.9)\n", | |
"Collecting smmap<4,>=3.0.1\n", | |
" Downloading https://files.pythonhosted.org/packages/b0/9a/4d409a6234eb940e6a78dfdfc66156e7522262f5f2fecca07dc55915952d/smmap-3.0.4-py2.py3-none-any.whl\n", | |
"Building wheels for collected packages: watchdog, gql, subprocess32, pathtools, graphql-core\n", | |
" Building wheel for watchdog (setup.py) ... \u001b[?25l\u001b[?25hdone\n", | |
" Created wheel for watchdog: filename=watchdog-0.10.3-cp36-none-any.whl size=73870 sha256=68178c701148e2175ac6faf8c5909a881e107661898e6d457de37319b66841ff\n", | |
" Stored in directory: /root/.cache/pip/wheels/a8/1d/38/2c19bb311f67cc7b4d07a2ec5ea36ab1a0a0ea50db994a5bc7\n", | |
" Building wheel for gql (setup.py) ... \u001b[?25l\u001b[?25hdone\n", | |
" Created wheel for gql: filename=gql-0.2.0-cp36-none-any.whl size=7630 sha256=d71a26c76744d5bdffdcdebe8f71449eb39bcb3932a2928fb2a86d915fed0f00\n", | |
" Stored in directory: /root/.cache/pip/wheels/ce/0e/7b/58a8a5268655b3ad74feef5aa97946f0addafb3cbb6bd2da23\n", | |
" Building wheel for subprocess32 (setup.py) ... \u001b[?25l\u001b[?25hdone\n", | |
" Created wheel for subprocess32: filename=subprocess32-3.5.4-cp36-none-any.whl size=6489 sha256=29029465143372e3cb022d6356009abb6a9b65cf4d567475867070fa9f22b931\n", | |
" Stored in directory: /root/.cache/pip/wheels/68/39/1a/5e402bdfdf004af1786c8b853fd92f8c4a04f22aad179654d1\n", | |
" Building wheel for pathtools (setup.py) ... \u001b[?25l\u001b[?25hdone\n", | |
" Created wheel for pathtools: filename=pathtools-0.1.2-cp36-none-any.whl size=8784 sha256=55f43fa983289fd5e10c32146ac088a36e2020da8c04609919b8604ef443747b\n", | |
" Stored in directory: /root/.cache/pip/wheels/0b/04/79/c3b0c3a0266a3cb4376da31e5bfe8bba0c489246968a68e843\n", | |
" Building wheel for graphql-core (setup.py) ... \u001b[?25l\u001b[?25hdone\n", | |
" Created wheel for graphql-core: filename=graphql_core-1.1-cp36-none-any.whl size=104650 sha256=b8b4027b6eeabf627fd99fac272ce464f1cf80607423854d5541f9509582a3d8\n", | |
" Stored in directory: /root/.cache/pip/wheels/45/99/d7/c424029bb0fe910c63b68dbf2aa20d3283d023042521bcd7d5\n", | |
"Successfully built watchdog gql subprocess32 pathtools graphql-core\n", | |
"Installing collected packages: pathtools, watchdog, sentry-sdk, graphql-core, gql, subprocess32, smmap, gitdb, GitPython, docker-pycreds, configparser, shortuuid, wandb\n", | |
"Successfully installed GitPython-3.1.3 configparser-5.0.0 docker-pycreds-0.4.0 gitdb-4.0.5 gql-0.2.0 graphql-core-1.1 pathtools-0.1.2 sentry-sdk-0.15.1 shortuuid-1.0.1 smmap-3.0.4 subprocess32-3.5.4 wandb-0.9.2 watchdog-0.10.3\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "2tFg7xkyR6CN", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 162 | |
}, | |
"outputId": "281c4e2e-d346-4917-860a-b61ea1aa5a0a" | |
}, | |
"source": [ | |
"import torch\n", | |
"import torch.nn as nn\n", | |
"import torch.nn.functional as F\n", | |
"from torch import optim\n", | |
"\n", | |
"import wandb\n", | |
"\n", | |
"\n", | |
"def num_flat_features(x):\n", | |
" size = x.size()[1:] # all dimensions except the batch dimension\n", | |
" num_features = 1\n", | |
" for s in size:\n", | |
" num_features *= s\n", | |
" return num_features\n", | |
"\n", | |
"\n", | |
"class Net1(nn.Module):\n", | |
"\n", | |
" def __init__(self):\n", | |
" super(Net1, self).__init__()\n", | |
" # 1 input image channel, 6 output channels, 3x3 square convolution\n", | |
" # kernel\n", | |
" self.conv1 = nn.Conv2d(1, 6, 3)\n", | |
" self.conv2 = nn.Conv2d(6, 16, 3)\n", | |
" # an affine operation: y = Wx + b\n", | |
" self.fc1 = nn.Linear(16 * 6 * 6, 120) # 6*6 from image dimension\n", | |
" self.fc2 = nn.Linear(120, 84)\n", | |
" self.fc3 = nn.Linear(84, 10)\n", | |
"\n", | |
" def forward(self, x):\n", | |
" # Max pooling over a (2, 2) window\n", | |
" x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))\n", | |
" # If the size is a square you can only specify a single number\n", | |
" x = F.max_pool2d(F.relu(self.conv2(x)), 2)\n", | |
" x = x.view(-1, num_flat_features(x))\n", | |
" x = F.relu(self.fc1(x))\n", | |
" x = F.relu(self.fc2(x))\n", | |
" x = self.fc3(x)\n", | |
"\n", | |
" x_dict = {'a': x[0, 5:], 'b': x[0, :5]}\n", | |
" return x_dict\n", | |
"\n", | |
"\n", | |
"class Net2(nn.Module):\n", | |
"\n", | |
" def __init__(self):\n", | |
" super().__init__()\n", | |
" # 1 input image channel, 6 output channels, 3x3 square convolution\n", | |
" # kernel\n", | |
" self.conv1 = nn.Conv2d(1, 6, 3)\n", | |
" self.conv2 = nn.Conv2d(6, 16, 3)\n", | |
" # an affine operation: y = Wx + b\n", | |
" self.fc1 = nn.Linear(16 * 6 * 6, 120) # 6*6 from image dimension\n", | |
" self.fc2 = nn.Linear(120, 84)\n", | |
" self.fc3_1 = nn.Linear(84, 5)\n", | |
" self.fc3_2 = nn.Linear(84, 5)\n", | |
"\n", | |
" def forward(self, x):\n", | |
" # Max pooling over a (2, 2) window\n", | |
" x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))\n", | |
" # If the size is a square you can only specify a single number\n", | |
" x = F.max_pool2d(F.relu(self.conv2(x)), 2)\n", | |
" x = x.view(-1, num_flat_features(x))\n", | |
" x = F.relu(self.fc1(x))\n", | |
" x = F.relu(self.fc2(x))\n", | |
"\n", | |
" return {'a': self.fc3_1(x), 'b': self.fc3_2(x)}\n", | |
"\n", | |
"\n", | |
"if __name__ == \"__main__\":\n", | |
" wandb.init(project='my_test', group='conv_nets')\n", | |
"\n", | |
" # net = Net1()\n", | |
" net = Net2()\n", | |
" wandb.watch(net)\n", | |
"\n", | |
" optimizer = optim.SGD(net.parameters(), lr=0.01)\n", | |
" criterion = nn.MSELoss()\n", | |
"\n", | |
" for _ in range(100):\n", | |
" in_feats = torch.randn(1, 1, 32, 32)\n", | |
" target = {'a': torch.randn(5), 'b': torch.randn(5)}\n", | |
"\n", | |
" optimizer.zero_grad() # zero the gradient buffers\n", | |
" output = net(in_feats)\n", | |
"\n", | |
" loss = torch.zeros(1)\n", | |
" for key in output:\n", | |
" loss = criterion(output[key], target[key])\n", | |
"\n", | |
" wandb.log({'loss': loss})\n", | |
" loss.backward()\n", | |
" optimizer.step() # Does the update\n" | |
], | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"application/javascript": [ | |
"\n", | |
" window._wandbApiKey = new Promise((resolve, reject) => {\n", | |
" function loadScript(url) {\n", | |
" return new Promise(function(resolve, reject) {\n", | |
" let newScript = document.createElement(\"script\");\n", | |
" newScript.onerror = reject;\n", | |
" newScript.onload = resolve;\n", | |
" document.body.appendChild(newScript);\n", | |
" newScript.src = url;\n", | |
" });\n", | |
" }\n", | |
" loadScript(\"https://cdn.jsdelivr.net/npm/postmate/build/postmate.min.js\").then(() => {\n", | |
" const iframe = document.createElement('iframe')\n", | |
" iframe.style.cssText = \"width:0;height:0;border:none\"\n", | |
" document.body.appendChild(iframe)\n", | |
" const handshake = new Postmate({\n", | |
" container: iframe,\n", | |
" url: 'https://app.wandb.ai/authorize'\n", | |
" });\n", | |
" const timeout = setTimeout(() => reject(\"Couldn't auto authenticate\"), 5000)\n", | |
" handshake.then(function(child) {\n", | |
" child.on('authorize', data => {\n", | |
" clearTimeout(timeout)\n", | |
" resolve(data)\n", | |
" });\n", | |
" });\n", | |
" })\n", | |
" });\n", | |
" " | |
], | |
"text/plain": [ | |
"<IPython.core.display.Javascript object>" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
} | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"\u001b[34m\u001b[1mwandb\u001b[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc\n" | |
], | |
"name": "stderr" | |
}, | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/html": [ | |
"\n", | |
" Logging results to <a href=\"https://wandb.com\" target=\"_blank\">Weights & Biases</a> <a href=\"https://docs.wandb.com/integrations/jupyter.html\" target=\"_blank\">(Documentation)</a>.<br/>\n", | |
" Project page: <a href=\"https://app.wandb.ai/tetamusha/my_test\" target=\"_blank\">https://app.wandb.ai/tetamusha/my_test</a><br/>\n", | |
" Run page: <a href=\"https://app.wandb.ai/tetamusha/my_test/runs/336leocg\" target=\"_blank\">https://app.wandb.ai/tetamusha/my_test/runs/336leocg</a><br/>\n", | |
" " | |
], | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
} | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"/usr/local/lib/python3.6/dist-packages/torch/nn/modules/loss.py:432: UserWarning: Using a target size (torch.Size([5])) that is different to the input size (torch.Size([1, 5])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.\n", | |
" return F.mse_loss(input, target, reduction=self.reduction)\n" | |
], | |
"name": "stderr" | |
} | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment