From 1ddb6779d258fd3e89bc6a7a389bad57f367ca1b Mon Sep 17 00:00:00 2001 From: database64128 Date: Sat, 28 May 2022 21:05:30 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=A6=AD=20Refactor=20handlers?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Zero-overhead handler: Conceal all-zero MAC2 by using XChaCha20-Poly1305 to encrypt part of handshake packets. This is a breaking change. - Handler interface: Drop unnecessary maxPacketLen. - Better handler tests. --- README.md | 2 +- packet/handler.go | 52 ++++++++----- packet/handler_test.go | 42 ++++++++--- packet/paranoid.go | 60 +++++++++------ packet/paranoid_test.go | 25 ++----- packet/zerooverhead.go | 145 +++++++++++++++++++++++++++--------- packet/zerooverhead_test.go | 61 ++++++++++++--- service/client.go | 20 ++--- service/client_linux.go | 17 ++--- service/server.go | 25 +++---- service/server_linux.go | 16 ++-- service/service.go | 7 ++ 12 files changed, 316 insertions(+), 156 deletions(-) diff --git a/README.md b/README.md index c6e1c67..f024685 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ ### 1. Zero overhead -Simply AES encrypt the first 16 bytes of all packets. Handshake packets (message type 1, 2, 3) are also randomly padded to look like normal traffic. +The first 16 bytes of all packets are encrypted using an AES block cipher. The remainder of handshake packets (message type 1, 2, 3) are also randomly padded and encrypted using an XChaCha20-Poly1305 AEAD cipher to blend into normal traffic. ### 2. Paranoid diff --git a/packet/handler.go b/packet/handler.go index 2daa31e..5a43b99 100644 --- a/packet/handler.go +++ b/packet/handler.go @@ -1,6 +1,8 @@ // Package packet contains types and methods that transform WireGuard packets. package packet +import "errors" + const ( WireGuardMessageTypeHandshakeInitiation = 1 WireGuardMessageTypeHandshakeResponse = 2 @@ -12,6 +14,27 @@ const ( WireGuardMessageLengthHandshakeCookieReply = 64 ) +var ( + ErrPacketSize = errors.New("packet is too big or too small to be processed") + ErrPayloadLength = errors.New("payload length field value is out of range") +) + +type HandlerErr struct { + Err error + Message string +} + +func (e *HandlerErr) Unwrap() error { + return e.Err +} + +func (e *HandlerErr) Error() string { + if e.Message == "" { + return e.Err.Error() + } + return e.Message +} + // Handler encrypts WireGuard packets and decrypts swgp packets. type Handler interface { // FrontOverhead returns the headroom to reserve in buffer before payload. @@ -20,25 +43,18 @@ type Handler interface { // RearOverhead returns the headroom to reserve in buffer after payload. RearOverhead() int - // EncryptZeroCopy encrypts a WireGuard packet and returns a swgp packet - // without copying or incurring any allocations. - // - // buf must have at least FrontOverhead() bytes before and RearOverhead() bytes - // after the WireGuard packet. - // - // In other words, start must not be less than FrontOverhead(), - // len(buf) - start - max(length, maxPacketLen) must not be less than RearOverhead(). + // EncryptZeroCopy encrypts a WireGuard packet and returns a swgp packet without copying or incurring any allocations. // - // length is allowed to be greater than maxPacketLen (IP fragmentation). + // The WireGuard packet starts at buf[wgPacketStart] and its length is specified by wgPacketLength. + // The returned swgp packet starts at buf[swgpPacketStart] and its length is specified by swgpPacketLength. // - // maxPacketLen is the maximum payload length of a single unfragmented UDP packet. - // - // For IPv4, maxPacketLen = MTU - 20 (IPv4 header) - 8 (UDP header). - // - // For IPv6, maxPacketLen = MTU - 40 (IPv6 header) - 8 (UDP header). - EncryptZeroCopy(buf []byte, start, length, maxPacketLen int) (swgpPacket []byte, err error) + // buf must have at least FrontOverhead() bytes before and RearOverhead() bytes after the WireGuard packet. + // In other words, start must not be less than FrontOverhead(), len(buf) must not be less than start + length + RearOverhead(). + EncryptZeroCopy(buf []byte, wgPacketStart, wgPacketLength int) (swgpPacketStart, swgpPacketLength int, err error) - // DecryptZeroCopy decrypts a swgp packet and returns a WireGuard packet - // without copying or incurring any allocations. - DecryptZeroCopy(swgpPacket []byte) (wgPacket []byte, err error) + // DecryptZeroCopy decrypts a swgp packet and returns a WireGuard packet without copying or incurring any allocations. + // + // The swgp packet starts at buf[swgpPacketStart] and its length is specified by swgpPacketLength. + // The returned WireGuard packet starts at buf[wgPacketStart] and its length is specified by wgPacketLength. + DecryptZeroCopy(buf []byte, swgpPacketStart, swgpPacketLength int) (wgPacketStart, wgPacketLength int, err error) } diff --git a/packet/handler_test.go b/packet/handler_test.go index d088117..49bc575 100644 --- a/packet/handler_test.go +++ b/packet/handler_test.go @@ -1,19 +1,36 @@ package packet -import "testing" +import ( + "crypto/rand" + "errors" + "testing" +) -func testHandler(t *testing.T, msgType byte, length int, h Handler, verifyFunc func(t *testing.T, wgPacket, swgpPacket, decryptedWgPacket []byte)) { - // Reserve a minimum of 24 bytes to ensure code can handle extra headroom. - frontHeadroom, rearHeadroom := 24, 24 +func testHandler( + t *testing.T, + msgType byte, + length, extraFrontHeadroom, extraRearHeadroom int, + h Handler, + expectedEncryptErr, expectedDecryptErr error, + verifyFunc func(t *testing.T, wgPacket, swgpPacket, decryptedWgPacket []byte), +) { + var frontHeadroom, rearHeadroom int frontOverhead, rearOverhead := h.FrontOverhead(), h.RearOverhead() if frontOverhead > frontHeadroom { frontHeadroom = frontOverhead } + frontHeadroom += extraFrontHeadroom if rearOverhead > rearHeadroom { rearHeadroom = rearOverhead } + rearHeadroom += extraRearHeadroom + // Prepare buffer. buf := make([]byte, frontHeadroom+length+rearHeadroom) + _, err := rand.Read(buf) + if err != nil { + t.Fatal(err) + } buf[frontHeadroom] = msgType var wgPacket, swgpPacket, decryptedWgPacket []byte @@ -22,19 +39,26 @@ func testHandler(t *testing.T, msgType byte, length int, h Handler, verifyFunc f wgPacket = append(wgPacket, buf[frontHeadroom:frontHeadroom+length]...) // Encrypt. - pkt, err := h.EncryptZeroCopy(buf, frontHeadroom, length, length+rearHeadroom-rearOverhead) + swgpPacketStart, swgpPacketLength, err := h.EncryptZeroCopy(buf, frontHeadroom, length) + if !errors.Is(err, expectedEncryptErr) { + t.Fatalf("Expected encryption error: %s\nGot: %s", expectedEncryptErr, err) + } if err != nil { - t.Fatal(err) + return } // Save encrypted packet. - swgpPacket = append(swgpPacket, pkt...) + swgpPacket = append(swgpPacket, buf[swgpPacketStart:swgpPacketStart+swgpPacketLength]...) // Decrypt. - decryptedWgPacket, err = h.DecryptZeroCopy(pkt) + wgPacketStart, wgPacketLength, err := h.DecryptZeroCopy(buf, swgpPacketStart, swgpPacketLength) + if !errors.Is(err, expectedDecryptErr) { + t.Fatalf("Expected decryption error: %s\nGot: %s", expectedDecryptErr, err) + } if err != nil { - t.Fatal(err) + return } + decryptedWgPacket = buf[wgPacketStart : wgPacketStart+wgPacketLength] verifyFunc(t, wgPacket, swgpPacket, decryptedWgPacket) } diff --git a/packet/paranoid.go b/packet/paranoid.go index 76c7c52..aff2428 100644 --- a/packet/paranoid.go +++ b/packet/paranoid.go @@ -15,7 +15,9 @@ import ( // All packets, irrespective of message type, are padded up to the maximum packet length // to hide any possible characteristics. // -// swgpPacket := 24B nonce + AEAD_Seal(2B payload length + payload + padding) +// swgpPacket := 24B nonce + AEAD_Seal(u16be payload length + payload + padding) +// +// paranoidHandler implements the Handler interface. type paranoidHandler struct { aead cipher.AEAD } @@ -44,58 +46,68 @@ func (h *paranoidHandler) RearOverhead() int { } // EncryptZeroCopy implements the Handler EncryptZeroCopy method. -func (h *paranoidHandler) EncryptZeroCopy(buf []byte, start, length, maxPacketLen int) (swgpPacket []byte, err error) { - if length > math.MaxUint16 { - return nil, fmt.Errorf("payload too long: %d is greater than 65535", length) +func (h *paranoidHandler) EncryptZeroCopy(buf []byte, wgPacketStart, wgPacketLength int) (swgpPacketStart, swgpPacketLength int, err error) { + if wgPacketLength > math.MaxUint16 { + err = &HandlerErr{ErrPacketSize, fmt.Sprintf("wg packet (length %d) is too large (greater than %d)", wgPacketLength, math.MaxUint16)} + return } + // Determine padding length. + rearHeadroom := len(buf) - wgPacketStart - wgPacketLength + paddingHeadroom := rearHeadroom - chacha20poly1305.Overhead var paddingLen int - - if maxPaddingLen := maxPacketLen - h.FrontOverhead() - length - h.RearOverhead(); maxPaddingLen > 0 { - paddingLen = mrand.Intn(maxPaddingLen + 1) + if paddingHeadroom > 0 { + paddingLen = mrand.Intn(paddingHeadroom) + 1 } - nonce := buf[start-chacha20poly1305.NonceSizeX-2 : start-2] - payloadLength := buf[start-2 : start] - plaintext := buf[start-2 : start+length+paddingLen] + // Calculate offsets. + swgpPacketStart = wgPacketStart - 2 - chacha20poly1305.NonceSizeX + swgpPacketLength = chacha20poly1305.NonceSizeX + 2 + wgPacketLength + paddingLen + chacha20poly1305.Overhead + + nonce := buf[swgpPacketStart : wgPacketStart-2] + payloadLength := buf[wgPacketStart-2 : wgPacketStart] + plaintext := buf[wgPacketStart-2 : wgPacketStart+wgPacketLength+paddingLen] // Write random nonce. _, err = rand.Read(nonce) if err != nil { - return nil, err + return } // Write payload length. - binary.BigEndian.PutUint16(payloadLength, uint16(length)) + binary.BigEndian.PutUint16(payloadLength, uint16(wgPacketLength)) // AEAD seal. - swgpPacket = h.aead.Seal(nonce, nonce, plaintext, nil) + h.aead.Seal(nonce, nonce, plaintext, nil) return } // DecryptZeroCopy implements the Handler DecryptZeroCopy method. -func (h *paranoidHandler) DecryptZeroCopy(swgpPacket []byte) (wgPacket []byte, err error) { - if len(swgpPacket) < chacha20poly1305.NonceSizeX { - return nil, fmt.Errorf("bad swgpPacket length: %d", len(swgpPacket)) +func (h *paranoidHandler) DecryptZeroCopy(buf []byte, swgpPacketStart, swgpPacketLength int) (wgPacketStart, wgPacketLength int, err error) { + if swgpPacketLength < chacha20poly1305.NonceSizeX+2+1+chacha20poly1305.Overhead { + err = &HandlerErr{ErrPacketSize, fmt.Sprintf("swgp packet (length %d) is too short", swgpPacketLength)} + return } - nonce := swgpPacket[:chacha20poly1305.NonceSizeX] - ciphertext := swgpPacket[chacha20poly1305.NonceSizeX:] + nonce := buf[swgpPacketStart : swgpPacketStart+chacha20poly1305.NonceSizeX] + ciphertext := buf[swgpPacketStart+chacha20poly1305.NonceSizeX : swgpPacketStart+swgpPacketLength] // AEAD open. plaintext, err := h.aead.Open(ciphertext[:0], nonce, ciphertext, nil) if err != nil { - return nil, err + return } // Read and validate payload length. - payloadLength := plaintext[:2] - length := int(binary.BigEndian.Uint16(payloadLength)) - if 2+length > len(plaintext) { - return nil, fmt.Errorf("payload length %d is greater than plaintext length %d", length, len(plaintext)) + payloadLengthBuf := plaintext[:2] + payloadLength := int(binary.BigEndian.Uint16(payloadLengthBuf)) + if payloadLength > len(plaintext)-2 { + err = &HandlerErr{ErrPayloadLength, fmt.Sprintf("payload length field value %d is out of range", payloadLength)} + return } - wgPacket = plaintext[2 : 2+length] + wgPacketStart = swgpPacketStart + chacha20poly1305.NonceSizeX + 2 + wgPacketLength = payloadLength return } diff --git a/packet/paranoid_test.go b/packet/paranoid_test.go index fc4a475..e2ad22a 100644 --- a/packet/paranoid_test.go +++ b/packet/paranoid_test.go @@ -24,7 +24,7 @@ func testNewParanoidHandler(t *testing.T) Handler { func testParanoidVerifyPacket(t *testing.T, wgPacket, swgpPacket, decryptedWgPacket []byte) { if len(swgpPacket) < chacha20poly1305.NonceSizeX+2+len(wgPacket)+chacha20poly1305.Overhead { - t.Error("bad swgpPacket length") + t.Error("Bad swgpPacket length.") } if !bytes.Equal(wgPacket, decryptedWgPacket) { @@ -32,22 +32,13 @@ func testParanoidVerifyPacket(t *testing.T, wgPacket, swgpPacket, decryptedWgPac } } -func TestParanoidHandleWireGuardHandshakeInitiationPacket(t *testing.T) { +func TestParanoidHandlePacket(t *testing.T) { h := testNewParanoidHandler(t) - testHandler(t, WireGuardMessageTypeHandshakeInitiation, WireGuardMessageLengthHandshakeInitiation, h, testParanoidVerifyPacket) -} -func TestParanoidHandleWireGuardHandshakeResponsePacket(t *testing.T) { - h := testNewParanoidHandler(t) - testHandler(t, WireGuardMessageTypeHandshakeResponse, WireGuardMessageLengthHandshakeResponse, h, testParanoidVerifyPacket) -} - -func TestParanoidHandleWireGuardHandshakeCookieReplyPacket(t *testing.T) { - h := testNewParanoidHandler(t) - testHandler(t, WireGuardMessageTypeHandshakeCookieReply, WireGuardMessageLengthHandshakeCookieReply, h, testParanoidVerifyPacket) -} - -func TestParanoidHandleWireGuardDataPacket(t *testing.T) { - h := testNewParanoidHandler(t) - testHandler(t, WireGuardMessageTypeData, 1452, h, testParanoidVerifyPacket) + for i := 1; i < 128; i++ { + testHandler(t, WireGuardMessageTypeHandshakeInitiation, i, 0, 0, h, nil, nil, testParanoidVerifyPacket) + testHandler(t, WireGuardMessageTypeHandshakeResponse, i, 0, 0, h, nil, nil, testParanoidVerifyPacket) + testHandler(t, WireGuardMessageTypeHandshakeCookieReply, i, 0, 0, h, nil, nil, testParanoidVerifyPacket) + testHandler(t, WireGuardMessageTypeData, i, 0, 0, h, nil, nil, testParanoidVerifyPacket) + } } diff --git a/packet/zerooverhead.go b/packet/zerooverhead.go index 28b9aa6..a1dbfff 100644 --- a/packet/zerooverhead.go +++ b/packet/zerooverhead.go @@ -4,16 +4,28 @@ import ( "crypto/aes" "crypto/cipher" "crypto/rand" + "encoding/binary" + "fmt" mrand "math/rand" + + "golang.org/x/crypto/chacha20poly1305" ) -// zeroOverheadHandler encrypts and decrypts the first 16 bytes of packets -// using an AES block cipher. -// Handshake packets (message type 1, 2, 3) are randomly padded to look like normal traffic. +// zeroOverheadHandshakePacketMinimumOverhead is the minimum overhead of a handshake packet encrypted by zeroOverheadHandler. +// Additional overhead is the random-length padding. +const zeroOverheadHandshakePacketMinimumOverhead = 2 + chacha20poly1305.Overhead + chacha20poly1305.NonceSizeX + +// zeroOverheadHandler encrypts and decrypts the first 16 bytes of packets using an AES block cipher. +// The remainder of handshake packets (message type 1, 2, 3) are also randomly padded and encrypted +// using an XChaCha20-Poly1305 AEAD cipher to blend into normal traffic. +// +// swgpPacket := aes(wgDataPacket[:16]) + wgDataPacket[16:] +// swgpPacket := aes(wgHandshakePacket[:16]) + AEAD_Seal(payload + padding + u16be payload length) + 24B nonce // // zeroOverheadHandler implements the Handler interface. type zeroOverheadHandler struct { - cb cipher.Block + cb cipher.Block + aead cipher.AEAD } // NewZeroOverheadHandler creates a zero-overhead handler that @@ -24,8 +36,14 @@ func NewZeroOverheadHandler(psk []byte) (Handler, error) { return nil, err } + aead, err := chacha20poly1305.NewX(psk) + if err != nil { + return nil, err + } + return &zeroOverheadHandler{ - cb: cb, + cb: cb, + aead: aead, }, nil } @@ -40,55 +58,112 @@ func (h *zeroOverheadHandler) RearOverhead() int { } // EncryptZeroCopy implements the Handler EncryptZeroCopy method. -func (h *zeroOverheadHandler) EncryptZeroCopy(buf []byte, start, length, maxPacketLen int) (swgpPacket []byte, err error) { - var paddingLen int +func (h *zeroOverheadHandler) EncryptZeroCopy(buf []byte, wgPacketStart, wgPacketLength int) (swgpPacketStart, swgpPacketLength int, err error) { + swgpPacketStart = wgPacketStart + swgpPacketLength = wgPacketLength - // Add padding only if: - // - Packet is handshake. - // - We have room for padding. - switch buf[start] { - case WireGuardMessageTypeHandshakeInitiation, WireGuardMessageTypeHandshakeResponse, WireGuardMessageTypeHandshakeCookieReply: - if maxPacketLen > length { - paddingLen = mrand.Intn(maxPacketLen - length + 1) - } + // Skip small packets. + if wgPacketLength < 16 { + return } - swgpPacket = buf[start : start+length+paddingLen] + // Save message type. + messageType := buf[wgPacketStart] + + // Encrypt first 16 bytes. + h.cb.Encrypt(buf[wgPacketStart:], buf[wgPacketStart:]) - if length < 16 { + // We are done with non-handshake packets. + switch messageType { + case WireGuardMessageTypeHandshakeInitiation, WireGuardMessageTypeHandshakeResponse, WireGuardMessageTypeHandshakeCookieReply: + default: return } - // Encrypt first 16 bytes. - h.cb.Encrypt(swgpPacket[:16], swgpPacket[:16]) + // Return error if packet is so big that buffer has no room for AEAD overhead. + rearHeadroom := len(buf) - wgPacketStart - wgPacketLength + paddingHeadroom := rearHeadroom - 2 - chacha20poly1305.Overhead - chacha20poly1305.NonceSizeX + if paddingHeadroom < 0 { + err = &HandlerErr{ErrPacketSize, fmt.Sprintf("handshake packet (length %d) is too large to process in buffer (length %d)", wgPacketLength, len(buf))} + return + } + + var paddingLen int + if paddingHeadroom > 0 { + paddingLen = mrand.Intn(paddingHeadroom) + 1 + } + + swgpPacketLength += paddingLen + zeroOverheadHandshakePacketMinimumOverhead + + // Calculate offsets. + plaintextStart := wgPacketStart + 16 + payloadLengthBufStart := wgPacketStart + wgPacketLength + paddingLen + plaintextEnd := payloadLengthBufStart + 2 + nonceStart := plaintextEnd + chacha20poly1305.Overhead + nonceEnd := nonceStart + chacha20poly1305.NonceSizeX - // Add padding. - if paddingLen > 0 { - padding := swgpPacket[length:] - _, err = rand.Read(padding) + // Write payload length. + payloadLength := wgPacketLength - 16 + payloadLengthBuf := buf[payloadLengthBufStart:plaintextEnd] + binary.BigEndian.PutUint16(payloadLengthBuf, uint16(payloadLength)) + + plaintext := buf[plaintextStart:plaintextEnd] + nonce := buf[nonceStart:nonceEnd] + _, err = rand.Read(nonce) + if err != nil { + return } + h.aead.Seal(plaintext[:0], nonce, plaintext, nil) return } // DecryptZeroCopy implements the Handler DecryptZeroCopy method. -func (h *zeroOverheadHandler) DecryptZeroCopy(swgpPacket []byte) (wgPacket []byte, err error) { - wgPacket = swgpPacket +func (h *zeroOverheadHandler) DecryptZeroCopy(buf []byte, swgpPacketStart, swgpPacketLength int) (wgPacketStart, wgPacketLength int, err error) { + wgPacketStart = swgpPacketStart + wgPacketLength = swgpPacketLength + + // Skip small packets. + if swgpPacketLength < 16 { + return + } // Decrypt first 16 bytes. - if len(swgpPacket) >= 16 { - h.cb.Decrypt(swgpPacket[:16], swgpPacket[:16]) + h.cb.Decrypt(buf[swgpPacketStart:], buf[swgpPacketStart:]) + + // We are done with non-handshake and short handshake packets. + switch buf[swgpPacketStart] { + case WireGuardMessageTypeHandshakeInitiation, WireGuardMessageTypeHandshakeResponse, WireGuardMessageTypeHandshakeCookieReply: + if swgpPacketLength < 16+zeroOverheadHandshakePacketMinimumOverhead { + err = &HandlerErr{ErrPacketSize, fmt.Sprintf("swgp packet too short: %d", swgpPacketLength)} + return + } + default: + return } - // Hide padding. - switch { - case swgpPacket[0] == WireGuardMessageTypeHandshakeInitiation && len(swgpPacket) >= WireGuardMessageLengthHandshakeInitiation: - wgPacket = swgpPacket[:WireGuardMessageLengthHandshakeInitiation] - case swgpPacket[0] == WireGuardMessageTypeHandshakeResponse && len(swgpPacket) >= WireGuardMessageLengthHandshakeResponse: - wgPacket = swgpPacket[:WireGuardMessageLengthHandshakeResponse] - case swgpPacket[0] == WireGuardMessageTypeHandshakeCookieReply && len(swgpPacket) >= WireGuardMessageLengthHandshakeCookieReply: - wgPacket = swgpPacket[:WireGuardMessageLengthHandshakeCookieReply] + // Calculate offsets. + nonceEnd := swgpPacketStart + swgpPacketLength + nonceStart := nonceEnd - chacha20poly1305.NonceSizeX + plaintextEnd := nonceStart - chacha20poly1305.Overhead + payloadLengthBufStart := plaintextEnd - 2 + plaintextStart := swgpPacketStart + 16 + + ciphertext := buf[plaintextStart:nonceStart] + nonce := buf[nonceStart:nonceEnd] + _, err = h.aead.Open(ciphertext[:0], nonce, ciphertext, nil) + if err != nil { + return + } + + // Read and validate payload length. + payloadLengthBuf := buf[payloadLengthBufStart:plaintextEnd] + payloadLength := int(binary.BigEndian.Uint16(payloadLengthBuf)) + if payloadLength > payloadLengthBufStart-plaintextStart { + err = &HandlerErr{ErrPayloadLength, fmt.Sprintf("payload length field value %d is out of range", payloadLength)} + return } + wgPacketLength = 16 + payloadLength return } diff --git a/packet/zerooverhead_test.go b/packet/zerooverhead_test.go index bc820ba..9c184e8 100644 --- a/packet/zerooverhead_test.go +++ b/packet/zerooverhead_test.go @@ -20,7 +20,31 @@ func testNewZeroOverheadHandler(t *testing.T) Handler { return h } -func testZeroOverheadVerifyPacket(t *testing.T, wgPacket, swgpPacket, decryptedWgPacket []byte) { +func testZeroOverheadVerifyUnchangedPacket(t *testing.T, wgPacket, swgpPacket, decryptedWgPacket []byte) { + if !bytes.Equal(wgPacket, swgpPacket) { + t.Error("The packet should be untouched.") + } + + if !bytes.Equal(wgPacket, decryptedWgPacket) { + t.Error("Decrypted packet is different from original packet.") + } +} + +func testZeroOverheadVerifyHandshakePacket(t *testing.T, wgPacket, swgpPacket, decryptedWgPacket []byte) { + if bytes.Equal(wgPacket, swgpPacket[:len(wgPacket)]) { + t.Error("The packet is not encrypted.") + } + + if len(swgpPacket) < len(wgPacket)+zeroOverheadHandshakePacketMinimumOverhead { + t.Error("Bad swgpPacket length.") + } + + if !bytes.Equal(wgPacket, decryptedWgPacket) { + t.Error("Decrypted packet is different from original packet.") + } +} + +func testZeroOverheadVerifyDataPacket(t *testing.T, wgPacket, swgpPacket, decryptedWgPacket []byte) { if bytes.Equal(wgPacket[:16], swgpPacket[:16]) { t.Error("The first 16 bytes are not encrypted.") } @@ -34,22 +58,41 @@ func testZeroOverheadVerifyPacket(t *testing.T, wgPacket, swgpPacket, decryptedW } } -func TestZeroOverheadHandleWireGuardHandshakeInitiationPacket(t *testing.T) { +func TestZeroOverheadHandleLessThan16Bytes(t *testing.T) { h := testNewZeroOverheadHandler(t) - testHandler(t, WireGuardMessageTypeHandshakeInitiation, WireGuardMessageLengthHandshakeInitiation, h, testZeroOverheadVerifyPacket) + + for i := 0; i < 16; i++ { + testHandler(t, WireGuardMessageTypeHandshakeInitiation, i, 1, 1, h, nil, nil, testZeroOverheadVerifyUnchangedPacket) + testHandler(t, WireGuardMessageTypeHandshakeResponse, i, 1, 1, h, nil, nil, testZeroOverheadVerifyUnchangedPacket) + testHandler(t, WireGuardMessageTypeHandshakeCookieReply, i, 1, 1, h, nil, nil, testZeroOverheadVerifyUnchangedPacket) + testHandler(t, WireGuardMessageTypeData, i, 1, 1, h, nil, nil, testZeroOverheadVerifyUnchangedPacket) + } } -func TestZeroOverheadHandleWireGuardHandshakeResponsePacket(t *testing.T) { +func TestZeroOverheadHandleEncryptErrPacketSize(t *testing.T) { h := testNewZeroOverheadHandler(t) - testHandler(t, WireGuardMessageTypeHandshakeResponse, WireGuardMessageLengthHandshakeResponse, h, testZeroOverheadVerifyPacket) + + for i := 0; i < zeroOverheadHandshakePacketMinimumOverhead; i++ { + testHandler(t, WireGuardMessageTypeHandshakeInitiation, WireGuardMessageLengthHandshakeInitiation, 1, i, h, ErrPacketSize, nil, testZeroOverheadVerifyUnchangedPacket) + testHandler(t, WireGuardMessageTypeHandshakeResponse, WireGuardMessageLengthHandshakeResponse, 1, i, h, ErrPacketSize, nil, testZeroOverheadVerifyUnchangedPacket) + testHandler(t, WireGuardMessageTypeHandshakeCookieReply, WireGuardMessageLengthHandshakeCookieReply, 1, i, h, ErrPacketSize, nil, testZeroOverheadVerifyUnchangedPacket) + } } -func TestZeroOverheadHandleWireGuardHandshakeCookieReplyPacket(t *testing.T) { +func TestZeroOverheadHandleHandshakePacket(t *testing.T) { h := testNewZeroOverheadHandler(t) - testHandler(t, WireGuardMessageTypeHandshakeCookieReply, WireGuardMessageLengthHandshakeCookieReply, h, testZeroOverheadVerifyPacket) + + for i := 16; i < 128; i++ { + testHandler(t, WireGuardMessageTypeHandshakeInitiation, i, 1, zeroOverheadHandshakePacketMinimumOverhead, h, nil, nil, testZeroOverheadVerifyHandshakePacket) + testHandler(t, WireGuardMessageTypeHandshakeResponse, i, 1, zeroOverheadHandshakePacketMinimumOverhead, h, nil, nil, testZeroOverheadVerifyHandshakePacket) + testHandler(t, WireGuardMessageTypeHandshakeCookieReply, i, 1, zeroOverheadHandshakePacketMinimumOverhead, h, nil, nil, testZeroOverheadVerifyHandshakePacket) + } } -func TestZeroOverheadHandleWireGuardDataPacket(t *testing.T) { +func TestZeroOverheadHandleDataPacket(t *testing.T) { h := testNewZeroOverheadHandler(t) - testHandler(t, WireGuardMessageTypeData, 1452, h, testZeroOverheadVerifyPacket) + + for i := 16; i < 128; i++ { + testHandler(t, WireGuardMessageTypeData, i, 1, 1, h, nil, nil, testZeroOverheadVerifyDataPacket) + } } diff --git a/service/client.go b/service/client.go index a913098..2a421c8 100644 --- a/service/client.go +++ b/service/client.go @@ -28,17 +28,10 @@ type ClientConfig struct { DisableSendmmsg bool `json:"disableSendmmsg"` } -// clientQueuedPacket stores an unencrypted wg packet. -type clientQueuedPacket struct { - bufp *[]byte - start int - length int -} - type clientNatEntry struct { clientOobCache []byte proxyConn *net.UDPConn - proxyConnSendCh chan clientQueuedPacket + proxyConnSendCh chan queuedPacket } type client struct { @@ -223,7 +216,7 @@ func (c *client) Start() (err error) { natEntry = &clientNatEntry{ proxyConn: proxyConn, - proxyConnSendCh: make(chan clientQueuedPacket, sendChannelCapacity), + proxyConnSendCh: make(chan queuedPacket, sendChannelCapacity), } c.table[clientAddr] = natEntry @@ -292,7 +285,7 @@ func (c *client) Start() (err error) { } select { - case natEntry.proxyConnSendCh <- clientQueuedPacket{packetBufp, frontOverhead, n}: + case natEntry.proxyConnSendCh <- queuedPacket{packetBufp, frontOverhead, n}: default: c.logger.Debug("swgpPacket dropped due to full send channel", zap.Stringer("service", c), @@ -326,7 +319,7 @@ func (c *client) relayWgToProxyGeneric(clientAddr netip.AddrPort, natEntry *clie packetBuf := *queuedPacket.bufp - swgpPacket, err := c.handler.EncryptZeroCopy(packetBuf, queuedPacket.start, queuedPacket.length, c.maxProxyPacketSize) + swgpPacketStart, swgpPacketLength, err := c.handler.EncryptZeroCopy(packetBuf, queuedPacket.start, queuedPacket.length) if err != nil { c.logger.Warn("Failed to encrypt WireGuard packet", zap.Stringer("service", c), @@ -337,6 +330,7 @@ func (c *client) relayWgToProxyGeneric(clientAddr netip.AddrPort, natEntry *clie c.packetBufPool.Put(queuedPacket.bufp) continue } + swgpPacket := packetBuf[swgpPacketStart : swgpPacketStart+swgpPacketLength] _, _, err = natEntry.proxyConn.WriteMsgUDPAddrPort(swgpPacket, nil, c.proxyAddr) if err != nil { @@ -394,8 +388,7 @@ func (c *client) relayProxyToWgGeneric(clientAddr netip.AddrPort, natEntry *clie continue } - swgpPacket := packetBuf[:n] - wgPacket, err := c.handler.DecryptZeroCopy(swgpPacket) + wgPacketStart, wgPacketLength, err := c.handler.DecryptZeroCopy(packetBuf, 0, n) if err != nil { c.logger.Warn("Failed to decrypt swgpPacket", zap.Stringer("service", c), @@ -406,6 +399,7 @@ func (c *client) relayProxyToWgGeneric(clientAddr netip.AddrPort, natEntry *clie ) continue } + wgPacket := packetBuf[wgPacketStart : wgPacketStart+wgPacketLength] _, _, err = c.wgConn.WriteMsgUDPAddrPort(wgPacket, natEntry.clientOobCache, clientAddr) if err != nil { diff --git a/service/client_linux.go b/service/client_linux.go index b5fe53f..2332462 100644 --- a/service/client_linux.go +++ b/service/client_linux.go @@ -21,7 +21,7 @@ func (c *client) getRelayWgToProxyFunc(disableSendmmsg bool) func(clientAddr net func (c *client) relayWgToProxySendmmsg(clientAddr netip.AddrPort, natEntry *clientNatEntry) { name, namelen := conn.AddrPortToSockaddr(c.proxyAddr) - dequeuedPackets := make([]clientQueuedPacket, 0, conn.UIO_MAXIOV) + dequeuedPackets := make([]queuedPacket, 0, conn.UIO_MAXIOV) iovec := make([]unix.Iovec, 0, conn.UIO_MAXIOV) msgvec := make([]conn.Mmsghdr, 0, conn.UIO_MAXIOV) @@ -29,7 +29,7 @@ func (c *client) relayWgToProxySendmmsg(clientAddr netip.AddrPort, natEntry *cli // Dequeue packets and append to dequeuedPackets. var ( - dequeuedPacket clientQueuedPacket + dequeuedPacket queuedPacket ok bool ) @@ -63,7 +63,7 @@ func (c *client) relayWgToProxySendmmsg(clientAddr netip.AddrPort, natEntry *cli for i, packet := range dequeuedPackets { packetBuf := *packet.bufp - swgpPacket, err := c.handler.EncryptZeroCopy(packetBuf, packet.start, packet.length, c.maxProxyPacketSize) + swgpPacketStart, swgpPacketLength, err := c.handler.EncryptZeroCopy(packetBuf, packet.start, packet.length) if err != nil { c.logger.Warn("Failed to encrypt WireGuard packet", zap.Stringer("service", c), @@ -74,8 +74,8 @@ func (c *client) relayWgToProxySendmmsg(clientAddr netip.AddrPort, natEntry *cli goto cleanup } - iovec[i].Base = &swgpPacket[0] - iovec[i].SetLen(len(swgpPacket)) + iovec[i].Base = &packetBuf[swgpPacketStart] + iovec[i].SetLen(swgpPacketLength) msgvec[i].Msghdr.Name = name msgvec[i].Msghdr.Namelen = namelen @@ -193,8 +193,7 @@ func (c *client) relayProxyToWgSendmmsg(clientAddr netip.AddrPort, natEntry *cli } packetBuf := unsafe.Slice(msg.Msghdr.Iov.Base, c.maxProxyPacketSize) - swgpPacket := packetBuf[:msg.Msglen] - wgPacket, err := c.handler.DecryptZeroCopy(swgpPacket) + wgPacketStart, wgPacketLength, err := c.handler.DecryptZeroCopy(packetBuf, 0, int(msg.Msglen)) if err != nil { c.logger.Warn("Failed to decrypt swgpPacket", zap.Stringer("service", c), @@ -206,8 +205,8 @@ func (c *client) relayProxyToWgSendmmsg(clientAddr netip.AddrPort, natEntry *cli continue } - smsgvec[ns].Msghdr.Iov.Base = &wgPacket[0] - smsgvec[ns].Msghdr.Iov.SetLen(len(wgPacket)) + smsgvec[ns].Msghdr.Iov.Base = &packetBuf[wgPacketStart] + smsgvec[ns].Msghdr.Iov.SetLen(wgPacketLength) if smsgControlLen > 0 { smsgvec[ns].Msghdr.Control = &smsgControl[0] smsgvec[ns].Msghdr.SetControllen(smsgControlLen) diff --git a/service/server.go b/service/server.go index 405b3d0..b068625 100644 --- a/service/server.go +++ b/service/server.go @@ -28,16 +28,10 @@ type ServerConfig struct { DisableSendmmsg bool `json:"disableSendmmsg"` } -// serverQueuedPacket stores a decrypted wg packet. -type serverQueuedPacket struct { - bufp *[]byte - wgPacket []byte -} - type serverNatEntry struct { clientOobCache []byte wgConn *net.UDPConn - wgConnSendCh chan serverQueuedPacket + wgConnSendCh chan queuedPacket maxProxyPacketSize int } @@ -174,8 +168,7 @@ func (s *server) Start() (err error) { continue } - swgpPacket := packetBuf[:n] - wgPacket, err := s.handler.DecryptZeroCopy(swgpPacket) + wgPacketStart, wgPacketLength, err := s.handler.DecryptZeroCopy(packetBuf, 0, n) if err != nil { s.logger.Warn("Failed to decrypt swgpPacket", zap.Stringer("service", s), @@ -229,7 +222,7 @@ func (s *server) Start() (err error) { natEntry = &serverNatEntry{ wgConn: wgConn, - wgConnSendCh: make(chan serverQueuedPacket, sendChannelCapacity), + wgConnSendCh: make(chan queuedPacket, sendChannelCapacity), } if addr := clientAddr.Addr(); addr.Is4() || addr.Is4In6() { @@ -275,7 +268,7 @@ func (s *server) Start() (err error) { ) // Update wgConn read deadline when a handshake initiation/response message is received. - switch wgPacket[0] { + switch packetBuf[wgPacketStart] { case packet.WireGuardMessageTypeHandshakeInitiation, packet.WireGuardMessageTypeHandshakeResponse: err = natEntry.wgConn.SetReadDeadline(time.Now().Add(RejectAfterTime)) if err != nil { @@ -305,7 +298,7 @@ func (s *server) Start() (err error) { } select { - case natEntry.wgConnSendCh <- serverQueuedPacket{packetBufp, wgPacket}: + case natEntry.wgConnSendCh <- queuedPacket{packetBufp, wgPacketStart, wgPacketLength}: default: s.logger.Debug("wgPacket dropped due to full send channel", zap.Stringer("service", s), @@ -337,7 +330,10 @@ func (s *server) relayProxyToWgGeneric(clientAddr netip.AddrPort, natEntry *serv break } - _, _, err := natEntry.wgConn.WriteMsgUDPAddrPort(queuedPacket.wgPacket, nil, s.wgAddr) + packetBuf := *queuedPacket.bufp + wgPacket := packetBuf[queuedPacket.start : queuedPacket.start+queuedPacket.length] + + _, _, err := natEntry.wgConn.WriteMsgUDPAddrPort(wgPacket, nil, s.wgAddr) if err != nil { s.logger.Warn("Failed to write wgPacket to wgConn", zap.Stringer("service", s), @@ -397,7 +393,7 @@ func (s *server) relayWgToProxyGeneric(clientAddr netip.AddrPort, natEntry *serv continue } - swgpPacket, err := s.handler.EncryptZeroCopy(packetBuf, frontOverhead, n, natEntry.maxProxyPacketSize) + swgpPacketStart, swgpPacketLength, err := s.handler.EncryptZeroCopy(packetBuf, frontOverhead, n) if err != nil { s.logger.Warn("Failed to encrypt WireGuard packet", zap.Stringer("service", s), @@ -408,6 +404,7 @@ func (s *server) relayWgToProxyGeneric(clientAddr netip.AddrPort, natEntry *serv ) continue } + swgpPacket := packetBuf[swgpPacketStart : swgpPacketStart+swgpPacketLength] _, _, err = s.proxyConn.WriteMsgUDPAddrPort(swgpPacket, natEntry.clientOobCache, clientAddr) if err != nil { diff --git a/service/server_linux.go b/service/server_linux.go index aeee68f..c7e568a 100644 --- a/service/server_linux.go +++ b/service/server_linux.go @@ -21,7 +21,7 @@ func (s *server) getRelayProxyToWgFunc(disableSendmmsg bool) func(clientAddr net func (s *server) relayProxyToWgSendmmsg(clientAddr netip.AddrPort, natEntry *serverNatEntry) { name, namelen := conn.AddrPortToSockaddr(s.wgAddr) - dequeuedPackets := make([]serverQueuedPacket, 0, conn.UIO_MAXIOV) + dequeuedPackets := make([]queuedPacket, 0, conn.UIO_MAXIOV) iovec := make([]unix.Iovec, 0, conn.UIO_MAXIOV) msgvec := make([]conn.Mmsghdr, 0, conn.UIO_MAXIOV) @@ -29,7 +29,7 @@ func (s *server) relayProxyToWgSendmmsg(clientAddr netip.AddrPort, natEntry *ser // Dequeue packets and append to dequeuedPackets. var ( - dequeuedPacket serverQueuedPacket + dequeuedPacket queuedPacket ok bool ) @@ -61,8 +61,10 @@ func (s *server) relayProxyToWgSendmmsg(clientAddr netip.AddrPort, natEntry *ser // Add packets to iovec and msgvec. for i, packet := range dequeuedPackets { - iovec[i].Base = &packet.wgPacket[0] - iovec[i].SetLen(len(packet.wgPacket)) + packetBuf := *packet.bufp + + iovec[i].Base = &packetBuf[packet.start] + iovec[i].SetLen(packet.length) msgvec[i].Msghdr.Name = name msgvec[i].Msghdr.Namelen = namelen @@ -184,7 +186,7 @@ func (s *server) relayWgToProxySendmmsg(clientAddr netip.AddrPort, natEntry *ser } packetBuf := unsafe.Slice((*byte)(unsafe.Add(unsafe.Pointer(msg.Msghdr.Iov.Base), -frontOverhead)), natEntry.maxProxyPacketSize) - swgpPacket, err := s.handler.EncryptZeroCopy(packetBuf, frontOverhead, int(msg.Msglen), natEntry.maxProxyPacketSize) + swgpPacketStart, swgpPacketLength, err := s.handler.EncryptZeroCopy(packetBuf, frontOverhead, int(msg.Msglen)) if err != nil { s.logger.Warn("Failed to encrypt WireGuard packet", zap.Stringer("service", s), @@ -196,8 +198,8 @@ func (s *server) relayWgToProxySendmmsg(clientAddr netip.AddrPort, natEntry *ser continue } - smsgvec[ns].Msghdr.Iov.Base = &swgpPacket[0] - smsgvec[ns].Msghdr.Iov.SetLen(len(swgpPacket)) + smsgvec[ns].Msghdr.Iov.Base = &packetBuf[swgpPacketStart] + smsgvec[ns].Msghdr.Iov.SetLen(swgpPacketLength) if smsgControlLen > 0 { smsgvec[ns].Msghdr.Control = &smsgControl[0] smsgvec[ns].Msghdr.SetControllen(smsgControlLen) diff --git a/service/service.go b/service/service.go index f6a5f43..f1e450e 100644 --- a/service/service.go +++ b/service/service.go @@ -115,3 +115,10 @@ func getPacketHandlerForProxyMode(proxyMode string, proxyPSK []byte) (handler pa } return } + +// queuedPacket is the structure used by send channels to queue packets for sending. +type queuedPacket struct { + bufp *[]byte + start int + length int +}