Skip to content

Commit

Permalink
Fix windows error
Browse files Browse the repository at this point in the history
  • Loading branch information
bili2002 committed Sep 18, 2024
1 parent d54b084 commit 1b15a1d
Showing 1 changed file with 22 additions and 15 deletions.
37 changes: 22 additions & 15 deletions onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,17 @@ std::enable_if_t<
return dst;
}

template <typename T>
std::conditional_t<sizeof(T) == sizeof(uint32_t), uint32_t, uint64_t> bit_cast_int(T val) {
if constexpr (sizeof(T) == sizeof(uint32_t)) {
return bit_cast<uint32_t>(val);
}
else if constexpr (sizeof(T) == sizeof(uint64_t)) {
return bit_cast<uint64_t>(val);
}
static_assert(sizeof(T) == sizeof(uint32_t) || sizeof(T) == sizeof(uint64_t));
}

template <typename InputType, typename ThresholdType, typename OutputType>
Status TreeEnsembleCommon<InputType, ThresholdType, OutputType>::Init(const OpKernelInfo& info) {
std::vector<ThresholdType> base_values_as_tensor, nodes_hitrates_as_tensor,
Expand Down Expand Up @@ -376,11 +387,8 @@ bool TreeEnsembleCommon<InputType, ThresholdType, OutputType>::CheckIfSubtreesAr
}

if (cmodes[left_id] == NODE_MODE::LEAF) {
const auto left_tree_node = node_tree_ids[left_id];
const auto left_target_node = std::lower_bound(indices.begin(), indices.end(), std::make_pair(left_tree_node, uint32_t(0)))->second;

const auto right_tree_node = node_tree_ids[right_id];
const auto right_target_node = std::lower_bound(indices.begin(), indices.end(), std::make_pair(right_tree_node, uint32_t(0)))->second;
const auto left_target_node = std::lower_bound(indices.begin(), indices.end(), std::make_pair(node_tree_ids[left_id], uint32_t(0)))->second;
const auto right_target_node = std::lower_bound(indices.begin(), indices.end(), std::make_pair(node_tree_ids[right_id], uint32_t(0)))->second;

if (target_class_weights_as_tensor.empty()) {
return target_class_weights[left_target_node] == target_class_weights[right_target_node];
Expand Down Expand Up @@ -439,7 +447,7 @@ size_t TreeEnsembleCommon<InputType, ThresholdType, OutputType>::AddNodes(
}

node.value_or_unique_weight = 0;
const auto node_threshold = nodes_values_as_tensor.empty() ? static_cast<ThresholdType>(node_values[i]) : nodes_values_as_tensor[i];
const ThresholdType node_threshold = nodes_values_as_tensor.empty() ? static_cast<ThresholdType>(node_values[i]) : nodes_values_as_tensor[i];
if (node.flags == NODE_MODE::BRANCH_EQ && CANMASK(node_threshold, ThresholdType)) {
UpdateThreshold(node_threshold, node.value_or_unique_weight);
node.flags = NODE_MODE::BRANCH_MEMBER;
Expand All @@ -452,7 +460,7 @@ size_t TreeEnsembleCommon<InputType, ThresholdType, OutputType>::AddNodes(
}
nodes_.push_back(std::move(node));
if (nodes_[node_pos].is_not_leaf()) {
auto falsenode_id = falsenode_ids[i];
size_t falsenode_id = falsenode_ids[i];

// Categoricals are represented as a chain of `EQ` nodes where the subtree for the true child is identical for all nodes in the chain
// Below we are folding together these nodes into one of mode `BRANCH_MEMBER`
Expand All @@ -461,7 +469,7 @@ size_t TreeEnsembleCommon<InputType, ThresholdType, OutputType>::AddNodes(
// and the one of the feature (the mask has only one bit set on the place for its value)
// Beware that if a category is bigger than the threshold type, the node stays as `EQ` and no combination is done
if (nodes_[node_pos].flags == NODE_MODE::BRANCH_MEMBER) {
auto falsenode_threshold = nodes_values_as_tensor.empty() ? static_cast<ThresholdType>(node_values[falsenode_id]) : nodes_values_as_tensor[falsenode_id];
ThresholdType falsenode_threshold = nodes_values_as_tensor.empty() ? static_cast<ThresholdType>(node_values[falsenode_id]) : nodes_values_as_tensor[falsenode_id];

while (cmodes[falsenode_id] == NODE_MODE::BRANCH_EQ && nodes_[node_pos].feature_id == nodes_featureids[falsenode_id] &&
CANMASK(falsenode_threshold, ThresholdType) &&
Expand Down Expand Up @@ -783,16 +791,15 @@ void TreeEnsembleCommon<InputType, ThresholdType, OutputType>::ComputeAgg(concur
} \
}



// Check whether the feature value is set true in the mask
inline bool SetMembershipCheck(double val, double mask) {
const auto val_as_int = static_cast<int64_t>(val);
return CANMASK(val_as_int, double) && (((1ll << (val_as_int - 1)) & bit_cast<uint64_t>(mask)) != 0);
template <typename T1, typename T2>
inline bool SetMembershipCheck(T1 val, T2 mask) {
const int64_t val_as_int = static_cast<int64_t>(val);
return CANMASK(val, T2) && (((1ll << (val_as_int - 1)) & bit_cast_int(mask)) != 0);
}

inline bool SetMembershipCheck(float val, float mask) {
const auto val_as_int = static_cast<int64_t>(val);
return CANMASK(val_as_int, float) && (((1ll << (val_as_int - 1)) & bit_cast<uint32_t>(mask)) != 0);
}

inline bool _isnan_(float x) { return std::isnan(x); }
inline bool _isnan_(double x) { return std::isnan(x); }
Expand Down

0 comments on commit 1b15a1d

Please sign in to comment.