Skip to content

Commit

Permalink
Merge pull request #330 from MeasureTransport/ExpandAllpy
Browse files Browse the repository at this point in the history
MultiIndexSet Expand() for python and deep copy in all languages
  • Loading branch information
dannys4 authored Apr 18, 2023
2 parents 5549746 + a1a9a95 commit f6c3ebd
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 4 deletions.
1 change: 1 addition & 0 deletions bindings/julia/src/MultiIndex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ void mpart::binding::MultiIndexWrapper(jlcxx::Module &mod) {
mod.method("length", [](FixedMultiIndexSet<Kokkos::HostSpace> &mset){return mset.Length();});
mod.method("size", [](FixedMultiIndexSet<Kokkos::HostSpace> &mset){return mset.Size();});
mod.method("vec", [](MultiIndex const& idx){ return idx.Vector(); });
mod.method("copy", [](MultiIndexSet const& mset){ MultiIndexSet mset_copy = mset; return mset_copy;});
mod.method("==", [](MultiIndex const& idx1, MultiIndex const& idx2){ return idx1 == idx2; });
mod.method("!=", [](MultiIndex const& idx1, MultiIndex const& idx2){ return idx1 != idx2; });
mod.method("<", [](MultiIndex const& idx1, MultiIndex const& idx2){ return idx1 < idx2; });
Expand Down
5 changes: 5 additions & 0 deletions bindings/matlab/mat/MultiIndexSet.m
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ function delete(this)
multi = MultiIndex(multi_id,"id");
end

function mset = DeepCopy(this)
mset_id = MParT_('MultiIndexSet_DeepCopy', this.id_);
mset = MultiIndeSet(mset_id,"id");
end

function result = MultiToIndex(this,multi)
result= MParT_('MultiIndexSet_MultiToIndex',this.id_,multi.get_id());
result = result + 1;
Expand Down
10 changes: 10 additions & 0 deletions bindings/matlab/src/MultiIndexSet_mex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,16 @@ MEX_DEFINE(MultiIndexSet_ExpandAny) (int nlhs, mxArray* plhs[],
output.set(0, mset->Expand());
}

MEX_DEFINE(MultiIndexSet_DeepCopy) (int nlhs, mxArray* plhs[],
int nrhs, const mxArray* prhs[]) {
InputArguments input(nrhs, prhs, 1);
OutputArguments output(nlhs, plhs, 1);
const MultiIndexSet& mset = Session<MultiIndexSet>::getConst(input.get(0));
MultiIndexSet mset_copy = MultiIndexSet::CreateTotalOrder(mset.Length(), 0);
mset_copy += mset;
output.set(0, Session<MultiIndexSet>::create(new MultiIndexSet(mset_copy)));
}

MEX_DEFINE(MultiIndexSet_ForciblyExpand) (int nlhs, mxArray* plhs[],
int nrhs, const mxArray* prhs[]) {
InputArguments input(nrhs, prhs, 2);
Expand Down
12 changes: 9 additions & 3 deletions bindings/python/src/MultiIndex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,10 @@ void mpart::binding::MultiIndexWrapper(py::module &m)
.def(py::init<const unsigned int>())
.def(py::init<Eigen::Ref<const Eigen::MatrixXi> const&>())
.def("fix", &MultiIndexSet::Fix)
.def("__len__", &MultiIndexSet::Length)
.def("__len__", &MultiIndexSet::Length, "Retrieves the length of _each_ multiindex within this set (i.e. the dimension of the input)")
.def("__getitem__", &MultiIndexSet::at)
.def("at", &MultiIndexSet::at)
.def("Size", &MultiIndexSet::Size)
.def("Size", &MultiIndexSet::Size, "Retrieves the number of elements in this MultiIndexSet")

.def_static("CreateTotalOrder", &MultiIndexSet::CreateTotalOrder, py::arg("length"), py::arg("maxOrder"), py::arg("limiter")=MultiIndexLimiter::None())
.def_static("CreateTensorProduct", &MultiIndexSet::CreateTensorProduct, py::arg("length"), py::arg("maxOrder"), py::arg("limiter")=MultiIndexLimiter::None())
Expand All @@ -127,9 +127,15 @@ void mpart::binding::MultiIndexWrapper(py::module &m)
.def("IndexToMulti",&MultiIndexSet::IndexToMulti)
.def("MultiToIndex", &MultiIndexSet::MultiToIndex)
.def("MaxOrders", &MultiIndexSet::MaxOrders)
.def("Expand", py::overload_cast<unsigned int>(&MultiIndexSet::Expand))
.def("Expand", py::overload_cast<unsigned int>(&MultiIndexSet::Expand), "Expand frontier w.r.t one multiindex")
.def("Expand", py::overload_cast<>(&MultiIndexSet::Expand), "Expand all frontiers of a MultiIndexSet")
.def("append", py::overload_cast<MultiIndex const&>(&MultiIndexSet::operator+=))
.def("__iadd__", py::overload_cast<MultiIndex const&>(&MultiIndexSet::operator+=))
.def("DeepCopy",[](const MultiIndexSet& mset)
{
MultiIndexSet mset_copy = mset;
return mset_copy;
})
.def("Activate", py::overload_cast<MultiIndex const&>(&MultiIndexSet::Activate))
.def("AddActive", &MultiIndexSet::AddActive)
.def("Frontier", &MultiIndexSet::Frontier)
Expand Down
4 changes: 3 additions & 1 deletion src/TrainMap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ nlopt::opt SetupOptimization(unsigned int dim, TrainOptions options) {
template<>
double mpart::TrainMap(std::shared_ptr<ConditionalMapBase<Kokkos::HostSpace>> map, std::shared_ptr<MapObjective<Kokkos::HostSpace>> objective, TrainOptions options) {
if(map->Coeffs().extent(0) == 0) {
std::cout << "TrainMap: Initializing map coeffs to 1." << std::endl;
if(options.verbose) {
std::cout << "TrainMap: Initializing map coeffs to 1." << std::endl;
}
Kokkos::View<double*, Kokkos::HostSpace> coeffs ("Default coeffs", map->numCoeffs);
Kokkos::parallel_for("Setting default coeff val", map->numCoeffs, KOKKOS_LAMBDA(const unsigned int i){
coeffs(i) = 1.;
Expand Down

0 comments on commit f6c3ebd

Please sign in to comment.