Skip to content

Commit

Permalink
Implement Kotlin-style DSL for DER encoding
Browse files Browse the repository at this point in the history
  • Loading branch information
nodh committed Sep 21, 2023
1 parent ef330bb commit b23f104
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 96 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package at.asitplus.wallet.lib.asn1

import at.asitplus.wallet.lib.CryptoPublicKey
import at.asitplus.wallet.lib.jws.JwsAlgorithm
import at.asitplus.wallet.lib.jws.JwsExtensions.encodeToByteArray
import at.asitplus.wallet.lib.jws.TbsCertificate
import io.matthewnelson.encoding.base16.Base16
import io.matthewnelson.encoding.core.Decoder.Companion.decodeToByteArray
import kotlinx.datetime.Instant


fun tag(tag: Int, block: () -> ByteArray): ByteArray {
val value = block()
return byteArrayOf(tag.toByte()) + value.size.encodeLength() + value
}

fun long(block: () -> Long) = tag(0x02) { block().encodeToByteArray().dropWhile { it == 0.toByte() }.toByteArray() }

fun int(block: () -> Int) = tag(0x02) { block().encodeToByteArray().dropWhile { it == 0.toByte() }.toByteArray() }

fun bitString(block: () -> ByteArray) = tag(0x03) { (byteArrayOf(0x00) + block()) }

fun oid(block: () -> String): ByteArray = tag(0x06) { block().decodeToByteArray(Base16()) }

fun sequence(block: () -> List<ByteArray>) = tag(0x30) { block().fold(byteArrayOf()) { acc, bytes -> acc + bytes } }

fun set(block: () -> List<ByteArray>) = tag(0x31) { block().fold(byteArrayOf()) { acc, bytes -> acc + bytes } }

fun utf8String(block: () -> String) = tag(0x0c) { block().encodeToByteArray() }

fun commonName(block: () -> String) = oid { "550403" } + utf8String { block() }

fun subjectPublicKey(block: () -> CryptoPublicKey) = when (val value = block()) {
is CryptoPublicKey.Ec -> value.encodeToDer()
}

fun utcTime(block: () -> Instant): ByteArray {
val value = block()
val matchResult =
Regex("[0-9]{2}([0-9]{2})-([0-9]{2})-([0-9]{2})T([0-9]{2}):([0-9]{2}):([0-9]{2})\\.([0-9]+)Z")
.matchEntire(value.toString())
?: throw IllegalArgumentException("instant serialization failed: ${value}")
val year =
matchResult.groups[1]?.value
?: throw IllegalArgumentException("instant serialization year failed: ${value}")
val month =
matchResult.groups[2]?.value
?: throw IllegalArgumentException("instant serialization month failed: ${value}")
val day =
matchResult.groups[3]?.value ?: throw IllegalArgumentException("instant serialization day failed: ${value}")
val hour =
matchResult.groups[4]?.value
?: throw IllegalArgumentException("instant serialization hour failed: ${value}")
val minute =
matchResult.groups[5]?.value
?: throw IllegalArgumentException("instant serialization minute failed: ${value}")
val seconds =
matchResult.groups[6]?.value
?: throw IllegalArgumentException("instant serialization seconds failed: ${value}")
val string = "$year$month$day$hour$minute${seconds}Z"
return tag(0x17) { string.encodeToByteArray() }
}

fun tbsCertificate(block: () -> TbsCertificate) = block().encodeToDer()

fun sigAlg(block: () -> JwsAlgorithm): ByteArray = when (val value = block()) {
JwsAlgorithm.ES256 -> sequence { listOf(oid { "2A8648CE3D040302" }) }
else -> throw IllegalArgumentException("sigAlg: $value")
}

private fun CryptoPublicKey.Ec.encodeToDer(): ByteArray {
val ecKeyTag = oid { "2A8648CE3D0201" }
val ecEncryptionNullTag = oid { "2A8648CE3D030107" }
val content = bitString { (byteArrayOf(0x04.toByte()) + x + y) }
return sequence { listOf(sequence { listOf(ecKeyTag, ecEncryptionNullTag) }, content) }
}

private fun Int.encodeLength(): ByteArray {
if (this < 128) {
return byteArrayOf(this.toByte())
}
if (this < 0x100) {
return byteArrayOf(0x81.toByte(), this.toByte())
}
if (this < 0x8000) {
return byteArrayOf(0x82.toByte(), (this ushr 8).toByte(), this.toByte())
}
throw IllegalArgumentException("length $this")
}
Original file line number Diff line number Diff line change
@@ -1,42 +1,80 @@
package at.asitplus.wallet.lib.jws

