Skip to content

Commit

Permalink
introduce LossFunction trait to gradient boosting and regression tree.
Browse files Browse the repository at this point in the history
from now on, you can use custom loss function.
  • Loading branch information
rmitsuboshi committed Jan 13, 2025
1 parent 3347cf5 commit 4a156aa
Show file tree
Hide file tree
Showing 11 changed files with 272 additions and 294 deletions.
7 changes: 3 additions & 4 deletions src/booster/gradient_boost/gbm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,12 @@ use std::ops::ControlFlow;
/// // Note that the default tolerance parameter is set as `1 / n_sample`,
/// // where `n_sample = data.shape().0` is
/// // the number of training examples in `data`.
/// let booster = GBM::init(&sample)
/// .loss(GBMLoss::L1);
/// let booster = GBM::init_with_loss(&sample, GBMLoss::L2);
///
/// // Set the weak learner with setting parameters.
/// let weak_learner = RegressionTreeBuilder::new(&sample)
/// .max_depth(2)
/// .loss(LossType::L1)
/// .loss(LossType::L2)
/// .build();
///
/// // Run `GBM` and obtain the resulting hypothesis `f`.
Expand All @@ -80,7 +79,7 @@ use std::ops::ControlFlow;
/// let training_loss = sample.target()
/// .into_iter()
/// .zip(predictions)
/// .map(|(y, fx)| (y - fx).abs())
/// .map(|(y, fx)| (y - fx).powi(2))
/// .sum::<f64>()
/// / n_sample;
///
Expand Down
82 changes: 75 additions & 7 deletions src/common/loss_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,16 @@ pub trait LossFunction {
/ n_items as f64
}

/// Gradient vector at current point.
/// Gradient vector at the current point.
fn gradient(&self, predictions: &[f64], target: &[f64]) -> Vec<f64>;


/// Hessian at the current point.
/// Here, this method assumes that the Hessian is diagonal,
/// so that it returns a diagonal vector.
fn hessian(&self, predictions: &[f64], target: &[f64]) -> Vec<f64>;


