Created
April 6, 2019 00:42
-
-
Save kyunghyuncho/28e5a6b8d1b8394bcd85fcf025d60190 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import argparse | |
import logging | |
from collections import OrderedDict | |
import numpy | |
import torch | |
from torch import nn | |
import pyro | |
import pyro.infer | |
import pyro.infer.mcmc | |
import pyro.distributions as dist | |
pyro.enable_validation(True) | |
def parse_options(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('-num-samples', type=int, default=500) | |
parser.add_argument('-warmup-steps', type=int, default=100) | |
parser.add_argument('-logging-level', type=str, default="None") | |
parser.add_argument('-type', type=str, default="mcmc") | |
parser.add_argument('data', type=str) | |
args = parser.parse_args() | |
options = OrderedDict( | |
{ | |
'num-samples': args.num_samples, | |
'warmup-steps': args.warmup_steps, | |
'logging-level': args.logging_level, | |
'type': args.type, | |
'data': args.data | |
}) | |
return options | |
def model(scores, config): | |
''' | |
prior for each algorithm a | |
p(m_a) = Unif(0, 5) | |
p(s_a) = N(s_a|m_a, 1^2) | |
''' | |
zm = [] | |
for mi in xrange(config['n_models']): | |
mu_ = pyro.sample("model-mean-{}".format(mi), dist.Uniform(0., 5.)) | |
zm.append(pyro.sample("model-{}".format(mi), dist.Normal(mu_, 1.))) | |
''' | |
prior for each turker's bias | |
p(s_t) = N(s_t | 0, 1^2) - i.e., no bias from each turker a priori | |
''' | |
tm = [] | |
for ti in xrange(config['n_turkers']): | |
tm.append(pyro.sample("turker-mean-{}".format(ti), dist.Normal(0., 1.))) | |
''' | |
likelihood mean for each score (algorithm, score) | |
p(s|a, t) = N(s| s_a + s_t, 1^2) | |
''' | |
mu = [] | |
for ii, sc in enumerate(scores): | |
mu.append(zm[sc[1]] + tm[sc[0]]) | |
mu_ = torch.stack(mu) | |
likelihood = pyro.sample("scores", dist.Normal(mu_, 1.)) | |
def do_inference(category, scores, config): | |
data_ = torch.from_numpy(numpy.array([sc[2][category] for sc in scores])).float() | |
conditioned_model = pyro.condition(model, | |
data={"scores": data_}) | |
if config['type'] == "mcmc": | |
''' | |
No U-Turn sampler: automatically adjust the step size for HMC | |
''' | |
nuts_kernel = pyro.infer.mcmc.NUTS(conditioned_model, adapt_step_size=True, step_size=0.1) | |
mcmc_run = pyro.infer.mcmc.MCMC(nuts_kernel, | |
num_samples=config['num-samples'], | |
warmup_steps=config['warmup-steps']).run(scores, config) | |
posterior = pyro.infer.EmpiricalMarginal(mcmc_run, | |
sites=["model-{}".format(mi) for mi in xrange(config['n_models'])]) | |
return posterior.mean, posterior.stddev | |
elif config['type'] == "importance": | |
''' | |
Importance sampling using the prior as a proposal: horrible idea | |
''' | |
marginal = pyro.infer.EmpiricalMarginal( | |
pyro.infer.Importance(conditioned_model, num_samples=config['num-samples']).run(scores, config), | |
sites=["model-{}".format(mi) for mi in xrange(config['n_models'])]) | |
mean = marginal() | |
return mean, 0. * mean | |
def main(): | |
config = parse_options() | |
if config['logging-level'] != "None": | |
logging.basicConfig(level=eval("logging.{}".format(config['logging-level']))) | |
#data = numpy.load('worker_stat.np.npy') | |
data = numpy.load(config['data']) | |
config['n_models'] = data[:,1].max() + 1 | |
config['n_turkers'] = data[:,0].max() + 1 | |
all_scores = [] | |
for dd in data: | |
ti = dd[0] | |
mi = dd[1] | |
sc = dd[2:5] | |
all_scores.append((ti, mi, sc)) | |
cname = ['fluency', 'consistency', 'engagingness'] | |
inferred = [] | |
for ci in xrange(3): | |
mean, stddev = do_inference(ci, all_scores, config) | |
inferred.append((mean.data.numpy(), stddev.data.numpy())) | |
for ci in xrange(3): | |
print "inference score category {}".format(cname[ci]) | |
for mi in xrange(config['n_models']): | |
print "{0:1.2f}+-{1:1.2f}".format(inferred[ci][0][mi],inferred[ci][1][mi]), | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment