Skip to content

Commit

Permalink
slight update
Browse files Browse the repository at this point in the history
  • Loading branch information
xunzheng committed Feb 7, 2020
1 parent e713918 commit 06c84be
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 8 deletions.
16 changes: 10 additions & 6 deletions experiments/expt_twovars.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,20 @@ def run_expt(num_graph, num_data_per_graph, n, d, s0, graph_type, sem_type, w_ra
for ii in tqdm(range(num_graph)):
B_true = utils.simulate_dag(d, s0, graph_type)
W_true = utils.simulate_parameter(B_true, w_ranges=w_ranges)
np.savetxt(f'{expt_name}/graph{ii:05}_W_true.csv', W_true, delimiter=',')
W_true_fn = os.path.join(expt_name, f'graph{ii:05}_W_true.csv')
np.savetxt(W_true_fn, W_true, delimiter=',')
for jj in range(num_data_per_graph):
X = utils.simulate_linear_sem(W_true, n, sem_type, noise_scale=noise_scale)
np.savetxt(f'{expt_name}/graph{ii:05}_data{jj:05}_X.csv', X, delimiter=',')
X_fn = os.path.join(expt_name, f'graph{ii:05}_data{jj:05}_X.csv')
np.savetxt(X_fn, X, delimiter=',')
# notears
W_est = notears.notears_linear_l1(X, lambda1=0, loss_type='l2')
assert utils.is_dag(W_est)
np.savetxt(f'{expt_name}/graph{ii:05}_data{jj:05}_W_notears.csv', W_est, delimiter=',')
W_notears = notears.notears_linear_l1(X, lambda1=0, loss_type='l2')
assert utils.is_dag(W_notears)
W_notears_fn = os.path.join(expt_name, f'graph{ii:05}_data{jj:05}_W_notears.csv')
np.savetxt(W_notears_fn, W_notears, delimiter=',')
# eval
acc = utils.count_accuracy(B_true, W_est != 0)
B_notears = (W_notears != 0)
acc = utils.count_accuracy(B_true, B_notears)
for metric in acc:
perf[metric].append(acc[metric])
# print stats
Expand Down
16 changes: 14 additions & 2 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@ def _simulate_single_equation(X, w, scale):
def count_accuracy(B_true, B_est):
"""Compute various accuracy metrics for B_est.
true positive = predicted association exists in condition in correct direction
reverse = predicted association exists in condition in opposite direction
false positive = predicted association does not exist in condition
Args:
B_true (np.ndarray): [d, d] ground truth graph, {0, 1}
B_est (np.ndarray): [d, d] estimate, {0, 1, -1}, -1 is undirected edge in CPDAG
Expand All @@ -139,8 +143,16 @@ def count_accuracy(B_true, B_est):
shd: undirected extra + undirected missing + reverse
nnz: prediction positive
"""
if ((B_est == -1) & (B_est.T == -1)).any():
raise ValueError('undirected edge should only appear once')
if (B_est == -1).any(): # cpdag
if not ((B_est == 0) | (B_est == 1) | (B_est == -1)).all():
raise ValueError('B_est should take value in {0,1,-1}')
if ((B_est == -1) & (B_est.T == -1)).any():
raise ValueError('undirected edge should only appear once')
else: # dag
if not ((B_est == 0) | (B_est == 1)).all():
raise ValueError('B_est should take value in {0,1}')
if not is_dag(B_est):
raise ValueError('B_est should be a DAG')
d = B_true.shape[0]
# linear index of nonzeros
pred_und = np.flatnonzero(B_est == -1)
Expand Down

0 comments on commit 06c84be

Please sign in to comment.