/// Best coffecient for the newly-attained hypothesis.
fn best_coefficient(
&self,
Expand All @@ -45,6 +51,17 @@ pub enum GBMLoss {
/// This loss function is also known as
/// **Mean Squared Error (MSE)**.
L2,


// /// Huber loss with parameter `delta`.
// /// Huber loss maps the given scalar `z` to
// /// `0.5 * z.powi(2)` if `z.abs() < delta`,
// /// `delta * (z.abs() - 0.5 * delta)`, otherwise.
// Huber(f64),


// /// Quantile loss
// Quantile(f64),
}


Expand All @@ -53,6 +70,7 @@ impl LossFunction for GBMLoss {
match self {
Self::L1 => "L1 loss",
Self::L2 => "L2 loss",
// Self::Huber(_) => "Huber loss",
}
}

Expand All @@ -61,6 +79,14 @@ impl LossFunction for GBMLoss {
match self {
Self::L1 => (prediction - true_value).abs(),
Self::L2 => (prediction - true_value).powi(2),
// Self::Huber(delta) => {
// let diff = (prediction - true_value).abs();
// if diff < *delta {
// 0.5 * diff.powi(2)
// } else {
// delta * (diff - 0.5 * delta)
// }
// },
}
}

Expand All @@ -73,17 +99,59 @@ impl LossFunction for GBMLoss {

match self {
Self::L1 => {
predictions.iter()
.zip(target)
.map(|(p, y)| (p - y).signum() / n_sample)
target.iter()
.zip(predictions)
.map(|(y, p)| (y - p).signum())
.collect()
},
Self::L2 => {
target.iter()
.zip(predictions)
.map(|(y, p)| p - y)
.collect()
},
// Self::Huber(delta) => {
// target.iter()
// .zip(predictions)
// .map(|(y, p)| {
// let diff = y - p;
// if diff.abs() < *delta {
// -diff
// } else {
// delta * diff.signum()
// }
// })
// .collect::<Vec<_>>()
// },
}
}


fn hessian(&self, predictions: &[f64], target: &[f64]) -> Vec<f64>
{
let n_sample = predictions.len();
assert_eq!(n_sample as usize, target.len());

match self {
Self::L1 => {
std::iter::repeat(0f64)
.take(n_sample)
.collect()
},
Self::L2 => {
predictions.iter()
.zip(target)
.map(|(p, y)| 2.0 * (p - y) / n_sample)
std::iter::repeat(1f64)
.take(n_sample)
.collect()
},
// Self::Huber(delta) => {
// target.iter()
// .zip(predictions)
// .map(|(y, p)| {
// let diff = (y - p).abs();
// if diff < *delta { 1f64 } else { 0f64 }
// })
// .collect::<Vec<_>>()
// },
}
}

Expand Down
1 change: 0 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,6 @@ pub use weak_learner::{
RegressionTree,
RegressionTreeBuilder,
RegressionTreeRegressor,
LossType,
};

/// Some useful functions / traits
Expand Down
2 changes: 1 addition & 1 deletion src/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ pub use crate::weak_learner::{
RegressionTree,
RegressionTreeBuilder,
RegressionTreeRegressor,
LossType,
};


Expand All @@ -81,6 +80,7 @@ pub use crate::{

pub use crate::common::{
loss_functions::GBMLoss,
loss_functions::LossFunction,
frank_wolfe::FWType,
};

1 change: 0 additions & 1 deletion src/weak_learner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ pub use self::naive_bayes::{


pub use self::regression_tree::{
LossType,
RegressionTree,
RegressionTreeBuilder,
RegressionTreeRegressor,
Expand Down
4 changes: 0 additions & 4 deletions src/weak_learner/regression_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@ mod regression_tree_algorithm;
// This file defines the regression tree regressor.
mod regression_tree_regressor;

// This file defines the loss type.
mod loss;

// Regression Tree builder.
mod builder;

Expand All @@ -18,5 +15,4 @@ mod train_node;

pub use regression_tree_algorithm::RegressionTree;
pub use regression_tree_regressor::RegressionTreeRegressor;
pub use loss::LossType;
pub use builder::RegressionTreeBuilder;
66 changes: 30 additions & 36 deletions src/weak_learner/regression_tree/bin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,8 @@ const EPS: f64 = 0.001;
const NUM_TOLERANCE: f64 = 1e-9;


/// A struct that stores the first/second order derivative information.
#[derive(Clone,Default)]
pub(crate) struct GradientHessian {
pub(crate) grad: f64,
pub(crate) hess: f64,
}


impl GradientHessian {
pub(super) fn new(grad: f64, hess: f64) -> Self {
Self { grad, hess }
}
}
type Gradient = f64;
type Hessian = f64;


/// Binning: A feature processing.
Expand Down Expand Up @@ -209,11 +198,13 @@ impl Bins {
&self,
indices: &[usize],
feat: &Feature,
gh: &[GradientHessian],
) -> Vec<(Bin, GradientHessian)>
gradient: &[Gradient],
hessian: &[Hessian],
) -> Vec<(Bin, Gradient, Hessian)>
{
let n_bins = self.0.len();
let mut packed = vec![GradientHessian::default(); n_bins];
let mut grad_pack = vec![0f64; n_bins];
let mut hess_pack = vec![0f64; n_bins];

for &i in indices {
let xi = feat[i];
Expand All @@ -226,10 +217,10 @@ impl Bins {
range.0.start.partial_cmp(&xi).unwrap()
})
.unwrap();
packed[pos].grad += gh[i].grad;
packed[pos].hess += gh[i].hess;
grad_pack[pos] += gradient[i];
hess_pack[pos] += hessian[i];
}
self.remove_zero_weight_pack_and_normalize(packed)
self.remove_zero_weight_pack_and_normalize(grad_pack, hess_pack)
}


Expand All @@ -254,42 +245,45 @@ impl Bins {
/// -
fn remove_zero_weight_pack_and_normalize(
&self,
pack: Vec<GradientHessian>,
) -> Vec<(Bin, GradientHessian)>
grad_pack: Vec<Gradient>,
hess_pack: Vec<Hessian>,
) -> Vec<(Bin, Gradient, Hessian)>
{
let mut iter = self.0.iter().zip(pack);
let mut iter = self.0.iter().zip(grad_pack.into_iter().zip(hess_pack));

let (prev_bin, mut prev_gh) = iter.next().unwrap();
let (prev_bin, (mut prev_grad, mut prev_hess)) = iter.next().unwrap();

let mut prev_bin = Bin::new(prev_bin.0.clone());
let mut iter = iter.filter(|(_, gh)| {
gh.grad != 0.0 || gh.hess != 0.0
let mut iter = iter.filter(|(_, (grad, hess))| {
*grad != 0.0 || *hess != 0.0
});

// The left-most bin might have zero weight.
// In this case, find the next non-zero weight bin and merge.
if prev_gh.grad == 0.0 && prev_gh.hess == 0.0 {
let (next_bin, next_gh) = iter.next().unwrap();
if prev_grad == 0.0 && prev_hess == 0.0 {
let (next_bin, (next_grad, next_hess)) = iter.next().unwrap();

let start = prev_bin.0.start;
let end = next_bin.0.end;
prev_bin = Bin::new(start..end);
prev_gh = next_gh;
prev_grad = next_grad;
prev_hess = next_hess;
}

let mut bin_and_gh = Vec::new();
for (next_bin, next_gh) in iter {
for (next_bin, (next_grad, next_hess)) in iter {
let start = prev_bin.0.start;
let end = (prev_bin.0.end + next_bin.0.start) / 2.0;
let bin = Bin::new(start..end);
bin_and_gh.push((bin, prev_gh));
bin_and_gh.push((bin, prev_grad, prev_hess));


prev_bin = Bin::new(next_bin.0.clone());
prev_bin.0.start = end;
prev_gh = next_gh;
prev_grad = next_grad;
prev_hess = next_hess;
}
bin_and_gh.push((prev_bin, prev_gh));
bin_and_gh.push((prev_bin, prev_grad, prev_hess));

bin_and_gh
}
Expand All @@ -310,7 +304,7 @@ impl fmt::Display for Bins {
let tail = bins.last()
.map(|bin| format!("{bin}"))
.unwrap();
write!(f, "{head}, ... , {tail}")
write!(f, "{head}, ... , {tail}")
} else {
let line = bins.iter()
.map(|bin| format!("{}", bin))
Expand All @@ -335,7 +329,7 @@ impl fmt::Display for Bin {
' '
};
let start = start.abs();
format!("{sgn}{start: >.2}")
format!("{sgn}{start: >.1}")
};
let end = if self.0.end == f64::MAX {
String::from("+Inf")
Expand All @@ -349,9 +343,9 @@ impl fmt::Display for Bin {
' '
};
let end = end.abs();
format!("{sgn}{end: >.2}")
format!("{sgn}{end: >.1}")
};

write!(f, "[{start}, {end})")
write!(f, "[{start: >8}, {end: >8})")
}
}
Loading

0 comments on commit 4a156aa

Please sign in to comment.