import at.asitplus.wallet.lib.CryptoPublicKey
import at.asitplus.wallet.lib.jws.JwsExtensions.encodeToByteArray
import io.matthewnelson.encoding.base16.Base16
import io.matthewnelson.encoding.core.Decoder.Companion.decodeToByteArray
import at.asitplus.wallet.lib.asn1.bitString
import at.asitplus.wallet.lib.asn1.commonName
import at.asitplus.wallet.lib.asn1.int
import at.asitplus.wallet.lib.asn1.long
import at.asitplus.wallet.lib.asn1.sequence
import at.asitplus.wallet.lib.asn1.set
import at.asitplus.wallet.lib.asn1.sigAlg
import at.asitplus.wallet.lib.asn1.subjectPublicKey
import at.asitplus.wallet.lib.asn1.tag
import at.asitplus.wallet.lib.asn1.tbsCertificate
import at.asitplus.wallet.lib.asn1.utcTime
import kotlinx.datetime.Instant

/**
* Very simple implementation of the meat of an X.509 Certificate:
* The structure that gets signed
*/
data class TbsCertificate(
val version: Int = 2,
val serialNumber: Long,
val signatureAlgorithm: JwsAlgorithm,
val issuer: String,
val issuerCommonName: String,
val validFrom: Instant,
val validUntil: Instant,
val subject: String,
val subjectPublicKey: CryptoPublicKey
val subjectCommonName: String,
val publicKey: CryptoPublicKey
) {
fun encodeToDer(): ByteArray {
return (version.encodeAsVersion() +
serialNumber.encodeToDer() +
signatureAlgorithm.encodeToDer() +
issuer.encodeAsCommonName() +
(validFrom.encodeToDer() + validUntil.encodeToDer()).sequence() +
subject.encodeAsCommonName() +
subjectPublicKey.encodeToDer())
.sequence()
fun encodeToDer() = sequence {
listOf(
tag(0xA0) {
int { version }
},
long { serialNumber },
sigAlg { signatureAlgorithm },
sequence {
listOf(set {
listOf(sequence {
listOf(commonName { issuerCommonName })
})
})
},
sequence {
listOf(
utcTime { validFrom },
utcTime { validUntil }
)
},
sequence {
listOf(set {
listOf(sequence {
listOf(commonName { subjectCommonName })
})
})
},
subjectPublicKey { publicKey }
)
}

}

/**
* Very simple implementation of an X.509 Certificate
*/
data class X509Certificate(
val tbsCertificate: TbsCertificate,
val signatureAlgorithm: JwsAlgorithm,
val signature: ByteArray
) {
fun encodeToDer(): ByteArray {
return (tbsCertificate.encodeToDer() +
signatureAlgorithm.encodeToDer() +
signature.encodeAsBitString()).sequence()
fun encodeToDer() = sequence {
listOf(
tbsCertificate { tbsCertificate },
sigAlg { signatureAlgorithm },
bitString { signature }
)
}

override fun equals(other: Any?): Boolean {
Expand All @@ -60,74 +98,3 @@ data class X509Certificate(
}
}

private fun String.encodeAsCommonName(): ByteArray {
return ("550403".decodeToByteArray(Base16()).oid() + this.encodeToDer()).sequence().set().sequence()
}

private fun Int.encodeAsVersion(): ByteArray = encodeToDer().wrapInAsn1Tag(0xA0.toByte())

private fun Int.encodeToDer(): ByteArray =
encodeToByteArray().dropWhile { it == 0.toByte() }.toByteArray().wrapInAsn1Tag(0x02)

private fun Long.encodeToDer(): ByteArray =
encodeToByteArray().dropWhile { it == 0.toByte() }.toByteArray().wrapInAsn1Tag(0x02)

private fun CryptoPublicKey.encodeToDer(): ByteArray = when (this) {
is CryptoPublicKey.Ec -> this.encodeToDer()
}

private fun CryptoPublicKey.Ec.encodeToDer(): ByteArray {
val ecKeyTag = "2A8648CE3D0201".decodeToByteArray(Base16()).oid()
val ecEncryptionNullTag = "2A8648CE3D030107".decodeToByteArray(Base16()).oid()
val content = (byteArrayOf(0x04.toByte()) + x + y).encodeAsBitString()
return ((ecKeyTag + ecEncryptionNullTag).sequence() + content).sequence()
}

private fun ByteArray.encodeAsBitString(): ByteArray = (byteArrayOf(0x00) + this).wrapInAsn1Tag(0x03)

private fun String.encodeToDer(): ByteArray = this.encodeToByteArray().wrapInAsn1Tag(0x0c)

private fun Instant.encodeToDer(): ByteArray {
val matchResult =
Regex("[0-9]{2}([0-9]{2})-([0-9]{2})-([0-9]{2})T([0-9]{2}):([0-9]{2}):([0-9]{2})\\.([0-9]+)Z")
.matchEntire(toString())
?: throw IllegalArgumentException("instant serialization failed: $this")
val year = matchResult.groups[1]?.value ?: throw IllegalArgumentException("instant serialization year failed: $this")
val month = matchResult.groups[2]?.value ?: throw IllegalArgumentException("instant serialization month failed: $this")
val day = matchResult.groups[3]?.value ?: throw IllegalArgumentException("instant serialization day failed: $this")
val hour = matchResult.groups[4]?.value ?: throw IllegalArgumentException("instant serialization hour failed: $this")
val minute = matchResult.groups[5]?.value ?: throw IllegalArgumentException("instant serialization minute failed: $this")
val seconds = matchResult.groups[6]?.value ?: throw IllegalArgumentException("instant serialization seconds failed: $this")
val string = "$year$month$day$hour$minute${seconds}Z"
return string.encodeToByteArray().wrapInAsn1Tag(0x17)
}

private fun JwsAlgorithm.encodeToDer(): ByteArray {
return when (this) {
JwsAlgorithm.ES256 -> "2A8648CE3D040302".decodeToByteArray(Base16()).oid().sequence()
else -> TODO()
}
}

private fun ByteArray.sequence() = this.wrapInAsn1Tag(0x30)

private fun ByteArray.set() = this.wrapInAsn1Tag(0x31)

private fun ByteArray.oid() = this.wrapInAsn1Tag(0x06)

private fun ByteArray.wrapInAsn1Tag(tag: Byte): ByteArray {
return byteArrayOf(tag) + this.size.encodeLength() + this
}

private fun Int.encodeLength(): ByteArray {
if (this < 128) {
return byteArrayOf(this.toByte())
}
if (this < 0x100) {
return byteArrayOf(0x81.toByte(), this.toByte())
}
if (this < 0x8000) {
return byteArrayOf(0x82.toByte(), (this ushr 8).toByte(), this.toByte())
}
throw IllegalArgumentException("length $this")
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ import kotlinx.cinterop.allocArrayOf
import kotlinx.cinterop.get
import kotlinx.cinterop.memScoped
import kotlinx.cinterop.reinterpret
import kotlinx.coroutines.runBlocking
import kotlinx.datetime.Clock
import kotlinx.datetime.DateTimeUnit
import kotlinx.datetime.plus
Expand Down Expand Up @@ -73,9 +72,22 @@ actual class DefaultCryptoService : CryptoService {
val publicKeyData = SecKeyCopyExternalRepresentation(publicKey, null)
val data = CFBridgingRelease(publicKeyData) as NSData
this.cryptoPublicKey = CryptoPublicKey.Ec.fromAnsiX963Bytes(EcCurve.SECP_256_R_1, data.toByteArray())!!
val tbsCertificate = TbsCertificate(version = 2, serialNumber = 3, signatureAlgorithm = JwsAlgorithm.ES256, issuer = "SelfSigned", validFrom = Clock.System.now(), validUntil = Clock.System.now().plus(10, DateTimeUnit.MINUTE), subject = "SelfSigned", subjectPublicKey = cryptoPublicKey)
val tbsCertificate = TbsCertificate(
version = 2,
serialNumber = 3,
signatureAlgorithm = JwsAlgorithm.ES256,
issuerCommonName = "SelfSigned",
validFrom = Clock.System.now(),
validUntil = Clock.System.now().plus(10, DateTimeUnit.MINUTE),
subjectCommonName = "SelfSigned",
publicKey = cryptoPublicKey
)
val signature = signInt(tbsCertificate.encodeToDer())
this.certificate = X509Certificate(tbsCertificate = tbsCertificate, signatureAlgorithm = JwsAlgorithm.ES256, signature = signature).encodeToDer()
this.certificate = X509Certificate(
tbsCertificate = tbsCertificate,
signatureAlgorithm = JwsAlgorithm.ES256,
signature = signature
).encodeToDer()
}

private fun signInt(input: ByteArray): ByteArray {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,12 @@ class X509CertificateJvmTest : FreeSpec({
val tbsCertificate = TbsCertificate(
version = 2,
serialNumber = serialNumber.toLong(),
issuer = commonName,
issuerCommonName = commonName,
validFrom = notBeforeDate.toInstant().toKotlinInstant(),
validUntil = notAfterDate.toInstant().toKotlinInstant(),
signatureAlgorithm = signatureAlgorithm,
subject = commonName,
subjectPublicKey = cryptoPublicKey
subjectCommonName = commonName,
publicKey = cryptoPublicKey
)
val signed = Signature.getInstance(signatureAlgorithm.jcaName).apply {
initSign(keyPair.private)
Expand Down

0 comments on commit b23f104

Please sign in to comment.