Skip to content

Commit

Permalink
fix merge conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
xadupre committed Oct 7, 2024
2 parents 8d0bc46 + 11c3d38 commit 125eedc
Show file tree
Hide file tree
Showing 3 changed files with 236 additions and 25 deletions.
1 change: 1 addition & 0 deletions onnxruntime/core/providers/cpu/ml/ml_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ enum class NODE_MODE_V5 : uint8_t {
BRANCH_NEQ = 5,
BRANCH_MEMBER = 6,
LEAF = 7
BRANCH_MEMBER = 14
};

static inline NODE_MODE MakeTreeNodeMode(const std::string& input) {
Expand Down
176 changes: 151 additions & 25 deletions onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,48 @@ class TreeEnsembleCommon : public TreeEnsembleCommonAttributes {
void ComputeAgg(concurrency::ThreadPool* ttp, const Tensor* X, Tensor* Y, Tensor* label, const AGG& agg) const;

private:
bool CheckIfSubtreesAreEqual(const size_t left_id, const size_t right_id, const int64_t tree_id, const InlinedVector<NODE_MODE>& cmodes,
const InlinedVector<size_t>& truenode_ids, const InlinedVector<size_t>& falsenode_ids, const std::vector<int64_t>& nodes_featureids,
const std::vector<ThresholdType>& nodes_values_as_tensor, const std::vector<float>& node_values,
const std::vector<float>& target_class_weights, const std::vector<ThresholdType>& target_class_weights_as_tensor,
const InlinedVector<TreeNodeElementId>& node_tree_ids, InlinedVector<std::pair<TreeNodeElementId, uint32_t>> indices);
size_t AddNodes(const size_t i, const InlinedVector<NODE_MODE>& cmodes, const InlinedVector<size_t>& truenode_ids,
const InlinedVector<size_t>& falsenode_ids, const std::vector<int64_t>& nodes_featureids,
const std::vector<ThresholdType>& nodes_values_as_tensor, const std::vector<float>& node_values,
const std::vector<int64_t>& nodes_missing_value_tracks_true, std::vector<size_t>& updated_mapping,
int64_t tree_id, const InlinedVector<TreeNodeElementId>& node_tree_ids);
int64_t tree_id, const InlinedVector<TreeNodeElementId>& node_tree_ids, const std::vector<float>& target_class_weights,
const std::vector<ThresholdType>& target_class_weights_as_tensor, InlinedVector<std::pair<TreeNodeElementId, uint32_t>> indices);
};

// Below is simple implementation of `bit_cast` as it is supported from c++20 and the current supported version is c++17
// Remove it when that is not the case
template <class To, class From>
std::enable_if_t<
sizeof(To) == sizeof(From) &&
std::is_trivially_copyable_v<From> &&
std::is_trivially_copyable_v<To>,
To>
// constexpr support needs compiler magic
static bit_cast(const From& src) noexcept {
static_assert(std::is_trivially_constructible_v<To>,
"This implementation additionally requires "
"destination type to be trivially constructible");

To dst;
std::memcpy(&dst, &src, sizeof(To));
return dst;
}

template <typename T>
std::conditional_t<sizeof(T) == sizeof(uint32_t), uint32_t, uint64_t> bit_cast_int(T val) {
if constexpr (sizeof(T) == sizeof(uint32_t)) {
return bit_cast<uint32_t>(val);
} else if constexpr (sizeof(T) == sizeof(uint64_t)) {
return bit_cast<uint64_t>(val);
}
static_assert(sizeof(T) == sizeof(uint32_t) || sizeof(T) == sizeof(uint64_t));
}

template <typename InputType, typename ThresholdType, typename OutputType>
Status TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(const OpKernelInfo& info) {
std::vector<ThresholdType> base_values_as_tensor, nodes_hitrates_as_tensor,
Expand Down Expand Up @@ -270,6 +305,16 @@ Status TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(
}
}

// Sort targets
InlinedVector<std::pair<TreeNodeElementId, uint32_t>> indices;
indices.reserve(target_class_nodeids.size());
for (i = 0, limit = target_class_nodeids.size(); i < limit; i++) {
indices.emplace_back(
TreeNodeElementId{target_class_treeids[i], target_class_nodeids[i]}, i);
}

std::sort(indices.begin(), indices.end());

