Skip to content

Commit

Permalink
🐹 all: use github.com/database64128/netx-go to reduce linkname usage
Browse files Browse the repository at this point in the history
  • Loading branch information
database64128 committed Sep 4, 2024
1 parent e79d3e3 commit 296510d
Show file tree
Hide file tree
Showing 12 changed files with 151 additions and 169 deletions.
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@ module github.com/database64128/tfo-go/v2
go 1.21.0

require golang.org/x/sys v0.24.1-0.20240828075529-ed67b1566aaf

require github.com/database64128/netx-go v0.0.0-20240904075656-1efc34d35e1a
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
github.com/database64128/netx-go v0.0.0-20240904075656-1efc34d35e1a h1:EFUgNuSxsGOE6zP49HlhhXx+w0SZs6ADuqVEMaKSFFU=
github.com/database64128/netx-go v0.0.0-20240904075656-1efc34d35e1a/go.mod h1:uMBPfZT3hyBlp6X8qIToro7wX+zymQTMe1bxfqUsbIs=
golang.org/x/sys v0.24.1-0.20240828075529-ed67b1566aaf h1:q2Cx0keWwW5HecyZeIyA3DCuupo8A/zjDqsOQK0+Z80=
golang.org/x/sys v0.24.1-0.20240828075529-ed67b1566aaf/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
15 changes: 6 additions & 9 deletions netpoll_windows_checklinkname0.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,11 @@ package tfo
import (
"net"
"sync"
"syscall"
"time"
_ "unsafe"
)

//go:linkname sockaddrToTCP net.sockaddrToTCP
func sockaddrToTCP(sa syscall.Sockaddr) net.Addr
"golang.org/x/sys/windows"
)

