From 6c599a7ed4105935e03b145cece3267a58effea7 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Tue, 28 Nov 2023 12:45:54 +0800 Subject: [PATCH] updates (#550) * [running] fix multiprocessing bugs * fix tests * [doc] update doc * update * [math] add `brainpy.math.gpu_memory_preallocation()` for controlling GPU memory preallocation * [math] `clear_buffer_memory` support to clear array and compilation both * [dyn] compatible old version of `.reset_state()` function * [setup] update installation info --- brainpy/_src/dynsys.py | 52 +++++++++++-------- brainpy/_src/math/environment.py | 48 ++++++++++++++--- brainpy/_src/mixin.py | 13 ----- brainpy/math/environment.py | 1 + .../operator_custom_with_numba.ipynb | 2 +- .../operator_custom_with_taichi.ipynb | 11 +++- setup.py | 13 ++++- 7 files changed, 93 insertions(+), 47 deletions(-) diff --git a/brainpy/_src/dynsys.py b/brainpy/_src/dynsys.py index 00120a666..10d2de792 100644 --- a/brainpy/_src/dynsys.py +++ b/brainpy/_src/dynsys.py @@ -2,8 +2,8 @@ import collections import inspect -import warnings import numbers +import warnings from typing import Union, Dict, Callable, Sequence, Optional, Any import numpy as np @@ -13,7 +13,7 @@ from brainpy._src.deprecations import _update_deprecate_msg from brainpy._src.initialize import parameter, variable_ from brainpy._src.mixin import SupportAutoDelay, Container, SupportInputProj, DelayRegister, _get_delay_tool -from brainpy.errors import NoImplementationError, UnsupportedError, APIChangedError +from brainpy.errors import NoImplementationError, UnsupportedError from brainpy.types import ArrayType, Shape __all__ = [ @@ -27,9 +27,9 @@ 'Dynamic', 'Projection', ] - IonChaDyn = None SLICE_VARS = 'slice_vars' +the_top_layer_reset_state = True def not_implemented(fun): @@ -138,16 +138,12 @@ def update(self, *args, **kwargs): """ raise NotImplementedError('Must implement "update" function by subclass self.') - def reset(self, *args, include_self: bool = False, **kwargs): + def reset(self, *args, **kwargs): """Reset function which reset the whole variables in the model (including its children models). ``reset()`` function is a collective behavior which resets all states in this model. See https://brainpy.readthedocs.io/en/latest/tutorial_toolbox/state_resetting.html for details. - - Args:: - include_self: bool. Reset states including the node self. Please turn on this if the node has - implemented its ".reset_state()" function. """ from brainpy._src.helpers import reset_state reset_state(self, *args, **kwargs) @@ -162,19 +158,6 @@ def reset_state(self, *args, **kwargs): """ pass - # raise APIChangedError( - # ''' - # From version >= 2.4.6, the policy of ``.reset_state()`` has been changed. - # - # 1. If you are resetting all states in a network by calling "net.reset_state()", please use - # "bp.reset_state(net)" function. ".reset_state()" only defines the resetting of local states - # in a local node (excluded its children nodes). - # - # 2. If you does not customize "reset_state()" function for a local node, please implement it in your subclass. - # - # ''' - # ) - def clear_input(self, *args, **kwargs): """Clear the input at the current time step.""" pass @@ -344,14 +327,37 @@ def _compatible_update(self, *args, **kwargs): return ret return update_fun(*args, **kwargs) + def _compatible_reset_state(self, *args, **kwargs): + global the_top_layer_reset_state + the_top_layer_reset_state = False + try: + self.reset(*args, **kwargs) + finally: + the_top_layer_reset_state = True + warnings.warn( + ''' + From version >= 2.4.6, the policy of ``.reset_state()`` has been changed. See https://brainpy.tech/docs/tutorial_toolbox/state_saving_and_loading.html for details. + + 1. If you are resetting all states in a network by calling "net.reset_state(*args, **kwargs)", please use + "bp.reset_state(net, *args, **kwargs)" function, or "net.reset(*args, **kwargs)". + ".reset_state()" only defines the resetting of local states in a local node (excluded its children nodes). + + 2. If you does not customize "reset_state()" function for a local node, please implement it in your subclass. + + ''', + DeprecationWarning + ) + def _get_update_fun(self): return object.__getattribute__(self, 'update') def __getattribute__(self, item): if item == 'update': return self._compatible_update # update function compatible with previous ``update()`` function - else: - return super().__getattribute__(item) + if item == 'reset_state': + if the_top_layer_reset_state: + return self._compatible_reset_state # reset_state function compatible with previous ``reset_state()`` function + return super().__getattribute__(item) def __repr__(self): return f'{self.name}(mode={self.mode})' diff --git a/brainpy/_src/math/environment.py b/brainpy/_src/math/environment.py index eef0361fc..b7a17bb9e 100644 --- a/brainpy/_src/math/environment.py +++ b/brainpy/_src/math/environment.py @@ -9,6 +9,7 @@ import warnings from typing import Any, Callable, TypeVar, cast +import jax from jax import config, numpy as jnp, devices from jax.lib import xla_bridge @@ -682,7 +683,11 @@ def set_host_device_count(n): os.environ["XLA_FLAGS"] = " ".join(["--xla_force_host_platform_device_count={}".format(n)] + xla_flags) -def clear_buffer_memory(platform=None): +def clear_buffer_memory( + platform: str = None, + array: bool = True, + compilation: bool = False +): """Clear all on-device buffers. This function will be very useful when you call models in a Python loop, @@ -697,18 +702,47 @@ def clear_buffer_memory(platform=None): ---------- platform: str The device to clear its memory. + array: bool + Clear all buffer array. + compilation: bool + Clear compilation cache. + """ - for buf in xla_bridge.get_backend(platform=platform).live_buffers(): - buf.delete() + if array: + for buf in xla_bridge.get_backend(platform=platform).live_buffers(): + buf.delete() + if compilation: + jax.clear_caches() -def disable_gpu_memory_preallocation(): - """Disable pre-allocating the GPU memory.""" +def disable_gpu_memory_preallocation(release_memory: bool = True): + """Disable pre-allocating the GPU memory. + + This disables the preallocation behavior. JAX will instead allocate GPU memory as needed, + potentially decreasing the overall memory usage. However, this behavior is more prone to + GPU memory fragmentation, meaning a JAX program that uses most of the available GPU memory + may OOM with preallocation disabled. + + Args: + release_memory: bool. Whether we release memory during the computation. + """ os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' - os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform' + if release_memory: + os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform' def enable_gpu_memory_preallocation(): """Disable pre-allocating the GPU memory.""" os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'true' - os.environ.pop('XLA_PYTHON_CLIENT_ALLOCATOR') + os.environ.pop('XLA_PYTHON_CLIENT_ALLOCATOR', None) + + +def gpu_memory_preallocation(percent: float): + """GPU memory allocation. + + If preallocation is enabled, this makes JAX preallocate ``percent`` of the total GPU memory, + instead of the default 75%. Lowering the amount preallocated can fix OOMs that occur when the JAX program starts. + """ + assert 0. <= percent < 1., f'GPU memory preallocation must be in [0., 1.]. But we got {percent}.' + os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = str(percent) + diff --git a/brainpy/_src/mixin.py b/brainpy/_src/mixin.py index 8ea8a5216..fe7c39940 100644 --- a/brainpy/_src/mixin.py +++ b/brainpy/_src/mixin.py @@ -519,19 +519,6 @@ def __subclasscheck__(self, subclass): return all([issubclass(subclass, cls) for cls in self.__bases__]) -class UnionType2(MixIn): - """Union type for multiple types. - - >>> import brainpy as bp - >>> - >>> isinstance(bp.dyn.Expon(1), JointType[bp.DynamicalSystem, bp.mixin.ParamDesc, bp.mixin.SupportAutoDelay]) - """ - - @classmethod - def __class_getitem__(cls, types: Union[type, Sequence[type]]) -> type: - return _MetaUnionType('UnionType', types, {}) - - if sys.version_info.minor > 8: class _JointGenericAlias(_UnionGenericAlias, _root=True): def __subclasscheck__(self, subclass): diff --git a/brainpy/math/environment.py b/brainpy/math/environment.py index a283cc921..d654a0217 100644 --- a/brainpy/math/environment.py +++ b/brainpy/math/environment.py @@ -30,6 +30,7 @@ clear_buffer_memory as clear_buffer_memory, enable_gpu_memory_preallocation as enable_gpu_memory_preallocation, disable_gpu_memory_preallocation as disable_gpu_memory_preallocation, + gpu_memory_preallocation as gpu_memory_preallocation, ditype as ditype, dftype as dftype, ) diff --git a/docs/tutorial_advanced/operator_custom_with_numba.ipynb b/docs/tutorial_advanced/operator_custom_with_numba.ipynb index 215d41418..b38cd0694 100644 --- a/docs/tutorial_advanced/operator_custom_with_numba.ipynb +++ b/docs/tutorial_advanced/operator_custom_with_numba.ipynb @@ -6,7 +6,7 @@ "collapsed": true }, "source": [ - "# Operator Customization with Numba" + "# CPU Operator Customization with Numba" ] }, { diff --git a/docs/tutorial_advanced/operator_custom_with_taichi.ipynb b/docs/tutorial_advanced/operator_custom_with_taichi.ipynb index 183a8a251..0443aed9d 100644 --- a/docs/tutorial_advanced/operator_custom_with_taichi.ipynb +++ b/docs/tutorial_advanced/operator_custom_with_taichi.ipynb @@ -4,9 +4,18 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Operator Customization with Taichi" + "# CPU and GPU Operator Customization with Taichi" ] }, + { + "cell_type": "markdown", + "source": [ + "This functionality is only available for ``brainpylib>=0.2.0``. " + ], + "metadata": { + "collapsed": false + } + }, { "cell_type": "markdown", "metadata": {}, diff --git a/setup.py b/setup.py index 69c33cdfe..f867e3078 100644 --- a/setup.py +++ b/setup.py @@ -39,7 +39,6 @@ # installation packages packages = find_packages(exclude=['lib*', 'docs', 'tests']) - # setup setup( name='brainpy', @@ -51,13 +50,23 @@ author_email='chao.brain@qq.com', packages=packages, python_requires='>=3.8', - install_requires=['numpy>=1.15', 'jax', 'tqdm', 'msgpack', 'numba'], + install_requires=['numpy>=1.15', 'jax>=0.4.13', 'tqdm', 'msgpack', 'numba'], url='https://github.com/brainpy/BrainPy', project_urls={ "Bug Tracker": "https://github.com/brainpy/BrainPy/issues", "Documentation": "https://brainpy.readthedocs.io/", "Source Code": "https://github.com/brainpy/BrainPy", }, + dependency_links=[ + 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html', + ], + extras_require={ + 'cpu': ['jaxlib>=0.4.13', 'brainpylib'], + 'cuda': ['jax[cuda]', 'brainpylib-cu11x'], + 'cuda11': ['jax[cuda11_local]', 'brainpylib-cu11x'], + 'cuda12': ['jax[cuda12_local]', 'brainpylib-cu12x'], + 'tpu': ['jax[tpu]'], + }, keywords=('computational neuroscience, ' 'brain-inspired computation, ' 'dynamical systems, '