// Let's construct nodes_ such that the false branch is always the next element in nodes_.
// updated_mapping will translates the old position of each node to the new node position in nodes_.
std::vector<size_t> updated_mapping(nodes_treeids.size(), 0);
Expand All @@ -280,26 +325,13 @@ Status TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(
int64_t tree_id = node_tree_ids[i].tree_id;
size_t root_position =
AddNodes(i, cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor, nodes_values,
nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids);
nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids,
target_class_weights, target_class_weights_as_tensor, indices);
roots_.push_back(&nodes_[root_position]);
previous_tree_id = tree_id;
}
}

n_trees_ = roots_.size();
if (((int64_t)nodes_.size()) != n_nodes_) {
ORT_THROW("Number of nodes in nodes_ (", nodes_.size(), ") is different from n_nodes (", n_nodes_, ").");
}

// Sort targets
InlinedVector<std::pair<TreeNodeElementId, uint32_t>> indices;
indices.reserve(target_class_nodeids.size());
for (i = 0, limit = target_class_nodeids.size(); i < limit; i++) {
indices.emplace_back(
std::pair<TreeNodeElementId, uint32_t>(TreeNodeElementId{target_class_treeids[i], target_class_nodeids[i]}, i));
}

std::sort(indices.begin(), indices.end());

TreeNodeElementId ind;
SparseValue<ThresholdType> w;
Expand Down Expand Up @@ -341,13 +373,56 @@ Status TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(
return Status::OK();
}

template <typename InputType, typename ThresholdType, typename OutputType>
bool TreeEnsembleCommon<InputType, ThresholdType, OutputType>::CheckIfSubtreesAreEqual(
const size_t left_id, const size_t right_id, const int64_t tree_id, const InlinedVector<NODE_MODE>& cmodes,
const InlinedVector<size_t>& truenode_ids, const InlinedVector<size_t>& falsenode_ids, const std::vector<int64_t>& nodes_featureids,
const std::vector<ThresholdType>& nodes_values_as_tensor, const std::vector<float>& node_values,
const std::vector<float>& target_class_weights, const std::vector<ThresholdType>& target_class_weights_as_tensor,
const InlinedVector<TreeNodeElementId>& node_tree_ids, InlinedVector<std::pair<TreeNodeElementId, uint32_t>> indices) {
// Leaves have values set at 0
if (cmodes[left_id] != cmodes[right_id] || nodes_featureids[left_id] != nodes_featureids[right_id] || (!nodes_values_as_tensor.empty() && nodes_values_as_tensor[left_id] != nodes_values_as_tensor[right_id]) || (nodes_values_as_tensor.empty() && node_values[left_id] != node_values[right_id])) {
return false;
}

if (cmodes[left_id] == NODE_MODE::LEAF) {
const auto left_target_node = std::lower_bound(indices.begin(), indices.end(), std::make_pair(node_tree_ids[left_id], uint32_t(0)))->second;
const auto right_target_node = std::lower_bound(indices.begin(), indices.end(), std::make_pair(node_tree_ids[right_id], uint32_t(0)))->second;

if (target_class_weights_as_tensor.empty()) {
return target_class_weights[left_target_node] == target_class_weights[right_target_node];
} else {
return target_class_weights_as_tensor[left_target_node] == target_class_weights_as_tensor[right_target_node];
}
}

return CheckIfSubtreesAreEqual(falsenode_ids[left_id], falsenode_ids[right_id], tree_id, cmodes, truenode_ids, falsenode_ids, nodes_featureids,
nodes_values_as_tensor, node_values, target_class_weights, target_class_weights_as_tensor, node_tree_ids, indices) &&
CheckIfSubtreesAreEqual(truenode_ids[left_id], truenode_ids[right_id], tree_id, cmodes, truenode_ids, falsenode_ids, nodes_featureids,
nodes_values_as_tensor, node_values, target_class_weights, target_class_weights_as_tensor, node_tree_ids, indices);
}

inline void UpdateThreshold(double val, double& mask) {
uint64_t new_mask = bit_cast<uint64_t>(mask) | (1ll << (static_cast<uint32_t>(val) - 1));
mask = bit_cast<double>(new_mask);
}

inline void UpdateThreshold(float val, float& mask) {
uint32_t new_mask = bit_cast<uint32_t>(mask) | (1 << (static_cast<uint32_t>(val) - 1));
mask = bit_cast<float>(new_mask);
}

