Created
February 7, 2020 00:01
-
-
Save maropu/fb6c0501469794d5e5f111d8e40f6d0b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# https://qiita.com/9_ties/items/3bdb177384937ddc88df | |
# https://homes.cs.washington.edu/~pedrod/papers/mlj05.pdf | |
import pandas as pd | |
import numpy as np | |
from scipy.special import logsumexp | |
from itertools import product | |
const = ['A', 'B'] | |
preds = [('Smokes', 1), ('Cancer', 1), ('Friends', 2)] # Predicate and arity | |
ground_atoms = [ | |
(p, *args) | |
for p, arity in preds | |
for args in product(const, repeat=arity) | |
] | |
print('=== Ground Atoms ===') | |
print(ground_atoms) | |
formulas = [ | |
# (atom, negation, arity, weight) | |
([('Smokes', (0,)), ('Cancer', (0,))], [1, 0], 1, 1.5), | |
([('Friends', (0,1)), ('Smokes', (0,)), ('Smokes', (1,))], [1, 0, 1], 2, 1.1), | |
([('Friends', (0,1)), ('Smokes', (0,)), ('Smokes', (1,))], [1, 1, 0], 2, 1.1) | |
] | |
ground_formulas = [] | |
for clauses, neg, arity, w in formulas: | |
for args in product(const, repeat=arity): | |
ground_formula = [ | |
(p, *map(lambda i: args[i], v)) | |
for p, v in clauses | |
] | |
ground_formulas.append((ground_formula, neg, w)) | |
print('=== Ground Formulas ===') | |
print(ground_formulas) | |
# Generate all configurations | |
X = pd.DataFrame(columns=ground_atoms, data=list(product([1, 0], repeat=len(ground_atoms)))) | |
# Compute sum_i(w_i*n_i(x)) | |
S = np.zeros(len(X)) | |
for f, neg, w in ground_formulas: | |
S += w * np.logical_xor(X[f], neg).any(1) | |
# Compute partition function | |
logZ = logsumexp(S) | |
# Compute joint probabilities | |
jointP = X.copy() | |
jointP['logP'] = S - logZ | |
print('=== Joint Probability ===') | |
print(jointP) | |
# Examples | |
print('=== P(Friends(A, B)) ===') | |
P_FrAB = np.exp(jointP.groupby([('Friends', 'A', 'B')])['logP'].agg(logsumexp)) | |
print(P_FrAB) | |
print('=== P(Friends(A, B)|Smokes(A)) ===') | |
P_FrAB_SmoA = np.exp(jointP.groupby([('Smokes', 'A'), ('Friends', 'A', 'B')])['logP'].agg(logsumexp)) | |
P_SmoA = np.exp(jointP.groupby([('Smokes', 'A')])['logP'].agg(logsumexp)) | |
print(P_FrAB_SmoA/P_SmoA) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment