Skip to content

Commit

Permalink
Merge pull request #3 from rmitsuboshi/custom-loss-function-for-gbm
Browse files Browse the repository at this point in the history
Custom loss function for gbm #2
  • Loading branch information
rmitsuboshi authored Jan 13, 2025
2 parents 4cdf19d + 4a156aa commit 82e2dd0
Show file tree
Hide file tree
Showing 11 changed files with 282 additions and 303 deletions.
24 changes: 12 additions & 12 deletions src/booster/gradient_boost/gbm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,12 @@ use std::ops::ControlFlow;
/// // Note that the default tolerance parameter is set as `1 / n_sample`,
/// // where `n_sample = data.shape().0` is
/// // the number of training examples in `data`.
/// let booster = GBM::init(&sample)
/// .loss(GBMLoss::L1);
/// let booster = GBM::init_with_loss(&sample, GBMLoss::L2);
///
/// // Set the weak learner with setting parameters.
/// let weak_learner = RegressionTreeBuilder::new(&sample)
/// .max_depth(2)
/// .loss(LossType::L1)
/// .loss(LossType::L2)
/// .build();
///
/// // Run `GBM` and obtain the resulting hypothesis `f`.
Expand All @@ -80,14 +79,14 @@ use std::ops::ControlFlow;
/// let training_loss = sample.target()
/// .into_iter()
/// .zip(predictions)
/// .map(|(y, fx)| (y - fx).abs())
/// .map(|(y, fx)| (y - fx).powi(2))
/// .sum::<f64>()
/// / n_sample;
///
///
/// println!("Training Loss is: {training_loss}");
/// ```
pub struct GBM<'a, F> {
pub struct GBM<'a, F, L> {
// Training data
sample: &'a Sample,

Expand All @@ -103,7 +102,7 @@ pub struct GBM<'a, F> {


// Some struct that implements `LossFunction` trait
loss: GBMLoss,
loss: L,


// Max iteration until GBM guarantees the optimality.
Expand All @@ -122,11 +121,11 @@ pub struct GBM<'a, F> {



