Created
March 20, 2017 00:41
-
-
Save yamaguchiyuto/182dc366f291e66a2a902f5efa1809e1 to your computer and use it in GitHub Desktop.
JTM experiments
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 40, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"from jtm import JTM" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 41, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"vocab = ['apple', 'banana', 'computer', 'mac', 'burger', 'ipad']\n", | |
"categories = ['PC', 'FOOD']\n", | |
"\n", | |
"X = [[],[]]\n", | |
"X[0].append([0,0,2,3,3,5]) # apple, apple, computer, mac, mac, ipad\n", | |
"X[1].append([0]) # PC\n", | |
"X[0].append([0,0,1,3,3,4]) # apple, apple, banana, mac, mac, burger\n", | |
"X[1].append([1]) # FOOD\n", | |
"X[0].append([2,2,5,5]) # computer, computer, ipad, ipad\n", | |
"X[1].append([0]) # PC\n", | |
"X[0].append([1,1,4,4]) # banana, banana, burger, burger\n", | |
"X[1].append([1]) # FOOD\n", | |
"X[0].append([0,3]) # apple, mac\n", | |
"X[1].append([0]) # PC\n", | |
"X[0].append([0,3]) # apple, mac\n", | |
"X[1].append([1]) # FOOD\n", | |
"\n", | |
"V = [len(vocab), len(categories)]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 42, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"# モデルの定義\n", | |
"K = 2\n", | |
"alpha=0.01\n", | |
"beta=0.01\n", | |
"max_iter=1000\n", | |
"model = JTM(K=K, alpha=alpha, beta=beta, max_iter=max_iter)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 43, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<jtm.JTM instance at 0x10cd1c638>" | |
] | |
}, | |
"execution_count": 43, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# JTMのフィッティング\n", | |
"model.fit(X,V)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 44, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[array([1, 1, 1, 1, 1, 1]), array([0, 0, 0, 0, 0, 0]), array([1, 1, 1, 1]), array([0, 0, 0, 0]), array([1, 1]), array([0, 0])]\n", | |
"[array([1]), array([0]), array([1]), array([0]), array([1]), array([0])]\n" | |
] | |
} | |
], | |
"source": [ | |
"# トピック割り当ての結果\n", | |
"# 最後の2つのドキュメントは出現単語は同じだけどカテゴリの情報をうまく使って正しいトピックに割り当てられている\n", | |
"print model.Z[0]\n", | |
"print model.Z[1]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 45, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<jtm.JTM instance at 0x10cd1c638>" | |
] | |
}, | |
"execution_count": 45, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# LDAのフィッティング\n", | |
"# カテゴリの情報(side information)を与えなければJTMはLDAと等価\n", | |
"X = [X[0]]\n", | |
"V = [V[0]]\n", | |
"model.fit(X,V)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 46, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[array([1, 1, 0, 1, 1, 0]), array([1, 1, 1, 1, 1, 1]), array([0, 0, 0, 0]), array([1, 1, 1, 1]), array([1, 1]), array([1, 1])]\n" | |
] | |
} | |
], | |
"source": [ | |
"# トピック割り当ての結果\n", | |
"# カテゴリ情報が与えられていないので、普通のLDAでは当然最後の2つのドキュメントを正しくトピック割り当て出来ない\n", | |
"print model.Z[0]" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 2", | |
"language": "python", | |
"name": "python2" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 2 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython2", | |
"version": "2.7.11" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment