Skip to content

Commit

Permalink
fix(loss): update binary_cross_entropy loss implementation
Browse files Browse the repository at this point in the history
Account for negation of targets, (1 - predictions), avg across batch.
TODO: Implement categorical cross-entropy loss / softmax cross-entropy for multiclass classification.
  • Loading branch information
drewxs committed Jul 23, 2023
1 parent ea5f522 commit 5ef6998
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion src/loss.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,12 @@ pub fn mean_squared_error(predictions: &Tensor, targets: &Tensor) -> f64 {

/// Binary cross entropy loss.
pub fn binary_cross_entropy(predictions: &Tensor, targets: &Tensor) -> f64 {
targets.mul(&predictions.ln()).sum()
let epsilon = 1e-8; // Small value to avoid taking the logarithm of zero
let loss = &(predictions.add_scalar(epsilon).ln().mul(&targets).add(
&targets
.sub_scalar(1.0)
.mul(&predictions.sub_scalar(1.0).add_scalar(epsilon).ln()),
))
.mul_scalar(-1.0);
loss.mean()
}

0 comments on commit 5ef6998

Please sign in to comment.