Skip to content

Commit

Permalink
ASN.1 Overhaul:
Browse files Browse the repository at this point in the history
* weed out legacy stuff in Asn1Encoder
* better explicit tagging
* better recursive OCTET STRINGS
* more consitent API use
  • Loading branch information
JesusMcCloud committed Oct 21, 2023
1 parent f9c9305 commit b599996
Show file tree
Hide file tree
Showing 10 changed files with 129 additions and 94 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ sealed class CryptoPublicKey : Asn1Encodable<Asn1Sequence>, Identifiable {
override fun encodeToTlv() = when (this) {
is Ec -> asn1Sequence {
sequence {
oid { oid }
oid { curve.oid }
append(oid)
append(curve.oid)
}
bitString {
(byteArrayOf(
Expand All @@ -51,18 +51,10 @@ sealed class CryptoPublicKey : Asn1Encodable<Asn1Sequence>, Identifiable {
is Rsa -> {
asn1Sequence {
sequence {
oid { oid }
append(oid)
asn1null()
}
bitString(asn1Sequence {
append {
Asn1Primitive(
BERTags.INTEGER,
n.ensureSize(bits.number / 8u)
.let { if (it.first() == 0x00.toByte()) it else byteArrayOf(0x00, *it) })
}
int { e }
})
bitString { iosEncoded }
}
}
}
Expand Down Expand Up @@ -192,11 +184,11 @@ sealed class CryptoPublicKey : Asn1Encodable<Asn1Sequence>, Identifiable {
*/
@Transient
override val iosEncoded = asn1Sequence {
append {
append(
Asn1Primitive(BERTags.INTEGER,
n.ensureSize(bits.number / 8u)
.let { if (it.first() == 0x00.toByte()) it else byteArrayOf(0x00, *it) })
}
)
int { e }
}.derEncoded

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,26 +40,27 @@ enum class JwsAlgorithm(val identifier: String, override val oid: ObjectIdentifi
}

override fun encodeToTlv() = when (this) {
ES256 -> asn1Sequence { oid { oid } }
ES384 -> asn1Sequence { oid { oid } }
ES512 -> asn1Sequence { oid { oid } }
ES256 -> asn1Sequence { append(oid) }
ES384 -> asn1Sequence { append(oid) }
ES512 -> asn1Sequence { append(oid) }

RS256 -> asn1Sequence {
oid { oid }
append(oid)
asn1null()
}

RS384 -> asn1Sequence {
oid { oid }
append(oid)
asn1null()
}

RS512 -> asn1Sequence {
oid { oid }
append(oid)
asn1null()
}

NON_JWS_SHA1_WITH_RSA -> asn1Sequence {
oid { oid }
append(oid)
asn1null()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ sealed class Asn1Structure(tag: UByte, children: List<Asn1Element>?) :
* @param tag the ASN.1 Tag to be used
* @param children the child nodes to be contained in this tag
*/
//TODO check if explicitly tagged
class Asn1Tagged(tag: UByte, children: List<Asn1Element>) : Asn1Structure(tag, children) {

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import at.asitplus.crypto.datatypes.asn1.BERTags.INTEGER
import at.asitplus.crypto.datatypes.asn1.BERTags.NULL
import at.asitplus.crypto.datatypes.asn1.BERTags.OBJECT_IDENTIFIER
import at.asitplus.crypto.datatypes.asn1.BERTags.UTC_TIME
import at.asitplus.crypto.datatypes.asn1.DERTags.toExplicitTag
import io.matthewnelson.encoding.base16.Base16
import io.matthewnelson.encoding.core.Decoder.Companion.decodeToByteArray
import io.matthewnelson.encoding.core.Encoder.Companion.encodeToString
Expand Down Expand Up @@ -63,100 +64,118 @@ class Asn1TreeBuilder {
/**
* appends a single [Asn1Element] to this ASN.1 structure
*/
fun append(child: () -> Asn1Element) = apply { elements += child() }
fun append(child: Asn1Element) {
elements += child
}

/**
* EXPLICITLY tags a single [Asn1Element] and adds it to this ASN.1 structure
* appends a single [Asn1Encodable] to this ASN.1 structure
*/
fun tagged(tag: UByte, child: () -> Asn1Element) = apply { elements += Asn1Tagged(tag, child()) }
fun append(child: Asn1Encodable<*>) = append(child.encodeToTlv())

/**
* Adds a BOOL [Asn1Primitive] to this ASN.1 structure
* EXPLICITLY tags and encapsulates the result of [init]
* <br>
* **NOTE:** automatically calls [at.asitplus.crypto.datatypes.asn1.DERTags.toExplicitTag] on [tag]
*/
fun bool(block: () -> Boolean) = apply { elements += block().encodeToTlv() }
fun tagged(tag: UByte, init: Asn1TreeBuilder.() -> Unit) {
val seq = Asn1TreeBuilder()
seq.init()
elements += Asn1Tagged(tag.toExplicitTag(), seq.elements)
}

/**
* Adds an INTEGER [Asn1Primitive] to this ASN.1 structure
* Adds a BOOL [Asn1Primitive] to this ASN.1 structure
*/
fun int(block: () -> Int) = apply { elements += block().encodeToTlv() }
fun bool(block: () -> Boolean) {
elements += block().encodeToTlv()
}

/**
* Adds an INTEGER [Asn1Primitive] to this ASN.1 structure
*/
fun long(block: () -> Long) = apply { elements += block().encodeToTlv() }
fun int(block: () -> Int) {
elements += block().encodeToTlv()
}

/**
* Adds the passed bytes as OCTET STRING [Asn1Primitive] to this ASN.1 structure
* Adds an INTEGER [Asn1Primitive] to this ASN.1 structure
*/
fun octetString(block: () -> ByteArray) = apply { elements += block().encodeToTlvOctetString() }
fun long(block: () -> Long) {
elements += block().encodeToTlv()
}

/**
* Adds passed [Asn1Element] as OCTET STRING [Asn1Primitive] to this ASN.1 structure
* Adds the passed bytes as OCTET STRING [Asn1Element] to this ASN.1 structure
*/
fun octetString(child: Asn1Element) = apply { octetString(block = { child.derEncoded }) }
fun octetString(block: () -> ByteArray) = apply { elements += block().encodeToTlvOctetString() }

/**
* Adds the passed bytes as BIT STRING [Asn1Primitive] to this ASN.1 structure
*/
fun bitString(block: () -> ByteArray) = apply { elements += block().encodeToTlvBitString() }

/**
* Adds the passed [Asn1Element] as BIT STRING [Asn1Primitive] to this ASN.1 structure
*/
fun bitString(child: Asn1Element) = apply { bitString(block = { child.derEncoded }) }

/**
* Adds the passed [ObjectIdentifier] as OBJECT IDENTIFIER to this ASN.1 structure
*/
fun oid(block: () -> ObjectIdentifier) = apply { elements += block().encodeToTlv() }
fun bitString(block: () -> ByteArray) {
elements += block().encodeToTlvBitString()
}

/**
* Shorthand method taking a HEX representation of an OID value, adding it as an OBJECT IDENTIFIER to this ASN.1 structure.
* Really only useful for quick debugging against other ASN.1 decoders, such as https://lapo.it/asn1js/, so we're keeping it for now
*/
@Deprecated("Used only for quick debugging. May be removed in the future")
fun hexEncoded(block: () -> String) = apply { elements += block().encodeTolvOid() }

fun hexEncoded(block: () -> String) {
elements += block().encodeTolvOid()
}

/**
* Adds the passed string as UTF8 STRING to this ASN.1 structure
*/
fun utf8String(block: () -> String) = apply { elements += Asn1String.UTF8(block()).encodeToTlv() }
fun utf8String(block: () -> String) {
elements += Asn1String.UTF8(block()).encodeToTlv()
}

/**
* Adds the passed string as PRINTABLE STRING to this ASN.1 structure
*/
fun printableString(block: () -> String) = apply { elements += Asn1String.Printable(block()).encodeToTlv() }
fun printableString(block: () -> String) {
elements += Asn1String.Printable(block()).encodeToTlv()
}

/**
* Adds the passed [Asn1String] to this ASN.1 structure
*/
fun string(block: () -> Asn1String) = apply {
fun string(block: () -> Asn1String) {
val str = block()
append { str.encodeToTlv() }
str.encodeToTlv()
}


/**
* Adds a NULL to this ASN.1 structure
*/
fun asn1null() = apply { elements += Asn1Primitive(NULL, byteArrayOf()) }
fun asn1null() {
elements += Asn1Primitive(NULL, byteArrayOf())
}

/**
* Adds the passed instant as UTC TIME to this ASN.1 structure
*/
fun utcTime(block: () -> Instant) = apply { elements += block().encodeToAsn1UtcTime() }
fun utcTime(block: () -> Instant) {
elements += block().encodeToAsn1UtcTime()
}

/**
* Adds the passed instant as GENERALIZED TIME to this ASN.1 structure
*/
fun generalizedTime(block: () -> Instant) =
apply { elements += block().encodeToAsn1GeneralizedTime() }
fun generalizedTime(block: () -> Instant) {
elements += block().encodeToAsn1GeneralizedTime()
}

private fun nest(type: CollectionType, init: Asn1TreeBuilder.() -> Unit) = apply {
private fun nest(type: CollectionType, init: Asn1TreeBuilder.() -> Unit) {
val seq = Asn1TreeBuilder()
seq.init()
elements += if (type == CollectionType.SEQUENCE) Asn1Sequence(seq.elements) else Asn1Set(seq.elements.let {
elements += if (type == CollectionType.SEQUENCE) Asn1Sequence(seq.elements)
else if (type == CollectionType.OCTET_STRING) Asn1EncapsulatingOctetString(seq.elements)
else Asn1Set(seq.elements.let {
if (type == CollectionType.SET) it.sortedBy { it.tag }
else {
if (it.any { elem -> elem.tag != it.first().tag }) throw IllegalArgumentException("SET_OF must only contain elements of the same tag")
Expand All @@ -172,11 +191,11 @@ class Asn1TreeBuilder {
* ```kotlin
* set {
* sequence {
* setOf { //note: DER encoding enfoces sorting here, so the result switches those
* setOf { //note: DER encoding enforces sorting here, so the result switches those
* printableString { "World" }
* printableString { "Hello" }
* }
* set { //note: DER encoding enfoces sorting by tags, so the order changes in the ourput
* set { //note: DER encoding enforces sorting by tags, so the order changes in the output
* printableString { "World" }
* printableString { "Hello" }
* utf8String { "!!!" }
Expand All @@ -194,11 +213,11 @@ class Asn1TreeBuilder {
* ```kotlin
* set {
* sequence {
* setOf { //note: DER encoding enfoces sorting here, so the result switches those
* setOf { //note: DER encoding enforces sorting here, so the result switches those
* printableString { "World" }
* printableString { "Hello" }
* }
* set { //note: DER encoding enfoces sorting by tags, so the order changes in the ourput
* set { //note: DER encoding enforces sorting by tags, so the order changes in the output
* printableString { "World" }
* printableString { "Hello" }
* utf8String { "!!!" }
Expand All @@ -216,11 +235,11 @@ class Asn1TreeBuilder {
* ```kotlin
* set {
* sequence {
* setOf { //note: DER encoding enfoces sorting here, so the result switches those
* setOf { //note: DER encoding enforces sorting here, so the result switches those
* printableString { "World" }
* printableString { "Hello" }
* }
* set { //note: DER encoding enfoces sorting by tags, so the order changes in the ourput
* set { //note: DER encoding enforces sorting by tags, so the order changes in the output
* printableString { "World" }
* printableString { "Hello" }
* utf8String { "!!!" }
Expand All @@ -231,12 +250,31 @@ class Asn1TreeBuilder {
*/
fun setOf(init: Asn1TreeBuilder.() -> Unit) = nest(CollectionType.SET_OF, init)

/**
* OCTET STRING builder. The result of [init] is encapsulated into an ASN.1 OCTET STRING and then added to this ASN.1 structure
* ```kotlin
* set {
* octetString {
* printableString { "Hello" }
* printableString { "World" }
* sequence {
* printableString { "World" }
* printableString { "Hello" }
* utf8String { "!!!" }
* }
* }
* }
* ```
*/
fun octetStringEncapsulated(init: Asn1TreeBuilder.() -> Unit) = nest(CollectionType.OCTET_STRING, init)

}

private enum class CollectionType {
SET,
SEQUENCE,
SET_OF
SET_OF,
OCTET_STRING
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ sealed class DistinguishedName : Asn1Encodable<Asn1Set>, Identifiable {

override fun encodeToTlv() = asn1Set {
sequence {
oid { oid }
append { value }
append(oid)
append(value)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,21 +38,19 @@ data class TbsCertificationRequest(
version: Int = 0,
attributes: List<Pkcs10CertificationRequestAttribute>? = null,
) : this(version, subjectName, publicKey, mutableListOf<Pkcs10CertificationRequestAttribute>().also { attrs ->
if(extensions.isEmpty()) throw IllegalArgumentException("No extensions provided!")
if (extensions.isEmpty()) throw IllegalArgumentException("No extensions provided!")
attributes?.let { attrs.addAll(it) }
attrs.add(Pkcs10CertificationRequestAttribute(KnownOIDs.extensionRequest, asn1Sequence {
extensions.forEach {
append { it.encodeToTlv() }
}
extensions.forEach { append(it) }
}))
})

override fun encodeToTlv() = asn1Sequence {
int { version }
sequence { subjectName.forEach { append { it.encodeToTlv() } } }
sequence { subjectName.forEach { append(it) } }
subjectPublicKey { publicKey }
append {
Asn1Tagged(0u.toExplicitTag(), attributes?.map { it.encodeToTlv() } ?: listOf())
tagged(0u) {
attributes?.map { append(it) }
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ data class Pkcs10CertificationRequestAttribute(
constructor(id: ObjectIdentifier, value: Asn1Element) : this(id, listOf(value))

override fun encodeToTlv() = asn1Sequence {
oid { oid }
set { value.forEach { append { it } } }
append ( oid )
set { value.forEach { append(it) } }
}

override fun equals(other: Any?): Boolean {
Expand Down
Loading

0 comments on commit b599996

Please sign in to comment.