diff --git a/config.go b/config.go index 2e2695c..e442788 100644 --- a/config.go +++ b/config.go @@ -46,7 +46,7 @@ type Config struct { NoDelay bool `json:"nodelay"` // (optional) client-side keep alive interval in seconds, default to 25 (every 25s) KeepAlive int `json:"keepalive"` - // (optional) server-side keep alive interval in seconds, default to 0 (disabled) + // (optional) server-side keep alive interval in seconds, default to 300 (every 5min) ServerKeepAlive int `json:"serverkeepalive"` // (optional) soft limit of concurrent unauthenticated connections, default to 10 StartupLimitStart int `json:"startuplimitstart"` @@ -56,6 +56,8 @@ type Config struct { StartupLimitFull int `json:"startuplimitfull"` // (optional) max concurrent streams, default to 4096 MaxConn int `json:"maxconn"` + // (optional) max concurrent sessions, default to 128 + MaxSessions int `json:"maxsessions"` // (optional) mux accept backlog, default to 16, you may not want to change this AcceptBacklog int `json:"backlog"` // (optional) stream window size in bytes, default to 256KiB, increase this on long fat networks @@ -79,6 +81,7 @@ var DefaultConfig = Config{ StartupLimitRate: 30, StartupLimitFull: 60, MaxConn: 4096, + MaxSessions: 128, AcceptBacklog: 16, StreamWindow: 256 * 1024, // 256 KiB RequestTimeout: 30, diff --git a/handler.go b/handler.go index b7a9faf..ede02df 100644 --- a/handler.go +++ b/handler.go @@ -9,6 +9,7 @@ import ( "github.com/hashicorp/yamux" "github.com/hexian000/tlswrapper/formats" + "github.com/hexian000/tlswrapper/hlistener" "github.com/hexian000/tlswrapper/meter" "github.com/hexian000/tlswrapper/proto" "github.com/hexian000/tlswrapper/slog" @@ -26,8 +27,11 @@ type TLSHandler struct { unauthorized atomic.Uint32 } -func (h *TLSHandler) Unauthorized() uint32 { - return h.unauthorized.Load() +func (h *TLSHandler) Stats() hlistener.ServerStats { + return hlistener.ServerStats{ + Sessions: uint32(h.s.NumSessions()), + HalfOpen: h.unauthorized.Load(), + } } func (h *TLSHandler) Serve(ctx context.Context, conn net.Conn) { diff --git a/hlistener/hlistener.go b/hlistener/hlistener.go index e3aefe6..93a5a91 100644 --- a/hlistener/hlistener.go +++ b/hlistener/hlistener.go @@ -6,10 +6,16 @@ import ( "sync/atomic" ) +type ServerStats struct { + Sessions uint32 + HalfOpen uint32 +} + type Config struct { - Start, Full uint32 - Rate float64 - Unauthorized func() uint32 + Start, Full uint32 + Rate float64 + MaxSessions uint32 + Stats func() ServerStats } type Listener struct { @@ -21,6 +27,20 @@ type Listener struct { } } +func (l *Listener) isLimited() bool { + stats := l.c.Stats() + if l.c.MaxSessions > 0 && stats.Sessions >= l.c.MaxSessions { + return true + } + if stats.HalfOpen >= l.c.Full { + return true + } + if stats.HalfOpen >= l.c.Start { + return rand.Float64() < l.c.Rate + } + return false +} + func (l *Listener) Accept() (net.Conn, error) { for { conn, err := l.l.Accept() @@ -28,16 +48,7 @@ func (l *Listener) Accept() (net.Conn, error) { return conn, err } l.stats.Accepted.Add(1) - n := l.c.Unauthorized() - refuse := false - if n >= l.c.Start { - if n >= l.c.Full { - refuse = true - } else { - refuse = rand.Float64() < l.c.Rate - } - } - if refuse { + if l.isLimited() { _ = conn.Close() continue } diff --git a/tunnel.go b/tunnel.go index fbff48c..7c0effe 100644 --- a/tunnel.go +++ b/tunnel.go @@ -47,10 +47,10 @@ func (t *Tunnel) Start() error { h := &TLSHandler{s: t.s, t: t} c := t.s.getConfig() t.l = hlistener.Wrap(l, &hlistener.Config{ - Start: uint32(c.StartupLimitStart), - Full: uint32(c.StartupLimitFull), - Rate: float64(c.StartupLimitRate) / 100.0, - Unauthorized: h.Unauthorized, + Start: uint32(c.StartupLimitStart), + Full: uint32(c.StartupLimitFull), + Rate: float64(c.StartupLimitRate) / 100.0, + Stats: h.Stats, }) l = t.l if err := t.s.g.Go(func() {