diff --git a/stash.go b/stash.go index 14a1bab..7a9aee1 100644 --- a/stash.go +++ b/stash.go @@ -41,6 +41,7 @@ type options struct { skipVerify bool readTimeout time.Duration writeTimeout time.Duration + dialTimeout time.Duration tlsConfig *tls.Config } @@ -72,6 +73,13 @@ func SetWriteTimeout(writeTimeout time.Duration) Option { } } +// SetDialTimeout Option func +func SetDialTimeout(dialTimeout time.Duration) Option { + return func(o *options) { + o.dialTimeout = dialTimeout + } +} + // SetKeepAlive Option func func SetKeepAlive(keepAlive time.Duration) Option { return func(o *options) { @@ -94,9 +102,10 @@ func SetProtocolConn(protocol string) Option { } } -func (s *Stash) dial(address string, o *options) error { - conn, err := net.Dial(o.protocol, address) +func (s *Stash) dial(address string) error { + conn, err := s.o.dialer.Dial(s.o.protocol, address) if err != nil { + conn.Close() return err } @@ -104,12 +113,12 @@ func (s *Stash) dial(address string, o *options) error { // if useTLS true // Force stash to use TLS - if o.useTLS { + if s.o.useTLS { var tlsConfig *tls.Config - if o.tlsConfig == nil { - tlsConfig = &tls.Config{InsecureSkipVerify: o.skipVerify} + if s.o.tlsConfig == nil { + tlsConfig = &tls.Config{InsecureSkipVerify: s.o.skipVerify} } else { - tlsConfig = o.tlsConfig + tlsConfig = s.o.tlsConfig } if tlsConfig.ServerName == "" { @@ -143,6 +152,7 @@ func Connect(host string, port uint64, opts ...Option) (*Stash, error) { o := &options{ dialer: &net.Dialer{ KeepAlive: time.Minute * 5, + Timeout: 30 * time.Second, }, protocol: "tcp", writeTimeout: 30 * time.Second, @@ -153,7 +163,7 @@ func Connect(host string, port uint64, opts ...Option) (*Stash, error) { } s.o = o - if err := s.dial(address, o); err != nil { + if err := s.dial(address); err != nil { return nil, err } @@ -182,7 +192,7 @@ func (s *Stash) Write(data []byte) (int, error) { if strings.Contains(err.Error(), BrokenPipeError) { log.Printf("go-stash: %s | do re dial\n", err.Error()) // re dial ignore error - err = s.dial(s.address, s.o) + err = s.dial(s.address) if err != nil { log.Printf("go-stash: %s | do re dial\n", err.Error()) }