Skip to content

Commit

Permalink
🦭 Refactor handlers
Browse files Browse the repository at this point in the history
- Zero-overhead handler: Conceal all-zero MAC2 by using XChaCha20-Poly1305 to encrypt part of handshake packets. This is a breaking change.
- Handler interface: Drop unnecessary maxPacketLen.
- Better handler tests.
  • Loading branch information
database64128 committed May 28, 2022
1 parent 35ed532 commit 1ddb677
Show file tree
Hide file tree
Showing 12 changed files with 316 additions and 156 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

### 1. Zero overhead

Simply AES encrypt the first 16 bytes of all packets. Handshake packets (message type 1, 2, 3) are also randomly padded to look like normal traffic.
The first 16 bytes of all packets are encrypted using an AES block cipher. The remainder of handshake packets (message type 1, 2, 3) are also randomly padded and encrypted using an XChaCha20-Poly1305 AEAD cipher to blend into normal traffic.

### 2. Paranoid

Expand Down
52 changes: 34 additions & 18 deletions packet/handler.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Package packet contains types and methods that transform WireGuard packets.
package packet

import "errors"

const (
WireGuardMessageTypeHandshakeInitiation = 1
WireGuardMessageTypeHandshakeResponse = 2
Expand All @@ -12,6 +14,27 @@ const (
WireGuardMessageLengthHandshakeCookieReply = 64
)

var (
ErrPacketSize = errors.New("packet is too big or too small to be processed")
ErrPayloadLength = errors.New("payload length field value is out of range")
)

type HandlerErr struct {
Err error
Message string
}

func (e *HandlerErr) Unwrap() error {
return e.Err
}

func (e *HandlerErr) Error() string {
if e.Message == "" {
return e.Err.Error()
}
return e.Message
}

// Handler encrypts WireGuard packets and decrypts swgp packets.
type Handler interface {
// FrontOverhead returns the headroom to reserve in buffer before payload.
Expand All @@ -20,25 +43,18 @@ type Handler interface {
// RearOverhead returns the headroom to reserve in buffer after payload.
RearOverhead() int

// EncryptZeroCopy encrypts a WireGuard packet and returns a swgp packet
// without copying or incurring any allocations.
//
// buf must have at least FrontOverhead() bytes before and RearOverhead() bytes
// after the WireGuard packet.
//
// In other words, start must not be less than FrontOverhead(),
// len(buf) - start - max(length, maxPacketLen) must not be less than RearOverhead().
// EncryptZeroCopy encrypts a WireGuard packet and returns a swgp packet without copying or incurring any allocations.
//
// length is allowed to be greater than maxPacketLen (IP fragmentation).
// The WireGuard packet starts at buf[wgPacketStart] and its length is specified by wgPacketLength.
// The returned swgp packet starts at buf[swgpPacketStart] and its length is specified by swgpPacketLength.
//
// maxPacketLen is the maximum payload length of a single unfragmented UDP packet.
//
// For IPv4, maxPacketLen = MTU - 20 (IPv4 header) - 8 (UDP header).
//
// For IPv6, maxPacketLen = MTU - 40 (IPv6 header) - 8 (UDP header).
EncryptZeroCopy(buf []byte, start, length, maxPacketLen int) (swgpPacket []byte, err error)
// buf must have at least FrontOverhead() bytes before and RearOverhead() bytes after the WireGuard packet.
// In other words, start must not be less than FrontOverhead(), len(buf) must not be less than start + length + RearOverhead().
EncryptZeroCopy(buf []byte, wgPacketStart, wgPacketLength int) (swgpPacketStart, swgpPacketLength int, err error)

// DecryptZeroCopy decrypts a swgp packet and returns a WireGuard packet
// without copying or incurring any allocations.
DecryptZeroCopy(swgpPacket []byte) (wgPacket []byte, err error)
// DecryptZeroCopy decrypts a swgp packet and returns a WireGuard packet without copying or incurring any allocations.
//
// The swgp packet starts at buf[swgpPacketStart] and its length is specified by swgpPacketLength.
// The returned WireGuard packet starts at buf[wgPacketStart] and its length is specified by wgPacketLength.
DecryptZeroCopy(buf []byte, swgpPacketStart, swgpPacketLength int) (wgPacketStart, wgPacketLength int, err error)
}
42 changes: 33 additions & 9 deletions packet/handler_test.go
Original file line number Diff line number Diff line change
@@ -1,19 +1,36 @@
package packet

import "testing"
import (
"crypto/rand"
"errors"
"testing"
)

func testHandler(t *testing.T, msgType byte, length int, h Handler, verifyFunc func(t *testing.T, wgPacket, swgpPacket, decryptedWgPacket []byte)) {
// Reserve a minimum of 24 bytes to ensure code can handle extra headroom.
frontHeadroom, rearHeadroom := 24, 24
func testHandler(
t *testing.T,
msgType byte,
length, extraFrontHeadroom, extraRearHeadroom int,
h Handler,
expectedEncryptErr, expectedDecryptErr error,
verifyFunc func(t *testing.T, wgPacket, swgpPacket, decryptedWgPacket []byte),
) {
var frontHeadroom, rearHeadroom int
frontOverhead, rearOverhead := h.FrontOverhead(), h.RearOverhead()
if frontOverhead > frontHeadroom {
frontHeadroom = frontOverhead
}
frontHeadroom += extraFrontHeadroom
if rearOverhead > rearHeadroom {
rearHeadroom = rearOverhead
}
rearHeadroom += extraRearHeadroom

// Prepare buffer.
buf := make([]byte, frontHeadroom+length+rearHeadroom)
_, err := rand.Read(buf)
if err != nil {
t.Fatal(err)
}
buf[frontHeadroom] = msgType

var wgPacket, swgpPacket, decryptedWgPacket []byte
Expand All @@ -22,19 +39,26 @@ func testHandler(t *testing.T, msgType byte, length int, h Handler, verifyFunc f
wgPacket = append(wgPacket, buf[frontHeadroom:frontHeadroom+length]...)

// Encrypt.
pkt, err := h.EncryptZeroCopy(buf, frontHeadroom, length, length+rearHeadroom-rearOverhead)
swgpPacketStart, swgpPacketLength, err := h.EncryptZeroCopy(buf, frontHeadroom, length)
if !errors.Is(err, expectedEncryptErr) {
t.Fatalf("Expected encryption error: %s\nGot: %s", expectedEncryptErr, err)
}
if err != nil {
t.Fatal(err)
return
}

// Save encrypted packet.
swgpPacket = append(swgpPacket, pkt...)
swgpPacket = append(swgpPacket, buf[swgpPacketStart:swgpPacketStart+swgpPacketLength]...)

// Decrypt.
decryptedWgPacket, err = h.DecryptZeroCopy(pkt)
wgPacketStart, wgPacketLength, err := h.DecryptZeroCopy(buf, swgpPacketStart, swgpPacketLength)
if !errors.Is(err, expectedDecryptErr) {
t.Fatalf("Expected decryption error: %s\nGot: %s", expectedDecryptErr, err)
}
if err != nil {
t.Fatal(err)
return
}
decryptedWgPacket = buf[wgPacketStart : wgPacketStart+wgPacketLength]

verifyFunc(t, wgPacket, swgpPacket, decryptedWgPacket)
}
60 changes: 36 additions & 24 deletions packet/paranoid.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ import (
// All packets, irrespective of message type, are padded up to the maximum packet length
// to hide any possible characteristics.
//
// swgpPacket := 24B nonce + AEAD_Seal(2B payload length + payload + padding)
// swgpPacket := 24B nonce + AEAD_Seal(u16be payload length + payload + padding)
//
// paranoidHandler implements the Handler interface.
type paranoidHandler struct {
aead cipher.AEAD
}
Expand Down Expand Up @@ -44,58 +46,68 @@ func (h *paranoidHandler) RearOverhead() int {
}

// EncryptZeroCopy implements the Handler EncryptZeroCopy method.
func (h *paranoidHandler) EncryptZeroCopy(buf []byte, start, length, maxPacketLen int) (swgpPacket []byte, err error) {
if length > math.MaxUint16 {
return nil, fmt.Errorf("payload too long: %d is greater than 65535", length)
func (h *paranoidHandler) EncryptZeroCopy(buf []byte, wgPacketStart, wgPacketLength int) (swgpPacketStart, swgpPacketLength int, err error) {
if wgPacketLength > math.MaxUint16 {
err = &HandlerErr{ErrPacketSize, fmt.Sprintf("wg packet (length %d) is too large (greater than %d)", wgPacketLength, math.MaxUint16)}
return
}

// Determine padding length.
rearHeadroom := len(buf) - wgPacketStart - wgPacketLength
paddingHeadroom := rearHeadroom - chacha20poly1305.Overhead
var paddingLen int

if maxPaddingLen := maxPacketLen - h.FrontOverhead() - length - h.RearOverhead(); maxPaddingLen > 0 {
paddingLen = mrand.Intn(maxPaddingLen + 1)
if paddingHeadroom > 0 {
paddingLen = mrand.Intn(paddingHeadroom) + 1
}

nonce := buf[start-chacha20poly1305.NonceSizeX-2 : start-2]
payloadLength := buf[start-2 : start]
plaintext := buf[start-2 : start+length+paddingLen]
// Calculate offsets.
swgpPacketStart = wgPacketStart - 2 - chacha20poly1305.NonceSizeX
swgpPacketLength = chacha20poly1305.NonceSizeX + 2 + wgPacketLength + paddingLen + chacha20poly1305.Overhead

nonce := buf[swgpPacketStart : wgPacketStart-2]
payloadLength := buf[wgPacketStart-2 : wgPacketStart]
plaintext := buf[wgPacketStart-2 : wgPacketStart+wgPacketLength+paddingLen]

// Write random nonce.
_, err = rand.Read(nonce)
if err != nil {
return nil, err
return
}

// Write payload length.
binary.BigEndian.PutUint16(payloadLength, uint16(length))
binary.BigEndian.PutUint16(payloadLength, uint16(wgPacketLength))

// AEAD seal.
swgpPacket = h.aead.Seal(nonce, nonce, plaintext, nil)
h.aead.Seal(nonce, nonce, plaintext, nil)

return
}

// DecryptZeroCopy implements the Handler DecryptZeroCopy method.
func (h *paranoidHandler) DecryptZeroCopy(swgpPacket []byte) (wgPacket []byte, err error) {
if len(swgpPacket) < chacha20poly1305.NonceSizeX {
return nil, fmt.Errorf("bad swgpPacket length: %d", len(swgpPacket))
func (h *paranoidHandler) DecryptZeroCopy(buf []byte, swgpPacketStart, swgpPacketLength int) (wgPacketStart, wgPacketLength int, err error) {
if swgpPacketLength < chacha20poly1305.NonceSizeX+2+1+chacha20poly1305.Overhead {
err = &HandlerErr{ErrPacketSize, fmt.Sprintf("swgp packet (length %d) is too short", swgpPacketLength)}
return
}

nonce := swgpPacket[:chacha20poly1305.NonceSizeX]
ciphertext := swgpPacket[chacha20poly1305.NonceSizeX:]
nonce := buf[swgpPacketStart : swgpPacketStart+chacha20poly1305.NonceSizeX]
ciphertext := buf[swgpPacketStart+chacha20poly1305.NonceSizeX : swgpPacketStart+swgpPacketLength]

// AEAD open.
plaintext, err := h.aead.Open(ciphertext[:0], nonce, ciphertext, nil)
if err != nil {
return nil, err
return
}

// Read and validate payload length.
payloadLength := plaintext[:2]
length := int(binary.BigEndian.Uint16(payloadLength))
if 2+length > len(plaintext) {
return nil, fmt.Errorf("payload length %d is greater than plaintext length %d", length, len(plaintext))
payloadLengthBuf := plaintext[:2]
payloadLength := int(binary.BigEndian.Uint16(payloadLengthBuf))
if payloadLength > len(plaintext)-2 {
err = &HandlerErr{ErrPayloadLength, fmt.Sprintf("payload length field value %d is out of range", payloadLength)}
return
}

wgPacket = plaintext[2 : 2+length]
wgPacketStart = swgpPacketStart + chacha20poly1305.NonceSizeX + 2
wgPacketLength = payloadLength
return
}
25 changes: 8 additions & 17 deletions packet/paranoid_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,30 +24,21 @@ func testNewParanoidHandler(t *testing.T) Handler {

func testParanoidVerifyPacket(t *testing.T, wgPacket, swgpPacket, decryptedWgPacket []byte) {
if len(swgpPacket) < chacha20poly1305.NonceSizeX+2+len(wgPacket)+chacha20poly1305.Overhead {
t.Error("bad swgpPacket length")
t.Error("Bad swgpPacket length.")
}

if !bytes.Equal(wgPacket, decryptedWgPacket) {
t.Error("Decrypted packet is different from original packet.")
}
}

func TestParanoidHandleWireGuardHandshakeInitiationPacket(t *testing.T) {
func TestParanoidHandlePacket(t *testing.T) {
h := testNewParanoidHandler(t)
testHandler(t, WireGuardMessageTypeHandshakeInitiation, WireGuardMessageLengthHandshakeInitiation, h, testParanoidVerifyPacket)
}

func TestParanoidHandleWireGuardHandshakeResponsePacket(t *testing.T) {
h := testNewParanoidHandler(t)
testHandler(t, WireGuardMessageTypeHandshakeResponse, WireGuardMessageLengthHandshakeResponse, h, testParanoidVerifyPacket)
}

func TestParanoidHandleWireGuardHandshakeCookieReplyPacket(t *testing.T) {
h := testNewParanoidHandler(t)
testHandler(t, WireGuardMessageTypeHandshakeCookieReply, WireGuardMessageLengthHandshakeCookieReply, h, testParanoidVerifyPacket)
}

func TestParanoidHandleWireGuardDataPacket(t *testing.T) {
h := testNewParanoidHandler(t)
testHandler(t, WireGuardMessageTypeData, 1452, h, testParanoidVerifyPacket)
for i := 1; i < 128; i++ {
testHandler(t, WireGuardMessageTypeHandshakeInitiation, i, 0, 0, h, nil, nil, testParanoidVerifyPacket)
testHandler(t, WireGuardMessageTypeHandshakeResponse, i, 0, 0, h, nil, nil, testParanoidVerifyPacket)
testHandler(t, WireGuardMessageTypeHandshakeCookieReply, i, 0, 0, h, nil, nil, testParanoidVerifyPacket)
testHandler(t, WireGuardMessageTypeData, i, 0, 0, h, nil, nil, testParanoidVerifyPacket)
}
}
Loading

0 comments on commit 1ddb677

Please sign in to comment.