diff --git a/crypto/crypto.go b/crypto/crypto.go index 5278f70acf..1b873bf976 100644 --- a/crypto/crypto.go +++ b/crypto/crypto.go @@ -280,7 +280,27 @@ func PubkeyToAddress(p ecdsa.PublicKey, nodeLocation common.Location) common.Add } func PubkeyBytesToAddress(pub []byte, nodeLocation common.Location) common.Address { - return common.BytesToAddress(Keccak256(pub[1:])[12:], nodeLocation) + var publicKey *ecdsa.PublicKey + var err error + // Check if the public key is compressed (33 bytes) or uncompressed (65 bytes) + switch len(pub) { + case 33: + publicKey, err = DecompressPubkey(pub) + if err != nil { + return common.Address{} + } + case 65: + publicKey, err = UnmarshalPubkey(pub) + if err != nil { + return common.Address{} + } + default: + return common.Address{} + } + + pubBytes := FromECDSAPub(publicKey) + + return common.BytesToAddress(Keccak256(pubBytes[1:])[12:], nodeLocation) } func zeroBytes(bytes []byte) { diff --git a/crypto/crypto_test.go b/crypto/crypto_test.go index 0db67bc09c..2466b23615 100644 --- a/crypto/crypto_test.go +++ b/crypto/crypto_test.go @@ -318,3 +318,67 @@ func TestGenerateAddress(t *testing.T) { } } } + +func TestPubkeyBytesToAddress(t *testing.T) { + tests := []struct { + name string + pubKeyHex string + expectedAddr string + shouldFail bool + }{ + { + name: "Compressed public key", + pubKeyHex: "0250495cb2f9535c684ebe4687b501c0d41a623d68c118b8dcecd393370f1d90e6", + expectedAddr: "00a65C75c4BE400C38B66bAc1103931Ab55b1597", + shouldFail: false, + }, + { + name: "Uncompressed public key", + pubKeyHex: "0450495cb2f9535c684ebe4687b501c0d41a623d68c118b8dcecd393370f1d90e65c4c6c44cd3fe809b41dfac9060ad84cb57e2d575fad24d25a7efa3396e73c10", + expectedAddr: "00a65C75c4BE400C38B66bAc1103931Ab55b1597", + shouldFail: false, + }, + { + name: "Invalid public key (32 bytes)", + pubKeyHex: "50495cb2f9535c684ebe4687b501c0d41a623d68c118b8dcecd393370f1d90e6", + expectedAddr: "", + shouldFail: true, + }, + { + name: "Invalid public key (64 bytes)", + pubKeyHex: "50495cb2f9535c684ebe4687b501c0d41a623d68c118b8dcecd393370f1d90e650495cb2f9535c684ebe4687b501c0d41a623d68c118b8dcecd393370f1d90e6", + expectedAddr: "", + shouldFail: true, + }, + { + name: "Empty public key", + pubKeyHex: "", + expectedAddr: "", + shouldFail: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pubKeyBytes, err := hex.DecodeString(tt.pubKeyHex) + if err != nil { + t.Fatalf("Failed to decode public key: %v", err) + } + + nodeLocation := common.Location{0, 0} + + addr := PubkeyBytesToAddress(pubKeyBytes, nodeLocation) + + if tt.shouldFail { + if addr != (common.Address{}) { + t.Errorf("Expected failure, but got address: %s", addr.Hex()) + } + } else { + expectedAddr := common.HexToAddress(tt.expectedAddr, nodeLocation) + if !addr.Equal(expectedAddr) { + t.Errorf("Address mismatch: got %s, want %s", addr.Hex(), expectedAddr.Hex()) + } + } + }) + } +}