{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "es-rnn-colab-nb-example.ipynb", "provenance": [], "collapsed_sections": [], "authorship_tag": "ABX9TyO3yaYBU+ZfWti+MxkoFxGY", "include_colab_link": true }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "view-in-github", "colab_type": "text" }, "source": [ "<a href=\"https://colab.research.google.com/gist/florisrc/e9bc4cb054c6ad1c8f1ac43a9c21d09f/es-rnn-colab-nb-example.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" ] }, { "cell_type": "markdown", "metadata": { "id": "IBEgiTXP4u4o", "colab_type": "text" }, "source": [ "# ES-RNN Colab NB Example\n", "\n", "A GPU-enabled version of the hybrid ES-RNN model by Slawek et al that won the M4 time-series forecasting competition by a large margin, here implemented in a Google Colab environment. The details of our implementation and the results are discussed in detail on this [paper](https://arxiv.org/abs/1907.03329).\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "id": "FuudmEWFycm5", "colab_type": "text" }, "source": [ "## Get data and code:\n" ] }, { "cell_type": "code", "metadata": { "id": "ovCK3_O-xU5E", "colab_type": "code", "outputId": "572ca8da-cab3-4050-bf55-50ca703f2b3b", "colab": { "base_uri": "https://localhost:8080/", "height": 1000 } }, "source": [ "# get data\n", "%cd /content\n", "!mkdir /content/m4_data \n", "%cd /content/m4_data\n", "!wget https://www.m4.unic.ac.cy/wp-content/uploads/2017/12/M4DataSet.zip\n", "!wget https://www.m4.unic.ac.cy/wp-content/uploads/2018/07/M-test-set.zip\n", "!wget https://github.com/M4Competition/M4-methods/raw/master/Dataset/M4-info.csv\n", "!mkdir ./Train && cd ./Train && unzip ../M4DataSet.zip && cd ..\n", "!mkdir ./Test && cd ./Test && unzip ../M-test-set.zip && cd ..\n", "\n", "# clone git repo\n", "%cd /content\n", "!git clone https://github.com/damitkwr/ESRNN-GPU.git\n", "\n", "# copy data to repo\n", "%cd /content/ESRNN-GPU/\n", "!mkdir ./data\n", "%cd data/\n", "!mkdir ./Train && cp /content/m4_data/Train/* ./Train/\n", "!mkdir ./Test && cp /content/m4_data/Test/* ./Test/\n", "!cp /content/m4_data/M4-info.csv ./info.csv\n", "!cd ../.." ], "execution_count": 0, "outputs": [ { "output_type": "stream", "text": [ "/content\n", "/content/m4_data\n", "--2020-04-20 10:14:51-- https://www.m4.unic.ac.cy/wp-content/uploads/2017/12/M4DataSet.zip\n", "Resolving www.m4.unic.ac.cy (www.m4.unic.ac.cy)... 35.177.142.35, 35.176.90.68\n", "Connecting to www.m4.unic.ac.cy (www.m4.unic.ac.cy)|35.177.142.35|:443... connected.\n", "HTTP request sent, awaiting response... 200 OK\n", "Length: 66613994 (64M) [application/zip]\n", "Saving to: ‘M4DataSet.zip’\n", "\n", "M4DataSet.zip 100%[===================>] 63.53M 19.3MB/s in 3.6s \n", "\n", "2020-04-20 10:14:55 (17.9 MB/s) - ‘M4DataSet.zip’ saved [66613994/66613994]\n", "\n", "--2020-04-20 10:14:56-- https://www.m4.unic.ac.cy/wp-content/uploads/2018/07/M-test-set.zip\n", "Resolving www.m4.unic.ac.cy (www.m4.unic.ac.cy)... 35.177.142.35, 35.176.90.68\n", "Connecting to www.m4.unic.ac.cy (www.m4.unic.ac.cy)|35.177.142.35|:443... connected.\n", "HTTP request sent, awaiting response... 200 OK\n", "Length: 3723045 (3.5M) [application/zip]\n", "Saving to: ‘M-test-set.zip’\n", "\n", "M-test-set.zip 100%[===================>] 3.55M 4.20MB/s in 0.8s \n", "\n", "2020-04-20 10:14:57 (4.20 MB/s) - ‘M-test-set.zip’ saved [3723045/3723045]\n", "\n", "--2020-04-20 10:14:58-- https://github.com/M4Competition/M4-methods/raw/master/Dataset/M4-info.csv\n", "Resolving github.com (github.com)... 192.30.255.112\n", "Connecting to github.com (github.com)|192.30.255.112|:443... connected.\n", "HTTP request sent, awaiting response... 301 Moved Permanently\n", "Location: https://github.com/Mcompetitions/M4-methods/raw/master/Dataset/M4-info.csv [following]\n", "--2020-04-20 10:14:58-- https://github.com/Mcompetitions/M4-methods/raw/master/Dataset/M4-info.csv\n", "Reusing existing connection to github.com:443.\n", "HTTP request sent, awaiting response... 302 Found\n", "Location: https://raw.githubusercontent.com/Mcompetitions/M4-methods/master/Dataset/M4-info.csv [following]\n", "--2020-04-20 10:14:59-- https://raw.githubusercontent.com/Mcompetitions/M4-methods/master/Dataset/M4-info.csv\n", "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.0.133, 151.101.64.133, 151.101.128.133, ...\n", "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.0.133|:443... connected.\n", "HTTP request sent, awaiting response... 200 OK\n", "Length: 4335598 (4.1M) [text/plain]\n", "Saving to: ‘M4-info.csv’\n", "\n", "M4-info.csv 100%[===================>] 4.13M --.-KB/s in 0.1s \n", "\n", "2020-04-20 10:14:59 (31.7 MB/s) - ‘M4-info.csv’ saved [4335598/4335598]\n", "\n", "Archive: ../M4DataSet.zip\n", " inflating: Daily-train.csv \n", " inflating: Hourly-train.csv \n", " inflating: Monthly-train.csv \n", " inflating: Quarterly-train.csv \n", " inflating: Weekly-train.csv \n", " inflating: Yearly-train.csv \n", "Archive: ../M-test-set.zip\n", " inflating: Daily-test.csv \n", " inflating: Hourly-test.csv \n", " inflating: Monthly-test.csv \n", " inflating: Quarterly-test.csv \n", " inflating: Weekly-test.csv \n", " inflating: Yearly-test.csv \n", "/content\n", "Cloning into 'ESRNN-GPU'...\n", "remote: Enumerating objects: 34, done.\u001b[K\n", "remote: Counting objects: 100% (34/34), done.\u001b[K\n", "remote: Compressing objects: 100% (25/25), done.\u001b[K\n", "remote: Total 521 (delta 18), reused 23 (delta 9), pack-reused 487\u001b[K\n", "Receiving objects: 100% (521/521), 76.96 MiB | 21.29 MiB/s, done.\n", "Resolving deltas: 100% (330/330), done.\n", "/content/ESRNN-GPU\n", "/content/ESRNN-GPU/data\n" ], "name": "stdout" } ] }, { "cell_type": "markdown", "metadata": { "id": "BhYOwpDjyiqP", "colab_type": "text" }, "source": [ "## Create colab environment with correct library versions" ] }, { "cell_type": "code", "metadata": { "id": "c9l1BWaOxUuw", "colab_type": "code", "outputId": "c53f5d40-2f6e-4574-db11-4ebc2dae8029", "colab": { "base_uri": "https://localhost:8080/", "height": 268 } }, "source": [ "# uninstall torch \n", "!pip uninstall torch\n", "!pip uninstall torch # run twice (recommendation pytorch forums)\n", "\n", "# and re-install as 0.4.1\n", "from os import path\n", "from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag\n", "platform = '{}{}-{}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag())\n", "\n", "accelerator = 'cu80' if path.exists('/opt/bin/nvidia-smi') else 'cpu'\n", "\n", "!pip install -q http://download.pytorch.org/whl/{accelerator}/torch-0.4.1-{platform}-linux_x86_64.whl torchvision\n", "\n", "# tensorflow version 1 \n", "%tensorflow_version 1.x\n", "\n", "import torch\n", "import tensorflow as tf \n", "print(f'Torch version: {torch.__version__}')\n", "print(f'Tensorflow version: {tf.__version__}')\n", "print(f'Torch.cuda.is_available: {torch.cuda.is_available()}')" ], "execution_count": 0, "outputs": [ { "output_type": "stream", "text": [ "Uninstalling torch-0.4.1:\n", " Would remove:\n", " /usr/local/lib/python3.6/dist-packages/torch-0.4.1.dist-info/*\n", " /usr/local/lib/python3.6/dist-packages/torch/*\n", "Proceed (y/n)? y\n", " Successfully uninstalled torch-0.4.1\n", "\u001b[33mWARNING: Skipping torch as it is not installed.\u001b[0m\n", "\u001b[K |████████████████████████████████| 483.0MB 1.2MB/s \n", "\u001b[31mERROR: torchvision 0.5.0 has requirement torch==1.4.0, but you'll have torch 0.4.1 which is incompatible.\u001b[0m\n", "\u001b[31mERROR: fastai 1.0.60 has requirement torch>=1.0.0, but you'll have torch 0.4.1 which is incompatible.\u001b[0m\n", "\u001b[?25hTensorFlow 1.x selected.\n", "Torch version: 0.4.1\n", "Tensorflow version: 1.15.2\n", "Torch.cuda.is_available: True\n" ], "name": "stdout" } ] }, { "cell_type": "markdown", "metadata": { "id": "9pVt5QPS0FN_", "colab_type": "text" }, "source": [ "## Check model configurations (optional)\n" ] }, { "cell_type": "code", "metadata": { "id": "EYXHAyL55YGx", "colab_type": "code", "outputId": "4cf207f2-b5da-4440-dd02-dc59679de4ef", "colab": { "base_uri": "https://localhost:8080/", "height": 627 } }, "source": [ "# move to project working directory\n", "%cd /content/ESRNN-GPU/\n", "\n", "# Check configuration\n", "import pprint\n", "from es_rnn.config import get_config\n", "\n", "config = get_config('Monthly') # can be quarterly, monthly, daily or yearly. \n", "pprint.pprint(config)" ], "execution_count": 0, "outputs": [ { "output_type": "stream", "text": [ "/content/ESRNN-GPU\n", "{'add_nl_layer': True,\n", " 'batch_size': 1024,\n", " 'c_state_penalty': 0,\n", " 'chop_val': 72,\n", " 'device': 'cuda',\n", " 'dilations': ((1, 3), (6, 12)),\n", " 'gradient_clipping': 20,\n", " 'input_size': 12,\n", " 'input_size_i': 12,\n", " 'learning_rate': 0.001,\n", " 'learning_rates': (10, 0.0001),\n", " 'level_variability_penalty': 50,\n", " 'lr_anneal_rate': 0.5,\n", " 'lr_anneal_step': 5,\n", " 'lr_ratio': 3.1622776601683795,\n", " 'lr_tolerance_multip': 1.005,\n", " 'min_epochs_before_changing_lrate': 2,\n", " 'min_learning_rate': 0.0001,\n", " 'num_of_categories': 6,\n", " 'num_of_train_epochs': 15,\n", " 'output_size': 18,\n", " 'output_size_i': 18,\n", " 'percentile': 50,\n", " 'print_output_stats': 3,\n", " 'print_train_batch_every': 5,\n", " 'prod': True,\n", " 'rnn_cell_type': 'LSTM',\n", " 'seasonality': 12,\n", " 'state_hsize': 50,\n", " 'tau': 0.5,\n", " 'training_percentile': 45,\n", " 'training_tau': 0.45,\n", " 'variable': 'Monthly'}\n" ], "name": "stdout" } ] }, { "cell_type": "markdown", "metadata": { "id": "Ehz5haXl9MKA", "colab_type": "text" }, "source": [ "## Edit model configurations (optional) " ] }, { "cell_type": "code", "metadata": { "id": "cERmd0848p61", "colab_type": "code", "outputId": "c1fff371-08ec-4ae7-ba1e-62ecf1b34fcc", "colab": { "base_uri": "https://localhost:8080/", "height": 35 } }, "source": [ "# print config.py and copy code to clipboard \n", "!cat /es_rnn/config.py" ], "execution_count": 0, "outputs": [ { "output_type": "stream", "text": [ "cat: /es_rnn/config.py: No such file or directory\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "3baLvW4A9eD9", "colab_type": "code", "outputId": "70d614fa-87c0-41cf-a52e-76678a5fb3ff", "colab": { "base_uri": "https://localhost:8080/", "height": 35 } }, "source": [ "%%writefile /content/ESRNN-GPU/es_rnn/config.py\n", "\n", "from math import sqrt\n", "\n", "import torch\n", "\n", "\n", "def get_config(interval):\n", " config = {\n", " 'prod': True,\n", " 'device': (\"cuda\" if torch.cuda.is_available() else \"cpu\"),\n", " 'percentile': 50,\n", " 'training_percentile': 45,\n", " 'add_nl_layer': True,\n", " 'rnn_cell_type': 'LSTM',\n", " 'learning_rate': 1e-3,\n", " 'learning_rates': ((10, 1e-4)),\n", " 'num_of_train_epochs': 5,\n", " 'num_of_categories': 6, # in data provided\n", " 'batch_size': 1024,\n", " 'gradient_clipping': 20,\n", " 'c_state_penalty': 0,\n", " 'min_learning_rate': 0.0001,\n", " 'lr_ratio': sqrt(10),\n", " 'lr_tolerance_multip': 1.005,\n", " 'min_epochs_before_changing_lrate': 2,\n", " 'print_train_batch_every': 5,\n", " 'print_output_stats': 3,\n", " 'lr_anneal_rate': 0.5,\n", " 'lr_anneal_step': 5\n", " }\n", "\n", " if interval == 'Quarterly':\n", " config.update({\n", " 'chop_val': 72,\n", " 'variable': \"Quarterly\",\n", " 'dilations': ((1, 2), (4, 8)),\n", " 'state_hsize': 40,\n", " 'seasonality': 4,\n", " 'input_size': 4,\n", " 'output_size': 8,\n", " 'level_variability_penalty': 80\n", " })\n", " elif interval == 'Monthly':\n", " config.update({\n", " # RUNTIME PARAMETERS\n", " 'chop_val': 72,\n", " 'variable': \"Monthly\",\n", " 'dilations': ((1, 3), (6, 12)),\n", " 'state_hsize': 50,\n", " 'seasonality': 12,\n", " 'input_size': 12,\n", " 'output_size': 18,\n", " 'level_variability_penalty': 50\n", " })\n", " elif interval == 'Daily':\n", " config.update({\n", " # RUNTIME PARAMETERS\n", " 'chop_val': 200,\n", " 'variable': \"Daily\",\n", " 'dilations': ((1, 7), (14, 28)),\n", " 'state_hsize': 50,\n", " 'seasonality': 7,\n", " 'input_size': 7,\n", " 'output_size': 14,\n", " 'level_variability_penalty': 50\n", " })\n", " elif interval == 'Yearly':\n", "\n", " config.update({\n", " # RUNTIME PARAMETERS\n", " 'chop_val': 25,\n", " 'variable': \"Yearly\",\n", " 'dilations': ((1, 2), (2, 6)),\n", " 'state_hsize': 30,\n", " 'seasonality': 1,\n", " 'input_size': 4,\n", " 'output_size': 6,\n", " 'level_variability_penalty': 0\n", " })\n", " else:\n", " print(\"I don't have that config. :(\")\n", "\n", " config['input_size_i'] = config['input_size']\n", " config['output_size_i'] = config['output_size']\n", " config['tau'] = config['percentile'] / 100\n", " config['training_tau'] = config['training_percentile'] / 100\n", "\n", " if not config['prod']:\n", " config['batch_size'] = 10\n", " config['num_of_train_epochs'] = 15\n", "\n", " return config" ], "execution_count": 0, "outputs": [ { "output_type": "stream", "text": [ "Overwriting /content/ESRNN-GPU/es_rnn/config.py\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "dD8_1psvxUka", "colab_type": "code", "outputId": "fd852628-9951-46ef-c791-294970a8e5d0", "colab": { "base_uri": "https://localhost:8080/", "height": 1000 } }, "source": [ "# move to project working directory\n", "%cd /content/ESRNN-GPU/\n", "\n", "import pandas as pd\n", "from torch.utils.data import DataLoader\n", "from es_rnn.data_loading import create_datasets, SeriesDataset\n", "from es_rnn.config import get_config\n", "from es_rnn.trainer import ESRNNTrainer\n", "from es_rnn.model import ESRNN\n", "import time\n", "\n", "print('loading config')\n", "config = get_config('Monthly')\n", "\n", "print('loading data')\n", "info = pd.read_csv('/content/ESRNN-GPU/data/info.csv')\n", "\n", "train_path = '/content/ESRNN-GPU/data/Train/%s-train.csv' % (config['variable'])\n", "test_path = '/content/ESRNN-GPU/data/Test/%s-test.csv' % (config['variable'])\n", "\n", "train, val, test = create_datasets(train_path, test_path, config['output_size'])\n", "\n", "dataset = SeriesDataset(train, val, test, info, config['variable'], config['chop_val'], config['device'])\n", "dataloader = DataLoader(dataset, batch_size=config['batch_size'], shuffle=True)\n", "\n", "run_id = str(int(time.time()))\n", "model = ESRNN(num_series=len(dataset), config=config)\n", "tr = ESRNNTrainer(model, dataloader, run_id, config, ohe_headers=dataset.dataInfoCatHeaders)\n", "tr.train_epochs() " ], "execution_count": 0, "outputs": [ { "output_type": "stream", "text": [ "/content/ESRNN-GPU\n", "loading config\n", "loading data\n", "WARNING:tensorflow:From /content/ESRNN-GPU/utils/logger.py:9: The name tf.summary.FileWriter is deprecated. Please use tf.compat.v1.summary.FileWriter instead.\n", "\n", "Train_batch: 1\n", "WARNING:tensorflow:From /content/ESRNN-GPU/utils/logger.py:20: The name tf.Summary is deprecated. Please use tf.compat.v1.Summary instead.\n", "\n", "Train_batch: 2\n", "Train_batch: 3\n", "Train_batch: 4\n", "Train_batch: 5\n", "Train_batch: 6\n", "Train_batch: 7\n", "Train_batch: 8\n", "Train_batch: 9\n", "Train_batch: 10\n", "Train_batch: 11\n", "Train_batch: 12\n", "Train_batch: 13\n", "Train_batch: 14\n", "Train_batch: 15\n", "Train_batch: 16\n", "Train_batch: 17\n", "Train_batch: 18\n", "Train_batch: 19\n", "Train_batch: 20\n", "Train_batch: 21\n", "Train_batch: 22\n", "Train_batch: 23\n", "Train_batch: 24\n", "Train_batch: 25\n", "Train_batch: 26\n", "Train_batch: 27\n", "Train_batch: 28\n", "Train_batch: 29\n", "Train_batch: 30\n", "Train_batch: 31\n", "Train_batch: 32\n", "Train_batch: 33\n", "Train_batch: 34\n", "Train_batch: 35\n", "[TRAIN] Epoch [1/15] Loss: 23.5622\n", "WARNING:tensorflow:From /content/ESRNN-GPU/utils/logger.py:32: The name tf.HistogramProto is deprecated. Please use tf.compat.v1.HistogramProto instead.\n", "\n", "Loss decreased, saving model!\n", "{'Demographic': 9.967270851135254, 'Finance': 14.129257202148438, 'Industry': 14.062990188598633, 'Macro': 14.504400253295898, 'Micro': 12.656795501708984, 'Other': 13.857166290283203, 'Overall': 13.255921363830566, 'hold_out_loss': 9.428352355957031}\n", "Train_batch: 1\n", "Train_batch: 2\n", "Train_batch: 3\n", "Train_batch: 4\n", "Train_batch: 5\n", "Train_batch: 6\n", "Train_batch: 7\n", "Train_batch: 8\n", "Train_batch: 9\n", "Train_batch: 10\n", "Train_batch: 11\n", "Train_batch: 12\n", "Train_batch: 13\n", "Train_batch: 14\n", "Train_batch: 15\n", "Train_batch: 16\n", "Train_batch: 17\n", "Train_batch: 18\n", "Train_batch: 19\n", "Train_batch: 20\n", "Train_batch: 21\n", "Train_batch: 22\n", "Train_batch: 23\n", "Train_batch: 24\n", "Train_batch: 25\n", "Train_batch: 26\n", "Train_batch: 27\n", "Train_batch: 28\n", "Train_batch: 29\n", "Train_batch: 30\n", "Train_batch: 31\n", "Train_batch: 32\n", "Train_batch: 33\n", "Train_batch: 34\n", "Train_batch: 35\n", "[TRAIN] Epoch [2/15] Loss: 6.0620\n", "Loss decreased, saving model!\n", "{'Demographic': 5.617069244384766, 'Finance': 10.89091968536377, 'Industry': 10.572635650634766, 'Macro': 11.038947105407715, 'Micro': 8.760212898254395, 'Other': 10.539490699768066, 'Overall': 9.613520622253418, 'hold_out_loss': 6.416837692260742}\n", "Train_batch: 1\n", "Train_batch: 2\n", "Train_batch: 3\n", "Train_batch: 4\n", "Train_batch: 5\n", "Train_batch: 6\n", "Train_batch: 7\n", "Train_batch: 8\n", "Train_batch: 9\n", "Train_batch: 10\n", "Train_batch: 11\n", "Train_batch: 12\n", "Train_batch: 13\n", "Train_batch: 14\n", "Train_batch: 15\n", "Train_batch: 16\n", "Train_batch: 17\n", "Train_batch: 18\n", "Train_batch: 19\n", "Train_batch: 20\n", "Train_batch: 21\n", "Train_batch: 22\n", "Train_batch: 23\n", "Train_batch: 24\n", "Train_batch: 25\n", "Train_batch: 26\n", "Train_batch: 27\n", "Train_batch: 28\n", "Train_batch: 29\n", "Train_batch: 30\n", "Train_batch: 31\n", "Train_batch: 32\n", "Train_batch: 33\n", "Train_batch: 34\n", "Train_batch: 35\n", "[TRAIN] Epoch [3/15] Loss: 5.0237\n", "Loss decreased, saving model!\n", "{'Demographic': 5.255251884460449, 'Finance': 10.675373077392578, 'Industry': 10.28479290008545, 'Macro': 10.747631072998047, 'Micro': 8.443456649780273, 'Other': 10.35895824432373, 'Overall': 9.323665618896484, 'hold_out_loss': 5.998364448547363}\n", "Train_batch: 1\n", "Train_batch: 2\n", "Train_batch: 3\n", "Train_batch: 4\n", "Train_batch: 5\n", "Train_batch: 6\n", "Train_batch: 7\n", "Train_batch: 8\n", "Train_batch: 9\n", "Train_batch: 10\n", "Train_batch: 11\n", "Train_batch: 12\n", "Train_batch: 13\n", "Train_batch: 14\n", "Train_batch: 15\n", "Train_batch: 16\n", "Train_batch: 17\n", "Train_batch: 18\n", "Train_batch: 19\n", "Train_batch: 20\n", "Train_batch: 21\n", "Train_batch: 22\n", "Train_batch: 23\n", "Train_batch: 24\n", "Train_batch: 25\n", "Train_batch: 26\n", "Train_batch: 27\n", "Train_batch: 28\n", "Train_batch: 29\n", "Train_batch: 30\n", "Train_batch: 31\n", "Train_batch: 32\n", "Train_batch: 33\n", "Train_batch: 34\n", "Train_batch: 35\n", "[TRAIN] Epoch [4/15] Loss: 4.7732\n", "Loss decreased, saving model!\n", "{'Demographic': 5.173051834106445, 'Finance': 10.586264610290527, 'Industry': 10.207746505737305, 'Macro': 10.641335487365723, 'Micro': 8.287897109985352, 'Other': 10.417609214782715, 'Overall': 9.224047660827637, 'hold_out_loss': 5.846914291381836}\n", "Train_batch: 1\n", "Train_batch: 2\n", "Train_batch: 3\n", "Train_batch: 4\n", "Train_batch: 5\n", "Train_batch: 6\n", "Train_batch: 7\n", "Train_batch: 8\n", "Train_batch: 9\n", "Train_batch: 10\n", "Train_batch: 11\n", "Train_batch: 12\n", "Train_batch: 13\n", "Train_batch: 14\n", "Train_batch: 15\n", "Train_batch: 16\n", "Train_batch: 17\n", "Train_batch: 18\n", "Train_batch: 19\n", "Train_batch: 20\n", "Train_batch: 21\n", "Train_batch: 22\n", "Train_batch: 23\n", "Train_batch: 24\n", "Train_batch: 25\n", "Train_batch: 26\n", "Train_batch: 27\n", "Train_batch: 28\n", "Train_batch: 29\n", "Train_batch: 30\n", "Train_batch: 31\n", "Train_batch: 32\n", "Train_batch: 33\n", "Train_batch: 34\n", "Train_batch: 35\n", "[TRAIN] Epoch [5/15] Loss: 4.6542\n", "Loss decreased, saving model!\n", "{'Demographic': 5.164522171020508, 'Finance': 10.578543663024902, 'Industry': 10.187837600708008, 'Macro': 10.623577117919922, 'Micro': 8.26281452178955, 'Other': 10.45866870880127, 'Overall': 9.20844554901123, 'hold_out_loss': 5.779313087463379}\n", "Train_batch: 1\n", "Train_batch: 2\n", "Train_batch: 3\n", "Train_batch: 4\n", "Train_batch: 5\n", "Train_batch: 6\n", "Train_batch: 7\n", "Train_batch: 8\n", "Train_batch: 9\n", "Train_batch: 10\n", "Train_batch: 11\n", "Train_batch: 12\n", "Train_batch: 13\n", "Train_batch: 14\n", "Train_batch: 15\n", "Train_batch: 16\n", "Train_batch: 17\n", "Train_batch: 18\n", "Train_batch: 19\n", "Train_batch: 20\n", "Train_batch: 21\n", "Train_batch: 22\n", "Train_batch: 23\n", "Train_batch: 24\n", "Train_batch: 25\n", "Train_batch: 26\n", "Train_batch: 27\n", "Train_batch: 28\n", "Train_batch: 29\n", "Train_batch: 30\n", "Train_batch: 31\n", "Train_batch: 32\n", "Train_batch: 33\n", "Train_batch: 34\n", "Train_batch: 35\n", "[TRAIN] Epoch [6/15] Loss: 4.5945\n", "Loss decreased, saving model!\n", "{'Demographic': 5.154304027557373, 'Finance': 10.566814422607422, 'Industry': 10.174098014831543, 'Macro': 10.608725547790527, 'Micro': 8.236871719360352, 'Other': 10.460794448852539, 'Overall': 9.193401336669922, 'hold_out_loss': 5.757056713104248}\n", "Train_batch: 1\n", "Train_batch: 2\n", "Train_batch: 3\n", "Train_batch: 4\n", "Train_batch: 5\n", "Train_batch: 6\n", "Train_batch: 7\n", "Train_batch: 8\n", "Train_batch: 9\n", "Train_batch: 10\n", "Train_batch: 11\n", "Train_batch: 12\n", "Train_batch: 13\n", "Train_batch: 14\n", "Train_batch: 15\n", "Train_batch: 16\n", "Train_batch: 17\n", "Train_batch: 18\n", "Train_batch: 19\n", "Train_batch: 20\n", "Train_batch: 21\n", "Train_batch: 22\n", "Train_batch: 23\n", "Train_batch: 24\n", "Train_batch: 25\n", "Train_batch: 26\n", "Train_batch: 27\n", "Train_batch: 28\n", "Train_batch: 29\n", "Train_batch: 30\n", "Train_batch: 31\n", "Train_batch: 32\n", "Train_batch: 33\n", "Train_batch: 34\n", "Train_batch: 35\n", "[TRAIN] Epoch [7/15] Loss: 4.5661\n", "Loss decreased, saving model!\n", "{'Demographic': 5.155148506164551, 'Finance': 10.550554275512695, 'Industry': 10.173151016235352, 'Macro': 10.588346481323242, 'Micro': 8.188667297363281, 'Other': 10.482839584350586, 'Overall': 9.177229881286621, 'hold_out_loss': 5.744265556335449}\n", "Train_batch: 1\n", "Train_batch: 2\n", "Train_batch: 3\n", "Train_batch: 4\n", "Train_batch: 5\n", "Train_batch: 6\n", "Train_batch: 7\n", "Train_batch: 8\n", "Train_batch: 9\n", "Train_batch: 10\n", "Train_batch: 11\n", "Train_batch: 12\n", "Train_batch: 13\n", "Train_batch: 14\n", "Train_batch: 15\n", "Train_batch: 16\n", "Train_batch: 17\n", "Train_batch: 18\n", "Train_batch: 19\n", "Train_batch: 20\n", "Train_batch: 21\n", "Train_batch: 22\n", "Train_batch: 23\n", "Train_batch: 24\n", "Train_batch: 25\n", "Train_batch: 26\n", "Train_batch: 27\n", "Train_batch: 28\n", "Train_batch: 29\n", "Train_batch: 30\n", "Train_batch: 31\n", "Train_batch: 32\n", "Train_batch: 33\n", "Train_batch: 34\n", "Train_batch: 35\n", "[TRAIN] Epoch [8/15] Loss: 4.5431\n", "Loss decreased, saving model!\n", "{'Demographic': 5.154897212982178, 'Finance': 10.566178321838379, 'Industry': 10.165489196777344, 'Macro': 10.598137855529785, 'Micro': 8.2199068069458, 'Other': 10.474693298339844, 'Overall': 9.186197280883789, 'hold_out_loss': 5.735326766967773}\n", "Train_batch: 1\n", "Train_batch: 2\n", "Train_batch: 3\n", "Train_batch: 4\n", "Train_batch: 5\n", "Train_batch: 6\n", "Train_batch: 7\n", "Train_batch: 8\n", "Train_batch: 9\n", "Train_batch: 10\n", "Train_batch: 11\n", "Train_batch: 12\n", "Train_batch: 13\n", "Train_batch: 14\n", "Train_batch: 15\n", "Train_batch: 16\n", "Train_batch: 17\n", "Train_batch: 18\n", "Train_batch: 19\n", "Train_batch: 20\n", "Train_batch: 21\n", "Train_batch: 22\n", "Train_batch: 23\n", "Train_batch: 24\n", "Train_batch: 25\n", "Train_batch: 26\n", "Train_batch: 27\n", "Train_batch: 28\n", "Train_batch: 29\n", "Train_batch: 30\n", "Train_batch: 31\n", "Train_batch: 32\n", "Train_batch: 33\n", "Train_batch: 34\n", "Train_batch: 35\n", "[TRAIN] Epoch [9/15] Loss: 4.5236\n", "Loss decreased, saving model!\n", "{'Demographic': 5.148446083068848, 'Finance': 10.555113792419434, 'Industry': 10.15269947052002, 'Macro': 10.580646514892578, 'Micro': 8.18532943725586, 'Other': 10.47098445892334, 'Overall': 9.170007705688477, 'hold_out_loss': 5.723298072814941}\n", "Train_batch: 1\n", "Train_batch: 2\n", "Train_batch: 3\n", "Train_batch: 4\n", "Train_batch: 5\n", "Train_batch: 6\n", "Train_batch: 7\n", "Train_batch: 8\n", "Train_batch: 9\n", "Train_batch: 10\n", "Train_batch: 11\n", "Train_batch: 12\n", "Train_batch: 13\n", "Train_batch: 14\n", "Train_batch: 15\n", "Train_batch: 16\n", "Train_batch: 17\n", "Train_batch: 18\n", "Train_batch: 19\n", "Train_batch: 20\n", "Train_batch: 21\n", "Train_batch: 22\n", "Train_batch: 23\n", "Train_batch: 24\n", "Train_batch: 25\n", "Train_batch: 26\n", "Train_batch: 27\n", "Train_batch: 28\n", "Train_batch: 29\n", "Train_batch: 30\n", "Train_batch: 31\n", "Train_batch: 32\n", "Train_batch: 33\n", "Train_batch: 34\n", "Train_batch: 35\n", "[TRAIN] Epoch [10/15] Loss: 4.5057\n", "Loss decreased, saving model!\n", "{'Demographic': 5.152772426605225, 'Finance': 10.554600715637207, 'Industry': 10.157341957092285, 'Macro': 10.57533073425293, 'Micro': 8.166055679321289, 'Other': 10.48807430267334, 'Overall': 9.167271614074707, 'hold_out_loss': 5.719578266143799}\n", "Train_batch: 1\n", "Train_batch: 2\n", "Train_batch: 3\n", "Train_batch: 4\n", "Train_batch: 5\n", "Train_batch: 6\n", "Train_batch: 7\n", "Train_batch: 8\n", "Train_batch: 9\n", "Train_batch: 10\n", "Train_batch: 11\n", "Train_batch: 12\n", "Train_batch: 13\n", "Train_batch: 14\n", "Train_batch: 15\n", "Train_batch: 16\n", "Train_batch: 17\n", "Train_batch: 18\n", "Train_batch: 19\n", "Train_batch: 20\n", "Train_batch: 21\n", "Train_batch: 22\n", "Train_batch: 23\n", "Train_batch: 24\n", "Train_batch: 25\n", "Train_batch: 26\n", "Train_batch: 27\n", "Train_batch: 28\n", "Train_batch: 29\n", "Train_batch: 30\n", "Train_batch: 31\n", "Train_batch: 32\n", "Train_batch: 33\n", "Train_batch: 34\n", "Train_batch: 35\n", "[TRAIN] Epoch [11/15] Loss: 4.4916\n", "Loss decreased, saving model!\n", "{'Demographic': 5.1514434814453125, 'Finance': 10.558562278747559, 'Industry': 10.150896072387695, 'Macro': 10.575738906860352, 'Micro': 8.171943664550781, 'Other': 10.482748985290527, 'Overall': 9.167464256286621, 'hold_out_loss': 5.716712951660156}\n", "Train_batch: 1\n", "Train_batch: 2\n", "Train_batch: 3\n", "Train_batch: 4\n", "Train_batch: 5\n", "Train_batch: 6\n", "Train_batch: 7\n", "Train_batch: 8\n", "Train_batch: 9\n", "Train_batch: 10\n", "Train_batch: 11\n", "Train_batch: 12\n", "Train_batch: 13\n", "Train_batch: 14\n", "Train_batch: 15\n", "Train_batch: 16\n", "Train_batch: 17\n", "Train_batch: 18\n", "Train_batch: 19\n", "Train_batch: 20\n", "Train_batch: 21\n", "Train_batch: 22\n", "Train_batch: 23\n", "Train_batch: 24\n", "Train_batch: 25\n", "Train_batch: 26\n", "Train_batch: 27\n", "Train_batch: 28\n", "Train_batch: 29\n", "Train_batch: 30\n", "Train_batch: 31\n", "Train_batch: 32\n", "Train_batch: 33\n", "Train_batch: 34\n", "Train_batch: 35\n", "[TRAIN] Epoch [12/15] Loss: 4.4839\n", "Loss decreased, saving model!\n", "{'Demographic': 5.161124229431152, 'Finance': 10.557811737060547, 'Industry': 10.139664649963379, 'Macro': 10.573782920837402, 'Micro': 8.167551040649414, 'Other': 10.473078727722168, 'Overall': 9.164985656738281, 'hold_out_loss': 5.712883472442627}\n", "Train_batch: 1\n", "Train_batch: 2\n", "Train_batch: 3\n", "Train_batch: 4\n", "Train_batch: 5\n", "Train_batch: 6\n", "Train_batch: 7\n", "Train_batch: 8\n", "Train_batch: 9\n", "Train_batch: 10\n", "Train_batch: 11\n", "Train_batch: 12\n", "Train_batch: 13\n", "Train_batch: 14\n", "Train_batch: 15\n", "Train_batch: 16\n", "Train_batch: 17\n", "Train_batch: 18\n", "Train_batch: 19\n", "Train_batch: 20\n", "Train_batch: 21\n", "Train_batch: 22\n", "Train_batch: 23\n", "Train_batch: 24\n", "Train_batch: 25\n", "Train_batch: 26\n", "Train_batch: 27\n", "Train_batch: 28\n", "Train_batch: 29\n", "Train_batch: 30\n", "Train_batch: 31\n", "Train_batch: 32\n", "Train_batch: 33\n", "Train_batch: 34\n", "Train_batch: 35\n", "[TRAIN] Epoch [13/15] Loss: 4.4770\n", "Loss decreased, saving model!\n", "{'Demographic': 5.154560089111328, 'Finance': 10.553889274597168, 'Industry': 10.143238067626953, 'Macro': 10.56783676147461, 'Micro': 8.161402702331543, 'Other': 10.475221633911133, 'Overall': 9.161617279052734, 'hold_out_loss': 5.711756229400635}\n", "Train_batch: 1\n", "Train_batch: 2\n", "Train_batch: 3\n", "Train_batch: 4\n", "Train_batch: 5\n", "Train_batch: 6\n", "Train_batch: 7\n", "Train_batch: 8\n", "Train_batch: 9\n", "Train_batch: 10\n", "Train_batch: 11\n", "Train_batch: 12\n", "Train_batch: 13\n", "Train_batch: 14\n", "Train_batch: 15\n", "Train_batch: 16\n", "Train_batch: 17\n", "Train_batch: 18\n", "Train_batch: 19\n", "Train_batch: 20\n", "Train_batch: 21\n", "Train_batch: 22\n", "Train_batch: 23\n", "Train_batch: 24\n", "Train_batch: 25\n", "Train_batch: 26\n", "Train_batch: 27\n", "Train_batch: 28\n", "Train_batch: 29\n", "Train_batch: 30\n", "Train_batch: 31\n", "Train_batch: 32\n", "Train_batch: 33\n", "Train_batch: 34\n", "Train_batch: 35\n", "[TRAIN] Epoch [14/15] Loss: 4.4707\n", "Loss decreased, saving model!\n", "{'Demographic': 5.163309097290039, 'Finance': 10.553582191467285, 'Industry': 10.137224197387695, 'Macro': 10.565842628479004, 'Micro': 8.156747817993164, 'Other': 10.47332763671875, 'Overall': 9.16030216217041, 'hold_out_loss': 5.708845138549805}\n", "Train_batch: 1\n", "Train_batch: 2\n", "Train_batch: 3\n", "Train_batch: 4\n", "Train_batch: 5\n", "Train_batch: 6\n", "Train_batch: 7\n", "Train_batch: 8\n", "Train_batch: 9\n", "Train_batch: 10\n", "Train_batch: 11\n", "Train_batch: 12\n", "Train_batch: 13\n", "Train_batch: 14\n", "Train_batch: 15\n", "Train_batch: 16\n", "Train_batch: 17\n", "Train_batch: 18\n", "Train_batch: 19\n", "Train_batch: 20\n", "Train_batch: 21\n", "Train_batch: 22\n", "Train_batch: 23\n", "Train_batch: 24\n", "Train_batch: 25\n", "Train_batch: 26\n", "Train_batch: 27\n", "Train_batch: 28\n", "Train_batch: 29\n", "Train_batch: 30\n", "Train_batch: 31\n", "Train_batch: 32\n", "Train_batch: 33\n", "Train_batch: 34\n", "Train_batch: 35\n", "[TRAIN] Epoch [15/15] Loss: 4.4641\n", "Loss decreased, saving model!\n", "{'Demographic': 5.177516460418701, 'Finance': 10.559904098510742, 'Industry': 10.120597839355469, 'Macro': 10.571897506713867, 'Micro': 8.172297477722168, 'Other': 10.448819160461426, 'Overall': 9.163888931274414, 'hold_out_loss': 5.705763339996338}\n", "Total Training Mins: 64.15\n" ], "name": "stdout" } ] }, { "cell_type": "code", "metadata": { "id": "Nh2_90Fm-nPX", "colab_type": "code", "colab": {} }, "source": [ "" ], "execution_count": 0, "outputs": [] } ] }