diff --git a/sweepbatcher/greedy_batch_selection.go b/sweepbatcher/greedy_batch_selection.go index 15585af80..78a2a5dc6 100644 --- a/sweepbatcher/greedy_batch_selection.go +++ b/sweepbatcher/greedy_batch_selection.go @@ -51,7 +51,7 @@ func (b *Batcher) greedyAddSweep(ctx context.Context, sweep *sweep) error { // Run the algorithm. Get batchId of possible batches, sorted from best // to worst. batchesIds, err := selectBatches( - batches, sweepFeeDetails, newBatchFeeDetails, + batches, sweepFeeDetails, newBatchFeeDetails, b.mixedBatch, ) if err != nil { return fmt.Errorf("batch selection algorithm failed for sweep "+ @@ -125,10 +125,11 @@ func estimateSweepFeeIncrement(s *sweep) (feeDetails, feeDetails, error) { // Create feeDetails for sweep. sweepFeeDetails := feeDetails{ FeeRate: s.minFeeRate, - NonCoopHint: s.nonCoopHint, + NonCoopHint: s.nonCoopHint || s.coopFailed, IsExternalAddr: s.isExternalAddr, // Calculate sweep weight as a difference. + MixedWeight: fd2.MixedWeight - fd1.MixedWeight, CoopWeight: fd2.CoopWeight - fd1.CoopWeight, NonCoopWeight: fd2.NonCoopWeight - fd1.NonCoopWeight, } @@ -152,7 +153,7 @@ func estimateBatchWeight(batch *batch) (feeDetails, error) { // Find if the batch has at least one non-cooperative sweep. hasNonCoop := false for _, sweep := range batch.sweeps { - if sweep.nonCoopHint { + if sweep.nonCoopHint || sweep.coopFailed { hasNonCoop = true } } @@ -177,11 +178,16 @@ func estimateBatchWeight(batch *batch) (feeDetails, error) { destAddr = (*btcutil.AddressTaproot)(nil) } - // Make two estimators: for coop and non-coop cases. - var coopWeight, nonCoopWeight input.TxWeightEstimator + // Make three estimators: for mixed, coop and non-coop cases. + var mixedWeight, coopWeight, nonCoopWeight input.TxWeightEstimator // Add output weight to the estimator. - err := sweeppkg.AddOutputEstimate(&coopWeight, destAddr) + err := sweeppkg.AddOutputEstimate(&mixedWeight, destAddr) + if err != nil { + return feeDetails{}, fmt.Errorf("sweep.AddOutputEstimate: %w", + err) + } + err = sweeppkg.AddOutputEstimate(&coopWeight, destAddr) if err != nil { return feeDetails{}, fmt.Errorf("sweep.AddOutputEstimate: %w", err) @@ -194,6 +200,19 @@ func estimateBatchWeight(batch *batch) (feeDetails, error) { // Add inputs. for _, sweep := range batch.sweeps { + if sweep.nonCoopHint || sweep.coopFailed { + err = sweep.htlcSuccessEstimator(&mixedWeight) + if err != nil { + return feeDetails{}, fmt.Errorf( + "htlcSuccessEstimator failed: %w", err, + ) + } + } else { + mixedWeight.AddTaprootKeySpendInput( + txscript.SigHashDefault, + ) + } + coopWeight.AddTaprootKeySpendInput(txscript.SigHashDefault) err = sweep.htlcSuccessEstimator(&nonCoopWeight) @@ -206,6 +225,7 @@ func estimateBatchWeight(batch *batch) (feeDetails, error) { return feeDetails{ BatchId: batch.id, FeeRate: batch.rbfCache.FeeRate, + MixedWeight: mixedWeight.Weight(), CoopWeight: coopWeight.Weight(), NonCoopWeight: nonCoopWeight.Weight(), NonCoopHint: hasNonCoop, @@ -222,6 +242,7 @@ const newBatchSignal = -1 type feeDetails struct { BatchId int32 FeeRate chainfee.SatPerKWeight + MixedWeight lntypes.WeightUnit CoopWeight lntypes.WeightUnit NonCoopWeight lntypes.WeightUnit NonCoopHint bool @@ -229,11 +250,14 @@ type feeDetails struct { } // fee returns fee of onchain transaction representing this instance. -func (e feeDetails) fee() btcutil.Amount { +func (e feeDetails) fee(mixedBatch bool) btcutil.Amount { var weight lntypes.WeightUnit - if e.NonCoopHint { + switch { + case mixedBatch: + weight = e.MixedWeight + case e.NonCoopHint: weight = e.NonCoopWeight - } else { + default: weight = e.CoopWeight } @@ -250,6 +274,7 @@ func (e1 feeDetails) combine(e2 feeDetails) feeDetails { return feeDetails{ FeeRate: feeRate, + MixedWeight: e1.MixedWeight + e2.MixedWeight, CoopWeight: e1.CoopWeight + e2.CoopWeight, NonCoopWeight: e1.NonCoopWeight + e2.NonCoopWeight, NonCoopHint: e1.NonCoopHint || e2.NonCoopHint, @@ -259,21 +284,26 @@ func (e1 feeDetails) combine(e2 feeDetails) feeDetails { // selectBatches returns the list of id of batches sorted from best to worst. // Creation a new batch is encoded as newBatchSignal. For each batch its fee -// rate and two weights are provided: weight in case of cooperative spending and -// weight in case non-cooperative spending (using preimages instead of taproot -// key spend). Also, a hint is provided to signal if the batch has to use -// non-cooperative spending path. The same data is also provided to the sweep -// for which we are selecting a batch to add. In case of the sweep weights are -// weight deltas resulted from adding the sweep. Finally, the same data is -// provided for new batch having this sweep only. The algorithm compares costs -// of adding the sweep to each existing batch, and costs of new batch creation -// for this sweep and returns BatchId of the winning batch. If the best option -// is to create a new batch, return newBatchSignal. Each fee details has also -// IsExternalAddr flag. There is a rule that sweeps having flag IsExternalAddr -// must go in individual batches. Cooperative spending is only available if all -// the sweeps support cooperative spending path. -func selectBatches(batches []feeDetails, sweep, oneSweepBatch feeDetails) ( - []int32, error) { +// rate and a set of weights are provided: weight in case of a mixed batch, +// weight in case of cooperative spending and weight in case non-cooperative +// spending. Also, a hint is provided to signal what spending path will be used +// by the batch. +// +// The same data is also provided for the sweep for which we are selecting a +// batch to add. In case of the sweep weights are weight deltas resulted from +// adding the sweep. Finally, the same data is provided for new batch having +// this sweep only. +// +// The algorithm compares costs of adding the sweep to each existing batch, and +// costs of new batch creation for this sweep and returns BatchId of the winning +// batch. If the best option is to create a new batch, return newBatchSignal. +// +// Each fee details has also IsExternalAddr flag. There is a rule that sweeps +// having flag IsExternalAddr must go in individual batches. Cooperative +// spending is only available if all the sweeps support cooperative spending +// path of in a mixed batch. +func selectBatches(batches []feeDetails, sweep, oneSweepBatch feeDetails, + mixedBatch bool) ([]int32, error) { // If the sweep has IsExternalAddr flag, the sweep can't be added to // a batch, so create new batch for it. @@ -294,7 +324,7 @@ func selectBatches(batches []feeDetails, sweep, oneSweepBatch feeDetails) ( // creation with this sweep only in it. The cost is its full fee. alternatives = append(alternatives, alternative{ batchId: newBatchSignal, - cost: oneSweepBatch.fee(), + cost: oneSweepBatch.fee(mixedBatch), }) // Try to add the sweep to every batch, calculate the costs and @@ -310,7 +340,7 @@ func selectBatches(batches []feeDetails, sweep, oneSweepBatch feeDetails) ( combinedBatch := batch.combine(sweep) // The cost is the fee increase. - cost := combinedBatch.fee() - batch.fee() + cost := combinedBatch.fee(mixedBatch) - batch.fee(mixedBatch) // The cost must be positive, because we added a sweep. if cost <= 0 { diff --git a/sweepbatcher/greedy_batch_selection_test.go b/sweepbatcher/greedy_batch_selection_test.go index a67e7db78..44a0ec40a 100644 --- a/sweepbatcher/greedy_batch_selection_test.go +++ b/sweepbatcher/greedy_batch_selection_test.go @@ -29,9 +29,9 @@ const ( ) * 4 coopTwoSweepBatchWeight = coopNewBatchWeight + coopInputWeight - nonCoopTwoSweepBatchWeight = coopTwoSweepBatchWeight + - 2*nonCoopPenalty - v2v3BatchWeight = nonCoopTwoSweepBatchWeight - 25 + nonCoopTwoSweepBatchWeight = coopTwoSweepBatchWeight + 2*nonCoopPenalty + v2v3BatchWeight = nonCoopTwoSweepBatchWeight - 25 + mixedTwoSweepBatchWeight = coopTwoSweepBatchWeight + nonCoopPenalty ) // testHtlcV2SuccessEstimator adds weight of non-cooperative input to estimator @@ -86,11 +86,13 @@ func TestEstimateSweepFeeIncrement(t *testing.T) { }, wantSweepFeeDetails: feeDetails{ FeeRate: lowFeeRate, + MixedWeight: coopInputWeight, CoopWeight: coopInputWeight, NonCoopWeight: nonCoopInputWeight, }, wantNewBatchFeeDetails: feeDetails{ FeeRate: lowFeeRate, + MixedWeight: coopNewBatchWeight, CoopWeight: coopNewBatchWeight, NonCoopWeight: nonCoopNewBatchWeight, }, @@ -104,11 +106,13 @@ func TestEstimateSweepFeeIncrement(t *testing.T) { }, wantSweepFeeDetails: feeDetails{ FeeRate: highFeeRate, + MixedWeight: coopInputWeight, CoopWeight: coopInputWeight, NonCoopWeight: nonCoopInputWeight, }, wantNewBatchFeeDetails: feeDetails{ FeeRate: highFeeRate, + MixedWeight: coopNewBatchWeight, CoopWeight: coopNewBatchWeight, NonCoopWeight: nonCoopNewBatchWeight, }, @@ -124,12 +128,14 @@ func TestEstimateSweepFeeIncrement(t *testing.T) { }, wantSweepFeeDetails: feeDetails{ FeeRate: lowFeeRate, + MixedWeight: coopInputWeight, CoopWeight: coopInputWeight, NonCoopWeight: nonCoopInputWeight, IsExternalAddr: true, }, wantNewBatchFeeDetails: feeDetails{ FeeRate: lowFeeRate, + MixedWeight: coopNewBatchWeight, CoopWeight: coopNewBatchWeight, NonCoopWeight: nonCoopNewBatchWeight, IsExternalAddr: true, @@ -146,12 +152,15 @@ func TestEstimateSweepFeeIncrement(t *testing.T) { }, wantSweepFeeDetails: feeDetails{ FeeRate: lowFeeRate, + MixedWeight: coopInputWeight, CoopWeight: coopInputWeight, NonCoopWeight: nonCoopInputWeight, IsExternalAddr: true, }, wantNewBatchFeeDetails: feeDetails{ FeeRate: lowFeeRate, + MixedWeight: coopNewBatchWeight - + p2pkhDiscount, CoopWeight: coopNewBatchWeight - p2pkhDiscount, NonCoopWeight: nonCoopNewBatchWeight - @@ -169,12 +178,37 @@ func TestEstimateSweepFeeIncrement(t *testing.T) { }, wantSweepFeeDetails: feeDetails{ FeeRate: lowFeeRate, + MixedWeight: nonCoopInputWeight, + CoopWeight: coopInputWeight, + NonCoopWeight: nonCoopInputWeight, + NonCoopHint: true, + }, + wantNewBatchFeeDetails: feeDetails{ + FeeRate: lowFeeRate, + MixedWeight: nonCoopNewBatchWeight, + CoopWeight: coopNewBatchWeight, + NonCoopWeight: nonCoopNewBatchWeight, + NonCoopHint: true, + }, + }, + + { + name: "coop-failed", + sweep: &sweep{ + minFeeRate: lowFeeRate, + htlcSuccessEstimator: se3, + coopFailed: true, + }, + wantSweepFeeDetails: feeDetails{ + FeeRate: lowFeeRate, + MixedWeight: nonCoopInputWeight, CoopWeight: coopInputWeight, NonCoopWeight: nonCoopInputWeight, NonCoopHint: true, }, wantNewBatchFeeDetails: feeDetails{ FeeRate: lowFeeRate, + MixedWeight: nonCoopNewBatchWeight, CoopWeight: coopNewBatchWeight, NonCoopWeight: nonCoopNewBatchWeight, NonCoopHint: true, @@ -230,6 +264,7 @@ func TestEstimateBatchWeight(t *testing.T) { wantBatchFeeDetails: feeDetails{ BatchId: 1, FeeRate: lowFeeRate, + MixedWeight: coopNewBatchWeight, CoopWeight: coopNewBatchWeight, NonCoopWeight: nonCoopNewBatchWeight, }, @@ -254,6 +289,7 @@ func TestEstimateBatchWeight(t *testing.T) { wantBatchFeeDetails: feeDetails{ BatchId: 1, FeeRate: lowFeeRate, + MixedWeight: coopTwoSweepBatchWeight, CoopWeight: coopTwoSweepBatchWeight, NonCoopWeight: nonCoopTwoSweepBatchWeight, }, @@ -278,6 +314,7 @@ func TestEstimateBatchWeight(t *testing.T) { wantBatchFeeDetails: feeDetails{ BatchId: 1, FeeRate: lowFeeRate, + MixedWeight: coopTwoSweepBatchWeight, CoopWeight: coopTwoSweepBatchWeight, NonCoopWeight: v2v3BatchWeight, }, @@ -299,6 +336,7 @@ func TestEstimateBatchWeight(t *testing.T) { wantBatchFeeDetails: feeDetails{ BatchId: 1, FeeRate: highFeeRate, + MixedWeight: coopNewBatchWeight, CoopWeight: coopNewBatchWeight, NonCoopWeight: nonCoopNewBatchWeight, }, @@ -324,6 +362,34 @@ func TestEstimateBatchWeight(t *testing.T) { wantBatchFeeDetails: feeDetails{ BatchId: 1, FeeRate: lowFeeRate, + MixedWeight: mixedTwoSweepBatchWeight, + CoopWeight: coopTwoSweepBatchWeight, + NonCoopWeight: nonCoopTwoSweepBatchWeight, + NonCoopHint: true, + }, + }, + + { + name: "coop-failed", + batch: &batch{ + id: 1, + rbfCache: rbfCache{ + FeeRate: lowFeeRate, + }, + sweeps: map[lntypes.Hash]sweep{ + swapHash1: { + htlcSuccessEstimator: se3, + }, + swapHash2: { + htlcSuccessEstimator: se3, + coopFailed: true, + }, + }, + }, + wantBatchFeeDetails: feeDetails{ + BatchId: 1, + FeeRate: lowFeeRate, + MixedWeight: mixedTwoSweepBatchWeight, CoopWeight: coopTwoSweepBatchWeight, NonCoopWeight: nonCoopTwoSweepBatchWeight, NonCoopHint: true, @@ -348,6 +414,7 @@ func TestEstimateBatchWeight(t *testing.T) { wantBatchFeeDetails: feeDetails{ BatchId: 1, FeeRate: lowFeeRate, + MixedWeight: coopNewBatchWeight, CoopWeight: coopNewBatchWeight, NonCoopWeight: nonCoopNewBatchWeight, IsExternalAddr: true, @@ -373,6 +440,7 @@ func TestSelectBatches(t *testing.T) { name string batches []feeDetails sweep, oneSweepBatch feeDetails + mixedBatch bool wantBestBatchesIds []int32 }{ { @@ -380,11 +448,13 @@ func TestSelectBatches(t *testing.T) { batches: []feeDetails{}, sweep: feeDetails{ FeeRate: lowFeeRate, + MixedWeight: coopInputWeight, CoopWeight: coopInputWeight, NonCoopWeight: nonCoopInputWeight, }, oneSweepBatch: feeDetails{ FeeRate: lowFeeRate, + MixedWeight: coopNewBatchWeight, CoopWeight: coopNewBatchWeight, NonCoopWeight: nonCoopNewBatchWeight, }, @@ -397,17 +467,20 @@ func TestSelectBatches(t *testing.T) { { BatchId: 1, FeeRate: lowFeeRate, + MixedWeight: coopNewBatchWeight, CoopWeight: coopNewBatchWeight, NonCoopWeight: nonCoopNewBatchWeight, }, }, sweep: feeDetails{ FeeRate: lowFeeRate, + MixedWeight: coopInputWeight, CoopWeight: coopInputWeight, NonCoopWeight: nonCoopInputWeight, }, oneSweepBatch: feeDetails{ FeeRate: lowFeeRate, + MixedWeight: coopNewBatchWeight, CoopWeight: coopNewBatchWeight, NonCoopWeight: nonCoopNewBatchWeight, }, @@ -420,17 +493,20 @@ func TestSelectBatches(t *testing.T) { { BatchId: 1, FeeRate: highFeeRate, + MixedWeight: coopNewBatchWeight, CoopWeight: coopNewBatchWeight, NonCoopWeight: nonCoopNewBatchWeight, }, }, sweep: feeDetails{ FeeRate: lowFeeRate, + MixedWeight: coopInputWeight, CoopWeight: coopInputWeight, NonCoopWeight: nonCoopInputWeight, }, oneSweepBatch: feeDetails{ FeeRate: lowFeeRate, + MixedWeight: coopNewBatchWeight, CoopWeight: coopNewBatchWeight, NonCoopWeight: nonCoopNewBatchWeight, }, @@ -443,23 +519,27 @@ func TestSelectBatches(t *testing.T) { { BatchId: 1, FeeRate: lowFeeRate, + MixedWeight: coopNewBatchWeight, CoopWeight: coopNewBatchWeight, NonCoopWeight: nonCoopNewBatchWeight, }, { BatchId: 2, FeeRate: highFeeRate, + MixedWeight: coopNewBatchWeight, CoopWeight: coopNewBatchWeight, NonCoopWeight: nonCoopNewBatchWeight, }, }, sweep: feeDetails{ FeeRate: lowFeeRate, + MixedWeight: coopInputWeight, CoopWeight: coopInputWeight, NonCoopWeight: nonCoopInputWeight, }, oneSweepBatch: feeDetails{ FeeRate: lowFeeRate, + MixedWeight: coopNewBatchWeight, CoopWeight: coopNewBatchWeight, NonCoopWeight: nonCoopNewBatchWeight, }, @@ -472,23 +552,27 @@ func TestSelectBatches(t *testing.T) { { BatchId: 1, FeeRate: lowFeeRate, + MixedWeight: coopNewBatchWeight, CoopWeight: coopNewBatchWeight, NonCoopWeight: nonCoopNewBatchWeight, }, { BatchId: 2, FeeRate: highFeeRate, + MixedWeight: coopNewBatchWeight, CoopWeight: coopNewBatchWeight, NonCoopWeight: nonCoopNewBatchWeight, }, }, sweep: feeDetails{ FeeRate: highFeeRate, + MixedWeight: coopInputWeight, CoopWeight: coopInputWeight, NonCoopWeight: nonCoopInputWeight, }, oneSweepBatch: feeDetails{ FeeRate: highFeeRate, + MixedWeight: coopNewBatchWeight, CoopWeight: coopNewBatchWeight, NonCoopWeight: nonCoopNewBatchWeight, }, @@ -501,24 +585,28 @@ func TestSelectBatches(t *testing.T) { { BatchId: 1, FeeRate: lowFeeRate, + MixedWeight: coopNewBatchWeight, CoopWeight: coopNewBatchWeight, NonCoopWeight: nonCoopNewBatchWeight, }, { BatchId: 2, FeeRate: highFeeRate, + MixedWeight: coopNewBatchWeight, CoopWeight: coopNewBatchWeight, NonCoopWeight: nonCoopNewBatchWeight, }, }, sweep: feeDetails{ FeeRate: highFeeRate, + MixedWeight: nonCoopInputWeight, CoopWeight: coopInputWeight, NonCoopWeight: nonCoopInputWeight, NonCoopHint: true, }, oneSweepBatch: feeDetails{ FeeRate: highFeeRate, + MixedWeight: nonCoopNewBatchWeight, CoopWeight: coopNewBatchWeight, NonCoopWeight: nonCoopNewBatchWeight, NonCoopHint: true, @@ -532,24 +620,28 @@ func TestSelectBatches(t *testing.T) { { BatchId: 1, FeeRate: lowFeeRate, + MixedWeight: 10000, CoopWeight: 10000, NonCoopWeight: 15000, }, { BatchId: 2, FeeRate: highFeeRate, + MixedWeight: 10000, CoopWeight: 10000, NonCoopWeight: 15000, }, }, sweep: feeDetails{ FeeRate: highFeeRate, + MixedWeight: nonCoopInputWeight, CoopWeight: coopInputWeight, NonCoopWeight: nonCoopInputWeight, NonCoopHint: true, }, oneSweepBatch: feeDetails{ FeeRate: highFeeRate, + MixedWeight: nonCoopNewBatchWeight, CoopWeight: coopNewBatchWeight, NonCoopWeight: nonCoopNewBatchWeight, NonCoopHint: true, @@ -557,18 +649,56 @@ func TestSelectBatches(t *testing.T) { wantBestBatchesIds: []int32{newBatchSignal, 2, 1}, }, + { + name: "high fee noncoop sweep, large batches, mixed", + batches: []feeDetails{ + { + BatchId: 1, + FeeRate: lowFeeRate, + MixedWeight: 10000, + CoopWeight: 10000, + NonCoopWeight: 15000, + }, + { + BatchId: 2, + FeeRate: highFeeRate, + MixedWeight: 10000, + CoopWeight: 10000, + NonCoopWeight: 15000, + }, + }, + sweep: feeDetails{ + FeeRate: highFeeRate, + MixedWeight: nonCoopInputWeight, + CoopWeight: coopInputWeight, + NonCoopWeight: nonCoopInputWeight, + NonCoopHint: true, + }, + oneSweepBatch: feeDetails{ + FeeRate: highFeeRate, + MixedWeight: nonCoopNewBatchWeight, + CoopWeight: coopNewBatchWeight, + NonCoopWeight: nonCoopNewBatchWeight, + NonCoopHint: true, + }, + mixedBatch: true, + wantBestBatchesIds: []int32{2, newBatchSignal, 1}, + }, + { name: "high fee noncoop sweep, high batch noncoop", batches: []feeDetails{ { BatchId: 1, FeeRate: lowFeeRate, + MixedWeight: coopNewBatchWeight, CoopWeight: coopNewBatchWeight, NonCoopWeight: nonCoopNewBatchWeight, }, { BatchId: 2, FeeRate: highFeeRate, + MixedWeight: nonCoopNewBatchWeight, CoopWeight: coopNewBatchWeight, NonCoopWeight: nonCoopNewBatchWeight, NonCoopHint: true, @@ -576,12 +706,14 @@ func TestSelectBatches(t *testing.T) { }, sweep: feeDetails{ FeeRate: highFeeRate, + MixedWeight: nonCoopInputWeight, CoopWeight: coopInputWeight, NonCoopWeight: nonCoopInputWeight, NonCoopHint: true, }, oneSweepBatch: feeDetails{ FeeRate: highFeeRate, + MixedWeight: nonCoopNewBatchWeight, CoopWeight: coopNewBatchWeight, NonCoopWeight: nonCoopNewBatchWeight, NonCoopHint: true, @@ -595,24 +727,28 @@ func TestSelectBatches(t *testing.T) { { BatchId: 1, FeeRate: lowFeeRate, + MixedWeight: coopNewBatchWeight, CoopWeight: coopNewBatchWeight, NonCoopWeight: nonCoopNewBatchWeight, }, { BatchId: 2, FeeRate: highFeeRate, + MixedWeight: coopNewBatchWeight, CoopWeight: coopNewBatchWeight, NonCoopWeight: nonCoopNewBatchWeight, }, }, sweep: feeDetails{ FeeRate: lowFeeRate, + MixedWeight: nonCoopInputWeight, CoopWeight: coopInputWeight, NonCoopWeight: nonCoopInputWeight, NonCoopHint: true, }, oneSweepBatch: feeDetails{ FeeRate: lowFeeRate, + MixedWeight: nonCoopNewBatchWeight, CoopWeight: coopNewBatchWeight, NonCoopWeight: nonCoopNewBatchWeight, NonCoopHint: true, @@ -626,24 +762,28 @@ func TestSelectBatches(t *testing.T) { { BatchId: 1, FeeRate: lowFeeRate, + MixedWeight: 10000, CoopWeight: 10000, NonCoopWeight: 15000, }, { BatchId: 2, FeeRate: highFeeRate, + MixedWeight: 10000, CoopWeight: 10000, NonCoopWeight: 15000, }, }, sweep: feeDetails{ FeeRate: lowFeeRate, + MixedWeight: nonCoopInputWeight, CoopWeight: coopInputWeight, NonCoopWeight: nonCoopInputWeight, NonCoopHint: true, }, oneSweepBatch: feeDetails{ FeeRate: lowFeeRate, + MixedWeight: nonCoopNewBatchWeight, CoopWeight: coopNewBatchWeight, NonCoopWeight: nonCoopNewBatchWeight, NonCoopHint: true, @@ -651,12 +791,49 @@ func TestSelectBatches(t *testing.T) { wantBestBatchesIds: []int32{newBatchSignal, 1, 2}, }, + { + name: "low fee noncoop sweep, large batches, mixed", + batches: []feeDetails{ + { + BatchId: 1, + FeeRate: lowFeeRate, + MixedWeight: 10000, + CoopWeight: 10000, + NonCoopWeight: 15000, + }, + { + BatchId: 2, + FeeRate: highFeeRate, + MixedWeight: 10000, + CoopWeight: 10000, + NonCoopWeight: 15000, + }, + }, + sweep: feeDetails{ + FeeRate: lowFeeRate, + MixedWeight: nonCoopInputWeight, + CoopWeight: coopInputWeight, + NonCoopWeight: nonCoopInputWeight, + NonCoopHint: true, + }, + oneSweepBatch: feeDetails{ + FeeRate: lowFeeRate, + MixedWeight: nonCoopNewBatchWeight, + CoopWeight: coopNewBatchWeight, + NonCoopWeight: nonCoopNewBatchWeight, + NonCoopHint: true, + }, + mixedBatch: true, + wantBestBatchesIds: []int32{1, newBatchSignal, 2}, + }, + { name: "low fee noncoop sweep, low batch noncoop", batches: []feeDetails{ { BatchId: 1, FeeRate: lowFeeRate, + MixedWeight: nonCoopNewBatchWeight, CoopWeight: coopNewBatchWeight, NonCoopWeight: nonCoopNewBatchWeight, NonCoopHint: true, @@ -664,18 +841,21 @@ func TestSelectBatches(t *testing.T) { { BatchId: 2, FeeRate: highFeeRate, + MixedWeight: coopNewBatchWeight, CoopWeight: coopNewBatchWeight, NonCoopWeight: nonCoopNewBatchWeight, }, }, sweep: feeDetails{ FeeRate: lowFeeRate, + MixedWeight: nonCoopInputWeight, CoopWeight: coopInputWeight, NonCoopWeight: nonCoopInputWeight, NonCoopHint: true, }, oneSweepBatch: feeDetails{ FeeRate: lowFeeRate, + MixedWeight: nonCoopNewBatchWeight, CoopWeight: coopNewBatchWeight, NonCoopWeight: nonCoopNewBatchWeight, NonCoopHint: true, @@ -689,24 +869,28 @@ func TestSelectBatches(t *testing.T) { { BatchId: 1, FeeRate: highFeeRate, + MixedWeight: coopNewBatchWeight, CoopWeight: coopNewBatchWeight, NonCoopWeight: nonCoopNewBatchWeight, }, { BatchId: 2, FeeRate: highFeeRate, + MixedWeight: coopNewBatchWeight, CoopWeight: coopNewBatchWeight, NonCoopWeight: nonCoopNewBatchWeight, }, }, sweep: feeDetails{ FeeRate: highFeeRate, + MixedWeight: coopInputWeight, CoopWeight: coopInputWeight, NonCoopWeight: nonCoopInputWeight, IsExternalAddr: true, }, oneSweepBatch: feeDetails{ FeeRate: highFeeRate, + MixedWeight: coopNewBatchWeight, CoopWeight: coopNewBatchWeight, NonCoopWeight: nonCoopNewBatchWeight, IsExternalAddr: true, @@ -720,12 +904,14 @@ func TestSelectBatches(t *testing.T) { { BatchId: 1, FeeRate: highFeeRate - 1, + MixedWeight: coopNewBatchWeight, CoopWeight: coopNewBatchWeight, NonCoopWeight: nonCoopNewBatchWeight, }, { BatchId: 2, FeeRate: highFeeRate, + MixedWeight: coopNewBatchWeight, CoopWeight: coopNewBatchWeight, NonCoopWeight: nonCoopNewBatchWeight, IsExternalAddr: true, @@ -733,11 +919,13 @@ func TestSelectBatches(t *testing.T) { }, sweep: feeDetails{ FeeRate: highFeeRate, + MixedWeight: coopInputWeight, CoopWeight: coopInputWeight, NonCoopWeight: nonCoopInputWeight, }, oneSweepBatch: feeDetails{ FeeRate: highFeeRate, + MixedWeight: coopNewBatchWeight, CoopWeight: coopNewBatchWeight, NonCoopWeight: nonCoopNewBatchWeight, }, @@ -750,6 +938,7 @@ func TestSelectBatches(t *testing.T) { t.Run(tc.name, func(t *testing.T) { gotBestBatchesIds, err := selectBatches( tc.batches, tc.sweep, tc.oneSweepBatch, + tc.mixedBatch, ) require.NoError(t, err) require.Equal( diff --git a/sweepbatcher/sweep_batch.go b/sweepbatcher/sweep_batch.go index 7ee758211..549d8f89f 100644 --- a/sweepbatcher/sweep_batch.go +++ b/sweepbatcher/sweep_batch.go @@ -7,9 +7,11 @@ import ( "errors" "fmt" "math" + "strings" "sync" "time" + "github.com/btcsuite/btcd/blockchain" "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcec/v2/schnorr/musig2" "github.com/btcsuite/btcd/btcutil" @@ -107,6 +109,11 @@ type sweep struct { // has to be spent using preimage. This is only used in fee estimations // when selecting a batch for the sweep to minimize fees. nonCoopHint bool + + // coopFailed is set, if we have tried to spend the sweep cooperatively, + // but it failed. We try to spend a sweep cooperatively only once. This + // status is not persisted in the DB. + coopFailed bool } // batchState is the state of the batch. @@ -160,6 +167,16 @@ type batchConfig struct { // Note that musig2SignSweep must be nil in this case, however signer // client must still be provided, as it is used for non-coop spendings. customMuSig2Signer SignMuSig2 + + // mixedBatch instructs sweepbatcher to create mixed batches with regard + // to cooperativeness. Such a batch can include sweeps signed both + // cooperatively and non-cooperatively. If cooperative signing fails for + // a sweep, transaction is updated to sign that sweep non-cooperatively + // and another round of cooperative signing runs on the remaining + // sweeps. The remaining sweeps are signed in non-cooperative (more + // expensive) way. If the whole procedure fails for whatever reason, the + // batch is signed non-cooperatively (the fallback). + mixedBatch bool } // rbfCache stores data related to our last fee bump. @@ -421,13 +438,18 @@ func (b *batch) addSweep(ctx context.Context, sweep *sweep) (bool, error) { // Before we run through the acceptance checks, let's just see if this // sweep is already in our batch. In that case, just update the sweep. - _, ok := b.sweeps[sweep.swapHash] + oldSweep, ok := b.sweeps[sweep.swapHash] if ok { + // Preserve coopFailed value not to forget about cooperative + // spending failure in this sweep. + tmp := *sweep + tmp.coopFailed = oldSweep.coopFailed + // If the sweep was resumed from storage, and the swap requested // to sweep again, a new sweep notifier will be created by the // swap. By re-assigning to the batch's sweep we make sure that // everything, including the notifier, is up to date. - b.sweeps[sweep.swapHash] = *sweep + b.sweeps[sweep.swapHash] = tmp // If this is the primary sweep, we also need to update the // batch's confirmation target and fee rate. @@ -724,7 +746,7 @@ func (b *batch) publish(ctx context.Context) error { var ( err error fee btcutil.Amount - coopSuccess bool + signSuccess bool ) if len(b.sweeps) == 0 { @@ -739,12 +761,19 @@ func (b *batch) publish(ctx context.Context) error { return err } - fee, err, coopSuccess = b.publishBatchCoop(ctx) - if err != nil { - b.log.Warnf("co-op publish error: %v", err) + if b.cfg.mixedBatch { + fee, err, signSuccess = b.publishMixedBatch(ctx) + if err != nil { + b.log.Warnf("Mixed batch publish error: %v", err) + } + } else { + fee, err, signSuccess = b.publishBatchCoop(ctx) + if err != nil { + b.log.Warnf("co-op publish error: %v", err) + } } - if !coopSuccess { + if !signSuccess { fee, err = b.publishBatch(ctx) } if err != nil { @@ -922,6 +951,47 @@ func (b *batch) publishBatch(ctx context.Context) (btcutil.Amount, error) { return fee, nil } +// createPsbt creates serialized PSBT and prevOuts map from unsignedTx and +// the list of sweeps. +func (b *batch) createPsbt(unsignedTx *wire.MsgTx, sweeps []sweep) ([]byte, + map[wire.OutPoint]*wire.TxOut, error) { + + // Create PSBT packet object. + packet, err := psbt.NewFromUnsignedTx(unsignedTx) + if err != nil { + return nil, nil, fmt.Errorf("failed to create PSBT: %w", err) + } + + // Sanity check: the number of inputs in PSBT must be equal to the + // number of sweeps. + if len(packet.Inputs) != len(sweeps) { + return nil, nil, fmt.Errorf("invalid number of packet inputs") + } + + // Create prevOuts map. + prevOuts := make(map[wire.OutPoint]*wire.TxOut, len(sweeps)) + + // Fill input info in PSBT and prevOuts. + for i, sweep := range sweeps { + txOut := &wire.TxOut{ + Value: int64(sweep.value), + PkScript: sweep.htlc.PkScript, + } + + prevOuts[sweep.outpoint] = txOut + packet.Inputs[i].WitnessUtxo = txOut + } + + // Serialize PSBT. + var psbtBuf bytes.Buffer + err = packet.Serialize(&psbtBuf) + if err != nil { + return nil, nil, fmt.Errorf("failed to serialize PSBT: %w", err) + } + + return psbtBuf.Bytes(), prevOuts, nil +} + // publishBatchCoop attempts to construct and publish a batch transaction that // collects all the required signatures interactively from the server. This // helps with collecting the funds immediately without revealing any information @@ -1013,36 +1083,15 @@ func (b *batch) publishBatchCoop(ctx context.Context) (btcutil.Amount, Value: int64(batchAmt - fee), }) - packet, err := psbt.NewFromUnsignedTx(batchTx) - if err != nil { - return fee, err, false - } - - if len(packet.Inputs) != len(sweeps) { - return fee, fmt.Errorf("invalid number of packet inputs"), false - } - - prevOuts := make(map[wire.OutPoint]*wire.TxOut) - - for i, sweep := range sweeps { - txOut := &wire.TxOut{ - Value: int64(sweep.value), - PkScript: sweep.htlc.PkScript, - } - - prevOuts[sweep.outpoint] = txOut - packet.Inputs[i].WitnessUtxo = txOut - } - - var psbtBuf bytes.Buffer - err = packet.Serialize(&psbtBuf) + // Create PSBT and prevOuts. + psbtBytes, prevOuts, err := b.createPsbt(batchTx, sweeps) if err != nil { return fee, err, false } // Attempt to cooperatively sign the batch tx with the server. err = b.coopSignBatchTx( - ctx, packet, sweeps, prevOuts, psbtBuf.Bytes(), + ctx, batchTx, sweeps, prevOuts, psbtBytes, ) if err != nil { return fee, err, false @@ -1071,6 +1120,361 @@ func (b *batch) publishBatchCoop(ctx context.Context) (btcutil.Amount, return fee, nil, true } +// constructUnsignedTx creates unsigned tx from the sweeps, paying to the addr. +// It also returns absolute fee (from weight and clamped). +func (b *batch) constructUnsignedTx(sweeps []sweep, + address btcutil.Address) (*wire.MsgTx, lntypes.WeightUnit, + btcutil.Amount, btcutil.Amount, error) { + + // Sanity check, there should be at least 1 sweep in this batch. + if len(sweeps) == 0 { + return nil, 0, 0, 0, fmt.Errorf("no sweeps in batch") + } + + // Create the batch transaction. + batchTx := &wire.MsgTx{ + Version: 2, + LockTime: uint32(b.currentHeight), + } + + // Add transaction inputs and estimate its weight. + var weightEstimate input.TxWeightEstimator + for _, sweep := range sweeps { + if sweep.nonCoopHint || sweep.coopFailed { + // Non-cooperative sweep. + batchTx.AddTxIn(&wire.TxIn{ + PreviousOutPoint: sweep.outpoint, + Sequence: sweep.htlc.SuccessSequence(), + }) + + err := sweep.htlcSuccessEstimator(&weightEstimate) + if err != nil { + return nil, 0, 0, 0, fmt.Errorf("sweep."+ + "htlcSuccessEstimator failed: %w", err) + } + } else { + // Cooperative sweep. + batchTx.AddTxIn(&wire.TxIn{ + PreviousOutPoint: sweep.outpoint, + }) + + weightEstimate.AddTaprootKeySpendInput( + txscript.SigHashDefault, + ) + } + } + + // Convert the destination address to pkScript. + batchPkScript, err := txscript.PayToAddrScript(address) + if err != nil { + return nil, 0, 0, 0, fmt.Errorf("txscript.PayToAddrScript "+ + "failed: %w", err) + } + + // Add the output to weight estimates. + err = sweeppkg.AddOutputEstimate(&weightEstimate, address) + if err != nil { + return nil, 0, 0, 0, fmt.Errorf("sweep.AddOutputEstimate "+ + "failed: %w", err) + } + + // Keep track of the total amount this batch is sweeping back. + batchAmt := btcutil.Amount(0) + for _, sweep := range sweeps { + batchAmt += sweep.value + } + + // Find weight and fee. + weight := weightEstimate.Weight() + feeForWeight := b.rbfCache.FeeRate.FeeForWeight(weight) + + // Clamp the calculated fee to the max allowed fee amount for the batch. + fee := clampBatchFee(feeForWeight, batchAmt) + + // Add the batch transaction output, which excludes the fees paid to + // miners. + batchTx.AddTxOut(&wire.TxOut{ + PkScript: batchPkScript, + Value: int64(batchAmt - fee), + }) + + return batchTx, weight, feeForWeight, fee, nil +} + +// publishMixedBatch constructs and publishes a batch transaction that can +// include sweeps spent both cooperatively and non-cooperatively. If a sweep is +// marked with nonCoopHint or coopFailed flags, it is spent non-cooperatively. +// If a cooperative sweep fails to sign cooperatively, the whole transaction +// is re-signed again, with this sweep signing non-cooperatively. This process +// is optimized, trying to detect all non-cooperative sweeps in one round. The +// function returns the absolute fee. The last result of the function indicates +// if signing succeeded. +func (b *batch) publishMixedBatch(ctx context.Context) (btcutil.Amount, error, + bool) { + + // Sanity check, there should be at least 1 sweep in this batch. + if len(b.sweeps) == 0 { + return 0, fmt.Errorf("no sweeps in batch"), false + } + + // Append this sweep to an array of sweeps. This is needed to keep the + // order of sweeps stored, as iterating the sweeps map does not + // guarantee same order. + sweeps := make([]sweep, 0, len(b.sweeps)) + for _, sweep := range b.sweeps { + sweeps = append(sweeps, sweep) + } + + // Determine if an external address is used. + addrOverride := false + for _, sweep := range sweeps { + if sweep.isExternalAddr { + addrOverride = true + } + } + + // Find destination address. + var address btcutil.Address + if addrOverride { + // Sanity check, there should be exactly 1 sweep in this batch. + if len(sweeps) != 1 { + return 0, fmt.Errorf("external address sweep batched " + + "with other sweeps"), false + } + + address = sweeps[0].destAddr + } else { + var err error + address, err = b.getBatchDestAddr(ctx) + if err != nil { + return 0, err, false + } + } + + // Each iteration of this loop is one attempt to sign the transaction + // cooperatively. We try cooperative signing only for the sweeps not + // known in advance to be non-cooperative (nonCoopHint) and not failed + // to sign cooperatively in previous rounds (coopFailed). If any of them + // fails, the sweep is excluded from all following rounds and another + // round is attempted. Otherwise the cycle completes and we sign the + // remaining sweeps non-cooperatively. + var ( + tx *wire.MsgTx + weight lntypes.WeightUnit + feeForWeight btcutil.Amount + fee btcutil.Amount + coopInputs int + ) + for attempt := 1; ; attempt++ { + b.log.Infof("Attempt %d of collecting cooperative signatures.", + attempt) + + // Construct unsigned batch transaction. + var err error + tx, weight, feeForWeight, fee, err = b.constructUnsignedTx( + sweeps, address, + ) + if err != nil { + return 0, fmt.Errorf("failed to construct tx: %w", err), + false + } + + // Create PSBT and prevOutsMap. + psbtBytes, prevOutsMap, err := b.createPsbt(tx, sweeps) + if err != nil { + return 0, fmt.Errorf("createPsbt failed: %w", err), + false + } + + // Keep track if any new sweep failed to sign cooperatively. + newCoopFailures := false + + // Try to sign all cooperative sweeps first. + coopInputs = 0 + for i, sweep := range sweeps { + // Skip non-cooperative sweeps. + if sweep.nonCoopHint || sweep.coopFailed { + continue + } + + // Try to sign the sweep cooperatively. + finalSig, err := b.musig2sign( + ctx, i, sweep, tx, prevOutsMap, psbtBytes, + ) + if err != nil { + b.log.Infof("cooperative signing failed for "+ + "sweep %x: %v", sweep.swapHash[:6], err) + + // Set coopFailed flag for this sweep in all the + // places we store the sweep. + sweep.coopFailed = true + sweeps[i] = sweep + b.sweeps[sweep.swapHash] = sweep + + // Update newCoopFailures to know if we need + // another attempt of cooperative signing. + newCoopFailures = true + } else { + // Put the signature to witness of the input. + tx.TxIn[i].Witness = wire.TxWitness{finalSig} + coopInputs++ + } + } + + // If there was any failure of cooperative signing, we need to + // update weight estimates (since non-cooperative signing has + // larger witness) and hence update the whole transaction and + // all the signatures. Otherwise we complete cooperative part. + if !newCoopFailures { + break + } + } + + // Calculate the expected number of non-cooperative sweeps. + nonCoopInputs := len(sweeps) - coopInputs + + // Now sign the remaining sweeps' inputs non-cooperatively. + // For that, first collect sign descriptors for the signatures. + // Also collect prevOuts for all inputs. + signDescs := make([]*lndclient.SignDescriptor, 0, nonCoopInputs) + prevOutsList := make([]*wire.TxOut, 0, len(sweeps)) + for i, sweep := range sweeps { + // Create and store the previous outpoint for this sweep. + prevOut := &wire.TxOut{ + Value: int64(sweep.value), + PkScript: sweep.htlc.PkScript, + } + prevOutsList = append(prevOutsList, prevOut) + + // Skip cooperative sweeps. + if !sweep.nonCoopHint && !sweep.coopFailed { + continue + } + + key, err := btcec.ParsePubKey( + sweep.htlcKeys.ReceiverScriptKey[:], + ) + if err != nil { + return 0, fmt.Errorf("btcec.ParsePubKey failed: %w", + err), false + } + + // Create and store the sign descriptor for this sweep. + signDesc := lndclient.SignDescriptor{ + WitnessScript: sweep.htlc.SuccessScript(), + Output: prevOut, + HashType: sweep.htlc.SigHash(), + InputIndex: i, + KeyDesc: keychain.KeyDescriptor{ + PubKey: key, + }, + } + + if sweep.htlc.Version == swap.HtlcV3 { + signDesc.SignMethod = input.TaprootScriptSpendSignMethod + } + + signDescs = append(signDescs, &signDesc) + } + + // Sanity checks. + if len(signDescs) != nonCoopInputs { + // This must not happen by construction. + return 0, fmt.Errorf("unexpected size of signDescs: %d != %d", + len(signDescs), nonCoopInputs), false + } + if len(prevOutsList) != len(sweeps) { + // This must not happen by construction. + return 0, fmt.Errorf("unexpected size of prevOutsList: "+ + "%d != %d", len(prevOutsList), len(sweeps)), false + } + + var rawSigs [][]byte + if nonCoopInputs > 0 { + // Produce the signatures for our inputs using sign descriptors. + var err error + rawSigs, err = b.signerClient.SignOutputRaw( + ctx, tx, signDescs, prevOutsList, + ) + if err != nil { + return 0, fmt.Errorf("signerClient.SignOutputRaw "+ + "failed: %w", err), false + } + } + + // Sanity checks. + if len(rawSigs) != nonCoopInputs { + // This must not happen by construction. + return 0, fmt.Errorf("unexpected size of rawSigs: %d != %d", + len(rawSigs), nonCoopInputs), false + } + + // Generate success witnesses for non-cooperative sweeps. + sigIndex := 0 + for i, sweep := range sweeps { + // Skip cooperative sweeps. + if !sweep.nonCoopHint && !sweep.coopFailed { + continue + } + + witness, err := sweep.htlc.GenSuccessWitness( + rawSigs[sigIndex], sweep.preimage, + ) + if err != nil { + return 0, fmt.Errorf("sweep.htlc.GenSuccessWitness "+ + "failed: %w", err), false + } + sigIndex++ + + // Add the success witness to our batch transaction's inputs. + tx.TxIn[i].Witness = witness + } + + // Log transaction's details. + var coopHexs, nonCoopHexs []string + for _, sweep := range sweeps { + swapHex := fmt.Sprintf("%x", sweep.swapHash[:6]) + if sweep.nonCoopHint || sweep.coopFailed { + nonCoopHexs = append(nonCoopHexs, swapHex) + } else { + coopHexs = append(coopHexs, swapHex) + } + } + txHash := tx.TxHash() + b.log.Infof("attempting to publish mixed tx=%v with feerate=%v, "+ + "weight=%v, feeForWeight=%v, fee=%v, sweeps=%d, "+ + "%d cooperative: (%s) and %d non-cooperative (%s), destAddr=%s", + txHash, b.rbfCache.FeeRate, weight, feeForWeight, fee, + len(tx.TxIn), coopInputs, strings.Join(coopHexs, ", "), + nonCoopInputs, strings.Join(nonCoopHexs, ", "), address) + + b.debugLogTx("serialized mixed batch", tx) + + // Make sure tx weight matches the expected value. + realWeight := lntypes.WeightUnit( + blockchain.GetTransactionWeight(btcutil.NewTx(tx)), + ) + if realWeight != weight { + b.log.Warnf("actual weight of tx %v is %v, estimated as %d", + txHash, realWeight, weight) + } + + // Publish the transaction. + err := b.wallet.PublishTransaction( + ctx, tx, b.cfg.txLabeler(b.id), + ) + if err != nil { + return 0, fmt.Errorf("publishing tx failed: %w", err), true + } + + // Store the batch transaction's txid and pkScript, for monitoring + // purposes. + b.batchTxid = &txHash + b.batchPkScript = tx.TxOut[0].PkScript + + return fee, nil, true +} + func (b *batch) debugLogTx(msg string, tx *wire.MsgTx) { // Serialize the transaction and convert to hex string. buf := bytes.NewBuffer(make([]byte, 0, tx.SerializeSize())) @@ -1084,7 +1488,7 @@ func (b *batch) debugLogTx(msg string, tx *wire.MsgTx) { // coopSignBatchTx collects the necessary signatures from the server in order // to cooperatively sweep the funds. -func (b *batch) coopSignBatchTx(ctx context.Context, packet *psbt.Packet, +func (b *batch) coopSignBatchTx(ctx context.Context, tx *wire.MsgTx, sweeps []sweep, prevOuts map[wire.OutPoint]*wire.TxOut, psbt []byte) error { @@ -1092,13 +1496,13 @@ func (b *batch) coopSignBatchTx(ctx context.Context, packet *psbt.Packet, sweep := sweep finalSig, err := b.musig2sign( - ctx, i, sweep, packet.UnsignedTx, prevOuts, psbt, + ctx, i, sweep, tx, prevOuts, psbt, ) if err != nil { return err } - packet.UnsignedTx.TxIn[i].Witness = wire.TxWitness{ + tx.TxIn[i].Witness = wire.TxWitness{ finalSig, } } diff --git a/sweepbatcher/sweep_batcher.go b/sweepbatcher/sweep_batcher.go index fb6f26631..6165b509e 100644 --- a/sweepbatcher/sweep_batcher.go +++ b/sweepbatcher/sweep_batcher.go @@ -286,6 +286,16 @@ type Batcher struct { // Note that musig2SignSweep must be nil in this case, however signer // client must still be provided, as it is used for non-coop spendings. customMuSig2Signer SignMuSig2 + + // mixedBatch instructs sweepbatcher to create mixed batches with regard + // to cooperativeness. Such a batch can include sweeps signed both + // cooperatively and non-cooperatively. If cooperative signing fails for + // a sweep, transaction is updated to sign that sweep non-cooperatively + // and another round of cooperative signing runs on the remaining + // sweeps. The remaining sweeps are signed in non-cooperative (more + // expensive) way. If the whole procedure fails for whatever reason, the + // batch is signed non-cooperatively (the fallback). + mixedBatch bool } // BatcherConfig holds batcher configuration. @@ -321,6 +331,16 @@ type BatcherConfig struct { // Note that musig2SignSweep must be nil in this case, however signer // client must still be provided, as it is used for non-coop spendings. customMuSig2Signer SignMuSig2 + + // mixedBatch instructs sweepbatcher to create mixed batches with regard + // to cooperativeness. Such a batch can include sweeps signed both + // cooperatively and non-cooperatively. If cooperative signing fails for + // a sweep, transaction is updated to sign that sweep non-cooperatively + // and another round of cooperative signing runs on the remaining + // sweeps. The remaining sweeps are signed in non-cooperative (more + // expensive) way. If the whole procedure fails for whatever reason, the + // batch is signed non-cooperatively (the fallback). + mixedBatch bool } // BatcherOption configures batcher behaviour. @@ -386,6 +406,20 @@ func WithCustomSignMuSig2(customMuSig2Signer SignMuSig2) BatcherOption { } } +// WithMixedBatch instructs sweepbatcher to create mixed batches with +// regard to cooperativeness. Such a batch can include both sweeps signed +// both cooperatively and non-cooperatively. If cooperative signing fails +// for a sweep, transaction is updated to sign that sweep non-cooperatively +// and another round of cooperative signing runs on the remaining sweeps. +// The remaining sweeps are signed in non-cooperative (more expensive) way. +// If the whole procedure fails for whatever reason, the batch is signed +// non-cooperatively (the fallback). +func WithMixedBatch() BatcherOption { + return func(cfg *BatcherConfig) { + cfg.mixedBatch = true + } +} + // NewBatcher creates a new Batcher instance. func NewBatcher(wallet lndclient.WalletKitClient, chainNotifier lndclient.ChainNotifierClient, @@ -433,6 +467,7 @@ func NewBatcher(wallet lndclient.WalletKitClient, customFeeRate: cfg.customFeeRate, txLabeler: cfg.txLabeler, customMuSig2Signer: cfg.customMuSig2Signer, + mixedBatch: cfg.mixedBatch, } } @@ -1050,6 +1085,7 @@ func (b *Batcher) newBatchConfig(maxTimeoutDistance int32) batchConfig { txLabeler: b.txLabeler, customMuSig2Signer: b.customMuSig2Signer, clock: b.clock, + mixedBatch: b.mixedBatch, } } diff --git a/sweepbatcher/sweep_batcher_test.go b/sweepbatcher/sweep_batcher_test.go index ef001700f..e48100e20 100644 --- a/sweepbatcher/sweep_batcher_test.go +++ b/sweepbatcher/sweep_batcher_test.go @@ -2880,6 +2880,552 @@ func testCustomSignMuSig2(t *testing.T, store testStore, checkBatcherError(t, runErr) } +// testWithMixedBatch tests mixed batches construction. It also tests +// non-cooperative sweeping (using a preimage). Sweeps are added one by one. +func testWithMixedBatch(t *testing.T, store testStore, + batcherStore testBatcherStore) { + + defer test.Guard(t)() + + lnd := test.NewMockLnd() + ctx, cancel := context.WithCancel(context.Background()) + + // Extract payment address from the invoice. + swapPaymentAddr, err := utils.ObtainSwapPaymentAddr( + swapInvoice, lnd.ChainParams, + ) + require.NoError(t, err) + + // Use sweepFetcher to provide NonCoopHint for swapHash1. + sweepFetcher := &sweepFetcherMock{ + store: map[lntypes.Hash]*SweepInfo{}, + } + + // Create 3 sweeps: + // 1. known in advance to be non-cooperative, + // 2. fails cosigning during an attempt, + // 3. co-signs successfully. + + // Create 3 preimages, for 3 sweeps. + var preimages = []lntypes.Preimage{ + {1}, + {2}, + {3}, + } + + // Swap hashes must match the preimages, for non-cooperative spending + // path to work. + var swapHashes = []lntypes.Hash{ + preimages[0].Hash(), + preimages[1].Hash(), + preimages[2].Hash(), + } + + // Create muSig2SignSweep working only for 3rd swapHash. + muSig2SignSweep := func(ctx context.Context, + protocolVersion loopdb.ProtocolVersion, swapHash lntypes.Hash, + paymentAddr [32]byte, nonce []byte, sweepTxPsbt []byte, + prevoutMap map[wire.OutPoint]*wire.TxOut) ( + []byte, []byte, error) { + + if swapHash == swapHashes[2] { + return nil, nil, nil + } else { + return nil, nil, fmt.Errorf("test error") + } + } + + // Use mixed batches. + batcher := NewBatcher(lnd.WalletKit, lnd.ChainNotifier, lnd.Signer, + muSig2SignSweep, testVerifySchnorrSig, lnd.ChainParams, + batcherStore, sweepFetcher, WithMixedBatch()) + + var wg sync.WaitGroup + wg.Add(1) + + var runErr error + go func() { + defer wg.Done() + runErr = batcher.Run(ctx) + }() + + // Wait for the batcher to be initialized. + <-batcher.initDone + + // Expected weights for transaction having 1, 2, and 3 sweeps. + wantWeights := []lntypes.WeightUnit{559, 952, 1182} + + // Two non-cooperative sweeps, one cooperative. + wantWitnessSizes := []int{4, 4, 1} + + // Create 3 swaps and 3 sweeps. + for i, swapHash := range swapHashes { + // Publish a block to trigger republishing. + err = lnd.NotifyHeight(601 + int32(i)) + require.NoError(t, err) + + // Put a swap into store to satisfy SQL constraints. + swap := &loopdb.LoopOutContract{ + SwapContract: loopdb.SwapContract{ + CltvExpiry: 111, + AmountRequested: 1_000_000, + ProtocolVersion: loopdb.ProtocolVersionMuSig2, + HtlcKeys: htlcKeys, + + // Make preimage unique to pass SQL constraints. + Preimage: preimages[i], + }, + + DestAddr: destAddr, + SwapInvoice: swapInvoice, + SweepConfTarget: 111, + } + + require.NoError(t, store.CreateLoopOut(ctx, swapHash, swap)) + store.AssertLoopOutStored() + + // Add SweepInfo to sweepFetcher. + htlc, err := utils.GetHtlc( + swapHash, &swap.SwapContract, lnd.ChainParams, + ) + require.NoError(t, err) + + sweepInfo := &SweepInfo{ + Preimage: preimages[i], + ConfTarget: 123, + Timeout: 111, + SwapInvoicePaymentAddr: *swapPaymentAddr, + ProtocolVersion: loopdb.ProtocolVersionMuSig2, + HTLCKeys: htlcKeys, + HTLC: *htlc, + HTLCSuccessEstimator: htlc.AddSuccessToEstimator, + DestAddr: destAddr, + } + // The first sweep is known in advance to be non-cooperative. + if i == 0 { + sweepInfo.NonCoopHint = true + } + sweepFetcher.store[swapHash] = sweepInfo + + // Create sweep request. + sweepReq := SweepRequest{ + SwapHash: swapHash, + Value: 1_000_000, + Outpoint: wire.OutPoint{ + Hash: chainhash.Hash{1, 1}, + Index: 1, + }, + Notifier: &dummyNotifier, + } + require.NoError(t, batcher.AddSweep(&sweepReq)) + + if i == 0 { + // Since a batch was created we check that it registered + // for its primary sweep's spend. + <-lnd.RegisterSpendChannel + } + + // Expect mockSigner.SignOutputRaw call to sign non-cooperative + // sweeps. + <-lnd.SignOutputRawChannel + + // A transaction is published. + tx := <-lnd.TxPublishChannel + require.Equal(t, i+1, len(tx.TxIn)) + + // Check types of inputs. + var witnessSizes []int + for _, txIn := range tx.TxIn { + witnessSizes = append(witnessSizes, len(txIn.Witness)) + } + // The order of inputs is not deterministic, because they + // are stored in map. + require.ElementsMatch(t, wantWitnessSizes[:i+1], witnessSizes) + + // Calculate expected values. + feeRate := test.DefaultMockFee + for range i { + // Bump fee the number of blocks passed. + feeRate += defaultFeeRateStep + } + amt := btcutil.Amount(1_000_000 * (i + 1)) + weight := wantWeights[i] + expectedFee := feeRate.FeeForWeight(weight) + + // Check weight. + gotWeight := lntypes.WeightUnit( + blockchain.GetTransactionWeight(btcutil.NewTx(tx)), + ) + require.Equal(t, weight, gotWeight, "weights don't match") + + // Check fee. + out := btcutil.Amount(tx.TxOut[0].Value) + gotFee := amt - out + require.Equal(t, expectedFee, gotFee, "fees don't match") + + // Check fee rate. + gotFeeRate := chainfee.NewSatPerKWeight(gotFee, gotWeight) + require.Equal(t, feeRate, gotFeeRate, "fee rates don't match") + } + + // Make sure we have stored the batch. + batches, err := batcherStore.FetchUnconfirmedSweepBatches(ctx) + require.NoError(t, err) + require.Len(t, batches, 1) + + // Now make the batcher quit by canceling the context. + cancel() + wg.Wait() + + // Make sure the batcher exited without an error. + checkBatcherError(t, runErr) +} + +// testWithMixedBatchCustom tests mixed batches construction, custom scenario. +// All sweeps are added at once. +func testWithMixedBatchCustom(t *testing.T, store testStore, + batcherStore testBatcherStore, preimages []lntypes.Preimage, + muSig2SignSweep MuSig2SignSweep, nonCoopHints []bool, + expectSignOutputRawChannel bool, wantWeight lntypes.WeightUnit, + wantWitnessSizes []int) { + + defer test.Guard(t)() + + lnd := test.NewMockLnd() + ctx, cancel := context.WithCancel(context.Background()) + + // Extract payment address from the invoice. + swapPaymentAddr, err := utils.ObtainSwapPaymentAddr( + swapInvoice, lnd.ChainParams, + ) + require.NoError(t, err) + + // Use sweepFetcher to provide NonCoopHint for swapHash1. + sweepFetcher := &sweepFetcherMock{ + store: map[lntypes.Hash]*SweepInfo{}, + } + + // Swap hashes must match the preimages, for non-cooperative spending + // path to work. + swapHashes := make([]lntypes.Hash, len(preimages)) + for i, preimage := range preimages { + swapHashes[i] = preimage.Hash() + } + + // Use mixed batches. + batcher := NewBatcher(lnd.WalletKit, lnd.ChainNotifier, lnd.Signer, + muSig2SignSweep, testVerifySchnorrSig, lnd.ChainParams, + batcherStore, sweepFetcher, WithMixedBatch()) + + var wg sync.WaitGroup + wg.Add(1) + + var runErr error + go func() { + defer wg.Done() + runErr = batcher.Run(ctx) + }() + + // Wait for the batcher to be initialized. + <-batcher.initDone + + // Create swaps and sweeps. + for i, swapHash := range swapHashes { + // Put a swap into store to satisfy SQL constraints. + swap := &loopdb.LoopOutContract{ + SwapContract: loopdb.SwapContract{ + CltvExpiry: 111, + AmountRequested: 1_000_000, + ProtocolVersion: loopdb.ProtocolVersionMuSig2, + HtlcKeys: htlcKeys, + + // Make preimage unique to pass SQL constraints. + Preimage: preimages[i], + }, + + DestAddr: destAddr, + SwapInvoice: swapInvoice, + SweepConfTarget: 111, + } + + require.NoError(t, store.CreateLoopOut(ctx, swapHash, swap)) + store.AssertLoopOutStored() + + // Add SweepInfo to sweepFetcher. + htlc, err := utils.GetHtlc( + swapHash, &swap.SwapContract, lnd.ChainParams, + ) + require.NoError(t, err) + + sweepFetcher.store[swapHash] = &SweepInfo{ + Preimage: preimages[i], + NonCoopHint: nonCoopHints[i], + + ConfTarget: 123, + Timeout: 111, + SwapInvoicePaymentAddr: *swapPaymentAddr, + ProtocolVersion: loopdb.ProtocolVersionMuSig2, + HTLCKeys: htlcKeys, + HTLC: *htlc, + HTLCSuccessEstimator: htlc.AddSuccessToEstimator, + DestAddr: destAddr, + } + + // Create sweep request. + sweepReq := SweepRequest{ + SwapHash: swapHash, + Value: 1_000_000, + Outpoint: wire.OutPoint{ + Hash: chainhash.Hash{1, 1}, + Index: 1, + }, + Notifier: &dummyNotifier, + } + require.NoError(t, batcher.AddSweep(&sweepReq)) + + if i == 0 { + // Since a batch was created we check that it registered + // for its primary sweep's spend. + <-lnd.RegisterSpendChannel + } + } + + if expectSignOutputRawChannel { + // Expect mockSigner.SignOutputRaw call to sign non-cooperative + // sweeps. + <-lnd.SignOutputRawChannel + } + + // A transaction is published. + tx := <-lnd.TxPublishChannel + require.Equal(t, len(preimages), len(tx.TxIn)) + + // Check types of inputs. + var witnessSizes []int + for _, txIn := range tx.TxIn { + witnessSizes = append(witnessSizes, len(txIn.Witness)) + } + // The order of inputs is not deterministic, because they + // are stored in map. + require.ElementsMatch(t, wantWitnessSizes, witnessSizes) + + // Calculate expected values. + feeRate := test.DefaultMockFee + amt := btcutil.Amount(1_000_000 * len(preimages)) + expectedFee := feeRate.FeeForWeight(wantWeight) + + // Check weight. + gotWeight := lntypes.WeightUnit( + blockchain.GetTransactionWeight(btcutil.NewTx(tx)), + ) + require.Equal(t, wantWeight, gotWeight, "weights don't match") + + // Check fee. + out := btcutil.Amount(tx.TxOut[0].Value) + gotFee := amt - out + require.Equal(t, expectedFee, gotFee, "fees don't match") + + // Check fee rate. + gotFeeRate := chainfee.NewSatPerKWeight(gotFee, gotWeight) + require.Equal(t, feeRate, gotFeeRate, "fee rates don't match") + + // Make sure we have stored the batch. + batches, err := batcherStore.FetchUnconfirmedSweepBatches(ctx) + require.NoError(t, err) + require.Len(t, batches, 1) + + // Now make the batcher quit by canceling the context. + cancel() + wg.Wait() + + // Make sure the batcher exited without an error. + checkBatcherError(t, runErr) +} + +// testWithMixedBatchLarge tests mixed batches construction, many sweeps. +// All sweeps are added at once. +func testWithMixedBatchLarge(t *testing.T, store testStore, + batcherStore testBatcherStore) { + + // Create 9 sweeps. 3 groups of 3 sweeps. + // 1. known in advance to be non-cooperative, + // 2. fails cosigning during an attempt, + // 3. co-signs successfully. + var preimages = []lntypes.Preimage{ + {1}, {2}, {3}, + {4}, {5}, {6}, + {7}, {8}, {9}, + } + + // Create muSig2SignSweep. It fails all the sweeps, works only one time + // for swapHashes[2] and works any number of times for 5 and 8. This + // emulates client disconnect after first successful co-signing. + swapHash2Used := false + muSig2SignSweep := func(ctx context.Context, + protocolVersion loopdb.ProtocolVersion, swapHash lntypes.Hash, + paymentAddr [32]byte, nonce []byte, sweepTxPsbt []byte, + prevoutMap map[wire.OutPoint]*wire.TxOut) ( + []byte, []byte, error) { + + switch { + case swapHash == preimages[2].Hash(): + if swapHash2Used { + return nil, nil, fmt.Errorf("disconnected") + } else { + swapHash2Used = true + + return nil, nil, nil + } + + case swapHash == preimages[5].Hash(): + return nil, nil, nil + + case swapHash == preimages[8].Hash(): + return nil, nil, nil + + default: + return nil, nil, fmt.Errorf("test error") + } + } + + // The first sweep in a group is known in advance to be + // non-cooperative. + nonCoopHints := []bool{ + true, false, false, + true, false, false, + true, false, false, + } + + // Expect mockSigner.SignOutputRaw call to sign non-cooperative + // sweeps. + expectSignOutputRawChannel := true + + // Two non-cooperative sweeps, one cooperative. + wantWitnessSizes := []int{4, 4, 4, 4, 4, 1, 4, 4, 1} + + // Expected weight. + wantWeight := lntypes.WeightUnit(3377) + + testWithMixedBatchCustom(t, store, batcherStore, preimages, + muSig2SignSweep, nonCoopHints, expectSignOutputRawChannel, + wantWeight, wantWitnessSizes) +} + +// testWithMixedBatchCoopOnly tests mixed batches construction, +// All sweeps are added at once. All the sweeps are cooperative. +func testWithMixedBatchCoopOnly(t *testing.T, store testStore, + batcherStore testBatcherStore) { + + // Create 3 sweeps, all cooperative. + var preimages = []lntypes.Preimage{ + {1}, {2}, {3}, + } + + // Create muSig2SignSweep, working for all sweeps. + muSig2SignSweep := func(ctx context.Context, + protocolVersion loopdb.ProtocolVersion, swapHash lntypes.Hash, + paymentAddr [32]byte, nonce []byte, sweepTxPsbt []byte, + prevoutMap map[wire.OutPoint]*wire.TxOut) ( + []byte, []byte, error) { + + return nil, nil, nil + } + + // All the sweeps are cooperative. + nonCoopHints := []bool{false, false, false} + + // Do not expect a mockSigner.SignOutputRaw call, because there are no + // non-cooperative sweeps. + expectSignOutputRawChannel := false + + // Two non-cooperative sweeps, one cooperative. + wantWitnessSizes := []int{1, 1, 1} + + // Expected weight. + wantWeight := lntypes.WeightUnit(856) + + testWithMixedBatchCustom(t, store, batcherStore, preimages, + muSig2SignSweep, nonCoopHints, expectSignOutputRawChannel, + wantWeight, wantWitnessSizes) +} + +// testWithMixedBatchNonCoopHintOnly tests mixed batches construction, +// All sweeps are added at once. All the sweeps are known to be non-cooperative +// in advance. +func testWithMixedBatchNonCoopHintOnly(t *testing.T, store testStore, + batcherStore testBatcherStore) { + + // Create 3 sweeps, all known to be non-cooperative in advance. + var preimages = []lntypes.Preimage{ + {1}, {2}, {3}, + } + + // Create muSig2SignSweep, panicking for all sweeps. + muSig2SignSweep := func(ctx context.Context, + protocolVersion loopdb.ProtocolVersion, swapHash lntypes.Hash, + paymentAddr [32]byte, nonce []byte, sweepTxPsbt []byte, + prevoutMap map[wire.OutPoint]*wire.TxOut) ( + []byte, []byte, error) { + + panic("must not be called in this test") + } + + // All the sweeps are non-cooperative, this is known in advance. + nonCoopHints := []bool{true, true, true} + + // Expect mockSigner.SignOutputRaw call to sign non-cooperative + // sweeps. + expectSignOutputRawChannel := true + + // Two non-cooperative sweeps, one cooperative. + wantWitnessSizes := []int{4, 4, 4} + + // Expected weight. + wantWeight := lntypes.WeightUnit(1345) + + testWithMixedBatchCustom(t, store, batcherStore, preimages, + muSig2SignSweep, nonCoopHints, expectSignOutputRawChannel, + wantWeight, wantWitnessSizes) +} + +// testWithMixedBatchCoopFailedOnly tests mixed batches construction, +// All sweeps are added at once. All the sweeps fail co-signing. +func testWithMixedBatchCoopFailedOnly(t *testing.T, store testStore, + batcherStore testBatcherStore) { + + // Create 3 sweeps, all fail co-signing. + var preimages = []lntypes.Preimage{ + {1}, {2}, {3}, + } + + // Create muSig2SignSweep, failing any co-sign attempt. + muSig2SignSweep := func(ctx context.Context, + protocolVersion loopdb.ProtocolVersion, swapHash lntypes.Hash, + paymentAddr [32]byte, nonce []byte, sweepTxPsbt []byte, + prevoutMap map[wire.OutPoint]*wire.TxOut) ( + []byte, []byte, error) { + + return nil, nil, fmt.Errorf("test error") + } + + // All the sweeps are non-cooperative, but this is not known in advance. + nonCoopHints := []bool{false, false, false} + + // Expect mockSigner.SignOutputRaw call to sign non-cooperative + // sweeps. + expectSignOutputRawChannel := true + + // Two non-cooperative sweeps, one cooperative. + wantWitnessSizes := []int{4, 4, 4} + + // Expected weight. + wantWeight := lntypes.WeightUnit(1345) + + testWithMixedBatchCustom(t, store, batcherStore, preimages, + muSig2SignSweep, nonCoopHints, expectSignOutputRawChannel, + wantWeight, wantWitnessSizes) +} + // TestSweepBatcherBatchCreation tests that sweep requests enter the expected // batch based on their timeout distance. func TestSweepBatcherBatchCreation(t *testing.T) { @@ -2980,6 +3526,37 @@ func TestCustomSignMuSig2(t *testing.T) { runTests(t, testCustomSignMuSig2) } +// TestWithMixedBatch tests mixed batches construction. It also tests +// non-cooperative sweeping (using a preimage). Sweeps are added one by one. +func TestWithMixedBatch(t *testing.T) { + runTests(t, testWithMixedBatch) +} + +// TestWithMixedBatchLarge tests mixed batches construction, many sweeps. +// All sweeps are added at once. +func TestWithMixedBatchLarge(t *testing.T) { + runTests(t, testWithMixedBatchLarge) +} + +// TestWithMixedBatchCoopOnly tests mixed batches construction, +// All sweeps are added at once. All the sweeps are cooperative. +func TestWithMixedBatchCoopOnly(t *testing.T) { + runTests(t, testWithMixedBatchCoopOnly) +} + +// TestWithMixedBatchNonCoopHintOnly tests mixed batches construction, +// All sweeps are added at once. All the sweeps are known to be non-cooperative +// in advance. +func TestWithMixedBatchNonCoopHintOnly(t *testing.T) { + runTests(t, testWithMixedBatchNonCoopHintOnly) +} + +// TestWithMixedBatchCoopFailedOnly tests mixed batches construction, +// All sweeps are added at once. All the sweeps fail co-signing. +func TestWithMixedBatchCoopFailedOnly(t *testing.T) { + runTests(t, testWithMixedBatchCoopFailedOnly) +} + // testBatcherStore is BatcherStore used in tests. type testBatcherStore interface { BatcherStore diff --git a/test/signer_mock.go b/test/signer_mock.go index e7f37624d..985e8577d 100644 --- a/test/signer_mock.go +++ b/test/signer_mock.go @@ -27,7 +27,12 @@ func (s *mockSigner) SignOutputRaw(ctx context.Context, tx *wire.MsgTx, SignDescriptors: signDescriptors, } - rawSigs := [][]byte{{1, 2, 3}} + rawSigs := make([][]byte, len(signDescriptors)) + for i := range signDescriptors { + sig := make([]byte, 64) + sig[0] = byte(i + 1) + rawSigs[i] = sig + } return rawSigs, nil } @@ -118,7 +123,10 @@ func (s *mockSigner) MuSig2Sign(context.Context, [32]byte, [32]byte, func (s *mockSigner) MuSig2CombineSig(context.Context, [32]byte, [][]byte) (bool, []byte, error) { - return true, nil, nil + sig := make([]byte, 64) + sig[0] = 42 + + return true, sig, nil } // MuSig2Cleanup removes a session from memory to free up resources.