Created
May 3, 2020 17:48
-
-
Save marcofavorito/f0c7db4dc5b58045308a64b157502cbc to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
from copy import copy | |
from typing import Optional | |
from pysat.solvers import Glucose3 | |
def is_sat(clauses) -> bool: | |
g = Glucose3() | |
for c in clauses: | |
g.add_clause(c) | |
return g.solve() | |
def find_all_models(clauses): | |
"""Find all models""" | |
models = [] | |
# a & (b | c) | |
g = Glucose3() | |
for c in clauses: | |
g.add_clause(c) | |
while True: | |
result = g.solve() | |
if not result: | |
break | |
model = g.get_model() | |
models.append(model) | |
# ban solution | |
g.add_clause([-1 * x for x in model]) | |
return models | |
class QuickXPlain: | |
def __init__(self, clauses): | |
g = Glucose3() | |
self.original_clauses = clauses | |
for c in clauses: | |
g.add_clause(c) | |
self.N = g.nof_vars() | |
self.vars = set(range(1, self.N + 1)) | |
def p(self, W): | |
assert all(0 < w <= self.N for w in W) | |
g = Glucose3() | |
# add original clauses | |
for c in self.original_clauses: | |
g.add_clause(c) | |
# add new ban clauses | |
negated_literals = list(map(lambda x: -x, filter(lambda x: x not in W, range(1, self.N + 1)))) | |
for negated_literal in negated_literals: | |
g.add_clause([negated_literal]) | |
result = g.solve() | |
return result | |
def quickxplain(self, b, t, has_set) -> set: | |
if has_set and self.p(b): | |
return set() | |
if len(t) == 1: return t | |
m = math.floor(len(t) / 2) | |
sorted_t = sorted(t) | |
t1, t2 = set(sorted_t[:m]), set(sorted_t[m:]) | |
m2 = self.quickxplain(set.union(b, t1), t2, len(t1) > 0) | |
m1 = self.quickxplain(set.union(b, m2), t1, len(m2) > 0) | |
return m1.union(m2) | |
def find_all_minimal_models(clauses): | |
models = [] | |
if not is_sat(clauses): | |
return models | |
original_clauses = copy(clauses) | |
banned_solutions = [] | |
while True: | |
result = find_minimal_model(original_clauses + banned_solutions) | |
if result is None: | |
break | |
# save and ban minimal model | |
models.append(result) | |
banned_solutions.append([-x for x in result]) | |
return models | |
def find_minimal_model(clauses) -> Optional[set]: | |
if not is_sat(clauses): | |
return None | |
finder = QuickXPlain(clauses) | |
return finder.quickxplain(set(), finder.vars, True) | |
if __name__ == '__main__': | |
# a & (b | c) | |
clauses = [ | |
[1], | |
[2, 3], | |
] | |
# models = find_all_models(clauses) | |
# print(models) | |
models = find_all_minimal_models(clauses) | |
print(models) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment