Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate on disconnect #31

Merged
merged 8 commits into from
Dec 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion core/src/commonMain/kotlin/dev/schlaubi/lavakord/Options.kt
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,16 @@ public interface LavaKordOptions {
* Configuration for Links and Nodes.
*
* @property autoReconnect Whether to auto-reconnect links or not
* @property autoMigrateOnDisconnect Whether to try to migrate links from a disconnected node onto a new one.
* This option has no effect if [autoReconnect] is false. If the node is trying to resume, the migration will only
* take place after the node gives up on resuming as per [retry].
* @property resumeTimeout amount of seconds Lavalink will wait to kill all players if the client fails to resume it's connection
* @property retry retry strategy (See [Retry] and [LinearRetry])
* @property showTrace whether [RestError.trace] should be populated
*/
public interface LinkConfig {
public val autoReconnect: Boolean
public val autoMigrateOnDisconnect: Boolean
public val resumeTimeout: Int
public val retry: Retry
public val showTrace: Boolean
Expand Down Expand Up @@ -130,12 +134,13 @@ public data class MutableLavaKordOptions(
*/
public data class LinkConfig(
override var autoReconnect: Boolean = true,
override var autoMigrateOnDisconnect: Boolean = true,
override var resumeTimeout: Int = 60,
override var retry: Retry = LinearRetry(2.seconds, 60.seconds, 10),
override val showTrace: Boolean = false
) : LavaKordOptions.LinkConfig {
internal fun seal(): LavaKordOptions.LinkConfig =
ImmutableLavaKordOptions.LinkConfig(autoReconnect, resumeTimeout, retry, showTrace)
ImmutableLavaKordOptions.LinkConfig(autoReconnect, autoMigrateOnDisconnect, resumeTimeout, retry, showTrace)

/**
* Creates a linear [Retry] strategy.
Expand Down Expand Up @@ -199,6 +204,7 @@ private data class ImmutableLavaKordOptions(
*/
data class LinkConfig(
override val autoReconnect: Boolean,
override val autoMigrateOnDisconnect: Boolean,
override val resumeTimeout: Int,
override val retry: Retry,
override val showTrace: Boolean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public interface Link {

/**
* Called internally when this link is connected or reconnected to a new node without resuming, thereby creating a
* new session.
* new session. This function may also be used if the link is moved to a new session.
* @param node The node that was connected to, which may be potentially different from the previously connected node
*/
public suspend fun onNewSession(node: Node)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@ import io.ktor.http.*
import io.ktor.serialization.kotlinx.*
import io.ktor.serialization.kotlinx.json.*
import kotlinx.atomicfu.atomic
import kotlinx.coroutines.joinAll
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.flow.SharedFlow
import kotlinx.coroutines.flow.asSharedFlow
import kotlinx.coroutines.launch
import kotlinx.coroutines.newCoroutineContext
import kotlinx.serialization.modules.SerializersModule
import kotlinx.serialization.modules.contextual
import kotlinx.serialization.modules.plus
Expand Down Expand Up @@ -140,7 +142,7 @@ public abstract class AbstractLavakord internal constructor(

override fun getLink(guildId: ULong): Link {
return linksMap.getOrPut(guildId) {
val node = loadBalancer.determineBestNode(guildId) as NodeImpl
val node = loadBalancer.determineBestNode(guildId) ?: error("No nodes are available")
buildNewLink(guildId, node)
}
}
Expand Down Expand Up @@ -169,7 +171,7 @@ public abstract class AbstractLavakord internal constructor(

override fun removeNode(name: String) {
val node = nodesMap.remove(name)
requireNotNull(node) { "There is no node with that name" }
requireNotNull(node) { "There is no node with name $name" }
node.close()
}

Expand All @@ -189,6 +191,15 @@ public abstract class AbstractLavakord internal constructor(
*/
protected abstract fun buildNewLink(guildId: ULong, node: Node): Link

internal suspend fun migrateFromDisconnectedNode(disconnectedNode: NodeImpl) {
linksMap.filterValues { it.node == disconnectedNode }.mapNotNull { (guild, link) ->
val newNode = loadBalancer.determineBestNode(guild) ?: return@mapNotNull null
launch {
link.onNewSession(newNode)
}
}.joinAll()
}

/** Called on websocket connect without resuming */
internal suspend fun onNewSession(node: Node) {
if (!options.link.autoReconnect) return
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import dev.arbjerg.lavalink.protocol.v4.PlayerUpdate
import dev.arbjerg.lavalink.protocol.v4.VoiceState
import dev.arbjerg.lavalink.protocol.v4.toOmissible
import dev.schlaubi.lavakord.audio.Link
import dev.schlaubi.lavakord.audio.Link.State
import dev.schlaubi.lavakord.audio.Node
import dev.schlaubi.lavakord.audio.player.Player
import dev.schlaubi.lavakord.audio.player.node
Expand All @@ -21,34 +22,43 @@ public abstract class AbstractLink(node: Node, final override val guildId: ULong
override val player: Player = WebsocketPlayer(node as NodeImpl, guildId)
abstract override val lavakord: AbstractLavakord
override var lastChannelId: ULong? = null
override var state: Link.State = Link.State.NOT_CONNECTED
override var state: State = State.NOT_CONNECTED
set(value) {
if (field == value) return
LOG.debug { "$this: $state -> $value" }
field = value
}
private var cachedVoiceState: VoiceState? = null

override suspend fun onDisconnected() {
state = Link.State.NOT_CONNECTED
node.destroyPlayer(guildId)
state = State.NOT_CONNECTED
cachedVoiceState = null
}

override suspend fun onNewSession(node: Node) {
this.node = node
player.node
val voiceState = cachedVoiceState

state = if (voiceState != null) State.CONNECTING else State.NOT_CONNECTED

cachedVoiceState?.let {
node.updatePlayer(guildId, request = PlayerUpdate(voice = it.toOmissible()))
try {
(player as WebsocketPlayer).recreatePlayer(node as NodeImpl, voiceState)
LOG.debug { "$this: recreated player on $node" }
} catch (e: Exception) {
state = State.NOT_CONNECTED
throw e
}
(player as WebsocketPlayer).recreatePlayer(node as NodeImpl)
}

override suspend fun destroy() {
val shouldDisconnect = state !== Link.State.DISCONNECTING && state !== Link.State.NOT_CONNECTED
state = Link.State.DESTROYING
val shouldDisconnect = state !== State.DISCONNECTING && state !== State.NOT_CONNECTED
state = State.DESTROYING
if (shouldDisconnect) {
disconnectAudio()
}
node.destroyPlayer(guildId)
lavakord.removeDestroyedLink(this)
state = Link.State.DESTROYED
state = State.DESTROYED
}

internal suspend fun onVoiceServerUpdate(update: VoiceState) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,12 @@ internal class LoadBalancer(
private val lavakord: LavaKord
) {

fun determineBestNode(guildId: ULong): Node {
val leastPenalty = lavakord.nodes
.asSequence()
.filter(Node::available)
.minByOrNull { calculatePenalties(it, penaltyProviders, guildId) }
fun determineBestNode(guildId: ULong): Node? = lavakord.nodes
.asSequence()
.filter(Node::available)
.minByOrNull { calculatePenalties(it, penaltyProviders, guildId) }

checkNotNull(leastPenalty) { "No nodes available" }

return leastPenalty
}

// Inspired by: https://github.com/Frederikam/Lavalink-Client/blob/master/src/main/java/lavalink/client/io/LavalinkLoadBalancer.java#L111
// Inspired by: https://github.com/freyacodes/Lavalink-Client/blob/master/src/main/java/lavalink/client/io/LavalinkLoadBalancer.java#L111
private fun calculatePenalties(
node: Node,
penaltyProviders: List<PenaltyProvider>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import dev.arbjerg.lavalink.protocol.v4.Stats
import dev.arbjerg.lavalink.protocol.v4.toOmissible
import dev.schlaubi.lavakord.Plugin
import dev.schlaubi.lavakord.audio.Event
import dev.schlaubi.lavakord.audio.Link
import dev.schlaubi.lavakord.audio.Node
import dev.schlaubi.lavakord.rest.getInfo
import dev.schlaubi.lavakord.rest.getVersion
Expand Down Expand Up @@ -33,6 +34,7 @@ import kotlinx.serialization.json.decodeFromJsonElement
import kotlinx.serialization.json.jsonObject
import kotlinx.serialization.json.jsonPrimitive
import mu.KotlinLogging
import kotlin.concurrent.Volatile
import kotlin.properties.ReadWriteProperty
import kotlin.reflect.KProperty

Expand All @@ -58,7 +60,9 @@ internal class NodeImpl(
private val retry = lavakord.options.link.retry

override var sessionId: String by SessionIdContainer()
override var available: Boolean = true

@Volatile
override var available: Boolean = false
override var lastStatsEvent: Stats? = null
private var eventPublisher: MutableSharedFlow<Event> =
MutableSharedFlow(extraBufferCapacity = Channel.UNLIMITED)
Expand Down Expand Up @@ -135,9 +139,12 @@ internal class NodeImpl(
val reason = session.closeReason.await()
val resumeAgain = resume && reason?.knownReason != CloseReason.Codes.NORMAL
if (resumeAgain) {
LOG.warn { "Disconnected from websocket for: $reason. Music will continue playing if we can reconnect within the next $resumeTimeout seconds" }
LOG.warn { "$name disconnected from websocket for: $reason. Music will continue playing if we can reconnect within the next $resumeTimeout seconds" }
} else {
LOG.warn { "Disconnected from websocket for: $reason. Not resuming." }
LOG.warn { "$name disconnected from websocket for: $reason. Not resuming." }
if (lavakord.options.link.autoReconnect && lavakord.options.link.autoMigrateOnDisconnect) {
lavakord.migrateFromDisconnectedNode(this)
}
}
reconnect(resume = resumeAgain)
}
Expand Down Expand Up @@ -178,14 +185,29 @@ internal class NodeImpl(
LOG.warn(e) {"Could not parse event"}
}
when (event) {
is Message.PlayerUpdateEvent -> (lavakord.getLink(event.guildId).player as WebsocketPlayer)
.provideState(event.state)
is Message.PlayerUpdateEvent -> {
val link = lavakord.getLink(event.guildId) as AbstractLink

if (event.state.connected && link.state == Link.State.CONNECTING) {
link.state = Link.State.CONNECTING
} else if (!event.state.connected && link.state == Link.State.DISCONNECTING) {
link.state = Link.State.NOT_CONNECTED
}

(link.player as WebsocketPlayer).provideState(event.state)
}

is Message.EmittedEvent.WebSocketClosedEvent -> {
// These codes represent an invalid session
// See https://discord.com/developers/docs/topics/opcodes-and-status-codes#voice-voice-close-event-codes
if (event.code == 4004 || event.code == 4006 || event.code == 4009 || event.code == 4014) {
lavakord.getLink(event.guildId).onDisconnected()
try {
if (event.code == 4004 || event.code == 4006 || event.code == 4009 || event.code == 4014) {
LOG.debug { "Node $name received close code ${event.code} for guild ${event.guildId}" }
lavakord.getLink(event.guildId).onDisconnected()
}
} finally {
// Must still be emitted
eventPublisher.tryEmit(event.toEvent())
}
}

Expand Down Expand Up @@ -214,8 +236,12 @@ internal class NodeImpl(
}

override fun close() {
available = false
lavakord.launch {
session.close(CloseReason(CloseReason.Codes.NORMAL, "Close requested by client"))
if (lavakord.options.link.autoReconnect && lavakord.options.link.autoMigrateOnDisconnect) {
lavakord.migrateFromDisconnectedNode(this@NodeImpl)
}
}
}

Expand All @@ -226,4 +252,6 @@ internal class NodeImpl(
href(resources.resourcesFormat, V4Api.WebSocket(), this)
}
}

override fun toString() = "Node($name)"
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import dev.schlaubi.lavakord.audio.player.Player
import dev.schlaubi.lavakord.rest.models.FiltersObject
import dev.schlaubi.lavakord.rest.models.toLavalink
import dev.schlaubi.lavakord.rest.updatePlayer
import kotlinx.atomicfu.AtomicBoolean
import kotlinx.atomicfu.atomic
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.filter
Expand All @@ -38,6 +40,7 @@ internal class WebsocketPlayer(node: NodeImpl, internal val guildId: ULong) : Pl
return (lastPosition + elapsedSinceUpdate).coerceAtMost(trackLength)
}
private var specifiedEndTime: Duration? = null
private val isRecreating = atomic(false)

override val volume: Int
get() = ((filters.volume ?: 1.0f) * 100).toInt()
Expand Down Expand Up @@ -90,7 +93,7 @@ internal class WebsocketPlayer(node: NodeImpl, internal val guildId: ULong) : Pl
private fun handleNewTrack(event: TrackStartEvent) {
updateTime = Clock.System.now()
val track = event.track
lastPosition = 0.milliseconds
lastPosition = event.track.info.position.milliseconds
playingTrack = track
}

Expand Down Expand Up @@ -127,22 +130,33 @@ internal class WebsocketPlayer(node: NodeImpl, internal val guildId: ULong) : Pl
}

internal fun provideState(state: PlayerState) {
// After migrating the player to a new node, the new node may send a position of 0 as we are starting a new track.
// This may cause a race condition where the migrated track starts at close to 0:00 even if the start time should
// be later. Ignoring the first player update if it is zero fixes this issue.
if (isRecreating.getAndSet(true) && state.position == 0L) return
updateTime = Instant.fromEpochMilliseconds(state.time)
lastPosition = state.position.milliseconds
}

internal suspend fun recreatePlayer(node: NodeImpl) {
internal suspend fun recreatePlayer(node: NodeImpl, voiceState: VoiceState?) {
this.node = node
val position = if (playingTrack == null) Omissible.omitted() else positionDuration.inWholeMilliseconds.toOmissible()
node.updatePlayer(guildId, noReplace = false, PlayerUpdate(
val position = if (playingTrack == null) null else positionDuration.inWholeMilliseconds

isRecreating.value = true
node.updatePlayer(
guildId, noReplace = false, PlayerUpdate(
encodedTrack = playingTrack?.encoded.toOmissible(),
identifier = Omissible.omitted(),
position = position,
position = position.toOmissible(),
endTime = specifiedEndTime?.inWholeMilliseconds.toOmissible(),
volume = volume.toOmissible(),
paused = paused.toOmissible(),
filters = filters.toLavalink().toOmissible()
filters = filters.toLavalink().toOmissible(),
voice = voiceState.toOmissible()
)
)

updateTime = Clock.System.now()
lastPosition = position?.milliseconds ?: 0.milliseconds
}
}