Skip to content

Commit

Permalink
Support categorical features with more than 64 categories
Browse files Browse the repository at this point in the history
  • Loading branch information
hcho3 committed Mar 8, 2018
1 parent 36dc714 commit 46a5d5e
Show file tree
Hide file tree
Showing 9 changed files with 57 additions and 46 deletions.
2 changes: 1 addition & 1 deletion include/treelite/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ TREELITE_DLL int TreeliteTreeBuilderSetNumericalTestNode(
TREELITE_DLL int TreeliteTreeBuilderSetCategoricalTestNode(
TreeBuilderHandle handle,
int node_key, unsigned feature_id,
const unsigned char* left_categories,
const unsigned int* left_categories,
size_t left_categories_len,
int default_left,
int left_child_key,
Expand Down
2 changes: 1 addition & 1 deletion include/treelite/frontend.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ class TreeBuilder {
*/
bool SetCategoricalTestNode(int node_key,
unsigned feature_id,
const std::vector<uint8_t>& left_categories,
const std::vector<uint32_t>& left_categories,
bool default_left, int left_child_key,
int right_child_key);
/*!
Expand Down
9 changes: 5 additions & 4 deletions include/treelite/tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class Tree {
return cmp_;
}
/*! \brief get categories for left child node */
inline const std::vector<uint8_t>& left_categories() const {
inline const std::vector<uint32_t>& left_categories() const {
return left_categories_;
}
/*! \brief get feature split type */
Expand Down Expand Up @@ -117,7 +117,7 @@ class Tree {
* threshold
*/
inline void set_categorical_split(unsigned split_index, bool default_left,
const std::vector<uint8_t>& left_categories) {
const std::vector<uint32_t>& left_categories) {
CHECK_LT(split_index, (1U << 31) - 1) << "split_index too big";
if (default_left) split_index |= (1U << 31);
this->sindex_ = split_index;
Expand Down Expand Up @@ -194,9 +194,10 @@ class Tree {
* \brief list of all categories belonging to the left node.
* Categories not in this list will belong to the right node.
* Categories are integers ranging from 0 to (n-1), where n is the number of
* categories in that particular feature. Let's assume n <= 64.
* categories in that particular feature.
* This list is assumed to be in ascending order.
*/
std::vector<uint8_t> left_categories_;
std::vector<uint32_t> left_categories_;
};

private:
Expand Down
3 changes: 1 addition & 2 deletions src/annotator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ void Traverse_(const treelite::Tree& tree, const Entry* data,
result = treelite::semantic::CompareWithOp(fvalue, op, threshold);
} else {
const auto fvalue = data[split_index].fvalue;
CHECK_LT(fvalue, 64) << "Cannot have more than 64 categories";
const uint8_t fvalue2 = static_cast<uint8_t>(fvalue);
const uint32_t fvalue2 = static_cast<uint32_t>(fvalue);
const auto left_categories = node.left_categories();
result = (std::binary_search(left_categories.begin(),
left_categories.end(), fvalue));
Expand Down
8 changes: 4 additions & 4 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -465,17 +465,17 @@ int TreeliteTreeBuilderSetNumericalTestNode(TreeBuilderHandle handle,
int TreeliteTreeBuilderSetCategoricalTestNode(
TreeBuilderHandle handle,
int node_key, unsigned feature_id,
const unsigned char* left_categories,
const unsigned int* left_categories,
size_t left_categories_len,
int default_left,
int left_child_key,
int right_child_key) {
API_BEGIN();
auto builder = static_cast<frontend::TreeBuilder*>(handle);
std::vector<uint8_t> vec(left_categories_len);
std::vector<uint32_t> vec(left_categories_len);
for (size_t i = 0; i < left_categories_len; ++i) {
CHECK(left_categories[i] <= std::numeric_limits<uint8_t>::max());
vec[i] = static_cast<uint8_t>(left_categories[i]);
CHECK(left_categories[i] <= std::numeric_limits<uint32_t>::max());
vec[i] = static_cast<uint32_t>(left_categories[i]);
}
return (builder->SetCategoricalTestNode(node_key, feature_id, vec,
(default_left != 0),
Expand Down
49 changes: 32 additions & 17 deletions src/compiler/recursive.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,27 +62,42 @@ class CategoricalSplitCondition : public treelite::semantic::Condition {
inline std::string Compile() const override {
const std::string bitmap
= std::string("data[") + std::to_string(split_index) + "].missing != -1";
const std::string comp
= std::string("((") + std::to_string(categorical_bitmap)
+ "U >> (unsigned int)(data[" + std::to_string(split_index)
+ "].fvalue)) & 1)";
return ((default_left) ? (std::string("!(") + bitmap + ") || ")
: (std::string(" (") + bitmap + ") && "))
+ ((categorical_bitmap == 0) ? std::string("0") : comp);
CHECK_GE(categorical_bitmap.size(), 1);
std::ostringstream comp;
comp << "(tmp = (unsigned int)(data[" << split_index << "].fvalue) ), "
<< "(tmp >= 0 && tmp < 64 && (( (uint64_t)"
<< categorical_bitmap[0] << "U >> tmp) & 1) )";
for (size_t i = 1; i < categorical_bitmap.size(); ++i) {
comp << " || (tmp >= " << (i * 64)
<< " && tmp < " << ((i + 1) * 64)
<< " && (( (uint64_t)" << categorical_bitmap[i]
<< "U >> (tmp - " << (i * 64) << ") ) & 1) )";
}
bool all_zeros = true;
for (uint64_t e : categorical_bitmap) {
all_zeros &= (e == 0);
}
return ((default_left) ? (std::string("!(") + bitmap + ") || (")
: (std::string(" (") + bitmap + ") && ("))
+ (all_zeros ? std::string("0") : comp.str()) + ")";
}

private:
unsigned split_index;
bool default_left;
uint64_t categorical_bitmap;

inline uint64_t to_bitmap(const std::vector<uint8_t>& left_categories) const {
uint64_t result = 0;
for (uint8_t e : left_categories) {
CHECK_LT(e, 64) << "Cannot have more than 64 categories in a feature";
result |= (static_cast<uint64_t>(1) << e);
std::vector<uint64_t> categorical_bitmap;

inline std::vector<uint64_t> to_bitmap(const std::vector<uint32_t>& left_categories) const {
const size_t num_left_categories = left_categories.size();
const uint32_t max_left_category = left_categories[num_left_categories - 1];
std::vector<uint64_t> bitmap((max_left_category + 1 + 63) / 64, 0);
for (size_t i = 0; i < left_categories.size(); ++i) {
const uint32_t cat = left_categories[i];
const size_t idx = cat / 64;
const uint32_t offset = cat % 64;
bitmap[idx] |= (static_cast<uint64_t>(1) << offset);
}
return result;
return bitmap;
}
};

Expand Down Expand Up @@ -598,9 +613,9 @@ inline std::string
GroupPolicy::Accumulator() const {
if (num_output_group > 1) {
return std::string("float sum[") + std::to_string(num_output_group)
+ "] = {0.0f};";
+ "] = {0.0f};\n unsigned int tmp;";
} else {
return "float sum = 0.0f;";
return "float sum = 0.0f;\n unsigned int tmp;";
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/frontend/builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ struct _Node {
// All others not in the list belong to the right child node.
// Categories are integers ranging from 0 to (n-1), where n is the number of
// categories in that particular feature. Let's assume n <= 64.
std::vector<uint8_t> left_categories;
std::vector<uint32_t> left_categories;

inline _Node()
: status(_Status::kEmpty),
Expand Down Expand Up @@ -188,7 +188,7 @@ TreeBuilder::SetNumericalTestNode(int node_key,
bool
TreeBuilder::SetCategoricalTestNode(int node_key,
unsigned feature_id,
const std::vector<uint8_t>& left_categories,
const std::vector<uint32_t>& left_categories,
bool default_left, int left_child_key,
int right_child_key) {
auto& tree = pimpl->tree;
Expand Down
20 changes: 8 additions & 12 deletions src/frontend/lightgbm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,13 @@ inline bool GetDecisionType(int8_t decision_type, int8_t mask) {
return (decision_type & mask) > 0;
}

inline std::vector<uint8_t> BitsetToList(const uint32_t* bits,
uint8_t nslots) {
std::vector<uint8_t> result;
CHECK(nslots == 1 || nslots == 2);
const uint8_t nbits = nslots * 32;
for (uint8_t i = 0; i < nbits; ++i) {
const uint8_t i1 = i / 32;
const uint8_t i2 = i % 32;
inline std::vector<uint32_t> BitsetToList(const uint32_t* bits,
uint8_t nslots) {
std::vector<uint32_t> result;
const uint32_t nbits = static_cast<uint32_t>(nslots) * 32;
for (uint32_t i = 0; i < nbits; ++i) {
const uint32_t i1 = i / 32;
const uint32_t i2 = i % 32;
if ((bits[i1] >> i2) & 1) {
result.push_back(i);
}
Expand Down Expand Up @@ -339,10 +338,7 @@ inline treelite::Model ParseStream(dmlc::Stream* fi) {
if (GetDecisionType(lgb_tree.decision_type[old_id], kCategoricalMask)) {
// categorical
const int cat_idx = static_cast<int>(lgb_tree.threshold[old_id]);
CHECK_LE(lgb_tree.cat_boundaries[cat_idx + 1]
- lgb_tree.cat_boundaries[cat_idx], 2)
<< "Categorical features must have 64 categories or fewer.";
const std::vector<uint8_t> left_categories
const std::vector<uint32_t> left_categories
= BitsetToList(lgb_tree.cat_threshold.data()
+ lgb_tree.cat_boundaries[cat_idx],
lgb_tree.cat_boundaries[cat_idx + 1]
Expand Down
6 changes: 3 additions & 3 deletions src/frontend/protobuf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,11 @@ Model LoadProtobufModel(const char* filename) {
<< "split_index must be between 0 and [num_feature] - 1.";
CHECK_GE(split_index, 0) << "split_index must be positive.";
const int left_categories_size = node.left_categories_size();
std::vector<uint8_t> left_categories;
std::vector<uint32_t> left_categories;
for (int i = 0; i < left_categories_size; ++i) {
const auto cat = node.left_categories(i);
CHECK(cat <= std::numeric_limits<uint8_t>::max());
left_categories.push_back(static_cast<uint8_t>(cat));
CHECK(cat <= std::numeric_limits<uint32_t>::max());
left_categories.push_back(static_cast<uint32_t>(cat));
}
tree.AddChilds(id);
tree[id].set_categorical_split(static_cast<unsigned>(split_index),
Expand Down

0 comments on commit 46a5d5e

Please sign in to comment.