diff --git a/transactions/notifications_test.go b/transactions/notifications_test.go new file mode 100644 index 00000000..25721bd4 --- /dev/null +++ b/transactions/notifications_test.go @@ -0,0 +1,179 @@ +package transactions + +import ( + "context" + "testing" + + "github.com/getAlby/hub/constants" + "github.com/getAlby/hub/db" + "github.com/getAlby/hub/events" + "github.com/getAlby/hub/tests" + "github.com/stretchr/testify/assert" +) + +func TestNotifications_ReceivedKnownPayment(t *testing.T) { + ctx := context.TODO() + + defer tests.RemoveTestService() + svc, err := tests.CreateTestService() + assert.NoError(t, err) + + mockPreimage := tests.MockLNClientTransaction.Preimage + svc.DB.Create(&db.Transaction{ + State: constants.TRANSACTION_STATE_PENDING, + Type: constants.TRANSACTION_TYPE_INCOMING, + PaymentRequest: tests.MockLNClientTransaction.Invoice, + PaymentHash: tests.MockLNClientTransaction.PaymentHash, + Preimage: &mockPreimage, + AmountMsat: 123000, + }) + + transactionsService := NewTransactionsService(svc.DB) + + transactionsService.ConsumeEvent(ctx, &events.Event{ + Event: "nwc_payment_received", + Properties: tests.MockLNClientTransaction, + }, map[string]interface{}{}) + + incomingTransaction, err := transactionsService.LookupTransaction(ctx, tests.MockLNClientTransaction.PaymentHash, nil, svc.LNClient, nil) + assert.NoError(t, err) + assert.Equal(t, uint64(123000), incomingTransaction.AmountMsat) + assert.Equal(t, constants.TRANSACTION_STATE_SETTLED, incomingTransaction.State) + assert.Equal(t, tests.MockLNClientTransaction.Preimage, *incomingTransaction.Preimage) + assert.Nil(t, incomingTransaction.FeeReserveMsat) + + transactions := []db.Transaction{} + result := svc.DB.Find(&transactions) + assert.Equal(t, int64(1), result.RowsAffected) +} + +func TestNotifications_ReceivedUnknownPayment(t *testing.T) { + ctx := context.TODO() + + defer tests.RemoveTestService() + svc, err := tests.CreateTestService() + assert.NoError(t, err) + + transactionsService := NewTransactionsService(svc.DB) + + transactionsService.ConsumeEvent(ctx, &events.Event{ + Event: "nwc_payment_received", + Properties: tests.MockLNClientTransaction, + }, map[string]interface{}{}) + + transactionType := constants.TRANSACTION_TYPE_INCOMING + incomingTransaction, err := transactionsService.LookupTransaction(ctx, tests.MockLNClientTransaction.PaymentHash, &transactionType, svc.LNClient, nil) + assert.NoError(t, err) + assert.Equal(t, uint64(tests.MockLNClientTransaction.Amount), incomingTransaction.AmountMsat) + assert.Equal(t, constants.TRANSACTION_STATE_SETTLED, incomingTransaction.State) + assert.Equal(t, tests.MockLNClientTransaction.Preimage, *incomingTransaction.Preimage) + assert.Nil(t, incomingTransaction.FeeReserveMsat) + + transactions := []db.Transaction{} + result := svc.DB.Find(&transactions) + assert.Equal(t, int64(1), result.RowsAffected) +} + +func TestNotifications_SentKnownPayment(t *testing.T) { + ctx := context.TODO() + + defer tests.RemoveTestService() + svc, err := tests.CreateTestService() + assert.NoError(t, err) + + feeReserve := uint64(10000) + svc.DB.Create(&db.Transaction{ + State: constants.TRANSACTION_STATE_PENDING, + Type: constants.TRANSACTION_TYPE_OUTGOING, + PaymentRequest: tests.MockLNClientTransaction.Invoice, + PaymentHash: tests.MockLNClientTransaction.PaymentHash, + AmountMsat: 123000, + FeeReserveMsat: &feeReserve, + }) + + transactionsService := NewTransactionsService(svc.DB) + + transactionsService.ConsumeEvent(ctx, &events.Event{ + Event: "nwc_payment_sent", + Properties: tests.MockLNClientTransaction, + }, map[string]interface{}{}) + + transactionType := constants.TRANSACTION_TYPE_OUTGOING + outgoingTransaction, err := transactionsService.LookupTransaction(ctx, tests.MockLNClientTransaction.PaymentHash, &transactionType, svc.LNClient, nil) + assert.NoError(t, err) + assert.Equal(t, uint64(123000), outgoingTransaction.AmountMsat) + assert.Equal(t, constants.TRANSACTION_STATE_SETTLED, outgoingTransaction.State) + assert.Equal(t, tests.MockLNClientTransaction.Preimage, *outgoingTransaction.Preimage) + assert.Zero(t, *outgoingTransaction.FeeReserveMsat) + + transactions := []db.Transaction{} + result := svc.DB.Find(&transactions) + assert.Equal(t, int64(1), result.RowsAffected) +} + +func TestNotifications_SentUnknownPayment(t *testing.T) { + ctx := context.TODO() + + defer tests.RemoveTestService() + svc, err := tests.CreateTestService() + assert.NoError(t, err) + + transactionsService := NewTransactionsService(svc.DB) + + transactions := []db.Transaction{} + result := svc.DB.Find(&transactions) + assert.Equal(t, int64(0), result.RowsAffected) + + transactionsService.ConsumeEvent(ctx, &events.Event{ + Event: "nwc_payment_sent", + Properties: tests.MockLNClientTransaction, + }, map[string]interface{}{}) + + transactionType := constants.TRANSACTION_TYPE_OUTGOING + outgoingTransaction, err := transactionsService.LookupTransaction(ctx, tests.MockLNClientTransaction.PaymentHash, &transactionType, svc.LNClient, nil) + assert.NoError(t, err) + assert.Equal(t, uint64(tests.MockLNClientTransaction.Amount), outgoingTransaction.AmountMsat) + assert.Equal(t, constants.TRANSACTION_STATE_SETTLED, outgoingTransaction.State) + assert.Equal(t, tests.MockLNClientTransaction.Preimage, *outgoingTransaction.Preimage) + assert.Zero(t, *outgoingTransaction.FeeReserveMsat) + + transactions = []db.Transaction{} + result = svc.DB.Find(&transactions) + assert.Equal(t, int64(1), result.RowsAffected) +} + +func TestNotifications_FailedKnownPayment(t *testing.T) { + ctx := context.TODO() + + defer tests.RemoveTestService() + svc, err := tests.CreateTestService() + assert.NoError(t, err) + + feeReserve := uint64(10000) + svc.DB.Create(&db.Transaction{ + State: constants.TRANSACTION_STATE_PENDING, + Type: constants.TRANSACTION_TYPE_OUTGOING, + PaymentRequest: tests.MockLNClientTransaction.Invoice, + PaymentHash: tests.MockLNClientTransaction.PaymentHash, + AmountMsat: 123000, + FeeReserveMsat: &feeReserve, + }) + + transactionsService := NewTransactionsService(svc.DB) + + transactionsService.ConsumeEvent(ctx, &events.Event{ + Event: "nwc_payment_failed_async", + Properties: tests.MockLNClientTransaction, + }, map[string]interface{}{}) + + transactionType := constants.TRANSACTION_TYPE_OUTGOING + outgoingTransaction, err := transactionsService.LookupTransaction(ctx, tests.MockLNClientTransaction.PaymentHash, &transactionType, svc.LNClient, nil) + assert.NoError(t, err) + assert.Equal(t, constants.TRANSACTION_STATE_FAILED, outgoingTransaction.State) + assert.Nil(t, outgoingTransaction.Preimage) + assert.Zero(t, *outgoingTransaction.FeeReserveMsat) + + transactions := []db.Transaction{} + result := svc.DB.Find(&transactions) + assert.Equal(t, int64(1), result.RowsAffected) +} diff --git a/transactions/todo_test.go b/transactions/todo_test.go index d267ee25..511fb72c 100644 --- a/transactions/todo_test.go +++ b/transactions/todo_test.go @@ -5,4 +5,5 @@ package transactions // TODO: list transactions // TODO: notifications -// TODO: fee reserve removed +// TODO: fee reserve removed - successful payment +// TODO: fee reserve removed - failed payment diff --git a/transactions/transactions_service.go b/transactions/transactions_service.go index 38f2923a..787e9ee5 100644 --- a/transactions/transactions_service.go +++ b/transactions/transactions_service.go @@ -486,7 +486,7 @@ func (svc *transactionsService) ConsumeEvent(ctx context.Context, event *events. }) if result.RowsAffected == 0 { - // Note: brand new payments (keysend only) cannot be associated with an app + // Note: brand new payments cannot be associated with an app var metadata string if lnClientTransaction.Metadata != nil { metadataBytes, err := json.Marshal(lnClientTransaction.Metadata) @@ -554,26 +554,60 @@ func (svc *transactionsService) ConsumeEvent(ctx context.Context, event *events. } var dbTransaction db.Transaction - result := svc.db.Find(&dbTransaction, &db.Transaction{ - Type: constants.TRANSACTION_TYPE_OUTGOING, - PaymentHash: lnClientTransaction.PaymentHash, - }) + err := svc.db.Transaction(func(tx *gorm.DB) error { + result := tx.Find(&dbTransaction, &db.Transaction{ + Type: constants.TRANSACTION_TYPE_OUTGOING, + PaymentHash: lnClientTransaction.PaymentHash, + }) - if result.RowsAffected == 0 { - logger.Logger.WithField("event", event).Error("Failed to find outgoing transaction by payment hash") - return - } + if result.RowsAffected == 0 { + // Note: brand new payments cannot be associated with an app + var metadata string + if lnClientTransaction.Metadata != nil { + metadataBytes, err := json.Marshal(lnClientTransaction.Metadata) + if err != nil { + logger.Logger.WithError(err).Error("Failed to serialize transaction metadata") + return err + } + metadata = string(metadataBytes) + } + var expiresAt *time.Time + if lnClientTransaction.ExpiresAt != nil { + expiresAtValue := time.Unix(*lnClientTransaction.ExpiresAt, 0) + expiresAt = &expiresAtValue + } + dbTransaction = db.Transaction{ + Type: constants.TRANSACTION_TYPE_OUTGOING, + AmountMsat: uint64(lnClientTransaction.Amount), + PaymentRequest: lnClientTransaction.Invoice, + PaymentHash: lnClientTransaction.PaymentHash, + Description: lnClientTransaction.Description, + DescriptionHash: lnClientTransaction.DescriptionHash, + ExpiresAt: expiresAt, + Metadata: metadata, + } + err := tx.Create(&dbTransaction).Error + if err != nil { + logger.Logger.WithFields(logrus.Fields{ + "payment_hash": lnClientTransaction.PaymentHash, + }).WithError(err).Error("Failed to create transaction") + return err + } + } + + settledAt := time.Now() + fee := uint64(lnClientTransaction.FeesPaid) + feeReserve := uint64(0) + err := tx.Model(&dbTransaction).Updates(&db.Transaction{ + FeeMsat: &fee, + FeeReserveMsat: &feeReserve, + Preimage: &lnClientTransaction.Preimage, + State: constants.TRANSACTION_STATE_SETTLED, + SettledAt: &settledAt, + }).Error + return err + }) - settledAt := time.Now() - fee := uint64(lnClientTransaction.FeesPaid) - feeReserve := uint64(0) - err := svc.db.Model(&dbTransaction).Updates(&db.Transaction{ - FeeMsat: &fee, - FeeReserveMsat: &feeReserve, - Preimage: &lnClientTransaction.Preimage, - State: constants.TRANSACTION_STATE_SETTLED, - SettledAt: &settledAt, - }).Error if err != nil { logger.Logger.WithFields(logrus.Fields{ "payment_hash": lnClientTransaction.PaymentHash, @@ -602,8 +636,10 @@ func (svc *transactionsService) ConsumeEvent(ctx context.Context, event *events. return } + feeReserve := uint64(0) err := svc.db.Model(&dbTransaction).Updates(&db.Transaction{ - State: constants.TRANSACTION_STATE_FAILED, + State: constants.TRANSACTION_STATE_FAILED, + FeeReserveMsat: &feeReserve, }).Error if err != nil { logger.Logger.WithFields(logrus.Fields{