diff --git a/core/host/host.go b/core/host/host.go index e62be281f..f6d80d001 100644 --- a/core/host/host.go +++ b/core/host/host.go @@ -5,6 +5,7 @@ package host import ( "context" + "net" "github.com/libp2p/go-libp2p/core/connmgr" "github.com/libp2p/go-libp2p/core/event" @@ -72,4 +73,7 @@ type Host interface { // EventBus returns the hosts eventbus EventBus() event.Bus + + // MapPort is a utility function that attempts to set up a port mapping + MapPort(protocol string, internalPort int) (net.Addr, int, error) } diff --git a/p2p/host/basic/basic_host.go b/p2p/host/basic/basic_host.go index 70c40bf18..712ba916f 100644 --- a/p2p/host/basic/basic_host.go +++ b/p2p/host/basic/basic_host.go @@ -1125,6 +1125,21 @@ func (h *BasicHost) Close() error { return nil } +func (h *BasicHost) MapPort(protocol string, internalPort int) (net.Addr, int, error) { + for h.natmgr == nil || h.natmgr.NAT() == nil { + time.Sleep(time.Millisecond * 100) + } + mapping, err := h.natmgr.NAT().NewMapping(protocol, internalPort) + if err != nil { + return nil, 0, err + } + addr, err := mapping.ExternalAddr() + if err != nil { + return nil, 0, err + } + return addr, mapping.ExternalPort(), nil +} + type streamWrapper struct { network.Stream rw io.ReadWriteCloser diff --git a/p2p/host/basic/natmgr.go b/p2p/host/basic/natmgr.go index 782c116d4..0d3e8dfc8 100644 --- a/p2p/host/basic/natmgr.go +++ b/p2p/host/basic/natmgr.go @@ -14,6 +14,12 @@ import ( ma "github.com/multiformats/go-multiaddr" ) +// discoveryNATPeriod is the period at which we try to discover NATs. +var discoveryNATPeriod = 3 * time.Second + +// discoveryTry is the number of times we try to discover NATs. +const discoveryTry = 5 + // NATManager is a simple interface to manage NAT devices. type NATManager interface { // NAT gets the NAT device managed by the NAT manager. @@ -88,11 +94,20 @@ func (nmgr *natManager) background(ctx context.Context) { discoverCtx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() + tryCount := 0 natInstance, err := inat.DiscoverNAT(discoverCtx) - if err != nil { + tryCount++ + for err != nil { log.Info("DiscoverNAT error:", err) - close(nmgr.ready) - return + if tryCount > discoveryTry { + log.Info("DiscoverNAT failed after ", tryCount, " tries") + return + } + time.Sleep(discoveryNATPeriod) + discoverCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + natInstance, err = inat.DiscoverNAT(discoverCtx) + tryCount++ } nmgr.natMx.Lock() diff --git a/p2p/host/blank/blank.go b/p2p/host/blank/blank.go index 9f3daeff2..9b23dabef 100644 --- a/p2p/host/blank/blank.go +++ b/p2p/host/blank/blank.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "io" + "net" "github.com/libp2p/go-libp2p/core/connmgr" "github.com/libp2p/go-libp2p/core/event" @@ -227,3 +228,7 @@ func (bh *BlankHost) ConnManager() connmgr.ConnManager { func (bh *BlankHost) EventBus() event.Bus { return bh.eventbus } + +func (bh *BlankHost) MapPort(protocol string, internalPort int) (net.Addr, int, error) { + return nil, 0, nil +} diff --git a/p2p/host/routed/routed.go b/p2p/host/routed/routed.go index eb8e58ee7..80c708430 100644 --- a/p2p/host/routed/routed.go +++ b/p2p/host/routed/routed.go @@ -3,6 +3,7 @@ package routedhost import ( "context" "fmt" + "net" "time" "github.com/libp2p/go-libp2p/core/connmgr" @@ -219,4 +220,8 @@ func (rh *RoutedHost) ConnManager() connmgr.ConnManager { return rh.host.ConnManager() } +func (rh *RoutedHost) MapPort(protocol string, internalPort int) (net.Addr, int, error) { + return rh.host.MapPort(protocol, internalPort) +} + var _ (host.Host) = (*RoutedHost)(nil)