From 173ad949385ae5733332a15b08245c7030b87a11 Mon Sep 17 00:00:00 2001 From: goncalo-frade-iohk <87179681+goncalo-frade-iohk@users.noreply.github.com> Date: Tue, 26 Sep 2023 10:13:08 +0100 Subject: [PATCH] feat: add mnemonic validation (#94) --- .../iohk/atala/prism/apollo/utils/Mnemonic.kt | 16 ++++++++-- .../atala/prism/apollo/utils/MnemonicTests.kt | 29 +++++++++++++++---- 2 files changed, 37 insertions(+), 8 deletions(-) diff --git a/base-asymmetric-encryption/src/commonMain/kotlin/io/iohk/atala/prism/apollo/utils/Mnemonic.kt b/base-asymmetric-encryption/src/commonMain/kotlin/io/iohk/atala/prism/apollo/utils/Mnemonic.kt index 223cf5081..7c2344278 100644 --- a/base-asymmetric-encryption/src/commonMain/kotlin/io/iohk/atala/prism/apollo/utils/Mnemonic.kt +++ b/base-asymmetric-encryption/src/commonMain/kotlin/io/iohk/atala/prism/apollo/utils/Mnemonic.kt @@ -16,15 +16,25 @@ final class Mnemonic { private const val PBKDF2C = 2048 private const val PBKDF2DKLen = 64 + class InvalidMnemonicCode(code: String) : RuntimeException(code) + + fun isValidMnemonicCode(code: Array): Boolean { + return code.all { it in MnemonicCodeEnglish.wordList } + } fun createRandomMnemonics(): Array { val entropyBytes = SecureRandom.generateSeed(SEED_ENTROPY_BITS_24_WORDS / 8) return MnemonicCode(MnemonicCodeEnglish.wordList.toTypedArray()).toMnemonic(entropyBytes) } - fun createSeed(mnemonics: String, passphrase: String = "AtalaPrism"): ByteArray { + @Throws(InvalidMnemonicCode::class) + fun createSeed(mnemonics: Array, passphrase: String = "AtalaPrism"): ByteArray { + if (!isValidMnemonicCode(mnemonics)) { + throw InvalidMnemonicCode(mnemonics.joinToString(separator = " ")) + } + val mnemonicString = mnemonics.joinToString(separator = " ") return PBKDF2SHA512.derive( - mnemonics, + mnemonicString, passphrase, PBKDF2C, PBKDF2DKLen @@ -32,7 +42,7 @@ final class Mnemonic { } fun createRandomSeed(passphrase: String = "AtalaPrism"): ByteArray { - val mnemonics = this.createRandomMnemonics().joinToString(separator = " ") + val mnemonics = this.createRandomMnemonics() return this.createSeed(mnemonics, passphrase) } } diff --git a/base-asymmetric-encryption/src/commonTest/kotlin/io/iohk/atala/prism/apollo/utils/MnemonicTests.kt b/base-asymmetric-encryption/src/commonTest/kotlin/io/iohk/atala/prism/apollo/utils/MnemonicTests.kt index dbaa5cced..c6009110d 100644 --- a/base-asymmetric-encryption/src/commonTest/kotlin/io/iohk/atala/prism/apollo/utils/MnemonicTests.kt +++ b/base-asymmetric-encryption/src/commonTest/kotlin/io/iohk/atala/prism/apollo/utils/MnemonicTests.kt @@ -4,11 +4,20 @@ import io.iohk.atala.prism.apollo.hashing.internal.toHexString import kotlin.test.Test import kotlin.test.assertContains import kotlin.test.assertEquals +import kotlin.test.assertFailsWith +import kotlin.test.assertFalse class MnemonicTests { + + @Test + fun testValidateMnemonics() { + val invalidMnemonics = arrayOf("abc", "ddd", "inv") + assertFalse(Mnemonic.isValidMnemonicCode(invalidMnemonics)) + } + @Test fun testCreateRandomMnemonics() { - val mnemonics = Mnemonic.createRandomMnemonics().joinToString(separator = " ") + val mnemonics = Mnemonic.createRandomMnemonics() val seed = Mnemonic.createSeed(mnemonics) assertEquals(seed.size, 64) } @@ -21,24 +30,34 @@ class MnemonicTests { @Test fun testCreateSeed() { - val mnemonics = "random seed mnemonic words" + val mnemonics = arrayOf("adjust", "animal", "anger", "around") val seed = Mnemonic.createSeed(mnemonics) assertEquals(seed.size, 64) val privateKey = seed.slice(IntRange(0, 31)) - assertContains(privateKey.toByteArray().toHexString(), "feac83cecc84531575eb67250a03d8ac112d4d6678674968bf3f6576ad028ae3") + assertContains(privateKey.toByteArray().toHexString(), "a078d8a0f3beca52ef17a1d0279eb6e9c410cd3837d3db38d31e43df6da95ac6") + } + + @Test + fun testCreateSeedInvalidMnemonics() { + val mnemonics = arrayOf("abc", "ddd", "adsada", "testing") + + assertFailsWith { + Mnemonic.createSeed(mnemonics) + } } @Test fun testCreateSeedWithPW() { - val mnemonics = "random seed mnemonic words" + val mnemonics = arrayOf("adjust", "animal", "anger", "around") val password = "123456" val seed = Mnemonic.createSeed(mnemonics, password) assertEquals(seed.size, 64) val privateKey = seed.slice(IntRange(0, 31)) - assertContains(privateKey.toByteArray().toHexString(), "b3a8af66eca002e8b4ca868c5b55a8c865f15e0cfea483d6a164a6fbecf83625") + + assertContains(privateKey.toByteArray().toHexString(), "815b70655ca4c9675f5fc15fe8f82315f07521d034eec45bf4d5912bd3a61218") } }