Skip to content

Instantly share code, notes, and snippets.

@kativenOG
Last active September 5, 2024 20:08
Show Gist options
  • Save kativenOG/530716747f32fb585780e847f1bf428d to your computer and use it in GitHub Desktop.
Save kativenOG/530716747f32fb585780e847f1bf428d to your computer and use it in GitHub Desktop.
newadvancedprogramming.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/kativenOG/530716747f32fb585780e847f1bf428d/newadvancedprogramming.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tU0dG551-T-q"
},
"source": [
"# **Overcoming catastrophic forgetting ER-GNN**\n",
" ***A continual graph learning project for the advanced topics in AI course.***"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Ar54CWNf6C--",
"outputId": "d8379ddd-c99d-4558-f500-4c8e0c85b671"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Looking in indexes: https://download.pytorch.org/whl/cu121\n",
"Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.4.0+cu121)\n",
"Requirement already satisfied: torchvision in /usr/local/lib/python3.10/dist-packages (0.19.0+cu121)\n",
"Requirement already satisfied: torchaudio in /usr/local/lib/python3.10/dist-packages (2.4.0+cu121)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch) (3.15.4)\n",
"Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch) (4.12.2)\n",
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch) (1.13.2)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.3)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.4)\n",
"Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch) (2024.6.1)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from torchvision) (1.26.4)\n",
"Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.10/dist-packages (from torchvision) (9.4.0)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (2.1.5)\n",
"Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->torch) (1.3.0)\n",
"Looking in indexes: https://download.pytorch.org/whl/cpu\n",
"Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.4.0+cu121)\n",
"Requirement already satisfied: torchvision in /usr/local/lib/python3.10/dist-packages (0.19.0+cu121)\n",
"Requirement already satisfied: torchaudio in /usr/local/lib/python3.10/dist-packages (2.4.0+cu121)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch) (3.15.4)\n",
"Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch) (4.12.2)\n",
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch) (1.13.2)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.3)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.4)\n",
"Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch) (2024.6.1)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from torchvision) (1.26.4)\n",
"Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.10/dist-packages (from torchvision) (9.4.0)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (2.1.5)\n",
"Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->torch) (1.3.0)\n",
"\n",
"2.4.0+cu121\n",
"12.1\n",
"nvcc: NVIDIA (R) Cuda compiler driver\n",
"Copyright (c) 2005-2023 NVIDIA Corporation\n",
"Built on Tue_Aug_15_22:02:13_PDT_2023\n",
"Cuda compilation tools, release 12.2, V12.2.140\n",
"Build cuda_12.2.r12.2/compiler.33191640_0\n",
"\n",
"Collecting torch_geometric\n",
" Downloading torch_geometric-2.5.3-py3-none-any.whl.metadata (64 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m64.2/64.2 kB\u001b[0m \u001b[31m1.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting icecream\n",
" Downloading icecream-2.1.3-py2.py3-none-any.whl.metadata (1.4 kB)\n",
"Collecting prettyprint\n",
" Downloading prettyprint-0.1.5.tar.gz (2.1 kB)\n",
" Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
"Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (4.66.5)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from torch_geometric) (1.26.4)\n",
"Requirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (from torch_geometric) (1.13.1)\n",
"Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch_geometric) (2024.6.1)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch_geometric) (3.1.4)\n",
"Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from torch_geometric) (3.10.5)\n",
"Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from torch_geometric) (2.32.3)\n",
"Requirement already satisfied: pyparsing in /usr/local/lib/python3.10/dist-packages (from torch_geometric) (3.1.4)\n",
"Requirement already satisfied: scikit-learn in /usr/local/lib/python3.10/dist-packages (from torch_geometric) (1.3.2)\n",
"Requirement already satisfied: psutil>=5.8.0 in /usr/local/lib/python3.10/dist-packages (from torch_geometric) (5.9.5)\n",
"Collecting colorama>=0.3.9 (from icecream)\n",
" Downloading colorama-0.4.6-py2.py3-none-any.whl.metadata (17 kB)\n",
"Requirement already satisfied: pygments>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from icecream) (2.16.1)\n",
"Collecting executing>=0.3.1 (from icecream)\n",
" Downloading executing-2.1.0-py2.py3-none-any.whl.metadata (8.9 kB)\n",
"Collecting asttokens>=2.0.1 (from icecream)\n",
" Downloading asttokens-2.4.1-py2.py3-none-any.whl.metadata (5.2 kB)\n",
"Requirement already satisfied: six>=1.12.0 in /usr/local/lib/python3.10/dist-packages (from asttokens>=2.0.1->icecream) (1.16.0)\n",
"Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->torch_geometric) (2.4.0)\n",
"Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->torch_geometric) (1.3.1)\n",
"Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->torch_geometric) (24.2.0)\n",
"Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->torch_geometric) (1.4.1)\n",
"Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->torch_geometric) (6.0.5)\n",
"Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->torch_geometric) (1.9.4)\n",
"Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->torch_geometric) (4.0.3)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch_geometric) (2.1.5)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->torch_geometric) (3.3.2)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->torch_geometric) (3.8)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->torch_geometric) (2.0.7)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->torch_geometric) (2024.8.30)\n",
"Requirement already satisfied: joblib>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from scikit-learn->torch_geometric) (1.4.2)\n",
"Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn->torch_geometric) (3.5.0)\n",
"Downloading torch_geometric-2.5.3-py3-none-any.whl (1.1 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m9.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading icecream-2.1.3-py2.py3-none-any.whl (8.4 kB)\n",
"Downloading asttokens-2.4.1-py2.py3-none-any.whl (27 kB)\n",
"Downloading colorama-0.4.6-py2.py3-none-any.whl (25 kB)\n",
"Downloading executing-2.1.0-py2.py3-none-any.whl (25 kB)\n",
"Building wheels for collected packages: prettyprint\n",
" Building wheel for prettyprint (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for prettyprint: filename=prettyprint-0.1.5-py3-none-any.whl size=3027 sha256=e36101ff3764c9d5e9fcaa671b1cb4832ec16f9c71092ddfd490bcb2c886b266\n",
" Stored in directory: /root/.cache/pip/wheels/b2/d0/51/477413885481c635ab7c6400f96f47b8a0971bbc1241ff9c9f\n",
"Successfully built prettyprint\n",
"Installing collected packages: prettyprint, executing, colorama, asttokens, icecream, torch_geometric\n",
"Successfully installed asttokens-2.4.1 colorama-0.4.6 executing-2.1.0 icecream-2.1.3 prettyprint-0.1.5 torch_geometric-2.5.3\n",
"Looking in links: https://data.pyg.org/whl/torch-2.4.0+cu121.html\n",
"Collecting pyg_lib\n",
" Downloading https://data.pyg.org/whl/torch-2.4.0%2Bcu121/pyg_lib-0.4.0%2Bpt24cu121-cp310-cp310-linux_x86_64.whl (2.5 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.5/2.5 MB\u001b[0m \u001b[31m23.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting torch_scatter\n",
" Downloading https://data.pyg.org/whl/torch-2.4.0%2Bcu121/torch_scatter-2.1.2%2Bpt24cu121-cp310-cp310-linux_x86_64.whl (10.9 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m10.9/10.9 MB\u001b[0m \u001b[31m28.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting torch_sparse\n",
" Downloading https://data.pyg.org/whl/torch-2.4.0%2Bcu121/torch_sparse-0.6.18%2Bpt24cu121-cp310-cp310-linux_x86_64.whl (5.1 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.1/5.1 MB\u001b[0m \u001b[31m28.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting torch_cluster\n",
" Downloading https://data.pyg.org/whl/torch-2.4.0%2Bcu121/torch_cluster-1.6.3%2Bpt24cu121-cp310-cp310-linux_x86_64.whl (3.4 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.4/3.4 MB\u001b[0m \u001b[31m22.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting torch_spline_conv\n",
" Downloading https://data.pyg.org/whl/torch-2.4.0%2Bcu121/torch_spline_conv-1.2.2%2Bpt24cu121-cp310-cp310-linux_x86_64.whl (986 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m986.2/986.2 kB\u001b[0m \u001b[31m22.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (from torch_sparse) (1.13.1)\n",
"Requirement already satisfied: numpy<2.3,>=1.22.4 in /usr/local/lib/python3.10/dist-packages (from scipy->torch_sparse) (1.26.4)\n",
"Installing collected packages: torch_spline_conv, torch_scatter, pyg_lib, torch_sparse, torch_cluster\n",
"Successfully installed pyg_lib-0.4.0+pt24cu121 torch_cluster-1.6.3+pt24cu121 torch_scatter-2.1.2+pt24cu121 torch_sparse-0.6.18+pt24cu121 torch_spline_conv-1.2.2+pt24cu121\n",
"Looking in links: https://data.pyg.org/whl/torch-2.4.0+cpu.html\n",
"Requirement already satisfied: pyg_lib in /usr/local/lib/python3.10/dist-packages (0.4.0+pt24cu121)\n",
"Requirement already satisfied: torch_scatter in /usr/local/lib/python3.10/dist-packages (2.1.2+pt24cu121)\n",
"Requirement already satisfied: torch_sparse in /usr/local/lib/python3.10/dist-packages (0.6.18+pt24cu121)\n",
"Requirement already satisfied: torch_cluster in /usr/local/lib/python3.10/dist-packages (1.6.3+pt24cu121)\n",
"Requirement already satisfied: torch_spline_conv in /usr/local/lib/python3.10/dist-packages (1.2.2+pt24cu121)\n",
"Requirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (from torch_sparse) (1.13.1)\n",
"Requirement already satisfied: numpy<2.3,>=1.22.4 in /usr/local/lib/python3.10/dist-packages (from scipy->torch_sparse) (1.26.4)\n",
"Collecting mypy\n",
" Downloading mypy-1.11.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl.metadata (1.9 kB)\n",
"Requirement already satisfied: typing-extensions>=4.6.0 in /usr/local/lib/python3.10/dist-packages (from mypy) (4.12.2)\n",
"Collecting mypy-extensions>=1.0.0 (from mypy)\n",
" Downloading mypy_extensions-1.0.0-py3-none-any.whl.metadata (1.1 kB)\n",
"Requirement already satisfied: tomli>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from mypy) (2.0.1)\n",
"Downloading mypy-1.11.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl (12.5 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m12.5/12.5 MB\u001b[0m \u001b[31m19.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading mypy_extensions-1.0.0-py3-none-any.whl (4.7 kB)\n",
"Installing collected packages: mypy-extensions, mypy\n",
"Successfully installed mypy-1.11.2 mypy-extensions-1.0.0\n",
"Collecting nbqa[toolchain]\n",
" Downloading nbqa-1.9.0-py3-none-any.whl.metadata (31 kB)\n",
"Collecting autopep8>=1.5 (from nbqa[toolchain])\n",
" Downloading autopep8-2.3.1-py2.py3-none-any.whl.metadata (16 kB)\n",
"Requirement already satisfied: ipython>=7.8.0 in /usr/local/lib/python3.10/dist-packages (from nbqa[toolchain]) (7.34.0)\n",
"Collecting tokenize-rt>=3.2.0 (from nbqa[toolchain])\n",
" Downloading tokenize_rt-6.0.0-py2.py3-none-any.whl.metadata (4.1 kB)\n",
"Requirement already satisfied: tomli in /usr/local/lib/python3.10/dist-packages (from nbqa[toolchain]) (2.0.1)\n",
"Collecting black (from nbqa[toolchain])\n",
" Downloading black-24.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl.metadata (78 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m78.2/78.2 kB\u001b[0m \u001b[31m2.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting blacken-docs (from nbqa[toolchain])\n",
" Downloading blacken_docs-1.18.0-py3-none-any.whl.metadata (8.5 kB)\n",
"Collecting flake8 (from nbqa[toolchain])\n",
" Downloading flake8-7.1.1-py2.py3-none-any.whl.metadata (3.8 kB)\n",
"Collecting isort (from nbqa[toolchain])\n",
" Downloading isort-5.13.2-py3-none-any.whl.metadata (12 kB)\n",
"Collecting jupytext (from nbqa[toolchain])\n",
" Downloading jupytext-1.16.4-py3-none-any.whl.metadata (13 kB)\n",
"Requirement already satisfied: mypy in /usr/local/lib/python3.10/dist-packages (from nbqa[toolchain]) (1.11.2)\n",
"Collecting pylint (from nbqa[toolchain])\n",
" Downloading pylint-3.2.7-py3-none-any.whl.metadata (12 kB)\n",
"Collecting pyupgrade (from nbqa[toolchain])\n",
" Downloading pyupgrade-3.17.0-py2.py3-none-any.whl.metadata (15 kB)\n",
"Collecting ruff (from nbqa[toolchain])\n",
" Downloading ruff-0.6.4-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (25 kB)\n",
"Collecting pycodestyle>=2.12.0 (from autopep8>=1.5->nbqa[toolchain])\n",
" Downloading pycodestyle-2.12.1-py2.py3-none-any.whl.metadata (4.5 kB)\n",
"Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.10/dist-packages (from ipython>=7.8.0->nbqa[toolchain]) (71.0.4)\n",
"Collecting jedi>=0.16 (from ipython>=7.8.0->nbqa[toolchain])\n",
" Using cached jedi-0.19.1-py2.py3-none-any.whl.metadata (22 kB)\n",
"Requirement already satisfied: decorator in /usr/local/lib/python3.10/dist-packages (from ipython>=7.8.0->nbqa[toolchain]) (4.4.2)\n",
"Requirement already satisfied: pickleshare in /usr/local/lib/python3.10/dist-packages (from ipython>=7.8.0->nbqa[toolchain]) (0.7.5)\n",
"Requirement already satisfied: traitlets>=4.2 in /usr/local/lib/python3.10/dist-packages (from ipython>=7.8.0->nbqa[toolchain]) (5.7.1)\n",
"Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from ipython>=7.8.0->nbqa[toolchain]) (3.0.47)\n",
"Requirement already satisfied: pygments in /usr/local/lib/python3.10/dist-packages (from ipython>=7.8.0->nbqa[toolchain]) (2.16.1)\n",
"Requirement already satisfied: backcall in /usr/local/lib/python3.10/dist-packages (from ipython>=7.8.0->nbqa[toolchain]) (0.2.0)\n",
"Requirement already satisfied: matplotlib-inline in /usr/local/lib/python3.10/dist-packages (from ipython>=7.8.0->nbqa[toolchain]) (0.1.7)\n",
"Requirement already satisfied: pexpect>4.3 in /usr/local/lib/python3.10/dist-packages (from ipython>=7.8.0->nbqa[toolchain]) (4.9.0)\n",
"Requirement already satisfied: click>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from black->nbqa[toolchain]) (8.1.7)\n",
"Requirement already satisfied: mypy-extensions>=0.4.3 in /usr/local/lib/python3.10/dist-packages (from black->nbqa[toolchain]) (1.0.0)\n",
"Requirement already satisfied: packaging>=22.0 in /usr/local/lib/python3.10/dist-packages (from black->nbqa[toolchain]) (24.1)\n",
"Collecting pathspec>=0.9.0 (from black->nbqa[toolchain])\n",
" Downloading pathspec-0.12.1-py3-none-any.whl.metadata (21 kB)\n",
"Requirement already satisfied: platformdirs>=2 in /usr/local/lib/python3.10/dist-packages (from black->nbqa[toolchain]) (4.2.2)\n",
"Requirement already satisfied: typing-extensions>=4.0.1 in /usr/local/lib/python3.10/dist-packages (from black->nbqa[toolchain]) (4.12.2)\n",
"Collecting mccabe<0.8.0,>=0.7.0 (from flake8->nbqa[toolchain])\n",
" Downloading mccabe-0.7.0-py2.py3-none-any.whl.metadata (5.0 kB)\n",
"Collecting pyflakes<3.3.0,>=3.2.0 (from flake8->nbqa[toolchain])\n",
" Downloading pyflakes-3.2.0-py2.py3-none-any.whl.metadata (3.5 kB)\n",
"Requirement already satisfied: markdown-it-py>=1.0 in /usr/local/lib/python3.10/dist-packages (from jupytext->nbqa[toolchain]) (3.0.0)\n",
"Requirement already satisfied: mdit-py-plugins in /usr/local/lib/python3.10/dist-packages (from jupytext->nbqa[toolchain]) (0.4.1)\n",
"Requirement already satisfied: nbformat in /usr/local/lib/python3.10/dist-packages (from jupytext->nbqa[toolchain]) (5.10.4)\n",
"Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from jupytext->nbqa[toolchain]) (6.0.2)\n",
"Collecting astroid<=3.3.0-dev0,>=3.2.4 (from pylint->nbqa[toolchain])\n",
" Downloading astroid-3.2.4-py3-none-any.whl.metadata (4.5 kB)\n",
"Requirement already satisfied: tomlkit>=0.10.1 in /usr/local/lib/python3.10/dist-packages (from pylint->nbqa[toolchain]) (0.13.2)\n",
"Collecting dill>=0.2 (from pylint->nbqa[toolchain])\n",
" Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)\n",
"Requirement already satisfied: parso<0.9.0,>=0.8.3 in /usr/local/lib/python3.10/dist-packages (from jedi>=0.16->ipython>=7.8.0->nbqa[toolchain]) (0.8.4)\n",
"Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=1.0->jupytext->nbqa[toolchain]) (0.1.2)\n",
"Requirement already satisfied: ptyprocess>=0.5 in /usr/local/lib/python3.10/dist-packages (from pexpect>4.3->ipython>=7.8.0->nbqa[toolchain]) (0.7.0)\n",
"Requirement already satisfied: wcwidth in /usr/local/lib/python3.10/dist-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->ipython>=7.8.0->nbqa[toolchain]) (0.2.13)\n",
"Requirement already satisfied: fastjsonschema>=2.15 in /usr/local/lib/python3.10/dist-packages (from nbformat->jupytext->nbqa[toolchain]) (2.20.0)\n",
"Requirement already satisfied: jsonschema>=2.6 in /usr/local/lib/python3.10/dist-packages (from nbformat->jupytext->nbqa[toolchain]) (4.23.0)\n",
"Requirement already satisfied: jupyter-core!=5.0.*,>=4.12 in /usr/local/lib/python3.10/dist-packages (from nbformat->jupytext->nbqa[toolchain]) (5.7.2)\n",
"Requirement already satisfied: attrs>=22.2.0 in /usr/local/lib/python3.10/dist-packages (from jsonschema>=2.6->nbformat->jupytext->nbqa[toolchain]) (24.2.0)\n",
"Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.10/dist-packages (from jsonschema>=2.6->nbformat->jupytext->nbqa[toolchain]) (2023.12.1)\n",
"Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.10/dist-packages (from jsonschema>=2.6->nbformat->jupytext->nbqa[toolchain]) (0.35.1)\n",
"Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.10/dist-packages (from jsonschema>=2.6->nbformat->jupytext->nbqa[toolchain]) (0.20.0)\n",
"Downloading autopep8-2.3.1-py2.py3-none-any.whl (45 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m45.7/45.7 kB\u001b[0m \u001b[31m3.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading tokenize_rt-6.0.0-py2.py3-none-any.whl (5.9 kB)\n",
"Downloading black-24.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl (1.8 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.8/1.8 MB\u001b[0m \u001b[31m17.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading blacken_docs-1.18.0-py3-none-any.whl (8.2 kB)\n",
"Downloading flake8-7.1.1-py2.py3-none-any.whl (57 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m57.7/57.7 kB\u001b[0m \u001b[31m4.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading isort-5.13.2-py3-none-any.whl (92 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m92.3/92.3 kB\u001b[0m \u001b[31m8.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading jupytext-1.16.4-py3-none-any.whl (153 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m153.5/153.5 kB\u001b[0m \u001b[31m14.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading nbqa-1.9.0-py3-none-any.whl (35 kB)\n",
"Downloading pylint-3.2.7-py3-none-any.whl (519 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m519.9/519.9 kB\u001b[0m \u001b[31m36.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading pyupgrade-3.17.0-py2.py3-none-any.whl (62 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m62.0/62.0 kB\u001b[0m \u001b[31m5.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading ruff-0.6.4-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (10.3 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m10.3/10.3 MB\u001b[0m \u001b[31m89.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading astroid-3.2.4-py3-none-any.whl (276 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m276.3/276.3 kB\u001b[0m \u001b[31m21.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m116.3/116.3 kB\u001b[0m \u001b[31m9.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hUsing cached jedi-0.19.1-py2.py3-none-any.whl (1.6 MB)\n",
"Downloading mccabe-0.7.0-py2.py3-none-any.whl (7.3 kB)\n",
"Downloading pathspec-0.12.1-py3-none-any.whl (31 kB)\n",
"Downloading pycodestyle-2.12.1-py2.py3-none-any.whl (31 kB)\n",
"Downloading pyflakes-3.2.0-py2.py3-none-any.whl (62 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m62.7/62.7 kB\u001b[0m \u001b[31m5.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hInstalling collected packages: tokenize-rt, ruff, pyflakes, pycodestyle, pathspec, mccabe, jedi, isort, dill, astroid, pyupgrade, pylint, flake8, black, autopep8, nbqa, blacken-docs, jupytext\n",
"Successfully installed astroid-3.2.4 autopep8-2.3.1 black-24.8.0 blacken-docs-1.18.0 dill-0.3.8 flake8-7.1.1 isort-5.13.2 jedi-0.19.1 jupytext-1.16.4 mccabe-0.7.0 nbqa-1.9.0 pathspec-0.12.1 pycodestyle-2.12.1 pyflakes-3.2.0 pylint-3.2.7 pyupgrade-3.17.0 ruff-0.6.4 tokenize-rt-6.0.0\n"
]
}
],
"source": [
"# Reinstall torch for cuda 1.22\n",
"# ! pip uninstall -y torch torchvision torchaudio\n",
"! pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121\n",
"! pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu\n",
"\n",
"# Check if the versions make sense\n",
"print()\n",
"! python -c \"import torch; print(torch.__version__)\"\n",
"! python -c \"import torch; print(torch.version.cuda)\"\n",
"! nvcc --version\n",
"print()\n",
"\n",
"# Install pytorch geometric for the correct version of torch and cuda (cpu if nothing else is working)\n",
"! pip install torch_geometric icecream prettyprint tqdm\n",
"! pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.4.0+cu121.html\n",
"! pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.4.0+cpu.html\n",
"\n",
"# Notebook types check\n",
"! python -m pip install mypy\n",
"! python -m pip install -U \"nbqa[toolchain]\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3xgpzveCK6rn"
},
"source": [
"## Connect to Drive to save the results"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "UADp8tMfK0B_",
"outputId": "c9e93e92-9ac8-4c00-953b-2fce60b2eb7b"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Mounted at /content/drive\n"
]
}
],
"source": [
"import os\n",
"from google.colab import drive\n",
"drive.mount('/content/drive')\n",
"\n",
"# Create all the Directories needed by the project\n",
"! mkdir -p /content/drive/MyDrive/advProgProject\n",
"! mkdir -p /content/drive/MyDrive/advProgProject/data\n",
"! mkdir -p /content/drive/MyDrive/advProgProject/results"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "R3J2nnBpjUFj"
},
"source": [
"## Static type checking of the whole notebook"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "iw8WKinXjQ8S",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "44922d3c-c021-49ea-e7d9-f090f3e9b6d1"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Cloning into '530716747f32fb585780e847f1bf428d'...\n",
"remote: Enumerating objects: 98, done.\u001b[K\n",
"remote: Counting objects: 100% (15/15), done.\u001b[K\n",
"remote: Compressing objects: 100% (15/15), done.\u001b[K\n",
"remote: Total 98 (delta 4), reused 0 (delta 0), pack-reused 83 (from 1)\u001b[K\n",
"Receiving objects: 100% (98/98), 322.66 KiB | 6.33 MiB/s, done.\n",
"Resolving deltas: 100% (35/35), done.\n",
"cloned_gist/_/newadvancedprogramming.ipynb:cell_4:45: \u001b[1m\u001b[31merror:\u001b[m Incompatible types in assignment (expression has type \u001b[m\u001b[1m\"None\"\u001b[m, variable has type \u001b[m\u001b[1m\"Tensor\"\u001b[m) \u001b[m\u001b[33m[assignment]\u001b[m\n",
"cloned_gist/_/newadvancedprogramming.ipynb:cell_4:58: \u001b[1m\u001b[31merror:\u001b[m Incompatible types in assignment (expression has type \u001b[m\u001b[1m\"int\"\u001b[m, variable has type \u001b[m\u001b[1m\"str\"\u001b[m) \u001b[m\u001b[33m[assignment]\u001b[m\n",
"cloned_gist/_/newadvancedprogramming.ipynb:cell_4:66: \u001b[1m\u001b[31merror:\u001b[m Incompatible types in assignment (expression has type \u001b[m\u001b[1m\"ndarray[Any, Any]\"\u001b[m, variable has type \u001b[m\u001b[1m\"Tensor\"\u001b[m) \u001b[m\u001b[33m[assignment]\u001b[m\n",
"cloned_gist/_/newadvancedprogramming.ipynb:cell_4:68: \u001b[1m\u001b[31merror:\u001b[m Incompatible types in assignment (expression has type \u001b[m\u001b[1m\"ndarray[Any, Any]\"\u001b[m, variable has type \u001b[m\u001b[1m\"Tensor\"\u001b[m) \u001b[m\u001b[33m[assignment]\u001b[m\n",
"cloned_gist/_/newadvancedprogramming.ipynb:cell_4:70: \u001b[1m\u001b[31merror:\u001b[m Incompatible types in assignment (expression has type \u001b[m\u001b[1m\"ndarray[Any, Any]\"\u001b[m, variable has type \u001b[m\u001b[1m\"Tensor\"\u001b[m) \u001b[m\u001b[33m[assignment]\u001b[m\n",
"cloned_gist/_/newadvancedprogramming.ipynb:cell_4:72: \u001b[1m\u001b[31merror:\u001b[m Incompatible types in assignment (expression has type \u001b[m\u001b[1m\"ndarray[Any, Any]\"\u001b[m, variable has type \u001b[m\u001b[1m\"Tensor\"\u001b[m) \u001b[m\u001b[33m[assignment]\u001b[m\n",
"cloned_gist/_/newadvancedprogramming.ipynb:cell_4:75: \u001b[1m\u001b[31merror:\u001b[m Incompatible types in assignment (expression has type \u001b[m\u001b[1m\"ndarray[Any, Any]\"\u001b[m, variable has type \u001b[m\u001b[1m\"Tensor\"\u001b[m) \u001b[m\u001b[33m[assignment]\u001b[m\n",
"cloned_gist/_/newadvancedprogramming.ipynb:cell_4:90: \u001b[1m\u001b[31merror:\u001b[m Module \u001b[m\u001b[1m\"torch.optim.lr_scheduler\"\u001b[m is not valid as a type \u001b[m\u001b[33m[valid-type]\u001b[m\n",
"cloned_gist/_/newadvancedprogramming.ipynb:cell_4:90: \u001b[34mnote:\u001b[m Perhaps you meant to use a protocol matching the module structure?\u001b[m\n",
"cloned_gist/_/newadvancedprogramming.ipynb:cell_4:92: \u001b[1m\u001b[31merror:\u001b[m Incompatible types in assignment (expression has type \u001b[m\u001b[1m\"Tensor\"\u001b[m, variable has type \u001b[m\u001b[1m\"ndarray[Any, Any]\"\u001b[m) \u001b[m\u001b[33m[assignment]\u001b[m\n",
"cloned_gist/_/newadvancedprogramming.ipynb:cell_4:97: \u001b[1m\u001b[31merror:\u001b[m \u001b[m\u001b[1m\"map\"\u001b[m expects 1 type argument, but 2 given \u001b[m\u001b[33m[type-arg]\u001b[m\n",
"cloned_gist/_/newadvancedprogramming.ipynb:cell_4:97: \u001b[1m\u001b[31merror:\u001b[m Incompatible types in assignment (expression has type \u001b[m\u001b[1m\"dict[Any, bool]\"\u001b[m, variable has type \u001b[m\u001b[1m\"map[Any]\"\u001b[m) \u001b[m\u001b[33m[assignment]\u001b[m\n",
"cloned_gist/_/newadvancedprogramming.ipynb:cell_4:153: \u001b[1m\u001b[31merror:\u001b[m Unsupported left operand type for != (torch.optim.lr_scheduler?) \u001b[m\u001b[33m[operator]\u001b[m\n",
"cloned_gist/_/newadvancedprogramming.ipynb:cell_4:154: \u001b[1m\u001b[31merror:\u001b[m torch.optim.lr_scheduler? has no attribute \u001b[m\u001b[1m\"step\"\u001b[m \u001b[m\u001b[33m[attr-defined]\u001b[m\n",
"cloned_gist/_/newadvancedprogramming.ipynb:cell_6:46: \u001b[1m\u001b[31merror:\u001b[m \u001b[m\u001b[1m\"map\"\u001b[m expects 1 type argument, but 2 given \u001b[m\u001b[33m[type-arg]\u001b[m\n",
"cloned_gist/_/newadvancedprogramming.ipynb:cell_6:46: \u001b[1m\u001b[31merror:\u001b[m Incompatible types in assignment (expression has type \u001b[m\u001b[1m\"dict[str, Tensor]\"\u001b[m, variable has type \u001b[m\u001b[1m\"map[Any]\"\u001b[m) \u001b[m\u001b[33m[assignment]\u001b[m\n",
"cloned_gist/_/newadvancedprogramming.ipynb:cell_6:57: \u001b[1m\u001b[31merror:\u001b[m Incompatible return value type (got \u001b[m\u001b[1m\"map[Any]\"\u001b[m, expected \u001b[m\u001b[1m\"dict[str, ndarray[Any, Any]]\"\u001b[m) \u001b[m\u001b[33m[return-value]\u001b[m\n",
"cloned_gist/_/newadvancedprogramming.ipynb:cell_8:10: \u001b[1m\u001b[31merror:\u001b[m Incompatible types in assignment (expression has type \u001b[m\u001b[1m\"dict[str, list[int]]\"\u001b[m, variable has type \u001b[m\u001b[1m\"dict[int, Any]\"\u001b[m) \u001b[m\u001b[33m[assignment]\u001b[m\n",
"cloned_gist/_/newadvancedprogramming.ipynb:cell_8:16: \u001b[1m\u001b[31merror:\u001b[m Argument 1 to \u001b[m\u001b[1m\"list\"\u001b[m has incompatible type \u001b[m\u001b[1m\"Sequence[object]\"\u001b[m; expected \u001b[m\u001b[1m\"Iterable[int]\"\u001b[m \u001b[m\u001b[33m[arg-type]\u001b[m\n",
"\u001b[1m\u001b[31mFound 18 errors in 1 file (checked 1 source file)\u001b[m\n"
]
}
],
"source": [
"! rm -rf cloned_gist && mkdir -p cloned_gist && cd cloned_gist && git clone https://gist.github.com/530716747f32fb585780e847f1bf428d.git && mv 530716747f32fb585780e847f1bf428d _\n",
"! nbqa mypy './cloned_gist/_/newadvancedprogramming.ipynb' --ignore-missing-imports"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "50IW_jvldOPn"
},
"source": [
"## Setup:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "xRRDj6nDdJ7D"
},
"outputs": [],
"source": [
"from time import perf_counter\n",
"import numpy as np\n",
"from os import listdir\n",
"from os.path import isfile, join\n",
"from tqdm.notebook import trange, tqdm\n",
"import os, random, json, pickle, collections\n",
"\n",
"import torch, torch_geometric\n",
"import torch.nn as nn\n",
"from torch.nn import Linear, ReLU, Dropout\n",
"from torch_geometric.datasets import Amazon\n",
"import torch.nn.functional as F\n",
"import torch_geometric.transforms as T\n",
"from torch_geometric.data import Data\n",
"from torch_geometric.utils import train_test_split_edges, remove_isolated_nodes, index_to_mask\n",
"from torch_geometric.nn import Sequential, GCNConv, GATConv, BatchNorm, global_mean_pool\n",
"from torch_geometric.loader import NeighborLoader\n",
"\n",
"from pprint import pprint\n",
"from icecream import ic\n",
"\n",
"from matplotlib import pyplot as plt\n",
"from sklearn.metrics import confusion_matrix, classification_report, ConfusionMatrixDisplay\n",
"\n",
"\n",
"# Set device (always GPU if available)\n",
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
"\n",
"# Hyperparams\n",
"n_tasks = 5\n",
"n_classes = n_tasks*2\n",
"test_size = 0.10\n",
"validation_size = 0.05\n",
"nEpochs = 50\n",
"\n",
"# Ok now lets create an object that takes care of most of the stuff related to a single task\n",
"class TaskManager():\n",
" # Attributes\n",
" task: str\n",
" labels: list[int]\n",
" og_dataset: Data\n",
" nodes_mask: torch.Tensor\n",
" train_edges_mask: torch.Tensor\n",
" validation_edges_mask: torch.Tensor\n",
" test_edges_mask: torch.Tensor\n",
" missing_edges_mask: torch.Tensor = None\n",
" model: nn.Module\n",
" device: torch.device\n",
" # __slots__ = ('task', 'labels', 'og_dataset', 'nodes_mask', 'train_edges_mask', 'validation_edges_mask', 'test_edges_mask' 'missing_edges_mask', 'model', 'device')\n",
"\n",
" def __init__(self, task: str, labels: list[int], og_dataset: Data, edges_masks: dict[str, np.ndarray])->None:\n",
" self.task: str = task\n",
" self.labels : list[int] = labels\n",
" self.og_dataset: Data = og_dataset\n",
" self.setFromMap(edges_masks)\n",
"\n",
" def load(self, filename: str)->None:\n",
" splitted: list[str] = filename.split(\".pkl\")[0].split(\"_\")\n",
" self.task: int = int(splitted[1])\n",
" self.labels: list[int] = list(map(int, splitted[2:]))\n",
" self.loadFromMap(filename= filename)\n",
"\n",
" def setFromMap(self, edges_masks: dict[str, np.ndarray])->None:\n",
" for name, mask in edges_masks.items():\n",
" match name:\n",
" case \"nodes\":\n",
" self.nodes_mask = mask\n",
" case \"train\":\n",
" self.train_edges_mask = mask\n",
" case \"validation\":\n",
" self.validation_edges_mask = mask\n",
" case \"test\":\n",
" self.test_edges_mask = mask\n",
" case \"missing\":\n",
" self.missing_edges_mask = mask\n",
" case _:\n",
" raise Exception(\"Invalid map entry\")\n",
"\n",
" def loadFromMap(self, filename: str)->None:\n",
" with open(f\"{filename}\", \"rb\") as filee:\n",
" self.setFromMap(pickle.load(filee))\n",
"\n",
" def setModel(self, model: nn.Module):\n",
" self.model = model\n",
"\n",
" def setDevice(self, device: torch.device)->None:\n",
" self.device = device\n",
"\n",
" # NB For now we are doing full reharsal\n",
" def train(self, optimizer: torch.optim.Optimizer, loss_fn: nn.Module, scheduler: torch.optim.lr_scheduler=None, previous_task_managers: list[\"TaskManager\"] = [])->None:\n",
" # Mesh all the previous nodes mask\n",
" new_node_mask: torch.Tensor = self.nodes_mask\n",
" for tm in previous_task_managers:\n",
" new_node_mask = torch.logical_or(new_node_mask, tm.nodes_mask).to(device)\n",
"\n",
" # mask to index function\n",
" all_nodes_tensor: torch.Tensor = torch.nonzero(new_node_mask, as_tuple=False).view(-1).to(device)\n",
" # and get a hash map out of it\n",
" all_nodes: map[int, bool] = { val: True for val in all_nodes_tensor.cpu().numpy().tolist()}\n",
" # Also use a Vectorize function so we can check if an entry is part of the afromentioned map\n",
" # (we have to do this in numpy because there is no vectorize counterpart in pytorch)\n",
" vectorized_check_presence = np.vectorize(lambda x: x in all_nodes)\n",
"\n",
" # Mesh all the previous edges masks\n",
" all_tms: list[\"TaskManager\"] = list([self, *previous_task_managers])\n",
"\n",
"\n",
" # Get the missing edges between all of the task managers\n",
" # as a starting point, get a mask of all the legal edges with the current noded\n",
" _edge_mask_1: torch.Tensor = torch.from_numpy(vectorized_check_presence(self.og_dataset.edge_index[0,:].cpu().numpy().astype(int))).to(torch.bool).to(device)\n",
" _edge_mask_2: torch.Tensor = torch.from_numpy(vectorized_check_presence(self.og_dataset.edge_index[1,:].cpu().numpy().astype(int))).to(torch.bool).to(device)\n",
" allowed_edges: torch.Tensor = torch.logical_and(_edge_mask_1, _edge_mask_2)\n",
"\n",
" # then concat them with the other missing edges mask\n",
" new_missing_edges_mask: torch.Tensor = torch.zeros(self.og_dataset.edge_index.shape[1]).to(torch.bool).to(device)\n",
" for tm in all_tms:\n",
" if tm.missing_edges_mask is not None: # Get the missing edges\n",
" # They have to be part of the allowed edges list to be added\n",
" new_missing_edges_mask = torch.logical_or(new_missing_edges_mask, torch.logical_and(tm.missing_edges_mask, allowed_edges))\n",
"\n",
" # Get the new edges mask by concatting the task connected components ones with the new one for train\n",
" new_edges_mask: torch.Tensor = torch.clone(new_missing_edges_mask)\n",
" for tm in all_tms:\n",
" # Here for test sake we use the full size of the train, doing full reharsal\n",
" new_edges_mask = torch.logical_or(new_edges_mask, tm.train_edges_mask)\n",
"\n",
" # Neighbour DataLoader\n",
" new_x, new_y = self.og_dataset.x[new_node_mask], self.og_dataset.y[new_node_mask]\n",
" # First create a \"subgraph\" just reindexes the edges based on the provided subset of nodes\n",
" subgraph_edge_index, _ = torch_geometric.utils.subgraph(subset=new_node_mask, edge_index=self.og_dataset.edge_index[:, new_edges_mask], relabel_nodes=True)\n",
" new_data: Data = Data(x=new_x, edge_index=subgraph_edge_index, y=new_y)\n",
" train_loader = NeighborLoader(new_data,\n",
" input_nodes=new_node_mask,\n",
" num_neighbors=[30, 3], # Sample 30 neighbors for each node for 2 iterations\n",
" batch_size=128,\n",
" transform=T.NormalizeFeatures(),\n",
" subgraph_type='induced',\n",
" shuffle=True)\n",
" self.model.train()\n",
" for epoch in range(1, nEpochs):\n",
" total_loss, total_examples = 0, 0\n",
" # Missing validation here\n",
" pprint(train_loader) # DEBUG\n",
" for batch in train_loader:\n",
" batch_size = batch.batch_size\n",
" batch = batch.to(self.device)\n",
"\n",
" # Std optimization step\n",
" optimizer.zero_grad()\n",
" out = model(batch)\n",
"\n",
" loss = loss_fn(out, batch.y)\n",
" loss.backward()\n",
"\n",
" # Reduce learning rate if platouing\n",
" if scheduler != None:\n",
" scheduler.step(loss)\n",
" optimizer.step()\n",
"\n",
" # Save the values for fun\n",
" total_examples += batch_size\n",
" total_loss += float(loss) * batch_size\n",
"\n",
" print(f\"\\t-[TRAIN] In epoch: {epoch} we have an average training loss of {(total_loss/total_examples):.4f}\")\n",
"\n",
" @torch.no_grad()\n",
" def test(self, performance_file: str= \"\")->np.ndarray:\n",
" y: np.ndarray = np.empty((0,0))\n",
" y_pred: np.ndarray = np.empty((0,0))\n",
"\n",
" # Neighbour DataLoader\n",
" nodes_mask = torch.from_numpy(self.nodes_mask).to(torch.bool)\n",
" new_x, new_y = self.og_dataset.x[self.nodes_mask], self.og_dataset.y[self.nodes_mask]\n",
" # First create a \"subgraph\" just reindexes the edges based on the provided subset of nodes\n",
" subgraph_edge_index, _= torch_geometric.utils.subgraph(subset=nodes_mask, edge_index=self.og_dataset.edge_index[:, self.test_edges_mask], relabel_nodes=True)\n",
" new_data: Data = Data(x=new_x, edge_index=subgraph_edge_index, y=new_y)\n",
" # Then create a new NeighbourLoader by using the new edge index provided by the subgraph\n",
" test_loader = NeighborLoader(new_data,\n",
" input_nodes=self.nodes_mask,\n",
" num_neighbors=[30, 3],\n",
" batch_size=128,\n",
" transform=T.NormalizeFeatures(),\n",
" subgraph_type='induced',\n",
" shuffle= True)\n",
"\n",
" for batch in test_loader:\n",
" batch = batch.to(self.device)\n",
" optimizer.zero_grad()\n",
" y_pred = np.append(y_pred, torch.argmax(model(batch), dim=1).cpu().numpy())\n",
" y = np.append(y, batch.y.cpu().numpy())\n",
"\n",
" report: dict = classification_report(y, y_pred, zero_division = 0.0)\n",
" print(\"\\nClassification Report:\\n\", report)\n",
"\n",
" # If specified save the performance in the performance file\n",
" if performance_file != \"\":\n",
" other_info: list[dict]\n",
" # Fetch the previous data from file\n",
" with open(performance_file, \"r\") as filee:\n",
" other_info = json.loads(filee.read())\n",
" other_info.append(report)\n",
" # Override it with new data (a simple append operation)\n",
" with open(performance_file, \"w\") as filee:\n",
" filee.write(json.dumps(other_info))\n",
"\n",
" cm: np.ndarray = confusion_matrix(y_pred, y)\n",
" return cm"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1OYaOmqj878S"
},
"source": [
"## Generate tasks\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "k9r3KvNUwBRI",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "5484d971-86e4-46db-deea-fa3334f0a761"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Do you want to generate a set of tasks, even if you already have tasks saved on your drive? [Y/N]\n",
"y\n"
]
}
],
"source": [
"# IMPORTANT: Your drive project Directory, choose wisely\n",
"base_dir: str = \"/content/drive/MyDrive/advProgProject\"\n",
"save_dir: str = os.path.join(base_dir, \"data\")\n",
"want_to_generate = False\n",
"\n",
"if len([ 0 for f in listdir(save_dir) if isfile(join(save_dir, f)) and f.endswith(\".pkl\")]) >= 1:\n",
" print(\"Do you want to generate a set of tasks, even if you already have tasks saved on your drive? [Y/N]\")\n",
" decision = input()\n",
" if decision == \"y\" or decision == \"Y\" or decision == \"Yes\" or decision == \"yes\" or decision == \"YES\":\n",
" want_to_generate = True\n",
"else:\n",
" want_to_generate = True"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"id": "uCzWuqDN6grL",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "7b1abc00-be09-478b-bc2c-ab95a65faf55"
},
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.10/dist-packages/torch_geometric/data/dataset.py:238: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
" if osp.exists(f) and torch.load(f) != _repr(self.pre_transform):\n",
"/usr/local/lib/python3.10/dist-packages/torch_geometric/data/dataset.py:246: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
" if osp.exists(f) and torch.load(f) != _repr(self.pre_filter):\n",
"/usr/local/lib/python3.10/dist-packages/torch_geometric/io/fs.py:215: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
" return torch.load(f, map_location)\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"torch.Size([13752, 767]) torch.Size([2, 491722]) torch.Size([13752])\n",
"\n",
"Class sizes :\n",
"Counter({4: 5158, 8: 2156, 1: 2142, 2: 1414, 7: 818, 3: 542, 6: 487, 0: 436, 5: 308, 9: 291})\n",
"\n",
"\n",
" - Task n.1 with labels [1, 6] number of nodes: 2629:\")\n",
" missing_edges_mask: 61932/491722\n",
" 45582/491722\n",
" \n",
"\n",
" - Task n.2 with labels [2, 0] number of nodes: 1850:\")\n",
" missing_edges_mask: 13394/491722\n",
" 44192/491722\n",
" \n",
"\n",
" - Task n.3 with labels [4, 7] number of nodes: 5976:\")\n",
" missing_edges_mask: 72496/491722\n",
" 227982/491722\n",
" \n",
"\n",
" - Task n.4 with labels [9, 3] number of nodes: 833:\")\n",
" missing_edges_mask: 19112/491722\n",
" 14228/491722\n",
" \n",
"\n",
" - Task n.5 with labels [5, 8] number of nodes: 2464:\")\n",
" missing_edges_mask: 46250/491722\n",
" 53146/491722\n",
" \n"
]
}
],
"source": [
"# Global data structure for data, training, metrics and memory generation!\n",
"task_managers: list[TaskManager] = []\n",
"\n",
"def split_mask(og_mask: torch.Tensor)-> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:\n",
" # first transform the mask to indeces for the edge_index datastructure:\n",
" og_indeces: torch.Tensor = torch.nonzero(og_mask, as_tuple=False).view(-1).to(device)\n",
"\n",
" indexes_size: int = og_indeces.shape[0]\n",
" # shuffle the indeces by masking with a random permutation:\n",
" shuffled_indeces: torch.Tensor = og_indeces[torch.randperm(indexes_size)]\n",
"\n",
" # generate the boolean masks based on slices of the randomly permutated list of indeces\n",
" n_edges:int = og_mask.shape[0]\n",
" train_index: int = int(indexes_size*(1 - validation_size - test_size))\n",
" test_index: int = train_index + int(indexes_size*test_size)\n",
" train_mask: torch.Tensor = torch_geometric.utils.index_to_mask(shuffled_indeces[:train_index], size=n_edges)\n",
" test_mask: torch.Tensor = torch_geometric.utils.index_to_mask(shuffled_indeces[train_index:test_index], size=n_edges)\n",
" validation_mask: torch.Tensor = torch_geometric.utils.index_to_mask(shuffled_indeces[test_index:], size=n_edges)\n",
"\n",
" return train_mask, test_mask, validation_mask\n",
"\n",
"def setTaskDataset(name: int, labels: list[int])-> dict[str, np.ndarray]:\n",
" dataset_copy: Data = dataset.clone().to(device)\n",
" nodes_mask = np.isin(dataset_copy.y.numpy(), labels)\n",
"\n",
" # Mask to indeces all nodes\n",
" all_nodes_list: np.ndarray = torch.nonzero(torch.from_numpy(nodes_mask).to(device), as_tuple=False).view(-1).cpu().numpy()\n",
" all_nodes = {val: True for val in all_nodes_list}\n",
" vectorized_check_presence = np.vectorize(lambda x: x in all_nodes)\n",
" edges_mask_1 = torch.from_numpy(vectorized_check_presence(dataset_copy.edge_index.cpu().numpy()[0,:])).to(device)\n",
" edges_mask_2 = torch.from_numpy(vectorized_check_presence(dataset_copy.edge_index.cpu().numpy()[1,:])).to(device)\n",
" edges_mask = torch.logical_and(edges_mask_1, edges_mask_2)\n",
"\n",
" # Missing edges mask that we are going to use later when meshing thogether multiple tasks\n",
" missing_edges_mask : torch.Tensor = torch.logical_xor(edges_mask_1, edges_mask_2).to(device)\n",
"\n",
" # Useful Info\n",
" new_edge_index: torch.Tensor = dataset_copy.edge_index[:,edges_mask]\n",
" print(f\"\"\"\n",
" - Task n.{name} with labels {labels} number of nodes: {collections.Counter(nodes_mask.tolist())[True]}:\")\n",
" missing_edges_mask: {torch.sum(missing_edges_mask)}/{missing_edges_mask.shape[0]}\n",
" {new_edge_index.shape[1]}/{dataset_copy.edge_index.shape[1]}\n",
" \"\"\")\n",
"\n",
"\n",
" # Generate the train/test/validation masks\n",
" train_mask, test_mask, validation_mask = split_mask(edges_mask)\n",
" t_nodes_mask = torch.from_numpy(nodes_mask).to(torch.bool)\n",
"\n",
" max_size = dataset.edge_index.shape[1]\n",
" return_map: map[str, torch.Tensor] = {\n",
" \"nodes\": t_nodes_mask.to(device),\n",
" \"train\": train_mask,\n",
" \"validation\": test_mask,\n",
" \"test\": validation_mask,\n",
" \"missing\": missing_edges_mask\n",
" }\n",
"\n",
" # Just pickle it for later (as a .pkl)\n",
" with open(f\"{save_dir}/{name}_{'_'.join(str(val) for val in labels)}.pkl\", \"wb\") as filee:\n",
" pickle.dump(return_map, filee)\n",
"\n",
" return return_map\n",
"\n",
"# Get the data,\n",
"# also lets store it on Google Drive so we dont have to redownload it every time\n",
"dataset = Amazon(save_dir, \"computers\", transform=T.ToUndirected(merge=True)).to(device)[0]\n",
"\n",
"# Info\n",
"print(dataset.x.shape, dataset.edge_index.shape, dataset.y.shape)\n",
"print(\"\\nClass sizes :\")\n",
"print(collections.Counter(dataset.y.numpy()))\n",
"print()\n",
"\n",
"if want_to_generate:\n",
" # First mercilessy remove all the previous .pkl files\n",
" for f in listdir(save_dir):\n",
" if isfile(join(save_dir, f)) and f.endswith(\".pkl\"):\n",
" os.remove(os.path.join(save_dir, f))\n",
"\n",
" # Then define a tasks with 2 different nodes classes (Class Incremental setting)\n",
" # - Generate the map:\n",
" classes = np.unique(dataset.y.numpy()).tolist()\n",
" random.shuffle(classes)\n",
" classes_per_task = len(classes) // n_tasks\n",
" previous, tasks = 0, {}\n",
" for i,new_index in enumerate(range(classes_per_task, len(classes)+2, classes_per_task)):\n",
" tasks[i] = classes[previous:new_index]\n",
" previous = new_index\n",
"\n",
" # Generate a dataset for each task\n",
" task_managers = [TaskManager(str(index), labels, dataset.to(device), setTaskDataset(index+1, labels)) for index, labels in tasks.items()]\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YdbxH7SsT-iB"
},
"source": [
"NB: we are losing a lot of information 3/4 edges of the graph lets keep the original dataset edge index and add them gradually"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "m50WslrcuT1d"
},
"source": [
"While we are at it lets just generate a small json for the class-pair splits (and its reader function), so we can reproduce the experiments:\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"id": "KxM6lAzrpLGN"
},
"outputs": [],
"source": [
"if want_to_generate:\n",
" os.remove(f\"{save_dir}/class_splits.json\")\n",
" with open(f\"{save_dir}/class_splits.json\", \"w\") as filee:\n",
" filee.write(json.dumps(tasks))\n",
"\n",
"def read_task_file()-> dict[str, list[int]]:\n",
" filee = open(f\"{save_dir}/class_splits.json\", \"r\")\n",
" tasks = json.loads(filee.read())\n",
" filee.close()\n",
" return tasks"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7DluM7JNaWlS"
},
"source": [
"### Load Old Task Split"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "am8M1xLqZxvD"
},
"outputs": [],
"source": [
"# Load each tasks from serial if we haven't just created them\n",
"def load_task_map(filename: str)->dict[str, np.ndarray]:\n",
" task_file = open(filename, \"rb\") # NB: b stands for binary mode (instead of text mode) for pickle\n",
" task_dataset: Data= pickle.load(task_file)\n",
" task_file.close()\n",
" return task_dataset\n",
"\n",
"if not want_to_generate:\n",
" tasks = read_task_file()\n",
" # list of tuples tuple[task,labels]\n",
" parsedFilesParams = list(map(lambda x: tuple([x.split(\".\")[0].split(\"_\")[0], list(map(int,x.split(\".\")[0].split(\"_\")[1:]))]) , [f for f in listdir(save_dir) if isfile(join(save_dir, f)) and f.endswith(\".pkl\")]))\n",
" # list of file paths: str\n",
" parsedFilesPaths = [f\"{save_dir}/{f}\" for f in listdir(save_dir) if isfile(join(save_dir, f)) and f.endswith(\".pkl\")]\n",
" # load the file contents into the Task_managers\n",
" task_managers = [TaskManager(str(vals[0]), list(vals[1]), dataset, load_task_map(path)) for vals, path in zip(parsedFilesParams, parsedFilesPaths)]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "r1DWP8s_zp_E"
},
"source": [
"## Models"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "OfFnAZb9-C_v"
},
"source": [
"### GCN (Graph Convolutional Network):"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"id": "7px9zyGq-CBP"
},
"outputs": [],
"source": [
"class test_SAGE(nn.Module):\n",
" def __init__(self, in_channels: int, out_channels: int, hidden_dim: int)-> None:\n",
" super().__init__()\n",
"\n",
" self.layer1 = torch_geometric.nn.SAGEConv(in_channels, hidden_dim)\n",
" self.layer2 = torch_geometric.nn.SAGEConv(hidden_dim, hidden_dim)\n",
" self.layer3 = torch_geometric.nn.SAGEConv(hidden_dim, hidden_dim)\n",
" self.classifier = Linear(hidden_dim, out_channels)\n",
"\n",
" self.bn1 = nn.BatchNorm1d(hidden_dim)\n",
" self.bn2 = nn.BatchNorm1d(hidden_dim)\n",
" self.bn3 = nn.BatchNorm1d(hidden_dim)\n",
"\n",
" self.activation = torch.nn.ReLU(inplace=True)\n",
" self.pool = torch_geometric.nn.global_add_pool\n",
"\n",
" def forward(self, batched):\n",
" x ,edge_index, batch = batched['x'],batched['edge_index'],batched['n_id']\n",
"\n",
" # 1. Layer 1\n",
" x = self.layer1(x, edge_index)\n",
" x = self.bn1(x)\n",
" x = self.activation(x)\n",
" # 2. Layer 2\n",
" x = self.layer2(x, edge_index)\n",
" x = self.bn2(x)\n",
" x = self.activation(x)\n",
" # 3. Layer 3\n",
" x = self.layer3(x, edge_index)\n",
" x = self.bn3(x)\n",
" x = self.activation(x)\n",
" # 3. global pooling\n",
" # x = self.pool(x,batch) # [batch_size,hiddin_dim]\n",
" # 4. classifier\n",
" x = self.classifier(x)\n",
"\n",
" return F.sigmoid(x)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oEvMHa94_I6j"
},
"source": [
"### GAT (Graph Attention Network):"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nrf9k0VdwHru"
},
"source": [
"### Creating the model:\n",
"we create the model and assign it to all task_managers."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "GDEZmgU6Lpw1",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "d789b4d2-ae87-4830-e96d-73e00572719b"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"UsingDevice: cpu\n",
"\n",
"Task: 0 Labels: 1 6\n",
"mid\n"
]
}
],
"source": [
"# Parameters we need when we create the neural net\n",
"in_channels = dataset.x.shape[1]\n",
"out_channels = n_tasks*2 # NOT SURE\n",
"\n",
"# Everything we need\n",
"model = test_SAGE( in_channels, out_channels, 256)\n",
"optimizer = torch.optim.Adam(model.parameters(), lr=0.001)\n",
"scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')\n",
"loss_fn = nn.CrossEntropyLoss()\n",
"\n",
"# Set the model for each task manager\n",
"for tm in task_managers: tm.setModel(model)\n",
"\n",
"print(f\"UsingDevice: {device}\")\n",
"for tm in task_managers: tm.setDevice(device)\n",
"# Move the model to GPU if available\n",
"model.to(device)\n",
"\n",
"# Test and Save Performances all the results in a task specific file and plot the confusion matrix\n",
"# create the dir if not present\n",
"result_dir = os.path.join(base_dir, \"results\")\n",
"! mkdir -p { result_dir} && cd {result_dir} && rm -rf $( ls )\n",
"\n",
"# Train-Test loop\n",
"for i, tm in enumerate(task_managers):\n",
" print(f\"\\nTask: {tm.task} Labels: {' '.join(str(lbl) for lbl in tm.labels)}\")\n",
" # Train on new task and use the previous taskManagers memories\n",
" tm.train(optimizer, loss_fn, scheduler, task_managers[:i])\n",
"\n",
" filename = os.path.join(result_dir, f\"task_{i+1}.performance.json\")\n",
" with open(filename, \"w+\") as filee: # Create the file\n",
" filee.write(json.dumps([]))\n",
"\n",
" # Print statistics for each task\n",
" cms: list[np.ndarray] = []\n",
" for cm in [ttm.test(filename) for ttm in task_managers[:i+1]]:\n",
" cms.append(cm)\n",
"\n",
" # Plot the confusion matrices\n",
" if i>0:\n",
" fig, axs = plt.subplots(1, i+1, figsize=(15, 15))\n",
" fig.tight_layout(pad=3.0)\n",
" for ttm, cm, ax in zip(task_managers[:i+1], cms, axs):\n",
" ax.title.set_text(f\"Task: {ttm.task} Labels: {' '.join([str(lbl) for lbl in ttm.labels])}\")\n",
" disp = ConfusionMatrixDisplay(confusion_matrix=cm)\n",
" disp.plot(cmap=plt.cm.Blues, ax=ax)\n",
" ax.grid(False)\n",
"\n",
" else: # just one\n",
" disp = ConfusionMatrixDisplay(confusion_matrix=cms[0])\n",
" disp.plot(cmap=plt.cm.Blues)\n",
" plt.gca().images[-1].colorbar.remove()\n",
" plt.title(f\"Task: {tm.task} Labels: {' '.join([str(lbl) for lbl in tm.labels])}\")\n",
"\n",
" fig.tight_layout(pad=3.0)\n",
" plt.show()\n"
]
}
],
"metadata": {
"colab": {
"provenance": [],
"authorship_tag": "ABX9TyPlJoqy+HB4iM5gwHj7xn+F",
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment