Created
May 30, 2020 01:44
-
-
Save refraction-ray/dc22288a0d9e22bb263e59487ad8f5ea to your computer and use it in GitHub Desktop.
einsum_symbol.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "einsum_symbol.ipynb", | |
"provenance": [], | |
"collapsed_sections": [], | |
"authorship_tag": "ABX9TyNRnzWO/H0U0kPQ3nLBDKsC", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/refraction-ray/dc22288a0d9e22bb263e59487ad8f5ea/einsum_symbol.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "figBFYiINilo", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 154 | |
}, | |
"outputId": "7a4cbe58-b951-4c93-f135-1501f070b8bc" | |
}, | |
"source": [ | |
"!pip install qop\n", | |
"!pip install opt_einsum" | |
], | |
"execution_count": 1, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Collecting qop\n", | |
" Downloading https://files.pythonhosted.org/packages/d3/56/8182fa684d0dc249ced5c3a219f9ae4b34e898dc8dca504ccba3cc848111/qop-0.0.2-py3-none-any.whl\n", | |
"Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from qop) (1.18.4)\n", | |
"Installing collected packages: qop\n", | |
"Successfully installed qop-0.0.2\n", | |
"Requirement already satisfied: opt_einsum in /usr/local/lib/python3.6/dist-packages (3.2.1)\n", | |
"Requirement already satisfied: numpy>=1.7 in /usr/local/lib/python3.6/dist-packages (from opt_einsum) (1.18.4)\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "aoJP-IYTNnmk", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"from qop.symbol import Symbols\n", | |
"from qop.base import simplify\n", | |
"from opt_einsum import contract\n", | |
"import numpy as np" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "Xi5utt9yNsF6", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"a,b,c = Symbols(\"abc\")" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "x89q05U_NwWJ", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
}, | |
"outputId": "0bdcf524-8ad4-4aa7-b92c-1e2505265804" | |
}, | |
"source": [ | |
"# a,b,c are some customized object with __add__ and __mul__\n", | |
"((a*a+3*b-c**2)**2).simplify()" | |
], | |
"execution_count": 4, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"1.0*a*a*a*a + 6.0*a*a*b + -2.0*a*a*c*c + 9.0*b*b + -6.0*b*c*c + 1.0*c*c*c*c" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 4 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "dmabx2_6Nw7S", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 50 | |
}, | |
"outputId": "db7a05a4-9bea-4602-b329-aacfcc1790fd" | |
}, | |
"source": [ | |
"# some utilities of numpy works perfectly with symbols\n", | |
"simplify(a*np.ones([2,2]))" | |
], | |
"execution_count": 7, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"array([[1.0*a, 1.0*a],\n", | |
" [1.0*a, 1.0*a]], dtype=object)" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 7 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "vvlADgQPNzIy", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 50 | |
}, | |
"outputId": "87cf4641-4eab-4cab-c4ca-d88a32a3db7c" | |
}, | |
"source": [ | |
"simplify(np.array([[a,b],[b,a]])@np.array([[1,2],[c,c]]))" | |
], | |
"execution_count": 8, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"array([[1.0*a + 1.0*b*c, 2.0*a + 1.0*b*c],\n", | |
" [1.0*b + 1.0*a*c, 2.0*b + 1.0*a*c]], dtype=object)" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 8 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "Iz3rbMF_N4Xr", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 50 | |
}, | |
"outputId": "9dd11605-93d0-4dc8-e90a-67ac3de16b0b" | |
}, | |
"source": [ | |
"## but einsum doesn't work that well, if the call is invoked with backend.einsum beynond ternsordot\n", | |
"## successful case\n", | |
"simplify(contract(\"ij,jk->ik\", np.array([[a,b],[1.1,2]]), np.array([[c+1,c],[c,0.]])))" | |
], | |
"execution_count": 10, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"array([[1.0*a*c + 1.0*a + 1.0*b*c, 1.0*a*c],\n", | |
" [3.1*c + 1.1*I, 1.1*c]], dtype=object)" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 10 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "3Q-eP7CzOM3u", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 391 | |
}, | |
"outputId": "4b6bc573-efc2-47a1-8dd7-365586460484" | |
}, | |
"source": [ | |
"# the above example succeed because it only call tensordot which supports symbolic objects in python level\n", | |
"# and if there is no blas contraction in the middle, the fallback call to c_einsum is invoked which is purely in C and cannot handle python duck types\n", | |
"contract(\"ij->\", np.array([[a,a]]))" | |
], | |
"execution_count": 11, | |
"outputs": [ | |
{ | |
"output_type": "error", | |
"ename": "TypeError", | |
"evalue": "ignored", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", | |
"\u001b[0;32m<ipython-input-11-8d2ca37c8f33>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# the above example succeed because it only call tensordot which supports symbolic objects in python level\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;31m# and if there is no blas contraction in the middle, the fallback call to c_einsum is invoked which is purely in C and cannot handle python duck types\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mcontract\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"ij->\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", | |
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/opt_einsum/contract.py\u001b[0m in \u001b[0;36mcontract\u001b[0;34m(*operands, **kwargs)\u001b[0m\n\u001b[1;32m 481\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mContractExpression\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfull_str\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcontraction_list\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconstants_dict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0meinsum_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 482\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 483\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0m_core_contract\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moperands\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcontraction_list\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbackend\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbackend\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0meinsum_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 484\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 485\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/opt_einsum/contract.py\u001b[0m in \u001b[0;36m_core_contract\u001b[0;34m(operands, contraction_list, backend, evaluate_constants, **einsum_kwargs)\u001b[0m\n\u001b[1;32m 565\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 566\u001b[0m \u001b[0;31m# Do the contraction\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 567\u001b[0;31m \u001b[0mnew_view\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_einsum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0meinsum_str\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mtmp_operands\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbackend\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbackend\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0meinsum_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 568\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 569\u001b[0m \u001b[0;31m# Append new items and dereference what we can\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/opt_einsum/sharing.py\u001b[0m in \u001b[0;36mcached_einsum\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 149\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcached_einsum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 150\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mcurrently_sharing\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 151\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0meinsum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 152\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 153\u001b[0m \u001b[0;31m# hash modulo commutativity by computing a canonical ordering and names\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/opt_einsum/contract.py\u001b[0m in \u001b[0;36m_einsum\u001b[0;34m(*operands, **kwargs)\u001b[0m\n\u001b[1;32m 338\u001b[0m \u001b[0meinsum_str\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mparser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconvert_to_valid_einsum_chars\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0meinsum_str\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 339\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 340\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0meinsum_str\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0moperands\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 341\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 342\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m<__array_function__ internals>\u001b[0m in \u001b[0;36meinsum\u001b[0;34m(*args, **kwargs)\u001b[0m\n", | |
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/numpy/core/einsumfunc.py\u001b[0m in \u001b[0;36meinsum\u001b[0;34m(*operands, **kwargs)\u001b[0m\n\u001b[1;32m 1354\u001b[0m \u001b[0;31m# If no optimization, run pure einsum\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1355\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0moptimize_arg\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1356\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mc_einsum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0moperands\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1357\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1358\u001b[0m \u001b[0mvalid_einsum_kwargs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m'out'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'dtype'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'order'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'casting'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;31mTypeError\u001b[0m: invalid data type for einsum" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "37gbOpwIPM89", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment