diff --git a/api/tbcapi/tbcapi.go b/api/tbcapi/tbcapi.go index 4111b91d..ede64578 100644 --- a/api/tbcapi/tbcapi.go +++ b/api/tbcapi/tbcapi.go @@ -11,11 +11,15 @@ import ( "reflect" "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" "github.com/hemilabs/heminetwork/api" "github.com/hemilabs/heminetwork/api/protocol" ) +// XXX we should kill the wrapping types that are basically identical to wire. +// Wire is a full citizen so treat it as such. + const ( APIVersion = 1 @@ -54,6 +58,12 @@ const ( CmdTxByIdRequest = "tbcapi-tx-by-id-request" CmdTxByIdResponse = "tbcapi-tx-by-id-response" + + CmdTxBroadcastRequest = "tbcapi-tx-broadcast-request" + CmdTxBroadcastResponse = "tbcapi-tx-broadcast-response" + + CmdTxBroadcastRawRequest = "tbcapi-tx-broadcast-raw-request" + CmdTxBroadcastRawResponse = "tbcapi-tx-broadcast-raw-response" ) var ( @@ -229,6 +239,26 @@ type TxByIdResponse struct { Error *protocol.Error `json:"error,omitempty"` } +type TxBroadcastRequest struct { + Tx *wire.MsgTx `json:"tx"` + Force bool `json:"force"` +} + +type TxBroadcastResponse struct { + TxID *chainhash.Hash `json:"tx_id"` + Error *protocol.Error `json:"error,omitempty"` +} + +type TxBroadcastRawRequest struct { + Tx api.ByteSlice `json:"tx"` + Force bool `json:"force"` +} + +type TxBroadcastRawResponse struct { + TxID *chainhash.Hash `json:"tx_id"` + Error *protocol.Error `json:"error,omitempty"` +} + var commands = map[protocol.Command]reflect.Type{ CmdPingRequest: reflect.TypeOf(PingRequest{}), CmdPingResponse: reflect.TypeOf(PingResponse{}), @@ -254,6 +284,10 @@ var commands = map[protocol.Command]reflect.Type{ CmdTxByIdRawResponse: reflect.TypeOf(TxByIdRawResponse{}), CmdTxByIdRequest: reflect.TypeOf(TxByIdRequest{}), CmdTxByIdResponse: reflect.TypeOf(TxByIdResponse{}), + CmdTxBroadcastRequest: reflect.TypeOf(TxBroadcastRequest{}), + CmdTxBroadcastResponse: reflect.TypeOf(TxBroadcastResponse{}), + CmdTxBroadcastRawRequest: reflect.TypeOf(TxBroadcastRawRequest{}), + CmdTxBroadcastRawResponse: reflect.TypeOf(TxBroadcastRawResponse{}), } type tbcAPI struct{} diff --git a/service/tbc/peer_manager.go b/service/tbc/peer_manager.go index 93414af7..4760de39 100644 --- a/service/tbc/peer_manager.go +++ b/service/tbc/peer_manager.go @@ -281,6 +281,29 @@ func (pm *PeerManager) All(ctx context.Context, f func(ctx context.Context, p *p } } +func (pm *PeerManager) AllBlock(ctx context.Context, f func(ctx context.Context, p *peer)) { + log.Tracef("AllBlock") + defer log.Tracef("AllBlock") + + var wgAll sync.WaitGroup + + pm.mtx.RLock() + for _, p := range pm.peers { + if !p.isConnected() { + continue + } + wgAll.Add(1) + go func() { + defer wgAll.Done() + f(ctx, p) + }() + } + pm.mtx.RUnlock() + + log.Infof("AllBlock waiting") + wgAll.Wait() +} + // RandomConnect blocks until there is a peer ready to use. func (pm *PeerManager) RandomConnect(ctx context.Context) (*peer, error) { log.Tracef("RandomConnect") diff --git a/service/tbc/rpc.go b/service/tbc/rpc.go index 30b36759..5f22a0c2 100644 --- a/service/tbc/rpc.go +++ b/service/tbc/rpc.go @@ -147,6 +147,20 @@ func (s *Server) handleWebsocketRead(ctx context.Context, ws *tbcWs) { return s.handleTxByIdRawRequest(ctx, req) } + go s.handleRequest(ctx, ws, id, cmd, handler) + case tbcapi.CmdTxBroadcastRequest: + handler := func(ctx context.Context) (any, error) { + req := payload.(*tbcapi.TxBroadcastRequest) + return s.handleTxBroadcastRequest(ctx, req) + } + + go s.handleRequest(ctx, ws, id, cmd, handler) + case tbcapi.CmdTxBroadcastRawRequest: + handler := func(ctx context.Context) (any, error) { + req := payload.(*tbcapi.TxBroadcastRawRequest) + return s.handleTxBroadcastRawRequest(ctx, req) + } + go s.handleRequest(ctx, ws, id, cmd, handler) default: err = fmt.Errorf("unknown command: %v", cmd) @@ -495,6 +509,45 @@ func (s *Server) handleTxByIdRequest(ctx context.Context, req *tbcapi.TxByIdRequ }, nil } +func (s *Server) handleTxBroadcastRequest(ctx context.Context, req *tbcapi.TxBroadcastRequest) (any, error) { + log.Tracef("handleTxBroadcastRequest") + defer log.Tracef("handleTxBroadcastRequest exit") + + txid, err := s.TxBroadcast(ctx, req.Tx, req.Force) + if err != nil { + if errors.Is(err, ErrTxAlreadyBroadcast) || errors.Is(err, ErrTxBroadcastNoPeers) { + return &tbcapi.TxBroadcastResponse{Error: protocol.RequestError(err)}, err + } + e := protocol.NewInternalError(err) + return &tbcapi.TxBroadcastResponse{Error: e.ProtocolError()}, e + } + + return &tbcapi.TxBroadcastResponse{TxID: txid}, nil +} + +func (s *Server) handleTxBroadcastRawRequest(ctx context.Context, req *tbcapi.TxBroadcastRawRequest) (any, error) { + log.Tracef("handleTxBroadcastRawRequest") + defer log.Tracef("handleTxBroadcastRawRequest exit") + + tx := wire.NewMsgTx(0) + err := tx.Deserialize(bytes.NewBuffer(req.Tx)) + if err != nil { + return &tbcapi.TxBroadcastResponse{ + Error: protocol.RequestError(err), + }, nil + } + txid, err := s.TxBroadcast(ctx, tx, req.Force) + if err != nil { + if errors.Is(err, ErrTxAlreadyBroadcast) || errors.Is(err, ErrTxBroadcastNoPeers) { + return &tbcapi.TxBroadcastResponse{Error: protocol.RequestError(err)}, err + } + e := protocol.NewInternalError(err) + return &tbcapi.TxBroadcastResponse{Error: e.ProtocolError()}, e + } + + return &tbcapi.TxBroadcastResponse{TxID: txid}, nil +} + func (s *Server) handleWebsocket(w http.ResponseWriter, r *http.Request) { log.Tracef("handleWebsocket: %v", r.RemoteAddr) defer log.Tracef("handleWebsocket exit: %v", r.RemoteAddr) diff --git a/service/tbc/tbc.go b/service/tbc/tbc.go index 807798e3..6397ea00 100644 --- a/service/tbc/tbc.go +++ b/service/tbc/tbc.go @@ -16,6 +16,7 @@ import ( "reflect" "strconv" "sync" + "sync/atomic" "time" "github.com/btcsuite/btcd/blockchain" @@ -63,6 +64,9 @@ var ( localnetSeeds = []string{ "127.0.0.1:18444", } + + ErrTxAlreadyBroadcast = errors.New("tx already broadcast") + ErrTxBroadcastNoPeers = errors.New("can't broadcast tx, no peers") ) var log = loggo.GetLogger("tbc") @@ -114,6 +118,9 @@ type Server struct { // mempool mempool *mempool + // broadcast + broadcast map[chainhash.Hash]*wire.MsgTx + // bitcoin network seeds []string // XXX remove wireNet wire.BitcoinNet @@ -173,6 +180,7 @@ func NewServer(cfg *Config) (*Server, error) { }), sessions: make(map[string]*tbcWs), requestTimeout: defaultRequestTimeout, + broadcast: make(map[chainhash.Hash]*wire.MsgTx, 16), } log.Infof("MEMPOOL IS CURRENTLY BROKEN AND HAS BEEN DISABLED") @@ -309,6 +317,14 @@ func (s *Server) handleGeneric(ctx context.Context, p *peer, msg wire.Message, r return false, fmt.Errorf("handle generic not found: %w", err) } + case *wire.MsgGetData: + if err := s.handleGetData(ctx, p, m, raw); err != nil { + return false, fmt.Errorf("handle generic get data: %w", err) + } + + case *wire.MsgMemPool: + log.Infof("mempool: %v", spew.Sdump(m)) + default: return false, nil } @@ -639,14 +655,20 @@ func (s *Server) handlePeer(ctx context.Context, p *peer) error { // Get p2p information. err = p.write(defaultCmdTimeout, wire.NewMsgGetAddr()) - if err != nil && !errors.Is(err, net.ErrClosed) { + if err != nil { + return err + } + + // Broadcast all tx's to new node. + err = s.TxBroadcastAllToPeer(ctx, p) + if err != nil { return err } if s.cfg.MempoolEnabled { // Start building the mempool. err = p.write(defaultCmdTimeout, wire.NewMsgMemPool()) - if err != nil && !errors.Is(err, net.ErrClosed) { + if err != nil { return err } } @@ -1205,9 +1227,19 @@ func (s *Server) handleBlock(ctx context.Context, p *peer, msg *wire.MsgBlock, r len(msg.Transactions), msg.Header.Timestamp) } + // Reap broadcast messages. + txHashes, _ := block.MsgBlock().TxHashes() + s.mtx.Lock() + for _, v := range txHashes { + if _, ok := s.broadcast[v]; ok { + delete(s.broadcast, v) + log.Infof("broadcast tx %v included in %v %v", v, bhs, height) + } + } + s.mtx.Unlock() + // Reap txs from mempool, no need to log error. if s.cfg.MempoolEnabled { - txHashes, _ := block.MsgBlock().TxHashes() _ = s.mempool.txsRemove(ctx, txHashes) } @@ -1301,6 +1333,43 @@ func (s *Server) handleNotFound(ctx context.Context, p *peer, msg *wire.MsgNotFo return nil } +func (s *Server) handleGetData(ctx context.Context, p *peer, msg *wire.MsgGetData, raw []byte) error { + log.Tracef("handleGetData %v", p) + defer log.Tracef("handleGetData %v exit", p) + + for _, v := range msg.InvList { + switch v.Type { + case wire.InvTypeError: + log.Errorf("get data error: %v", v.Hash) + case wire.InvTypeTx: + s.mtx.RLock() + if tx, ok := s.broadcast[v.Hash]; ok { + log.Debugf("handleGetData %v", spew.Sdump(msg)) + txc := tx.Copy() + err := p.write(defaultCmdTimeout, txc) + if err != nil { + log.Errorf("write tx: %v", err) + } + } + s.mtx.RUnlock() + case wire.InvTypeBlock: + log.Infof("get data block: %v", v.Hash) + case wire.InvTypeFilteredBlock: + log.Infof("get data filtered block: %v", v.Hash) + case wire.InvTypeWitnessBlock: + log.Infof("get data witness block: %v", v.Hash) + case wire.InvTypeWitnessTx: + log.Infof("get data witness tx: %v", v.Hash) + case wire.InvTypeFilteredWitnessBlock: + log.Infof("get data filtered witness block: %v", v.Hash) + default: + log.Errorf("get data unknown: %v", spew.Sdump(v.Hash)) + } + } + + return nil +} + func (s *Server) insertGenesis(ctx context.Context) error { log.Tracef("insertGenesis") defer log.Tracef("insertGenesis exit") @@ -1538,6 +1607,76 @@ func (s *Server) TxById(ctx context.Context, txId *chainhash.Hash) (*wire.MsgTx, return nil, database.ErrNotFound } +func (s *Server) TxBroadcastAllToPeer(ctx context.Context, p *peer) error { + log.Tracef("TxBroadcastAllToPeer %v", p) + defer log.Tracef("TxBroadcastAllToPeer %v exit", p) + + s.mtx.RLock() + if len(s.broadcast) == 0 { + s.mtx.RUnlock() + return nil + } + + invTx := wire.NewMsgInv() + for k := range s.broadcast { + err := invTx.AddInvVect(wire.NewInvVect(wire.InvTypeTx, &k)) + if err != nil { + s.mtx.RUnlock() + return fmt.Errorf("invalid vector: %w", err) + } + } + s.mtx.RUnlock() + + err := p.write(defaultCmdTimeout, invTx) + if err != nil { + return fmt.Errorf("broadcast all %v: %w", p, err) + } + + log.Debugf("broadcast all txs to peer %v: tx count %v", p, len(invTx.InvList)) + + return nil +} + +func (s *Server) TxBroadcast(ctx context.Context, tx *wire.MsgTx, force bool) (*chainhash.Hash, error) { + log.Tracef("TxBroadcast") + defer log.Tracef("TxBroadcast exit") + + s.mtx.Lock() + if _, ok := s.broadcast[tx.TxHash()]; ok && !force { + s.mtx.Unlock() + return nil, ErrTxAlreadyBroadcast + } + s.broadcast[tx.TxHash()] = tx + txb := tx.Copy() + s.mtx.Unlock() + + txHash := txb.TxHash() + invTx := wire.NewMsgInv() + err := invTx.AddInvVect(wire.NewInvVect(wire.InvTypeTx, &txHash)) + if err != nil { + return nil, fmt.Errorf("invalid vector: %w", err) + } + var success atomic.Uint64 + inv := func(ctx context.Context, p *peer) { + log.Tracef("inv %v", p) + defer log.Tracef("inv %v exit", p) + + err := p.write(defaultCmdTimeout, invTx) + if err != nil { + log.Debugf("inv %v: %v", p, err) + return + } + success.Add(1) + } + s.pm.AllBlock(ctx, inv) + + if success.Load() == 0 { + return nil, ErrTxBroadcastNoPeers + } + + return &txHash, nil +} + func feesFromTransactions(txs []*btcutil.Tx) error { for idx, tx := range txs { for _, txIn := range tx.MsgTx().TxIn {