Skip to content

Commit

Permalink
Merge pull request #8 from rarimo/fix/convert-bits-length
Browse files Browse the repository at this point in the history
Fix ExtractFirstNBits function implementation. Fix passports with different hash algos for dg and signed attr
  • Loading branch information
artemskriabin authored Jan 2, 2025
2 parents 3957fc8 + f7d6566 commit c52ff85
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 51 deletions.
101 changes: 60 additions & 41 deletions internal/service/api/handlers/register.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,13 @@ func Register(w http.ResponseWriter, r *http.Request) {
}

algorithmPair := types.AlgorithmPair{
HashAlgorithm: types.HashAlgorithmFromString(req.Data.Attributes.DocumentSod.HashAlgorithm),
DgHashAlgorithm: types.HashAlgorithmFromString(req.Data.Attributes.DocumentSod.HashAlgorithm),
SignedAttrHashAlg: types.HashAlgorithmFromString(req.Data.Attributes.DocumentSod.HashAlgorithm),
SignatureAlgorithm: types.SignatureAlgorithmFromString(req.Data.Attributes.DocumentSod.SignatureAlgorithm),
}

documentSOD := data.DocumentSOD{
HashAlgorigthm: algorithmPair.HashAlgorithm,
HashAlgorigthm: algorithmPair.DgHashAlgorithm,
SignatureAlgorithm: algorithmPair.SignatureAlgorithm,
SignedAttributes: *utils.TruncateHexPrefix(&req.Data.Attributes.DocumentSod.SignedAttributes),
EncapsulatedContent: *utils.TruncateHexPrefix(&req.Data.Attributes.DocumentSod.EncapsulatedContent),
Expand Down Expand Up @@ -112,7 +113,7 @@ func Register(w http.ResponseWriter, r *http.Request) {

if err := verifier.VerifyGroth16(
req.Data.Attributes.ZkProof,
verifierCfg.VerificationKeys[algorithmPair.HashAlgorithm],
verifierCfg.VerificationKeys[algorithmPair.DgHashAlgorithm],
); err != nil {
log.WithError(err).Error("failed to verify zk proof")
// TODO: Add documentSOD.ErrorKind and documentSOD.Error initialization for all errors in this handler
Expand Down Expand Up @@ -199,7 +200,7 @@ func Register(w http.ResponseWriter, r *http.Request) {
return
}

err = verifySod(signedAttributes, encapsulatedContent, slaveSignature, cert, algorithmPair, verifierCfg)
saHashBytes, err := verifySod(signedAttributes, encapsulatedContent, slaveSignature, cert, &algorithmPair, verifierCfg)
if err != nil {
sodError := new(types.SodError)
if !errors2.As(err, &sodError) {
Expand All @@ -224,10 +225,6 @@ func Register(w http.ResponseWriter, r *http.Request) {
return
}

saHash := types.GeneralHash(algorithmPair.HashAlgorithm)
saHash.Write(signedAttributes)
saHashBytes := saHash.Sum(nil)

truncatedSignedAttributes, err := utils.ExtractFirstNBits(saHashBytes, 252)
if err != nil {
log.WithError(err).Error("failed to extract bits from signed attributes")
Expand Down Expand Up @@ -265,7 +262,7 @@ func Register(w http.ResponseWriter, r *http.Request) {
return
}

extractedDg15Hash := types.GeneralHash(algorithmPair.HashAlgorithm)
extractedDg15Hash := types.GeneralHash(algorithmPair.DgHashAlgorithm)
extractedDg15Hash.Write(extractedDg15)

if !bytes.Equal(dg15Hash, extractedDg15Hash.Sum(nil)) {
Expand All @@ -285,9 +282,9 @@ func Register(w http.ResponseWriter, r *http.Request) {
}

addressesCfg := api.AddressesConfig(r)
verifierContract, ok := addressesCfg.Verifiers[algorithmPair.HashAlgorithm]
verifierContract, ok := addressesCfg.Verifiers[algorithmPair.DgHashAlgorithm]
if !ok {
log.Errorf("No verifier contract found for hash algorithm %s", algorithmPair.HashAlgorithm)
log.Errorf("No verifier contract found for hash algorithm %s", algorithmPair.DgHashAlgorithm)
jsonError = append(jsonError, problems.InternalError())
return
}
Expand Down Expand Up @@ -330,15 +327,20 @@ func Register(w http.ResponseWriter, r *http.Request) {
}

func verifySod(
signedAttributes []byte,
encapsulatedContent []byte,
signature []byte,
cert *x509.Certificate,
algorithmPair types.AlgorithmPair,
cfg *config.VerifierConfig,
) error {
if err := validateSignedAttributes(signedAttributes, encapsulatedContent, algorithmPair.HashAlgorithm); err != nil {
return &types.SodError{
signedAttributes []byte,
encapsulatedContent []byte,
signature []byte,
cert *x509.Certificate,
algorithmPair *types.AlgorithmPair,
cfg *config.VerifierConfig,
) ([]byte, error) {
if algorithmPair == nil {
return nil, errors.New("algorithm pair is nil")
}

err := validateSignedAttributes(signedAttributes, encapsulatedContent, &algorithmPair.SignedAttrHashAlg)
if err != nil {
return nil, &types.SodError{
VerboseError: err,
Details: &types.SodErrorDetails{
Kind: types.SAValidateErr,
Expand All @@ -347,10 +349,11 @@ func verifySod(
}
}

if err := verifySignature(signature, cert, signedAttributes, algorithmPair); err != nil {
signedAttrHash, err := verifySignature(signature, cert, signedAttributes, *algorithmPair)
if err != nil {
unwrappedErr := errors2.Unwrap(err)
if errors2.Is(unwrappedErr, types.ErrInvalidPublicKey{}) {
return &types.SodError{
return nil, &types.SodError{
VerboseError: err,
Details: &types.SodErrorDetails{
Kind: types.PEMFilePubKeyErr,
Expand All @@ -359,7 +362,7 @@ func verifySod(
}
}

return &types.SodError{
return nil, &types.SodError{
VerboseError: err,
Details: &types.SodErrorDetails{
Kind: types.SigVerifyErr,
Expand All @@ -368,8 +371,9 @@ func verifySod(
}
}

if err := validateCert(cert, cfg.MasterCerts, cfg.DisableTimeChecks, cfg.DisableNameChecks); err != nil {
return &types.SodError{
err = validateCert(cert, cfg.MasterCerts, cfg.DisableTimeChecks, cfg.DisableNameChecks)
if err != nil {
return nil, &types.SodError{
VerboseError: err,
Details: &types.SodErrorDetails{
Kind: types.PEMFileValidateErr,
Expand All @@ -378,7 +382,7 @@ func verifySod(
}
}

return nil
return signedAttrHash, nil
}

func parseCertificate(pemFile []byte) (*x509.Certificate, error) {
Expand All @@ -396,9 +400,9 @@ func parseCertificate(pemFile []byte) (*x509.Certificate, error) {
}

func validateSignedAttributes(
signedAttributes,
encapsulatedContent []byte,
hashAlgorithm types.HashAlgorithm,
signedAttributes,
encapsulatedContent []byte,
hashAlgorithm *types.HashAlgorithm,
) error {
signedAttributesASN1 := make([]asn1.RawValue, 0)

Expand All @@ -415,14 +419,29 @@ func validateSignedAttributes(
return errors.Wrap(err, "failed to unmarshal ASN1")
}

h := types.GeneralHash(hashAlgorithm)
h.Write(encapsulatedContent)
d := h.Sum(nil)

if len(digestAttr.Digest) == 0 {
return errors.New("signed attributes digest values amount is 0")
}

hashAlgorithmFromDigest := types.HashAlgorithmFromSize(len(digestAttr.Digest[0].Bytes))
if hashAlgorithm == nil {
fmt.Printf("passed hash algorithm is nil, using from signed attr: %s\n", hashAlgorithmFromDigest.String())
hashAlgorithm = &hashAlgorithmFromDigest
}

if hashAlgorithmFromDigest != *hashAlgorithm {
// TODO use log
fmt.Printf("found different hash algorithm in signed attr %s\n", hashAlgorithmFromDigest.String())
if _, ok := types.IsValidHashAlgorithm(hashAlgorithmFromDigest.String()); ok {
fmt.Printf("changing hash algorithm from %s to %s\n", hashAlgorithm.String(), hashAlgorithmFromDigest.String())
*hashAlgorithm = hashAlgorithmFromDigest
}
}

h := types.GeneralHash(*hashAlgorithm)
h.Write(encapsulatedContent)
d := h.Sum(nil)

if !bytes.Equal(digestAttr.Digest[0].Bytes, d) {
return errors.From(
errors.New("digest values are not equal"), logan.F{
Expand All @@ -436,20 +455,20 @@ func validateSignedAttributes(
}

func verifySignature(
signature []byte,
cert *x509.Certificate,
signedAttributes []byte,
algorithmPair types.AlgorithmPair,
) error {
h := types.GeneralHash(algorithmPair.HashAlgorithm)
signature []byte,
cert *x509.Certificate,
signedAttributes []byte,
algorithmPair types.AlgorithmPair,
) ([]byte, error) {
h := types.GeneralHash(algorithmPair.SignedAttrHashAlg)
h.Write(signedAttributes)
d := h.Sum(nil)

if err := types.GeneralVerify(cert.PublicKey, d, signature, algorithmPair); err != nil {
return err
return nil, err
}

return nil
return d, nil
}

func validateCert(cert *x509.Certificate, masterCerts *x509.CertPool, disableTimeChecks, disableNameChecks bool) error {
Expand Down
16 changes: 16 additions & 0 deletions internal/types/enums.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ var hashAlgorithmMap = map[string]HashAlgorithm{
"SHA512": SHA512,
}

var hashAlgorithmSizeMap = map[int]HashAlgorithm{
20: SHA1,
28: SHA224,
32: SHA256,
48: SHA384,
64: SHA512,
}

func (h HashAlgorithm) String() string {
switch h {
case SHA1:
Expand All @@ -45,6 +53,14 @@ func HashAlgorithmFromString(alg string) HashAlgorithm {
return h
}

func HashAlgorithmFromSize(size int) HashAlgorithm {
h, ok := hashAlgorithmSizeMap[size]
if !ok {
return HashAlgorithm(0)
}
return h
}

func IsValidHashAlgorithm(alg string) (HashAlgorithm, bool) {
h, ok := hashAlgorithmMap[alg]
return h, ok
Expand Down
7 changes: 4 additions & 3 deletions internal/types/signature_algorithm.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ func (e ErrInvalidPublicKey) Error() string {

// AlgorithmPair defines a hash and signature algorithm combination.
type AlgorithmPair struct {
HashAlgorithm
DgHashAlgorithm HashAlgorithm
SignedAttrHashAlg HashAlgorithm
SignatureAlgorithm
}

Expand All @@ -32,13 +33,13 @@ func GeneralVerify(publicKey interface{}, hash []byte, signature []byte, algo Al
if !ok {
return ErrInvalidPublicKey{Expected: algo.SignatureAlgorithm}
}
return rsa.VerifyPKCS1v15(rsaKey, getCryptoHash(algo.HashAlgorithm), hash, signature)
return rsa.VerifyPKCS1v15(rsaKey, getCryptoHash(algo.SignedAttrHashAlg), hash, signature)
case RSAPSS:
rsaKey, ok := publicKey.(*rsa.PublicKey)
if !ok {
return ErrInvalidPublicKey{Expected: algo.SignatureAlgorithm}
}
return rsa.VerifyPSS(rsaKey, getCryptoHash(algo.HashAlgorithm), hash, signature, nil)
return rsa.VerifyPSS(rsaKey, getCryptoHash(algo.SignedAttrHashAlg), hash, signature, nil)
case ECDSA:
ecdsaKey, ok := publicKey.(*ecdsa.PublicKey)
if !ok {
Expand Down
16 changes: 9 additions & 7 deletions internal/utils/asn1_operations.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,22 @@ import (
"gitlab.com/distributed_lab/logan/v3/errors"
)

// ExtractFirstNBits extracts the first n bits from data.
// If data contains fewer than n bits, it pads the result with zeros.
func ExtractFirstNBits(data []byte, n uint) ([]byte, error) {
if n == 0 {
return []byte{}, nil
}

numBytes := (n + 7) / 8
result := make([]byte, numBytes)

if uint(len(data))*8 < n {
return nil, fmt.Errorf("not enough bits in data: required %d, available %d", n, len(data)*8)
bytesToCopy := numBytes
if uint(len(data)) < numBytes {
bytesToCopy = uint(len(data))
}

result := make([]byte, numBytes)

copy(result, data[:numBytes])
copy(result, data[:bytesToCopy])

remainingBits := n % 8
if remainingBits != 0 {
Expand Down Expand Up @@ -71,8 +73,8 @@ func TruncateHexPrefix(hexString *string) *string {
}

func BuildSignedData(
contract, verifier *common.Address,
passportHash, dg1Commitment, publicKey [32]byte,
contract, verifier *common.Address,
passportHash, dg1Commitment, publicKey [32]byte,
) ([]byte, error) {
return abiEncodePacked(types.RegistrationSimplePrefix, contract, passportHash[:], dg1Commitment[:], publicKey[:], verifier)
}
Expand Down

0 comments on commit c52ff85

Please sign in to comment.