Skip to content

Commit

Permalink
TraversalModel: avoid changing source/target data type
Browse files Browse the repository at this point in the history
- the old implementation led to a int->float->int conversion which caused
  64bit integer IDs to be mangled
- also add warning about non-numeric source/target IDs
  • Loading branch information
schlegelp committed Aug 5, 2023
1 parent bd92718 commit dec0434
Showing 1 changed file with 55 additions and 29 deletions.
84 changes: 55 additions & 29 deletions navis/models/network_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
# Set up logging
logger = config.get_logger(__name__)

__all__ = ['BayesianTraversalModel', 'TraversalModel', 'linear_activation_p', 'random_linear_activation_function']
__all__ = ['BayesianTraversalModel', 'TraversalModel', 'linear_activation_p',
'random_linear_activation_function']


class BaseNetworkModel:
Expand All @@ -39,6 +40,11 @@ def __init__(self, edges: pd.DataFrame, source: str, target: str):
self.source = source
self.target = target

if (self.edges.dtypes[source] == object) or (self.edges.dtypes[target] == object):
logger.warning('Looks like sources and/or targets in your edge list '
'might be strings? This can massively slow down '
'computations. If at all possible try to use numeric IDs.')

@property
def n_nodes(self) -> int:
"""Return unique nodes in network."""
Expand Down Expand Up @@ -247,67 +253,87 @@ def run(self, iterations: int = 100, return_iterations=False, **kwargs) -> pd.Da
# For some reason this is required for progress bars in Jupyter to show
print(' ', end='', flush=True)
# For faster access, use the raw array
edges = self.edges[[self.source, self.target, self.weights]].values
# Note: we're splitting the columns in case we have different datatypes
# (e.g. int64 for IDs and float for weights)
sources = self.edges[self.source].values
targets = self.edges[self.target].values
weights = self.edges[self.weights].values

# For some reason the progress bar does not show unless we have a print here
all_trav = None
all_enc_nodes = None
for it in config.trange(1, iterations + 1,
disable=config.pbar_hide,
leave=config.pbar_leave,
position=kwargs.get('position', 0)):
# Set seeds as encountered in step 1
if not return_iterations:
enc = np.array([[1, s] for s in self.seeds])
else:
enc = np.array([[1, s, it] for s in self.seeds])
enc_nodes = self.seeds
enc_steps = np.repeat(1, len(self.seeds))
if return_iterations:
enc_it = np.repeat(it, len(self.seeds))

# Start with all edges
this_edges = edges
this_weights = weights
this_sources = sources
this_targets = targets
for i in range(2, self.max_steps + 1):
# Which edges have their presynaptic node already traversed?
pre_trav = np.isin(this_edges[:, 0], enc[:, 1]) # 21
pre_trav = np.isin(this_sources, enc_nodes)
# Among those, which edges have the postsynaptic node traversed?
post_trav = np.isin(this_edges[pre_trav, 1], enc[:, 1]) # 63
post_trav = np.isin(this_targets[pre_trav], enc_nodes)

# Combine conditions to find edges where the presynaptic node
# has been traversed but not the postsynaptic node
pre_not_post = np.where(pre_trav)[0][~post_trav]
out_edges = this_edges[pre_not_post]
out_targets = this_targets[pre_not_post]
out_weights = this_weights[pre_not_post]

# Drop edges that have already been traversed - speeds up things
pre_and_post = np.where(pre_trav)[0][post_trav]
this_edges = np.delete(this_edges, pre_and_post, axis=0)
this_targets = np.delete(this_targets, pre_and_post, axis=0)
this_sources = np.delete(this_sources, pre_and_post, axis=0)
this_weights = np.delete(this_weights, pre_and_post, axis=0)

# Stop if we traversed the entire (reachable) graph
if out_edges.size == 0:
if out_targets.size == 0:
break

# Collect weights
w = out_edges[:, 2]

# Edges traversed in this round
trav_edges = out_edges[self.traversal_func(w)]
trav = self.traversal_func(out_weights)
if not trav.sum():
continue

trav_targets = out_targets[trav]

# Keep track
if not trav_edges.size == 0:
new_trav = np.unique(trav_edges[:, 1]).astype(int)
if not return_iterations:
enc = np.concatenate((enc, [[i, b] for b in new_trav]), axis=0)
else:
enc = np.concatenate((enc, [[i, b, it] for b in new_trav]), axis=0)
new_trav = np.unique(trav_targets)

# Store results
enc_nodes = np.concatenate((enc_nodes, new_trav))
enc_steps = np.concatenate((enc_steps, np.repeat(i, len(new_trav))))
if return_iterations:
enc_it = np.concatenate((enc_it, np.repeat(it, len(new_trav))))

# Save this round of traversal
if not isinstance(all_trav, np.ndarray):
all_trav = enc
if all_enc_nodes is None:
all_enc_nodes = enc_nodes
all_enc_steps = enc_steps
if return_iterations:
all_enc_it = enc_it
else:
all_trav = np.concatenate((all_trav, enc), axis=0)
all_enc_nodes = np.concatenate((all_enc_nodes, enc_nodes))
all_enc_steps = np.concatenate((all_enc_steps, enc_steps))
if return_iterations:
all_enc_it = np.concatenate((all_enc_it, enc_it))

self.iterations = iterations

cols = ['steps', 'node']
# Combine results into DataFrame
self.results = pd.DataFrame()
self.results['steps'] = all_enc_steps
self.results['node'] = all_enc_nodes
if return_iterations:
cols += ['iteration']
self.results = pd.DataFrame(all_trav, columns=cols).astype(int)
self.results['iteration'] = all_enc_it

return self.results


Expand Down

0 comments on commit dec0434

Please sign in to comment.