Skip to content

Instantly share code, notes, and snippets.

Revisions

  1. davidefiocco revised this gist Mar 1, 2021. 1 changed file with 1715 additions and 119 deletions.
    1,834 changes: 1,715 additions & 119 deletions text-classification-in-pytorch-to-refactor-with-pytorch-lightning.ipynb
    1,715 additions, 119 deletions not shown because the diff is too large. Please use a local Git client to view these changes.
  2. davidefiocco revised this gist Feb 22, 2021. 1 changed file with 81 additions and 72 deletions.
    Original file line number Diff line number Diff line change
    @@ -22,12 +22,13 @@
    "colab": {
    "name": "Text classification in PyTorch to refactor with PyTorch lightning.ipynb",
    "provenance": [],
    "collapsed_sections": [],
    "include_colab_link": true
    },
    "accelerator": "GPU",
    "widgets": {
    "application/vnd.jupyter.widget-state+json": {
    "b308d20b6ed64d2197482ed1c3b78fb4": {
    "7a970585d344438cb8a7f6ad3549456c": {
    "model_module": "@jupyter-widgets/controls",
    "model_name": "HBoxModel",
    "state": {
    @@ -39,15 +40,15 @@
    "_view_count": null,
    "_view_module_version": "1.5.0",
    "box_style": "",
    "layout": "IPY_MODEL_bc014fea651046dc80a7482d1aeef913",
    "layout": "IPY_MODEL_7fc52ee4ac26451da2de282049d92df6",
    "_model_module": "@jupyter-widgets/controls",
    "children": [
    "IPY_MODEL_65a0c8a8032b4ceb83187d6816af3ccd",
    "IPY_MODEL_53c98ba3347d410c82b0fd71db70167d"
    "IPY_MODEL_13b55fdf91d24395bc33fb2295f27f7e",
    "IPY_MODEL_04dfbd8cde42447b8af2bbd939ed419d"
    ]
    }
    },
    "bc014fea651046dc80a7482d1aeef913": {
    "7fc52ee4ac26451da2de282049d92df6": {
    "model_module": "@jupyter-widgets/base",
    "model_name": "LayoutModel",
    "state": {
    @@ -98,12 +99,12 @@
    "left": null
    }
    },
    "65a0c8a8032b4ceb83187d6816af3ccd": {
    "13b55fdf91d24395bc33fb2295f27f7e": {
    "model_module": "@jupyter-widgets/controls",
    "model_name": "FloatProgressModel",
    "state": {
    "_view_name": "ProgressView",
    "style": "IPY_MODEL_2d3526fe06524cd79cc8859d532a8da1",
    "style": "IPY_MODEL_c3be6c3e533d436f94da6c2b4228d688",
    "_dom_classes": [],
    "description": "Epoch 0: 0%",
    "_model_name": "FloatProgressModel",
    @@ -118,15 +119,15 @@
    "min": 0,
    "description_tooltip": null,
    "_model_module": "@jupyter-widgets/controls",
    "layout": "IPY_MODEL_f6192130f3b04748aa3b8b7c6a14e290"
    "layout": "IPY_MODEL_fb98c0adf7584a27935a9cf19f1e981a"
    }
    },
    "53c98ba3347d410c82b0fd71db70167d": {
    "04dfbd8cde42447b8af2bbd939ed419d": {
    "model_module": "@jupyter-widgets/controls",
    "model_name": "HTMLModel",
    "state": {
    "_view_name": "HTMLView",
    "style": "IPY_MODEL_1a5e1382f32e4928b80118e1e1ffffea",
    "style": "IPY_MODEL_c2b1148532464811b11add3c5f453fc9",
    "_dom_classes": [],
    "description": "",
    "_model_name": "HTMLModel",
    @@ -138,10 +139,10 @@
    "_view_module_version": "1.5.0",
    "description_tooltip": null,
    "_model_module": "@jupyter-widgets/controls",
    "layout": "IPY_MODEL_e42749af77594d3f8f97414d79278c12"
    "layout": "IPY_MODEL_033287576bbd4089bc4c5537619352ba"
    }
    },
    "2d3526fe06524cd79cc8859d532a8da1": {
    "c3be6c3e533d436f94da6c2b4228d688": {
    "model_module": "@jupyter-widgets/controls",
    "model_name": "ProgressStyleModel",
    "state": {
    @@ -156,7 +157,7 @@
    "_model_module": "@jupyter-widgets/controls"
    }
    },
    "f6192130f3b04748aa3b8b7c6a14e290": {
    "fb98c0adf7584a27935a9cf19f1e981a": {
    "model_module": "@jupyter-widgets/base",
    "model_name": "LayoutModel",
    "state": {
    @@ -207,7 +208,7 @@
    "left": null
    }
    },
    "1a5e1382f32e4928b80118e1e1ffffea": {
    "c2b1148532464811b11add3c5f453fc9": {
    "model_module": "@jupyter-widgets/controls",
    "model_name": "DescriptionStyleModel",
    "state": {
    @@ -221,7 +222,7 @@
    "_model_module": "@jupyter-widgets/controls"
    }
    },
    "e42749af77594d3f8f97414d79278c12": {
    "033287576bbd4089bc4c5537619352ba": {
    "model_module": "@jupyter-widgets/base",
    "model_name": "LayoutModel",
    "state": {
    @@ -304,19 +305,43 @@
    "source": [
    "%matplotlib inline"
    ],
    "execution_count": 1,
    "execution_count": 2,
    "outputs": []
    },
    {
    "cell_type": "code",
    "metadata": {
    "id": "RWuv_rZYRdXz"
    "id": "RWuv_rZYRdXz",
    "outputId": "131615fd-a1a6-4b76-ebb7-239e973c8250",
    "colab": {
    "base_uri": "https://localhost:8080/"
    }
    },
    "source": [
    "pip install torchtext pytorch-lightning --upgrade --quiet"
    ],
    "execution_count": 2,
    "outputs": []
    "execution_count": 3,
    "outputs": [
    {
    "output_type": "stream",
    "text": [
    "\u001b[K |████████████████████████████████| 7.0MB 3.8MB/s \n",
    "\u001b[K |████████████████████████████████| 819kB 43.8MB/s \n",
    "\u001b[K |████████████████████████████████| 776.8MB 23kB/s \n",
    "\u001b[K |████████████████████████████████| 102kB 10.4MB/s \n",
    "\u001b[K |████████████████████████████████| 276kB 43.5MB/s \n",
    "\u001b[K |████████████████████████████████| 829kB 40.0MB/s \n",
    "\u001b[K |████████████████████████████████| 1.3MB 41.1MB/s \n",
    "\u001b[K |████████████████████████████████| 296kB 37.9MB/s \n",
    "\u001b[K |████████████████████████████████| 143kB 45.4MB/s \n",
    "\u001b[?25h Building wheel for PyYAML (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
    " Building wheel for future (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
    " Building wheel for idna-ssl (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
    "\u001b[31mERROR: torchvision 0.8.1+cu101 has requirement torch==1.7.0, but you'll have torch 1.7.1 which is incompatible.\u001b[0m\n"
    ],
    "name": "stdout"
    }
    ]
    },
    {
    "cell_type": "code",
    @@ -325,11 +350,12 @@
    "colab": {
    "base_uri": "https://localhost:8080/"
    },
    "outputId": "99cb63e9-b430-4096-a1a8-8126e76a2437"
    "outputId": "a3831522-d7c7-45f8-83f3-f3937609944a"
    },
    "source": [
    "import torch\n",
    "import torchtext\n",
    "import torch.nn as nn\n",
    "from torchtext.datasets import text_classification\n",
    "NGRAMS = 2\n",
    "import os\n",
    @@ -339,47 +365,19 @@
    " root='./.data', ngrams=NGRAMS, vocab=None)\n",
    "BATCH_SIZE = 16"
    ],
    "execution_count": 3,
    "execution_count": 12,
    "outputs": [
    {
    "output_type": "stream",
    "text": [
    "120000lines [00:06, 17747.80lines/s]\n",
    "120000lines [00:15, 7709.20lines/s]\n",
    "7600lines [00:00, 7603.03lines/s]\n"
    "120000lines [00:08, 14629.23lines/s]\n",
    "120000lines [00:21, 5473.50lines/s]\n",
    "7600lines [00:01, 5971.63lines/s]\n"
    ],
    "name": "stderr"
    }
    ]
    },
    {
    "cell_type": "code",
    "metadata": {
    "id": "jEbcAxuuQ2ei"
    },
    "source": [
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "class TextSentiment(nn.Module):\n",
    " def __init__(self, vocab_size, embed_dim, num_class):\n",
    " super().__init__()\n",
    " self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)\n",
    " self.fc = nn.Linear(embed_dim, num_class)\n",
    " self.init_weights()\n",
    "\n",
    " def init_weights(self):\n",
    " initrange = 0.5\n",
    " self.embedding.weight.data.uniform_(-initrange, initrange)\n",
    " self.fc.weight.data.uniform_(-initrange, initrange)\n",
    " self.fc.bias.data.zero_()\n",
    "\n",
    " def forward(self, text, offsets):\n",
    " embedded = self.embedding(text, offsets)\n",
    " return self.fc(embedded)"
    ],
    "execution_count": 4,
    "outputs": []
    },
    {
    "cell_type": "code",
    "metadata": {
    @@ -390,7 +388,7 @@
    "EMBED_DIM = 32\n",
    "NUM_CLASS = len(train_dataset.get_labels())"
    ],
    "execution_count": 5,
    "execution_count": 13,
    "outputs": []
    },
    {
    @@ -423,7 +421,7 @@
    " text = torch.cat(text)\n",
    " return text, offsets, label"
    ],
    "execution_count": 6,
    "execution_count": 14,
    "outputs": []
    },
    {
    @@ -451,7 +449,7 @@
    "sub_train_, sub_valid_ = \\\n",
    " random_split(train_dataset, [train_len, len(train_dataset) - train_len])"
    ],
    "execution_count": 7,
    "execution_count": 15,
    "outputs": []
    },
    {
    @@ -467,7 +465,7 @@
    "valid_dl = DataLoader(sub_valid_, batch_size=BATCH_SIZE, shuffle=True,\r\n",
    " collate_fn=generate_batch)"
    ],
    "execution_count": 8,
    "execution_count": 16,
    "outputs": []
    },
    {
    @@ -480,7 +478,7 @@
    "from torch import optim\r\n",
    "from torch.nn import CrossEntropyLoss"
    ],
    "execution_count": 9,
    "execution_count": 17,
    "outputs": []
    },
    {
    @@ -519,7 +517,7 @@
    "\r\n",
    " return loss"
    ],
    "execution_count": 10,
    "execution_count": 18,
    "outputs": []
    },
    {
    @@ -530,7 +528,7 @@
    "source": [
    "classifier = LitTextClassifier(NUM_CLASS)"
    ],
    "execution_count": 11,
    "execution_count": 19,
    "outputs": []
    },
    {
    @@ -541,23 +539,23 @@
    "base_uri": "https://localhost:8080/",
    "height": 826,
    "referenced_widgets": [
    "b308d20b6ed64d2197482ed1c3b78fb4",
    "bc014fea651046dc80a7482d1aeef913",
    "65a0c8a8032b4ceb83187d6816af3ccd",
    "53c98ba3347d410c82b0fd71db70167d",
    "2d3526fe06524cd79cc8859d532a8da1",
    "f6192130f3b04748aa3b8b7c6a14e290",
    "1a5e1382f32e4928b80118e1e1ffffea",
    "e42749af77594d3f8f97414d79278c12"
    "7a970585d344438cb8a7f6ad3549456c",
    "7fc52ee4ac26451da2de282049d92df6",
    "13b55fdf91d24395bc33fb2295f27f7e",
    "04dfbd8cde42447b8af2bbd939ed419d",
    "c3be6c3e533d436f94da6c2b4228d688",
    "fb98c0adf7584a27935a9cf19f1e981a",
    "c2b1148532464811b11add3c5f453fc9",
    "033287576bbd4089bc4c5537619352ba"
    ]
    },
    "outputId": "f6a4cb25-bd43-4ef8-f1b3-94debed883c2"
    "outputId": "bedb9d2d-2182-4ddd-c556-574bff395ad2"
    },
    "source": [
    "trainer = pl.Trainer(gpus=1, max_epochs=1)\r\n",
    "trainer.fit(classifier, train_dl)"
    ],
    "execution_count": 12,
    "execution_count": 20,
    "outputs": [
    {
    "output_type": "stream",
    @@ -593,7 +591,7 @@
    "output_type": "display_data",
    "data": {
    "application/vnd.jupyter.widget-view+json": {
    "model_id": "b308d20b6ed64d2197482ed1c3b78fb4",
    "model_id": "7a970585d344438cb8a7f6ad3549456c",
    "version_minor": 0,
    "version_major": 2
    },
    @@ -619,7 +617,7 @@
    "traceback": [
    "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
    "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
    "\u001b[0;32m<ipython-input-12-62647d0f905e>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mtrainer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpl\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTrainer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgpus\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_epochs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclassifier\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_dl\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
    "\u001b[0;32m<ipython-input-20-62647d0f905e>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mtrainer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpl\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTrainer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgpus\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_epochs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclassifier\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_dl\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
    "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, model, train_dataloader, val_dataloaders, datamodule)\u001b[0m\n\u001b[1;32m 511\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 512\u001b[0m \u001b[0;31m# dispath `start_training` or `start_testing` or `start_predicting`\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 513\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdispatch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 514\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 515\u001b[0m \u001b[0;31m# plugin will finalized fitting (e.g. ddp_spawn will load trained model)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
    "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36mdispatch\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 551\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 552\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 553\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maccelerator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstart_training\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 554\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 555\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mtrain_or_test_or_predict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
    "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/pytorch_lightning/accelerators/accelerator.py\u001b[0m in \u001b[0;36mstart_training\u001b[0;34m(self, trainer)\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 73\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mstart_training\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrainer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 74\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining_type_plugin\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstart_training\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 75\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 76\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mstart_testing\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrainer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
    @@ -641,7 +639,7 @@
    "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/pytorch_lightning/trainer/training_loop.py\u001b[0m in \u001b[0;36mtraining_step\u001b[0;34m(self, split_batch, batch_idx, opt_idx, hiddens)\u001b[0m\n\u001b[1;32m 291\u001b[0m \u001b[0mmodel_ref\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_results\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mResult\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 292\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofiler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"training_step\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 293\u001b[0;31m \u001b[0mtraining_step_output\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maccelerator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 294\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maccelerator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpost_training_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 295\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
    "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/pytorch_lightning/accelerators/accelerator.py\u001b[0m in \u001b[0;36mtraining_step\u001b[0;34m(self, args)\u001b[0m\n\u001b[1;32m 155\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 156\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprecision_plugin\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_step_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining_type_plugin\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_step_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 157\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining_type_plugin\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 158\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 159\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mpost_training_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
    "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py\u001b[0m in \u001b[0;36mtraining_step\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 120\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 121\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mtraining_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 122\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlightning_module\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 123\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 124\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mpost_training_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
    "\u001b[0;32m<ipython-input-10-67aaa49cb76c>\u001b[0m in \u001b[0;36mtraining_step\u001b[0;34m(self, batch, batch_idx)\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0mtext\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moffsets\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcls\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtext\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moffsets\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 27\u001b[0;31m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcriterion\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcls\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 28\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 29\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
    "\u001b[0;32m<ipython-input-18-67aaa49cb76c>\u001b[0m in \u001b[0;36mtraining_step\u001b[0;34m(self, batch, batch_idx)\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0mtext\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moffsets\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcls\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtext\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moffsets\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 27\u001b[0;31m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcriterion\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcls\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 28\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 29\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
    "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/nn/modules/loss.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, weight, size_average, ignore_index, reduce, reduction)\u001b[0m\n\u001b[1;32m 955\u001b[0m def __init__(self, weight: Optional[Tensor] = None, size_average=None, ignore_index: int = -100,\n\u001b[1;32m 956\u001b[0m reduce=None, reduction: str = 'mean') -> None:\n\u001b[0;32m--> 957\u001b[0;31m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mCrossEntropyLoss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msize_average\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreduce\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreduction\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 958\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mignore_index\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mignore_index\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 959\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
    "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/nn/modules/loss.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, weight, size_average, reduce, reduction)\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0;32mclass\u001b[0m \u001b[0m_WeightedLoss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_Loss\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptional\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msize_average\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreduce\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreduction\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mstr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'mean'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 25\u001b[0;31m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_WeightedLoss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msize_average\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreduce\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreduction\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 26\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mregister_buffer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'weight'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
    "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/nn/modules/loss.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, size_average, reduce, reduction)\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_Loss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0msize_average\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mreduce\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 18\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreduction\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_Reduction\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlegacy_get_string\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msize_average\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreduce\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 19\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreduction\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mreduction\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
    @@ -650,6 +648,17 @@
    ]
    }
    ]
    },
    {
    "cell_type": "code",
    "metadata": {
    "id": "GrqxQXadZ3vb"
    },
    "source": [
    ""
    ],
    "execution_count": null,
    "outputs": []
    }
    ]
    }
  3. davidefiocco revised this gist Feb 22, 2021. 1 changed file with 12 additions and 1 deletion.
    Original file line number Diff line number Diff line change
    @@ -21,7 +21,8 @@
    },
    "colab": {
    "name": "Text classification in PyTorch to refactor with PyTorch lightning.ipynb",
    "provenance": []
    "provenance": [],
    "include_colab_link": true
    },
    "accelerator": "GPU",
    "widgets": {
    @@ -275,6 +276,16 @@
    }
    },
    "cells": [
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "view-in-github",
    "colab_type": "text"
    },
    "source": [
    "<a href=\"https://colab.research.google.com/gist/davidefiocco/3b6c6b1e09c4f664b3a73e5bf24d1668/text-classification-in-pytorch-to-refactor-with-pytorch-lightning.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
    ]
    },
    {
    "cell_type": "markdown",
    "metadata": {
  4. davidefiocco created this gist Feb 22, 2021.
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,644 @@
    {
    "nbformat": 4,
    "nbformat_minor": 0,
    "metadata": {
    "kernelspec": {
    "display_name": "Python 3",
    "language": "python",
    "name": "python3"
    },
    "language_info": {
    "codemirror_mode": {
    "name": "ipython",
    "version": 3
    },
    "file_extension": ".py",
    "mimetype": "text/x-python",
    "name": "python",
    "nbconvert_exporter": "python",
    "pygments_lexer": "ipython3",
    "version": "3.6.8"
    },
    "colab": {
    "name": "Text classification in PyTorch to refactor with PyTorch lightning.ipynb",
    "provenance": []
    },
    "accelerator": "GPU",
    "widgets": {
    "application/vnd.jupyter.widget-state+json": {
    "b308d20b6ed64d2197482ed1c3b78fb4": {
    "model_module": "@jupyter-widgets/controls",
    "model_name": "HBoxModel",
    "state": {
    "_view_name": "HBoxView",
    "_dom_classes": [],
    "_model_name": "HBoxModel",
    "_view_module": "@jupyter-widgets/controls",
    "_model_module_version": "1.5.0",
    "_view_count": null,
    "_view_module_version": "1.5.0",
    "box_style": "",
    "layout": "IPY_MODEL_bc014fea651046dc80a7482d1aeef913",
    "_model_module": "@jupyter-widgets/controls",
    "children": [
    "IPY_MODEL_65a0c8a8032b4ceb83187d6816af3ccd",
    "IPY_MODEL_53c98ba3347d410c82b0fd71db70167d"
    ]
    }
    },
    "bc014fea651046dc80a7482d1aeef913": {
    "model_module": "@jupyter-widgets/base",
    "model_name": "LayoutModel",
    "state": {
    "_view_name": "LayoutView",
    "grid_template_rows": null,
    "right": null,
    "justify_content": null,
    "_view_module": "@jupyter-widgets/base",
    "overflow": null,
    "_model_module_version": "1.2.0",
    "_view_count": null,
    "flex_flow": "row wrap",
    "width": "100%",
    "min_width": null,
    "border": null,
    "align_items": null,
    "bottom": null,
    "_model_module": "@jupyter-widgets/base",
    "top": null,
    "grid_column": null,
    "overflow_y": null,
    "overflow_x": null,
    "grid_auto_flow": null,
    "grid_area": null,
    "grid_template_columns": null,
    "flex": null,
    "_model_name": "LayoutModel",
    "justify_items": null,
    "grid_row": null,
    "max_height": null,
    "align_content": null,
    "visibility": null,
    "align_self": null,
    "height": null,
    "min_height": null,
    "padding": null,
    "grid_auto_rows": null,
    "grid_gap": null,
    "max_width": null,
    "order": null,
    "_view_module_version": "1.2.0",
    "grid_template_areas": null,
    "object_position": null,
    "object_fit": null,
    "grid_auto_columns": null,
    "margin": null,
    "display": "inline-flex",
    "left": null
    }
    },
    "65a0c8a8032b4ceb83187d6816af3ccd": {
    "model_module": "@jupyter-widgets/controls",
    "model_name": "FloatProgressModel",
    "state": {
    "_view_name": "ProgressView",
    "style": "IPY_MODEL_2d3526fe06524cd79cc8859d532a8da1",
    "_dom_classes": [],
    "description": "Epoch 0: 0%",
    "_model_name": "FloatProgressModel",
    "bar_style": "danger",
    "max": 7125,
    "_view_module": "@jupyter-widgets/controls",
    "_model_module_version": "1.5.0",
    "value": 0,
    "_view_count": null,
    "_view_module_version": "1.5.0",
    "orientation": "horizontal",
    "min": 0,
    "description_tooltip": null,
    "_model_module": "@jupyter-widgets/controls",
    "layout": "IPY_MODEL_f6192130f3b04748aa3b8b7c6a14e290"
    }
    },
    "53c98ba3347d410c82b0fd71db70167d": {
    "model_module": "@jupyter-widgets/controls",
    "model_name": "HTMLModel",
    "state": {
    "_view_name": "HTMLView",
    "style": "IPY_MODEL_1a5e1382f32e4928b80118e1e1ffffea",
    "_dom_classes": [],
    "description": "",
    "_model_name": "HTMLModel",
    "placeholder": "",
    "_view_module": "@jupyter-widgets/controls",
    "_model_module_version": "1.5.0",
    "value": " 0/7125 [00:00&lt;?, ?it/s]",
    "_view_count": null,
    "_view_module_version": "1.5.0",
    "description_tooltip": null,
    "_model_module": "@jupyter-widgets/controls",
    "layout": "IPY_MODEL_e42749af77594d3f8f97414d79278c12"
    }
    },
    "2d3526fe06524cd79cc8859d532a8da1": {
    "model_module": "@jupyter-widgets/controls",
    "model_name": "ProgressStyleModel",
    "state": {
    "_view_name": "StyleView",
    "_model_name": "ProgressStyleModel",
    "description_width": "initial",
    "_view_module": "@jupyter-widgets/base",
    "_model_module_version": "1.5.0",
    "_view_count": null,
    "_view_module_version": "1.2.0",
    "bar_color": null,
    "_model_module": "@jupyter-widgets/controls"
    }
    },
    "f6192130f3b04748aa3b8b7c6a14e290": {
    "model_module": "@jupyter-widgets/base",
    "model_name": "LayoutModel",
    "state": {
    "_view_name": "LayoutView",
    "grid_template_rows": null,
    "right": null,
    "justify_content": null,
    "_view_module": "@jupyter-widgets/base",
    "overflow": null,
    "_model_module_version": "1.2.0",
    "_view_count": null,
    "flex_flow": null,
    "width": null,
    "min_width": null,
    "border": null,
    "align_items": null,
    "bottom": null,
    "_model_module": "@jupyter-widgets/base",
    "top": null,
    "grid_column": null,
    "overflow_y": null,
    "overflow_x": null,
    "grid_auto_flow": null,
    "grid_area": null,
    "grid_template_columns": null,
    "flex": "2",
    "_model_name": "LayoutModel",
    "justify_items": null,
    "grid_row": null,
    "max_height": null,
    "align_content": null,
    "visibility": null,
    "align_self": null,
    "height": null,
    "min_height": null,
    "padding": null,
    "grid_auto_rows": null,
    "grid_gap": null,
    "max_width": null,
    "order": null,
    "_view_module_version": "1.2.0",
    "grid_template_areas": null,
    "object_position": null,
    "object_fit": null,
    "grid_auto_columns": null,
    "margin": null,
    "display": null,
    "left": null
    }
    },
    "1a5e1382f32e4928b80118e1e1ffffea": {
    "model_module": "@jupyter-widgets/controls",
    "model_name": "DescriptionStyleModel",
    "state": {
    "_view_name": "StyleView",
    "_model_name": "DescriptionStyleModel",
    "description_width": "",
    "_view_module": "@jupyter-widgets/base",
    "_model_module_version": "1.5.0",
    "_view_count": null,
    "_view_module_version": "1.2.0",
    "_model_module": "@jupyter-widgets/controls"
    }
    },
    "e42749af77594d3f8f97414d79278c12": {
    "model_module": "@jupyter-widgets/base",
    "model_name": "LayoutModel",
    "state": {
    "_view_name": "LayoutView",
    "grid_template_rows": null,
    "right": null,
    "justify_content": null,
    "_view_module": "@jupyter-widgets/base",
    "overflow": null,
    "_model_module_version": "1.2.0",
    "_view_count": null,
    "flex_flow": null,
    "width": null,
    "min_width": null,
    "border": null,
    "align_items": null,
    "bottom": null,
    "_model_module": "@jupyter-widgets/base",
    "top": null,
    "grid_column": null,
    "overflow_y": null,
    "overflow_x": null,
    "grid_auto_flow": null,
    "grid_area": null,
    "grid_template_columns": null,
    "flex": null,
    "_model_name": "LayoutModel",
    "justify_items": null,
    "grid_row": null,
    "max_height": null,
    "align_content": null,
    "visibility": null,
    "align_self": null,
    "height": null,
    "min_height": null,
    "padding": null,
    "grid_auto_rows": null,
    "grid_gap": null,
    "max_width": null,
    "order": null,
    "_view_module_version": "1.2.0",
    "grid_template_areas": null,
    "object_position": null,
    "object_fit": null,
    "grid_auto_columns": null,
    "margin": null,
    "display": null,
    "left": null
    }
    }
    }
    }
    },
    "cells": [
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "SVhHUPSWPCHG"
    },
    "source": [
    "The code below is a PyTorch text classifier obtained by getting code from https://pytorch.org/tutorials/beginner/text_sentiment_ngrams_tutorial.html, attempting (clumsily) a PyTorch lightning refactor.\r\n",
    "How can I make this work with PyTorch lightning?"
    ]
    },
    {
    "cell_type": "code",
    "metadata": {
    "id": "8-un-rXKQ2ec"
    },
    "source": [
    "%matplotlib inline"
    ],
    "execution_count": 1,
    "outputs": []
    },
    {
    "cell_type": "code",
    "metadata": {
    "id": "RWuv_rZYRdXz"
    },
    "source": [
    "pip install torchtext pytorch-lightning --upgrade --quiet"
    ],
    "execution_count": 2,
    "outputs": []
    },
    {
    "cell_type": "code",
    "metadata": {
    "id": "vOAdk0qdQ2ef",
    "colab": {
    "base_uri": "https://localhost:8080/"
    },
    "outputId": "99cb63e9-b430-4096-a1a8-8126e76a2437"
    },
    "source": [
    "import torch\n",
    "import torchtext\n",
    "from torchtext.datasets import text_classification\n",
    "NGRAMS = 2\n",
    "import os\n",
    "if not os.path.isdir('./.data'):\n",
    "\tos.mkdir('./.data')\n",
    "train_dataset, test_dataset = text_classification.DATASETS['AG_NEWS'](\n",
    " root='./.data', ngrams=NGRAMS, vocab=None)\n",
    "BATCH_SIZE = 16"
    ],
    "execution_count": 3,
    "outputs": [
    {
    "output_type": "stream",
    "text": [
    "120000lines [00:06, 17747.80lines/s]\n",
    "120000lines [00:15, 7709.20lines/s]\n",
    "7600lines [00:00, 7603.03lines/s]\n"
    ],
    "name": "stderr"
    }
    ]
    },
    {
    "cell_type": "code",
    "metadata": {
    "id": "jEbcAxuuQ2ei"
    },
    "source": [
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "class TextSentiment(nn.Module):\n",
    " def __init__(self, vocab_size, embed_dim, num_class):\n",
    " super().__init__()\n",
    " self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)\n",
    " self.fc = nn.Linear(embed_dim, num_class)\n",
    " self.init_weights()\n",
    "\n",
    " def init_weights(self):\n",
    " initrange = 0.5\n",
    " self.embedding.weight.data.uniform_(-initrange, initrange)\n",
    " self.fc.weight.data.uniform_(-initrange, initrange)\n",
    " self.fc.bias.data.zero_()\n",
    "\n",
    " def forward(self, text, offsets):\n",
    " embedded = self.embedding(text, offsets)\n",
    " return self.fc(embedded)"
    ],
    "execution_count": 4,
    "outputs": []
    },
    {
    "cell_type": "code",
    "metadata": {
    "id": "SUOxiQOJQ2el"
    },
    "source": [
    "VOCAB_SIZE = len(train_dataset.get_vocab())\n",
    "EMBED_DIM = 32\n",
    "NUM_CLASS = len(train_dataset.get_labels())"
    ],
    "execution_count": 5,
    "outputs": []
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "YJXial03Q2em"
    },
    "source": [
    "Functions used to generate batch\n",
    "--------------------------------\n",
    "\n",
    "\n"
    ]
    },
    {
    "cell_type": "code",
    "metadata": {
    "id": "36rTk-S5Q2en"
    },
    "source": [
    "def generate_batch(batch):\n",
    " label = torch.tensor([entry[0] for entry in batch])\n",
    " text = [entry[1] for entry in batch]\n",
    " offsets = [0] + [len(entry) for entry in text]\n",
    " # torch.Tensor.cumsum returns the cumulative sum\n",
    " # of elements in the dimension dim.\n",
    " # torch.Tensor([1.0, 2.0, 3.0]).cumsum(dim=0)\n",
    "\n",
    " offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)\n",
    " text = torch.cat(text)\n",
    " return text, offsets, label"
    ],
    "execution_count": 6,
    "outputs": []
    },
    {
    "cell_type": "markdown",
    "metadata": {
    "id": "kiaOrOBhQ2ep"
    },
    "source": [
    "Define functions to train the model and evaluate results.\n",
    "---------------------------------------------------------\n",
    "\n",
    "\n"
    ]
    },
    {
    "cell_type": "code",
    "metadata": {
    "id": "mMzZQ2e-Q2er"
    },
    "source": [
    "import time\n",
    "from torch.utils.data.dataset import random_split\n",
    "\n",
    "train_len = int(len(train_dataset) * 0.95)\n",
    "sub_train_, sub_valid_ = \\\n",
    " random_split(train_dataset, [train_len, len(train_dataset) - train_len])"
    ],
    "execution_count": 7,
    "outputs": []
    },
    {
    "cell_type": "code",
    "metadata": {
    "id": "drSdUCLuSuhK"
    },
    "source": [
    "from torch.utils.data import DataLoader\r\n",
    "\r\n",
    "train_dl = DataLoader(sub_train_, batch_size=BATCH_SIZE, shuffle=True,\r\n",
    " collate_fn=generate_batch)\r\n",
    "valid_dl = DataLoader(sub_valid_, batch_size=BATCH_SIZE, shuffle=True,\r\n",
    " collate_fn=generate_batch)"
    ],
    "execution_count": 8,
    "outputs": []
    },
    {
    "cell_type": "code",
    "metadata": {
    "id": "9-3usDU_TGY9"
    },
    "source": [
    "import pytorch_lightning as pl\r\n",
    "from torch import optim\r\n",
    "from torch.nn import CrossEntropyLoss"
    ],
    "execution_count": 9,
    "outputs": []
    },
    {
    "cell_type": "code",
    "metadata": {
    "id": "qFUFjqu_S9Gp"
    },
    "source": [
    "class LitTextClassifier(pl.LightningModule):\r\n",
    " def __init__(self, num_class, criterion = CrossEntropyLoss):\r\n",
    " super().__init__()\r\n",
    " self.embedding = nn.EmbeddingBag(VOCAB_SIZE, EMBED_DIM, sparse=False)\r\n",
    " self.fc = nn.Linear(EMBED_DIM, num_class)\r\n",
    " self.init_weights()\r\n",
    " self.criterion = criterion\r\n",
    "\r\n",
    " def init_weights(self):\r\n",
    " initrange = 0.5\r\n",
    " self.embedding.weight.data.uniform_(-initrange, initrange)\r\n",
    " self.fc.weight.data.uniform_(-initrange, initrange)\r\n",
    " self.fc.bias.data.zero_()\r\n",
    "\r\n",
    " def forward(self, text, offsets):\r\n",
    " embedded = self.embedding(text, offsets)\r\n",
    " return self.fc(embedded)\r\n",
    "\r\n",
    " def configure_optimizers(self):\r\n",
    " optimizer = optim.SGD(self.parameters(), lr=4.0)\r\n",
    " return optimizer\r\n",
    "\r\n",
    " def training_step(self, batch, batch_idx):\r\n",
    " # I am messing up things here\r\n",
    " text, offsets, cls = batch\r\n",
    " output = self.forward(text, offsets)\r\n",
    " loss = self.criterion(output, cls)\r\n",
    "\r\n",
    " return loss"
    ],
    "execution_count": 10,
    "outputs": []
    },
    {
    "cell_type": "code",
    "metadata": {
    "id": "TgxhOV1YUCoD"
    },
    "source": [
    "classifier = LitTextClassifier(NUM_CLASS)"
    ],
    "execution_count": 11,
    "outputs": []
    },
    {
    "cell_type": "code",
    "metadata": {
    "id": "BYgIuqqW7kQw",
    "colab": {
    "base_uri": "https://localhost:8080/",
    "height": 826,
    "referenced_widgets": [
    "b308d20b6ed64d2197482ed1c3b78fb4",
    "bc014fea651046dc80a7482d1aeef913",
    "65a0c8a8032b4ceb83187d6816af3ccd",
    "53c98ba3347d410c82b0fd71db70167d",
    "2d3526fe06524cd79cc8859d532a8da1",
    "f6192130f3b04748aa3b8b7c6a14e290",
    "1a5e1382f32e4928b80118e1e1ffffea",
    "e42749af77594d3f8f97414d79278c12"
    ]
    },
    "outputId": "f6a4cb25-bd43-4ef8-f1b3-94debed883c2"
    },
    "source": [
    "trainer = pl.Trainer(gpus=1, max_epochs=1)\r\n",
    "trainer.fit(classifier, train_dl)"
    ],
    "execution_count": 12,
    "outputs": [
    {
    "output_type": "stream",
    "text": [
    "GPU available: True, used: True\n",
    "INFO:lightning:GPU available: True, used: True\n",
    "TPU available: None, using: 0 TPU cores\n",
    "INFO:lightning:TPU available: None, using: 0 TPU cores\n",
    "\n",
    " | Name | Type | Params\n",
    "-------------------------------------------\n",
    "0 | embedding | EmbeddingBag | 41.9 M\n",
    "1 | fc | Linear | 132 \n",
    "-------------------------------------------\n",
    "41.9 M Trainable params\n",
    "0 Non-trainable params\n",
    "41.9 M Total params\n",
    "167.533 Total estimated model params size (MB)\n",
    "INFO:lightning:\n",
    " | Name | Type | Params\n",
    "-------------------------------------------\n",
    "0 | embedding | EmbeddingBag | 41.9 M\n",
    "1 | fc | Linear | 132 \n",
    "-------------------------------------------\n",
    "41.9 M Trainable params\n",
    "0 Non-trainable params\n",
    "41.9 M Total params\n",
    "167.533 Total estimated model params size (MB)\n"
    ],
    "name": "stderr"
    },
    {
    "output_type": "display_data",
    "data": {
    "application/vnd.jupyter.widget-view+json": {
    "model_id": "b308d20b6ed64d2197482ed1c3b78fb4",
    "version_minor": 0,
    "version_major": 2
    },
    "text/plain": [
    "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…"
    ]
    },
    "metadata": {
    "tags": []
    }
    },
    {
    "output_type": "stream",
    "text": [
    "\n"
    ],
    "name": "stdout"
    },
    {
    "output_type": "error",
    "ename": "RuntimeError",
    "evalue": "ignored",
    "traceback": [
    "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
    "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
    "\u001b[0;32m<ipython-input-12-62647d0f905e>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mtrainer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpl\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTrainer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgpus\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_epochs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclassifier\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_dl\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
    "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, model, train_dataloader, val_dataloaders, datamodule)\u001b[0m\n\u001b[1;32m 511\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 512\u001b[0m \u001b[0;31m# dispath `start_training` or `start_testing` or `start_predicting`\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 513\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdispatch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 514\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 515\u001b[0m \u001b[0;31m# plugin will finalized fitting (e.g. ddp_spawn will load trained model)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
    "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36mdispatch\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 551\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 552\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 553\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maccelerator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstart_training\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 554\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 555\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mtrain_or_test_or_predict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
    "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/pytorch_lightning/accelerators/accelerator.py\u001b[0m in \u001b[0;36mstart_training\u001b[0;34m(self, trainer)\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 73\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mstart_training\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrainer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 74\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining_type_plugin\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstart_training\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 75\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 76\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mstart_testing\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrainer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
    "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py\u001b[0m in \u001b[0;36mstart_training\u001b[0;34m(self, trainer)\u001b[0m\n\u001b[1;32m 109\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mstart_training\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrainer\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'Trainer'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 110\u001b[0m \u001b[0;31m# double dispatch to initiate the training loop\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 111\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_results\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_train\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 112\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mstart_testing\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrainer\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'Trainer'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
    "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/pytorch_lightning/trainer/trainer.py\u001b[0m in \u001b[0;36mrun_train\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 642\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofiler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"run_training_epoch\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 643\u001b[0m \u001b[0;31m# run train epoch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 644\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_loop\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_training_epoch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 645\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 646\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax_steps\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax_steps\u001b[0m \u001b[0;34m<=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mglobal_step\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
    "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/pytorch_lightning/trainer/training_loop.py\u001b[0m in \u001b[0;36mrun_training_epoch\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 490\u001b[0m \u001b[0;31m# ------------------------------------\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 491\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofiler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"run_training_batch\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 492\u001b[0;31m \u001b[0mbatch_output\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_training_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdataloader_idx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 493\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 494\u001b[0m \u001b[0;31m# when returning -1 from train_step, we end epoch early\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
    "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/pytorch_lightning/trainer/training_loop.py\u001b[0m in \u001b[0;36mrun_training_batch\u001b[0;34m(self, batch, batch_idx, dataloader_idx)\u001b[0m\n\u001b[1;32m 648\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 649\u001b[0m \u001b[0;31m# optimizer step\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 650\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mopt_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_step_and_backward_closure\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 651\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 652\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
    "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/pytorch_lightning/trainer/training_loop.py\u001b[0m in \u001b[0;36moptimizer_step\u001b[0;34m(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure)\u001b[0m\n\u001b[1;32m 432\u001b[0m \u001b[0mon_tpu\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_device_type\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mDeviceType\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTPU\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0m_TPU_AVAILABLE\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 433\u001b[0m \u001b[0musing_native_amp\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0musing_native_amp\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 434\u001b[0;31m \u001b[0musing_lbfgs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mis_lbfgs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 435\u001b[0m )\n\u001b[1;32m 436\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
    "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/pytorch_lightning/core/lightning.py\u001b[0m in \u001b[0;36moptimizer_step\u001b[0;34m(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs)\u001b[0m\n\u001b[1;32m 1382\u001b[0m \u001b[0;31m# wraps into LightingOptimizer only for running step\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1383\u001b[0m \u001b[0moptimizer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mLightningOptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_to_lightning_optimizer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer_idx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1384\u001b[0;31m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclosure\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moptimizer_closure\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1385\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1386\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0moptimizer_zero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepoch\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer_idx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
    "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/pytorch_lightning/core/optimizer.py\u001b[0m in \u001b[0;36mstep\u001b[0;34m(self, closure, *args, **kwargs)\u001b[0m\n\u001b[1;32m 217\u001b[0m \u001b[0mprofiler_name\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34mf\"optimizer_step_and_closure_{self._optimizer_idx}\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 218\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 219\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__optimizer_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mclosure\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mclosure\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mprofiler_name\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mprofiler_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 220\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_total_optimizer_step_calls\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 221\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
    "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/pytorch_lightning/core/optimizer.py\u001b[0m in \u001b[0;36m__optimizer_step\u001b[0;34m(self, closure, profiler_name, **kwargs)\u001b[0m\n\u001b[1;32m 133\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 134\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofiler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mprofiler_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 135\u001b[0;31m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maccelerator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_optimizer_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlambda_closure\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mclosure\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 136\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 137\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_trainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_loop\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautomatic_optimization\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
    "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/pytorch_lightning/accelerators/accelerator.py\u001b[0m in \u001b[0;36moptimizer_step\u001b[0;34m(self, optimizer, opt_idx, lambda_closure, **kwargs)\u001b[0m\n\u001b[1;32m 276\u001b[0m )\n\u001b[1;32m 277\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mmake_optimizer_step\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 278\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_optimizer_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mopt_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlambda_closure\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 279\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprecision_plugin\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpost_optimizer_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mopt_idx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 280\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining_type_plugin\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpost_optimizer_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mopt_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
    "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/pytorch_lightning/accelerators/accelerator.py\u001b[0m in \u001b[0;36mrun_optimizer_step\u001b[0;34m(self, optimizer, optimizer_idx, lambda_closure, **kwargs)\u001b[0m\n\u001b[1;32m 281\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 282\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mrun_optimizer_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer_idx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlambda_closure\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mCallable\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 283\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining_type_plugin\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlambda_closure\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlambda_closure\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 284\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 285\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0moptimizer_zero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcurrent_epoch\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mopt_idx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
    "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py\u001b[0m in \u001b[0;36moptimizer_step\u001b[0;34m(self, optimizer, lambda_closure, **kwargs)\u001b[0m\n\u001b[1;32m 158\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 159\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0moptimizer_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptim\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mOptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlambda_closure\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mCallable\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 160\u001b[0;31m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mclosure\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlambda_closure\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
    "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/autograd/grad_mode.py\u001b[0m in \u001b[0;36mdecorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mdecorate_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__class__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 26\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 27\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mcast\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mF\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdecorate_context\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
    "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/optim/sgd.py\u001b[0m in \u001b[0;36mstep\u001b[0;34m(self, closure)\u001b[0m\n\u001b[1;32m 84\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mclosure\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 85\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menable_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 86\u001b[0;31m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mclosure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 87\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 88\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mgroup\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparam_groups\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
    "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/pytorch_lightning/trainer/training_loop.py\u001b[0m in \u001b[0;36mtrain_step_and_backward_closure\u001b[0;34m()\u001b[0m\n\u001b[1;32m 643\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mtrain_step_and_backward_closure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 644\u001b[0m result = self.training_step_and_backward(\n\u001b[0;32m--> 645\u001b[0;31m \u001b[0msplit_batch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mopt_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhiddens\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 646\u001b[0m )\n\u001b[1;32m 647\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloss\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
    "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/pytorch_lightning/trainer/training_loop.py\u001b[0m in \u001b[0;36mtraining_step_and_backward\u001b[0;34m(self, split_batch, batch_idx, opt_idx, optimizer, hiddens)\u001b[0m\n\u001b[1;32m 736\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofiler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"training_step_and_backward\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 737\u001b[0m \u001b[0;31m# lightning module hook\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 738\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msplit_batch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mopt_idx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhiddens\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 739\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_curr_step_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 740\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
    "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/pytorch_lightning/trainer/training_loop.py\u001b[0m in \u001b[0;36mtraining_step\u001b[0;34m(self, split_batch, batch_idx, opt_idx, hiddens)\u001b[0m\n\u001b[1;32m 291\u001b[0m \u001b[0mmodel_ref\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_results\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mResult\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 292\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofiler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprofile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"training_step\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 293\u001b[0;31m \u001b[0mtraining_step_output\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maccelerator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 294\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maccelerator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpost_training_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 295\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
    "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/pytorch_lightning/accelerators/accelerator.py\u001b[0m in \u001b[0;36mtraining_step\u001b[0;34m(self, args)\u001b[0m\n\u001b[1;32m 155\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 156\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprecision_plugin\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_step_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining_type_plugin\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_step_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 157\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining_type_plugin\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 158\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 159\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mpost_training_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
    "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py\u001b[0m in \u001b[0;36mtraining_step\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 120\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 121\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mtraining_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 122\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlightning_module\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 123\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 124\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mpost_training_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
    "\u001b[0;32m<ipython-input-10-67aaa49cb76c>\u001b[0m in \u001b[0;36mtraining_step\u001b[0;34m(self, batch, batch_idx)\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0mtext\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moffsets\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcls\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtext\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moffsets\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 27\u001b[0;31m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcriterion\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcls\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 28\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 29\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
    "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/nn/modules/loss.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, weight, size_average, ignore_index, reduce, reduction)\u001b[0m\n\u001b[1;32m 955\u001b[0m def __init__(self, weight: Optional[Tensor] = None, size_average=None, ignore_index: int = -100,\n\u001b[1;32m 956\u001b[0m reduce=None, reduction: str = 'mean') -> None:\n\u001b[0;32m--> 957\u001b[0;31m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mCrossEntropyLoss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msize_average\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreduce\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreduction\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 958\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mignore_index\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mignore_index\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 959\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
    "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/nn/modules/loss.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, weight, size_average, reduce, reduction)\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0;32mclass\u001b[0m \u001b[0m_WeightedLoss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_Loss\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptional\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msize_average\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreduce\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreduction\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mstr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'mean'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 25\u001b[0;31m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_WeightedLoss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msize_average\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreduce\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreduction\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 26\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mregister_buffer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'weight'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
    "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/nn/modules/loss.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, size_average, reduce, reduction)\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_Loss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0msize_average\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mreduce\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 18\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreduction\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_Reduction\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlegacy_get_string\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msize_average\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreduce\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 19\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreduction\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mreduction\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
    "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/nn/_reduction.py\u001b[0m in \u001b[0;36mlegacy_get_string\u001b[0;34m(size_average, reduce, emit_warning)\u001b[0m\n\u001b[1;32m 35\u001b[0m \u001b[0mreduce\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 36\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 37\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0msize_average\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mreduce\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 38\u001b[0m \u001b[0mret\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'mean'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 39\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mreduce\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
    "\u001b[0;31mRuntimeError\u001b[0m: Boolean value of Tensor with more than one value is ambiguous"
    ]
    }
    ]
    }
    ]
    }