Skip to content

Commit

Permalink
Refactor ASN.1 DSL
Browse files Browse the repository at this point in the history
  • Loading branch information
JesusMcCloud authored and nodh committed Sep 22, 2023
1 parent 75e9a3f commit d2d4fd4
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 72 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,36 +8,72 @@ import io.matthewnelson.encoding.base16.Base16
import io.matthewnelson.encoding.core.Decoder.Companion.decodeToByteArray
import kotlinx.datetime.Instant

class SequenceBuilder {

fun tag(tag: Int, block: () -> ByteArray): ByteArray {
val value = block()
return byteArrayOf(tag.toByte()) + value.size.encodeLength() + value
}
internal val elements = mutableListOf<ByteArray>()

fun long(block: () -> Long) = apply { elements += block().encodeToAsn1() }

fun bitString(block: () -> ByteArray) = apply { elements += block().encodeToBitString() }

fun oid(block: () -> String) = apply { elements += block().encodeToOid() }

fun utf8String(block: () -> String) = apply { elements += asn1Tag(0x0c, block().encodeToByteArray()) }

fun long(block: () -> Long) = tag(0x02) { block().encodeToByteArray().dropWhile { it == 0.toByte() }.toByteArray() }
fun version(block: () -> Int) = apply { elements += asn1Tag(0xA0, block().encodeToAsn1()) }

fun int(block: () -> Int) = tag(0x02) { block().encodeToByteArray().dropWhile { it == 0.toByte() }.toByteArray() }
fun commonName(block: () -> String) = apply {
oid { "550403" }
utf8String { block() }
}

fun subjectPublicKey(block: () -> CryptoPublicKey) = apply { elements += block().encodeToAsn1() }

fun bitString(block: () -> ByteArray) = tag(0x03) { (byteArrayOf(0x00) + block()) }
fun tbsCertificate(block: () -> TbsCertificate) = apply { elements += block().encodeToDer() }

fun oid(block: () -> String): ByteArray = tag(0x06) { block().decodeToByteArray(Base16()) }
fun sigAlg(block: () -> JwsAlgorithm) = apply { elements += block().encodeToAsn1() }

fun sequence(block: () -> List<ByteArray>) = tag(0x30) { block().fold(byteArrayOf()) { acc, bytes -> acc + bytes } }
fun utcTime(block: () -> Instant) = apply { elements += block().encodeToAsn1() }

fun set(block: () -> List<ByteArray>) = tag(0x31) { block().fold(byteArrayOf()) { acc, bytes -> acc + bytes } }
fun sequence(init: SequenceBuilder.() -> Unit) = apply {
val seq = SequenceBuilder()
seq.init()
elements += asn1Tag(0x30, seq.elements.fold(byteArrayOf()) { acc, bytes -> acc + bytes })
}

fun utf8String(block: () -> String) = tag(0x0c) { block().encodeToByteArray() }
fun set(init: SequenceBuilder.() -> Unit) = apply {
val seq = SequenceBuilder()
seq.init()
elements += asn1Tag(0x31, seq.elements.fold(byteArrayOf()) { acc, bytes -> acc + bytes })
}
}

fun commonName(block: () -> String) = oid { "550403" } + utf8String { block() }

