diff --git a/go.mod b/go.mod index 6f807eb..dd79044 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,4 @@ module github.com/database64128/tfo-go/v2 go 1.21.0 -require golang.org/x/sys v0.24.0 +require golang.org/x/sys v0.24.1-0.20240828075529-ed67b1566aaf diff --git a/go.sum b/go.sum index d88e7bd..09b5eab 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,2 @@ -golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg= -golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.24.1-0.20240828075529-ed67b1566aaf h1:q2Cx0keWwW5HecyZeIyA3DCuupo8A/zjDqsOQK0+Z80= +golang.org/x/sys v0.24.1-0.20240828075529-ed67b1566aaf/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= diff --git a/syscall_darwin.go b/syscall_darwin.go deleted file mode 100644 index b29e931..0000000 --- a/syscall_darwin.go +++ /dev/null @@ -1,157 +0,0 @@ -package tfo - -import ( - "syscall" - "unsafe" - - "golang.org/x/sys/unix" -) - -// Do the interface allocations only once for common -// Errno values. -var ( - errEAGAIN error = syscall.EAGAIN - errEINVAL error = syscall.EINVAL - errENOENT error = syscall.ENOENT -) - -// errnoErr returns common boxed Errno values, to prevent -// allocations at runtime. -func errnoErr(e syscall.Errno) error { - switch e { - case 0: - return nil - case unix.EAGAIN: - return errEAGAIN - case unix.EINVAL: - return errEINVAL - case unix.ENOENT: - return errENOENT - } - return e -} - -func sockaddrp(sa syscall.Sockaddr) (unsafe.Pointer, uint32, error) { - switch sa := sa.(type) { - case nil: - return nil, 0, nil - case *syscall.SockaddrInet4: - return (*sockaddrInet4)(unsafe.Pointer(sa)).sockaddr() - case *syscall.SockaddrInet6: - return (*sockaddrInet6)(unsafe.Pointer(sa)).sockaddr() - default: - return nil, 0, syscall.EAFNOSUPPORT - } -} - -// Copied from src/syscall/syscall_unix.go -type sockaddrInet4 struct { - Port int - Addr [4]byte - raw syscall.RawSockaddrInet4 -} - -// Copied from src/syscall/syscall_unix.go -type sockaddrInet6 struct { - Port int - ZoneId uint32 - Addr [16]byte - raw syscall.RawSockaddrInet6 -} - -func (sa *sockaddrInet4) sockaddr() (unsafe.Pointer, uint32, error) { - if sa.Port < 0 || sa.Port > 0xFFFF { - return nil, 0, syscall.EINVAL - } - sa.raw.Len = syscall.SizeofSockaddrInet4 - sa.raw.Family = syscall.AF_INET - p := (*[2]byte)(unsafe.Pointer(&sa.raw.Port)) - p[0] = byte(sa.Port >> 8) - p[1] = byte(sa.Port) - sa.raw.Addr = sa.Addr - return unsafe.Pointer(&sa.raw), uint32(sa.raw.Len), nil -} - -func (sa *sockaddrInet6) sockaddr() (unsafe.Pointer, uint32, error) { - if sa.Port < 0 || sa.Port > 0xFFFF { - return nil, 0, syscall.EINVAL - } - sa.raw.Len = syscall.SizeofSockaddrInet6 - sa.raw.Family = syscall.AF_INET6 - p := (*[2]byte)(unsafe.Pointer(&sa.raw.Port)) - p[0] = byte(sa.Port >> 8) - p[1] = byte(sa.Port) - sa.raw.Scope_id = sa.ZoneId - sa.raw.Addr = sa.Addr - return unsafe.Pointer(&sa.raw), uint32(sa.raw.Len), nil -} - -type sa_endpoints_t struct { - sae_srcif uint - sae_srcaddr unsafe.Pointer - sae_srcaddrlen uint32 - sae_dstaddr unsafe.Pointer - sae_dstaddrlen uint32 -} - -const ( - SAE_ASSOCID_ANY = 0 - CONNECT_RESUME_ON_READ_WRITE = 0x1 - CONNECT_DATA_IDEMPOTENT = 0x2 - CONNECT_DATA_AUTHENTICATED = 0x4 -) - -// Connectx enables TFO if a non-empty buf is passed. -// If an empty buf is passed, TFO is not enabled. -func Connectx(s int, srcif uint, from syscall.Sockaddr, to syscall.Sockaddr, buf []byte) (uint, error) { - from_ptr, from_n, err := sockaddrp(from) - if err != nil { - return 0, err - } - - to_ptr, to_n, err := sockaddrp(to) - if err != nil { - return 0, err - } - - sae := sa_endpoints_t{ - sae_srcif: srcif, - sae_srcaddr: from_ptr, - sae_srcaddrlen: from_n, - sae_dstaddr: to_ptr, - sae_dstaddrlen: to_n, - } - - var ( - flags uint - iov *unix.Iovec - iovcnt uint - ) - - if len(buf) > 0 { - flags = CONNECT_DATA_IDEMPOTENT - iov = &unix.Iovec{ - Base: &buf[0], - Len: uint64(len(buf)), - } - iovcnt = 1 - } - - var bytesSent uint - - r1, _, e1 := unix.Syscall9(unix.SYS_CONNECTX, - uintptr(s), - uintptr(unsafe.Pointer(&sae)), - SAE_ASSOCID_ANY, - uintptr(flags), - uintptr(unsafe.Pointer(iov)), - uintptr(iovcnt), - uintptr(unsafe.Pointer(&bytesSent)), - 0, - 0) - ret := int(r1) - if ret == -1 { - err = errnoErr(e1) - } - return bytesSent, err -} diff --git a/tfo_bsd+linux.go b/tfo_bsd+linux.go index f395971..11a2241 100644 --- a/tfo_bsd+linux.go +++ b/tfo_bsd+linux.go @@ -96,13 +96,18 @@ func (d *Dialer) dialSingle(ctx context.Context, network string, laddr, raddr *n } } + rusa, err := unixSockaddrFromSyscallSockaddr(rsa) + if err != nil { + return nil, err + } + var ( n int canFallback bool ) if err = connWriteFunc(ctx, f, func(f *os.File) (err error) { - n, canFallback, err = connect(rawConn, rsa, b) + n, canFallback, err = connect(rawConn, rusa, b) return err }); err != nil { if d.Fallback && canFallback { @@ -127,7 +132,27 @@ func (d *Dialer) dialSingle(ctx context.Context, network string, laddr, raddr *n return c.(*net.TCPConn), err } -func connect(rawConn syscall.RawConn, rsa syscall.Sockaddr, b []byte) (n int, canFallback bool, err error) { +func unixSockaddrFromSyscallSockaddr(sa syscall.Sockaddr) (unix.Sockaddr, error) { + if sa == nil { + return nil, nil + } + switch sa := sa.(type) { + case *syscall.SockaddrInet4: + return &unix.SockaddrInet4{ + Port: sa.Port, + Addr: sa.Addr, + }, nil + case *syscall.SockaddrInet6: + return &unix.SockaddrInet6{ + Port: sa.Port, + ZoneId: sa.ZoneId, + Addr: sa.Addr, + }, nil + } + return nil, errors.New("unsupported sockaddr type") +} + +func connect(rawConn syscall.RawConn, rsa unix.Sockaddr, b []byte) (n int, canFallback bool, err error) { var done bool if perr := rawConn.Write(func(fd uintptr) bool { diff --git a/tfo_darwin.go b/tfo_darwin.go index c2f579e..d31de8c 100644 --- a/tfo_darwin.go +++ b/tfo_darwin.go @@ -99,7 +99,20 @@ const setTFODialerFromSocketSockoptName = "TCP_FASTOPEN_FORCE_ENABLE" const connectSyscallName = "connectx" -func doConnect(fd uintptr, rsa syscall.Sockaddr, b []byte) (int, error) { - n, err := Connectx(int(fd), 0, nil, rsa, b) +func doConnect(fd uintptr, rsa unix.Sockaddr, b []byte) (int, error) { + var ( + flags uint32 + iov []unix.Iovec + ) + if len(b) > 0 { + flags = unix.CONNECT_DATA_IDEMPOTENT + iov = []unix.Iovec{ + { + Base: &b[0], + Len: uint64(len(b)), + }, + } + } + n, err := unix.Connectx(int(fd), 0, nil, rsa, unix.SAE_ASSOCID_ANY, flags, iov, nil) return int(n), err } diff --git a/tfo_freebsd+linux.go b/tfo_freebsd+linux.go index 21dab1f..da0185a 100644 --- a/tfo_freebsd+linux.go +++ b/tfo_freebsd+linux.go @@ -3,8 +3,6 @@ package tfo import ( - "syscall" - "golang.org/x/sys/unix" ) @@ -18,6 +16,6 @@ func (*Dialer) setIPv6Only(fd int, family int, ipv6only bool) error { const connectSyscallName = "sendmsg" -func doConnect(fd uintptr, rsa syscall.Sockaddr, b []byte) (int, error) { - return syscall.SendmsgN(int(fd), b, nil, rsa, sendtoImplicitConnectFlag|unix.MSG_NOSIGNAL) +func doConnect(fd uintptr, rsa unix.Sockaddr, b []byte) (int, error) { + return unix.SendmsgN(int(fd), b, nil, rsa, sendtoImplicitConnectFlag|unix.MSG_NOSIGNAL) }