Skip to content

Commit

Permalink
Add allowRejoin param to Replay, update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
fakelag committed Feb 25, 2024
1 parent cf200bf commit 850c43d
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 18 deletions.
27 changes: 21 additions & 6 deletions discordplayer/discordplayer.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,26 +228,35 @@ func (dms *DiscordMusicSession) ClearMediaQueue() bool {
return true
}

func (dms *DiscordMusicSession) Replay() error {
// Sets current media to be replayed once after its done playing, or
// enqueues+starts the last finished song if allowRejoin is true. Returned context is
// non-nil if the worker is restarted & the bot rejoins voice
func (dms *DiscordMusicSession) Replay(allowRejoin bool) (context.Context, error) {
err := dms.sendCommand(dms.chanReplayCommand)

if err == nil {
return nil
return nil, nil
}

if errors.Is(err, ErrorWorkerNotActive) {
if allowRejoin && errors.Is(err, ErrorWorkerNotActive) {
dms.mutex.RLock()
lastCompletedMedia := dms.lastCompletedMedia
dms.mutex.RUnlock()

if lastCompletedMedia == nil {
return ErrorNoMediaFound
return nil, ErrorNoMediaFound
}

return dms.EnqueueMedia(lastCompletedMedia)
err = dms.EnqueueMedia(lastCompletedMedia)

if err != nil {
return nil, err
}

return dms.Start()
}

return err
return nil, err
}

func (dms *DiscordMusicSession) SetPaused(paused bool) error {
Expand Down Expand Up @@ -371,6 +380,12 @@ func (dms *DiscordMusicSession) GetGuildID() string {
return dms.guildID
}

func (dms *DiscordMusicSession) IsWorkerActive() bool {
dms.mutex.RLock()
defer dms.mutex.RUnlock()
return dms.workerActive
}

func (dms *DiscordMusicSession) sendCommand(command chan bool) error {
dms.mutex.Lock()
defer dms.mutex.Unlock()
Expand Down
97 changes: 91 additions & 6 deletions discordplayer/discordplayer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ var _ = Describe("Discord Player", func() {
}
})

It("Repeats the current media upon receiving repeat command", func() {
It("Repeats the current media upon receiving repeat command when the worker is active", func() {
ctrl := gomock.NewController(GinkgoT())

currentMediaDone := make(chan error)
Expand All @@ -377,21 +377,25 @@ var _ = Describe("Discord Player", func() {
playerContext.mockVoiceConnection.EXPECT().IsReady().Return(true).AnyTimes()
playerContext.mockDca.EXPECT().EncodeFile(playerContext.mockMedia.FileURL(), gomock.Any()).Return(nil, nil).AnyTimes()

Eventually(func() entities.Media {
return playerContext.dms.GetCurrentlyPlayingMedia()
}).WithTimeout(failTimeout).WithPolling(50 * time.Millisecond).ShouldNot(BeNil())

// Done to current media
currentMediaDone <- nil

Eventually(func() entities.Media {
return playerContext.dms.GetCurrentlyPlayingMedia()
}).WithTimeout(failTimeout).WithPolling(50 * time.Millisecond).Should(BeNil())

err := playerContext.dms.Replay()
_, err := playerContext.dms.Replay(false)
Expect(err).NotTo(HaveOccurred())

Eventually(func() entities.Media {
return playerContext.dms.GetCurrentlyPlayingMedia()
}).WithTimeout(failTimeout).WithPolling(50 * time.Millisecond).ShouldNot(BeNil())

err = playerContext.dms.Replay()
_, err = playerContext.dms.Replay(false)
Expect(err).NotTo(HaveOccurred())

// Done done to first repeat
Expand Down Expand Up @@ -421,6 +425,77 @@ var _ = Describe("Discord Player", func() {
}
})

It("Repeats the current media upon receiving repeat command when the worker is not active", func() {
ctrl := gomock.NewController(GinkgoT())

currentMediaDone := make(chan error)
mockDcaStreamingSession := NewMockDcaStreamingSession(ctrl)
playerContext := JoinMockVoiceChannelAndPlayEx(context.TODO(), ctrl, currentMediaDone, false, mockDcaStreamingSession)
playerContext.mockVoiceConnection.EXPECT().Speaking(gomock.Any()).AnyTimes()
playerContext.mockVoiceConnection.EXPECT().IsReady().Return(true).AnyTimes()
playerContext.mockDiscordSession.EXPECT().
ChannelVoiceJoin(playerContext.guildID, playerContext.channelID, false, false).
Return(playerContext.mockVoiceConnection, nil)
playerContext.mockDca.EXPECT().EncodeFile(playerContext.mockMedia.FileURL(), gomock.Any()).Return(nil, nil).AnyTimes()
playerContext.mockVoiceConnection.EXPECT().Disconnect()

_, err := playerContext.dms.Replay(true)
Expect(err).To(MatchError(discordplayer.ErrorNoMediaFound))

Expect(playerContext.dms.EnqueueMedia(playerContext.mockMedia)).NotTo(HaveOccurred())

_, err = playerContext.dms.Start()
Expect(err).NotTo(HaveOccurred())

Eventually(func() entities.Media {
return playerContext.dms.GetCurrentlyPlayingMedia()
}).WithTimeout(failTimeout).WithPolling(50 * time.Millisecond).ShouldNot(BeNil())

currentMediaDone <- nil
Expect(playerContext.dms.Leave()).To(Succeed())

Eventually(func() bool {
return playerContext.dms.IsWorkerActive()
}).WithTimeout(failTimeout).WithPolling(50 * time.Millisecond).Should(BeFalse())

_, err = playerContext.dms.Replay(false)
Expect(err).Should(MatchError(discordplayer.ErrorWorkerNotActive))

newCtx, err := playerContext.dms.Replay(true)
Expect(err).ShouldNot(HaveOccurred())
Expect(newCtx).NotTo(BeNil())

Eventually(func() bool {
return playerContext.dms.IsWorkerActive()
}).WithTimeout(failTimeout).WithPolling(50 * time.Millisecond).Should(BeTrue())

Eventually(func() entities.Media {
return playerContext.dms.GetCurrentlyPlayingMedia()
}).WithTimeout(failTimeout).WithPolling(50 * time.Millisecond).ShouldNot(BeNil())

currentMediaDone <- nil

Eventually(func() entities.Media {
return playerContext.dms.GetCurrentlyPlayingMedia()
}).WithTimeout(failTimeout).WithPolling(50 * time.Millisecond).Should(BeNil())

c := make(chan struct{})

playerContext.mockVoiceConnection.EXPECT().Disconnect().Do(func() {
close(currentMediaDone)
close(c)
})

Expect(playerContext.dms.Leave()).To(Succeed())

select {
case <-c:
return
case <-time.After(20 * time.Second):
Fail("Voice worker timed out")
}
})

It("Pauses & unpauses the current music streaming session upon receiving the pause command", func() {
ctrl := gomock.NewController(GinkgoT())

Expand Down Expand Up @@ -1597,15 +1672,25 @@ var _ = Describe("Discord Player", func() {
When("Using the API invalidly", func() {
It("Returns a sensible error if attempting to Start() without a voice channel", func() {
dms, err := discordplayer.NewDiscordMusicSession(context.TODO(), nil, &discordplayer.DiscordMusicSessionOptions{
GuildID: gID,
VoiceChannelID: "",
MediaQueueMaxSize: 10,
GuildID: gID,
VoiceChannelID: "",
})

Expect(err).NotTo(HaveOccurred())
Expect(dms).NotTo(BeNil())
_, err = dms.Start()
Expect(err).To(MatchError(discordplayer.ErrorNoVoiceChannelSet))
})

It("Returns a sensible error if attempting to EnqueueMedia() with a nil interface", func() {
dms, err := discordplayer.NewDiscordMusicSession(context.TODO(), nil, &discordplayer.DiscordMusicSessionOptions{
GuildID: gID,
VoiceChannelID: "",
})

Expect(dms).NotTo(BeNil())
Expect(err).NotTo(HaveOccurred())
Expect(dms.EnqueueMedia(nil)).To(MatchError(discordplayer.ErrorInvalidMedia))
})
})
})
7 changes: 1 addition & 6 deletions discordplayer/voiceworker.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ workerloop:
dms.mutex.RUnlock()

if repeatMedia != nil {
// TODO repeat from from start of the queue
dms.EnqueueMedia(repeatMedia)
}
default:
Expand Down Expand Up @@ -385,12 +386,6 @@ func (dms *DiscordMusicSession) disconnectAndExitWorker() {
}
}

func (dms *DiscordMusicSession) isWorkerActive() bool {
dms.mutex.RLock()
defer dms.mutex.RUnlock()
return dms.workerActive
}

func (dms *DiscordMusicSession) setCurrentlyPlayingMediaAndSession(media entities.Media, session *DcaMediaSession) {
dms.mutex.Lock()
defer dms.mutex.Unlock()
Expand Down

0 comments on commit 850c43d

Please sign in to comment.