Skip to content

Commit

Permalink
🪇 windows: use linkname on netFD
Browse files Browse the repository at this point in the history
- We no longer need to call some of the low-level netpoll functions.
- We can now attach a finalizer like std does.
  • Loading branch information
database64128 committed Feb 1, 2024
1 parent 0d9a898 commit 15055b2
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 46 deletions.
68 changes: 35 additions & 33 deletions netpoll_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,13 @@ import (
"net"
"sync"
"syscall"
"time"
_ "unsafe"

"golang.org/x/sys/windows"
)

//go:linkname sockaddrToTCP net.sockaddrToTCP
func sockaddrToTCP(sa syscall.Sockaddr) net.Addr

//go:linkname runtime_pollServerInit internal/poll.runtime_pollServerInit
func runtime_pollServerInit()

//go:linkname runtime_pollOpen internal/poll.runtime_pollOpen
func runtime_pollOpen(fd uintptr) (uintptr, int)

// Copied from src/internal/poll/fd_poll_runtime.go
var serverInit sync.Once

//go:linkname execIO internal/poll.execIO
func execIO(o *operation, submit func(o *operation) error) (int, error)

Expand Down Expand Up @@ -73,22 +63,6 @@ type pFD struct {
kind byte
}

func (fd *pFD) init() error {
serverInit.Do(runtime_pollServerInit)
ctx, errno := runtime_pollOpen(uintptr(fd.Sysfd))
if errno != 0 {
return syscall.Errno(errno)
}
fd.pd = ctx
fd.rop.mode = 'r'
fd.wop.mode = 'w'
fd.rop.fd = fd
fd.wop.fd = fd
fd.rop.runtimeCtx = fd.pd
fd.wop.runtimeCtx = fd.pd
return nil
}

func (fd *pFD) ConnectEx(ra syscall.Sockaddr, b []byte) (n int, err error) {
fd.wop.sa = ra
n, err = execIO(&fd.wop, func(o *operation) error {
Expand All @@ -112,15 +86,43 @@ type netFD struct {
raddr net.Addr
}

//go:linkname newFD net.newFD
func newFD(sysfd syscall.Handle, family, sotype int, net string) (*netFD, error)

//go:linkname netFDInit net.(*netFD).init
func netFDInit(fd *netFD) error

//go:linkname netFDClose net.(*netFD).Close
func netFDClose(fd *netFD) error

//go:linkname netFDCtrlNetwork net.(*netFD).ctrlNetwork
func netFDCtrlNetwork(fd *netFD) string

//go:linkname netFDWrite net.(*netFD).Write
func netFDWrite(fd *netFD, p []byte) (int, error)

//go:linkname netFDSetWriteDeadline net.(*netFD).SetWriteDeadline
func netFDSetWriteDeadline(fd *netFD, t time.Time) error

func (fd *netFD) init() error {
return netFDInit(fd)
}

func (fd *netFD) Close() error {
return netFDClose(fd)
}

func (fd *netFD) ctrlNetwork() string {
if fd.net == "tcp4" || fd.family == windows.AF_INET {
return "tcp4"
}
return "tcp6"
return netFDCtrlNetwork(fd)
}

//go:linkname newFD net.newFD
func newFD(sysfd syscall.Handle, family, sotype int, net string) (*netFD, error)
func (fd *netFD) Write(p []byte) (int, error) {
return netFDWrite(fd, p)
}

func (fd *netFD) SetWriteDeadline(t time.Time) error {
return netFDSetWriteDeadline(fd, t)
}

// Copied from src/net/rawconn.go
type rawConn struct {
Expand Down
26 changes: 13 additions & 13 deletions tfo_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"net"
"os"
"runtime"
"syscall"
"unsafe"

Expand Down Expand Up @@ -67,44 +68,42 @@ func (d *Dialer) dialSingle(ctx context.Context, network string, laddr, raddr *n
return nil, err
}

tc := (*net.TCPConn)(unsafe.Pointer(&fd))

if err = setIPv6Only(handle, family, ipv6only); err != nil {
tc.Close()
fd.Close()
return nil, wrapSyscallError("setsockopt(IPV6_V6ONLY)", err)
}

if err = setNoDelay(handle, 1); err != nil {
tc.Close()
fd.Close()
return nil, wrapSyscallError("setsockopt(TCP_NODELAY)", err)
}

if err = setTFODialer(uintptr(handle)); err != nil {
if !d.Fallback || !errors.Is(err, errors.ErrUnsupported) {
tc.Close()
fd.Close()
return nil, wrapSyscallError("setsockopt(TCP_FASTOPEN)", err)
}
runtimeDialTFOSupport.storeNone()
}

if ctrlCtxFn != nil {
if err = ctrlCtxFn(ctx, fd.ctrlNetwork(), raddr.String(), newRawConn(fd)); err != nil {
tc.Close()
fd.Close()
return nil, err
}
}

if err = syscall.Bind(syscall.Handle(handle), lsa); err != nil {
tc.Close()
fd.Close()
return nil, wrapSyscallError("bind", err)
}

if err = fd.pfd.init(); err != nil {
tc.Close()
if err = fd.init(); err != nil {
fd.Close()
return nil, err
}

if err = connWriteFunc(ctx, tc, func(c *net.TCPConn) error {
if err = connWriteFunc(ctx, fd, func(fd *netFD) error {
n, err := fd.pfd.ConnectEx(rsa, b)
if err != nil {
return os.NewSyscallError("connectex", err)
Expand All @@ -127,16 +126,17 @@ func (d *Dialer) dialSingle(ctx context.Context, network string, laddr, raddr *n
fd.raddr = sockaddrToTCP(rsa)

if n < len(b) {
if _, err = tc.Write(b[n:]); err != nil {
if _, err = fd.Write(b[n:]); err != nil {
return err
}
}

return nil
}); err != nil {
tc.Close()
fd.Close()
return nil, err
}

return tc, nil
runtime.SetFinalizer(fd, netFDClose)
return (*net.TCPConn)(unsafe.Pointer(&fd)), nil
}

0 comments on commit 15055b2

Please sign in to comment.