Skip to content

Instantly share code, notes, and snippets.

@dkohlsdorf
Created March 15, 2020 19:30
Show Gist options
  • Save dkohlsdorf/3dc5c4d139cd706eb7fdb89f7a5c0a3f to your computer and use it in GitHub Desktop.
Save dkohlsdorf/3dc5c4d139cd706eb7fdb89f7a5c0a3f to your computer and use it in GitHub Desktop.
traverse agglomerative clustering sklearn
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 79,
"metadata": {},
"outputs": [],
"source": [
"import pickle as pkl\n",
"import numpy as np\n",
"import itertools\n",
"from collections import namedtuple"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# From the Documentation\n",
"```\n",
"children_array-like of shape (n_samples-1, 2)\n",
"The children of each non-leaf node. \n",
"Values less than n_samples correspond to leaves of the tree which are the original samples. \n",
"A node i greater than or equal to n_samples is a non-leaf node and has children children_[i - n_samples].\n",
"Alternatively at the i-th iteration, children[i][0] and children[i][1] are merged to form node n_samples + i\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/daniel.kohlsdorf/anaconda3/lib/python3.7/site-packages/sklearn/base.py:318: UserWarning: Trying to unpickle estimator AgglomerativeClustering from version 0.21.3 when using version 0.22.2.post1. This might lead to breaking code or invalid results. Use at your own risk.\n",
" UserWarning)\n"
]
}
],
"source": [
"agg = pkl.load(open('models/v2_lstm_v5/agg.pkl', 'rb'))\n",
"tree = agg.children_ \n",
"n_samples = agg.n_leaves_"
]
},
{
"cell_type": "code",
"execution_count": 61,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"13897\n",
"11274\n"
]
}
],
"source": [
"ii = itertools.count(n_samples)\n",
"trees = {}\n",
"for x in tree:\n",
" trees[next(ii)] = {'left': x[0], 'right':x[1]}\n",
"print(len(trees))\n",
"print(agg.n_clusters_)"
]
},
{
"cell_type": "code",
"execution_count": 108,
"metadata": {},
"outputs": [],
"source": [
"class Node(namedtuple(\"Node\", \"i l r\")):\n",
" \n",
" def depth(self):\n",
" l = 0\n",
" r = 0\n",
" if isinstance(self.l, Node):\n",
" l = 1 + self.l.depth() \n",
" if isinstance(self.r, Node):\n",
" r = 1 + self.r.depth() \n",
" return max(l, r)\n",
" \n",
" \n",
"def iterate_tree(node, trees, n_leafs, closed):\n",
" closed.add(node)\n",
" left = trees[node]['left']\n",
" right = trees[node]['right']\n",
" if left >= n_leafs:\n",
" left = iterate_tree(left, trees, n_leafs, closed)\n",
" if right >= n_leafs:\n",
" right = iterate_tree(right, trees, n_leafs, closed)\n",
" return Node(node, left, right)"
]
},
{
"cell_type": "code",
"execution_count": 109,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
".\n"
]
},
{
"data": {
"text/plain": [
"444"
]
},
"execution_count": 109,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"closed = set([])\n",
"openl = [k for k,_ in trees.items()]\n",
"clusters = []\n",
"while len(openl) > 0:\n",
" print('.')\n",
" max_node = max(openl)\n",
" cluster = iterate_tree(max_node, trees, n_samples, closed)\n",
" openl = [k for k,_ in trees.items() if k not in closed]\n",
" clusters.append(cluster)\n",
"assert(len(closed) == len(trees))\n",
"clusters[0].depth()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment