Skip to content

Commit

Permalink
Merge pull request #288 from MeasureTransport/issue287
Browse files Browse the repository at this point in the history
Add LogDeterminantInputGrad function to bindings.
  • Loading branch information
dannys4 authored Dec 19, 2022
2 parents bf34ed7 + ed105a1 commit 53c2c1a
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 0 deletions.
7 changes: 7 additions & 0 deletions bindings/julia/src/ConditionalMapBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ void mpart::binding::ConditionalMapBaseWrapper(jlcxx::Module &mod) {
map.LogDeterminantCoeffGradImpl(JuliaToKokkos(pts), JuliaToKokkos(output));
return output;
})
.method("LogDeterminantInputGrad", [](ConditionalMapBase<Kokkos::HostSpace>& map, jlcxx::ArrayRef<double,2> pts){
unsigned int numPts = size(pts,1);
unsigned int numInputs = map.inputDim;
jlcxx::ArrayRef<double,2> output = jlMalloc<double>(numInputs, numPts);
map.LogDeterminantInputGradImpl(JuliaToKokkos(pts), JuliaToKokkos(output));
return output;
})
.method("Inverse", [](ConditionalMapBase<Kokkos::HostSpace> &map, jlcxx::ArrayRef<double,2> x1, jlcxx::ArrayRef<double,2> r) {
unsigned int numPts = size(r,1);
unsigned int outputDim = map.outputDim;
Expand Down
5 changes: 5 additions & 0 deletions bindings/matlab/mat/ConditionalMap.m
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,11 @@ function SetCoeffs(this,coeffs)
MParT_('ConditionalMap_LogDeterminantCoeffGrad',this.id_,pts,result);
end

function result = LogDeterminantInputGrad(this,pts)
result = zeros(this.inputDim, size(pts,2));
MParT_('ConditionalMap_LogDeterminantInputGrad',this.id_,pts,result);
end

function result = get_id(this)
result = this.id_;
end
Expand Down
13 changes: 13 additions & 0 deletions bindings/matlab/src/ConditionalMap_mex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,19 @@ MEX_DEFINE(ConditionalMap_LogDeterminantCoeffGrad) (int nlhs, mxArray* plhs[],
condMap.map_ptr->LogDeterminantCoeffGradImpl(pts,out);
}

MEX_DEFINE(ConditionalMap_LogDeterminantInputGrad) (int nlhs, mxArray* plhs[],
int nrhs, const mxArray* prhs[]) {
InputArguments input(nrhs, prhs, 3);
OutputArguments output(nlhs, plhs, 0);

const ConditionalMapMex& condMap = Session<ConditionalMapMex>::getConst(input.get(0));

auto pts = MexToKokkos2d(prhs[1]);
auto out = MexToKokkos2d(prhs[2]);

condMap.map_ptr->LogDeterminantInputGradImpl(pts,out);
}

} // namespace

MEX_DISPATCH // Don't forget to add this if MEX_DEFINE() is used.
1 change: 1 addition & 0 deletions bindings/python/src/ConditionalMapBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ void mpart::binding::ConditionalMapBaseWrapper(py::module &m)
.def("LogDeterminant", static_cast<Eigen::VectorXd (ConditionalMapBase<MemorySpace>::*)(Eigen::Ref<const Eigen::RowMatrixXd> const&)>(&ConditionalMapBase<MemorySpace>::LogDeterminant))
.def("Inverse", static_cast<Eigen::RowMatrixXd (ConditionalMapBase<MemorySpace>::*)(Eigen::Ref<const Eigen::RowMatrixXd> const&, Eigen::Ref<const Eigen::RowMatrixXd> const&)>(&ConditionalMapBase<MemorySpace>::Inverse))
.def("LogDeterminantCoeffGrad", static_cast<Eigen::RowMatrixXd (ConditionalMapBase<MemorySpace>::*)(Eigen::Ref<const Eigen::RowMatrixXd> const&)>(&ConditionalMapBase<MemorySpace>::LogDeterminantCoeffGrad))
.def("LogDeterminantInputGrad", static_cast<Eigen::RowMatrixXd (ConditionalMapBase<MemorySpace>::*)(Eigen::Ref<const Eigen::RowMatrixXd> const&)>(&ConditionalMapBase<MemorySpace>::LogDeterminantInputGrad))
.def("GetBaseFunction", &ConditionalMapBase<MemorySpace>::GetBaseFunction)
;

Expand Down
3 changes: 3 additions & 0 deletions docs/source/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ This installation should also automatically install and build Kokkos, Eigen, Cer
Feel free to mix and match previous installations of Eigen, Cereal, Kokkos, Pybind11, and Catch2 with libraries you don't already have using these :code:`X_ROOT` flags. Note that Catch2 and Kokkos in this example will need to be compiled with shared libraries. MParT has not been tested with all versions of all dependencies, but it does require CMake version >=3.13. Further, it has been tested with Kokkos 3.7.0, Eigen 3.4.0, Pybind11 2.9.2, Cereal 1.3.2, and Catch2 3.1.0 (there have been some issues encountered when compiling MParT with Catch2 3.0.1).

.. tip::
If you are using Kokkos <3.7.0, you will need to use the :code:`Kokkos_ENABLE_PTHREAD` flag instead of :code:`Kokkos_ENABLE_THREADS` in the CMake configuration.

You can force MParT to use previously installed versions of the dependencies by setting :code:`MPART_FETCH_DEPS=OFF`. The default value of :code:`MPART_FETCH_DEPS=ON` will allow MParT to download and locally install any external dependencies using CMake's :code:`FetchContent` directive.

Note that if you do not wish to compile bindings for Python, Julia, or Matlab, you can turn off binding compilation by setting the :code:`MPART_<language>=OFF` variable during CMake configuration. For a default build with only the core c++ library, and without requiring the Cereal library, you can use
Expand Down

0 comments on commit 53c2c1a

Please sign in to comment.