#define BITCOUNT(T) int64_t(sizeof(T) * 8)
#define CANMASK(v, T) (v >= 1 && v <= BITCOUNT(T)) && v == std::floor(v)

template <typename InputType, typename ThresholdType, typename OutputType>
size_t TreeEnsembleCommon<InputType, ThresholdType, OutputType>::AddNodes(
const size_t i, const InlinedVector<NODE_MODE>& cmodes, const InlinedVector<size_t>& truenode_ids,
const InlinedVector<size_t>& falsenode_ids, const std::vector<int64_t>& nodes_featureids,
const std::vector<ThresholdType>& nodes_values_as_tensor, const std::vector<float>& node_values,
const std::vector<int64_t>& nodes_missing_value_tracks_true, std::vector<size_t>& updated_mapping, int64_t tree_id,
const InlinedVector<TreeNodeElementId>& node_tree_ids) {
const InlinedVector<TreeNodeElementId>& node_tree_ids, const std::vector<float>& target_class_weights,
const std::vector<ThresholdType>& target_class_weights_as_tensor, InlinedVector<std::pair<TreeNodeElementId, uint32_t>> indices) {
// Validate this index maps to the same tree_id as the one we should be building.
if (node_tree_ids[i].tree_id != tree_id) {
ORT_THROW("Tree id mismatch. Expected ", tree_id, " but got ", node_tree_ids[i].tree_id, " at position ", i);
Expand All @@ -369,23 +444,54 @@ size_t TreeEnsembleCommon<InputType, ThresholdType, OutputType>::AddNodes(
if (node.feature_id > max_feature_id_) {
max_feature_id_ = node.feature_id;
}
node.value_or_unique_weight =
nodes_values_as_tensor.empty() ? static_cast<ThresholdType>(node_values[i]) : nodes_values_as_tensor[i];

node.value_or_unique_weight = 0;
const ThresholdType node_threshold = nodes_values_as_tensor.empty() ? static_cast<ThresholdType>(node_values[i]) : nodes_values_as_tensor[i];
if (node.flags == NODE_MODE::BRANCH_EQ && CANMASK(node_threshold, ThresholdType)) {
UpdateThreshold(node_threshold, node.value_or_unique_weight);
node.flags = NODE_MODE::BRANCH_MEMBER;
} else {
node.value_or_unique_weight = node_threshold;
}

if (i < static_cast<size_t>(nodes_missing_value_tracks_true.size()) && nodes_missing_value_tracks_true[i] == 1) {
node.flags |= static_cast<uint8_t>(MissingTrack::kTrue);
}
nodes_.push_back(std::move(node));
if (nodes_[node_pos].is_not_leaf()) {
size_t falsenode_id = falsenode_ids[i];

// Categoricals are represented as a chain of `EQ` nodes where the subtree for the true child is identical for all nodes in the chain
// Below we are folding together these nodes into one of mode `BRANCH_MEMBER`
// The threshold of this node should be interpreted as a bitmask showing which categoricals values were found in the chain
// Afterwards, when looking whether a feature is included we can do an `and` with the mask of the node
// and the one of the feature (the mask has only one bit set on the place for its value)
// Beware that if a category is bigger than the threshold type, the node stays as `EQ` and no combination is done
if (nodes_[node_pos].flags == NODE_MODE::BRANCH_MEMBER) {
ThresholdType falsenode_threshold = nodes_values_as_tensor.empty() ? static_cast<ThresholdType>(node_values[falsenode_id]) : nodes_values_as_tensor[falsenode_id];

while (cmodes[falsenode_id] == NODE_MODE::BRANCH_EQ && nodes_[node_pos].feature_id == nodes_featureids[falsenode_id] &&
CANMASK(falsenode_threshold, ThresholdType) &&
CheckIfSubtreesAreEqual(truenode_ids[i], truenode_ids[falsenode_id], tree_id, cmodes, truenode_ids, falsenode_ids,
nodes_featureids, nodes_values_as_tensor, node_values, target_class_weights, target_class_weights_as_tensor, node_tree_ids, indices)) {
UpdateThreshold(falsenode_threshold, nodes_[node_pos].value_or_unique_weight);
falsenode_id = falsenode_ids[falsenode_id];
falsenode_threshold = nodes_values_as_tensor.empty() ? static_cast<ThresholdType>(node_values[falsenode_id]) : nodes_values_as_tensor[falsenode_id];
}
}

size_t false_branch =
AddNodes(falsenode_ids[i], cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor,
node_values, nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids);
AddNodes(falsenode_id, cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor,
node_values, nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids,
target_class_weights, target_class_weights_as_tensor, indices);
if (false_branch != node_pos + 1) {
ORT_THROW("False node must always be the next node, but it isn't at index ", node_pos, " with flags ",
static_cast<int>(nodes_[node_pos].flags));
}
size_t true_branch =
AddNodes(truenode_ids[i], cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor,
node_values, nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids);
node_values, nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids,
target_class_weights, target_class_weights_as_tensor, indices);
// We don't need to store the false branch pointer since we know it is always in the immediate next entry in nodes_.
// nodes_[node_pos].falsenode_inc_or_n_weights.ptr = &nodes_[false_branch];
nodes_[node_pos].truenode_or_weight.ptr = &nodes_[true_branch];
Expand Down Expand Up @@ -684,6 +790,13 @@ void TreeEnsembleCommon<InputType, ThresholdType, OutputType>::ComputeAgg(concur
} \
}

// Check whether the feature value is set true in the mask
template <typename T1, typename T2>
inline bool SetMembershipCheck(T1 val, T2 mask) {
const int64_t val_as_int = static_cast<int64_t>(val);
return CANMASK(val, T2) && (((1ll << (val_as_int - 1)) & bit_cast_int(mask)) != 0);
}

inline bool _isnan_(float x) { return std::isnan(x); }
inline bool _isnan_(double x) { return std::isnan(x); }
inline bool _isnan_(int64_t) { return false; }
Expand Down Expand Up @@ -727,8 +840,19 @@ TreeEnsembleCommon<InputType, ThresholdType, OutputType>::ProcessTreeNodeLeave(
TREE_FIND_VALUE(!=)
break;
case NODE_MODE::BRANCH_MEMBER:
ORT_THROW("NODE_MODE::BRANCH_MEMBER is not implemented.");
break;
if (has_missing_tracks_) {
while (root->is_not_leaf()) {
val = x_data[root->feature_id];
root = (SetMembershipCheck(val, root->value_or_unique_weight) || (root->is_missing_track_true() && _isnan_(val)))
? root->truenode_or_weight.ptr
: root + 1;
}
} else {
while (root->is_not_leaf()) {
val = x_data[root->feature_id];
root = SetMembershipCheck(val, root->value_or_unique_weight) ? root->truenode_or_weight.ptr : root + 1;
}
}
case NODE_MODE::LEAF:
break;
}
Expand Down Expand Up @@ -763,7 +887,9 @@ TreeEnsembleCommon<InputType, ThresholdType, OutputType>::ProcessTreeNodeLeave(
: root + 1;
break;
case NODE_MODE::BRANCH_MEMBER:
ORT_THROW("NODE_MODE::BRANCH_MEMBER is not implemented.");
root = (SetMembershipCheck(val, root->value_or_unique_weight) || (root->is_missing_track_true() && _isnan_(val)))
? root->truenode_or_weight.ptr
: root + 1;
break;
case NODE_MODE::LEAF:
return root;
Expand Down
84 changes: 84 additions & 0 deletions onnxruntime/test/providers/cpu/ml/treeregressor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,90 @@ TEST(MLOpTest, TreeRegressorSingleTargetSum_as_tensor_precision) {
GenTreeAndRunTest1_as_tensor_precision(3);
}

TEST(MLOpTest, TreeRegressorCategoricals) {
OpTester test("TreeEnsembleRegressor", 3, onnxruntime::kMLDomain);

// tree
int64_t n_targets = 1;
std::vector<int64_t> nodes_featureids = {0, 0, 0, 0, 1, 0, 0};
std::vector<std::string> nodes_modes = {"BRANCH_EQ", "BRANCH_EQ", "BRANCH_EQ", "LEAF", "BRANCH_LEQ", "LEAF", "LEAF"};
std::vector<float> nodes_values = {1, 3, 4, 0, 5.5, 0, 0};

std::vector<int64_t> nodes_treeids = {0, 0, 0, 0, 0, 0, 0};
std::vector<int64_t> nodes_nodeids = {0, 1, 2, 3, 4, 5, 6};
std::vector<int64_t> nodes_falsenodeids = {1, 2, 3, 0, 5, 0, 0};
std::vector<int64_t> nodes_truenodeids = {4, 4, 4, 0, 6, 0, 0};

std::string post_transform = "NONE";
std::vector<int64_t> target_ids = {0, 0, 0};
std::vector<int64_t> target_nodeids = {3, 5, 6};
std::vector<int64_t> target_treeids = {0, 0, 0};
std::vector<float> target_weights = {-4.699999809265137, 17.700000762939453, 11.100000381469727};

// add attributes
test.AddAttribute("nodes_truenodeids", nodes_truenodeids);
test.AddAttribute("nodes_falsenodeids", nodes_falsenodeids);
test.AddAttribute("nodes_treeids", nodes_treeids);
test.AddAttribute("nodes_nodeids", nodes_nodeids);
test.AddAttribute("nodes_featureids", nodes_featureids);
test.AddAttribute("nodes_values", nodes_values);
test.AddAttribute("nodes_modes", nodes_modes);
test.AddAttribute("target_treeids", target_treeids);
test.AddAttribute("target_nodeids", target_nodeids);
test.AddAttribute("target_ids", target_ids);
test.AddAttribute("target_weights", target_weights);
test.AddAttribute("n_targets", n_targets);

// fill input data
std::vector<float> X = {3.0f, 6.6f, 1.0f, 5.0f, 5.0f, 5.5f};
std::vector<float> Y = {17.700000762939453, 11.100000381469727, -4.699999809265137};
test.AddInput<float>("X", {3, 2}, X);
test.AddOutput<float>("Y", {3, 1}, Y);
test.Run();
}

TEST(MLOpTest, TreeRegressorCategoricalsFolding) {
OpTester test("TreeEnsembleRegressor", 3, onnxruntime::kMLDomain);

// tree
int64_t n_targets = 1;
std::vector<int64_t> nodes_featureids = {0, 0, 1, 1, 0, 0, 0};
std::vector<std::string> nodes_modes = {"BRANCH_EQ", "BRANCH_EQ", "BRANCH_EQ", "BRANCH_EQ", "LEAF", "LEAF", "LEAF"};
std::vector<float> nodes_values = {1, 3, 2, 3, 0, 0, 0};

std::vector<int64_t> nodes_treeids = {0, 0, 0, 0, 0, 0, 0};
std::vector<int64_t> nodes_nodeids = {0, 1, 2, 3, 4, 5, 6};
std::vector<int64_t> nodes_falsenodeids = {1, 2, 3, 4, 0, 0, 0};
std::vector<int64_t> nodes_truenodeids = {5, 5, 6, 6, 0, 0, 0};

std::string post_transform = "NONE";
std::vector<int64_t> target_ids = {0, 0, 0};
std::vector<int64_t> target_nodeids = {4, 5, 6};
std::vector<int64_t> target_treeids = {0, 0, 0};
std::vector<float> target_weights = {17.700000762939453, 11.100000381469727, -4.699999809265137};

// add attributes
test.AddAttribute("nodes_truenodeids", nodes_truenodeids);
test.AddAttribute("nodes_falsenodeids", nodes_falsenodeids);
test.AddAttribute("nodes_treeids", nodes_treeids);
test.AddAttribute("nodes_nodeids", nodes_nodeids);
test.AddAttribute("nodes_featureids", nodes_featureids);
test.AddAttribute("nodes_values", nodes_values);
test.AddAttribute("nodes_modes", nodes_modes);
test.AddAttribute("target_treeids", target_treeids);
test.AddAttribute("target_nodeids", target_nodeids);
test.AddAttribute("target_ids", target_ids);
test.AddAttribute("target_weights", target_weights);
test.AddAttribute("n_targets", n_targets);

// fill input data
std::vector<float> X = {1.0f, 2.0f, 3.0f, 1.0f, 2.0f, 3.0f, 2.0f, 1.0f};
std::vector<float> Y = {11.100000381469727, 11.100000381469727, -4.699999809265137, 17.700000762939453};
test.AddInput<float>("X", {4, 2}, X);
test.AddOutput<float>("Y", {4, 1}, Y);
test.Run();
}

TEST(MLOpTest, TreeRegressorTrueNodeBeforeNode) {
OpTester test("TreeEnsembleRegressor", 3, onnxruntime::kMLDomain);

Expand Down

0 comments on commit 125eedc

Please sign in to comment.