Created
October 13, 2016 18:46
-
-
Save leegao/9b1e79229b2dc4273d638ec3ec0b32bf to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import re | |
from itertools import permutations | |
class Instance(object): | |
def __init__(self, groups): | |
self.groups = (tuple(groups[0]), tuple(groups[1])) | |
self.level = len([0 for group in groups for _ in group]) // 3 | |
reserved = ''.join([x for group in groups for triplet in group for x in triplet if x is not 'x']) | |
self.residuals = ''.join([str(i) for i in range(1, 10) for _ in range(self.level - reserved.count(str(i)))]) | |
def is_final(self): | |
if any([x is 'x' for group in self.groups for triplet in group for x in triplet]): return False | |
left, right = [sum([int(triplet) for triplet in group]) for group in self.groups] | |
return left == right | |
def generate_constraints(self): | |
carry = 0, 0 | |
for i in range(1, 4): | |
if any([triplet[-i] is 'x' for group in self.groups for triplet in group]): | |
return tuple([tuple([t[-i] for t in g] + [carry[j]]) for j, g in enumerate(self.groups)]), i | |
else: | |
carry = ([sum([int(triplet[-i:]) for triplet in group]) // (10 ** i) for group in self.groups]) | |
@classmethod | |
def solve_constraints(cls, constraints, residuals): | |
holes = [(i, j) for i, group in enumerate(constraints) for j, x in enumerate(group) if x is 'x'] | |
assert len(residuals) >= len(holes) | |
solutions = set([]) | |
for substitution in permutations(residuals, len(holes)): | |
fulfilled = (list(constraints[0]), list(constraints[1])) | |
for k, (i, j) in enumerate(holes): fulfilled[i][j] = substitution[k] | |
left, right = list(map(lambda x: sum(map(int, x)), fulfilled)) | |
if left % 10 != right % 10: continue | |
solution = (tuple(fulfilled[0]), tuple(fulfilled[1])) | |
if solution not in solutions: | |
yield solution | |
solutions.add(solution) | |
def commit(self, level, fulfilled): | |
new = [[[x for x in triplet] for triplet in group] for group in self.groups] | |
for i, group in enumerate(self.groups): | |
for j, triplet in enumerate(group): | |
new[i][j][-level] = fulfilled[i][j] | |
return Instance([[''.join(triplet) for triplet in group] for group in new]) | |
def __repr__(self): | |
return ' = '.join([' + '.join(foo) for foo in self.groups]) + ' with residuals [%s].' % (self.residuals) | |
def parse(s): | |
groups = list(map(lambda x: list(map(str.strip, x.split('+'))), re.match(r'(.+)=(.+)', s).groups())) | |
return Instance(groups) | |
def solve(instance): | |
if instance.is_final(): return instance | |
if not instance.residuals: return | |
constraints, level = instance.generate_constraints() | |
for solution in Instance.solve_constraints(constraints, instance.residuals): | |
result = solve(instance.commit(level, solution)) | |
if result: return result | |
print(solve(parse('xxx + xxx + xxx + x29 + 821 = xxx + xxx + 8xx + 867'))) | |
print(solve(parse('xxx + xxx + xxx + 4x1 + 689 = xxx + xxx + x5x + 957'))) | |
print(solve(parse('xxx + xxx + xxx + 64x + 581 = xxx + xxx + xx2 + 623'))) | |
print(solve(parse('xxx + xxx + xxx + x81 + 759 = xxx + xxx + 8xx + 462'))) | |
print(solve(parse('xxx + xxx + xxx + 6x3 + 299 = xxx + xxx + x8x + 423'))) | |
print(solve(parse('xxx + xxx + xxx + 58x + 561 = xxx + xxx + xx7 + 993'))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment