diff --git a/atala-prism-sdk/src/commonMain/kotlin/io/iohk/atala/prism/walletsdk/prismagent/ConnectionManager.kt b/atala-prism-sdk/src/commonMain/kotlin/io/iohk/atala/prism/walletsdk/prismagent/ConnectionManager.kt index f0a4f7394..ccfc61e49 100644 --- a/atala-prism-sdk/src/commonMain/kotlin/io/iohk/atala/prism/walletsdk/prismagent/ConnectionManager.kt +++ b/atala-prism-sdk/src/commonMain/kotlin/io/iohk/atala/prism/walletsdk/prismagent/ConnectionManager.kt @@ -35,6 +35,7 @@ import kotlin.jvm.Throws * @property castor The instance of the Castor interface used for working with DIDs. * @property pluto The instance of the Pluto interface used for storing messages and connection information. * @property mediationHandler The instance of the MediationHandler interface used for handling mediation. + * @property experimentLiveModeOptIn Flag to opt in or out of the experimental feature mediator live mode, using websockets. * @property pairings The mutable list of DIDPair representing the connections managed by the ConnectionManager. */ class ConnectionManager( @@ -44,6 +45,7 @@ class ConnectionManager( internal val mediationHandler: MediationHandler, private var pairings: MutableList, private val pollux: Pollux, + private val experimentLiveModeOptIn: Boolean = false, private val scope: CoroutineScope = CoroutineScope(Dispatchers.IO) ) : ConnectionsManager, DIDCommConnection { @@ -66,22 +68,23 @@ class ConnectionManager( // Resolve the DID document for the mediator val mediatorDidDoc = castor.resolveDID(currentMediatorDID.toString()) var serviceEndpoint: String? = null - - // Loop through the services in the DID document to find a WebSocket endpoint - mediatorDidDoc.services.forEach { - if (it.serviceEndpoint.uri.contains("wss://") || it.serviceEndpoint.uri.contains("ws://")) { - serviceEndpoint = it.serviceEndpoint.uri - return@forEach // Exit loop once the WebSocket endpoint is found + if (experimentLiveModeOptIn) { + // Loop through the services in the DID document to find a WebSocket endpoint + mediatorDidDoc.services.forEach { + if (it.serviceEndpoint.uri.contains("wss://") || it.serviceEndpoint.uri.contains("ws://")) { + serviceEndpoint = it.serviceEndpoint.uri + return@forEach // Exit loop once the WebSocket endpoint is found + } } - } - // If a WebSocket service endpoint is found - serviceEndpoint?.let { serviceEndpointUrl -> - // Listen for unread messages on the WebSocket endpoint - mediationHandler.listenUnreadMessages( - serviceEndpointUrl - ) { arrayMessages -> - processMessages(arrayMessages) + // If a WebSocket service endpoint is found + serviceEndpoint?.let { serviceEndpointUrl -> + // Listen for unread messages on the WebSocket endpoint + mediationHandler.listenUnreadMessages( + serviceEndpointUrl + ) { arrayMessages -> + processMessages(arrayMessages) + } } } diff --git a/atala-prism-sdk/src/commonMain/kotlin/io/iohk/atala/prism/walletsdk/prismagent/PrismAgent.kt b/atala-prism-sdk/src/commonMain/kotlin/io/iohk/atala/prism/walletsdk/prismagent/PrismAgent.kt index ad5d23105..488ae5b42 100644 --- a/atala-prism-sdk/src/commonMain/kotlin/io/iohk/atala/prism/walletsdk/prismagent/PrismAgent.kt +++ b/atala-prism-sdk/src/commonMain/kotlin/io/iohk/atala/prism/walletsdk/prismagent/PrismAgent.kt @@ -47,6 +47,7 @@ import io.iohk.atala.prism.walletsdk.logger.PrismLoggerImpl import io.iohk.atala.prism.walletsdk.pollux.models.AnonCredential import io.iohk.atala.prism.walletsdk.pollux.models.CredentialRequestMeta import io.iohk.atala.prism.walletsdk.pollux.models.JWTCredential +import io.iohk.atala.prism.walletsdk.prismagent.helpers.AgentOptions import io.iohk.atala.prism.walletsdk.prismagent.mediation.BasicMediatorHandler import io.iohk.atala.prism.walletsdk.prismagent.mediation.MediationHandler import io.iohk.atala.prism.walletsdk.prismagent.protocols.ProtocolType @@ -119,6 +120,7 @@ class PrismAgent { private val api: Api private var connectionManager: ConnectionManager private var logger: PrismLogger + private val agentOptions: AgentOptions /** * Initializes the PrismAgent with the given dependencies. @@ -133,6 +135,7 @@ class PrismAgent { * @param api An optional Api instance used by the PrismAgent if provided, otherwise a default ApiImpl will be used. * @param logger An optional PrismLogger instance used by the PrismAgent if provided, otherwise a PrismLoggerImpl with * LogComponent.PRISM_AGENT will be used. + * @param agentOptions Options to configure certain features with in the prism agent. */ @JvmOverloads constructor( @@ -144,7 +147,8 @@ class PrismAgent { connectionManager: ConnectionManager, seed: Seed?, api: Api?, - logger: PrismLogger = PrismLoggerImpl(LogComponent.PRISM_AGENT) + logger: PrismLogger = PrismLoggerImpl(LogComponent.PRISM_AGENT), + agentOptions: AgentOptions ) { prismAgentScope.launch { flowState.emit(State.STOPPED) @@ -170,6 +174,7 @@ class PrismAgent { } ) this.logger = logger + this.agentOptions = agentOptions } /** @@ -184,6 +189,7 @@ class PrismAgent { * @param api The instance of the API. Default is null. * @param mediatorHandler The mediator handler. * @param logger The logger for PrismAgent. Default is PrismLoggerImpl with LogComponent.PRISM_AGENT. + * @param agentOptions Options to configure certain features with in the prism agent. */ @JvmOverloads constructor( @@ -195,7 +201,8 @@ class PrismAgent { seed: Seed? = null, api: Api? = null, mediatorHandler: MediationHandler, - logger: PrismLogger = PrismLoggerImpl(LogComponent.PRISM_AGENT) + logger: PrismLogger = PrismLoggerImpl(LogComponent.PRISM_AGENT), + agentOptions: AgentOptions ) { prismAgentScope.launch { flowState.emit(State.STOPPED) @@ -220,9 +227,10 @@ class PrismAgent { } ) this.logger = logger + this.agentOptions = agentOptions // Pairing will be removed in the future this.connectionManager = - ConnectionManager(mercury, castor, pluto, mediatorHandler, mutableListOf(), pollux) + ConnectionManager(mercury, castor, pluto, mediatorHandler, mutableListOf(), pollux, agentOptions.experiments.liveMode) } init { @@ -462,7 +470,7 @@ class PrismAgent { fun setupMediatorHandler(mediatorHandler: MediationHandler) { stop() this.connectionManager = - ConnectionManager(mercury, castor, pluto, mediatorHandler, mutableListOf(), pollux) + ConnectionManager(mercury, castor, pluto, mediatorHandler, mutableListOf(), pollux, agentOptions.experiments.liveMode) } /** diff --git a/atala-prism-sdk/src/commonMain/kotlin/io/iohk/atala/prism/walletsdk/prismagent/helpers/AgentOptions.kt b/atala-prism-sdk/src/commonMain/kotlin/io/iohk/atala/prism/walletsdk/prismagent/helpers/AgentOptions.kt new file mode 100644 index 000000000..4973bdbac --- /dev/null +++ b/atala-prism-sdk/src/commonMain/kotlin/io/iohk/atala/prism/walletsdk/prismagent/helpers/AgentOptions.kt @@ -0,0 +1,5 @@ +package io.iohk.atala.prism.walletsdk.prismagent.helpers + +data class AgentOptions(val experiments: Experiments = Experiments()) + +data class Experiments(val liveMode: Boolean = false) diff --git a/atala-prism-sdk/src/commonTest/kotlin/io/iohk/atala/prism/walletsdk/prismagent/ConnectionManagerTest.kt b/atala-prism-sdk/src/commonTest/kotlin/io/iohk/atala/prism/walletsdk/prismagent/ConnectionManagerTest.kt index e8e52b5d1..9a32042a5 100644 --- a/atala-prism-sdk/src/commonTest/kotlin/io/iohk/atala/prism/walletsdk/prismagent/ConnectionManagerTest.kt +++ b/atala-prism-sdk/src/commonTest/kotlin/io/iohk/atala/prism/walletsdk/prismagent/ConnectionManagerTest.kt @@ -68,6 +68,7 @@ class ConnectionManagerTest { mediationHandler = basicMediatorHandlerMock, pairings = mutableListOf(), pollux = polluxMock, + experimentLiveModeOptIn = true, scope = CoroutineScope(testDispatcher) ) } @@ -123,6 +124,103 @@ class ConnectionManagerTest { verify(basicMediatorHandlerMock).listenUnreadMessages(any(), any()) } + @Test + fun testStartFetchingMessages_whenServiceEndpointContainsWSSButOptInLiveModeFalse_thenRegunarlApi() = runTest { + connectionManager = ConnectionManager( + mercury = mercuryMock, + castor = castorMock, + pluto = plutoMock, + mediationHandler = basicMediatorHandlerMock, + pairings = mutableListOf(), + pollux = polluxMock, + experimentLiveModeOptIn = false, + scope = CoroutineScope(testDispatcher) + ) + + `when`(basicMediatorHandlerMock.mediatorDID) + .thenReturn(DID("did:prism:b6c0c33d701ac1b9a262a14454d1bbde3d127d697a76950963c5fd930605:Cj8KPRI7CgdtYXN0ZXIwEAFKLgoJc2VmsxEiECSTjyV7sUfCr_ArpN9rvCwR9fRMAhcsr_S7ZRiJk4p5k")) + + val vmAuthentication = DIDDocument.VerificationMethod( + id = DIDUrl(DID("2", "1", "0")), + controller = DID("2", "2", "0"), + type = Curve.ED25519.value, + publicKeyJwk = mapOf("crv" to Curve.ED25519.value, "x" to "") + ) + + val vmKeyAgreement = DIDDocument.VerificationMethod( + id = DIDUrl(DID("3", "1", "0")), + controller = DID("3", "2", "0"), + type = Curve.X25519.value, + publicKeyJwk = mapOf("crv" to Curve.X25519.value, "x" to "") + ) + + val vmService = DIDDocument.Service( + id = UUID.randomUUID().toString(), + type = emptyArray(), + serviceEndpoint = DIDDocument.ServiceEndpoint( + uri = "wss://serviceEndpoint" + ) + ) + + val didDoc = DIDDocument( + id = DID("did:prism:asdfasdf"), + coreProperties = arrayOf( + DIDDocument.Authentication( + urls = emptyArray(), + verificationMethods = arrayOf(vmAuthentication) + ), + DIDDocument.KeyAgreement( + urls = emptyArray(), + verificationMethods = arrayOf(vmKeyAgreement) + ), + DIDDocument.Services( + values = arrayOf(vmService) + ) + ) + ) + + `when`(castorMock.resolveDID(any())).thenReturn(didDoc) + val messages = arrayOf(Pair("1234", Message(piuri = "", body = ""))) + `when`(basicMediatorHandlerMock.pickupUnreadMessages(any())).thenReturn( + flow { + emit( + messages + ) + } + ) + val attachments: Array = + arrayOf( + AttachmentDescriptor( + mediaType = "application/json", + format = CredentialType.JWT.type, + data = AttachmentBase64(base64 = "asdfasdfasdfasdfasdfasdfasdfasdfasdf".base64UrlEncoded) + ) + ) + val listMessages = listOf( + Message( + piuri = ProtocolType.DidcommconnectionRequest.value, + body = "" + ), + Message( + piuri = ProtocolType.DidcommIssueCredential.value, + thid = UUID.randomUUID().toString(), + from = DID("did:peer:asdf897a6sdf"), + to = DID("did:peer:f706sg678ha"), + attachments = attachments, + body = """{}""" + ) + ) + val messageList: Flow> = flow { + emit(listMessages) + } + `when`(plutoMock.getAllMessages()).thenReturn(messageList) + + connectionManager.startFetchingMessages() + assertNotNull(connectionManager.fetchingMessagesJob) + verify(basicMediatorHandlerMock).pickupUnreadMessages(10) + verify(basicMediatorHandlerMock).registerMessagesAsRead(arrayOf("1234")) + } + @Test fun testStartFetchingMessages_whenServiceEndpointNotContainsWSS_thenUseAPIRequest() = runBlockingTest { diff --git a/atala-prism-sdk/src/commonTest/kotlin/io/iohk/atala/prism/walletsdk/prismagent/PrismAgentTests.kt b/atala-prism-sdk/src/commonTest/kotlin/io/iohk/atala/prism/walletsdk/prismagent/PrismAgentTests.kt index 99dca5cc5..def1787b6 100644 --- a/atala-prism-sdk/src/commonTest/kotlin/io/iohk/atala/prism/walletsdk/prismagent/PrismAgentTests.kt +++ b/atala-prism-sdk/src/commonTest/kotlin/io/iohk/atala/prism/walletsdk/prismagent/PrismAgentTests.kt @@ -22,6 +22,7 @@ import io.iohk.atala.prism.walletsdk.logger.PrismLoggerMock import io.iohk.atala.prism.walletsdk.mercury.ApiMock import io.iohk.atala.prism.walletsdk.pollux.PolluxImpl import io.iohk.atala.prism.walletsdk.pollux.models.CredentialRequestMeta +import io.iohk.atala.prism.walletsdk.prismagent.helpers.AgentOptions import io.iohk.atala.prism.walletsdk.prismagent.protocols.ProtocolType import io.iohk.atala.prism.walletsdk.prismagent.protocols.issueCredential.CredentialPreview import io.iohk.atala.prism.walletsdk.prismagent.protocols.issueCredential.IssueCredential @@ -61,7 +62,7 @@ class PrismAgentTests { polluxMock = PolluxMock() mediationHandlerMock = MediationHandlerMock() // Pairing will be removed in the future - connectionManager = ConnectionManager(mercuryMock, castorMock, plutoMock, mediationHandlerMock, mutableListOf(), polluxMock) + connectionManager = ConnectionManager(mercuryMock, castorMock, plutoMock, mediationHandlerMock, mutableListOf(), polluxMock, true) json = Json { ignoreUnknownKeys = true prettyPrint = true @@ -83,7 +84,8 @@ class PrismAgentTests { connectionManager = connectionManager, seed = seed, api = null, - logger = PrismLoggerMock() + logger = PrismLoggerMock(), + agentOptions = AgentOptions() ) plutoMock.getPrismLastKeyPathIndexReturn = flow { emit(0) } val newDID = agent.createNewPrismDID() @@ -106,7 +108,8 @@ class PrismAgentTests { connectionManager, null, null, - logger = PrismLoggerMock() + logger = PrismLoggerMock(), + agentOptions = AgentOptions() ) val newDID = agent.createNewPeerDID(services = emptyArray(), updateMediator = false) @@ -129,7 +132,8 @@ class PrismAgentTests { connectionManager, null, null, - logger = PrismLoggerMock() + logger = PrismLoggerMock(), + agentOptions = AgentOptions() ) val seAccept = arrayOf("someAccepts") @@ -163,7 +167,8 @@ class PrismAgentTests { connectionManager = connectionManager, seed = null, api = ApiMock(HttpStatusCode.OK, "{\"success\":\"true\"}"), - logger = PrismLoggerMock() + logger = PrismLoggerMock(), + agentOptions = AgentOptions() ) val invitationString = """ { @@ -189,7 +194,8 @@ class PrismAgentTests { connectionManager = connectionManager, seed = null, api = api, - logger = PrismLoggerMock() + logger = PrismLoggerMock(), + agentOptions = AgentOptions() ) val invitationString = """ { @@ -215,7 +221,8 @@ class PrismAgentTests { connectionManager = connectionManager, seed = null, api = ApiMock(HttpStatusCode.OK, "{\"success\":\"true\"}"), - logger = PrismLoggerMock() + logger = PrismLoggerMock(), + agentOptions = AgentOptions() ) val invitationString = """ { @@ -240,7 +247,8 @@ class PrismAgentTests { connectionManager = connectionManager, seed = null, api = null, - logger = PrismLoggerMock() + logger = PrismLoggerMock(), + agentOptions = AgentOptions() ) plutoMock.getDIDPrivateKeysReturn = flow { emit(listOf(null)) } @@ -264,7 +272,8 @@ class PrismAgentTests { connectionManager = connectionManager, seed = null, api = null, - logger = PrismLoggerMock() + logger = PrismLoggerMock(), + agentOptions = AgentOptions() ) val privateKeys = listOf( @@ -293,7 +302,8 @@ class PrismAgentTests { connectionManager = connectionManager, seed = null, api = null, - logger = PrismLoggerMock() + logger = PrismLoggerMock(), + agentOptions = AgentOptions() ) val invitationString = """ @@ -338,7 +348,8 @@ class PrismAgentTests { connectionManager = connectionManager, seed = null, api = null, - logger = PrismLoggerMock() + logger = PrismLoggerMock(), + agentOptions = AgentOptions() ) val invitationString = """ @@ -368,7 +379,8 @@ class PrismAgentTests { connectionManager = connectionManager, seed = null, api = null, - logger = PrismLoggerMock() + logger = PrismLoggerMock(), + agentOptions = AgentOptions() ) assertEquals(PrismAgent.State.STOPPED, agent.state) agent.start() @@ -386,7 +398,8 @@ class PrismAgentTests { connectionManager = connectionManager, seed = null, api = null, - logger = PrismLoggerMock() + logger = PrismLoggerMock(), + agentOptions = AgentOptions() ) agent.stop() assertEquals(PrismAgent.State.STOPPED, agent.state) @@ -405,7 +418,8 @@ class PrismAgentTests { connectionManager = connectionManager, seed = null, api = null, - logger = PrismLoggerMock() + logger = PrismLoggerMock(), + agentOptions = AgentOptions() ) val x = agent.parseInvitation(oob) assert(x is OutOfBandInvitation) @@ -434,7 +448,8 @@ class PrismAgentTests { connectionManager = connectionManager, seed = null, api = ApiMock(HttpStatusCode.OK, "{\"success\":\"true\"}"), - logger = PrismLoggerMock() + logger = PrismLoggerMock(), + agentOptions = AgentOptions() ) val attachmentDescriptor = @@ -494,7 +509,8 @@ class PrismAgentTests { connectionManager = connectionManager, seed = null, api = ApiMock(HttpStatusCode.OK, "{\"success\":\"true\"}"), - logger = PrismLoggerMock() + logger = PrismLoggerMock(), + agentOptions = AgentOptions() ) val attachmentDescriptor = diff --git a/sampleapp/src/main/java/io/iohk/atala/prism/sampleapp/Sdk.kt b/sampleapp/src/main/java/io/iohk/atala/prism/sampleapp/Sdk.kt index 123c8bb37..9b4a080e6 100644 --- a/sampleapp/src/main/java/io/iohk/atala/prism/sampleapp/Sdk.kt +++ b/sampleapp/src/main/java/io/iohk/atala/prism/sampleapp/Sdk.kt @@ -21,6 +21,8 @@ import io.iohk.atala.prism.walletsdk.pluto.data.DbConnection import io.iohk.atala.prism.walletsdk.pollux.PolluxImpl import io.iohk.atala.prism.walletsdk.prismagent.PrismAgent import io.iohk.atala.prism.walletsdk.prismagent.PrismAgentError +import io.iohk.atala.prism.walletsdk.prismagent.helpers.AgentOptions +import io.iohk.atala.prism.walletsdk.prismagent.helpers.Experiments import io.iohk.atala.prism.walletsdk.prismagent.mediation.BasicMediatorHandler import io.iohk.atala.prism.walletsdk.prismagent.mediation.MediationHandler import kotlinx.coroutines.CoroutineScope @@ -137,7 +139,8 @@ class Sdk { mercury = mercury, pollux = pollux, seed = seed, - mediatorHandler = handler + mediatorHandler = handler, + agentOptions = AgentOptions(Experiments(liveMode = false)) ) }