Skip to content

Commit

Permalink
address PR #23 comments
Browse files Browse the repository at this point in the history
  • Loading branch information
JesusMcCloud committed Oct 30, 2023
1 parent 776d747 commit a34db53
Show file tree
Hide file tree
Showing 3 changed files with 238 additions and 205 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package at.asitplus.crypto.datatypes.cose

import at.asitplus.crypto.datatypes.CryptoPublicKey
import at.asitplus.crypto.datatypes.EcCurve
import at.asitplus.crypto.datatypes.asn1.decodeFromDer
import at.asitplus.crypto.datatypes.asn1.encodeToByteArray
import at.asitplus.crypto.datatypes.cose.io.cborSerializer
import io.github.aakira.napier.Napier
Expand All @@ -18,148 +17,9 @@ import kotlinx.serialization.encoding.Decoder
import kotlinx.serialization.encoding.Encoder
import kotlinx.serialization.encoding.decodeStructure

// Class needed to handle overlapping serial labels in COSE standard
sealed class CoseKeyParams() {

abstract fun toCryptoPublicKey(): CryptoPublicKey?

// Implements elliptic curve public key parameters in case of y being a Bytearray
data class EcYByteArrayParams(
val curve: CoseEllipticCurve? = null,
val x: ByteArray? = null,
val y: ByteArray? = null,
val d: ByteArray? = null
) : CoseKeyParams() {
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other == null || this::class != other::class) return false

other as EcYByteArrayParams

if (curve != other.curve) return false
if (x != null) {
if (other.x == null) return false
if (!x.contentEquals(other.x)) return false
} else if (other.x != null) return false
if (y != null) {
if (other.y == null) return false
if (!y.contentEquals(other.y)) return false
} else if (other.y != null) return false
if (d != null) {
if (other.d == null) return false
if (!d.contentEquals(other.d)) return false
} else if (other.d != null) return false

return true
}

override fun hashCode(): Int {
var result = curve?.hashCode() ?: 0
result = 31 * result + (x?.contentHashCode() ?: 0)
result = 31 * result + (y?.contentHashCode() ?: 0)
result = 31 * result + (d?.contentHashCode() ?: 0)
return result
}

override fun toCryptoPublicKey(): CryptoPublicKey? {
return let {
CryptoPublicKey.Ec.fromCoordinates(
curve = curve?.toJwkCurve() ?: return null,
x = x ?: return null,
y = y ?: return null
)
}
}
}

// Implements elliptic curve public key parameters in case of y being a bool value
data class EcYBoolParams(
val curve: CoseEllipticCurve? = null,
val x: ByteArray? = null,
val y: Boolean? = null,
val d: ByteArray? = null
) : CoseKeyParams() {

override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other == null || this::class != other::class) return false

other as EcYBoolParams

if (curve != other.curve) return false
if (x != null) {
if (other.x == null) return false
if (!x.contentEquals(other.x)) return false
} else if (other.x != null) return false
if (y != other.y) return false
if (d != null) {
if (other.d == null) return false
if (!d.contentEquals(other.d)) return false
} else if (other.d != null) return false

return true
}

override fun hashCode(): Int {
var result = curve?.hashCode() ?: 0
result = 31 * result + (x?.contentHashCode() ?: 0)
result = 31 * result + (y?.hashCode() ?: 0)
result = 31 * result + (d?.contentHashCode() ?: 0)
return result
}

override fun toCryptoPublicKey(): CryptoPublicKey? = TODO()

// TODO conversion to cryptoPublicKey (needs de-/compression of Y coordinate)
}

// Implements RSA public key parameters
data class RsaParams(
val n: ByteArray? = null,
val e: ByteArray? = null,
val d: ByteArray? = null
) : CoseKeyParams() {
override fun equals(other: Any?): Boolean {
if (this === other) return true
if (other == null || this::class != other::class) return false

other as RsaParams

if (n != null) {
if (other.n == null) return false
if (!n.contentEquals(other.n)) return false
} else if (other.n != null) return false
if (e != null) {
if (other.e == null) return false
if (!e.contentEquals(other.e)) return false
} else if (other.e != null) return false
if (d != null) {
if (other.d == null) return false
if (!d.contentEquals(other.d)) return false
} else if (other.d != null) return false

return true
}

override fun hashCode(): Int {
var result = n?.contentHashCode() ?: 0
result = 31 * result + (e?.contentHashCode() ?: 0)
result = 31 * result + (d?.contentHashCode() ?: 0)
return result
}

override fun toCryptoPublicKey(): CryptoPublicKey? {
return let {
CryptoPublicKey.Rsa(
n = n ?: return null,
e = e?.let { bytes -> Int.decodeFromDer(bytes) } ?: return null
)
}
}
}
}


