-
Notifications
You must be signed in to change notification settings - Fork 8
/
096.py
83 lines (75 loc) · 2.22 KB
/
096.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# time cost = 221 ms ± 2.83 ms
from itertools import product
def exact_cover(X, Y):
X = {j: set() for j in X}
for i, row in Y.items():
for j in row:
X[j].add(i)
return X, Y
def select(X, Y, r):
cols = []
for j in Y[r]:
for i in X[j]:
for k in Y[i]:
if k != j:
X[k].remove(i)
cols.append(X.pop(j))
return cols
def deselect(X, Y, r, cols):
for j in reversed(Y[r]):
X[j] = cols.pop()
for i in X[j]:
for k in Y[i]:
if k != j:
X[k].add(i)
def solve(X, Y, solution):
if not X:
yield list(solution)
else:
c = min(X, key=lambda c: len(X[c]))
for r in list(X[c]):
solution.append(r)
cols = select(X, Y, r)
for s in solve(X, Y, solution):
yield s
deselect(X, Y, r, cols)
solution.pop()
def solve_sudoku(size, grid):
R, C = size
N = R * C
X = ([("rc", rc) for rc in product(range(N), range(N))] +
[("rn", rn) for rn in product(range(N), range(1, N + 1))] +
[("cn", cn) for cn in product(range(N), range(1, N + 1))] +
[("bn", bn) for bn in product(range(N), range(1, N + 1))])
Y = dict()
for r, c, n in product(range(N), range(N), range(1, N + 1)):
b = (r // R) * R + (c // C) # Box number
Y[(r, c, n)] = [
("rc", (r, c)),
("rn", (r, n)),
("cn", (c, n)),
("bn", (b, n))]
X, Y = exact_cover(X, Y)
for i, row in enumerate(grid):
for j, n in enumerate(row):
if n:
select(X, Y, (i, j, n))
for solution in solve(X, Y, []):
for (r, c, n) in solution:
grid[r][c] = n
yield grid
def import_grids():
grids = []
with open('data/pe096.txt') as f:
data = f.readlines()
for i in range(1,492,10):
grid = [[int(x) for x in row.strip()] for row in data[i:i+9]]
grids.append(grid)
return grids
def main():
ans = 0
grids = import_grids()
for grid in grids:
g = list(solve_sudoku((3, 3), grid))[0]
ans += (g[0][0]*100+g[0][1]*10+g[0][2])
return ans