diff --git a/conda_requirements.txt b/conda_requirements.txt index f8f1f391..81034aa8 100644 --- a/conda_requirements.txt +++ b/conda_requirements.txt @@ -1,5 +1,5 @@ numpy -pytorch >= 1.9 +pytorch >= 2.0 torchtriton matplotlib numba @@ -8,4 +8,5 @@ ase h5py tqdm python-graphviz -lightning \ No newline at end of file +lightning +opt_einsum \ No newline at end of file diff --git a/docs/source/examples/controller.rst b/docs/source/examples/controller.rst index cc2d5016..c5d70cc5 100644 --- a/docs/source/examples/controller.rst +++ b/docs/source/examples/controller.rst @@ -2,11 +2,12 @@ Controller ========== How to define a controller for more customized control of the training process. -We assume that there is a set of ``training_modules`` assembled and a ``database`` object has been constructed. +We assume that there is a set of :class:`~hippynn.experiment.assembly.TrainingModules` assembled, called ``training_modules``, +and a :class:`~hippynn.databases.Database`-like object called ``database`` that has been constructed. The following snippet shows how to set up a controller using a custom scheduler or optimizer:: - from hippynn.experiment.controllers import RaiseBatchSizeOnPlateau,PatienceController + from hippynn.experiment.controllers import RaiseBatchSizeOnPlateau, PatienceController optimizer = torch.optim.Adam(training_modules.model.parameters(),lr=1e-3) diff --git a/docs/source/examples/ensembles.rst b/docs/source/examples/ensembles.rst index 6389bd50..11093d81 100644 --- a/docs/source/examples/ensembles.rst +++ b/docs/source/examples/ensembles.rst @@ -21,5 +21,6 @@ The ``ensemble_info`` object provides the counts for the inputs and targets of t and the counts of those corresponding quantities across the ensemble members. A typical use case would be to then build a Predictor or ASE Calculator from the ensemble. -See :file:`~examples/ensembling_models.py` for a detailed example. +See `/examples/ensembling_models.py`_ for a detailed example. +.. _/examples/ensembling_models.py: https://github.com/lanl/hippynn/blob/development/examples/ensembling_models.py diff --git a/docs/source/examples/plotting.rst b/docs/source/examples/plotting.rst index 8d1fdf4d..cfec981c 100644 --- a/docs/source/examples/plotting.rst +++ b/docs/source/examples/plotting.rst @@ -2,12 +2,13 @@ Plotting ======== -How to make a plotmaker. +:mod:`hippynn.plotting` is only available if matplotlib is installed. -Let's assume you have a ``molecule_energy`` node that you are training to. +By default, hippynn will plot loss metrics over time when training ends. +On top of this, hippynn can make diagnostic plots during its evaluation phase. +For example, Let's assume you have a ``molecule_energy`` node that you are training to. A simple plot maker would look like this:: - from hippynn import plotting plot_maker = hippynn.plotting.PlotMaker( @@ -19,7 +20,8 @@ A simple plot maker would look like this:: training_modules,db_info = assemble_for_training(train_loss, validation_losses, plot_maker=plot_maker) -The plot maker is thus passed to `assemble_for_training` and attached to the model evaluator. +The plot maker is thus passed to :func:`~hippynn.experiment.assemble_for_training` and attached to the model evaluator. + + -Note that :mod:`hippynn.plotting` is only available if matplotlib is installed. diff --git a/docs/source/examples/predictor.rst b/docs/source/examples/predictor.rst index cc204824..37f4e073 100644 --- a/docs/source/examples/predictor.rst +++ b/docs/source/examples/predictor.rst @@ -1,10 +1,10 @@ Predictor ========= -The predictor is a simple API for making predictions on an entire database. +The :class:`~hippynn.graphs.Predictor` is a class for making predictions on an entire database. Often you'll want to make predictions based on the model. For this, -use :meth:`Predictor.from_graph`. Let's assume you have a ``GraphModule`` called ``model``:: +use the :meth:`~hippynn.graphs.Predictor.from_graph`. method. Let's assume you have a :class:`~hippynn.GraphModule` called ``model``:: predictor = hippynn.graphs.Predictor.from_graph(model) diff --git a/docs/source/examples/restarting.rst b/docs/source/examples/restarting.rst index 18a00949..aa00afa7 100644 --- a/docs/source/examples/restarting.rst +++ b/docs/source/examples/restarting.rst @@ -117,7 +117,7 @@ Advanced Details - Here are a list of objects and their final device after loading. .. list-table:: - :widths: 40 30 + :widths: 30 70 :header-rows: 1 * - Objects diff --git a/docs/source/installation.rst b/docs/source/installation.rst index 3538b678..abcfa8cf 100644 --- a/docs/source/installation.rst +++ b/docs/source/installation.rst @@ -7,7 +7,7 @@ Requirements Requirements: * Python_ >= 3.9 - * pytorch_ >= 1.9 + * pytorch_ >= 2.0 * numpy_ Optional Dependencies: @@ -20,6 +20,7 @@ Optional Dependencies: * graphviz_ (for visualizing model graphs) * h5py_ (for loading ani-h5 datasets) * pytorch-lightning_ (for distributed training) + * opt_einsum_ (backend for accelerating some pytorch expressions) Interfacing codes: * ASE_ @@ -41,7 +42,7 @@ Interfacing codes: .. _PYSEQM: https://github.com/lanl/PYSEQM .. _pytorch-lightning: https://github.com/Lightning-AI/pytorch-lightning .. _hippynn: https://github.com/lanl/hippynn/ - +.. _opt_einsum: https://github.com/dgasmith/opt_einsum Installation Instructions ^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/source/user_guide/ckernels.rst b/docs/source/user_guide/ckernels.rst index eb504da7..5b6f16e1 100644 --- a/docs/source/user_guide/ckernels.rst +++ b/docs/source/user_guide/ckernels.rst @@ -4,14 +4,106 @@ Custom Kernels Bottom line up front -------------------- -We use custom kernels in `hippynn` to accelerate the HIP-NN neural network message passing. -On the GPU, the best implementation to select is ``triton``, followed by ``cupy``, -followed by ``numba``. On the CPU, only ``numba`` is available. In general, these +If possible, install ``triton`` and ``numba``, as they will accelerate HIP-NN networks +and reduce memory cost on GPU and CPU, respectively. + + +Brief Description +----------------- + +We use custom kernels in hippynn to accelerate the HIP-NN neural network message passing and +to significantly reduce the amount of memory required in passing messages. +On the GPU, the best implementation to select is ``"triton"``, followed by ``"cupy"``, +followed by ``"numba"``. On the CPU, only ``"numba"`` is available. In general, these custom kernels are very useful, and the only reasons for them to be off is if are if the packages are not available for installation in your environment or if diagnosing whether or not a bug could be related to potential misconfiguration of these additional packages. -``triton`` comes with recent versions of ``pytorch``, so optimistically you may already be -configured to use the custom kernels. +``"triton"`` comes with recent versions of ``"pytorch"``, so optimistically you may already be +configured to use the custom kernels. Finally, there is the ``"sparse"`` implementation, which +uses torch.sparse functions. This saves memory much as the kernels from external packages, +however, it does not currently achieve a significant speedup over pytorch. + + +Comparison Table +---------------- + + +.. list-table:: Hippynn Custom Kernels Options Summary + :widths: 4 30 3 3 3 3 10 30 + :header-rows: 1 + + * - Name + - Description + - Low memory + - Speedup + - CPU + - GPU + - Required Packages + - Notes + * - pytorch + - Dense operations and index add operations + - No + - No + - Yes + - Yes + - None + - lowest overhead, gauranteed to run, but poorest performance + for large data + * - triton + - CSR-dense with OpenAI's triton compiler + using autotuning. + - Yes + - Excellent + - no + - yes + - triton + - Best option for GPU. Does incur some start-up lag due to autotuning. + * - numba + - CSR-dense hybrid with numba + - Yes + - Good + - Yes + - Yes + - numba + - Best option for CPU; non-CPU implementations fall back to this on CPU when available. + * - cupy + - CSR-dense hybrid with cupy/C code. + - Yes + - Great + - no + - yes + - cupy + - Direct translation of numba algorithm, but has improved performance. + * - sparse + - CSR-dense using torch.sparse operations. + - Yes + - None + - Yes + - Yes + - pytorch>=2.4 + - Cannot handle all systems, but raises an error on failure. + +.. note:: + Kernels which do not support the CPU fall back to numba if it is available, and + to pytorch if it is not. + +.. note:: + Custom Kernels do come with some launch overheads compared to the pytorch implementation. + If your workload is small (small batch sizes, networks, and/or small systems) + and you're using a GPU, then you may find best performance with kernels set to ``"pytorch"``. + +.. note:: + The sparse implementation is slow for very small workload sizes. At large workload + sizes, it is about as fast as pytorch (while using less memory), but still slower + than numba. + +.. note:: + The sparse implementation does not handle message-passing where atoms can appear + together in two or more sets of pairs due to small systems with periodic boundary conditions. + + +For information on how to set the custom kernels, see :doc:`settings` + Detailed Explanation -------------------- diff --git a/hippynn/_settings_setup.py b/hippynn/_settings_setup.py index 675fae91..c2393340 100644 --- a/hippynn/_settings_setup.py +++ b/hippynn/_settings_setup.py @@ -88,13 +88,12 @@ def kernel_handler(kernel_string): kernel = { "0": False, "false": False, - "pytorch": False, "1": True, "true": True, }.get(kernel_string, kernel_string) - if kernel not in [True, False, "auto", "triton", "cupy", "numba"]: - warnings.warn(f"Unexpected custom kernel setting: {kernel_string}.", stacklevel=3) + # This function used to warn about unexpected kernel settings. + # Now this is an error which is raised in the custom_kernels module. return kernel diff --git a/hippynn/custom_kernels/__init__.py b/hippynn/custom_kernels/__init__.py index 14bc5eda..1cdfa73c 100644 --- a/hippynn/custom_kernels/__init__.py +++ b/hippynn/custom_kernels/__init__.py @@ -13,20 +13,42 @@ On import, this module attempts to set the custom kernels as specified by the user in hippynn.settings. +See the :doc:`/user_guide/ckernels` section of the documentation for more information. +Depending on your available packages, you may have the following options: -.. .. autofunction:: envsum +* "pytorch": dense pytorch operations. +* "sparse": sparse pytorch operations. Can be faster than pure pytorch for large + enough systems, and will not require as much memory. + May require latest pytorch version. Cannot cover all circumstances; should error + if encountering a result that this implementation cannot cover.. +* "numba": numba implementation of custom kernels, beats pytorch-based kernels. +* "cupy": cupy implementation of custom kernels, better than numba +* "triton": triton-based custom kernels, uses auto-tuning and the triton compiler. + This is usually the best option. -.. .. autofunction:: sensesum +The available items are stored in the variable :data:`hippynn.custom_kernels.CUSTOM_KERNELS_AVAILABLE`. +The active implementation is stored in :data:`hippynn.custom_kernels.CUSTOM_KERNELS_ACTIVE`. -.. .. autofunction:: featsum +For more information, see :doc:`/user_guide/ckernels` """ +# Dev notes: +# For a new implmentation, make a new MessagePassingKernels object in your implementation file. +# This will register your implementation with the system. +# Then do the following: +# - Add your implementation name to _POSSIBLE_CUSTOM_KERNELS +# - Add an import block for the file to the populate_custom_kernels() function. +# If your custom kernel has constraints on what devices or hardware configurations are possible, +# the have your module raise an ImportError. You may also want to use +# warnings.warn to warn the user of why the implementation is not available. import warnings from typing import Union import torch from .. import settings -from . import autograd_wrapper, env_pytorch +from .registry import MessagePassingKernels +from . import env_pytorch + class CustomKernelError(Exception): @@ -45,184 +67,145 @@ def populate_custom_kernel_availability(): # check order for kernels is numba, cupy, triton. global CUSTOM_KERNELS_AVAILABLE - CUSTOM_KERNELS_AVAILABLE = [] - + # Look for CPU-capable kernels. try: import numba + from . import env_numba + from . import env_atomic + except ImportError: + pass - CUSTOM_KERNELS_AVAILABLE.append("numba") + try: + from . import env_sparse except ImportError: pass + # these kernels only work if torch can get to the GPU. if torch.cuda.is_available(): try: import cupy - - if "numba" not in CUSTOM_KERNELS_AVAILABLE: - warnings.warn("Cupy was found, but numba was not. Cupy custom kernels not available.") - else: - CUSTOM_KERNELS_AVAILABLE.append("cupy") + from . import env_cupy except ImportError: pass try: import triton - if torch.cuda.is_available(): - device_capability = torch.cuda.get_device_capability() - if device_capability[0] > 6: - CUSTOM_KERNELS_AVAILABLE.append("triton") - else: - warnings.warn( - f"Triton found but not supported by GPU's compute capability: {device_capability}" - ) + # Note: Device capability check is located in env_triton now. + from . import env_triton except ImportError: pass - + CUSTOM_KERNELS_AVAILABLE = list(MessagePassingKernels.get_available_implementations()) return CUSTOM_KERNELS_AVAILABLE -def _check_numba(): - import numba.cuda - import torch - if not numba.cuda.is_available(): - if torch.cuda.is_available(): - warnings.warn("numba.cuda.is_available() returned False: Custom kernels will fail on GPU tensors.") - return True - else: - # atexit.register(numba.cuda.close) - # Dev note for the future: Do not attempt the above `atexit` call! - # Causes segfault on program exit on some systems. - # Probably due to both numba and torch trying to finalize the GPU. - # Leaving this note here in case anyone is tempted to try it in the future. - # (At one point this was the right strategy...) - return False - - -def _check_cupy(): - import cupy - import numba - import torch - - if not cupy.cuda.is_available(): - if torch.cuda.is_available(): - warnings.warn("cupy.cuda.is_available() returned False: Custom kernels will fail on GPU tensors.") - return - -def set_custom_kernels(active: Union[bool, str] = True): +def set_custom_kernels(active: Union[bool, str] = True) -> str: """ Activate or deactivate custom kernels for interaction. This function changes the global variables: - :func:`hippynn.custom_kernels.envsum` - - :func:`hippynn.custom_kernels.sensum` + - :func:`hippynn.custom_kernels.sensesum` - :func:`hippynn.custom_kernels.featsum` - - :data:`hippynn.custom_kernels.CUSTOM_KERNELS_ACTIVE` - :param active: If true, set custom kernels to the best available. If False, turn them off and default to pytorch. - If "triton", "numba" or "cupy", use those implementations explicitly. If "auto", use best available. - :return: None + Special non-implementation-name values are: + - True: - Use the best GPU kernel from recommended implementations, error if none are available. + - False: - equivalent to "pytorch" + - "auto": - Equivalently to True if recommended is available, else equivalent to "pytorch" + + :param active: implementation name to activate + :return: active, actual implementation selected. """ + populate_custom_kernel_availability() + global envsum, sensesum, featsum, CUSTOM_KERNELS_ACTIVE + if active is False: + active = "pytorch" + if isinstance(active, str): active = active.lower() if active not in _POSSIBLE_CUSTOM_KERNELS: - raise CustomKernelError(f"Unrecognized custom kernel implementation: {active}") - - if not CUSTOM_KERNELS_AVAILABLE: + raise warnings.warn(f"Using non-standard custom kernel implementation: {active}") + + # Our goal is that this if-block is to handle the cases for values in the range of + # [True, "auto"] and turn them into the suitable actual implementation. + if any((impl in CUSTOM_KERNELS_AVAILABLE) for impl in _RECOMMENDED_CUSTOM_KERNELS): + # If recommended custom kernels are available, + # then True reverts to "auto" and "False" reverts to "pytorch". + if active is True: + active = "auto" + if active == "auto": + for impl_case in _RECOMMENDED_CUSTOM_KERNELS: + if impl_case in CUSTOM_KERNELS_AVAILABLE: + active = impl_case + break # exit the for loop, we found the best choice. + else: + # In this case, no recommended kernel is available. + # Use pytorch if active=="auto", and error if active==True. if active == "auto": warnings.warn( "triton, cupy and numba are not available: " "Custom kernels will be disabled and performance may be degraded.\n" "To silence this warning, set HIPPYNN_USE_CUSTOM_KERNELS=False", stacklevel=2) - - if active in ("auto", "pytorch"): # These are equivalent to "false" when custom kernels are not available. - active = False - elif active: - # The user explicitly set a custom kernel implementation or just True. + active = "pytorch" + elif active is True: + # The user explicitly set a custom kernel implementation to true, but no recommended ones. raise CustomKernelError( - "Triton, numba and cupy were not found." + - f"Custom kernels are not available, but they were required by library setting: {active}") - else: - # If custom kernels are available, then "auto" and "pytorch" revert to bool values. - active_map = {"auto": True, "pytorch": False} - active = active_map.get(active, active) - - # Handle fallback to pytorch kernels. - if not active: - envsum = env_pytorch.envsum - sensesum = env_pytorch.sensesum - featsum = env_pytorch.featsum - CUSTOM_KERNELS_ACTIVE = False - return - - # Select custom kernel implementation - if not CUSTOM_KERNELS_AVAILABLE: - raise CustomKernelError("Numba was not found. Custom kernels are not available.") - - if active is True: - if "triton" in CUSTOM_KERNELS_AVAILABLE: - active = "triton" - elif "cupy" in CUSTOM_KERNELS_AVAILABLE: - active = "cupy" - else: - active = "numba" - - if active not in CUSTOM_KERNELS_AVAILABLE: - raise CustomKernelError(f"Unavailable custom kernel implementation: {active}") - - if active == "triton": - from .env_triton import envsum as triton_envsum, sensesum as triton_sensesum, featsum as triton_featsum - - envsum, sensesum, featsum = autograd_wrapper.wrap_envops(triton_envsum, triton_sensesum, triton_featsum) - - elif active == "cupy": - _check_numba() - _check_cupy() - from .env_cupy import cupy_envsum, cupy_featsum, cupy_sensesum - - envsum, sensesum, featsum = autograd_wrapper.wrap_envops(cupy_envsum, cupy_sensesum, cupy_featsum) - - elif active == "numba": - _check_numba() - from .env_numba import new_envsum, new_featsum, new_sensesum - - envsum, sensesum, featsum = autograd_wrapper.wrap_envops(new_envsum, new_sensesum, new_featsum) - - else: - # We shouldn't get here except possibly mid-development, but just in case: - # if you add a custom kernel implementation remember to add to this - # dispatch block. - raise CustomKernelError(f"Unknown Implementation: {active}") - + "Triton, numba and cupy were not found. " + + f"Recommended custom kernels are not available, " + f"but they were required by library setting: {active}") + + # Ok, finally set the implementation. Note that get_implementation + # will error with type CustomKernelError if not found. + kernel_implementation = MessagePassingKernels.get_implementation(active) + envsum = kernel_implementation.envsum + sensesum = kernel_implementation.sensesum + featsum = kernel_implementation.featsum CUSTOM_KERNELS_ACTIVE = active - return + + return active CUSTOM_KERNELS_AVAILABLE = [] #: List of available kernel implementations based on currently installed packages.. -_POSSIBLE_CUSTOM_KERNELS = [True, False, "triton", "numba", "cupy", "pytorch", "auto"] +_POSSIBLE_CUSTOM_KERNELS = ( + True, + False, + "triton", + "numba", + "cupy", + "pytorch", + "sparse", + "auto", # This means, if possible, use order in _RECOMMENDED_CUSTOM_KERNELS below. +) + +# These are in order of preference! If you change the order, you change the default for "auto". +_RECOMMENDED_CUSTOM_KERNELS = ( + "triton", + "numba", + "cupy", +) try_custom_kernels = settings.USE_CUSTOM_KERNELS -CUSTOM_KERNELS_ACTIVE = None #: Which custom kernel implementation is currently active. +CUSTOM_KERNELS_ACTIVE = None #: Which custom kernel implementation is currently active. envsum = None #: See :func:`hippynn.custom_kernels.env_pytorch.envsum` for more information. sensesum = None #: See :func:`hippynn.custom_kernels.env_pytorch.sensesum` for more information. featsum = None #: See :func:`hippynn.custom_kernels.env_pytorch.featsum` for more information. try: - populate_custom_kernel_availability() set_custom_kernels(try_custom_kernels) except CustomKernelError as eee: - raise + raise # We re-raise custom kernel releated errors. except Exception as ee: warnings.warn(f"Custom kernels are disabled due to an unexpected error:\n" f"\t{ee}", stacklevel=2) del ee - + # Since we don't know what caused the error in the above, + # let's not re-call the function. envsum = env_pytorch.envsum sensesum = env_pytorch.sensesum featsum = env_pytorch.featsum diff --git a/hippynn/custom_kernels/autograd_wrapper.py b/hippynn/custom_kernels/autograd_wrapper.py index d7c3c96f..e29222a2 100644 --- a/hippynn/custom_kernels/autograd_wrapper.py +++ b/hippynn/custom_kernels/autograd_wrapper.py @@ -7,7 +7,7 @@ from contextlib import contextmanager _DEVICE_CONTEXT_LOCK = threading.Lock() -_DEVICE_TIMEOUT = 10 # if custom kernels have locked for 10s, throw an error +_DEVICE_TIMEOUT = 30 # if custom kernels have locked for 10s, throw an error @contextmanager @@ -123,3 +123,5 @@ def backward(ctx, grad_output): featsum = AGFeatsum.apply return envsum, sensesum, featsum + + diff --git a/hippynn/custom_kernels/env_atomic.py b/hippynn/custom_kernels/env_atomic.py new file mode 100644 index 00000000..b1cf749d --- /dev/null +++ b/hippynn/custom_kernels/env_atomic.py @@ -0,0 +1,183 @@ +""" +Atomic-operation version of custom kernels. + +These kernels are not recommended for actualy use; they exist +for benchmarking purposes. +""" +import functools + +import torch +import numba +import numba.cuda +import numpy as np + +from .tensor_wrapper import NumbaCompatibleTensorFunction, via_numpy +from .registry import MessagePassingKernels + + +# conventions: +# pidx : index of pair +# pfidx : index of first atom in pair (receiver) +# psidx : index of second atom in pair (sender) +# fidx : index of feature +# nidx : index of sensitivity (nu) + +# Kernel which sums sensitivities and features to get environment. +# Numpy core signature: (p,n),(a,f),(p),(p),(a,n,f) +class WrappedEnvsum(NumbaCompatibleTensorFunction): + def out_shape(self, sense_shape, feat_shape, *other_shapes): + n_pair, n_nu = sense_shape + n_atom, n_feature = feat_shape + return n_atom, n_nu, n_feature + + def launch_bounds(self, sense_shape, *other_shapes): + n_pairs, n_nu = sense_shape + TPB = 1024 + BPG = (n_pairs + TPB - 1) // TPB + return BPG, TPB + + @staticmethod + def make_kernel(KERNEL_DTYPE): + sig = "void({DTYPE}[:,:,],{DTYPE}[:,:],int64[:],int64[:],{DTYPE}[:,:,:])".format(DTYPE=KERNEL_DTYPE) + @numba.cuda.jit(sig) + def kernel(sens, feat, pfirst, psecond, env): + n_pairs, n_nu = sens.shape + n_atom, n_feat = feat.shape + + pidx, nidx, fidx = numba.cuda.grid(3) + + if pidx < n_pairs: + pfidx = pfirst[pidx] + psidx = psecond[pidx] + for nidx in range(n_nu): + s = sens[pidx, nidx] + if abs(s) > 1e-10: + for fidx in range(n_feat): + out = s * feat[psidx, fidx] + numba.cuda.atomic.add(env, (pfidx, nidx, fidx), out) + return kernel + + @staticmethod + @via_numpy + @numba.jit(parallel=True) + def cpu_kernel(sens, feat, pfirst, psecond): + n_pairs, n_nu = sens.shape + n_atom, n_feat = feat.shape + env_features = np.zeros((n_atom, n_nu, n_feat), dtype=sens.dtype) + for nidx in numba.prange(n_nu): + for pidx, (pfidx, psidx) in enumerate(zip(pfirst, psecond)): + s = sens[pidx, nidx] + if abs(s) > 1e-10: + for fidx in range(n_feat): + env_features[pfidx, nidx, fidx] += s * feat[psidx, fidx] + return env_features + + +# Kernel which sums environment and features to get sensitivity +#Numpy core signature: (a,n,f),(a,f),(p),(p),(p,n), +class WrappedSensesum(NumbaCompatibleTensorFunction): + def out_shape(self, env_shape, feat_shape, pfirst_shape, psecond_shape): + n_pair, = pfirst_shape + n_atom, n_nu, n_feat = env_shape + return n_pair, n_nu + + def launch_bounds(self, env_shape, feat_shape, pfirst_shape, psecond_shape): + n_pairs, = pfirst_shape + TPB = 1024 + BPG = (n_pairs + TPB - 1) // TPB + return BPG, TPB + + @staticmethod + def make_kernel(KERNEL_DTYPE): + sig = "void({DTYPE}[:,:,:],{DTYPE}[:,:],int64[:], int64[:],{DTYPE}[:,:])".format(DTYPE=KERNEL_DTYPE) + @numba.cuda.jit(sig) + def kernel(env, feat, pfirst, psecond, sense): + n_pairs, = pfirst.shape + n_atom, n_nu, n_feat = env.shape + pidx, nidx, fidx = numba.cuda.grid(3) + + if pidx < n_pairs: + pfidx = pfirst[pidx] + psidx = psecond[pidx] + for nidx in range(n_nu): + tmp = 0. + for fidx in range(n_feat): + tmp += env[pfidx, nidx, fidx] * feat[psidx, fidx] + numba.cuda.atomic.add(sense, (pidx, nidx), tmp) + + return kernel + + @staticmethod + @via_numpy + @numba.jit(parallel=True) + def cpu_kernel(env, feat, pfirst, psecond): + n_atom, n_nu, n_feat = env.shape + n_pairs, = pfirst.shape + sense = np.zeros((n_pairs, n_nu), dtype=env.dtype) + for nidx in numba.prange(n_nu): + for pidx in numba.prange(n_pairs): + pfidx = pfirst[pidx] + psidx = psecond[pidx] + for fidx in range(n_feat): + sense[pidx, nidx]+=env[pfidx, nidx, fidx] * feat[psidx, fidx] + return sense + +# Kernel which sums environment and sensitivity to get features +#Numpy core signature: (a,n,f),(p,n),(p),(p),(a,f), +class WrappedFeatsum(NumbaCompatibleTensorFunction): + def out_shape(self, env_shape, sense_shape, pfirst_shape, psecond_shape): + n_atom, n_nu, n_feature = env_shape + return n_atom, n_feature + + def launch_bounds(self, env_shape, sense_shape, pfirst_shape, psecond_shape): + n_pairs, n_nu = sense_shape + TPB = 1024 + BPG = (n_pairs + TPB - 1) // TPB + return BPG, TPB + + @staticmethod + def make_kernel(KERNEL_DTYPE): + sig = "void({DTYPE}[:,:,:],{DTYPE}[:,:],int64[:],int64[:],{DTYPE}[:,:])".format(DTYPE=KERNEL_DTYPE) + + @numba.cuda.jit(sig) + def kernel(env, sense, pfirst, psecond, feat): + n_pairs, n_nu = sense.shape + n_atom, n_feat = feat.shape + + pidx, nidx, fidx = numba.cuda.grid(3) + + if pidx < n_pairs: + pfidx = pfirst[pidx] + psidx = psecond[pidx] + for fidx in range(n_feat): + tmp = 0 + for nidx in range(n_nu): + tmp += env[pfidx, nidx, fidx] * sense[pidx, nidx] + numba.cuda.atomic.add(feat, (psidx, fidx), tmp) + + return kernel + + @staticmethod + @via_numpy + @numba.jit(parallel=False) + def cpu_kernel(env, sens, pfirst, psecond): + n_atom, n_nu, n_feat = env.shape + n_pairs, = pfirst.shape + feat = np.zeros((n_atom, n_feat), dtype=sens.dtype) + + for pidx, (pfidx, psidx) in enumerate(zip(pfirst, psecond)): + for nidx in range(n_nu): + for fidx in range(n_feat): + feat[psidx, fidx] += env[pfidx, nidx, fidx] * sens[pidx, nidx] + return feat + +atomic_envsum = WrappedEnvsum() +atomic_sensesum = WrappedSensesum() +atomic_featsum = WrappedFeatsum() + +numba_kernels = MessagePassingKernels( + "_numba_atomic", + atomic_envsum, + atomic_sensesum, + atomic_featsum, +) diff --git a/hippynn/custom_kernels/env_cupy.py b/hippynn/custom_kernels/env_cupy.py index 38c0ad49..44b6fa1c 100644 --- a/hippynn/custom_kernels/env_cupy.py +++ b/hippynn/custom_kernels/env_cupy.py @@ -1,15 +1,26 @@ """ -CuPy implementation of envsum custom kernels for GPU. +Cupy implementation of envsum custom kernels for GPU. """ -# Dev Note: CPU implementation of these ops is still done by numba. -# As such, numba is still required and calls to CPU ops must -# obey the same API as the numba implementations. - +import warnings import torch import cupy -from hippynn.custom_kernels.env_numba import WrappedEnvsum, WrappedSensesum, WrappedFeatsum -from hippynn.custom_kernels.utils import resort_pairs_cached +if not cupy.cuda.is_available(): + if torch.cuda.is_available(): + warnings.warn("Cupy is installed but cupy.cuda.is_available() returned False. " + "Custom kernels will most likely fail on GPU tensors. ") + +# If numba is available, this implementation will default to numba on CPU. If not, use vanilla pytorch. +try: + from .env_numba import new_envsum as envsum_alternative, new_sensesum as sensesum_alternative, new_featsum as featsum_alternative +except ImportError: + # Load backup implementation for CPU tensors. + from .env_pytorch import envsum as envsum_alternative, sensesum as sensesum_alternative, featsum as featsum_alternative + +from .env_numba import WrappedEnvsum, WrappedSensesum, WrappedFeatsum +from .utils import resort_pairs_cached + +from hippynn.custom_kernels import MessagePassingKernels CUPY_KERNEL_CODE = r""" extern "C" __global__ @@ -149,16 +160,17 @@ def __call__(self, dtype, BPG, TPB, array_args, shape_args): return out_array -class CupyEnvsum(CupyGPUKernel, WrappedEnvsum): +class CupyEnvsum(CupyGPUKernel): _cupy_name = "cupy_envsum" def __call__(self, sense, feat, pfirst, psecond): + dev = sense.device + if dev.type == "cpu": + return envsum_alternative(sense, feat, pfirst, psecond) + psecond_hold = psecond argsort, atom1_ids, atom1_starts, pfirst, (sense, psecond) = resort_pairs_cached(pfirst, [sense, psecond]) resort_pairs_cached(psecond_hold, []) - dev = sense.device - if dev.type == "cpu": - return self.cpu_kernel(sense, feat, pfirst, psecond, atom1_ids, atom1_starts) n_pairs, n_nu = sense.shape n_atoms, n_feat = feat.shape @@ -178,8 +190,7 @@ def __call__(self, sense, feat, pfirst, psecond): shape_args = n_nu, n_feat, n_interact if n_feat > 512: - raise ValueError(f"Numba GPU custom kernels are not compatible with feature sizes greater than 512 (got {n_feat})") - + raise ValueError(f"Cupy GPU custom kernels are not compatible with feature sizes greater than 512 (got {n_feat})") TPB_MAX = 512 TPB_X = n_feat @@ -193,13 +204,13 @@ def __call__(self, sense, feat, pfirst, psecond): return super().__call__(dtype, BPG, TPB, array_args, shape_args) -class CupySensesum(CupyGPUKernel, WrappedSensesum): +class CupySensesum(CupyGPUKernel): _cupy_name = "cupy_sensesum" def __call__(self, env, feat, pfirst, psecond): dev = env.device if dev.type == "cpu": - return self.cpu_kernel(env, feat, pfirst, psecond) + return sensesum_alternative(env, feat, pfirst, psecond) (n_pairs,) = pfirst.shape n_atoms, n_nu, n_feat = env.shape @@ -210,7 +221,7 @@ def __call__(self, env, feat, pfirst, psecond): shape_args = n_pairs, n_nu, n_feat if n_nu > 512: - raise ValueError(f"Numba GPU custom kernels are not compatible with sensitivity sizes greater than 512 (got {n_nu})") + raise ValueError(f"Cupy GPU custom kernels are not compatible with sensitivity sizes greater than 512 (got {n_nu})") TPB_MAX = 512 TPB_Y = n_nu @@ -222,16 +233,17 @@ def __call__(self, env, feat, pfirst, psecond): return super().__call__(dtype, BPG, TPB, array_args, shape_args) -class CupyFeatsum(CupyGPUKernel, WrappedFeatsum): +class CupyFeatsum(CupyGPUKernel): _cupy_name = "cupy_featsum" def __call__(self, env, sense, pfirst, psecond): + dev = env.device + if dev.type == "cpu": + return featsum_alternative(env, sense, pfirst, psecond) + pfirst_hold = pfirst argsort, atom2_ids, atom2_starts, psecond, (sense, pfirst) = resort_pairs_cached(psecond, [sense, pfirst]) resort_pairs_cached(pfirst_hold, []) - dev = env.device - if dev.type == "cpu": - return self.cpu_kernel(env, sense, pfirst, psecond, atom2_ids, atom2_starts) (n_pairs,) = pfirst.shape n_atoms, n_nu, n_feat = env.shape @@ -265,3 +277,10 @@ def __call__(self, env, sense, pfirst, psecond): cupy_envsum = CupyEnvsum() cupy_sensesum = CupySensesum() cupy_featsum = CupyFeatsum() + +cupy_kernels = MessagePassingKernels( + "cupy", + cupy_envsum, + cupy_sensesum, + cupy_featsum, +) diff --git a/hippynn/custom_kernels/env_numba.py b/hippynn/custom_kernels/env_numba.py index 1b840767..0d502462 100644 --- a/hippynn/custom_kernels/env_numba.py +++ b/hippynn/custom_kernels/env_numba.py @@ -1,15 +1,26 @@ """ Numba implementation of envsum operations. """ +# Dev note for the future: Do not attempt the `atexit` call: +# >>> atexit.register(numba.cuda.close) +# Causes segfault on program exit on some systems. +# Probably due to both numba and torch trying to finalize the GPU. +# Leaving this note here in case anyone is tempted to try it in the future. +# (At one point in history, this was the right strategy.) +import warnings +import torch import numba import numba.cuda import numpy as np from .utils import resort_pairs_cached from .tensor_wrapper import via_numpy, NumbaCompatibleTensorFunction +from .registry import MessagePassingKernels -# Very basic implementation. -# While simple, it beats a set of pytorch operations simply by using far less memory. +if not numba.cuda.is_available(): + if torch.cuda.is_available(): + warnings.warn("Numba is installed but numba.cuda.is_available() returned False. " + "Custom kernels will most likely fail on GPU tensors. ") # conventions: # pidx : index of pair @@ -18,7 +29,6 @@ # fidx : index of feature # nidx : index of sensitivity (nu) - # Kernel which sums sensitivities and features to get environment. # Numpy core signature: (p,n),(a,f),(p),(p),(a,n,f) class WrappedEnvsum(NumbaCompatibleTensorFunction): @@ -281,3 +291,10 @@ def cpu_kernel( new_envsum = WrappedEnvsum() new_sensesum = WrappedSensesum() new_featsum = WrappedFeatsum() + +numba_kernels = MessagePassingKernels( + "numba", + new_envsum, + new_sensesum, + new_featsum, +) diff --git a/hippynn/custom_kernels/env_pytorch.py b/hippynn/custom_kernels/env_pytorch.py index 6724b8fd..c6c4013e 100644 --- a/hippynn/custom_kernels/env_pytorch.py +++ b/hippynn/custom_kernels/env_pytorch.py @@ -7,6 +7,7 @@ """ import torch from torch import Tensor +from .registry import MessagePassingKernels def envsum(sensitivities: Tensor, features: Tensor, pair_first: Tensor, pair_second: Tensor) -> Tensor: @@ -48,7 +49,8 @@ def sensesum(env, features, pair_first, pair_second): n_atoms, n_nu, n_feat = env.shape pair_env = env[pair_first] pair_feat = features[pair_second] - sense = (pair_env * pair_feat.unsqueeze(1)).sum(dim=2) + # sense = torch.einsum("psf,pf->ps", pair_env, pair_feat) # einsum notation; should be completely equivalent + sense = torch.bmm(pair_env, pair_feat.unsqueeze(2)).squeeze(2) # bmm return sense @@ -64,6 +66,81 @@ def featsum(env, sense, pair_first, pair_second): See the :doc:`/user_guide/ckernels` section of the documentation for more information. + :param env: (n_atoms, n_sensitivities, n_features) floating tensor + :param sense: (n_pairs, n_sensitivities) floating point tensor + :param pair_first: (n_pairs,) index tensor indicating first atom of pair + :param pair_second: (n_pairs,) index tensor indicating second atom of pair + :return: feat (n_atoms, n_features) floating point tensor + """ + n_atoms, n_nu, n_feat = env.shape + pair_env = env[pair_first] + # pair_feat = torch.einsum("psf,ps->pf", pair_env, sense) # einsum notation; should be completely equivalent + pair_feat = torch.bmm(sense.unsqueeze(1), pair_env).squeeze(1) # bmm + feat = torch.zeros(n_atoms, n_feat, device=env.device, dtype=env.dtype) + feat.index_add_(0, pair_second, pair_feat) + return feat + + +def _envsum_legacy(sensitivities: Tensor, features: Tensor, pair_first: Tensor, pair_second: Tensor) -> Tensor: + """ + Original envsum implementation. + + Computes outer product of sensitivities of pairs and atom features from pair_second, + whilst accumulating them onto indices pair_first. + + See the :doc:`/user_guide/ckernels` section of the documentation for more information. + + :param sensitivities: (n_pairs, n_sensitivities) floating point tensor + :param features: (n_atoms, n_features) floating point tensor + :param pair_first: (n_pairs,) index tensor indicating first atom of pair + :param pair_second: (n_pairs,) index tensor indicating second atom of pair + :return: env (n_atoms, n_sensitivities, n_features) floating tensor + """ + n_pairs, n_nu = sensitivities.shape + n_atom, n_feat = features.shape + pair_features = features[pair_second].unsqueeze(1) + sensitivities = sensitivities.unsqueeze(2) + pair_env_features = sensitivities * pair_features + env_features = torch.zeros((n_atom, n_nu, n_feat), device=features.device, dtype=features.dtype) + env_features.index_add_(0, pair_first, pair_env_features) + return env_features + + +def _sensesum_legacy(env, features, pair_first, pair_second): + """ + Original sensesum implementation. + + Computes product of environment at pair_first with features from pair_second, + whilst summing over feature indices. + + See the :doc:`/user_guide/ckernels` section of the documentation for more information. + + :param env: (n_atoms, n_sensitivities, n_features) floating tensor + :param features: (n_atoms, n_features) floating point tensor + :param pair_first: (n_pairs,) index tensor indicating first atom of pair + :param pair_second: (n_pairs,) index tensor indicating second atom of pair + :return: sense (n_pairs, n_sensitivities) floating point tensor + """ + n_atoms, n_nu, n_feat = env.shape + pair_env = env[pair_first] + pair_feat = features[pair_second] + sense = (pair_env * pair_feat.unsqueeze(1)).sum(dim=2) + return sense + + +def _featsum_legacy(env, sense, pair_first, pair_second): + """ + Original featsum implementation. + + Compute inner product of sensitivities with environment tensor over atoms + from pair_first, while accumulating them on to pair_second. + + The summation order is different from envsum because this + signature naturally supports the use of featsum as a backwards pass + for envsum, and vise-versa. + + See the :doc:`/user_guide/ckernels` section of the documentation for more information. + :param env: (n_atoms, n_sensitivities, n_features) floating tensor :param sense: (n_pairs, n_sensitivities) floating point tensor :param pair_first: (n_pairs,) index tensor indicating first atom of pair @@ -76,3 +153,55 @@ def featsum(env, sense, pair_first, pair_second): feat = torch.zeros(n_atoms, n_feat, device=env.device, dtype=env.dtype) feat.index_add_(0, pair_second, pair_feat) return feat + + +# Note: torch.compile functions always need to be wrapped because +# at least at the moment, AOT autograd does not allow double-backwards passes. + +old_kernels = MessagePassingKernels( + "_legacy", + _envsum_legacy, _sensesum_legacy, _featsum_legacy, + wrap=False, # Important distinction! +) + +old_kernels_jit = MessagePassingKernels( + "_legacy_jit", + _envsum_legacy, _sensesum_legacy, _featsum_legacy, + compiler=torch.jit.script, +) + +old_kernels_compile = MessagePassingKernels( + "_legacy_compile", + _envsum_legacy, _sensesum_legacy, _featsum_legacy, + compiler=torch.compile, +) + +pytorch_kernels_raw = MessagePassingKernels( + "_pytorch_raw", + envsum, + sensesum, + featsum, + wrap=False, # Important distinction! +) + +pytorch_kernels_wrapped = MessagePassingKernels( + "_pytorch_raw_wrapped", + envsum, sensesum, featsum, +) + +pytorch_kernels_jit = MessagePassingKernels( + "_pytorch_jit", + envsum, + sensesum, + featsum, + compiler=torch.jit.script, +) + +pytorch_kernels_compile = MessagePassingKernels( + "pytorch", + envsum, sensesum, featsum, + compiler=torch.compile, +) + + + diff --git a/hippynn/custom_kernels/env_sparse.py b/hippynn/custom_kernels/env_sparse.py new file mode 100644 index 00000000..18ddb886 --- /dev/null +++ b/hippynn/custom_kernels/env_sparse.py @@ -0,0 +1,132 @@ +""" +Pure pytorch implementation of envsum operations +""" +import torch + +from .registry import MessagePassingKernels + + +# TODO: Does resort_pairs_cached give enough to allow direct construction of CSR? + + +def make_sparse_sense(sensitivities, pair_first, pair_second, n_atom: int): + """ + Construct sensitivities as a sparse matrix with shape + (n_atoms * n_nu, n_atoms). + + The n_atoms * n_nu is needed because of limitations in the hybrid-sparse + matrix multiply routines available in pytorch. This function + is implemented seprately because both envsum and sensesum use it. + The to_csr call is done afterwards because envsum needs to be + transposed. + + :param sensitivities: + :param pair_first: + :param pair_second: + :param n_atom: + :return: + """ + n_pairs, n_nu = sensitivities.shape + + pf_unsqueeze = pair_first.unsqueeze(1).expand(n_pairs, n_nu) + nu_range = torch.arange(n_nu, device=pair_first.device).unsqueeze(0) + first_index = (pf_unsqueeze * n_nu + nu_range).flatten() + second_index = pair_second.unsqueeze(1).expand(n_pairs, n_nu).flatten() + + indices = torch.stack([first_index, second_index]) + sparse_sense = torch.sparse_coo_tensor( + values=sensitivities.flatten(), + indices=indices, + size=(n_atom * n_nu, n_atom), + dtype=sensitivities.dtype, + device=sensitivities.device) + + return sparse_sense + + +def envsum(sensitivities, features, pair_first, pair_second): + n_pairs, n_nu = sensitivities.shape + n_atom, n_feat = features.shape + + sparse_sense = make_sparse_sense(sensitivities, pair_first, pair_second, n_atom) + sparse_sense = sparse_sense.to_sparse_csr() + + env = torch.mm(sparse_sense, features).reshape(n_atom, n_nu, n_feat) + + return env + + +def sensesum(env, features, pair_first, pair_second): + """ + + Sparse sensesum implementation uses a sparsity matrix + of zeros with shape (n_atoms x n_atoms), combined with + the crucial function torch.sparse sampled_addmm. + + :param env: + :param features: + :param pair_first: + :param pair_second: + :return: + """ + + n_atoms, n_nu, n_feat = env.shape + + indices = torch.stack([pair_first, pair_second]) + sparse_pairs = torch.sparse_coo_tensor( + values=torch.zeros_like(pair_first, dtype=env.dtype), + indices=indices, + size=(n_atoms, n_atoms), + dtype=env.dtype, + device=pair_first.device) + + sparse_pairs = sparse_pairs.to_sparse_csr() + + env_rs = env.permute(1, 0, 2) # Putting sensitivity index first + feat_rs = features.transpose(0, 1) # 2D transpose + feat_rs = feat_rs.unsqueeze(0).expand(n_nu, -1, -1) # feat needs same batch size + sense_sparse = torch.sparse.sampled_addmm(sparse_pairs, env_rs, feat_rs) + sense_sparse = sense_sparse.to_sparse_coo() + + sense_values = sense_sparse.values() + + # This will error if the same pair appears twice. + try: + n_pair, = pair_first.shape + sense_values = sense_values.reshape(n_nu, n_pair) + except RuntimeError as ee: + raise ValueError( + f"Sensitivity values shape changed. Likely more than one pair entry " + f"connecting the same atoms. The 'sparse' implementation custom kernels do not support " + f"pair lists containing duplicate items. " + f"Input shape: {n_pair} Output shape: {sense_values.shape[0] // n_nu}") from ee + + sense_values = sense_values.transpose(0, 1) + + # Note: indices emerge sorted. If we have not sorted by pfirst and psecond, + # we must then invert how they appear. + pair_rank_array = n_atoms * pair_first + pair_second + inverse_order = torch.argsort(torch.argsort(pair_rank_array)) + sense_values = sense_values[inverse_order] + + return sense_values + + +def featsum(env, sense, pair_first, pair_second): + n_atoms, n_nu, n_feat = env.shape + + sparse_sense = make_sparse_sense(sense, pair_first, pair_second, n_atoms) + sparse_sense = sparse_sense.transpose(0, 1) + sparse_sense = sparse_sense.to_sparse_csr() + + feat = torch.mm(sparse_sense, env.reshape(n_atoms * n_nu, n_feat)) + return feat + + +sparse_kernels = MessagePassingKernels( + "sparse", + envsum, sensesum, featsum, +) + +# Note: no sparse_jit because try/except not supported. +# to sparse_compile because it won't transpose a matrix?? \ No newline at end of file diff --git a/hippynn/custom_kernels/env_triton.py b/hippynn/custom_kernels/env_triton.py index 266318d7..e5ee28ca 100644 --- a/hippynn/custom_kernels/env_triton.py +++ b/hippynn/custom_kernels/env_triton.py @@ -1,7 +1,12 @@ +""" +triton implementation of envsum custom kernels for GPU. +""" +import warnings import torch import triton import triton.language as tl from .utils import resort_pairs_cached +from .registry import MessagePassingKernels # If numba is available, this implementation will default to numba on CPU. If not, use vanilla pytorch. try: @@ -10,11 +15,23 @@ # Load backup implementation for CPU tensors. from .env_pytorch import envsum as envsum_alternative, sensesum as sensesum_alternative, featsum as featsum_alternative + +if torch.cuda.is_available(): + device_capability = torch.cuda.get_device_capability() + if not device_capability[0] > 6: + msg = f"`triton` package found, but does not support GPU's compute capability: {device_capability}" + # First warn, then error, because: + # - the warning should be seen by the user. + # - The error is caught by the __init__ module and uses this as a signal not to include + # 'triton' as an available implementation + warnings.warn(msg, stacklevel=2) + raise ImportError(msg) + + def config_pruner(configs, nargs, **kwargs): """ Trims the unnecessary config options based on the sens. and feat. sizes """ - #print("For some reason the config pruner also gets arguments:",kwargs) p2_sens_size = triton.next_power_of_2(nargs["sens_size"]) p2_feat_size = triton.next_power_of_2(nargs["feat_size"]) @@ -40,6 +57,7 @@ def config_pruner(configs, nargs, **kwargs): num_warps=config.num_warps, ) + def get_autotune_config(): """ Create a list of config options for the kernels @@ -325,3 +343,11 @@ def featsum(env, sense, pfirst, psecond): argsort, atom2_ids, atom2_starts, psecond, (sense, pfirst) = resort_pairs_cached(psecond, [sense, pfirst]) resort_pairs_cached(pfirst_hold, []) # preemptively sort (probably no-op) return featsum_triton(env, sense, pfirst, psecond, atom2_ids, atom2_starts, out_feat=None) + + +triton_kernels = MessagePassingKernels( + "triton", + envsum, + sensesum, + featsum, +) diff --git a/hippynn/custom_kernels/fast_convert.py b/hippynn/custom_kernels/fast_convert.py index 119e5f04..f8b55ed6 100644 --- a/hippynn/custom_kernels/fast_convert.py +++ b/hippynn/custom_kernels/fast_convert.py @@ -1,5 +1,7 @@ """ -This module implements a version of converting +This module implements a version of converting from pytorch tensors +to numba DeviceNDArrays that skips much of the indirection that takes place +in the numba implementation. Note: This is not entirely API safe as numba has not exposed all of these functions directly. diff --git a/hippynn/custom_kernels/registry.py b/hippynn/custom_kernels/registry.py new file mode 100644 index 00000000..3ba46641 --- /dev/null +++ b/hippynn/custom_kernels/registry.py @@ -0,0 +1,63 @@ +from hippynn.custom_kernels.autograd_wrapper import wrap_envops + + +class MessagePassingKernels: + _registered_implementations = {} # Registry for custom kernel implementations. + + def __init__(self, impl_name: str, envsum_impl, sensesum_impl, featsum_impl, wrap=True, + compiler=None,): + """ + :param impl_name: name for implementation. + :param envsum_impl: non-autograd-wrapped envsum implementation + :param sensesum_impl: non-autograd-wrapped sensesum implementation + :param featsum_impl: non-autograd-wrapped featsum implementation + :param wrap: set to false if implementations are already autograd-capable. + """ + + if compiler is not None: + envsum_impl, sensesum_impl, featsum_impl = \ + map(compiler, (envsum_impl, sensesum_impl, featsum_impl)) + + self.envsum_impl = envsum_impl + self.sensesum_impl = sensesum_impl + self.featsum_impl = featsum_impl + + if wrap: + envsum, sensesum, featsum = wrap_envops(envsum_impl, sensesum_impl, featsum_impl) + else: + envsum, sensesum, featsum = envsum_impl, sensesum_impl, featsum_impl + + self.envsum = envsum + self.sensesum = sensesum + self.featsum = featsum + + impl_name = impl_name.lower() + if impl_name in self._registered_implementations: + raise ValueError(f"Already have implementation of kernels named {impl_name}!") + else: + self._registered_implementations[impl_name] = self + + @classmethod + def get_implementation(cls, impl_name): + """ + + :param impl_name: + :return: + :raises CustomKernelError if implementation is not available or known. + """ + from . import CustomKernelError + try: + impl = cls._registered_implementations[impl_name.lower()] + except KeyError: + raise CustomKernelError(f"Unavailable custom kernel implementation: {impl_name}") + return impl + + @classmethod + def get_available_implementations(self, hidden=False): + """ + Return the available implementations of the custom kernels. + + :param hidden: Show all implementations, even those which have no improved performance characteristics. + :return: + """ + return [k for k in self._registered_implementations.keys() if not k.startswith("_")] diff --git a/hippynn/custom_kernels/test_env_numba.py b/hippynn/custom_kernels/test_env.py similarity index 65% rename from hippynn/custom_kernels/test_env_numba.py rename to hippynn/custom_kernels/test_env.py index 616a2eb8..15940077 100644 --- a/hippynn/custom_kernels/test_env_numba.py +++ b/hippynn/custom_kernels/test_env.py @@ -1,15 +1,17 @@ """ Test module for verifying implementation correctness against pytorch. """ +import platform +import warnings +import time + import numpy as np import torch from . import env_pytorch -from . import autograd_wrapper from .utils import clear_pair_cache -import warnings - +from .registry import MessagePassingKernels try: from . import env_numba @@ -26,9 +28,17 @@ except ImportError: warnings.warn("triton implementation not importable.") env_triton = None +try: + from . import env_sparse +except ImportError: + warnings.warn("sparse implementation not importable.") + env_sparse = None -def get_simulated_data(n_molecules, n_atoms, atom_prob, n_features, n_nu, printinfo=False, dtype=None, device=torch.device("cpu")): +def get_simulated_data( + n_molecules, n_atoms, atom_prob, n_features, n_nu, printinfo=False, dtype=None, device=torch.device("cpu"), randomize_order=False, + use_duplicate=False, +): """ Get semi-realistic test data for hipnn. n_molecules : number of molecules in the batch @@ -85,17 +95,31 @@ def get_simulated_data(n_molecules, n_atoms, atom_prob, n_features, n_nu, printi ) atom_index = np.arange(n_molecules * n_atoms).reshape(molatom_shp) + nonzero_pair = np.nonzero(pair_presence) # We'll just do fully connected molecules. - pair_first_pre = np.repeat(atom_index[:, :, np.newaxis], n_atoms, axis=2)[np.nonzero(pair_presence)] - pair_second_pre = np.repeat(atom_index[:, np.newaxis, :], n_atoms, axis=1)[np.nonzero(pair_presence)] + pair_first_pre = np.repeat(atom_index[:, :, np.newaxis], n_atoms, axis=2)[nonzero_pair] + pair_second_pre = np.repeat(atom_index[:, np.newaxis, :], n_atoms, axis=1)[nonzero_pair] pair_first = inv_real_atoms[pair_first_pre] pair_second = inv_real_atoms[pair_second_pre] + + if randomize_order: + random_order = np.random.permutation(np.arange(len(pair_first))) + pair_first = pair_first[random_order] + pair_second = pair_second[random_order] + n_pairs = len(pair_first) - # NOTE: These fake sensitivities are NONSYMMETRIC. - # Current HIP-NN does not do that, but a future one could. + # reduplicate for testing if an implementation works with duplicate pairs. + if use_duplicate: + pair_first[0] = pair_first[-1] + pair_second[0] = pair_second[-1] + + # NOTE: These synthetic sensitivities are NONSYMMETRIC, that is, j->i does not mean i->j, + # and also does not mean that the value of the sensitivity is the same + # even when i<->j. + on_sensitivites = np.random.choice([True, False], p=[3 / n_nu, 1 - 3 / n_nu], size=(n_pairs, n_nu)) pair_sensitivites = np.random.random(size=(n_pairs, n_nu)) * on_sensitivites assert not (pair_first == pair_second).any() @@ -123,39 +147,28 @@ def get_simulated_data(n_molecules, n_atoms, atom_prob, n_features, n_nu, printi TEST_MEGA_PARAMS = dict(n_molecules=500, n_atoms=30, atom_prob=0.7, n_features=128, n_nu=100) TEST_ULTRA_PARAMS = dict(n_molecules=500, n_atoms=30, atom_prob=0.7, n_features=128, n_nu=320) TEST_GIGA_PARAMS = dict(n_molecules=32, n_atoms=30, atom_prob=0.7, n_features=512, n_nu=320) +TEST_PARAMS = dict( + tiny=TEST_TINY_PARAMS, + small=TEST_SMALL_PARAMS, + medium=TEST_MEDIUM_PARAMS, + large=TEST_LARGE_PARAMS, + mega=TEST_MEGA_PARAMS, + ultra=TEST_ULTRA_PARAMS, + giga=TEST_GIGA_PARAMS, +) + + +# A class for testing the correctness and speed of the kernels function implementations. + + +class EnvOpsTester: + def __init__(self, name: str, suspicious_deviation: int = 0.5): + implementation = MessagePassingKernels.get_implementation(name) + self.envsum = implementation.envsum + self.sensesum = implementation.sensesum + self.featsum = implementation.featsum + self.name = name -# reference implementation - - -def test_pytorch_reference(): - sense, feat, pfirst, psecond = get_simulated_data(**TEST_TINY_PARAMS, dtype=torch.float64) - sense.requires_grad_(True) - feat.requires_grad_(True) - env = env_pytorch.envsum(sense, feat, pfirst, psecond) - pt_gradsense, pt_gradfeat = torch.autograd.grad(env.sum(), [sense, feat]) - sense.requires_grad_(False) - feat.requires_grad_(False) - ref_gradsense = env_pytorch.sensesum(torch.ones_like(env), feat, pfirst, psecond) - ref_gradfeat = env_pytorch.featsum(torch.ones_like(env), sense, pfirst, psecond) - - assert torch.allclose(pt_gradsense, ref_gradsense) - assert torch.allclose(pt_gradfeat, ref_gradfeat) - assert (pt_gradsense == ref_gradsense).all() - assert (pt_gradfeat == ref_gradfeat).all() - - -test_pytorch_reference() -del test_pytorch_reference - - -# A class for testing the correctness of the kernels functions. - - -class Envops_tester: - def __init__(self, envsum_raw, sensesum_raw, featsum_raw, suspicious_deviation=0.5): - self.envsum, self.sensesum, self.featsum, *_ = autograd_wrapper.wrap_envops( - envsum_impl=envsum_raw, sensesum_impl=sensesum_raw, featsum_impl=featsum_raw - ) self.tol_f64 = dict(atol=1e-8, rtol=1e-5) self.tol_f32 = dict(atol=1e-5, rtol=1e-5) # Absolute tolerance a bit fuzzier for float32. self.suspicious_deviation = suspicious_deviation @@ -189,11 +202,18 @@ def check_all_grad(self, repeats=3, device=torch.device("cpu")): for i in range(repeats): self.check_all_grad_once(device=device) - def check_allclose_once(self, use_large=False, device=torch.device("cpu")): + def check_allclose_once(self, use_large=False, device=torch.device("cpu"), randomize_order=False): if use_large: - sense, feat, pfirst, psecond = get_simulated_data(**TEST_LARGE_PARAMS, dtype=torch.float32, device=device) + params = TEST_LARGE_PARAMS else: - sense, feat, pfirst, psecond = get_simulated_data(**TEST_TINY_PARAMS, dtype=torch.float64, device=device) + params = TEST_TINY_PARAMS + + sense, feat, pfirst, psecond = get_simulated_data( + **params, + dtype=torch.float32, + device=device, + randomize_order=randomize_order, + ) n_atoms, n_features = feat.shape n_pairs, n_nu = sense.shape @@ -223,8 +243,17 @@ def check_empty(self, device=torch.device("cpu")): sense_g = self.sensesum(env, feat, pfirst, psecond) feat_g = self.featsum(env, sense, pfirst, psecond) except Exception as ee: - raise ValueError("Failed an operation on data with zero pairs") from ee - print("Passed zero-pair check") + raise ValueError("Failed an operation on data with zero pairs!") from ee + print("Passed zero-pair check!") + + def check_forward_noerr(self, device=torch.device("cpu")): + + sense, feat, pfirst, psecond = get_simulated_data(**TEST_TINY_PARAMS, dtype=torch.float64, device=device) + + env = self.envsum(sense, feat, pfirst, psecond) + sense_g = self.sensesum(env, feat, pfirst, psecond) + feat_g = self.featsum(env, sense, pfirst, psecond) + print("Passed forward execution check!") def all_close_witherror(self, r1, r2): r1 = r1.data.cpu().numpy() @@ -269,43 +298,43 @@ def check_allclose(self, repeats=30, use_large=False, device=torch.device("cpu") raise RuntimeError("Failed during iteration {}".format(i)) from ee def check_correctness(self, n_grad=1, n_small=100, n_large=3, device=torch.device("cpu")): - self.check_empty(device=device) + + print("Checking that functions execute...") + safe_synchronize() + self.check_forward_noerr() + + safe_synchronize() + self.check_empty(device=device) # this is now covered in autograd wrapper... + + safe_synchronize() + self.check_allclose_once(device=device, randomize_order=True) + print("Passed random pair order test...") + + safe_synchronize() print("Checking gradients {} times...".format(n_grad)) self.check_all_grad(repeats=n_grad, device=device) print("Passed gradient checks!") + safe_synchronize() print("Checking forward methods on small data {} times...".format(n_small), flush=True) self.check_allclose(repeats=n_small, use_large=False, device=device) print("Passed small tensor forward checks!") + safe_synchronize() print("Checking forward methods on large data {} times...".format(n_large), flush=True) self.check_allclose(repeats=n_large, use_large=True, device=device) print("Passed large tensor forward checks!") - def check_speed(self, n_repetitions=10, device=torch.device("cpu"), data_size=TEST_LARGE_PARAMS, compare_against="pytorch"): - if compare_against.lower() == "pytorch": - comp_envsum = env_pytorch.envsum - comp_sensesum = env_pytorch.sensesum - comp_featsum = env_pytorch.featsum - elif compare_against.lower() == "numba": - comp_envsum = env_numba.new_envsum - comp_sensesum = env_numba.new_sensesum - comp_featsum = env_numba.new_featsum - elif compare_against.lower() == "cupy": - comp_envsum = env_cupy.cupy_envsum - comp_sensesum = env_cupy.cupy_sensesum - comp_featsum = env_cupy.cupy_featsum - elif compare_against.lower() == "triton": - comp_envsum = env_triton.envsum - comp_sensesum = env_triton.featsum - comp_featsum = env_triton.featsum + def check_speed(self, n_repetitions=10, device=torch.device("cpu"), data_size=TEST_LARGE_PARAMS, compare_against="pytorch"): - else: - raise ValueError("Unknown implementation to comapre against:'{}'".format(compare_against)) + comparison_impl = MessagePassingKernels.get_implementation(compare_against) + comp_envsum = comparison_impl.envsum + comp_sensesum = comparison_impl.sensesum + comp_featsum = comparison_impl.featsum - te, ts, tf = (TimerHolder(name) for name in ("Envsum", "Sensesum", "Featsum")) - tne, tns, tnf = (TimerHolder("{}_{}".format(compare_against, name)) for name in ("Envsum", "Sensesum", "Featsum")) + te, ts, tf = (TimerHolder(f"{self.name}_{name}", device=device) for name in ("Envsum", "Sensesum", "Featsum")) + tne, tns, tnf = (TimerHolder(f"{compare_against}_{name}", device=device) for name in ("Envsum", "Sensesum", "Featsum")) - print("Repetitions: {}".format(n_repetitions)) + print(f"Repetitions: {n_repetitions}") with torch.autograd.no_grad(): # Warming up by running on data of this specific size sense, feat, pfirst, psecond = get_simulated_data(**data_size, dtype=torch.float32, device=device) @@ -323,7 +352,6 @@ def check_speed(self, n_repetitions=10, device=torch.device("cpu"), data_size=TE for i in range(n_repetitions): print(".", end="", flush=True) sense, feat, pfirst, psecond = get_simulated_data(**data_size, dtype=torch.float32, device=device) - torch.cuda.synchronize() with tne.add(): env = comp_envsum(sense, feat, pfirst, psecond) with tns.add(): @@ -337,7 +365,7 @@ def check_speed(self, n_repetitions=10, device=torch.device("cpu"), data_size=TE self.sensesum(env, feat, pfirst, psecond) with tf.add(): self.featsum(env, sense, pfirst, psecond) - print() # Newline to terminate the ... printing + print() # Newline to terminate the '...' printing for t in [tne, tns, tnf] + [te, ts, tf]: print("Mean {} time: {} Median: {}".format(t.name, t.mean_elapsed, t.median_elapsed)) for tn, t in zip([tne, tns, tnf], [te, ts, tf]): @@ -348,16 +376,27 @@ def check_speed(self, n_repetitions=10, device=torch.device("cpu"), data_size=TE print("Overall {} time: {}".format(compare_against, tnsum)) print("Overall time now: {}".format(tsum)) print("Overall speedup: {}".format(tnsum / tsum)) - return # prof + new_results = [t.to_dict() for t in [tne, tns, tnf]] + compare_results = [t.to_dict() for t in [te, ts, tf]] + return new_results, compare_results -import time +def safe_synchronize(): + if torch.cuda.is_available(): + torch.cuda.synchronize() + return class TimerHolder: - def __init__(self, name=None): + def __init__(self, name, device): self.snippets = [] self.name = name + if device.type == "cuda": + self.device = torch.cuda.get_device_name(device) + elif device.type == "cpu": + self.device = platform.processor() + else: + self.device = device.type def add(self): t = TimedSnippet() @@ -376,19 +415,27 @@ def mean_elapsed(self): def median_elapsed(self): return np.median([t.elapsed for t in self.snippets]) + def to_dict(self): + return { + "name": self.name, + "count": len(self.snippets), + "mean": self.mean_elapsed, + "median": self.median_elapsed, + "individual": [t.elapsed for t in self.snippets], + "device": self.device, + } + class TimedSnippet: def __init__(self): self.start = self.end = None def __enter__(self): - if torch.cuda.is_available(): - torch.cuda.synchronize() + safe_synchronize() self.start = time.time() def __exit__(self, exc_type, exc_value, exc_tb): - if torch.cuda.is_available(): - torch.cuda.synchronize() + safe_synchronize() self.end = time.time() @property @@ -396,7 +443,18 @@ def elapsed(self): return self.end - self.start -def main(env_impl, sense_impl, feat_impl, args=None): +def speed_test_loop(tester: EnvOpsTester, speed_test_spec: dict, device: torch.device, compare_against: str): + all_results = {} + for sys_type, repetitions in speed_test_spec.items(): + params = TEST_PARAMS[sys_type] + print("-" * 80) + print(f"{sys_type} systems:", params) + results = tester.check_speed(n_repetitions=repetitions, data_size=params, device=device, compare_against=compare_against) + all_results[sys_type] = results + return results + + +def main(args=None): if args is None: # calling without arguments looks for them from command line @@ -408,121 +466,155 @@ def main(env_impl, sense_impl, feat_impl, args=None): args = SimpleNamespace(**args) np.random.seed(args.seed) - tester = Envops_tester( - env_impl, - sense_impl, - feat_impl, - ) + tester = EnvOpsTester(name=args.implementation) + + test_gpu = not args.no_gpu + device = torch.device(args.accelerator) + if test_gpu: + device_type = device.type + if device_type == "cuda": + available = torch.cuda.is_available() + elif device_type == "mps": + available = torch.backends.mps.is_available() + elif device_type == "cpu": + raise ValueError("Accelerator cannot be set to 'cpu'.") + else: + available = True # hopefully... + + if not available: + print(f"Backend {device_type} not available, skipping GPU tests.") + test_gpu = False + + test_cpu = not args.no_cpu + speed = not args.no_speed compare_against = args.compare_against - test_gpu = not args.no_test_gpu - test_cpu = not args.no_test_cpu + if speed: + # just so that we error early if a bad name is specified. + compare_impl = MessagePassingKernels.get_implementation(args.compare_against) correctness = not args.no_correctness - if torch.cuda.is_available() and not args.no_test_gpu: - print("Running GPU tests") - free_mem, total_mem = torch.cuda.memory.mem_get_info() + # Standard test numbers. + speed_tests = dict(mega=3, large=5, medium=100, small=100) + # if compare_against == 'pytorch': + # del speed_tests['mega'] # This takes about 15s/iteration on CPU + + if test_gpu: + print("Running GPU tests.") + gpu_speed_tests = speed_tests.copy() + if torch.cuda.is_available(): + free_mem, total_mem = torch.cuda.memory.mem_get_info() + del total_mem + else: + free_mem = 0 # Don't assume there is a lot of memory. + # More than 2GB memory, we can run large systems use_large_gpu = free_mem > 2**31 + if not use_large_gpu: + print("Torch indicates less than 2GB free GPU memory -- skipping large system test") + del gpu_speed_tests["large"] + else: + gpu_speed_tests["large"] = 20 + + # More than 30GB memory, we can run mega and ultra tests use_verylarge_gpu = free_mem > 30 * (2**30) + if use_verylarge_gpu: + gpu_speed_tests.pop("mega") + gpu_speed_tests = dict(mega=20, **gpu_speed_tests) + else: + gpu_speed_tests.pop("mega", None) + print("Numba indicates less than 30GB free GPU memory -- skipping mega system test") + # Note: base pytorch implementation will error on ultra configurations. use_ultra = (not correctness) and use_verylarge_gpu and (compare_against.lower() != "pytorch") - - n_large_gpu = args.n_large if use_large_gpu else 0 + if use_ultra: + gpu_speed_tests = dict(giga=20, ultra=20, **gpu_speed_tests) if correctness: - tester.check_correctness(device=torch.device("cuda"), n_large=n_large_gpu) - - if use_verylarge_gpu: - if use_ultra: - - print("-" * 80) - print("Giga systems:", TEST_GIGA_PARAMS) - tester.check_speed( - n_repetitions=20, data_size=TEST_GIGA_PARAMS, device=torch.device("cuda"), compare_against=compare_against - ) - print("-" * 80) - print("Ultra systems:", TEST_ULTRA_PARAMS) - tester.check_speed( - n_repetitions=20, data_size=TEST_ULTRA_PARAMS, device=torch.device("cuda"), compare_against=compare_against - ) - print("-" * 80) - print("Mega systems:", TEST_MEGA_PARAMS) - tester.check_speed(n_repetitions=20, data_size=TEST_MEGA_PARAMS, device=torch.device("cuda"), compare_against=compare_against) - else: - print("Numba indicates less than 30GB free GPU memory -- skipping mega system test") - if use_large_gpu: - print("-" * 80) - print("Large systems:", TEST_LARGE_PARAMS) - tester.check_speed(n_repetitions=20, data_size=TEST_LARGE_PARAMS, device=torch.device("cuda"), compare_against=compare_against) + n_large_gpu = args.n_large if use_large_gpu else 0 + tester.check_correctness(device=device, n_large=n_large_gpu) else: - print("Numba indicates less than 2GB free GPU memory -- skipping large system test") - - print("-" * 80) - print("Medium systems:", TEST_MEDIUM_PARAMS) - tester.check_speed(n_repetitions=100, data_size=TEST_MEDIUM_PARAMS, device=torch.device("cuda"), compare_against=compare_against) - print("-" * 80) - print("Small systems:", TEST_SMALL_PARAMS) - tester.check_speed(n_repetitions=100, data_size=TEST_SMALL_PARAMS, device=torch.device("cuda"), compare_against=compare_against) + print("Skipped correctness checks.") - else: - if not args.no_test_gpu: - print("Cuda not available, not running GPU tests.") + if speed: + speed_test_loop(tester, gpu_speed_tests, device=device, compare_against=compare_against) else: - print("Skipped GPU tests.") + print("Skipped speed tests.") + else: + print("Skipped GPU tests.") if test_cpu: - print("Running CPU tests") + print("Running CPU tests.") if correctness: tester.check_correctness(n_large=args.n_large) + else: + print("Skipped correctness checks.") - print("-" * 80) - print("Large systems:", TEST_LARGE_PARAMS) - tester.check_speed(n_repetitions=10, compare_against=compare_against) - print("-" * 80) - print("Medium systems:", TEST_MEDIUM_PARAMS) - tester.check_speed(n_repetitions=100, data_size=TEST_MEDIUM_PARAMS, compare_against=compare_against) - print("-" * 80) - print("Small systems:", TEST_SMALL_PARAMS) - tester.check_speed(n_repetitions=100, compare_against=compare_against, data_size=TEST_SMALL_PARAMS) + if speed: + speed_test_loop(tester, speed_tests, device=torch.device("cpu"), compare_against=compare_against) + else: + print("Skipped speed tests.") else: print("Skipped CPU tests.") +# reference implementation +def test_pytorch_reference(): + sense, feat, pfirst, psecond = get_simulated_data(**TEST_TINY_PARAMS, dtype=torch.float64) + sense.requires_grad_(True) + feat.requires_grad_(True) + env = env_pytorch.envsum(sense, feat, pfirst, psecond) + pt_gradsense, pt_gradfeat = torch.autograd.grad(env.sum(), [sense, feat]) + sense.requires_grad_(False) + feat.requires_grad_(False) + ref_gradsense = env_pytorch.sensesum(torch.ones_like(env), feat, pfirst, psecond) + ref_gradfeat = env_pytorch.featsum(torch.ones_like(env), sense, pfirst, psecond) + + # Note: Numerical differences + assert torch.allclose(pt_gradsense, ref_gradsense, atol=1e-15, rtol=1e-15) + assert torch.allclose(pt_gradfeat, ref_gradfeat, atol=1e-15, rtol=1e-15) + + def parse_args(): import argparse parser = argparse.ArgumentParser() - parser.add_argument("--seed", type=int, default=0, help="name for run") + parser.add_argument("implementation", type=str, help="Implementation to test.") parser.add_argument( "--compare-against", type=str, default="pytorch", help=""" - implementation to compare speed with. Options are: pytorch, numba, cupy, triton""", + Implementation to compare speed with. Options are: pytorch, numba, cupy, triton + Correctness is always compared against pytorch. + """, ) + parser.add_argument("--accelerator", type=str, default="cuda", help="Device to treat as the GPU.") + parser.add_argument("--no-cpu", action="store_true", default=False, help="Flag to skip CPU tests.") + parser.add_argument("--no-gpu", action="store_true", default=False, help="Flag to skip GPU tests.") + parser.add_argument("--no-speed", action="store_true", default=False, help="Flag to skip speed tests.") + parser.add_argument("--no-correctness", action="store_true", default=False, help="Flag to skip correctness tests.") + + parser.add_argument("--seed", type=int, default=0, help="Seed") parser.add_argument( "--n_large", type=int, default=5, help=""" Number of times to check correctness of forward pass. Set this to a large number (e.g. 200) to - stress-test a new implementation against corner-cases.""", + stress-test a new implementation against corner-cases. + """, ) - parser.add_argument("--no-test-cpu", action="store_true", default=False, help="Set to false to skip CPU tests.") - parser.add_argument("--no-test-gpu", action="store_true", default=False, help="Set to false to skip GPU tests.") - parser.add_argument("--no-correctness", action="store_true", default=False, help="Set to false to skip GPU tests.") args = parser.parse_args() return args if __name__ == "__main__": - main( - env_numba.new_envsum, - env_numba.new_sensesum, - env_numba.new_featsum, - ) + # Ensure nothing is going wrong with reference sensesum and featsum + test_pytorch_reference() + + main() diff --git a/hippynn/custom_kernels/test_env_cupy.py b/hippynn/custom_kernels/test_env_cupy.py deleted file mode 100644 index 638dc642..00000000 --- a/hippynn/custom_kernels/test_env_cupy.py +++ /dev/null @@ -1,12 +0,0 @@ -import torch -import numba -from . import env_cupy -from .test_env_numba import Envops_tester, main -from .test_env_numba import TEST_MEGA_PARAMS, TEST_LARGE_PARAMS, TEST_MEDIUM_PARAMS, TEST_SMALL_PARAMS - -if __name__ == "__main__": - main( - env_cupy.cupy_envsum, - env_cupy.cupy_sensesum, - env_cupy.cupy_featsum, - ) diff --git a/hippynn/custom_kernels/test_env_triton.py b/hippynn/custom_kernels/test_env_triton.py deleted file mode 100644 index 95788068..00000000 --- a/hippynn/custom_kernels/test_env_triton.py +++ /dev/null @@ -1,13 +0,0 @@ -import torch -from .env_triton import envsum, sensesum, featsum -from .test_env_numba import Envops_tester, main, get_simulated_data -from .test_env_numba import TEST_MEGA_PARAMS, TEST_LARGE_PARAMS, TEST_MEDIUM_PARAMS, TEST_SMALL_PARAMS -from .utils import resort_pairs_cached - -if __name__ == "__main__": - - main( - envsum, - sensesum, - featsum, - ) diff --git a/hippynn/custom_kernels/test_speed_env.py b/hippynn/custom_kernels/test_speed_env.py new file mode 100644 index 00000000..b3a3b674 --- /dev/null +++ b/hippynn/custom_kernels/test_speed_env.py @@ -0,0 +1,97 @@ +import json +import pathlib + +import numpy as np +import torch +from .registry import MessagePassingKernels + +from .test_env import TEST_PARAMS, EnvOpsTester + + +def parse_args(): + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("implementations", type=str, nargs="*", help="Implementation(s) to test.") + parser.add_argument("--seed", type=int, default=0, help="Seed") + + parser.add_argument("--all-hidden", action="store_true", default=False, help="Use all implementations, even with _ beginning.") + parser.add_argument("--all-impl", action="store_true", default=False, help="Use all non-hidden implementations.") + parser.add_argument("--all-gpu", action="store_true", default=False, help="Use low-mem implementations suitable for GPU.") + parser.add_argument("--all-gpu", action="store_true", default=False, help="CPU-capable implementaitons.") + + for param_type in TEST_PARAMS.keys(): + parser.add_argument(f"--{param_type}", type=int, default=0, help=f"Count for param type {param_type}") + + parser.add_argument(f"--test-all-count", type=int, default=0, help=f"Apply m inimumcount for all param types.") + + parser.add_argument("--accelerator", type=str, default="cuda", help="Device to use.") + parser.add_argument("--file", type=str, default="speed_tests.json", help="Where to store results.") + parser.add_argument("--overwrite", default=False, action="store_true", help="Whether to overwrite.") + + args = parser.parse_args() + return args + + +def main(args=None): + if args is None: + args = parse_args() + + default = args.test_all_count + if default > 0: + for k in TEST_PARAMS: + setattr(args, k, default) + + test_spec = {k: count for k in TEST_PARAMS if (count := getattr(args, k, 0)) > 0} + print(TEST_PARAMS.keys()) + print(test_spec) + results = {} + + implementations = args.implementations + + path = pathlib.Path(args.file) + + if args.all_gpu: + implementations = ["sparse", "numba", "cupy", "triton"] + if args.all_cpu: + implementations = ["sparse", "pytorch", "numba"] + + if args.all_impl: + implementations = MessagePassingKernels.get_available_implementations() + if args.all_hidden: + implementations = MessagePassingKernels.get_available_implementations(hidden=True) + + # Error if implementation does not exist. + for impl in implementations: + MessagePassingKernels.get_implementation(impl) + + if path.suffix != ".json": + raise AssertionError(f"File extension not allowed! Suffix: '{path.suffix}'") + + if not args.overwrite: + if path.exists(): + raise FileExistsError(f"Will not overwrite existing file! {path}") + + if len(implementations) == 0: + raise ValueError("nothing to test") + print("Testing implementations:", implementations) + for impl in implementations: + print("Testing implementation:", impl) + tester = EnvOpsTester(impl) + results[impl] = impl_results = {} + for k, count in test_spec.items(): + print(f"Testing {k} {count} times:") + np.random.seed(args.seed) + out0, out1 = tester.check_speed( + n_repetitions=count, device=torch.device(args.accelerator), data_size=TEST_PARAMS[k], compare_against=impl + ) + impl_results[k] = dict(tested=out0, comparison=out1) + + with open(path, "wt") as f: + json.dump(results, f) + + return results + + +if __name__ == "__main__": + main() diff --git a/hippynn/custom_kernels/utils.py b/hippynn/custom_kernels/utils.py index 7623977f..0a28bc63 100644 --- a/hippynn/custom_kernels/utils.py +++ b/hippynn/custom_kernels/utils.py @@ -1,3 +1,8 @@ +""" + +Utilities for the custom kernels, including pre-sorting the indices. + +""" import torch import threading from functools import partial diff --git a/hippynn/experiment/assembly.py b/hippynn/experiment/assembly.py index 8338ff70..845daac3 100644 --- a/hippynn/experiment/assembly.py +++ b/hippynn/experiment/assembly.py @@ -180,11 +180,15 @@ def precompute_pairs(model, database, batch_size=10, device=None, make_dense=Fal :param device: where to do the precomputation. :param make_dense: return a dense array of pairs. Warning, this can be memory-expensive. However, it is necessary if you are going to use num_workers>0 in your dataloaders. If False, the cache is stored as a sparse array. - :param n_images: number of images for cache storage, increase this if it fails. - However, large values can incur a large memory cost if make_dense is True. + :param n_images: number of images for cache storage; increase this if it fails. + However, large values can incur a very large memory cost if make_dense is True. :return: None-- changes the model graph. + .. note :: + After running pre-compute pairs, your model will expect to load pairs directly from the database, + and your database will contain cached pair entries. + Note that the returned model needs to be re-assembled with the new graph for the cache to take effect. Example usage: >>> precompute_pairs(training_modules.model,database,device='cuda') diff --git a/hippynn/experiment/lightning_trainer.py b/hippynn/experiment/lightning_trainer.py index 3b6e1d52..e141e1f9 100644 --- a/hippynn/experiment/lightning_trainer.py +++ b/hippynn/experiment/lightning_trainer.py @@ -14,6 +14,7 @@ import warnings import copy from pathlib import Path +from typing import Optional import torch @@ -86,7 +87,7 @@ def __init__( raise NotImplementedError("Generic args and kwargs not supported.") @classmethod - def from_experiment_setup(cls, training_modules: TrainingModules, database: Database, setup_params: SetupParams, **kwargs): + def from_experiment_setup(cls, training_modules: TrainingModules, database: Optional[Database], setup_params: SetupParams, **kwargs): """ Create a lightning module using the same arguments as for :func:`hippynn.experiment.setup_and_train`. @@ -94,7 +95,7 @@ def from_experiment_setup(cls, training_modules: TrainingModules, database: Data :param database: :param setup_params: :param kwargs: - :return: + :return: lightning_module, database """ training_modules, controller, metric_tracker = setup_training(training_modules, setup_params) return cls.from_train_setup(training_modules, database, controller, metric_tracker, **kwargs) @@ -103,7 +104,7 @@ def from_experiment_setup(cls, training_modules: TrainingModules, database: Data def from_train_setup( cls, training_modules: TrainingModules, - database: Database, + database: Optional[Database], controller: Controller, metric_tracker: MetricTracker, callbacks=None, @@ -120,7 +121,7 @@ def from_train_setup( :param callbacks: :param batch_callbacks: :param kwargs: - :return: + :return: lightning_module, database """ @@ -153,7 +154,10 @@ def from_train_setup( if callbacks is not None or batch_callbacks is not None: return NotImplemented("arbitrary callbacks are not yet supported with pytorch lightning.") - return trainer, HippynnDataModule(database, controller.batch_size) + if database is not None: + database = HippynnDataModule(database, controller.batch_size) + + return trainer, database def on_save_checkpoint(self, checkpoint) -> None: """ diff --git a/hippynn/layers/hiplayers.py b/hippynn/layers/hiplayers.py index b93aae60..f02cc1bc 100644 --- a/hippynn/layers/hiplayers.py +++ b/hippynn/layers/hiplayers.py @@ -195,6 +195,9 @@ def forward(self, in_features, pair_first, pair_second, dist_pairs): # Z: input features # E: environment features (S*Z) + # Q = (VZ) # torch.mm + # E = (QS) # custom_kernels.featsum + # E = (SZ) env_features = custom_kernels.envsum(sense_vals, in_features, pair_first, pair_second) diff --git a/hippynn/plotting/plotmaker.py b/hippynn/plotting/plotmaker.py index c9d5d796..636badf8 100644 --- a/hippynn/plotting/plotmaker.py +++ b/hippynn/plotting/plotmaker.py @@ -14,6 +14,13 @@ class PlotMaker: """ def __init__(self, *plotters, plot_every, save_dir="plots/"): + """ + + :param plotters: Individual plotters to use. + :param plot_every: How often to make plots during training (in epochs) + :param save_dir: What directory to store the plots in, relative to the + training experiment path. + """ self.plotters = plotters self.save_dir = save_dir self.plot_every = plot_every diff --git a/hippynn/plotting/plotters.py b/hippynn/plotting/plotters.py index 998c13a5..38776a55 100644 --- a/hippynn/plotting/plotters.py +++ b/hippynn/plotting/plotters.py @@ -16,6 +16,14 @@ class Plotter: """ def __init__(self, parents, plt_fn=None, saved=False, shown=False): + """ + Base plotter arguments inherited by all plotters. + + :param parents: nodes reflecting the data required to make the plotter + :param plt_fn: a function to use to make the plot + :param saved: whether to save the plot to a file + :param shown: whether to show the plot using ``plt.show`` + """ self.parents = parents self.shown = shown self.saved = saved diff --git a/pyproject.toml b/pyproject.toml index 4b500306..95d61955 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ classifiers=[ readme="README.rst" dependencies=[ "numpy", - "torch", + "torch>2.0", ] [project.optional-dependencies] @@ -39,4 +39,5 @@ full=[ "h5py", "lightning", "scipy", + "opt_einsum", ] \ No newline at end of file