From de311d1bb5f55c2ec44bddcb59cade48beae687e Mon Sep 17 00:00:00 2001 From: Jens Pots Date: Thu, 4 Jul 2024 16:10:16 +0200 Subject: [PATCH] feat: type safe argument parsing using reflection --- src/main/kotlin/runner/jvm/Arguments.kt | 130 ++++++++++++++++++++ src/test/kotlin/runner/jvm/ArgumentsTest.kt | 125 +++++++++++++++++++ 2 files changed, 255 insertions(+) create mode 100644 src/main/kotlin/runner/jvm/Arguments.kt create mode 100644 src/test/kotlin/runner/jvm/ArgumentsTest.kt diff --git a/src/main/kotlin/runner/jvm/Arguments.kt b/src/main/kotlin/runner/jvm/Arguments.kt new file mode 100644 index 0000000..46a8c35 --- /dev/null +++ b/src/main/kotlin/runner/jvm/Arguments.kt @@ -0,0 +1,130 @@ +package technology.idlab.runner.jvm + +import kotlin.reflect.KType +import kotlin.reflect.full.isSubclassOf +import kotlin.reflect.full.isSuperclassOf +import kotlin.reflect.jvm.jvmErasure +import kotlin.reflect.typeOf +import technology.idlab.util.Log + +/** + * Recursively check if a value corresponds to a given KType. This function can be run either + * strictly or loosely. This is due to type erasure, which means we must cast the value to the KType + * and deal with type parameters manually. For example, for Pair, we call the function + * recursively for both the `first` as `second` data field. For List, we call the function for + * each element of the list. If a given container is not supported in that fashion, the result will + * depend on the strict parameter. + */ +fun safeCast(to: KType, from: Any, strict: Boolean = false): Boolean { + // The base case, where the Any type matches everything. + if (to.jvmErasure == Any::class) { + return true + } + + // The requested type must be an actual superclass of the value given. + if (!to.jvmErasure.isSuperclassOf(from::class)) { + return false + } + + // Retrieve the type arguments. If these are empty, then we can safely assume that the type is + // cast correctly and safely. + val typeArguments = to.arguments + if (typeArguments.isEmpty()) { + return true + } + + // If the value is pair, check both first and second. + if (to.jvmErasure == Pair::class) { + if (!from::class.isSubclassOf(Pair::class)) { + return false + } + + // Extract pair and the type arguments. + @Suppress("UNCHECKED_CAST") val pair = from as Pair + val first = to.arguments[0].type!! + val second = to.arguments[1].type!! + + return safeCast(first, pair.first) && safeCast(second, pair.second) + } + + // If the value is a list, check all elements. + if (to.jvmErasure == List::class) { + if (!from::class.isSubclassOf(List::class)) { + return false + } + + // Extract values. + @Suppress("UNCHECKED_CAST") val list = from as List + val elementType = to.arguments[0].type!! + return list.all { safeCast(elementType, it, strict) } + } + + // We will never be able to exhaustively go over all types, due to type erasure. However, we're + // if the user is okay with non-strict type checking, we may end here. + return !strict +} + +data class Arguments( + val args: Map>, +) { + /** + * Get an argument in a type safe way. The type parameter, either inferred or explicitly given, + * will be used to recursively check the resulting type. Note that if you want to retrieve an + * argument with type T which has Argument.Count.REQUIRED, you can either request type T directly + * or the list with one element using the List type. + */ + inline operator fun get(name: String, strict: Boolean = false): T { + val type = typeOf() + + // Retrieve the value from the map. + val argumentList = + this.args[name] + ?: if (type.isMarkedNullable) { + return null as T + } else { + Log.shared.fatal("Argument $name is missing") + } + + // Special case: check if the type is not a list, because in that case, we would need to get the + // first element instead. + val arg = + if (T::class.isSuperclassOf(List::class)) { + argumentList + } else { + if (argumentList.size != 1) { + Log.shared.fatal("Cannot obtain single argument if there is not exactly one value.") + } + + argumentList[0] + } + + if (safeCast(type, arg, strict)) { + return arg as T + } else { + Log.shared.fatal("Could not parse $name to ${T::class.simpleName}") + } + } + + companion object { + /** + * Parse a (nested) map into type-safe arguments. This method calls itself recursively for all + * values which are maps as well. + */ + fun from(args: Map>): Arguments { + return Arguments( + args.mapValues { (_, list) -> + list.map { arg -> + if (arg::class.isSubclassOf(Map::class)) { + if (safeCast(typeOf>>(), arg)) { + @Suppress("UNCHECKED_CAST") Arguments.from(arg as Map>) + } else { + Log.shared.fatal("Cannot have raw maps in arguments.") + } + } else { + arg + } + } + }) + } + } +} diff --git a/src/test/kotlin/runner/jvm/ArgumentsTest.kt b/src/test/kotlin/runner/jvm/ArgumentsTest.kt new file mode 100644 index 0000000..3e712a8 --- /dev/null +++ b/src/test/kotlin/runner/jvm/ArgumentsTest.kt @@ -0,0 +1,125 @@ +package runner.jvm + +import kotlin.test.assertEquals +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertThrows +import technology.idlab.exception.RunnerException +import technology.idlab.runner.jvm.Arguments + +class ArgumentsTest { + @Test + fun single() { + val args = Arguments(mapOf("key" to listOf("value"))) + assertEquals("value", args.get("key")) + } + + @Test + fun notSingle() { + val args = Arguments(mapOf("key" to listOf("value1", "value2"))) + assertThrows { args.get("key") } + } + + @Test + fun singleList() { + val args = Arguments(mapOf("key" to listOf("value"))) + assertEquals(listOf("value"), args.get>("key")) + } + + @Test + fun longList() { + val args = Arguments(mapOf("key" to listOf("value1", "value2"))) + assertEquals(listOf("value1", "value2"), args.get>("key")) + } + + @Test + fun longListWrong() { + val args = Arguments(mapOf("key" to listOf("value1", "value2"))) + + assertThrows { args.get>("key", strict = true) } + } + + @Test + fun nullable() { + val args = Arguments(mapOf()) + assertEquals(null, args.get("key")) + } + + @Test + fun nonNullable() { + val args = Arguments(mapOf()) + assertThrows { args.get("key") } + } + + @Test + fun invalidCast() { + val args = Arguments(mapOf("key" to listOf("value"))) + assertThrows { args.get("key") } + } + + @Test + fun pairs() { + val args = Arguments(mapOf("first" to listOf(Pair(1, "a")), "second" to listOf(Pair(2, "b")))) + + // Get first pair correctly. + val first = args.get>("first") + assertEquals(1, first.first) + assertEquals("a", first.second) + + // Get second pair correctly, use operator syntax. + val second: Pair = args["second"] + assertEquals(2, second.first) + assertEquals("b", second.second) + + // Get first pair as a list. + val firstList = args.get>>("first") + assertEquals(1, firstList[0].first) + assertEquals("a", firstList[0].second) + + // Same for second, use operator syntax. + val secondList: List> = args["second"] + assertEquals(2, secondList[0].first) + assertEquals("b", secondList[0].second) + + // Attempt to get integer as double, in strict mode. + assertThrows { args.get>("first", strict = true) } + + // Attempt to get string as integer. + assertThrows { args.get>("first") } + } + + @Test + fun nested() { + val args = Arguments.from(mapOf("root" to listOf(mapOf("leaf" to listOf("Hello, World!"))))) + + val value = args.get("root").get("leaf") + assertEquals("Hello, World!", value) + } + + @Test + fun inheritance() { + // The base class. + open class A + + // The extended class. + open class B : A() + + // The extended, extended class. + class C : B() + + // Create three arguments, each with the lists. + val args = + Arguments(mapOf("a" to listOf(A(), A()), "b" to listOf(B(), B()), "c" to listOf(C(), C()))) + + assertEquals(2, args.get>("a", strict = true).size) + assertEquals(2, args.get>("b", strict = true).size) + assertEquals(2, args.get>("c", strict = true).size) + + assertThrows { args.get>("a", strict = true) } + assertEquals(2, args.get>("b", strict = true).size) + assertEquals(2, args.get>("c", strict = true).size) + + assertThrows { args.get>("a", strict = true) } + assertThrows { args.get>("b", strict = true) } + assertEquals(2, args.get>("c", strict = true).size) + } +}