Created
March 23, 2017 03:09
-
-
Save yamaguchiyuto/261e275d0d43148c9feaec344bab880d to your computer and use it in GitHub Desktop.
ATM experiment
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"from atm import ATM" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 199, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"# データ生成\n", | |
"vocab = {'computer':0, 'banana':1, 'apple':2}\n", | |
"authors = {'Yamada':0, 'Suzuki':1}\n", | |
"\n", | |
"V = len(vocab)\n", | |
"S = len(authors)\n", | |
"\n", | |
"W = [] # words\n", | |
"A = [] # authors\n", | |
"\n", | |
"W.append(map(lambda v:vocab[v], ['computer', 'computer']))\n", | |
"A.append(map(lambda a:authors[a], ['Yamada']))\n", | |
"W.append(map(lambda v:vocab[v], ['banana', 'banana']))\n", | |
"A.append(map(lambda a:authors[a], ['Suzuki']))\n", | |
"W.append(map(lambda v:vocab[v], ['computer', 'computer', 'apple']))\n", | |
"A.append(map(lambda a:authors[a], ['Yamada']))\n", | |
"W.append(map(lambda v:vocab[v], ['banana', 'banana', 'apple']))\n", | |
"A.append(map(lambda a:authors[a], ['Suzuki']))\n", | |
"W.append(map(lambda v:vocab[v], ['apple']))\n", | |
"A.append(map(lambda a:authors[a], ['Yamada']))\n", | |
"W.append(map(lambda v:vocab[v], ['apple']))\n", | |
"A.append(map(lambda a:authors[a], ['Suzuki']))\n", | |
"W.append(map(lambda v:vocab[v], ['computer', 'computer']))\n", | |
"A.append(map(lambda a:authors[a], ['Suzuki', 'Yamada']))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 200, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"# モデル定義\n", | |
"K = 2\n", | |
"alpha = 0.01\n", | |
"beta= 0.01\n", | |
"max_iter = 1000\n", | |
"\n", | |
"atm = ATM(K=K, alpha=alpha, beta=beta, max_iter=max_iter)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 201, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<atm.ATM instance at 0x10e8a0638>" | |
] | |
}, | |
"execution_count": 201, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# Author Topic Model のフィッティング\n", | |
"atm.fit(W,A,V,S)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 202, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[array([0, 0]), array([1, 1]), array([0, 0, 0]), array([1, 1, 1]), array([0]), array([1]), array([0, 0])]\n", | |
"[array([0, 0]), array([0, 0]), array([0, 0, 0]), array([0, 0, 0]), array([0]), array([0]), array([1, 1])]\n" | |
] | |
} | |
], | |
"source": [ | |
"# 正しくトピックが割り当てられている\n", | |
"# 5番目の文書の \"apple\" はコンピューターの話題ばかり書く山田に書かれたので、コンピュータトピックが割り当てられている\n", | |
"# 6番目の文書の \"apple\" は食べ物の話題ばかり書く鈴木に書かれたので、食べ物トピックが割り当てられている\n", | |
"# 7番目の文書は二人の共著になっているが、鈴木は食べ物の話題しか書けないので、明らかに山田だけで書いたことが分かる\n", | |
"print atm.max_Z\n", | |
"print atm.max_Y" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 2", | |
"language": "python", | |
"name": "python2" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 2 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython2", | |
"version": "2.7.11" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment