diff --git a/Makefile b/Makefile index dd5bf2d..950ac70 100644 --- a/Makefile +++ b/Makefile @@ -106,6 +106,7 @@ GO_LDFLAGS += -X "github.com/mattermost/${APP_NAME}/service.bu GO_LDFLAGS += -X "github.com/mattermost/${APP_NAME}/service.buildVersion=$(APP_VERSION)" GO_LDFLAGS += -X "github.com/mattermost/${APP_NAME}/service.buildDate=$(BUILD_DATE)" GO_LDFLAGS += -X "github.com/mattermost/${APP_NAME}/service.goVersion=$(GO_VERSION)" +GO_LDFLAGS += -X "github.com/mattermost/${APP_NAME}/cmd/transcriber/config.inTranscriber=true" # Architectures to build for GO_BUILD_PLATFORMS ?= "linux-${ARCH}" GO_BUILD_PLATFORMS_ARTIFACTS = $(foreach cmd,$(addprefix go-build/,${APP_NAME}),$(addprefix $(cmd)-,$(GO_BUILD_PLATFORMS))) diff --git a/cmd/transcriber/call/transcriber.go b/cmd/transcriber/call/transcriber.go index 23b82f2..0cd3d13 100644 --- a/cmd/transcriber/call/transcriber.go +++ b/cmd/transcriber/call/transcriber.go @@ -16,7 +16,6 @@ import ( const ( pluginID = "com.mattermost.calls" - wsEvPrefix = "custom_" + pluginID + "_" wsEvCaption = "custom_" + pluginID + "_caption" wsEvMetric = "custom_" + pluginID + "_metric" maxTracksContexes = 256 @@ -40,27 +39,45 @@ type Transcriber struct { captionsPoolDoneCh chan struct{} } -func NewTranscriber(cfg config.CallTranscriberConfig) (*Transcriber, error) { +func NewTranscriber(cfg config.CallTranscriberConfig) (t *Transcriber, retErr error) { + if err := cfg.IsValidURL(); err != nil { + return nil, fmt.Errorf("failed to validate URL: %w", err) + } + + apiClient := model.NewAPIv4Client(cfg.SiteURL) + apiClient.SetToken(cfg.AuthToken) + + t = &Transcriber{ + cfg: cfg, + apiClient: apiClient, + } + + defer func() { + if retErr != nil && t != nil { + retErrStr := fmt.Errorf("failed to create Transcriber: %w", retErr) + if err := t.ReportJobFailure(retErrStr.Error()); err != nil { + retErr = fmt.Errorf("failed to report job failure: %s, original error: %s", err.Error(), retErrStr) + } + } + }() + if err := cfg.IsValid(); err != nil { - return nil, fmt.Errorf("failed to validate config: %w", err) + return t, err } - client, err := client.New(client.Config{ + rtcdClient, err := client.New(client.Config{ SiteURL: cfg.SiteURL, AuthToken: cfg.AuthToken, ChannelID: cfg.CallID, JobID: cfg.TranscriptionID, }) if err != nil { - return nil, fmt.Errorf("failed to create calls client: %w", err) + return t, err } - apiClient := model.NewAPIv4Client(cfg.SiteURL) - apiClient.SetToken(cfg.AuthToken) - - t := &Transcriber{ + t = &Transcriber{ cfg: cfg, - client: client, + client: rtcdClient, apiClient: apiClient, errCh: make(chan error, 1), doneCh: make(chan struct{}), @@ -68,7 +85,8 @@ func NewTranscriber(cfg config.CallTranscriberConfig) (*Transcriber, error) { captionsPoolQueueCh: make(chan captionPackage, transcriberQueueChBuffer), captionsPoolDoneCh: make(chan struct{}), } - return t, nil + + return } func (t *Transcriber) Start(ctx context.Context) error { diff --git a/cmd/transcriber/config/config.go b/cmd/transcriber/config/config.go index fe39c56..b937e07 100644 --- a/cmd/transcriber/config/config.go +++ b/cmd/transcriber/config/config.go @@ -12,7 +12,10 @@ import ( "github.com/mattermost/calls-transcriber/cmd/transcriber/transcribe" ) -var idRE = regexp.MustCompile(`^[a-z0-9]{26}$`) +var ( + inTranscriber = "false" + idRE = regexp.MustCompile(`^[a-z0-9]{26}$`) +) const ( // defaults @@ -95,10 +98,7 @@ func (a TranscribeAPI) IsValid() bool { } } -func (cfg CallTranscriberConfig) IsValid() error { - if cfg == (CallTranscriberConfig{}) { - return fmt.Errorf("config cannot be empty") - } +func (cfg CallTranscriberConfig) IsValidURL() error { if cfg.SiteURL == "" { return fmt.Errorf("SiteURL cannot be empty") } @@ -112,16 +112,28 @@ func (cfg CallTranscriberConfig) IsValid() error { return fmt.Errorf("SiteURL parsing failed: invalid path %q", u.Path) } + return nil +} + +func (cfg CallTranscriberConfig) IsValid() error { + if cfg == (CallTranscriberConfig{}) { + return fmt.Errorf("config cannot be empty") + } + + if err := cfg.IsValidURL(); err != nil { + return err + } + if cfg.CallID == "" { return fmt.Errorf("CallID cannot be empty") } else if !idRE.MatchString(cfg.CallID) { return fmt.Errorf("CallID parsing failed") } - if cfg.PostID == "" { - return fmt.Errorf("PostID cannot be empty") - } else if !idRE.MatchString(cfg.PostID) { - return fmt.Errorf("PostID parsing failed") + if cfg.TranscriptionID == "" { + return fmt.Errorf("TranscriptionID cannot be empty") + } else if !idRE.MatchString(cfg.TranscriptionID) { + return fmt.Errorf("TranscriptionID parsing failed") } if cfg.AuthToken == "" { @@ -130,10 +142,10 @@ func (cfg CallTranscriberConfig) IsValid() error { return fmt.Errorf("AuthToken parsing failed") } - if cfg.TranscriptionID == "" { - return fmt.Errorf("TranscriptionID cannot be empty") - } else if !idRE.MatchString(cfg.TranscriptionID) { - return fmt.Errorf("TranscriptionID parsing failed") + if cfg.PostID == "" { + return fmt.Errorf("PostID cannot be empty") + } else if !idRE.MatchString(cfg.PostID) { + return fmt.Errorf("PostID parsing failed") } if !cfg.TranscribeAPI.IsValid() { @@ -146,16 +158,21 @@ func (cfg CallTranscriberConfig) IsValid() error { return fmt.Errorf("OutputFormat value is not valid") } - numCPU := runtime.NumCPU() - if cfg.NumThreads < 1 || cfg.NumThreads > numCPU { - return fmt.Errorf("NumThreads should be in the range [1, %d]", numCPU) - } - if cfg.LiveCaptionsOn { - if cfg.LiveCaptionsNumTranscribers < 1 || cfg.LiveCaptionsNumThreadsPerTranscriber < 1 || - cfg.LiveCaptionsNumTranscribers*cfg.LiveCaptionsNumThreadsPerTranscriber > numCPU { - return fmt.Errorf("LiveCaptionsNumTranscribers * LiveCaptionsNumThreadsPerTranscriber should be in the range [1, %d]", numCPU) + if inTranscriber == "true" { + numCPU := runtime.NumCPU() + if cfg.NumThreads < 1 || cfg.NumThreads > numCPU { + return fmt.Errorf("NumThreads should be in the range [1, %d]", numCPU) } + if cfg.LiveCaptionsOn { + if cfg.LiveCaptionsNumTranscribers < 1 || cfg.LiveCaptionsNumThreadsPerTranscriber < 1 || + cfg.LiveCaptionsNumTranscribers*cfg.LiveCaptionsNumThreadsPerTranscriber > numCPU { + return fmt.Errorf("LiveCaptionsNumTranscribers * LiveCaptionsNumThreadsPerTranscriber should be in the range [1, %d]", numCPU) + } + } + } + + if cfg.LiveCaptionsOn { if !cfg.LiveCaptionsModelSize.IsValid() { return fmt.Errorf("LiveCaptionsModelSize value is not valid") } diff --git a/cmd/transcriber/config/config_test.go b/cmd/transcriber/config/config_test.go index 345855a..02163be 100644 --- a/cmd/transcriber/config/config_test.go +++ b/cmd/transcriber/config/config_test.go @@ -16,6 +16,7 @@ func TestConfigIsValid(t *testing.T) { tcs := []struct { name string cfg CallTranscriberConfig + inTranscriber string expectedError string }{ { @@ -38,32 +39,31 @@ func TestConfigIsValid(t *testing.T) { expectedError: "CallID cannot be empty", }, { - name: "missing PostID", + name: "missing TranscriptionID", cfg: CallTranscriberConfig{ - SiteURL: "http://localhost:8065", - CallID: "8w8jorhr7j83uqr6y1st894hqe", - AuthToken: "qj75unbsef83ik9p7ueypb6iyw", + SiteURL: "http://localhost:8065", + CallID: "8w8jorhr7j83uqr6y1st894hqe", }, - expectedError: "PostID cannot be empty", + expectedError: "TranscriptionID cannot be empty", }, { name: "missing AuthToken", cfg: CallTranscriberConfig{ - SiteURL: "http://localhost:8065", - CallID: "8w8jorhr7j83uqr6y1st894hqe", - PostID: "udzdsg7dwidbzcidx5khrf8nee", + SiteURL: "http://localhost:8065", + CallID: "8w8jorhr7j83uqr6y1st894hqe", + TranscriptionID: "on5yfih5etn5m8rfdidamc1oxa", }, expectedError: "AuthToken cannot be empty", }, { - name: "missing TranscriptionID", + name: "missing PostID", cfg: CallTranscriberConfig{ - SiteURL: "http://localhost:8065", - CallID: "8w8jorhr7j83uqr6y1st894hqe", - PostID: "udzdsg7dwidbzcidx5khrf8nee", - AuthToken: "qj75unbsef83ik9p7ueypb6iyw", + SiteURL: "http://localhost:8065", + CallID: "8w8jorhr7j83uqr6y1st894hqe", + TranscriptionID: "on5yfih5etn5m8rfdidamc1oxa", + AuthToken: "qj75unbsef83ik9p7ueypb6iyw", }, - expectedError: "TranscriptionID cannot be empty", + expectedError: "PostID cannot be empty", }, { name: "invalid TranscribeAPI", @@ -114,8 +114,24 @@ func TestConfigIsValid(t *testing.T) { ModelSize: ModelSizeMedium, OutputFormat: OutputFormatVTT, }, + inTranscriber: "true", expectedError: fmt.Sprintf("NumThreads should be in the range [1, %d]", runtime.NumCPU()), }, + { + name: "valid NumThreads if not in container", + cfg: CallTranscriberConfig{ + SiteURL: "http://localhost:8065", + CallID: "8w8jorhr7j83uqr6y1st894hqe", + PostID: "udzdsg7dwidbzcidx5khrf8nee", + AuthToken: "qj75unbsef83ik9p7ueypb6iyw", + TranscriptionID: "on5yfih5etn5m8rfdidamc1oxa", + TranscribeAPI: TranscribeAPIDefault, + ModelSize: ModelSizeMedium, + OutputFormat: OutputFormatVTT, + }, + inTranscriber: "false", + expectedError: "SilenceThresholdMs should be a positive number", + }, { name: "invalid SilenceThresholdMs", cfg: CallTranscriberConfig{ @@ -184,8 +200,58 @@ func TestConfigIsValid(t *testing.T) { }, }, }, + inTranscriber: "true", expectedError: fmt.Sprintf("LiveCaptionsNumTranscribers * LiveCaptionsNumThreadsPerTranscriber should be in the range [1, %d]", runtime.NumCPU()), }, + { + name: "valid LiveCaptionsNumTranscribers if not in a container", + cfg: CallTranscriberConfig{ + SiteURL: "http://localhost:8065", + CallID: "8w8jorhr7j83uqr6y1st894hqe", + PostID: "udzdsg7dwidbzcidx5khrf8nee", + AuthToken: "qj75unbsef83ik9p7ueypb6iyw", + TranscriptionID: "on5yfih5etn5m8rfdidamc1oxa", + TranscribeAPI: TranscribeAPIDefault, + ModelSize: ModelSizeMedium, + OutputFormat: OutputFormatVTT, + NumThreads: 1, + LiveCaptionsOn: true, + OutputOptions: OutputOptions{ + Text: transcribe.TextOptions{ + CompactOptions: transcribe.TextCompactOptions{ + SilenceThresholdMs: 2000, + MaxSegmentDurationMs: 10000, + }, + }, + }, + }, + expectedError: "LiveCaptionsModelSize value is not valid", + }, + { + name: "invalid LiveCaptionsNumTranscribers", + cfg: CallTranscriberConfig{ + SiteURL: "http://localhost:8065", + CallID: "8w8jorhr7j83uqr6y1st894hqe", + PostID: "udzdsg7dwidbzcidx5khrf8nee", + AuthToken: "qj75unbsef83ik9p7ueypb6iyw", + TranscriptionID: "on5yfih5etn5m8rfdidamc1oxa", + TranscribeAPI: TranscribeAPIDefault, + ModelSize: ModelSizeMedium, + OutputFormat: OutputFormatVTT, + NumThreads: 1, + LiveCaptionsOn: true, + OutputOptions: OutputOptions{ + Text: transcribe.TextOptions{ + CompactOptions: transcribe.TextCompactOptions{ + SilenceThresholdMs: 2000, + MaxSegmentDurationMs: 10000, + }, + }, + }, + }, + inTranscriber: "false", + expectedError: "LiveCaptionsModelSize value is not valid", + }, { name: "invalid LiveCaptionsLanguage", cfg: CallTranscriberConfig{ @@ -245,6 +311,7 @@ func TestConfigIsValid(t *testing.T) { for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { + inTranscriber = tc.inTranscriber err := tc.cfg.IsValid() if tc.expectedError == "" { require.NoError(t, err) @@ -417,6 +484,8 @@ func TestCallTranscriberConfigMap(t *testing.T) { cfg.OutputOptions.WebVTT.OmitSpeaker = true cfg.SetDefaults() + inTranscriber = "true" + t.Run("default config", func(t *testing.T) { var c CallTranscriberConfig err := c.FromMap(cfg.ToMap()).IsValid()