Skip to content

Commit

Permalink
improve tag
Browse files Browse the repository at this point in the history
  • Loading branch information
JesusMcCloud committed Oct 21, 2024
1 parent e0fa1ba commit 4784b65
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,20 @@ sealed class Asn1Element(

val tag by lazy { tlv.tag }

val derEncoded: ByteArray by lazy {
(children?.fold(Buffer()) { acc, extendedTlv -> acc.apply { write(extendedTlv.derEncoded) } }
?.let {
Buffer().apply { write(tlv.tag.encodedTag); encodeLength(it.size); it.transferTo(this) }
}?.readByteArray()
?: Buffer().apply { write(tlv.tag.encodedTag); write(encodedLength);write(tlv.content) }.readByteArray())

internal val derBuffered: Source by lazy {
children?.fold(Buffer()) { acc, tlv -> acc.apply { /*Yes, we copy!*/ tlv.derBuffered.peek().transferTo(this) } }
?.let { Buffer().apply { write(tlv.tag.encodedTag); encodeLength(it.size); it.transferTo(this) } }
?: Buffer().apply { write(tlv.tag.encodedTag); write(encodedLength); tlv.content.wrapInUnsafeSource().transferTo(this) }
}

/**
* DER-encoded representation of this ASN.1 element.
* Each time, this property is read, a new copy of the underlying DER-encoded representation is created.
* This enables non-destructive modification of DER-encoded representations, because a fresh copy can always be obtained.
*/
val derEncoded: ByteArray get() = derBuffered.peek().readByteArray()

override fun toString(): String = "(tag=${tlv.tag}" +
", length=${length}" +
", overallLength=${overallLength}" +
Expand Down Expand Up @@ -235,17 +241,19 @@ sealed class Asn1Element(

@Serializable
@ConsistentCopyVisibility
data class Tag private constructor(
data class Tag
/**
* This constructor performs no validations on the passed byte array!
*/
internal constructor(
val tagValue: ULong, val encodedTagLength: Int,
@Serializable(with = ByteArrayBase64Serializer::class) val encodedTag: ByteArray
) : Comparable<Tag> {
private constructor(values: Triple<ULong, Int, ByteArray>) : this(values.first, values.second, values.third)

//TODO this CTOR is internally called only using already validated inputs.
// We need another CTOR to prevent double-parsing and byte copying
constructor(derEncoded: ByteArray) : this(
derEncoded.wrapInUnsafeSource().decodeTag()
.let { Triple(it.first, it.second.size, derEncoded) }
derEncoded.wrapInUnsafeSource().decodeTag(collect = false)
.let { Triple(it.first, it.second, derEncoded) }
)

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ fun String.Companion.decodeFromAsn1ContentBytes(bytes: ByteString) = bytes.decod
private fun Source.readTlv(): TLV.Shallow = runRethrowing {
if (exhausted()) throw IllegalArgumentException("Can't read TLV, input empty")

val tag = decodeTag()
val tag = decodeTag(collect = true)
val length = decodeLength()
require(length < 1024 * 1024) { "Heap space" }
val value = Buffer().also {
Expand All @@ -381,9 +381,18 @@ private fun Source.readTlv(): TLV.Shallow = runRethrowing {
}
}

return TLV.Shallow(Asn1Element.Tag(tag.second), value)
return TLV.Shallow(
Asn1Element.Tag(
tagValue = tag.first,
encodedTagLength = length.toInt(),
encodedTag = tag.third!!
), value
)
}

/**
* [collect] tells this method to collect read bytes into a [ByteArray]
*/
@Throws(IllegalArgumentException::class)
private fun Source.decodeLength(): Long =
readByte().let { firstByte ->
Expand All @@ -392,9 +401,11 @@ private fun Source.decodeLength(): Long =
} else { // its BER long form!
val numOctets = (firstByte byteMask UVARINT_MASK).toInt()
require(numOctets <= Long.SIZE_BYTES) { "TLV Length too long ($numOctets bytes). Invalid data?" }

(0 until numOctets).fold(0L) { acc, index ->
require(!exhausted()) { "Can't decode length" }
acc + (readUByte().toLong() shl Byte.SIZE_BITS * (numOctets - index - 1))
acc + (readUByte()
.toLong() shl Byte.SIZE_BITS * (numOctets - index - 1))
}.also {
require(it >= 0) { "TLV length overflow: $it. Invalid data?" }
}
Expand All @@ -406,17 +417,20 @@ private fun Byte.isBerShortForm() = this byteMask UVARINT_SINGLEBYTE_MAXVALUE ==
internal infix fun Byte.byteMask(mask: Int) = (this and mask.toUInt().toByte()).toUByte()
internal infix fun Byte.byteMask(mask: Byte) = (this and mask).toUByte()

internal fun Source.decodeTag(): Pair<ULong, ByteArray> =
/**
* [collect] tells this method to collect read bytes into a [ByteArray]
*/
internal fun Source.decodeTag(collect: Boolean = false): Triple<ULong, Int, ByteArray?> =
readByte().let { firstByte ->
(firstByte byteMask 0x1F).let { tagNumber ->
if (tagNumber <= 30U) {
tagNumber.toULong() to byteArrayOf(firstByte)
Triple(tagNumber.toULong(), 1, if (collect) byteArrayOf(firstByte) else null)
} else {
decodeAsn1VarULongShallow().let { (l, b) ->
l to Buffer().apply {
Triple(l, 1 + b.size.toInt(), if (collect) Buffer().apply {
writeByte(firstByte)
b.transferTo(this)
}.readByteArray()
}.readByteArray() else null)
}
}
}
Expand Down

0 comments on commit 4784b65

Please sign in to comment.