diff --git a/stash.go b/stash.go index 59de2b5..7a9aee1 100644 --- a/stash.go +++ b/stash.go @@ -15,10 +15,8 @@ const ( BrokenPipeError = "broken pipe" ) -var ( - // CRLF (Carriage Return and Line Feed in ASCII code) - CRLF = []byte{13, 10} -) +// CRLF (Carriage Return and Line Feed in ASCII code) +var CRLF = []byte{13, 10} func addCRLF(data []byte) []byte { return append(data, CRLF...) @@ -38,10 +36,12 @@ type Option func(*options) type options struct { dialer *net.Dialer + protocol string useTLS bool skipVerify bool readTimeout time.Duration writeTimeout time.Duration + dialTimeout time.Duration tlsConfig *tls.Config } @@ -73,6 +73,13 @@ func SetWriteTimeout(writeTimeout time.Duration) Option { } } +// SetDialTimeout Option func +func SetDialTimeout(dialTimeout time.Duration) Option { + return func(o *options) { + o.dialTimeout = dialTimeout + } +} + // SetKeepAlive Option func func SetKeepAlive(keepAlive time.Duration) Option { return func(o *options) { @@ -87,14 +94,18 @@ func SetTLSConfig(config *tls.Config) Option { } } -func (s *Stash) dial(address string, o *options) error { - addr, err := net.ResolveTCPAddr("tcp", address) - if err != nil { - return err +// SetProtocolConn Option func +// set protocol connection between logstash : `tcp` or `udp` +func SetProtocolConn(protocol string) Option { + return func(o *options) { + o.protocol = protocol } +} - conn, err := net.DialTCP("tcp", nil, addr) +func (s *Stash) dial(address string) error { + conn, err := s.o.dialer.Dial(s.o.protocol, address) if err != nil { + conn.Close() return err } @@ -102,12 +113,12 @@ func (s *Stash) dial(address string, o *options) error { // if useTLS true // Force stash to use TLS - if o.useTLS { + if s.o.useTLS { var tlsConfig *tls.Config - if o.tlsConfig == nil { - tlsConfig = &tls.Config{InsecureSkipVerify: o.skipVerify} + if s.o.tlsConfig == nil { + tlsConfig = &tls.Config{InsecureSkipVerify: s.o.skipVerify} } else { - tlsConfig = o.tlsConfig + tlsConfig = s.o.tlsConfig } if tlsConfig.ServerName == "" { @@ -141,7 +152,9 @@ func Connect(host string, port uint64, opts ...Option) (*Stash, error) { o := &options{ dialer: &net.Dialer{ KeepAlive: time.Minute * 5, + Timeout: 30 * time.Second, }, + protocol: "tcp", writeTimeout: 30 * time.Second, readTimeout: 30 * time.Second, } @@ -150,7 +163,7 @@ func Connect(host string, port uint64, opts ...Option) (*Stash, error) { } s.o = o - if err := s.dial(address, o); err != nil { + if err := s.dial(address); err != nil { return nil, err } @@ -179,7 +192,7 @@ func (s *Stash) Write(data []byte) (int, error) { if strings.Contains(err.Error(), BrokenPipeError) { log.Printf("go-stash: %s | do re dial\n", err.Error()) // re dial ignore error - err = s.dial(s.address, s.o) + err = s.dial(s.address) if err != nil { log.Printf("go-stash: %s | do re dial\n", err.Error()) }