Created
March 15, 2020 19:30
-
-
Save dkohlsdorf/3dc5c4d139cd706eb7fdb89f7a5c0a3f to your computer and use it in GitHub Desktop.
traverse agglomerative clustering sklearn
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 79, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import pickle as pkl\n", | |
"import numpy as np\n", | |
"import itertools\n", | |
"from collections import namedtuple" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# From the Documentation\n", | |
"```\n", | |
"children_array-like of shape (n_samples-1, 2)\n", | |
"The children of each non-leaf node. \n", | |
"Values less than n_samples correspond to leaves of the tree which are the original samples. \n", | |
"A node i greater than or equal to n_samples is a non-leaf node and has children children_[i - n_samples].\n", | |
"Alternatively at the i-th iteration, children[i][0] and children[i][1] are merged to form node n_samples + i\n", | |
"```" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 34, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/Users/daniel.kohlsdorf/anaconda3/lib/python3.7/site-packages/sklearn/base.py:318: UserWarning: Trying to unpickle estimator AgglomerativeClustering from version 0.21.3 when using version 0.22.2.post1. This might lead to breaking code or invalid results. Use at your own risk.\n", | |
" UserWarning)\n" | |
] | |
} | |
], | |
"source": [ | |
"agg = pkl.load(open('models/v2_lstm_v5/agg.pkl', 'rb'))\n", | |
"tree = agg.children_ \n", | |
"n_samples = agg.n_leaves_" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 61, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"13897\n", | |
"11274\n" | |
] | |
} | |
], | |
"source": [ | |
"ii = itertools.count(n_samples)\n", | |
"trees = {}\n", | |
"for x in tree:\n", | |
" trees[next(ii)] = {'left': x[0], 'right':x[1]}\n", | |
"print(len(trees))\n", | |
"print(agg.n_clusters_)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 108, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class Node(namedtuple(\"Node\", \"i l r\")):\n", | |
" \n", | |
" def depth(self):\n", | |
" l = 0\n", | |
" r = 0\n", | |
" if isinstance(self.l, Node):\n", | |
" l = 1 + self.l.depth() \n", | |
" if isinstance(self.r, Node):\n", | |
" r = 1 + self.r.depth() \n", | |
" return max(l, r)\n", | |
" \n", | |
" \n", | |
"def iterate_tree(node, trees, n_leafs, closed):\n", | |
" closed.add(node)\n", | |
" left = trees[node]['left']\n", | |
" right = trees[node]['right']\n", | |
" if left >= n_leafs:\n", | |
" left = iterate_tree(left, trees, n_leafs, closed)\n", | |
" if right >= n_leafs:\n", | |
" right = iterate_tree(right, trees, n_leafs, closed)\n", | |
" return Node(node, left, right)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 109, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
".\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"444" | |
] | |
}, | |
"execution_count": 109, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"closed = set([])\n", | |
"openl = [k for k,_ in trees.items()]\n", | |
"clusters = []\n", | |
"while len(openl) > 0:\n", | |
" print('.')\n", | |
" max_node = max(openl)\n", | |
" cluster = iterate_tree(max_node, trees, n_samples, closed)\n", | |
" openl = [k for k,_ in trees.items() if k not in closed]\n", | |
" clusters.append(cluster)\n", | |
"assert(len(closed) == len(trees))\n", | |
"clusters[0].depth()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.3" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment