Skip to content

Commit

Permalink
fix(netstack): TCPConn must not implement net.PacketConn (#40)
Browse files Browse the repository at this point in the history
Otherwise, miekg/dns cannot distinguish between TCP and UDP.

Fix by using composition rather than embedding.
  • Loading branch information
bassosimone authored Nov 27, 2024
1 parent 376adf5 commit 2270ccf
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 16 deletions.
51 changes: 39 additions & 12 deletions netsim/netstack/tcpconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,16 @@ import (
type TCPConn struct {
buf bytes.Buffer
initonce sync.Once
*Port
rlock sync.Mutex
p *Port
rlock sync.Mutex
}

// NewTCPConn creates a new TCP connection.
func NewTCPConn(p *Port) *TCPConn {
return &TCPConn{
buf: bytes.Buffer{},
initonce: sync.Once{},
Port: p,
p: p,
rlock: sync.Mutex{},
}
}
Expand All @@ -41,7 +41,7 @@ func (c *TCPConn) Accept() (err error) {
c.initonce.Do(func() {
c.SetDeadline(time.Now().Add(time.Second))
defer c.SetDeadline(time.Time{})
err = c.Port.WritePacket(nil, TCPFlagSYN|TCPFlagACK, netip.AddrPort{})
err = c.p.WritePacket(nil, TCPFlagSYN|TCPFlagACK, netip.AddrPort{})
})
return
}
Expand All @@ -53,12 +53,12 @@ func (c *TCPConn) Connect(ctx context.Context) (err error) {
c.SetDeadline(d)
defer c.SetDeadline(time.Time{})
}
err = c.Port.WritePacket(nil, TCPFlagSYN, netip.AddrPort{})
err = c.p.WritePacket(nil, TCPFlagSYN, netip.AddrPort{})
if err != nil {
return
}
var pkt *Packet
pkt, err = c.Port.ReadPacket()
pkt, err = c.p.ReadPacket()
if err != nil {
return
}
Expand All @@ -74,9 +74,6 @@ func (c *TCPConn) Connect(ctx context.Context) (err error) {
return
}

// Ensure [*TCPConn] implements [net.PacketConn].
var _ net.PacketConn = &TCPConn{}

// Ensure [*TCPConn] implements [net.Conn].
var _ net.Conn = &TCPConn{}

Expand All @@ -94,7 +91,7 @@ func (c *TCPConn) Read(buf []byte) (int, error) {
}

// otherwise, attempt to read the next packet
pkt, err := c.Port.ReadPacket()
pkt, err := c.p.ReadPacket()
if err != nil {
return 0, err
}
Expand All @@ -119,6 +116,36 @@ func (c *TCPConn) Read(buf []byte) (int, error) {

// Close implements [net.Conn].
func (c *TCPConn) Close() error {
c.Port.WritePacket(nil, TCPFlagFIN, netip.AddrPort{})
return c.Port.Close()
c.p.WritePacket(nil, TCPFlagFIN, netip.AddrPort{})
return c.p.Close()
}

// LocalAddr implements [net.Conn].
func (c *TCPConn) LocalAddr() net.Addr {
return c.p.LocalAddr()
}

// RemoteAddr implements [net.Conn].
func (c *TCPConn) RemoteAddr() net.Addr {
return c.p.RemoteAddr()
}

// SetDeadline implements [net.Conn].
func (c *TCPConn) SetDeadline(t time.Time) error {
return c.p.SetDeadline(t)
}

// SetReadDeadline implements [net.Conn].
func (c *TCPConn) SetReadDeadline(t time.Time) error {
return c.p.SetReadDeadline(t)
}

// SetWriteDeadline implements [net.Conn].
func (c *TCPConn) SetWriteDeadline(t time.Time) error {
return c.p.SetWriteDeadline(t)
}

// Write implements [net.Conn].
func (c *TCPConn) Write(data []byte) (int, error) {
return c.p.Write(data)
}
56 changes: 52 additions & 4 deletions netsim/netstack/udpconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,76 @@

package netstack

import "net"
import (
"net"
"time"
)

// UDPConn is a UDP connection.
//
// The zero value is invalid; construct using [NewUDPConn].
type UDPConn struct {
*Port
p *Port
}

// NewUDPConn creates a new UDP connection.
func NewUDPConn(p *Port) *UDPConn {
return &UDPConn{Port: p}
return &UDPConn{p: p}
}

// Ensure [*UDPConn] implements [net.PacketConn].
var _ net.PacketConn = &UDPConn{}

// Close implements [net.PacketConn].
func (c *UDPConn) Close() error {
return c.p.Close()
}

// LocalAddr implements [net.PacketConn].
func (c *UDPConn) LocalAddr() net.Addr {
return c.p.LocalAddr()
}

// ReadFrom implements [net.PacketConn].
func (c *UDPConn) ReadFrom(buf []byte) (int, net.Addr, error) {
return c.p.ReadFrom(buf)
}

// SetDeadline implements [net.PacketConn].
func (c *UDPConn) SetDeadline(t time.Time) error {
return c.p.SetDeadline(t)
}

// SetReadDeadline implements [net.PacketConn].
func (c *UDPConn) SetReadDeadline(t time.Time) error {
return c.p.SetReadDeadline(t)
}

// SetWriteDeadline implements net.PacketConn.
func (c *UDPConn) SetWriteDeadline(t time.Time) error {
return c.p.SetWriteDeadline(t)
}

// WriteTo implements net.PacketConn.
func (c *UDPConn) WriteTo(pkt []byte, addr net.Addr) (int, error) {
return c.p.WriteTo(pkt, addr)
}

// Ensure [*UDPConn] implements [net.Conn].
var _ net.Conn = &UDPConn{}

// Read implements [net.Conn].
func (c *UDPConn) Read(buf []byte) (int, error) {
count, _, err := c.Port.ReadFrom(buf)
count, _, err := c.p.ReadFrom(buf)
return count, err
}

// RemoteAddr implements [net.Conn].
func (c *UDPConn) RemoteAddr() net.Addr {
return c.p.RemoteAddr()
}

// Write implements [net.Conn].
func (c *UDPConn) Write(data []byte) (int, error) {
return c.p.Write(data)
}

0 comments on commit 2270ccf

Please sign in to comment.