/**
* COSE public key as per [RFC 8152](https://www.rfc-editor.org/rfc/rfc8152.html#page-33). Since this is used as part of a COSE-specific DTO, every property is nullable
*/
@OptIn(ExperimentalSerializationApi::class)
@Serializable(with = CoseKeySerializer::class)
data class CoseKey(
Expand Down Expand Up @@ -214,55 +74,42 @@ data class CoseKey(
return result
}

/**
* @return a [CryptoPublicKey] equivalent if conversion is possibl (i.e. if all key params are set)<br> or `null` in case the required key params are not contained in this COSE key (i.e. if only a `kid` is used))
*/
fun toCryptoPublicKey() = keyParams.toCryptoPublicKey()

fun serialize() = cborSerializer.encodeToByteArray(this)

companion object {

fun deserialize(it: ByteArray) = kotlin.runCatching {
cborSerializer.decodeFromByteArray<CoseHeader>(it)
cborSerializer.decodeFromByteArray<CoseKey>(it)
}.getOrElse {
Napier.w("deserialize failed", it)
null
}

fun fromAnsiX963Bytes(input: ByteArray, algorithm: CoseAlgorithm? = null): CoseKey? =
CryptoPublicKey.Ec.fromAnsiX963Bytes(input).toCoseKey(algorithm)

fun fromCoordinates(
curve: CoseEllipticCurve,
x: ByteArray,
y: ByteArray,
algorithm: CoseAlgorithm? = null
): CoseKey? = CryptoPublicKey.Ec.fromCoordinates(curve.toJwkCurve(), x, y).toCoseKey(algorithm)

fun fromPKCS1encoded(input: ByteArray, algorithm: CoseAlgorithm? = null): CoseKey? =
CryptoPublicKey.Rsa.fromPKCS1encoded(input).toCoseKey(algorithm)

@Deprecated("Use function [fromAnsiX963Bytes] above instead!")
@Deprecated("Use [CryptoPublicKey.fromAnsiX963Bytes] and [toCoseKey] instead!")
fun fromAnsiX963Bytes(type: CoseKeyType, curve: CoseEllipticCurve, it: ByteArray) =
if (type == CoseKeyType.EC2 && curve == CoseEllipticCurve.P256) {
val pubKey = CryptoPublicKey.Ec.fromAnsiX963Bytes(it)
pubKey.toCoseKey()
} else null

@Deprecated("Use function [fromCoordinates] above instead")
@Deprecated("Use function [CryptoPublicKey.fromCoordinates] and [toCoseKey] above instead")
fun fromCoordinates(
type: CoseKeyType,
curve: CoseEllipticCurve,
x: ByteArray,
y: ByteArray
): CoseKey? {
return fromCoordinates(curve, x, y)
}
): CoseKey? = CryptoPublicKey.Ec.fromCoordinates(curve.toJwkCurve(), x, y).toCoseKey()

}
}

