Skip to content

Commit

Permalink
Replace juju/ratelimit with golang.org/x/time/rate
Browse files Browse the repository at this point in the history
- Manual, partial cherry-pick of dde8c33
  • Loading branch information
rod-hynes committed Jul 5, 2024
1 parent debc2b6 commit eea9ace
Show file tree
Hide file tree
Showing 16 changed files with 713 additions and 775 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ require (
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e
github.com/google/gopacket v1.1.19
github.com/grafov/m3u8 v0.0.0-20171211212457-6ab8f28ed427
github.com/juju/ratelimit v1.0.2
github.com/marusama/semaphore v0.0.0-20171214154724-565ffd8e868a
github.com/miekg/dns v1.1.44-0.20210804161652-ab67aa642300
github.com/mitchellh/panicwrap v0.0.0-20170106182340-fce601fe5557
Expand All @@ -47,6 +46,7 @@ require (
golang.org/x/sync v0.2.0
golang.org/x/sys v0.19.0
golang.org/x/term v0.19.0
golang.org/x/time v0.5.0
)

require (
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,6 @@ github.com/jsimonetti/rtnetlink v0.0.0-20210212075122-66c871082f2b/go.mod h1:8w9
github.com/jsimonetti/rtnetlink v0.0.0-20210525051524-4cc836578190/go.mod h1:NmKSdU4VGSiv1bMsdqNALI4RSvvjtz65tTMCnD05qLo=
github.com/jsimonetti/rtnetlink v0.0.0-20210721205614-4cc3c1489576 h1:dH/k0qzR1oouF25AoMwH6FXOr16zV4WZFcYnZGpqro0=
github.com/jsimonetti/rtnetlink v0.0.0-20210721205614-4cc3c1489576/go.mod h1:qdKhcKUxYn3/QvneOvPWXXMPqktEBHnCW98wUTA3rmA=
github.com/juju/ratelimit v1.0.2 h1:sRxmtRiajbvrcLQT7S+JbqU0ntsb9W2yhSdNN8tWfaI=
github.com/juju/ratelimit v1.0.2/go.mod h1:qapgC/Gy+xNh9UxzV13HGGl/6UXNN+ct+vwSgWNm/qk=
github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0 h1:iQTw/8FWTuc7uiaSepXwyf3o52HaUYcV+Tu66S3F5GA=
github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0/go.mod h1:1NbS8ALrpOvjt0rHPNLyCIeMtbizbir8U//inJ+zuB8=
github.com/klauspost/compress v1.16.7 h1:2mk3MPGNzKyxErAw8YaohYh69+pa4sIQSC0fPGCFR9I=
Expand Down Expand Up @@ -339,6 +337,8 @@ golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
Expand Down
147 changes: 115 additions & 32 deletions psiphon/common/throttled.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import (
"time"

"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
"github.com/juju/ratelimit"
"golang.org/x/time/rate"
)

// RateLimits specify the rate limits for a ThrottledConn.
Expand Down Expand Up @@ -72,20 +72,28 @@ type ThrottledConn struct {
writeBytesPerSecond int64
closeAfterExhausted int32
readLock sync.Mutex
readRateLimiter *ratelimit.Bucket
readRateLimiter *rate.Limiter
readDelayTimer *time.Timer
writeLock sync.Mutex
writeRateLimiter *ratelimit.Bucket
writeRateLimiter *rate.Limiter
writeDelayTimer *time.Timer
isClosed int32
stopBroadcast chan struct{}
isStream bool
net.Conn
}

// NewThrottledConn initializes a new ThrottledConn.
func NewThrottledConn(conn net.Conn, limits RateLimits) *ThrottledConn {
//
// Set isStreamConn to true when conn is stream-oriented, such as TCP, and
// false when the conn is packet-oriented, such as UDP. When conn is a
// stream, reads and writes may be split to accomodate rate limits.
func NewThrottledConn(
conn net.Conn, isStream bool, limits RateLimits) *ThrottledConn {

throttledConn := &ThrottledConn{
Conn: conn,
isStream: isStream,
stopBroadcast: make(chan struct{}),
}
throttledConn.SetLimits(limits)
Expand Down Expand Up @@ -137,10 +145,8 @@ func (conn *ThrottledConn) Read(buffer []byte) (int, error) {
conn.readLock.Lock()
defer conn.readLock.Unlock()

select {
case <-conn.stopBroadcast:
if atomic.LoadInt32(&conn.isClosed) == 1 {
return 0, errors.TraceNew("throttled conn closed")
default:
}

// Use the base conn until the unthrottled count is
Expand All @@ -158,34 +164,68 @@ func (conn *ThrottledConn) Read(buffer []byte) (int, error) {
return 0, errors.TraceNew("throttled conn exhausted")
}

rate := atomic.SwapInt64(&conn.readBytesPerSecond, -1)
readRate := atomic.SwapInt64(&conn.readBytesPerSecond, -1)

if rate != -1 {
if readRate != -1 {
// SetLimits has been called and a new rate limiter
// must be initialized. When no limit is specified,
// the reader/writer is simply the base conn.
// No state is retained from the previous rate limiter,
// so a pending I/O throttle sleep may be skipped when
// the old and new rate are similar.
if rate == 0 {
if readRate == 0 {
conn.readRateLimiter = nil
} else {
conn.readRateLimiter =
ratelimit.NewBucketWithRate(float64(rate), rate)
rate.NewLimiter(rate.Limit(readRate), int(readRate))
}
}

// The number of bytes read cannot exceed the rate limiter burst size,
// which is enforced by rate.Limiter.ReserveN. Reduce any read buffer
// size to be at most the burst size.
//
// Read should still return as soon as read bytes are available; and the
// number of bytes that will be received is unknown; so there is no loop
// here to read more bytes. Reducing the read buffer size minimizes
// latency for the up-to-burst-size bytes read, whereas allowing a full
// read followed by multiple ReserveN calls and sleeps would increase
// latency.
//
// In practise, with Psiphon tunnels, throttling is not applied until
// after the Psiphon API handshake, so read buffer reductions won't
// impact early obfuscation traffic shaping; and reads are on the order
// of one SSH "packet", up to 32K, unlikely to be split for all but the
// most restrictive of rate limits.

if conn.readRateLimiter != nil {
burst := conn.readRateLimiter.Burst()
if len(buffer) > burst {
if !conn.isStream {
return 0, errors.TraceNew("non-stream read buffer exceeds burst")
}
buffer = buffer[:burst]
}
}

n, err := conn.Conn.Read(buffer)

// Sleep to enforce the rate limit. This is the same logic as implemented in
// ratelimit.Reader, but using a timer and a close signal instead of an
// uninterruptible time.Sleep.
//
// The readDelayTimer is always expired/stopped and drained after this code
// block and is ready to be Reset on the next call.
if n > 0 && conn.readRateLimiter != nil {

// While rate.Limiter.WaitN would be simpler to use, internally Wait
// creates a new timer for every call which must sleep, which is
// expected to be most calls. Instead, call ReserveN to get the sleep
// time and reuse one timer without allocation.
//
// TODO: avoid allocation: ReserveN allocates a *Reservation; while
// the internal reserveN returns a struct, not a pointer.

if n >= 0 && conn.readRateLimiter != nil {
sleepDuration := conn.readRateLimiter.Take(int64(n))
reservation := conn.readRateLimiter.ReserveN(time.Now(), n)
if !reservation.OK() {
// This error is not expected, given the buffer size adjustment.
return 0, errors.TraceNew("burst size exceeded")
}
sleepDuration := reservation.Delay()
if sleepDuration > 0 {
if conn.readDelayTimer == nil {
conn.readDelayTimer = time.NewTimer(sleepDuration)
Expand All @@ -202,7 +242,8 @@ func (conn *ThrottledConn) Read(buffer []byte) (int, error) {
}
}

return n, errors.Trace(err)
// Don't wrap I/O errors
return n, err
}

func (conn *ThrottledConn) Write(buffer []byte) (int, error) {
Expand All @@ -212,10 +253,8 @@ func (conn *ThrottledConn) Write(buffer []byte) (int, error) {
conn.writeLock.Lock()
defer conn.writeLock.Unlock()

select {
case <-conn.stopBroadcast:
if atomic.LoadInt32(&conn.isClosed) == 1 {
return 0, errors.TraceNew("throttled conn closed")
default:
}

if atomic.LoadInt64(&conn.writeUnthrottledBytes) > 0 {
Expand All @@ -229,19 +268,58 @@ func (conn *ThrottledConn) Write(buffer []byte) (int, error) {
return 0, errors.TraceNew("throttled conn exhausted")
}

rate := atomic.SwapInt64(&conn.writeBytesPerSecond, -1)
writeRate := atomic.SwapInt64(&conn.writeBytesPerSecond, -1)

if rate != -1 {
if rate == 0 {
if writeRate != -1 {
if writeRate == 0 {
conn.writeRateLimiter = nil
} else {
conn.writeRateLimiter =
ratelimit.NewBucketWithRate(float64(rate), rate)
rate.NewLimiter(rate.Limit(writeRate), int(writeRate))
}
}

if len(buffer) >= 0 && conn.writeRateLimiter != nil {
sleepDuration := conn.writeRateLimiter.Take(int64(len(buffer)))
if conn.writeRateLimiter == nil {
n, err := conn.Conn.Write(buffer)
// Don't wrap I/O errors
return n, err
}

// The number of bytes written cannot exceed the rate limiter burst size,
// which is enforced by rate.Limiter.ReserveN. Split writes to be at most
// the burst size.
//
// Splitting writes may have some effect on the shape of TCP packets sent
// on the network.
//
// In practise, with Psiphon tunnels, throttling is not applied until
// after the Psiphon API handshake, so write splits won't impact early
// obfuscation traffic shaping; and writes are on the order of one
// SSH "packet", up to 32K, unlikely to be split for all but the most
// restrictive of rate limits.

burst := conn.writeRateLimiter.Burst()
if !conn.isStream && len(buffer) > burst {
return 0, errors.TraceNew("non-stream write exceeds burst")
}
totalWritten := 0
for i := 0; i < len(buffer); i += burst {

j := i + burst
if j > len(buffer) {
j = len(buffer)
}
b := buffer[i:j]

// See comment in Read regarding rate.Limiter.ReserveN vs.
// rate.Limiter.WaitN.

reservation := conn.writeRateLimiter.ReserveN(time.Now(), len(b))
if !reservation.OK() {
// This error is not expected, given the write split adjustments.
return 0, errors.TraceNew("burst size exceeded")
}
sleepDuration := reservation.Delay()
if sleepDuration > 0 {
if conn.writeDelayTimer == nil {
conn.writeDelayTimer = time.NewTimer(sleepDuration)
Expand All @@ -256,11 +334,16 @@ func (conn *ThrottledConn) Write(buffer []byte) (int, error) {
}
}
}
}

n, err := conn.Conn.Write(buffer)
n, err := conn.Conn.Write(b)
totalWritten += n
if err != nil {
// Don't wrap I/O errors
return totalWritten, err
}
}

return n, errors.Trace(err)
return totalWritten, nil
}

func (conn *ThrottledConn) Close() error {
Expand Down
38 changes: 28 additions & 10 deletions psiphon/common/throttled_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ package common
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"math"
"net"
Expand Down Expand Up @@ -113,7 +114,7 @@ func runRateLimitsTest(t *testing.T, rateLimits RateLimits) {
if err != nil {
return conn, err
}
return NewThrottledConn(conn, rateLimits), nil
return NewThrottledConn(conn, true, rateLimits), nil
}

client := &http.Client{
Expand Down Expand Up @@ -204,27 +205,27 @@ func TestThrottledConnClose(t *testing.T) {
n := 4
b := make([]byte, n+1)

throttledConn := NewThrottledConn(&testConn{}, rateLimits)
throttledConn := NewThrottledConn(&testConn{}, true, rateLimits)

now := time.Now()
_, err := throttledConn.Read(b)
_, err := io.ReadFull(throttledConn, b)
elapsed := time.Since(now)
if err != nil || elapsed < time.Duration(n)*time.Second {
t.Errorf("unexpected interrupted read: %s, %v", elapsed, err)
}

now = time.Now()
go func() {
go func(conn net.Conn) {
time.Sleep(500 * time.Millisecond)
throttledConn.Close()
}()
conn.Close()
}(throttledConn)
_, err = throttledConn.Read(b)
elapsed = time.Since(now)
if elapsed > 1*time.Second {
t.Errorf("unexpected uninterrupted read: %s, %v", elapsed, err)
}

throttledConn = NewThrottledConn(&testConn{}, rateLimits)
throttledConn = NewThrottledConn(&testConn{}, true, rateLimits)

now = time.Now()
_, err = throttledConn.Write(b)
Expand All @@ -234,17 +235,34 @@ func TestThrottledConnClose(t *testing.T) {
}

now = time.Now()
go func() {
go func(conn net.Conn) {
time.Sleep(500 * time.Millisecond)
throttledConn.Close()
}()
conn.Close()
}(throttledConn)
_, err = throttledConn.Write(b)
elapsed = time.Since(now)
if elapsed > 1*time.Second {
t.Errorf("unexpected uninterrupted write: %s, %v", elapsed, err)
}
}

func TestNonStreamThrottledConn(t *testing.T) {

MTU := int64(1500)

rateLimits := RateLimits{
ReadBytesPerSecond: MTU - 1,
WriteBytesPerSecond: MTU - 1,
}

throttledConn := NewThrottledConn(&testConn{}, false, rateLimits)

_, err := throttledConn.Write(make([]byte, MTU))
if err == nil {
t.Errorf("unexpected split write")
}
}

type testConn struct {
}

Expand Down
Loading

0 comments on commit eea9ace

Please sign in to comment.