-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
🍮 conn: rewrite socket cmsg handling
Add support for UDP GRO and GSO on Linux and Windows.
- Loading branch information
1 parent
d37e6c8
commit cab0972
Showing
13 changed files
with
416 additions
and
127 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
package conn | ||
|
||
import "net/netip" | ||
|
||
// SocketControlMessageBufferSize specifies the buffer size for receiving socket control messages. | ||
const SocketControlMessageBufferSize = socketControlMessageBufferSize | ||
|
||
// SocketControlMessage contains information that can be parsed from or put into socket control messages. | ||
type SocketControlMessage struct { | ||
// PktinfoAddr is the IP address of the network interface the packet was received from. | ||
PktinfoAddr netip.Addr | ||
|
||
// PktinfoIfindex is the index of the network interface the packet was received from. | ||
PktinfoIfindex uint32 | ||
|
||
// SegmentSize is the UDP GRO/GSO segment size. | ||
SegmentSize uint32 | ||
} | ||
|
||
// ParseSocketControlMessage parses a sequence of socket control messages and returns the parsed information. | ||
func ParseSocketControlMessage(cmsg []byte) (m SocketControlMessage, err error) { | ||
return parseSocketControlMessage(cmsg) | ||
} | ||
|
||
// AppendTo appends the socket control message to the buffer. | ||
func (m SocketControlMessage) AppendTo(b []byte) []byte { | ||
return m.appendTo(b) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
package conn | ||
|
||
import ( | ||
"fmt" | ||
"net/netip" | ||
"unsafe" | ||
|
||
"github.com/database64128/swgp-go/slicehelper" | ||
"golang.org/x/sys/unix" | ||
) | ||
|
||
const socketControlMessageBufferSize = unix.SizeofCmsghdr + alignedSizeofInet6Pktinfo | ||
|
||
const cmsgAlignTo = 4 | ||
|
||
func cmsgAlign(n uint32) uint32 { | ||
return (n + cmsgAlignTo - 1) & ^uint32(cmsgAlignTo-1) | ||
} | ||
|
||
func parseSocketControlMessage(cmsg []byte) (m SocketControlMessage, err error) { | ||
for len(cmsg) >= unix.SizeofCmsghdr { | ||
cmsghdr := (*unix.Cmsghdr)(unsafe.Pointer(unsafe.SliceData(cmsg))) | ||
msgSize := cmsgAlign(cmsghdr.Len) | ||
if cmsghdr.Len < unix.SizeofCmsghdr || int(msgSize) > len(cmsg) { | ||
return m, fmt.Errorf("invalid control message length %d", cmsghdr.Len) | ||
} | ||
|
||
switch { | ||
case cmsghdr.Level == unix.IPPROTO_IP && cmsghdr.Type == unix.IP_PKTINFO: | ||
if len(cmsg) < unix.SizeofCmsghdr+unix.SizeofInet4Pktinfo { | ||
return m, fmt.Errorf("invalid IP_PKTINFO control message length %d", cmsghdr.Len) | ||
} | ||
var pktinfo unix.Inet4Pktinfo | ||
_ = copy(unsafe.Slice((*byte)(unsafe.Pointer(&pktinfo)), unix.SizeofInet4Pktinfo), cmsg[unix.SizeofCmsghdr:]) | ||
m.PktinfoAddr = netip.AddrFrom4(pktinfo.Spec_dst) | ||
m.PktinfoIfindex = pktinfo.Ifindex | ||
|
||
case cmsghdr.Level == unix.IPPROTO_IPV6 && cmsghdr.Type == unix.IPV6_PKTINFO: | ||
if len(cmsg) < unix.SizeofCmsghdr+unix.SizeofInet6Pktinfo { | ||
return m, fmt.Errorf("invalid IPV6_PKTINFO control message length %d", cmsghdr.Len) | ||
} | ||
var pktinfo unix.Inet6Pktinfo | ||
_ = copy(unsafe.Slice((*byte)(unsafe.Pointer(&pktinfo)), unix.SizeofInet6Pktinfo), cmsg[unix.SizeofCmsghdr:]) | ||
m.PktinfoAddr = netip.AddrFrom16(pktinfo.Addr) | ||
m.PktinfoIfindex = pktinfo.Ifindex | ||
} | ||
|
||
cmsg = cmsg[msgSize:] | ||
} | ||
|
||
return m, nil | ||
} | ||
|
||
const ( | ||
alignedSizeofInet4Pktinfo = (unix.SizeofInet4Pktinfo + unix.SizeofPtr - 1) & ^(unix.SizeofPtr - 1) | ||
alignedSizeofInet6Pktinfo = (unix.SizeofInet6Pktinfo + unix.SizeofPtr - 1) & ^(unix.SizeofPtr - 1) | ||
) | ||
|
||
func (m SocketControlMessage) appendTo(b []byte) []byte { | ||
switch { | ||
case m.PktinfoAddr.Is4(): | ||
var msgBuf []byte | ||
b, msgBuf = slicehelper.Extend(b, unix.SizeofCmsghdr+alignedSizeofInet4Pktinfo) | ||
cmsghdr := (*unix.Cmsghdr)(unsafe.Pointer(unsafe.SliceData(msgBuf))) | ||
*cmsghdr = unix.Cmsghdr{ | ||
Len: unix.SizeofCmsghdr + unix.SizeofInet4Pktinfo, | ||
Level: unix.IPPROTO_IP, | ||
Type: unix.IP_PKTINFO, | ||
} | ||
pktinfo := unix.Inet4Pktinfo{ | ||
Ifindex: m.PktinfoIfindex, | ||
Spec_dst: m.PktinfoAddr.As4(), | ||
} | ||
_ = copy(msgBuf[unix.SizeofCmsghdr:], unsafe.Slice((*byte)(unsafe.Pointer(&pktinfo)), unix.SizeofInet4Pktinfo)) | ||
|
||
case m.PktinfoAddr.Is6(): | ||
var msgBuf []byte | ||
b, msgBuf = slicehelper.Extend(b, unix.SizeofCmsghdr+alignedSizeofInet6Pktinfo) | ||
cmsghdr := (*unix.Cmsghdr)(unsafe.Pointer(unsafe.SliceData(msgBuf))) | ||
*cmsghdr = unix.Cmsghdr{ | ||
Len: unix.SizeofCmsghdr + unix.SizeofInet6Pktinfo, | ||
Level: unix.IPPROTO_IPV6, | ||
Type: unix.IPV6_PKTINFO, | ||
} | ||
pktinfo := unix.Inet6Pktinfo{ | ||
Addr: m.PktinfoAddr.As16(), | ||
Ifindex: m.PktinfoIfindex, | ||
} | ||
_ = copy(msgBuf[unix.SizeofCmsghdr:], unsafe.Slice((*byte)(unsafe.Pointer(&pktinfo)), unix.SizeofInet6Pktinfo)) | ||
} | ||
|
||
return b | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
package conn | ||
|
||
import ( | ||
"fmt" | ||
"net/netip" | ||
"unsafe" | ||
|
||
"github.com/database64128/swgp-go/slicehelper" | ||
"golang.org/x/sys/unix" | ||
) | ||
|
||
const socketControlMessageBufferSize = unix.SizeofCmsghdr + alignedSizeofInet6Pktinfo + | ||
unix.SizeofCmsghdr + alignedSizeofGROSegmentSize | ||
|
||
const sizeofGROSegmentSize = int(unsafe.Sizeof(uint16(0))) | ||
|
||
func cmsgAlign(n uint64) uint64 { | ||
return (n + unix.SizeofPtr - 1) & ^uint64(unix.SizeofPtr-1) | ||
} | ||
|
||
func parseSocketControlMessage(cmsg []byte) (m SocketControlMessage, err error) { | ||
for len(cmsg) >= unix.SizeofCmsghdr { | ||
cmsghdr := (*unix.Cmsghdr)(unsafe.Pointer(unsafe.SliceData(cmsg))) | ||
msgSize := cmsgAlign(cmsghdr.Len) | ||
if cmsghdr.Len < unix.SizeofCmsghdr || int(msgSize) > len(cmsg) { | ||
return m, fmt.Errorf("invalid control message length %d", cmsghdr.Len) | ||
} | ||
|
||
switch { | ||
case cmsghdr.Level == unix.IPPROTO_IP && cmsghdr.Type == unix.IP_PKTINFO: | ||
if len(cmsg) < unix.SizeofCmsghdr+unix.SizeofInet4Pktinfo { | ||
return m, fmt.Errorf("invalid IP_PKTINFO control message length %d", cmsghdr.Len) | ||
} | ||
var pktinfo unix.Inet4Pktinfo | ||
_ = copy(unsafe.Slice((*byte)(unsafe.Pointer(&pktinfo)), unix.SizeofInet4Pktinfo), cmsg[unix.SizeofCmsghdr:]) | ||
m.PktinfoAddr = netip.AddrFrom4(pktinfo.Spec_dst) | ||
m.PktinfoIfindex = uint32(pktinfo.Ifindex) | ||
|
||
case cmsghdr.Level == unix.IPPROTO_IPV6 && cmsghdr.Type == unix.IPV6_PKTINFO: | ||
if len(cmsg) < unix.SizeofCmsghdr+unix.SizeofInet6Pktinfo { | ||
return m, fmt.Errorf("invalid IPV6_PKTINFO control message length %d", cmsghdr.Len) | ||
} | ||
var pktinfo unix.Inet6Pktinfo | ||
_ = copy(unsafe.Slice((*byte)(unsafe.Pointer(&pktinfo)), unix.SizeofInet6Pktinfo), cmsg[unix.SizeofCmsghdr:]) | ||
m.PktinfoAddr = netip.AddrFrom16(pktinfo.Addr) | ||
m.PktinfoIfindex = pktinfo.Ifindex | ||
|
||
case cmsghdr.Level == unix.IPPROTO_UDP && cmsghdr.Type == unix.UDP_GRO: | ||
if len(cmsg) < unix.SizeofCmsghdr+sizeofGROSegmentSize { | ||
return m, fmt.Errorf("invalid UDP_GRO control message length %d", cmsghdr.Len) | ||
} | ||
var segmentSize uint16 | ||
_ = copy(unsafe.Slice((*byte)(unsafe.Pointer(&segmentSize)), sizeofGROSegmentSize), cmsg[unix.SizeofCmsghdr:]) | ||
m.SegmentSize = uint32(segmentSize) | ||
} | ||
|
||
cmsg = cmsg[msgSize:] | ||
} | ||
|
||
return m, nil | ||
} | ||
|
||
const ( | ||
alignedSizeofInet4Pktinfo = (unix.SizeofInet4Pktinfo + unix.SizeofPtr - 1) & ^(unix.SizeofPtr - 1) | ||
alignedSizeofInet6Pktinfo = (unix.SizeofInet6Pktinfo + unix.SizeofPtr - 1) & ^(unix.SizeofPtr - 1) | ||
alignedSizeofGROSegmentSize = (sizeofGROSegmentSize + unix.SizeofPtr - 1) & ^(unix.SizeofPtr - 1) | ||
) | ||
|
||
func (m SocketControlMessage) appendTo(b []byte) []byte { | ||
switch { | ||
case m.PktinfoAddr.Is4(): | ||
var msgBuf []byte | ||
b, msgBuf = slicehelper.Extend(b, unix.SizeofCmsghdr+alignedSizeofInet4Pktinfo) | ||
cmsghdr := (*unix.Cmsghdr)(unsafe.Pointer(unsafe.SliceData(msgBuf))) | ||
*cmsghdr = unix.Cmsghdr{ | ||
Len: unix.SizeofCmsghdr + unix.SizeofInet4Pktinfo, | ||
Level: unix.IPPROTO_IP, | ||
Type: unix.IP_PKTINFO, | ||
} | ||
pktinfo := unix.Inet4Pktinfo{ | ||
Ifindex: int32(m.PktinfoIfindex), | ||
Spec_dst: m.PktinfoAddr.As4(), | ||
} | ||
_ = copy(msgBuf[unix.SizeofCmsghdr:], unsafe.Slice((*byte)(unsafe.Pointer(&pktinfo)), unix.SizeofInet4Pktinfo)) | ||
|
||
case m.PktinfoAddr.Is6(): | ||
var msgBuf []byte | ||
b, msgBuf = slicehelper.Extend(b, unix.SizeofCmsghdr+alignedSizeofInet6Pktinfo) | ||
cmsghdr := (*unix.Cmsghdr)(unsafe.Pointer(unsafe.SliceData(msgBuf))) | ||
*cmsghdr = unix.Cmsghdr{ | ||
Len: unix.SizeofCmsghdr + unix.SizeofInet6Pktinfo, | ||
Level: unix.IPPROTO_IPV6, | ||
Type: unix.IPV6_PKTINFO, | ||
} | ||
pktinfo := unix.Inet6Pktinfo{ | ||
Addr: m.PktinfoAddr.As16(), | ||
Ifindex: m.PktinfoIfindex, | ||
} | ||
_ = copy(msgBuf[unix.SizeofCmsghdr:], unsafe.Slice((*byte)(unsafe.Pointer(&pktinfo)), unix.SizeofInet6Pktinfo)) | ||
} | ||
|
||
if m.SegmentSize > 0 { | ||
var msgBuf []byte | ||
b, msgBuf = slicehelper.Extend(b, unix.SizeofCmsghdr+alignedSizeofGROSegmentSize) | ||
cmsghdr := (*unix.Cmsghdr)(unsafe.Pointer(unsafe.SliceData(msgBuf))) | ||
*cmsghdr = unix.Cmsghdr{ | ||
Len: unix.SizeofCmsghdr + uint64(sizeofGROSegmentSize), | ||
Level: unix.IPPROTO_UDP, | ||
Type: unix.UDP_GRO, | ||
} | ||
segmentSize := uint16(m.SegmentSize) | ||
_ = copy(msgBuf[unix.SizeofCmsghdr:], unsafe.Slice((*byte)(unsafe.Pointer(&segmentSize)), sizeofGROSegmentSize)) | ||
} | ||
|
||
return b | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
//go:build !darwin && !linux && !windows | ||
|
||
package conn | ||
|
||
const socketControlMessageBufferSize = 0 | ||
|
||
func parseSocketControlMessage(_ []byte) (SocketControlMessage, error) { | ||
return SocketControlMessage{}, nil | ||
} | ||
|
||
func (SocketControlMessage) appendTo(b []byte) []byte { | ||
return b | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
package conn | ||
|
||
import ( | ||
"fmt" | ||
"net/netip" | ||
"unsafe" | ||
|
||
"github.com/database64128/swgp-go/slicehelper" | ||
"golang.org/x/sys/windows" | ||
) | ||
|
||
const socketControlMessageBufferSize = sizeofCmsghdr + alignedSizeofInet6Pktinfo + | ||
sizeofCmsghdr + alignedSizeofCoalescedInfo | ||
|
||
const ( | ||
sizeofPtr = int(unsafe.Sizeof(uintptr(0))) | ||
sizeofCmsghdr = int(unsafe.Sizeof(Cmsghdr{})) | ||
sizeofInet4Pktinfo = int(unsafe.Sizeof(Inet4Pktinfo{})) | ||
sizeofInet6Pktinfo = int(unsafe.Sizeof(Inet6Pktinfo{})) | ||
sizeofCoalescedInfo = int(unsafe.Sizeof(uint32(0))) | ||
) | ||
|
||
// Structure CMSGHDR from ws2def.h | ||
type Cmsghdr struct { | ||
Len uintptr | ||
Level int32 | ||
Type int32 | ||
} | ||
|
||
// Structure IN_PKTINFO from ws2ipdef.h | ||
type Inet4Pktinfo struct { | ||
Addr [4]byte | ||
Ifindex uint32 | ||
} | ||
|
||
// Structure IN6_PKTINFO from ws2ipdef.h | ||
type Inet6Pktinfo struct { | ||
Addr [16]byte | ||
Ifindex uint32 | ||
} | ||
|
||
func cmsgAlign(n int) int { | ||
return (n + sizeofPtr - 1) & ^(sizeofPtr - 1) | ||
} | ||
|
||
func parseSocketControlMessage(cmsg []byte) (m SocketControlMessage, err error) { | ||
for len(cmsg) >= sizeofCmsghdr { | ||
cmsghdr := (*Cmsghdr)(unsafe.Pointer(unsafe.SliceData(cmsg))) | ||
msgLen := int(cmsghdr.Len) | ||
msgSize := cmsgAlign(msgLen) | ||
if msgLen < sizeofCmsghdr || msgSize > len(cmsg) { | ||
return m, fmt.Errorf("invalid control message length %d", cmsghdr.Len) | ||
} | ||
|
||
switch { | ||
case cmsghdr.Level == windows.IPPROTO_IP && cmsghdr.Type == windows.IP_PKTINFO: | ||
if len(cmsg) < sizeofCmsghdr+sizeofInet4Pktinfo { | ||
return m, fmt.Errorf("invalid IP_PKTINFO control message length %d", cmsghdr.Len) | ||
} | ||
var pktinfo Inet4Pktinfo | ||
_ = copy(unsafe.Slice((*byte)(unsafe.Pointer(&pktinfo)), sizeofInet4Pktinfo), cmsg[sizeofCmsghdr:]) | ||
m.PktinfoAddr = netip.AddrFrom4(pktinfo.Addr) | ||
m.PktinfoIfindex = pktinfo.Ifindex | ||
|
||
case cmsghdr.Level == windows.IPPROTO_IPV6 && cmsghdr.Type == windows.IPV6_PKTINFO: | ||
if len(cmsg) < sizeofCmsghdr+sizeofInet6Pktinfo { | ||
return m, fmt.Errorf("invalid IPV6_PKTINFO control message length %d", cmsghdr.Len) | ||
} | ||
var pktinfo Inet6Pktinfo | ||
_ = copy(unsafe.Slice((*byte)(unsafe.Pointer(&pktinfo)), sizeofInet6Pktinfo), cmsg[sizeofCmsghdr:]) | ||
m.PktinfoAddr = netip.AddrFrom16(pktinfo.Addr) | ||
m.PktinfoIfindex = pktinfo.Ifindex | ||
|
||
case cmsghdr.Level == windows.IPPROTO_UDP && cmsghdr.Type == windows.UDP_COALESCED_INFO: | ||
if len(cmsg) < sizeofCmsghdr+sizeofCoalescedInfo { | ||
return m, fmt.Errorf("invalid UDP_COALESCED_INFO control message length %d", cmsghdr.Len) | ||
} | ||
_ = copy(unsafe.Slice((*byte)(unsafe.Pointer(&m.SegmentSize)), sizeofCoalescedInfo), cmsg[sizeofCmsghdr:]) | ||
} | ||
|
||
cmsg = cmsg[msgSize:] | ||
} | ||
|
||
return m, nil | ||
} | ||
|
||
const ( | ||
alignedSizeofInet4Pktinfo = (sizeofInet4Pktinfo + sizeofPtr - 1) & ^(sizeofPtr - 1) | ||
alignedSizeofInet6Pktinfo = (sizeofInet6Pktinfo + sizeofPtr - 1) & ^(sizeofPtr - 1) | ||
alignedSizeofCoalescedInfo = (sizeofCoalescedInfo + sizeofPtr - 1) & ^(sizeofPtr - 1) | ||
) | ||
|
||
func (m SocketControlMessage) appendTo(b []byte) []byte { | ||
switch { | ||
case m.PktinfoAddr.Is4(): | ||
var msgBuf []byte | ||
b, msgBuf = slicehelper.Extend(b, sizeofCmsghdr+alignedSizeofInet4Pktinfo) | ||
cmsghdr := (*Cmsghdr)(unsafe.Pointer(unsafe.SliceData(msgBuf))) | ||
*cmsghdr = Cmsghdr{ | ||
Len: uintptr(sizeofCmsghdr + sizeofInet4Pktinfo), | ||
Level: windows.IPPROTO_IP, | ||
Type: windows.IP_PKTINFO, | ||
} | ||
pktinfo := Inet4Pktinfo{ | ||
Addr: m.PktinfoAddr.As4(), | ||
Ifindex: m.PktinfoIfindex, | ||
} | ||
_ = copy(msgBuf[sizeofCmsghdr:], unsafe.Slice((*byte)(unsafe.Pointer(&pktinfo)), sizeofInet4Pktinfo)) | ||
|
||
case m.PktinfoAddr.Is6(): | ||
var msgBuf []byte | ||
b, msgBuf = slicehelper.Extend(b, sizeofCmsghdr+alignedSizeofInet6Pktinfo) | ||
cmsghdr := (*Cmsghdr)(unsafe.Pointer(unsafe.SliceData(msgBuf))) | ||
*cmsghdr = Cmsghdr{ | ||
Len: uintptr(sizeofCmsghdr + sizeofInet6Pktinfo), | ||
Level: windows.IPPROTO_IPV6, | ||
Type: windows.IPV6_PKTINFO, | ||
} | ||
pktinfo := Inet6Pktinfo{ | ||
Addr: m.PktinfoAddr.As16(), | ||
Ifindex: m.PktinfoIfindex, | ||
} | ||
_ = copy(msgBuf[sizeofCmsghdr:], unsafe.Slice((*byte)(unsafe.Pointer(&pktinfo)), sizeofInet6Pktinfo)) | ||
} | ||
|
||
if m.SegmentSize > 0 { | ||
var msgBuf []byte | ||
b, msgBuf = slicehelper.Extend(b, sizeofCmsghdr+alignedSizeofCoalescedInfo) | ||
cmsghdr := (*Cmsghdr)(unsafe.Pointer(unsafe.SliceData(msgBuf))) | ||
*cmsghdr = Cmsghdr{ | ||
Len: uintptr(sizeofCmsghdr + sizeofCoalescedInfo), | ||
Level: windows.IPPROTO_UDP, | ||
Type: windows.UDP_COALESCED_INFO, | ||
} | ||
_ = copy(msgBuf[sizeofCmsghdr:], unsafe.Slice((*byte)(unsafe.Pointer(&m.SegmentSize)), sizeofCoalescedInfo)) | ||
} | ||
|
||
return b | ||
} |
Oops, something went wrong.