diff --git a/navis/models/network_models.py b/navis/models/network_models.py index 20fd23e7..e8b050f5 100644 --- a/navis/models/network_models.py +++ b/navis/models/network_models.py @@ -24,7 +24,8 @@ # Set up logging logger = config.get_logger(__name__) -__all__ = ['BayesianTraversalModel', 'TraversalModel', 'linear_activation_p', 'random_linear_activation_function'] +__all__ = ['BayesianTraversalModel', 'TraversalModel', 'linear_activation_p', + 'random_linear_activation_function'] class BaseNetworkModel: @@ -39,6 +40,11 @@ def __init__(self, edges: pd.DataFrame, source: str, target: str): self.source = source self.target = target + if (self.edges.dtypes[source] == object) or (self.edges.dtypes[target] == object): + logger.warning('Looks like sources and/or targets in your edge list ' + 'might be strings? This can massively slow down ' + 'computations. If at all possible try to use numeric IDs.') + @property def n_nodes(self) -> int: """Return unique nodes in network.""" @@ -247,67 +253,87 @@ def run(self, iterations: int = 100, return_iterations=False, **kwargs) -> pd.Da # For some reason this is required for progress bars in Jupyter to show print(' ', end='', flush=True) # For faster access, use the raw array - edges = self.edges[[self.source, self.target, self.weights]].values + # Note: we're splitting the columns in case we have different datatypes + # (e.g. int64 for IDs and float for weights) + sources = self.edges[self.source].values + targets = self.edges[self.target].values + weights = self.edges[self.weights].values # For some reason the progress bar does not show unless we have a print here - all_trav = None + all_enc_nodes = None for it in config.trange(1, iterations + 1, disable=config.pbar_hide, leave=config.pbar_leave, position=kwargs.get('position', 0)): # Set seeds as encountered in step 1 - if not return_iterations: - enc = np.array([[1, s] for s in self.seeds]) - else: - enc = np.array([[1, s, it] for s in self.seeds]) + enc_nodes = self.seeds + enc_steps = np.repeat(1, len(self.seeds)) + if return_iterations: + enc_it = np.repeat(it, len(self.seeds)) # Start with all edges - this_edges = edges + this_weights = weights + this_sources = sources + this_targets = targets for i in range(2, self.max_steps + 1): # Which edges have their presynaptic node already traversed? - pre_trav = np.isin(this_edges[:, 0], enc[:, 1]) # 21 + pre_trav = np.isin(this_sources, enc_nodes) # Among those, which edges have the postsynaptic node traversed? - post_trav = np.isin(this_edges[pre_trav, 1], enc[:, 1]) # 63 + post_trav = np.isin(this_targets[pre_trav], enc_nodes) # Combine conditions to find edges where the presynaptic node # has been traversed but not the postsynaptic node pre_not_post = np.where(pre_trav)[0][~post_trav] - out_edges = this_edges[pre_not_post] + out_targets = this_targets[pre_not_post] + out_weights = this_weights[pre_not_post] # Drop edges that have already been traversed - speeds up things pre_and_post = np.where(pre_trav)[0][post_trav] - this_edges = np.delete(this_edges, pre_and_post, axis=0) + this_targets = np.delete(this_targets, pre_and_post, axis=0) + this_sources = np.delete(this_sources, pre_and_post, axis=0) + this_weights = np.delete(this_weights, pre_and_post, axis=0) # Stop if we traversed the entire (reachable) graph - if out_edges.size == 0: + if out_targets.size == 0: break - # Collect weights - w = out_edges[:, 2] - # Edges traversed in this round - trav_edges = out_edges[self.traversal_func(w)] + trav = self.traversal_func(out_weights) + if not trav.sum(): + continue + + trav_targets = out_targets[trav] # Keep track - if not trav_edges.size == 0: - new_trav = np.unique(trav_edges[:, 1]).astype(int) - if not return_iterations: - enc = np.concatenate((enc, [[i, b] for b in new_trav]), axis=0) - else: - enc = np.concatenate((enc, [[i, b, it] for b in new_trav]), axis=0) + new_trav = np.unique(trav_targets) + + # Store results + enc_nodes = np.concatenate((enc_nodes, new_trav)) + enc_steps = np.concatenate((enc_steps, np.repeat(i, len(new_trav)))) + if return_iterations: + enc_it = np.concatenate((enc_it, np.repeat(it, len(new_trav)))) # Save this round of traversal - if not isinstance(all_trav, np.ndarray): - all_trav = enc + if all_enc_nodes is None: + all_enc_nodes = enc_nodes + all_enc_steps = enc_steps + if return_iterations: + all_enc_it = enc_it else: - all_trav = np.concatenate((all_trav, enc), axis=0) + all_enc_nodes = np.concatenate((all_enc_nodes, enc_nodes)) + all_enc_steps = np.concatenate((all_enc_steps, enc_steps)) + if return_iterations: + all_enc_it = np.concatenate((all_enc_it, enc_it)) self.iterations = iterations - cols = ['steps', 'node'] + # Combine results into DataFrame + self.results = pd.DataFrame() + self.results['steps'] = all_enc_steps + self.results['node'] = all_enc_nodes if return_iterations: - cols += ['iteration'] - self.results = pd.DataFrame(all_trav, columns=cols).astype(int) + self.results['iteration'] = all_enc_it + return self.results