Created
March 15, 2020 19:28
-
-
Save dkohlsdorf/16d044e85c385401dbd7e5a8326a708a to your computer and use it in GitHub Desktop.
Recursive Auto Encoder With Nodes
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import tensorflow as tf\n", | |
"import numpy as np\n", | |
"import os\n", | |
"import pandas as pd\n", | |
"import pickle as pkl" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 95, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def merge_encoder(n_in):\n", | |
" a = tf.keras.layers.Input(n_in)\n", | |
" b = tf.keras.layers.Input(n_in)\n", | |
" c = tf.keras.layers.Concatenate()([a,b])\n", | |
" h = tf.keras.layers.Dense(n_in, activation='relu')(c)\n", | |
" o = tf.keras.layers.Dense(n_in * 2)(h)\n", | |
" merge = tf.keras.models.Model(inputs=[a, b], outputs=[h, c, o])\n", | |
" merge.summary()\n", | |
" return merge\n", | |
"\n", | |
"class Node:\n", | |
" \n", | |
" def __init__(self, i, embedding, score, payload, l = None, r = None):\n", | |
" self.i = i\n", | |
" self.score = score\n", | |
" self.embedding = embedding\n", | |
" self.left = l\n", | |
" self.right = r\n", | |
" self.payload = payload\n", | |
" \n", | |
" def print(self, offset=\"\"):\n", | |
" print(\"{} {} {} {}\".format(offset, self.i, self.score, np.mean(self.embeding)))\n", | |
" if self.left is not None and self.right is not None:\n", | |
" self.left.print(offset + \"\\t\")\n", | |
" self.right.print(offset + \"\\t\")\n", | |
"\n", | |
" def merge(self, other, merger):\n", | |
" merged = merger([self.embedding, other.embedding])\n", | |
" h = merged[0]\n", | |
" c = merged[1]\n", | |
" y = merged[2]\n", | |
" #score = tf.nn.l2_loss(y - c) + self.score + other.score\n", | |
" score = tf.nn.softmax_cross_entropy_with_logits(c, y) + self.score + other.score\n", | |
" return Node(-1, h, score, self, other)\n", | |
"\n", | |
"def ts2leafs(df):\n", | |
" sequence = []\n", | |
" for i, row in df.iterrows():\n", | |
" node = Node(i, row['token'], tf.constant(0.0), row)\n", | |
" sequence.append(node)\n", | |
" return sequence\n", | |
"\n", | |
"def merge(x, m):\n", | |
" while len(x) > 1: \n", | |
" min_loss = float('inf')\n", | |
" min_node = None\n", | |
" min_i = 0\n", | |
" min_j = 0\n", | |
" for i in range(len(x)):\n", | |
" for j in range(len(x)):\n", | |
" if i < j:\n", | |
" node = x[i].merge(x[j], m)\n", | |
" if node.score < min_loss:\n", | |
" min_node = node\n", | |
" min_loss = node.score\n", | |
" min_i = i\n", | |
" min_j = j\n", | |
" print(\"Merge: {} {}\".format(min_i, min_j))\n", | |
" x[min_i] = min_node\n", | |
" x = [x[idx] for idx in range(0, len(x)) if idx != min_j]\n", | |
" return x[0]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Merge: 10 65\n", | |
"Merge: 10 14\n", | |
"Merge: 9 85\n", | |
"Merge: 76 101\n", | |
"Merge: 6 95\n", | |
"Merge: 10 85\n", | |
"Merge: 19 31\n", | |
"Merge: 37 54\n", | |
"Merge: 13 52\n", | |
"Merge: 4 35\n", | |
"Merge: 47 73\n", | |
"Merge: 25 53\n", | |
"Merge: 2 93\n", | |
"Merge: 16 96\n", | |
"Merge: 3 73\n", | |
"Merge: 11 26\n", | |
"Merge: 14 89\n", | |
"Merge: 64 78\n", | |
"Merge: 43 50\n", | |
"Merge: 30 78\n", | |
"Merge: 5 53\n", | |
"Merge: 40 73\n", | |
"Merge: 21 72\n", | |
"Merge: 70 74\n", | |
"Merge: 6 29\n", | |
"Merge: 9 58\n", | |
"Merge: 37 55\n", | |
"Merge: 78 83\n", | |
"Merge: 35 38\n", | |
"Merge: 52 71\n", | |
"Merge: 24 33\n", | |
"Merge: 62 74\n", | |
"Merge: 28 30\n", | |
"Merge: 49 58\n", | |
"Merge: 38 43\n", | |
"Merge: 41 45\n", | |
"Merge: 66 74\n", | |
"Merge: 30 57\n", | |
"Merge: 55 62\n", | |
"Merge: 27 56\n", | |
"Merge: 7 17\n", | |
"Merge: 67 69\n", | |
"Merge: 1 22\n", | |
"Merge: 12 29\n", | |
"Merge: 40 42\n", | |
"Merge: 48 54\n", | |
"Merge: 60 61\n", | |
"Merge: 2 17\n", | |
"Merge: 18 55\n", | |
"Merge: 44 52\n", | |
"Merge: 46 60\n", | |
"Merge: 4 41\n", | |
"Merge: 22 23\n", | |
"Merge: 16 51\n", | |
"Merge: 39 43\n", | |
"Merge: 45 50\n", | |
"Merge: 8 29\n", | |
"Merge: 17 34\n", | |
"Merge: 20 49\n", | |
"Merge: 0 13\n", | |
"Merge: 10 34\n", | |
"Merge: 3 14\n", | |
"Merge: 11 29\n", | |
"Merge: 25 34\n", | |
"Merge: 27 34\n", | |
"Merge: 9 30\n", | |
"Merge: 13 17\n", | |
"Merge: 29 36\n", | |
"Merge: 5 25\n", | |
"Merge: 27 36\n", | |
"Merge: 18 36\n", | |
"Merge: 22 38\n", | |
"Merge: 21 26\n", | |
"Merge: 20 28\n", | |
"Merge: 6 29\n", | |
"Merge: 16 34\n", | |
"Merge: 12 33\n", | |
"Merge: 30 32\n", | |
"Merge: 7 32\n", | |
"Merge: 4 29\n", | |
"Merge: 1 2\n", | |
"Merge: 22 29\n", | |
"Merge: 7 9\n", | |
"Merge: 0 15\n", | |
"Merge: 16 25\n", | |
"Merge: 12 13\n", | |
"Merge: 20 21\n", | |
"Merge: 8 11\n", | |
"Merge: 2 9\n", | |
"Merge: 4 12\n", | |
"Merge: 18 19\n", | |
"Merge: 13 15\n", | |
"Merge: 5 14\n", | |
"Merge: 9 11\n", | |
"Merge: 6 16\n", | |
"Merge: 3 13\n", | |
"Merge: 1 7\n", | |
"Merge: 0 10\n", | |
"Merge: 9 11\n", | |
"Merge: 4 7\n", | |
"Merge: 2 10\n", | |
"Merge: 5 9\n", | |
"Merge: 6 7\n", | |
"Merge: 1 3\n", | |
"Merge: 0 6\n", | |
"Merge: 2 3\n", | |
"Merge: 3 4\n", | |
"Merge: 0 1\n", | |
"Merge: 1 2\n", | |
"Merge: 0 1\n", | |
"done merging: [1285.6406]\n", | |
"Epoch: 5\n", | |
"Merge: 10 65\n", | |
"Merge: 10 14\n", | |
"Merge: 9 85\n", | |
"Merge: 76 101\n", | |
"Merge: 6 95\n", | |
"Merge: 10 85\n", | |
"Merge: 19 31\n", | |
"Merge: 37 54\n", | |
"Merge: 13 52\n", | |
"Merge: 4 35\n", | |
"Merge: 47 73\n", | |
"Merge: 25 53\n", | |
"Merge: 2 93\n", | |
"Merge: 16 96\n", | |
"Merge: 3 73\n", | |
"Merge: 11 26\n", | |
"Merge: 14 89\n", | |
"Merge: 64 78\n", | |
"Merge: 43 50\n", | |
"Merge: 30 78\n", | |
"Merge: 5 53\n", | |
"Merge: 40 73\n", | |
"Merge: 6 29\n", | |
"Merge: 21 71\n", | |
"Merge: 69 73\n", | |
"Merge: 9 58\n", | |
"Merge: 37 55\n", | |
"Merge: 78 83\n", | |
"Merge: 35 38\n", | |
"Merge: 52 71\n", | |
"Merge: 24 33\n", | |
"Merge: 62 74\n", | |
"Merge: 28 30\n", | |
"Merge: 49 58\n", | |
"Merge: 38 43\n", | |
"Merge: 41 45\n", | |
"Merge: 66 74\n", | |
"Merge: 30 57\n", | |
"Merge: 55 62\n", | |
"Merge: 27 56\n", | |
"Merge: 7 17\n", | |
"Merge: 67 69\n", | |
"Merge: 1 22\n", | |
"Merge: 12 29\n", | |
"Merge: 40 42\n", | |
"Merge: 48 54\n", | |
"Merge: 60 61\n", | |
"Merge: 2 17\n", | |
"Merge: 18 55\n", | |
"Merge: 44 52\n", | |
"Merge: 46 60\n", | |
"Merge: 4 41\n", | |
"Merge: 22 23\n", | |
"Merge: 16 51\n", | |
"Merge: 39 43\n", | |
"Merge: 45 50\n", | |
"Merge: 8 29\n", | |
"Merge: 17 34\n", | |
"Merge: 20 49\n", | |
"Merge: 0 13\n", | |
"Merge: 10 34\n", | |
"Merge: 3 14\n", | |
"Merge: 11 29\n", | |
"Merge: 25 34\n", | |
"Merge: 27 34\n", | |
"Merge: 9 30\n", | |
"Merge: 13 17\n", | |
"Merge: 29 36\n", | |
"Merge: 5 25\n", | |
"Merge: 27 36\n", | |
"Merge: 18 36\n", | |
"Merge: 22 38\n", | |
"Merge: 21 26\n" | |
] | |
} | |
], | |
"source": [ | |
"df = pd.read_csv('models/v2_lstm_v5/seq_clustering_log_06281101C.csv', names=[\"start\", \"stop\", \"file\", \"cluster\"], header=None)\n", | |
"tokens = dict([(c, i) for i, c in enumerate(sorted(list(set(df['cluster']))))])\n", | |
"bits = int(np.ceil(np.log(len(tokens)) / np.log(2)))\n", | |
"for c, i in tokens.items():\n", | |
" tokens[c] = np.float32([int(c) for c in np.binary_repr(i, width = bits)]).reshape(1, bits)\n", | |
"df['token'] = df['cluster'].apply(lambda x : tokens[x])\n", | |
"\n", | |
"m = merge_encoder(bits)\n", | |
"optimizer = tf.keras.optimizers.Adam()\n", | |
"x = ts2leafs(df)\n", | |
"\n", | |
"print(\"Start Merging\")\n", | |
"node = None\n", | |
"for epoch in range(0, 25):\n", | |
" with tf.GradientTape(watch_accessed_variables=True) as tape:\n", | |
" print(\"Epoch: {}\".format(epoch))\n", | |
" tape.watch(m.variables) \n", | |
" node = merge(x, m)\n", | |
" print(\"done merging: {}\".format(node.score))\n", | |
" g = tape.gradient(node.score, m.variables)\n", | |
" optimizer.apply_gradients(zip(g, m.variables))\n", | |
" pkl.dump(node, open('epoch_{}_merged_{}.pkl'.format(epoch, \"seq_clustering_log_06281101C\"), \"wb\"))\n", | |
"m.save('dolphin_merger.h5')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.3" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment