diff --git a/MobileLibrary/Android/PsiphonTunnel/ca_psiphon_psiphontunnel_backup_rules.xml b/MobileLibrary/Android/PsiphonTunnel/ca_psiphon_psiphontunnel_backup_rules.xml
index 29077a85d..f766196da 100644
--- a/MobileLibrary/Android/PsiphonTunnel/ca_psiphon_psiphontunnel_backup_rules.xml
+++ b/MobileLibrary/Android/PsiphonTunnel/ca_psiphon_psiphontunnel_backup_rules.xml
@@ -1,4 +1,4 @@
-
+
diff --git a/psiphon/dialParameters.go b/psiphon/dialParameters.go
index 8bdc07171..3fa9d72ad 100644
--- a/psiphon/dialParameters.go
+++ b/psiphon/dialParameters.go
@@ -773,7 +773,7 @@ func MakeDialParameters(
case protocol.TUNNEL_PROTOCOL_MARIONETTE_OBFUSCATED_SSH:
- // Note: port comes from marionnete "format"
+ // Note: port comes from marionette "format"
dialParams.DirectDialAddress = serverEntry.IpAddress
default:
diff --git a/psiphon/server/listener.go b/psiphon/server/listener.go
index 03021c2ba..71a8ff31e 100644
--- a/psiphon/server/listener.go
+++ b/psiphon/server/listener.go
@@ -25,14 +25,15 @@ import (
"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/fragmentor"
- "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/parameters"
"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
)
// TacticsListener wraps a net.Listener and applies server-side implementation
// of certain tactics parameters to accepted connections. Tactics filtering is
-// limited to GeoIP attributes as the client has not yet sent API paramaters.
+// limited to GeoIP attributes as the client has not yet sent API parameters.
+// GeoIP uses the immediate peer IP, and so TacticsListener is suitable only
+// for tactics that do not require the original client GeoIP when fronted.
type TacticsListener struct {
net.Listener
support *SupportServices
@@ -77,11 +78,14 @@ func (listener *TacticsListener) accept() (net.Conn, error) {
return nil, err
}
+ // Limitation: RemoteAddr is the immediate peer IP, which is not the original
+ // client IP in the case of fronting.
geoIPData := listener.geoIPLookup(
common.IPAddressFromAddr(conn.RemoteAddr()))
p, err := listener.support.ServerTacticsParametersCache.Get(geoIPData)
if err != nil {
+ conn.Close()
return nil, errors.Trace(err)
}
@@ -90,34 +94,6 @@ func (listener *TacticsListener) accept() (net.Conn, error) {
return conn, nil
}
- // Disconnect immediately if the clients tactics restricts usage of the
- // fronting provider ID. The probability may be used to influence usage of a
- // given fronting provider; but when only that provider works for a given
- // client, and the probability is less than 1.0, the client can retry until
- // it gets a successful coin flip.
- //
- // Clients will also skip candidates with restricted fronting provider IDs.
- // The client-side probability, RestrictFrontingProviderIDsClientProbability,
- // is applied independently of the server-side coin flip here.
- //
- //
- // At this stage, GeoIP tactics filters are active, but handshake API
- // parameters are not.
- //
- // See the comment in server.LoadConfig regarding fronting provider ID
- // limitations.
-
- if protocol.TunnelProtocolUsesFrontedMeek(listener.tunnelProtocol) &&
- common.Contains(
- p.Strings(parameters.RestrictFrontingProviderIDs),
- listener.support.Config.GetFrontingProviderID()) {
- if p.WeightedCoinFlip(
- parameters.RestrictFrontingProviderIDsServerProbability) {
- conn.Close()
- return nil, nil
- }
- }
-
// Server-side fragmentation may be synchronized with client-side in two ways.
//
// In the OSSH case, replay is always activated and it is seeded using the
diff --git a/psiphon/server/listener_test.go b/psiphon/server/listener_test.go
index 0d5a05f84..2447a78c3 100644
--- a/psiphon/server/listener_test.go
+++ b/psiphon/server/listener_test.go
@@ -28,7 +28,6 @@ import (
"time"
"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/fragmentor"
- "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/tactics"
)
@@ -37,8 +36,6 @@ func TestListener(t *testing.T) {
tunnelProtocol := protocol.TUNNEL_PROTOCOL_FRONTED_MEEK
- frontingProviderID := prng.HexString(8)
-
tacticsConfigJSONFormat := `
{
"RequestPublicKey" : "%s",
@@ -65,19 +62,6 @@ func TestListener(t *testing.T) {
"FragmentorDownstreamMaxWriteBytes" : 1
}
}
- },
- {
- "Filter" : {
- "Regions": ["R3"],
- "ISPs": ["I3"],
- "Cities": ["C3"]
- },
- "Tactics" : {
- "Parameters" : {
- "RestrictFrontingProviderIDs" : ["%s"],
- "RestrictFrontingProviderIDsServerProbability" : 1.0
- }
- }
}
]
}
@@ -92,7 +76,7 @@ func TestListener(t *testing.T) {
tacticsConfigJSON := fmt.Sprintf(
tacticsConfigJSONFormat,
tacticsRequestPublicKey, tacticsRequestPrivateKey, tacticsRequestObfuscatedKey,
- tunnelProtocol, frontingProviderID)
+ tunnelProtocol)
tacticsConfigFilename := filepath.Join(testDataDirName, "tactics_config.json")
@@ -122,12 +106,6 @@ func TestListener(t *testing.T) {
listenerUnfragmentedGeoIPWrongCity := func(string) GeoIPData {
return GeoIPData{Country: "R1", ISP: "I1", City: "C2"}
}
- listenerRestrictedFrontingProviderIDGeoIP := func(string) GeoIPData {
- return GeoIPData{Country: "R3", ISP: "I3", City: "C3"}
- }
- listenerUnrestrictedFrontingProviderIDWrongRegion := func(string) GeoIPData {
- return GeoIPData{Country: "R2", ISP: "I3", City: "C3"}
- }
listenerTestCases := []struct {
description string
@@ -159,18 +137,6 @@ func TestListener(t *testing.T) {
false,
true,
},
- {
- "restricted",
- listenerRestrictedFrontingProviderIDGeoIP,
- false,
- false,
- },
- {
- "unrestricted-region",
- listenerUnrestrictedFrontingProviderIDWrongRegion,
- false,
- true,
- },
}
for _, testCase := range listenerTestCases {
@@ -182,7 +148,7 @@ func TestListener(t *testing.T) {
}
support := &SupportServices{
- Config: &Config{frontingProviderID: frontingProviderID},
+ Config: &Config{},
TacticsServer: tacticsServer,
}
support.ReplayCache = NewReplayCache(support)
diff --git a/psiphon/server/meek.go b/psiphon/server/meek.go
index 0042b35f9..8663fdd6e 100644
--- a/psiphon/server/meek.go
+++ b/psiphon/server/meek.go
@@ -43,6 +43,7 @@ import (
"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/obfuscator"
+ "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/parameters"
"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/values"
@@ -344,8 +345,9 @@ func (server *MeekServer) ServeHTTP(responseWriter http.ResponseWriter, request
session,
underlyingConn,
endPoint,
- clientIP,
+ endPointGeoIPData,
err := server.getSessionOrEndpoint(request, meekCookie)
+
if err != nil {
// Debug since session cookie errors commonly occur during
// normal operation.
@@ -359,9 +361,8 @@ func (server *MeekServer) ServeHTTP(responseWriter http.ResponseWriter, request
// Endpoint mode. Currently, this means it's handled by the tactics
// request handler.
- geoIPData := server.support.GeoIPService.Lookup(clientIP)
handled := server.support.TacticsServer.HandleEndPoint(
- endPoint, common.GeoIPData(geoIPData), responseWriter, request)
+ endPoint, common.GeoIPData(*endPointGeoIPData), responseWriter, request)
if !handled {
log.WithTraceFields(LogFields{"endPoint": endPoint}).Info("unhandled endpoint")
common.TerminateHTTPConnection(responseWriter, request)
@@ -587,7 +588,7 @@ func checkRangeHeader(request *http.Request) (int, bool) {
// mode; or the endpoint is returned when the meek cookie indicates endpoint
// mode.
func (server *MeekServer) getSessionOrEndpoint(
- request *http.Request, meekCookie *http.Cookie) (string, *meekSession, net.Conn, string, string, error) {
+ request *http.Request, meekCookie *http.Cookie) (string, *meekSession, net.Conn, string, *GeoIPData, error) {
underlyingConn := request.Context().Value(meekNetConnContextKey).(net.Conn)
@@ -601,7 +602,7 @@ func (server *MeekServer) getSessionOrEndpoint(
// TODO: can multiple http client connections using same session cookie
// cause race conditions on session struct?
session.touch()
- return existingSessionID, session, underlyingConn, "", "", nil
+ return existingSessionID, session, underlyingConn, "", nil, nil
}
// Determine the client remote address, which is used for geolocation
@@ -610,6 +611,8 @@ func (server *MeekServer) getSessionOrEndpoint(
// headers such as X-Forwarded-For.
clientIP := strings.Split(request.RemoteAddr, ":")[0]
+ usedProxyForwardedForHeader := false
+ var geoIPData GeoIPData
if len(server.support.Config.MeekProxyForwardedForHeaders) > 0 {
for _, header := range server.support.Config.MeekProxyForwardedForHeaders {
@@ -619,23 +622,29 @@ func (server *MeekServer) getSessionOrEndpoint(
// list of IPs (each proxy in a chain). The first IP should be
// the client IP.
proxyClientIP := strings.Split(value, ",")[0]
- if net.ParseIP(proxyClientIP) != nil &&
- server.support.GeoIPService.Lookup(
- proxyClientIP).Country != GEOIP_UNKNOWN_VALUE {
-
- clientIP = proxyClientIP
- break
+ if net.ParseIP(proxyClientIP) != nil {
+ proxyClientGeoIPData := server.support.GeoIPService.Lookup(proxyClientIP)
+ if proxyClientGeoIPData.Country != GEOIP_UNKNOWN_VALUE {
+ usedProxyForwardedForHeader = true
+ clientIP = proxyClientIP
+ geoIPData = proxyClientGeoIPData
+ break
+ }
}
}
}
}
+ if !usedProxyForwardedForHeader {
+ geoIPData = server.support.GeoIPService.Lookup(clientIP)
+ }
+
// The session is new (or expired). Treat the cookie value as a new meek
// cookie, extract the payload, and create a new session.
payloadJSON, err := server.getMeekCookiePayload(clientIP, meekCookie.Value)
if err != nil {
- return "", nil, nil, "", "", errors.Trace(err)
+ return "", nil, nil, "", nil, errors.Trace(err)
}
// Note: this meek server ignores legacy values PsiphonClientSessionId
@@ -644,7 +653,7 @@ func (server *MeekServer) getSessionOrEndpoint(
err = json.Unmarshal(payloadJSON, &clientSessionData)
if err != nil {
- return "", nil, nil, "", "", errors.Trace(err)
+ return "", nil, nil, "", nil, errors.Trace(err)
}
tunnelProtocol := server.listenerTunnelProtocol
@@ -656,7 +665,7 @@ func (server *MeekServer) getSessionOrEndpoint(
server.listenerTunnelProtocol,
server.support.Config.GetRunningProtocols()) {
- return "", nil, nil, "", "", errors.Tracef(
+ return "", nil, nil, "", nil, errors.Tracef(
"invalid client tunnel protocol: %s", clientSessionData.ClientTunnelProtocol)
}
@@ -669,8 +678,8 @@ func (server *MeekServer) getSessionOrEndpoint(
// rate limit is primarily intended to limit memory resource consumption and
// not the overhead incurred by cookie validation.
- if server.rateLimit(clientIP, tunnelProtocol) {
- return "", nil, nil, "", "", errors.TraceNew("rate limit exceeded")
+ if server.rateLimit(clientIP, geoIPData, tunnelProtocol) {
+ return "", nil, nil, "", nil, errors.TraceNew("rate limit exceeded")
}
// Handle endpoints before enforcing CheckEstablishTunnels.
@@ -678,7 +687,7 @@ func (server *MeekServer) getSessionOrEndpoint(
// handled by servers which would otherwise reject new tunnels.
if clientSessionData.EndPoint != "" {
- return "", nil, nil, clientSessionData.EndPoint, clientIP, nil
+ return "", nil, nil, clientSessionData.EndPoint, &geoIPData, nil
}
// Don't create new sessions when not establishing. A subsequent SSH handshake
@@ -686,7 +695,42 @@ func (server *MeekServer) getSessionOrEndpoint(
if server.support.TunnelServer != nil &&
!server.support.TunnelServer.CheckEstablishTunnels() {
- return "", nil, nil, "", "", errors.TraceNew("not establishing tunnels")
+ return "", nil, nil, "", nil, errors.TraceNew("not establishing tunnels")
+ }
+
+ // Disconnect immediately if the tactics for the client restricts usage of
+ // the fronting provider ID. The probability may be used to influence
+ // usage of a given fronting provider; but when only that provider works
+ // for a given client, and the probability is less than 1.0, the client
+ // can retry until it gets a successful coin flip.
+ //
+ // Clients will also skip candidates with restricted fronting provider IDs.
+ // The client-side probability, RestrictFrontingProviderIDsClientProbability,
+ // is applied independently of the server-side coin flip here.
+ //
+ // At this stage, GeoIP tactics filters are active, but handshake API
+ // parameters are not.
+ //
+ // See the comment in server.LoadConfig regarding fronting provider ID
+ // limitations.
+
+ if protocol.TunnelProtocolUsesFrontedMeek(server.listenerTunnelProtocol) &&
+ server.support.ServerTacticsParametersCache != nil {
+
+ p, err := server.support.ServerTacticsParametersCache.Get(geoIPData)
+ if err != nil {
+ return "", nil, nil, "", nil, errors.Trace(err)
+ }
+
+ if !p.IsNil() &&
+ common.Contains(
+ p.Strings(parameters.RestrictFrontingProviderIDs),
+ server.support.Config.GetFrontingProviderID()) {
+ if p.WeightedCoinFlip(
+ parameters.RestrictFrontingProviderIDsServerProbability) {
+ return "", nil, nil, "", nil, errors.TraceNew("restricted fronting provider")
+ }
+ }
}
// Create a new session
@@ -736,7 +780,7 @@ func (server *MeekServer) getSessionOrEndpoint(
if clientSessionData.MeekProtocolVersion >= MEEK_PROTOCOL_VERSION_2 {
sessionID, err = makeMeekSessionID()
if err != nil {
- return "", nil, nil, "", "", errors.Trace(err)
+ return "", nil, nil, "", nil, errors.Trace(err)
}
}
@@ -748,10 +792,11 @@ func (server *MeekServer) getSessionOrEndpoint(
// will close when session.delete calls Close() on the meekConn.
server.clientHandler(clientSessionData.ClientTunnelProtocol, session.clientConn)
- return sessionID, session, underlyingConn, "", "", nil
+ return sessionID, session, underlyingConn, "", nil, nil
}
-func (server *MeekServer) rateLimit(clientIP string, tunnelProtocol string) bool {
+func (server *MeekServer) rateLimit(
+ clientIP string, geoIPData GeoIPData, tunnelProtocol string) bool {
historySize,
thresholdSeconds,
@@ -774,9 +819,6 @@ func (server *MeekServer) rateLimit(clientIP string, tunnelProtocol string) bool
if len(regions) > 0 || len(ISPs) > 0 || len(cities) > 0 {
- // TODO: avoid redundant GeoIP lookups?
- geoIPData := server.support.GeoIPService.Lookup(clientIP)
-
if len(regions) > 0 {
if !common.Contains(regions, geoIPData.Country) {
return false
diff --git a/psiphon/server/meek_test.go b/psiphon/server/meek_test.go
index 8f2a15a42..656f2ff42 100755
--- a/psiphon/server/meek_test.go
+++ b/psiphon/server/meek_test.go
@@ -25,8 +25,10 @@ import (
crypto_rand "crypto/rand"
"encoding/base64"
"fmt"
+ "io/ioutil"
"math/rand"
"net"
+ "path/filepath"
"sync"
"sync/atomic"
"syscall"
@@ -38,6 +40,7 @@ import (
"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/parameters"
"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
+ "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/tactics"
"golang.org/x/crypto/nacl/box"
)
@@ -245,6 +248,7 @@ func TestMeekResiliency(t *testing.T) {
},
TrafficRulesSet: &TrafficRulesSet{},
}
+ mockSupport.GeoIPService, _ = NewGeoIPService([]string{})
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
@@ -401,19 +405,73 @@ func (interruptor *fileDescriptorInterruptor) BindToDevice(fileDescriptor int) (
}
func TestMeekRateLimiter(t *testing.T) {
- runTestMeekRateLimiter(t, true)
- runTestMeekRateLimiter(t, false)
+ runTestMeekAccessControl(t, true, false)
+ runTestMeekAccessControl(t, false, false)
}
-func runTestMeekRateLimiter(t *testing.T, rateLimit bool) {
+func TestMeekRestrictFrontingProviders(t *testing.T) {
+ runTestMeekAccessControl(t, false, true)
+ runTestMeekAccessControl(t, false, false)
+}
+
+func runTestMeekAccessControl(t *testing.T, rateLimit, restrictProvider bool) {
attempts := 10
allowedConnections := 5
+
if !rateLimit {
allowedConnections = 10
}
+ if restrictProvider {
+ allowedConnections = 0
+ }
+
+ // Configure tactics
+
+ frontingProviderID := prng.HexString(8)
+
+ tacticsConfigJSONFormat := `
+ {
+ "RequestPublicKey" : "%s",
+ "RequestPrivateKey" : "%s",
+ "RequestObfuscatedKey" : "%s",
+ "DefaultTactics" : {
+ "TTL" : "60s",
+ "Probability" : 1.0,
+ "Parameters" : {
+ "RestrictFrontingProviderIDs" : ["%s"],
+ "RestrictFrontingProviderIDsServerProbability" : 1.0
+ }
+ }
+ }
+ `
+
+ tacticsRequestPublicKey, tacticsRequestPrivateKey, tacticsRequestObfuscatedKey, err :=
+ tactics.GenerateKeys()
+ if err != nil {
+ t.Fatalf("error generating tactics keys: %s", err)
+ }
+
+ restrictFrontingProviderID := ""
+
+ if restrictProvider {
+ restrictFrontingProviderID = frontingProviderID
+ }
+
+ tacticsConfigJSON := fmt.Sprintf(
+ tacticsConfigJSONFormat,
+ tacticsRequestPublicKey, tacticsRequestPrivateKey, tacticsRequestObfuscatedKey,
+ restrictFrontingProviderID)
+
+ tacticsConfigFilename := filepath.Join(testDataDirName, "tactics_config.json")
+
+ err = ioutil.WriteFile(tacticsConfigFilename, []byte(tacticsConfigJSON), 0600)
+ if err != nil {
+ t.Fatalf("error paving tactics config file: %s", err)
+ }
+
// Run meek server
rawMeekCookieEncryptionPublicKey, rawMeekCookieEncryptionPrivateKey, err := box.GenerateKey(crypto_rand.Reader)
@@ -424,11 +482,11 @@ func runTestMeekRateLimiter(t *testing.T, rateLimit bool) {
meekCookieEncryptionPrivateKey := base64.StdEncoding.EncodeToString(rawMeekCookieEncryptionPrivateKey[:])
meekObfuscatedKey := prng.HexString(SSH_OBFUSCATED_KEY_BYTE_LENGTH)
- tunnelProtocol := protocol.TUNNEL_PROTOCOL_UNFRONTED_MEEK
+ tunnelProtocol := protocol.TUNNEL_PROTOCOL_FRONTED_MEEK
meekRateLimiterTunnelProtocols := []string{tunnelProtocol}
if !rateLimit {
- meekRateLimiterTunnelProtocols = []string{protocol.TUNNEL_PROTOCOL_UNFRONTED_MEEK_HTTPS}
+ meekRateLimiterTunnelProtocols = []string{protocol.TUNNEL_PROTOCOL_FRONTED_MEEK}
}
mockSupport := &SupportServices{
@@ -436,6 +494,7 @@ func runTestMeekRateLimiter(t *testing.T, rateLimit bool) {
MeekObfuscatedKey: meekObfuscatedKey,
MeekCookieEncryptionPrivateKey: meekCookieEncryptionPrivateKey,
TunnelProtocolPorts: map[string]int{tunnelProtocol: 0},
+ frontingProviderID: frontingProviderID,
},
TrafficRulesSet: &TrafficRulesSet{
MeekRateLimiterHistorySize: allowedConnections,
@@ -445,6 +504,15 @@ func runTestMeekRateLimiter(t *testing.T, rateLimit bool) {
MeekRateLimiterReapHistoryFrequencySeconds: 1,
},
}
+ mockSupport.GeoIPService, _ = NewGeoIPService([]string{})
+
+ tacticsServer, err := tactics.NewServer(nil, nil, nil, tacticsConfigFilename)
+ if err != nil {
+ t.Fatalf("tactics.NewServer failed: %s", err)
+ }
+
+ mockSupport.TacticsServer = tacticsServer
+ mockSupport.ServerTacticsParametersCache = NewServerTacticsParametersCache(mockSupport)
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
@@ -567,8 +635,8 @@ func runTestMeekRateLimiter(t *testing.T, rateLimit bool) {
totalFailures != attempts-totalConnections {
t.Fatalf(
- "Unexpected results: %d connections, %d failures",
- totalConnections, totalFailures)
+ "Unexpected results: %d connections, %d failures, %d allowed",
+ totalConnections, totalFailures, allowedConnections)
}
// Graceful shutdown
diff --git a/psiphon/server/server_test.go b/psiphon/server/server_test.go
index 25232babb..4e6b89d5a 100644
--- a/psiphon/server/server_test.go
+++ b/psiphon/server/server_test.go
@@ -1338,6 +1338,10 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
expectServerBPFField := ServerBPFEnabled() && doServerTactics
expectServerPacketManipulationField := runConfig.doPacketManipulation
expectBurstFields := runConfig.doBurstMonitor
+ expectTCPPortForwardDial := runConfig.doTunneledWebRequest
+ expectTCPDataTransfer := runConfig.doTunneledWebRequest && !expectTrafficFailure && !runConfig.doSplitTunnel
+ // Even with expectTrafficFailure, DNS port forwards will succeed
+ expectUDPDataTransfer := runConfig.doTunneledNTPRequest
select {
case logFields := <-serverTunnelLog:
@@ -1347,6 +1351,9 @@ func runServer(t *testing.T, runConfig *runServerConfig) {
expectServerBPFField,
expectServerPacketManipulationField,
expectBurstFields,
+ expectTCPPortForwardDial,
+ expectTCPDataTransfer,
+ expectUDPDataTransfer,
logFields)
if err != nil {
t.Fatalf("invalid server tunnel log fields: %s", err)
@@ -1404,6 +1411,9 @@ func checkExpectedServerTunnelLogFields(
expectServerBPFField bool,
expectServerPacketManipulationField bool,
expectBurstFields bool,
+ expectTCPPortForwardDial bool,
+ expectTCPDataTransfer bool,
+ expectUDPDataTransfer bool,
fields map[string]interface{}) error {
// Limitations:
@@ -1649,6 +1659,66 @@ func checkExpectedServerTunnelLogFields(
return fmt.Errorf("unexpected network_type '%s'", fields["network_type"])
}
+ var checkTCPMetric func(float64) bool
+ if expectTCPPortForwardDial {
+ checkTCPMetric = func(f float64) bool { return f > 0 }
+ } else {
+ checkTCPMetric = func(f float64) bool { return f == 0 }
+ }
+
+ for _, name := range []string{
+ "peak_concurrent_dialing_port_forward_count_tcp",
+ } {
+ if fields[name] == nil {
+ return fmt.Errorf("missing expected field '%s'", name)
+ }
+ if !checkTCPMetric(fields[name].(float64)) {
+ return fmt.Errorf("unexpected field value %s: '%v'", name, fields[name])
+ }
+ }
+
+ if expectTCPDataTransfer {
+ checkTCPMetric = func(f float64) bool { return f > 0 }
+ } else {
+ checkTCPMetric = func(f float64) bool { return f == 0 }
+ }
+
+ for _, name := range []string{
+ "bytes_up_tcp",
+ "bytes_down_tcp",
+ "peak_concurrent_port_forward_count_tcp",
+ "total_port_forward_count_tcp",
+ } {
+ if fields[name] == nil {
+ return fmt.Errorf("missing expected field '%s'", name)
+ }
+ if !checkTCPMetric(fields[name].(float64)) {
+ return fmt.Errorf("unexpected field value %s: '%v'", name, fields[name])
+ }
+ }
+
+ var checkUDPMetric func(float64) bool
+ if expectUDPDataTransfer {
+ checkUDPMetric = func(f float64) bool { return f > 0 }
+ } else {
+ checkUDPMetric = func(f float64) bool { return f == 0 }
+ }
+
+ for _, name := range []string{
+ "bytes_up_udp",
+ "bytes_down_udp",
+ "peak_concurrent_port_forward_count_udp",
+ "total_port_forward_count_udp",
+ "total_udpgw_channel_count",
+ } {
+ if fields[name] == nil {
+ return fmt.Errorf("missing expected field '%s'", name)
+ }
+ if !checkUDPMetric(fields[name].(float64)) {
+ return fmt.Errorf("unexpected field value %s: '%v'", name, fields[name])
+ }
+ }
+
return nil
}
diff --git a/psiphon/server/tunnelServer.go b/psiphon/server/tunnelServer.go
index dfbc0450b..fc493c6dc 100644
--- a/psiphon/server/tunnelServer.go
+++ b/psiphon/server/tunnelServer.go
@@ -1260,8 +1260,10 @@ type sshClient struct {
isFirstTunnelInSession bool
supportsServerRequests bool
handshakeState handshakeState
- udpChannel ssh.Channel
+ udpgwChannelHandler *udpgwPortForwardMultiplexer
+ totalUdpgwChannelCount int
packetTunnelChannel ssh.Channel
+ totalPacketTunnelChannelCount int
trafficRules TrafficRules
tcpTrafficState trafficState
udpTrafficState trafficState
@@ -2495,11 +2497,11 @@ func (sshClient *sshClient) handleNewTCPPortForwardChannel(
// Intercept TCP port forwards to a specified udpgw server and handle directly.
// TODO: also support UDP explicitly, e.g. with a custom "direct-udp" channel type?
- isUDPChannel := sshClient.sshServer.support.Config.UDPInterceptUdpgwServerAddress != "" &&
+ isUdpgwChannel := sshClient.sshServer.support.Config.UDPInterceptUdpgwServerAddress != "" &&
sshClient.sshServer.support.Config.UDPInterceptUdpgwServerAddress ==
net.JoinHostPort(directTcpipExtraData.HostToConnect, strconv.Itoa(int(directTcpipExtraData.PortToConnect)))
- if isUDPChannel {
+ if isUdpgwChannel {
// Dispatch immediately. handleUDPChannel runs the udpgw protocol in its
// own worker goroutine.
@@ -2507,7 +2509,7 @@ func (sshClient *sshClient) handleNewTCPPortForwardChannel(
waitGroup.Add(1)
go func(channel ssh.NewChannel) {
defer waitGroup.Done()
- sshClient.handleUDPChannel(channel)
+ sshClient.handleUdpgwChannel(channel)
}(newChannel)
} else {
@@ -2558,20 +2560,39 @@ func (sshClient *sshClient) setPacketTunnelChannel(channel ssh.Channel) {
sshClient.packetTunnelChannel.Close()
}
sshClient.packetTunnelChannel = channel
+ sshClient.totalPacketTunnelChannelCount += 1
sshClient.Unlock()
}
-// setUDPChannel sets the single UDP channel for this sshClient.
-// Each sshClient may have only one concurrent UDP channel. Each
-// UDP channel multiplexes many UDP port forwards via the udpgw
-// protocol. Any existing UDP channel is closed.
-func (sshClient *sshClient) setUDPChannel(channel ssh.Channel) {
+// setUdpgwChannelHandler sets the single udpgw channel handler for this
+// sshClient. Each sshClient may have only one concurrent udpgw
+// channel/handler. Each udpgw channel multiplexes many UDP port forwards via
+// the udpgw protocol. Any existing udpgw channel/handler is closed.
+func (sshClient *sshClient) setUdpgwChannelHandler(udpgwChannelHandler *udpgwPortForwardMultiplexer) bool {
sshClient.Lock()
- if sshClient.udpChannel != nil {
- sshClient.udpChannel.Close()
+ if sshClient.udpgwChannelHandler != nil {
+ previousHandler := sshClient.udpgwChannelHandler
+ sshClient.udpgwChannelHandler = nil
+
+ // stop must be run without holding the sshClient mutex lock, as the
+ // udpgw goroutines may attempt to lock the same mutex. For example,
+ // udpgwPortForwardMultiplexer.run calls sshClient.establishedPortForward
+ // which calls sshClient.allocatePortForward.
+ sshClient.Unlock()
+ previousHandler.stop()
+ sshClient.Lock()
+
+ // In case some other channel has set the sshClient.udpgwChannelHandler
+ // in the meantime, fail. The caller should discard this channel/handler.
+ if sshClient.udpgwChannelHandler != nil {
+ sshClient.Unlock()
+ return false
+ }
}
- sshClient.udpChannel = channel
+ sshClient.udpgwChannelHandler = udpgwChannelHandler
+ sshClient.totalUdpgwChannelCount += 1
sshClient.Unlock()
+ return true
}
var serverTunnelStatParams = append(
@@ -2616,6 +2637,8 @@ func (sshClient *sshClient) logTunnel(additionalMetrics []LogFields) {
// sshClient.udpTrafficState.peakConcurrentDialingPortForwardCount isn't meaningful
logFields["peak_concurrent_port_forward_count_udp"] = sshClient.udpTrafficState.peakConcurrentPortForwardCount
logFields["total_port_forward_count_udp"] = sshClient.udpTrafficState.totalPortForwardCount
+ logFields["total_udpgw_channel_count"] = sshClient.totalUdpgwChannelCount
+ logFields["total_packet_tunnel_channel_count"] = sshClient.totalPacketTunnelChannelCount
logFields["pre_handshake_random_stream_count"] = sshClient.preHandshakeRandomStreamMetrics.count
logFields["pre_handshake_random_stream_upstream_bytes"] = sshClient.preHandshakeRandomStreamMetrics.upstreamBytes
diff --git a/psiphon/server/udp.go b/psiphon/server/udp.go
index f58c2ba19..2279e0d31 100644
--- a/psiphon/server/udp.go
+++ b/psiphon/server/udp.go
@@ -25,7 +25,6 @@ import (
"fmt"
"io"
"net"
- "runtime/debug"
"sync"
"sync/atomic"
@@ -35,7 +34,7 @@ import (
"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
)
-// handleUDPChannel implements UDP port forwarding. A single UDP
+// handleUdpgwChannel implements UDP port forwarding. A single UDP
// SSH channel follows the udpgw protocol, which multiplexes many
// UDP port forwards.
//
@@ -43,10 +42,10 @@ import (
// Copyright (c) 2009, Ambroz Bizjak
// https://github.com/ambrop72/badvpn
//
-func (sshClient *sshClient) handleUDPChannel(newChannel ssh.NewChannel) {
+func (sshClient *sshClient) handleUdpgwChannel(newChannel ssh.NewChannel) {
// Accept this channel immediately. This channel will replace any
- // previously existing UDP channel for this client.
+ // previously existing udpgw channel for this client.
sshChannel, requests, err := newChannel.Accept()
if err != nil {
@@ -58,33 +57,81 @@ func (sshClient *sshClient) handleUDPChannel(newChannel ssh.NewChannel) {
go ssh.DiscardRequests(requests)
defer sshChannel.Close()
- sshClient.setUDPChannel(sshChannel)
-
- multiplexer := &udpPortForwardMultiplexer{
+ multiplexer := &udpgwPortForwardMultiplexer{
sshClient: sshClient,
sshChannel: sshChannel,
- portForwards: make(map[uint16]*udpPortForward),
+ portForwards: make(map[uint16]*udpgwPortForward),
portForwardLRU: common.NewLRUConns(),
relayWaitGroup: new(sync.WaitGroup),
+ runWaitGroup: new(sync.WaitGroup),
+ }
+
+ multiplexer.runWaitGroup.Add(1)
+
+ // setUdpgwChannelHandler will close any existing
+ // udpgwPortForwardMultiplexer, waiting for all run/relayDownstream
+ // goroutines to first terminate and all UDP socket resources to be
+ // cleaned up.
+ //
+ // This synchronous shutdown also ensures that the
+ // concurrentPortForwardCount is reduced to 0 before installing the new
+ // udpgwPortForwardMultiplexer and its LRU object. If the older handler
+ // were to dangle with open port forwards, and concurrentPortForwardCount
+ // were to hit the max, the wrong LRU, the new one, would be used to
+ // close the LRU port forward.
+ //
+ // Call setUdpgwHandler only after runWaitGroup is initialized, to ensure
+ // runWaitGroup.Wait() cannot be invoked (by some subsequent new udpgw
+ // channel) before initialized.
+
+ if !sshClient.setUdpgwChannelHandler(multiplexer) {
+ // setUdpgwChannelHandler returns false if some other SSH channel
+ // calls setUdpgwChannelHandler in the middle of this call. In that
+ // case, discard this channel: the client's latest udpgw channel is
+ // retained.
+ return
}
+
multiplexer.run()
+ multiplexer.runWaitGroup.Done()
}
-type udpPortForwardMultiplexer struct {
+type udpgwPortForwardMultiplexer struct {
sshClient *sshClient
sshChannelWriteMutex sync.Mutex
sshChannel ssh.Channel
portForwardsMutex sync.Mutex
- portForwards map[uint16]*udpPortForward
+ portForwards map[uint16]*udpgwPortForward
portForwardLRU *common.LRUConns
relayWaitGroup *sync.WaitGroup
+ runWaitGroup *sync.WaitGroup
+}
+
+func (mux *udpgwPortForwardMultiplexer) stop() {
+
+ // udpgwPortForwardMultiplexer must be initialized by handleUdpgwChannel.
+ //
+ // stop closes the udpgw SSH channel, which will cause the run goroutine
+ // to exit its message read loop and await closure of all relayDownstream
+ // goroutines. Closing all port forward UDP conns will cause all
+ // relayDownstream to exit.
+
+ _ = mux.sshChannel.Close()
+
+ mux.portForwardsMutex.Lock()
+ for _, portForward := range mux.portForwards {
+ _ = portForward.conn.Close()
+ }
+ mux.portForwardsMutex.Unlock()
+
+ mux.runWaitGroup.Wait()
}
-func (mux *udpPortForwardMultiplexer) run() {
+func (mux *udpgwPortForwardMultiplexer) run() {
- // In a loop, read udpgw messages from the client to this channel. Each message is
- // a UDP packet to send upstream either via a new port forward, or on an existing
- // port forward.
+ // In a loop, read udpgw messages from the client to this channel. Each
+ // message contains a UDP packet to send upstream either via a new port
+ // forward, or on an existing port forward.
//
// A goroutine is run to read downstream packets for each UDP port forward. All read
// packets are encapsulated in udpgw protocol and sent down the channel to the client.
@@ -92,16 +139,6 @@ func (mux *udpPortForwardMultiplexer) run() {
// When the client disconnects or the server shuts down, the channel will close and
// readUdpgwMessage will exit with EOF.
- // Recover from and log any unexpected panics caused by udpgw input handling bugs.
- // Note: this covers the run() goroutine only and not relayDownstream() goroutines.
- defer func() {
- if e := recover(); e != nil {
- err := errors.Tracef(
- "udpPortForwardMultiplexer panic: %s: %s", e, debug.Stack())
- log.WithTraceFields(LogFields{"error": err}).Warning("run failed")
- }
- }()
-
buffer := make([]byte, udpgwProtocolMaxMessageSize)
for {
// Note: message.packet points to the reusable memory in "buffer".
@@ -119,27 +156,37 @@ func (mux *udpPortForwardMultiplexer) run() {
portForward := mux.portForwards[message.connID]
mux.portForwardsMutex.Unlock()
- if portForward != nil && message.discardExistingConn {
+ // In the udpgw protocol, an existing port forward is closed when
+ // either the discard flag is set or the remote address has changed.
+
+ if portForward != nil &&
+ (message.discardExistingConn ||
+ !bytes.Equal(portForward.remoteIP, message.remoteIP) ||
+ portForward.remotePort != message.remotePort) {
+
// The port forward's goroutine will complete cleanup, including
// tallying stats and calling sshClient.closedPortForward.
// portForward.conn.Close() will signal this shutdown.
- // TODO: wait for goroutine to exit before proceeding?
portForward.conn.Close()
- portForward = nil
- }
-
- if portForward != nil {
- // Verify that portForward remote address matches latest message
+ // Synchronously await the termination of the relayDownstream
+ // goroutine. This ensures that the previous goroutine won't
+ // invoke removePortForward, with the connID that will be reused
+ // for the new port forward, after this point.
+ //
+ // Limitation: this synchronous shutdown cannot prevent a "wrong
+ // remote address" error on the badvpn udpgw client, which occurs
+ // when the client recycles a port forward (setting discard) but
+ // receives, from the server, a udpgw message containing the old
+ // remote address for the previous port forward with the same
+ // conn ID. That downstream message from the server may be in
+ // flight in the SSH channel when the client discard message arrives.
+ portForward.relayWaitGroup.Wait()
- if !bytes.Equal(portForward.remoteIP, message.remoteIP) ||
- portForward.remotePort != message.remotePort {
-
- log.WithTrace().Warning("UDP port forward remote address mismatch")
- continue
- }
+ portForward = nil
+ }
- } else {
+ if portForward == nil {
// Create a new port forward
@@ -237,17 +284,18 @@ func (mux *udpPortForwardMultiplexer) run() {
continue
}
- portForward = &udpPortForward{
- connID: message.connID,
- preambleSize: message.preambleSize,
- remoteIP: message.remoteIP,
- remotePort: message.remotePort,
- dialIP: dialIP,
- conn: conn,
- lruEntry: lruEntry,
- bytesUp: 0,
- bytesDown: 0,
- mux: mux,
+ portForward = &udpgwPortForward{
+ connID: message.connID,
+ preambleSize: message.preambleSize,
+ remoteIP: message.remoteIP,
+ remotePort: message.remotePort,
+ dialIP: dialIP,
+ conn: conn,
+ lruEntry: lruEntry,
+ bytesUp: 0,
+ bytesDown: 0,
+ relayWaitGroup: new(sync.WaitGroup),
+ mux: mux,
}
if message.forwardDNS {
@@ -258,6 +306,7 @@ func (mux *udpPortForwardMultiplexer) run() {
mux.portForwards[portForward.connID] = portForward
mux.portForwardsMutex.Unlock()
+ portForward.relayWaitGroup.Add(1)
mux.relayWaitGroup.Add(1)
go portForward.relayDownstream()
}
@@ -276,7 +325,7 @@ func (mux *udpPortForwardMultiplexer) run() {
atomic.AddInt64(&portForward.bytesUp, int64(len(message.packet)))
}
- // Cleanup all UDP port forward workers when exiting
+ // Cleanup all udpgw port forward workers when exiting
mux.portForwardsMutex.Lock()
for _, portForward := range mux.portForwards {
@@ -288,13 +337,13 @@ func (mux *udpPortForwardMultiplexer) run() {
mux.relayWaitGroup.Wait()
}
-func (mux *udpPortForwardMultiplexer) removePortForward(connID uint16) {
+func (mux *udpgwPortForwardMultiplexer) removePortForward(connID uint16) {
mux.portForwardsMutex.Lock()
delete(mux.portForwards, connID)
mux.portForwardsMutex.Unlock()
}
-type udpPortForward struct {
+type udpgwPortForward struct {
// Note: 64-bit ints used with atomic operations are placed
// at the start of struct to ensure 64-bit alignment.
// (https://golang.org/pkg/sync/atomic/#pkg-note-BUG)
@@ -309,10 +358,12 @@ type udpPortForward struct {
dialIP net.IP
conn net.Conn
lruEntry *common.LRUConnsEntry
- mux *udpPortForwardMultiplexer
+ relayWaitGroup *sync.WaitGroup
+ mux *udpgwPortForwardMultiplexer
}
-func (portForward *udpPortForward) relayDownstream() {
+func (portForward *udpgwPortForward) relayDownstream() {
+ defer portForward.relayWaitGroup.Done()
defer portForward.mux.relayWaitGroup.Done()
// Downstream UDP packets are read into the reusable memory