Skip to content

Commit

Permalink
Merge pull request #160 from MeasureTransport/rubiop/matlab-bindings
Browse files Browse the repository at this point in the history
Completed matlab-bindings (hopefully)
  • Loading branch information
mparno authored Jul 22, 2022
2 parents 47cf06c + 02f15d5 commit 96eae14
Show file tree
Hide file tree
Showing 24 changed files with 1,344 additions and 204 deletions.
12 changes: 12 additions & 0 deletions MParT/MapFactory.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,18 @@ namespace mpart{
unsigned int outputDim,
unsigned int totalOrder,
MapOptions options = MapOptions());

/**
@brief Constructs a (generally) non-monotone multivariate expansion.
@param outputDim The output dimension of the expansion. Each output will be defined by the same multiindex set but will have different coefficients.
@param mset The multiindex set specifying which terms should be used in the multivariate expansion.
@param options Options specifying the 1d basis functions used in the parameterization.
*/
template<typename MemorySpace>
std::shared_ptr<ParameterizedFunctionBase<MemorySpace>> CreateExpansion(unsigned int outputDim,
FixedMultiIndexSet<MemorySpace> const& mset,
MapOptions options = MapOptions());

}
}

Expand Down
6 changes: 3 additions & 3 deletions MParT/MultiIndices/MultiIndex.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class MultiIndex {
*/
bool Set(unsigned int ind, unsigned int val);

/** Obtain the a particular component of the multiindex. Notice that this function can be slow for multiindices with many nonzero components. The worst case performance requires \f$O(|\mathbf{j}|_0)\f$ integer comparisons, where \f$|\mathbf{j}|_0\f$ denotes the number of nonzero entries in the multiindex.
/** Obtain a particular component of the multiindex. Notice that this function can be slow for multiindices with many nonzero components. The worst case performance requires \f$O(|\mathbf{j}|_0)\f$ integer comparisons, where \f$|\mathbf{j}|_0\f$ denotes the number of nonzero entries in the multiindex.
@param[in] ind The component to return.
@return The integer stored in component dim of the multiindex.
*/
Expand Down Expand Up @@ -177,10 +177,10 @@ class MultiIndex {
std::vector<unsigned int> nzVals;

/// The maximum index over all nzInds pairs.
unsigned maxValue;
unsigned int maxValue;

// the total order of the multiindex (i.e. the sum of the indices)
unsigned totalOrder;
unsigned int totalOrder;

}; // class MultiIndex

Expand Down
6 changes: 3 additions & 3 deletions MParT/MultiIndices/MultiIndexSet.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ MultiIndexSet set(length, limiter);
@param[in] activeIndex Linear index of interest.
@return A constant reference to the MultiIndex.
*/
MultiIndex const& IndexToMulti(unsigned int activeIndex) const{return allMultis.at(active2global.at(activeIndex));};
MultiIndex IndexToMulti(unsigned int activeIndex) const{return allMultis.at(active2global.at(activeIndex));};

/** Given a multiindex, return the linear index where it is located.
@param[in] input An instance of the MultiIndex class.
Expand All @@ -138,7 +138,7 @@ MultiIndexSet set(length, limiter);
@param[in] activeIndex The index of the active MultiIndex to return.
@return A pointer to the MultiIndex at index outputIndex.
*/
MultiIndex const& at(int activeIndex) const{return IndexToMulti(activeIndex);}
MultiIndex at(int activeIndex) const{return IndexToMulti(activeIndex);}

/**
* This function provides access to each of the MultiIndices without any bounds checking on the vector.
Expand Down Expand Up @@ -170,7 +170,7 @@ MultiIndexSet set(length, limiter);
already in the set and if the input function is unique, it is
added to the set.
@param[in] rhs The MultiIndex we want to add to the set.
@return A reference to this MultiIndex set, which may now contain the new
@return A reference to this MultiIndex, which may now contain the new
MultiIndex in rhs.
*/
MultiIndexSet& operator+=(MultiIndex const& rhs);
Expand Down
2 changes: 0 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,7 @@ cd build
cmake \
-DCMAKE_INSTALL_PREFIX=<your/install/path> \
-DPYTHON_EXECUTABLE=`which python` \
-DCMAKE_BUILD_TYPE=Release \
-DCMAKE_OSX_ARCHITECTURES=x86_64 \
-DMatlab_MEX_EXTENSION="mexmaci64" \
-DKokkos_ENABLE_PTHREAD=ON \
-DKokkos_ENABLE_SERIAL=ON \
..
Expand Down
2 changes: 2 additions & 0 deletions bindings/matlab/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ set(MEX_SOURCE
src/KokkosUtilities_mex.cpp
src/ConditionalMap_mex.cpp
src/MultiIndexSet_mex.cpp
src/MultiIndex_mex.cpp
src/FixedMultiIndexSet_mex.cpp
src/ParameterizedFunctionBase_mex.cpp
src/MexArrayConversions.cpp
src/MexMapOptionsConversions.cpp
)
Expand Down
64 changes: 24 additions & 40 deletions bindings/matlab/external/mexplus_eigen.h
Original file line number Diff line number Diff line change
@@ -1,22 +1,8 @@
/*
* eos - A 3D Morphable Model fitting library written in modern C++11/14.
*
* File: matlab/include/mexplus_eigen.hpp
*
* Copyright 2016-2018 Patrik Huber
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
/**
* Adapted from: https://github.com/patrikhuber/eos/blob/master/matlab/include/mexplus_eigen.hpp
* @brief Handle custom conversion between mxarray and C++ matrix types
*/

#pragma once

#ifndef MPART_MEXPLUS_EIGEN_HPP
Expand All @@ -34,12 +20,7 @@
namespace mexplus {

/**
* @brief Define a template specialisation for Eigen::MatrixXd for ... .
*
* The default precision in Matlab is double, but most matrices in eos (for example the PCA basis matrices
* are stored as float values, so this defines conversion from these matrices to Matlab.
*
* Todo: Documentation.
* @brief Define a template specialisation for Eigen::MatrixXd
*/
template <>
inline mxArray* MxArray::from(const Eigen::MatrixXd& eigen_matrix)
Expand All @@ -48,12 +29,6 @@ inline mxArray* MxArray::from(const Eigen::MatrixXd& eigen_matrix)
const int num_cols = static_cast<int>(eigen_matrix.cols());
MxArray out_array(MxArray::Numeric<double>(num_rows, num_cols));

// This might not copy the data but it's evil and probably really dangerous!!!:
// mxSetData(const_cast<mxArray*>(matrix.get()), (void*)value.data());

// This copies the data. But I suppose it makes sense that we copy the data when we go
// from C++ to Matlab, since Matlab can unload the C++ mex module at any time I think.
// Loop is column-wise
for (int c = 0; c < num_cols; ++c)
{
for (int r = 0; r < num_rows; ++r)
Expand All @@ -64,13 +39,28 @@ inline mxArray* MxArray::from(const Eigen::MatrixXd& eigen_matrix)
return out_array.release();
};

/**
* @brief Define a template specialisation for Eigen::VectorXd
*/
template <>
inline mxArray* MxArray::from(const Eigen::VectorXd& eigen_vector)
{
const int num_elems = static_cast<int>(eigen_vector.size());
//Choose (1,n) or (n,1) ?
MxArray out_array(MxArray::Numeric<double>(1,num_elems));

for (int c = 0; c < num_elems; ++c)
{
out_array.set(1, c, eigen_vector(c));
}
return out_array.release();
};



/**
* @brief Define a template specialisation for Eigen::MatrixXd for ... .
* @brief Define a template specialisation for Eigen::MatrixXd
*
* Todo: Documentation.
* TODO: Maybe provide this one as MatrixXf as well as MatrixXd? Matlab's default is double?
*/
template <>
inline void MxArray::to(const mxArray* in_array, Eigen::MatrixXd* eigen_matrix)
Expand Down Expand Up @@ -102,14 +92,8 @@ inline void MxArray::to(const mxArray* in_array, Eigen::MatrixXd* eigen_matrix)
const auto nrows = array.dimensions()[0]; // or use array.rows()
const auto ncols = array.dimensions()[1];

// I think I can just use Eigen::Matrix, not a Map - the Matrix c'tor that we call creates a Map anyway?
Eigen::Map<Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>> eigen_map(
array.getData<double>(), nrows, ncols);
// Not sure that's alright - who owns the data? I think as it is now, everything points to the data in the
// mxArray owned by Matlab, but I'm not 100% sure.
// Actually, doesn't eigen_map go out of scope and get destroyed? This might be trouble? But this
// assignment should (or might) copy, then it's fine? Check if it invokes the copy c'tor.
// 2 May 2018: Yes this copies.
*eigen_matrix = eigen_map;
};

Expand Down Expand Up @@ -182,4 +166,4 @@ inline void MxArray::to(const mxArray* in_array, Kokkos::View<double**, Kokkos::

} /* namespace mexplus */

#endif /* MPART_MEXPLUS_EIGEN_HPP */
#endif /* MPART_MEXPLUS_EIGEN_H */
44 changes: 30 additions & 14 deletions bindings/matlab/mat/ConditionalMap.m
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,22 @@
function this = ConditionalMap(varargin)

if(nargin==2)
mset = varargin{1};
mapOptions = varargin{2};
mexOptions = mapOptions.getMexOptions;

input_str=['MParT_(',char(39),'ConditionalMap_newMap',char(39),',mset.get_id()'];
for o=1:length(mexOptions)
input_o=[',mexOptions{',num2str(o),'}'];
input_str=[input_str,input_o];
if(isstring(varargin{2}))
if(varargin{2}=="id")
this.id_=varargin{1};
end
else
mset = varargin{1};
mapOptions = varargin{2};
mexOptions = mapOptions.getMexOptions;
input_str=['MParT_(',char(39),'ConditionalMap_newMap',char(39),',mset.get_id()'];
for o=1:length(mexOptions)
input_o=[',mexOptions{',num2str(o),'}'];
input_str=[input_str,input_o];
end
input_str=[input_str,')'];
this.id_ = eval(input_str);
end
input_str=[input_str,')'];

this.id_ = eval(input_str);

elseif(nargin==4)
inputDim = varargin{1};
outputDim = varargin{2};
Expand All @@ -54,8 +57,7 @@
this.id_ = eval(input_str);

elseif(nargin==1)
MParT_('ConditionalMap_newTriMap', varargin{1});

this.id_=MParT_('ConditionalMap_newTriMap', varargin{1});
else
error('Invalid number of inputs')
end
Expand All @@ -66,6 +68,16 @@ function delete(this)
MParT_('ConditionalMap_deleteMap', this.id_);
end

function condMap = GetComponent(this,i)
condMap_id = MParT_('ConditionalMap_GetComponent',this.id_,i-1);
condMap = ConditionalMap(condMap_id,"id");
end

function parFunc = GetBaseFunction(this)
parFunc_id=MParT_('ConditionalMap_GetBaseFunction',this.id_);
parFunc = ParameterizedFunction(parFunc_id,"id");
end

function SetCoeffs(this,coeffs)
MParT_('ConditionalMap_SetCoeffs',this.id_,coeffs(:));
end
Expand All @@ -74,6 +86,10 @@ function SetCoeffs(this,coeffs)
result = MParT_('ConditionalMap_Coeffs',this.id_);
end

function result = CoeffMap(this)
result = MParT_('ConditionalMap_CoeffMap',this.id_);
end

function result = numCoeffs(this)
result = MParT_('ConditionalMap_numCoeffs',this.id_);
end
Expand Down
111 changes: 111 additions & 0 deletions bindings/matlab/mat/MultiIndex.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
classdef MultiIndex < handle
%DATABASE Example usage of the mexplus development kit.
%
% This class definition gives an interface to the underlying MEX functions
% built in the private directory. It is a good practice to wrap MEX functions
% with Matlab script so that the API is well documented and separated from
% its C++ implementation. Also such a wrapper is a good place to validate
% input arguments.
%
% Build
% -----
%
% make
%
% See `make.m` for details.
%

properties (Access = private)
id_
end

methods
function this = MultiIndex(varargin)
if(nargin==2)
if(isstring(varargin{2}))
if(varargin{2}=="id")
this.id_ = varargin{1};
end
else
this.id_ = MParT_('MultiIndex_newDefault', varargin{1},varargin{2});
end
else
this.id_ = MParT_('MultiIndex_newEigen',varargin{1});
end
end

function delete(this)
%DELETE Destructor.
MParT_('MultiIndex_delete', this.id_);
end

function result = Vector(this)
result = MParT_('MultiIndex_Vector', this.id_);
end

function result = Sum(this)
result = MParT_('MultiIndex_Sum', this.id_);
end

function result = Max(this)
result = MParT_('MultiIndex_Max', this.id_);
end

function result = Set(this,ind,val)
result = MParT_('MultiIndex_Set', this.id_,ind-1,val);
end

function result = Get(this,ind)
result = MParT_('MultiIndex_Get', this.id_,ind-1);
end

function result = NumNz(this)
result = MParT_('MultiIndex_NumNz', this.id_);
end

function result = String(this)
result = MParT_('MultiIndex_String', this.id_);
end

function result = Length(this)
result = MParT_('MultiIndex_Length', this.id_);
end

% == operator
function result = eq(this,multi)
result = MParT_('MultiIndex_Eq',this.id_,multi.get_id());
end

% ~= operator
function result = ne(this,multi)
result = MParT_('MultiIndex_Ne',this.id_,multi.get_id());
end

% < operator
function result = lt(this,multi)
result = MParT_('MultiIndex_Lt',this.id_,multi.get_id());
end

% > operator
function result = gt(this,multi)
result = MParT_('MultiIndex_Gt',this.id_,multi.get_id());
end

% >= operator
function result = Ge(this,multi)
result = MParT_('MultiIndex_Ge',this.id_,multi.get_id());
end

% <= operator
function result = Le(this,multi)
result = MParT_('MultiIndex_Le',this.id_,multi.get_id());
end

function result = get_id(this)
result = this.id_;
end


end

end
Loading

0 comments on commit 96eae14

Please sign in to comment.