-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add convnd, conv1d, expand, refactor conv2d, pad (#290)
* initial conv1d support * generalize conv1d to convnd * reshape weight * update * fix convnd * remove old conv2d tests * add nev conv2d eval tests * remove unused pad code * temprorarily skip shape_expand_dims when shape and axes are clipped index * temprorarily disable some conv1d and conv2d cases when on utl * temporarily disable conv tests on functional and constexpr suite * temporarily skp some graph functional tests * add eager expand * reduce some conv1d testing precision on utl build * skip expand test on utl build * temporarily disable pad test on gpu * add maybe_unused attributes for gcc werror * temporarily disable mbed-platformio ci * sprinkle more [[maybe_unused]]
- Loading branch information
Showing
166 changed files
with
11,140 additions
and
6,684 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,7 +12,7 @@ on: | |
|
||
jobs: | ||
build: | ||
|
||
if: ${{ false }} | ||
runs-on: ubuntu-latest | ||
|
||
strategy: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
#ifndef NMTOOLS_ARRAY_ARRAY_CONV1D_HPP | ||
#define NMTOOLS_ARRAY_ARRAY_CONV1D_HPP | ||
|
||
#include "nmtools/array/view/conv1d.hpp" | ||
#include "nmtools/array/eval.hpp" | ||
|
||
namespace nmtools::array | ||
{ | ||
template <typename output_t=none_t, typename context_t=none_t, typename resolver_t=eval_result_t<> | ||
, typename input_t, typename weight_t, typename bias_t=none_t | ||
, typename stride_t=none_t, typename padding_t=none_t, typename dilation_t=none_t, typename groups_t=meta::ct<1>> | ||
constexpr auto conv1d(const input_t& input, const weight_t& weight, const bias_t& bias=bias_t{} | ||
, const stride_t& stride=stride_t{}, const padding_t& padding=padding_t{}, const dilation_t& dilation=dilation_t{}, groups_t groups=groups_t{} | ||
, context_t&& context=context_t{}, output_t&& output=output_t{},meta::as_value<resolver_t> resolver=meta::as_value_v<resolver_t>) | ||
{ | ||
auto result = view::conv1d(input,weight,bias,stride,padding,dilation,groups); | ||
return eval(result | ||
, nmtools::forward<context_t>(context) | ||
, nmtools::forward<output_t>(output) | ||
, resolver | ||
); | ||
} | ||
} | ||
|
||
#endif // NMTOOLS_ARRAY_ARRAY_CONV1D_HPP |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
#ifndef NMTOOLS_ARRAY_ARRAY_CONV2D_HPP | ||
#define NMTOOLS_ARRAY_ARRAY_CONV2D_HPP | ||
|
||
#include "nmtools/array/view/conv2d.hpp" | ||
#include "nmtools/array/eval.hpp" | ||
|
||
namespace nmtools::array | ||
{ | ||
template <typename output_t=none_t, typename context_t=none_t, typename resolver_t=eval_result_t<> | ||
, typename input_t, typename weight_t, typename bias_t=none_t | ||
, typename stride_t=none_t, typename padding_t=none_t, typename dilation_t=none_t, typename groups_t=meta::ct<1>> | ||
constexpr auto conv2dv2(const input_t& input, const weight_t& weight, const bias_t& bias=bias_t{} | ||
, const stride_t& stride=stride_t{}, const padding_t& padding=padding_t{}, const dilation_t& dilation=dilation_t{}, groups_t groups=groups_t{} | ||
, context_t&& context=context_t{}, output_t&& output=output_t{},meta::as_value<resolver_t> resolver=meta::as_value_v<resolver_t>) | ||
{ | ||
auto result = view::conv2dv2(input,weight,bias,stride,padding,dilation,groups); | ||
return eval(result | ||
, nmtools::forward<context_t>(context) | ||
, nmtools::forward<output_t>(output) | ||
, resolver | ||
); | ||
} | ||
} | ||
|
||
#endif // NMTOOLS_ARRAY_ARRAY_CONV2D_HPP |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
#ifndef NMTOOLS_ARRAY_ARRAY_EXPAND_HPP | ||
#define NMTOOLS_ARRAY_ARRAY_EXPAND_HPP | ||
|
||
#include "nmtools/array/view/expand.hpp" | ||
#include "nmtools/array/eval.hpp" | ||
|
||
namespace nmtools::array | ||
{ | ||
/** | ||
* @brief Eagerly expand the contents of an array. | ||
* | ||
* @tparam output_t | ||
* @tparam context_t | ||
* @tparam array_t | ||
* @tparam axis_t | ||
* @param array input array | ||
* @param axis position in the expanded axes where the new axis (or axes) is placed. | ||
* @param context evaluation context | ||
* @param output | ||
* @return constexpr auto | ||
*/ | ||
template <typename output_t=none_t, typename context_t=none_t, typename resolver_t=eval_result_t<>, | ||
typename array_t, typename axis_t, typename spacing_t=nm_index_t, typename fill_value_t=nm_index_t> | ||
constexpr auto expand(const array_t& array, const axis_t& axis, const spacing_t& spacing=spacing_t{1}, fill_value_t fill_value=fill_value_t{0}, | ||
context_t&& context=context_t{}, output_t&& output=output_t{},meta::as_value<resolver_t> resolver=meta::as_value_v<resolver_t>) | ||
{ | ||
auto expanded = view::expand(array,axis,spacing,fill_value); | ||
return eval(expanded | ||
,nmtools::forward<context_t>(context) | ||
,nmtools::forward<output_t>(output) | ||
,resolver | ||
); | ||
} // expand | ||
} // namespace nmtools::array | ||
|
||
#endif // NMTOOLS_ARRAY_ARRAY_EXPAND_HPP |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.