import numpy as np def symdirichlet(alpha, n): v = np.zeros(n)+alpha return np.random.dirichlet(v) def exp_digamma(x): if x < 0.1: return x/100 a = x*x b = a*x c = b*x return x - 0.5 + 1./(24*x) exp_digamma = np.vectorize(exp_digamma) import re wre = re.compile(r"(\w)+") def get_words(text, stop=True): "A simple tokenizer" l = 0 while l < len(text): s = wre.search(text,l) try: w = text[s.start():s.end()].lower() if stop: yield w elif w not in stoplist: yield w l = s.end() except: break class EmLda(object): def __init__(self, docs, nt): self.Nt = nt self.docs = [] self.all_words = [] self.reverse_map = {} self.Nd = 0 for d in docs: doc = [] self.docs.append(doc) self.Nd += 1 for w in get_words(d): if len(w) < 5: continue if not w in self.reverse_map: self.reverse_map[w] = len(self.all_words) self.all_words.append(w) doc.append(self.reverse_map[w]) self.V = len(self.all_words) self.ctopics = np.zeros((self.Nt, self.V)) self.doccounts = np.zeros((self.Nd, self.Nt)) self.ptopics = np.zeros((self.Nt, self.V)) self.docprobs = np.zeros((self.Nd, self.Nt)) self.beta = 1. self.alpha = 1. for d in xrange(self.Nd): for i,w in enumerate(self.docs[d]): zw = symdirichlet(self.beta, self.Nt) self.ctopics.T[w] += zw self.doccounts[d] += zw self.m() def e(self): self.ctopics.fill(0) self.doccounts.fill(0) for d in xrange(self.Nd): for i,w in enumerate(self.docs[d]): zw = self.ptopics.T[w].copy() zw *= self.docprobs[d] zw /= np.sum(zw) self.ctopics.T[w] += zw self.doccounts[d] += zw def m(self): self.ptopics.fill(0) self.docprobs.fill(0) self.ptopics += self.ctopics + self.alpha self.ptopics /= np.sum(self.ctopics, axis=1).reshape((-1,1)) self.docprobs += self.doccounts + self.beta self.docprobs /= np.sum(self.docprobs, axis=1).reshape((-1,1)) def iterate(self): self.e() self.m() def run(self, n): for i in xrange(n): print "iter", i import sys sys.stdout.flush() self.iterate() for w in xrange(100): print "word", self.all_words[w], print "topics", self.ptopics.T[w]/np.sum(self.ptopics.T[w]) for i in xrange(self.Nt): print print print "Topic", i print print_topic(self, self.ptopics[i], 40) class PrLda(EmLda): def __init__(self, *args): super(PrLda, self).__init__(*args) self.Nw = sum(len(d) for d in self.docs) c = 0 self.sigma = 15. def e(self): self.do_lambda() self.do_z() def do_lambda(self): self.ptheta = [[] for i in xrange(self.V)] self.ctopics.fill(0) self.doccounts.fill(0) for d in xrange(self.Nd): for i,w in enumerate(self.docs[d]): zw = self.ptopics.T[w].copy() zw *= self.docprobs[d] zw /= np.sum(zw) self.ptheta[w].append(zw) self.lbda = [] self.ptheta = map(np.array, self.ptheta) for w in xrange(self.V): self.lbda.append(self.optimize_lambda(self.ptheta[w])) def optimize_lambda(self, ptheta, steps=50, lrate=1.): lbda = np.ones(ptheta.shape[1]) lbda /= np.sum(lbda).T lbda *= self.sigma prevobj = np.inf n = 0 while True: obj = np.sum(np.sum(ptheta,axis=0).T*np.exp(-lbda)) if n > 5 and prevobj - obj < 0.01*obj: break n += 1 prevobj = obj # do the gradient descent lbda += lrate*np.sum(ptheta,axis=0)*np.exp(-lbda) # truncate lbda *= lbda>0 # project it into the l1 ball with diameter sigma ll = -np.sort(-lbda) cs = np.argmin(((ll-self.sigma).cumsum()/(np.arange(len(ll))+1.))) >= 0 theta = ll[cs-1] lbda -= theta lbda *= lbda > 0 return lbda def do_z(self): indices = [0 for i in xrange(self.V)] for d in xrange(self.Nd): for i,w in enumerate(self.docs[d]): zw = self.ptheta[w][indices[w]] zw *= np.exp(-self.lbda[w]) zw /= np.sum(zw) indices[w] += 1 self.ctopics.T[w] += zw self.doccounts[d] += zw class VarLda(EmLda): def e(self): self.ctopics.fill(0) self.doccounts.fill(0) for d in xrange(self.Nd): for i,w in enumerate(self.docs[d]): zw = self.ptopics.T[w].copy() zw *= self.docprobs[d] zw = exp_digamma(zw)/exp_digamma(np.sum(zw)) self.ctopics.T[w] += zw self.doccounts[d] += zw def print_topic(model, t, n): s = np.argsort(-t) for w in s[:n]: print " ",model.all_words[w] if __name__ == '__main__': import sys docs = [] nt = int(sys.argv[1]) import os for fname in os.listdir(sys.argv[2]): if not fname.startswith("."): docs.append(file(os.path.join(sys.argv[2],fname)).read()) el = VarLda(docs, nt) el.run(10) el = EmLda(docs, nt) el.run(10) el = PrLda(docs, nt) el.run(10) el = PrLda(docs, nt) el.sigma = 150 el.run(10)