diff --git a/.readthedocs.yaml b/.readthedocs.yaml
index 952cd0a5..e85991a8 100644
--- a/.readthedocs.yaml
+++ b/.readthedocs.yaml
@@ -28,4 +28,3 @@ sphinx:
python:
install:
- requirements: docs/requirements.txt
- system_packages: true
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 46219eb1..5f8a9655 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -1,7 +1,7 @@
cmake_policy(SET CMP0091 NEW)
set(CMAKE_FIND_NO_INSTALL_PREFIX TRUE FORCE)
cmake_minimum_required (VERSION 3.16)
-project(treelite LANGUAGES CXX C VERSION 3.9.0)
+project(treelite LANGUAGES CXX C VERSION 3.9.1)
# check MSVC version
if(MSVC)
diff --git a/ops/conda_env/dev.yml b/ops/conda_env/dev.yml
index 18b485f8..52824cbb 100644
--- a/ops/conda_env/dev.yml
+++ b/ops/conda_env/dev.yml
@@ -18,10 +18,10 @@ dependencies:
- llvm-openmp
- cython
- lightgbm
-- xgboost
- cpplint
- pylint
- awscli
- pip
- pip:
- cibuildwheel
+ - xgboost>=2.0
diff --git a/ops/cpp-python-coverage.sh b/ops/cpp-python-coverage.sh
index 7e98702c..6f142361 100755
--- a/ops/cpp-python-coverage.sh
+++ b/ops/cpp-python-coverage.sh
@@ -9,7 +9,7 @@ echo "##[section]Building Treelite..."
mkdir build/
cd build/
cmake .. -DTEST_COVERAGE=ON -DCMAKE_BUILD_TYPE=Debug -DBUILD_CPP_TEST=ON -GNinja
-ninja
+ninja install -v
cd ..
echo "##[section]Running Google C++ tests..."
@@ -17,7 +17,7 @@ echo "##[section]Running Google C++ tests..."
echo "##[section]Build Cython extension..."
cd tests/cython
-python setup.py build_ext --inplace
+pip install -vvv .
cd ../..
echo "##[section]Running Python integration tests..."
diff --git a/python/treelite/VERSION b/python/treelite/VERSION
index a5c4c763..6bd10744 100644
--- a/python/treelite/VERSION
+++ b/python/treelite/VERSION
@@ -1 +1 @@
-3.9.0
+3.9.1
diff --git a/runtime/java/treelite4j/pom.xml b/runtime/java/treelite4j/pom.xml
index aa6e8267..e9fe04f2 100644
--- a/runtime/java/treelite4j/pom.xml
+++ b/runtime/java/treelite4j/pom.xml
@@ -5,7 +5,7 @@
4.0.0
ml.dmlc
treelite4j
- 3.9.0
+ 3.9.1
jar
UTF-8
diff --git a/runtime/python/treelite_runtime/VERSION b/runtime/python/treelite_runtime/VERSION
index a5c4c763..6bd10744 100644
--- a/runtime/python/treelite_runtime/VERSION
+++ b/runtime/python/treelite_runtime/VERSION
@@ -1 +1 @@
-3.9.0
+3.9.1
diff --git a/src/frontend/xgboost.cc b/src/frontend/xgboost.cc
index 863ae173..a022c1bd 100644
--- a/src/frontend/xgboost.cc
+++ b/src/frontend/xgboost.cc
@@ -296,7 +296,7 @@ class XGBTree {
nodes[nodes[nid].cleft() ].set_parent(nid, true);
nodes[nodes[nid].cright()].set_parent(nid, false);
}
- inline void Load(PeekableInputStream* fi) {
+ inline void Load(PeekableInputStream* fi, LearnerModelParam const& mparam) {
TREELITE_CHECK_EQ(fi->Read(¶m, sizeof(TreeParam)), sizeof(TreeParam))
<< "Ill-formed XGBoost model file: can't read TreeParam";
TREELITE_CHECK_GT(param.num_nodes, 0)
@@ -309,13 +309,17 @@ class XGBTree {
TREELITE_CHECK_EQ(fi->Read(stats.data(), sizeof(NodeStat) * stats.size()),
sizeof(NodeStat) * stats.size())
<< "Ill-formed XGBoost model file: cannot read specified number of nodes";
- if (param.size_leaf_vector != 0) {
+ if (param.size_leaf_vector != 0 && mparam.major_version < 2) {
uint64_t len;
TREELITE_CHECK_EQ(fi->Read(&len, sizeof(len)), sizeof(len))
<< "Ill-formed XGBoost model file";
if (len > 0) {
CONSUME_BYTES(fi, sizeof(bst_float) * len);
}
+ } else if (mparam.major_version == 2) {
+ TREELITE_CHECK_EQ(param.size_leaf_vector, 1)
+ << "Multi-target models are not supported with binary serialization. "
+ << "Please save the XGBoost model using the JSON format.";
}
TREELITE_CHECK_EQ(param.num_roots, 1)
<< "Invalid XGBoost model file: treelite does not support trees "
@@ -378,7 +382,7 @@ inline std::unique_ptr ParseStream(std::istream& fi) {
<< "Invalid XGBoost model file: num_trees must be 0 or greater";
for (int i = 0; i < gbm_param_.num_trees; ++i) {
xgb_trees_.emplace_back();
- xgb_trees_.back().Load(fp.get());
+ xgb_trees_.back().Load(fp.get(), mparam_);
}
if (mparam_.major_version < 1 || (mparam_.major_version == 1 && mparam_.minor_version < 6)) {
// In XGBoost 1.6, num_roots is used as num_parallel_tree, so don't check
diff --git a/src/frontend/xgboost_json.cc b/src/frontend/xgboost_json.cc
index 725eff5c..7b37d89f 100644
--- a/src/frontend/xgboost_json.cc
+++ b/src/frontend/xgboost_json.cc
@@ -366,7 +366,8 @@ bool GBTreeModelHandler::StartArray() {
return (push_key_handler, RegTreeHandler>,
std::vector>>(
"trees", output.model->trees) ||
- push_key_handler, std::vector>("tree_info", output.tree_info));
+ push_key_handler, std::vector>("tree_info", output.tree_info) ||
+ push_key_handler("iteration_indptr"));
}
bool GBTreeModelHandler::StartObject() {
@@ -377,7 +378,8 @@ bool GBTreeModelHandler::StartObject() {
}
bool GBTreeModelHandler::is_recognized_key(const std::string& key) {
- return (key == "trees" || key == "tree_info" || key == "gbtree_model_param");
+ return (key == "trees" || key == "tree_info" || key == "gbtree_model_param"
+ || key == "iteration_indptr");
}
/******************************************************************************
@@ -460,7 +462,8 @@ bool ObjectiveHandler::StartObject() {
push_key_handler("lambda_rank_param") ||
push_key_handler("aft_loss_param") ||
push_key_handler("pseduo_huber_param") ||
- push_key_handler("pseudo_huber_param"));
+ push_key_handler("pseudo_huber_param") ||
+ push_key_handler("lambdarank_param"));
}
bool ObjectiveHandler::String(const char *str, std::size_t length, bool) {
@@ -474,7 +477,8 @@ bool ObjectiveHandler::is_recognized_key(const std::string& key) {
return (key == "reg_loss_param" || key == "poisson_regression_param"
|| key == "tweedie_regression_param" || key == "softmax_multiclass_param"
|| key == "lambda_rank_param" || key == "aft_loss_param"
- || key == "pseduo_huber_param" || key == "pseudo_huber_param" || key == "name");
+ || key == "pseduo_huber_param" || key == "pseudo_huber_param"
+ || key == "lambdarank_param" || key == "name");
}
/******************************************************************************
diff --git a/tests/cython/CMakeLists.txt b/tests/cython/CMakeLists.txt
new file mode 100644
index 00000000..4750b0cd
--- /dev/null
+++ b/tests/cython/CMakeLists.txt
@@ -0,0 +1,40 @@
+cmake_minimum_required(VERSION 3.18)
+
+project(treelite_serializer_ext LANGUAGES CXX)
+
+find_package(
+ Python
+ COMPONENTS Interpreter Development.Module
+ REQUIRED)
+
+find_program(CYTHON "cython")
+
+find_package(Treelite REQUIRED)
+
+add_custom_command(
+ OUTPUT serializer.cpp
+ DEPENDS serializer.pyx
+ VERBATIM
+ COMMAND "${CYTHON}" "${PROJECT_SOURCE_DIR}/serializer.pyx" --output-file
+ "${PROJECT_BINARY_DIR}/serializer.cpp")
+
+if(DEFINED ENV{CONDA_PREFIX})
+ set(CMAKE_PREFIX_PATH "$ENV{CONDA_PREFIX};${CMAKE_PREFIX_PATH}")
+ message(STATUS "Detected Conda environment, CMAKE_PREFIX_PATH set to: ${CMAKE_PREFIX_PATH}")
+ if(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT)
+ message(STATUS "No CMAKE_INSTALL_PREFIX argument detected, setting to: $ENV{CONDA_PREFIX}")
+ set(CMAKE_INSTALL_PREFIX $ENV{CONDA_PREFIX})
+ endif()
+else()
+ message(STATUS "No Conda environment detected")
+endif()
+
+python_add_library(serializer MODULE "${PROJECT_BINARY_DIR}/serializer.cpp" WITH_SOABI)
+target_link_libraries(serializer PRIVATE treelite::treelite)
+set_target_properties(serializer
+ PROPERTIES
+ POSITION_INDEPENDENT_CODE ON
+ CXX_STANDARD 17
+ CXX_STANDARD_REQUIRED ON)
+
+install(TARGETS serializer DESTINATION "${PROJECT_SOURCE_DIR}")
diff --git a/tests/cython/pyproject.toml b/tests/cython/pyproject.toml
new file mode 100644
index 00000000..01ed7451
--- /dev/null
+++ b/tests/cython/pyproject.toml
@@ -0,0 +1,7 @@
+[build-system]
+requires = ["scikit-build-core", "cython"]
+build-backend = "scikit_build_core.build"
+
+[project]
+name = "example"
+version = "0.0.1"
diff --git a/tests/cython/setup.py b/tests/cython/setup.py
deleted file mode 100644
index f5500662..00000000
--- a/tests/cython/setup.py
+++ /dev/null
@@ -1,27 +0,0 @@
-import os
-from distutils.sysconfig import get_python_lib
-from setuptools import setup, find_packages
-from setuptools.extension import Extension
-from Cython.Distutils.build_ext import new_build_ext as build_ext
-
-extensions = [
- Extension('*',
- sources=['*.pyx'],
- include_dirs=['../../include', '../../build/include'],
- library_dirs=[get_python_lib(), '../../build/'],
- #runtime_library_dirs=[os.path.join(os.sys.prefix, 'lib')],
- libraries=['treelite'],
- language='c++',
- extra_compile_args=['--std=c++14'])
-]
-
-setup(
- name='cython_test',
- version='0.0.1',
- setup_requires=['cython'],
- ext_modules=extensions,
- packages=find_packages(),
- install_requires=['cython'],
- cmdclass={'build_ext': build_ext},
- zip_safe=False
-)
diff --git a/tests/python/test_gtil.py b/tests/python/test_gtil.py
index 4dd01504..761f4833 100644
--- a/tests/python/test_gtil.py
+++ b/tests/python/test_gtil.py
@@ -8,7 +8,7 @@
import numpy as np
import pytest
import scipy
-from hypothesis import assume, given, settings
+from hypothesis import given, settings
from hypothesis.strategies import data as hypothesis_callback
from hypothesis.strategies import floats, integers, just, sampled_from
from sklearn.datasets import load_svmlight_file
@@ -198,16 +198,15 @@ def test_skl_hist_gradient_boosting_with_categorical():
treelite.sklearn.import_model(clf)
+@pytest.mark.parametrize("objective",
+ [
+ "reg:linear",
+ "reg:squarederror",
+ "reg:squaredlogerror",
+ "reg:pseudohubererror",
+ ])
@given(
dataset=standard_regression_datasets(),
- objective=sampled_from(
- [
- "reg:linear",
- "reg:squarederror",
- "reg:squaredlogerror",
- "reg:pseudohubererror",
- ]
- ),
model_format=sampled_from(["binary", "json"]),
num_boost_round=integers(min_value=5, max_value=50),
num_parallel_tree=integers(min_value=1, max_value=5),
@@ -218,9 +217,14 @@ def test_xgb_regression(
):
# pylint: disable=too-many-locals
"""Test XGBoost with regression data"""
+
+ # See https://github.com/dmlc/xgboost/pull/9574
+ if objective == "reg:pseudohubererror":
+ pytest.xfail("XGBoost 2.0 has a bug in the serialization of Pseudo-Huber error")
+
X, y = dataset
if objective == "reg:squaredlogerror":
- assume(np.all(y > -1))
+ y = np.where(y <= -1, -0.9, y)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, shuffle=False
)
@@ -330,7 +334,7 @@ def test_xgb_multiclass_classifier(
("count:poisson", 4),
("rank:pairwise", 5),
("rank:ndcg", 5),
- ("rank:map", 5),
+ ("rank:map", 2),
],
),
model_format=sampled_from(["binary", "json"]),
diff --git a/tests/python/test_xgboost_integration.py b/tests/python/test_xgboost_integration.py
index 54e67550..82c21178 100644
--- a/tests/python/test_xgboost_integration.py
+++ b/tests/python/test_xgboost_integration.py
@@ -7,7 +7,7 @@
import numpy as np
import pytest
-from hypothesis import assume, given, settings
+from hypothesis import given, settings
from hypothesis.strategies import integers, lists, sampled_from
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
@@ -27,16 +27,15 @@
pytest.skip("XGBoost not installed; skipping", allow_module_level=True)
+@pytest.mark.parametrize("objective",
+ [
+ "reg:linear",
+ "reg:squarederror",
+ "reg:squaredlogerror",
+ "reg:pseudohubererror",
+ ])
@given(
toolchain=sampled_from(os_compatible_toolchains()),
- objective=sampled_from(
- [
- "reg:linear",
- "reg:squarederror",
- "reg:squaredlogerror",
- "reg:pseudohubererror",
- ]
- ),
model_format=sampled_from(["binary", "json"]),
num_parallel_tree=integers(min_value=1, max_value=10),
dataset=standard_regression_datasets(),
@@ -45,9 +44,14 @@
def test_xgb_regression(toolchain, objective, model_format, num_parallel_tree, dataset):
# pylint: disable=too-many-locals
"""Test a random regression dataset"""
+
+ # See https://github.com/dmlc/xgboost/pull/9574
+ if objective == "reg:pseudohubererror":
+ pytest.xfail("XGBoost 2.0 has a bug in the serialization of Pseudo-Huber error")
+
X, y = dataset
if objective == "reg:squaredlogerror":
- assume(np.all(y > -1))
+ y = np.where(y <= -1, -0.9, y)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, shuffle=False
)
@@ -59,6 +63,7 @@ def test_xgb_regression(toolchain, objective, model_format, num_parallel_tree, d
"verbosity": 0,
"objective": objective,
"num_parallel_tree": num_parallel_tree,
+ "base_score": 0.0
}
num_round = 10
bst = xgb.train(
@@ -96,7 +101,7 @@ def test_xgb_regression(toolchain, objective, model_format, num_parallel_tree, d
assert predictor.num_feature == dtrain.num_col()
assert predictor.num_class == 1
assert predictor.pred_transform == "identity"
- assert predictor.global_bias == 0.5
+ assert predictor.global_bias == 0.0
assert predictor.sigmoid_alpha == 1.0
dmat = treelite_runtime.DMatrix(X_test, dtype="float32")
out_pred = predictor.predict(dmat)
@@ -184,7 +189,7 @@ def test_xgb_iris(
("count:poisson", 4, math.log(0.5)),
("rank:pairwise", 5, 0.5),
("rank:ndcg", 5, 0.5),
- ("rank:map", 5, 0.5),
+ ("rank:map", 2, 0.5),
],
ids=[
"binary:logistic",
@@ -276,7 +281,8 @@ def test_xgb_deserializers(toolchain, dataset):
)
dtrain = xgb.DMatrix(X_train, label=y_train)
dtest = xgb.DMatrix(X_test, label=y_test)
- param = {"max_depth": 8, "eta": 1, "silent": 1, "objective": "reg:linear"}
+ param = {"max_depth": 8, "eta": 1, "silent": 1, "objective": "reg:linear",
+ "base_score": 0.5}
num_round = 10
bst = xgb.train(
param,
@@ -417,7 +423,6 @@ def test_xgb_dart(tmpdir, toolchain, model_format):
assert predictor.num_feature == dtrain.num_col()
assert predictor.num_class == 1
assert predictor.pred_transform == "sigmoid"
- np.testing.assert_almost_equal(predictor.global_bias, 0, decimal=5)
assert predictor.sigmoid_alpha == 1.0
dmat = treelite_runtime.DMatrix(X, dtype="float32")
out_pred = predictor.predict(dmat)