/**
* Converts CryptoPublicKey into CoseKey
* If algorithm is not set then key can be used for any algorithm with same kty (RFC 8152), returns null for invalid kty/algorithm pairs
* If [algorithm] is not set then key can be used for any algorithm with same kty (RFC 8152), returns null for invalid kty/algorithm pairs
*/
fun CryptoPublicKey.toCoseKey(algorithm: CoseAlgorithm? = null): CoseKey? =
when (this) {
Expand Down Expand Up @@ -374,7 +221,7 @@ object CoseKeySerializer : KSerializer<CoseKey> {
},
when (val params = src.keyParams) {
is CoseKeyParams.RsaParams -> params.d
is CoseKeyParams.EcYByteArrayParams -> params.d; else -> TODO()
is CoseKeyParams.EcYByteArrayParams -> params.d
},

)
Expand Down Expand Up @@ -470,7 +317,7 @@ object CoseKeySerializer : KSerializer<CoseKey> {
get() = CoseKeySerialContainer.serializer().descriptor

override fun deserialize(decoder: Decoder): CoseKey {
val labels = mapOf<String,Long>(
val labels = mapOf<String, Long>(
"kty" to 1,
"kid" to 2,
"alg" to 3,
Expand All @@ -494,50 +341,78 @@ object CoseKeySerializer : KSerializer<CoseKey> {
var d: ByteArray? = null

decoder.decodeStructure(descriptor) {
val e=this
while (true) {
val index= decodeElementIndex(descriptor)
if(index==-1) break
val label = descriptor.getElementAnnotations(index).filterIsInstance<SerialLabel>().first().label
if (label == labels["kty"]) type =
decodeSerializableElement(CoseKeyTypeSerializer.descriptor, index, CoseKeyTypeSerializer)
else if (label == labels["kid"]) keyId =
decodeNullableSerializableElement(ByteArraySerializer().descriptor, index, ByteArraySerializer())
else if (label == labels["alg"]) alg =
decodeNullableSerializableElement(CoseAlgorithmSerializer.descriptor, index, CoseAlgorithmSerializer)
else if (label == labels["key_ops"]) keyOps =
decodeNullableSerializableElement(
ArraySerializer(CoseKeyOperationSerializer).descriptor,
index,
ArraySerializer(CoseKeyOperationSerializer)
)
else if (label == labels["n/crv"]) {
when (type) {
CoseKeyType.EC2 -> {
val deser = CoseEllipticCurveSerializer
crv = decodeNullableSerializableElement(deser.descriptor, index, deser)
}

CoseKeyType.RSA -> {
val deser = ByteArraySerializer()
n = decodeNullableSerializableElement(deser.descriptor, index, deser)
val index = decodeElementIndex(descriptor)
if (index == -1) break
val label = descriptor.getElementAnnotations(index).filterIsInstance<SerialLabel>().first().label
when (label) {
labels["kty"] -> type =
decodeSerializableElement(CoseKeyTypeSerializer.descriptor, index, CoseKeyTypeSerializer)

labels["kid"] -> keyId =
decodeNullableSerializableElement(
ByteArraySerializer().descriptor,
index,
ByteArraySerializer()
)

labels["alg"] -> alg =
decodeNullableSerializableElement(
CoseAlgorithmSerializer.descriptor,
index,
CoseAlgorithmSerializer
)

labels["key_ops"] -> keyOps =
decodeNullableSerializableElement(
ArraySerializer(CoseKeyOperationSerializer).descriptor,
index,
ArraySerializer(CoseKeyOperationSerializer)
)

labels["n/crv"] -> {
when (type) {
CoseKeyType.EC2 -> {
val deser = CoseEllipticCurveSerializer
crv = decodeNullableSerializableElement(deser.descriptor, index, deser)
}

CoseKeyType.RSA -> {
val deser = ByteArraySerializer()
n = decodeNullableSerializableElement(deser.descriptor, index, deser)
}

CoseKeyType.SYMMETRIC -> TODO()
}

CoseKeyType.SYMMETRIC -> TODO()
null -> TODO()
}

} else if (label == labels["x/e"]) xOrE =
decodeNullableSerializableElement(ByteArraySerializer().descriptor, index, ByteArraySerializer())
else if (label == labels["y"]) y =
decodeNullableSerializableElement(ByteArraySerializer().descriptor, index, ByteArraySerializer())
else if (label == labels["d"]) d =
decodeNullableSerializableElement(ByteArraySerializer().descriptor, index, ByteArraySerializer())
else {
break
labels["x/e"] -> xOrE =
decodeNullableSerializableElement(
ByteArraySerializer().descriptor,
index,
ByteArraySerializer()
)

labels["y"] -> y =
decodeNullableSerializableElement(
ByteArraySerializer().descriptor,
index,
ByteArraySerializer()
)

labels["d"] -> d =
decodeNullableSerializableElement(
ByteArraySerializer().descriptor,
index,
ByteArraySerializer()
)

else -> {
break
}
}
}

}
return when (type) {
CoseKeyType.EC2 -> {
Expand Down
Loading

0 comments on commit a34db53

Please sign in to comment.