generated from JetBrains/intellij-platform-plugin-template
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
introduced operations interface + basic impl
- Loading branch information
Peva Blanchard
committed
Aug 17, 2023
1 parent
22c7928
commit 501f46e
Showing
6 changed files
with
342 additions
and
0 deletions.
There are no files selected for viewing
43 changes: 43 additions & 0 deletions
43
src/main/kotlin/ch/kleis/lcaplugin/core/math/Operations.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
package ch.kleis.lcaplugin.core.math | ||
|
||
interface Operations<Q, M> { | ||
/* | ||
Quantities | ||
*/ | ||
operator fun Q.plus(other: Q): Q | ||
operator fun Q.minus(other: Q): Q | ||
operator fun Q.times(other: Q): Q | ||
operator fun Q.div(other: Q): Q | ||
|
||
fun pure(value: Double): Q | ||
operator fun Q.plus(other: Double): Q = this + pure(other) | ||
operator fun Q.minus(other: Double): Q = this - pure(other) | ||
operator fun Q.times(other: Double): Q = this * pure(other) | ||
operator fun Q.div(other: Double): Q = this / pure(other) | ||
|
||
/* | ||
Matrices | ||
*/ | ||
|
||
fun zeros( | ||
rowDim: Int, colDim: Int, | ||
): M | ||
|
||
fun M.rowDim(): Int | ||
fun M.colDim(): Int | ||
|
||
fun M.negate(): M | ||
fun M.transpose(): M | ||
|
||
operator fun M.get(row: Int, col: Int): Q | ||
operator fun M.set(row: Int, col: Int, value: Q) | ||
|
||
fun M.matMul(other: M): M | ||
fun M.matDiv(other: M): M? | ||
fun M.matTransposeDiv(other: M): M? { | ||
return this.transpose().matDiv(other.transpose())?.transpose() | ||
} | ||
fun M.add(row: Int, col: Int, value: Q) { | ||
this[row, col] = this[row, col] + value | ||
} | ||
} |
5 changes: 5 additions & 0 deletions
5
src/main/kotlin/ch/kleis/lcaplugin/core/math/basic/BasicMatrix.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
package ch.kleis.lcaplugin.core.math.basic | ||
|
||
import org.ejml.simple.SimpleMatrix | ||
|
||
data class BasicMatrix(internal val inner: SimpleMatrix) |
3 changes: 3 additions & 0 deletions
3
src/main/kotlin/ch/kleis/lcaplugin/core/math/basic/BasicNumber.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
package ch.kleis.lcaplugin.core.math.basic | ||
|
||
data class BasicNumber(val value: Double) |
85 changes: 85 additions & 0 deletions
85
src/main/kotlin/ch/kleis/lcaplugin/core/math/basic/BasicOperations.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
package ch.kleis.lcaplugin.core.math.basic | ||
|
||
import ch.kleis.lcaplugin.core.math.Operations | ||
import com.intellij.openapi.diagnostic.Logger | ||
import org.ejml.data.DMatrixSparseCSC | ||
import org.ejml.data.MatrixType | ||
import org.ejml.simple.SimpleMatrix | ||
import org.ejml.sparse.csc.CommonOps_DSCC | ||
|
||
class BasicOperations : Operations<BasicNumber, BasicMatrix> { | ||
companion object { | ||
private val LOG = Logger.getInstance(BasicOperations::class.java) | ||
} | ||
|
||
override fun BasicNumber.plus(other: BasicNumber): BasicNumber { | ||
return BasicNumber(value + other.value) | ||
} | ||
|
||
override fun BasicNumber.minus(other: BasicNumber): BasicNumber { | ||
return BasicNumber(value - other.value) | ||
} | ||
|
||
override fun BasicNumber.times(other: BasicNumber): BasicNumber { | ||
return BasicNumber(value * other.value) | ||
} | ||
|
||
override fun BasicNumber.div(other: BasicNumber): BasicNumber { | ||
return BasicNumber(value / other.value) | ||
} | ||
|
||
override fun BasicMatrix.matMul(other: BasicMatrix): BasicMatrix { | ||
val result = DMatrixSparseCSC(this.rowDim(), other.colDim()) | ||
CommonOps_DSCC.mult(inner.dscc, other.inner.dscc, result) | ||
return BasicMatrix( | ||
SimpleMatrix.wrap(result) | ||
) | ||
} | ||
|
||
override fun BasicMatrix.matDiv(other: BasicMatrix): BasicMatrix? { | ||
LOG.info("Start solving lhs(${other.rowDim()}, ${other.rowDim()}) * x = rhs(${this.rowDim()}, ${this.colDim()})") | ||
val lhs = other.inner.dscc | ||
val rhs = this.inner.dscc | ||
val result = DMatrixSparseCSC(lhs.numCols, rhs.numCols) | ||
return if (CommonOps_DSCC.solve(lhs, rhs, result)) { | ||
LOG.info("End solving with result(${result.numRows}, ${result.numCols})") | ||
BasicMatrix( | ||
SimpleMatrix.wrap(result) | ||
) | ||
} else { | ||
null | ||
} | ||
} | ||
|
||
override fun pure(value: Double): BasicNumber { | ||
return BasicNumber(value) | ||
} | ||
|
||
override fun zeros(rowDim: Int, colDim: Int): BasicMatrix { | ||
return BasicMatrix(SimpleMatrix(rowDim, colDim, MatrixType.DSCC)) | ||
} | ||
|
||
override fun BasicMatrix.colDim(): Int { | ||
return this.inner.numCols | ||
} | ||
|
||
override fun BasicMatrix.rowDim(): Int { | ||
return this.inner.numRows | ||
} | ||
|
||
override fun BasicMatrix.set(row: Int, col: Int, value: BasicNumber) { | ||
this.inner[row, col] = value.value | ||
} | ||
|
||
override fun BasicMatrix.get(row: Int, col: Int): BasicNumber { | ||
return pure(this.inner[row, col]) | ||
} | ||
|
||
override fun BasicMatrix.transpose(): BasicMatrix { | ||
return BasicMatrix(inner.transpose()) | ||
} | ||
|
||
override fun BasicMatrix.negate(): BasicMatrix { | ||
return BasicMatrix(inner.negative()) | ||
} | ||
} |
183 changes: 183 additions & 0 deletions
183
src/test/kotlin/ch/kleis/lcaplugin/core/math/basic/BasicOperationsTest.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
package ch.kleis.lcaplugin.core.math.basic | ||
|
||
import ch.kleis.lcaplugin.core.matrix.BasicMatrixFixture | ||
import org.junit.Assert | ||
import org.junit.Test | ||
import kotlin.test.assertEquals | ||
import kotlin.test.assertNull | ||
|
||
|
||
class BasicOperationsTest { | ||
private val precision = 1e-6 | ||
private val ops = BasicOperations() | ||
|
||
@Test | ||
fun test_pure() { | ||
with(ops) { | ||
// given | ||
val v = 1.0 | ||
|
||
// when | ||
val actual = pure(v) | ||
|
||
// then | ||
assertEquals(BasicNumber(1.0), actual) | ||
} | ||
} | ||
|
||
@Test | ||
fun test_plus() { | ||
with(ops) { | ||
// given | ||
val l = BasicNumber(1.0) | ||
val r = BasicNumber(2.0) | ||
|
||
// when | ||
val actual = l + r | ||
|
||
// then | ||
assertEquals(BasicNumber(1.0 + 2.0), actual) | ||
} | ||
} | ||
|
||
@Test | ||
fun test_minus() { | ||
with(ops) { | ||
// given | ||
val l = BasicNumber(1.0) | ||
val r = BasicNumber(2.0) | ||
|
||
// when | ||
val actual = l - r | ||
|
||
// then | ||
assertEquals(BasicNumber(1.0 - 2.0), actual) | ||
} | ||
} | ||
|
||
@Test | ||
fun test_times() { | ||
with(ops) { | ||
// given | ||
val l = BasicNumber(1.0) | ||
val r = BasicNumber(2.0) | ||
|
||
// when | ||
val actual = l * r | ||
|
||
// then | ||
assertEquals(BasicNumber(1.0 * 2.0), actual) | ||
} | ||
} | ||
|
||
@Test | ||
fun test_div() { | ||
with(ops) { | ||
// given | ||
val l = BasicNumber(1.0) | ||
val r = BasicNumber(2.0) | ||
|
||
// when | ||
val actual = l / r | ||
|
||
// then | ||
assertEquals(BasicNumber(1.0 / 2.0), actual) | ||
} | ||
} | ||
|
||
@Test | ||
fun test_matDiv() { | ||
with(ops) { | ||
// given | ||
val lhs = BasicMatrixFixture.make( | ||
2, 2, arrayOf( | ||
2.0, 0.0, | ||
0.0, 4.0, | ||
) | ||
) | ||
val rhs = BasicMatrixFixture.make( | ||
2, 3, arrayOf( | ||
1.0, 0.0, 0.0, | ||
0.0, 2.0, 0.0, | ||
) | ||
) | ||
|
||
// when | ||
val actual = rhs.matDiv(lhs)!! | ||
|
||
// then | ||
assertBasicMatrixEqual( | ||
actual, arrayOf( | ||
0.5, 0.0, 0.0, | ||
0.0, 0.5, 0.0, | ||
) | ||
) | ||
} | ||
} | ||
|
||
@Test | ||
fun test_matDiv_whenZeroCols() { | ||
with(ops) { | ||
// given | ||
val lhs = zeros(3, 0) | ||
val rhs = BasicMatrixFixture.make( | ||
3, 1, arrayOf( | ||
1.0, | ||
2.0, | ||
3.0 | ||
) | ||
) | ||
|
||
// when | ||
val actual = rhs.matDiv(lhs)!! | ||
|
||
// then | ||
Assert.assertEquals(actual.rowDim(), 0) | ||
Assert.assertEquals(actual.colDim(), 1) | ||
} | ||
} | ||
|
||
@Test | ||
fun test_matDiv_whenNonInvertible() { | ||
with(ops) { | ||
// given | ||
val lhs = BasicMatrixFixture.make( | ||
3, 3, arrayOf( | ||
1.0, -2.0, 0.0, | ||
0.0, 1.0, 0.0, | ||
0.0, 0.0, 0.0, | ||
) | ||
) | ||
val rhs = BasicMatrixFixture.make( | ||
3, 2, arrayOf( | ||
1.0, 4.0, | ||
1.0, 2.0, | ||
1.0, 1.0, | ||
) | ||
) | ||
|
||
// when | ||
val actual = rhs.matDiv(lhs) | ||
|
||
// then | ||
assertNull(actual) | ||
} | ||
} | ||
|
||
|
||
private fun assertBasicMatrixEqual(actual: BasicMatrix, expected: Array<Double>) { | ||
with(ops) { | ||
Assert.assertEquals(expected.size, actual.rowDim() * actual.colDim()) | ||
for (row in 0 until actual.rowDim()) { | ||
for (col in 0 until actual.colDim()) { | ||
Assert.assertEquals( | ||
"(row ${row}, col ${col}):", | ||
expected[row * actual.colDim() + col], | ||
actual[row, col].value, | ||
precision | ||
) | ||
} | ||
} | ||
} | ||
} | ||
} |
23 changes: 23 additions & 0 deletions
23
src/test/kotlin/ch/kleis/lcaplugin/core/matrix/BasicMatrixFixture.kt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
package ch.kleis.lcaplugin.core.matrix | ||
|
||
import ch.kleis.lcaplugin.core.math.basic.BasicMatrix | ||
import ch.kleis.lcaplugin.core.math.basic.BasicOperations | ||
|
||
class BasicMatrixFixture { | ||
companion object { | ||
private val ops = BasicOperations() | ||
fun make(rows: Int, cols: Int, data: Array<Double>): BasicMatrix { | ||
with(ops) { | ||
val a = zeros(rows, cols) | ||
for (row in 0 until rows) { | ||
for (col in 0 until cols) { | ||
if(data[cols * row + col] != 0.0) { | ||
a.add(row, col, pure(data[cols * row + col])) | ||
} | ||
} | ||
} | ||
return a | ||
} | ||
} | ||
} | ||
} |