Skip to content

Commit

Permalink
Add retry logic to fetching user profile (#27)
Browse files Browse the repository at this point in the history
  • Loading branch information
streamer45 authored Jul 12, 2024
1 parent 136952f commit ab408d0
Show file tree
Hide file tree
Showing 13 changed files with 387 additions and 50 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,8 @@ dist/
# temporary track files
tracks/

# go tooling
bin/

.config.env*
*.vim
9 changes: 9 additions & 0 deletions .mockery.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
quiet: False
with-expecter: true
dir: "cmd/transcriber/mocks/{{.PackagePath}}"
packages:
github.com/mattermost/calls-transcriber/cmd/transcriber/call:
config:
interfaces:
APIClient:

9 changes: 9 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ GO_TEST_OPTS += -mod=readonly -failfast -race
# Temporary folder to output compiled binaries artifacts
GO_OUT_BIN_DIR := ./dist

# We need to export GOBIN to allow it to be set
# for processes spawned from the Makefile
export GOBIN ?= $(PWD)/bin

## Github Variables
# A github access token that provides access to upload artifacts under releases
GITHUB_TOKEN ?= a_token
Expand Down Expand Up @@ -414,3 +418,8 @@ clean: ## to clean-up
@$(INFO) cleaning /${GO_OUT_BIN_DIR} folder...
$(AT)rm -rf ${GO_OUT_BIN_DIR} || ${FAIL}
@$(OK) cleaning /${GO_OUT_BIN_DIR} folder

.PHONY: mocks
mocks: ## Create mock files
$(GO) install github.com/vektra/mockery/v2/...@v2.40.3
$(GOBIN)/mockery
2 changes: 1 addition & 1 deletion cmd/transcriber/call/job.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (

func (t *Transcriber) postJobStatus(status public.JobStatus) error {
apiURL := fmt.Sprintf("%s/plugins/%s/bot/calls/%s/jobs/%s/status",
t.apiClient.URL, pluginID, t.cfg.CallID, t.cfg.TranscriptionID)
t.apiURL, pluginID, t.cfg.CallID, t.cfg.TranscriptionID)

payload, err := json.Marshal(&status)
if err != nil {
Expand Down
1 change: 0 additions & 1 deletion cmd/transcriber/call/live_captions.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,6 @@ func (t *Transcriber) processLiveCaptionsForTrack(ctx trackContext, pktPayloadsC
}
if err := t.client.SendWS(wsEvCaption, public.CaptionMsg{
SessionID: ctx.sessionID,
UserID: ctx.user.Id,
Text: text,
NewAudioLenMs: float64(newAudioLenMs),
}, false); err != nil {
Expand Down
19 changes: 10 additions & 9 deletions cmd/transcriber/call/tracks.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,28 +73,29 @@ func (t *Transcriber) handleTrack(ctx any) error {
return nil
}

user, err := t.getUserForSession(sessionID)
if err != nil {
return fmt.Errorf("failed to get user for session: %w", err)
}

t.liveTracksWg.Add(1)
go t.processLiveTrack(track, sessionID, user)
go t.processLiveTrack(track, sessionID)

return nil
}

// processLiveTrack saves the content of a voice track to a file for later processing.
// This involves muxing the raw Opus packets into a OGG file with the
// timings adjusted to account for any potential gaps due to mute/unmute sequences.
func (t *Transcriber) processLiveTrack(track trackRemote, sessionID string, user *model.User) {
func (t *Transcriber) processLiveTrack(track trackRemote, sessionID string) {
ctx := trackContext{
trackID: track.ID(),
sessionID: sessionID,
user: user,
filename: filepath.Join(getDataDir(), fmt.Sprintf("%s_%s.ogg", user.Id, track.ID())),
}

user, err := t.getUserForSession(ctx.sessionID)
if err != nil {
slog.Error("failed to get user for session", slog.String("err", err.Error()), slog.String("trackID", ctx.trackID))
return
}
ctx.user = user
ctx.filename = filepath.Join(getDataDir(), fmt.Sprintf("%s_%s.ogg", user.Id, track.ID()))

slog.Debug("processing voice track",
slog.String("username", user.Username),
slog.String("sessionID", sessionID),
Expand Down
28 changes: 17 additions & 11 deletions cmd/transcriber/call/transcriber.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ package call
import (
"context"
"fmt"
"io"
"log/slog"
"net/http"
"sync"
"sync/atomic"
"time"
Expand All @@ -21,11 +23,18 @@ const (
maxTracksContexes = 256
)

type APIClient interface {
DoAPIRequest(ctx context.Context, method, url, data, etag string) (*http.Response, error)
DoAPIRequestBytes(ctx context.Context, method, url string, data []byte, etag string) (*http.Response, error)
DoAPIRequestReader(ctx context.Context, method, url string, data io.Reader, headers map[string]string) (*http.Response, error)
}

type Transcriber struct {
cfg config.CallTranscriberConfig

client *client.Client
apiClient *model.Client4
apiClient APIClient
apiURL string

errCh chan error
doneCh chan struct{}
Expand All @@ -50,6 +59,7 @@ func NewTranscriber(cfg config.CallTranscriberConfig) (t *Transcriber, retErr er
t = &Transcriber{
cfg: cfg,
apiClient: apiClient,
apiURL: apiClient.URL,
}

defer func() {
Expand All @@ -75,16 +85,12 @@ func NewTranscriber(cfg config.CallTranscriberConfig) (t *Transcriber, retErr er
return t, err
}

t = &Transcriber{
cfg: cfg,
client: rtcdClient,
apiClient: apiClient,
errCh: make(chan error, 1),
doneCh: make(chan struct{}),
trackCtxs: make(chan trackContext, maxTracksContexes),
captionsPoolQueueCh: make(chan captionPackage, transcriberQueueChBuffer),
captionsPoolDoneCh: make(chan struct{}),
}
t.client = rtcdClient
t.errCh = make(chan error, 1)
t.doneCh = make(chan struct{})
t.trackCtxs = make(chan trackContext, maxTracksContexes)
t.captionsPoolQueueCh = make(chan captionPackage, transcriberQueueChBuffer)
t.captionsPoolDoneCh = make(chan struct{})

return
}
Expand Down
81 changes: 72 additions & 9 deletions cmd/transcriber/call/transcriber_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,24 @@ import (
"fmt"
"io"
"log/slog"
"net/http"
"os"
"path/filepath"
"strings"
"testing"
"time"

"github.com/mattermost/calls-transcriber/cmd/transcriber/config"
"github.com/mattermost/calls-transcriber/cmd/transcriber/ogg"

mocks "github.com/mattermost/calls-transcriber/cmd/transcriber/mocks/github.com/mattermost/calls-transcriber/cmd/transcriber/call"

"github.com/mattermost/mattermost/server/public/model"

"github.com/pion/interceptor"
"github.com/pion/rtp"

"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -104,6 +110,17 @@ func TestProcessLiveTrack(t *testing.T) {
t.Run("empty payloads", func(t *testing.T) {
tr := setupTranscriberForTest(t)

mockClient := &mocks.MockAPIClient{}
tr.apiClient = mockClient

defer mockClient.AssertExpectations(t)

mockClient.On("DoAPIRequest", mock.Anything, http.MethodGet,
"http://localhost:8065/plugins/com.mattermost.calls/bot/calls/8w8jorhr7j83uqr6y1st894hqe/sessions/sessionID/profile", "", "").
Return(&http.Response{
Body: io.NopCloser(strings.NewReader(`{"id": "userID", "username": "testuser"}`)),
}, nil).Once()

track := &trackRemoteMock{
id: "trackID",
}
Expand Down Expand Up @@ -158,19 +175,18 @@ func TestProcessLiveTrack(t *testing.T) {
}

sessionID := "sessionID"
user := &model.User{Id: "userID", Username: "testuser"}

dataDir := os.Getenv("DATA_DIR")
os.Setenv("DATA_DIR", os.TempDir())
defer os.Setenv("DATA_DIR", dataDir)

tr.liveTracksWg.Add(1)
tr.startTime.Store(newTimeP(time.Now().Add(-time.Second)))
tr.processLiveTrack(track, sessionID, user)
tr.processLiveTrack(track, sessionID)
close(tr.trackCtxs)
require.Len(t, tr.trackCtxs, 1)

trackFile, err := os.Open(filepath.Join(getDataDir(), fmt.Sprintf("%s_%s.ogg", user.Id, track.id)))
trackFile, err := os.Open(filepath.Join(getDataDir(), fmt.Sprintf("userID_%s.ogg", track.id)))
defer trackFile.Close()
require.NoError(t, err)

Expand Down Expand Up @@ -206,6 +222,17 @@ func TestProcessLiveTrack(t *testing.T) {
t.Run("out of order packets", func(t *testing.T) {
tr := setupTranscriberForTest(t)

mockClient := &mocks.MockAPIClient{}
tr.apiClient = mockClient

defer mockClient.AssertExpectations(t)

mockClient.On("DoAPIRequest", mock.Anything, http.MethodGet,
"http://localhost:8065/plugins/com.mattermost.calls/bot/calls/8w8jorhr7j83uqr6y1st894hqe/sessions/sessionID/profile", "", "").
Return(&http.Response{
Body: io.NopCloser(strings.NewReader(`{"id": "userID", "username": "testuser"}`)),
}, nil).Once()

track := &trackRemoteMock{
id: "trackID",
}
Expand Down Expand Up @@ -253,19 +280,18 @@ func TestProcessLiveTrack(t *testing.T) {
}

sessionID := "sessionID"
user := &model.User{Id: "userID", Username: "testuser"}

dataDir := os.Getenv("DATA_DIR")
os.Setenv("DATA_DIR", os.TempDir())
defer os.Setenv("DATA_DIR", dataDir)

tr.liveTracksWg.Add(1)
tr.startTime.Store(newTimeP(time.Now().Add(-time.Second)))
tr.processLiveTrack(track, sessionID, user)
tr.processLiveTrack(track, sessionID)
close(tr.trackCtxs)
require.Len(t, tr.trackCtxs, 1)

trackFile, err := os.Open(filepath.Join(getDataDir(), fmt.Sprintf("%s_%s.ogg", user.Id, track.id)))
trackFile, err := os.Open(filepath.Join(getDataDir(), fmt.Sprintf("userID_%s.ogg", track.id)))
defer trackFile.Close()
require.NoError(t, err)

Expand Down Expand Up @@ -300,6 +326,17 @@ func TestProcessLiveTrack(t *testing.T) {
t.Run("timestamp wrap around", func(t *testing.T) {
tr := setupTranscriberForTest(t)

mockClient := &mocks.MockAPIClient{}
tr.apiClient = mockClient

defer mockClient.AssertExpectations(t)

mockClient.On("DoAPIRequest", mock.Anything, http.MethodGet,
"http://localhost:8065/plugins/com.mattermost.calls/bot/calls/8w8jorhr7j83uqr6y1st894hqe/sessions/sessionID/profile", "", "").
Return(&http.Response{
Body: io.NopCloser(strings.NewReader(`{"id": "userID", "username": "testuser"}`)),
}, nil).Once()

track := &trackRemoteMock{
id: "trackID",
}
Expand Down Expand Up @@ -347,19 +384,18 @@ func TestProcessLiveTrack(t *testing.T) {
}

sessionID := "sessionID"
user := &model.User{Id: "userID", Username: "testuser"}

dataDir := os.Getenv("DATA_DIR")
os.Setenv("DATA_DIR", os.TempDir())
defer os.Setenv("DATA_DIR", dataDir)

tr.liveTracksWg.Add(1)
tr.startTime.Store(newTimeP(time.Now().Add(-time.Second)))
tr.processLiveTrack(track, sessionID, user)
tr.processLiveTrack(track, sessionID)
close(tr.trackCtxs)
require.Len(t, tr.trackCtxs, 1)

trackFile, err := os.Open(filepath.Join(getDataDir(), fmt.Sprintf("%s_%s.ogg", user.Id, track.id)))
trackFile, err := os.Open(filepath.Join(getDataDir(), fmt.Sprintf("userID_%s.ogg", track.id)))
defer trackFile.Close()
require.NoError(t, err)

Expand Down Expand Up @@ -395,4 +431,31 @@ func TestProcessLiveTrack(t *testing.T) {
require.Equal(t, io.EOF, err)
})
})

t.Run("should reattempt getUserForSession on failure", func(t *testing.T) {
tr := setupTranscriberForTest(t)

mockClient := &mocks.MockAPIClient{}
tr.apiClient = mockClient

defer mockClient.AssertExpectations(t)

mockClient.On("DoAPIRequest", mock.Anything, http.MethodGet,
"http://localhost:8065/plugins/com.mattermost.calls/bot/calls/8w8jorhr7j83uqr6y1st894hqe/sessions/sessionID/profile", "", "").
Return(nil, fmt.Errorf("failed")).Once()

mockClient.On("DoAPIRequest", mock.Anything, http.MethodGet,
"http://localhost:8065/plugins/com.mattermost.calls/bot/calls/8w8jorhr7j83uqr6y1st894hqe/sessions/sessionID/profile", "", "").
Return(&http.Response{
Body: io.NopCloser(strings.NewReader(`{"id": "userID", "username": "testuser"}`)),
}, nil).Once()

tr.liveTracksWg.Add(1)
tr.startTime.Store(newTimeP(time.Now().Add(-time.Second)))
tr.processLiveTrack(&trackRemoteMock{
id: "trackID",
}, "sessionID")
close(tr.trackCtxs)
require.Len(t, tr.trackCtxs, 1)
})
}
Loading

0 comments on commit ab408d0

Please sign in to comment.