From 35aa6ca944d864ccff2c6b893f19ae9babfc38cd Mon Sep 17 00:00:00 2001 From: rubiop Date: Sat, 16 Jul 2022 17:21:40 -0400 Subject: [PATCH 01/20] removed useless dependency --- bindings/matlab/src/ConditionalMap_mex.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/bindings/matlab/src/ConditionalMap_mex.cpp b/bindings/matlab/src/ConditionalMap_mex.cpp index e65a9bee..f464bddf 100644 --- a/bindings/matlab/src/ConditionalMap_mex.cpp +++ b/bindings/matlab/src/ConditionalMap_mex.cpp @@ -8,7 +8,6 @@ #include "MParT/ConditionalMapBase.h" #include "MParT/TriangularMap.h" #include -#include "mexplus_eigen.h" #include From 2ccf672e57ec15c917409476edc93927153f4c48 Mon Sep 17 00:00:00 2001 From: rubiop Date: Sat, 16 Jul 2022 22:03:15 -0400 Subject: [PATCH 02/20] completed MultiIndexSet methods w/o Mutlndex --- README.md | 2 - bindings/matlab/external/mexplus_eigen.h | 47 ++---------- bindings/matlab/mat/MultiIndexSet.m | 47 ++++++++++++ bindings/matlab/notes_binding.md | 9 +++ bindings/matlab/src/MultiIndexSet_mex.cpp | 91 ++++++++++++++++++++++- 5 files changed, 152 insertions(+), 44 deletions(-) create mode 100644 bindings/matlab/notes_binding.md diff --git a/README.md b/README.md index ac695d8c..8df9943e 100644 --- a/README.md +++ b/README.md @@ -42,9 +42,7 @@ cd build cmake \ -DCMAKE_INSTALL_PREFIX= \ -DPYTHON_EXECUTABLE=`which python` \ - -DCMAKE_BUILD_TYPE=Release \ -DCMAKE_OSX_ARCHITECTURES=x86_64 \ - -DMatlab_MEX_EXTENSION="mexmaci64" \ -DKokkos_ENABLE_PTHREAD=ON \ -DKokkos_ENABLE_SERIAL=ON \ .. diff --git a/bindings/matlab/external/mexplus_eigen.h b/bindings/matlab/external/mexplus_eigen.h index 81049719..d93c8b99 100644 --- a/bindings/matlab/external/mexplus_eigen.h +++ b/bindings/matlab/external/mexplus_eigen.h @@ -1,22 +1,8 @@ -/* - * eos - A 3D Morphable Model fitting library written in modern C++11/14. - * - * File: matlab/include/mexplus_eigen.hpp - * - * Copyright 2016-2018 Patrik Huber - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. +/** + * Adapted from: https://github.com/patrikhuber/eos/blob/master/matlab/include/mexplus_eigen.hpp + * @brief Handle custom conversion between mxarray and C++ matrix types */ + #pragma once #ifndef MPART_MEXPLUS_EIGEN_HPP @@ -34,12 +20,7 @@ namespace mexplus { /** - * @brief Define a template specialisation for Eigen::MatrixXd for ... . - * - * The default precision in Matlab is double, but most matrices in eos (for example the PCA basis matrices - * are stored as float values, so this defines conversion from these matrices to Matlab. - * - * Todo: Documentation. + * @brief Define a template specialisation for Eigen::MatrixXd */ template <> inline mxArray* MxArray::from(const Eigen::MatrixXd& eigen_matrix) @@ -48,12 +29,6 @@ inline mxArray* MxArray::from(const Eigen::MatrixXd& eigen_matrix) const int num_cols = static_cast(eigen_matrix.cols()); MxArray out_array(MxArray::Numeric(num_rows, num_cols)); - // This might not copy the data but it's evil and probably really dangerous!!!: - // mxSetData(const_cast(matrix.get()), (void*)value.data()); - - // This copies the data. But I suppose it makes sense that we copy the data when we go - // from C++ to Matlab, since Matlab can unload the C++ mex module at any time I think. - // Loop is column-wise for (int c = 0; c < num_cols; ++c) { for (int r = 0; r < num_rows; ++r) @@ -67,10 +42,8 @@ inline mxArray* MxArray::from(const Eigen::MatrixXd& eigen_matrix) /** - * @brief Define a template specialisation for Eigen::MatrixXd for ... . + * @brief Define a template specialisation for Eigen::MatrixXd * - * Todo: Documentation. - * TODO: Maybe provide this one as MatrixXf as well as MatrixXd? Matlab's default is double? */ template <> inline void MxArray::to(const mxArray* in_array, Eigen::MatrixXd* eigen_matrix) @@ -102,14 +75,8 @@ inline void MxArray::to(const mxArray* in_array, Eigen::MatrixXd* eigen_matrix) const auto nrows = array.dimensions()[0]; // or use array.rows() const auto ncols = array.dimensions()[1]; - // I think I can just use Eigen::Matrix, not a Map - the Matrix c'tor that we call creates a Map anyway? Eigen::Map> eigen_map( array.getData(), nrows, ncols); - // Not sure that's alright - who owns the data? I think as it is now, everything points to the data in the - // mxArray owned by Matlab, but I'm not 100% sure. - // Actually, doesn't eigen_map go out of scope and get destroyed? This might be trouble? But this - // assignment should (or might) copy, then it's fine? Check if it invokes the copy c'tor. - // 2 May 2018: Yes this copies. *eigen_matrix = eigen_map; }; @@ -182,4 +149,4 @@ inline void MxArray::to(const mxArray* in_array, Kokkos::View; template class mexplus::Session>; @@ -44,7 +45,6 @@ MEX_DEFINE(MultiIndexSet_delete) (int nlhs, mxArray* plhs[], Session::destroy(input.get(0)); } -// Defines MEX API for delete. MEX_DEFINE(MultiIndexSet_MaxOrders) (int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) { InputArguments input(nrhs, prhs, 1); @@ -53,4 +53,91 @@ MEX_DEFINE(MultiIndexSet_MaxOrders) (int nlhs, mxArray* plhs[], output.set(0, mset.MaxOrders()); } +MEX_DEFINE(MultiIndexSet_Size) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 1); + OutputArguments output(nlhs, plhs, 1); + const MultiIndexSet& mset = Session::getConst(input.get(0)); + output.set(0, mset.Size()); +} + +MEX_DEFINE(MultiIndexSet_Expand) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 2); + OutputArguments output(nlhs, plhs, 1); + MultiIndexSet *mset = Session::get(input.get(0)); + unsigned int activeInd = input.get(1); + output.set(0, mset->Expand(activeInd)); +} + +MEX_DEFINE(MultiIndexSet_ForciblyExpand) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 2); + OutputArguments output(nlhs, plhs, 1); + MultiIndexSet *mset = Session::get(input.get(0)); + const unsigned int activeIndex = input.get(1); + output.set(0, mset->ForciblyExpand(activeIndex)); +} + +MEX_DEFINE(MultiIndexSet_Frontier) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 1); + OutputArguments output(nlhs, plhs, 1); + const MultiIndexSet& mset = Session::getConst(input.get(0)); + output.set(0, mset.Frontier()); +} + +MEX_DEFINE(MultiIndexSet_StrictFrontier) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 1); + OutputArguments output(nlhs, plhs, 1); + const MultiIndexSet& mset = Session::getConst(input.get(0)); + output.set(0, mset.StrictFrontier()); +} + +MEX_DEFINE(MultiIndexSet_BackwardNeighbors) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 2); + OutputArguments output(nlhs, plhs, 1); + const MultiIndexSet& mset = Session::getConst(input.get(0)); + unsigned int activeIndex = input.get(1); + output.set(0, mset.BackwardNeighbors(activeIndex)); +} + +MEX_DEFINE(MultiIndexSet_IsExpandable) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 2); + OutputArguments output(nlhs, plhs, 1); + const MultiIndexSet& mset = Session::getConst(input.get(0)); + unsigned int activeIndex = input.get(1); + output.set(0, mset.IsExpandable(activeIndex)); +} + +MEX_DEFINE(MultiIndexSet_NumActiveForward) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 2); + OutputArguments output(nlhs, plhs, 1); + const MultiIndexSet& mset = Session::getConst(input.get(0)); + unsigned int activeInd = input.get(1); + output.set(0, mset.NumActiveForward(activeInd)); +} + +MEX_DEFINE(MultiIndexSet_NumForward) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 2); + OutputArguments output(nlhs, plhs, 1); + const MultiIndexSet& mset = Session::getConst(input.get(0)); + unsigned int activeInd = input.get(1); + output.set(0, mset.NumForward(activeInd)); +} + +MEX_DEFINE(MultiIndexSet_Visualize) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 1); + OutputArguments output(nlhs, plhs, 0); + const MultiIndexSet& mset = Session::getConst(input.get(0)); + mset.Visualize(); +} + + } // namespace \ No newline at end of file From 5202278d054b92e3f2af0c992bb83122a6adc3f6 Mon Sep 17 00:00:00 2001 From: rubiop Date: Sat, 16 Jul 2022 22:14:15 -0400 Subject: [PATCH 03/20] update notes binding --- bindings/matlab/notes_binding.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/bindings/matlab/notes_binding.md b/bindings/matlab/notes_binding.md index 0ee3f680..44dc8e22 100644 --- a/bindings/matlab/notes_binding.md +++ b/bindings/matlab/notes_binding.md @@ -3,7 +3,12 @@ # Filling .m files # getConst +'const MultiIndexSet& mset = Session::getConst(input.get(0))' +'output.set(0, mset.Frontier())' # get +'MultiIndexSet *mset = Session::get(input.get(0))' +'output.set(0, mset->ForciblyExpand(activeIndex))' + # id_ From 1d123fb744c0b0ea5767492c15a03ca585f21a86 Mon Sep 17 00:00:00 2001 From: rubiop Date: Sat, 16 Jul 2022 22:17:43 -0400 Subject: [PATCH 04/20] update bonding notes --- bindings/matlab/notes_binding.md | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/bindings/matlab/notes_binding.md b/bindings/matlab/notes_binding.md index 44dc8e22..15b43ec4 100644 --- a/bindings/matlab/notes_binding.md +++ b/bindings/matlab/notes_binding.md @@ -3,12 +3,17 @@ # Filling .m files # getConst -'const MultiIndexSet& mset = Session::getConst(input.get(0))' -'output.set(0, mset.Frontier())' +``` +const MultiIndexSet& mset = Session::getConst(input.get(0)) +output.set(0, mset.Frontier()) +``` + # get +``` +MultiIndexSet *mset = Session::get(input.get(0)) +output.set(0, mset->ForciblyExpand(activeIndex)) +``` -'MultiIndexSet *mset = Session::get(input.get(0))' -'output.set(0, mset->ForciblyExpand(activeIndex))' # id_ From 3b3b3d8513b2b7e81840b8c258c15692b83a998f Mon Sep 17 00:00:00 2001 From: rubiop Date: Sun, 17 Jul 2022 14:30:54 -0400 Subject: [PATCH 05/20] first methods of MultiIndex --- bindings/matlab/CMakeLists.txt | 1 + bindings/matlab/mat/MultiIndex.m | 47 +++++++++++++++++++++ bindings/matlab/notes_binding.md | 2 + bindings/matlab/src/MultiIndex_mex.cpp | 57 ++++++++++++++++++++++++++ 4 files changed, 107 insertions(+) create mode 100644 bindings/matlab/mat/MultiIndex.m create mode 100644 bindings/matlab/src/MultiIndex_mex.cpp diff --git a/bindings/matlab/CMakeLists.txt b/bindings/matlab/CMakeLists.txt index 05aa22e6..c2343b95 100644 --- a/bindings/matlab/CMakeLists.txt +++ b/bindings/matlab/CMakeLists.txt @@ -8,6 +8,7 @@ set(MEX_SOURCE src/KokkosUtilities_mex.cpp src/ConditionalMap_mex.cpp src/MultiIndexSet_mex.cpp + src/MultiIndex_mex.cpp src/FixedMultiIndexSet_mex.cpp src/MexArrayConversions.cpp src/MexMapOptionsConversions.cpp diff --git a/bindings/matlab/mat/MultiIndex.m b/bindings/matlab/mat/MultiIndex.m new file mode 100644 index 00000000..0d9ce2e2 --- /dev/null +++ b/bindings/matlab/mat/MultiIndex.m @@ -0,0 +1,47 @@ +classdef MultiIndex < handle +%DATABASE Example usage of the mexplus development kit. +% +% This class definition gives an interface to the underlying MEX functions +% built in the private directory. It is a good practice to wrap MEX functions +% with Matlab script so that the API is well documented and separated from +% its C++ implementation. Also such a wrapper is a good place to validate +% input arguments. +% +% Build +% ----- +% +% make +% +% See `make.m` for details. +% + +properties (Access = private) + id_ +end + +methods + function this = MultiIndex(varargin) + if(nargin==2) + this.id_ = MParT_('MultiIndex_newDefault', varargin{1},varargin{2}); + else + if length(varargin{1})==1 + this.id_ = MParT_('MultiIndex_newDefault', varargin{1},0); + else + this.id_ = MParT_('MultiIndex_newEigen',varargin{1}); + end + end + end + + function delete(this) + %DELETE Destructor. + MParT_('MultiIndex_delete', this.id_); + end + + function result = String(this) + result = MParT_('MultiIndex_String', this.id_); + end + + +end + +end diff --git a/bindings/matlab/notes_binding.md b/bindings/matlab/notes_binding.md index 15b43ec4..f80160e7 100644 --- a/bindings/matlab/notes_binding.md +++ b/bindings/matlab/notes_binding.md @@ -15,5 +15,7 @@ MultiIndexSet *mset = Session::get(input.get(0)) output.set(0, mset->ForciblyExpand(activeIndex)) ``` +# MapOption + # id_ diff --git a/bindings/matlab/src/MultiIndex_mex.cpp b/bindings/matlab/src/MultiIndex_mex.cpp new file mode 100644 index 00000000..d0afd614 --- /dev/null +++ b/bindings/matlab/src/MultiIndex_mex.cpp @@ -0,0 +1,57 @@ +#include +#include "MParT/MultiIndices/MultiIndexSet.h" +#include "MParT/MultiIndices/MultiIndex.h" + +#include "MParT/Utilities/ArrayConversions.h" +#include "MexArrayConversions.h" +#include "mexplus_eigen.h" +#include + + +using namespace mpart; +using namespace mexplus; + + +// Instance manager for MultiIndex +template class mexplus::Session; + + +namespace { + +MEX_DEFINE(MultiIndex_newDefault) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + + InputArguments input(nrhs, prhs, 2); + OutputArguments output(nlhs, plhs, 1); + const unsigned int lengthIn = input.get(0); + const unsigned int val = input.get(1); + output.set(0, Session::create(new MultiIndex(lengthIn,val))); +} + +MEX_DEFINE(MultiIndex_newEigen) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + + InputArguments input(nrhs, prhs, 1); + OutputArguments output(nlhs, plhs, 1); + const auto multi = input.get(0); + output.set(0, Session::create(new MultiIndex(multi.cast()))); +} + +// Defines MEX API for delete. +MEX_DEFINE(MultiIndex_delete) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 1); + OutputArguments output(nlhs, plhs, 0); + Session::destroy(input.get(0)); +} + +MEX_DEFINE(MultiIndex_String) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 1); + OutputArguments output(nlhs, plhs, 1); + const MultiIndex& mset = Session::getConst(input.get(0)); + output.set(0, mset.String()); +} + + +} // namespace \ No newline at end of file From ef7b2aeef02cf331541f99e644ffdffb969e8858 Mon Sep 17 00:00:00 2001 From: rubiop Date: Sun, 17 Jul 2022 16:29:17 -0400 Subject: [PATCH 06/20] more multiIndex methods --- bindings/matlab/mat/MultiIndex.m | 16 ++++++++++ bindings/matlab/src/MultiIndex_mex.cpp | 43 +++++++++++++++++++++++--- 2 files changed, 55 insertions(+), 4 deletions(-) diff --git a/bindings/matlab/mat/MultiIndex.m b/bindings/matlab/mat/MultiIndex.m index 0d9ce2e2..9d0840b6 100644 --- a/bindings/matlab/mat/MultiIndex.m +++ b/bindings/matlab/mat/MultiIndex.m @@ -37,6 +37,22 @@ function delete(this) MParT_('MultiIndex_delete', this.id_); end + function result = Vector(this) + result = MParT_('MultiIndex_Vector', this.id_); + end + + function result = Sum(this) + result = MParT_('MultiIndex_Sum', this.id_); + end + + function result = Max(this) + result = MParT_('MultiIndex_Max', this.id_); + end + + function result = Set(this,ind,val) + result = MParT_('MultiIndex_Set', this.id_,ind-1,val); + end + function result = String(this) result = MParT_('MultiIndex_String', this.id_); end diff --git a/bindings/matlab/src/MultiIndex_mex.cpp b/bindings/matlab/src/MultiIndex_mex.cpp index d0afd614..df57cce1 100644 --- a/bindings/matlab/src/MultiIndex_mex.cpp +++ b/bindings/matlab/src/MultiIndex_mex.cpp @@ -33,8 +33,8 @@ MEX_DEFINE(MultiIndex_newEigen) (int nlhs, mxArray* plhs[], InputArguments input(nrhs, prhs, 1); OutputArguments output(nlhs, plhs, 1); - const auto multi = input.get(0); - output.set(0, Session::create(new MultiIndex(multi.cast()))); + const auto mult = input.get(0); + output.set(0, Session::create(new MultiIndex(mult.cast()))); } // Defines MEX API for delete. @@ -45,12 +45,47 @@ MEX_DEFINE(MultiIndex_delete) (int nlhs, mxArray* plhs[], Session::destroy(input.get(0)); } +MEX_DEFINE(MultiIndex_Vector) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 1); + OutputArguments output(nlhs, plhs, 1); + const MultiIndex& multi = Session::getConst(input.get(0)); + output.set(0, multi.Vector()); +} + +MEX_DEFINE(MultiIndex_Sum) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 1); + OutputArguments output(nlhs, plhs, 1); + const MultiIndex& multi = Session::getConst(input.get(0)); + output.set(0, multi.Sum()); +} + +MEX_DEFINE(MultiIndex_Max) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 1); + OutputArguments output(nlhs, plhs, 1); + const MultiIndex& multi = Session::getConst(input.get(0)); + output.set(0, multi.Max()); +} + +MEX_DEFINE(MultiIndex_Set) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 3); + OutputArguments output(nlhs, plhs, 1); + MultiIndex *multi = Session::get(input.get(0)); + unsigned int ind = input.get(1); + unsigned int val = input.get(2); + output.set(0, multi->Set(ind,val)); +} + + MEX_DEFINE(MultiIndex_String) (int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) { InputArguments input(nrhs, prhs, 1); OutputArguments output(nlhs, plhs, 1); - const MultiIndex& mset = Session::getConst(input.get(0)); - output.set(0, mset.String()); + const MultiIndex& multi = Session::getConst(input.get(0)); + output.set(0, multi.String()); } From 2843b6ed939d32ed14285c4c21a5386dd2f7a8ac Mon Sep 17 00:00:00 2001 From: rubiop Date: Sun, 17 Jul 2022 17:36:27 -0400 Subject: [PATCH 07/20] binding overloaded operator --- MParT/MultiIndices/MultiIndex.h | 2 +- bindings/matlab/mat/MultiIndex.m | 20 +++++++++++++++ bindings/matlab/src/MultiIndex_mex.cpp | 35 ++++++++++++++++++++++++++ 3 files changed, 56 insertions(+), 1 deletion(-) diff --git a/MParT/MultiIndices/MultiIndex.h b/MParT/MultiIndices/MultiIndex.h index 9e731d79..3ae272c2 100644 --- a/MParT/MultiIndices/MultiIndex.h +++ b/MParT/MultiIndices/MultiIndex.h @@ -85,7 +85,7 @@ class MultiIndex { */ bool Set(unsigned int ind, unsigned int val); - /** Obtain the a particular component of the multiindex. Notice that this function can be slow for multiindices with many nonzero components. The worst case performance requires \f$O(|\mathbf{j}|_0)\f$ integer comparisons, where \f$|\mathbf{j}|_0\f$ denotes the number of nonzero entries in the multiindex. + /** Obtain a particular component of the multiindex. Notice that this function can be slow for multiindices with many nonzero components. The worst case performance requires \f$O(|\mathbf{j}|_0)\f$ integer comparisons, where \f$|\mathbf{j}|_0\f$ denotes the number of nonzero entries in the multiindex. @param[in] ind The component to return. @return The integer stored in component dim of the multiindex. */ diff --git a/bindings/matlab/mat/MultiIndex.m b/bindings/matlab/mat/MultiIndex.m index 9d0840b6..a60121af 100644 --- a/bindings/matlab/mat/MultiIndex.m +++ b/bindings/matlab/mat/MultiIndex.m @@ -53,10 +53,30 @@ function delete(this) result = MParT_('MultiIndex_Set', this.id_,ind-1,val); end + function result = Get(this,ind) + result = MParT_('MultiIndex_Get', this.id_,ind-1); + end + + function result = NumNz(this) + result = MParT_('MultiIndex_NumNz', this.id_); + end + function result = String(this) result = MParT_('MultiIndex_String', this.id_); end + function result = Length(this) + result = MParT_('MultiIndex_Length', this.id_); + end + + function result = compare(this,multi) + result = MParT_('MultiIndex_Eq',this.id_,multi.get_id()); + end + + function result = get_id(this) + result = this.id_; + end + end diff --git a/bindings/matlab/src/MultiIndex_mex.cpp b/bindings/matlab/src/MultiIndex_mex.cpp index df57cce1..af513ad5 100644 --- a/bindings/matlab/src/MultiIndex_mex.cpp +++ b/bindings/matlab/src/MultiIndex_mex.cpp @@ -79,6 +79,22 @@ MEX_DEFINE(MultiIndex_Set) (int nlhs, mxArray* plhs[], output.set(0, multi->Set(ind,val)); } +MEX_DEFINE(MultiIndex_Get) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 2); + OutputArguments output(nlhs, plhs, 1); + const MultiIndex& multi = Session::getConst(input.get(0)); + unsigned int ind = input.get(1); + output.set(0, multi.Get(ind)); +} + +MEX_DEFINE(MultiIndex_NumNz) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 1); + OutputArguments output(nlhs, plhs, 1); + const MultiIndex& multi = Session::getConst(input.get(0)); + output.set(0, multi.NumNz()); +} MEX_DEFINE(MultiIndex_String) (int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) { @@ -88,5 +104,24 @@ MEX_DEFINE(MultiIndex_String) (int nlhs, mxArray* plhs[], output.set(0, multi.String()); } +MEX_DEFINE(MultiIndex_Length) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 1); + OutputArguments output(nlhs, plhs, 1); + const MultiIndex& multi = Session::getConst(input.get(0)); + output.set(0, multi.Length()); +} + +MEX_DEFINE(MultiIndex_Eq) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 2); + OutputArguments output(nlhs, plhs, 1); + const MultiIndex& multi = Session::getConst(input.get(0)); + const MultiIndex& multi2 = Session::getConst(input.get(1)); + output.set(0, multi==multi2); +} + + + } // namespace \ No newline at end of file From a141d3782627f96135f8de24aabf9ff49c61d0ff Mon Sep 17 00:00:00 2001 From: rubiop Date: Sun, 17 Jul 2022 18:11:46 -0400 Subject: [PATCH 08/20] finished operator overload for MultiIndex --- bindings/matlab/mat/MultiIndex.m | 28 +++++++++++++++- bindings/matlab/src/MultiIndex_mex.cpp | 44 ++++++++++++++++++++++++++ 2 files changed, 71 insertions(+), 1 deletion(-) diff --git a/bindings/matlab/mat/MultiIndex.m b/bindings/matlab/mat/MultiIndex.m index a60121af..9fa3ef91 100644 --- a/bindings/matlab/mat/MultiIndex.m +++ b/bindings/matlab/mat/MultiIndex.m @@ -69,10 +69,36 @@ function delete(this) result = MParT_('MultiIndex_Length', this.id_); end - function result = compare(this,multi) + % == operator + function result = eq(this,multi) result = MParT_('MultiIndex_Eq',this.id_,multi.get_id()); end + % ~= operator + function result = ne(this,multi) + result = MParT_('MultiIndex_Ne',this.id_,multi.get_id()); + end + + % < operator + function result = lt(this,multi) + result = MParT_('MultiIndex_Lt',this.id_,multi.get_id()); + end + + % > operator + function result = gt(this,multi) + result = MParT_('MultiIndex_Gt',this.id_,multi.get_id()); + end + + % >= operator + function result = Ge(this,multi) + result = MParT_('MultiIndex_Ge',this.id_,multi.get_id()); + end + + % <= operator + function result = Le(this,multi) + result = MParT_('MultiIndex_Le',this.id_,multi.get_id()); + end + function result = get_id(this) result = this.id_; end diff --git a/bindings/matlab/src/MultiIndex_mex.cpp b/bindings/matlab/src/MultiIndex_mex.cpp index af513ad5..4f331710 100644 --- a/bindings/matlab/src/MultiIndex_mex.cpp +++ b/bindings/matlab/src/MultiIndex_mex.cpp @@ -121,6 +121,50 @@ MEX_DEFINE(MultiIndex_Eq) (int nlhs, mxArray* plhs[], output.set(0, multi==multi2); } +MEX_DEFINE(MultiIndex_Ne) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 2); + OutputArguments output(nlhs, plhs, 1); + const MultiIndex& multi = Session::getConst(input.get(0)); + const MultiIndex& multi2 = Session::getConst(input.get(1)); + output.set(0, multi!=multi2); +} + +MEX_DEFINE(MultiIndex_Lt) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 2); + OutputArguments output(nlhs, plhs, 1); + const MultiIndex& multi = Session::getConst(input.get(0)); + const MultiIndex& multi2 = Session::getConst(input.get(1)); + output.set(0, multi::getConst(input.get(0)); + const MultiIndex& multi2 = Session::getConst(input.get(1)); + output.set(0, multi>multi2); +} + +MEX_DEFINE(MultiIndex_Ge) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 2); + OutputArguments output(nlhs, plhs, 1); + const MultiIndex& multi = Session::getConst(input.get(0)); + const MultiIndex& multi2 = Session::getConst(input.get(1)); + output.set(0, multi>=multi2); +} + +MEX_DEFINE(MultiIndex_Le) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 2); + OutputArguments output(nlhs, plhs, 1); + const MultiIndex& multi = Session::getConst(input.get(0)); + const MultiIndex& multi2 = Session::getConst(input.get(1)); + output.set(0, multi<=multi2); +} From 1f3ecfc07738c448fa31aed5e465129b82f1396c Mon Sep 17 00:00:00 2001 From: rubiop Date: Mon, 18 Jul 2022 00:44:44 -0400 Subject: [PATCH 09/20] more methods and overloads --- bindings/matlab/mat/MultiIndex.m | 12 +++--- bindings/matlab/mat/MultiIndexSet.m | 39 ++++++++++++++++++- bindings/matlab/notes_binding.md | 2 + bindings/matlab/src/MultiIndexSet_mex.cpp | 47 +++++++++++++++++++++++ 4 files changed, 93 insertions(+), 7 deletions(-) diff --git a/bindings/matlab/mat/MultiIndex.m b/bindings/matlab/mat/MultiIndex.m index 9fa3ef91..145caa2e 100644 --- a/bindings/matlab/mat/MultiIndex.m +++ b/bindings/matlab/mat/MultiIndex.m @@ -21,14 +21,14 @@ methods function this = MultiIndex(varargin) - if(nargin==2) - this.id_ = MParT_('MultiIndex_newDefault', varargin{1},varargin{2}); - else - if length(varargin{1})==1 - this.id_ = MParT_('MultiIndex_newDefault', varargin{1},0); + if(nargin==2) + if(varargin{2}=='id') + this.id_ = varargin{1}; else - this.id_ = MParT_('MultiIndex_newEigen',varargin{1}); + this.id_ = MParT_('MultiIndex_newDefault', varargin{1},varargin{2}); end + else + this.id_ = MParT_('MultiIndex_newEigen',varargin{1}); end end diff --git a/bindings/matlab/mat/MultiIndexSet.m b/bindings/matlab/mat/MultiIndexSet.m index 94850e3f..96daadaf 100644 --- a/bindings/matlab/mat/MultiIndexSet.m +++ b/bindings/matlab/mat/MultiIndexSet.m @@ -39,10 +39,46 @@ function delete(this) MParT_('MultiIndexSet_delete', this.id_); end + function multi = IndexToMulti(this,activeIndex) + multi_id = MParT_('MultiIndexSet_IndexToMulti',this.id_,activeIndex-1); + multi = MultiIndex(multi_id,'id'); + end + + function result = MultiToIndex(this,multi) + result= MParT_('MultiIndexSet_MultiToIndex',this.id_,multi.get_id()); + result = result + 1; + end + + function result = Length(this) + result = MParT_('MultiIndexSet_Length', this.id_); + end + function result = MaxOrders(this) result = MParT_('MultiIndexSet_MaxOrders', this.id_); end + function multi = at(this,activeIndex) + multi_id = MParT_('MultiIndexSet_at', this.id_,activeIndex-1); + multi = MultiIndex(multi_id,'id'); + end + + function varargout = subsref(this,s) %seems dangerous + switch s(1).type + case '.' + % Keep built-in functionality for '.' + [varargout{1:nargout}] = builtin('subsref', this, s); + case '()' + % Keep built-in functionality for '()' + [varargout{1:nargout}] = builtin('subsref', this, s); + case '{}' + activeIndex = s(1).subs{1}; + multi_id = MParT_('MultiIndexSet_subsref',this.id_,activeIndex-1); + [varargout{1:nargout}] = MultiIndex(multi_id,'id'); + otherwise + error('Indexing expression invalid.') + end + end + function result = Size(this) result = MParT_('MultiIndexSet_Size', this.id_); end @@ -89,7 +125,6 @@ function Visualize(this) MParT_('MultiIndexSet_Visualize',this.id_); end - function result = get_id(this) result = this.id_; end @@ -97,6 +132,8 @@ function Visualize(this) function fixed_mset = Fix(this) fixed_mset = FixedMultiIndexSet(this); end + + end diff --git a/bindings/matlab/notes_binding.md b/bindings/matlab/notes_binding.md index f80160e7..2e157a55 100644 --- a/bindings/matlab/notes_binding.md +++ b/bindings/matlab/notes_binding.md @@ -19,3 +19,5 @@ output.set(0, mset->ForciblyExpand(activeIndex)) # id_ + +# How to pass MParT object as argument of other object methods diff --git a/bindings/matlab/src/MultiIndexSet_mex.cpp b/bindings/matlab/src/MultiIndexSet_mex.cpp index 58469ba9..2e2a90aa 100644 --- a/bindings/matlab/src/MultiIndexSet_mex.cpp +++ b/bindings/matlab/src/MultiIndexSet_mex.cpp @@ -45,6 +45,32 @@ MEX_DEFINE(MultiIndexSet_delete) (int nlhs, mxArray* plhs[], Session::destroy(input.get(0)); } +MEX_DEFINE(MultiIndexSet_IndexToMulti) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 2); + OutputArguments output(nlhs, plhs, 1); + const MultiIndexSet& mset = Session::getConst(input.get(0)); + unsigned int activeIndex = input.get(1); + output.set(0, Session::create(new MultiIndex(mset.IndexToMulti(activeIndex)))); +} + +MEX_DEFINE(MultiIndexSet_MultiToIndex) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 2); + OutputArguments output(nlhs, plhs, 1); + const MultiIndexSet& mset = Session::getConst(input.get(0)); + const MultiIndex& multi = Session::getConst(input.get(1)); + output.set(0,mset.MultiToIndex(multi)); +} + +MEX_DEFINE(MultiIndexSet_Length) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 1); + OutputArguments output(nlhs, plhs, 1); + const MultiIndexSet& mset = Session::getConst(input.get(0)); + output.set(0, mset.Length()); +} + MEX_DEFINE(MultiIndexSet_MaxOrders) (int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) { InputArguments input(nrhs, prhs, 1); @@ -53,6 +79,24 @@ MEX_DEFINE(MultiIndexSet_MaxOrders) (int nlhs, mxArray* plhs[], output.set(0, mset.MaxOrders()); } +MEX_DEFINE(MultiIndexSet_at) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 2); + OutputArguments output(nlhs, plhs, 1); + const MultiIndexSet& mset = Session::getConst(input.get(0)); + int activeIndex = input.get(1); + output.set(0, Session::create(new MultiIndex(mset.at(activeIndex)))); +} + +MEX_DEFINE(MultiIndexSet_subsref) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 2); + OutputArguments output(nlhs, plhs, 1); + const MultiIndexSet& mset = Session::getConst(input.get(0)); + int activeIndex = input.get(1); + output.set(0, Session::create(new MultiIndex(mset[activeIndex]))); +} + MEX_DEFINE(MultiIndexSet_Size) (int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) { InputArguments input(nrhs, prhs, 1); @@ -140,4 +184,7 @@ MEX_DEFINE(MultiIndexSet_Visualize) (int nlhs, mxArray* plhs[], } + + + } // namespace \ No newline at end of file From 3357b2eab3150e3d98db1c9a3ea98f4e76d0772f Mon Sep 17 00:00:00 2001 From: rubiop Date: Mon, 18 Jul 2022 11:50:59 -0400 Subject: [PATCH 10/20] more bindings work --- MParT/MultiIndices/MultiIndexSet.h | 2 +- bindings/matlab/mat/MultiIndexSet.m | 15 ++++++++++++ bindings/matlab/src/MultiIndexSet_mex.cpp | 28 +++++++++++++++++++++++ 3 files changed, 44 insertions(+), 1 deletion(-) diff --git a/MParT/MultiIndices/MultiIndexSet.h b/MParT/MultiIndices/MultiIndexSet.h index 02c8cf93..563caf7a 100644 --- a/MParT/MultiIndices/MultiIndexSet.h +++ b/MParT/MultiIndices/MultiIndexSet.h @@ -170,7 +170,7 @@ MultiIndexSet set(length, limiter); already in the set and if the input function is unique, it is added to the set. @param[in] rhs The MultiIndex we want to add to the set. - @return A reference to this MultiIndex set, which may now contain the new + @return A reference to this MultiIndex, which may now contain the new MultiIndex in rhs. */ MultiIndexSet& operator+=(MultiIndex const& rhs); diff --git a/bindings/matlab/mat/MultiIndexSet.m b/bindings/matlab/mat/MultiIndexSet.m index 96daadaf..64e227e7 100644 --- a/bindings/matlab/mat/MultiIndexSet.m +++ b/bindings/matlab/mat/MultiIndexSet.m @@ -83,6 +83,21 @@ function delete(this) result = MParT_('MultiIndexSet_Size', this.id_); end + function result = plus(this,toAdd) + if strcmp(class(toAdd),'MultiIndexSet') + MParT_('MultiIndexSet_addMultiIndexSet',this.id_,toAdd.get_id()); + elseif strcmp(class(toAdd),'MultiIndex') + MParT_('MultiIndexSet_addMultiIndex',this.id_,toAdd.get_id()); + else + error('Unrecognized type to add to MultiIndexSet') + end + end + + function result = Union(this,mset) + %-1 to keep consitent with matlab ordering + result = MParT_('MultiIndexSet_Union',this.id_,mset.get_id()); + end + function result = Expand(this,activeInd) %-1 to keep consitent with matlab ordering result = MParT_('MultiIndexSet_Expand',this.id_,activeInd-1); diff --git a/bindings/matlab/src/MultiIndexSet_mex.cpp b/bindings/matlab/src/MultiIndexSet_mex.cpp index 2e2a90aa..550afa25 100644 --- a/bindings/matlab/src/MultiIndexSet_mex.cpp +++ b/bindings/matlab/src/MultiIndexSet_mex.cpp @@ -105,6 +105,34 @@ MEX_DEFINE(MultiIndexSet_Size) (int nlhs, mxArray* plhs[], output.set(0, mset.Size()); } +MEX_DEFINE(MultiIndexSet_addMultiIndexSet) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 2); + OutputArguments output(nlhs, plhs, 0); + MultiIndexSet *mset = Session::get(input.get(0)); + const MultiIndexSet& msetToAdd = Session::getConst(input.get(1)); + (*mset)+=msetToAdd; +} + +MEX_DEFINE(MultiIndexSet_addMultiIndex) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 2); + OutputArguments output(nlhs, plhs, 0); + MultiIndexSet *mset = Session::get(input.get(0)); + const MultiIndex& multiToAdd = Session::getConst(input.get(1)); + (*mset)+=multiToAdd; +} + +MEX_DEFINE(MultiIndexSet_Union) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 2); + OutputArguments output(nlhs, plhs, 1); + MultiIndexSet *mset = Session::get(input.get(0)); + const MultiIndexSet& rhs = Session::getConst(input.get(1)); + output.set(0, mset->Union(rhs)); +} + + MEX_DEFINE(MultiIndexSet_Expand) (int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) { InputArguments input(nrhs, prhs, 2); From 95b322b7d866fd744012c9deea7a2cfe46dc7edf Mon Sep 17 00:00:00 2001 From: rubiop Date: Mon, 18 Jul 2022 12:43:26 -0400 Subject: [PATCH 11/20] overeloaded plus --- bindings/matlab/mat/MultiIndexSet.m | 11 +++++++++-- bindings/matlab/src/MultiIndexSet_mex.cpp | 17 +++++++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/bindings/matlab/mat/MultiIndexSet.m b/bindings/matlab/mat/MultiIndexSet.m index 64e227e7..03815f2f 100644 --- a/bindings/matlab/mat/MultiIndexSet.m +++ b/bindings/matlab/mat/MultiIndexSet.m @@ -94,9 +94,16 @@ function delete(this) end function result = Union(this,mset) - %-1 to keep consitent with matlab ordering result = MParT_('MultiIndexSet_Union',this.id_,mset.get_id()); - end + end + + function Activate(this,multi) + MParT_('MultiIndexSet_Activate',this.id_,multi.get_id()); + end + + function result = AddActive(this,multi) + result = MParT_('MultiIndexSet_AddActive',this.id_,multi.get_id()); + end function result = Expand(this,activeInd) %-1 to keep consitent with matlab ordering diff --git a/bindings/matlab/src/MultiIndexSet_mex.cpp b/bindings/matlab/src/MultiIndexSet_mex.cpp index 550afa25..57a4a997 100644 --- a/bindings/matlab/src/MultiIndexSet_mex.cpp +++ b/bindings/matlab/src/MultiIndexSet_mex.cpp @@ -132,6 +132,23 @@ MEX_DEFINE(MultiIndexSet_Union) (int nlhs, mxArray* plhs[], output.set(0, mset->Union(rhs)); } +MEX_DEFINE(MultiIndexSet_Activate) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 2); + OutputArguments output(nlhs, plhs, 0); + MultiIndexSet *mset = Session::get(input.get(0)); + const MultiIndex& multi = Session::getConst(input.get(1)); + mset->Activate(multi); +} + +MEX_DEFINE(MultiIndexSet_AddActive) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 2); + OutputArguments output(nlhs, plhs, 1); + MultiIndexSet *mset = Session::get(input.get(0)); + const MultiIndex& multi = Session::getConst(input.get(1)); + mset->AddActive(multi); +} MEX_DEFINE(MultiIndexSet_Expand) (int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) { From e9b9bac3192838aed72a1c776ed184f2efd46d6a Mon Sep 17 00:00:00 2001 From: rubiop Date: Mon, 18 Jul 2022 15:26:27 -0400 Subject: [PATCH 12/20] work done on AdmissibleFowardNeighbors --- bindings/matlab/mat/MultiIndexSet.m | 25 ++++++++++++++-- bindings/matlab/src/MultiIndexSet_mex.cpp | 35 ++++++++++++++++++++++- 2 files changed, 56 insertions(+), 4 deletions(-) diff --git a/bindings/matlab/mat/MultiIndexSet.m b/bindings/matlab/mat/MultiIndexSet.m index 03815f2f..d914339e 100644 --- a/bindings/matlab/mat/MultiIndexSet.m +++ b/bindings/matlab/mat/MultiIndexSet.m @@ -105,14 +105,33 @@ function Activate(this,multi) result = MParT_('MultiIndexSet_AddActive',this.id_,multi.get_id()); end - function result = Expand(this,activeInd) - %-1 to keep consitent with matlab ordering - result = MParT_('MultiIndexSet_Expand',this.id_,activeInd-1); + function result = Expand(this,varargin) + if(nargin == 2) + %-1 to keep consitent with matlab ordering + result = MParT_('MultiIndexSet_Expand',this.id_,varargin{1}-1); + elseif(nargin == 1) + result = MParT_('MultiIndexSet_ExpandAny',this.id_); + else + error('Wrong number of inputs') + end end function result = ForciblyExpand(this,activeInd) %-1 to keep consitent with matlab ordering result = MParT_('MultiIndexSet_ForciblyExpand',this.id_,activeInd-1); + end + + function result = ForciblyActivate(this,multi) + result = MParT_('MultiIndexSet_ForciblyActivate',this.id_,multi.get_id()); + end + + function listMultis = AdmissibleForwardNeighbors(this,activeInd) + %-1 to keep consistent with matlab ordering + multi_ids = MParT_('MultiIndexSet_AdmissibleFowardNeighbors',this.id_,activeInd-1); + listMultis = []; + for i = 1:length(multi_ids) + listMultis=[listMultis,MultiIndex(multi_ids(i),'id')]; + end end function result = Frontier(this) diff --git a/bindings/matlab/src/MultiIndexSet_mex.cpp b/bindings/matlab/src/MultiIndexSet_mex.cpp index 57a4a997..062a3db7 100644 --- a/bindings/matlab/src/MultiIndexSet_mex.cpp +++ b/bindings/matlab/src/MultiIndexSet_mex.cpp @@ -147,7 +147,7 @@ MEX_DEFINE(MultiIndexSet_AddActive) (int nlhs, mxArray* plhs[], OutputArguments output(nlhs, plhs, 1); MultiIndexSet *mset = Session::get(input.get(0)); const MultiIndex& multi = Session::getConst(input.get(1)); - mset->AddActive(multi); + output.set(0,mset->AddActive(multi)); } MEX_DEFINE(MultiIndexSet_Expand) (int nlhs, mxArray* plhs[], @@ -159,6 +159,14 @@ MEX_DEFINE(MultiIndexSet_Expand) (int nlhs, mxArray* plhs[], output.set(0, mset->Expand(activeInd)); } +// MEX_DEFINE(MultiIndexSet_ExpandAny) (int nlhs, mxArray* plhs[], +// int nrhs, const mxArray* prhs[]) { +// InputArguments input(nrhs, prhs, 1); +// OutputArguments output(nlhs, plhs, 0); +// MultiIndexSet *mset = Session::get(input.get(0)); +// output.set(0, mset->Expand()); +// } + MEX_DEFINE(MultiIndexSet_ForciblyExpand) (int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) { InputArguments input(nrhs, prhs, 2); @@ -168,6 +176,31 @@ MEX_DEFINE(MultiIndexSet_ForciblyExpand) (int nlhs, mxArray* plhs[], output.set(0, mset->ForciblyExpand(activeIndex)); } +MEX_DEFINE(MultiIndexSet_ForciblyActivate) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 2); + OutputArguments output(nlhs, plhs, 1); + MultiIndexSet *mset = Session::get(input.get(0)); + const MultiIndex& multi = Session::getConst(input.get(1)); + output.set(0,mset->ForciblyActivate(multi)); +} + +MEX_DEFINE(MultiIndexSet_AdmissibleFowardNeighbors) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 2); + MultiIndexSet *mset = Session::get(input.get(0)); + unsigned int activeIndex = input.get(1); + std::vector vecMultiIndex = mset->AdmissibleForwardNeighbors(activeIndex); + OutputArguments output(nlhs, plhs, 1); + std::vector multi_ids(vecMultiIndex.size()); + for (int i=0; i::create(new MultiIndex(vecMultiIndex[i])); + } + output.set(0,multi_ids); + +} + + MEX_DEFINE(MultiIndexSet_Frontier) (int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) { InputArguments input(nrhs, prhs, 1); From 304541e8880fda941ff846953aaa70096688f1c5 Mon Sep 17 00:00:00 2001 From: rubiop Date: Mon, 18 Jul 2022 15:48:12 -0400 Subject: [PATCH 13/20] finished MultiIndexSet --- bindings/matlab/mat/MultiIndexSet.m | 30 ++++++++++++++- bindings/matlab/src/MultiIndexSet_mex.cpp | 46 ++++++++++++++++++++++- 2 files changed, 72 insertions(+), 4 deletions(-) diff --git a/bindings/matlab/mat/MultiIndexSet.m b/bindings/matlab/mat/MultiIndexSet.m index d914339e..b5ab4c62 100644 --- a/bindings/matlab/mat/MultiIndexSet.m +++ b/bindings/matlab/mat/MultiIndexSet.m @@ -136,6 +136,24 @@ function Activate(this,multi) function result = Frontier(this) result = MParT_('MultiIndexSet_Frontier',this.id_); + end + + function listMultis = Margin(this) + %-1 to keep consistent with matlab ordering + multi_ids = MParT_('MultiIndexSet_Margin',this.id_); + listMultis = []; + for i = 1:length(multi_ids) + listMultis=[listMultis,MultiIndex(multi_ids(i),'id')]; + end + end + + function listMultis = ReducedMargin(this) + %-1 to keep consistent with matlab ordering + multi_ids = MParT_('MultiIndexSet_ReducedMargin',this.id_); + listMultis = []; + for i = 1:length(multi_ids) + listMultis=[listMultis,MultiIndex(multi_ids(i),'id')]; + end end function result = StrictFrontier(this) @@ -145,12 +163,20 @@ function Activate(this,multi) function result = BackwardNeighbors(this,activeIndex) %-1 to keep consitent with matlab ordering result = MParT_('MultiIndexSet_BackwardNeighbors',this.id_,activeIndex-1); - end + end + + function result = IsAdmissible(this,multi) + result = MParT_('MultiIndexSet_IsAdmissible',this.id_,multi.get_id()); + end function result = IsExpandable(this,activeIndex) %-1 to keep consitent with matlab ordering result = MParT_('MultiIndexSet_IsExpandable',this.id_,activeIndex-1); - end + end + + function result = IsActive(this,multi) + result = MParT_('MultiIndexSet_IsActive',this.id_,multi.get_id()); + end function result = NumActiveForward(this,activeIndex) %-1 to keep consitent with matlab ordering diff --git a/bindings/matlab/src/MultiIndexSet_mex.cpp b/bindings/matlab/src/MultiIndexSet_mex.cpp index 062a3db7..01000688 100644 --- a/bindings/matlab/src/MultiIndexSet_mex.cpp +++ b/bindings/matlab/src/MultiIndexSet_mex.cpp @@ -197,10 +197,8 @@ MEX_DEFINE(MultiIndexSet_AdmissibleFowardNeighbors) (int nlhs, mxArray* plhs[], multi_ids[i] = Session::create(new MultiIndex(vecMultiIndex[i])); } output.set(0,multi_ids); - } - MEX_DEFINE(MultiIndexSet_Frontier) (int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) { InputArguments input(nrhs, prhs, 1); @@ -209,6 +207,32 @@ MEX_DEFINE(MultiIndexSet_Frontier) (int nlhs, mxArray* plhs[], output.set(0, mset.Frontier()); } +MEX_DEFINE(MultiIndexSet_Margin) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 1); + OutputArguments output(nlhs, plhs, 1); + MultiIndexSet *mset = Session::get(input.get(0)); + std::vector vecMultiIndex = mset->Margin(); + std::vector multi_ids(vecMultiIndex.size()); + for (int i=0; i::create(new MultiIndex(vecMultiIndex[i])); + } + output.set(0,multi_ids); +} + +MEX_DEFINE(MultiIndexSet_ReducedMargin) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 1); + OutputArguments output(nlhs, plhs, 1); + MultiIndexSet *mset = Session::get(input.get(0)); + std::vector vecMultiIndex = mset->ReducedMargin(); + std::vector multi_ids(vecMultiIndex.size()); + for (int i=0; i::create(new MultiIndex(vecMultiIndex[i])); + } + output.set(0,multi_ids); +} + MEX_DEFINE(MultiIndexSet_StrictFrontier) (int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) { InputArguments input(nrhs, prhs, 1); @@ -226,6 +250,15 @@ MEX_DEFINE(MultiIndexSet_BackwardNeighbors) (int nlhs, mxArray* plhs[], output.set(0, mset.BackwardNeighbors(activeIndex)); } +MEX_DEFINE(MultiIndexSet_IsAdmissible) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 2); + OutputArguments output(nlhs, plhs, 1); + const MultiIndexSet& mset = Session::getConst(input.get(0)); + const MultiIndex& multi = Session::getConst(input.get(1)); + output.set(0, mset.IsAdmissible(multi)); +} + MEX_DEFINE(MultiIndexSet_IsExpandable) (int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) { InputArguments input(nrhs, prhs, 2); @@ -235,6 +268,15 @@ MEX_DEFINE(MultiIndexSet_IsExpandable) (int nlhs, mxArray* plhs[], output.set(0, mset.IsExpandable(activeIndex)); } +MEX_DEFINE(MultiIndexSet_IsActive) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 2); + OutputArguments output(nlhs, plhs, 1); + const MultiIndexSet& mset = Session::getConst(input.get(0)); + const MultiIndex& multi = Session::getConst(input.get(1)); + output.set(0, mset.IsActive(multi)); +} + MEX_DEFINE(MultiIndexSet_NumActiveForward) (int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) { InputArguments input(nrhs, prhs, 2); From c2d573cbf6d82799a33a7c8367e92537d314fbf4 Mon Sep 17 00:00:00 2001 From: rubiop Date: Mon, 18 Jul 2022 21:44:09 -0400 Subject: [PATCH 14/20] correction bug TriangularMap from list --- bindings/matlab/CMakeLists.txt | 1 + bindings/matlab/mat/ConditionalMap.m | 35 +++++--- bindings/matlab/mat/MultiIndex.m | 6 +- bindings/matlab/mat/TriangularMap.m | 90 +------------------ bindings/matlab/src/ConditionalMap_mex.cpp | 10 +++ .../src/ParameterizedFunctionBase_mex.cpp | 40 +++++++++ 6 files changed, 79 insertions(+), 103 deletions(-) create mode 100644 bindings/matlab/src/ParameterizedFunctionBase_mex.cpp diff --git a/bindings/matlab/CMakeLists.txt b/bindings/matlab/CMakeLists.txt index c2343b95..81137337 100644 --- a/bindings/matlab/CMakeLists.txt +++ b/bindings/matlab/CMakeLists.txt @@ -10,6 +10,7 @@ set(MEX_SOURCE src/MultiIndexSet_mex.cpp src/MultiIndex_mex.cpp src/FixedMultiIndexSet_mex.cpp + #src/ParameterizedFunctionBase_mex.cpp src/MexArrayConversions.cpp src/MexMapOptionsConversions.cpp ) diff --git a/bindings/matlab/mat/ConditionalMap.m b/bindings/matlab/mat/ConditionalMap.m index 0cef86de..11dc5caa 100644 --- a/bindings/matlab/mat/ConditionalMap.m +++ b/bindings/matlab/mat/ConditionalMap.m @@ -24,19 +24,22 @@ function this = ConditionalMap(varargin) if(nargin==2) - mset = varargin{1}; - mapOptions = varargin{2}; - mexOptions = mapOptions.getMexOptions; - - input_str=['MParT_(',char(39),'ConditionalMap_newMap',char(39),',mset.get_id()']; - for o=1:length(mexOptions) - input_o=[',mexOptions{',num2str(o),'}']; - input_str=[input_str,input_o]; + if(isstring(varargin{2})) + if(varargin{2}=="id") + this.id_=varargin{1}; + end + else + mset = varargin{1}; + mapOptions = varargin{2}; + mexOptions = mapOptions.getMexOptions; + input_str=['MParT_(',char(39),'ConditionalMap_newMap',char(39),',mset.get_id()']; + for o=1:length(mexOptions) + input_o=[',mexOptions{',num2str(o),'}']; + input_str=[input_str,input_o]; + end + input_str=[input_str,')']; + this.id_ = eval(input_str); end - input_str=[input_str,')']; - - this.id_ = eval(input_str); - elseif(nargin==4) inputDim = varargin{1}; outputDim = varargin{2}; @@ -54,8 +57,7 @@ this.id_ = eval(input_str); elseif(nargin==1) - MParT_('ConditionalMap_newTriMap', varargin{1}); - + this.id_=MParT_('ConditionalMap_newTriMap', varargin{1}); else error('Invalid number of inputs') end @@ -66,6 +68,11 @@ function delete(this) MParT_('ConditionalMap_deleteMap', this.id_); end + function condMap = GetComponent(this,i) + condMap_id = MParT_('ConditionalMap_GetComponent',this.id_,i-1); + condMap = ConditionalMap(condMap_id,"id") + end + function SetCoeffs(this,coeffs) MParT_('ConditionalMap_SetCoeffs',this.id_,coeffs(:)); end diff --git a/bindings/matlab/mat/MultiIndex.m b/bindings/matlab/mat/MultiIndex.m index 145caa2e..9c512a2b 100644 --- a/bindings/matlab/mat/MultiIndex.m +++ b/bindings/matlab/mat/MultiIndex.m @@ -22,8 +22,10 @@ methods function this = MultiIndex(varargin) if(nargin==2) - if(varargin{2}=='id') - this.id_ = varargin{1}; + if(isstring(varargin{2})) + if(varargin{2}=="id") + this.id_ = varargin{1}; + end else this.id_ = MParT_('MultiIndex_newDefault', varargin{1},varargin{2}); end diff --git a/bindings/matlab/mat/TriangularMap.m b/bindings/matlab/mat/TriangularMap.m index c086cc0a..21f0c448 100644 --- a/bindings/matlab/mat/TriangularMap.m +++ b/bindings/matlab/mat/TriangularMap.m @@ -1,87 +1,3 @@ -classdef TriangularMap < handle -%DATABASE Example usage of the mexplus development kit. -% -% This class definition gives an interface to the underlying MEX functions -% built in the private directory. It is a good practice to wrap MEX functions -% with Matlab script so that the API is well documented and separated from -% its C++ implementation. Also such a wrapper is a good place to validate -% input arguments. -% -% Build -% ----- -% -% make -% -% See `make.m` for details. -% - -properties (Access = private) - id_ -end - -methods - function this = TriangularMap(list_id) - %DATABASE Create a new database. - this.id_ = MParT_('newTriMap',list_id); - end - - function delete(this) - %DELETE Destructor. - MParT_('deleteMap', this.id_); - end - - function SetCoeffs(this,coeffs) - MParT_('SetCoeffs',this.id_,coeffs(:)); - end - - function result = Coeffs(this) - result = MParT_('Coeffs',this.id_); - end - - function result = numCoeffs(this) - result = MParT_('numCoeffs',this.id_); - end - - %function result = Evaluate(this,pts) - % result = zeros(1,size(pts,2)); - % MParT_('Evaluate',this.id_,pts,result); - %end - - function result = Evaluate(this,pts) - result = MParT_('Evaluate',this.id_,pts); - end - - function result = LogDeterminant(this,pts) - result = MParT_('LogDeterminant',this.id_,pts); - end - - function result = Inverse(this,x1,r) - result = MParT_('Inverse',this.id_,x1,r); - end - - function result = CoeffGrad(this,pts,sens) - result = MParT_('CoeffGrad',this.id_,pts,sens); - end - - function result = LogDeterminantCoeffGrad(this,pts) - result = MParT_('LogDeterminantCoeffGrad',this.id_,pts); - end - - function result = get_id(this) - result = this.id_; - end -end - -methods (Static) - function environment = getEnvironment() - %GETENVIRONMENT Get environment info. - environment = mexmpart('getEnvironment'); - end - - function setEnvironment(environment) - %SETENVIRONMENT Set environment info. - mexmpart('setEnvironment', environment); - end -end - -end +function map = TriangularMap(listCondMaps) + map = ConditionalMap(listCondMaps); +end \ No newline at end of file diff --git a/bindings/matlab/src/ConditionalMap_mex.cpp b/bindings/matlab/src/ConditionalMap_mex.cpp index f464bddf..2ad4cea5 100644 --- a/bindings/matlab/src/ConditionalMap_mex.cpp +++ b/bindings/matlab/src/ConditionalMap_mex.cpp @@ -107,6 +107,16 @@ MEX_DEFINE(ConditionalMap_SetCoeffs) (int nlhs, mxArray* plhs[], condMap.map_ptr->SetCoeffs(coeffs); } +MEX_DEFINE(ConditionalMap_GetComponent) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + + InputArguments input(nrhs, prhs, 2); + OutputArguments output(nlhs, plhs, 1); + TriangularMap *triMap = Session>::get(input.get(0)); + unsigned int i = input.get(1); + output.set(0, Session>>::create(new std::shared_ptr>(triMap->GetComponent(i)))); +} + MEX_DEFINE(ConditionalMap_Coeffs) (int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) { InputArguments input(nrhs, prhs, 1); diff --git a/bindings/matlab/src/ParameterizedFunctionBase_mex.cpp b/bindings/matlab/src/ParameterizedFunctionBase_mex.cpp new file mode 100644 index 00000000..aa3ef5de --- /dev/null +++ b/bindings/matlab/src/ParameterizedFunctionBase_mex.cpp @@ -0,0 +1,40 @@ +#include +#include "MParT/MultiIndices/MultiIndexSet.h" +#include "MParT/MultiIndices/FixedMultiIndexSet.h" +#include "MParT/ParameterizedFunctionBase.h" + +#include "MParT/Utilities/ArrayConversions.h" +#include "MexArrayConversions.h" +#include "mexplus_eigen.h" +#include + + +using namespace mpart; +using namespace mexplus; +using MemorySpace = Kokkos::HostSpace; + + +// Instance manager for ParameterizedFunctionBase +template class mexplus::Session>; + + +namespace { +MEX_DEFINE(ParameterizedFunctionBase_new) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 3); + OutputArguments output(nlhs, plhs, 1); + unsigned int inDim = input.get(0); + unsigned int outDim = input.get(1); + unsigned int nCoeffs = input.get(2); + output.set(0, Session>::create(new ParameterizedFunctionBase(inDim,outDim,nCoeffs))); +} + +// Defines MEX API for delete. +MEX_DEFINE(ParameterizedFunctionBase_delete) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 1); + OutputArguments output(nlhs, plhs, 0); + Session>::destroy(input.get(0)); +} + +} // namespace \ No newline at end of file From 9ae3f865f3c1a4ab3004af036e4aa01d724dc791 Mon Sep 17 00:00:00 2001 From: rubiop Date: Tue, 19 Jul 2022 12:51:54 -0400 Subject: [PATCH 15/20] GetComponent worked --- bindings/matlab/mat/ConditionalMap.m | 2 +- bindings/matlab/src/ConditionalMap_mex.cpp | 14 ++++++++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/bindings/matlab/mat/ConditionalMap.m b/bindings/matlab/mat/ConditionalMap.m index 11dc5caa..ba011223 100644 --- a/bindings/matlab/mat/ConditionalMap.m +++ b/bindings/matlab/mat/ConditionalMap.m @@ -70,7 +70,7 @@ function delete(this) function condMap = GetComponent(this,i) condMap_id = MParT_('ConditionalMap_GetComponent',this.id_,i-1); - condMap = ConditionalMap(condMap_id,"id") + condMap = ConditionalMap(condMap_id,"id"); end function SetCoeffs(this,coeffs) diff --git a/bindings/matlab/src/ConditionalMap_mex.cpp b/bindings/matlab/src/ConditionalMap_mex.cpp index 2ad4cea5..70c80031 100644 --- a/bindings/matlab/src/ConditionalMap_mex.cpp +++ b/bindings/matlab/src/ConditionalMap_mex.cpp @@ -26,6 +26,10 @@ class ConditionalMapMex { // The class map_ptr = MapFactory::CreateComponent(mset,opts); } + ConditionalMapMex(std::shared_ptr> init_ptr){ + map_ptr = init_ptr; + } + ConditionalMapMex(std::vector>> blocks){ map_ptr = std::make_shared>(blocks); } @@ -112,9 +116,15 @@ MEX_DEFINE(ConditionalMap_GetComponent) (int nlhs, mxArray* plhs[], InputArguments input(nrhs, prhs, 2); OutputArguments output(nlhs, plhs, 1); - TriangularMap *triMap = Session>::get(input.get(0)); unsigned int i = input.get(1); - output.set(0, Session>>::create(new std::shared_ptr>(triMap->GetComponent(i)))); + ConditionalMapMex *condMap = Session::get(input.get(0)); + std::shared_ptr> condMap_ptr = condMap->map_ptr; + std::shared_ptr> tri_ptr = std::dynamic_pointer_cast>(condMap_ptr); + if(tri_ptr==nullptr){ + throw std::runtime_error("Tried to access GetComponent with a type other than TriangularMap"); + }else{ + output.set(0, Session::create(new ConditionalMapMex(tri_ptr->GetComponent(i)))); + } } MEX_DEFINE(ConditionalMap_Coeffs) (int nlhs, mxArray* plhs[], From ff50bf57191466748099cf8261099594e75790c3 Mon Sep 17 00:00:00 2001 From: rubiop Date: Tue, 19 Jul 2022 14:00:00 -0400 Subject: [PATCH 16/20] moved test in one folder --- bindings/matlab/CMakeLists.txt | 2 +- bindings/matlab/src/ConditionalMap_mex.cpp | 1 + bindings/matlab/src/ParameterizedFunctionBase_mex.cpp | 3 --- bindings/matlab/{mat => tests}/TestExp.m | 0 4 files changed, 2 insertions(+), 4 deletions(-) rename bindings/matlab/{mat => tests}/TestExp.m (100%) diff --git a/bindings/matlab/CMakeLists.txt b/bindings/matlab/CMakeLists.txt index 81137337..c3a848f3 100644 --- a/bindings/matlab/CMakeLists.txt +++ b/bindings/matlab/CMakeLists.txt @@ -10,7 +10,7 @@ set(MEX_SOURCE src/MultiIndexSet_mex.cpp src/MultiIndex_mex.cpp src/FixedMultiIndexSet_mex.cpp - #src/ParameterizedFunctionBase_mex.cpp + src/ParameterizedFunctionBase_mex.cpp src/MexArrayConversions.cpp src/MexMapOptionsConversions.cpp ) diff --git a/bindings/matlab/src/ConditionalMap_mex.cpp b/bindings/matlab/src/ConditionalMap_mex.cpp index 70c80031..9060d0c1 100644 --- a/bindings/matlab/src/ConditionalMap_mex.cpp +++ b/bindings/matlab/src/ConditionalMap_mex.cpp @@ -33,6 +33,7 @@ class ConditionalMapMex { // The class ConditionalMapMex(std::vector>> blocks){ map_ptr = std::make_shared>(blocks); } + ConditionalMapMex(unsigned int inputDim, unsigned int outputDim, unsigned int totalOrder, MapOptions opts){ map_ptr = MapFactory::CreateTriangular(inputDim,outputDim,totalOrder,opts); } diff --git a/bindings/matlab/src/ParameterizedFunctionBase_mex.cpp b/bindings/matlab/src/ParameterizedFunctionBase_mex.cpp index aa3ef5de..95a230a7 100644 --- a/bindings/matlab/src/ParameterizedFunctionBase_mex.cpp +++ b/bindings/matlab/src/ParameterizedFunctionBase_mex.cpp @@ -1,8 +1,5 @@ #include -#include "MParT/MultiIndices/MultiIndexSet.h" -#include "MParT/MultiIndices/FixedMultiIndexSet.h" #include "MParT/ParameterizedFunctionBase.h" - #include "MParT/Utilities/ArrayConversions.h" #include "MexArrayConversions.h" #include "mexplus_eigen.h" diff --git a/bindings/matlab/mat/TestExp.m b/bindings/matlab/tests/TestExp.m similarity index 100% rename from bindings/matlab/mat/TestExp.m rename to bindings/matlab/tests/TestExp.m From b607d46a6e5ce9182dc12e3dbaf2b5db5d757e59 Mon Sep 17 00:00:00 2001 From: Matthew Parno Date: Tue, 19 Jul 2022 14:15:11 -0400 Subject: [PATCH 17/20] Added factory method for MultivariateExpansion. --- MParT/MapFactory.h | 12 ++++++++++++ src/MapFactory.cpp | 38 ++++++++++++++++++++++++++++++++++++++ tests/Test_MapFactory.cpp | 25 +++++++++++++++++++++++++ 3 files changed, 75 insertions(+) diff --git a/MParT/MapFactory.h b/MParT/MapFactory.h index e50c8e44..e579e442 100644 --- a/MParT/MapFactory.h +++ b/MParT/MapFactory.h @@ -37,6 +37,18 @@ namespace mpart{ unsigned int outputDim, unsigned int totalOrder, MapOptions options = MapOptions()); + + /** + @brief Constructs a (generally) non-monotone multivariate expansion. + @param outputDim The output dimension of the expansion. Each output will be defined by the same multiindex set but will have different coefficients. + @param mset The multiindex set specifying which terms should be used in the multivariate expansion. + @param options Options specifying the 1d basis functions used in the parameterization. + */ + template + std::shared_ptr> CreateExpansion(unsigned int outputDim, + FixedMultiIndexSet const& mset, + MapOptions options = MapOptions()); + } } diff --git a/src/MapFactory.cpp b/src/MapFactory.cpp index 3c6b3a30..d2e1fb5a 100644 --- a/src/MapFactory.cpp +++ b/src/MapFactory.cpp @@ -95,9 +95,47 @@ std::shared_ptr> mpart::MapFactory::CreateTriang } +template +std::shared_ptr> mpart::MapFactory::CreateExpansion(unsigned int outputDim, + FixedMultiIndexSet const& mset, + MapOptions opts) +{ + std::shared_ptr> output; + + if(opts.basisType==BasisTypes::ProbabilistHermite){ + + ProbabilistHermite basis1d; + output = std::make_shared>(outputDim, mset, basis1d); + + }else if(opts.basisType==BasisTypes::PhysicistHermite){ + + PhysicistHermite basis1d; + output = std::make_shared>(outputDim, mset, basis1d); + + }else if(opts.basisType==BasisTypes::HermiteFunctions){ + + HermiteFunction basis1d; + output = std::make_shared>(outputDim, mset, basis1d); + } + + if(output){ + output->SetCoeffs(Kokkos::View("Component Coefficients", output->numCoeffs)); + return output; + } + + std::stringstream msg; + msg << "Could not parse options in CreateExpansion. Unknown 1d basis type."; + throw std::runtime_error(msg.str()); + + return nullptr; +} + + template std::shared_ptr> mpart::MapFactory::CreateComponent(FixedMultiIndexSet const&, MapOptions); +template std::shared_ptr> mpart::MapFactory::CreateExpansion(unsigned int, FixedMultiIndexSet const&, MapOptions); template std::shared_ptr> mpart::MapFactory::CreateTriangular(unsigned int, unsigned int, unsigned int, MapOptions); #if defined(KOKKOS_ENABLE_CUDA ) || defined(KOKKOS_ENABLE_SYCL) template std::shared_ptr> mpart::MapFactory::CreateComponent(FixedMultiIndexSet const&, MapOptions); + template std::shared_ptr> mpart::MapFactory::CreateExpansion(unsigned int, FixedMultiIndexSet const&, MapOptions); template std::shared_ptr> mpart::MapFactory::CreateTriangular(unsigned int, unsigned int, unsigned int, MapOptions); #endif \ No newline at end of file diff --git a/tests/Test_MapFactory.cpp b/tests/Test_MapFactory.cpp index 7450318d..14b6e079 100644 --- a/tests/Test_MapFactory.cpp +++ b/tests/Test_MapFactory.cpp @@ -31,6 +31,31 @@ TEST_CASE( "Testing map component factory", "[MapFactoryComponent]" ) { Kokkos::View eval = map->Evaluate(pts); } + +TEST_CASE( "Testing multivariate expansion factory", "[MapFactoryExpansion]" ) { + + MapOptions options; + options.basisType = BasisTypes::ProbabilistHermite; + + unsigned int outDim = 5; + unsigned int inDim = 3; + unsigned int maxDegree = 5; + FixedMultiIndexSet mset(inDim,maxDegree); + + std::shared_ptr> func = MapFactory::CreateExpansion(outDim, mset, options); + REQUIRE(func!=nullptr); + + unsigned int numPts = 100; + Kokkos::View pts("Points", inDim, numPts); + for(unsigned int i=0; i eval = func->Evaluate(pts); + CHECK(eval.extent(0)==outDim); + CHECK(eval.extent(1)==numPts); +} + + TEST_CASE( "Testing factory method for triangular map", "[MapFactoryTriangular]" ) { MapOptions options; From 1a06ca7334d8488858ea6da735150641fa152389 Mon Sep 17 00:00:00 2001 From: rubiop Date: Tue, 19 Jul 2022 18:00:26 -0400 Subject: [PATCH 18/20] ParameterizedFunction and finished ConditionalMap --- bindings/matlab/external/mexplus_eigen.h | 17 +++ bindings/matlab/mat/ConditionalMap.m | 9 ++ bindings/matlab/mat/ParameterizedFunction.m | 93 +++++++++++++ bindings/matlab/src/ConditionalMap_mex.cpp | 34 +++++ .../src/ParameterizedFunctionBase_mex.cpp | 126 ++++++++++++++++-- 5 files changed, 268 insertions(+), 11 deletions(-) create mode 100644 bindings/matlab/mat/ParameterizedFunction.m diff --git a/bindings/matlab/external/mexplus_eigen.h b/bindings/matlab/external/mexplus_eigen.h index d93c8b99..84667ebe 100644 --- a/bindings/matlab/external/mexplus_eigen.h +++ b/bindings/matlab/external/mexplus_eigen.h @@ -39,6 +39,23 @@ inline mxArray* MxArray::from(const Eigen::MatrixXd& eigen_matrix) return out_array.release(); }; +/** + * @brief Define a template specialisation for Eigen::VectorXd + */ +template <> +inline mxArray* MxArray::from(const Eigen::VectorXd& eigen_vector) +{ + const int num_elems = static_cast(eigen_vector.size()); + //Choose (1,n) or (n,1) ? + MxArray out_array(MxArray::Numeric(1,num_elems)); + + for (int c = 0; c < num_elems; ++c) + { + out_array.set(1, c, eigen_vector(c)); + } + return out_array.release(); +}; + /** diff --git a/bindings/matlab/mat/ConditionalMap.m b/bindings/matlab/mat/ConditionalMap.m index ba011223..948fec09 100644 --- a/bindings/matlab/mat/ConditionalMap.m +++ b/bindings/matlab/mat/ConditionalMap.m @@ -73,6 +73,11 @@ function delete(this) condMap = ConditionalMap(condMap_id,"id"); end + function parFunc = GetBaseFunction(this) + parFunc_id=MParT_('ConditionalMap_GetBaseFunction',this.id_); + parFunc = ParameterizedFunction(parFunc_id,"id"); + end + function SetCoeffs(this,coeffs) MParT_('ConditionalMap_SetCoeffs',this.id_,coeffs(:)); end @@ -81,6 +86,10 @@ function SetCoeffs(this,coeffs) result = MParT_('ConditionalMap_Coeffs',this.id_); end + function result = CoeffMap(this) + result = MParT_('ConditionalMap_CoeffMap',this.id_); + end + function result = numCoeffs(this) result = MParT_('ConditionalMap_numCoeffs',this.id_); end diff --git a/bindings/matlab/mat/ParameterizedFunction.m b/bindings/matlab/mat/ParameterizedFunction.m new file mode 100644 index 00000000..8ca53d88 --- /dev/null +++ b/bindings/matlab/mat/ParameterizedFunction.m @@ -0,0 +1,93 @@ +classdef ParameterizedFunction < handle +%DATABASE Example usage of the mexplus development kit. +% +% This class definition gives an interface to the underlying MEX functions +% built in the private directory. It is a good practice to wrap MEX functions +% with Matlab script so that the API is well documented and separated from +% its C++ implementation. Also such a wrapper is a good place to validate +% input arguments. +% +% Build +% ----- +% +% make +% +% See `make.m` for details. +% + +properties (Access = private) + id_ +end + +methods + + function this = ParameterizedFunction(varargin) + if(nargin==2) + if(isstring(varargin{2})) + if(varargin{2}=="id") + this.id_=varargin{1}; + end + end + elseif(nargin==3) + outputDim = varargin{1}; + mset = varargin{2}; + mapOptions = varargin{3}; + mexOptions = mapOptions.getMexOptions; + input_str=['MParT_(',char(39),'ParameterizedFunction_newMap',char(39),',outputDim,mset.get_id()']; + for o=1:length(mexOptions) + input_o=[',mexOptions{',num2str(o),'}']; + input_str=[input_str,input_o]; + end + input_str=[input_str,')']; + this.id_ = eval(input_str); + else + error('Invalid number of inputs') + end + end + + function delete(this) + %DELETE Destructor. + MParT_('ParameterizedFunction_delete', this.id_); + end + + function SetCoeffs(this,coeffs) + MParT_('ParameterizedFunction_SetCoeffs',this.id_,coeffs(:)); + end + + function result = Coeffs(this) + result = MParT_('ParameterizedFunction_Coeffs',this.id_); + end + + function result = CoeffMap(this) + result = MParT_('ParameterizedFunction_CoeffMap',this.id_); + end + + function result = numCoeffs(this) + result = MParT_('ParameterizedFunction_numCoeffs',this.id_); + end + + function result = Evaluate(this,pts) + result = zeros(this.outputDim, size(pts,2)); + MParT_('ParameterizedFunction_Evaluate',this.id_,pts,result); + end + + function result = CoeffGrad(this,pts,sens) + result = zeros(this.numCoeffs, size(pts,2)); + MParT_('ParameterizedFunction_CoeffGrad',this.id_,pts,sens,result); + end + + function result = get_id(this) + result = this.id_; + end + + function result = outputDim(this) + result = MParT_('ParameterizedFunction_outputDim',this.id_); + end + + function result = inputDim(this) + result = MParT_('ParameterizedFunction_inputDim',this.id_); + end + +end + +end diff --git a/bindings/matlab/src/ConditionalMap_mex.cpp b/bindings/matlab/src/ConditionalMap_mex.cpp index 9060d0c1..117ca5b8 100644 --- a/bindings/matlab/src/ConditionalMap_mex.cpp +++ b/bindings/matlab/src/ConditionalMap_mex.cpp @@ -39,8 +39,23 @@ class ConditionalMapMex { // The class } }; //end class +class ParameterizedFunctionMex { // The class +public: + std::shared_ptr> fun_ptr; + + ParameterizedFunctionMex(unsigned int outputDim, FixedMultiIndexSet const& mset, + MapOptions opts){ + fun_ptr = MapFactory::CreateExpansion(outputDim,mset,opts); + } + + ParameterizedFunctionMex(std::shared_ptr> init_ptr){ + fun_ptr = init_ptr; + } +}; //end class + // Instance manager for ConditionalMap. template class mexplus::Session; +template class mexplus::Session; namespace { @@ -128,6 +143,16 @@ MEX_DEFINE(ConditionalMap_GetComponent) (int nlhs, mxArray* plhs[], } } +MEX_DEFINE(ConditionalMap_GetBaseFunction) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + + InputArguments input(nrhs, prhs, 1); + OutputArguments output(nlhs, plhs, 1); + ConditionalMapMex *condMap = Session::get(input.get(0)); + std::shared_ptr> func_ptr = condMap->map_ptr->GetBaseFunction(); + output.set(0, Session::create(new ParameterizedFunctionMex(func_ptr))); +} + MEX_DEFINE(ConditionalMap_Coeffs) (int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) { InputArguments input(nrhs, prhs, 1); @@ -137,6 +162,15 @@ MEX_DEFINE(ConditionalMap_Coeffs) (int nlhs, mxArray* plhs[], output.set(0,coeffs); } +MEX_DEFINE(ConditionalMap_CoeffMap) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 1); + OutputArguments output(nlhs, plhs, 1); + const ConditionalMapMex& condMap = Session::getConst(input.get(0)); + auto coeffs = condMap.map_ptr->CoeffMap(); + output.set(0,coeffs); +} + MEX_DEFINE(ConditionalMap_numCoeffs) (int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) { InputArguments input(nrhs, prhs, 1); diff --git a/bindings/matlab/src/ParameterizedFunctionBase_mex.cpp b/bindings/matlab/src/ParameterizedFunctionBase_mex.cpp index 95a230a7..75919dd0 100644 --- a/bindings/matlab/src/ParameterizedFunctionBase_mex.cpp +++ b/bindings/matlab/src/ParameterizedFunctionBase_mex.cpp @@ -1,8 +1,12 @@ #include -#include "MParT/ParameterizedFunctionBase.h" +#include "MParT/MultiIndices/MultiIndexSet.h" #include "MParT/Utilities/ArrayConversions.h" #include "MexArrayConversions.h" -#include "mexplus_eigen.h" +#include "MexMapOptionsConversions.h" +#include "MParT/MapOptions.h" +#include "MParT/MapFactory.h" +#include "MParT/ConditionalMapBase.h" +#include "MParT/TriangularMap.h" #include @@ -10,28 +14,128 @@ using namespace mpart; using namespace mexplus; using MemorySpace = Kokkos::HostSpace; +class ParameterizedFunctionMex { // The class +public: + std::shared_ptr> fun_ptr; + + ParameterizedFunctionMex(unsigned int outputDim, FixedMultiIndexSet const& mset, + MapOptions opts){ + fun_ptr = MapFactory::CreateExpansion(outputDim,mset,opts); + } + + ParameterizedFunctionMex(std::shared_ptr> init_ptr){ + fun_ptr = init_ptr; + } +}; //end class + // Instance manager for ParameterizedFunctionBase template class mexplus::Session>; namespace { -MEX_DEFINE(ParameterizedFunctionBase_new) (int nlhs, mxArray* plhs[], - int nrhs, const mxArray* prhs[]) { - InputArguments input(nrhs, prhs, 3); +MEX_DEFINE(ParameterizedFunction_newMap) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + + InputArguments input(nrhs, prhs, 11); OutputArguments output(nlhs, plhs, 1); - unsigned int inDim = input.get(0); - unsigned int outDim = input.get(1); - unsigned int nCoeffs = input.get(2); - output.set(0, Session>::create(new ParameterizedFunctionBase(inDim,outDim,nCoeffs))); + unsigned int outputDim = input.get(0); + const MultiIndexSet& mset = Session::getConst(input.get(1)); + MapOptions opts = MapOptionsFromMatlab(input.get(2),input.get(3), + input.get(4),input.get(5), + input.get(6),input.get(7), + input.get(8),input.get(9), + input.get(10)); + + output.set(0, Session::create(new ParameterizedFunctionMex(outputDim,mset.Fix(),opts))); } // Defines MEX API for delete. -MEX_DEFINE(ParameterizedFunctionBase_delete) (int nlhs, mxArray* plhs[], +MEX_DEFINE(ParameterizedFunction_delete) (int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) { InputArguments input(nrhs, prhs, 1); OutputArguments output(nlhs, plhs, 0); - Session>::destroy(input.get(0)); + Session::destroy(input.get(0)); +} + +MEX_DEFINE(ParameterizedFunction_SetCoeffs) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 2); + OutputArguments output(nlhs, plhs, 0); + const ParameterizedFunctionMex& parFunc = Session::getConst(input.get(0)); + auto coeffs = MexToKokkos1d(prhs[1]); + parFunc.fun_ptr->SetCoeffs(coeffs); +} + +MEX_DEFINE(ParameterizedFunction_Coeffs) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 1); + OutputArguments output(nlhs, plhs, 1); + const ParameterizedFunctionMex& parFunc = Session::getConst(input.get(0)); + auto coeffs = KokkosToVec(parFunc.fun_ptr->Coeffs()); + output.set(0,coeffs); +} + +MEX_DEFINE(ParameterizedFunction_CoeffMap) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 1); + OutputArguments output(nlhs, plhs, 1); + const ParameterizedFunctionMex& parFunc = Session::getConst(input.get(0)); + auto coeffs = parFunc.fun_ptr->CoeffMap(); + output.set(0,coeffs); +} + +MEX_DEFINE(ParameterizedFunction_numCoeffs) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 1); + OutputArguments output(nlhs, plhs, 1); + const ParameterizedFunctionMex& parFunc = Session::getConst(input.get(0)); + auto numcoeffs = parFunc.fun_ptr->numCoeffs; + output.set(0,numcoeffs); +} + +MEX_DEFINE(ParameterizedFunction_outputDim) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 1); + OutputArguments output(nlhs, plhs, 1); + const ParameterizedFunctionMex& parFunc = Session::getConst(input.get(0)); + unsigned int outDim = parFunc.fun_ptr->outputDim; + output.set(0, outDim); +} + +MEX_DEFINE(ParameterizedFunction_inputDim) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 1); + OutputArguments output(nlhs, plhs, 1); + const ParameterizedFunctionMex& parFunc = Session::getConst(input.get(0)); + unsigned int inDim = parFunc.fun_ptr->inputDim; + output.set(0, inDim); +} + +MEX_DEFINE(ParameterizedFunction_Evaluate) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 3); + OutputArguments output(nlhs, plhs, 0); + + const ParameterizedFunctionMex& parFunc = Session::getConst(input.get(0)); + StridedMatrix pts = MexToKokkos2d(prhs[1]); + StridedMatrix out = MexToKokkos2d(prhs[2]); + parFunc.fun_ptr->EvaluateImpl(pts, out); +} + +MEX_DEFINE(ParameterizedFunction_CoeffGrad) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + + InputArguments input(nrhs, prhs, 4); + OutputArguments output(nlhs, plhs, 0); + + const ParameterizedFunctionMex& parFunc = Session::getConst(input.get(0)); + + auto pts = MexToKokkos2d(prhs[1]); + auto sens = MexToKokkos2d(prhs[2]); + auto out = MexToKokkos2d(prhs[3]); + + parFunc.fun_ptr->CoeffGradImpl(pts,sens,out); } } // namespace \ No newline at end of file From b489dd109fe9f24b2b9df1edee87530eae54d8dc Mon Sep 17 00:00:00 2001 From: rubiop Date: Tue, 19 Jul 2022 18:04:32 -0400 Subject: [PATCH 19/20] added tests --- bindings/matlab/notes_binding.md | 23 ---- bindings/matlab/tests/TestExp.m | 51 ------- bindings/matlab/tests/Test_MultiIndex.m | 31 +++++ bindings/matlab/tests/Test_MultiIndexSet.m | 129 ++++++++++++++++++ .../matlab/tests/Test_ParameterizedFunction.m | 20 +++ bindings/matlab/tests/Test_TriangularMap.m | 38 ++++++ 6 files changed, 218 insertions(+), 74 deletions(-) delete mode 100644 bindings/matlab/notes_binding.md delete mode 100644 bindings/matlab/tests/TestExp.m create mode 100644 bindings/matlab/tests/Test_MultiIndex.m create mode 100644 bindings/matlab/tests/Test_MultiIndexSet.m create mode 100644 bindings/matlab/tests/Test_ParameterizedFunction.m create mode 100644 bindings/matlab/tests/Test_TriangularMap.m diff --git a/bindings/matlab/notes_binding.md b/bindings/matlab/notes_binding.md deleted file mode 100644 index 2e157a55..00000000 --- a/bindings/matlab/notes_binding.md +++ /dev/null @@ -1,23 +0,0 @@ -# Filling _mex.cpp files - -# Filling .m files - -# getConst -``` -const MultiIndexSet& mset = Session::getConst(input.get(0)) -output.set(0, mset.Frontier()) -``` - - -# get -``` -MultiIndexSet *mset = Session::get(input.get(0)) -output.set(0, mset->ForciblyExpand(activeIndex)) -``` - -# MapOption - - -# id_ - -# How to pass MParT object as argument of other object methods diff --git a/bindings/matlab/tests/TestExp.m b/bindings/matlab/tests/TestExp.m deleted file mode 100644 index 44423d16..00000000 --- a/bindings/matlab/tests/TestExp.m +++ /dev/null @@ -1,51 +0,0 @@ -classdef TestExp < handle -%DATABASE Example usage of the mexplus development kit. -% -% This class definition gives an interface to the underlying MEX functions -% built in the private directory. It is a good practice to wrap MEX functions -% with Matlab script so that the API is well documented and separated from -% its C++ implementation. Also such a wrapper is a good place to validate -% input arguments. -% -% Build -% ----- -% -% make -% -% See `make.m` for details. -% - -properties (Access = private) - id_ % ID of the session. -end - -methods - function this = TestExp() - %DATABASE Create a new database. - this.id_ = MParT_('new'); - end - - function delete(this) - %DELETE Destructor. - MParT_('delete', this.id_); - end - - function result = Evaluate1d(this, x) - result = MParT_('Evaluate1d',this.id_, x); - end - -end - -methods (Static) - function environment = getEnvironment() - %GETENVIRONMENT Get environment info. - environment = mexmpart('getEnvironment'); - end - - function setEnvironment(environment) - %SETENVIRONMENT Set environment info. - mexmpart('setEnvironment', environment); - end -end - -end diff --git a/bindings/matlab/tests/Test_MultiIndex.m b/bindings/matlab/tests/Test_MultiIndex.m new file mode 100644 index 00000000..6580fa3d --- /dev/null +++ b/bindings/matlab/tests/Test_MultiIndex.m @@ -0,0 +1,31 @@ +clear; +addpath(genpath('~/Installations/MParT/matlab/')) + +addpath(genpath('.')); + +KokkosInitialize(8); + +a=[2,3,4]; +multi=MultiIndex(a); +multi.String() + +multi2=MultiIndex(6,1); +multi2.String() + +%multi3=MultiIndex(6); %Not sure we want to keep this +multi3.String() + +multi3.Set(1,9); +multi3.Vector(); + +disp(multi2==multi3); +disp(multi2~=multi2); + +disp(multi2>multi3) + +disp(multi2>=multi3) +disp(multi2<=multi2) +disp(multi3 Date: Fri, 22 Jul 2022 14:48:48 -0400 Subject: [PATCH 20/20] Minor fixes and implemenation of Expand() to multiindexset. --- MParT/MultiIndices/MultiIndex.h | 4 ++-- MParT/MultiIndices/MultiIndexSet.h | 4 ++-- bindings/matlab/mat/MultiIndexSet.m | 14 +++++++------- bindings/matlab/src/BasisTypes_mex.cpp | 2 +- bindings/matlab/src/MultiIndexSet_mex.cpp | 16 +++++++++------- bindings/matlab/tests/Test_MultiIndex.m | 4 +--- bindings/matlab/tests/Test_MultiIndexSet.m | 2 -- .../matlab/tests/Test_ParameterizedFunction.m | 2 -- bindings/matlab/tests/Test_TriangularMap.m | 1 - src/MultiIndices/MultiIndexSet.cpp | 12 ++++++++++++ 10 files changed, 34 insertions(+), 27 deletions(-) diff --git a/MParT/MultiIndices/MultiIndex.h b/MParT/MultiIndices/MultiIndex.h index 3ae272c2..9790a011 100644 --- a/MParT/MultiIndices/MultiIndex.h +++ b/MParT/MultiIndices/MultiIndex.h @@ -177,10 +177,10 @@ class MultiIndex { std::vector nzVals; /// The maximum index over all nzInds pairs. - unsigned maxValue; + unsigned int maxValue; // the total order of the multiindex (i.e. the sum of the indices) - unsigned totalOrder; + unsigned int totalOrder; }; // class MultiIndex diff --git a/MParT/MultiIndices/MultiIndexSet.h b/MParT/MultiIndices/MultiIndexSet.h index 563caf7a..f6694ef9 100644 --- a/MParT/MultiIndices/MultiIndexSet.h +++ b/MParT/MultiIndices/MultiIndexSet.h @@ -111,7 +111,7 @@ MultiIndexSet set(length, limiter); @param[in] activeIndex Linear index of interest. @return A constant reference to the MultiIndex. */ - MultiIndex const& IndexToMulti(unsigned int activeIndex) const{return allMultis.at(active2global.at(activeIndex));}; + MultiIndex IndexToMulti(unsigned int activeIndex) const{return allMultis.at(active2global.at(activeIndex));}; /** Given a multiindex, return the linear index where it is located. @param[in] input An instance of the MultiIndex class. @@ -138,7 +138,7 @@ MultiIndexSet set(length, limiter); @param[in] activeIndex The index of the active MultiIndex to return. @return A pointer to the MultiIndex at index outputIndex. */ - MultiIndex const& at(int activeIndex) const{return IndexToMulti(activeIndex);} + MultiIndex at(int activeIndex) const{return IndexToMulti(activeIndex);} /** * This function provides access to each of the MultiIndices without any bounds checking on the vector. diff --git a/bindings/matlab/mat/MultiIndexSet.m b/bindings/matlab/mat/MultiIndexSet.m index b5ab4c62..cbee5eef 100644 --- a/bindings/matlab/mat/MultiIndexSet.m +++ b/bindings/matlab/mat/MultiIndexSet.m @@ -40,8 +40,8 @@ function delete(this) end function multi = IndexToMulti(this,activeIndex) - multi_id = MParT_('MultiIndexSet_IndexToMulti',this.id_,activeIndex-1); - multi = MultiIndex(multi_id,'id'); + multi_id = MParT_('MultiIndexSet_IndexToMulti', this.id_, activeIndex-1); + multi = MultiIndex(multi_id,"id"); end function result = MultiToIndex(this,multi) @@ -59,7 +59,7 @@ function delete(this) function multi = at(this,activeIndex) multi_id = MParT_('MultiIndexSet_at', this.id_,activeIndex-1); - multi = MultiIndex(multi_id,'id'); + multi = MultiIndex(multi_id,"id"); end function varargout = subsref(this,s) %seems dangerous @@ -73,7 +73,7 @@ function delete(this) case '{}' activeIndex = s(1).subs{1}; multi_id = MParT_('MultiIndexSet_subsref',this.id_,activeIndex-1); - [varargout{1:nargout}] = MultiIndex(multi_id,'id'); + [varargout{1:nargout}] = MultiIndex(multi_id,"id"); otherwise error('Indexing expression invalid.') end @@ -130,7 +130,7 @@ function Activate(this,multi) multi_ids = MParT_('MultiIndexSet_AdmissibleFowardNeighbors',this.id_,activeInd-1); listMultis = []; for i = 1:length(multi_ids) - listMultis=[listMultis,MultiIndex(multi_ids(i),'id')]; + listMultis=[listMultis,MultiIndex(multi_ids(i),"id")]; end end @@ -143,7 +143,7 @@ function Activate(this,multi) multi_ids = MParT_('MultiIndexSet_Margin',this.id_); listMultis = []; for i = 1:length(multi_ids) - listMultis=[listMultis,MultiIndex(multi_ids(i),'id')]; + listMultis=[listMultis,MultiIndex(multi_ids(i),"id")]; end end @@ -152,7 +152,7 @@ function Activate(this,multi) multi_ids = MParT_('MultiIndexSet_ReducedMargin',this.id_); listMultis = []; for i = 1:length(multi_ids) - listMultis=[listMultis,MultiIndex(multi_ids(i),'id')]; + listMultis=[listMultis,MultiIndex(multi_ids(i),"id")]; end end diff --git a/bindings/matlab/src/BasisTypes_mex.cpp b/bindings/matlab/src/BasisTypes_mex.cpp index affa5cff..fff1d2d9 100644 --- a/bindings/matlab/src/BasisTypes_mex.cpp +++ b/bindings/matlab/src/BasisTypes_mex.cpp @@ -163,4 +163,4 @@ MEX_DEFINE(deleteBasisTypesOpt) (int nlhs, mxArray* plhs[], } // namespace -MEX_DISPATCH // Don't forget to add this if MEX_DEFINE() is used. \ No newline at end of file +//MEX_DISPATCH // Don't forget to add this if MEX_DEFINE() is used. \ No newline at end of file diff --git a/bindings/matlab/src/MultiIndexSet_mex.cpp b/bindings/matlab/src/MultiIndexSet_mex.cpp index 01000688..2960adf8 100644 --- a/bindings/matlab/src/MultiIndexSet_mex.cpp +++ b/bindings/matlab/src/MultiIndexSet_mex.cpp @@ -14,6 +14,7 @@ using namespace mexplus; // Instance manager for MultiIndexSet // To do: bind functions using MultiIndex objects +//template class mexplus::Session; template class mexplus::Session; template class mexplus::Session>; @@ -51,6 +52,7 @@ MEX_DEFINE(MultiIndexSet_IndexToMulti) (int nlhs, mxArray* plhs[], OutputArguments output(nlhs, plhs, 1); const MultiIndexSet& mset = Session::getConst(input.get(0)); unsigned int activeIndex = input.get(1); + output.set(0, Session::create(new MultiIndex(mset.IndexToMulti(activeIndex)))); } @@ -159,13 +161,13 @@ MEX_DEFINE(MultiIndexSet_Expand) (int nlhs, mxArray* plhs[], output.set(0, mset->Expand(activeInd)); } -// MEX_DEFINE(MultiIndexSet_ExpandAny) (int nlhs, mxArray* plhs[], -// int nrhs, const mxArray* prhs[]) { -// InputArguments input(nrhs, prhs, 1); -// OutputArguments output(nlhs, plhs, 0); -// MultiIndexSet *mset = Session::get(input.get(0)); -// output.set(0, mset->Expand()); -// } +MEX_DEFINE(MultiIndexSet_ExpandAny) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 1); + OutputArguments output(nlhs, plhs, 1); + MultiIndexSet *mset = Session::get(input.get(0)); + output.set(0, mset->Expand()); +} MEX_DEFINE(MultiIndexSet_ForciblyExpand) (int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) { diff --git a/bindings/matlab/tests/Test_MultiIndex.m b/bindings/matlab/tests/Test_MultiIndex.m index 6580fa3d..2de3cb0c 100644 --- a/bindings/matlab/tests/Test_MultiIndex.m +++ b/bindings/matlab/tests/Test_MultiIndex.m @@ -1,6 +1,4 @@ clear; -addpath(genpath('~/Installations/MParT/matlab/')) - addpath(genpath('.')); KokkosInitialize(8); @@ -12,7 +10,7 @@ multi2=MultiIndex(6,1); multi2.String() -%multi3=MultiIndex(6); %Not sure we want to keep this +multi3=MultiIndex([0,1,2,3]); %Not sure we want to keep this multi3.String() multi3.Set(1,9); diff --git a/bindings/matlab/tests/Test_MultiIndexSet.m b/bindings/matlab/tests/Test_MultiIndexSet.m index 44a92079..6def512c 100644 --- a/bindings/matlab/tests/Test_MultiIndexSet.m +++ b/bindings/matlab/tests/Test_MultiIndexSet.m @@ -1,6 +1,4 @@ clear; -addpath(genpath('~/Installations/MParT/matlab/')) - addpath(genpath('.')); KokkosInitialize(8); diff --git a/bindings/matlab/tests/Test_ParameterizedFunction.m b/bindings/matlab/tests/Test_ParameterizedFunction.m index b33b8ae6..df914b91 100644 --- a/bindings/matlab/tests/Test_ParameterizedFunction.m +++ b/bindings/matlab/tests/Test_ParameterizedFunction.m @@ -1,6 +1,4 @@ clear -addpath(genpath('~/Installations/MParT/matlab/')) - addpath(genpath('.')); KokkosInitialize(8); diff --git a/bindings/matlab/tests/Test_TriangularMap.m b/bindings/matlab/tests/Test_TriangularMap.m index 47fd5404..1fd064c7 100644 --- a/bindings/matlab/tests/Test_TriangularMap.m +++ b/bindings/matlab/tests/Test_TriangularMap.m @@ -1,5 +1,4 @@ clear -addpath(genpath('~/Installations/MParT/matlab/')) addpath(genpath('.')); diff --git a/src/MultiIndices/MultiIndexSet.cpp b/src/MultiIndices/MultiIndexSet.cpp index 5cdb5b61..4f6419ce 100644 --- a/src/MultiIndices/MultiIndexSet.cpp +++ b/src/MultiIndices/MultiIndexSet.cpp @@ -645,6 +645,18 @@ std::vector MultiIndexSet::Expand(unsigned int activeIndex) return newIndices; } + +std::vector MultiIndexSet::Expand() +{ + std::vector frontier = Frontier(); + std::vector newInds, allNewInds; + for(auto& ind : frontier){ + newInds = Expand(ind); + allNewInds.insert(allNewInds.end(), newInds.begin(), newInds.end()); + } + return allNewInds; +} + std::vector MultiIndexSet::ForciblyExpand(unsigned int const activeIndex) { assert(activeIndex