diff --git a/MobileLibrary/Android/PsiphonTunnel/ca_psiphon_psiphontunnel_backup_rules.xml b/MobileLibrary/Android/PsiphonTunnel/ca_psiphon_psiphontunnel_backup_rules.xml index 29077a85d..f766196da 100644 --- a/MobileLibrary/Android/PsiphonTunnel/ca_psiphon_psiphontunnel_backup_rules.xml +++ b/MobileLibrary/Android/PsiphonTunnel/ca_psiphon_psiphontunnel_backup_rules.xml @@ -1,4 +1,4 @@ - + diff --git a/psiphon/dialParameters.go b/psiphon/dialParameters.go index 8bdc07171..3fa9d72ad 100644 --- a/psiphon/dialParameters.go +++ b/psiphon/dialParameters.go @@ -773,7 +773,7 @@ func MakeDialParameters( case protocol.TUNNEL_PROTOCOL_MARIONETTE_OBFUSCATED_SSH: - // Note: port comes from marionnete "format" + // Note: port comes from marionette "format" dialParams.DirectDialAddress = serverEntry.IpAddress default: diff --git a/psiphon/server/listener.go b/psiphon/server/listener.go index 03021c2ba..71a8ff31e 100644 --- a/psiphon/server/listener.go +++ b/psiphon/server/listener.go @@ -25,14 +25,15 @@ import ( "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common" "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors" "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/fragmentor" - "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/parameters" "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng" "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol" ) // TacticsListener wraps a net.Listener and applies server-side implementation // of certain tactics parameters to accepted connections. Tactics filtering is -// limited to GeoIP attributes as the client has not yet sent API paramaters. +// limited to GeoIP attributes as the client has not yet sent API parameters. +// GeoIP uses the immediate peer IP, and so TacticsListener is suitable only +// for tactics that do not require the original client GeoIP when fronted. type TacticsListener struct { net.Listener support *SupportServices @@ -77,11 +78,14 @@ func (listener *TacticsListener) accept() (net.Conn, error) { return nil, err } + // Limitation: RemoteAddr is the immediate peer IP, which is not the original + // client IP in the case of fronting. geoIPData := listener.geoIPLookup( common.IPAddressFromAddr(conn.RemoteAddr())) p, err := listener.support.ServerTacticsParametersCache.Get(geoIPData) if err != nil { + conn.Close() return nil, errors.Trace(err) } @@ -90,34 +94,6 @@ func (listener *TacticsListener) accept() (net.Conn, error) { return conn, nil } - // Disconnect immediately if the clients tactics restricts usage of the - // fronting provider ID. The probability may be used to influence usage of a - // given fronting provider; but when only that provider works for a given - // client, and the probability is less than 1.0, the client can retry until - // it gets a successful coin flip. - // - // Clients will also skip candidates with restricted fronting provider IDs. - // The client-side probability, RestrictFrontingProviderIDsClientProbability, - // is applied independently of the server-side coin flip here. - // - // - // At this stage, GeoIP tactics filters are active, but handshake API - // parameters are not. - // - // See the comment in server.LoadConfig regarding fronting provider ID - // limitations. - - if protocol.TunnelProtocolUsesFrontedMeek(listener.tunnelProtocol) && - common.Contains( - p.Strings(parameters.RestrictFrontingProviderIDs), - listener.support.Config.GetFrontingProviderID()) { - if p.WeightedCoinFlip( - parameters.RestrictFrontingProviderIDsServerProbability) { - conn.Close() - return nil, nil - } - } - // Server-side fragmentation may be synchronized with client-side in two ways. // // In the OSSH case, replay is always activated and it is seeded using the diff --git a/psiphon/server/listener_test.go b/psiphon/server/listener_test.go index 0d5a05f84..2447a78c3 100644 --- a/psiphon/server/listener_test.go +++ b/psiphon/server/listener_test.go @@ -28,7 +28,6 @@ import ( "time" "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/fragmentor" - "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng" "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol" "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/tactics" ) @@ -37,8 +36,6 @@ func TestListener(t *testing.T) { tunnelProtocol := protocol.TUNNEL_PROTOCOL_FRONTED_MEEK - frontingProviderID := prng.HexString(8) - tacticsConfigJSONFormat := ` { "RequestPublicKey" : "%s", @@ -65,19 +62,6 @@ func TestListener(t *testing.T) { "FragmentorDownstreamMaxWriteBytes" : 1 } } - }, - { - "Filter" : { - "Regions": ["R3"], - "ISPs": ["I3"], - "Cities": ["C3"] - }, - "Tactics" : { - "Parameters" : { - "RestrictFrontingProviderIDs" : ["%s"], - "RestrictFrontingProviderIDsServerProbability" : 1.0 - } - } } ] } @@ -92,7 +76,7 @@ func TestListener(t *testing.T) { tacticsConfigJSON := fmt.Sprintf( tacticsConfigJSONFormat, tacticsRequestPublicKey, tacticsRequestPrivateKey, tacticsRequestObfuscatedKey, - tunnelProtocol, frontingProviderID) + tunnelProtocol) tacticsConfigFilename := filepath.Join(testDataDirName, "tactics_config.json") @@ -122,12 +106,6 @@ func TestListener(t *testing.T) { listenerUnfragmentedGeoIPWrongCity := func(string) GeoIPData { return GeoIPData{Country: "R1", ISP: "I1", City: "C2"} } - listenerRestrictedFrontingProviderIDGeoIP := func(string) GeoIPData { - return GeoIPData{Country: "R3", ISP: "I3", City: "C3"} - } - listenerUnrestrictedFrontingProviderIDWrongRegion := func(string) GeoIPData { - return GeoIPData{Country: "R2", ISP: "I3", City: "C3"} - } listenerTestCases := []struct { description string @@ -159,18 +137,6 @@ func TestListener(t *testing.T) { false, true, }, - { - "restricted", - listenerRestrictedFrontingProviderIDGeoIP, - false, - false, - }, - { - "unrestricted-region", - listenerUnrestrictedFrontingProviderIDWrongRegion, - false, - true, - }, } for _, testCase := range listenerTestCases { @@ -182,7 +148,7 @@ func TestListener(t *testing.T) { } support := &SupportServices{ - Config: &Config{frontingProviderID: frontingProviderID}, + Config: &Config{}, TacticsServer: tacticsServer, } support.ReplayCache = NewReplayCache(support) diff --git a/psiphon/server/meek.go b/psiphon/server/meek.go index 0042b35f9..8663fdd6e 100644 --- a/psiphon/server/meek.go +++ b/psiphon/server/meek.go @@ -43,6 +43,7 @@ import ( "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common" "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors" "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/obfuscator" + "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/parameters" "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng" "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol" "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/values" @@ -344,8 +345,9 @@ func (server *MeekServer) ServeHTTP(responseWriter http.ResponseWriter, request session, underlyingConn, endPoint, - clientIP, + endPointGeoIPData, err := server.getSessionOrEndpoint(request, meekCookie) + if err != nil { // Debug since session cookie errors commonly occur during // normal operation. @@ -359,9 +361,8 @@ func (server *MeekServer) ServeHTTP(responseWriter http.ResponseWriter, request // Endpoint mode. Currently, this means it's handled by the tactics // request handler. - geoIPData := server.support.GeoIPService.Lookup(clientIP) handled := server.support.TacticsServer.HandleEndPoint( - endPoint, common.GeoIPData(geoIPData), responseWriter, request) + endPoint, common.GeoIPData(*endPointGeoIPData), responseWriter, request) if !handled { log.WithTraceFields(LogFields{"endPoint": endPoint}).Info("unhandled endpoint") common.TerminateHTTPConnection(responseWriter, request) @@ -587,7 +588,7 @@ func checkRangeHeader(request *http.Request) (int, bool) { // mode; or the endpoint is returned when the meek cookie indicates endpoint // mode. func (server *MeekServer) getSessionOrEndpoint( - request *http.Request, meekCookie *http.Cookie) (string, *meekSession, net.Conn, string, string, error) { + request *http.Request, meekCookie *http.Cookie) (string, *meekSession, net.Conn, string, *GeoIPData, error) { underlyingConn := request.Context().Value(meekNetConnContextKey).(net.Conn) @@ -601,7 +602,7 @@ func (server *MeekServer) getSessionOrEndpoint( // TODO: can multiple http client connections using same session cookie // cause race conditions on session struct? session.touch() - return existingSessionID, session, underlyingConn, "", "", nil + return existingSessionID, session, underlyingConn, "", nil, nil } // Determine the client remote address, which is used for geolocation @@ -610,6 +611,8 @@ func (server *MeekServer) getSessionOrEndpoint( // headers such as X-Forwarded-For. clientIP := strings.Split(request.RemoteAddr, ":")[0] + usedProxyForwardedForHeader := false + var geoIPData GeoIPData if len(server.support.Config.MeekProxyForwardedForHeaders) > 0 { for _, header := range server.support.Config.MeekProxyForwardedForHeaders { @@ -619,23 +622,29 @@ func (server *MeekServer) getSessionOrEndpoint( // list of IPs (each proxy in a chain). The first IP should be // the client IP. proxyClientIP := strings.Split(value, ",")[0] - if net.ParseIP(proxyClientIP) != nil && - server.support.GeoIPService.Lookup( - proxyClientIP).Country != GEOIP_UNKNOWN_VALUE { - - clientIP = proxyClientIP - break + if net.ParseIP(proxyClientIP) != nil { + proxyClientGeoIPData := server.support.GeoIPService.Lookup(proxyClientIP) + if proxyClientGeoIPData.Country != GEOIP_UNKNOWN_VALUE { + usedProxyForwardedForHeader = true + clientIP = proxyClientIP + geoIPData = proxyClientGeoIPData + break + } } } } } + if !usedProxyForwardedForHeader { + geoIPData = server.support.GeoIPService.Lookup(clientIP) + } + // The session is new (or expired). Treat the cookie value as a new meek // cookie, extract the payload, and create a new session. payloadJSON, err := server.getMeekCookiePayload(clientIP, meekCookie.Value) if err != nil { - return "", nil, nil, "", "", errors.Trace(err) + return "", nil, nil, "", nil, errors.Trace(err) } // Note: this meek server ignores legacy values PsiphonClientSessionId @@ -644,7 +653,7 @@ func (server *MeekServer) getSessionOrEndpoint( err = json.Unmarshal(payloadJSON, &clientSessionData) if err != nil { - return "", nil, nil, "", "", errors.Trace(err) + return "", nil, nil, "", nil, errors.Trace(err) } tunnelProtocol := server.listenerTunnelProtocol @@ -656,7 +665,7 @@ func (server *MeekServer) getSessionOrEndpoint( server.listenerTunnelProtocol, server.support.Config.GetRunningProtocols()) { - return "", nil, nil, "", "", errors.Tracef( + return "", nil, nil, "", nil, errors.Tracef( "invalid client tunnel protocol: %s", clientSessionData.ClientTunnelProtocol) } @@ -669,8 +678,8 @@ func (server *MeekServer) getSessionOrEndpoint( // rate limit is primarily intended to limit memory resource consumption and // not the overhead incurred by cookie validation. - if server.rateLimit(clientIP, tunnelProtocol) { - return "", nil, nil, "", "", errors.TraceNew("rate limit exceeded") + if server.rateLimit(clientIP, geoIPData, tunnelProtocol) { + return "", nil, nil, "", nil, errors.TraceNew("rate limit exceeded") } // Handle endpoints before enforcing CheckEstablishTunnels. @@ -678,7 +687,7 @@ func (server *MeekServer) getSessionOrEndpoint( // handled by servers which would otherwise reject new tunnels. if clientSessionData.EndPoint != "" { - return "", nil, nil, clientSessionData.EndPoint, clientIP, nil + return "", nil, nil, clientSessionData.EndPoint, &geoIPData, nil } // Don't create new sessions when not establishing. A subsequent SSH handshake @@ -686,7 +695,42 @@ func (server *MeekServer) getSessionOrEndpoint( if server.support.TunnelServer != nil && !server.support.TunnelServer.CheckEstablishTunnels() { - return "", nil, nil, "", "", errors.TraceNew("not establishing tunnels") + return "", nil, nil, "", nil, errors.TraceNew("not establishing tunnels") + } + + // Disconnect immediately if the tactics for the client restricts usage of + // the fronting provider ID. The probability may be used to influence + // usage of a given fronting provider; but when only that provider works + // for a given client, and the probability is less than 1.0, the client + // can retry until it gets a successful coin flip. + // + // Clients will also skip candidates with restricted fronting provider IDs. + // The client-side probability, RestrictFrontingProviderIDsClientProbability, + // is applied independently of the server-side coin flip here. + // + // At this stage, GeoIP tactics filters are active, but handshake API + // parameters are not. + // + // See the comment in server.LoadConfig regarding fronting provider ID + // limitations. + + if protocol.TunnelProtocolUsesFrontedMeek(server.listenerTunnelProtocol) && + server.support.ServerTacticsParametersCache != nil { + + p, err := server.support.ServerTacticsParametersCache.Get(geoIPData) + if err != nil { + return "", nil, nil, "", nil, errors.Trace(err) + } + + if !p.IsNil() && + common.Contains( + p.Strings(parameters.RestrictFrontingProviderIDs), + server.support.Config.GetFrontingProviderID()) { + if p.WeightedCoinFlip( + parameters.RestrictFrontingProviderIDsServerProbability) { + return "", nil, nil, "", nil, errors.TraceNew("restricted fronting provider") + } + } } // Create a new session @@ -736,7 +780,7 @@ func (server *MeekServer) getSessionOrEndpoint( if clientSessionData.MeekProtocolVersion >= MEEK_PROTOCOL_VERSION_2 { sessionID, err = makeMeekSessionID() if err != nil { - return "", nil, nil, "", "", errors.Trace(err) + return "", nil, nil, "", nil, errors.Trace(err) } } @@ -748,10 +792,11 @@ func (server *MeekServer) getSessionOrEndpoint( // will close when session.delete calls Close() on the meekConn. server.clientHandler(clientSessionData.ClientTunnelProtocol, session.clientConn) - return sessionID, session, underlyingConn, "", "", nil + return sessionID, session, underlyingConn, "", nil, nil } -func (server *MeekServer) rateLimit(clientIP string, tunnelProtocol string) bool { +func (server *MeekServer) rateLimit( + clientIP string, geoIPData GeoIPData, tunnelProtocol string) bool { historySize, thresholdSeconds, @@ -774,9 +819,6 @@ func (server *MeekServer) rateLimit(clientIP string, tunnelProtocol string) bool if len(regions) > 0 || len(ISPs) > 0 || len(cities) > 0 { - // TODO: avoid redundant GeoIP lookups? - geoIPData := server.support.GeoIPService.Lookup(clientIP) - if len(regions) > 0 { if !common.Contains(regions, geoIPData.Country) { return false diff --git a/psiphon/server/meek_test.go b/psiphon/server/meek_test.go index 8f2a15a42..656f2ff42 100755 --- a/psiphon/server/meek_test.go +++ b/psiphon/server/meek_test.go @@ -25,8 +25,10 @@ import ( crypto_rand "crypto/rand" "encoding/base64" "fmt" + "io/ioutil" "math/rand" "net" + "path/filepath" "sync" "sync/atomic" "syscall" @@ -38,6 +40,7 @@ import ( "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/parameters" "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng" "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol" + "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/tactics" "golang.org/x/crypto/nacl/box" ) @@ -245,6 +248,7 @@ func TestMeekResiliency(t *testing.T) { }, TrafficRulesSet: &TrafficRulesSet{}, } + mockSupport.GeoIPService, _ = NewGeoIPService([]string{}) listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { @@ -401,19 +405,73 @@ func (interruptor *fileDescriptorInterruptor) BindToDevice(fileDescriptor int) ( } func TestMeekRateLimiter(t *testing.T) { - runTestMeekRateLimiter(t, true) - runTestMeekRateLimiter(t, false) + runTestMeekAccessControl(t, true, false) + runTestMeekAccessControl(t, false, false) } -func runTestMeekRateLimiter(t *testing.T, rateLimit bool) { +func TestMeekRestrictFrontingProviders(t *testing.T) { + runTestMeekAccessControl(t, false, true) + runTestMeekAccessControl(t, false, false) +} + +func runTestMeekAccessControl(t *testing.T, rateLimit, restrictProvider bool) { attempts := 10 allowedConnections := 5 + if !rateLimit { allowedConnections = 10 } + if restrictProvider { + allowedConnections = 0 + } + + // Configure tactics + + frontingProviderID := prng.HexString(8) + + tacticsConfigJSONFormat := ` + { + "RequestPublicKey" : "%s", + "RequestPrivateKey" : "%s", + "RequestObfuscatedKey" : "%s", + "DefaultTactics" : { + "TTL" : "60s", + "Probability" : 1.0, + "Parameters" : { + "RestrictFrontingProviderIDs" : ["%s"], + "RestrictFrontingProviderIDsServerProbability" : 1.0 + } + } + } + ` + + tacticsRequestPublicKey, tacticsRequestPrivateKey, tacticsRequestObfuscatedKey, err := + tactics.GenerateKeys() + if err != nil { + t.Fatalf("error generating tactics keys: %s", err) + } + + restrictFrontingProviderID := "" + + if restrictProvider { + restrictFrontingProviderID = frontingProviderID + } + + tacticsConfigJSON := fmt.Sprintf( + tacticsConfigJSONFormat, + tacticsRequestPublicKey, tacticsRequestPrivateKey, tacticsRequestObfuscatedKey, + restrictFrontingProviderID) + + tacticsConfigFilename := filepath.Join(testDataDirName, "tactics_config.json") + + err = ioutil.WriteFile(tacticsConfigFilename, []byte(tacticsConfigJSON), 0600) + if err != nil { + t.Fatalf("error paving tactics config file: %s", err) + } + // Run meek server rawMeekCookieEncryptionPublicKey, rawMeekCookieEncryptionPrivateKey, err := box.GenerateKey(crypto_rand.Reader) @@ -424,11 +482,11 @@ func runTestMeekRateLimiter(t *testing.T, rateLimit bool) { meekCookieEncryptionPrivateKey := base64.StdEncoding.EncodeToString(rawMeekCookieEncryptionPrivateKey[:]) meekObfuscatedKey := prng.HexString(SSH_OBFUSCATED_KEY_BYTE_LENGTH) - tunnelProtocol := protocol.TUNNEL_PROTOCOL_UNFRONTED_MEEK + tunnelProtocol := protocol.TUNNEL_PROTOCOL_FRONTED_MEEK meekRateLimiterTunnelProtocols := []string{tunnelProtocol} if !rateLimit { - meekRateLimiterTunnelProtocols = []string{protocol.TUNNEL_PROTOCOL_UNFRONTED_MEEK_HTTPS} + meekRateLimiterTunnelProtocols = []string{protocol.TUNNEL_PROTOCOL_FRONTED_MEEK} } mockSupport := &SupportServices{ @@ -436,6 +494,7 @@ func runTestMeekRateLimiter(t *testing.T, rateLimit bool) { MeekObfuscatedKey: meekObfuscatedKey, MeekCookieEncryptionPrivateKey: meekCookieEncryptionPrivateKey, TunnelProtocolPorts: map[string]int{tunnelProtocol: 0}, + frontingProviderID: frontingProviderID, }, TrafficRulesSet: &TrafficRulesSet{ MeekRateLimiterHistorySize: allowedConnections, @@ -445,6 +504,15 @@ func runTestMeekRateLimiter(t *testing.T, rateLimit bool) { MeekRateLimiterReapHistoryFrequencySeconds: 1, }, } + mockSupport.GeoIPService, _ = NewGeoIPService([]string{}) + + tacticsServer, err := tactics.NewServer(nil, nil, nil, tacticsConfigFilename) + if err != nil { + t.Fatalf("tactics.NewServer failed: %s", err) + } + + mockSupport.TacticsServer = tacticsServer + mockSupport.ServerTacticsParametersCache = NewServerTacticsParametersCache(mockSupport) listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { @@ -567,8 +635,8 @@ func runTestMeekRateLimiter(t *testing.T, rateLimit bool) { totalFailures != attempts-totalConnections { t.Fatalf( - "Unexpected results: %d connections, %d failures", - totalConnections, totalFailures) + "Unexpected results: %d connections, %d failures, %d allowed", + totalConnections, totalFailures, allowedConnections) } // Graceful shutdown diff --git a/psiphon/server/server_test.go b/psiphon/server/server_test.go index 25232babb..4e6b89d5a 100644 --- a/psiphon/server/server_test.go +++ b/psiphon/server/server_test.go @@ -1338,6 +1338,10 @@ func runServer(t *testing.T, runConfig *runServerConfig) { expectServerBPFField := ServerBPFEnabled() && doServerTactics expectServerPacketManipulationField := runConfig.doPacketManipulation expectBurstFields := runConfig.doBurstMonitor + expectTCPPortForwardDial := runConfig.doTunneledWebRequest + expectTCPDataTransfer := runConfig.doTunneledWebRequest && !expectTrafficFailure && !runConfig.doSplitTunnel + // Even with expectTrafficFailure, DNS port forwards will succeed + expectUDPDataTransfer := runConfig.doTunneledNTPRequest select { case logFields := <-serverTunnelLog: @@ -1347,6 +1351,9 @@ func runServer(t *testing.T, runConfig *runServerConfig) { expectServerBPFField, expectServerPacketManipulationField, expectBurstFields, + expectTCPPortForwardDial, + expectTCPDataTransfer, + expectUDPDataTransfer, logFields) if err != nil { t.Fatalf("invalid server tunnel log fields: %s", err) @@ -1404,6 +1411,9 @@ func checkExpectedServerTunnelLogFields( expectServerBPFField bool, expectServerPacketManipulationField bool, expectBurstFields bool, + expectTCPPortForwardDial bool, + expectTCPDataTransfer bool, + expectUDPDataTransfer bool, fields map[string]interface{}) error { // Limitations: @@ -1649,6 +1659,66 @@ func checkExpectedServerTunnelLogFields( return fmt.Errorf("unexpected network_type '%s'", fields["network_type"]) } + var checkTCPMetric func(float64) bool + if expectTCPPortForwardDial { + checkTCPMetric = func(f float64) bool { return f > 0 } + } else { + checkTCPMetric = func(f float64) bool { return f == 0 } + } + + for _, name := range []string{ + "peak_concurrent_dialing_port_forward_count_tcp", + } { + if fields[name] == nil { + return fmt.Errorf("missing expected field '%s'", name) + } + if !checkTCPMetric(fields[name].(float64)) { + return fmt.Errorf("unexpected field value %s: '%v'", name, fields[name]) + } + } + + if expectTCPDataTransfer { + checkTCPMetric = func(f float64) bool { return f > 0 } + } else { + checkTCPMetric = func(f float64) bool { return f == 0 } + } + + for _, name := range []string{ + "bytes_up_tcp", + "bytes_down_tcp", + "peak_concurrent_port_forward_count_tcp", + "total_port_forward_count_tcp", + } { + if fields[name] == nil { + return fmt.Errorf("missing expected field '%s'", name) + } + if !checkTCPMetric(fields[name].(float64)) { + return fmt.Errorf("unexpected field value %s: '%v'", name, fields[name]) + } + } + + var checkUDPMetric func(float64) bool + if expectUDPDataTransfer { + checkUDPMetric = func(f float64) bool { return f > 0 } + } else { + checkUDPMetric = func(f float64) bool { return f == 0 } + } + + for _, name := range []string{ + "bytes_up_udp", + "bytes_down_udp", + "peak_concurrent_port_forward_count_udp", + "total_port_forward_count_udp", + "total_udpgw_channel_count", + } { + if fields[name] == nil { + return fmt.Errorf("missing expected field '%s'", name) + } + if !checkUDPMetric(fields[name].(float64)) { + return fmt.Errorf("unexpected field value %s: '%v'", name, fields[name]) + } + } + return nil } diff --git a/psiphon/server/tunnelServer.go b/psiphon/server/tunnelServer.go index dfbc0450b..fc493c6dc 100644 --- a/psiphon/server/tunnelServer.go +++ b/psiphon/server/tunnelServer.go @@ -1260,8 +1260,10 @@ type sshClient struct { isFirstTunnelInSession bool supportsServerRequests bool handshakeState handshakeState - udpChannel ssh.Channel + udpgwChannelHandler *udpgwPortForwardMultiplexer + totalUdpgwChannelCount int packetTunnelChannel ssh.Channel + totalPacketTunnelChannelCount int trafficRules TrafficRules tcpTrafficState trafficState udpTrafficState trafficState @@ -2495,11 +2497,11 @@ func (sshClient *sshClient) handleNewTCPPortForwardChannel( // Intercept TCP port forwards to a specified udpgw server and handle directly. // TODO: also support UDP explicitly, e.g. with a custom "direct-udp" channel type? - isUDPChannel := sshClient.sshServer.support.Config.UDPInterceptUdpgwServerAddress != "" && + isUdpgwChannel := sshClient.sshServer.support.Config.UDPInterceptUdpgwServerAddress != "" && sshClient.sshServer.support.Config.UDPInterceptUdpgwServerAddress == net.JoinHostPort(directTcpipExtraData.HostToConnect, strconv.Itoa(int(directTcpipExtraData.PortToConnect))) - if isUDPChannel { + if isUdpgwChannel { // Dispatch immediately. handleUDPChannel runs the udpgw protocol in its // own worker goroutine. @@ -2507,7 +2509,7 @@ func (sshClient *sshClient) handleNewTCPPortForwardChannel( waitGroup.Add(1) go func(channel ssh.NewChannel) { defer waitGroup.Done() - sshClient.handleUDPChannel(channel) + sshClient.handleUdpgwChannel(channel) }(newChannel) } else { @@ -2558,20 +2560,39 @@ func (sshClient *sshClient) setPacketTunnelChannel(channel ssh.Channel) { sshClient.packetTunnelChannel.Close() } sshClient.packetTunnelChannel = channel + sshClient.totalPacketTunnelChannelCount += 1 sshClient.Unlock() } -// setUDPChannel sets the single UDP channel for this sshClient. -// Each sshClient may have only one concurrent UDP channel. Each -// UDP channel multiplexes many UDP port forwards via the udpgw -// protocol. Any existing UDP channel is closed. -func (sshClient *sshClient) setUDPChannel(channel ssh.Channel) { +// setUdpgwChannelHandler sets the single udpgw channel handler for this +// sshClient. Each sshClient may have only one concurrent udpgw +// channel/handler. Each udpgw channel multiplexes many UDP port forwards via +// the udpgw protocol. Any existing udpgw channel/handler is closed. +func (sshClient *sshClient) setUdpgwChannelHandler(udpgwChannelHandler *udpgwPortForwardMultiplexer) bool { sshClient.Lock() - if sshClient.udpChannel != nil { - sshClient.udpChannel.Close() + if sshClient.udpgwChannelHandler != nil { + previousHandler := sshClient.udpgwChannelHandler + sshClient.udpgwChannelHandler = nil + + // stop must be run without holding the sshClient mutex lock, as the + // udpgw goroutines may attempt to lock the same mutex. For example, + // udpgwPortForwardMultiplexer.run calls sshClient.establishedPortForward + // which calls sshClient.allocatePortForward. + sshClient.Unlock() + previousHandler.stop() + sshClient.Lock() + + // In case some other channel has set the sshClient.udpgwChannelHandler + // in the meantime, fail. The caller should discard this channel/handler. + if sshClient.udpgwChannelHandler != nil { + sshClient.Unlock() + return false + } } - sshClient.udpChannel = channel + sshClient.udpgwChannelHandler = udpgwChannelHandler + sshClient.totalUdpgwChannelCount += 1 sshClient.Unlock() + return true } var serverTunnelStatParams = append( @@ -2616,6 +2637,8 @@ func (sshClient *sshClient) logTunnel(additionalMetrics []LogFields) { // sshClient.udpTrafficState.peakConcurrentDialingPortForwardCount isn't meaningful logFields["peak_concurrent_port_forward_count_udp"] = sshClient.udpTrafficState.peakConcurrentPortForwardCount logFields["total_port_forward_count_udp"] = sshClient.udpTrafficState.totalPortForwardCount + logFields["total_udpgw_channel_count"] = sshClient.totalUdpgwChannelCount + logFields["total_packet_tunnel_channel_count"] = sshClient.totalPacketTunnelChannelCount logFields["pre_handshake_random_stream_count"] = sshClient.preHandshakeRandomStreamMetrics.count logFields["pre_handshake_random_stream_upstream_bytes"] = sshClient.preHandshakeRandomStreamMetrics.upstreamBytes diff --git a/psiphon/server/udp.go b/psiphon/server/udp.go index f58c2ba19..2279e0d31 100644 --- a/psiphon/server/udp.go +++ b/psiphon/server/udp.go @@ -25,7 +25,6 @@ import ( "fmt" "io" "net" - "runtime/debug" "sync" "sync/atomic" @@ -35,7 +34,7 @@ import ( "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors" ) -// handleUDPChannel implements UDP port forwarding. A single UDP +// handleUdpgwChannel implements UDP port forwarding. A single UDP // SSH channel follows the udpgw protocol, which multiplexes many // UDP port forwards. // @@ -43,10 +42,10 @@ import ( // Copyright (c) 2009, Ambroz Bizjak // https://github.com/ambrop72/badvpn // -func (sshClient *sshClient) handleUDPChannel(newChannel ssh.NewChannel) { +func (sshClient *sshClient) handleUdpgwChannel(newChannel ssh.NewChannel) { // Accept this channel immediately. This channel will replace any - // previously existing UDP channel for this client. + // previously existing udpgw channel for this client. sshChannel, requests, err := newChannel.Accept() if err != nil { @@ -58,33 +57,81 @@ func (sshClient *sshClient) handleUDPChannel(newChannel ssh.NewChannel) { go ssh.DiscardRequests(requests) defer sshChannel.Close() - sshClient.setUDPChannel(sshChannel) - - multiplexer := &udpPortForwardMultiplexer{ + multiplexer := &udpgwPortForwardMultiplexer{ sshClient: sshClient, sshChannel: sshChannel, - portForwards: make(map[uint16]*udpPortForward), + portForwards: make(map[uint16]*udpgwPortForward), portForwardLRU: common.NewLRUConns(), relayWaitGroup: new(sync.WaitGroup), + runWaitGroup: new(sync.WaitGroup), + } + + multiplexer.runWaitGroup.Add(1) + + // setUdpgwChannelHandler will close any existing + // udpgwPortForwardMultiplexer, waiting for all run/relayDownstream + // goroutines to first terminate and all UDP socket resources to be + // cleaned up. + // + // This synchronous shutdown also ensures that the + // concurrentPortForwardCount is reduced to 0 before installing the new + // udpgwPortForwardMultiplexer and its LRU object. If the older handler + // were to dangle with open port forwards, and concurrentPortForwardCount + // were to hit the max, the wrong LRU, the new one, would be used to + // close the LRU port forward. + // + // Call setUdpgwHandler only after runWaitGroup is initialized, to ensure + // runWaitGroup.Wait() cannot be invoked (by some subsequent new udpgw + // channel) before initialized. + + if !sshClient.setUdpgwChannelHandler(multiplexer) { + // setUdpgwChannelHandler returns false if some other SSH channel + // calls setUdpgwChannelHandler in the middle of this call. In that + // case, discard this channel: the client's latest udpgw channel is + // retained. + return } + multiplexer.run() + multiplexer.runWaitGroup.Done() } -type udpPortForwardMultiplexer struct { +type udpgwPortForwardMultiplexer struct { sshClient *sshClient sshChannelWriteMutex sync.Mutex sshChannel ssh.Channel portForwardsMutex sync.Mutex - portForwards map[uint16]*udpPortForward + portForwards map[uint16]*udpgwPortForward portForwardLRU *common.LRUConns relayWaitGroup *sync.WaitGroup + runWaitGroup *sync.WaitGroup +} + +func (mux *udpgwPortForwardMultiplexer) stop() { + + // udpgwPortForwardMultiplexer must be initialized by handleUdpgwChannel. + // + // stop closes the udpgw SSH channel, which will cause the run goroutine + // to exit its message read loop and await closure of all relayDownstream + // goroutines. Closing all port forward UDP conns will cause all + // relayDownstream to exit. + + _ = mux.sshChannel.Close() + + mux.portForwardsMutex.Lock() + for _, portForward := range mux.portForwards { + _ = portForward.conn.Close() + } + mux.portForwardsMutex.Unlock() + + mux.runWaitGroup.Wait() } -func (mux *udpPortForwardMultiplexer) run() { +func (mux *udpgwPortForwardMultiplexer) run() { - // In a loop, read udpgw messages from the client to this channel. Each message is - // a UDP packet to send upstream either via a new port forward, or on an existing - // port forward. + // In a loop, read udpgw messages from the client to this channel. Each + // message contains a UDP packet to send upstream either via a new port + // forward, or on an existing port forward. // // A goroutine is run to read downstream packets for each UDP port forward. All read // packets are encapsulated in udpgw protocol and sent down the channel to the client. @@ -92,16 +139,6 @@ func (mux *udpPortForwardMultiplexer) run() { // When the client disconnects or the server shuts down, the channel will close and // readUdpgwMessage will exit with EOF. - // Recover from and log any unexpected panics caused by udpgw input handling bugs. - // Note: this covers the run() goroutine only and not relayDownstream() goroutines. - defer func() { - if e := recover(); e != nil { - err := errors.Tracef( - "udpPortForwardMultiplexer panic: %s: %s", e, debug.Stack()) - log.WithTraceFields(LogFields{"error": err}).Warning("run failed") - } - }() - buffer := make([]byte, udpgwProtocolMaxMessageSize) for { // Note: message.packet points to the reusable memory in "buffer". @@ -119,27 +156,37 @@ func (mux *udpPortForwardMultiplexer) run() { portForward := mux.portForwards[message.connID] mux.portForwardsMutex.Unlock() - if portForward != nil && message.discardExistingConn { + // In the udpgw protocol, an existing port forward is closed when + // either the discard flag is set or the remote address has changed. + + if portForward != nil && + (message.discardExistingConn || + !bytes.Equal(portForward.remoteIP, message.remoteIP) || + portForward.remotePort != message.remotePort) { + // The port forward's goroutine will complete cleanup, including // tallying stats and calling sshClient.closedPortForward. // portForward.conn.Close() will signal this shutdown. - // TODO: wait for goroutine to exit before proceeding? portForward.conn.Close() - portForward = nil - } - - if portForward != nil { - // Verify that portForward remote address matches latest message + // Synchronously await the termination of the relayDownstream + // goroutine. This ensures that the previous goroutine won't + // invoke removePortForward, with the connID that will be reused + // for the new port forward, after this point. + // + // Limitation: this synchronous shutdown cannot prevent a "wrong + // remote address" error on the badvpn udpgw client, which occurs + // when the client recycles a port forward (setting discard) but + // receives, from the server, a udpgw message containing the old + // remote address for the previous port forward with the same + // conn ID. That downstream message from the server may be in + // flight in the SSH channel when the client discard message arrives. + portForward.relayWaitGroup.Wait() - if !bytes.Equal(portForward.remoteIP, message.remoteIP) || - portForward.remotePort != message.remotePort { - - log.WithTrace().Warning("UDP port forward remote address mismatch") - continue - } + portForward = nil + } - } else { + if portForward == nil { // Create a new port forward @@ -237,17 +284,18 @@ func (mux *udpPortForwardMultiplexer) run() { continue } - portForward = &udpPortForward{ - connID: message.connID, - preambleSize: message.preambleSize, - remoteIP: message.remoteIP, - remotePort: message.remotePort, - dialIP: dialIP, - conn: conn, - lruEntry: lruEntry, - bytesUp: 0, - bytesDown: 0, - mux: mux, + portForward = &udpgwPortForward{ + connID: message.connID, + preambleSize: message.preambleSize, + remoteIP: message.remoteIP, + remotePort: message.remotePort, + dialIP: dialIP, + conn: conn, + lruEntry: lruEntry, + bytesUp: 0, + bytesDown: 0, + relayWaitGroup: new(sync.WaitGroup), + mux: mux, } if message.forwardDNS { @@ -258,6 +306,7 @@ func (mux *udpPortForwardMultiplexer) run() { mux.portForwards[portForward.connID] = portForward mux.portForwardsMutex.Unlock() + portForward.relayWaitGroup.Add(1) mux.relayWaitGroup.Add(1) go portForward.relayDownstream() } @@ -276,7 +325,7 @@ func (mux *udpPortForwardMultiplexer) run() { atomic.AddInt64(&portForward.bytesUp, int64(len(message.packet))) } - // Cleanup all UDP port forward workers when exiting + // Cleanup all udpgw port forward workers when exiting mux.portForwardsMutex.Lock() for _, portForward := range mux.portForwards { @@ -288,13 +337,13 @@ func (mux *udpPortForwardMultiplexer) run() { mux.relayWaitGroup.Wait() } -func (mux *udpPortForwardMultiplexer) removePortForward(connID uint16) { +func (mux *udpgwPortForwardMultiplexer) removePortForward(connID uint16) { mux.portForwardsMutex.Lock() delete(mux.portForwards, connID) mux.portForwardsMutex.Unlock() } -type udpPortForward struct { +type udpgwPortForward struct { // Note: 64-bit ints used with atomic operations are placed // at the start of struct to ensure 64-bit alignment. // (https://golang.org/pkg/sync/atomic/#pkg-note-BUG) @@ -309,10 +358,12 @@ type udpPortForward struct { dialIP net.IP conn net.Conn lruEntry *common.LRUConnsEntry - mux *udpPortForwardMultiplexer + relayWaitGroup *sync.WaitGroup + mux *udpgwPortForwardMultiplexer } -func (portForward *udpPortForward) relayDownstream() { +func (portForward *udpgwPortForward) relayDownstream() { + defer portForward.relayWaitGroup.Done() defer portForward.mux.relayWaitGroup.Done() // Downstream UDP packets are read into the reusable memory