From 501f46e0777c5149db8b94021a0307a5a2730b83 Mon Sep 17 00:00:00 2001 From: Peva Blanchard Date: Thu, 17 Aug 2023 19:04:17 +0200 Subject: [PATCH] introduced operations interface + basic impl --- .../kleis/lcaplugin/core/math/Operations.kt | 43 ++++ .../lcaplugin/core/math/basic/BasicMatrix.kt | 5 + .../lcaplugin/core/math/basic/BasicNumber.kt | 3 + .../core/math/basic/BasicOperations.kt | 85 ++++++++ .../core/math/basic/BasicOperationsTest.kt | 183 ++++++++++++++++++ .../core/matrix/BasicMatrixFixture.kt | 23 +++ 6 files changed, 342 insertions(+) create mode 100644 src/main/kotlin/ch/kleis/lcaplugin/core/math/Operations.kt create mode 100644 src/main/kotlin/ch/kleis/lcaplugin/core/math/basic/BasicMatrix.kt create mode 100644 src/main/kotlin/ch/kleis/lcaplugin/core/math/basic/BasicNumber.kt create mode 100644 src/main/kotlin/ch/kleis/lcaplugin/core/math/basic/BasicOperations.kt create mode 100644 src/test/kotlin/ch/kleis/lcaplugin/core/math/basic/BasicOperationsTest.kt create mode 100644 src/test/kotlin/ch/kleis/lcaplugin/core/matrix/BasicMatrixFixture.kt diff --git a/src/main/kotlin/ch/kleis/lcaplugin/core/math/Operations.kt b/src/main/kotlin/ch/kleis/lcaplugin/core/math/Operations.kt new file mode 100644 index 000000000..74a656a7c --- /dev/null +++ b/src/main/kotlin/ch/kleis/lcaplugin/core/math/Operations.kt @@ -0,0 +1,43 @@ +package ch.kleis.lcaplugin.core.math + +interface Operations { + /* + Quantities + */ + operator fun Q.plus(other: Q): Q + operator fun Q.minus(other: Q): Q + operator fun Q.times(other: Q): Q + operator fun Q.div(other: Q): Q + + fun pure(value: Double): Q + operator fun Q.plus(other: Double): Q = this + pure(other) + operator fun Q.minus(other: Double): Q = this - pure(other) + operator fun Q.times(other: Double): Q = this * pure(other) + operator fun Q.div(other: Double): Q = this / pure(other) + + /* + Matrices + */ + + fun zeros( + rowDim: Int, colDim: Int, + ): M + + fun M.rowDim(): Int + fun M.colDim(): Int + + fun M.negate(): M + fun M.transpose(): M + + operator fun M.get(row: Int, col: Int): Q + operator fun M.set(row: Int, col: Int, value: Q) + + fun M.matMul(other: M): M + fun M.matDiv(other: M): M? + fun M.matTransposeDiv(other: M): M? { + return this.transpose().matDiv(other.transpose())?.transpose() + } + fun M.add(row: Int, col: Int, value: Q) { + this[row, col] = this[row, col] + value + } +} diff --git a/src/main/kotlin/ch/kleis/lcaplugin/core/math/basic/BasicMatrix.kt b/src/main/kotlin/ch/kleis/lcaplugin/core/math/basic/BasicMatrix.kt new file mode 100644 index 000000000..f96ce1040 --- /dev/null +++ b/src/main/kotlin/ch/kleis/lcaplugin/core/math/basic/BasicMatrix.kt @@ -0,0 +1,5 @@ +package ch.kleis.lcaplugin.core.math.basic + +import org.ejml.simple.SimpleMatrix + +data class BasicMatrix(internal val inner: SimpleMatrix) diff --git a/src/main/kotlin/ch/kleis/lcaplugin/core/math/basic/BasicNumber.kt b/src/main/kotlin/ch/kleis/lcaplugin/core/math/basic/BasicNumber.kt new file mode 100644 index 000000000..8c7436889 --- /dev/null +++ b/src/main/kotlin/ch/kleis/lcaplugin/core/math/basic/BasicNumber.kt @@ -0,0 +1,3 @@ +package ch.kleis.lcaplugin.core.math.basic + +data class BasicNumber(val value: Double) diff --git a/src/main/kotlin/ch/kleis/lcaplugin/core/math/basic/BasicOperations.kt b/src/main/kotlin/ch/kleis/lcaplugin/core/math/basic/BasicOperations.kt new file mode 100644 index 000000000..c11055859 --- /dev/null +++ b/src/main/kotlin/ch/kleis/lcaplugin/core/math/basic/BasicOperations.kt @@ -0,0 +1,85 @@ +package ch.kleis.lcaplugin.core.math.basic + +import ch.kleis.lcaplugin.core.math.Operations +import com.intellij.openapi.diagnostic.Logger +import org.ejml.data.DMatrixSparseCSC +import org.ejml.data.MatrixType +import org.ejml.simple.SimpleMatrix +import org.ejml.sparse.csc.CommonOps_DSCC + +class BasicOperations : Operations { + companion object { + private val LOG = Logger.getInstance(BasicOperations::class.java) + } + + override fun BasicNumber.plus(other: BasicNumber): BasicNumber { + return BasicNumber(value + other.value) + } + + override fun BasicNumber.minus(other: BasicNumber): BasicNumber { + return BasicNumber(value - other.value) + } + + override fun BasicNumber.times(other: BasicNumber): BasicNumber { + return BasicNumber(value * other.value) + } + + override fun BasicNumber.div(other: BasicNumber): BasicNumber { + return BasicNumber(value / other.value) + } + + override fun BasicMatrix.matMul(other: BasicMatrix): BasicMatrix { + val result = DMatrixSparseCSC(this.rowDim(), other.colDim()) + CommonOps_DSCC.mult(inner.dscc, other.inner.dscc, result) + return BasicMatrix( + SimpleMatrix.wrap(result) + ) + } + + override fun BasicMatrix.matDiv(other: BasicMatrix): BasicMatrix? { + LOG.info("Start solving lhs(${other.rowDim()}, ${other.rowDim()}) * x = rhs(${this.rowDim()}, ${this.colDim()})") + val lhs = other.inner.dscc + val rhs = this.inner.dscc + val result = DMatrixSparseCSC(lhs.numCols, rhs.numCols) + return if (CommonOps_DSCC.solve(lhs, rhs, result)) { + LOG.info("End solving with result(${result.numRows}, ${result.numCols})") + BasicMatrix( + SimpleMatrix.wrap(result) + ) + } else { + null + } + } + + override fun pure(value: Double): BasicNumber { + return BasicNumber(value) + } + + override fun zeros(rowDim: Int, colDim: Int): BasicMatrix { + return BasicMatrix(SimpleMatrix(rowDim, colDim, MatrixType.DSCC)) + } + + override fun BasicMatrix.colDim(): Int { + return this.inner.numCols + } + + override fun BasicMatrix.rowDim(): Int { + return this.inner.numRows + } + + override fun BasicMatrix.set(row: Int, col: Int, value: BasicNumber) { + this.inner[row, col] = value.value + } + + override fun BasicMatrix.get(row: Int, col: Int): BasicNumber { + return pure(this.inner[row, col]) + } + + override fun BasicMatrix.transpose(): BasicMatrix { + return BasicMatrix(inner.transpose()) + } + + override fun BasicMatrix.negate(): BasicMatrix { + return BasicMatrix(inner.negative()) + } +} diff --git a/src/test/kotlin/ch/kleis/lcaplugin/core/math/basic/BasicOperationsTest.kt b/src/test/kotlin/ch/kleis/lcaplugin/core/math/basic/BasicOperationsTest.kt new file mode 100644 index 000000000..d5efb9467 --- /dev/null +++ b/src/test/kotlin/ch/kleis/lcaplugin/core/math/basic/BasicOperationsTest.kt @@ -0,0 +1,183 @@ +package ch.kleis.lcaplugin.core.math.basic + +import ch.kleis.lcaplugin.core.matrix.BasicMatrixFixture +import org.junit.Assert +import org.junit.Test +import kotlin.test.assertEquals +import kotlin.test.assertNull + + +class BasicOperationsTest { + private val precision = 1e-6 + private val ops = BasicOperations() + + @Test + fun test_pure() { + with(ops) { + // given + val v = 1.0 + + // when + val actual = pure(v) + + // then + assertEquals(BasicNumber(1.0), actual) + } + } + + @Test + fun test_plus() { + with(ops) { + // given + val l = BasicNumber(1.0) + val r = BasicNumber(2.0) + + // when + val actual = l + r + + // then + assertEquals(BasicNumber(1.0 + 2.0), actual) + } + } + + @Test + fun test_minus() { + with(ops) { + // given + val l = BasicNumber(1.0) + val r = BasicNumber(2.0) + + // when + val actual = l - r + + // then + assertEquals(BasicNumber(1.0 - 2.0), actual) + } + } + + @Test + fun test_times() { + with(ops) { + // given + val l = BasicNumber(1.0) + val r = BasicNumber(2.0) + + // when + val actual = l * r + + // then + assertEquals(BasicNumber(1.0 * 2.0), actual) + } + } + + @Test + fun test_div() { + with(ops) { + // given + val l = BasicNumber(1.0) + val r = BasicNumber(2.0) + + // when + val actual = l / r + + // then + assertEquals(BasicNumber(1.0 / 2.0), actual) + } + } + + @Test + fun test_matDiv() { + with(ops) { + // given + val lhs = BasicMatrixFixture.make( + 2, 2, arrayOf( + 2.0, 0.0, + 0.0, 4.0, + ) + ) + val rhs = BasicMatrixFixture.make( + 2, 3, arrayOf( + 1.0, 0.0, 0.0, + 0.0, 2.0, 0.0, + ) + ) + + // when + val actual = rhs.matDiv(lhs)!! + + // then + assertBasicMatrixEqual( + actual, arrayOf( + 0.5, 0.0, 0.0, + 0.0, 0.5, 0.0, + ) + ) + } + } + + @Test + fun test_matDiv_whenZeroCols() { + with(ops) { + // given + val lhs = zeros(3, 0) + val rhs = BasicMatrixFixture.make( + 3, 1, arrayOf( + 1.0, + 2.0, + 3.0 + ) + ) + + // when + val actual = rhs.matDiv(lhs)!! + + // then + Assert.assertEquals(actual.rowDim(), 0) + Assert.assertEquals(actual.colDim(), 1) + } + } + + @Test + fun test_matDiv_whenNonInvertible() { + with(ops) { + // given + val lhs = BasicMatrixFixture.make( + 3, 3, arrayOf( + 1.0, -2.0, 0.0, + 0.0, 1.0, 0.0, + 0.0, 0.0, 0.0, + ) + ) + val rhs = BasicMatrixFixture.make( + 3, 2, arrayOf( + 1.0, 4.0, + 1.0, 2.0, + 1.0, 1.0, + ) + ) + + // when + val actual = rhs.matDiv(lhs) + + // then + assertNull(actual) + } + } + + + private fun assertBasicMatrixEqual(actual: BasicMatrix, expected: Array) { + with(ops) { + Assert.assertEquals(expected.size, actual.rowDim() * actual.colDim()) + for (row in 0 until actual.rowDim()) { + for (col in 0 until actual.colDim()) { + Assert.assertEquals( + "(row ${row}, col ${col}):", + expected[row * actual.colDim() + col], + actual[row, col].value, + precision + ) + } + } + } + } +} diff --git a/src/test/kotlin/ch/kleis/lcaplugin/core/matrix/BasicMatrixFixture.kt b/src/test/kotlin/ch/kleis/lcaplugin/core/matrix/BasicMatrixFixture.kt new file mode 100644 index 000000000..eb38fdffd --- /dev/null +++ b/src/test/kotlin/ch/kleis/lcaplugin/core/matrix/BasicMatrixFixture.kt @@ -0,0 +1,23 @@ +package ch.kleis.lcaplugin.core.matrix + +import ch.kleis.lcaplugin.core.math.basic.BasicMatrix +import ch.kleis.lcaplugin.core.math.basic.BasicOperations + +class BasicMatrixFixture { + companion object { + private val ops = BasicOperations() + fun make(rows: Int, cols: Int, data: Array): BasicMatrix { + with(ops) { + val a = zeros(rows, cols) + for (row in 0 until rows) { + for (col in 0 until cols) { + if(data[cols * row + col] != 0.0) { + a.add(row, col, pure(data[cols * row + col])) + } + } + } + return a + } + } + } +}