Skip to content

Commit

Permalink
add sessions functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
clintpurser committed Jul 14, 2023
1 parent 537f048 commit bc2022e
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 26 deletions.
5 changes: 1 addition & 4 deletions lib/src/di/di_clients/di_web_rtc_client.dart
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,5 @@ Future<WebRtcClientChannel> _getWebRtcClient(
final webRtcPeerConnection = WebRtcPeerConnection(webRtcDirectDataSource);
await webRtcPeerConnection.createConnection();

return WebRtcClientChannel(
webRtcPeerConnection.peerConnection,
webRtcPeerConnection.dataChannel,
);
return WebRtcClientChannel(webRtcPeerConnection.peerConnection, webRtcPeerConnection.dataChannel, () => '');
}
13 changes: 12 additions & 1 deletion lib/src/domain/web_rtc/web_rtc_client/web_rtc_client.dart
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import 'package:flutter_webrtc/flutter_webrtc.dart';
import 'package:grpc/grpc.dart';
import 'package:grpc/grpc_connection_interface.dart';

import '../../../robot/sessions_client.dart';
import 'web_rtc_client_connection.dart';

class WebRtcClientChannel extends ClientChannelBase {
final RTCPeerConnection rtcPeerConnection;
final RTCDataChannel dataChannel;
final String Function() _sessionId;
final List<Function(RTCDataChannelMessage data)> onMessageListeners = [];

WebRtcClientChannel(this.rtcPeerConnection, this.dataChannel) {
WebRtcClientChannel(this.rtcPeerConnection, this.dataChannel, this._sessionId) {
dataChannel.onMessage = (data) {
onMessageListeners.forEach((listener) => listener(data));
};
Expand All @@ -22,4 +25,12 @@ class WebRtcClientChannel extends ClientChannelBase {

@override
ClientConnection createConnection() => WebRtcClientConnection(this);

@override
ClientCall<Q, R> createCall<Q, R>(ClientMethod<Q, R> method, Stream<Q> requests, CallOptions options) {
if (!SessionsClient.unallowedMethods.contains(method.path)) {
options = options.mergedWith(CallOptions(metadata: {SessionsClient.sessionMetadataKey: _sessionId()}));
}
return super.createCall(method, requests, options);
}
}
17 changes: 13 additions & 4 deletions lib/src/robot/client.dart
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import 'dart:async';

import 'package:grpc/grpc_connection_interface.dart';
import 'package:logger/logger.dart';
import 'package:viam_sdk/src/robot/sessions_client.dart';

import '../domain/web_rtc/web_rtc_client/web_rtc_client.dart';
import '../gen/common/v1/common.pb.dart';
Expand All @@ -26,6 +27,9 @@ class RobotClientOptions {
/// The frequency (in seconds) at which to attempt to reconnect a disconnected robot. 0 (zero) signifies no reconnection attempts
final attemptReconnectInterval = 1;

/// Whether sessions are enabled
final enableSessions = true;

RobotClientOptions() : dialOptions = DialOptions();

/// Convenience initializer for creating options with specified [DialOptions]
Expand All @@ -43,9 +47,10 @@ class RobotClientOptions {
class RobotClient {
bool _connected = true;
late String _address;
late DialOptions _options;
late RobotClientOptions _options;
late ClientChannelBase _channel;
late RobotServiceClient _client;
late SessionsClient _sessionsClient;
List<ResourceName> resourceNames = [];
ResourceManager _manager = ResourceManager();
late final StreamManager _streamManager;
Expand All @@ -56,8 +61,9 @@ class RobotClient {
static Future<RobotClient> atAddress(String url, RobotClientOptions options) async {
final client = RobotClient._();
client._address = url;
client._options = options.dialOptions;
client._channel = await dial(url, options.dialOptions);
client._options = options;
client._channel = await dial(url, options.dialOptions, () => client._sessionsClient.metadata());
client._sessionsClient = SessionsClient(client._channel, options.enableSessions);
client._client = RobotServiceClient(client._channel);
client._streamManager = StreamManager(client._channel as WebRtcClientChannel);
await client.refresh();
Expand Down Expand Up @@ -145,19 +151,22 @@ class RobotClient {
.d('Attempting to reconnect to the robot at $_address every $reconnectInterval ${(reconnectInterval > 1) ? 'seconds' : 'second'}');

while (!_connected) {
_sessionsClient.reset();
try {
final channel = await dial(_address, _options);
final channel = await dial(_address, _options.dialOptions, () => _sessionsClient.metadata());
final client = RobotServiceClient(channel);
await client.resourceNames(ResourceNamesRequest());

_channel = channel;
_streamManager.channel = _channel as WebRtcClientChannel;
_client = client;
_sessionsClient = SessionsClient(_channel, _options.enableSessions);
await refresh();
_connected = true;
_logger.d('Successfully reconnected robot');
} catch (e) {
await _channel.shutdown();
_sessionsClient.reset();
_logger.d('Failed to reconnect, trying again in $reconnectInterval ${(reconnectInterval > 1) ? 'seconds' : 'second'}');
await Future.delayed(Duration(seconds: reconnectInterval));
}
Expand Down
98 changes: 98 additions & 0 deletions lib/src/robot/sessions_client.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import 'dart:async';

import 'package:grpc/grpc_connection_interface.dart';
import 'package:logger/logger.dart';
import 'package:viam_sdk/protos/robot/robot.dart';

import '../resource/base.dart';

final _logger = Logger();

/// A Session allows a client to express that it is actively connected
/// and supports stopping actuating components when it's not.
class SessionsClient implements ResourceRPCClient {
static const sessionMetadataKey = 'viam-sid';
static const unallowedMethods = [
'/grpc.reflection.v1alpha.ServerReflection/ServerReflectionInfo',
'/proto.rpc.webrtc.v1.SignalingService/Call',
'/proto.rpc.webrtc.v1.SignalingService/CallUpdate',
'/proto.rpc.webrtc.v1.SignalingService/OptionalWebRTCConfig',
'/proto.rpc.v1.AuthService/Authenticate',
'/viam.robot.v1.RobotService/ResourceNames',
'/viam.robot.v1.RobotService/ResourceRPCSubtypes',
'/viam.robot.v1.RobotService/StartSession',
'/viam.robot.v1.RobotService/SendSessionHeartbeat',
];

@override
ClientChannelBase channel;

@override
RobotServiceClient get client => RobotServiceClient(channel);

String _currentId = '';
final bool _enabled;
bool _supported = false;
late Duration _heartbeatInterval;

SessionsClient(this.channel, this._enabled) {
metadata();
}

String metadata() {
if (!_enabled) return '';

if (_currentId != '') return _currentId;

final request = StartSessionRequest();
try {
final future = client.startSession(request);

future.then((response) {
_supported = true;
_currentId = response.id;

// We send heartbeats slightly faster than the interval window to
// ensure that we don't fall outside of it and expire the session.
_heartbeatInterval = Duration(
seconds: response.heartbeatWindow.seconds.toInt() ~/ 5,
microseconds: response.heartbeatWindow.nanos ~/ 5,
);

_heartbeatTask();

return _currentId;
});
} catch (e) {
_logger.e('error starting session: $e');
}

return '';
}

void reset() {
_logger.d('resetting session');
_currentId = '';
_supported = false;
}

Future<void> _heartbeatTask() async {
while (_supported) {
_heartbeatTick();
await Future.delayed(_heartbeatInterval);
}
}

void _heartbeatTick() {
if (!_supported) return;

final request = SendSessionHeartbeatRequest()..id = _currentId;

try {
client.sendSessionHeartbeat(request);
} on GrpcError catch (e) {
_logger.d('Session terminated: $e');
reset();
}
}
}
54 changes: 37 additions & 17 deletions lib/src/rpc/dial.dart
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import 'package:grpc/grpc.dart';
import 'package:grpc/grpc_connection_interface.dart';
import 'package:grpc/grpc_or_grpcweb.dart';
import 'package:logger/logger.dart';
import 'package:viam_sdk/src/robot/sessions_client.dart';

import '../domain/web_rtc/web_rtc_client/web_rtc_client.dart';
import '../gen/proto/rpc/v1/auth.pb.dart' as pb;
Expand Down Expand Up @@ -96,34 +97,34 @@ class DialWebRtcOptions {
}

/// Connect to a robot at the provided address with the given options
Future<ClientChannelBase> dial(String address, DialOptions? options) async {
Future<ClientChannelBase> dial(String address, DialOptions? options, String Function() sessionCallback) async {
_logger.i('Connecting to Robot at $address');
final opts = options ?? DialOptions();
bool disableWebRtc = opts.webRtcOptions?.disable ?? false;
if (address.contains('.local.') || address.contains('localhost')) {
disableWebRtc = true;
}
if (disableWebRtc) {
return _dialDirectGrpc(address, opts);
return _dialDirectGrpc(address, opts, sessionCallback);
}
return _dialWebRtc(address, opts);
return _dialWebRtc(address, opts, sessionCallback);
}

Future<ClientChannelBase> _dialDirectGrpc(String address, DialOptions options) async {
Future<ClientChannelBase> _dialDirectGrpc(String address, DialOptions options, String Function() sessionCallback) async {
_logger.d('Dialing direct GRPC to $address');
if (options.credentials == null) {
final host = _hostAndPort(address, options.insecure);
return ClientChannel(host.host,
return GrpcOrGrpcWebClientChannel.grpc(host.host,
port: host.port,
options: ChannelOptions(
credentials: options.insecure ? const ChannelCredentials.insecure() : const ChannelCredentials.secure(),
codecRegistry: CodecRegistry(codecs: const [GzipCodec(), IdentityCodec()]),
));
}
return _authenticatedChannel(address, options);
return _authenticatedChannel(address, options, sessionCallback);
}

Future<ClientChannelBase> _dialWebRtc(String address, DialOptions options) async {
Future<ClientChannelBase> _dialWebRtc(String address, DialOptions options, String Function() sessionCallback) async {
_logger.d('Dialing WebRTC to $address');
if (options.authEntity.isNullOrEmpty) {
if (options.externalAuthAddress.isNullOrEmpty) {
Expand All @@ -137,7 +138,7 @@ Future<ClientChannelBase> _dialWebRtc(String address, DialOptions options) async

final signalingServer = options.webRtcOptions?.signalingServerAddress ?? 'app.viam.com';
_logger.d('Connecting to signaling server: $signalingServer');
final signalingChannel = await _dialDirectGrpc(signalingServer, options);
final signalingChannel = await _dialDirectGrpc(signalingServer, options, sessionCallback);
_logger.d('Connected to signaling server: $signalingServer');
final signalingClient = SignalingServiceClient(signalingChannel, options: CallOptions(metadata: {'rpc-host': address}));
WebRTCConfig config;
Expand Down Expand Up @@ -306,9 +307,9 @@ Future<ClientChannelBase> _dialWebRtc(String address, DialOptions options) async
await didConnect.future;
} catch (error, st) {
_logger.i('Could not connect via WebRTC, attempting direct gRPC connection', error, st);
return _dialDirectGrpc(address, options);
return _dialDirectGrpc(address, options, sessionCallback);
}
return WebRtcClientChannel(peerConnection, dataChannel);
return WebRtcClientChannel(peerConnection, dataChannel, sessionCallback);
}

String _convertSDPtoJsonString(RTCSessionDescription? sdp) {
Expand All @@ -321,19 +322,18 @@ String _encodeSDPJsonStringToBase64String(String sdp) {
return base64.encode(bytes);
}

Future<GrpcOrGrpcWebClientChannel> _authenticatedChannel(String address, DialOptions options) async {
Future<AuthenticatedChannel> _authenticatedChannel(String address, DialOptions options, String Function() sessionsCallback) async {
String accessToken = options.accessToken ?? '';
if (accessToken.isNotEmpty && options.externalAuthAddress.isNullOrEmpty && options.externalAuthToEntity.isNullOrEmpty) {
_logger.d('Received pre-authenticated access token');
final addr = _hostAndPort(address, options.insecure);
return AuthenticatedChannel(addr.host, addr.port, accessToken, options.insecure);
return AuthenticatedChannel(addr.host, addr.port, accessToken, options.insecure, sessionsCallback);
}

final addr = _hostAndPort(options.externalAuthAddress ?? address, options.insecure);
final authEntity = options.authEntity ?? address.replaceAll(RegExp(r'^(.*:\/\/)/'), '');
_logger.d('Authenticating to address: $addr, for entity: $authEntity');
GrpcOrGrpcWebClientChannel authChannel =
GrpcOrGrpcWebClientChannel.toSingleEndpoint(host: addr.host, port: addr.port, transportSecure: !options.insecure);
var authChannel = GrpcOrGrpcWebClientChannel.toSingleEndpoint(host: addr.host, port: addr.port, transportSecure: !options.insecure);
final authClient = AuthServiceClient(authChannel);
final credentials = pb.Credentials();
if (options.credentials?.type != null) {
Expand All @@ -358,7 +358,7 @@ Future<GrpcOrGrpcWebClientChannel> _authenticatedChannel(String address, DialOpt
if (options.externalAuthAddress.isNotNullNorEmpty && options.externalAuthToEntity.isNotNullNorEmpty) {
final addr = _hostAndPort(options.externalAuthAddress!, options.insecure);
_logger.d('Authenticating to external address: $addr, for entity: ${options.externalAuthToEntity}');
authChannel = AuthenticatedChannel(addr.host, addr.port, accessToken, options.insecure);
authChannel = AuthenticatedChannel(addr.host, addr.port, accessToken, options.insecure, sessionsCallback);
final extAuthClient = ExternalAuthServiceClient(authChannel);
final toRequest = pb.AuthenticateToRequest();
if (options.externalAuthToEntity != null) {
Expand All @@ -375,13 +375,14 @@ Future<GrpcOrGrpcWebClientChannel> _authenticatedChannel(String address, DialOpt
}

final actual = _hostAndPort(address, options.insecure);
return AuthenticatedChannel(actual.host, actual.port, accessToken, options.insecure);
return AuthenticatedChannel(actual.host, actual.port, accessToken, options.insecure, sessionsCallback);
}

class AuthenticatedChannel extends GrpcOrGrpcWebClientChannel {
final String accessToken;
final String Function()? _sessionId;

AuthenticatedChannel(String host, int port, this.accessToken, bool insecure)
AuthenticatedChannel(String host, int port, this.accessToken, bool insecure, [this._sessionId])
: super.toSingleEndpoint(
host: host,
port: port,
Expand All @@ -390,6 +391,10 @@ class AuthenticatedChannel extends GrpcOrGrpcWebClientChannel {

@override
ClientCall<Q, R> createCall<Q, R>(ClientMethod<Q, R> method, Stream<Q> requests, CallOptions options) {
if (!SessionsClient.unallowedMethods.contains(method.path) && _sessionId != null) {
options = options.mergedWith(CallOptions(metadata: {SessionsClient.sessionMetadataKey: _sessionId!()}));
}

options = options.mergedWith(CallOptions(metadata: {'Authorization': 'Bearer $accessToken'}));
return super.createCall(method, requests, options);
}
Expand All @@ -416,3 +421,18 @@ _HostAndPort _hostAndPort(String address, bool insecure) {
}
return _HostAndPort(host, port);
}

class ClientChannelWithSessions extends GrpcOrGrpcWebClientChannel {
final String Function() _sessionId;

ClientChannelWithSessions.toSingleEndpoint(this._sessionId, {required super.host, required super.port, required super.transportSecure})
: super.toSingleEndpoint();

@override
ClientCall<Q, R> createCall<Q, R>(ClientMethod<Q, R> method, Stream<Q> requests, CallOptions options) {
if (!SessionsClient.unallowedMethods.contains(method.path)) {
options = options.mergedWith(CallOptions(metadata: {SessionsClient.sessionMetadataKey: _sessionId()}));
}
return super.createCall(method, requests, options);
}
}

0 comments on commit bc2022e

Please sign in to comment.