Skip to content

Commit

Permalink
introduced operations interface + basic impl
Browse files Browse the repository at this point in the history
  • Loading branch information
Peva Blanchard committed Aug 17, 2023
1 parent 22c7928 commit 501f46e
Show file tree
Hide file tree
Showing 6 changed files with 342 additions and 0 deletions.
43 changes: 43 additions & 0 deletions src/main/kotlin/ch/kleis/lcaplugin/core/math/Operations.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package ch.kleis.lcaplugin.core.math

interface Operations<Q, M> {
/*
Quantities
*/
operator fun Q.plus(other: Q): Q
operator fun Q.minus(other: Q): Q
operator fun Q.times(other: Q): Q
operator fun Q.div(other: Q): Q

fun pure(value: Double): Q
operator fun Q.plus(other: Double): Q = this + pure(other)
operator fun Q.minus(other: Double): Q = this - pure(other)
operator fun Q.times(other: Double): Q = this * pure(other)
operator fun Q.div(other: Double): Q = this / pure(other)

/*
Matrices
*/

fun zeros(
rowDim: Int, colDim: Int,
): M

fun M.rowDim(): Int
fun M.colDim(): Int

fun M.negate(): M
fun M.transpose(): M

operator fun M.get(row: Int, col: Int): Q
operator fun M.set(row: Int, col: Int, value: Q)

fun M.matMul(other: M): M
fun M.matDiv(other: M): M?
fun M.matTransposeDiv(other: M): M? {
return this.transpose().matDiv(other.transpose())?.transpose()
}
fun M.add(row: Int, col: Int, value: Q) {
this[row, col] = this[row, col] + value
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package ch.kleis.lcaplugin.core.math.basic

import org.ejml.simple.SimpleMatrix

data class BasicMatrix(internal val inner: SimpleMatrix)
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package ch.kleis.lcaplugin.core.math.basic

data class BasicNumber(val value: Double)
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package ch.kleis.lcaplugin.core.math.basic

import ch.kleis.lcaplugin.core.math.Operations
import com.intellij.openapi.diagnostic.Logger
import org.ejml.data.DMatrixSparseCSC
import org.ejml.data.MatrixType
import org.ejml.simple.SimpleMatrix
import org.ejml.sparse.csc.CommonOps_DSCC

class BasicOperations : Operations<BasicNumber, BasicMatrix> {
companion object {
private val LOG = Logger.getInstance(BasicOperations::class.java)
}

override fun BasicNumber.plus(other: BasicNumber): BasicNumber {
return BasicNumber(value + other.value)
}

override fun BasicNumber.minus(other: BasicNumber): BasicNumber {
return BasicNumber(value - other.value)
}

override fun BasicNumber.times(other: BasicNumber): BasicNumber {
return BasicNumber(value * other.value)
}

override fun BasicNumber.div(other: BasicNumber): BasicNumber {
return BasicNumber(value / other.value)
}

override fun BasicMatrix.matMul(other: BasicMatrix): BasicMatrix {
val result = DMatrixSparseCSC(this.rowDim(), other.colDim())
CommonOps_DSCC.mult(inner.dscc, other.inner.dscc, result)
return BasicMatrix(
SimpleMatrix.wrap(result)
)
}

override fun BasicMatrix.matDiv(other: BasicMatrix): BasicMatrix? {
LOG.info("Start solving lhs(${other.rowDim()}, ${other.rowDim()}) * x = rhs(${this.rowDim()}, ${this.colDim()})")
val lhs = other.inner.dscc
val rhs = this.inner.dscc
val result = DMatrixSparseCSC(lhs.numCols, rhs.numCols)
return if (CommonOps_DSCC.solve(lhs, rhs, result)) {
LOG.info("End solving with result(${result.numRows}, ${result.numCols})")
BasicMatrix(
SimpleMatrix.wrap(result)
)
} else {
null
}
}

override fun pure(value: Double): BasicNumber {
return BasicNumber(value)
}

override fun zeros(rowDim: Int, colDim: Int): BasicMatrix {
return BasicMatrix(SimpleMatrix(rowDim, colDim, MatrixType.DSCC))
}

override fun BasicMatrix.colDim(): Int {
return this.inner.numCols
}

override fun BasicMatrix.rowDim(): Int {
return this.inner.numRows
}

override fun BasicMatrix.set(row: Int, col: Int, value: BasicNumber) {
this.inner[row, col] = value.value
}

override fun BasicMatrix.get(row: Int, col: Int): BasicNumber {
return pure(this.inner[row, col])
}

override fun BasicMatrix.transpose(): BasicMatrix {
return BasicMatrix(inner.transpose())
}

override fun BasicMatrix.negate(): BasicMatrix {
return BasicMatrix(inner.negative())
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
package ch.kleis.lcaplugin.core.math.basic

import ch.kleis.lcaplugin.core.matrix.BasicMatrixFixture
import org.junit.Assert
import org.junit.Test
import kotlin.test.assertEquals
import kotlin.test.assertNull


class BasicOperationsTest {
private val precision = 1e-6
private val ops = BasicOperations()

@Test
fun test_pure() {
with(ops) {
// given
val v = 1.0

// when
val actual = pure(v)

// then
assertEquals(BasicNumber(1.0), actual)
}
}

@Test
fun test_plus() {
with(ops) {
// given
val l = BasicNumber(1.0)
val r = BasicNumber(2.0)

// when
val actual = l + r

// then
assertEquals(BasicNumber(1.0 + 2.0), actual)
}
}

@Test
fun test_minus() {
with(ops) {
// given
val l = BasicNumber(1.0)
val r = BasicNumber(2.0)

// when
val actual = l - r

// then
assertEquals(BasicNumber(1.0 - 2.0), actual)
}
}

@Test
fun test_times() {
with(ops) {
// given
val l = BasicNumber(1.0)
val r = BasicNumber(2.0)

// when
val actual = l * r

// then
assertEquals(BasicNumber(1.0 * 2.0), actual)
}
}

@Test
fun test_div() {
with(ops) {
// given
val l = BasicNumber(1.0)
val r = BasicNumber(2.0)

// when
val actual = l / r

// then
assertEquals(BasicNumber(1.0 / 2.0), actual)
}
}

@Test
fun test_matDiv() {
with(ops) {
// given
val lhs = BasicMatrixFixture.make(
2, 2, arrayOf(
2.0, 0.0,
0.0, 4.0,
)
)
val rhs = BasicMatrixFixture.make(
2, 3, arrayOf(
1.0, 0.0, 0.0,
0.0, 2.0, 0.0,
)
)

// when
val actual = rhs.matDiv(lhs)!!

// then
assertBasicMatrixEqual(
actual, arrayOf(
0.5, 0.0, 0.0,
0.0, 0.5, 0.0,
)
)
}
}

@Test
fun test_matDiv_whenZeroCols() {
with(ops) {
// given
val lhs = zeros(3, 0)
val rhs = BasicMatrixFixture.make(
3, 1, arrayOf(
1.0,
2.0,
3.0
)
)

// when
val actual = rhs.matDiv(lhs)!!

// then
Assert.assertEquals(actual.rowDim(), 0)
Assert.assertEquals(actual.colDim(), 1)
}
}

@Test
fun test_matDiv_whenNonInvertible() {
with(ops) {
// given
val lhs = BasicMatrixFixture.make(
3, 3, arrayOf(
1.0, -2.0, 0.0,
0.0, 1.0, 0.0,
0.0, 0.0, 0.0,
)
)
val rhs = BasicMatrixFixture.make(
3, 2, arrayOf(
1.0, 4.0,
1.0, 2.0,
1.0, 1.0,
)
)

// when
val actual = rhs.matDiv(lhs)

// then
assertNull(actual)
}
}


private fun assertBasicMatrixEqual(actual: BasicMatrix, expected: Array<Double>) {
with(ops) {
Assert.assertEquals(expected.size, actual.rowDim() * actual.colDim())
for (row in 0 until actual.rowDim()) {
for (col in 0 until actual.colDim()) {
Assert.assertEquals(
"(row ${row}, col ${col}):",
expected[row * actual.colDim() + col],
actual[row, col].value,
precision
)
}
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package ch.kleis.lcaplugin.core.matrix

import ch.kleis.lcaplugin.core.math.basic.BasicMatrix
import ch.kleis.lcaplugin.core.math.basic.BasicOperations

class BasicMatrixFixture {
companion object {
private val ops = BasicOperations()
fun make(rows: Int, cols: Int, data: Array<Double>): BasicMatrix {
with(ops) {
val a = zeros(rows, cols)
for (row in 0 until rows) {
for (col in 0 until cols) {
if(data[cols * row + col] != 0.0) {
a.add(row, col, pure(data[cols * row + col]))
}
}
}
return a
}
}
}
}

0 comments on commit 501f46e

Please sign in to comment.