fun subjectPublicKey(block: () -> CryptoPublicKey) = when (val value = block()) {
is CryptoPublicKey.Ec -> value.encodeToDer()
fun sequence(init: SequenceBuilder.() -> Unit): ByteArray {
val seq = SequenceBuilder()
seq.init()
return asn1Tag(0x30, seq.elements.fold(byteArrayOf()) { acc, bytes -> acc + bytes })
}

fun utcTime(block: () -> Instant): ByteArray {
val value = block()
private fun Int.encodeToAsn1() = asn1Tag(0x02, encodeToDer())

private fun Int.encodeToDer() = encodeToByteArray().dropWhile { it == 0.toByte() }.toByteArray()

private fun Long.encodeToAsn1() = asn1Tag(0x02, encodeToDer())

private fun Long.encodeToDer() = encodeToByteArray().dropWhile { it == 0.toByte() }.toByteArray()

private fun ByteArray.encodeToBitString() = asn1Tag(0x03, (byteArrayOf(0x00) + this))

private fun asn1Tag(tag: Int, value: ByteArray) = byteArrayOf(tag.toByte()) + value.size.encodeLength() + value

private fun String.encodeToOid() = asn1Tag(0x06, decodeToByteArray(Base16()))

private fun Instant.encodeToAsn1(): ByteArray {
val value = this.toString()
if (value.isEmpty()) return asn1Tag(0x17, byteArrayOf())
val matchResult = Regex("[0-9]{2}([0-9]{2})-([0-9]{2})-([0-9]{2})T([0-9]{2}):([0-9]{2}):([0-9]{2})")
.matchAt(value.toString(), 0)
.matchAt(value, 0)
?: throw IllegalArgumentException("instant serialization failed: ${value}")
val year = matchResult.groups[1]?.value
?: throw IllegalArgumentException("instant serialization year failed: ${value}")
Expand All @@ -51,21 +87,22 @@ fun utcTime(block: () -> Instant): ByteArray {
?: throw IllegalArgumentException("instant serialization minute failed: ${value}")
val seconds = matchResult.groups[6]?.value
?: throw IllegalArgumentException("instant serialization seconds failed: ${value}")
return tag(0x17) { "$year$month$day$hour$minute${seconds}Z".encodeToByteArray() }
return asn1Tag(0x17, "$year$month$day$hour$minute${seconds}Z".encodeToByteArray())
}

fun tbsCertificate(block: () -> TbsCertificate) = block().encodeToDer()

fun sigAlg(block: () -> JwsAlgorithm): ByteArray = when (val value = block()) {
JwsAlgorithm.ES256 -> sequence { listOf(oid { "2A8648CE3D040302" }) }
else -> throw IllegalArgumentException("sigAlg: $value")
private fun JwsAlgorithm.encodeToAsn1() = when (this) {
JwsAlgorithm.ES256 -> sequence { oid { "2A8648CE3D040302" } }
else -> throw IllegalArgumentException("sigAlg: $this")
}

private fun CryptoPublicKey.Ec.encodeToDer(): ByteArray {
val ecKeyTag = oid { "2A8648CE3D0201" }
val ecEncryptionNullTag = oid { "2A8648CE3D030107" }
val content = bitString { (byteArrayOf(0x04.toByte()) + x + y) }
return sequence { listOf(sequence { listOf(ecKeyTag, ecEncryptionNullTag) }, content) }
private fun CryptoPublicKey.encodeToAsn1() = when (this) {
is CryptoPublicKey.Ec -> sequence {
sequence {
oid { "2A8648CE3D0201" }
oid { "2A8648CE3D030107" }
}
bitString { (byteArrayOf(0x04.toByte()) + x + y) }
}
}

private fun Int.encodeLength(): ByteArray {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,7 @@
package at.asitplus.wallet.lib.jws

import at.asitplus.wallet.lib.CryptoPublicKey
import at.asitplus.wallet.lib.asn1.bitString
import at.asitplus.wallet.lib.asn1.commonName
import at.asitplus.wallet.lib.asn1.int
import at.asitplus.wallet.lib.asn1.long
import at.asitplus.wallet.lib.asn1.sequence
import at.asitplus.wallet.lib.asn1.set
import at.asitplus.wallet.lib.asn1.sigAlg
import at.asitplus.wallet.lib.asn1.subjectPublicKey
import at.asitplus.wallet.lib.asn1.tag
import at.asitplus.wallet.lib.asn1.tbsCertificate
import at.asitplus.wallet.lib.asn1.utcTime
import kotlinx.datetime.Instant

/**
Expand All @@ -29,36 +19,29 @@ data class TbsCertificate(
val publicKey: CryptoPublicKey
) {
fun encodeToDer() = sequence {
listOf(
tag(0xA0) {
int { version }
},
long { serialNumber },
sigAlg { signatureAlgorithm },
sequence {
listOf(set {
listOf(sequence {
listOf(commonName { issuerCommonName })
})
})
},
sequence {
listOf(
utcTime { validFrom },
utcTime { validUntil }
)
},
sequence {
listOf(set {
listOf(sequence {
listOf(commonName { subjectCommonName })
})
})
},
subjectPublicKey { publicKey }
)
version { version }
long { serialNumber }
sigAlg { signatureAlgorithm }
sequence {
set {
sequence {
commonName { issuerCommonName }
}
}
}
sequence {
utcTime { validFrom }
utcTime { validUntil }
}
sequence {
set {
sequence {
commonName { subjectCommonName }
}
}
}
subjectPublicKey { publicKey }
}

}

/**
Expand All @@ -70,11 +53,9 @@ data class X509Certificate(
val signature: ByteArray
) {
fun encodeToDer() = sequence {
listOf(
tbsCertificate { tbsCertificate },
sigAlg { signatureAlgorithm },
bitString { signature }
)
tbsCertificate { tbsCertificate }
sigAlg { signatureAlgorithm }
bitString { signature }
}

override fun equals(other: Any?): Boolean {
Expand Down

0 comments on commit d2d4fd4

Please sign in to comment.