From bc2022e14571bab998daed2169785b2a0a748fa9 Mon Sep 17 00:00:00 2001 From: Clint Purser Date: Fri, 14 Jul 2023 17:15:15 -0400 Subject: [PATCH] add sessions functionality --- lib/src/di/di_clients/di_web_rtc_client.dart | 5 +- .../web_rtc_client/web_rtc_client.dart | 13 ++- lib/src/robot/client.dart | 17 +++- lib/src/robot/sessions_client.dart | 98 +++++++++++++++++++ lib/src/rpc/dial.dart | 54 ++++++---- 5 files changed, 161 insertions(+), 26 deletions(-) create mode 100644 lib/src/robot/sessions_client.dart diff --git a/lib/src/di/di_clients/di_web_rtc_client.dart b/lib/src/di/di_clients/di_web_rtc_client.dart index 95ddb30f97..42811e5bd4 100644 --- a/lib/src/di/di_clients/di_web_rtc_client.dart +++ b/lib/src/di/di_clients/di_web_rtc_client.dart @@ -15,8 +15,5 @@ Future _getWebRtcClient( final webRtcPeerConnection = WebRtcPeerConnection(webRtcDirectDataSource); await webRtcPeerConnection.createConnection(); - return WebRtcClientChannel( - webRtcPeerConnection.peerConnection, - webRtcPeerConnection.dataChannel, - ); + return WebRtcClientChannel(webRtcPeerConnection.peerConnection, webRtcPeerConnection.dataChannel, () => ''); } diff --git a/lib/src/domain/web_rtc/web_rtc_client/web_rtc_client.dart b/lib/src/domain/web_rtc/web_rtc_client/web_rtc_client.dart index fb8b83a478..433aa8f213 100644 --- a/lib/src/domain/web_rtc/web_rtc_client/web_rtc_client.dart +++ b/lib/src/domain/web_rtc/web_rtc_client/web_rtc_client.dart @@ -1,14 +1,17 @@ import 'package:flutter_webrtc/flutter_webrtc.dart'; +import 'package:grpc/grpc.dart'; import 'package:grpc/grpc_connection_interface.dart'; +import '../../../robot/sessions_client.dart'; import 'web_rtc_client_connection.dart'; class WebRtcClientChannel extends ClientChannelBase { final RTCPeerConnection rtcPeerConnection; final RTCDataChannel dataChannel; + final String Function() _sessionId; final List onMessageListeners = []; - WebRtcClientChannel(this.rtcPeerConnection, this.dataChannel) { + WebRtcClientChannel(this.rtcPeerConnection, this.dataChannel, this._sessionId) { dataChannel.onMessage = (data) { onMessageListeners.forEach((listener) => listener(data)); }; @@ -22,4 +25,12 @@ class WebRtcClientChannel extends ClientChannelBase { @override ClientConnection createConnection() => WebRtcClientConnection(this); + + @override + ClientCall createCall(ClientMethod method, Stream requests, CallOptions options) { + if (!SessionsClient.unallowedMethods.contains(method.path)) { + options = options.mergedWith(CallOptions(metadata: {SessionsClient.sessionMetadataKey: _sessionId()})); + } + return super.createCall(method, requests, options); + } } diff --git a/lib/src/robot/client.dart b/lib/src/robot/client.dart index af5f94eb62..30818ecd1f 100644 --- a/lib/src/robot/client.dart +++ b/lib/src/robot/client.dart @@ -2,6 +2,7 @@ import 'dart:async'; import 'package:grpc/grpc_connection_interface.dart'; import 'package:logger/logger.dart'; +import 'package:viam_sdk/src/robot/sessions_client.dart'; import '../domain/web_rtc/web_rtc_client/web_rtc_client.dart'; import '../gen/common/v1/common.pb.dart'; @@ -26,6 +27,9 @@ class RobotClientOptions { /// The frequency (in seconds) at which to attempt to reconnect a disconnected robot. 0 (zero) signifies no reconnection attempts final attemptReconnectInterval = 1; + /// Whether sessions are enabled + final enableSessions = true; + RobotClientOptions() : dialOptions = DialOptions(); /// Convenience initializer for creating options with specified [DialOptions] @@ -43,9 +47,10 @@ class RobotClientOptions { class RobotClient { bool _connected = true; late String _address; - late DialOptions _options; + late RobotClientOptions _options; late ClientChannelBase _channel; late RobotServiceClient _client; + late SessionsClient _sessionsClient; List resourceNames = []; ResourceManager _manager = ResourceManager(); late final StreamManager _streamManager; @@ -56,8 +61,9 @@ class RobotClient { static Future atAddress(String url, RobotClientOptions options) async { final client = RobotClient._(); client._address = url; - client._options = options.dialOptions; - client._channel = await dial(url, options.dialOptions); + client._options = options; + client._channel = await dial(url, options.dialOptions, () => client._sessionsClient.metadata()); + client._sessionsClient = SessionsClient(client._channel, options.enableSessions); client._client = RobotServiceClient(client._channel); client._streamManager = StreamManager(client._channel as WebRtcClientChannel); await client.refresh(); @@ -145,19 +151,22 @@ class RobotClient { .d('Attempting to reconnect to the robot at $_address every $reconnectInterval ${(reconnectInterval > 1) ? 'seconds' : 'second'}'); while (!_connected) { + _sessionsClient.reset(); try { - final channel = await dial(_address, _options); + final channel = await dial(_address, _options.dialOptions, () => _sessionsClient.metadata()); final client = RobotServiceClient(channel); await client.resourceNames(ResourceNamesRequest()); _channel = channel; _streamManager.channel = _channel as WebRtcClientChannel; _client = client; + _sessionsClient = SessionsClient(_channel, _options.enableSessions); await refresh(); _connected = true; _logger.d('Successfully reconnected robot'); } catch (e) { await _channel.shutdown(); + _sessionsClient.reset(); _logger.d('Failed to reconnect, trying again in $reconnectInterval ${(reconnectInterval > 1) ? 'seconds' : 'second'}'); await Future.delayed(Duration(seconds: reconnectInterval)); } diff --git a/lib/src/robot/sessions_client.dart b/lib/src/robot/sessions_client.dart new file mode 100644 index 0000000000..7e8af6f5a1 --- /dev/null +++ b/lib/src/robot/sessions_client.dart @@ -0,0 +1,98 @@ +import 'dart:async'; + +import 'package:grpc/grpc_connection_interface.dart'; +import 'package:logger/logger.dart'; +import 'package:viam_sdk/protos/robot/robot.dart'; + +import '../resource/base.dart'; + +final _logger = Logger(); + +/// A Session allows a client to express that it is actively connected +/// and supports stopping actuating components when it's not. +class SessionsClient implements ResourceRPCClient { + static const sessionMetadataKey = 'viam-sid'; + static const unallowedMethods = [ + '/grpc.reflection.v1alpha.ServerReflection/ServerReflectionInfo', + '/proto.rpc.webrtc.v1.SignalingService/Call', + '/proto.rpc.webrtc.v1.SignalingService/CallUpdate', + '/proto.rpc.webrtc.v1.SignalingService/OptionalWebRTCConfig', + '/proto.rpc.v1.AuthService/Authenticate', + '/viam.robot.v1.RobotService/ResourceNames', + '/viam.robot.v1.RobotService/ResourceRPCSubtypes', + '/viam.robot.v1.RobotService/StartSession', + '/viam.robot.v1.RobotService/SendSessionHeartbeat', + ]; + + @override + ClientChannelBase channel; + + @override + RobotServiceClient get client => RobotServiceClient(channel); + + String _currentId = ''; + final bool _enabled; + bool _supported = false; + late Duration _heartbeatInterval; + + SessionsClient(this.channel, this._enabled) { + metadata(); + } + + String metadata() { + if (!_enabled) return ''; + + if (_currentId != '') return _currentId; + + final request = StartSessionRequest(); + try { + final future = client.startSession(request); + + future.then((response) { + _supported = true; + _currentId = response.id; + + // We send heartbeats slightly faster than the interval window to + // ensure that we don't fall outside of it and expire the session. + _heartbeatInterval = Duration( + seconds: response.heartbeatWindow.seconds.toInt() ~/ 5, + microseconds: response.heartbeatWindow.nanos ~/ 5, + ); + + _heartbeatTask(); + + return _currentId; + }); + } catch (e) { + _logger.e('error starting session: $e'); + } + + return ''; + } + + void reset() { + _logger.d('resetting session'); + _currentId = ''; + _supported = false; + } + + Future _heartbeatTask() async { + while (_supported) { + _heartbeatTick(); + await Future.delayed(_heartbeatInterval); + } + } + + void _heartbeatTick() { + if (!_supported) return; + + final request = SendSessionHeartbeatRequest()..id = _currentId; + + try { + client.sendSessionHeartbeat(request); + } on GrpcError catch (e) { + _logger.d('Session terminated: $e'); + reset(); + } + } +} diff --git a/lib/src/rpc/dial.dart b/lib/src/rpc/dial.dart index 1da0c30fd8..95c36dffd9 100644 --- a/lib/src/rpc/dial.dart +++ b/lib/src/rpc/dial.dart @@ -6,6 +6,7 @@ import 'package:grpc/grpc.dart'; import 'package:grpc/grpc_connection_interface.dart'; import 'package:grpc/grpc_or_grpcweb.dart'; import 'package:logger/logger.dart'; +import 'package:viam_sdk/src/robot/sessions_client.dart'; import '../domain/web_rtc/web_rtc_client/web_rtc_client.dart'; import '../gen/proto/rpc/v1/auth.pb.dart' as pb; @@ -96,7 +97,7 @@ class DialWebRtcOptions { } /// Connect to a robot at the provided address with the given options -Future dial(String address, DialOptions? options) async { +Future dial(String address, DialOptions? options, String Function() sessionCallback) async { _logger.i('Connecting to Robot at $address'); final opts = options ?? DialOptions(); bool disableWebRtc = opts.webRtcOptions?.disable ?? false; @@ -104,26 +105,26 @@ Future dial(String address, DialOptions? options) async { disableWebRtc = true; } if (disableWebRtc) { - return _dialDirectGrpc(address, opts); + return _dialDirectGrpc(address, opts, sessionCallback); } - return _dialWebRtc(address, opts); + return _dialWebRtc(address, opts, sessionCallback); } -Future _dialDirectGrpc(String address, DialOptions options) async { +Future _dialDirectGrpc(String address, DialOptions options, String Function() sessionCallback) async { _logger.d('Dialing direct GRPC to $address'); if (options.credentials == null) { final host = _hostAndPort(address, options.insecure); - return ClientChannel(host.host, + return GrpcOrGrpcWebClientChannel.grpc(host.host, port: host.port, options: ChannelOptions( credentials: options.insecure ? const ChannelCredentials.insecure() : const ChannelCredentials.secure(), codecRegistry: CodecRegistry(codecs: const [GzipCodec(), IdentityCodec()]), )); } - return _authenticatedChannel(address, options); + return _authenticatedChannel(address, options, sessionCallback); } -Future _dialWebRtc(String address, DialOptions options) async { +Future _dialWebRtc(String address, DialOptions options, String Function() sessionCallback) async { _logger.d('Dialing WebRTC to $address'); if (options.authEntity.isNullOrEmpty) { if (options.externalAuthAddress.isNullOrEmpty) { @@ -137,7 +138,7 @@ Future _dialWebRtc(String address, DialOptions options) async final signalingServer = options.webRtcOptions?.signalingServerAddress ?? 'app.viam.com'; _logger.d('Connecting to signaling server: $signalingServer'); - final signalingChannel = await _dialDirectGrpc(signalingServer, options); + final signalingChannel = await _dialDirectGrpc(signalingServer, options, sessionCallback); _logger.d('Connected to signaling server: $signalingServer'); final signalingClient = SignalingServiceClient(signalingChannel, options: CallOptions(metadata: {'rpc-host': address})); WebRTCConfig config; @@ -306,9 +307,9 @@ Future _dialWebRtc(String address, DialOptions options) async await didConnect.future; } catch (error, st) { _logger.i('Could not connect via WebRTC, attempting direct gRPC connection', error, st); - return _dialDirectGrpc(address, options); + return _dialDirectGrpc(address, options, sessionCallback); } - return WebRtcClientChannel(peerConnection, dataChannel); + return WebRtcClientChannel(peerConnection, dataChannel, sessionCallback); } String _convertSDPtoJsonString(RTCSessionDescription? sdp) { @@ -321,19 +322,18 @@ String _encodeSDPJsonStringToBase64String(String sdp) { return base64.encode(bytes); } -Future _authenticatedChannel(String address, DialOptions options) async { +Future _authenticatedChannel(String address, DialOptions options, String Function() sessionsCallback) async { String accessToken = options.accessToken ?? ''; if (accessToken.isNotEmpty && options.externalAuthAddress.isNullOrEmpty && options.externalAuthToEntity.isNullOrEmpty) { _logger.d('Received pre-authenticated access token'); final addr = _hostAndPort(address, options.insecure); - return AuthenticatedChannel(addr.host, addr.port, accessToken, options.insecure); + return AuthenticatedChannel(addr.host, addr.port, accessToken, options.insecure, sessionsCallback); } final addr = _hostAndPort(options.externalAuthAddress ?? address, options.insecure); final authEntity = options.authEntity ?? address.replaceAll(RegExp(r'^(.*:\/\/)/'), ''); _logger.d('Authenticating to address: $addr, for entity: $authEntity'); - GrpcOrGrpcWebClientChannel authChannel = - GrpcOrGrpcWebClientChannel.toSingleEndpoint(host: addr.host, port: addr.port, transportSecure: !options.insecure); + var authChannel = GrpcOrGrpcWebClientChannel.toSingleEndpoint(host: addr.host, port: addr.port, transportSecure: !options.insecure); final authClient = AuthServiceClient(authChannel); final credentials = pb.Credentials(); if (options.credentials?.type != null) { @@ -358,7 +358,7 @@ Future _authenticatedChannel(String address, DialOpt if (options.externalAuthAddress.isNotNullNorEmpty && options.externalAuthToEntity.isNotNullNorEmpty) { final addr = _hostAndPort(options.externalAuthAddress!, options.insecure); _logger.d('Authenticating to external address: $addr, for entity: ${options.externalAuthToEntity}'); - authChannel = AuthenticatedChannel(addr.host, addr.port, accessToken, options.insecure); + authChannel = AuthenticatedChannel(addr.host, addr.port, accessToken, options.insecure, sessionsCallback); final extAuthClient = ExternalAuthServiceClient(authChannel); final toRequest = pb.AuthenticateToRequest(); if (options.externalAuthToEntity != null) { @@ -375,13 +375,14 @@ Future _authenticatedChannel(String address, DialOpt } final actual = _hostAndPort(address, options.insecure); - return AuthenticatedChannel(actual.host, actual.port, accessToken, options.insecure); + return AuthenticatedChannel(actual.host, actual.port, accessToken, options.insecure, sessionsCallback); } class AuthenticatedChannel extends GrpcOrGrpcWebClientChannel { final String accessToken; + final String Function()? _sessionId; - AuthenticatedChannel(String host, int port, this.accessToken, bool insecure) + AuthenticatedChannel(String host, int port, this.accessToken, bool insecure, [this._sessionId]) : super.toSingleEndpoint( host: host, port: port, @@ -390,6 +391,10 @@ class AuthenticatedChannel extends GrpcOrGrpcWebClientChannel { @override ClientCall createCall(ClientMethod method, Stream requests, CallOptions options) { + if (!SessionsClient.unallowedMethods.contains(method.path) && _sessionId != null) { + options = options.mergedWith(CallOptions(metadata: {SessionsClient.sessionMetadataKey: _sessionId!()})); + } + options = options.mergedWith(CallOptions(metadata: {'Authorization': 'Bearer $accessToken'})); return super.createCall(method, requests, options); } @@ -416,3 +421,18 @@ _HostAndPort _hostAndPort(String address, bool insecure) { } return _HostAndPort(host, port); } + +class ClientChannelWithSessions extends GrpcOrGrpcWebClientChannel { + final String Function() _sessionId; + + ClientChannelWithSessions.toSingleEndpoint(this._sessionId, {required super.host, required super.port, required super.transportSecure}) + : super.toSingleEndpoint(); + + @override + ClientCall createCall(ClientMethod method, Stream requests, CallOptions options) { + if (!SessionsClient.unallowedMethods.contains(method.path)) { + options = options.mergedWith(CallOptions(metadata: {SessionsClient.sessionMetadataKey: _sessionId()})); + } + return super.createCall(method, requests, options); + } +}