//go:linkname execIO internal/poll.execIO
func execIO(o *operation, submit func(o *operation) error) (int, error)
Expand All @@ -26,7 +24,7 @@ type pFD struct {
fdmuW uint32

// System file descriptor. Immutable until Close.
Sysfd syscall.Handle
Sysfd windows.Handle

// Read operation.
rop operation
Expand Down Expand Up @@ -65,10 +63,9 @@ type pFD struct {
kind byte
}

func (fd *pFD) ConnectEx(ra syscall.Sockaddr, b []byte) (n int, err error) {
fd.wop.sa = ra
func (fd *pFD) ConnectEx(ra windows.Sockaddr, b []byte) (n int, err error) {
n, err = execIO(&fd.wop, func(o *operation) error {
return syscall.ConnectEx(o.fd.Sysfd, o.sa, &b[0], uint32(len(b)), &o.qty, &o.o)
return windows.ConnectEx(o.fd.Sysfd, ra, &b[0], uint32(len(b)), &o.qty, &o.o)
})
return
}
Expand All @@ -89,7 +86,7 @@ type netFD struct {
}

//go:linkname newFD net.newFD
func newFD(sysfd syscall.Handle, family, sotype int, net string) (*netFD, error)
func newFD(sysfd windows.Handle, family, sotype int, net string) (*netFD, error)

//go:linkname netFDInit net.(*netFD).init
func netFDInit(fd *netFD) error
Expand Down
14 changes: 6 additions & 8 deletions netpoll_windows_go121.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
package tfo

import (
"syscall"

"golang.org/x/sys/windows"
)

Expand All @@ -14,7 +12,7 @@ import (
type operation struct {
// Used by IOCP interface, it must be first field
// of the struct, as our code rely on it.
o syscall.Overlapped
o windows.Overlapped

// fields used by runtime.netpoll
runtimeCtx uintptr
Expand All @@ -24,12 +22,12 @@ type operation struct {

// fields used only by net package
fd *pFD
buf syscall.WSABuf
buf windows.WSABuf
msg windows.WSAMsg
sa syscall.Sockaddr
rsa *syscall.RawSockaddrAny
sa windows.Sockaddr
rsa *windows.RawSockaddrAny
rsan int32
handle syscall.Handle
handle windows.Handle
flags uint32
bufs []syscall.WSABuf
bufs []windows.WSABuf
}
14 changes: 6 additions & 8 deletions netpoll_windows_go123_checklinkname0.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
package tfo

import (
"syscall"

"golang.org/x/sys/windows"
)

Expand All @@ -14,21 +12,21 @@ import (
type operation struct {
// Used by IOCP interface, it must be first field
// of the struct, as our code rely on it.
o syscall.Overlapped
o windows.Overlapped

// fields used by runtime.netpoll
runtimeCtx uintptr
mode int32

// fields used only by net package
fd *pFD
buf syscall.WSABuf
buf windows.WSABuf
msg windows.WSAMsg
sa syscall.Sockaddr
rsa *syscall.RawSockaddrAny
sa windows.Sockaddr
rsa *windows.RawSockaddrAny
rsan int32
handle syscall.Handle
handle windows.Handle
flags uint32
qty uint32
bufs []syscall.WSABuf
bufs []windows.WSABuf
}
11 changes: 0 additions & 11 deletions tfo.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@ import (
"context"
"errors"
"net"
"os"
"sync/atomic"
"syscall"
"time"
)

Expand Down Expand Up @@ -228,15 +226,6 @@ func opAddr(a *net.TCPAddr) net.Addr {
return a
}

// wrapSyscallError takes an error and a syscall name. If the error is
// a syscall.Errno, it wraps it in a os.SyscallError using the syscall name.
func wrapSyscallError(name string, err error) error {
if _, ok := err.(syscall.Errno); ok {
err = os.NewSyscallError(name, err)
}
return err
}

// aLongTimeAgo is a non-zero time, far in the past, used for immediate deadlines.
var aLongTimeAgo = time.Unix(0, 0)

Expand Down
71 changes: 40 additions & 31 deletions tfo_bsd+linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"os"
"syscall"

"github.com/database64128/netx-go"
"golang.org/x/sys/unix"
)

Expand All @@ -34,19 +35,7 @@ func ctrlNetwork(network string, family int) string {
}

func (d *Dialer) dialSingle(ctx context.Context, network string, laddr, raddr *net.TCPAddr, b []byte, ctrlCtxFn func(context.Context, string, string, syscall.RawConn) error) (*net.TCPConn, error) {
ltsa := (*tcpSockaddr)(laddr)
rtsa := (*tcpSockaddr)(raddr)
family, ipv6only := favoriteAddrFamily(network, ltsa, rtsa, "dial")

lsa, err := ltsa.sockaddr(family)
if err != nil {
return nil, err
}

rsa, err := rtsa.sockaddr(family)
if err != nil {
return nil, err
}
family, ipv6only := favoriteDialAddrFamily(network, laddr, raddr)

fd, err := d.socket(family)
if err != nil {
Expand All @@ -55,18 +44,18 @@ func (d *Dialer) dialSingle(ctx context.Context, network string, laddr, raddr *n

if err = d.setIPv6Only(fd, family, ipv6only); err != nil {
unix.Close(fd)
return nil, wrapSyscallError("setsockopt(IPV6_V6ONLY)", err)
return nil, os.NewSyscallError("setsockopt(IPV6_V6ONLY)", err)
}

if err = setNoDelay(fd, 1); err != nil {
unix.Close(fd)
return nil, wrapSyscallError("setsockopt(TCP_NODELAY)", err)
return nil, os.NewSyscallError("setsockopt(TCP_NODELAY)", err)
}

if err = setTFODialerFromSocket(uintptr(fd)); err != nil {
if !d.Fallback || !errors.Is(err, errors.ErrUnsupported) {
unix.Close(fd)
return nil, wrapSyscallError("setsockopt("+setTFODialerFromSocketSockoptName+")", err)
return nil, os.NewSyscallError("setsockopt("+setTFODialerFromSocketSockoptName+")", err)
}
runtimeDialTFOSupport.storeNone()
}
Expand All @@ -86,8 +75,13 @@ func (d *Dialer) dialSingle(ctx context.Context, network string, laddr, raddr *n
}

if laddr != nil {
lsa, err := unixSockaddrFromTCPAddr(laddr, family)
if err != nil {
return nil, err
}

if cErr := rawConn.Control(func(fd uintptr) {
err = syscall.Bind(int(fd), lsa)
err = unix.Bind(int(fd), lsa)
}); cErr != nil {
return nil, cErr
}
Expand All @@ -96,7 +90,7 @@ func (d *Dialer) dialSingle(ctx context.Context, network string, laddr, raddr *n
}
}

rusa, err := unixSockaddrFromSyscallSockaddr(rsa)
rsa, err := unixSockaddrFromTCPAddr(raddr, family)
if err != nil {
return nil, err
}
Expand All @@ -107,7 +101,7 @@ func (d *Dialer) dialSingle(ctx context.Context, network string, laddr, raddr *n
)

if err = connWriteFunc(ctx, f, func(f *os.File) (err error) {
n, canFallback, err = connect(rawConn, rusa, b)
n, canFallback, err = connect(rawConn, rsa, b)
return err
}); err != nil {
if d.Fallback && canFallback {
Expand All @@ -132,24 +126,39 @@ func (d *Dialer) dialSingle(ctx context.Context, network string, laddr, raddr *n
return c.(*net.TCPConn), err
}

func unixSockaddrFromSyscallSockaddr(sa syscall.Sockaddr) (unix.Sockaddr, error) {
if sa == nil {
func unixSockaddrFromTCPAddr(a *net.TCPAddr, family int) (unix.Sockaddr, error) {
if a == nil {
return nil, nil
}
switch sa := sa.(type) {
case *syscall.SockaddrInet4:
ip := a.IP
switch family {
case unix.AF_INET:
if len(ip) == 0 {
ip = net.IPv4zero
}
ip4 := ip.To4()
if ip4 == nil {
return nil, &net.AddrError{Err: "non-IPv4 address", Addr: ip.String()}
}
return &unix.SockaddrInet4{
Port: sa.Port,
Addr: sa.Addr,
Port: a.Port,
Addr: [4]byte(ip4),
}, nil
case *syscall.SockaddrInet6:
case unix.AF_INET6:
if len(ip) == 0 || ip.Equal(net.IPv4zero) {
ip = net.IPv6zero
}
ip6 := ip.To16()
if ip6 == nil {
return nil, &net.AddrError{Err: "non-IPv6 address", Addr: ip.String()}
}
return &unix.SockaddrInet6{
Port: sa.Port,
ZoneId: sa.ZoneId,
Addr: sa.Addr,
Port: a.Port,
ZoneId: uint32(netx.ZoneCache.Index(a.Zone)),
Addr: [16]byte(ip6),
}, nil
}
return nil, errors.New("unsupported sockaddr type")
return nil, &net.AddrError{Err: "invalid address family", Addr: ip.String()}
}

func connect(rawConn syscall.RawConn, rsa unix.Sockaddr, b []byte) (n int, canFallback bool, err error) {
Expand Down Expand Up @@ -187,7 +196,7 @@ func connect(rawConn syscall.RawConn, rsa unix.Sockaddr, b []byte) (n int, canFa
func getSocketError(fd int, call string) error {
nerr, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_ERROR)
if err != nil {
return wrapSyscallError("getsockopt", err)
return os.NewSyscallError("getsockopt", err)
}
if nerr != 0 {
return os.NewSyscallError(call, syscall.Errno(nerr))
Expand Down
Loading

0 comments on commit 296510d

Please sign in to comment.