generated from JetBrains/intellij-platform-plugin-template
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Peva Blanchard
committed
Jul 20, 2023
1 parent
3736120
commit 3c2b192
Showing
10 changed files
with
644 additions
and
91 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
package ch.kleis.lcaplugin.core.math | ||
|
||
import ch.kleis.lcaplugin.core.lang.evaluator.EvaluatorException | ||
import ch.kleis.lcaplugin.core.math.ejml.EJMLMatrix | ||
import ch.kleis.lcaplugin.core.math.ejml.EJMLMatrixFactory | ||
import org.ejml.data.DMatrixSparseCSC | ||
import org.jetbrains.kotlinx.multik.api.mk | ||
import org.jetbrains.kotlinx.multik.api.zeros | ||
import org.jetbrains.kotlinx.multik.ndarray.data.D1Array | ||
import org.jetbrains.kotlinx.multik.ndarray.data.D2Array | ||
import org.jetbrains.kotlinx.multik.ndarray.data.D3Array | ||
import org.jetbrains.kotlinx.multik.ndarray.data.set | ||
import org.jetbrains.kotlinx.multik.ndarray.operations.forEachMultiIndexed | ||
|
||
fun d1MatchDimensions(l: D1Array<Double>, r: D1Array<Double>): Pair<D1Array<Double>, D1Array<Double>> { | ||
return when { | ||
l.isEmpty() -> mk.zeros<Double>(r.size) to r | ||
r.isEmpty() -> l to mk.zeros(l.size) | ||
l.size != r.size -> throw EvaluatorException("d1arrays cannot be broadcast") | ||
else -> l to r | ||
} | ||
} | ||
|
||
fun d2Ejml(m: D2Array<Double>): EJMLMatrix { | ||
val (rows, cols) = m.shape | ||
val r = EJMLMatrixFactory.INSTANCE.zero(rows, cols) | ||
m.forEachMultiIndexed { index, d -> | ||
val (i, j) = index | ||
r.set(i, j, d) | ||
} | ||
return r | ||
} | ||
|
||
fun d3Ejml(m: D3Array<Double>): EJMLMatrix { | ||
val (rows, cols, nps) = m.shape | ||
val nParams = if (nps == 0) 1 else nps | ||
val r = EJMLMatrixFactory.INSTANCE.zero(rows, cols * nParams) | ||
m.forEachMultiIndexed { index, d -> | ||
val (i, j, k) = index | ||
r.set(i, j * nParams + k, d) | ||
} | ||
return r | ||
} | ||
|
||
fun ejmlD2(m: EJMLMatrix): D2Array<Double> { | ||
val (rows, cols) = Pair(m.rowDim(), m.colDim()) | ||
val r = mk.zeros<Double>(rows, cols) | ||
val iterator = m.matrix.getMatrix<DMatrixSparseCSC>().createCoordinateIterator() | ||
while (iterator.hasNext()) { | ||
val v = iterator.next() | ||
r[v.row, v.col] = v.value | ||
} | ||
return r | ||
} | ||
|
||
fun ejmlD3(m: EJMLMatrix, nParams: Int): D3Array<Double> { | ||
val (rows, cols) = Pair(m.rowDim(), m.colDim()) | ||
val r = mk.zeros<Double>(rows, cols, nParams) | ||
if (nParams == 0) return r | ||
val iterator = m.matrix.getMatrix<DMatrixSparseCSC>().createCoordinateIterator() | ||
while (iterator.hasNext()) { | ||
val v = iterator.next() | ||
val i = v.row | ||
val j = v.col / nParams | ||
val k = v.col % nParams | ||
r[i, j, k] = v.value | ||
} | ||
return r | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
92 changes: 92 additions & 0 deletions
92
src/test/kotlin/ch/kleis/lcaplugin/core/math/DualMatrixTest.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
package ch.kleis.lcaplugin.core.math | ||
|
||
import ch.kleis.lcaplugin.core.math.MatrixFixture.Companion.makeDualMatrix | ||
import org.junit.Test | ||
import kotlin.test.assertEquals | ||
|
||
|
||
class DualMatrixTest { | ||
private val dx = DualNumber.basis(2, 0) | ||
private val dy = DualNumber.basis(2, 1) | ||
private fun c(d: Double): DualNumber { | ||
return DualNumber.constant(d) | ||
} | ||
|
||
@Test | ||
fun value() { | ||
// given | ||
val m = makeDualMatrix(2, 3, 2, | ||
arrayOf( | ||
c(1.0), c(2.0) + dx, c(3.0), | ||
c(4.0), c(5.0), c(6.0), | ||
) | ||
) | ||
|
||
// when | ||
val actual = m.value(0, 1, ) | ||
|
||
// then | ||
val expected = c(2.0) + dx | ||
assertEquals(expected, actual) | ||
} | ||
|
||
@Test | ||
fun set() { | ||
// given | ||
val m = makeDualMatrix(2, 3, 2, | ||
arrayOf( | ||
c(1.0), c(2.0), c(3.0), | ||
c(4.0), c(5.0), c(6.0), | ||
) | ||
) | ||
|
||
// when | ||
m.set(0, 1, c(3.0) + dx) | ||
|
||
// then | ||
val expected = makeDualMatrix(2, 3, 2, | ||
arrayOf( | ||
c(1.0), c(3.0) + dx, c(3.0), | ||
c(4.0), c(5.0), c(6.0), | ||
) | ||
) | ||
assertEquals(expected, m) | ||
} | ||
|
||
|
||
@Test | ||
fun minus() { | ||
} | ||
|
||
@Test | ||
fun plus() { | ||
} | ||
|
||
@Test | ||
fun matMul() { | ||
} | ||
|
||
@Test | ||
fun matDiv() { | ||
} | ||
|
||
@Test | ||
fun negate() { | ||
} | ||
|
||
@Test | ||
fun transpose() { | ||
} | ||
|
||
@Test | ||
fun rowDim() { | ||
} | ||
|
||
@Test | ||
fun colDim() { | ||
} | ||
|
||
@Test | ||
fun add() { | ||
} | ||
} |
Oops, something went wrong.