Created
March 21, 2017 09:20
-
-
Save yamaguchiyuto/eb2431dee0f77269bd93be19ab357a0e to your computer and use it in GitHub Desktop.
CTM experiments
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"from jtm import JTM # Joint Topic Model\n", | |
"from ctm import CTM # Correspondence Topic Model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"# データ生成\n", | |
"vocab = {'computer':0, 'banana':1, 'ipad':2, 'orange':3, 'windows':4}\n", | |
"categories = {'FOODS':0, 'TECH':1}\n", | |
"\n", | |
"Vw = len(vocab)\n", | |
"Vx = len(categories)\n", | |
"\n", | |
"W = [] # words\n", | |
"X = [] # categories\n", | |
"\n", | |
"W.append(map(lambda v:vocab[v], ['computer', 'ipad', 'ipad']))\n", | |
"X.append(map(lambda c:categories[c], ['TECH']))\n", | |
"W.append(map(lambda v:vocab[v], ['ipad', 'ipad']))\n", | |
"X.append(map(lambda c:categories[c], ['TECH']))\n", | |
"W.append(map(lambda v:vocab[v], ['ipad', 'ipad']))\n", | |
"X.append(map(lambda c:categories[c], ['TECH']))\n", | |
"W.append(map(lambda v:vocab[v], ['ipad', 'ipad']))\n", | |
"X.append(map(lambda c:categories[c], ['TECH']))\n", | |
"W.append(map(lambda v:vocab[v], ['banana', 'orange']))\n", | |
"X.append(map(lambda c:categories[c], ['FOODS']))\n", | |
"W.append(map(lambda v:vocab[v], ['banana', 'orange']))\n", | |
"X.append(map(lambda c:categories[c], ['FOODS']))\n", | |
"W.append(map(lambda v:vocab[v], ['windows', 'windows']))\n", | |
"X.append(map(lambda c:categories[c], ['TECH']))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# モデル定義\n", | |
"K = 3\n", | |
"alpha = 0.1\n", | |
"beta= 0.1\n", | |
"gamma= 0.1\n", | |
"max_iter = 100\n", | |
"\n", | |
"jtm = JTM(K=K, alpha=alpha, beta=beta, max_iter=max_iter)\n", | |
"ctm = CTM(K=K, alpha=alpha, beta=beta, gamma=gamma, max_iter=max_iter)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<ctm.CTM instance at 0x10a7a7fc8>" | |
] | |
}, | |
"execution_count": 4, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# CTMのフィッティング\n", | |
"ctm.fit(W,X,Vw,Vx)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[array([2, 2, 2]), array([2, 2]), array([2, 2]), array([2, 2]), array([1, 1]), array([1, 1]), array([0, 0])]\n", | |
"[array([2]), array([2]), array([2]), array([2]), array([1]), array([1]), array([0])]\n" | |
] | |
} | |
], | |
"source": [ | |
"# うまいことトピック割当できているし、\n", | |
"# 単語に割り当てられていないトピックが付加情報に割り当てられることもない\n", | |
"print ctm.Z\n", | |
"print ctm.Y" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<jtm.JTM instance at 0x10a7f1098>" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# JTMのフィティング\n", | |
"jtm.fit([W,X],[Vw,Vx])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[array([1, 0, 0]), array([0, 0]), array([0, 0]), array([0, 0]), array([2, 2]), array([2, 2]), array([1, 1])]\n", | |
"[array([0]), array([0]), array([0]), array([0]), array([2]), array([2]), array([0])]\n" | |
] | |
} | |
], | |
"source": [ | |
"# あまりうまくトピック割り当てができていない\n", | |
"# 最後の文書ではトピック 0 は単語に割り当てられていないのに付加情報に割り当てられてしまっている\n", | |
"print jtm.Z[0]\n", | |
"print jtm.Z[1]" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 2", | |
"language": "python", | |
"name": "python2" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 2 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython2", | |
"version": "2.7.11" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment