From b499b37b52f8ff79f90b69c5ac930ad7f80d6906 Mon Sep 17 00:00:00 2001 From: Oliver Gugger Date: Sat, 15 Jun 2024 17:09:00 +0200 Subject: [PATCH] zombierecovery: support MuSig2 --- .golangci.yml | 1 + cmd/chantools/zombierecovery_findmatches.go | 23 +- cmd/chantools/zombierecovery_makeoffer.go | 585 ++++++++++++++---- .../zombierecovery_makeoffer_test.go | 2 +- cmd/chantools/zombierecovery_preparekeys.go | 63 +- cmd/chantools/zombierecovery_signoffer.go | 15 + go.mod | 4 +- go.sum | 4 +- lnd/channel.go | 59 ++ lnd/channel_test.go | 111 ++++ 10 files changed, 722 insertions(+), 145 deletions(-) create mode 100644 lnd/channel_test.go diff --git a/.golangci.yml b/.golangci.yml index 3439c36..51013e7 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -54,6 +54,7 @@ linters: - protogetter - depguard - mnd + - gomoddirectives issues: exclude-rules: diff --git a/cmd/chantools/zombierecovery_findmatches.go b/cmd/chantools/zombierecovery_findmatches.go index 76188cf..4277310 100644 --- a/cmd/chantools/zombierecovery_findmatches.go +++ b/cmd/chantools/zombierecovery_findmatches.go @@ -95,16 +95,19 @@ type nodeInfo struct { } type channel struct { - ChannelID string `json:"short_channel_id"` - ChanPoint string `json:"chan_point"` - Address string `json:"address"` - Capacity int64 `json:"capacity"` - txid string - vout uint32 - ourKeyIndex uint32 - ourKey *btcec.PublicKey - theirKey *btcec.PublicKey - witnessScript []byte + ChannelID string `json:"short_channel_id"` + ChanPoint string `json:"chan_point"` + Address string `json:"address"` + Capacity int64 `json:"capacity"` + MuSig2NonceRandomness string `json:"musig2_nonce_randomness,omitempty"` + MuSig2Nonces string `json:"musig2_nonces,omitempty"` + txid string + vout uint32 + ourKeyIndex uint32 + ourKey *btcec.PublicKey + theirKey *btcec.PublicKey + pkScript []byte + witnessScript []byte } type match struct { diff --git a/cmd/chantools/zombierecovery_makeoffer.go b/cmd/chantools/zombierecovery_makeoffer.go index dc52611..a59cd11 100644 --- a/cmd/chantools/zombierecovery_makeoffer.go +++ b/cmd/chantools/zombierecovery_makeoffer.go @@ -3,6 +3,8 @@ package main import ( "bufio" "bytes" + "crypto/rand" + "crypto/sha256" "encoding/hex" "encoding/json" "errors" @@ -12,13 +14,19 @@ import ( "strings" "github.com/btcsuite/btcd/btcec/v2" + "github.com/btcsuite/btcd/btcec/v2/schnorr" + "github.com/btcsuite/btcd/btcec/v2/schnorr/musig2" + "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/btcutil/hdkeychain" "github.com/btcsuite/btcd/btcutil/psbt" "github.com/btcsuite/btcd/chaincfg" + "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" - "github.com/btcsuite/btcwallet/wallet/txrules" + "github.com/btcsuite/btcwallet/wallet" "github.com/lightninglabs/chantools/lnd" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" + "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/spf13/cobra" ) @@ -157,6 +165,20 @@ func (c *zombieRecoveryMakeOfferCommand) Execute(_ *cobra.Command, return errors.New("invalid files, channel address " + "missing") } + + if len(keys2.Channels[idx].MuSig2Nonces) != + len(node1Channel.MuSig2Nonces) { + + return errors.New("invalid files, MuSig2 nonce " + + "lengths don't match") + } + + if len(keys2.Channels[idx].MuSig2NonceRandomness) != + len(node1Channel.MuSig2NonceRandomness) { + + return errors.New("invalid files, MuSig2 randomness " + + "lengths don't match") + } } // If we're only matching, we can stop here. @@ -194,32 +216,42 @@ func (c *zombieRecoveryMakeOfferCommand) Execute(_ *cobra.Command, var ( ourKeys []string ourPayoutAddr string + ourChannels []*channel theirKeys []string theirPayoutAddr string + theirChannels []*channel ) if keys1.Node1.PubKey == pubKeyStr && len(keys1.Node1.MultisigKeys) > 0 { ourKeys = keys1.Node1.MultisigKeys ourPayoutAddr = keys1.Node1.PayoutAddr + ourChannels = keys1.Channels theirKeys = keys2.Node2.MultisigKeys theirPayoutAddr = keys2.Node2.PayoutAddr + theirChannels = keys2.Channels } if keys1.Node2.PubKey == pubKeyStr && len(keys1.Node2.MultisigKeys) > 0 { ourKeys = keys1.Node2.MultisigKeys ourPayoutAddr = keys1.Node2.PayoutAddr + ourChannels = keys1.Channels theirKeys = keys2.Node1.MultisigKeys theirPayoutAddr = keys2.Node1.PayoutAddr + theirChannels = keys2.Channels } if keys2.Node1.PubKey == pubKeyStr && len(keys2.Node1.MultisigKeys) > 0 { ourKeys = keys2.Node1.MultisigKeys ourPayoutAddr = keys2.Node1.PayoutAddr + ourChannels = keys2.Channels theirKeys = keys1.Node2.MultisigKeys theirPayoutAddr = keys1.Node2.PayoutAddr + theirChannels = keys1.Channels } if keys2.Node2.PubKey == pubKeyStr && len(keys2.Node2.MultisigKeys) > 0 { ourKeys = keys2.Node2.MultisigKeys ourPayoutAddr = keys2.Node2.PayoutAddr + ourChannels = keys2.Channels theirKeys = keys1.Node1.MultisigKeys theirPayoutAddr = keys1.Node1.PayoutAddr + theirChannels = keys1.Channels } if len(ourKeys) == 0 || len(theirKeys) == 0 { return errors.New("couldn't find necessary keys") @@ -243,14 +275,23 @@ func (c *zombieRecoveryMakeOfferCommand) Execute(_ *cobra.Command, return err } + // Let's prepare the PSBT. + packet, err := psbt.NewFromUnsignedTx(wire.NewMsgTx(2)) + if err != nil { + return fmt.Errorf("error creating PSBT from TX: %w", err) + } + // Let's now sum up the tally of how much of the rescued funds should // go to which party. var ( - inputs = make([]*wire.TxIn, 0, len(keys1.Channels)) - ourSum int64 - theirSum int64 + ourSum int64 + theirSum int64 + estimator input.TxWeightEstimator + signDescs = make( + []*input.SignDescriptor, 0, len(keys1.Channels), + ) ) - for idx, channel := range keys1.Channels { + for idx, channel := range ourChannels { op, err := lnd.ParseOutpoint(channel.ChanPoint) if err != nil { return fmt.Errorf("error parsing channel out point: %w", @@ -269,30 +310,155 @@ func (c *zombieRecoveryMakeOfferCommand) Execute(_ *cobra.Command, ourSum += ourPart theirSum += theirPart - inputs = append(inputs, &wire.TxIn{ + txIn := &wire.TxIn{ PreviousOutPoint: *op, - // It's not actually an old sig script but a witness - // script but we'll move that to the correct place once - // we create the PSBT. - SignatureScript: channel.witnessScript, - }) + } + pIn := psbt.PInput{ + WitnessScript: channel.witnessScript, + WitnessUtxo: &wire.TxOut{ + PkScript: channel.pkScript, + Value: channel.Capacity, + }, + // We'll be signing with our key, so we can just add the + // other party's pubkey as additional info, so it's easy + // for them to sign as well. + Unknowns: []*psbt.Unknown{{ + Key: PsbtKeyTypeOutputMissingSigPubkey, + Value: channel.theirKey.SerializeCompressed(), + }}, + } + + channelAddr, err := lnd.ParseAddress( + channel.Address, chainParams, + ) + if err != nil { + return fmt.Errorf("error parsing channel address: %w", + err) + } + + prevOutFetcher := txscript.NewCannedPrevOutputFetcher( + pIn.WitnessUtxo.PkScript, pIn.WitnessUtxo.Value, + ) + signDesc := &input.SignDescriptor{ + KeyDesc: keychain.KeyDescriptor{ + PubKey: channel.ourKey, + KeyLocator: keychain.KeyLocator{ + Family: keychain.KeyFamilyMultiSig, + Index: channel.ourKeyIndex, + }, + }, + WitnessScript: channel.witnessScript, + Output: pIn.WitnessUtxo, + InputIndex: idx, + PrevOutputFetcher: prevOutFetcher, + } + + switch a := channelAddr.(type) { + case *btcutil.AddressWitnessScriptHash: + estimator.AddWitnessInput(input.MultiSigWitnessSize) + pIn.SighashType = txscript.SigHashAll + signDesc.HashType = txscript.SigHashAll + signDesc.SignMethod = input.WitnessV0SignMethod + + case *btcutil.AddressTaproot: + estimator.AddTaprootKeySpendInput( + txscript.SigHashDefault, + ) + pIn.SighashType = txscript.SigHashDefault + signDesc.HashType = txscript.SigHashDefault + signDesc.SignMethod = input.TaprootKeySpendSignMethod + + err := addMuSig2Data( + extendedKey, &pIn, channel, theirChannels[idx], + op, a.WitnessProgram(), + ) + if err != nil { + return fmt.Errorf("error adding MuSig2 data: "+ + "%w", err) + } + + default: + return errors.New("unsupported address type for " + + "channel address") + } + + packet.UnsignedTx.TxIn = append(packet.UnsignedTx.TxIn, txIn) + packet.Inputs = append(packet.Inputs, pIn) + signDescs = append(signDescs, signDesc) } - // Let's create a fee estimator now to give an overview over the - // deducted fees. - estimator := input.TxWeightEstimator{} + // Don't create dust. + dustLimit := int64(lnwallet.DustLimitForSize(input.P2WSHSize)) + if ourSum < dustLimit { + ourSum = 0 + } + if theirSum < dustLimit { + theirSum = 0 + } // Only add output for us if we should receive something. + var ourOutput, theirOutput *wire.TxOut if ourSum > 0 { - estimator.AddP2WKHOutput() + err = lnd.CheckAddress( + ourPayoutAddr, chainParams, false, "our payout", + lnd.AddrTypeP2WKH, lnd.AddrTypeP2TR, + ) + if err != nil { + return fmt.Errorf("error verifying our payout "+ + "address: %w", err) + } + + pkScript, err := lnd.PrepareWalletAddress( + ourPayoutAddr, chainParams, &estimator, nil, + "our payout", + ) + if err != nil { + return fmt.Errorf("error preparing our payout "+ + "address: %w", err) + } + + ourOutput = &wire.TxOut{ + PkScript: pkScript, + Value: ourSum, + } + packet.UnsignedTx.TxOut = append( + packet.UnsignedTx.TxOut, ourOutput, + ) + packet.Outputs = append(packet.Outputs, psbt.POutput{}) } + if theirSum > 0 { - estimator.AddP2WKHOutput() - } - for range inputs { - estimator.AddWitnessInput(input.MultiSigWitnessSize) + err = lnd.CheckAddress( + theirPayoutAddr, chainParams, false, "their payout", + lnd.AddrTypeP2WKH, lnd.AddrTypeP2TR, + ) + if err != nil { + return fmt.Errorf("error verifying their payout "+ + "address: %w", err) + } + + pkScript, err := lnd.PrepareWalletAddress( + theirPayoutAddr, chainParams, &estimator, nil, + "their payout", + ) + if err != nil { + return fmt.Errorf("error preparing their payout "+ + "address: %w", err) + } + + theirOutput = &wire.TxOut{ + PkScript: pkScript, + Value: theirSum, + } + packet.UnsignedTx.TxOut = append( + packet.UnsignedTx.TxOut, theirOutput, + ) + packet.Outputs = append(packet.Outputs, psbt.POutput{}) } - feeRateKWeight := chainfee.SatPerKVByte(1000 * c.FeeRate).FeePerKWeight() + + feeRateKWeight := chainfee.SatPerKVByte( + 1000 * c.FeeRate, + ).FeePerKWeight() totalFee := int64(feeRateKWeight.FeeForWeight(estimator.Weight())) fmt.Printf("Current tally (before fees):\n\t"+ @@ -307,125 +473,78 @@ func (c *zombieRecoveryMakeOfferCommand) Execute(_ *cobra.Command, switch { case ourSum-halfFee > 0 && theirSum-halfFee > 0: ourSum -= halfFee + ourOutput.Value -= halfFee theirSum -= halfFee + theirOutput.Value -= halfFee case ourSum-totalFee > 0: ourSum -= totalFee + ourOutput.Value -= totalFee case theirSum-totalFee > 0: theirSum -= totalFee + theirOutput.Value -= totalFee default: return errors.New("error distributing fees, unhandled case") } - // Our output. - pkScript, err := lnd.GetP2WPKHScript(ourPayoutAddr, chainParams) - if err != nil { - return fmt.Errorf("error parsing our payout address: %w", err) - } - ourTxOut := &wire.TxOut{ - PkScript: pkScript, - Value: ourSum, - } - - // Their output - pkScript, err = lnd.GetP2WPKHScript(theirPayoutAddr, chainParams) - if err != nil { - return fmt.Errorf("error parsing their payout address: %w", err) - } - theirTxOut := &wire.TxOut{ - PkScript: pkScript, - Value: theirSum, - } - - // Don't create dust. - if txrules.IsDustOutput(ourTxOut, txrules.DefaultRelayFeePerKb) { - ourSum = 0 - } - if txrules.IsDustOutput(theirTxOut, txrules.DefaultRelayFeePerKb) { - theirSum = 0 - } - fmt.Printf("Current tally (after fees):\n\t"+ "To our address (%s): %d sats\n\t"+ "To their address (%s): %d sats\n", ourPayoutAddr, ourSum, theirPayoutAddr, theirSum) - // And now create the PSBT. - tx := wire.NewMsgTx(2) - if ourSum > 0 { - tx.TxOut = append(tx.TxOut, ourTxOut) - } - if theirSum > 0 { - tx.TxOut = append(tx.TxOut, theirTxOut) - } - for _, txIn := range inputs { - tx.TxIn = append(tx.TxIn, &wire.TxIn{ - PreviousOutPoint: txIn.PreviousOutPoint, - }) - } - packet, err := psbt.NewFromUnsignedTx(tx) - if err != nil { - return fmt.Errorf("error creating PSBT from TX: %w", err) - } - - // First we add the necessary information to the psbt package so that - // we can sign the transaction with SIGHASH_ALL. - for idx, txIn := range inputs { - channel := keys1.Channels[idx] - - // We've mis-used this field to transport the witness script, - // let's now copy it to the correct place. - packet.Inputs[idx].WitnessScript = txIn.SignatureScript - - // Let's prepare the witness UTXO. - pkScript, err := input.WitnessScriptHash(channel.witnessScript) - if err != nil { - return err - } - packet.Inputs[idx].WitnessUtxo = &wire.TxOut{ - PkScript: pkScript, - Value: channel.Capacity, - } - - // We'll be signing with our key so we can just add the other - // party's pubkey as additional info so it's easy for them to - // sign as well. - packet.Inputs[idx].Unknowns = append( - packet.Inputs[idx].Unknowns, &psbt.Unknown{ - Key: PsbtKeyTypeOutputMissingSigPubkey, - Value: channel.theirKey.SerializeCompressed(), - }, - ) - } - // Loop a second time through the inputs and sign each input. We now - // have all the witness/nonwitness data filled in the psbt package. + // have all the witness/non-witness data filled in the psbt package. signer := &lnd.Signer{ ExtendedKey: extendedKey, ChainParams: chainParams, } - for idx, txIn := range inputs { - channel := keys1.Channels[idx] + for idx := range packet.UnsignedTx.TxIn { + signDesc := signDescs[idx] + + // If we're dealing with a taproot channel, we'll need to + // create a MuSig2 partial signature. + if signDesc.SignMethod == input.TaprootKeySpendSignMethod { + err := muSig2PartialSign( + signer, &signDesc.KeyDesc, packet, idx, + ) + if err != nil { + return fmt.Errorf("error creating MuSig2 "+ + "partial signature: %w", err) + } - keyDesc := keychain.KeyDescriptor{ - PubKey: channel.ourKey, - KeyLocator: keychain.KeyLocator{ - Family: keychain.KeyFamilyMultiSig, - Index: channel.ourKeyIndex, - }, + continue + } + + ourSigRaw, err := signer.SignOutputRaw( + packet.UnsignedTx, signDesc, + ) + if err != nil { + return fmt.Errorf("error signing with our key: %w", err) } - utxo := &wire.TxOut{ - Value: channel.Capacity, + ourSig := append(ourSigRaw.Serialize(), byte(signDesc.HashType)) + + // Great, we were able to create our sig, let's add it to the + // PSBT. + updater, err := psbt.NewUpdater(packet) + if err != nil { + return fmt.Errorf("error creating PSBT updater: %w", + err) } - err = signer.AddPartialSignature( - packet, keyDesc, utxo, txIn.SignatureScript, idx, + status, err := updater.Sign( + idx, ourSig, + signDesc.KeyDesc.PubKey.SerializeCompressed(), nil, + signDesc.WitnessScript, ) if err != nil { - return fmt.Errorf("error signing input %d: %w", idx, + return fmt.Errorf("error adding signature to PSBT: %w", err) } + if status != 0 { + return fmt.Errorf("unexpected status for signature "+ + "update, got %d wanted 0", status) + } } // Looks like we're done! @@ -466,7 +585,7 @@ channelLoop: for _, channel := range channels { for ourKeyIndex, ourKey := range ourPubKeys { for _, theirKey := range theirPubKeys { - match, witnessScript, err := matchScript( + match, witScript, pkScript, err := matchScript( channel.Address, ourKey, theirKey, chainParams, ) @@ -479,7 +598,8 @@ channelLoop: channel.ourKeyIndex = uint32(ourKeyIndex) channel.ourKey = ourKey channel.theirKey = theirKey - channel.witnessScript = witnessScript + channel.witnessScript = witScript + channel.pkScript = pkScript log.Infof("Found keys for channel %s: "+ "our key %x, their key %x", @@ -500,25 +620,47 @@ channelLoop: } func matchScript(address string, key1, key2 *btcec.PublicKey, - params *chaincfg.Params) (bool, []byte, error) { + params *chaincfg.Params) (bool, []byte, []byte, error) { - channelScript, err := lnd.GetP2WSHScript(address, params) + addr, err := lnd.ParseAddress(address, params) if err != nil { - return false, nil, err + return false, nil, nil, fmt.Errorf("error parsing channel "+ + "funding address '%s': %w", address, err) } - witnessScript, err := input.GenMultiSigScript( - key1.SerializeCompressed(), key2.SerializeCompressed(), - ) - if err != nil { - return false, nil, err - } - pkScript, err := input.WitnessScriptHash(witnessScript) + channelScript, err := txscript.PayToAddrScript(addr) if err != nil { - return false, nil, err + return false, nil, nil, err } - return bytes.Equal(channelScript, pkScript), witnessScript, nil + switch addr.(type) { + case *btcutil.AddressWitnessScriptHash: + witnessScript, err := input.GenMultiSigScript( + key1.SerializeCompressed(), key2.SerializeCompressed(), + ) + if err != nil { + return false, nil, nil, err + } + pkScript, err := input.WitnessScriptHash(witnessScript) + if err != nil { + return false, nil, nil, err + } + + return bytes.Equal(channelScript, pkScript), witnessScript, + pkScript, nil + + case *btcutil.AddressTaproot: + pkScript, _, err := input.GenTaprootFundingScript(key1, key2, 0) + if err != nil { + return false, nil, nil, err + } + + return bytes.Equal(channelScript, pkScript), nil, pkScript, nil + + default: + return false, nil, nil, fmt.Errorf("unsupported address type "+ + "for channel funding address: %T", addr) + } } func askAboutChannel(channel *channel, current, total int, ourAddr, @@ -560,3 +702,198 @@ func askAboutChannel(channel *channel, current, total int, ourAddr, return int64(ourPart), theirPart, nil } + +func addMuSig2Data(extendedKey *hdkeychain.ExtendedKey, pIn *psbt.PInput, + ourChannel, theirChannel *channel, channelPoint *wire.OutPoint, + xOnlyPubKey []byte) error { + + aggKey, err := schnorr.ParsePubKey(xOnlyPubKey) + if err != nil { + return fmt.Errorf("error parsing x-only pubkey: %w", err) + } + + ourRandomnessBytes, err := hex.DecodeString( + ourChannel.MuSig2NonceRandomness, + ) + if err != nil { + return fmt.Errorf("error decoding nonce randomness: %w", err) + } + + theirRandomnessBytes, err := hex.DecodeString( + theirChannel.MuSig2NonceRandomness, + ) + if err != nil { + return fmt.Errorf("error decoding nonce randomness: %w", err) + } + + ourNonceBytes, err := hex.DecodeString(ourChannel.MuSig2Nonces) + if err != nil { + return fmt.Errorf("error decoding nonce: %w", err) + } + + theirNonceBytes, err := hex.DecodeString(theirChannel.MuSig2Nonces) + if err != nil { + return fmt.Errorf("error decoding nonce: %w", err) + } + + // We first make sure that the nonces we got are correct, and we created + // them initially, before we create new ones (to avoid security issues + // when signing multiple offers). + var ourRandomness [32]byte + copy(ourRandomness[:], ourRandomnessBytes) + ourNonces, err := lnd.GenerateMuSig2Nonces( + extendedKey, ourRandomness, channelPoint, chainParams, nil, + ) + if err != nil { + return fmt.Errorf("error generating MuSig2 nonces: %w", err) + } + + if !bytes.Equal(ourNonces.PubNonce[:], ourNonceBytes) { + return errors.New("MuSig2 nonces don't match") + } + + // Because at this point we're going to create a partial signature, we + // create a new nonce pair for the session. This is to make sure that + // the nonce is unique for each session, in case we're signing multiple + // offers. + if _, err := rand.Read(ourRandomness[:]); err != nil { + return fmt.Errorf("error generating randomness: %w", err) + } + + ourNonces, err = lnd.GenerateMuSig2Nonces( + extendedKey, ourRandomness, channelPoint, chainParams, nil, + ) + if err != nil { + return fmt.Errorf("error generating MuSig2 nonces: %w", err) + } + + var theirNonces [musig2.PubNonceSize]byte + copy(theirNonces[:], theirNonceBytes) + + pIn.MuSig2PubNonces = append(pIn.MuSig2PubNonces, &psbt.MuSig2PubNonce{ + PubKey: ourChannel.ourKey, + AggregateKey: aggKey, + TapLeafHash: ourRandomness[:], + PubNonce: ourNonces.PubNonce, + }, &psbt.MuSig2PubNonce{ + PubKey: ourChannel.theirKey, + AggregateKey: aggKey, + TapLeafHash: theirRandomnessBytes, + PubNonce: theirNonces, + }) + + return nil +} + +func muSig2PartialSign(signer *lnd.Signer, keyDesc *keychain.KeyDescriptor, + packet *psbt.Packet, idx int) error { + + signingKey, err := signer.FetchPrivateKey(keyDesc) + if err != nil { + return fmt.Errorf("error fetching private key: %w", err) + } + + pIn := packet.Inputs[idx] + if len(pIn.MuSig2PubNonces) != 2 { + return fmt.Errorf("expected 2 MuSig2 nonces in packet input, "+ + "got %d", len(pIn.MuSig2PubNonces)) + } + channelPoint := &packet.UnsignedTx.TxIn[idx].PreviousOutPoint + + var ourNonces, theirNonces *psbt.MuSig2PubNonce + for idx := range pIn.MuSig2PubNonces { + nonce := pIn.MuSig2PubNonces[idx] + if nonce.PubKey.IsEqual(keyDesc.PubKey) { + ourNonces = nonce + } else { + theirNonces = nonce + } + } + if ourNonces == nil || theirNonces == nil { + return errors.New("couldn't find our or their nonce") + } + + keys := []*btcec.PublicKey{ourNonces.PubKey, theirNonces.PubKey} + aggKey, _, _, err := musig2.AggregateKeys( + keys, true, musig2.WithBIP86KeyTweak(), + ) + if err != nil { + return fmt.Errorf("error aggregating keys: %w", err) + } + + ctx, err := musig2.NewContext( + signingKey, true, musig2.WithBip86TweakCtx(), + musig2.WithKnownSigners(keys), + ) + if err != nil { + return fmt.Errorf("error creating MuSig2 context: %w", err) + } + + // Check that the randomness in the tap leaf hash is correct. We'll then + // later check that it also corresponds to the public nonces. + var emptyHash [32]byte + if len(ourNonces.TapLeafHash) != sha256.Size || + bytes.Equal(ourNonces.TapLeafHash, emptyHash[:]) { + + return errors.New("invalid nonce randomness in tap leaf hash") + } + + // Generate the secure nonces from the information we got. We use the + // tap leaf hash to transport our randomness. + var ourRandomness [32]byte + copy(ourRandomness[:], ourNonces.TapLeafHash) + ourSecNonces, err := lnd.GenerateMuSig2Nonces( + signer.ExtendedKey, ourRandomness, channelPoint, chainParams, + signingKey, + ) + if err != nil { + return fmt.Errorf("error generating MuSig2 nonces: %w", err) + } + + // Make sure the re-derived nonces match the public nonces in the PSBT. + if !bytes.Equal(ourSecNonces.PubNonce[:], ourNonces.PubNonce[:]) { + return errors.New("re-derived public nonce doesn't match") + } + + sess, err := ctx.NewSession(musig2.WithPreGeneratedNonce(ourSecNonces)) + if err != nil { + return fmt.Errorf("error creating MuSig2 session: %w", err) + } + + haveAll, err := sess.RegisterPubNonce(theirNonces.PubNonce) + if err != nil { + return fmt.Errorf("error registering remote nonce: %w", err) + } + + if !haveAll { + return errors.New("didn't receive all nonces") + } + + prevOutFetcher := wallet.PsbtPrevOutputFetcher(packet) + sigHashes := txscript.NewTxSigHashes(packet.UnsignedTx, prevOutFetcher) + sigHash, err := txscript.CalcTaprootSignatureHash( + sigHashes, packet.Inputs[idx].SighashType, packet.UnsignedTx, + idx, prevOutFetcher, + ) + if err != nil { + return fmt.Errorf("error calculating signature hash: %w", err) + } + + var sigHashMsg [32]byte + copy(sigHashMsg[:], sigHash) + partialSig, err := sess.Sign(sigHashMsg, musig2.WithSortedKeys()) + if err != nil { + return fmt.Errorf("error signing with MuSig2: %w", err) + } + + psbtPartialSig := &psbt.MuSig2PartialSig{ + PubKey: ourNonces.PubKey, + AggregateKey: aggKey.PreTweakedKey, + PartialSig: *partialSig, + } + packet.Inputs[idx].MuSig2PartialSigs = append( + packet.Inputs[idx].MuSig2PartialSigs, psbtPartialSig, + ) + + return nil +} diff --git a/cmd/chantools/zombierecovery_makeoffer_test.go b/cmd/chantools/zombierecovery_makeoffer_test.go index 6a536d9..1c250ae 100644 --- a/cmd/chantools/zombierecovery_makeoffer_test.go +++ b/cmd/chantools/zombierecovery_makeoffer_test.go @@ -25,7 +25,7 @@ var ( ) func TestMatchScript(t *testing.T) { - ok, _, err := matchScript(addr, key1, key2, &chaincfg.MainNetParams) + ok, _, _, err := matchScript(addr, key1, key2, &chaincfg.MainNetParams) require.NoError(t, err) require.True(t, ok) } diff --git a/cmd/chantools/zombierecovery_preparekeys.go b/cmd/chantools/zombierecovery_preparekeys.go index 3871987..fc5c058 100644 --- a/cmd/chantools/zombierecovery_preparekeys.go +++ b/cmd/chantools/zombierecovery_preparekeys.go @@ -1,7 +1,7 @@ package main import ( - "bytes" + "crypto/rand" "encoding/hex" "encoding/json" "errors" @@ -9,6 +9,8 @@ import ( "os" "time" + "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/wire" "github.com/lightninglabs/chantools/lnd" "github.com/spf13/cobra" ) @@ -50,7 +52,7 @@ correct ones for the matched channels.`, cc.cmd.Flags().StringVar( &cc.PayoutAddr, "payout_addr", "", "the address where this "+ "node's rescued funds should be sent to, must be a "+ - "P2WPKH (native SegWit) address", + "P2WPKH (native SegWit) or P2TR (Taproot) address", ) cc.cmd.Flags().Uint32Var( &cc.NumKeys, "num_keys", numMultisigKeys, "the number of "+ @@ -70,9 +72,13 @@ func (c *zombieRecoveryPrepareKeysCommand) Execute(_ *cobra.Command, return fmt.Errorf("error reading root key: %w", err) } - _, err = lnd.GetP2WPKHScript(c.PayoutAddr, chainParams) + err = lnd.CheckAddress( + c.PayoutAddr, chainParams, false, "payout", lnd.AddrTypeP2WKH, + lnd.AddrTypeP2TR, + ) if err != nil { - return errors.New("invalid payout address, must be P2WPKH") + return errors.New("invalid payout address, must be P2WPKH or " + + "P2TR") } matchFileBytes, err := os.ReadFile(c.MatchFile) @@ -81,9 +87,8 @@ func (c *zombieRecoveryPrepareKeysCommand) Execute(_ *cobra.Command, c.MatchFile, err) } - decoder := json.NewDecoder(bytes.NewReader(matchFileBytes)) - match := &match{} - if err := decoder.Decode(&match); err != nil { + var match match + if err := json.Unmarshal(matchFileBytes, &match); err != nil { return fmt.Errorf("error decoding match file %s: %w", c.MatchFile, err) } @@ -115,6 +120,50 @@ func (c *zombieRecoveryPrepareKeysCommand) Execute(_ *cobra.Command, nodeInfo = match.Node2 } + // If there are any Simple Taproot channels, we need to generate some + // randomness and nonces from that randomness for each channel. + for idx := range match.Channels { + matchChannel := match.Channels[idx] + addr, err := lnd.ParseAddress(matchChannel.Address, chainParams) + if err != nil { + return fmt.Errorf("error parsing channel funding "+ + "address '%s': %w", matchChannel.Address, err) + } + + _, isP2TR := addr.(*btcutil.AddressTaproot) + if isP2TR { + chanPoint, err := wire.NewOutPointFromString( + matchChannel.ChanPoint, + ) + if err != nil { + return fmt.Errorf("error parsing channel "+ + "point %s: %w", matchChannel.ChanPoint, + err) + } + + var randomness [32]byte + if _, err := rand.Read(randomness[:]); err != nil { + return err + } + + nonces, err := lnd.GenerateMuSig2Nonces( + extendedKey, randomness, chanPoint, chainParams, + nil, + ) + if err != nil { + return fmt.Errorf("error generating MuSig2 "+ + "nonces: %w", err) + } + + matchChannel.MuSig2NonceRandomness = hex.EncodeToString( + randomness[:], + ) + matchChannel.MuSig2Nonces = hex.EncodeToString( + nonces.PubNonce[:], + ) + } + } + // Derive all 2500 keys now, this might take a while. for index := range c.NumKeys { _, pubKey, _, err := lnd.DeriveKey( diff --git a/cmd/chantools/zombierecovery_signoffer.go b/cmd/chantools/zombierecovery_signoffer.go index 4541e64..d433ee3 100644 --- a/cmd/chantools/zombierecovery_signoffer.go +++ b/cmd/chantools/zombierecovery_signoffer.go @@ -151,6 +151,21 @@ func signOffer(rootKey *hdkeychain.ExtendedKey, return fmt.Errorf("could not find local multisig key: "+ "%w", err) } + + // If this is a Simple Taproot channel, we need to generate a + // partial MuSig2 signature instead. + if len(packet.Inputs[idx].MuSig2PartialSigs) > 0 { + err = muSig2PartialSign( + signer, localKeyDesc, packet, idx, + ) + if err != nil { + return fmt.Errorf("error adding partial "+ + "signature: %w", err) + } + + continue + } + if len(packet.Inputs[idx].WitnessScript) == 0 { return errors.New("invalid PSBT, missing witness " + "script") diff --git a/go.mod b/go.mod index 7d4d71d..cd0b0af 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,7 @@ require ( github.com/btcsuite/btcd/chaincfg/chainhash v1.1.0 github.com/btcsuite/btclog v0.0.0-20170628155309-84c8d2346e9f github.com/btcsuite/btcwallet v0.16.10-0.20240410030101-6fe19a472a62 - github.com/btcsuite/btcwallet/wallet/txrules v1.2.1 + github.com/btcsuite/btcwallet/wallet/txrules v1.2.1 // indirect github.com/btcsuite/btcwallet/walletdb v1.4.2 github.com/coreos/bbolt v1.3.3 github.com/davecgh/go-spew v1.1.1 @@ -211,3 +211,5 @@ require ( // allows us to specify that as an option. This is required for the // taproot-assets dependency to function properly. replace google.golang.org/protobuf => github.com/lightninglabs/protobuf-go-hex-display v1.33.0-hex-display + +replace github.com/btcsuite/btcd/btcutil/psbt => github.com/guggero/btcd/btcutil/psbt v0.0.0-20240615145141-63f97ed9872a diff --git a/go.sum b/go.sum index 7892285..e504a9c 100644 --- a/go.sum +++ b/go.sum @@ -658,8 +658,6 @@ github.com/btcsuite/btcd/btcutil v1.0.0/go.mod h1:Uoxwv0pqYWhD//tfTiipkxNfdhG9Ur github.com/btcsuite/btcd/btcutil v1.1.0/go.mod h1:5OapHB7A2hBBWLm48mmw4MOHNJCcUBTwmWH/0Jn8VHE= github.com/btcsuite/btcd/btcutil v1.1.5 h1:+wER79R5670vs/ZusMTF1yTcRYE5GUsFbdjdisflzM8= github.com/btcsuite/btcd/btcutil v1.1.5/go.mod h1:PSZZ4UitpLBWzxGd5VGOrLnmOjtPP/a6HaFo12zMs00= -github.com/btcsuite/btcd/btcutil/psbt v1.1.8 h1:4voqtT8UppT7nmKQkXV+T9K8UyQjKOn2z/ycpmJK8wg= -github.com/btcsuite/btcd/btcutil/psbt v1.1.8/go.mod h1:kA6FLH/JfUx++j9pYU0pyu+Z8XGBQuuTmuKYUf6q7/U= github.com/btcsuite/btcd/chaincfg/chainhash v1.0.0/go.mod h1:7SFka0XMvUgj3hfZtydOrQY2mwhPclbT2snogU7SQQc= github.com/btcsuite/btcd/chaincfg/chainhash v1.0.1/go.mod h1:7SFka0XMvUgj3hfZtydOrQY2mwhPclbT2snogU7SQQc= github.com/btcsuite/btcd/chaincfg/chainhash v1.1.0 h1:59Kx4K6lzOW5w6nFlA0v5+lk/6sjybR934QNHSJZPTQ= @@ -1004,6 +1002,8 @@ github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFb github.com/grpc-ecosystem/grpc-gateway/v2 v2.7.0/go.mod h1:hgWBS7lorOAVIJEQMi4ZsPv9hVvWI6+ch50m39Pf2Ks= github.com/grpc-ecosystem/grpc-gateway/v2 v2.11.3 h1:lLT7ZLSzGLI08vc9cpd+tYmNWjdKDqyr/2L+f6U12Fk= github.com/grpc-ecosystem/grpc-gateway/v2 v2.11.3/go.mod h1:o//XUCC/F+yRGJoPO/VU0GSB0f8Nhgmxx0VIRUvaC0w= +github.com/guggero/btcd/btcutil/psbt v0.0.0-20240615145141-63f97ed9872a h1:8TM7i6cMdvNZWtw6eVVP5wDOvVbc9Cjf/ZIc3+APo34= +github.com/guggero/btcd/btcutil/psbt v0.0.0-20240615145141-63f97ed9872a/go.mod h1:7+GB/GHXQM8xCb9q1A5sHDT3LNgrK7fZofPddOAgc3U= github.com/hashicorp/consul/api v1.1.0/go.mod h1:VmuI/Lkw1nC05EYQWNKwWGbkg+FbDBtguAZLlVdkD9Q= github.com/hashicorp/consul/sdk v0.1.1/go.mod h1:VKf9jXwCTEY1QZP2MOLRhb5i/I/ssyNV1vwHyQBF0x8= github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= diff --git a/lnd/channel.go b/lnd/channel.go index 18eb7ac..bb3a0fa 100644 --- a/lnd/channel.go +++ b/lnd/channel.go @@ -1,17 +1,23 @@ package lnd import ( + "bytes" "errors" "fmt" "strconv" "strings" + "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcec/v2/ecdsa" + "github.com/btcsuite/btcd/btcec/v2/schnorr/musig2" + "github.com/btcsuite/btcd/btcutil/hdkeychain" + "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/input" + "github.com/lightningnetwork/lnd/lnwire" ) type LightningChannel struct { @@ -111,3 +117,56 @@ func ParseOutpoint(s string) (*wire.OutPoint, error) { Index: uint32(index), }, nil } + +// GenerateMuSig2Nonces generates random nonces for a MuSig2 signing session. +func GenerateMuSig2Nonces(extendedKey *hdkeychain.ExtendedKey, + randomness [32]byte, chanPoint *wire.OutPoint, + chainParams *chaincfg.Params, + signingKey *btcec.PrivateKey) (*musig2.Nonces, error) { + + privKey, err := DeriveMuSig2NoncePrivKey(extendedKey, chainParams) + if err != nil { + return nil, err + } + + chanID := lnwire.NewChanIDFromOutPoint(*chanPoint) + nonces, err := musig2.GenNonces( + musig2.WithPublicKey(privKey.PubKey()), + musig2.WithNonceSecretKeyAux(privKey), + musig2.WithCustomRand(bytes.NewReader(randomness[:])), + musig2.WithNonceAuxInput(chanID[:]), + ) + if err != nil { + return nil, err + } + + // If we actually know the final signing key, we need to update it in + // the secret nonce to bypass a check in the MuSig2 library. + if signingKey != nil { + copy( + nonces.SecNonce[btcec.PrivKeyBytesLen*2:], + signingKey.PubKey().SerializeCompressed(), + ) + } + + return nonces, nil +} + +// DeriveMuSig2NoncePrivKey derives a private key to be used as a nonce in a +// MuSig2 signing session. +func DeriveMuSig2NoncePrivKey(extendedKey *hdkeychain.ExtendedKey, + chainParams *chaincfg.Params) (*btcec.PrivateKey, error) { + + // We use a derivation path that is not used by lnd, to make sure we + // don't put any keys at risk. + path := fmt.Sprintf( + LndDerivationPath+"/0/%d", chainParams.HDCoinType, 1337, 1337, + ) + + key, _, _, err := DeriveKey(extendedKey, path, chainParams) + if err != nil { + return nil, err + } + + return key.ECPrivKey() +} diff --git a/lnd/channel_test.go b/lnd/channel_test.go new file mode 100644 index 0000000..8bf5e08 --- /dev/null +++ b/lnd/channel_test.go @@ -0,0 +1,111 @@ +package lnd + +import ( + "encoding/hex" + "testing" + + "github.com/btcsuite/btcd/btcutil/hdkeychain" + "github.com/btcsuite/btcd/chaincfg" + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" + "github.com/stretchr/testify/require" +) + +var ( + rootKey = "tprv8ZgxMBicQKsPejNXQLJKe3dBBs9Zrt53EZrsBzVLQ8rZji3" + + "hVb3wcoRvgrjvTmjPG2ixoGUUkCyC6yBEy9T5gbLdvD2a5VmJbcFd5Q9pkAs" + + staticRand = [32]byte{ + 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, + 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, + 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, + 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, + } + + staticChanPoint = &wire.OutPoint{ + Hash: chainhash.Hash{ + 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, + }, + Index: 123, + } + + testNetParams = &chaincfg.TestNet3Params + mainNetParams = &chaincfg.MainNetParams + + staticPubNonceHex = "0275757be33335347132895c3cf7c9d5d4c6dbfbc2b8090b" + + "0c311929b5b3304629026f5183811ea44bd60110f9d4c30525bb1e8c72f9" + + "19b766464e91db7739d4123a" +) + +func TestGenerateMuSig2Nonces(t *testing.T) { + extendedKey, err := hdkeychain.NewKeyFromString(rootKey) + require.NoError(t, err) + + staticNonces, err := GenerateMuSig2Nonces( + extendedKey, staticRand, staticChanPoint, testNetParams, nil, + ) + require.NoError(t, err) + + require.Equal( + t, staticPubNonceHex, + hex.EncodeToString(staticNonces.PubNonce[:]), + ) + + testCases := []struct { + name string + randomness [32]byte + chanPoint *wire.OutPoint + chainParams *chaincfg.Params + pubNonce string + }{{ + name: "mainnet", + randomness: staticRand, + chanPoint: staticChanPoint, + chainParams: mainNetParams, + pubNonce: "02045795da7cffa1e2d8e64c4dfe606cd54b9d727e93f8c277" + + "5b1d5442b80f605c024471a42dae0583f08262dcd09162d692bd" + + "2ceb44178f37599c925b6465e92786", + }, { + name: "channel point", + randomness: staticRand, + chanPoint: &wire.OutPoint{ + Hash: chainhash.Hash{ + 0x1, 0x2, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, + }, + Index: 124, + }, + chainParams: testNetParams, + pubNonce: "025c22d8bc5fd0605fa007db40977da4caff2d5312bd865b62" + + "b7db6a184ea1b0d803278833480a3b7005cd1fad18c9a2740407" + + "8a325d3c85f33b0370663d2943e44d", + }, { + name: "randomness", + randomness: [32]byte{0x1}, + chanPoint: staticChanPoint, + chainParams: testNetParams, + pubNonce: "02a0d0b3281e92130e64a454ad122b37c8fd771647eb442769" + + "103b583db5d73753030ed0c6e04f4fb729b2db5a34331a4e5283" + + "b3872004222c401a4b9ec6d0540f64", + }} + + for idx := range testCases { + tc := testCases[idx] + + t.Run(tc.name, func(t *testing.T) { + nonces, err := GenerateMuSig2Nonces( + extendedKey, tc.randomness, tc.chanPoint, + tc.chainParams, nil, + ) + require.NoError(t, err) + + require.NotEqual( + t, staticNonces.PubNonce, nonces.PubNonce, + ) + + require.Equal( + t, tc.pubNonce, + hex.EncodeToString(nonces.PubNonce[:]), + ) + }) + } +}