diff --git a/MParT/MapFactory.h b/MParT/MapFactory.h index e50c8e44..e579e442 100644 --- a/MParT/MapFactory.h +++ b/MParT/MapFactory.h @@ -37,6 +37,18 @@ namespace mpart{ unsigned int outputDim, unsigned int totalOrder, MapOptions options = MapOptions()); + + /** + @brief Constructs a (generally) non-monotone multivariate expansion. + @param outputDim The output dimension of the expansion. Each output will be defined by the same multiindex set but will have different coefficients. + @param mset The multiindex set specifying which terms should be used in the multivariate expansion. + @param options Options specifying the 1d basis functions used in the parameterization. + */ + template + std::shared_ptr> CreateExpansion(unsigned int outputDim, + FixedMultiIndexSet const& mset, + MapOptions options = MapOptions()); + } } diff --git a/MParT/MultiIndices/MultiIndex.h b/MParT/MultiIndices/MultiIndex.h index 9e731d79..9790a011 100644 --- a/MParT/MultiIndices/MultiIndex.h +++ b/MParT/MultiIndices/MultiIndex.h @@ -85,7 +85,7 @@ class MultiIndex { */ bool Set(unsigned int ind, unsigned int val); - /** Obtain the a particular component of the multiindex. Notice that this function can be slow for multiindices with many nonzero components. The worst case performance requires \f$O(|\mathbf{j}|_0)\f$ integer comparisons, where \f$|\mathbf{j}|_0\f$ denotes the number of nonzero entries in the multiindex. + /** Obtain a particular component of the multiindex. Notice that this function can be slow for multiindices with many nonzero components. The worst case performance requires \f$O(|\mathbf{j}|_0)\f$ integer comparisons, where \f$|\mathbf{j}|_0\f$ denotes the number of nonzero entries in the multiindex. @param[in] ind The component to return. @return The integer stored in component dim of the multiindex. */ @@ -177,10 +177,10 @@ class MultiIndex { std::vector nzVals; /// The maximum index over all nzInds pairs. - unsigned maxValue; + unsigned int maxValue; // the total order of the multiindex (i.e. the sum of the indices) - unsigned totalOrder; + unsigned int totalOrder; }; // class MultiIndex diff --git a/MParT/MultiIndices/MultiIndexSet.h b/MParT/MultiIndices/MultiIndexSet.h index 02c8cf93..f6694ef9 100644 --- a/MParT/MultiIndices/MultiIndexSet.h +++ b/MParT/MultiIndices/MultiIndexSet.h @@ -111,7 +111,7 @@ MultiIndexSet set(length, limiter); @param[in] activeIndex Linear index of interest. @return A constant reference to the MultiIndex. */ - MultiIndex const& IndexToMulti(unsigned int activeIndex) const{return allMultis.at(active2global.at(activeIndex));}; + MultiIndex IndexToMulti(unsigned int activeIndex) const{return allMultis.at(active2global.at(activeIndex));}; /** Given a multiindex, return the linear index where it is located. @param[in] input An instance of the MultiIndex class. @@ -138,7 +138,7 @@ MultiIndexSet set(length, limiter); @param[in] activeIndex The index of the active MultiIndex to return. @return A pointer to the MultiIndex at index outputIndex. */ - MultiIndex const& at(int activeIndex) const{return IndexToMulti(activeIndex);} + MultiIndex at(int activeIndex) const{return IndexToMulti(activeIndex);} /** * This function provides access to each of the MultiIndices without any bounds checking on the vector. @@ -170,7 +170,7 @@ MultiIndexSet set(length, limiter); already in the set and if the input function is unique, it is added to the set. @param[in] rhs The MultiIndex we want to add to the set. - @return A reference to this MultiIndex set, which may now contain the new + @return A reference to this MultiIndex, which may now contain the new MultiIndex in rhs. */ MultiIndexSet& operator+=(MultiIndex const& rhs); diff --git a/README.md b/README.md index ac695d8c..8df9943e 100644 --- a/README.md +++ b/README.md @@ -42,9 +42,7 @@ cd build cmake \ -DCMAKE_INSTALL_PREFIX= \ -DPYTHON_EXECUTABLE=`which python` \ - -DCMAKE_BUILD_TYPE=Release \ -DCMAKE_OSX_ARCHITECTURES=x86_64 \ - -DMatlab_MEX_EXTENSION="mexmaci64" \ -DKokkos_ENABLE_PTHREAD=ON \ -DKokkos_ENABLE_SERIAL=ON \ .. diff --git a/bindings/matlab/CMakeLists.txt b/bindings/matlab/CMakeLists.txt index 05aa22e6..c3a848f3 100644 --- a/bindings/matlab/CMakeLists.txt +++ b/bindings/matlab/CMakeLists.txt @@ -8,7 +8,9 @@ set(MEX_SOURCE src/KokkosUtilities_mex.cpp src/ConditionalMap_mex.cpp src/MultiIndexSet_mex.cpp + src/MultiIndex_mex.cpp src/FixedMultiIndexSet_mex.cpp + src/ParameterizedFunctionBase_mex.cpp src/MexArrayConversions.cpp src/MexMapOptionsConversions.cpp ) diff --git a/bindings/matlab/external/mexplus_eigen.h b/bindings/matlab/external/mexplus_eigen.h index 81049719..84667ebe 100644 --- a/bindings/matlab/external/mexplus_eigen.h +++ b/bindings/matlab/external/mexplus_eigen.h @@ -1,22 +1,8 @@ -/* - * eos - A 3D Morphable Model fitting library written in modern C++11/14. - * - * File: matlab/include/mexplus_eigen.hpp - * - * Copyright 2016-2018 Patrik Huber - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. +/** + * Adapted from: https://github.com/patrikhuber/eos/blob/master/matlab/include/mexplus_eigen.hpp + * @brief Handle custom conversion between mxarray and C++ matrix types */ + #pragma once #ifndef MPART_MEXPLUS_EIGEN_HPP @@ -34,12 +20,7 @@ namespace mexplus { /** - * @brief Define a template specialisation for Eigen::MatrixXd for ... . - * - * The default precision in Matlab is double, but most matrices in eos (for example the PCA basis matrices - * are stored as float values, so this defines conversion from these matrices to Matlab. - * - * Todo: Documentation. + * @brief Define a template specialisation for Eigen::MatrixXd */ template <> inline mxArray* MxArray::from(const Eigen::MatrixXd& eigen_matrix) @@ -48,12 +29,6 @@ inline mxArray* MxArray::from(const Eigen::MatrixXd& eigen_matrix) const int num_cols = static_cast(eigen_matrix.cols()); MxArray out_array(MxArray::Numeric(num_rows, num_cols)); - // This might not copy the data but it's evil and probably really dangerous!!!: - // mxSetData(const_cast(matrix.get()), (void*)value.data()); - - // This copies the data. But I suppose it makes sense that we copy the data when we go - // from C++ to Matlab, since Matlab can unload the C++ mex module at any time I think. - // Loop is column-wise for (int c = 0; c < num_cols; ++c) { for (int r = 0; r < num_rows; ++r) @@ -64,13 +39,28 @@ inline mxArray* MxArray::from(const Eigen::MatrixXd& eigen_matrix) return out_array.release(); }; +/** + * @brief Define a template specialisation for Eigen::VectorXd + */ +template <> +inline mxArray* MxArray::from(const Eigen::VectorXd& eigen_vector) +{ + const int num_elems = static_cast(eigen_vector.size()); + //Choose (1,n) or (n,1) ? + MxArray out_array(MxArray::Numeric(1,num_elems)); + + for (int c = 0; c < num_elems; ++c) + { + out_array.set(1, c, eigen_vector(c)); + } + return out_array.release(); +}; + /** - * @brief Define a template specialisation for Eigen::MatrixXd for ... . + * @brief Define a template specialisation for Eigen::MatrixXd * - * Todo: Documentation. - * TODO: Maybe provide this one as MatrixXf as well as MatrixXd? Matlab's default is double? */ template <> inline void MxArray::to(const mxArray* in_array, Eigen::MatrixXd* eigen_matrix) @@ -102,14 +92,8 @@ inline void MxArray::to(const mxArray* in_array, Eigen::MatrixXd* eigen_matrix) const auto nrows = array.dimensions()[0]; // or use array.rows() const auto ncols = array.dimensions()[1]; - // I think I can just use Eigen::Matrix, not a Map - the Matrix c'tor that we call creates a Map anyway? Eigen::Map> eigen_map( array.getData(), nrows, ncols); - // Not sure that's alright - who owns the data? I think as it is now, everything points to the data in the - // mxArray owned by Matlab, but I'm not 100% sure. - // Actually, doesn't eigen_map go out of scope and get destroyed? This might be trouble? But this - // assignment should (or might) copy, then it's fine? Check if it invokes the copy c'tor. - // 2 May 2018: Yes this copies. *eigen_matrix = eigen_map; }; @@ -182,4 +166,4 @@ inline void MxArray::to(const mxArray* in_array, Kokkos::View operator + function result = gt(this,multi) + result = MParT_('MultiIndex_Gt',this.id_,multi.get_id()); + end + + % >= operator + function result = Ge(this,multi) + result = MParT_('MultiIndex_Ge',this.id_,multi.get_id()); + end + + % <= operator + function result = Le(this,multi) + result = MParT_('MultiIndex_Le',this.id_,multi.get_id()); + end + + function result = get_id(this) + result = this.id_; + end + + +end + +end diff --git a/bindings/matlab/mat/MultiIndexSet.m b/bindings/matlab/mat/MultiIndexSet.m index 58ee8e77..cbee5eef 100644 --- a/bindings/matlab/mat/MultiIndexSet.m +++ b/bindings/matlab/mat/MultiIndexSet.m @@ -39,10 +39,159 @@ function delete(this) MParT_('MultiIndexSet_delete', this.id_); end + function multi = IndexToMulti(this,activeIndex) + multi_id = MParT_('MultiIndexSet_IndexToMulti', this.id_, activeIndex-1); + multi = MultiIndex(multi_id,"id"); + end + + function result = MultiToIndex(this,multi) + result= MParT_('MultiIndexSet_MultiToIndex',this.id_,multi.get_id()); + result = result + 1; + end + + function result = Length(this) + result = MParT_('MultiIndexSet_Length', this.id_); + end + function result = MaxOrders(this) result = MParT_('MultiIndexSet_MaxOrders', this.id_); end + function multi = at(this,activeIndex) + multi_id = MParT_('MultiIndexSet_at', this.id_,activeIndex-1); + multi = MultiIndex(multi_id,"id"); + end + + function varargout = subsref(this,s) %seems dangerous + switch s(1).type + case '.' + % Keep built-in functionality for '.' + [varargout{1:nargout}] = builtin('subsref', this, s); + case '()' + % Keep built-in functionality for '()' + [varargout{1:nargout}] = builtin('subsref', this, s); + case '{}' + activeIndex = s(1).subs{1}; + multi_id = MParT_('MultiIndexSet_subsref',this.id_,activeIndex-1); + [varargout{1:nargout}] = MultiIndex(multi_id,"id"); + otherwise + error('Indexing expression invalid.') + end + end + + function result = Size(this) + result = MParT_('MultiIndexSet_Size', this.id_); + end + + function result = plus(this,toAdd) + if strcmp(class(toAdd),'MultiIndexSet') + MParT_('MultiIndexSet_addMultiIndexSet',this.id_,toAdd.get_id()); + elseif strcmp(class(toAdd),'MultiIndex') + MParT_('MultiIndexSet_addMultiIndex',this.id_,toAdd.get_id()); + else + error('Unrecognized type to add to MultiIndexSet') + end + end + + function result = Union(this,mset) + result = MParT_('MultiIndexSet_Union',this.id_,mset.get_id()); + end + + function Activate(this,multi) + MParT_('MultiIndexSet_Activate',this.id_,multi.get_id()); + end + + function result = AddActive(this,multi) + result = MParT_('MultiIndexSet_AddActive',this.id_,multi.get_id()); + end + + function result = Expand(this,varargin) + if(nargin == 2) + %-1 to keep consitent with matlab ordering + result = MParT_('MultiIndexSet_Expand',this.id_,varargin{1}-1); + elseif(nargin == 1) + result = MParT_('MultiIndexSet_ExpandAny',this.id_); + else + error('Wrong number of inputs') + end + end + + function result = ForciblyExpand(this,activeInd) + %-1 to keep consitent with matlab ordering + result = MParT_('MultiIndexSet_ForciblyExpand',this.id_,activeInd-1); + end + + function result = ForciblyActivate(this,multi) + result = MParT_('MultiIndexSet_ForciblyActivate',this.id_,multi.get_id()); + end + + function listMultis = AdmissibleForwardNeighbors(this,activeInd) + %-1 to keep consistent with matlab ordering + multi_ids = MParT_('MultiIndexSet_AdmissibleFowardNeighbors',this.id_,activeInd-1); + listMultis = []; + for i = 1:length(multi_ids) + listMultis=[listMultis,MultiIndex(multi_ids(i),"id")]; + end + end + + function result = Frontier(this) + result = MParT_('MultiIndexSet_Frontier',this.id_); + end + + function listMultis = Margin(this) + %-1 to keep consistent with matlab ordering + multi_ids = MParT_('MultiIndexSet_Margin',this.id_); + listMultis = []; + for i = 1:length(multi_ids) + listMultis=[listMultis,MultiIndex(multi_ids(i),"id")]; + end + end + + function listMultis = ReducedMargin(this) + %-1 to keep consistent with matlab ordering + multi_ids = MParT_('MultiIndexSet_ReducedMargin',this.id_); + listMultis = []; + for i = 1:length(multi_ids) + listMultis=[listMultis,MultiIndex(multi_ids(i),"id")]; + end + end + + function result = StrictFrontier(this) + result = MParT_('MultiIndexSet_StrictFrontier',this.id_); + end + + function result = BackwardNeighbors(this,activeIndex) + %-1 to keep consitent with matlab ordering + result = MParT_('MultiIndexSet_BackwardNeighbors',this.id_,activeIndex-1); + end + + function result = IsAdmissible(this,multi) + result = MParT_('MultiIndexSet_IsAdmissible',this.id_,multi.get_id()); + end + + function result = IsExpandable(this,activeIndex) + %-1 to keep consitent with matlab ordering + result = MParT_('MultiIndexSet_IsExpandable',this.id_,activeIndex-1); + end + + function result = IsActive(this,multi) + result = MParT_('MultiIndexSet_IsActive',this.id_,multi.get_id()); + end + + function result = NumActiveForward(this,activeIndex) + %-1 to keep consitent with matlab ordering + result = MParT_('MultiIndexSet_NumActiveForward',this.id_,activeIndex-1); + end + + function result = NumForward(this,activeIndex) + %-1 to keep consitent with matlab ordering + result = MParT_('MultiIndexSet_NumForward',this.id_,activeIndex-1); + end + + function Visualize(this) + MParT_('MultiIndexSet_Visualize',this.id_); + end + function result = get_id(this) result = this.id_; end @@ -50,6 +199,8 @@ function delete(this) function fixed_mset = Fix(this) fixed_mset = FixedMultiIndexSet(this); end + + end diff --git a/bindings/matlab/mat/ParameterizedFunction.m b/bindings/matlab/mat/ParameterizedFunction.m new file mode 100644 index 00000000..8ca53d88 --- /dev/null +++ b/bindings/matlab/mat/ParameterizedFunction.m @@ -0,0 +1,93 @@ +classdef ParameterizedFunction < handle +%DATABASE Example usage of the mexplus development kit. +% +% This class definition gives an interface to the underlying MEX functions +% built in the private directory. It is a good practice to wrap MEX functions +% with Matlab script so that the API is well documented and separated from +% its C++ implementation. Also such a wrapper is a good place to validate +% input arguments. +% +% Build +% ----- +% +% make +% +% See `make.m` for details. +% + +properties (Access = private) + id_ +end + +methods + + function this = ParameterizedFunction(varargin) + if(nargin==2) + if(isstring(varargin{2})) + if(varargin{2}=="id") + this.id_=varargin{1}; + end + end + elseif(nargin==3) + outputDim = varargin{1}; + mset = varargin{2}; + mapOptions = varargin{3}; + mexOptions = mapOptions.getMexOptions; + input_str=['MParT_(',char(39),'ParameterizedFunction_newMap',char(39),',outputDim,mset.get_id()']; + for o=1:length(mexOptions) + input_o=[',mexOptions{',num2str(o),'}']; + input_str=[input_str,input_o]; + end + input_str=[input_str,')']; + this.id_ = eval(input_str); + else + error('Invalid number of inputs') + end + end + + function delete(this) + %DELETE Destructor. + MParT_('ParameterizedFunction_delete', this.id_); + end + + function SetCoeffs(this,coeffs) + MParT_('ParameterizedFunction_SetCoeffs',this.id_,coeffs(:)); + end + + function result = Coeffs(this) + result = MParT_('ParameterizedFunction_Coeffs',this.id_); + end + + function result = CoeffMap(this) + result = MParT_('ParameterizedFunction_CoeffMap',this.id_); + end + + function result = numCoeffs(this) + result = MParT_('ParameterizedFunction_numCoeffs',this.id_); + end + + function result = Evaluate(this,pts) + result = zeros(this.outputDim, size(pts,2)); + MParT_('ParameterizedFunction_Evaluate',this.id_,pts,result); + end + + function result = CoeffGrad(this,pts,sens) + result = zeros(this.numCoeffs, size(pts,2)); + MParT_('ParameterizedFunction_CoeffGrad',this.id_,pts,sens,result); + end + + function result = get_id(this) + result = this.id_; + end + + function result = outputDim(this) + result = MParT_('ParameterizedFunction_outputDim',this.id_); + end + + function result = inputDim(this) + result = MParT_('ParameterizedFunction_inputDim',this.id_); + end + +end + +end diff --git a/bindings/matlab/mat/TestExp.m b/bindings/matlab/mat/TestExp.m deleted file mode 100644 index 44423d16..00000000 --- a/bindings/matlab/mat/TestExp.m +++ /dev/null @@ -1,51 +0,0 @@ -classdef TestExp < handle -%DATABASE Example usage of the mexplus development kit. -% -% This class definition gives an interface to the underlying MEX functions -% built in the private directory. It is a good practice to wrap MEX functions -% with Matlab script so that the API is well documented and separated from -% its C++ implementation. Also such a wrapper is a good place to validate -% input arguments. -% -% Build -% ----- -% -% make -% -% See `make.m` for details. -% - -properties (Access = private) - id_ % ID of the session. -end - -methods - function this = TestExp() - %DATABASE Create a new database. - this.id_ = MParT_('new'); - end - - function delete(this) - %DELETE Destructor. - MParT_('delete', this.id_); - end - - function result = Evaluate1d(this, x) - result = MParT_('Evaluate1d',this.id_, x); - end - -end - -methods (Static) - function environment = getEnvironment() - %GETENVIRONMENT Get environment info. - environment = mexmpart('getEnvironment'); - end - - function setEnvironment(environment) - %SETENVIRONMENT Set environment info. - mexmpart('setEnvironment', environment); - end -end - -end diff --git a/bindings/matlab/mat/TriangularMap.m b/bindings/matlab/mat/TriangularMap.m index c086cc0a..21f0c448 100644 --- a/bindings/matlab/mat/TriangularMap.m +++ b/bindings/matlab/mat/TriangularMap.m @@ -1,87 +1,3 @@ -classdef TriangularMap < handle -%DATABASE Example usage of the mexplus development kit. -% -% This class definition gives an interface to the underlying MEX functions -% built in the private directory. It is a good practice to wrap MEX functions -% with Matlab script so that the API is well documented and separated from -% its C++ implementation. Also such a wrapper is a good place to validate -% input arguments. -% -% Build -% ----- -% -% make -% -% See `make.m` for details. -% - -properties (Access = private) - id_ -end - -methods - function this = TriangularMap(list_id) - %DATABASE Create a new database. - this.id_ = MParT_('newTriMap',list_id); - end - - function delete(this) - %DELETE Destructor. - MParT_('deleteMap', this.id_); - end - - function SetCoeffs(this,coeffs) - MParT_('SetCoeffs',this.id_,coeffs(:)); - end - - function result = Coeffs(this) - result = MParT_('Coeffs',this.id_); - end - - function result = numCoeffs(this) - result = MParT_('numCoeffs',this.id_); - end - - %function result = Evaluate(this,pts) - % result = zeros(1,size(pts,2)); - % MParT_('Evaluate',this.id_,pts,result); - %end - - function result = Evaluate(this,pts) - result = MParT_('Evaluate',this.id_,pts); - end - - function result = LogDeterminant(this,pts) - result = MParT_('LogDeterminant',this.id_,pts); - end - - function result = Inverse(this,x1,r) - result = MParT_('Inverse',this.id_,x1,r); - end - - function result = CoeffGrad(this,pts,sens) - result = MParT_('CoeffGrad',this.id_,pts,sens); - end - - function result = LogDeterminantCoeffGrad(this,pts) - result = MParT_('LogDeterminantCoeffGrad',this.id_,pts); - end - - function result = get_id(this) - result = this.id_; - end -end - -methods (Static) - function environment = getEnvironment() - %GETENVIRONMENT Get environment info. - environment = mexmpart('getEnvironment'); - end - - function setEnvironment(environment) - %SETENVIRONMENT Set environment info. - mexmpart('setEnvironment', environment); - end -end - -end +function map = TriangularMap(listCondMaps) + map = ConditionalMap(listCondMaps); +end \ No newline at end of file diff --git a/bindings/matlab/src/BasisTypes_mex.cpp b/bindings/matlab/src/BasisTypes_mex.cpp index affa5cff..fff1d2d9 100644 --- a/bindings/matlab/src/BasisTypes_mex.cpp +++ b/bindings/matlab/src/BasisTypes_mex.cpp @@ -163,4 +163,4 @@ MEX_DEFINE(deleteBasisTypesOpt) (int nlhs, mxArray* plhs[], } // namespace -MEX_DISPATCH // Don't forget to add this if MEX_DEFINE() is used. \ No newline at end of file +//MEX_DISPATCH // Don't forget to add this if MEX_DEFINE() is used. \ No newline at end of file diff --git a/bindings/matlab/src/ConditionalMap_mex.cpp b/bindings/matlab/src/ConditionalMap_mex.cpp index e65a9bee..117ca5b8 100644 --- a/bindings/matlab/src/ConditionalMap_mex.cpp +++ b/bindings/matlab/src/ConditionalMap_mex.cpp @@ -8,7 +8,6 @@ #include "MParT/ConditionalMapBase.h" #include "MParT/TriangularMap.h" #include -#include "mexplus_eigen.h" #include @@ -27,16 +26,36 @@ class ConditionalMapMex { // The class map_ptr = MapFactory::CreateComponent(mset,opts); } + ConditionalMapMex(std::shared_ptr> init_ptr){ + map_ptr = init_ptr; + } + ConditionalMapMex(std::vector>> blocks){ map_ptr = std::make_shared>(blocks); } + ConditionalMapMex(unsigned int inputDim, unsigned int outputDim, unsigned int totalOrder, MapOptions opts){ map_ptr = MapFactory::CreateTriangular(inputDim,outputDim,totalOrder,opts); } }; //end class +class ParameterizedFunctionMex { // The class +public: + std::shared_ptr> fun_ptr; + + ParameterizedFunctionMex(unsigned int outputDim, FixedMultiIndexSet const& mset, + MapOptions opts){ + fun_ptr = MapFactory::CreateExpansion(outputDim,mset,opts); + } + + ParameterizedFunctionMex(std::shared_ptr> init_ptr){ + fun_ptr = init_ptr; + } +}; //end class + // Instance manager for ConditionalMap. template class mexplus::Session; +template class mexplus::Session; namespace { @@ -108,6 +127,32 @@ MEX_DEFINE(ConditionalMap_SetCoeffs) (int nlhs, mxArray* plhs[], condMap.map_ptr->SetCoeffs(coeffs); } +MEX_DEFINE(ConditionalMap_GetComponent) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + + InputArguments input(nrhs, prhs, 2); + OutputArguments output(nlhs, plhs, 1); + unsigned int i = input.get(1); + ConditionalMapMex *condMap = Session::get(input.get(0)); + std::shared_ptr> condMap_ptr = condMap->map_ptr; + std::shared_ptr> tri_ptr = std::dynamic_pointer_cast>(condMap_ptr); + if(tri_ptr==nullptr){ + throw std::runtime_error("Tried to access GetComponent with a type other than TriangularMap"); + }else{ + output.set(0, Session::create(new ConditionalMapMex(tri_ptr->GetComponent(i)))); + } +} + +MEX_DEFINE(ConditionalMap_GetBaseFunction) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + + InputArguments input(nrhs, prhs, 1); + OutputArguments output(nlhs, plhs, 1); + ConditionalMapMex *condMap = Session::get(input.get(0)); + std::shared_ptr> func_ptr = condMap->map_ptr->GetBaseFunction(); + output.set(0, Session::create(new ParameterizedFunctionMex(func_ptr))); +} + MEX_DEFINE(ConditionalMap_Coeffs) (int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) { InputArguments input(nrhs, prhs, 1); @@ -117,6 +162,15 @@ MEX_DEFINE(ConditionalMap_Coeffs) (int nlhs, mxArray* plhs[], output.set(0,coeffs); } +MEX_DEFINE(ConditionalMap_CoeffMap) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 1); + OutputArguments output(nlhs, plhs, 1); + const ConditionalMapMex& condMap = Session::getConst(input.get(0)); + auto coeffs = condMap.map_ptr->CoeffMap(); + output.set(0,coeffs); +} + MEX_DEFINE(ConditionalMap_numCoeffs) (int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) { InputArguments input(nrhs, prhs, 1); diff --git a/bindings/matlab/src/MultiIndexSet_mex.cpp b/bindings/matlab/src/MultiIndexSet_mex.cpp index 0359bd6d..2960adf8 100644 --- a/bindings/matlab/src/MultiIndexSet_mex.cpp +++ b/bindings/matlab/src/MultiIndexSet_mex.cpp @@ -12,7 +12,9 @@ using namespace mpart; using namespace mexplus; -// Instance manager for Multi_idxs_tr. +// Instance manager for MultiIndexSet +// To do: bind functions using MultiIndex objects +//template class mexplus::Session; template class mexplus::Session; template class mexplus::Session>; @@ -44,7 +46,33 @@ MEX_DEFINE(MultiIndexSet_delete) (int nlhs, mxArray* plhs[], Session::destroy(input.get(0)); } -// Defines MEX API for delete. +MEX_DEFINE(MultiIndexSet_IndexToMulti) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 2); + OutputArguments output(nlhs, plhs, 1); + const MultiIndexSet& mset = Session::getConst(input.get(0)); + unsigned int activeIndex = input.get(1); + + output.set(0, Session::create(new MultiIndex(mset.IndexToMulti(activeIndex)))); +} + +MEX_DEFINE(MultiIndexSet_MultiToIndex) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 2); + OutputArguments output(nlhs, plhs, 1); + const MultiIndexSet& mset = Session::getConst(input.get(0)); + const MultiIndex& multi = Session::getConst(input.get(1)); + output.set(0,mset.MultiToIndex(multi)); +} + +MEX_DEFINE(MultiIndexSet_Length) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 1); + OutputArguments output(nlhs, plhs, 1); + const MultiIndexSet& mset = Session::getConst(input.get(0)); + output.set(0, mset.Length()); +} + MEX_DEFINE(MultiIndexSet_MaxOrders) (int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) { InputArguments input(nrhs, prhs, 1); @@ -53,4 +81,232 @@ MEX_DEFINE(MultiIndexSet_MaxOrders) (int nlhs, mxArray* plhs[], output.set(0, mset.MaxOrders()); } +MEX_DEFINE(MultiIndexSet_at) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 2); + OutputArguments output(nlhs, plhs, 1); + const MultiIndexSet& mset = Session::getConst(input.get(0)); + int activeIndex = input.get(1); + output.set(0, Session::create(new MultiIndex(mset.at(activeIndex)))); +} + +MEX_DEFINE(MultiIndexSet_subsref) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 2); + OutputArguments output(nlhs, plhs, 1); + const MultiIndexSet& mset = Session::getConst(input.get(0)); + int activeIndex = input.get(1); + output.set(0, Session::create(new MultiIndex(mset[activeIndex]))); +} + +MEX_DEFINE(MultiIndexSet_Size) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 1); + OutputArguments output(nlhs, plhs, 1); + const MultiIndexSet& mset = Session::getConst(input.get(0)); + output.set(0, mset.Size()); +} + +MEX_DEFINE(MultiIndexSet_addMultiIndexSet) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 2); + OutputArguments output(nlhs, plhs, 0); + MultiIndexSet *mset = Session::get(input.get(0)); + const MultiIndexSet& msetToAdd = Session::getConst(input.get(1)); + (*mset)+=msetToAdd; +} + +MEX_DEFINE(MultiIndexSet_addMultiIndex) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 2); + OutputArguments output(nlhs, plhs, 0); + MultiIndexSet *mset = Session::get(input.get(0)); + const MultiIndex& multiToAdd = Session::getConst(input.get(1)); + (*mset)+=multiToAdd; +} + +MEX_DEFINE(MultiIndexSet_Union) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 2); + OutputArguments output(nlhs, plhs, 1); + MultiIndexSet *mset = Session::get(input.get(0)); + const MultiIndexSet& rhs = Session::getConst(input.get(1)); + output.set(0, mset->Union(rhs)); +} + +MEX_DEFINE(MultiIndexSet_Activate) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 2); + OutputArguments output(nlhs, plhs, 0); + MultiIndexSet *mset = Session::get(input.get(0)); + const MultiIndex& multi = Session::getConst(input.get(1)); + mset->Activate(multi); +} + +MEX_DEFINE(MultiIndexSet_AddActive) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 2); + OutputArguments output(nlhs, plhs, 1); + MultiIndexSet *mset = Session::get(input.get(0)); + const MultiIndex& multi = Session::getConst(input.get(1)); + output.set(0,mset->AddActive(multi)); +} + +MEX_DEFINE(MultiIndexSet_Expand) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 2); + OutputArguments output(nlhs, plhs, 1); + MultiIndexSet *mset = Session::get(input.get(0)); + unsigned int activeInd = input.get(1); + output.set(0, mset->Expand(activeInd)); +} + +MEX_DEFINE(MultiIndexSet_ExpandAny) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 1); + OutputArguments output(nlhs, plhs, 1); + MultiIndexSet *mset = Session::get(input.get(0)); + output.set(0, mset->Expand()); +} + +MEX_DEFINE(MultiIndexSet_ForciblyExpand) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 2); + OutputArguments output(nlhs, plhs, 1); + MultiIndexSet *mset = Session::get(input.get(0)); + const unsigned int activeIndex = input.get(1); + output.set(0, mset->ForciblyExpand(activeIndex)); +} + +MEX_DEFINE(MultiIndexSet_ForciblyActivate) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 2); + OutputArguments output(nlhs, plhs, 1); + MultiIndexSet *mset = Session::get(input.get(0)); + const MultiIndex& multi = Session::getConst(input.get(1)); + output.set(0,mset->ForciblyActivate(multi)); +} + +MEX_DEFINE(MultiIndexSet_AdmissibleFowardNeighbors) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 2); + MultiIndexSet *mset = Session::get(input.get(0)); + unsigned int activeIndex = input.get(1); + std::vector vecMultiIndex = mset->AdmissibleForwardNeighbors(activeIndex); + OutputArguments output(nlhs, plhs, 1); + std::vector multi_ids(vecMultiIndex.size()); + for (int i=0; i::create(new MultiIndex(vecMultiIndex[i])); + } + output.set(0,multi_ids); +} + +MEX_DEFINE(MultiIndexSet_Frontier) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 1); + OutputArguments output(nlhs, plhs, 1); + const MultiIndexSet& mset = Session::getConst(input.get(0)); + output.set(0, mset.Frontier()); +} + +MEX_DEFINE(MultiIndexSet_Margin) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 1); + OutputArguments output(nlhs, plhs, 1); + MultiIndexSet *mset = Session::get(input.get(0)); + std::vector vecMultiIndex = mset->Margin(); + std::vector multi_ids(vecMultiIndex.size()); + for (int i=0; i::create(new MultiIndex(vecMultiIndex[i])); + } + output.set(0,multi_ids); +} + +MEX_DEFINE(MultiIndexSet_ReducedMargin) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 1); + OutputArguments output(nlhs, plhs, 1); + MultiIndexSet *mset = Session::get(input.get(0)); + std::vector vecMultiIndex = mset->ReducedMargin(); + std::vector multi_ids(vecMultiIndex.size()); + for (int i=0; i::create(new MultiIndex(vecMultiIndex[i])); + } + output.set(0,multi_ids); +} + +MEX_DEFINE(MultiIndexSet_StrictFrontier) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 1); + OutputArguments output(nlhs, plhs, 1); + const MultiIndexSet& mset = Session::getConst(input.get(0)); + output.set(0, mset.StrictFrontier()); +} + +MEX_DEFINE(MultiIndexSet_BackwardNeighbors) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 2); + OutputArguments output(nlhs, plhs, 1); + const MultiIndexSet& mset = Session::getConst(input.get(0)); + unsigned int activeIndex = input.get(1); + output.set(0, mset.BackwardNeighbors(activeIndex)); +} + +MEX_DEFINE(MultiIndexSet_IsAdmissible) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 2); + OutputArguments output(nlhs, plhs, 1); + const MultiIndexSet& mset = Session::getConst(input.get(0)); + const MultiIndex& multi = Session::getConst(input.get(1)); + output.set(0, mset.IsAdmissible(multi)); +} + +MEX_DEFINE(MultiIndexSet_IsExpandable) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 2); + OutputArguments output(nlhs, plhs, 1); + const MultiIndexSet& mset = Session::getConst(input.get(0)); + unsigned int activeIndex = input.get(1); + output.set(0, mset.IsExpandable(activeIndex)); +} + +MEX_DEFINE(MultiIndexSet_IsActive) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 2); + OutputArguments output(nlhs, plhs, 1); + const MultiIndexSet& mset = Session::getConst(input.get(0)); + const MultiIndex& multi = Session::getConst(input.get(1)); + output.set(0, mset.IsActive(multi)); +} + +MEX_DEFINE(MultiIndexSet_NumActiveForward) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 2); + OutputArguments output(nlhs, plhs, 1); + const MultiIndexSet& mset = Session::getConst(input.get(0)); + unsigned int activeInd = input.get(1); + output.set(0, mset.NumActiveForward(activeInd)); +} + +MEX_DEFINE(MultiIndexSet_NumForward) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 2); + OutputArguments output(nlhs, plhs, 1); + const MultiIndexSet& mset = Session::getConst(input.get(0)); + unsigned int activeInd = input.get(1); + output.set(0, mset.NumForward(activeInd)); +} + +MEX_DEFINE(MultiIndexSet_Visualize) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 1); + OutputArguments output(nlhs, plhs, 0); + const MultiIndexSet& mset = Session::getConst(input.get(0)); + mset.Visualize(); +} + + + + + } // namespace \ No newline at end of file diff --git a/bindings/matlab/src/MultiIndex_mex.cpp b/bindings/matlab/src/MultiIndex_mex.cpp new file mode 100644 index 00000000..4f331710 --- /dev/null +++ b/bindings/matlab/src/MultiIndex_mex.cpp @@ -0,0 +1,171 @@ +#include +#include "MParT/MultiIndices/MultiIndexSet.h" +#include "MParT/MultiIndices/MultiIndex.h" + +#include "MParT/Utilities/ArrayConversions.h" +#include "MexArrayConversions.h" +#include "mexplus_eigen.h" +#include + + +using namespace mpart; +using namespace mexplus; + + +// Instance manager for MultiIndex +template class mexplus::Session; + + +namespace { + +MEX_DEFINE(MultiIndex_newDefault) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + + InputArguments input(nrhs, prhs, 2); + OutputArguments output(nlhs, plhs, 1); + const unsigned int lengthIn = input.get(0); + const unsigned int val = input.get(1); + output.set(0, Session::create(new MultiIndex(lengthIn,val))); +} + +MEX_DEFINE(MultiIndex_newEigen) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + + InputArguments input(nrhs, prhs, 1); + OutputArguments output(nlhs, plhs, 1); + const auto mult = input.get(0); + output.set(0, Session::create(new MultiIndex(mult.cast()))); +} + +// Defines MEX API for delete. +MEX_DEFINE(MultiIndex_delete) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 1); + OutputArguments output(nlhs, plhs, 0); + Session::destroy(input.get(0)); +} + +MEX_DEFINE(MultiIndex_Vector) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 1); + OutputArguments output(nlhs, plhs, 1); + const MultiIndex& multi = Session::getConst(input.get(0)); + output.set(0, multi.Vector()); +} + +MEX_DEFINE(MultiIndex_Sum) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 1); + OutputArguments output(nlhs, plhs, 1); + const MultiIndex& multi = Session::getConst(input.get(0)); + output.set(0, multi.Sum()); +} + +MEX_DEFINE(MultiIndex_Max) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 1); + OutputArguments output(nlhs, plhs, 1); + const MultiIndex& multi = Session::getConst(input.get(0)); + output.set(0, multi.Max()); +} + +MEX_DEFINE(MultiIndex_Set) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 3); + OutputArguments output(nlhs, plhs, 1); + MultiIndex *multi = Session::get(input.get(0)); + unsigned int ind = input.get(1); + unsigned int val = input.get(2); + output.set(0, multi->Set(ind,val)); +} + +MEX_DEFINE(MultiIndex_Get) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 2); + OutputArguments output(nlhs, plhs, 1); + const MultiIndex& multi = Session::getConst(input.get(0)); + unsigned int ind = input.get(1); + output.set(0, multi.Get(ind)); +} + +MEX_DEFINE(MultiIndex_NumNz) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 1); + OutputArguments output(nlhs, plhs, 1); + const MultiIndex& multi = Session::getConst(input.get(0)); + output.set(0, multi.NumNz()); +} + +MEX_DEFINE(MultiIndex_String) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 1); + OutputArguments output(nlhs, plhs, 1); + const MultiIndex& multi = Session::getConst(input.get(0)); + output.set(0, multi.String()); +} + +MEX_DEFINE(MultiIndex_Length) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 1); + OutputArguments output(nlhs, plhs, 1); + const MultiIndex& multi = Session::getConst(input.get(0)); + output.set(0, multi.Length()); +} + +MEX_DEFINE(MultiIndex_Eq) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 2); + OutputArguments output(nlhs, plhs, 1); + const MultiIndex& multi = Session::getConst(input.get(0)); + const MultiIndex& multi2 = Session::getConst(input.get(1)); + output.set(0, multi==multi2); +} + +MEX_DEFINE(MultiIndex_Ne) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 2); + OutputArguments output(nlhs, plhs, 1); + const MultiIndex& multi = Session::getConst(input.get(0)); + const MultiIndex& multi2 = Session::getConst(input.get(1)); + output.set(0, multi!=multi2); +} + +MEX_DEFINE(MultiIndex_Lt) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 2); + OutputArguments output(nlhs, plhs, 1); + const MultiIndex& multi = Session::getConst(input.get(0)); + const MultiIndex& multi2 = Session::getConst(input.get(1)); + output.set(0, multi::getConst(input.get(0)); + const MultiIndex& multi2 = Session::getConst(input.get(1)); + output.set(0, multi>multi2); +} + +MEX_DEFINE(MultiIndex_Ge) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 2); + OutputArguments output(nlhs, plhs, 1); + const MultiIndex& multi = Session::getConst(input.get(0)); + const MultiIndex& multi2 = Session::getConst(input.get(1)); + output.set(0, multi>=multi2); +} + +MEX_DEFINE(MultiIndex_Le) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 2); + OutputArguments output(nlhs, plhs, 1); + const MultiIndex& multi = Session::getConst(input.get(0)); + const MultiIndex& multi2 = Session::getConst(input.get(1)); + output.set(0, multi<=multi2); +} + + + +} // namespace \ No newline at end of file diff --git a/bindings/matlab/src/ParameterizedFunctionBase_mex.cpp b/bindings/matlab/src/ParameterizedFunctionBase_mex.cpp new file mode 100644 index 00000000..75919dd0 --- /dev/null +++ b/bindings/matlab/src/ParameterizedFunctionBase_mex.cpp @@ -0,0 +1,141 @@ +#include +#include "MParT/MultiIndices/MultiIndexSet.h" +#include "MParT/Utilities/ArrayConversions.h" +#include "MexArrayConversions.h" +#include "MexMapOptionsConversions.h" +#include "MParT/MapOptions.h" +#include "MParT/MapFactory.h" +#include "MParT/ConditionalMapBase.h" +#include "MParT/TriangularMap.h" +#include + + +using namespace mpart; +using namespace mexplus; +using MemorySpace = Kokkos::HostSpace; + +class ParameterizedFunctionMex { // The class +public: + std::shared_ptr> fun_ptr; + + ParameterizedFunctionMex(unsigned int outputDim, FixedMultiIndexSet const& mset, + MapOptions opts){ + fun_ptr = MapFactory::CreateExpansion(outputDim,mset,opts); + } + + ParameterizedFunctionMex(std::shared_ptr> init_ptr){ + fun_ptr = init_ptr; + } +}; //end class + + +// Instance manager for ParameterizedFunctionBase +template class mexplus::Session>; + + +namespace { +MEX_DEFINE(ParameterizedFunction_newMap) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + + InputArguments input(nrhs, prhs, 11); + OutputArguments output(nlhs, plhs, 1); + unsigned int outputDim = input.get(0); + const MultiIndexSet& mset = Session::getConst(input.get(1)); + MapOptions opts = MapOptionsFromMatlab(input.get(2),input.get(3), + input.get(4),input.get(5), + input.get(6),input.get(7), + input.get(8),input.get(9), + input.get(10)); + + output.set(0, Session::create(new ParameterizedFunctionMex(outputDim,mset.Fix(),opts))); +} + +// Defines MEX API for delete. +MEX_DEFINE(ParameterizedFunction_delete) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 1); + OutputArguments output(nlhs, plhs, 0); + Session::destroy(input.get(0)); +} + +MEX_DEFINE(ParameterizedFunction_SetCoeffs) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 2); + OutputArguments output(nlhs, plhs, 0); + const ParameterizedFunctionMex& parFunc = Session::getConst(input.get(0)); + auto coeffs = MexToKokkos1d(prhs[1]); + parFunc.fun_ptr->SetCoeffs(coeffs); +} + +MEX_DEFINE(ParameterizedFunction_Coeffs) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 1); + OutputArguments output(nlhs, plhs, 1); + const ParameterizedFunctionMex& parFunc = Session::getConst(input.get(0)); + auto coeffs = KokkosToVec(parFunc.fun_ptr->Coeffs()); + output.set(0,coeffs); +} + +MEX_DEFINE(ParameterizedFunction_CoeffMap) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 1); + OutputArguments output(nlhs, plhs, 1); + const ParameterizedFunctionMex& parFunc = Session::getConst(input.get(0)); + auto coeffs = parFunc.fun_ptr->CoeffMap(); + output.set(0,coeffs); +} + +MEX_DEFINE(ParameterizedFunction_numCoeffs) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 1); + OutputArguments output(nlhs, plhs, 1); + const ParameterizedFunctionMex& parFunc = Session::getConst(input.get(0)); + auto numcoeffs = parFunc.fun_ptr->numCoeffs; + output.set(0,numcoeffs); +} + +MEX_DEFINE(ParameterizedFunction_outputDim) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 1); + OutputArguments output(nlhs, plhs, 1); + const ParameterizedFunctionMex& parFunc = Session::getConst(input.get(0)); + unsigned int outDim = parFunc.fun_ptr->outputDim; + output.set(0, outDim); +} + +MEX_DEFINE(ParameterizedFunction_inputDim) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 1); + OutputArguments output(nlhs, plhs, 1); + const ParameterizedFunctionMex& parFunc = Session::getConst(input.get(0)); + unsigned int inDim = parFunc.fun_ptr->inputDim; + output.set(0, inDim); +} + +MEX_DEFINE(ParameterizedFunction_Evaluate) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + InputArguments input(nrhs, prhs, 3); + OutputArguments output(nlhs, plhs, 0); + + const ParameterizedFunctionMex& parFunc = Session::getConst(input.get(0)); + StridedMatrix pts = MexToKokkos2d(prhs[1]); + StridedMatrix out = MexToKokkos2d(prhs[2]); + parFunc.fun_ptr->EvaluateImpl(pts, out); +} + +MEX_DEFINE(ParameterizedFunction_CoeffGrad) (int nlhs, mxArray* plhs[], + int nrhs, const mxArray* prhs[]) { + + InputArguments input(nrhs, prhs, 4); + OutputArguments output(nlhs, plhs, 0); + + const ParameterizedFunctionMex& parFunc = Session::getConst(input.get(0)); + + auto pts = MexToKokkos2d(prhs[1]); + auto sens = MexToKokkos2d(prhs[2]); + auto out = MexToKokkos2d(prhs[3]); + + parFunc.fun_ptr->CoeffGradImpl(pts,sens,out); +} + +} // namespace \ No newline at end of file diff --git a/bindings/matlab/tests/Test_MultiIndex.m b/bindings/matlab/tests/Test_MultiIndex.m new file mode 100644 index 00000000..2de3cb0c --- /dev/null +++ b/bindings/matlab/tests/Test_MultiIndex.m @@ -0,0 +1,29 @@ +clear; +addpath(genpath('.')); + +KokkosInitialize(8); + +a=[2,3,4]; +multi=MultiIndex(a); +multi.String() + +multi2=MultiIndex(6,1); +multi2.String() + +multi3=MultiIndex([0,1,2,3]); %Not sure we want to keep this +multi3.String() + +multi3.Set(1,9); +multi3.Vector(); + +disp(multi2==multi3); +disp(multi2~=multi2); + +disp(multi2>multi3) + +disp(multi2>=multi3) +disp(multi2<=multi2) +disp(multi3> mpart::MapFactory::CreateTriang } +template +std::shared_ptr> mpart::MapFactory::CreateExpansion(unsigned int outputDim, + FixedMultiIndexSet const& mset, + MapOptions opts) +{ + std::shared_ptr> output; + + if(opts.basisType==BasisTypes::ProbabilistHermite){ + + ProbabilistHermite basis1d; + output = std::make_shared>(outputDim, mset, basis1d); + + }else if(opts.basisType==BasisTypes::PhysicistHermite){ + + PhysicistHermite basis1d; + output = std::make_shared>(outputDim, mset, basis1d); + + }else if(opts.basisType==BasisTypes::HermiteFunctions){ + + HermiteFunction basis1d; + output = std::make_shared>(outputDim, mset, basis1d); + } + + if(output){ + output->SetCoeffs(Kokkos::View("Component Coefficients", output->numCoeffs)); + return output; + } + + std::stringstream msg; + msg << "Could not parse options in CreateExpansion. Unknown 1d basis type."; + throw std::runtime_error(msg.str()); + + return nullptr; +} + + template std::shared_ptr> mpart::MapFactory::CreateComponent(FixedMultiIndexSet const&, MapOptions); +template std::shared_ptr> mpart::MapFactory::CreateExpansion(unsigned int, FixedMultiIndexSet const&, MapOptions); template std::shared_ptr> mpart::MapFactory::CreateTriangular(unsigned int, unsigned int, unsigned int, MapOptions); #if defined(KOKKOS_ENABLE_CUDA ) || defined(KOKKOS_ENABLE_SYCL) template std::shared_ptr> mpart::MapFactory::CreateComponent(FixedMultiIndexSet const&, MapOptions); + template std::shared_ptr> mpart::MapFactory::CreateExpansion(unsigned int, FixedMultiIndexSet const&, MapOptions); template std::shared_ptr> mpart::MapFactory::CreateTriangular(unsigned int, unsigned int, unsigned int, MapOptions); #endif \ No newline at end of file diff --git a/src/MultiIndices/MultiIndexSet.cpp b/src/MultiIndices/MultiIndexSet.cpp index 5cdb5b61..4f6419ce 100644 --- a/src/MultiIndices/MultiIndexSet.cpp +++ b/src/MultiIndices/MultiIndexSet.cpp @@ -645,6 +645,18 @@ std::vector MultiIndexSet::Expand(unsigned int activeIndex) return newIndices; } + +std::vector MultiIndexSet::Expand() +{ + std::vector frontier = Frontier(); + std::vector newInds, allNewInds; + for(auto& ind : frontier){ + newInds = Expand(ind); + allNewInds.insert(allNewInds.end(), newInds.begin(), newInds.end()); + } + return allNewInds; +} + std::vector MultiIndexSet::ForciblyExpand(unsigned int const activeIndex) { assert(activeIndex eval = map->Evaluate(pts); } + +TEST_CASE( "Testing multivariate expansion factory", "[MapFactoryExpansion]" ) { + + MapOptions options; + options.basisType = BasisTypes::ProbabilistHermite; + + unsigned int outDim = 5; + unsigned int inDim = 3; + unsigned int maxDegree = 5; + FixedMultiIndexSet mset(inDim,maxDegree); + + std::shared_ptr> func = MapFactory::CreateExpansion(outDim, mset, options); + REQUIRE(func!=nullptr); + + unsigned int numPts = 100; + Kokkos::View pts("Points", inDim, numPts); + for(unsigned int i=0; i eval = func->Evaluate(pts); + CHECK(eval.extent(0)==outDim); + CHECK(eval.extent(1)==numPts); +} + + TEST_CASE( "Testing factory method for triangular map", "[MapFactoryTriangular]" ) { MapOptions options;