Skip to content

Commit

Permalink
dual numbers
Browse files Browse the repository at this point in the history
  • Loading branch information
Peva Blanchard committed Jul 19, 2023
1 parent 74da549 commit 358ebca
Show file tree
Hide file tree
Showing 3 changed files with 448 additions and 0 deletions.
2 changes: 2 additions & 0 deletions build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ dependencies {
builtBy("generateEmissionFactors30")
})
implementation("org.ejml:ejml-simple:0.43")
implementation("ai.djl:api:0.23.0")
implementation("ai.djl.pytorch:pytorch-engine:0.23.0")

val arrowVersion = "1.1.5"
val olcaSimaproVersion = "3.0.5"
Expand Down
130 changes: 130 additions & 0 deletions src/main/kotlin/ch/kleis/lcaplugin/core/lang/math/DualNumber.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
package ch.kleis.lcaplugin.core.lang.math

import ai.djl.ndarray.NDArray
import ai.djl.ndarray.NDManager
import ai.djl.ndarray.types.DataType
import ai.djl.ndarray.types.Shape
import kotlin.math.pow


data class DualNumber(
val zeroth: Double,
val first: NDArray,
) {
override fun toString(): String {
return "$zeroth"
}

companion object {
private val manager = NDManager.newBaseManager()

fun constant(c: Double): DualNumber {
return constant(c, Shape())
}

fun constant(c: Double, shape: Shape): DualNumber {
return DualNumber(
c,
manager.zeros(shape, DataType.FLOAT64)
)
}

fun basis(dim: Int, index: Int): DualNumber {
return DualNumber(
0.0,
manager.create(IntRange(0, dim - 1).map {
if (index == it) 1.0 else 0.0
}.toDoubleArray())
)
}
}

/*
Arithmetic
*/

fun pow(e: Double): DualNumber {
return DualNumber(
this.zeroth.pow(e),
this.first.mul(e * this.zeroth.pow(e - 1.0)), // (x^e)' = e.x^(e-1).dx
)
}

operator fun plus(other: DualNumber): DualNumber {
return DualNumber(
this.zeroth + other.zeroth,
this.first.add(other.first),
)
}

operator fun minus(other: DualNumber): DualNumber {
return DualNumber(
this.zeroth - other.zeroth,
this.first.sub(other.first),
)
}

operator fun times(other: DualNumber): DualNumber {
return DualNumber(
this.zeroth * other.zeroth,
this.first.mul(other.zeroth).add(other.first.mul(zeroth)), // (a.b)' = a'.b + a.b'
)
}

operator fun div(other: DualNumber): DualNumber {
return DualNumber(
this.zeroth / other.zeroth,
this.first.div(other.zeroth)
.sub(other.first.mul(this.zeroth / other.zeroth.pow(2))), // (a/b)' = a'/b - a.b'/b^2
)
}

operator fun unaryMinus(): DualNumber {
TODO()
}

operator fun unaryPlus(): DualNumber {
TODO()
}

/*
Right Arithmetic with Double
*/

operator fun plus(other: Double): DualNumber {
return this.plus(constant(other, this.first.shape))
}

operator fun minus(other: Double): DualNumber {
return this.minus(constant(other, this.first.shape))
}

operator fun times(other: Double): DualNumber {
return this.times(constant(other, this.first.shape))
}

operator fun div(other: Double): DualNumber {
return this.div(constant(other, this.first.shape))
}
}

/*
Left Arithmetic with Double
*/

operator fun Double.plus(other: DualNumber): DualNumber {
return DualNumber.constant(this, other.first.shape).plus(other)
}

operator fun Double.minus(other: DualNumber): DualNumber {
return DualNumber.constant(this, other.first.shape).minus(other)
}

operator fun Double.times(other: DualNumber): DualNumber {
return DualNumber.constant(this, other.first.shape).times(other)
}

operator fun Double.div(other: DualNumber): DualNumber {
return DualNumber.constant(this, other.first.shape).div(other)
}

Loading

0 comments on commit 358ebca

Please sign in to comment.