impl<'a, F> GBM<'a, F>
impl<'a, F, L> GBM<'a, F, L>
{
/// Initialize the `GBM`.
/// This method sets some parameters `GBM` holds.
pub fn init(sample: &'a Sample) -> Self {
pub fn init_with_loss(sample: &'a Sample, loss: L) -> Self {

let n_sample = sample.shape().0;
let predictions = vec![0.0; n_sample];
Expand All @@ -138,7 +137,7 @@ impl<'a, F> GBM<'a, F>
weights: Vec::new(),
hypotheses: Vec::new(),

loss: GBMLoss::L2,
loss,

max_iter: 100,

Expand All @@ -150,7 +149,7 @@ impl<'a, F> GBM<'a, F>
}


impl<'a, F> GBM<'a, F> {
impl<'a, F, L> GBM<'a, F, L> {
/// Returns the maximum iteration
/// of the `GBM` to find a combined hypothesis
/// that has error at most `tolerance`.
Expand All @@ -170,15 +169,16 @@ impl<'a, F> GBM<'a, F> {


/// Set the Loss Type.
pub fn loss(mut self, loss_type: GBMLoss) -> Self {
pub fn loss(mut self, loss_type: L) -> Self {
self.loss = loss_type;
self
}
}


impl<F> Booster<F> for GBM<'_, F>
impl<F, L> Booster<F> for GBM<'_, F, L>
where F: Regressor + Clone,
L: LossFunction,
{
type Output = WeightedMajority<F>;

Expand Down
82 changes: 75 additions & 7 deletions src/common/loss_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,16 @@ pub trait LossFunction {
/ n_items as f64
}

/// Gradient vector at current point.
/// Gradient vector at the current point.
fn gradient(&self, predictions: &[f64], target: &[f64]) -> Vec<f64>;


/// Hessian at the current point.
/// Here, this method assumes that the Hessian is diagonal,
/// so that it returns a diagonal vector.
fn hessian(&self, predictions: &[f64], target: &[f64]) -> Vec<f64>;


/// Best coffecient for the newly-attained hypothesis.
fn best_coefficient(
&self,
Expand All @@ -45,6 +51,17 @@ pub enum GBMLoss {
/// This loss function is also known as
/// **Mean Squared Error (MSE)**.
L2,


// /// Huber loss with parameter `delta`.
// /// Huber loss maps the given scalar `z` to
// /// `0.5 * z.powi(2)` if `z.abs() < delta`,
// /// `delta * (z.abs() - 0.5 * delta)`, otherwise.
// Huber(f64),


// /// Quantile loss
// Quantile(f64),
}


Expand All @@ -53,6 +70,7 @@ impl LossFunction for GBMLoss {
match self {
Self::L1 => "L1 loss",
Self::L2 => "L2 loss",
// Self::Huber(_) => "Huber loss",
}
}

Expand All @@ -61,6 +79,14 @@ impl LossFunction for GBMLoss {
match self {
Self::L1 => (prediction - true_value).abs(),
Self::L2 => (prediction - true_value).powi(2),
// Self::Huber(delta) => {
// let diff = (prediction - true_value).abs();
// if diff < *delta {
// 0.5 * diff.powi(2)
// } else {
// delta * (diff - 0.5 * delta)
// }
// },
}
}

Expand All @@ -73,17 +99,59 @@ impl LossFunction for GBMLoss {

match self {
Self::L1 => {
predictions.iter()
.zip(target)
.map(|(p, y)| (p - y).signum() / n_sample)
target.iter()
.zip(predictions)
.map(|(y, p)| (y - p).signum())
.collect()
},
Self::L2 => {
target.iter()
.zip(predictions)
.map(|(y, p)| p - y)
.collect()
},
// Self::Huber(delta) => {
// target.iter()
// .zip(predictions)
// .map(|(y, p)| {
// let diff = y - p;
// if diff.abs() < *delta {
// -diff
// } else {
// delta * diff.signum()
// }
// })
// .collect::<Vec<_>>()
// },
}
}


fn hessian(&self, predictions: &[f64], target: &[f64]) -> Vec<f64>
{
let n_sample = predictions.len();
assert_eq!(n_sample as usize, target.len());

match self {
Self::L1 => {
std::iter::repeat(0f64)
.take(n_sample)
.collect()
},
Self::L2 => {
predictions.iter()
.zip(target)
.map(|(p, y)| 2.0 * (p - y) / n_sample)
std::iter::repeat(1f64)
.take(n_sample)
.collect()
},
// Self::Huber(delta) => {
// target.iter()
// .zip(predictions)
// .map(|(y, p)| {
// let diff = (y - p).abs();
// if diff < *delta { 1f64 } else { 0f64 }
// })
// .collect::<Vec<_>>()
// },
}
}

Expand Down
1 change: 0 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,6 @@ pub use weak_learner::{
RegressionTree,
RegressionTreeBuilder,
RegressionTreeRegressor,
LossType,
};

/// Some useful functions / traits
Expand Down
2 changes: 1 addition & 1 deletion src/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ pub use crate::weak_learner::{
RegressionTree,
RegressionTreeBuilder,
RegressionTreeRegressor,
LossType,
};


Expand All @@ -81,6 +80,7 @@ pub use crate::{

pub use crate::common::{
loss_functions::GBMLoss,
loss_functions::LossFunction,
frank_wolfe::FWType,
};

1 change: 0 additions & 1 deletion src/weak_learner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ pub use self::naive_bayes::{


pub use self::regression_tree::{
LossType,
RegressionTree,
RegressionTreeBuilder,
RegressionTreeRegressor,
Expand Down
4 changes: 0 additions & 4 deletions src/weak_learner/regression_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@ mod regression_tree_algorithm;
// This file defines the regression tree regressor.
mod regression_tree_regressor;

// This file defines the loss type.
mod loss;

// Regression Tree builder.
mod builder;

Expand All @@ -18,5 +15,4 @@ mod train_node;

pub use regression_tree_algorithm::RegressionTree;
pub use regression_tree_regressor::RegressionTreeRegressor;
pub use loss::LossType;
pub use builder::RegressionTreeBuilder;
Loading

0 comments on commit 82e2dd0

Please sign in to comment.