forked from 00111011/HU-AI-Spring2024
-
Notifications
You must be signed in to change notification settings - Fork 0
/
csp.py
executable file
·148 lines (111 loc) · 4.87 KB
/
csp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
from csp.utils import *
# ___ csp structure ___
class CSP:
"""constraints a function f(A, a, B, b) that returns true
if neighbors A, B satisfy the constraint when they have values A=a, B=b
"""
def __init__(self, variables: list, domains: dict, neighbors: dict, constraints):
variables = variables or list(domains.keys())
self.variables = variables
self.domains = domains
self.neighbors = neighbors
self.constraints = constraints
self.curr_domains = None
self.nassigns = 0
def assign(self, var, val, assignment):
assignment[var] = val
self.nassigns += 1
def unassign(self, var, assignment):
if var in assignment:
del assignment[var]
def nconflicts(self, var, val, assignment):
def conflict(var2):
return var2 in assignment and not self.constraints(var, val, var2, assignment[var2])
return count(conflict(v) for v in self.neighbors[var])
def display(self, assignment):
print(assignment)
def goal_test(self, state):
assignment = dict(state)
return (len(assignment) == len(self.variables)
and all(self.nconflicts(variable, assignment[variable], assignment) == 0
for variable in self.variables))
def support_pruning(self):
"""Make sure we can prune values from domains."""
if self.curr_domains is None:
self.curr_domains = {v: list(self.domains[v]) for v in self.variables}
def suppose(self, var, value):
"""Start accumulating inferences from assuming var=value."""
self.support_pruning()
removals = [(var, a) for a in self.curr_domains[var] if a != value]
self.curr_domains[var] = [value]
return removals
def prune(self, var, value, removals):
self.curr_domains[var].remove(value)
if removals is not None:
removals.append((var, value))
def choices(self, var):
"""Return all values for var that aren't currently ruled out."""
return (self.curr_domains or self.domains)[var]
def restore(self, removals):
"""Undo a supposition and all inferences from it."""
for B, b in removals:
self.curr_domains[B].append(b)
# ___ CSP Backtracking Search ___
# Variable ordering
def first_unassigned_variable(assignment, csp):
return first([var for var in csp.variables if var not in assignment])
def mrv(assignment, csp):
"""Minimum-remaining-values heuristic."""
return min([v for v in csp.variables if v not in assignment],
key=lambda var: num_legal_values(csp, var, assignment))
def num_legal_values(csp, var, assignment):
if csp.curr_domains:
return len(csp.curr_domains[var])
else:
return count(csp.nconflicts(var, val, assignment) == 0 for val in csp.domains[var])
# Value ordering
def unordered_domain_values(var, assignment, csp, probability=0.2):
if random.random() < probability:
return shuffled(csp.choices(var))
return csp.choices(var)
def lcv(var, assignment, csp, probability=0.0):
"""Least-constraining-values heuristic."""
if random.random() < probability:
return shuffled(csp.choices(var))
return sorted(csp.choices(var), key=lambda val: csp.nconflicts(var, val, assignment))
# Inference
def no_inference(csp, var, value, assignment, removals):
return True
def forward_checking(csp: CSP, var, value, assignment, removals):
"""Prune neighbor values inconsistent with var=value."""
csp.support_pruning()
for B in csp.neighbors[var]:
if B not in assignment:
for b in csp.curr_domains[B][:]:
if not csp.constraints(var, value, B, b):
csp.prune(B, b, removals)
if not csp.curr_domains[B]:
return False
return True
# The search
def backtracking_search(csp, assignment: dict={}, select_unassigned_variable=first_unassigned_variable,
order_domain_values=unordered_domain_values, inference=no_inference, probability=0.0):
def backtrack(assignment):
if len(assignment) == len(csp.variables):
return assignment
var = select_unassigned_variable(assignment, csp)
for value in order_domain_values(var, assignment, csp, probability):
if 0 == csp.nconflicts(var, value, assignment):
csp.assign(var, value, assignment)
removals = csp.suppose(var, value)
if inference(csp, var, value, assignment, removals):
result = backtrack(assignment)
if result is not None:
return result
csp.restore(removals)
csp.unassign(var, assignment)
return None
result = backtrack(assignment)
if not (result is None or csp.goal_test(result)):
return None
return result