Skip to content

Commit

Permalink
feat: experimental opt-in for mediator live mode (#150)
Browse files Browse the repository at this point in the history
  • Loading branch information
cristianIOHK authored Apr 25, 2024
1 parent c54b5b9 commit 54f23cc
Show file tree
Hide file tree
Showing 6 changed files with 168 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import kotlin.jvm.Throws
* @property castor The instance of the Castor interface used for working with DIDs.
* @property pluto The instance of the Pluto interface used for storing messages and connection information.
* @property mediationHandler The instance of the MediationHandler interface used for handling mediation.
* @property experimentLiveModeOptIn Flag to opt in or out of the experimental feature mediator live mode, using websockets.
* @property pairings The mutable list of DIDPair representing the connections managed by the ConnectionManager.
*/
class ConnectionManager(
Expand All @@ -44,6 +45,7 @@ class ConnectionManager(
internal val mediationHandler: MediationHandler,
private var pairings: MutableList<DIDPair>,
private val pollux: Pollux,
private val experimentLiveModeOptIn: Boolean = false,
private val scope: CoroutineScope = CoroutineScope(Dispatchers.IO)
) : ConnectionsManager, DIDCommConnection {

Expand All @@ -66,22 +68,23 @@ class ConnectionManager(
// Resolve the DID document for the mediator
val mediatorDidDoc = castor.resolveDID(currentMediatorDID.toString())
var serviceEndpoint: String? = null

// Loop through the services in the DID document to find a WebSocket endpoint
mediatorDidDoc.services.forEach {
if (it.serviceEndpoint.uri.contains("wss://") || it.serviceEndpoint.uri.contains("ws://")) {
serviceEndpoint = it.serviceEndpoint.uri
return@forEach // Exit loop once the WebSocket endpoint is found
if (experimentLiveModeOptIn) {
// Loop through the services in the DID document to find a WebSocket endpoint
mediatorDidDoc.services.forEach {
if (it.serviceEndpoint.uri.contains("wss://") || it.serviceEndpoint.uri.contains("ws://")) {
serviceEndpoint = it.serviceEndpoint.uri
return@forEach // Exit loop once the WebSocket endpoint is found
}
}
}

// If a WebSocket service endpoint is found
serviceEndpoint?.let { serviceEndpointUrl ->
// Listen for unread messages on the WebSocket endpoint
mediationHandler.listenUnreadMessages(
serviceEndpointUrl
) { arrayMessages ->
processMessages(arrayMessages)
// If a WebSocket service endpoint is found
serviceEndpoint?.let { serviceEndpointUrl ->
// Listen for unread messages on the WebSocket endpoint
mediationHandler.listenUnreadMessages(
serviceEndpointUrl
) { arrayMessages ->
processMessages(arrayMessages)
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ import io.iohk.atala.prism.walletsdk.logger.PrismLoggerImpl
import io.iohk.atala.prism.walletsdk.pollux.models.AnonCredential
import io.iohk.atala.prism.walletsdk.pollux.models.CredentialRequestMeta
import io.iohk.atala.prism.walletsdk.pollux.models.JWTCredential
import io.iohk.atala.prism.walletsdk.prismagent.helpers.AgentOptions
import io.iohk.atala.prism.walletsdk.prismagent.mediation.BasicMediatorHandler
import io.iohk.atala.prism.walletsdk.prismagent.mediation.MediationHandler
import io.iohk.atala.prism.walletsdk.prismagent.protocols.ProtocolType
Expand Down Expand Up @@ -119,6 +120,7 @@ class PrismAgent {
private val api: Api
private var connectionManager: ConnectionManager
private var logger: PrismLogger
private val agentOptions: AgentOptions

/**
* Initializes the PrismAgent with the given dependencies.
Expand All @@ -133,6 +135,7 @@ class PrismAgent {
* @param api An optional Api instance used by the PrismAgent if provided, otherwise a default ApiImpl will be used.
* @param logger An optional PrismLogger instance used by the PrismAgent if provided, otherwise a PrismLoggerImpl with
* LogComponent.PRISM_AGENT will be used.
* @param agentOptions Options to configure certain features with in the prism agent.
*/
@JvmOverloads
constructor(
Expand All @@ -144,7 +147,8 @@ class PrismAgent {
connectionManager: ConnectionManager,
seed: Seed?,
api: Api?,
logger: PrismLogger = PrismLoggerImpl(LogComponent.PRISM_AGENT)
logger: PrismLogger = PrismLoggerImpl(LogComponent.PRISM_AGENT),
agentOptions: AgentOptions
) {
prismAgentScope.launch {
flowState.emit(State.STOPPED)
Expand All @@ -170,6 +174,7 @@ class PrismAgent {
}
)
this.logger = logger
this.agentOptions = agentOptions
}

/**
Expand All @@ -184,6 +189,7 @@ class PrismAgent {
* @param api The instance of the API. Default is null.
* @param mediatorHandler The mediator handler.
* @param logger The logger for PrismAgent. Default is PrismLoggerImpl with LogComponent.PRISM_AGENT.
* @param agentOptions Options to configure certain features with in the prism agent.
*/
@JvmOverloads
constructor(
Expand All @@ -195,7 +201,8 @@ class PrismAgent {
seed: Seed? = null,
api: Api? = null,
mediatorHandler: MediationHandler,
logger: PrismLogger = PrismLoggerImpl(LogComponent.PRISM_AGENT)
logger: PrismLogger = PrismLoggerImpl(LogComponent.PRISM_AGENT),
agentOptions: AgentOptions
) {
prismAgentScope.launch {
flowState.emit(State.STOPPED)
Expand All @@ -220,9 +227,10 @@ class PrismAgent {
}
)
this.logger = logger
this.agentOptions = agentOptions
// Pairing will be removed in the future
this.connectionManager =
ConnectionManager(mercury, castor, pluto, mediatorHandler, mutableListOf(), pollux)
ConnectionManager(mercury, castor, pluto, mediatorHandler, mutableListOf(), pollux, agentOptions.experiments.liveMode)
}

init {
Expand Down Expand Up @@ -462,7 +470,7 @@ class PrismAgent {
fun setupMediatorHandler(mediatorHandler: MediationHandler) {
stop()
this.connectionManager =
ConnectionManager(mercury, castor, pluto, mediatorHandler, mutableListOf(), pollux)
ConnectionManager(mercury, castor, pluto, mediatorHandler, mutableListOf(), pollux, agentOptions.experiments.liveMode)
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package io.iohk.atala.prism.walletsdk.prismagent.helpers

data class AgentOptions(val experiments: Experiments = Experiments())

data class Experiments(val liveMode: Boolean = false)
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class ConnectionManagerTest {
mediationHandler = basicMediatorHandlerMock,
pairings = mutableListOf(),
pollux = polluxMock,
experimentLiveModeOptIn = true,
scope = CoroutineScope(testDispatcher)
)
}
Expand Down Expand Up @@ -123,6 +124,103 @@ class ConnectionManagerTest {
verify(basicMediatorHandlerMock).listenUnreadMessages(any(), any())
}

@Test
fun testStartFetchingMessages_whenServiceEndpointContainsWSSButOptInLiveModeFalse_thenRegunarlApi() = runTest {
connectionManager = ConnectionManager(
mercury = mercuryMock,
castor = castorMock,
pluto = plutoMock,
mediationHandler = basicMediatorHandlerMock,
pairings = mutableListOf(),
pollux = polluxMock,
experimentLiveModeOptIn = false,
scope = CoroutineScope(testDispatcher)
)

`when`(basicMediatorHandlerMock.mediatorDID)
.thenReturn(DID("did:prism:b6c0c33d701ac1b9a262a14454d1bbde3d127d697a76950963c5fd930605:Cj8KPRI7CgdtYXN0ZXIwEAFKLgoJc2VmsxEiECSTjyV7sUfCr_ArpN9rvCwR9fRMAhcsr_S7ZRiJk4p5k"))

val vmAuthentication = DIDDocument.VerificationMethod(
id = DIDUrl(DID("2", "1", "0")),
controller = DID("2", "2", "0"),
type = Curve.ED25519.value,
publicKeyJwk = mapOf("crv" to Curve.ED25519.value, "x" to "")
)

val vmKeyAgreement = DIDDocument.VerificationMethod(
id = DIDUrl(DID("3", "1", "0")),
controller = DID("3", "2", "0"),
type = Curve.X25519.value,
publicKeyJwk = mapOf("crv" to Curve.X25519.value, "x" to "")
)

val vmService = DIDDocument.Service(
id = UUID.randomUUID().toString(),
type = emptyArray(),
serviceEndpoint = DIDDocument.ServiceEndpoint(
uri = "wss://serviceEndpoint"
)
)

val didDoc = DIDDocument(
id = DID("did:prism:asdfasdf"),
coreProperties = arrayOf(
DIDDocument.Authentication(
urls = emptyArray(),
verificationMethods = arrayOf(vmAuthentication)
),
DIDDocument.KeyAgreement(
urls = emptyArray(),
verificationMethods = arrayOf(vmKeyAgreement)
),
DIDDocument.Services(
values = arrayOf(vmService)
)
)
)

`when`(castorMock.resolveDID(any())).thenReturn(didDoc)
val messages = arrayOf(Pair("1234", Message(piuri = "", body = "")))
`when`(basicMediatorHandlerMock.pickupUnreadMessages(any())).thenReturn(
flow {
emit(
messages
)
}
)
val attachments: Array<AttachmentDescriptor> =
arrayOf(
AttachmentDescriptor(
mediaType = "application/json",
format = CredentialType.JWT.type,
data = AttachmentBase64(base64 = "asdfasdfasdfasdfasdfasdfasdfasdfasdf".base64UrlEncoded)
)
)
val listMessages = listOf(
Message(
piuri = ProtocolType.DidcommconnectionRequest.value,
body = ""
),
Message(
piuri = ProtocolType.DidcommIssueCredential.value,
thid = UUID.randomUUID().toString(),
from = DID("did:peer:asdf897a6sdf"),
to = DID("did:peer:f706sg678ha"),
attachments = attachments,
body = """{}"""
)
)
val messageList: Flow<List<Message>> = flow {
emit(listMessages)
}
`when`(plutoMock.getAllMessages()).thenReturn(messageList)

connectionManager.startFetchingMessages()
assertNotNull(connectionManager.fetchingMessagesJob)
verify(basicMediatorHandlerMock).pickupUnreadMessages(10)
verify(basicMediatorHandlerMock).registerMessagesAsRead(arrayOf("1234"))
}

@Test
fun testStartFetchingMessages_whenServiceEndpointNotContainsWSS_thenUseAPIRequest() =
runBlockingTest {
Expand Down
Loading

0 comments on commit 54f23cc

Please sign in to comment.