Skip to content

Commit

Permalink
Move HalfStepHook inheritance to python (#23)
Browse files Browse the repository at this point in the history
  • Loading branch information
pabloferz authored Sep 19, 2023
2 parents 9deea21 + d06ab59 commit e90de42
Show file tree
Hide file tree
Showing 11 changed files with 217 additions and 253 deletions.
13 changes: 10 additions & 3 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
FROM ssages/pysages-base:latest
WORKDIR /hoomd-dlext/.docker_build
WORKDIR /hoomd-dlext

COPY . ../
RUN cmake .. && make install
# Install python dependencies
# hadolint ignore=DL3013
RUN python3 -m pip install --no-cache-dir --upgrade pip "setuptools-scm==7.1.0"

# Build the plugin
COPY . .
RUN cmake -S . -B build && cmake --build build --target install && rm -rf build

# Test it can be loaded
RUN python3 -c "import hoomd; import hoomd.dlext"
22 changes: 22 additions & 0 deletions cmake/Modules/FindHOOMDTools.cmake
Original file line number Diff line number Diff line change
@@ -1,3 +1,25 @@
# This requires setuptools-scm to be installed as we will use it to setup the module version
function(set_version target)
find_package(Python QUIET COMPONENTS Interpreter)
set(GET_VERSION_SCRIPT "
from setuptools_scm import get_version
print(get_version(), end='')"
)
execute_process(
COMMAND ${Python_EXECUTABLE} -c ${GET_VERSION_SCRIPT}
ERROR_VARIABLE error
OUTPUT_VARIABLE GIT_VERSION
RESULT_VARIABLE exit_code
)
if(NOT exit_code EQUAL 0)
message(FATAL_ERROR
"The build process depends on setuptools-scm, make sure it is installed. "
"Got the following error:\n${error}"
)
endif()
target_compile_definitions(${target} PUBLIC GIT_VERSION=${GIT_VERSION})
endfunction()

# Try finding HOOMD first from the current environment
set(HOOMD_GPU_PLATFORM "CUDA" CACHE STRING "GPU backend: CUDA or HIP.")
find_package(HOOMD QUIET)
Expand Down
32 changes: 7 additions & 25 deletions dlext/include/Sampler.h → dlext/include/CallbackHandler.h
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
// SPDX-License-Identifier: MIT
// This file is part of `hoomd-dlext`, see LICENSE.md

#ifndef DLEXT_SAMPLER_H_
#define DLEXT_SAMPLER_H_
#ifndef DLEXT_CALLBACKHANDLER_H_
#define DLEXT_CALLBACKHANDLER_H_

#include "SystemView.h"
#include "hoomd/HalfStepHook.h"
#include "DLExt.h"

namespace hoomd
{
Expand All @@ -20,28 +19,14 @@ using TimeStep = unsigned int;
using TimeStep = uint64_t;
#endif

template <typename ExternalUpdater, template <typename> class Wrapper>
class DEFAULT_VISIBILITY Sampler : public HalfStepHook {
template <template <typename> class Wrapper>
class DEFAULT_VISIBILITY CallbackHandler {
public:
//! Constructor
Sampler(
SystemView& sysview,
ExternalUpdater update_callback,
AccessLocation location,
AccessMode mode
)
CallbackHandler(SystemView& sysview)
: _sysview { sysview }
, _update_callback { update_callback }
, _location { location }
, _mode { mode }
{ }

void setSystemDefinition(SPtr<SystemDefinition> sysdef) override { }
void update(TimeStep timestep) override
{
forward_data(_update_callback, _location, _mode, timestep);
}

const SystemView& system_view() const { return _sysview; }

//! Wraps the system positions, velocities, reverse tags, images and forces as
Expand All @@ -67,13 +52,10 @@ class DEFAULT_VISIBILITY Sampler : public HalfStepHook {

private:
SystemView _sysview;
ExternalUpdater _update_callback;
AccessLocation _location;
AccessMode _mode;
};

} // namespace dlext
} // namespace md
} // namespace hoomd

#endif // DLEXT_SAMPLER_H_
#endif // DLEXT_CALLBACKHANDLER_H_
99 changes: 88 additions & 11 deletions dlext/include/DLExt.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
#ifndef HOOMD_DLPACK_EXTENSION_H_
#define HOOMD_DLPACK_EXTENSION_H_

#include "cxx11utils.h"
#include "SystemView.h"
#include "dlpack/dlpack.h"
#include "hoomd/GlobalArray.h"

#include <type_traits>
#include <vector>
Expand All @@ -18,17 +17,21 @@ namespace md
namespace dlext
{

using namespace cxx11utils;
using namespace hoomd;
namespace cxx11 = cxx11utils;

// { // Aliases

using DLManagedTensorPtr = DLManagedTensor*;
using DLManagedTensorDeleter = void (*)(DLManagedTensorPtr);
using DLManagedTensorDeleter = void (*)(DLManagedTensor*);

template <typename T>
using ArrayHandleUPtr = std::unique_ptr<ArrayHandle<T>>;

template <template <typename> class Array, typename T, typename Object>
using ArrayPropertyGetter = const Array<T>& (Object::*)() const;

template <typename T>
using PropertyGetter = T (*)(const SystemView&, AccessLocation, AccessMode);

// } // Aliases

// { // Constants
Expand Down Expand Up @@ -64,23 +67,30 @@ struct DLDataBridge {
};

template <typename T>
using DLDataBridgeUPtr = std::unique_ptr<DLDataBridge<T>>;

template <typename T>
void delete_bridge(DLManagedTensorPtr tensor)
void delete_bridge(DLManagedTensor* tensor)
{
if (tensor)
delete static_cast<DLDataBridge<T>*>(tensor->manager_ctx);
}

void do_not_delete(DLManagedTensorPtr tensor) { }
void do_not_delete(DLManagedTensor* tensor) { }

template <typename T>
inline void* opaque(T* data) { return static_cast<void*>(data); }

template <typename T>
inline void* opaque(const T* data) { return (void*)(data); }

inline DLDevice device_info(const SystemView& sysview, AccessLocation location)
{
#ifdef ENABLE_CUDA
auto gpu_flag = (location == kOnDevice);
#else
auto gpu_flag = false;
#endif
return DLDevice { gpu_flag ? kDLCUDA : kDLCPU, sysview.get_device_id(gpu_flag) };
}

template <typename>
constexpr DLDataType dtype();
template <>
Expand All @@ -107,6 +117,73 @@ constexpr int64_t stride1<int3>() { return 3; }
template <>
constexpr int64_t stride1<unsigned int>() { return 1; }

template <template <typename> class A, typename T, typename O>
DLManagedTensor* wrap(
const SystemView& sysview, ArrayPropertyGetter<A, T, O> getter,
AccessLocation requested_location, AccessMode mode,
int64_t size2 = 1, uint64_t offset = 0, uint64_t stride1_offset = 0
)
{
assert((size2 >= 1));

auto location = sysview.is_gpu_enabled() ? requested_location : kOnHost;
auto handle = cxx11::make_unique<ArrayHandle<T>>(
INVOKE(*(sysview.particle_data()), getter)(), location, mode
);
auto bridge = cxx11::make_unique<DLDataBridge<T>>(handle);

bridge->tensor.manager_ctx = bridge.get();
bridge->tensor.deleter = delete_bridge<T>;

auto& dltensor = bridge->tensor.dl_tensor;
dltensor.data = opaque(bridge->handle->data);
dltensor.device = device_info(sysview, location);
dltensor.dtype = dtype<T>();

auto& shape = bridge->shape;
shape.push_back(particle_number<A>(sysview));
if (size2 > 1)
shape.push_back(size2);

auto& strides = bridge->strides;
strides.push_back(stride1<T>() + stride1_offset);
if (size2 > 1)
strides.push_back(1);

dltensor.ndim = shape.size();
dltensor.shape = reinterpret_cast<std::int64_t*>(shape.data());
dltensor.strides = reinterpret_cast<std::int64_t*>(strides.data());
dltensor.byte_offset = offset;

return &(bridge.release()->tensor);
}

#define DLEXT_PROPERTY_WRAPPER(PROPERTY, GETTER, SIZE1) \
struct PROPERTY final { \
static DLManagedTensor* from( \
const SystemView& sysview, AccessLocation location, AccessMode mode = kReadWrite \
) \
{ \
return wrap(sysview, GETTER, location, mode, SIZE1); \
} \
};

DLEXT_PROPERTY_WRAPPER(PositionsTypes, &ParticleData::getPositions, 4)
DLEXT_PROPERTY_WRAPPER(VelocitiesMasses, &ParticleData::getVelocities, 4)
DLEXT_PROPERTY_WRAPPER(Orientations, &ParticleData::getOrientationArray, 4)
DLEXT_PROPERTY_WRAPPER(AngularMomenta, &ParticleData::getAngularMomentumArray, 4)
DLEXT_PROPERTY_WRAPPER(MomentsOfInertia, &ParticleData::getMomentsOfInertiaArray, 3)
DLEXT_PROPERTY_WRAPPER(Charges, &ParticleData::getCharges, 1)
DLEXT_PROPERTY_WRAPPER(Diameters, &ParticleData::getDiameters, 1)
DLEXT_PROPERTY_WRAPPER(Images, &ParticleData::getImages, 3)
DLEXT_PROPERTY_WRAPPER(Tags, &ParticleData::getTags, 1)
DLEXT_PROPERTY_WRAPPER(RTags, &ParticleData::getRTags, 1)
DLEXT_PROPERTY_WRAPPER(NetForces, &ParticleData::getNetForce, 4)
DLEXT_PROPERTY_WRAPPER(NetTorques, &ParticleData::getNetTorqueArray, 4)
DLEXT_PROPERTY_WRAPPER(NetVirial, &ParticleData::getNetVirial, 6)

#undef DLEXT_PROPERTY_WRAPPER

} // namespace dlext
} // namespace md
} // namespace hoomd
Expand Down
Loading

0 comments on commit e90de42

Please sign in to comment.