{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "es-rnn-colab-nb-example.ipynb",
      "provenance": [],
      "collapsed_sections": [],
      "authorship_tag": "ABX9TyO3yaYBU+ZfWti+MxkoFxGY",
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/gist/florisrc/e9bc4cb054c6ad1c8f1ac43a9c21d09f/es-rnn-colab-nb-example.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IBEgiTXP4u4o",
        "colab_type": "text"
      },
      "source": [
        "# ES-RNN Colab NB Example\n",
        "\n",
        "A GPU-enabled version of the hybrid ES-RNN model by Slawek et al that won the M4 time-series forecasting competition by a large margin, here implemented in a Google Colab environment. The details of our implementation and the results are discussed in detail on this [paper](https://arxiv.org/abs/1907.03329).\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FuudmEWFycm5",
        "colab_type": "text"
      },
      "source": [
        "## Get data and code:\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ovCK3_O-xU5E",
        "colab_type": "code",
        "outputId": "572ca8da-cab3-4050-bf55-50ca703f2b3b",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        }
      },
      "source": [
        "# get data\n",
        "%cd /content\n",
        "!mkdir /content/m4_data \n",
        "%cd /content/m4_data\n",
        "!wget https://www.m4.unic.ac.cy/wp-content/uploads/2017/12/M4DataSet.zip\n",
        "!wget https://www.m4.unic.ac.cy/wp-content/uploads/2018/07/M-test-set.zip\n",
        "!wget https://github.com/M4Competition/M4-methods/raw/master/Dataset/M4-info.csv\n",
        "!mkdir ./Train && cd ./Train && unzip ../M4DataSet.zip && cd ..\n",
        "!mkdir ./Test && cd ./Test && unzip ../M-test-set.zip && cd ..\n",
        "\n",
        "# clone git repo\n",
        "%cd /content\n",
        "!git clone https://github.com/damitkwr/ESRNN-GPU.git\n",
        "\n",
        "# copy data to repo\n",
        "%cd /content/ESRNN-GPU/\n",
        "!mkdir ./data\n",
        "%cd data/\n",
        "!mkdir ./Train && cp /content/m4_data/Train/* ./Train/\n",
        "!mkdir ./Test && cp /content/m4_data/Test/* ./Test/\n",
        "!cp /content/m4_data/M4-info.csv ./info.csv\n",
        "!cd ../.."
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "/content\n",
            "/content/m4_data\n",
            "--2020-04-20 10:14:51--  https://www.m4.unic.ac.cy/wp-content/uploads/2017/12/M4DataSet.zip\n",
            "Resolving www.m4.unic.ac.cy (www.m4.unic.ac.cy)... 35.177.142.35, 35.176.90.68\n",
            "Connecting to www.m4.unic.ac.cy (www.m4.unic.ac.cy)|35.177.142.35|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 66613994 (64M) [application/zip]\n",
            "Saving to: ‘M4DataSet.zip’\n",
            "\n",
            "M4DataSet.zip       100%[===================>]  63.53M  19.3MB/s    in 3.6s    \n",
            "\n",
            "2020-04-20 10:14:55 (17.9 MB/s) - ‘M4DataSet.zip’ saved [66613994/66613994]\n",
            "\n",
            "--2020-04-20 10:14:56--  https://www.m4.unic.ac.cy/wp-content/uploads/2018/07/M-test-set.zip\n",
            "Resolving www.m4.unic.ac.cy (www.m4.unic.ac.cy)... 35.177.142.35, 35.176.90.68\n",
            "Connecting to www.m4.unic.ac.cy (www.m4.unic.ac.cy)|35.177.142.35|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 3723045 (3.5M) [application/zip]\n",
            "Saving to: ‘M-test-set.zip’\n",
            "\n",
            "M-test-set.zip      100%[===================>]   3.55M  4.20MB/s    in 0.8s    \n",
            "\n",
            "2020-04-20 10:14:57 (4.20 MB/s) - ‘M-test-set.zip’ saved [3723045/3723045]\n",
            "\n",
            "--2020-04-20 10:14:58--  https://github.com/M4Competition/M4-methods/raw/master/Dataset/M4-info.csv\n",
            "Resolving github.com (github.com)... 192.30.255.112\n",
            "Connecting to github.com (github.com)|192.30.255.112|:443... connected.\n",
            "HTTP request sent, awaiting response... 301 Moved Permanently\n",
            "Location: https://github.com/Mcompetitions/M4-methods/raw/master/Dataset/M4-info.csv [following]\n",
            "--2020-04-20 10:14:58--  https://github.com/Mcompetitions/M4-methods/raw/master/Dataset/M4-info.csv\n",
            "Reusing existing connection to github.com:443.\n",
            "HTTP request sent, awaiting response... 302 Found\n",
            "Location: https://raw.githubusercontent.com/Mcompetitions/M4-methods/master/Dataset/M4-info.csv [following]\n",
            "--2020-04-20 10:14:59--  https://raw.githubusercontent.com/Mcompetitions/M4-methods/master/Dataset/M4-info.csv\n",
            "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.0.133, 151.101.64.133, 151.101.128.133, ...\n",
            "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.0.133|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 4335598 (4.1M) [text/plain]\n",
            "Saving to: ‘M4-info.csv’\n",
            "\n",
            "M4-info.csv         100%[===================>]   4.13M  --.-KB/s    in 0.1s    \n",
            "\n",
            "2020-04-20 10:14:59 (31.7 MB/s) - ‘M4-info.csv’ saved [4335598/4335598]\n",
            "\n",
            "Archive:  ../M4DataSet.zip\n",
            "  inflating: Daily-train.csv         \n",
            "  inflating: Hourly-train.csv        \n",
            "  inflating: Monthly-train.csv       \n",
            "  inflating: Quarterly-train.csv     \n",
            "  inflating: Weekly-train.csv        \n",
            "  inflating: Yearly-train.csv        \n",
            "Archive:  ../M-test-set.zip\n",
            "  inflating: Daily-test.csv          \n",
            "  inflating: Hourly-test.csv         \n",
            "  inflating: Monthly-test.csv        \n",
            "  inflating: Quarterly-test.csv      \n",
            "  inflating: Weekly-test.csv         \n",
            "  inflating: Yearly-test.csv         \n",
            "/content\n",
            "Cloning into 'ESRNN-GPU'...\n",
            "remote: Enumerating objects: 34, done.\u001b[K\n",
            "remote: Counting objects: 100% (34/34), done.\u001b[K\n",
            "remote: Compressing objects: 100% (25/25), done.\u001b[K\n",
            "remote: Total 521 (delta 18), reused 23 (delta 9), pack-reused 487\u001b[K\n",
            "Receiving objects: 100% (521/521), 76.96 MiB | 21.29 MiB/s, done.\n",
            "Resolving deltas: 100% (330/330), done.\n",
            "/content/ESRNN-GPU\n",
            "/content/ESRNN-GPU/data\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BhYOwpDjyiqP",
        "colab_type": "text"
      },
      "source": [
        "## Create colab environment with correct library versions"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "c9l1BWaOxUuw",
        "colab_type": "code",
        "outputId": "c53f5d40-2f6e-4574-db11-4ebc2dae8029",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 268
        }
      },
      "source": [
        "# uninstall torch  \n",
        "!pip uninstall torch\n",
        "!pip uninstall torch # run twice (recommendation pytorch forums)\n",
        "\n",
        "# and re-install as 0.4.1\n",
        "from os import path\n",
        "from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag\n",
        "platform = '{}{}-{}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag())\n",
        "\n",
        "accelerator = 'cu80' if path.exists('/opt/bin/nvidia-smi') else 'cpu'\n",
        "\n",
        "!pip install -q http://download.pytorch.org/whl/{accelerator}/torch-0.4.1-{platform}-linux_x86_64.whl torchvision\n",
        "\n",
        "# tensorflow version 1 \n",
        "%tensorflow_version 1.x\n",
        "\n",
        "import torch\n",
        "import tensorflow as tf \n",
        "print(f'Torch version: {torch.__version__}')\n",
        "print(f'Tensorflow version: {tf.__version__}')\n",
        "print(f'Torch.cuda.is_available: {torch.cuda.is_available()}')"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Uninstalling torch-0.4.1:\n",
            "  Would remove:\n",
            "    /usr/local/lib/python3.6/dist-packages/torch-0.4.1.dist-info/*\n",
            "    /usr/local/lib/python3.6/dist-packages/torch/*\n",
            "Proceed (y/n)? y\n",
            "  Successfully uninstalled torch-0.4.1\n",
            "\u001b[33mWARNING: Skipping torch as it is not installed.\u001b[0m\n",
            "\u001b[K     |████████████████████████████████| 483.0MB 1.2MB/s \n",
            "\u001b[31mERROR: torchvision 0.5.0 has requirement torch==1.4.0, but you'll have torch 0.4.1 which is incompatible.\u001b[0m\n",
            "\u001b[31mERROR: fastai 1.0.60 has requirement torch>=1.0.0, but you'll have torch 0.4.1 which is incompatible.\u001b[0m\n",
            "\u001b[?25hTensorFlow 1.x selected.\n",
            "Torch version: 0.4.1\n",
            "Tensorflow version: 1.15.2\n",
            "Torch.cuda.is_available: True\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9pVt5QPS0FN_",
        "colab_type": "text"
      },
      "source": [
        "## Check model configurations (optional)\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "EYXHAyL55YGx",
        "colab_type": "code",
        "outputId": "4cf207f2-b5da-4440-dd02-dc59679de4ef",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 627
        }
      },
      "source": [
        "# move to project working directory\n",
        "%cd /content/ESRNN-GPU/\n",
        "\n",
        "# Check configuration\n",
        "import pprint\n",
        "from es_rnn.config import get_config\n",
        "\n",
        "config = get_config('Monthly')    # can be quarterly, monthly, daily or yearly. \n",
        "pprint.pprint(config)"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "/content/ESRNN-GPU\n",
            "{'add_nl_layer': True,\n",
            " 'batch_size': 1024,\n",
            " 'c_state_penalty': 0,\n",
            " 'chop_val': 72,\n",
            " 'device': 'cuda',\n",
            " 'dilations': ((1, 3), (6, 12)),\n",
            " 'gradient_clipping': 20,\n",
            " 'input_size': 12,\n",
            " 'input_size_i': 12,\n",
            " 'learning_rate': 0.001,\n",
            " 'learning_rates': (10, 0.0001),\n",
            " 'level_variability_penalty': 50,\n",
            " 'lr_anneal_rate': 0.5,\n",
            " 'lr_anneal_step': 5,\n",
            " 'lr_ratio': 3.1622776601683795,\n",
            " 'lr_tolerance_multip': 1.005,\n",
            " 'min_epochs_before_changing_lrate': 2,\n",
            " 'min_learning_rate': 0.0001,\n",
            " 'num_of_categories': 6,\n",
            " 'num_of_train_epochs': 15,\n",
            " 'output_size': 18,\n",
            " 'output_size_i': 18,\n",
            " 'percentile': 50,\n",
            " 'print_output_stats': 3,\n",
            " 'print_train_batch_every': 5,\n",
            " 'prod': True,\n",
            " 'rnn_cell_type': 'LSTM',\n",
            " 'seasonality': 12,\n",
            " 'state_hsize': 50,\n",
            " 'tau': 0.5,\n",
            " 'training_percentile': 45,\n",
            " 'training_tau': 0.45,\n",
            " 'variable': 'Monthly'}\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ehz5haXl9MKA",
        "colab_type": "text"
      },
      "source": [
        "## Edit model configurations (optional) "
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "cERmd0848p61",
        "colab_type": "code",
        "outputId": "c1fff371-08ec-4ae7-ba1e-62ecf1b34fcc",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        }
      },
      "source": [
        "# print config.py and copy code to clipboard \n",
        "!cat /es_rnn/config.py"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "cat: /es_rnn/config.py: No such file or directory\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "3baLvW4A9eD9",
        "colab_type": "code",
        "outputId": "70d614fa-87c0-41cf-a52e-76678a5fb3ff",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 35
        }
      },
      "source": [
        "%%writefile /content/ESRNN-GPU/es_rnn/config.py\n",
        "\n",
        "from math import sqrt\n",
        "\n",
        "import torch\n",
        "\n",
        "\n",
        "def get_config(interval):\n",
        "    config = {\n",
        "        'prod': True,\n",
        "        'device': (\"cuda\" if torch.cuda.is_available() else \"cpu\"),\n",
        "        'percentile': 50,\n",
        "        'training_percentile': 45,\n",
        "        'add_nl_layer': True,\n",
        "        'rnn_cell_type': 'LSTM',\n",
        "        'learning_rate': 1e-3,\n",
        "        'learning_rates': ((10, 1e-4)),\n",
        "        'num_of_train_epochs': 5,\n",
        "        'num_of_categories': 6,  # in data provided\n",
        "        'batch_size': 1024,\n",
        "        'gradient_clipping': 20,\n",
        "        'c_state_penalty': 0,\n",
        "        'min_learning_rate': 0.0001,\n",
        "        'lr_ratio': sqrt(10),\n",
        "        'lr_tolerance_multip': 1.005,\n",
        "        'min_epochs_before_changing_lrate': 2,\n",
        "        'print_train_batch_every': 5,\n",
        "        'print_output_stats': 3,\n",
        "        'lr_anneal_rate': 0.5,\n",
        "        'lr_anneal_step': 5\n",
        "    }\n",
        "\n",
        "    if interval == 'Quarterly':\n",
        "        config.update({\n",
        "            'chop_val': 72,\n",
        "            'variable': \"Quarterly\",\n",
        "            'dilations': ((1, 2), (4, 8)),\n",
        "            'state_hsize': 40,\n",
        "            'seasonality': 4,\n",
        "            'input_size': 4,\n",
        "            'output_size': 8,\n",
        "            'level_variability_penalty': 80\n",
        "        })\n",
        "    elif interval == 'Monthly':\n",
        "        config.update({\n",
        "            #     RUNTIME PARAMETERS\n",
        "            'chop_val': 72,\n",
        "            'variable': \"Monthly\",\n",
        "            'dilations': ((1, 3), (6, 12)),\n",
        "            'state_hsize': 50,\n",
        "            'seasonality': 12,\n",
        "            'input_size': 12,\n",
        "            'output_size': 18,\n",
        "            'level_variability_penalty': 50\n",
        "        })\n",
        "    elif interval == 'Daily':\n",
        "        config.update({\n",
        "            #     RUNTIME PARAMETERS\n",
        "            'chop_val': 200,\n",
        "            'variable': \"Daily\",\n",
        "            'dilations': ((1, 7), (14, 28)),\n",
        "            'state_hsize': 50,\n",
        "            'seasonality': 7,\n",
        "            'input_size': 7,\n",
        "            'output_size': 14,\n",
        "            'level_variability_penalty': 50\n",
        "        })\n",
        "    elif interval == 'Yearly':\n",
        "\n",
        "        config.update({\n",
        "            #     RUNTIME PARAMETERS\n",
        "            'chop_val': 25,\n",
        "            'variable': \"Yearly\",\n",
        "            'dilations': ((1, 2), (2, 6)),\n",
        "            'state_hsize': 30,\n",
        "            'seasonality': 1,\n",
        "            'input_size': 4,\n",
        "            'output_size': 6,\n",
        "            'level_variability_penalty': 0\n",
        "        })\n",
        "    else:\n",
        "        print(\"I don't have that config. :(\")\n",
        "\n",
        "    config['input_size_i'] = config['input_size']\n",
        "    config['output_size_i'] = config['output_size']\n",
        "    config['tau'] = config['percentile'] / 100\n",
        "    config['training_tau'] = config['training_percentile'] / 100\n",
        "\n",
        "    if not config['prod']:\n",
        "        config['batch_size'] = 10\n",
        "        config['num_of_train_epochs'] = 15\n",
        "\n",
        "    return config"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Overwriting /content/ESRNN-GPU/es_rnn/config.py\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "dD8_1psvxUka",
        "colab_type": "code",
        "outputId": "fd852628-9951-46ef-c791-294970a8e5d0",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        }
      },
      "source": [
        "# move to project working directory\n",
        "%cd /content/ESRNN-GPU/\n",
        "\n",
        "import pandas as pd\n",
        "from torch.utils.data import DataLoader\n",
        "from es_rnn.data_loading import create_datasets, SeriesDataset\n",
        "from es_rnn.config import get_config\n",
        "from es_rnn.trainer import ESRNNTrainer\n",
        "from es_rnn.model import ESRNN\n",
        "import time\n",
        "\n",
        "print('loading config')\n",
        "config = get_config('Monthly')\n",
        "\n",
        "print('loading data')\n",
        "info = pd.read_csv('/content/ESRNN-GPU/data/info.csv')\n",
        "\n",
        "train_path = '/content/ESRNN-GPU/data/Train/%s-train.csv' % (config['variable'])\n",
        "test_path = '/content/ESRNN-GPU/data/Test/%s-test.csv' % (config['variable'])\n",
        "\n",
        "train, val, test = create_datasets(train_path, test_path, config['output_size'])\n",
        "\n",
        "dataset = SeriesDataset(train, val, test, info, config['variable'], config['chop_val'], config['device'])\n",
        "dataloader = DataLoader(dataset, batch_size=config['batch_size'], shuffle=True)\n",
        "\n",
        "run_id = str(int(time.time()))\n",
        "model = ESRNN(num_series=len(dataset), config=config)\n",
        "tr = ESRNNTrainer(model, dataloader, run_id, config, ohe_headers=dataset.dataInfoCatHeaders)\n",
        "tr.train_epochs() "
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "/content/ESRNN-GPU\n",
            "loading config\n",
            "loading data\n",
            "WARNING:tensorflow:From /content/ESRNN-GPU/utils/logger.py:9: The name tf.summary.FileWriter is deprecated. Please use tf.compat.v1.summary.FileWriter instead.\n",
            "\n",
            "Train_batch: 1\n",
            "WARNING:tensorflow:From /content/ESRNN-GPU/utils/logger.py:20: The name tf.Summary is deprecated. Please use tf.compat.v1.Summary instead.\n",
            "\n",
            "Train_batch: 2\n",
            "Train_batch: 3\n",
            "Train_batch: 4\n",
            "Train_batch: 5\n",
            "Train_batch: 6\n",
            "Train_batch: 7\n",
            "Train_batch: 8\n",
            "Train_batch: 9\n",
            "Train_batch: 10\n",
            "Train_batch: 11\n",
            "Train_batch: 12\n",
            "Train_batch: 13\n",
            "Train_batch: 14\n",
            "Train_batch: 15\n",
            "Train_batch: 16\n",
            "Train_batch: 17\n",
            "Train_batch: 18\n",
            "Train_batch: 19\n",
            "Train_batch: 20\n",
            "Train_batch: 21\n",
            "Train_batch: 22\n",
            "Train_batch: 23\n",
            "Train_batch: 24\n",
            "Train_batch: 25\n",
            "Train_batch: 26\n",
            "Train_batch: 27\n",
            "Train_batch: 28\n",
            "Train_batch: 29\n",
            "Train_batch: 30\n",
            "Train_batch: 31\n",
            "Train_batch: 32\n",
            "Train_batch: 33\n",
            "Train_batch: 34\n",
            "Train_batch: 35\n",
            "[TRAIN]  Epoch [1/15]   Loss: 23.5622\n",
            "WARNING:tensorflow:From /content/ESRNN-GPU/utils/logger.py:32: The name tf.HistogramProto is deprecated. Please use tf.compat.v1.HistogramProto instead.\n",
            "\n",
            "Loss decreased, saving model!\n",
            "{'Demographic': 9.967270851135254, 'Finance': 14.129257202148438, 'Industry': 14.062990188598633, 'Macro': 14.504400253295898, 'Micro': 12.656795501708984, 'Other': 13.857166290283203, 'Overall': 13.255921363830566, 'hold_out_loss': 9.428352355957031}\n",
            "Train_batch: 1\n",
            "Train_batch: 2\n",
            "Train_batch: 3\n",
            "Train_batch: 4\n",
            "Train_batch: 5\n",
            "Train_batch: 6\n",
            "Train_batch: 7\n",
            "Train_batch: 8\n",
            "Train_batch: 9\n",
            "Train_batch: 10\n",
            "Train_batch: 11\n",
            "Train_batch: 12\n",
            "Train_batch: 13\n",
            "Train_batch: 14\n",
            "Train_batch: 15\n",
            "Train_batch: 16\n",
            "Train_batch: 17\n",
            "Train_batch: 18\n",
            "Train_batch: 19\n",
            "Train_batch: 20\n",
            "Train_batch: 21\n",
            "Train_batch: 22\n",
            "Train_batch: 23\n",
            "Train_batch: 24\n",
            "Train_batch: 25\n",
            "Train_batch: 26\n",
            "Train_batch: 27\n",
            "Train_batch: 28\n",
            "Train_batch: 29\n",
            "Train_batch: 30\n",
            "Train_batch: 31\n",
            "Train_batch: 32\n",
            "Train_batch: 33\n",
            "Train_batch: 34\n",
            "Train_batch: 35\n",
            "[TRAIN]  Epoch [2/15]   Loss: 6.0620\n",
            "Loss decreased, saving model!\n",
            "{'Demographic': 5.617069244384766, 'Finance': 10.89091968536377, 'Industry': 10.572635650634766, 'Macro': 11.038947105407715, 'Micro': 8.760212898254395, 'Other': 10.539490699768066, 'Overall': 9.613520622253418, 'hold_out_loss': 6.416837692260742}\n",
            "Train_batch: 1\n",
            "Train_batch: 2\n",
            "Train_batch: 3\n",
            "Train_batch: 4\n",
            "Train_batch: 5\n",
            "Train_batch: 6\n",
            "Train_batch: 7\n",
            "Train_batch: 8\n",
            "Train_batch: 9\n",
            "Train_batch: 10\n",
            "Train_batch: 11\n",
            "Train_batch: 12\n",
            "Train_batch: 13\n",
            "Train_batch: 14\n",
            "Train_batch: 15\n",
            "Train_batch: 16\n",
            "Train_batch: 17\n",
            "Train_batch: 18\n",
            "Train_batch: 19\n",
            "Train_batch: 20\n",
            "Train_batch: 21\n",
            "Train_batch: 22\n",
            "Train_batch: 23\n",
            "Train_batch: 24\n",
            "Train_batch: 25\n",
            "Train_batch: 26\n",
            "Train_batch: 27\n",
            "Train_batch: 28\n",
            "Train_batch: 29\n",
            "Train_batch: 30\n",
            "Train_batch: 31\n",
            "Train_batch: 32\n",
            "Train_batch: 33\n",
            "Train_batch: 34\n",
            "Train_batch: 35\n",
            "[TRAIN]  Epoch [3/15]   Loss: 5.0237\n",
            "Loss decreased, saving model!\n",
            "{'Demographic': 5.255251884460449, 'Finance': 10.675373077392578, 'Industry': 10.28479290008545, 'Macro': 10.747631072998047, 'Micro': 8.443456649780273, 'Other': 10.35895824432373, 'Overall': 9.323665618896484, 'hold_out_loss': 5.998364448547363}\n",
            "Train_batch: 1\n",
            "Train_batch: 2\n",
            "Train_batch: 3\n",
            "Train_batch: 4\n",
            "Train_batch: 5\n",
            "Train_batch: 6\n",
            "Train_batch: 7\n",
            "Train_batch: 8\n",
            "Train_batch: 9\n",
            "Train_batch: 10\n",
            "Train_batch: 11\n",
            "Train_batch: 12\n",
            "Train_batch: 13\n",
            "Train_batch: 14\n",
            "Train_batch: 15\n",
            "Train_batch: 16\n",
            "Train_batch: 17\n",
            "Train_batch: 18\n",
            "Train_batch: 19\n",
            "Train_batch: 20\n",
            "Train_batch: 21\n",
            "Train_batch: 22\n",
            "Train_batch: 23\n",
            "Train_batch: 24\n",
            "Train_batch: 25\n",
            "Train_batch: 26\n",
            "Train_batch: 27\n",
            "Train_batch: 28\n",
            "Train_batch: 29\n",
            "Train_batch: 30\n",
            "Train_batch: 31\n",
            "Train_batch: 32\n",
            "Train_batch: 33\n",
            "Train_batch: 34\n",
            "Train_batch: 35\n",
            "[TRAIN]  Epoch [4/15]   Loss: 4.7732\n",
            "Loss decreased, saving model!\n",
            "{'Demographic': 5.173051834106445, 'Finance': 10.586264610290527, 'Industry': 10.207746505737305, 'Macro': 10.641335487365723, 'Micro': 8.287897109985352, 'Other': 10.417609214782715, 'Overall': 9.224047660827637, 'hold_out_loss': 5.846914291381836}\n",
            "Train_batch: 1\n",
            "Train_batch: 2\n",
            "Train_batch: 3\n",
            "Train_batch: 4\n",
            "Train_batch: 5\n",
            "Train_batch: 6\n",
            "Train_batch: 7\n",
            "Train_batch: 8\n",
            "Train_batch: 9\n",
            "Train_batch: 10\n",
            "Train_batch: 11\n",
            "Train_batch: 12\n",
            "Train_batch: 13\n",
            "Train_batch: 14\n",
            "Train_batch: 15\n",
            "Train_batch: 16\n",
            "Train_batch: 17\n",
            "Train_batch: 18\n",
            "Train_batch: 19\n",
            "Train_batch: 20\n",
            "Train_batch: 21\n",
            "Train_batch: 22\n",
            "Train_batch: 23\n",
            "Train_batch: 24\n",
            "Train_batch: 25\n",
            "Train_batch: 26\n",
            "Train_batch: 27\n",
            "Train_batch: 28\n",
            "Train_batch: 29\n",
            "Train_batch: 30\n",
            "Train_batch: 31\n",
            "Train_batch: 32\n",
            "Train_batch: 33\n",
            "Train_batch: 34\n",
            "Train_batch: 35\n",
            "[TRAIN]  Epoch [5/15]   Loss: 4.6542\n",
            "Loss decreased, saving model!\n",
            "{'Demographic': 5.164522171020508, 'Finance': 10.578543663024902, 'Industry': 10.187837600708008, 'Macro': 10.623577117919922, 'Micro': 8.26281452178955, 'Other': 10.45866870880127, 'Overall': 9.20844554901123, 'hold_out_loss': 5.779313087463379}\n",
            "Train_batch: 1\n",
            "Train_batch: 2\n",
            "Train_batch: 3\n",
            "Train_batch: 4\n",
            "Train_batch: 5\n",
            "Train_batch: 6\n",
            "Train_batch: 7\n",
            "Train_batch: 8\n",
            "Train_batch: 9\n",
            "Train_batch: 10\n",
            "Train_batch: 11\n",
            "Train_batch: 12\n",
            "Train_batch: 13\n",
            "Train_batch: 14\n",
            "Train_batch: 15\n",
            "Train_batch: 16\n",
            "Train_batch: 17\n",
            "Train_batch: 18\n",
            "Train_batch: 19\n",
            "Train_batch: 20\n",
            "Train_batch: 21\n",
            "Train_batch: 22\n",
            "Train_batch: 23\n",
            "Train_batch: 24\n",
            "Train_batch: 25\n",
            "Train_batch: 26\n",
            "Train_batch: 27\n",
            "Train_batch: 28\n",
            "Train_batch: 29\n",
            "Train_batch: 30\n",
            "Train_batch: 31\n",
            "Train_batch: 32\n",
            "Train_batch: 33\n",
            "Train_batch: 34\n",
            "Train_batch: 35\n",
            "[TRAIN]  Epoch [6/15]   Loss: 4.5945\n",
            "Loss decreased, saving model!\n",
            "{'Demographic': 5.154304027557373, 'Finance': 10.566814422607422, 'Industry': 10.174098014831543, 'Macro': 10.608725547790527, 'Micro': 8.236871719360352, 'Other': 10.460794448852539, 'Overall': 9.193401336669922, 'hold_out_loss': 5.757056713104248}\n",
            "Train_batch: 1\n",
            "Train_batch: 2\n",
            "Train_batch: 3\n",
            "Train_batch: 4\n",
            "Train_batch: 5\n",
            "Train_batch: 6\n",
            "Train_batch: 7\n",
            "Train_batch: 8\n",
            "Train_batch: 9\n",
            "Train_batch: 10\n",
            "Train_batch: 11\n",
            "Train_batch: 12\n",
            "Train_batch: 13\n",
            "Train_batch: 14\n",
            "Train_batch: 15\n",
            "Train_batch: 16\n",
            "Train_batch: 17\n",
            "Train_batch: 18\n",
            "Train_batch: 19\n",
            "Train_batch: 20\n",
            "Train_batch: 21\n",
            "Train_batch: 22\n",
            "Train_batch: 23\n",
            "Train_batch: 24\n",
            "Train_batch: 25\n",
            "Train_batch: 26\n",
            "Train_batch: 27\n",
            "Train_batch: 28\n",
            "Train_batch: 29\n",
            "Train_batch: 30\n",
            "Train_batch: 31\n",
            "Train_batch: 32\n",
            "Train_batch: 33\n",
            "Train_batch: 34\n",
            "Train_batch: 35\n",
            "[TRAIN]  Epoch [7/15]   Loss: 4.5661\n",
            "Loss decreased, saving model!\n",
            "{'Demographic': 5.155148506164551, 'Finance': 10.550554275512695, 'Industry': 10.173151016235352, 'Macro': 10.588346481323242, 'Micro': 8.188667297363281, 'Other': 10.482839584350586, 'Overall': 9.177229881286621, 'hold_out_loss': 5.744265556335449}\n",
            "Train_batch: 1\n",
            "Train_batch: 2\n",
            "Train_batch: 3\n",
            "Train_batch: 4\n",
            "Train_batch: 5\n",
            "Train_batch: 6\n",
            "Train_batch: 7\n",
            "Train_batch: 8\n",
            "Train_batch: 9\n",
            "Train_batch: 10\n",
            "Train_batch: 11\n",
            "Train_batch: 12\n",
            "Train_batch: 13\n",
            "Train_batch: 14\n",
            "Train_batch: 15\n",
            "Train_batch: 16\n",
            "Train_batch: 17\n",
            "Train_batch: 18\n",
            "Train_batch: 19\n",
            "Train_batch: 20\n",
            "Train_batch: 21\n",
            "Train_batch: 22\n",
            "Train_batch: 23\n",
            "Train_batch: 24\n",
            "Train_batch: 25\n",
            "Train_batch: 26\n",
            "Train_batch: 27\n",
            "Train_batch: 28\n",
            "Train_batch: 29\n",
            "Train_batch: 30\n",
            "Train_batch: 31\n",
            "Train_batch: 32\n",
            "Train_batch: 33\n",
            "Train_batch: 34\n",
            "Train_batch: 35\n",
            "[TRAIN]  Epoch [8/15]   Loss: 4.5431\n",
            "Loss decreased, saving model!\n",
            "{'Demographic': 5.154897212982178, 'Finance': 10.566178321838379, 'Industry': 10.165489196777344, 'Macro': 10.598137855529785, 'Micro': 8.2199068069458, 'Other': 10.474693298339844, 'Overall': 9.186197280883789, 'hold_out_loss': 5.735326766967773}\n",
            "Train_batch: 1\n",
            "Train_batch: 2\n",
            "Train_batch: 3\n",
            "Train_batch: 4\n",
            "Train_batch: 5\n",
            "Train_batch: 6\n",
            "Train_batch: 7\n",
            "Train_batch: 8\n",
            "Train_batch: 9\n",
            "Train_batch: 10\n",
            "Train_batch: 11\n",
            "Train_batch: 12\n",
            "Train_batch: 13\n",
            "Train_batch: 14\n",
            "Train_batch: 15\n",
            "Train_batch: 16\n",
            "Train_batch: 17\n",
            "Train_batch: 18\n",
            "Train_batch: 19\n",
            "Train_batch: 20\n",
            "Train_batch: 21\n",
            "Train_batch: 22\n",
            "Train_batch: 23\n",
            "Train_batch: 24\n",
            "Train_batch: 25\n",
            "Train_batch: 26\n",
            "Train_batch: 27\n",
            "Train_batch: 28\n",
            "Train_batch: 29\n",
            "Train_batch: 30\n",
            "Train_batch: 31\n",
            "Train_batch: 32\n",
            "Train_batch: 33\n",
            "Train_batch: 34\n",
            "Train_batch: 35\n",
            "[TRAIN]  Epoch [9/15]   Loss: 4.5236\n",
            "Loss decreased, saving model!\n",
            "{'Demographic': 5.148446083068848, 'Finance': 10.555113792419434, 'Industry': 10.15269947052002, 'Macro': 10.580646514892578, 'Micro': 8.18532943725586, 'Other': 10.47098445892334, 'Overall': 9.170007705688477, 'hold_out_loss': 5.723298072814941}\n",
            "Train_batch: 1\n",
            "Train_batch: 2\n",
            "Train_batch: 3\n",
            "Train_batch: 4\n",
            "Train_batch: 5\n",
            "Train_batch: 6\n",
            "Train_batch: 7\n",
            "Train_batch: 8\n",
            "Train_batch: 9\n",
            "Train_batch: 10\n",
            "Train_batch: 11\n",
            "Train_batch: 12\n",
            "Train_batch: 13\n",
            "Train_batch: 14\n",
            "Train_batch: 15\n",
            "Train_batch: 16\n",
            "Train_batch: 17\n",
            "Train_batch: 18\n",
            "Train_batch: 19\n",
            "Train_batch: 20\n",
            "Train_batch: 21\n",
            "Train_batch: 22\n",
            "Train_batch: 23\n",
            "Train_batch: 24\n",
            "Train_batch: 25\n",
            "Train_batch: 26\n",
            "Train_batch: 27\n",
            "Train_batch: 28\n",
            "Train_batch: 29\n",
            "Train_batch: 30\n",
            "Train_batch: 31\n",
            "Train_batch: 32\n",
            "Train_batch: 33\n",
            "Train_batch: 34\n",
            "Train_batch: 35\n",
            "[TRAIN]  Epoch [10/15]   Loss: 4.5057\n",
            "Loss decreased, saving model!\n",
            "{'Demographic': 5.152772426605225, 'Finance': 10.554600715637207, 'Industry': 10.157341957092285, 'Macro': 10.57533073425293, 'Micro': 8.166055679321289, 'Other': 10.48807430267334, 'Overall': 9.167271614074707, 'hold_out_loss': 5.719578266143799}\n",
            "Train_batch: 1\n",
            "Train_batch: 2\n",
            "Train_batch: 3\n",
            "Train_batch: 4\n",
            "Train_batch: 5\n",
            "Train_batch: 6\n",
            "Train_batch: 7\n",
            "Train_batch: 8\n",
            "Train_batch: 9\n",
            "Train_batch: 10\n",
            "Train_batch: 11\n",
            "Train_batch: 12\n",
            "Train_batch: 13\n",
            "Train_batch: 14\n",
            "Train_batch: 15\n",
            "Train_batch: 16\n",
            "Train_batch: 17\n",
            "Train_batch: 18\n",
            "Train_batch: 19\n",
            "Train_batch: 20\n",
            "Train_batch: 21\n",
            "Train_batch: 22\n",
            "Train_batch: 23\n",
            "Train_batch: 24\n",
            "Train_batch: 25\n",
            "Train_batch: 26\n",
            "Train_batch: 27\n",
            "Train_batch: 28\n",
            "Train_batch: 29\n",
            "Train_batch: 30\n",
            "Train_batch: 31\n",
            "Train_batch: 32\n",
            "Train_batch: 33\n",
            "Train_batch: 34\n",
            "Train_batch: 35\n",
            "[TRAIN]  Epoch [11/15]   Loss: 4.4916\n",
            "Loss decreased, saving model!\n",
            "{'Demographic': 5.1514434814453125, 'Finance': 10.558562278747559, 'Industry': 10.150896072387695, 'Macro': 10.575738906860352, 'Micro': 8.171943664550781, 'Other': 10.482748985290527, 'Overall': 9.167464256286621, 'hold_out_loss': 5.716712951660156}\n",
            "Train_batch: 1\n",
            "Train_batch: 2\n",
            "Train_batch: 3\n",
            "Train_batch: 4\n",
            "Train_batch: 5\n",
            "Train_batch: 6\n",
            "Train_batch: 7\n",
            "Train_batch: 8\n",
            "Train_batch: 9\n",
            "Train_batch: 10\n",
            "Train_batch: 11\n",
            "Train_batch: 12\n",
            "Train_batch: 13\n",
            "Train_batch: 14\n",
            "Train_batch: 15\n",
            "Train_batch: 16\n",
            "Train_batch: 17\n",
            "Train_batch: 18\n",
            "Train_batch: 19\n",
            "Train_batch: 20\n",
            "Train_batch: 21\n",
            "Train_batch: 22\n",
            "Train_batch: 23\n",
            "Train_batch: 24\n",
            "Train_batch: 25\n",
            "Train_batch: 26\n",
            "Train_batch: 27\n",
            "Train_batch: 28\n",
            "Train_batch: 29\n",
            "Train_batch: 30\n",
            "Train_batch: 31\n",
            "Train_batch: 32\n",
            "Train_batch: 33\n",
            "Train_batch: 34\n",
            "Train_batch: 35\n",
            "[TRAIN]  Epoch [12/15]   Loss: 4.4839\n",
            "Loss decreased, saving model!\n",
            "{'Demographic': 5.161124229431152, 'Finance': 10.557811737060547, 'Industry': 10.139664649963379, 'Macro': 10.573782920837402, 'Micro': 8.167551040649414, 'Other': 10.473078727722168, 'Overall': 9.164985656738281, 'hold_out_loss': 5.712883472442627}\n",
            "Train_batch: 1\n",
            "Train_batch: 2\n",
            "Train_batch: 3\n",
            "Train_batch: 4\n",
            "Train_batch: 5\n",
            "Train_batch: 6\n",
            "Train_batch: 7\n",
            "Train_batch: 8\n",
            "Train_batch: 9\n",
            "Train_batch: 10\n",
            "Train_batch: 11\n",
            "Train_batch: 12\n",
            "Train_batch: 13\n",
            "Train_batch: 14\n",
            "Train_batch: 15\n",
            "Train_batch: 16\n",
            "Train_batch: 17\n",
            "Train_batch: 18\n",
            "Train_batch: 19\n",
            "Train_batch: 20\n",
            "Train_batch: 21\n",
            "Train_batch: 22\n",
            "Train_batch: 23\n",
            "Train_batch: 24\n",
            "Train_batch: 25\n",
            "Train_batch: 26\n",
            "Train_batch: 27\n",
            "Train_batch: 28\n",
            "Train_batch: 29\n",
            "Train_batch: 30\n",
            "Train_batch: 31\n",
            "Train_batch: 32\n",
            "Train_batch: 33\n",
            "Train_batch: 34\n",
            "Train_batch: 35\n",
            "[TRAIN]  Epoch [13/15]   Loss: 4.4770\n",
            "Loss decreased, saving model!\n",
            "{'Demographic': 5.154560089111328, 'Finance': 10.553889274597168, 'Industry': 10.143238067626953, 'Macro': 10.56783676147461, 'Micro': 8.161402702331543, 'Other': 10.475221633911133, 'Overall': 9.161617279052734, 'hold_out_loss': 5.711756229400635}\n",
            "Train_batch: 1\n",
            "Train_batch: 2\n",
            "Train_batch: 3\n",
            "Train_batch: 4\n",
            "Train_batch: 5\n",
            "Train_batch: 6\n",
            "Train_batch: 7\n",
            "Train_batch: 8\n",
            "Train_batch: 9\n",
            "Train_batch: 10\n",
            "Train_batch: 11\n",
            "Train_batch: 12\n",
            "Train_batch: 13\n",
            "Train_batch: 14\n",
            "Train_batch: 15\n",
            "Train_batch: 16\n",
            "Train_batch: 17\n",
            "Train_batch: 18\n",
            "Train_batch: 19\n",
            "Train_batch: 20\n",
            "Train_batch: 21\n",
            "Train_batch: 22\n",
            "Train_batch: 23\n",
            "Train_batch: 24\n",
            "Train_batch: 25\n",
            "Train_batch: 26\n",
            "Train_batch: 27\n",
            "Train_batch: 28\n",
            "Train_batch: 29\n",
            "Train_batch: 30\n",
            "Train_batch: 31\n",
            "Train_batch: 32\n",
            "Train_batch: 33\n",
            "Train_batch: 34\n",
            "Train_batch: 35\n",
            "[TRAIN]  Epoch [14/15]   Loss: 4.4707\n",
            "Loss decreased, saving model!\n",
            "{'Demographic': 5.163309097290039, 'Finance': 10.553582191467285, 'Industry': 10.137224197387695, 'Macro': 10.565842628479004, 'Micro': 8.156747817993164, 'Other': 10.47332763671875, 'Overall': 9.16030216217041, 'hold_out_loss': 5.708845138549805}\n",
            "Train_batch: 1\n",
            "Train_batch: 2\n",
            "Train_batch: 3\n",
            "Train_batch: 4\n",
            "Train_batch: 5\n",
            "Train_batch: 6\n",
            "Train_batch: 7\n",
            "Train_batch: 8\n",
            "Train_batch: 9\n",
            "Train_batch: 10\n",
            "Train_batch: 11\n",
            "Train_batch: 12\n",
            "Train_batch: 13\n",
            "Train_batch: 14\n",
            "Train_batch: 15\n",
            "Train_batch: 16\n",
            "Train_batch: 17\n",
            "Train_batch: 18\n",
            "Train_batch: 19\n",
            "Train_batch: 20\n",
            "Train_batch: 21\n",
            "Train_batch: 22\n",
            "Train_batch: 23\n",
            "Train_batch: 24\n",
            "Train_batch: 25\n",
            "Train_batch: 26\n",
            "Train_batch: 27\n",
            "Train_batch: 28\n",
            "Train_batch: 29\n",
            "Train_batch: 30\n",
            "Train_batch: 31\n",
            "Train_batch: 32\n",
            "Train_batch: 33\n",
            "Train_batch: 34\n",
            "Train_batch: 35\n",
            "[TRAIN]  Epoch [15/15]   Loss: 4.4641\n",
            "Loss decreased, saving model!\n",
            "{'Demographic': 5.177516460418701, 'Finance': 10.559904098510742, 'Industry': 10.120597839355469, 'Macro': 10.571897506713867, 'Micro': 8.172297477722168, 'Other': 10.448819160461426, 'Overall': 9.163888931274414, 'hold_out_loss': 5.705763339996338}\n",
            "Total Training Mins: 64.15\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Nh2_90Fm-nPX",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}