Skip to content

Commit

Permalink
implement balanced tree reduce for xilinxhls backend
Browse files Browse the repository at this point in the history
  • Loading branch information
francescobrivio authored and thesps committed May 13, 2024
1 parent 271401c commit a36ccde
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 14 deletions.
52 changes: 41 additions & 11 deletions conifer/backends/xilinxhls/firmware/BDT_unrolled.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,37 @@

namespace BDT{

/* ---
* Balanced tree reduce implementation.
* Reduces an array of inputs to a single value using the template binary operator 'Op',
* for example summing all elements with OpAdd, or finding the maximum with OpMax
* Use only when the input array is fully unrolled. Or, slice out a fully unrolled section
* before applying and accumulate the result over the rolled dimension.
* Required for emulation to guarantee equality of ordering.
* --- */
constexpr int floorlog2(int x) { return (x < 2) ? 0 : 1 + floorlog2(x / 2); }

constexpr int pow2(int x) { return x == 0 ? 1 : 2 * pow2(x - 1); }

template <class T, int N, class Op> T reduce(const T *x, Op op) {
static constexpr int leftN = pow2(floorlog2(N - 1)) > 0 ? pow2(floorlog2(N - 1)) : 0;
static constexpr int rightN = N - leftN > 0 ? N - leftN : 0;
if (N == 1) {
return x[0];
}
if (N == 2) {
return op(x[0], x[1]);
}
return op(reduce<T, leftN, Op>(x, op), reduce<T, rightN, Op>(x + leftN, op));
}

template <class T> class OpAdd {
public:
T operator()(T a, T b) { return a + b; }
};

// Number of trees given number of classes
constexpr int fn_classes(int n_classes){
// Number of trees given number of classes
return n_classes == 2 ? 1 : n_classes;
}

Expand Down Expand Up @@ -99,23 +128,24 @@ struct BDT{
public:
score_t normalisation;
score_t init_predict[fn_classes(n_classes)];
OpAdd<score_t> op_add;

void tree_scores(input_t x, score_t scores[n_trees][fn_classes(n_classes)]) const;
void tree_scores(input_t x, score_t scores[fn_classes(n_classes)][n_trees]) const;

void decision_function(input_t x, score_t score[fn_classes(n_classes)]) const{
score_t scores[n_trees][fn_classes(n_classes)];
score_t scores[fn_classes(n_classes)][n_trees];
#pragma HLS ARRAY_PARTITION variable=scores dim=0
// Get predictions scores
tree_scores(x, scores);
// Reduce
Reduce:
for(int j = 0; j < fn_classes(n_classes); j++){
// Init predictions
score[j] = init_predict[j];
// Sum predictions from trees via "reduce" method
score[j] += reduce<score_t, n_trees, OpAdd<score_t>>(scores[j], op_add);
}
tree_scores(x, scores);
Trees:
for(int i = 0; i < n_trees; i++){
Classes:
for(int j = 0; j < fn_classes(n_classes); j++){
score[j] += scores[i][j];
}
}
// Normalize predictions
for(int j = 0; j < fn_classes(n_classes); j++){
score[j] *= normalisation;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#include "parameters.h"

template<>
void BDT::BDT<n_trees, n_classes, input_arr_t, score_t, threshold_t>::tree_scores(input_arr_t x, score_t scores[n_trees][fn_classes(n_classes)]) const {
void BDT::BDT<n_trees, n_classes, input_arr_t, score_t, threshold_t>::tree_scores(input_arr_t x, score_t scores[fn_classes(n_classes)][n_trees]) const {
// conifer insert tree_scores
}

4 changes: 2 additions & 2 deletions conifer/backends/xilinxhls/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def write_bdt_h(self):
newline = ''
for it, trees in enumerate(self.trees):
for ic, tree in enumerate(trees):
newline += f' scores[{it}][{ic}] = tree_{it}_{ic}.decision_function(x);\n'
newline += f' scores[{ic}][{it}] = tree_{ic}_{it}.decision_function(x);\n'
else:
newline = line
fout.write(newline)
Expand Down Expand Up @@ -227,7 +227,7 @@ def _write_parameters_h_unrolled(self, fout):
for iclass, tree in enumerate(trees):
fout.write(f'static const BDT::Tree<{itree*nc+iclass}, {tree.n_nodes()}, {tree.n_leaves()}')
fout.write(f', input_arr_t, score_t, threshold_t>')
fout.write(f' tree_{itree}_{iclass} = {{\n')
fout.write(f' tree_{iclass}_{itree} = {{\n')
# loop over fields
for ifield, field in enumerate(tree_fields):
newline = ' {'
Expand Down

0 comments on commit a36ccde

Please sign in to comment.