diff --git a/sorn/sorn.py b/sorn/sorn.py index d34f96e..7bacc94 100644 --- a/sorn/sorn.py +++ b/sorn/sorn.py @@ -4,7 +4,7 @@ import numpy as np import os import random -import concurrent +import concurrent.futures try: from sorn.utils import Initializer @@ -372,44 +372,46 @@ def initialize_plasticity(): return wee, wei, wie, te, ti, x, y - class Aync: - def __init__(self, max_workers=4): - self.plasticity = Plasticity() - self.max_workers = max_workers - def step(self, X, Y, Wee, Wei, Te, freeze): +class Async: + def __init__(self, max_workers=4): + super().__init__() + self.max_workers = max_workers + self.plasticity = Plasticity() - with concurrent.futures.ThreadPoolExecutor( - max_workers=self.max_workers - ) as executor: + def step(self, X, Y, Wee, Wei, Te, freeze): - # STDP - if "stdp" not in freeze: - stdp = executor.submit( - self.plasticity.stdp, Wee, X, cutoff_weights=(0.0, 1.0) - ) - Wee = stdp.result() - - # Intrinsic plasticity - if "ip" not in freeze: - ip = executor.submit(self.plasticity.ip, Te, X) - Te = ip.result() - # Structural plasticity - if "sp" not in freeze: - sp = executor.submit(self.plasticity.structural_plasticity, Wee) - Wee = sp.result() - # iSTDP - if "istdp" not in freeze: - istdp = executor.submit( - self.plasticity.istdp, self.Wei, X, Y, cutoff_weights=(0.0, 1.0) - ) - Wei = istdp.result() + with concurrent.futures.ThreadPoolExecutor( + max_workers=self.max_workers + ) as executor: + + # STDP + if "stdp" not in freeze: + stdp = executor.submit( + self.plasticity.stdp, Wee, X, cutoff_weights=(0.0, 1.0) + ) + Wee = stdp.result() + + # Intrinsic plasticity + if "ip" not in freeze: + ip = executor.submit(self.plasticity.ip, Te, X) + Te = ip.result() + # Structural plasticity + if "sp" not in freeze: + sp = executor.submit(self.plasticity.structural_plasticity, Wee) + Wee = sp.result() + # iSTDP + if "istdp" not in freeze: + istdp = executor.submit( + self.plasticity.istdp, Wei, X, Y, cutoff_weights=(0.0, 1.0) + ) + Wei = istdp.result() - # Synaptic scaling Wee - if "ss" not in freeze: - Wee = self.plasticity.ss(Wee) - Wei = self.plasticity.ss(Wei) - return Wee, Wei, Te + # Synaptic scaling Wee + if "ss" not in freeze: + Wee = self.plasticity.ss(Wee) + Wei = self.plasticity.ss(Wei) + return Wee, Wei, Te class MatrixCollection(Sorn): @@ -909,10 +911,8 @@ def simulate_sorn( y_buffer[:, 0] = Y[i][:, 1] y_buffer[:, 1] = inhibitory_state_yt_buffer.T - Wee[i], Wei[i], Te[i] = ( - Plasticity() - .Async(max_workers=max_workers) - .step(x_buffer, y_buffer, Wee[i], Wei[i], Te[i], self.freeze) + Wee[i], Wei[i], Te[i] = Async(max_workers=max_workers).step( + x_buffer, y_buffer, Wee[i], Wei[i], Te[i], self.freeze ) # Assign the matrices to the matrix collections matrix_collection.weight_matrix(Wee[i], Wei[i], Wie[i], i) @@ -1017,7 +1017,7 @@ def train_sorn( Sorn.time_steps = time_steps self.inputs = np.asarray(inputs) self.freeze = [] if freeze == None else freeze - + self.max_workers = max_workers X_all = [0] * self.time_steps Y_all = [0] * self.time_steps R_all = [0] * self.time_steps @@ -1043,7 +1043,6 @@ def train_sorn( # Buffers to get the resulting x and y vectors at the current time step and update the master matrix x_buffer, y_buffer = np.zeros((Sorn.ne, 2)), np.zeros((Sorn.ni, 2)) - te_buffer, ti_buffer = np.zeros((Sorn.ne, 1)), np.zeros((Sorn.ni, 1)) Wee, Wei, Wie = ( matrix_collection.Wee, @@ -1077,11 +1076,10 @@ def train_sorn( y_buffer[:, 1] = inhibitory_state_yt_buffer.T if self.phase == "plasticity": - Wee[i], Wei[i], Te[i] = ( - Plasticity() - .Async(max_workers=self.max_workers) - .step(x_buffer, y_buffer, Wee[i], Wei[i], Te[i], self.freeze) + Wee[i], Wei[i], Te[i] = Async(max_workers=self.max_workers).step( + x_buffer, y_buffer, Wee[i], Wei[i], Te[i], self.freeze ) + else: # Wee[i], Wei[i], Te[i] remain same pass diff --git a/test_sorn.py b/test_sorn.py index 542b7d6..709dd92 100644 --- a/test_sorn.py +++ b/test_sorn.py @@ -40,7 +40,7 @@ def test_runsorn(self): matrices=None, time_steps=2, noise=True, - nu=num_features + nu=num_features, ), ) # Initilize and resume the simulation of SORN using the state dictionary, state_dict @@ -62,7 +62,8 @@ def test_runsorn(self): phase="plasticity", matrices=state_dict, time_steps=2, - noise=False, freeze=['ip'] + noise=False, + freeze=["ip"], ), ) @@ -74,7 +75,8 @@ def test_runsorn(self): phase="plasticity", matrices=state_dict, time_steps=2, - noise=False, freeze=['stdp','istdp','ss','sp'] + noise=False, + freeze=["stdp", "istdp", "ss", "sp"], ), ) @@ -97,7 +99,8 @@ def test_runsorn(self): phase="training", matrices=state_dict, time_steps=1, - noise=True, freeze=['stdp','istdp','ss','sp'] + noise=True, + freeze=["stdp", "istdp", "ss", "sp"], ), ) @@ -134,8 +137,7 @@ def test_runsorn(self): ) def test_plotter(self): - """Test the Plotter class methods in utils module - """ + """Test the Plotter class methods in utils module""" # Histogram of number of postsynaptic connections per neuron in the excitatory pool self.assertRaises( @@ -210,8 +212,7 @@ def test_plotter(self): ) def test_statistics(self): - """Test the functions in Statistics class - """ + """Test the functions in Statistics class""" # Firing rate of a neuron self.assertRaises( Exception, @@ -272,6 +273,6 @@ def test_statistics(self): ), ) -if __name__ == "__main__": - unittest.main(argv=['first-arg-is-ignored'], exit=False) +if __name__ == "__main__": + unittest.main(argv=["first-arg-is-ignored"], exit=False)