diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index 0cb29695f968b..ff6cdd143e7e1 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -235,7 +235,7 @@ jobs: - name: Check if Terraform resources are up to date # We have to add the current directory as a safe directory or else git commands will not work as expected. # The protoc-gen-terraform version must match the version in integrations/terraform/Makefile - run: git config --global --add safe.directory $(realpath .) && go install github.com/gravitational/protoc-gen-terraform@c91cc3ef4d7d0046c36cb96b1cd337e466c61225 && make terraform-resources-up-to-date + run: git config --global --add safe.directory $(realpath .) && go install github.com/gravitational/protoc-gen-terraform/v3@v3.0.2 && make terraform-resources-up-to-date lint-rfd: name: Lint (RFD) diff --git a/.golangci.yml b/.golangci.yml index 98859bad6c7d9..ecc5e7c8e253f 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -115,14 +115,6 @@ linters-settings: desc: 'use "crypto" or "x/crypto" instead' # Prevent importing any additional logging libraries. logging: - files: - # Integrations are still allowed to use logrus becuase they haven't - # been converted to slog yet. Once they use slog, remove this exception. - - '!**/integrations/**' - # The log package still contains the logrus formatter consumed by the integrations. - # Remove this exception when said formatter is deleted. - - '!**/lib/utils/log/**' - - '!**/lib/utils/cli.go' deny: - pkg: github.com/sirupsen/logrus desc: 'use "log/slog" instead' diff --git a/docs/pages/admin-guides/access-controls/access-monitoring.mdx b/docs/pages/admin-guides/access-controls/access-monitoring.mdx index 7f5a7b2a0a864..25797cf3e89d3 100644 --- a/docs/pages/admin-guides/access-controls/access-monitoring.mdx +++ b/docs/pages/admin-guides/access-controls/access-monitoring.mdx @@ -17,7 +17,7 @@ Users are able to write their own custom access monitoring queries by querying t Access Monitoring is not currently supported with External Audit Storage - in Teleport Enterprise (cloud-hosted). This functionality will be + in Teleport Enterprise (Cloud). This functionality will be enabled in a future Teleport release. diff --git a/docs/pages/admin-guides/management/external-audit-storage.mdx b/docs/pages/admin-guides/management/external-audit-storage.mdx index 6aa2fcc0368b8..587bb7ffebe56 100644 --- a/docs/pages/admin-guides/management/external-audit-storage.mdx +++ b/docs/pages/admin-guides/management/external-audit-storage.mdx @@ -21,6 +21,12 @@ External Audit Storage is based on Teleport's available on Teleport Enterprise Cloud clusters running Teleport v14.2.1 or above. + +On Teleport Enterprise (Cloud), External Audit +Storage is not currently supported for users who have Access Monitoring enabled. +This functionality will be enabled in a future Teleport release. + + ## Prerequisites 1. A Teleport Enterprise Cloud account. If you do not have one, [sign diff --git a/go.mod b/go.mod index 3c35132910093..78f04732806b6 100644 --- a/go.mod +++ b/go.mod @@ -179,7 +179,6 @@ require ( github.com/sigstore/cosign/v2 v2.4.1 github.com/sigstore/sigstore v1.8.11 github.com/sijms/go-ora/v2 v2.8.22 - github.com/sirupsen/logrus v1.9.3 github.com/snowflakedb/gosnowflake v1.12.1 github.com/spf13/cobra v1.8.1 github.com/spiffe/go-spiffe/v2 v2.4.0 @@ -501,6 +500,7 @@ require ( github.com/sigstore/protobuf-specs v0.3.2 // indirect github.com/sigstore/rekor v1.3.6 // indirect github.com/sigstore/timestamp-authority v1.2.2 // indirect + github.com/sirupsen/logrus v1.9.3 // indirect github.com/sourcegraph/conc v0.3.0 // indirect github.com/spf13/afero v1.11.0 // indirect github.com/spf13/cast v1.7.0 // indirect diff --git a/integrations/access/accesslist/app.go b/integrations/access/accesslist/app.go index 02f933baf5ecd..ba40de3abf575 100644 --- a/integrations/access/accesslist/app.go +++ b/integrations/access/accesslist/app.go @@ -33,6 +33,7 @@ import ( "github.com/gravitational/teleport/integrations/lib" "github.com/gravitational/teleport/integrations/lib/logger" pd "github.com/gravitational/teleport/integrations/lib/plugindata" + logutils "github.com/gravitational/teleport/lib/utils/log" ) const ( @@ -118,7 +119,7 @@ func (a *App) run(ctx context.Context) error { log := logger.Get(ctx) - log.Info("Access list monitor is running") + log.InfoContext(ctx, "Access list monitor is running") a.job.SetReady(true) @@ -134,7 +135,7 @@ func (a *App) run(ctx context.Context) error { } timer.Reset(jitter(reminderInterval)) case <-ctx.Done(): - log.Info("Access list monitor is finished") + log.InfoContext(ctx, "Access list monitor is finished") return nil } } @@ -146,7 +147,7 @@ func (a *App) run(ctx context.Context) error { func (a *App) remindIfNecessary(ctx context.Context) error { log := logger.Get(ctx) - log.Info("Looking for Access List Review reminders") + log.InfoContext(ctx, "Looking for Access List Review reminders") var nextToken string var err error @@ -156,13 +157,14 @@ func (a *App) remindIfNecessary(ctx context.Context) error { accessLists, nextToken, err = a.apiClient.ListAccessLists(ctx, 0 /* default page size */, nextToken) if err != nil { if trace.IsNotImplemented(err) { - log.Errorf("access list endpoint is not implemented on this auth server, so the access list app is ceasing to run.") + log.ErrorContext(ctx, "access list endpoint is not implemented on this auth server, so the access list app is ceasing to run") return trace.Wrap(err) } else if trace.IsAccessDenied(err) { - log.Warnf("Slack bot does not have permissions to list access lists. Please add access_list read and list permissions " + - "to the role associated with the Slack bot.") + const msg = "Slack bot does not have permissions to list access lists. Please add access_list read and list permissions " + + "to the role associated with the Slack bot." + log.WarnContext(ctx, msg) } else { - log.Errorf("error listing access lists: %v", err) + log.ErrorContext(ctx, "error listing access lists", "error", err) } break } @@ -170,7 +172,10 @@ func (a *App) remindIfNecessary(ctx context.Context) error { for _, accessList := range accessLists { recipients, err := a.getRecipientsRequiringReminders(ctx, accessList) if err != nil { - log.WithError(err).Warnf("Error getting recipients to notify for review due for access list %q", accessList.Spec.Title) + log.WarnContext(ctx, "Error getting recipients to notify for review due for access list", + "error", err, + "access_list", accessList.Spec.Title, + ) continue } @@ -195,7 +200,7 @@ func (a *App) remindIfNecessary(ctx context.Context) error { } if len(errs) > 0 { - log.WithError(trace.NewAggregate(errs...)).Warn("Error notifying for access list reviews") + log.WarnContext(ctx, "Error notifying for access list reviews", "error", trace.NewAggregate(errs...)) } return nil @@ -213,7 +218,10 @@ func (a *App) getRecipientsRequiringReminders(ctx context.Context, accessList *a // If the current time before the notification start time, skip notifications. if now.Before(notificationStart) { - log.Debugf("Access list %s is not ready for notifications, notifications start at %s", accessList.GetName(), notificationStart.Format(time.RFC3339)) + log.DebugContext(ctx, "Access list is not ready for notifications", + "access_list", accessList.GetName(), + "notification_start_time", notificationStart.Format(time.RFC3339), + ) return nil, nil } @@ -255,12 +263,17 @@ func (a *App) fetchRecipients(ctx context.Context, accessList *accesslist.Access if err != nil { // TODO(kiosion): Remove in v18; protecting against server not having `GetAccessListOwners` func. if trace.IsNotImplemented(err) { - log.WithError(err).Warnf("Error getting nested owners for access list '%v', continuing with only explicit owners", accessList.GetName()) + log.WarnContext(ctx, "Error getting nested owners for access list, continuing with only explicit owners", + "error", err, + "access_list", accessList.GetName(), + ) for _, owner := range accessList.Spec.Owners { allOwners = append(allOwners, &owner) } } else { - log.WithError(err).Errorf("Error getting owners for access list '%v'", accessList.GetName()) + log.ErrorContext(ctx, "Error getting owners for access list", + "error", err, + "access_list", accessList.GetName()) } } @@ -270,7 +283,7 @@ func (a *App) fetchRecipients(ctx context.Context, accessList *accesslist.Access for _, owner := range allOwners { recipient, err := a.bot.FetchRecipient(ctx, owner.Name) if err != nil { - log.Debugf("error getting recipient %s", owner.Name) + log.DebugContext(ctx, "error getting recipient", "recipient", owner.Name) continue } allRecipients[owner.Name] = *recipient @@ -293,7 +306,10 @@ func (a *App) updatePluginDataAndGetRecipientsRequiringReminders(ctx context.Con // Calculate days from start. daysFromStart := now.Sub(notificationStart) / oneDay windowStart = notificationStart.Add(daysFromStart * oneDay) - log.Infof("windowStart: %s, now: %s", windowStart.String(), now.String()) + log.InfoContext(ctx, "calculating window start", + "window_start", logutils.StringerAttr(windowStart), + "now", logutils.StringerAttr(now), + ) } recipients := []common.Recipient{} @@ -304,7 +320,10 @@ func (a *App) updatePluginDataAndGetRecipientsRequiringReminders(ctx context.Con // If the notification window is before the last notification date, then this user doesn't need a notification. if !windowStart.After(lastNotification) { - log.Debugf("User %s has already been notified for access list %s", recipient.Name, accessList.GetName()) + log.DebugContext(ctx, "User has already been notified for access list", + "user", recipient.Name, + "access_list", accessList.GetName(), + ) userNotifications[recipient.Name] = lastNotification continue } diff --git a/integrations/access/accessmonitoring/access_monitoring_rules.go b/integrations/access/accessmonitoring/access_monitoring_rules.go index 3dea9ea2bf543..82c91413bff96 100644 --- a/integrations/access/accessmonitoring/access_monitoring_rules.go +++ b/integrations/access/accessmonitoring/access_monitoring_rules.go @@ -151,8 +151,10 @@ func (amrh *RuleHandler) RecipientsFromAccessMonitoringRules(ctx context.Context for _, rule := range amrh.getAccessMonitoringRules() { match, err := MatchAccessRequest(rule.Spec.Condition, req) if err != nil { - log.WithError(err).WithField("rule", rule.Metadata.Name). - Warn("Failed to parse access monitoring notification rule") + log.WarnContext(ctx, "Failed to parse access monitoring notification rule", + "error", err, + "rule", rule.Metadata.Name, + ) } if !match { continue @@ -160,7 +162,7 @@ func (amrh *RuleHandler) RecipientsFromAccessMonitoringRules(ctx context.Context for _, recipient := range rule.Spec.Notification.Recipients { rec, err := amrh.fetchRecipientCallback(ctx, recipient) if err != nil { - log.WithError(err).Warn("Failed to fetch plugin recipients based on Access monitoring rule recipients") + log.WarnContext(ctx, "Failed to fetch plugin recipients based on Access monitoring rule recipients", "error", err) continue } recipientSet.Add(*rec) @@ -176,8 +178,10 @@ func (amrh *RuleHandler) RawRecipientsFromAccessMonitoringRules(ctx context.Cont for _, rule := range amrh.getAccessMonitoringRules() { match, err := MatchAccessRequest(rule.Spec.Condition, req) if err != nil { - log.WithError(err).WithField("rule", rule.Metadata.Name). - Warn("Failed to parse access monitoring notification rule") + log.WarnContext(ctx, "Failed to parse access monitoring notification rule", + "error", err, + "rule", rule.Metadata.Name, + ) } if !match { continue diff --git a/integrations/access/accessrequest/app.go b/integrations/access/accessrequest/app.go index 8a5effc73dabd..17182ec3dc8ee 100644 --- a/integrations/access/accessrequest/app.go +++ b/integrations/access/accessrequest/app.go @@ -21,6 +21,7 @@ package accessrequest import ( "context" "fmt" + "log/slog" "slices" "strings" "time" @@ -36,6 +37,7 @@ import ( "github.com/gravitational/teleport/integrations/lib/logger" pd "github.com/gravitational/teleport/integrations/lib/plugindata" "github.com/gravitational/teleport/integrations/lib/watcherjob" + logutils "github.com/gravitational/teleport/lib/utils/log" ) const ( @@ -189,16 +191,16 @@ func (a *App) onWatcherEvent(ctx context.Context, event types.Event) error { func (a *App) handleAccessRequest(ctx context.Context, event types.Event) error { op := event.Type reqID := event.Resource.GetName() - ctx, _ = logger.WithField(ctx, "request_id", reqID) + ctx, _ = logger.With(ctx, "request_id", reqID) switch op { case types.OpPut: - ctx, _ = logger.WithField(ctx, "request_op", "put") + ctx, _ = logger.With(ctx, "request_op", "put") req, ok := event.Resource.(types.AccessRequest) if !ok { return trace.BadParameter("unexpected resource type %T", event.Resource) } - ctx, log := logger.WithField(ctx, "request_state", req.GetState().String()) + ctx, log := logger.With(ctx, "request_state", req.GetState().String()) var err error switch { @@ -207,21 +209,29 @@ func (a *App) handleAccessRequest(ctx context.Context, event types.Event) error case req.GetState().IsResolved(): err = a.onResolvedRequest(ctx, req) default: - log.WithField("event", event).Warn("Unknown request state") + log.WarnContext(ctx, "Unknown request state", + slog.Group("event", + slog.Any("type", logutils.StringerAttr(event.Type)), + slog.Group("resource", + "kind", event.Resource.GetKind(), + "name", event.Resource.GetName(), + ), + ), + ) return nil } if err != nil { - log.WithError(err).Errorf("Failed to process request") + log.ErrorContext(ctx, "Failed to process request", "error", err) return trace.Wrap(err) } return nil case types.OpDelete: - ctx, log := logger.WithField(ctx, "request_op", "delete") + ctx, log := logger.With(ctx, "request_op", "delete") if err := a.onDeletedRequest(ctx, reqID); err != nil { - log.WithError(err).Errorf("Failed to process deleted request") + log.ErrorContext(ctx, "Failed to process deleted request", "error", err) return trace.Wrap(err) } return nil @@ -242,7 +252,7 @@ func (a *App) onPendingRequest(ctx context.Context, req types.AccessRequest) err loginsByRole, err := a.getLoginsByRole(ctx, req) if trace.IsAccessDenied(err) { - log.Warnf("Missing permissions to get logins by role. Please add role.read to the associated role. error: %s", err) + log.WarnContext(ctx, "Missing permissions to get logins by role, please add role.read to the associated role", "error", err) } else if err != nil { return trace.Wrap(err) } @@ -265,12 +275,12 @@ func (a *App) onPendingRequest(ctx context.Context, req types.AccessRequest) err return trace.Wrap(err) } } else { - log.Warning("No channel to post") + log.WarnContext(ctx, "No channel to post") } // Try to approve the request if user is currently on-call. if err := a.tryApproveRequest(ctx, reqID, req); err != nil { - log.Warningf("Failed to auto approve request: %v", err) + log.WarnContext(ctx, "Failed to auto approve request", "error", err) } case trace.IsAlreadyExists(err): // The messages were already sent, nothing to do, we can update the reviews @@ -311,7 +321,7 @@ func (a *App) onResolvedRequest(ctx context.Context, req types.AccessRequest) er case types.RequestState_PROMOTED: tag = pd.ResolvedPromoted default: - logger.Get(ctx).Warningf("Unknown state %v (%s)", state, state.String()) + logger.Get(ctx).WarnContext(ctx, "Unknown state", "state", logutils.StringerAttr(state)) return replyErr } err := trace.Wrap(a.updateMessages(ctx, req.GetName(), tag, reason, req.GetReviews())) @@ -330,13 +340,13 @@ func (a *App) broadcastAccessRequestMessages(ctx context.Context, recipients []c return trace.Wrap(err) } for _, data := range sentMessages { - logger.Get(ctx).WithFields(logger.Fields{ - "channel_id": data.ChannelID, - "message_id": data.MessageID, - }).Info("Successfully posted messages") + logger.Get(ctx).InfoContext(ctx, "Successfully posted messages", + "channel_id", data.ChannelID, + "message_id", data.MessageID, + ) } if err != nil { - logger.Get(ctx).WithError(err).Error("Failed to post one or more messages") + logger.Get(ctx).ErrorContext(ctx, "Failed to post one or more messages", "error", err) } _, err = a.pluginData.Update(ctx, reqID, func(existing PluginData) (PluginData, error) { @@ -369,7 +379,7 @@ func (a *App) postReviewReplies(ctx context.Context, reqID string, reqReviews [] return existing, nil }) if trace.IsAlreadyExists(err) { - logger.Get(ctx).Debug("Failed to post reply: replies are already sent") + logger.Get(ctx).DebugContext(ctx, "Failed to post reply: replies are already sent") return nil } if err != nil { @@ -383,7 +393,7 @@ func (a *App) postReviewReplies(ctx context.Context, reqID string, reqReviews [] errors := make([]error, 0, len(slice)) for _, data := range pd.SentMessages { - ctx, _ = logger.WithFields(ctx, logger.Fields{"channel_id": data.ChannelID, "message_id": data.MessageID}) + ctx, _ = logger.With(ctx, "channel_id", data.ChannelID, "message_id", data.MessageID) for _, review := range slice { if err := a.bot.PostReviewReply(ctx, data.ChannelID, data.MessageID, review); err != nil { errors = append(errors, err) @@ -425,7 +435,7 @@ func (a *App) getMessageRecipients(ctx context.Context, req types.AccessRequest) for _, recipient := range recipients { rec, err := a.bot.FetchRecipient(ctx, recipient) if err != nil { - log.Warningf("Failed to fetch Opsgenie recipient: %v", err) + log.WarnContext(ctx, "Failed to fetch Opsgenie recipient", "error", err) continue } recipientSet.Add(*rec) @@ -436,7 +446,7 @@ func (a *App) getMessageRecipients(ctx context.Context, req types.AccessRequest) validEmailSuggReviewers := []string{} for _, reviewer := range req.GetSuggestedReviewers() { if !lib.IsEmail(reviewer) { - log.Warningf("Failed to notify a suggested reviewer: %q does not look like a valid email", reviewer) + log.WarnContext(ctx, "Failed to notify a suggested reviewer with an invalid email address", "reviewer", reviewer) continue } @@ -446,7 +456,7 @@ func (a *App) getMessageRecipients(ctx context.Context, req types.AccessRequest) for _, rawRecipient := range rawRecipients { recipient, err := a.bot.FetchRecipient(ctx, rawRecipient) if err != nil { - log.WithError(err).Warn("Failure when fetching recipient, continuing anyway") + log.WarnContext(ctx, "Failure when fetching recipient, continuing anyway", "error", err) } else { recipientSet.Add(*recipient) } @@ -476,7 +486,7 @@ func (a *App) updateMessages(ctx context.Context, reqID string, tag pd.Resolutio return existing, nil }) if trace.IsNotFound(err) { - log.Debug("Failed to update messages: plugin data is missing") + log.DebugContext(ctx, "Failed to update messages: plugin data is missing") return nil } if trace.IsAlreadyExists(err) { @@ -485,7 +495,7 @@ func (a *App) updateMessages(ctx context.Context, reqID string, tag pd.Resolutio "cannot change the resolution tag of an already resolved request, existing: %s, event: %s", pluginData.ResolutionTag, tag) } - log.Debug("Request is already resolved, ignoring event") + log.DebugContext(ctx, "Request is already resolved, ignoring event") return nil } if err != nil { @@ -496,13 +506,17 @@ func (a *App) updateMessages(ctx context.Context, reqID string, tag pd.Resolutio if err := a.bot.UpdateMessages(ctx, reqID, reqData, sentMessages, reviews); err != nil { return trace.Wrap(err) } - log.Infof("Successfully marked request as %s in all messages", tag) + + log.InfoContext(ctx, "Marked request with resolution and sent emails!", "resolution", tag) if err := a.bot.NotifyUser(ctx, reqID, reqData); err != nil && !trace.IsNotImplemented(err) { return trace.Wrap(err) } - log.Infof("Successfully notified user %s request marked as %s", reqData.User, tag) + log.InfoContext(ctx, "Successfully notified user", + "user", reqData.User, + "resolution", tag, + ) return nil } @@ -545,13 +559,11 @@ func (a *App) getResourceNames(ctx context.Context, req types.AccessRequest) ([] // tryApproveRequest attempts to automatically approve the access request if the // user is on call for the configured service/team. func (a *App) tryApproveRequest(ctx context.Context, reqID string, req types.AccessRequest) error { - log := logger.Get(ctx). - WithField("req_id", reqID). - WithField("user", req.GetUser()) + log := logger.Get(ctx).With("req_id", reqID, "user", req.GetUser()) oncallUsers, err := a.bot.FetchOncallUsers(ctx, req) if trace.IsNotImplemented(err) { - log.Debugf("Skipping auto-approval because %q bot does not support automatic approvals.", a.pluginName) + log.DebugContext(ctx, "Skipping auto-approval because bot does not support automatic approvals", "bot", a.pluginName) return nil } if err != nil { @@ -559,7 +571,7 @@ func (a *App) tryApproveRequest(ctx context.Context, reqID string, req types.Acc } if !slices.Contains(oncallUsers, req.GetUser()) { - log.Debug("Skipping approval because user is not on-call.") + log.DebugContext(ctx, "Skipping approval because user is not on-call") return nil } @@ -573,12 +585,12 @@ func (a *App) tryApproveRequest(ctx context.Context, reqID string, req types.Acc }, }); err != nil { if strings.HasSuffix(err.Error(), "has already reviewed this request") { - log.Debug("Request has already been reviewed.") + log.DebugContext(ctx, "Request has already been reviewed") return nil } return trace.Wrap(err) } - log.Info("Successfully submitted a request approval.") + log.InfoContext(ctx, "Successfully submitted a request approval") return nil } diff --git a/integrations/access/common/app.go b/integrations/access/common/app.go index 805c0dde6ef8a..6c174e1422b75 100644 --- a/integrations/access/common/app.go +++ b/integrations/access/common/app.go @@ -88,7 +88,7 @@ func (a *BaseApp) WaitReady(ctx context.Context) (bool, error) { func (a *BaseApp) checkTeleportVersion(ctx context.Context) (proto.PingResponse, error) { log := logger.Get(ctx) - log.Debug("Checking Teleport server version") + log.DebugContext(ctx, "Checking Teleport server version") pong, err := a.APIClient.Ping(ctx) if err != nil { @@ -156,9 +156,9 @@ func (a *BaseApp) run(ctx context.Context) error { a.mainJob.SetReady(allOK) if allOK { - log.Info("Plugin is ready") + log.InfoContext(ctx, "Plugin is ready") } else { - log.Error("Plugin is not ready") + log.ErrorContext(ctx, "Plugin is not ready") } for _, app := range a.apps { @@ -203,11 +203,11 @@ func (a *BaseApp) init(ctx context.Context) error { } } - log.Debug("Starting API health check...") + log.DebugContext(ctx, "Starting API health check") if err = a.Bot.CheckHealth(ctx); err != nil { return trace.Wrap(err, "API health check failed") } - log.Debug("API health check finished ok") + log.DebugContext(ctx, "API health check finished ok") return nil } diff --git a/integrations/access/common/auth/token_provider.go b/integrations/access/common/auth/token_provider.go index f4ae33936a709..e0c23b0b36427 100644 --- a/integrations/access/common/auth/token_provider.go +++ b/integrations/access/common/auth/token_provider.go @@ -20,12 +20,12 @@ package auth import ( "context" + "log/slog" "sync" "time" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" - "github.com/sirupsen/logrus" "github.com/gravitational/teleport/integrations/access/common/auth/oauth" "github.com/gravitational/teleport/integrations/access/common/auth/storage" @@ -65,7 +65,7 @@ type RotatedAccessTokenProviderConfig struct { Refresher oauth.Refresher Clock clockwork.Clock - Log *logrus.Entry + Log *slog.Logger } // CheckAndSetDefaults validates a configuration and sets default values @@ -87,7 +87,7 @@ func (c *RotatedAccessTokenProviderConfig) CheckAndSetDefaults() error { c.Clock = clockwork.NewRealClock() } if c.Log == nil { - c.Log = logrus.NewEntry(logrus.StandardLogger()) + c.Log = slog.Default() } return nil } @@ -104,7 +104,7 @@ type RotatedAccessTokenProvider struct { refresher oauth.Refresher clock clockwork.Clock - log logrus.FieldLogger + log *slog.Logger lock sync.RWMutex // protects the below fields creds *storage.Credentials @@ -153,12 +153,12 @@ func (r *RotatedAccessTokenProvider) RefreshLoop(ctx context.Context) { timer := r.clock.NewTimer(interval) defer timer.Stop() - r.log.Infof("Will attempt token refresh in: %s", interval) + r.log.InfoContext(ctx, "Starting token refresh loop", "next_refresh", interval) for { select { case <-ctx.Done(): - r.log.Info("Shutting down") + r.log.InfoContext(ctx, "Shutting down") return case <-timer.Chan(): creds, _ := r.store.GetCredentials(ctx) @@ -174,18 +174,21 @@ func (r *RotatedAccessTokenProvider) RefreshLoop(ctx context.Context) { interval := r.getRefreshInterval(creds) timer.Reset(interval) - r.log.Infof("Next refresh in: %s", interval) + r.log.InfoContext(ctx, "Refreshed token", "next_refresh", interval) continue } creds, err := r.refresh(ctx) if err != nil { - r.log.Errorf("Error while refreshing: %s. Will retry after: %s", err, r.retryInterval) + r.log.ErrorContext(ctx, "Error while refreshing token", + "error", err, + "retry_interval", r.retryInterval, + ) timer.Reset(r.retryInterval) } else { err := r.store.PutCredentials(ctx, creds) if err != nil { - r.log.Errorf("Error while storing the refreshed credentials: %s", err) + r.log.ErrorContext(ctx, "Error while storing the refreshed credentials", "error", err) timer.Reset(r.retryInterval) continue } @@ -196,7 +199,7 @@ func (r *RotatedAccessTokenProvider) RefreshLoop(ctx context.Context) { interval := r.getRefreshInterval(creds) timer.Reset(interval) - r.log.Infof("Successfully refreshed credentials. Next refresh in: %s", interval) + r.log.InfoContext(ctx, "Successfully refreshed credentials", "next_refresh", interval) } } } diff --git a/integrations/access/common/auth/token_provider_test.go b/integrations/access/common/auth/token_provider_test.go index fca79776ba024..e4f02ec3d3ae5 100644 --- a/integrations/access/common/auth/token_provider_test.go +++ b/integrations/access/common/auth/token_provider_test.go @@ -20,12 +20,12 @@ package auth import ( "context" + "log/slog" "testing" "time" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" - "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" "github.com/gravitational/teleport/integrations/access/common/auth/oauth" @@ -57,9 +57,6 @@ func (s *mockStore) PutCredentials(ctx context.Context, creds *storage.Credentia } func TestRotatedAccessTokenProvider(t *testing.T) { - log := logrus.New() - log.Level = logrus.DebugLevel - newProvider := func(ctx context.Context, store storage.Store, refresher oauth.Refresher, clock clockwork.Clock, initialCreds *storage.Credentials) *RotatedAccessTokenProvider { return &RotatedAccessTokenProvider{ store: store, @@ -70,7 +67,7 @@ func TestRotatedAccessTokenProvider(t *testing.T) { tokenBufferInterval: 1 * time.Hour, creds: initialCreds, - log: log, + log: slog.Default(), } } diff --git a/integrations/access/datadog/bot.go b/integrations/access/datadog/bot.go index e92dbbb524a20..4e1f52a6c218d 100644 --- a/integrations/access/datadog/bot.go +++ b/integrations/access/datadog/bot.go @@ -162,7 +162,7 @@ func (b Bot) FetchOncallUsers(ctx context.Context, req types.AccessRequest) ([]s annotationKey := types.TeleportNamespace + types.ReqAnnotationApproveSchedulesLabel teamNames, err := common.GetNamesFromAnnotations(req, annotationKey) if err != nil { - log.Debug("Automatic approvals annotation is empty or unspecified.") + log.DebugContext(ctx, "Automatic approvals annotation is empty or unspecified") return nil, nil } diff --git a/integrations/access/datadog/client.go b/integrations/access/datadog/client.go index 489eb0c51a44d..2d4ebf79ea5f2 100644 --- a/integrations/access/datadog/client.go +++ b/integrations/access/datadog/client.go @@ -126,7 +126,7 @@ func onAfterDatadogResponse(sink common.StatusSink) resty.ResponseMiddleware { defer cancel() if err := sink.Emit(ctx, status); err != nil { - log.WithError(err).Errorf("Error while emitting Datadog Incident Management plugin status: %v", err) + log.ErrorContext(ctx, "Error while emitting Datadog Incident Management plugin status", "error", err) } } diff --git a/integrations/access/datadog/cmd/teleport-datadog/main.go b/integrations/access/datadog/cmd/teleport-datadog/main.go index cb9cbd1959771..84a6a14c0955f 100644 --- a/integrations/access/datadog/cmd/teleport-datadog/main.go +++ b/integrations/access/datadog/cmd/teleport-datadog/main.go @@ -22,6 +22,7 @@ import ( "context" _ "embed" "fmt" + "log/slog" "os" "github.com/alecthomas/kingpin/v2" @@ -67,12 +68,13 @@ func main() { if err := run(*path, *debug); err != nil { lib.Bail(err) } else { - logger.Standard().Info("Successfully shut down") + slog.InfoContext(context.Background(), "Successfully shut down") } } } func run(configPath string, debug bool) error { + ctx := context.Background() conf, err := datadog.LoadConfig(configPath) if err != nil { return trace.Wrap(err) @@ -86,14 +88,15 @@ func run(configPath string, debug bool) error { return err } if debug { - logger.Standard().Debugf("DEBUG logging enabled") + slog.DebugContext(ctx, "DEBUG logging enabled") } app := datadog.NewDatadogApp(conf) go lib.ServeSignals(app, common.PluginShutdownTimeout) - logger.Standard().Infof("Starting Teleport Access Datadog Incident Management Plugin %s:%s", teleport.Version, teleport.Gitref) - return trace.Wrap( - app.Run(context.Background()), + slog.InfoContext(ctx, "Starting Teleport Access Datadog Incident Management Plugin", + "version", teleport.Version, + "git_ref", teleport.Gitref, ) + return trace.Wrap(app.Run(ctx)) } diff --git a/integrations/access/datadog/testlib/fake_datadog.go b/integrations/access/datadog/testlib/fake_datadog.go index 64ef2e35b93b7..5cfe8b539f454 100644 --- a/integrations/access/datadog/testlib/fake_datadog.go +++ b/integrations/access/datadog/testlib/fake_datadog.go @@ -32,7 +32,6 @@ import ( "github.com/gravitational/trace" "github.com/julienschmidt/httprouter" - log "github.com/sirupsen/logrus" "github.com/gravitational/teleport/integrations/access/datadog" ) @@ -281,6 +280,6 @@ func (d *FakeDatadog) GetOncallTeams() (map[string][]string, bool) { func panicIf(err error) { if err != nil { - log.Panicf("%v at %v", err, string(debug.Stack())) + panic(fmt.Sprintf("%v at %v", err, string(debug.Stack()))) } } diff --git a/integrations/access/discord/bot.go b/integrations/access/discord/bot.go index ca231bdf83a93..576606998b23c 100644 --- a/integrations/access/discord/bot.go +++ b/integrations/access/discord/bot.go @@ -94,8 +94,7 @@ func emitStatusUpdate(resp *resty.Response, statusSink common.StatusSink) { if err := statusSink.Emit(ctx, status); err != nil { logger.Get(resp.Request.Context()). - WithError(err). - Errorf("Error while emitting Discord plugin status: %v", err) + ErrorContext(ctx, "Error while emitting Discord plugin status", "error", err) } } diff --git a/integrations/access/discord/cmd/teleport-discord/main.go b/integrations/access/discord/cmd/teleport-discord/main.go index cd19ce64591b6..f624b407742ba 100644 --- a/integrations/access/discord/cmd/teleport-discord/main.go +++ b/integrations/access/discord/cmd/teleport-discord/main.go @@ -20,6 +20,7 @@ import ( "context" _ "embed" "fmt" + "log/slog" "os" "github.com/alecthomas/kingpin/v2" @@ -65,12 +66,13 @@ func main() { if err := run(*path, *debug); err != nil { lib.Bail(err) } else { - logger.Standard().Info("Successfully shut down") + slog.InfoContext(context.Background(), "Successfully shut down") } } } func run(configPath string, debug bool) error { + ctx := context.Background() conf, err := discord.LoadDiscordConfig(configPath) if err != nil { return trace.Wrap(err) @@ -84,14 +86,15 @@ func run(configPath string, debug bool) error { return trace.Wrap(err) } if debug { - logger.Standard().Debugf("DEBUG logging enabled") + slog.DebugContext(ctx, "DEBUG logging enabled") } app := discord.NewApp(conf) go lib.ServeSignals(app, common.PluginShutdownTimeout) - logger.Standard().Infof("Starting Teleport Access Discord Plugin %s:%s", teleport.Version, teleport.Gitref) - return trace.Wrap( - app.Run(context.Background()), + slog.InfoContext(ctx, "Starting Teleport Access Discord Plugin", + "version", teleport.Version, + "git_ref", teleport.Gitref, ) + return trace.Wrap(app.Run(ctx)) } diff --git a/integrations/access/discord/testlib/fake_discord.go b/integrations/access/discord/testlib/fake_discord.go index c5a176446be5b..0a059d8ac81e2 100644 --- a/integrations/access/discord/testlib/fake_discord.go +++ b/integrations/access/discord/testlib/fake_discord.go @@ -32,7 +32,6 @@ import ( "github.com/gravitational/trace" "github.com/julienschmidt/httprouter" - log "github.com/sirupsen/logrus" "github.com/gravitational/teleport/integrations/access/discord" ) @@ -188,6 +187,6 @@ func (s *FakeDiscord) CheckMessageUpdateByResponding(ctx context.Context) (disco func panicIf(err error) { if err != nil { - log.Panicf("%v at %v", err, string(debug.Stack())) + panic(fmt.Sprintf("%v at %v", err, string(debug.Stack()))) } } diff --git a/integrations/access/email/app.go b/integrations/access/email/app.go index 07bb3b558080e..cae9c33ed5315 100644 --- a/integrations/access/email/app.go +++ b/integrations/access/email/app.go @@ -18,6 +18,7 @@ package email import ( "context" + "log/slog" "slices" "time" @@ -32,6 +33,7 @@ import ( "github.com/gravitational/teleport/integrations/lib/logger" "github.com/gravitational/teleport/integrations/lib/watcherjob" "github.com/gravitational/teleport/lib/utils" + logutils "github.com/gravitational/teleport/lib/utils/log" ) const ( @@ -90,7 +92,6 @@ func (a *App) run(ctx context.Context) error { var err error log := logger.Get(ctx) - log.Infof("Starting Teleport Access Email Plugin") if err = a.init(ctx); err != nil { return trace.Wrap(err) @@ -137,9 +138,9 @@ func (a *App) run(ctx context.Context) error { a.mainJob.SetReady(ok) if ok { - log.Info("Plugin is ready") + log.InfoContext(ctx, "Plugin is ready") } else { - log.Error("Plugin is not ready") + log.ErrorContext(ctx, "Plugin is not ready") } <-watcherJob.Done() @@ -186,24 +187,24 @@ func (a *App) init(ctx context.Context) error { }, }) - log.Debug("Starting client connection health check...") + log.DebugContext(ctx, "Starting client connection health check") if err = a.client.CheckHealth(ctx); err != nil { return trace.Wrap(err, "client connection health check failed") } - log.Debug("Client connection health check finished ok") + log.DebugContext(ctx, "Client connection health check finished ok") return nil } // checkTeleportVersion checks that Teleport version is not lower than required func (a *App) checkTeleportVersion(ctx context.Context) (proto.PingResponse, error) { log := logger.Get(ctx) - log.Debug("Checking Teleport server version") + log.DebugContext(ctx, "Checking Teleport server version") pong, err := a.apiClient.Ping(ctx) if err != nil { if trace.IsNotImplemented(err) { return pong, trace.Wrap(err, "server version must be at least %s", minServerVersion) } - log.Error("Unable to get Teleport server version") + log.ErrorContext(ctx, "Unable to get Teleport server version") return pong, trace.Wrap(err) } err = utils.CheckMinVersion(pong.ServerVersion, minServerVersion) @@ -229,16 +230,16 @@ func (a *App) handleAccessRequest(ctx context.Context, event types.Event) error } op := event.Type reqID := event.Resource.GetName() - ctx, _ = logger.WithField(ctx, "request_id", reqID) + ctx, _ = logger.With(ctx, "request_id", reqID) switch op { case types.OpPut: - ctx, _ = logger.WithField(ctx, "request_op", "put") + ctx, _ = logger.With(ctx, "request_op", "put") req, ok := event.Resource.(types.AccessRequest) if !ok { return trace.Errorf("unexpected resource type %T", event.Resource) } - ctx, log := logger.WithField(ctx, "request_state", req.GetState().String()) + ctx, log := logger.With(ctx, "request_state", req.GetState().String()) var err error switch { @@ -249,21 +250,31 @@ func (a *App) handleAccessRequest(ctx context.Context, event types.Event) error case req.GetState().IsDenied(): err = a.onResolvedRequest(ctx, req) default: - log.WithField("event", event).Warn("Unknown request state") + log.WarnContext(ctx, "Unknown request state", + slog.Group("event", + slog.Any("type", logutils.StringerAttr(event.Type)), + slog.Group("resource", + "kind", event.Resource.GetKind(), + "name", event.Resource.GetName(), + ), + ), + ) + + log.With("event", event).WarnContext(ctx, "Unknown request state") return nil } if err != nil { - log.WithError(err).Errorf("Failed to process request") + log.ErrorContext(ctx, "Failed to process request", "error", err) return trace.Wrap(err) } return nil case types.OpDelete: - ctx, log := logger.WithField(ctx, "request_op", "delete") + ctx, log := logger.With(ctx, "request_op", "delete") if err := a.onDeletedRequest(ctx, reqID); err != nil { - log.WithError(err).Errorf("Failed to process deleted request") + log.ErrorContext(ctx, "Failed to process deleted request", "error", err) return trace.Wrap(err) } return nil @@ -292,7 +303,7 @@ func (a *App) onPendingRequest(ctx context.Context, req types.AccessRequest) err if isNew { recipients := a.getRecipients(ctx, req) if len(recipients) == 0 { - log.Warning("No recipients to send") + log.WarnContext(ctx, "No recipients to send") return nil } @@ -329,7 +340,7 @@ func (a *App) onResolvedRequest(ctx context.Context, req types.AccessRequest) er case types.RequestState_DENIED: resolution.Tag = ResolvedDenied default: - logger.Get(ctx).Warningf("Unknown state %v (%s)", state, state.String()) + logger.Get(ctx).WarnContext(ctx, "Unknown state", "state", logutils.StringerAttr(state)) return replyErr } err := trace.Wrap(a.sendResolution(ctx, req.GetName(), resolution)) @@ -359,7 +370,7 @@ func (a *App) getRecipients(ctx context.Context, req types.AccessRequest) []comm rawRecipients := a.conf.RoleToRecipients.GetRawRecipientsFor(req.GetRoles(), req.GetSuggestedReviewers()) for _, rawRecipient := range rawRecipients { if !lib.IsEmail(rawRecipient) { - log.Warningf("Failed to notify a reviewer: %q does not look like a valid email", rawRecipient) + log.WarnContext(ctx, "Failed to notify a suggested reviewer with an invalid email address", "reviewer", rawRecipient) continue } recipientSet.Add(common.Recipient{ @@ -382,7 +393,7 @@ func (a *App) sendNewThreads(ctx context.Context, recipients []common.Recipient, logSentThreads(ctx, threadsSent, "new threads") if err != nil { - logger.Get(ctx).WithError(err).Error("Failed send one or more messages") + logger.Get(ctx).ErrorContext(ctx, "Failed send one or more messages", "error", err) } _, err = a.modifyPluginData(ctx, reqID, func(existing *PluginData) (PluginData, bool) { @@ -425,7 +436,7 @@ func (a *App) sendReviews(ctx context.Context, reqID string, reqData RequestData return trace.Wrap(err) } if !ok { - logger.Get(ctx).Debug("Failed to post reply: plugin data is missing") + logger.Get(ctx).DebugContext(ctx, "Failed to post reply: plugin data is missing") return nil } reviews := reqReviews[oldCount:] @@ -439,7 +450,11 @@ func (a *App) sendReviews(ctx context.Context, reqID string, reqData RequestData if err != nil { errors = append(errors, err) } - logger.Get(ctx).Infof("New review for request %v by %v is %v", reqID, review.Author, review.ProposedState.String()) + logger.Get(ctx).InfoContext(ctx, "New review for request", + "request_id", reqID, + "author", review.Author, + "state", logutils.StringerAttr(review.ProposedState), + ) logSentThreads(ctx, threadsSent, "new review") } @@ -473,7 +488,7 @@ func (a *App) sendResolution(ctx context.Context, reqID string, resolution Resol return trace.Wrap(err) } if !ok { - log.Debug("Failed to update messages: plugin data is missing") + log.DebugContext(ctx, "Failed to update messages: plugin data is missing") return nil } @@ -482,7 +497,7 @@ func (a *App) sendResolution(ctx context.Context, reqID string, resolution Resol threadsSent, err := a.client.SendResolution(ctx, threads, reqID, reqData) logSentThreads(ctx, threadsSent, "request resolved") - log.Infof("Marked request as %s and sent emails!", resolution.Tag) + log.InfoContext(ctx, "Marked request with resolution and sent emails", "resolution", resolution.Tag) if err != nil { return trace.Wrap(err) @@ -567,10 +582,11 @@ func (a *App) updatePluginData(ctx context.Context, reqID string, data PluginDat // logSentThreads logs successfully sent emails func logSentThreads(ctx context.Context, threads []EmailThread, kind string) { for _, thread := range threads { - logger.Get(ctx).WithFields(logger.Fields{ - "email": thread.Email, - "timestamp": thread.Timestamp, - "message_id": thread.MessageID, - }).Infof("Successfully sent %v!", kind) + logger.Get(ctx).InfoContext(ctx, "Successfully sent", + "email", thread.Email, + "timestamp", thread.Timestamp, + "message_id", thread.MessageID, + "kind", kind, + ) } } diff --git a/integrations/access/email/client.go b/integrations/access/email/client.go index 6ef1d2f04144e..f687f5deb0009 100644 --- a/integrations/access/email/client.go +++ b/integrations/access/email/client.go @@ -61,16 +61,16 @@ func NewClient(ctx context.Context, conf Config, clusterName, webProxyAddr strin if conf.Mailgun != nil { mailer = NewMailgunMailer(*conf.Mailgun, conf.StatusSink, conf.Delivery.Sender, clusterName, conf.RoleToRecipients[types.Wildcard]) - logger.Get(ctx).WithField("domain", conf.Mailgun.Domain).Info("Using Mailgun as email transport") + logger.Get(ctx).InfoContext(ctx, "Using Mailgun as email transport", "domain", conf.Mailgun.Domain) } if conf.SMTP != nil { mailer = NewSMTPMailer(*conf.SMTP, conf.StatusSink, conf.Delivery.Sender, clusterName) - logger.Get(ctx).WithFields(logger.Fields{ - "host": conf.SMTP.Host, - "port": conf.SMTP.Port, - "username": conf.SMTP.Username, - }).Info("Using SMTP as email transport") + logger.Get(ctx).InfoContext(ctx, "Using SMTP as email transport", + "host", conf.SMTP.Host, + "port", conf.SMTP.Port, + "username", conf.SMTP.Username, + ) } return Client{ diff --git a/integrations/access/email/cmd/teleport-email/main.go b/integrations/access/email/cmd/teleport-email/main.go index 840c80da76177..ccaec3acbed36 100644 --- a/integrations/access/email/cmd/teleport-email/main.go +++ b/integrations/access/email/cmd/teleport-email/main.go @@ -20,6 +20,7 @@ import ( "context" _ "embed" "fmt" + "log/slog" "os" "github.com/alecthomas/kingpin/v2" @@ -65,12 +66,13 @@ func main() { if err := run(*path, *debug); err != nil { lib.Bail(err) } else { - logger.Standard().Info("Successfully shut down") + slog.InfoContext(context.Background(), "Successfully shut down") } } } func run(configPath string, debug bool) error { + ctx := context.Background() conf, err := email.LoadConfig(configPath) if err != nil { return trace.Wrap(err) @@ -84,11 +86,11 @@ func run(configPath string, debug bool) error { return err } if debug { - logger.Standard().Debugf("DEBUG logging enabled") + slog.DebugContext(ctx, "DEBUG logging enabled") } if conf.Delivery.Recipients != nil { - logger.Standard().Warn("The delivery.recipients config option is deprecated, set role_to_recipients[\"*\"] instead for the same functionality") + slog.WarnContext(ctx, "The delivery.recipients config option is deprecated, set role_to_recipients[\"*\"] instead for the same functionality") } app, err := email.NewApp(*conf) @@ -98,8 +100,9 @@ func run(configPath string, debug bool) error { go lib.ServeSignals(app, common.PluginShutdownTimeout) - logger.Standard().Infof("Starting Teleport Access Email Plugin %s:%s", teleport.Version, teleport.Gitref) - return trace.Wrap( - app.Run(context.Background()), + slog.InfoContext(ctx, "Starting Teleport Access Email Plugin", + "version", teleport.Version, + "git_ref", teleport.Gitref, ) + return trace.Wrap(app.Run(ctx)) } diff --git a/integrations/access/email/mailers.go b/integrations/access/email/mailers.go index 60d5b4592449f..5cbd3d98bee02 100644 --- a/integrations/access/email/mailers.go +++ b/integrations/access/email/mailers.go @@ -114,7 +114,7 @@ func (m *SMTPMailer) CheckHealth(ctx context.Context) error { return trace.Wrap(err) } if err := client.Close(); err != nil { - log.Debug("Failed to close client connection after health check") + log.DebugContext(ctx, "Failed to close client connection after health check") } return nil } @@ -191,7 +191,7 @@ func (m *SMTPMailer) emitStatus(ctx context.Context, statusErr error) { code = http.StatusInternalServerError } if err := m.sink.Emit(ctx, common.StatusFromStatusCode(code)); err != nil { - log.WithError(err).Error("Error while emitting Email plugin status") + log.ErrorContext(ctx, "Error while emitting Email plugin status", "error", err) } } @@ -252,7 +252,7 @@ func (t *statusSinkTransport) RoundTrip(req *http.Request) (*http.Response, erro status := common.StatusFromStatusCode(resp.StatusCode) if err := t.sink.Emit(ctx, status); err != nil { - log.WithError(err).Error("Error while emitting Email plugin status") + log.ErrorContext(ctx, "Error while emitting Email plugin status", "error", err) } } return resp, nil diff --git a/integrations/access/email/testlib/mock_mailgun.go b/integrations/access/email/testlib/mock_mailgun.go index 58cbbc8ebb098..7895a5cdcaefe 100644 --- a/integrations/access/email/testlib/mock_mailgun.go +++ b/integrations/access/email/testlib/mock_mailgun.go @@ -24,7 +24,6 @@ import ( "github.com/google/uuid" "github.com/gravitational/trace" - log "github.com/sirupsen/logrus" ) const ( @@ -58,7 +57,8 @@ func newMockMailgunServer(concurrency int) *mockMailgunServer { s := httptest.NewUnstartedServer(func(mg *mockMailgunServer) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { if err := r.ParseMultipartForm(multipartFormBufSize); err != nil { - log.Error(err) + w.WriteHeader(http.StatusInternalServerError) + return } id := uuid.New().String() diff --git a/integrations/access/jira/app.go b/integrations/access/jira/app.go index 2aab94e887f0d..c8e6c8273ec02 100644 --- a/integrations/access/jira/app.go +++ b/integrations/access/jira/app.go @@ -21,6 +21,7 @@ package jira import ( "context" "fmt" + "log/slog" "net/url" "regexp" "strings" @@ -40,6 +41,7 @@ import ( "github.com/gravitational/teleport/integrations/lib/logger" "github.com/gravitational/teleport/integrations/lib/watcherjob" "github.com/gravitational/teleport/lib/utils" + logutils "github.com/gravitational/teleport/lib/utils/log" ) const ( @@ -125,7 +127,6 @@ func (a *App) run(ctx context.Context) error { var err error log := logger.Get(ctx) - log.Infof("Starting Teleport Jira Plugin") if err = a.init(ctx); err != nil { return trace.Wrap(err) @@ -164,9 +165,9 @@ func (a *App) run(ctx context.Context) error { ok := (a.webhookSrv == nil || httpOk) && watcherOk a.mainJob.SetReady(ok) if ok { - log.Info("Plugin is ready") + log.InfoContext(ctx, "Plugin is ready") } else { - log.Error("Plugin is not ready") + log.ErrorContext(ctx, "Plugin is not ready") } if httpJob != nil { @@ -205,11 +206,11 @@ func (a *App) init(ctx context.Context) error { return trace.Wrap(err) } - log.Debug("Starting Jira API health check...") + log.DebugContext(ctx, "Starting Jira API health check") if err = a.jira.HealthCheck(ctx); err != nil { return trace.Wrap(err, "api health check failed") } - log.Debug("Jira API health check finished ok") + log.DebugContext(ctx, "Jira API health check finished ok") if !a.conf.DisableWebhook { webhookSrv, err := NewWebhookServer(a.conf.HTTP, a.onJiraWebhook) @@ -227,13 +228,13 @@ func (a *App) init(ctx context.Context) error { func (a *App) checkTeleportVersion(ctx context.Context) (proto.PingResponse, error) { log := logger.Get(ctx) - log.Debug("Checking Teleport server version") + log.DebugContext(ctx, "Checking Teleport server version") pong, err := a.teleport.Ping(ctx) if err != nil { if trace.IsNotImplemented(err) { return pong, trace.Wrap(err, "server version must be at least %s", minServerVersion) } - log.Error("Unable to get Teleport server version") + log.ErrorContext(ctx, "Unable to get Teleport server version") return pong, trace.Wrap(err) } err = utils.CheckMinVersion(pong.ServerVersion, minServerVersion) @@ -246,17 +247,17 @@ func (a *App) onWatcherEvent(ctx context.Context, event types.Event) error { } op := event.Type reqID := event.Resource.GetName() - ctx, _ = logger.WithField(ctx, "request_id", reqID) + ctx, _ = logger.With(ctx, "request_id", reqID) switch op { case types.OpPut: - ctx, _ = logger.WithField(ctx, "request_op", "put") + ctx, _ = logger.With(ctx, "request_op", "put") req, ok := event.Resource.(types.AccessRequest) if !ok { return trace.Errorf("unexpected resource type %T", event.Resource) } - ctx, log := logger.WithField(ctx, "request_state", req.GetState().String()) - log.Debug("Processing watcher event") + ctx, log := logger.With(ctx, "request_state", req.GetState().String()) + log.DebugContext(ctx, "Processing watcher event") var err error switch { @@ -265,21 +266,29 @@ func (a *App) onWatcherEvent(ctx context.Context, event types.Event) error { case req.GetState().IsResolved(): err = a.onResolvedRequest(ctx, req) default: - log.WithField("event", event).Warn("Unknown request state") + log.WarnContext(ctx, "Unknown request state", + slog.Group("event", + slog.Any("type", logutils.StringerAttr(event.Type)), + slog.Group("resource", + "kind", event.Resource.GetKind(), + "name", event.Resource.GetName(), + ), + ), + ) return nil } if err != nil { - log.WithError(err).Error("Failed to process request") + log.ErrorContext(ctx, "Failed to process request", "error", err) return trace.Wrap(err) } return nil case types.OpDelete: - ctx, log := logger.WithField(ctx, "request_op", "delete") + ctx, log := logger.With(ctx, "request_op", "delete") if err := a.onDeletedRequest(ctx, reqID); err != nil { - log.WithError(err).Errorf("Failed to process deleted request") + log.ErrorContext(ctx, "Failed to process deleted request", "error", err) return trace.Wrap(err) } return nil @@ -299,10 +308,11 @@ func (a *App) onJiraWebhook(_ context.Context, webhook Webhook) error { return nil } - ctx, log := logger.WithFields(ctx, logger.Fields{ - "jira_issue_id": webhook.Issue.ID, - }) - log.Debugf("Processing incoming webhook event %q with type %q", webhookEvent, issueEventTypeName) + ctx, log := logger.With(ctx, "jira_issue_id", webhook.Issue.ID) + log.DebugContext(ctx, "Processing incoming webhook event", + "event", webhookEvent, + "event_type", issueEventTypeName, + ) if webhook.Issue == nil { return trace.Errorf("got webhook without issue info") @@ -333,20 +343,20 @@ func (a *App) onJiraWebhook(_ context.Context, webhook Webhook) error { if statusName == "" { return trace.Errorf("getting Jira issue status: %w", err) } - log.Warnf("Using most recent successful getIssue response: %v", err) + log.WarnContext(ctx, "Using most recent successful getIssue response", "error", err) } - ctx, log = logger.WithFields(ctx, logger.Fields{ - "jira_issue_id": issue.ID, - "jira_issue_key": issue.Key, - }) + ctx, log = logger.With(ctx, + "jira_issue_id", issue.ID, + "jira_issue_key", issue.Key, + ) switch { case statusName == "pending": - log.Debug("Issue has pending status, ignoring it") + log.DebugContext(ctx, "Issue has pending status, ignoring it") return nil case statusName == "expired": - log.Debug("Issue has expired status, ignoring it") + log.DebugContext(ctx, "Issue has expired status, ignoring it") return nil case statusName != "approved" && statusName != "denied": return trace.BadParameter("unknown Jira status %s", statusName) @@ -357,11 +367,11 @@ func (a *App) onJiraWebhook(_ context.Context, webhook Webhook) error { return trace.Wrap(err) } if reqID == "" { - log.Debugf("Missing %q issue property", RequestIDPropertyKey) + log.DebugContext(ctx, "Missing teleportAccessRequestId issue property") return nil } - ctx, log = logger.WithField(ctx, "request_id", reqID) + ctx, log = logger.With(ctx, "request_id", reqID) reqs, err := a.teleport.GetAccessRequests(ctx, types.AccessRequestFilter{ID: reqID}) if err != nil { @@ -382,8 +392,9 @@ func (a *App) onJiraWebhook(_ context.Context, webhook Webhook) error { return trace.Errorf("plugin data is blank") } if pluginData.IssueID != issue.ID { - log.WithField("plugin_data_issue_id", pluginData.IssueID). - Debug("plugin_data.issue_id does not match issue.id") + log.DebugContext(ctx, "plugin_data.issue_id does not match issue.id", + "plugin_data_issue_id", pluginData.IssueID, + ) return trace.Errorf("issue_id from request's plugin_data does not match") } @@ -406,17 +417,17 @@ func (a *App) onJiraWebhook(_ context.Context, webhook Webhook) error { author, reason, err := a.loadResolutionInfo(ctx, issue, statusName) if err != nil { - log.WithError(err).Error("Failed to load resolution info from the issue history") + log.ErrorContext(ctx, "Failed to load resolution info from the issue history", "error", err) } resolution.Reason = reason - ctx, _ = logger.WithFields(ctx, logger.Fields{ - "jira_user_email": author.EmailAddress, - "jira_user_name": author.DisplayName, - "request_user": req.GetUser(), - "request_roles": req.GetRoles(), - "reason": reason, - }) + ctx, _ = logger.With(ctx, + "jira_user_email", author.EmailAddress, + "jira_user_name", author.DisplayName, + "request_user", req.GetUser(), + "request_roles", req.GetRoles(), + "reason", reason, + ) if err := a.resolveRequest(ctx, reqID, author.EmailAddress, resolution); err != nil { return trace.Wrap(err) } @@ -498,11 +509,11 @@ func (a *App) createIssue(ctx context.Context, reqID string, reqData RequestData return trace.Wrap(err) } - ctx, log := logger.WithFields(ctx, logger.Fields{ - "jira_issue_id": data.IssueID, - "jira_issue_key": data.IssueKey, - }) - log.Info("Jira Issue created") + ctx, log := logger.With(ctx, + "jira_issue_id", data.IssueID, + "jira_issue_key", data.IssueKey, + ) + log.InfoContext(ctx, "Jira Issue created") // Save jira issue info in plugin data. _, err = a.modifyPluginData(ctx, reqID, func(existing *PluginData) (PluginData, bool) { @@ -551,11 +562,11 @@ func (a *App) addReviewComments(ctx context.Context, reqID string, reqReviews [] } if !ok { if issueID == "" { - logger.Get(ctx).Debug("Failed to add the comment: plugin data is blank") + logger.Get(ctx).DebugContext(ctx, "Failed to add the comment: plugin data is blank") } return nil } - ctx, _ = logger.WithField(ctx, "jira_issue_id", issueID) + ctx, _ = logger.With(ctx, "jira_issue_id", issueID) slice := reqReviews[oldCount:] if len(slice) == 0 { @@ -621,7 +632,7 @@ func (a *App) resolveRequest(ctx context.Context, reqID string, userEmail string return trace.Wrap(err) } - logger.Get(ctx).Infof("Jira user %s the request", resolution.Tag) + logger.Get(ctx).InfoContext(ctx, "Jira user processed the request", "resolution", resolution.Tag) return nil } @@ -658,18 +669,18 @@ func (a *App) resolveIssue(ctx context.Context, reqID string, resolution Resolut } if !ok { if issueID == "" { - logger.Get(ctx).Debug("Failed to resolve the issue: plugin data is blank") + logger.Get(ctx).DebugContext(ctx, "Failed to resolve the issue: plugin data is blank") } // Either plugin data is missing or issue is already resolved by us, just quit. return nil } - ctx, log := logger.WithField(ctx, "jira_issue_id", issueID) + ctx, log := logger.With(ctx, "jira_issue_id", issueID) if err := a.jira.ResolveIssue(ctx, issueID, resolution); err != nil { return trace.Wrap(err) } - log.Info("Successfully resolved the issue") + log.InfoContext(ctx, "Successfully resolved the issue") return nil } diff --git a/integrations/access/jira/client.go b/integrations/access/jira/client.go index 2877966af663b..a23381e4d2666 100644 --- a/integrations/access/jira/client.go +++ b/integrations/access/jira/client.go @@ -125,7 +125,7 @@ func NewJiraClient(conf JiraConfig, clusterName, teleportProxyAddr string, statu defer cancel() if err := statusSink.Emit(ctx, status); err != nil { - log.WithError(err).Errorf("Error while emitting Jira plugin status: %v", err) + log.ErrorContext(ctx, "Error while emitting Jira plugin status", "error", err) } } @@ -199,7 +199,7 @@ func (j *Jira) HealthCheck(ctx context.Context) error { } } - log.Debug("Checking out Jira project...") + log.DebugContext(ctx, "Checking out Jira project") var project Project _, err = j.client.NewRequest(). SetContext(ctx). @@ -209,9 +209,12 @@ func (j *Jira) HealthCheck(ctx context.Context) error { if err != nil { return trace.Wrap(err) } - log.Debugf("Found project %q named %q", project.Key, project.Name) + log.DebugContext(ctx, "Found Jira project", + "project", project.Key, + "project_name", project.Name, + ) - log.Debug("Checking out Jira project permissions...") + log.DebugContext(ctx, "Checking out Jira project permissions") queryOptions, err := query.Values(GetMyPermissionsQueryOptions{ ProjectKey: j.project, Permissions: jiraRequiredPermissions, @@ -433,7 +436,7 @@ func (j *Jira) ResolveIssue(ctx context.Context, issueID string, resolution Reso if err2 := trace.Wrap(j.TransitionIssue(ctx, issue.ID, transition.ID)); err2 != nil { return trace.NewAggregate(err1, err2) } - logger.Get(ctx).Debugf("Successfully moved the issue to the status %q", toStatus) + logger.Get(ctx).DebugContext(ctx, "Successfully moved the issue to the target status", "target_status", toStatus) return trace.Wrap(err1) } @@ -457,7 +460,7 @@ func (j *Jira) AddResolutionComment(ctx context.Context, id string, resolution R SetBody(CommentInput{Body: builder.String()}). Post("rest/api/2/issue/{issueID}/comment") if err == nil { - logger.Get(ctx).Debug("Successfully added a resolution comment to the issue") + logger.Get(ctx).DebugContext(ctx, "Successfully added a resolution comment to the issue") } return trace.Wrap(err) } diff --git a/integrations/access/jira/cmd/teleport-jira/main.go b/integrations/access/jira/cmd/teleport-jira/main.go index b2c2bb0672d06..851de27473296 100644 --- a/integrations/access/jira/cmd/teleport-jira/main.go +++ b/integrations/access/jira/cmd/teleport-jira/main.go @@ -20,6 +20,7 @@ import ( "context" _ "embed" "fmt" + "log/slog" "os" "github.com/alecthomas/kingpin/v2" @@ -72,12 +73,13 @@ func main() { if err := run(*path, *insecure, *debug); err != nil { lib.Bail(err) } else { - logger.Standard().Info("Successfully shut down") + slog.InfoContext(context.Background(), "Successfully shut down") } } } func run(configPath string, insecure bool, debug bool) error { + ctx := context.Background() conf, err := jira.LoadConfig(configPath) if err != nil { return trace.Wrap(err) @@ -91,7 +93,7 @@ func run(configPath string, insecure bool, debug bool) error { return err } if debug { - logger.Standard().Debugf("DEBUG logging enabled") + slog.DebugContext(ctx, "DEBUG logging enabled") } conf.HTTP.Insecure = insecure @@ -102,8 +104,9 @@ func run(configPath string, insecure bool, debug bool) error { go lib.ServeSignals(app, common.PluginShutdownTimeout) - logger.Standard().Infof("Starting Teleport Access Jira Plugin %s:%s", teleport.Version, teleport.Gitref) - return trace.Wrap( - app.Run(context.Background()), + slog.InfoContext(ctx, "Starting Teleport Access Jira Plugin", + "version", teleport.Version, + "git_ref", teleport.Gitref, ) + return trace.Wrap(app.Run(ctx)) } diff --git a/integrations/access/jira/testlib/fake_jira.go b/integrations/access/jira/testlib/fake_jira.go index 1da8c432ec3a9..9696500620aba 100644 --- a/integrations/access/jira/testlib/fake_jira.go +++ b/integrations/access/jira/testlib/fake_jira.go @@ -30,7 +30,6 @@ import ( "github.com/gravitational/trace" "github.com/julienschmidt/httprouter" - log "github.com/sirupsen/logrus" "github.com/gravitational/teleport/integrations/access/jira" ) @@ -304,6 +303,6 @@ func (s *FakeJira) CheckIssueTransition(ctx context.Context) (jira.Issue, error) func panicIf(err error) { if err != nil { - log.Panicf("%v at %v", err, string(debug.Stack())) + panic(fmt.Sprintf("%v at %v", err, string(debug.Stack()))) } } diff --git a/integrations/access/jira/testlib/suite.go b/integrations/access/jira/testlib/suite.go index 38341d589fa5d..c2a3d421f442c 100644 --- a/integrations/access/jira/testlib/suite.go +++ b/integrations/access/jira/testlib/suite.go @@ -721,7 +721,7 @@ func (s *JiraSuiteOSS) TestRace() { defer cancel() var lastErr error for { - logger.Get(ctx).Infof("Trying to approve issue %q", issue.Key) + logger.Get(ctx).InfoContext(ctx, "Trying to approve issue", "issue_key", issue.Key) resp, err := s.postWebhook(ctx, s.webhookURL.String(), issue.ID, "Approved") if err != nil { if lib.IsDeadline(err) { diff --git a/integrations/access/jira/webhook_server.go b/integrations/access/jira/webhook_server.go index b83e449b992c8..e9e409959b40a 100644 --- a/integrations/access/jira/webhook_server.go +++ b/integrations/access/jira/webhook_server.go @@ -105,29 +105,31 @@ func (s *WebhookServer) processWebhook(rw http.ResponseWriter, r *http.Request, defer cancel() httpRequestID := fmt.Sprintf("%v-%v", time.Now().Unix(), atomic.AddUint64(&s.counter, 1)) - ctx, log := logger.WithField(ctx, "jira_http_id", httpRequestID) + ctx, log := logger.With(ctx, "jira_http_id", httpRequestID) var webhook Webhook body, err := io.ReadAll(io.LimitReader(r.Body, jiraWebhookPayloadLimit+1)) if err != nil { - log.WithError(err).Error("Failed to read webhook payload") + log.ErrorContext(ctx, "Failed to read webhook payload", "error", err) http.Error(rw, "", http.StatusInternalServerError) return } if len(body) > jiraWebhookPayloadLimit { - log.Error("Received a webhook larger than %d bytes", jiraWebhookPayloadLimit) + log.ErrorContext(ctx, "Received a webhook with a payload that exceeded the limit", + "payload_size", len(body), + "payload_size_limit", jiraWebhookPayloadLimit, + ) http.Error(rw, "", http.StatusRequestEntityTooLarge) } if err = json.Unmarshal(body, &webhook); err != nil { - log.WithError(err).Error("Failed to parse webhook payload") + log.ErrorContext(ctx, "Failed to parse webhook payload", "error", err) http.Error(rw, "", http.StatusBadRequest) return } if err = s.onWebhook(ctx, webhook); err != nil { - log.WithError(err).Error("Failed to process webhook") - log.Debugf("%v", trace.DebugReport(err)) + log.ErrorContext(ctx, "Failed to process webhook", "error", err) var code int switch { case lib.IsCanceled(err) || lib.IsDeadline(err): diff --git a/integrations/access/mattermost/bot.go b/integrations/access/mattermost/bot.go index c7de9d0aaae44..edf0a7e73264d 100644 --- a/integrations/access/mattermost/bot.go +++ b/integrations/access/mattermost/bot.go @@ -150,7 +150,7 @@ func NewBot(conf Config, clusterName, webProxyAddr string) (Bot, error) { ctx, cancel := context.WithTimeout(context.Background(), mmStatusEmitTimeout) defer cancel() if err := sink.Emit(ctx, status); err != nil { - log.Errorf("Error while emitting plugin status: %v", err) + log.ErrorContext(ctx, "Error while emitting plugin status", "error", err) } }() @@ -463,14 +463,14 @@ func (b Bot) buildPostText(reqID string, reqData pd.AccessRequestData) (string, } func (b Bot) tryLookupDirectChannel(ctx context.Context, userEmail string) string { - log := logger.Get(ctx).WithField("mm_user_email", userEmail) + log := logger.Get(ctx).With("mm_user_email", userEmail) channel, err := b.LookupDirectChannel(ctx, userEmail) if err != nil { var errResult *ErrorResult if errors.As(trace.Unwrap(err), &errResult) { - log.Warningf("Failed to lookup direct channel info: %q", errResult.Message) + log.WarnContext(ctx, "Failed to lookup direct channel info", "error", errResult.Message) } else { - log.WithError(err).Error("Failed to lookup direct channel info") + log.ErrorContext(ctx, "Failed to lookup direct channel info", "error", err) } return "" } @@ -478,17 +478,17 @@ func (b Bot) tryLookupDirectChannel(ctx context.Context, userEmail string) strin } func (b Bot) tryLookupChannel(ctx context.Context, team, name string) string { - log := logger.Get(ctx).WithFields(logger.Fields{ - "mm_team": team, - "mm_channel": name, - }) + log := logger.Get(ctx).With( + "mm_team", team, + "mm_channel", name, + ) channel, err := b.LookupChannel(ctx, team, name) if err != nil { var errResult *ErrorResult if errors.As(trace.Unwrap(err), &errResult) { - log.Warningf("Failed to lookup channel info: %q", errResult.Message) + log.WarnContext(ctx, "Failed to lookup channel info", "error", errResult.Message) } else { - log.WithError(err).Error("Failed to lookup channel info") + log.ErrorContext(ctx, "Failed to lookup channel info", "error", err) } return "" } diff --git a/integrations/access/mattermost/cmd/teleport-mattermost/main.go b/integrations/access/mattermost/cmd/teleport-mattermost/main.go index 7c4777b26655b..0c67abb62ef86 100644 --- a/integrations/access/mattermost/cmd/teleport-mattermost/main.go +++ b/integrations/access/mattermost/cmd/teleport-mattermost/main.go @@ -20,6 +20,7 @@ import ( "context" _ "embed" "fmt" + "log/slog" "os" "github.com/alecthomas/kingpin/v2" @@ -65,12 +66,13 @@ func main() { if err := run(*path, *debug); err != nil { lib.Bail(err) } else { - logger.Standard().Info("Successfully shut down") + slog.InfoContext(context.Background(), "Successfully shut down") } } } func run(configPath string, debug bool) error { + ctx := context.Background() conf, err := mattermost.LoadConfig(configPath) if err != nil { return trace.Wrap(err) @@ -84,14 +86,15 @@ func run(configPath string, debug bool) error { return err } if debug { - logger.Standard().Debugf("DEBUG logging enabled") + slog.DebugContext(ctx, "DEBUG logging enabled") } app := mattermost.NewMattermostApp(conf) go lib.ServeSignals(app, common.PluginShutdownTimeout) - logger.Standard().Infof("Starting Teleport Access Mattermost Plugin %s:%s", teleport.Version, teleport.Gitref) - return trace.Wrap( - app.Run(context.Background()), + slog.InfoContext(ctx, "Starting Teleport Access Mattermost Plugin", + "version", teleport.Version, + "git_ref", teleport.Gitref, ) + return trace.Wrap(app.Run(ctx)) } diff --git a/integrations/access/mattermost/testlib/fake_mattermost.go b/integrations/access/mattermost/testlib/fake_mattermost.go index 10cc048e743bd..b2c28287c6153 100644 --- a/integrations/access/mattermost/testlib/fake_mattermost.go +++ b/integrations/access/mattermost/testlib/fake_mattermost.go @@ -31,7 +31,6 @@ import ( "github.com/gravitational/trace" "github.com/julienschmidt/httprouter" - log "github.com/sirupsen/logrus" "github.com/gravitational/teleport/integrations/access/mattermost" ) @@ -387,6 +386,6 @@ func (s *FakeMattermost) CheckPostUpdate(ctx context.Context) (mattermost.Post, func panicIf(err error) { if err != nil { - log.Panicf("%v at %v", err, string(debug.Stack())) + panic(fmt.Sprintf("%v at %v", err, string(debug.Stack()))) } } diff --git a/integrations/access/msteams/app.go b/integrations/access/msteams/app.go index 306be091ca8b0..b18c96ba3f4a3 100644 --- a/integrations/access/msteams/app.go +++ b/integrations/access/msteams/app.go @@ -62,14 +62,9 @@ type App struct { // NewApp initializes a new teleport-msteams app and returns it. func NewApp(conf Config) (*App, error) { - log, err := conf.Log.NewSLogLogger() - if err != nil { - return nil, trace.Wrap(err) - } - app := &App{ conf: conf, - log: log.With("plugin", pluginName), + log: slog.With("plugin", pluginName), } app.mainJob = lib.NewServiceJob(app.run) diff --git a/integrations/access/msteams/bot.go b/integrations/access/msteams/bot.go index c0598c1f4d24f..4292f856dba90 100644 --- a/integrations/access/msteams/bot.go +++ b/integrations/access/msteams/bot.go @@ -30,7 +30,6 @@ import ( "github.com/gravitational/teleport/integrations/access/common" "github.com/gravitational/teleport/integrations/access/msteams/msapi" "github.com/gravitational/teleport/integrations/lib" - "github.com/gravitational/teleport/integrations/lib/logger" "github.com/gravitational/teleport/integrations/lib/plugindata" ) @@ -469,7 +468,7 @@ func (b *Bot) CheckHealth(ctx context.Context) error { Code: status, ErrorMessage: message, }); err != nil { - logger.Get(ctx).Errorf("Error while emitting ms teams plugin status: %v", err) + b.log.ErrorContext(ctx, "Error while emitting ms teams plugin status", "error", err) } } return trace.Wrap(err) diff --git a/integrations/access/msteams/cmd/teleport-msteams/main.go b/integrations/access/msteams/cmd/teleport-msteams/main.go index 970df1ac98db4..75e66a46b7cf7 100644 --- a/integrations/access/msteams/cmd/teleport-msteams/main.go +++ b/integrations/access/msteams/cmd/teleport-msteams/main.go @@ -16,6 +16,7 @@ package main import ( "context" + "log/slog" "os" "time" @@ -99,7 +100,7 @@ func main() { if err := run(*startConfigPath, *debug); err != nil { lib.Bail(err) } else { - logger.Standard().Info("Successfully shut down") + slog.InfoContext(context.Background(), "Successfully shut down") } } diff --git a/integrations/access/msteams/testlib/fake_msteams.go b/integrations/access/msteams/testlib/fake_msteams.go index ceb1a3edc2d41..f3e4d4c5550c2 100644 --- a/integrations/access/msteams/testlib/fake_msteams.go +++ b/integrations/access/msteams/testlib/fake_msteams.go @@ -30,7 +30,6 @@ import ( "github.com/google/uuid" "github.com/gravitational/trace" "github.com/julienschmidt/httprouter" - log "github.com/sirupsen/logrus" "github.com/gravitational/teleport/integrations/access/msteams/msapi" ) @@ -326,6 +325,6 @@ func (s *FakeTeams) CheckMessageUpdate(ctx context.Context) (Msg, error) { func panicIf(err error) { if err != nil { - log.Panicf("%v at %v", err, string(debug.Stack())) + panic(fmt.Sprintf("%v at %v", err, string(debug.Stack()))) } } diff --git a/integrations/access/msteams/uninstall.go b/integrations/access/msteams/uninstall.go index e60a9ce0c8ddd..22aa9e6961ab1 100644 --- a/integrations/access/msteams/uninstall.go +++ b/integrations/access/msteams/uninstall.go @@ -18,7 +18,8 @@ import ( "context" "github.com/gravitational/trace" - log "github.com/sirupsen/logrus" + + "github.com/gravitational/teleport/integrations/lib/logger" ) func Uninstall(ctx context.Context, configPath string) error { @@ -26,11 +27,13 @@ func Uninstall(ctx context.Context, configPath string) error { if err != nil { return trace.Wrap(err) } - err = checkApp(ctx, b) - if err != nil { + + if err := checkApp(ctx, b); err != nil { return trace.Wrap(err) } + log := logger.Get(ctx) + var errs []error for _, recipient := range c.Recipients.GetAllRawRecipients() { _, isChannel := b.checkChannelURL(recipient) @@ -38,11 +41,11 @@ func Uninstall(ctx context.Context, configPath string) error { errs = append(errs, b.UninstallAppForUser(ctx, recipient)) } } - err = trace.NewAggregate(errs...) - if err != nil { - log.Errorln("The following error(s) happened when uninstalling the Teams App:") + + if trace.NewAggregate(errs...) != nil { + log.ErrorContext(ctx, "Encountered error(s) when uninstalling the Teams App", "error", err) return err } - log.Info("Successfully uninstalled app for all recipients") + log.InfoContext(ctx, "Successfully uninstalled app for all recipients") return nil } diff --git a/integrations/access/msteams/validate.go b/integrations/access/msteams/validate.go index 61d9d25f635e8..7969d7edebe0d 100644 --- a/integrations/access/msteams/validate.go +++ b/integrations/access/msteams/validate.go @@ -17,6 +17,7 @@ package msteams import ( "context" "fmt" + "log/slog" "time" cards "github.com/DanielTitkov/go-adaptive-cards" @@ -142,11 +143,7 @@ func loadConfig(configPath string) (*Bot, *Config, error) { fmt.Printf(" - Checking application %v status...\n", c.MSAPI.TeamsAppID) - log, err := c.Log.NewSLogLogger() - if err != nil { - return nil, nil, trace.Wrap(err) - } - b, err := NewBot(c, "local", "", log) + b, err := NewBot(c, "local", "", slog.Default()) if err != nil { return nil, nil, trace.Wrap(err) } diff --git a/integrations/access/opsgenie/app.go b/integrations/access/opsgenie/app.go index 132389ad5b5a3..60950f31fa4b1 100644 --- a/integrations/access/opsgenie/app.go +++ b/integrations/access/opsgenie/app.go @@ -22,6 +22,7 @@ import ( "context" "errors" "fmt" + "log/slog" "strings" "time" @@ -39,6 +40,7 @@ import ( "github.com/gravitational/teleport/integrations/lib/logger" "github.com/gravitational/teleport/integrations/lib/watcherjob" "github.com/gravitational/teleport/lib/utils" + logutils "github.com/gravitational/teleport/lib/utils/log" ) const ( @@ -115,7 +117,7 @@ func (a *App) run(ctx context.Context) error { var err error log := logger.Get(ctx) - log.Infof("Starting Teleport Access Opsgenie Plugin") + log.InfoContext(ctx, "Starting Teleport Access Opsgenie Plugin") if err = a.init(ctx); err != nil { return trace.Wrap(err) @@ -147,9 +149,9 @@ func (a *App) run(ctx context.Context) error { a.mainJob.SetReady(ok) if ok { - log.Info("Plugin is ready") + log.InfoContext(ctx, "Plugin is ready") } else { - log.Error("Plugin is not ready") + log.ErrorContext(ctx, "Plugin is not ready") } <-watcherJob.Done() @@ -177,24 +179,24 @@ func (a *App) init(ctx context.Context) error { } log := logger.Get(ctx) - log.Debug("Starting API health check...") + log.DebugContext(ctx, "Starting API health check") if err = a.opsgenie.CheckHealth(ctx); err != nil { return trace.Wrap(err, "API health check failed") } - log.Debug("API health check finished ok") + log.DebugContext(ctx, "API health check finished ok") return nil } func (a *App) checkTeleportVersion(ctx context.Context) (proto.PingResponse, error) { log := logger.Get(ctx) - log.Debug("Checking Teleport server version") + log.DebugContext(ctx, "Checking Teleport server version") pong, err := a.teleport.Ping(ctx) if err != nil { if trace.IsNotImplemented(err) { return pong, trace.Wrap(err, "server version must be at least %s", minServerVersion) } - log.Error("Unable to get Teleport server version") + log.ErrorContext(ctx, "Unable to get Teleport server version") return pong, trace.Wrap(err) } err = utils.CheckMinVersion(pong.ServerVersion, minServerVersion) @@ -219,16 +221,16 @@ func (a *App) handleAcessRequest(ctx context.Context, event types.Event) error { } op := event.Type reqID := event.Resource.GetName() - ctx, _ = logger.WithField(ctx, "request_id", reqID) + ctx, _ = logger.With(ctx, "request_id", reqID) switch op { case types.OpPut: - ctx, _ = logger.WithField(ctx, "request_op", "put") + ctx, _ = logger.With(ctx, "request_op", "put") req, ok := event.Resource.(types.AccessRequest) if !ok { return trace.Errorf("unexpected resource type %T", event.Resource) } - ctx, log := logger.WithField(ctx, "request_state", req.GetState().String()) + ctx, log := logger.With(ctx, "request_state", req.GetState().String()) var err error switch { @@ -237,21 +239,29 @@ func (a *App) handleAcessRequest(ctx context.Context, event types.Event) error { case req.GetState().IsResolved(): err = a.onResolvedRequest(ctx, req) default: - log.WithField("event", event).Warn("Unknown request state") + log.WarnContext(ctx, "Unknown request state", + slog.Group("event", + slog.Any("type", logutils.StringerAttr(event.Type)), + slog.Group("resource", + "kind", event.Resource.GetKind(), + "name", event.Resource.GetName(), + ), + ), + ) return nil } if err != nil { - log.WithError(err).Error("Failed to process request") + log.ErrorContext(ctx, "Failed to process request", "error", err) return trace.Wrap(err) } return nil case types.OpDelete: - ctx, log := logger.WithField(ctx, "request_op", "delete") + ctx, log := logger.With(ctx, "request_op", "delete") if err := a.onDeletedRequest(ctx, reqID); err != nil { - log.WithError(err).Error("Failed to process deleted request") + log.ErrorContext(ctx, "Failed to process deleted request", "error", err) return trace.Wrap(err) } return nil @@ -310,13 +320,13 @@ func (a *App) getNotifySchedulesAndTeams(ctx context.Context, req types.AccessRe scheduleAnnotationKey := types.TeleportNamespace + types.ReqAnnotationNotifySchedulesLabel schedules, err = common.GetNamesFromAnnotations(req, scheduleAnnotationKey) if err != nil { - log.Debugf("No schedules to notify in %s", scheduleAnnotationKey) + log.DebugContext(ctx, "No schedules to notify", "schedule", scheduleAnnotationKey) } teamAnnotationKey := types.TeleportNamespace + types.ReqAnnotationTeamsLabel teams, err = common.GetNamesFromAnnotations(req, teamAnnotationKey) if err != nil { - log.Debugf("No teams to notify in %s", teamAnnotationKey) + log.DebugContext(ctx, "No teams to notify", "teams", teamAnnotationKey) } if len(schedules) == 0 && len(teams) == 0 { @@ -336,7 +346,7 @@ func (a *App) tryNotifyService(ctx context.Context, req types.AccessRequest) (bo recipientSchedules, recipientTeams, err := a.getMessageRecipients(ctx, req) if err != nil { - log.Debugf("Skipping the notification: %s", err) + log.DebugContext(ctx, "Skipping notification", "error", err) return false, trace.Wrap(errMissingAnnotation) } @@ -434,8 +444,8 @@ func (a *App) createAlert(ctx context.Context, reqID string, reqData RequestData if err != nil { return trace.Wrap(err) } - ctx, log := logger.WithField(ctx, "opsgenie_alert_id", data.AlertID) - log.Info("Successfully created Opsgenie alert") + ctx, log := logger.With(ctx, "opsgenie_alert_id", data.AlertID) + log.InfoContext(ctx, "Successfully created Opsgenie alert") // Save opsgenie alert info in plugin data. _, err = a.modifyPluginData(ctx, reqID, func(existing *PluginData) (PluginData, bool) { @@ -479,10 +489,10 @@ func (a *App) postReviewNotes(ctx context.Context, reqID string, reqReviews []ty return trace.Wrap(err) } if !ok { - logger.Get(ctx).Debug("Failed to post the note: plugin data is missing") + logger.Get(ctx).DebugContext(ctx, "Failed to post the note: plugin data is missing") return nil } - ctx, _ = logger.WithField(ctx, "opsgenie_alert_id", data.AlertID) + ctx, _ = logger.With(ctx, "opsgenie_alert_id", data.AlertID) slice := reqReviews[oldCount:] if len(slice) == 0 { @@ -504,7 +514,7 @@ func (a *App) tryApproveRequest(ctx context.Context, req types.AccessRequest) er serviceNames, err := a.getOnCallServiceNames(req) if err != nil { - logger.Get(ctx).Debugf("Skipping the approval: %s", err) + logger.Get(ctx).DebugContext(ctx, "Skipping approval", "error", err) return nil } @@ -537,14 +547,14 @@ func (a *App) tryApproveRequest(ctx context.Context, req types.AccessRequest) er }, }); err != nil { if strings.HasSuffix(err.Error(), "has already reviewed this request") { - log.Debug("Already reviewed the request") + log.DebugContext(ctx, "Already reviewed the request") return nil } return trace.Wrap(err, "submitting access request") } } - log.Info("Successfully submitted a request approval") + log.InfoContext(ctx, "Successfully submitted a request approval") return nil } @@ -576,15 +586,15 @@ func (a *App) resolveAlert(ctx context.Context, reqID string, resolution Resolut return trace.Wrap(err) } if !ok { - logger.Get(ctx).Debug("Failed to resolve the alert: plugin data is missing") + logger.Get(ctx).DebugContext(ctx, "Failed to resolve the alert: plugin data is missing") return nil } - ctx, log := logger.WithField(ctx, "opsgenie_alert_id", alertID) + ctx, log := logger.With(ctx, "opsgenie_alert_id", alertID) if err := a.opsgenie.ResolveAlert(ctx, alertID, resolution); err != nil { return trace.Wrap(err) } - log.Info("Successfully resolved the alert") + log.InfoContext(ctx, "Successfully resolved the alert") return nil } diff --git a/integrations/access/opsgenie/client.go b/integrations/access/opsgenie/client.go index 2619c6ed6f7a9..2c8cdaec09a33 100644 --- a/integrations/access/opsgenie/client.go +++ b/integrations/access/opsgenie/client.go @@ -185,10 +185,10 @@ func (og Client) tryGetAlertRequestResult(ctx context.Context, reqID string) (Ge for { alertRequestResult, err := og.getAlertRequestResult(ctx, reqID) if err == nil { - logger.Get(ctx).Debugf("Got alert request result: %+v", alertRequestResult) + logger.Get(ctx).DebugContext(ctx, "Got alert request result", "alert_id", alertRequestResult.Data.AlertID) return alertRequestResult, nil } - logger.Get(ctx).Debug("Failed to get alert request result:", err) + logger.Get(ctx).DebugContext(ctx, "Failed to get alert request result", "error", err) if err := backoff.Do(ctx); err != nil { return GetAlertRequestResult{}, trace.Wrap(err) } @@ -344,8 +344,10 @@ func (og Client) CheckHealth(ctx context.Context) error { code = types.PluginStatusCode_OTHER_ERROR } if err := og.StatusSink.Emit(ctx, &types.PluginStatusV1{Code: code}); err != nil { - logger.Get(resp.Request.Context()).WithError(err). - WithField("code", resp.StatusCode()).Errorf("Error while emitting servicenow plugin status: %v", err) + logger.Get(resp.Request.Context()).ErrorContext(ctx, "Error while emitting servicenow plugin status", + "error", err, + "code", resp.StatusCode(), + ) } } diff --git a/integrations/access/opsgenie/testlib/fake_opsgenie.go b/integrations/access/opsgenie/testlib/fake_opsgenie.go index 9b5e6252119d1..1c124e19a75fc 100644 --- a/integrations/access/opsgenie/testlib/fake_opsgenie.go +++ b/integrations/access/opsgenie/testlib/fake_opsgenie.go @@ -32,7 +32,6 @@ import ( "github.com/gravitational/trace" "github.com/julienschmidt/httprouter" - log "github.com/sirupsen/logrus" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/integrations/access/opsgenie" @@ -314,7 +313,7 @@ func (s *FakeOpsgenie) GetSchedule(scheduleName string) ([]opsgenie.Responder, b func panicIf(err error) { if err != nil { - log.Panicf("%v at %v", err, string(debug.Stack())) + panic(fmt.Sprintf("%v at %v", err, string(debug.Stack()))) } } diff --git a/integrations/access/pagerduty/app.go b/integrations/access/pagerduty/app.go index 5eadcc5147cd0..2351c5d2d5f02 100644 --- a/integrations/access/pagerduty/app.go +++ b/integrations/access/pagerduty/app.go @@ -22,6 +22,7 @@ import ( "context" "errors" "fmt" + "log/slog" "strings" "time" @@ -38,6 +39,7 @@ import ( "github.com/gravitational/teleport/integrations/lib/logger" "github.com/gravitational/teleport/integrations/lib/watcherjob" "github.com/gravitational/teleport/lib/utils" + logutils "github.com/gravitational/teleport/lib/utils/log" ) const ( @@ -106,7 +108,6 @@ func (a *App) run(ctx context.Context) error { var err error log := logger.Get(ctx) - log.Infof("Starting Teleport Access PagerDuty Plugin") if err = a.init(ctx); err != nil { return trace.Wrap(err) @@ -146,9 +147,9 @@ func (a *App) run(ctx context.Context) error { a.mainJob.SetReady(ok) if ok { - log.Info("Plugin is ready") + log.InfoContext(ctx, "Plugin is ready") } else { - log.Error("Plugin is not ready") + log.ErrorContext(ctx, "Plugin is not ready") } <-watcherJob.Done() @@ -202,25 +203,25 @@ func (a *App) init(ctx context.Context) error { return trace.Wrap(err) } - log.Debug("Starting PagerDuty API health check...") + log.DebugContext(ctx, "Starting PagerDuty API health check") if err = a.pagerduty.HealthCheck(ctx); err != nil { return trace.Wrap(err, "api health check failed. check your credentials and service_id settings") } - log.Debug("PagerDuty API health check finished ok") + log.DebugContext(ctx, "PagerDuty API health check finished ok") return nil } func (a *App) checkTeleportVersion(ctx context.Context) (proto.PingResponse, error) { log := logger.Get(ctx) - log.Debug("Checking Teleport server version") + log.DebugContext(ctx, "Checking Teleport server version") pong, err := a.teleport.Ping(ctx) if err != nil { if trace.IsNotImplemented(err) { return pong, trace.Wrap(err, "server version must be at least %s", minServerVersion) } - log.Error("Unable to get Teleport server version") + log.ErrorContext(ctx, "Unable to get Teleport server version") return pong, trace.Wrap(err) } err = utils.CheckMinVersion(pong.ServerVersion, minServerVersion) @@ -245,16 +246,16 @@ func (a *App) handleAccessRequest(ctx context.Context, event types.Event) error } op := event.Type reqID := event.Resource.GetName() - ctx, _ = logger.WithField(ctx, "request_id", reqID) + ctx, _ = logger.With(ctx, "request_id", reqID) switch op { case types.OpPut: - ctx, _ = logger.WithField(ctx, "request_op", "put") + ctx, _ = logger.With(ctx, "request_op", "put") req, ok := event.Resource.(types.AccessRequest) if !ok { return trace.Errorf("unexpected resource type %T", event.Resource) } - ctx, log := logger.WithField(ctx, "request_state", req.GetState().String()) + ctx, log := logger.With(ctx, "request_state", req.GetState().String()) var err error switch { @@ -263,21 +264,29 @@ func (a *App) handleAccessRequest(ctx context.Context, event types.Event) error case req.GetState().IsResolved(): err = a.onResolvedRequest(ctx, req) default: - log.WithField("event", event).Warn("Unknown request state") + log.WarnContext(ctx, "Unknown request state", + slog.Group("event", + slog.Any("type", logutils.StringerAttr(event.Type)), + slog.Group("resource", + "kind", event.Resource.GetKind(), + "name", event.Resource.GetName(), + ), + ), + ) return nil } if err != nil { - log.WithError(err).Error("Failed to process request") + log.ErrorContext(ctx, "Failed to process request", "error", err) return trace.Wrap(err) } return nil case types.OpDelete: - ctx, log := logger.WithField(ctx, "request_op", "delete") + ctx, log := logger.With(ctx, "request_op", "delete") if err := a.onDeletedRequest(ctx, reqID); err != nil { - log.WithError(err).Error("Failed to process deleted request") + log.ErrorContext(ctx, "Failed to process deleted request", "error", err) return trace.Wrap(err) } return nil @@ -288,7 +297,7 @@ func (a *App) handleAccessRequest(ctx context.Context, event types.Event) error func (a *App) onPendingRequest(ctx context.Context, req types.AccessRequest) error { if len(req.GetSystemAnnotations()) == 0 { - logger.Get(ctx).Debug("Cannot proceed further. Request is missing any annotations") + logger.Get(ctx).DebugContext(ctx, "Cannot proceed further - request is missing any annotations") return nil } @@ -370,11 +379,11 @@ func (a *App) tryNotifyService(ctx context.Context, req types.AccessRequest) (bo serviceName, err := a.getNotifyServiceName(ctx, req) if err != nil { - log.Debugf("Skipping the notification: %s", err) + log.DebugContext(ctx, "Skipping the notification", "error", err) return false, trace.Wrap(errSkip) } - ctx, _ = logger.WithField(ctx, "pd_service_name", serviceName) + ctx, _ = logger.With(ctx, "pd_service_name", serviceName) service, err := a.pagerduty.FindServiceByName(ctx, serviceName) if err != nil { return false, trace.Wrap(err, "finding pagerduty service %s", serviceName) @@ -420,8 +429,8 @@ func (a *App) createIncident(ctx context.Context, serviceID, reqID string, reqDa if err != nil { return trace.Wrap(err) } - ctx, log := logger.WithField(ctx, "pd_incident_id", data.IncidentID) - log.Info("Successfully created PagerDuty incident") + ctx, log := logger.With(ctx, "pd_incident_id", data.IncidentID) + log.InfoContext(ctx, "Successfully created PagerDuty incident") // Save pagerduty incident info in plugin data. _, err = a.modifyPluginData(ctx, reqID, func(existing *PluginData) (PluginData, bool) { @@ -465,10 +474,10 @@ func (a *App) postReviewNotes(ctx context.Context, reqID string, reqReviews []ty return trace.Wrap(err) } if !ok { - logger.Get(ctx).Debug("Failed to post the note: plugin data is missing") + logger.Get(ctx).DebugContext(ctx, "Failed to post the note: plugin data is missing") return nil } - ctx, _ = logger.WithField(ctx, "pd_incident_id", data.IncidentID) + ctx, _ = logger.With(ctx, "pd_incident_id", data.IncidentID) slice := reqReviews[oldCount:] if len(slice) == 0 { @@ -490,36 +499,40 @@ func (a *App) tryApproveRequest(ctx context.Context, req types.AccessRequest) er serviceNames, err := a.getOnCallServiceNames(req) if err != nil { - logger.Get(ctx).Debugf("Skipping the approval: %s", err) + logger.Get(ctx).DebugContext(ctx, "Skipping approval", "error", err) return nil } userName := req.GetUser() if !lib.IsEmail(userName) { - logger.Get(ctx).Warningf("Skipping the approval: %q does not look like a valid email", userName) + logger.Get(ctx).WarnContext(ctx, "Skipping approval, found invalid email", "pd_user_email", userName) return nil } user, err := a.pagerduty.FindUserByEmail(ctx, userName) if err != nil { if trace.IsNotFound(err) { - log.WithError(err).WithField("pd_user_email", userName).Debug("Skipping the approval: email is not found") + log.DebugContext(ctx, "Skipping approval, email is not found", + "error", err, + "pd_user_email", userName) return nil } return trace.Wrap(err) } - ctx, log = logger.WithFields(ctx, logger.Fields{ - "pd_user_email": user.Email, - "pd_user_name": user.Name, - }) + ctx, log = logger.With(ctx, + "pd_user_email", user.Email, + "pd_user_name", user.Name, + ) services, err := a.pagerduty.FindServicesByNames(ctx, serviceNames) if err != nil { return trace.Wrap(err) } if len(services) == 0 { - log.WithField("pd_service_names", serviceNames).Warning("Failed to find any service") + log.WarnContext(ctx, "Failed to find any service", + "pd_service_names", serviceNames, + ) return nil } @@ -536,7 +549,7 @@ func (a *App) tryApproveRequest(ctx context.Context, req types.AccessRequest) er return trace.Wrap(err) } if len(escalationPolicyIDs) == 0 { - log.Debug("Skipping the approval: user is not on call") + log.DebugContext(ctx, "Skipping the approval: user is not on call") return nil } @@ -561,13 +574,13 @@ func (a *App) tryApproveRequest(ctx context.Context, req types.AccessRequest) er }, }); err != nil { if strings.HasSuffix(err.Error(), "has already reviewed this request") { - log.Debug("Already reviewed the request") + log.DebugContext(ctx, "Already reviewed the request") return nil } return trace.Wrap(err, "submitting access request") } - log.Info("Successfully submitted a request approval") + log.InfoContext(ctx, "Successfully submitted a request approval") return nil } @@ -599,15 +612,15 @@ func (a *App) resolveIncident(ctx context.Context, reqID string, resolution Reso return trace.Wrap(err) } if !ok { - logger.Get(ctx).Debug("Failed to resolve the incident: plugin data is missing") + logger.Get(ctx).DebugContext(ctx, "Failed to resolve the incident: plugin data is missing") return nil } - ctx, log := logger.WithField(ctx, "pd_incident_id", incidentID) + ctx, log := logger.With(ctx, "pd_incident_id", incidentID) if err := a.pagerduty.ResolveIncident(ctx, incidentID, resolution); err != nil { return trace.Wrap(err) } - log.Info("Successfully resolved the incident") + log.InfoContext(ctx, "Successfully resolved the incident") return nil } diff --git a/integrations/access/pagerduty/client.go b/integrations/access/pagerduty/client.go index 51adfb38f5aed..fd42876a154ca 100644 --- a/integrations/access/pagerduty/client.go +++ b/integrations/access/pagerduty/client.go @@ -122,7 +122,7 @@ func onAfterPagerDutyResponse(sink common.StatusSink) resty.ResponseMiddleware { defer cancel() if err := sink.Emit(ctx, status); err != nil { - log.WithError(err).Errorf("Error while emitting PagerDuty plugin status: %v", err) + log.ErrorContext(ctx, "Error while emitting PagerDuty plugin status", "error", err) } if resp.IsError() { @@ -288,7 +288,7 @@ func (p *Pagerduty) FindUserByEmail(ctx context.Context, userEmail string) (User } if len(result.Users) > 0 && result.More { - logger.Get(ctx).Warningf("PagerDuty returned too many results when querying by email %q", userEmail) + logger.Get(ctx).WarnContext(ctx, "PagerDuty returned too many results when querying user email", "email", userEmail) } return User{}, trace.NotFound("failed to find pagerduty user by email %s", userEmail) @@ -387,10 +387,10 @@ func (p *Pagerduty) FilterOnCallPolicies(ctx context.Context, userID string, esc if len(filteredIDSet) == 0 { if anyData { - logger.Get(ctx).WithFields(logger.Fields{ - "pd_user_id": userID, - "pd_escalation_policy_ids": escalationPolicyIDs, - }).Warningf("PagerDuty returned some oncalls array but none of them matched the query") + logger.Get(ctx).WarnContext(ctx, "PagerDuty returned some oncalls array but none of them matched the query", + "pd_user_id", userID, + "pd_escalation_policy_ids", escalationPolicyIDs, + ) } return nil, nil diff --git a/integrations/access/pagerduty/cmd/teleport-pagerduty/main.go b/integrations/access/pagerduty/cmd/teleport-pagerduty/main.go index aa4a8ba96eb32..58cfa27248d56 100644 --- a/integrations/access/pagerduty/cmd/teleport-pagerduty/main.go +++ b/integrations/access/pagerduty/cmd/teleport-pagerduty/main.go @@ -20,6 +20,7 @@ import ( "context" _ "embed" "fmt" + "log/slog" "os" "github.com/alecthomas/kingpin/v2" @@ -65,12 +66,13 @@ func main() { if err := run(*path, *debug); err != nil { lib.Bail(err) } else { - logger.Standard().Info("Successfully shut down") + slog.InfoContext(context.Background(), "Successfully shut down") } } } func run(configPath string, debug bool) error { + ctx := context.Background() conf, err := pagerduty.LoadConfig(configPath) if err != nil { return trace.Wrap(err) @@ -84,7 +86,7 @@ func run(configPath string, debug bool) error { return err } if debug { - logger.Standard().Debugf("DEBUG logging enabled") + slog.DebugContext(ctx, "DEBUG logging enabled") } app, err := pagerduty.NewApp(*conf) @@ -94,8 +96,9 @@ func run(configPath string, debug bool) error { go lib.ServeSignals(app, common.PluginShutdownTimeout) - logger.Standard().Infof("Starting Teleport Access PagerDuty Plugin %s:%s", teleport.Version, teleport.Gitref) - return trace.Wrap( - app.Run(context.Background()), + slog.InfoContext(ctx, "Starting Teleport Access PagerDuty Plugin", + "version", teleport.Version, + "git_ref", teleport.Gitref, ) + return trace.Wrap(app.Run(ctx)) } diff --git a/integrations/access/pagerduty/testlib/fake_pagerduty.go b/integrations/access/pagerduty/testlib/fake_pagerduty.go index 18a2a6ae24361..eee358f022458 100644 --- a/integrations/access/pagerduty/testlib/fake_pagerduty.go +++ b/integrations/access/pagerduty/testlib/fake_pagerduty.go @@ -32,7 +32,6 @@ import ( "github.com/gravitational/trace" "github.com/julienschmidt/httprouter" - log "github.com/sirupsen/logrus" "github.com/gravitational/teleport/integrations/access/pagerduty" "github.com/gravitational/teleport/integrations/lib/stringset" @@ -565,6 +564,6 @@ func (s *FakePagerduty) CheckNewIncidentNote(ctx context.Context) (FakeIncidentN func panicIf(err error) { if err != nil { - log.Panicf("%v at %v", err, string(debug.Stack())) + panic(fmt.Sprintf("%v at %v", err, string(debug.Stack()))) } } diff --git a/integrations/access/servicenow/app.go b/integrations/access/servicenow/app.go index 3d56f4fc97a8b..07248b488d872 100644 --- a/integrations/access/servicenow/app.go +++ b/integrations/access/servicenow/app.go @@ -21,6 +21,7 @@ package servicenow import ( "context" "fmt" + "log/slog" "net/url" "slices" "strings" @@ -41,6 +42,7 @@ import ( "github.com/gravitational/teleport/integrations/lib/logger" "github.com/gravitational/teleport/integrations/lib/watcherjob" "github.com/gravitational/teleport/lib/utils" + logutils "github.com/gravitational/teleport/lib/utils/log" ) const ( @@ -116,7 +118,7 @@ func (a *App) WaitReady(ctx context.Context) (bool, error) { func (a *App) run(ctx context.Context) error { log := logger.Get(ctx) - log.Infof("Starting Teleport Access Servicenow Plugin") + log.InfoContext(ctx, "Starting Teleport Access Servicenow Plugin") if err := a.init(ctx); err != nil { return trace.Wrap(err) @@ -153,9 +155,9 @@ func (a *App) run(ctx context.Context) error { } a.mainJob.SetReady(ok) if ok { - log.Info("ServiceNow plugin is ready") + log.InfoContext(ctx, "ServiceNow plugin is ready") } else { - log.Error("ServiceNow plugin is not ready") + log.ErrorContext(ctx, "ServiceNow plugin is not ready") } <-watcherJob.Done() @@ -190,25 +192,25 @@ func (a *App) init(ctx context.Context) error { return trace.Wrap(err) } - log.Debug("Starting API health check...") + log.DebugContext(ctx, "Starting API health check") if err = a.serviceNow.CheckHealth(ctx); err != nil { return trace.Wrap(err, "API health check failed") } - log.Debug("API health check finished ok") + log.DebugContext(ctx, "API health check finished ok") return nil } func (a *App) checkTeleportVersion(ctx context.Context) (proto.PingResponse, error) { log := logger.Get(ctx) - log.Debug("Checking Teleport server version") + log.DebugContext(ctx, "Checking Teleport server version") pong, err := a.teleport.Ping(ctx) if err != nil { if trace.IsNotImplemented(err) { return pong, trace.Wrap(err, "server version must be at least %s", minServerVersion) } - log.Error("Unable to get Teleport server version") + log.ErrorContext(ctx, "Unable to get Teleport server version") return pong, trace.Wrap(err) } err = utils.CheckMinVersion(pong.ServerVersion, minServerVersion) @@ -233,16 +235,16 @@ func (a *App) handleAccessRequest(ctx context.Context, event types.Event) error } op := event.Type reqID := event.Resource.GetName() - ctx, _ = logger.WithField(ctx, "request_id", reqID) + ctx, _ = logger.With(ctx, "request_id", reqID) switch op { case types.OpPut: - ctx, _ = logger.WithField(ctx, "request_op", "put") + ctx, _ = logger.With(ctx, "request_op", "put") req, ok := event.Resource.(types.AccessRequest) if !ok { return trace.Errorf("unexpected resource type %T", event.Resource) } - ctx, log := logger.WithField(ctx, "request_state", req.GetState().String()) + ctx, log := logger.With(ctx, "request_state", req.GetState().String()) var err error switch { @@ -251,21 +253,29 @@ func (a *App) handleAccessRequest(ctx context.Context, event types.Event) error case req.GetState().IsResolved(): err = a.onResolvedRequest(ctx, req) default: - log.WithField("event", event).Warnf("Unknown request state: %q", req.GetState()) + log.WarnContext(ctx, "Unknown request state", + slog.Group("event", + slog.Any("type", logutils.StringerAttr(event.Type)), + slog.Group("resource", + "kind", event.Resource.GetKind(), + "name", event.Resource.GetName(), + ), + ), + ) return nil } if err != nil { - log.WithError(err).Error("Failed to process request") + log.ErrorContext(ctx, "Failed to process request", "error", err) return trace.Wrap(err) } return nil case types.OpDelete: - ctx, log := logger.WithField(ctx, "request_op", "delete") + ctx, log := logger.With(ctx, "request_op", "delete") if err := a.onDeletedRequest(ctx, reqID); err != nil { - log.WithError(err).Error("Failed to process deleted request") + log.ErrorContext(ctx, "Failed to process deleted request", "error", err) return trace.Wrap(err) } return nil @@ -276,7 +286,7 @@ func (a *App) handleAccessRequest(ctx context.Context, event types.Event) error func (a *App) onPendingRequest(ctx context.Context, req types.AccessRequest) error { reqID := req.GetName() - log := logger.Get(ctx).WithField("reqId", reqID) + log := logger.Get(ctx).With("req_id", reqID) resourceNames, err := a.getResourceNames(ctx, req) if err != nil { @@ -303,7 +313,7 @@ func (a *App) onPendingRequest(ctx context.Context, req types.AccessRequest) err } if isNew { - log.Infof("Creating servicenow incident") + log.InfoContext(ctx, "Creating servicenow incident") recipientAssignee := a.accessMonitoringRules.RecipientsFromAccessMonitoringRules(ctx, req) assignees := []string{} recipientAssignee.ForEach(func(r common.Recipient) { @@ -375,8 +385,8 @@ func (a *App) createIncident(ctx context.Context, reqID string, reqData RequestD if err != nil { return trace.Wrap(err) } - ctx, log := logger.WithField(ctx, "servicenow_incident_id", data.IncidentID) - log.Info("Successfully created Servicenow incident") + ctx, log := logger.With(ctx, "servicenow_incident_id", data.IncidentID) + log.InfoContext(ctx, "Successfully created Servicenow incident") // Save servicenow incident info in plugin data. _, err = a.modifyPluginData(ctx, reqID, func(existing *PluginData) (PluginData, bool) { @@ -420,10 +430,10 @@ func (a *App) postReviewNotes(ctx context.Context, reqID string, reqReviews []ty return trace.Wrap(err) } if !ok { - logger.Get(ctx).Debug("Failed to post the note: plugin data is missing") + logger.Get(ctx).DebugContext(ctx, "Failed to post the note: plugin data is missing") return nil } - ctx, _ = logger.WithField(ctx, "servicenow_incident_id", data.IncidentID) + ctx, _ = logger.With(ctx, "servicenow_incident_id", data.IncidentID) slice := reqReviews[oldCount:] if len(slice) == 0 { @@ -445,22 +455,28 @@ func (a *App) tryApproveRequest(ctx context.Context, req types.AccessRequest) er serviceNames, err := a.getOnCallServiceNames(req) if err != nil { - logger.Get(ctx).Debugf("Skipping the approval: %s", err) + logger.Get(ctx).DebugContext(ctx, "Skipping the approval", "error", err) return nil } - log.Debugf("Checking the following shifts to see if the requester is on-call: %s", serviceNames) + log.DebugContext(ctx, "Checking the shifts to see if the requester is on-call", "shifts", serviceNames) onCallUsers, err := a.getOnCallUsers(ctx, serviceNames) if err != nil { return trace.Wrap(err) } - log.Debugf("Users on-call are: %s", onCallUsers) + log.DebugContext(ctx, "Users on-call are", "on_call_users", onCallUsers) if userIsOnCall := slices.Contains(onCallUsers, req.GetUser()); !userIsOnCall { - log.Debugf("User %q is not on-call, not approving the request %q.", req.GetUser(), req.GetName()) + log.DebugContext(ctx, "User is not on-call, not approving the request", + "user", req.GetUser(), + "request", req.GetName(), + ) return nil } - log.Debugf("User %q is on-call. Auto-approving the request %q.", req.GetUser(), req.GetName()) + log.DebugContext(ctx, "User is on-call, auto-approving the request", + "user", req.GetUser(), + "request", req.GetName(), + ) if _, err := a.teleport.SubmitAccessReview(ctx, types.AccessReviewSubmission{ RequestID: req.GetName(), Review: types.AccessReview{ @@ -474,12 +490,12 @@ func (a *App) tryApproveRequest(ctx context.Context, req types.AccessRequest) er }, }); err != nil { if strings.HasSuffix(err.Error(), "has already reviewed this request") { - log.Debug("Already reviewed the request") + log.DebugContext(ctx, "Already reviewed the request") return nil } return trace.Wrap(err, "submitting access request") } - log.Info("Successfully submitted a request approval") + log.InfoContext(ctx, "Successfully submitted a request approval") return nil } @@ -490,7 +506,7 @@ func (a *App) getOnCallUsers(ctx context.Context, serviceNames []string) ([]stri respondersResult, err := a.serviceNow.GetOnCall(ctx, scheduleName) if err != nil { if trace.IsNotFound(err) { - log.WithError(err).Error("Failed to retrieve responder from schedule") + log.ErrorContext(ctx, "Failed to retrieve responder from schedule", "error", err) continue } return nil, trace.Wrap(err) @@ -528,15 +544,15 @@ func (a *App) resolveIncident(ctx context.Context, reqID string, resolution Reso return trace.Wrap(err) } if !ok { - logger.Get(ctx).Debug("Failed to resolve the incident: plugin data is missing") + logger.Get(ctx).DebugContext(ctx, "Failed to resolve the incident: plugin data is missing") return nil } - ctx, log := logger.WithField(ctx, "servicenow_incident_id", incidentID) + ctx, log := logger.With(ctx, "servicenow_incident_id", incidentID) if err := a.serviceNow.ResolveIncident(ctx, incidentID, resolution); err != nil { return trace.Wrap(err) } - log.Info("Successfully resolved the incident") + log.InfoContext(ctx, "Successfully resolved the incident") return nil } diff --git a/integrations/access/servicenow/client.go b/integrations/access/servicenow/client.go index 8d0fb4f62b9de..8c306c1efa4ee 100644 --- a/integrations/access/servicenow/client.go +++ b/integrations/access/servicenow/client.go @@ -287,7 +287,10 @@ func (snc *Client) CheckHealth(ctx context.Context) error { } if err := snc.StatusSink.Emit(ctx, &types.PluginStatusV1{Code: code}); err != nil { log := logger.Get(resp.Request.Context()) - log.WithError(err).WithField("code", resp.StatusCode()).Errorf("Error while emitting servicenow plugin status: %v", err) + log.ErrorContext(ctx, "Error while emitting servicenow plugin status", + "error", err, + "code", resp.StatusCode(), + ) } } diff --git a/integrations/access/servicenow/testlib/fake_servicenow.go b/integrations/access/servicenow/testlib/fake_servicenow.go index 3b2d70e82a9b2..edf3fdced5fe7 100644 --- a/integrations/access/servicenow/testlib/fake_servicenow.go +++ b/integrations/access/servicenow/testlib/fake_servicenow.go @@ -32,7 +32,6 @@ import ( "github.com/google/uuid" "github.com/gravitational/trace" "github.com/julienschmidt/httprouter" - log "github.com/sirupsen/logrus" "github.com/gravitational/teleport/integrations/access/servicenow" "github.com/gravitational/teleport/integrations/lib/stringset" @@ -284,6 +283,6 @@ func (s *FakeServiceNow) getOnCall(rotationName string) []string { func panicIf(err error) { if err != nil { - log.Panicf("%v at %v", err, string(debug.Stack())) + panic(fmt.Sprintf("%v at %v", err, string(debug.Stack()))) } } diff --git a/integrations/access/slack/bot.go b/integrations/access/slack/bot.go index 9c58093cb9897..e7fefa0107163 100644 --- a/integrations/access/slack/bot.go +++ b/integrations/access/slack/bot.go @@ -29,7 +29,6 @@ import ( "github.com/go-resty/resty/v2" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" - log "github.com/sirupsen/logrus" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/types/accesslist" @@ -37,6 +36,7 @@ import ( "github.com/gravitational/teleport/integrations/access/accessrequest" "github.com/gravitational/teleport/integrations/access/common" "github.com/gravitational/teleport/integrations/lib" + "github.com/gravitational/teleport/integrations/lib/logger" pd "github.com/gravitational/teleport/integrations/lib/plugindata" ) @@ -68,7 +68,7 @@ func onAfterResponseSlack(sink common.StatusSink) func(_ *resty.Client, resp *re ctx, cancel := context.WithTimeout(context.Background(), statusEmitTimeout) defer cancel() if err := sink.Emit(ctx, status); err != nil { - log.Errorf("Error while emitting plugin status: %v", err) + logger.Get(ctx).ErrorContext(ctx, "Error while emitting plugin status", "error", err) } }() @@ -139,7 +139,7 @@ func (b Bot) BroadcastAccessRequestMessage(ctx context.Context, recipients []com // the case with most SSO setups. userRecipient, err := b.FetchRecipient(ctx, reqData.User) if err != nil { - log.Warningf("Unable to find user %s in Slack, will not be able to notify.", reqData.User) + logger.Get(ctx).WarnContext(ctx, "Unable to find user in Slack, will not be able to notify", "user", reqData.User) } // Include the user in the list of recipients if it exists. diff --git a/integrations/access/slack/cmd/teleport-slack/main.go b/integrations/access/slack/cmd/teleport-slack/main.go index 1f77db5f21492..ffa73144f540b 100644 --- a/integrations/access/slack/cmd/teleport-slack/main.go +++ b/integrations/access/slack/cmd/teleport-slack/main.go @@ -20,6 +20,7 @@ import ( "context" _ "embed" "fmt" + "log/slog" "os" "github.com/alecthomas/kingpin/v2" @@ -65,12 +66,13 @@ func main() { if err := run(*path, *debug); err != nil { lib.Bail(err) } else { - logger.Standard().Info("Successfully shut down") + slog.InfoContext(context.Background(), "Successfully shut down") } } } func run(configPath string, debug bool) error { + ctx := context.Background() conf, err := slack.LoadSlackConfig(configPath) if err != nil { return trace.Wrap(err) @@ -84,14 +86,15 @@ func run(configPath string, debug bool) error { return trace.Wrap(err) } if debug { - logger.Standard().Debugf("DEBUG logging enabled") + slog.DebugContext(ctx, "DEBUG logging enabled") } app := slack.NewSlackApp(conf) go lib.ServeSignals(app, common.PluginShutdownTimeout) - logger.Standard().Infof("Starting Teleport Access Slack Plugin %s:%s", teleport.Version, teleport.Gitref) - return trace.Wrap( - app.Run(context.Background()), + slog.InfoContext(ctx, "Starting Teleport Access Slack Plugin", + "version", teleport.Version, + "git_ref", teleport.Gitref, ) + return trace.Wrap(app.Run(ctx)) } diff --git a/integrations/access/slack/testlib/fake_slack.go b/integrations/access/slack/testlib/fake_slack.go index eef81460da7f1..d18a43230c744 100644 --- a/integrations/access/slack/testlib/fake_slack.go +++ b/integrations/access/slack/testlib/fake_slack.go @@ -31,7 +31,6 @@ import ( "github.com/gravitational/trace" "github.com/julienschmidt/httprouter" - log "github.com/sirupsen/logrus" "github.com/gravitational/teleport/integrations/access/slack" ) @@ -315,6 +314,6 @@ func (s *FakeSlack) CheckMessageUpdateByResponding(ctx context.Context) (slack.M func panicIf(err error) { if err != nil { - log.Panicf("%v at %v", err, string(debug.Stack())) + panic(fmt.Sprintf("%v at %v", err, string(debug.Stack()))) } } diff --git a/integrations/event-handler/fake_fluentd_test.go b/integrations/event-handler/fake_fluentd_test.go index ecf286569f12d..72a363468ba15 100644 --- a/integrations/event-handler/fake_fluentd_test.go +++ b/integrations/event-handler/fake_fluentd_test.go @@ -31,8 +31,6 @@ import ( "github.com/gravitational/trace" "github.com/stretchr/testify/require" - - "github.com/gravitational/teleport/integrations/lib/logger" ) type FakeFluentd struct { @@ -150,7 +148,6 @@ func (f *FakeFluentd) GetURL() string { func (f *FakeFluentd) Respond(w http.ResponseWriter, r *http.Request) { req, err := io.ReadAll(r.Body) if err != nil { - logger.Standard().WithError(err).Error("FakeFluentd Respond() failed to read body") fmt.Fprintln(w, "NOK") return } diff --git a/integrations/event-handler/main.go b/integrations/event-handler/main.go index 859f6544c1e06..693b5bb24e036 100644 --- a/integrations/event-handler/main.go +++ b/integrations/event-handler/main.go @@ -46,8 +46,6 @@ const ( ) func main() { - // This initializes the legacy logrus logger. This has been kept in place - // in case any of the dependencies are still using logrus. logger.Init() ctx := kong.Parse( @@ -64,17 +62,13 @@ func main() { Format: "text", } if cli.Debug { - enableLogDebug() logCfg.Severity = "debug" } - log, err := logCfg.NewSLogLogger() - if err != nil { - fmt.Println(trace.DebugReport(trace.Wrap(err, "initializing logger"))) + + if err := logger.Setup(logCfg); err != nil { + fmt.Println(trace.DebugReport(err)) os.Exit(-1) } - // Whilst this package mostly dependency injects slog, upstream dependencies - // may still use the default slog logger. - slog.SetDefault(log) switch { case ctx.Command() == "version": @@ -86,25 +80,16 @@ func main() { os.Exit(-1) } case ctx.Command() == "start": - err := start(log) + err := start(slog.Default()) if err != nil { lib.Bail(err) } else { - log.InfoContext(context.TODO(), "Successfully shut down") + slog.InfoContext(context.TODO(), "Successfully shut down") } } } -// turn on log debugging -func enableLogDebug() { - err := logger.Setup(logger.Config{Severity: "debug", Output: "stderr"}) - if err != nil { - fmt.Println(trace.DebugReport(err)) - os.Exit(-1) - } -} - // start spawns the main process func start(log *slog.Logger) error { app, err := NewApp(&cli.Start, log) diff --git a/integrations/lib/bail.go b/integrations/lib/bail.go index 72804cd0ac3c4..d1351bb05f7fe 100644 --- a/integrations/lib/bail.go +++ b/integrations/lib/bail.go @@ -19,22 +19,24 @@ package lib import ( + "context" "errors" + "log/slog" "os" "github.com/gravitational/trace" - log "github.com/sirupsen/logrus" ) // Bail exits with nonzero exit code and prints an error to a log. func Bail(err error) { + ctx := context.Background() var agg trace.Aggregate if errors.As(trace.Unwrap(err), &agg) { for i, err := range agg.Errors() { - log.WithError(err).Errorf("Terminating with fatal error [%d]...", i+1) + slog.ErrorContext(ctx, "Terminating with fatal error", "error_number", i+1, "error", err) } } else { - log.WithError(err).Error("Terminating with fatal error...") + slog.ErrorContext(ctx, "Terminating with fatal error", "error", err) } os.Exit(1) } diff --git a/integrations/lib/config.go b/integrations/lib/config.go index 24f6c981e6686..66285167e5e36 100644 --- a/integrations/lib/config.go +++ b/integrations/lib/config.go @@ -22,12 +22,12 @@ import ( "context" "errors" "io" + "log/slog" "os" "strings" "time" "github.com/gravitational/trace" - log "github.com/sirupsen/logrus" "google.golang.org/grpc" grpcbackoff "google.golang.org/grpc/backoff" @@ -137,7 +137,7 @@ func NewIdentityFileWatcher(ctx context.Context, path string, interval time.Dura } if err := dynamicCred.Reload(); err != nil { - log.WithError(err).Error("Failed to reload identity file from disk.") + slog.ErrorContext(ctx, "Failed to reload identity file from disk", "error", err) } timer.Reset(interval) } @@ -152,7 +152,7 @@ func (cfg TeleportConfig) NewClient(ctx context.Context) (*client.Client, error) case cfg.Addr != "": addr = cfg.Addr case cfg.AuthServer != "": - log.Warn("Configuration setting `auth_server` is deprecated, consider to change it to `addr`") + slog.WarnContext(ctx, "Configuration setting `auth_server` is deprecated, consider to change it to `addr`") addr = cfg.AuthServer } @@ -173,13 +173,13 @@ func (cfg TeleportConfig) NewClient(ctx context.Context) (*client.Client, error) } if validCred, err := credentials.CheckIfExpired(creds); err != nil { - log.Warn(err) + slog.WarnContext(ctx, "found expired credentials", "error", err) if !validCred { return nil, trace.BadParameter( "No valid credentials found, this likely means credentials are expired. In this case, please sign new credentials and increase their TTL if needed.", ) } - log.Info("At least one non-expired credential has been found, continuing startup") + slog.InfoContext(ctx, "At least one non-expired credential has been found, continuing startup") } bk := grpcbackoff.DefaultConfig diff --git a/integrations/lib/embeddedtbot/bot.go b/integrations/lib/embeddedtbot/bot.go index e693b40793fe5..b8ed026386114 100644 --- a/integrations/lib/embeddedtbot/bot.go +++ b/integrations/lib/embeddedtbot/bot.go @@ -26,7 +26,6 @@ import ( "time" "github.com/gravitational/trace" - log "github.com/sirupsen/logrus" "github.com/gravitational/teleport/api/client" "github.com/gravitational/teleport/api/client/proto" @@ -106,9 +105,9 @@ func (b *EmbeddedBot) start(ctx context.Context) { go func() { err := bot.Run(botCtx) if err != nil { - log.Errorf("bot exited with error: %s", err) + slog.ErrorContext(botCtx, "bot exited with error", "error", err) } else { - log.Infof("bot exited without error") + slog.InfoContext(botCtx, "bot exited without error") } b.errCh <- trace.Wrap(err) }() @@ -142,10 +141,10 @@ func (b *EmbeddedBot) waitForCredentials(ctx context.Context, deadline time.Dura select { case <-waitCtx.Done(): - log.Warn("context canceled while waiting for the bot client") + slog.WarnContext(ctx, "context canceled while waiting for the bot client") return nil, trace.Wrap(ctx.Err()) case <-b.credential.Ready(): - log.Infof("credential ready") + slog.InfoContext(ctx, "credential ready") } return b.credential, nil @@ -177,7 +176,7 @@ func (b *EmbeddedBot) StartAndWaitForCredentials(ctx context.Context, deadline t // buildClient reads tbot's memory disttination, retrieves the certificates // and builds a new Teleport client using those certs. func (b *EmbeddedBot) buildClient(ctx context.Context) (*client.Client, error) { - log.Infof("Building a new client to connect to %s", b.cfg.AuthServer) + slog.InfoContext(ctx, "Building a new client to connect to cluster", "auth_server_address", b.cfg.AuthServer) c, err := client.New(ctx, client.Config{ Addrs: []string{b.cfg.AuthServer}, Credentials: []client.Credentials{b.credential}, diff --git a/integrations/lib/http.go b/integrations/lib/http.go index dbb279913a5bd..6f98ad957a75c 100644 --- a/integrations/lib/http.go +++ b/integrations/lib/http.go @@ -24,6 +24,7 @@ import ( "crypto/x509" "errors" "fmt" + "log/slog" "net" "net/http" "net/url" @@ -33,7 +34,8 @@ import ( "github.com/gravitational/trace" "github.com/julienschmidt/httprouter" - log "github.com/sirupsen/logrus" + + logutils "github.com/gravitational/teleport/lib/utils/log" ) // TLSConfig stores TLS configuration for a http service @@ -178,7 +180,7 @@ func NewHTTP(config HTTPConfig) (*HTTP, error) { if verify := config.TLS.VerifyClientCertificateFunc; verify != nil { tlsConfig.VerifyPeerCertificate = func(_ [][]byte, chains [][]*x509.Certificate) error { if err := verify(chains); err != nil { - log.WithError(err).Error("HTTPS client certificate verification failed") + slog.ErrorContext(context.Background(), "HTTPS client certificate verification failed", "error", err) return err } return nil @@ -217,7 +219,7 @@ func BuildURLPath(args ...interface{}) string { // ListenAndServe runs a http(s) server on a provided port. func (h *HTTP) ListenAndServe(ctx context.Context) error { - defer log.Debug("HTTP server terminated") + defer slog.DebugContext(ctx, "HTTP server terminated") var err error h.server.BaseContext = func(_ net.Listener) context.Context { @@ -256,10 +258,10 @@ func (h *HTTP) ListenAndServe(ctx context.Context) error { } if h.Insecure { - log.Debugf("Starting insecure HTTP server on %s", addr) + slog.DebugContext(ctx, "Starting insecure HTTP server", "listen_addr", logutils.StringerAttr(addr)) err = h.server.Serve(listener) } else { - log.Debugf("Starting secure HTTPS server on %s", addr) + slog.DebugContext(ctx, "Starting secure HTTPS server", "listen_addr", logutils.StringerAttr(addr)) err = h.server.ServeTLS(listener, h.CertFile, h.KeyFile) } if errors.Is(err, http.ErrServerClosed) { @@ -288,7 +290,7 @@ func (h *HTTP) ServiceJob() ServiceJob { return NewServiceJob(func(ctx context.Context) error { MustGetProcess(ctx).OnTerminate(func(ctx context.Context) error { if err := h.ShutdownWithTimeout(ctx, time.Second*5); err != nil { - log.Error("HTTP server graceful shutdown failed") + slog.ErrorContext(ctx, "HTTP server graceful shutdown failed") return err } return nil diff --git a/integrations/lib/logger/logger.go b/integrations/lib/logger/logger.go index 7422f03ff906c..a1ce5bf7275ed 100644 --- a/integrations/lib/logger/logger.go +++ b/integrations/lib/logger/logger.go @@ -20,16 +20,11 @@ package logger import ( "context" - "io" - "io/fs" "log/slog" "os" - "strings" "github.com/gravitational/trace" - log "github.com/sirupsen/logrus" - "github.com/gravitational/teleport" "github.com/gravitational/teleport/lib/utils" logutils "github.com/gravitational/teleport/lib/utils/log" ) @@ -41,8 +36,6 @@ type Config struct { Format string `toml:"format"` } -type Fields = log.Fields - type contextKey struct{} var extraFields = []string{logutils.LevelField, logutils.ComponentField, logutils.CallerField} @@ -50,179 +43,50 @@ var extraFields = []string{logutils.LevelField, logutils.ComponentField, logutil // Init sets up logger for a typical daemon scenario until configuration // file is parsed func Init() { - formatter := &logutils.TextFormatter{ - EnableColors: utils.IsTerminal(os.Stderr), - ComponentPadding: 1, // We don't use components so strip the padding - ExtraFields: extraFields, - } - - log.SetOutput(os.Stderr) - if err := formatter.CheckAndSetDefaults(); err != nil { - log.WithError(err).Error("unable to create text log formatter") - return - } - - log.SetFormatter(formatter) + enableColors := utils.IsTerminal(os.Stderr) + logutils.Initialize(logutils.Config{ + Severity: slog.LevelInfo.String(), + Format: "text", + ExtraFields: extraFields, + EnableColors: enableColors, + Padding: 1, + }) } func Setup(conf Config) error { + var enableColors bool switch conf.Output { case "stderr", "error", "2": - log.SetOutput(os.Stderr) + enableColors = utils.IsTerminal(os.Stderr) case "", "stdout", "out", "1": - log.SetOutput(os.Stdout) + enableColors = utils.IsTerminal(os.Stdout) default: - // assume it's a file path: - logFile, err := os.Create(conf.Output) - if err != nil { - return trace.Wrap(err, "failed to create the log file") - } - log.SetOutput(logFile) } - switch strings.ToLower(conf.Severity) { - case "info": - log.SetLevel(log.InfoLevel) - case "err", "error": - log.SetLevel(log.ErrorLevel) - case "debug": - log.SetLevel(log.DebugLevel) - case "warn", "warning": - log.SetLevel(log.WarnLevel) - case "trace": - log.SetLevel(log.TraceLevel) - default: - return trace.BadParameter("unsupported logger severity: '%v'", conf.Severity) - } - - return nil + _, _, err := logutils.Initialize(logutils.Config{ + Output: conf.Output, + Severity: conf.Severity, + Format: conf.Format, + ExtraFields: extraFields, + EnableColors: enableColors, + Padding: 1, + }) + return trace.Wrap(err) } -// NewSLogLogger builds a slog.Logger from the logger.Config. -// TODO(tross): Defer logging initialization to logutils.Initialize and use the -// global slog loggers once integrations has been updated to use slog. -func (conf Config) NewSLogLogger() (*slog.Logger, error) { - const ( - // logFileDefaultMode is the preferred permissions mode for log file. - logFileDefaultMode fs.FileMode = 0o644 - // logFileDefaultFlag is the preferred flags set to log file. - logFileDefaultFlag = os.O_WRONLY | os.O_CREATE | os.O_APPEND - ) - - var w io.Writer - switch conf.Output { - case "": - w = logutils.NewSharedWriter(os.Stderr) - case "stderr", "error", "2": - w = logutils.NewSharedWriter(os.Stderr) - case "stdout", "out", "1": - w = logutils.NewSharedWriter(os.Stdout) - case teleport.Syslog: - w = os.Stderr - sw, err := logutils.NewSyslogWriter() - if err != nil { - slog.Default().ErrorContext(context.Background(), "Failed to switch logging to syslog", "error", err) - break - } - - // If syslog output has been configured and is supported by the operating system, - // then the shared writer is not needed because the syslog writer is already - // protected with a mutex. - w = sw - default: - // Assume this is a file path. - sharedWriter, err := logutils.NewFileSharedWriter(conf.Output, logFileDefaultFlag, logFileDefaultMode) - if err != nil { - return nil, trace.Wrap(err, "failed to init the log file shared writer") - } - w = logutils.NewWriterFinalizer[*logutils.FileSharedWriter](sharedWriter) - if err := sharedWriter.RunWatcherReopen(context.Background()); err != nil { - return nil, trace.Wrap(err) - } - } - - level := new(slog.LevelVar) - switch strings.ToLower(conf.Severity) { - case "", "info": - level.Set(slog.LevelInfo) - case "err", "error": - level.Set(slog.LevelError) - case teleport.DebugLevel: - level.Set(slog.LevelDebug) - case "warn", "warning": - level.Set(slog.LevelWarn) - case "trace": - level.Set(logutils.TraceLevel) - default: - return nil, trace.BadParameter("unsupported logger severity: %q", conf.Severity) - } - - configuredFields, err := logutils.ValidateFields(extraFields) - if err != nil { - return nil, trace.Wrap(err) - } - - var slogLogger *slog.Logger - switch strings.ToLower(conf.Format) { - case "": - fallthrough // not set. defaults to 'text' - case "text": - enableColors := utils.IsTerminal(os.Stderr) - slogLogger = slog.New(logutils.NewSlogTextHandler(w, logutils.SlogTextHandlerConfig{ - Level: level, - EnableColors: enableColors, - ConfiguredFields: configuredFields, - })) - slog.SetDefault(slogLogger) - case "json": - slogLogger = slog.New(logutils.NewSlogJSONHandler(w, logutils.SlogJSONHandlerConfig{ - Level: level, - ConfiguredFields: configuredFields, - })) - slog.SetDefault(slogLogger) - default: - return nil, trace.BadParameter("unsupported log output format : %q", conf.Format) - } - - return slogLogger, nil -} - -func WithLogger(ctx context.Context, logger log.FieldLogger) context.Context { - return withLogger(ctx, logger) -} - -func withLogger(ctx context.Context, logger log.FieldLogger) context.Context { +func WithLogger(ctx context.Context, logger *slog.Logger) context.Context { return context.WithValue(ctx, contextKey{}, logger) } -func WithField(ctx context.Context, key string, value interface{}) (context.Context, log.FieldLogger) { - logger := Get(ctx).WithField(key, value) - return withLogger(ctx, logger), logger +func With(ctx context.Context, args ...any) (context.Context, *slog.Logger) { + logger := Get(ctx).With(args...) + return WithLogger(ctx, logger), logger } -func WithFields(ctx context.Context, logFields Fields) (context.Context, log.FieldLogger) { - logger := Get(ctx).WithFields(logFields) - return withLogger(ctx, logger), logger -} - -func SetField(ctx context.Context, key string, value interface{}) context.Context { - ctx, _ = WithField(ctx, key, value) - return ctx -} - -func SetFields(ctx context.Context, logFields Fields) context.Context { - ctx, _ = WithFields(ctx, logFields) - return ctx -} - -func Get(ctx context.Context) log.FieldLogger { - if logger, ok := ctx.Value(contextKey{}).(log.FieldLogger); ok && logger != nil { +func Get(ctx context.Context) *slog.Logger { + if logger, ok := ctx.Value(contextKey{}).(*slog.Logger); ok && logger != nil { return logger } - return Standard() -} - -func Standard() log.FieldLogger { - return log.StandardLogger() + return slog.Default() } diff --git a/integrations/lib/signals.go b/integrations/lib/signals.go index 4774915a6271b..4702455dfc7ca 100644 --- a/integrations/lib/signals.go +++ b/integrations/lib/signals.go @@ -20,12 +20,11 @@ package lib import ( "context" + "log/slog" "os" "os/signal" "syscall" "time" - - log "github.com/sirupsen/logrus" ) type Terminable interface { @@ -48,9 +47,9 @@ func ServeSignals(app Terminable, shutdownTimeout time.Duration) { gracefulShutdown := func() { tctx, tcancel := context.WithTimeout(ctx, shutdownTimeout) defer tcancel() - log.Infof("Attempting graceful shutdown...") + slog.InfoContext(tctx, "Attempting graceful shutdown") if err := app.Shutdown(tctx); err != nil { - log.Infof("Graceful shutdown failed. Trying fast shutdown...") + slog.InfoContext(tctx, "Graceful shutdown failed, attempting fast shutdown") app.Close() } } diff --git a/integrations/lib/tctl/tctl.go b/integrations/lib/tctl/tctl.go index 25e7e5e95e0da..5fa0a3252b45b 100644 --- a/integrations/lib/tctl/tctl.go +++ b/integrations/lib/tctl/tctl.go @@ -27,6 +27,7 @@ import ( "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/integrations/lib/logger" + logutils "github.com/gravitational/teleport/lib/utils/log" ) var regexpStatusCAPin = regexp.MustCompile(`CA pin +(sha256:[a-zA-Z0-9]+)`) @@ -59,10 +60,14 @@ func (tctl Tctl) Sign(ctx context.Context, username, format, outPath string) err outPath, ) cmd := exec.CommandContext(ctx, tctl.cmd(), args...) - log.Debugf("Running %s", cmd) + log.DebugContext(ctx, "Running tctl auth sign", "command", logutils.StringerAttr(cmd)) output, err := cmd.CombinedOutput() if err != nil { - log.WithError(err).WithField("args", args).Debug("tctl auth sign failed:", string(output)) + log.DebugContext(ctx, "tctl auth sign failed", + "error", err, + "args", args, + "command_output", string(output), + ) return trace.Wrap(err, "tctl auth sign failed") } return nil @@ -73,7 +78,7 @@ func (tctl Tctl) Create(ctx context.Context, resources []types.Resource) error { log := logger.Get(ctx) args := append(tctl.baseArgs(), "create") cmd := exec.CommandContext(ctx, tctl.cmd(), args...) - log.Debugf("Running %s", cmd) + log.DebugContext(ctx, "Running tctl create", "command", logutils.StringerAttr(cmd)) stdinPipe, err := cmd.StdinPipe() if err != nil { return trace.Wrap(err, "failed to get stdin pipe") @@ -81,16 +86,19 @@ func (tctl Tctl) Create(ctx context.Context, resources []types.Resource) error { go func() { defer func() { if err := stdinPipe.Close(); err != nil { - log.WithError(trace.Wrap(err)).Error("Failed to close stdin pipe") + log.ErrorContext(ctx, "Failed to close stdin pipe", "error", err) } }() if err := writeResourcesYAML(stdinPipe, resources); err != nil { - log.WithError(trace.Wrap(err)).Error("Failed to serialize resources stdin") + log.ErrorContext(ctx, "Failed to serialize resources stdin", "error", err) } }() output, err := cmd.CombinedOutput() if err != nil { - log.WithError(err).Debug("tctl create failed:", string(output)) + log.DebugContext(ctx, "tctl create failed", + "error", err, + "command_output", string(output), + ) return trace.Wrap(err, "tctl create failed") } return nil @@ -102,7 +110,7 @@ func (tctl Tctl) GetAll(ctx context.Context, query string) ([]types.Resource, er args := append(tctl.baseArgs(), "get", query) cmd := exec.CommandContext(ctx, tctl.cmd(), args...) - log.Debugf("Running %s", cmd) + log.DebugContext(ctx, "Running tctl get", "command", logutils.StringerAttr(cmd)) stdoutPipe, err := cmd.StdoutPipe() if err != nil { return nil, trace.Wrap(err, "failed to get stdout") @@ -140,7 +148,7 @@ func (tctl Tctl) GetCAPin(ctx context.Context) (string, error) { args := append(tctl.baseArgs(), "status") cmd := exec.CommandContext(ctx, tctl.cmd(), args...) - log.Debugf("Running %s", cmd) + log.DebugContext(ctx, "Running tctl status", "command", logutils.StringerAttr(cmd)) output, err := cmd.Output() if err != nil { return "", trace.Wrap(err, "failed to get auth status") diff --git a/integrations/lib/testing/integration/suite.go b/integrations/lib/testing/integration/suite.go index 22c0754f66a3b..c0f03c647ef75 100644 --- a/integrations/lib/testing/integration/suite.go +++ b/integrations/lib/testing/integration/suite.go @@ -93,7 +93,7 @@ func (s *Suite) initContexts(oldT *testing.T, newT *testing.T) { } else { baseCtx = context.Background() } - baseCtx, _ = logger.WithField(baseCtx, "test", newT.Name()) + baseCtx, _ = logger.With(baseCtx, "test", newT.Name()) baseCtx, cancel := context.WithCancel(baseCtx) newT.Cleanup(cancel) @@ -163,7 +163,7 @@ func (s *Suite) StartApp(app AppI) { if err := app.Run(ctx); err != nil { // We're in a goroutine so we can't just require.NoError(t, err). // All we can do is to log an error. - logger.Get(ctx).WithError(err).Error("Application failed") + logger.Get(ctx).ErrorContext(ctx, "Application failed", "error", err) } }() diff --git a/integrations/lib/watcherjob/watcherjob.go b/integrations/lib/watcherjob/watcherjob.go index 2999b86aaad0b..a7d2d14482ae6 100644 --- a/integrations/lib/watcherjob/watcherjob.go +++ b/integrations/lib/watcherjob/watcherjob.go @@ -130,23 +130,23 @@ func newJobWithEvents(events types.Events, config Config, fn EventFunc, watchIni if config.FailFast { return trace.WrapWithMessage(err, "Connection problem detected. Exiting as fail fast is on.") } - log.WithError(err).Error("Connection problem detected. Attempting to reconnect.") + log.ErrorContext(ctx, "Connection problem detected, attempting to reconnect", "error", err) case errors.Is(err, io.EOF): if config.FailFast { return trace.WrapWithMessage(err, "Watcher stream closed. Exiting as fail fast is on.") } - log.WithError(err).Error("Watcher stream closed. Attempting to reconnect.") + log.ErrorContext(ctx, "Watcher stream closed attempting to reconnect", "error", err) case lib.IsCanceled(err): - log.Debug("Watcher context is canceled") + log.DebugContext(ctx, "Watcher context is canceled") return trace.Wrap(err) default: - log.WithError(err).Error("Watcher event loop failed") + log.ErrorContext(ctx, "Watcher event loop failed", "error", err) return trace.Wrap(err) } // To mitigate a potentially aggressive retry loop, we wait if err := bk.Do(ctx); err != nil { - log.Debug("Watcher context was canceled while waiting before a reconnection") + log.DebugContext(ctx, "Watcher context was canceled while waiting before a reconnection") return trace.Wrap(err) } } @@ -162,7 +162,7 @@ func (job job) watchEvents(ctx context.Context) error { } defer func() { if err := watcher.Close(); err != nil { - logger.Get(ctx).WithError(err).Error("Failed to close a watcher") + logger.Get(ctx).ErrorContext(ctx, "Failed to close a watcher", "error", err) } }() @@ -170,7 +170,7 @@ func (job job) watchEvents(ctx context.Context) error { return trace.Wrap(err) } - logger.Get(ctx).Debug("Watcher connected") + logger.Get(ctx).DebugContext(ctx, "Watcher connected") job.SetReady(true) for { @@ -253,7 +253,7 @@ func (job job) eventLoop(ctx context.Context) error { event := *eventPtr resource := event.Resource if resource == nil { - log.Error("received an event with empty resource field") + log.ErrorContext(ctx, "received an event with empty resource field") } key := eventKey{kind: resource.GetKind(), name: resource.GetName()} if queue, loaded := queues[key]; loaded { diff --git a/integrations/operator/crdgen/cmd/protoc-gen-crd-docs/debug.go b/integrations/operator/crdgen/cmd/protoc-gen-crd-docs/debug.go index b1c7c7339c4ba..585c82058d5fb 100644 --- a/integrations/operator/crdgen/cmd/protoc-gen-crd-docs/debug.go +++ b/integrations/operator/crdgen/cmd/protoc-gen-crd-docs/debug.go @@ -21,38 +21,37 @@ package main import ( + "context" + "log/slog" "os" - "github.com/gravitational/trace" - log "github.com/sirupsen/logrus" - crdgen "github.com/gravitational/teleport/integrations/operator/crdgen" + logutils "github.com/gravitational/teleport/lib/utils/log" ) func main() { - log.SetLevel(log.DebugLevel) - log.SetOutput(os.Stderr) + slog.SetDefault(slog.New(logutils.NewSlogTextHandler(os.Stderr, + logutils.SlogTextHandlerConfig{ + Level: slog.LevelDebug, + }, + ))) + ctx := context.Background() inputPath := os.Getenv(crdgen.PluginInputPathEnvironment) if inputPath == "" { - log.Error( - trace.BadParameter( - "When built with the 'debug' tag, the input path must be set through the environment variable: %s", - crdgen.PluginInputPathEnvironment, - ), - ) + slog.ErrorContext(ctx, "When built with the 'debug' tag, the input path must be set through the TELEPORT_PROTOC_READ_FILE environment variable") os.Exit(-1) } - log.Infof("This is a debug build, the protoc request is read from the file: '%s'", inputPath) + slog.InfoContext(ctx, "This is a debug build, the protoc request is read from the file", "input_path", inputPath) req, err := crdgen.ReadRequestFromFile(inputPath) if err != nil { - log.WithError(err).Error("error reading request from file") + slog.ErrorContext(ctx, "error reading request from file", "error", err) os.Exit(-1) } if err := crdgen.HandleDocsRequest(req); err != nil { - log.WithError(err).Error("Failed to generate docs") + slog.ErrorContext(ctx, "Failed to generate docs", "error", err) os.Exit(-1) } } diff --git a/integrations/operator/crdgen/cmd/protoc-gen-crd-docs/main.go b/integrations/operator/crdgen/cmd/protoc-gen-crd-docs/main.go index e091e5a8c1d0f..ac1be771b0bf0 100644 --- a/integrations/operator/crdgen/cmd/protoc-gen-crd-docs/main.go +++ b/integrations/operator/crdgen/cmd/protoc-gen-crd-docs/main.go @@ -21,20 +21,26 @@ package main import ( + "context" + "log/slog" "os" "github.com/gogo/protobuf/vanity/command" - log "github.com/sirupsen/logrus" crdgen "github.com/gravitational/teleport/integrations/operator/crdgen" + logutils "github.com/gravitational/teleport/lib/utils/log" ) func main() { - log.SetLevel(log.DebugLevel) - log.SetOutput(os.Stderr) + slog.SetDefault(slog.New(logutils.NewSlogTextHandler(os.Stderr, + logutils.SlogTextHandlerConfig{ + Level: slog.LevelDebug, + }, + ))) + req := command.Read() if err := crdgen.HandleDocsRequest(req); err != nil { - log.WithError(err).Error("Failed to generate schema") + slog.ErrorContext(context.Background(), "Failed to generate schema", "error", err) os.Exit(-1) } } diff --git a/integrations/operator/crdgen/cmd/protoc-gen-crd/debug.go b/integrations/operator/crdgen/cmd/protoc-gen-crd/debug.go index bf19cf7eaca87..2da3e47ab9ec8 100644 --- a/integrations/operator/crdgen/cmd/protoc-gen-crd/debug.go +++ b/integrations/operator/crdgen/cmd/protoc-gen-crd/debug.go @@ -21,38 +21,37 @@ package main import ( + "context" + "log/slog" "os" - "github.com/gravitational/trace" - log "github.com/sirupsen/logrus" - crdgen "github.com/gravitational/teleport/integrations/operator/crdgen" + logutils "github.com/gravitational/teleport/lib/utils/log" ) func main() { - log.SetLevel(log.DebugLevel) - log.SetOutput(os.Stderr) + slog.SetDefault(slog.New(logutils.NewSlogTextHandler(os.Stderr, + logutils.SlogTextHandlerConfig{ + Level: slog.LevelDebug, + }, + ))) + ctx := context.Background() inputPath := os.Getenv(crdgen.PluginInputPathEnvironment) if inputPath == "" { - log.Error( - trace.BadParameter( - "When built with the 'debug' tag, the input path must be set through the environment variable: %s", - crdgen.PluginInputPathEnvironment, - ), - ) + slog.ErrorContext(ctx, "When built with the 'debug' tag, the input path must be set through the TELEPORT_PROTOC_READ_FILE environment variable") os.Exit(-1) } - log.Infof("This is a debug build, the protoc request is read from the file: '%s'", inputPath) + slog.InfoContext(ctx, "This is a debug build, the protoc request is read from the file", "input_path", inputPath) req, err := crdgen.ReadRequestFromFile(inputPath) if err != nil { - log.WithError(err).Error("error reading request from file") + slog.ErrorContext(ctx, "error reading request from file", "error", err) os.Exit(-1) } if err := crdgen.HandleCRDRequest(req); err != nil { - log.WithError(err).Error("Failed to generate schema") + slog.ErrorContext(ctx, "Failed to generate schema", "error", err) os.Exit(-1) } } diff --git a/integrations/operator/crdgen/cmd/protoc-gen-crd/main.go b/integrations/operator/crdgen/cmd/protoc-gen-crd/main.go index 863af95862505..a557993626415 100644 --- a/integrations/operator/crdgen/cmd/protoc-gen-crd/main.go +++ b/integrations/operator/crdgen/cmd/protoc-gen-crd/main.go @@ -21,20 +21,26 @@ package main import ( + "context" + "log/slog" "os" "github.com/gogo/protobuf/vanity/command" - log "github.com/sirupsen/logrus" crdgen "github.com/gravitational/teleport/integrations/operator/crdgen" + logutils "github.com/gravitational/teleport/lib/utils/log" ) func main() { - log.SetLevel(log.DebugLevel) - log.SetOutput(os.Stderr) + slog.SetDefault(slog.New(logutils.NewSlogTextHandler(os.Stderr, + logutils.SlogTextHandlerConfig{ + Level: slog.LevelDebug, + }, + ))) + req := command.Read() if err := crdgen.HandleCRDRequest(req); err != nil { - log.WithError(err).Error("Failed to generate schema") + slog.ErrorContext(context.Background(), "Failed to generate schema", "error", err) os.Exit(-1) } } diff --git a/integrations/terraform/Makefile b/integrations/terraform/Makefile index 572a07d4d45dc..149aef0ed5b4b 100644 --- a/integrations/terraform/Makefile +++ b/integrations/terraform/Makefile @@ -47,7 +47,7 @@ $(BUILDDIR)/terraform-provider-teleport_%: terraform-provider-teleport-v$(VERSIO CUSTOM_IMPORTS_TMP_DIR ?= /tmp/protoc-gen-terraform/custom-imports # This version must match the version installed by .github/workflows/lint.yaml -PROTOC_GEN_TERRAFORM_VERSION ?= v3.0.0 +PROTOC_GEN_TERRAFORM_VERSION ?= v3.0.2 PROTOC_GEN_TERRAFORM_EXISTS := $(shell $(PROTOC_GEN_TERRAFORM) version 2>&1 >/dev/null | grep 'protoc-gen-terraform $(PROTOC_GEN_TERRAFORM_VERSION)') .PHONY: gen-tfschema diff --git a/integrations/terraform/README.md b/integrations/terraform/README.md index 53e752f725d41..dde74bc7b793b 100644 --- a/integrations/terraform/README.md +++ b/integrations/terraform/README.md @@ -7,9 +7,9 @@ Please, refer to [official documentation](https://goteleport.com/docs/admin-guid ## Development 1. Install [`protobuf`](https://grpc.io/docs/protoc-installation/). -2. Install [`protoc-gen-terraform`](https://github.com/gravitational/protoc-gen-terraform) @v3.0.0. +2. Install [`protoc-gen-terraform`](https://github.com/gravitational/protoc-gen-terraform) @v3.0.2. - ```go install github.com/gravitational/protoc-gen-terraform@c91cc3ef4d7d0046c36cb96b1cd337e466c61225``` + ```go install github.com/gravitational/protoc-gen-terraform/v3@v3.0.2``` 3. Install [`Terraform`](https://learn.hashicorp.com/tutorials/terraform/install-cli) v1.1.0+. Alternatively, you can use [`tfenv`](https://github.com/tfutils/tfenv). Please note that on Mac M1 you need to specify `TFENV_ARCH` (ex: `TFENV_ARCH=arm64 tfenv install 1.1.6`). diff --git a/integrations/terraform/go.mod b/integrations/terraform/go.mod index d3240ffff8135..5222dc914a105 100644 --- a/integrations/terraform/go.mod +++ b/integrations/terraform/go.mod @@ -21,7 +21,6 @@ require ( github.com/hashicorp/terraform-plugin-log v0.9.0 github.com/hashicorp/terraform-plugin-sdk/v2 v2.10.1 github.com/jonboulle/clockwork v0.4.0 - github.com/sirupsen/logrus v1.9.3 github.com/stretchr/testify v1.10.0 google.golang.org/grpc v1.69.2 google.golang.org/protobuf v1.36.2 @@ -307,6 +306,7 @@ require ( github.com/shirou/gopsutil/v4 v4.24.12 // indirect github.com/shopspring/decimal v1.4.0 // indirect github.com/sijms/go-ora/v2 v2.8.22 // indirect + github.com/sirupsen/logrus v1.9.3 // indirect github.com/spf13/cast v1.7.0 // indirect github.com/spf13/cobra v1.8.1 // indirect github.com/spf13/pflag v1.0.5 // indirect diff --git a/integrations/terraform/provider/errors.go b/integrations/terraform/provider/errors.go index d31715366d192..6c0f838b474bf 100644 --- a/integrations/terraform/provider/errors.go +++ b/integrations/terraform/provider/errors.go @@ -17,9 +17,11 @@ limitations under the License. package provider import ( + "context" + "log/slog" + "github.com/gravitational/trace" "github.com/hashicorp/terraform-plugin-framework/diag" - log "github.com/sirupsen/logrus" ) // diagFromWrappedErr wraps error with additional information @@ -43,7 +45,7 @@ func diagFromWrappedErr(summary string, err error, kind string) diag.Diagnostic // diagFromErr converts error to diag.Diagnostics. If logging level is debug, provides trace.DebugReport instead of short text. func diagFromErr(summary string, err error) diag.Diagnostic { - if log.GetLevel() >= log.DebugLevel { + if slog.Default().Enabled(context.Background(), slog.LevelDebug) { return diag.NewErrorDiagnostic(err.Error(), trace.DebugReport(err)) } diff --git a/integrations/terraform/provider/provider.go b/integrations/terraform/provider/provider.go index 13b20d20c434f..99d460a49f806 100644 --- a/integrations/terraform/provider/provider.go +++ b/integrations/terraform/provider/provider.go @@ -19,6 +19,7 @@ package provider import ( "context" "fmt" + "log/slog" "net" "os" "strconv" @@ -29,13 +30,13 @@ import ( "github.com/hashicorp/terraform-plugin-framework/diag" "github.com/hashicorp/terraform-plugin-framework/tfsdk" "github.com/hashicorp/terraform-plugin-framework/types" - log "github.com/sirupsen/logrus" "google.golang.org/grpc" "google.golang.org/grpc/grpclog" "github.com/gravitational/teleport/api/client" "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/lib/utils" + logutils "github.com/gravitational/teleport/lib/utils/log" ) const ( @@ -305,7 +306,7 @@ func (p *Provider) Configure(ctx context.Context, req tfsdk.ConfigureProviderReq return } - log.WithFields(log.Fields{"addr": addr}).Debug("Using Teleport address") + slog.DebugContext(ctx, "Using Teleport address", "addr", addr) dialTimeoutDuration, err := time.ParseDuration(dialTimeoutDurationStr) if err != nil { @@ -393,7 +394,7 @@ func (p *Provider) Configure(ctx context.Context, req tfsdk.ConfigureProviderReq // checkTeleportVersion ensures that Teleport version is at least minServerVersion func (p *Provider) checkTeleportVersion(ctx context.Context, client *client.Client, resp *tfsdk.ConfigureProviderResponse) bool { - log.Debug("Checking Teleport server version") + slog.DebugContext(ctx, "Checking Teleport server version") pong, err := client.Ping(ctx) if err != nil { if trace.IsNotImplemented(err) { @@ -403,13 +404,13 @@ func (p *Provider) checkTeleportVersion(ctx context.Context, client *client.Clie ) return false } - log.WithError(err).Debug("Teleport version check error!") + slog.DebugContext(ctx, "Teleport version check error", "error", err) resp.Diagnostics.AddError("Unable to get Teleport server version!", "Unable to get Teleport server version!") return false } err = utils.CheckMinVersion(pong.ServerVersion, minServerVersion) if err != nil { - log.WithError(err).Debug("Teleport version check error!") + slog.DebugContext(ctx, "Teleport version check error", "error", err) resp.Diagnostics.AddError("Teleport version check error!", err.Error()) return false } @@ -447,7 +448,7 @@ func (p *Provider) validateAddr(addr string, resp *tfsdk.ConfigureProviderRespon _, _, err := net.SplitHostPort(addr) if err != nil { - log.WithField("addr", addr).WithError(err).Debug("Teleport address format error!") + slog.DebugContext(context.Background(), "Teleport address format error", "error", err, "addr", addr) resp.Diagnostics.AddError( "Invalid Teleport address format", fmt.Sprintf("Teleport address must be specified as host:port. Got %q", addr), @@ -461,20 +462,32 @@ func (p *Provider) validateAddr(addr string, resp *tfsdk.ConfigureProviderRespon // configureLog configures logging func (p *Provider) configureLog() { + level := slog.LevelError // Get Terraform log level - level, err := log.ParseLevel(os.Getenv("TF_LOG")) - if err != nil { - log.SetLevel(log.ErrorLevel) - } else { - log.SetLevel(level) + switch strings.ToLower(os.Getenv("TF_LOG")) { + case "panic", "fatal", "error": + level = slog.LevelError + case "warn", "warning": + level = slog.LevelWarn + case "info": + level = slog.LevelInfo + case "debug": + level = slog.LevelDebug + case "trace": + level = logutils.TraceLevel } - log.SetFormatter(&log.TextFormatter{}) + _, _, err := logutils.Initialize(logutils.Config{ + Severity: level.String(), + Format: "text", + }) + if err != nil { + return + } // Show GRPC debug logs only if TF_LOG=DEBUG - if log.GetLevel() >= log.DebugLevel { - l := grpclog.NewLoggerV2(log.StandardLogger().Out, log.StandardLogger().Out, log.StandardLogger().Out) - grpclog.SetLoggerV2(l) + if level <= slog.LevelDebug { + grpclog.SetLoggerV2(grpclog.NewLoggerV2(os.Stderr, os.Stderr, os.Stderr)) } } diff --git a/lib/auth/auth.go b/lib/auth/auth.go index 82bd49e68befb..aef1a77ed2564 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -3241,39 +3241,41 @@ func generateCert(ctx context.Context, a *Server, req certRequest, caType types. return nil, trace.Wrap(err) } - params := services.UserCertParams{ - CASigner: sshSigner, - PublicUserKey: req.sshPublicKey, - Username: req.user.GetName(), - Impersonator: req.impersonator, - AllowedLogins: allowedLogins, - TTL: sessionTTL, - Roles: req.checker.RoleNames(), - CertificateFormat: certificateFormat, - PermitPortForwarding: req.checker.CanPortForward(), - PermitAgentForwarding: req.checker.CanForwardAgents(), - PermitX11Forwarding: req.checker.PermitX11Forwarding(), - RouteToCluster: req.routeToCluster, - Traits: req.traits, - ActiveRequests: req.activeRequests, - MFAVerified: req.mfaVerified, - PreviousIdentityExpires: req.previousIdentityExpires, - LoginIP: req.loginIP, - PinnedIP: pinnedIP, - DisallowReissue: req.disallowReissue, - Renewable: req.renewable, - Generation: req.generation, - BotName: req.botName, - BotInstanceID: req.botInstanceID, - CertificateExtensions: req.checker.CertificateExtensions(), - AllowedResourceIDs: requestedResourcesStr, - ConnectionDiagnosticID: req.connectionDiagnosticID, - PrivateKeyPolicy: attestedKeyPolicy, - DeviceID: req.deviceExtensions.DeviceID, - DeviceAssetTag: req.deviceExtensions.AssetTag, - DeviceCredentialID: req.deviceExtensions.CredentialID, - GitHubUserID: githubUserID, - GitHubUsername: githubUsername, + params := sshca.UserCertificateRequest{ + CASigner: sshSigner, + PublicUserKey: req.sshPublicKey, + TTL: sessionTTL, + CertificateFormat: certificateFormat, + Identity: sshca.Identity{ + Username: req.user.GetName(), + Impersonator: req.impersonator, + AllowedLogins: allowedLogins, + Roles: req.checker.RoleNames(), + PermitPortForwarding: req.checker.CanPortForward(), + PermitAgentForwarding: req.checker.CanForwardAgents(), + PermitX11Forwarding: req.checker.PermitX11Forwarding(), + RouteToCluster: req.routeToCluster, + Traits: req.traits, + ActiveRequests: req.activeRequests, + MFAVerified: req.mfaVerified, + PreviousIdentityExpires: req.previousIdentityExpires, + LoginIP: req.loginIP, + PinnedIP: pinnedIP, + DisallowReissue: req.disallowReissue, + Renewable: req.renewable, + Generation: req.generation, + BotName: req.botName, + BotInstanceID: req.botInstanceID, + CertificateExtensions: req.checker.CertificateExtensions(), + AllowedResourceIDs: requestedResourcesStr, + ConnectionDiagnosticID: req.connectionDiagnosticID, + PrivateKeyPolicy: attestedKeyPolicy, + DeviceID: req.deviceExtensions.DeviceID, + DeviceAssetTag: req.deviceExtensions.AssetTag, + DeviceCredentialID: req.deviceExtensions.CredentialID, + GitHubUserID: githubUserID, + GitHubUsername: githubUsername, + }, } signedSSHCert, err = a.GenerateUserCert(params) if err != nil { diff --git a/lib/auth/keygen/keygen.go b/lib/auth/keygen/keygen.go index cd6bb0acb28ee..5f47b3a90ac16 100644 --- a/lib/auth/keygen/keygen.go +++ b/lib/auth/keygen/keygen.go @@ -23,7 +23,6 @@ import ( "crypto/rand" "fmt" "log/slog" - "strings" "time" "github.com/gravitational/trace" @@ -31,12 +30,11 @@ import ( "golang.org/x/crypto/ssh" "github.com/gravitational/teleport" - "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/api/types/wrappers" apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/lib/modules" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/sshca" "github.com/gravitational/teleport/lib/utils" ) @@ -129,164 +127,70 @@ func (k *Keygen) GenerateHostCertWithoutValidation(c services.HostCertParams) ([ // GenerateUserCert generates a user ssh certificate with the passed in parameters. // The private key of the CA to sign the certificate must be provided. -func (k *Keygen) GenerateUserCert(c services.UserCertParams) ([]byte, error) { - if err := c.CheckAndSetDefaults(); err != nil { - return nil, trace.Wrap(err, "error validating UserCertParams") +func (k *Keygen) GenerateUserCert(req sshca.UserCertificateRequest) ([]byte, error) { + if err := req.CheckAndSetDefaults(); err != nil { + return nil, trace.Wrap(err, "error validating user certificate request") } - return k.GenerateUserCertWithoutValidation(c) + return k.GenerateUserCertWithoutValidation(req) } // GenerateUserCertWithoutValidation generates a user ssh certificate with the // passed in parameters without validating them. -func (k *Keygen) GenerateUserCertWithoutValidation(c services.UserCertParams) ([]byte, error) { - pubKey, _, _, _, err := ssh.ParseAuthorizedKey(c.PublicUserKey) +func (k *Keygen) GenerateUserCertWithoutValidation(req sshca.UserCertificateRequest) ([]byte, error) { + pubKey, _, _, _, err := ssh.ParseAuthorizedKey(req.PublicUserKey) if err != nil { return nil, trace.Wrap(err) } - validBefore := uint64(ssh.CertTimeInfinity) - if c.TTL != 0 { - b := k.clock.Now().UTC().Add(c.TTL) - validBefore = uint64(b.Unix()) + + // create shallow copy of identity since we want to make some local changes + ident := req.Identity + + // since this method ignores the supplied values for ValidBefore/ValidAfter, avoid confusing by + // rejecting identities where they are set. + if ident.ValidBefore != 0 { + return nil, trace.BadParameter("ValidBefore should not be set in calls to GenerateUserCert") + } + if ident.ValidAfter != 0 { + return nil, trace.BadParameter("ValidAfter should not be set in calls to GenerateUserCert") + } + + // calculate ValidBefore based on the outer request TTL + ident.ValidBefore = uint64(ssh.CertTimeInfinity) + if req.TTL != 0 { + b := k.clock.Now().UTC().Add(req.TTL) + ident.ValidBefore = uint64(b.Unix()) slog.DebugContext( context.TODO(), "Generated user key with expiry.", - "allowed_logins", c.AllowedLogins, - "valid_before_unix_ts", validBefore, + "allowed_logins", ident.AllowedLogins, + "valid_before_unix_ts", ident.ValidBefore, "valid_before", b, ) } - cert := &ssh.Certificate{ - // we have to use key id to identify teleport user - KeyId: c.Username, - ValidPrincipals: c.AllowedLogins, - Key: pubKey, - ValidAfter: uint64(k.clock.Now().UTC().Add(-1 * time.Minute).Unix()), - ValidBefore: validBefore, - CertType: ssh.UserCert, - } - cert.Permissions.Extensions = map[string]string{ - teleport.CertExtensionPermitPTY: "", - } - if c.PermitX11Forwarding { - cert.Permissions.Extensions[teleport.CertExtensionPermitX11Forwarding] = "" - } - if c.PermitAgentForwarding { - cert.Permissions.Extensions[teleport.CertExtensionPermitAgentForwarding] = "" - } - if c.PermitPortForwarding { - cert.Permissions.Extensions[teleport.CertExtensionPermitPortForwarding] = "" - } - if c.MFAVerified != "" { - cert.Permissions.Extensions[teleport.CertExtensionMFAVerified] = c.MFAVerified - } - if !c.PreviousIdentityExpires.IsZero() { - cert.Permissions.Extensions[teleport.CertExtensionPreviousIdentityExpires] = c.PreviousIdentityExpires.Format(time.RFC3339) - } - if c.LoginIP != "" { - cert.Permissions.Extensions[teleport.CertExtensionLoginIP] = c.LoginIP - } - if c.Impersonator != "" { - cert.Permissions.Extensions[teleport.CertExtensionImpersonator] = c.Impersonator - } - if c.DisallowReissue { - cert.Permissions.Extensions[teleport.CertExtensionDisallowReissue] = "" - } - if c.Renewable { - cert.Permissions.Extensions[teleport.CertExtensionRenewable] = "" - } - if c.Generation > 0 { - cert.Permissions.Extensions[teleport.CertExtensionGeneration] = fmt.Sprint(c.Generation) - } - if c.BotName != "" { - cert.Permissions.Extensions[teleport.CertExtensionBotName] = c.BotName - } - if c.BotInstanceID != "" { - cert.Permissions.Extensions[teleport.CertExtensionBotInstanceID] = c.BotInstanceID - } - if c.AllowedResourceIDs != "" { - cert.Permissions.Extensions[teleport.CertExtensionAllowedResources] = c.AllowedResourceIDs - } - if c.ConnectionDiagnosticID != "" { - cert.Permissions.Extensions[teleport.CertExtensionConnectionDiagnosticID] = c.ConnectionDiagnosticID - } - if c.PrivateKeyPolicy != "" { - cert.Permissions.Extensions[teleport.CertExtensionPrivateKeyPolicy] = string(c.PrivateKeyPolicy) - } - if devID := c.DeviceID; devID != "" { - cert.Permissions.Extensions[teleport.CertExtensionDeviceID] = devID - } - if assetTag := c.DeviceAssetTag; assetTag != "" { - cert.Permissions.Extensions[teleport.CertExtensionDeviceAssetTag] = assetTag - } - if credID := c.DeviceCredentialID; credID != "" { - cert.Permissions.Extensions[teleport.CertExtensionDeviceCredentialID] = credID - } - if c.GitHubUserID != "" { - cert.Permissions.Extensions[teleport.CertExtensionGitHubUserID] = c.GitHubUserID - } - if c.GitHubUsername != "" { - cert.Permissions.Extensions[teleport.CertExtensionGitHubUsername] = c.GitHubUsername - } - if c.PinnedIP != "" { + // set ValidAfter to be 1 minute in the past + ident.ValidAfter = uint64(k.clock.Now().UTC().Add(-1 * time.Minute).Unix()) + + // if the provided identity is attempting to perform IP pinning, make sure modules are enforced + if ident.PinnedIP != "" { if modules.GetModules().BuildType() != modules.BuildEnterprise { return nil, trace.AccessDenied("source IP pinning is only supported in Teleport Enterprise") } - if cert.CriticalOptions == nil { - cert.CriticalOptions = make(map[string]string) - } - // IPv4, all bits matter - ip := c.PinnedIP + "/32" - if strings.Contains(c.PinnedIP, ":") { - // IPv6 - ip = c.PinnedIP + "/128" - } - cert.CriticalOptions[teleport.CertCriticalOptionSourceAddress] = ip } - for _, extension := range c.CertificateExtensions { - // TODO(lxea): update behavior when non ssh, non extensions are supported. - if extension.Mode != types.CertExtensionMode_EXTENSION || - extension.Type != types.CertExtensionType_SSH { - continue - } - cert.Extensions[extension.Name] = extension.Value + // encode the identity into a certificate + cert, err := ident.Encode(req.CertificateFormat) + if err != nil { + return nil, trace.Wrap(err) } - // Add roles, traits, and route to cluster in the certificate extensions if - // the standard format was requested. Certificate extensions are not included - // legacy SSH certificates due to a bug in OpenSSH <= OpenSSH 7.1: - // https://bugzilla.mindrot.org/show_bug.cgi?id=2387 - if c.CertificateFormat == constants.CertificateFormatStandard { - traits, err := wrappers.MarshalTraits(&c.Traits) - if err != nil { - return nil, trace.Wrap(err) - } - if len(traits) > 0 { - cert.Permissions.Extensions[teleport.CertExtensionTeleportTraits] = string(traits) - } - if len(c.Roles) != 0 { - roles, err := services.MarshalCertRoles(c.Roles) - if err != nil { - return nil, trace.Wrap(err) - } - cert.Permissions.Extensions[teleport.CertExtensionTeleportRoles] = roles - } - if c.RouteToCluster != "" { - cert.Permissions.Extensions[teleport.CertExtensionTeleportRouteToCluster] = c.RouteToCluster - } - if !c.ActiveRequests.IsEmpty() { - requests, err := c.ActiveRequests.Marshal() - if err != nil { - return nil, trace.Wrap(err) - } - cert.Permissions.Extensions[teleport.CertExtensionTeleportActiveRequests] = string(requests) - } - } + // set the public key of the certificate + cert.Key = pubKey - if err := cert.SignCert(rand.Reader, c.CASigner); err != nil { + if err := cert.SignCert(rand.Reader, req.CASigner); err != nil { return nil, trace.Wrap(err) } + return ssh.MarshalAuthorizedKey(cert), nil } diff --git a/lib/auth/keygen/keygen_test.go b/lib/auth/keygen/keygen_test.go index e2d68d91a923e..d6c243b3ee986 100644 --- a/lib/auth/keygen/keygen_test.go +++ b/lib/auth/keygen/keygen_test.go @@ -38,6 +38,7 @@ import ( "github.com/gravitational/teleport/lib/auth/test" "github.com/gravitational/teleport/lib/cryptosuites" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/sshca" ) type nativeContext struct { @@ -226,23 +227,24 @@ func TestUserCertCompatibility(t *testing.T) { for i, tc := range tests { comment := fmt.Sprintf("Test %v", i) - userCertificateBytes, err := tt.suite.A.GenerateUserCert(services.UserCertParams{ - CASigner: caSigner, - PublicUserKey: ssh.MarshalAuthorizedKey(caSigner.PublicKey()), - Username: "user", - AllowedLogins: []string{"centos", "root"}, - TTL: time.Hour, - Roles: []string{"foo"}, - CertificateExtensions: []*types.CertExtension{{ - Type: types.CertExtensionType_SSH, - Mode: types.CertExtensionMode_EXTENSION, - Name: "login@github.com", - Value: "hello", + userCertificateBytes, err := tt.suite.A.GenerateUserCert(sshca.UserCertificateRequest{ + CASigner: caSigner, + PublicUserKey: ssh.MarshalAuthorizedKey(caSigner.PublicKey()), + TTL: time.Hour, + CertificateFormat: tc.inCompatibility, + Identity: sshca.Identity{ + Username: "user", + AllowedLogins: []string{"centos", "root"}, + Roles: []string{"foo"}, + CertificateExtensions: []*types.CertExtension{{ + Type: types.CertExtensionType_SSH, + Mode: types.CertExtensionMode_EXTENSION, + Name: "login@github.com", + Value: "hello", + }}, + PermitAgentForwarding: true, + PermitPortForwarding: true, }, - }, - CertificateFormat: tc.inCompatibility, - PermitAgentForwarding: true, - PermitPortForwarding: true, }) require.NoError(t, err, comment) diff --git a/lib/auth/test/suite.go b/lib/auth/test/suite.go index 3e97874d8802e..14d22f8265647 100644 --- a/lib/auth/test/suite.go +++ b/lib/auth/test/suite.go @@ -95,15 +95,17 @@ func (s *AuthSuite) GenerateUserCert(t *testing.T) { caSigner, err := ssh.ParsePrivateKey(priv) require.NoError(t, err) - cert, err := s.A.GenerateUserCert(services.UserCertParams{ - CASigner: caSigner, - PublicUserKey: pub, - Username: "user", - AllowedLogins: []string{"centos", "root"}, - TTL: time.Hour, - PermitAgentForwarding: true, - PermitPortForwarding: true, - CertificateFormat: constants.CertificateFormatStandard, + cert, err := s.A.GenerateUserCert(sshca.UserCertificateRequest{ + CASigner: caSigner, + PublicUserKey: pub, + TTL: time.Hour, + CertificateFormat: constants.CertificateFormatStandard, + Identity: sshca.Identity{ + Username: "user", + AllowedLogins: []string{"centos", "root"}, + PermitAgentForwarding: true, + PermitPortForwarding: true, + }, }) require.NoError(t, err) @@ -112,59 +114,67 @@ func (s *AuthSuite) GenerateUserCert(t *testing.T) { err = checkCertExpiry(cert, s.Clock.Now().Add(-1*time.Minute), s.Clock.Now().Add(1*time.Hour)) require.NoError(t, err) - cert, err = s.A.GenerateUserCert(services.UserCertParams{ - CASigner: caSigner, - PublicUserKey: pub, - Username: "user", - AllowedLogins: []string{"root"}, - TTL: -20, - PermitAgentForwarding: true, - PermitPortForwarding: true, - CertificateFormat: constants.CertificateFormatStandard, + cert, err = s.A.GenerateUserCert(sshca.UserCertificateRequest{ + CASigner: caSigner, + PublicUserKey: pub, + TTL: -20, + CertificateFormat: constants.CertificateFormatStandard, + Identity: sshca.Identity{ + Username: "user", + AllowedLogins: []string{"root"}, + PermitAgentForwarding: true, + PermitPortForwarding: true, + }, }) require.NoError(t, err) err = checkCertExpiry(cert, s.Clock.Now().Add(-1*time.Minute), s.Clock.Now().Add(apidefaults.MinCertDuration)) require.NoError(t, err) - _, err = s.A.GenerateUserCert(services.UserCertParams{ - CASigner: caSigner, - PublicUserKey: pub, - Username: "user", - AllowedLogins: []string{"root"}, - TTL: 0, - PermitAgentForwarding: true, - PermitPortForwarding: true, - CertificateFormat: constants.CertificateFormatStandard, + _, err = s.A.GenerateUserCert(sshca.UserCertificateRequest{ + CASigner: caSigner, + PublicUserKey: pub, + TTL: 0, + CertificateFormat: constants.CertificateFormatStandard, + Identity: sshca.Identity{ + Username: "user", + AllowedLogins: []string{"root"}, + PermitAgentForwarding: true, + PermitPortForwarding: true, + }, }) require.NoError(t, err) err = checkCertExpiry(cert, s.Clock.Now().Add(-1*time.Minute), s.Clock.Now().Add(apidefaults.MinCertDuration)) require.NoError(t, err) - _, err = s.A.GenerateUserCert(services.UserCertParams{ - CASigner: caSigner, - PublicUserKey: pub, - Username: "user", - AllowedLogins: []string{"root"}, - TTL: time.Hour, - PermitAgentForwarding: true, - PermitPortForwarding: true, - CertificateFormat: constants.CertificateFormatStandard, + _, err = s.A.GenerateUserCert(sshca.UserCertificateRequest{ + CASigner: caSigner, + PublicUserKey: pub, + TTL: time.Hour, + CertificateFormat: constants.CertificateFormatStandard, + Identity: sshca.Identity{ + Username: "user", + AllowedLogins: []string{"root"}, + PermitAgentForwarding: true, + PermitPortForwarding: true, + }, }) require.NoError(t, err) inRoles := []string{"role-1", "role-2"} impersonator := "alice" - cert, err = s.A.GenerateUserCert(services.UserCertParams{ - CASigner: caSigner, - PublicUserKey: pub, - Username: "user", - Impersonator: impersonator, - AllowedLogins: []string{"root"}, - TTL: time.Hour, - PermitAgentForwarding: true, - PermitPortForwarding: true, - CertificateFormat: constants.CertificateFormatStandard, - Roles: inRoles, + cert, err = s.A.GenerateUserCert(sshca.UserCertificateRequest{ + CASigner: caSigner, + PublicUserKey: pub, + TTL: time.Hour, + CertificateFormat: constants.CertificateFormatStandard, + Identity: sshca.Identity{ + Username: "user", + Impersonator: impersonator, + AllowedLogins: []string{"root"}, + PermitAgentForwarding: true, + PermitPortForwarding: true, + Roles: inRoles, + }, }) require.NoError(t, err) parsedCert, err := sshutils.ParseCertificate(cert) @@ -178,15 +188,17 @@ func (s *AuthSuite) GenerateUserCert(t *testing.T) { // Check that MFAVerified and PreviousIdentityExpires are encoded into ssh cert clock := clockwork.NewFakeClock() - cert, err = s.A.GenerateUserCert(services.UserCertParams{ - CASigner: caSigner, - PublicUserKey: pub, - Username: "user", - AllowedLogins: []string{"root"}, - TTL: time.Minute, - CertificateFormat: constants.CertificateFormatStandard, - MFAVerified: "mfa-device-id", - PreviousIdentityExpires: clock.Now().Add(time.Hour), + cert, err = s.A.GenerateUserCert(sshca.UserCertificateRequest{ + CASigner: caSigner, + PublicUserKey: pub, + TTL: time.Minute, + CertificateFormat: constants.CertificateFormatStandard, + Identity: sshca.Identity{ + Username: "user", + AllowedLogins: []string{"root"}, + MFAVerified: "mfa-device-id", + PreviousIdentityExpires: clock.Now().Add(time.Hour), + }, }) require.NoError(t, err) parsedCert, err = sshutils.ParseCertificate(cert) @@ -202,14 +214,16 @@ func (s *AuthSuite) GenerateUserCert(t *testing.T) { const devID = "deviceid1" const devTag = "devicetag1" const devCred = "devicecred1" - certRaw, err := s.A.GenerateUserCert(services.UserCertParams{ - CASigner: caSigner, // Required. - PublicUserKey: pub, // Required. - Username: "llama", // Required. - AllowedLogins: []string{"llama"}, // Required. - DeviceID: devID, - DeviceAssetTag: devTag, - DeviceCredentialID: devCred, + certRaw, err := s.A.GenerateUserCert(sshca.UserCertificateRequest{ + CASigner: caSigner, // Required. + PublicUserKey: pub, // Required. + Identity: sshca.Identity{ + Username: "llama", // Required. + AllowedLogins: []string{"llama"}, // Required. + DeviceID: devID, + DeviceAssetTag: devTag, + DeviceCredentialID: devCred, + }, }) require.NoError(t, err, "GenerateUserCert failed") @@ -223,13 +237,15 @@ func (s *AuthSuite) GenerateUserCert(t *testing.T) { t.Run("github identity", func(t *testing.T) { githubUserID := "1234567" githubUsername := "github-user" - certRaw, err := s.A.GenerateUserCert(services.UserCertParams{ - CASigner: caSigner, // Required. - PublicUserKey: pub, // Required. - Username: "llama", // Required. - AllowedLogins: []string{"llama"}, // Required. - GitHubUserID: githubUserID, - GitHubUsername: githubUsername, + certRaw, err := s.A.GenerateUserCert(sshca.UserCertificateRequest{ + CASigner: caSigner, // Required. + PublicUserKey: pub, // Required. + Identity: sshca.Identity{ + Username: "llama", // Required. + AllowedLogins: []string{"llama"}, // Required. + GitHubUserID: githubUserID, + GitHubUsername: githubUsername, + }, }) require.NoError(t, err, "GenerateUserCert failed") diff --git a/lib/auth/testauthority/testauthority.go b/lib/auth/testauthority/testauthority.go index 8dae039d9c1f4..b58f9ac27493d 100644 --- a/lib/auth/testauthority/testauthority.go +++ b/lib/auth/testauthority/testauthority.go @@ -29,6 +29,7 @@ import ( "github.com/gravitational/teleport/lib/auth/keygen" "github.com/gravitational/teleport/lib/cryptosuites" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/sshca" ) type Keygen struct { @@ -60,7 +61,7 @@ func (n *Keygen) GenerateHostCert(c services.HostCertParams) ([]byte, error) { return n.GenerateHostCertWithoutValidation(c) } -func (n *Keygen) GenerateUserCert(c services.UserCertParams) ([]byte, error) { +func (n *Keygen) GenerateUserCert(c sshca.UserCertificateRequest) ([]byte, error) { return n.GenerateUserCertWithoutValidation(c) } diff --git a/lib/auth/usertasks/usertasksv1/service.go b/lib/auth/usertasks/usertasksv1/service.go index a383e55a70135..74223f258369c 100644 --- a/lib/auth/usertasks/usertasksv1/service.go +++ b/lib/auth/usertasks/usertasksv1/service.go @@ -19,6 +19,7 @@ package usertasksv1 import ( + "cmp" "context" "log/slog" "time" @@ -131,7 +132,7 @@ func (s *Service) CreateUserTask(ctx context.Context, req *usertasksv1.CreateUse return nil, trace.Wrap(err) } - s.updateStatus(req.UserTask) + s.updateStatus(req.UserTask, nil /* existing user task */) rsp, err := s.backend.CreateUserTask(ctx, req.UserTask) s.emitCreateAuditEvent(ctx, rsp, authCtx, err) @@ -264,10 +265,7 @@ func (s *Service) UpdateUserTask(ctx context.Context, req *usertasksv1.UpdateUse } stateChanged := existingUserTask.GetSpec().GetState() != req.GetUserTask().GetSpec().GetState() - - if stateChanged { - s.updateStatus(req.UserTask) - } + s.updateStatus(req.UserTask, existingUserTask) rsp, err := s.backend.UpdateUserTask(ctx, req.UserTask) s.emitUpdateAuditEvent(ctx, existingUserTask, req.GetUserTask(), authCtx, err) @@ -333,9 +331,7 @@ func (s *Service) UpsertUserTask(ctx context.Context, req *usertasksv1.UpsertUse stateChanged = existingUserTask.GetSpec().GetState() != req.GetUserTask().GetSpec().GetState() } - if stateChanged { - s.updateStatus(req.UserTask) - } + s.updateStatus(req.UserTask, existingUserTask) rsp, err := s.backend.UpsertUserTask(ctx, req.UserTask) s.emitUpsertAuditEvent(ctx, existingUserTask, req.GetUserTask(), authCtx, err) @@ -350,10 +346,21 @@ func (s *Service) UpsertUserTask(ctx context.Context, req *usertasksv1.UpsertUse return rsp, nil } -func (s *Service) updateStatus(ut *usertasksv1.UserTask) { +func (s *Service) updateStatus(ut *usertasksv1.UserTask, existing *usertasksv1.UserTask) { + // Default status for UserTask. ut.Status = &usertasksv1.UserTaskStatus{ LastStateChange: timestamppb.New(s.clock.Now()), } + + if existing != nil { + // Inherit everything from existing UserTask. + ut.Status.LastStateChange = cmp.Or(existing.GetStatus().GetLastStateChange(), ut.Status.LastStateChange) + + // Update specific values. + if existing.GetSpec().GetState() != ut.GetSpec().GetState() { + ut.Status.LastStateChange = timestamppb.New(s.clock.Now()) + } + } } func (s *Service) emitUpsertAuditEvent(ctx context.Context, old, new *usertasksv1.UserTask, authCtx *authz.Context, err error) { diff --git a/lib/auth/usertasks/usertasksv1/service_test.go b/lib/auth/usertasks/usertasksv1/service_test.go index d40b3740af591..1a909c278bdd8 100644 --- a/lib/auth/usertasks/usertasksv1/service_test.go +++ b/lib/auth/usertasks/usertasksv1/service_test.go @@ -153,6 +153,7 @@ func TestEvents(t *testing.T) { // LastStateChange is updated. require.Equal(t, timestamppb.New(fakeClock.Now()), createUserTaskResp.Status.LastStateChange) + expectedLastStateChange := createUserTaskResp.Status.LastStateChange ut1.Spec.DiscoverEc2.Instances["i-345"] = &usertasksv1.DiscoverEC2Instance{ InstanceId: "i-345", DiscoveryConfig: "dc01", @@ -165,7 +166,7 @@ func TestEvents(t *testing.T) { require.Len(t, testReporter.emittedEvents, 1) consumeAssertEvent(t, auditEventsSink.C(), auditEventFor(userTaskName, "update", "OPEN", "OPEN")) // LastStateChange is not updated. - require.Equal(t, createUserTaskResp.Status.LastStateChange, upsertUserTaskResp.Status.LastStateChange) + require.Equal(t, expectedLastStateChange.AsTime(), upsertUserTaskResp.Status.LastStateChange.AsTime()) ut1.Spec.State = "RESOLVED" fakeClock.Advance(1 * time.Minute) @@ -177,6 +178,36 @@ func TestEvents(t *testing.T) { // LastStateChange was updated because the state changed. require.Equal(t, timestamppb.New(fakeClock.Now()), updateUserTaskResp.Status.LastStateChange) + // Updating one of the instances. + expectedLastStateChange = updateUserTaskResp.Status.GetLastStateChange() + fakeClock.Advance(1 * time.Minute) + ut1.Spec.DiscoverEc2.Instances["i-345"] = &usertasksv1.DiscoverEC2Instance{ + InstanceId: "i-345", + DiscoveryConfig: "dc01", + DiscoveryGroup: "dg01", + SyncTime: timestamppb.New(fakeClock.Now()), + } + updateUserTaskResp, err = service.UpdateUserTask(ctx, &usertasksv1.UpdateUserTaskRequest{UserTask: ut1}) + require.NoError(t, err) + // Does not change the LastStateChange + require.Equal(t, expectedLastStateChange.AsTime(), updateUserTaskResp.Status.LastStateChange.AsTime()) + consumeAssertEvent(t, auditEventsSink.C(), auditEventFor(userTaskName, "update", "RESOLVED", "RESOLVED")) + + // Upserting one of the instances. + expectedLastStateChange = updateUserTaskResp.Status.GetLastStateChange() + fakeClock.Advance(1 * time.Minute) + ut1.Spec.DiscoverEc2.Instances["i-345"] = &usertasksv1.DiscoverEC2Instance{ + InstanceId: "i-345", + DiscoveryConfig: "dc01", + DiscoveryGroup: "dg01", + SyncTime: timestamppb.New(fakeClock.Now()), + } + upsertUserTaskResp, err = service.UpsertUserTask(ctx, &usertasksv1.UpsertUserTaskRequest{UserTask: ut1}) + require.NoError(t, err) + // Does not change the LastStateChange + require.Equal(t, expectedLastStateChange.AsTime(), upsertUserTaskResp.Status.LastStateChange.AsTime()) + consumeAssertEvent(t, auditEventsSink.C(), auditEventFor(userTaskName, "update", "RESOLVED", "RESOLVED")) + _, err = service.DeleteUserTask(ctx, &usertasksv1.DeleteUserTaskRequest{Name: userTaskName}) require.NoError(t, err) // No usage report for deleted resources. diff --git a/lib/client/api.go b/lib/client/api.go index ed94462aa9c73..8b4c317265573 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -2853,7 +2853,7 @@ type execResult struct { // sharedWriter is an [io.Writer] implementation that protects // writes with a mutex. This allows a single [io.Writer] to be shared -// by both logrus and slog without their output clobbering each other. +// by multiple command runners. type sharedWriter struct { mu sync.Mutex io.Writer diff --git a/lib/client/client_store_test.go b/lib/client/client_store_test.go index 8090c5e664851..71239884aaaba 100644 --- a/lib/client/client_store_test.go +++ b/lib/client/client_store_test.go @@ -45,6 +45,7 @@ import ( "github.com/gravitational/teleport/lib/cryptosuites" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/sshca" "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" @@ -104,16 +105,18 @@ func (s *testAuthority) makeSignedKeyRing(t *testing.T, idx KeyRingIndex, makeEx caSigner, err := ssh.ParsePrivateKey(CAPriv) require.NoError(t, err) - cert, err := s.keygen.GenerateUserCert(services.UserCertParams{ - CASigner: caSigner, - PublicUserKey: sshPriv.MarshalSSHPublicKey(), - Username: idx.Username, - AllowedLogins: allowedLogins, - TTL: ttl, - PermitAgentForwarding: false, - PermitPortForwarding: true, - GitHubUserID: "1234567", - GitHubUsername: "github-username", + cert, err := s.keygen.GenerateUserCert(sshca.UserCertificateRequest{ + CASigner: caSigner, + PublicUserKey: sshPriv.MarshalSSHPublicKey(), + TTL: ttl, + Identity: sshca.Identity{ + Username: idx.Username, + AllowedLogins: allowedLogins, + PermitAgentForwarding: false, + PermitPortForwarding: true, + GitHubUserID: "1234567", + GitHubUsername: "github-username", + }, }) require.NoError(t, err) diff --git a/lib/client/cluster_client_test.go b/lib/client/cluster_client_test.go index 7a90be3f30d80..e529b4737d1db 100644 --- a/lib/client/cluster_client_test.go +++ b/lib/client/cluster_client_test.go @@ -39,7 +39,7 @@ import ( libmfa "github.com/gravitational/teleport/lib/client/mfa" "github.com/gravitational/teleport/lib/fixtures" "github.com/gravitational/teleport/lib/observability/tracing" - "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/sshca" "github.com/gravitational/teleport/lib/tlsca" ) @@ -390,13 +390,15 @@ func TestIssueUserCertsWithMFA(t *testing.T) { var sshCert, tlsCert []byte var err error if req.SSHPublicKey != nil { - sshCert, err = ca.keygen.GenerateUserCert(services.UserCertParams{ + sshCert, err = ca.keygen.GenerateUserCert(sshca.UserCertificateRequest{ CASigner: caSigner, PublicUserKey: req.SSHPublicKey, TTL: req.Expires.Sub(clock.Now()), - Username: req.Username, CertificateFormat: req.Format, - RouteToCluster: req.RouteToCluster, + Identity: sshca.Identity{ + Username: req.Username, + RouteToCluster: req.RouteToCluster, + }, }) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/client/identityfile/identity_test.go b/lib/client/identityfile/identity_test.go index 3f52aefe162db..9d8eeb62a894d 100644 --- a/lib/client/identityfile/identity_test.go +++ b/lib/client/identityfile/identity_test.go @@ -46,7 +46,7 @@ import ( "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/fixtures" "github.com/gravitational/teleport/lib/kube/kubeconfig" - "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/sshca" "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/tlsca" ) @@ -108,11 +108,13 @@ func newClientKeyRing(t *testing.T, modifiers ...func(*tlsca.Identity)) *client. caSigner, err := ssh.NewSignerFromKey(signer) require.NoError(t, err) - certificate, err := keygen.GenerateUserCert(services.UserCertParams{ + certificate, err := keygen.GenerateUserCert(sshca.UserCertificateRequest{ CASigner: caSigner, PublicUserKey: ssh.MarshalAuthorizedKey(privateKey.SSHPublicKey()), - Username: "testuser", - AllowedLogins: []string{"testuser"}, + Identity: sshca.Identity{ + Username: "testuser", + AllowedLogins: []string{"testuser"}, + }, }) require.NoError(t, err) diff --git a/lib/client/keyagent_test.go b/lib/client/keyagent_test.go index 4c0c078e82293..a8dfdae28da95 100644 --- a/lib/client/keyagent_test.go +++ b/lib/client/keyagent_test.go @@ -50,6 +50,7 @@ import ( "github.com/gravitational/teleport/lib/cryptosuites" "github.com/gravitational/teleport/lib/fixtures" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/sshca" "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" ) @@ -751,16 +752,18 @@ func (s *KeyAgentTestSuite) makeKeyRing(t *testing.T, username, proxyHost string sshPub, err := ssh.NewPublicKey(sshKey.Public()) require.NoError(t, err) - certificate, err := testauthority.New().GenerateUserCert(services.UserCertParams{ - CertificateFormat: constants.CertificateFormatStandard, - CASigner: caSigner, - PublicUserKey: ssh.MarshalAuthorizedKey(sshPub), - Username: username, - AllowedLogins: []string{username}, - TTL: ttl, - PermitAgentForwarding: true, - PermitPortForwarding: true, - RouteToCluster: s.clusterName, + certificate, err := testauthority.New().GenerateUserCert(sshca.UserCertificateRequest{ + CertificateFormat: constants.CertificateFormatStandard, + CASigner: caSigner, + PublicUserKey: ssh.MarshalAuthorizedKey(sshPub), + TTL: ttl, + Identity: sshca.Identity{ + Username: username, + AllowedLogins: []string{username}, + PermitAgentForwarding: true, + PermitPortForwarding: true, + RouteToCluster: s.clusterName, + }, }) require.NoError(t, err) diff --git a/lib/reversetunnel/srv_test.go b/lib/reversetunnel/srv_test.go index 2477739df359a..8794a8323f0f1 100644 --- a/lib/reversetunnel/srv_test.go +++ b/lib/reversetunnel/srv_test.go @@ -39,6 +39,7 @@ import ( "github.com/gravitational/teleport/lib/auth/authclient" "github.com/gravitational/teleport/lib/auth/testauthority" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/sshca" "github.com/gravitational/teleport/lib/utils" ) @@ -103,15 +104,17 @@ func TestServerKeyAuth(t *testing.T) { { desc: "user cert", key: func() ssh.PublicKey { - rawCert, err := ta.GenerateUserCert(services.UserCertParams{ + rawCert, err := ta.GenerateUserCert(sshca.UserCertificateRequest{ CASigner: caSigner, PublicUserKey: pub, - Username: con.User(), - AllowedLogins: []string{con.User()}, - Roles: []string{"dev", "admin"}, - RouteToCluster: "user-cluster-name", CertificateFormat: constants.CertificateFormatStandard, TTL: time.Minute, + Identity: sshca.Identity{ + Username: con.User(), + AllowedLogins: []string{con.User()}, + Roles: []string{"dev", "admin"}, + RouteToCluster: "user-cluster-name", + }, }) require.NoError(t, err) key, _, _, _, err := ssh.ParseAuthorizedKey(rawCert) diff --git a/lib/service/service.go b/lib/service/service.go index 7638ee5e85caf..7fd997e7234f0 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -54,6 +54,7 @@ import ( "github.com/gravitational/roundtrip" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" + "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/quic-go/quic-go" "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" @@ -657,6 +658,15 @@ type TeleportProcess struct { // resolver is used to identify the reverse tunnel address when connecting via // the proxy. resolver reversetunnelclient.Resolver + + // metricRegistry is the prometheus metric registry for the process. + // Every teleport service that wants to register metrics should use this + // instead of the global prometheus.DefaultRegisterer to avoid registration + // conflicts. + // + // Both the metricsRegistry and the default global registry are gathered by + // Telepeort's metric service. + metricsRegistry *prometheus.Registry } // processIndex is an internal process index @@ -1179,6 +1189,7 @@ func NewTeleport(cfg *servicecfg.Config) (*TeleportProcess, error) { logger: cfg.Logger, cloudLabels: cloudLabels, TracingProvider: tracing.NoopProvider(), + metricsRegistry: cfg.MetricsRegistry, } process.registerExpectedServices(cfg) @@ -3405,11 +3416,46 @@ func (process *TeleportProcess) initUploaderService() error { return nil } +// promHTTPLogAdapter adapts a slog.Logger into a promhttp.Logger. +type promHTTPLogAdapter struct { + ctx context.Context + *slog.Logger +} + +// Println implements the promhttp.Logger interface. +func (l promHTTPLogAdapter) Println(v ...interface{}) { + //nolint:sloglint // msg cannot be constant + l.ErrorContext(l.ctx, fmt.Sprint(v...)) +} + // initMetricsService starts the metrics service currently serving metrics for // prometheus consumption func (process *TeleportProcess) initMetricsService() error { mux := http.NewServeMux() - mux.Handle("/metrics", promhttp.Handler()) + + // We gather metrics both from the in-process registry (preferred metrics registration method) + // and the global registry (used by some Teleport services and many dependencies). + gatherers := prometheus.Gatherers{ + process.metricsRegistry, + prometheus.DefaultGatherer, + } + + metricsHandler := promhttp.InstrumentMetricHandler( + process.metricsRegistry, promhttp.HandlerFor(gatherers, promhttp.HandlerOpts{ + // Errors can happen if metrics are registered with identical names in both the local and the global registry. + // In this case, we log the error but continue collecting metrics. The first collected metric will win + // (the one from the local metrics registry takes precedence). + // As we move more things to the local registry, especially in other tools like tbot, we will have less + // conflicts in tests. + ErrorHandling: promhttp.ContinueOnError, + ErrorLog: promHTTPLogAdapter{ + ctx: process.ExitContext(), + Logger: process.logger.With(teleport.ComponentKey, teleport.ComponentMetrics), + }, + }), + ) + + mux.Handle("/metrics", metricsHandler) logger := process.logger.With(teleport.ComponentKey, teleport.Component(teleport.ComponentMetrics, process.id)) diff --git a/lib/service/service_test.go b/lib/service/service_test.go index 52e59387ff580..4c08a87689145 100644 --- a/lib/service/service_test.go +++ b/lib/service/service_test.go @@ -23,9 +23,11 @@ import ( "crypto/tls" "errors" "fmt" + "io" "log/slog" "net" "net/http" + "net/url" "os" "path/filepath" "strings" @@ -39,6 +41,8 @@ import ( "github.com/google/uuid" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" + "github.com/prometheus/client_golang/prometheus" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/sync/errgroup" "google.golang.org/grpc" @@ -1887,7 +1891,7 @@ func TestAgentRolloutController(t *testing.T) { dataDir := makeTempDir(t) cfg := servicecfg.MakeDefaultConfig() - // We use a real clock because too many sevrices are using the clock and it's not possible to accurately wait for + // We use a real clock because too many services are using the clock and it's not possible to accurately wait for // each one of them to reach the point where they wait for the clock to advance. If we add a WaitUntil(X waiters) // check, this will break the next time we add a new waiter. cfg.Clock = clockwork.NewRealClock() @@ -1906,7 +1910,7 @@ func TestAgentRolloutController(t *testing.T) { process, err := NewTeleport(cfg) require.NoError(t, err) - // Test setup: start the Teleport auth and wait for it to beocme ready + // Test setup: start the Teleport auth and wait for it to become ready require.NoError(t, process.Start()) // Test setup: wait for every service to start @@ -1949,6 +1953,84 @@ func TestAgentRolloutController(t *testing.T) { }, 5*time.Second, 10*time.Millisecond) } +func TestMetricsService(t *testing.T) { + t.Parallel() + // Test setup: create a listener for the metrics server, get its file descriptor. + + // Note: this code is copied from integrations/helpers/NewListenerOn() to avoid including helpers in a production + // build and avoid a cyclic dependency. + metricsListener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + t.Cleanup(func() { + assert.NoError(t, metricsListener.Close()) + }) + require.IsType(t, &net.TCPListener{}, metricsListener) + metricsListenerFile, err := metricsListener.(*net.TCPListener).File() + require.NoError(t, err) + + // Test setup: create a new teleport process + dataDir := makeTempDir(t) + cfg := servicecfg.MakeDefaultConfig() + cfg.DataDir = dataDir + cfg.SetAuthServerAddress(utils.NetAddr{AddrNetwork: "tcp", Addr: "127.0.0.1:0"}) + cfg.Auth.Enabled = true + cfg.Proxy.Enabled = false + cfg.SSH.Enabled = false + cfg.DebugService.Enabled = false + cfg.Auth.StorageConfig.Params["path"] = dataDir + cfg.Auth.ListenAddr = utils.NetAddr{AddrNetwork: "tcp", Addr: "127.0.0.1:0"} + cfg.Metrics.Enabled = true + + // Configure the metrics server to use the listener we previously created. + cfg.Metrics.ListenAddr = &utils.NetAddr{AddrNetwork: "tcp", Addr: metricsListener.Addr().String()} + cfg.FileDescriptors = []*servicecfg.FileDescriptor{ + {Type: string(ListenerMetrics), Address: metricsListener.Addr().String(), File: metricsListenerFile}, + } + + // Create and start the Teleport service. + process, err := NewTeleport(cfg) + require.NoError(t, err) + require.NoError(t, process.Start()) + t.Cleanup(func() { + assert.NoError(t, process.Close()) + assert.NoError(t, process.Wait()) + }) + + // Test setup: create our test metrics. + nonce := strings.ReplaceAll(uuid.NewString(), "-", "") + localMetric := prometheus.NewGauge(prometheus.GaugeOpts{ + Namespace: "test", + Name: "local_metric_" + nonce, + }) + globalMetric := prometheus.NewGauge(prometheus.GaugeOpts{ + Namespace: "test", + Name: "global_metric_" + nonce, + }) + require.NoError(t, process.metricsRegistry.Register(localMetric)) + require.NoError(t, prometheus.Register(globalMetric)) + + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + t.Cleanup(cancel) + _, err = process.WaitForEvent(ctx, MetricsReady) + require.NoError(t, err) + + // Test execution: get metrics and check the tests metrics are here. + metricsURL, err := url.Parse("http://" + metricsListener.Addr().String()) + require.NoError(t, err) + metricsURL.Path = "/metrics" + resp, err := http.Get(metricsURL.String()) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.NoError(t, resp.Body.Close()) + + // Test validation: check that the metrics server served both the local and global registry. + require.Contains(t, string(body), "local_metric_"+nonce) + require.Contains(t, string(body), "global_metric_"+nonce) +} + // makeTempDir makes a temp dir with a shorter name than t.TempDir() in order to // avoid https://github.com/golang/go/issues/62614. func makeTempDir(t *testing.T) string { diff --git a/lib/service/servicecfg/config.go b/lib/service/servicecfg/config.go index a89e79a8f6302..a89e29a2c7b54 100644 --- a/lib/service/servicecfg/config.go +++ b/lib/service/servicecfg/config.go @@ -34,6 +34,7 @@ import ( "github.com/ghodss/yaml" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" + "github.com/prometheus/client_golang/prometheus" "golang.org/x/crypto/ssh" "github.com/gravitational/teleport" @@ -264,6 +265,12 @@ type Config struct { // protocol. DatabaseREPLRegistry dbrepl.REPLRegistry + // MetricsRegistry is the prometheus metrics registry used by the Teleport process to register its metrics. + // As of today, not every Teleport metric is registered against this registry. Some Teleport services + // and Teleport dependencies are using the global registry. + // Both the MetricsRegistry and the default global registry are gathered by Teleport's metric service. + MetricsRegistry *prometheus.Registry + // token is either the token needed to join the auth server, or a path pointing to a file // that contains the token // @@ -520,6 +527,10 @@ func ApplyDefaults(cfg *Config) { cfg.LoggerLevel = new(slog.LevelVar) } + if cfg.MetricsRegistry == nil { + cfg.MetricsRegistry = prometheus.NewRegistry() + } + // Remove insecure and (borderline insecure) cryptographic primitives from // default configuration. These can still be added back in file configuration by // users, but not supported by default by Teleport. See #1856 for more diff --git a/lib/services/authority.go b/lib/services/authority.go index fb6a3efe612e6..2345342b1195b 100644 --- a/lib/services/authority.go +++ b/lib/services/authority.go @@ -32,9 +32,7 @@ import ( "github.com/jonboulle/clockwork" "golang.org/x/crypto/ssh" - apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/api/types/wrappers" apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/api/utils/keys" "github.com/gravitational/teleport/lib/jwt" @@ -321,103 +319,6 @@ func (c HostCertParams) Check() error { return nil } -// UserCertParams defines OpenSSH user certificate parameters -type UserCertParams struct { - // CASigner is the signer that will sign the public key of the user with the CA private key - CASigner ssh.Signer - // PublicUserKey is the public key of the user in SSH authorized_keys format. - PublicUserKey []byte - // TTL defines how long a certificate is valid for - TTL time.Duration - // Username is teleport username - Username string - // Impersonator is set when a user requests certificate for another user - Impersonator string - // AllowedLogins is a list of SSH principals - AllowedLogins []string - // PermitX11Forwarding permits X11 forwarding for this cert - PermitX11Forwarding bool - // PermitAgentForwarding permits agent forwarding for this cert - PermitAgentForwarding bool - // PermitPortForwarding permits port forwarding. - PermitPortForwarding bool - // PermitFileCopying permits the use of SCP/SFTP. - PermitFileCopying bool - // Roles is a list of roles assigned to this user - Roles []string - // CertificateFormat is the format of the SSH certificate. - CertificateFormat string - // RouteToCluster specifies the target cluster - // if present in the certificate, will be used - // to route the requests to - RouteToCluster string - // Traits hold claim data used to populate a role at runtime. - Traits wrappers.Traits - // ActiveRequests tracks privilege escalation requests applied during - // certificate construction. - ActiveRequests RequestIDs - // MFAVerified is the UUID of an MFA device when this Identity was - // confirmed immediately after an MFA check. - MFAVerified string - // PreviousIdentityExpires is the expiry time of the identity/cert that this - // identity/cert was derived from. It is used to determine a session's hard - // deadline in cases where both require_session_mfa and disconnect_expired_cert - // are enabled. See https://github.com/gravitational/teleport/issues/18544. - PreviousIdentityExpires time.Time - // LoginIP is an observed IP of the client on the moment of certificate creation. - LoginIP string - // PinnedIP is an IP from which client must communicate with Teleport. - PinnedIP string - // DisallowReissue flags that any attempt to request new certificates while - // authenticated with this cert should be denied. - DisallowReissue bool - // CertificateExtensions are user configured ssh key extensions - CertificateExtensions []*types.CertExtension - // Renewable indicates this certificate is renewable. - Renewable bool - // Generation counts the number of times a certificate has been renewed. - Generation uint64 - // BotName is set to the name of the bot, if the user is a Machine ID bot user. - // Empty for human users. - BotName string - // BotInstanceID is the unique identifier for the bot instance, if this is a - // Machine ID bot. It is empty for human users. - BotInstanceID string - // AllowedResourceIDs lists the resources the user should be able to access. - AllowedResourceIDs string - // ConnectionDiagnosticID references the ConnectionDiagnostic that we should use to append traces when testing a Connection. - ConnectionDiagnosticID string - // PrivateKeyPolicy is the private key policy supported by this certificate. - PrivateKeyPolicy keys.PrivateKeyPolicy - // DeviceID is the trusted device identifier. - DeviceID string - // DeviceAssetTag is the device inventory identifier. - DeviceAssetTag string - // DeviceCredentialID is the identifier for the credential used by the device - // to authenticate itself. - DeviceCredentialID string - // GitHubUserID indicates the GitHub user ID identified by the GitHub - // connector. - GitHubUserID string - // GitHubUserID indicates the GitHub username identified by the GitHub - // connector. - GitHubUsername string -} - -// CheckAndSetDefaults checks the user certificate parameters -func (c *UserCertParams) CheckAndSetDefaults() error { - if c.CASigner == nil { - return trace.BadParameter("CASigner is required") - } - if c.TTL < apidefaults.MinCertDuration { - c.TTL = apidefaults.MinCertDuration - } - if len(c.AllowedLogins) == 0 { - return trace.BadParameter("AllowedLogins are required") - } - return nil -} - // CertPoolFromCertAuthorities returns a certificate pool from the TLS certificates // set up in the certificate authorities list, as well as the number of certificates // that were added to the pool. diff --git a/lib/srv/authhandlers_test.go b/lib/srv/authhandlers_test.go index 78856817654a9..907a3db97b786 100644 --- a/lib/srv/authhandlers_test.go +++ b/lib/srv/authhandlers_test.go @@ -35,7 +35,7 @@ import ( "github.com/gravitational/teleport/lib/auth/testauthority" "github.com/gravitational/teleport/lib/cryptosuites" "github.com/gravitational/teleport/lib/events/eventstest" - "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/sshca" ) type mockCAandAuthPrefGetter struct { @@ -213,11 +213,13 @@ func TestRBAC(t *testing.T) { privateKey, err := cryptosuites.GeneratePrivateKeyWithAlgorithm(cryptosuites.ECDSAP256) require.NoError(t, err) - c, err := keygen.GenerateUserCert(services.UserCertParams{ + c, err := keygen.GenerateUserCert(sshca.UserCertificateRequest{ CASigner: caSigner, PublicUserKey: ssh.MarshalAuthorizedKey(privateKey.SSHPublicKey()), - Username: "testuser", - AllowedLogins: []string{"testuser"}, + Identity: sshca.Identity{ + Username: "testuser", + AllowedLogins: []string{"testuser"}, + }, }) require.NoError(t, err) @@ -385,16 +387,18 @@ func TestRBACJoinMFA(t *testing.T) { require.NoError(t, err) keygen := testauthority.New() - c, err := keygen.GenerateUserCert(services.UserCertParams{ - CASigner: caSigner, - PublicUserKey: privateKey.MarshalSSHPublicKey(), - Username: username, - AllowedLogins: []string{username}, - Traits: wrappers.Traits{ - teleport.TraitInternalPrefix: []string{""}, - }, - Roles: []string{tt.role}, + c, err := keygen.GenerateUserCert(sshca.UserCertificateRequest{ + CASigner: caSigner, + PublicUserKey: privateKey.MarshalSSHPublicKey(), CertificateFormat: constants.CertificateFormatStandard, + Identity: sshca.Identity{ + Username: username, + AllowedLogins: []string{username}, + Traits: wrappers.Traits{ + teleport.TraitInternalPrefix: []string{""}, + }, + Roles: []string{tt.role}, + }, }) require.NoError(t, err) diff --git a/lib/srv/desktop/rdp/rdpclient/client.go b/lib/srv/desktop/rdp/rdpclient/client.go index 821408d2208fa..534644e6be1df 100644 --- a/lib/srv/desktop/rdp/rdpclient/client.go +++ b/lib/srv/desktop/rdp/rdpclient/client.go @@ -93,7 +93,7 @@ func init() { var rustLogLevel string // initialize the Rust logger by setting $RUST_LOG based - // on the logrus log level + // on the slog log level // (unless RUST_LOG is already explicitly set, then we // assume the user knows what they want) rl := os.Getenv("RUST_LOG") diff --git a/lib/sshca/identity.go b/lib/sshca/identity.go new file mode 100644 index 0000000000000..19f40bfdf336d --- /dev/null +++ b/lib/sshca/identity.go @@ -0,0 +1,392 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +// Package sshca specifies interfaces for SSH certificate authorities +package sshca + +import ( + "fmt" + "maps" + "strconv" + "strings" + "time" + + "github.com/gravitational/trace" + "golang.org/x/crypto/ssh" + + "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/constants" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/types/wrappers" + "github.com/gravitational/teleport/api/utils/keys" + "github.com/gravitational/teleport/lib/services" +) + +// Identity is a user identity. All identity fields map directly to an ssh certificate field. +type Identity struct { + // ValidAfter is the unix timestamp that marks the start time for when the certificate should + // be considered valid. + ValidAfter uint64 + // ValidBefore is the unix timestamp that marks the end time for when the certificate should + // be considered valid. + ValidBefore uint64 + // Username is teleport username + Username string + // Impersonator is set when a user requests certificate for another user + Impersonator string + // AllowedLogins is a list of SSH principals + AllowedLogins []string + // PermitX11Forwarding permits X11 forwarding for this cert + PermitX11Forwarding bool + // PermitAgentForwarding permits agent forwarding for this cert + PermitAgentForwarding bool + // PermitPortForwarding permits port forwarding. + PermitPortForwarding bool + // Roles is a list of roles assigned to this user + Roles []string + // RouteToCluster specifies the target cluster + // if present in the certificate, will be used + // to route the requests to + RouteToCluster string + // Traits hold claim data used to populate a role at runtime. + Traits wrappers.Traits + // ActiveRequests tracks privilege escalation requests applied during + // certificate construction. + ActiveRequests services.RequestIDs + // MFAVerified is the UUID of an MFA device when this Identity was + // confirmed immediately after an MFA check. + MFAVerified string + // PreviousIdentityExpires is the expiry time of the identity/cert that this + // identity/cert was derived from. It is used to determine a session's hard + // deadline in cases where both require_session_mfa and disconnect_expired_cert + // are enabled. See https://github.com/gravitational/teleport/issues/18544. + PreviousIdentityExpires time.Time + // LoginIP is an observed IP of the client on the moment of certificate creation. + LoginIP string + // PinnedIP is an IP from which client must communicate with Teleport. + PinnedIP string + // DisallowReissue flags that any attempt to request new certificates while + // authenticated with this cert should be denied. + DisallowReissue bool + // CertificateExtensions are user configured ssh key extensions (note: this field also + // ends up aggregating all *unknown* extensions during cert parsing, meaning that this + // can sometimes contain fields that were inserted by a newer version of teleport). + CertificateExtensions []*types.CertExtension + // Renewable indicates this certificate is renewable. + Renewable bool + // Generation counts the number of times a certificate has been renewed, with a generation of 1 + // meaning the cert has never been renewed. A generation of zero means the cert's generation is + // not being tracked. + Generation uint64 + // BotName is set to the name of the bot, if the user is a Machine ID bot user. + // Empty for human users. + BotName string + // BotInstanceID is the unique identifier for the bot instance, if this is a + // Machine ID bot. It is empty for human users. + BotInstanceID string + // AllowedResourceIDs lists the resources the user should be able to access. + AllowedResourceIDs string + // ConnectionDiagnosticID references the ConnectionDiagnostic that we should use to append traces when testing a Connection. + ConnectionDiagnosticID string + // PrivateKeyPolicy is the private key policy supported by this certificate. + PrivateKeyPolicy keys.PrivateKeyPolicy + // DeviceID is the trusted device identifier. + DeviceID string + // DeviceAssetTag is the device inventory identifier. + DeviceAssetTag string + // DeviceCredentialID is the identifier for the credential used by the device + // to authenticate itself. + DeviceCredentialID string + // GitHubUserID indicates the GitHub user ID identified by the GitHub + // connector. + GitHubUserID string + // GitHubUsername indicates the GitHub username identified by the GitHub + // connector. + GitHubUsername string +} + +// Check performs validation of certain fields in the identity. +func (i *Identity) Check() error { + if len(i.AllowedLogins) == 0 { + return trace.BadParameter("ssh user identity missing allowed logins") + } + + return nil +} + +// Encode encodes the identity into an ssh certificate. Note that the returned certificate is incomplete +// and must be have its public key set before signing. +func (i *Identity) Encode(certFormat string) (*ssh.Certificate, error) { + validBefore := i.ValidBefore + if validBefore == 0 { + validBefore = uint64(ssh.CertTimeInfinity) + } + validAfter := i.ValidAfter + if validAfter == 0 { + validAfter = uint64(time.Now().UTC().Add(-1 * time.Minute).Unix()) + } + cert := &ssh.Certificate{ + // we have to use key id to identify teleport user + KeyId: i.Username, + ValidPrincipals: i.AllowedLogins, + ValidAfter: validAfter, + ValidBefore: validBefore, + CertType: ssh.UserCert, + } + cert.Permissions.Extensions = map[string]string{ + teleport.CertExtensionPermitPTY: "", + } + + if i.PermitX11Forwarding { + cert.Permissions.Extensions[teleport.CertExtensionPermitX11Forwarding] = "" + } + if i.PermitAgentForwarding { + cert.Permissions.Extensions[teleport.CertExtensionPermitAgentForwarding] = "" + } + if i.PermitPortForwarding { + cert.Permissions.Extensions[teleport.CertExtensionPermitPortForwarding] = "" + } + if i.MFAVerified != "" { + cert.Permissions.Extensions[teleport.CertExtensionMFAVerified] = i.MFAVerified + } + if !i.PreviousIdentityExpires.IsZero() { + cert.Permissions.Extensions[teleport.CertExtensionPreviousIdentityExpires] = i.PreviousIdentityExpires.Format(time.RFC3339) + } + if i.LoginIP != "" { + cert.Permissions.Extensions[teleport.CertExtensionLoginIP] = i.LoginIP + } + if i.Impersonator != "" { + cert.Permissions.Extensions[teleport.CertExtensionImpersonator] = i.Impersonator + } + if i.DisallowReissue { + cert.Permissions.Extensions[teleport.CertExtensionDisallowReissue] = "" + } + if i.Renewable { + cert.Permissions.Extensions[teleport.CertExtensionRenewable] = "" + } + if i.Generation > 0 { + cert.Permissions.Extensions[teleport.CertExtensionGeneration] = fmt.Sprint(i.Generation) + } + if i.BotName != "" { + cert.Permissions.Extensions[teleport.CertExtensionBotName] = i.BotName + } + if i.BotInstanceID != "" { + cert.Permissions.Extensions[teleport.CertExtensionBotInstanceID] = i.BotInstanceID + } + if i.AllowedResourceIDs != "" { + cert.Permissions.Extensions[teleport.CertExtensionAllowedResources] = i.AllowedResourceIDs + } + if i.ConnectionDiagnosticID != "" { + cert.Permissions.Extensions[teleport.CertExtensionConnectionDiagnosticID] = i.ConnectionDiagnosticID + } + if i.PrivateKeyPolicy != "" { + cert.Permissions.Extensions[teleport.CertExtensionPrivateKeyPolicy] = string(i.PrivateKeyPolicy) + } + if devID := i.DeviceID; devID != "" { + cert.Permissions.Extensions[teleport.CertExtensionDeviceID] = devID + } + if assetTag := i.DeviceAssetTag; assetTag != "" { + cert.Permissions.Extensions[teleport.CertExtensionDeviceAssetTag] = assetTag + } + if credID := i.DeviceCredentialID; credID != "" { + cert.Permissions.Extensions[teleport.CertExtensionDeviceCredentialID] = credID + } + if i.GitHubUserID != "" { + cert.Permissions.Extensions[teleport.CertExtensionGitHubUserID] = i.GitHubUserID + } + if i.GitHubUsername != "" { + cert.Permissions.Extensions[teleport.CertExtensionGitHubUsername] = i.GitHubUsername + } + + if i.PinnedIP != "" { + if cert.CriticalOptions == nil { + cert.CriticalOptions = make(map[string]string) + } + // IPv4, all bits matter + ip := i.PinnedIP + "/32" + if strings.Contains(i.PinnedIP, ":") { + // IPv6 + ip = i.PinnedIP + "/128" + } + cert.CriticalOptions[teleport.CertCriticalOptionSourceAddress] = ip + } + + for _, extension := range i.CertificateExtensions { + // TODO(lxea): update behavior when non ssh, non extensions are supported. + if extension.Mode != types.CertExtensionMode_EXTENSION || + extension.Type != types.CertExtensionType_SSH { + continue + } + cert.Extensions[extension.Name] = extension.Value + } + + // Add roles, traits, and route to cluster in the certificate extensions if + // the standard format was requested. Certificate extensions are not included + // legacy SSH certificates due to a bug in OpenSSH <= OpenSSH 7.1: + // https://bugzilla.mindrot.org/show_bug.cgi?id=2387 + if certFormat == constants.CertificateFormatStandard { + traits, err := wrappers.MarshalTraits(&i.Traits) + if err != nil { + return nil, trace.Wrap(err) + } + if len(traits) > 0 { + cert.Permissions.Extensions[teleport.CertExtensionTeleportTraits] = string(traits) + } + if len(i.Roles) != 0 { + roles, err := services.MarshalCertRoles(i.Roles) + if err != nil { + return nil, trace.Wrap(err) + } + cert.Permissions.Extensions[teleport.CertExtensionTeleportRoles] = roles + } + if i.RouteToCluster != "" { + cert.Permissions.Extensions[teleport.CertExtensionTeleportRouteToCluster] = i.RouteToCluster + } + if !i.ActiveRequests.IsEmpty() { + requests, err := i.ActiveRequests.Marshal() + if err != nil { + return nil, trace.Wrap(err) + } + cert.Permissions.Extensions[teleport.CertExtensionTeleportActiveRequests] = string(requests) + } + } + + return cert, nil +} + +// DecodeIdentity decodes an ssh certificate into an identity. +func DecodeIdentity(cert *ssh.Certificate) (*Identity, error) { + if cert.CertType != ssh.UserCert { + return nil, trace.BadParameter("DecodeIdentity intended for use with user certs, got %v", cert.CertType) + } + ident := &Identity{ + Username: cert.KeyId, + AllowedLogins: cert.ValidPrincipals, + ValidAfter: cert.ValidAfter, + ValidBefore: cert.ValidBefore, + } + + // clone the extension map and remove entries from the clone as they are processed so + // that we can easily aggregate the remainder into the CertificateExtensions field. + extensions := maps.Clone(cert.Extensions) + + takeExtension := func(name string) (value string, ok bool) { + v, ok := extensions[name] + if !ok { + return "", false + } + delete(extensions, name) + return v, true + } + + takeValue := func(name string) string { + value, _ := takeExtension(name) + return value + } + + takeBool := func(name string) bool { + _, ok := takeExtension(name) + return ok + } + + // ignore the permit pty extension, it's always set + _, _ = takeExtension(teleport.CertExtensionPermitPTY) + + ident.PermitX11Forwarding = takeBool(teleport.CertExtensionPermitX11Forwarding) + ident.PermitAgentForwarding = takeBool(teleport.CertExtensionPermitAgentForwarding) + ident.PermitPortForwarding = takeBool(teleport.CertExtensionPermitPortForwarding) + ident.MFAVerified = takeValue(teleport.CertExtensionMFAVerified) + + if v, ok := takeExtension(teleport.CertExtensionPreviousIdentityExpires); ok { + t, err := time.Parse(time.RFC3339, v) + if err != nil { + return nil, trace.BadParameter("failed to parse value %q for extension %q as RFC3339 timestamp: %v", v, teleport.CertExtensionPreviousIdentityExpires, err) + } + ident.PreviousIdentityExpires = t + } + + ident.LoginIP = takeValue(teleport.CertExtensionLoginIP) + ident.Impersonator = takeValue(teleport.CertExtensionImpersonator) + ident.DisallowReissue = takeBool(teleport.CertExtensionDisallowReissue) + ident.Renewable = takeBool(teleport.CertExtensionRenewable) + + if v, ok := takeExtension(teleport.CertExtensionGeneration); ok { + i, err := strconv.ParseUint(v, 10, 64) + if err != nil { + return nil, trace.BadParameter("failed to parse value %q for extension %q as uint64: %v", v, teleport.CertExtensionGeneration, err) + } + ident.Generation = i + } + + ident.BotName = takeValue(teleport.CertExtensionBotName) + ident.BotInstanceID = takeValue(teleport.CertExtensionBotInstanceID) + ident.AllowedResourceIDs = takeValue(teleport.CertExtensionAllowedResources) + ident.ConnectionDiagnosticID = takeValue(teleport.CertExtensionConnectionDiagnosticID) + ident.PrivateKeyPolicy = keys.PrivateKeyPolicy(takeValue(teleport.CertExtensionPrivateKeyPolicy)) + ident.DeviceID = takeValue(teleport.CertExtensionDeviceID) + ident.DeviceAssetTag = takeValue(teleport.CertExtensionDeviceAssetTag) + ident.DeviceCredentialID = takeValue(teleport.CertExtensionDeviceCredentialID) + ident.GitHubUserID = takeValue(teleport.CertExtensionGitHubUserID) + ident.GitHubUsername = takeValue(teleport.CertExtensionGitHubUsername) + + if v, ok := cert.CriticalOptions[teleport.CertCriticalOptionSourceAddress]; ok { + parts := strings.Split(v, "/") + if len(parts) != 2 { + return nil, trace.BadParameter("failed to parse value %q for critical option %q as CIDR", v, teleport.CertCriticalOptionSourceAddress) + } + ident.PinnedIP = parts[0] + } + + if v, ok := takeExtension(teleport.CertExtensionTeleportTraits); ok { + var traits wrappers.Traits + if err := wrappers.UnmarshalTraits([]byte(v), &traits); err != nil { + return nil, trace.BadParameter("failed to unmarshal value %q for extension %q as traits: %v", v, teleport.CertExtensionTeleportTraits, err) + } + ident.Traits = traits + } + + if v, ok := takeExtension(teleport.CertExtensionTeleportRoles); ok { + roles, err := services.UnmarshalCertRoles(v) + if err != nil { + return nil, trace.BadParameter("failed to unmarshal value %q for extension %q as roles: %v", v, teleport.CertExtensionTeleportRoles, err) + } + ident.Roles = roles + } + + ident.RouteToCluster = takeValue(teleport.CertExtensionTeleportRouteToCluster) + + if v, ok := takeExtension(teleport.CertExtensionTeleportActiveRequests); ok { + var requests services.RequestIDs + if err := requests.Unmarshal([]byte(v)); err != nil { + return nil, trace.BadParameter("failed to unmarshal value %q for extension %q as active requests: %v", v, teleport.CertExtensionTeleportActiveRequests, err) + } + ident.ActiveRequests = requests + } + + // aggregate all remaining extensions into the CertificateExtensions field + for name, value := range extensions { + ident.CertificateExtensions = append(ident.CertificateExtensions, &types.CertExtension{ + Name: name, + Value: value, + Type: types.CertExtensionType_SSH, + Mode: types.CertExtensionMode_EXTENSION, + }) + } + + return ident, nil +} diff --git a/lib/sshca/identity_test.go b/lib/sshca/identity_test.go new file mode 100644 index 0000000000000..5c7c6db75b3e8 --- /dev/null +++ b/lib/sshca/identity_test.go @@ -0,0 +1,97 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +// Package sshca specifies interfaces for SSH certificate authorities +package sshca + +import ( + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/api/constants" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/types/wrappers" + "github.com/gravitational/teleport/api/utils/keys" + "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/utils/testutils" +) + +func TestIdentityConversion(t *testing.T) { + ident := &Identity{ + ValidAfter: 1, + ValidBefore: 2, + Username: "user", + Impersonator: "impersonator", + AllowedLogins: []string{"login1", "login2"}, + PermitX11Forwarding: true, + PermitAgentForwarding: true, + PermitPortForwarding: true, + Roles: []string{"role1", "role2"}, + RouteToCluster: "cluster", + Traits: wrappers.Traits{"trait1": []string{"value1"}, "trait2": []string{"value2"}}, + ActiveRequests: services.RequestIDs{ + AccessRequests: []string{uuid.NewString()}, + }, + MFAVerified: "mfa", + PreviousIdentityExpires: time.Unix(12345, 0), + LoginIP: "127.0.0.1", + PinnedIP: "127.0.0.1", + DisallowReissue: true, + CertificateExtensions: []*types.CertExtension{&types.CertExtension{ + Name: "extname", + Value: "extvalue", + Type: types.CertExtensionType_SSH, + Mode: types.CertExtensionMode_EXTENSION, + }}, + Renewable: true, + Generation: 3, + BotName: "bot", + BotInstanceID: "instance", + AllowedResourceIDs: "resource", + ConnectionDiagnosticID: "diag", + PrivateKeyPolicy: keys.PrivateKeyPolicy("policy"), + DeviceID: "device", + DeviceAssetTag: "asset", + DeviceCredentialID: "cred", + GitHubUserID: "github", + GitHubUsername: "ghuser", + } + + ignores := []string{ + "CertExtension.Type", // only currently defined enum variant is a zero value + "CertExtension.Mode", // only currently defined enum variant is a zero value + // TODO(fspmarshall): figure out a mechanism for making ignore of grpc fields more convenient + "CertExtension.XXX_NoUnkeyedLiteral", + "CertExtension.XXX_unrecognized", + "CertExtension.XXX_sizecache", + } + + require.True(t, testutils.ExhaustiveNonEmpty(ident, ignores...), "empty=%+v", testutils.FindAllEmpty(ident, ignores...)) + + cert, err := ident.Encode(constants.CertificateFormatStandard) + require.NoError(t, err) + + ident2, err := DecodeIdentity(cert) + require.NoError(t, err) + + require.Empty(t, cmp.Diff(ident, ident2)) +} diff --git a/lib/sshca/sshca.go b/lib/sshca/sshca.go index 5e9e3f548f853..15f5dcf6c1aeb 100644 --- a/lib/sshca/sshca.go +++ b/lib/sshca/sshca.go @@ -20,6 +20,12 @@ package sshca import ( + "time" + + "github.com/gravitational/trace" + "golang.org/x/crypto/ssh" + + apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/lib/services" ) @@ -33,5 +39,34 @@ type Authority interface { // GenerateUserCert generates user ssh certificate, it takes pkey as a signing // private key (user certificate authority) - GenerateUserCert(certParams services.UserCertParams) ([]byte, error) + GenerateUserCert(UserCertificateRequest) ([]byte, error) +} + +// UserCertificateRequest is a request to generate a new ssh user certificate. +type UserCertificateRequest struct { + // CASigner is the signer that will sign the public key of the user with the CA private key + CASigner ssh.Signer + // PublicUserKey is the public key of the user in SSH authorized_keys format. + PublicUserKey []byte + // TTL defines how long a certificate is valid for (if specified, ValidAfter/ValidBefore within the + // identity must not be set). + TTL time.Duration + // CertificateFormat is the format of the SSH certificate. + CertificateFormat string + // Identity is the user identity to be encoded in the certificate. + Identity Identity +} + +func (r *UserCertificateRequest) CheckAndSetDefaults() error { + if r.CASigner == nil { + return trace.BadParameter("ssh user certificate request missing ca signer") + } + if r.TTL < apidefaults.MinCertDuration { + r.TTL = apidefaults.MinCertDuration + } + if err := r.Identity.Check(); err != nil { + return trace.Wrap(err) + } + + return nil } diff --git a/lib/utils/cli.go b/lib/utils/cli.go index e79c0bc2aa8f0..648cf7095352f 100644 --- a/lib/utils/cli.go +++ b/lib/utils/cli.go @@ -26,7 +26,6 @@ import ( "flag" "fmt" "io" - stdlog "log" "log/slog" "os" "runtime" @@ -38,7 +37,6 @@ import ( "github.com/alecthomas/kingpin/v2" "github.com/gravitational/trace" - "github.com/sirupsen/logrus" "golang.org/x/term" "github.com/gravitational/teleport" @@ -100,59 +98,18 @@ func InitLogger(purpose LoggingPurpose, level slog.Level, opts ...LoggerOption) opt(&o) } - logrus.StandardLogger().ReplaceHooks(make(logrus.LevelHooks)) - logrus.SetLevel(logutils.SlogLevelToLogrusLevel(level)) - - var ( - w io.Writer - enableColors bool - ) - switch purpose { - case LoggingForCLI: - // If debug logging was asked for on the CLI, then write logs to stderr. - // Otherwise, discard all logs. - if level == slog.LevelDebug { - enableColors = IsTerminal(os.Stderr) - w = logutils.NewSharedWriter(os.Stderr) - } else { - w = io.Discard - enableColors = false - } - case LoggingForDaemon: - enableColors = IsTerminal(os.Stderr) - w = logutils.NewSharedWriter(os.Stderr) - } - - var ( - formatter logrus.Formatter - handler slog.Handler - ) - switch o.format { - case LogFormatText, "": - textFormatter := logutils.NewDefaultTextFormatter(enableColors) - - // Calling CheckAndSetDefaults enables the timestamp field to - // be included in the output. The error returned is ignored - // because the default formatter cannot be invalid. - if purpose == LoggingForCLI && level == slog.LevelDebug { - _ = textFormatter.CheckAndSetDefaults() - } - - formatter = textFormatter - handler = logutils.NewSlogTextHandler(w, logutils.SlogTextHandlerConfig{ - Level: level, - EnableColors: enableColors, - }) - case LogFormatJSON: - formatter = &logutils.JSONFormatter{} - handler = logutils.NewSlogJSONHandler(w, logutils.SlogJSONHandlerConfig{ - Level: level, - }) + // If debug or trace logging is not enabled for CLIs, + // then discard all log output. + if purpose == LoggingForCLI && level > slog.LevelDebug { + slog.SetDefault(slog.New(logutils.DiscardHandler{})) + return } - logrus.SetFormatter(formatter) - logrus.SetOutput(w) - slog.SetDefault(slog.New(handler)) + logutils.Initialize(logutils.Config{ + Severity: level.String(), + Format: o.format, + EnableColors: IsTerminal(os.Stderr), + }) } var initTestLoggerOnce = sync.Once{} @@ -163,56 +120,24 @@ func InitLoggerForTests() { // Parse flags to check testing.Verbose(). flag.Parse() - level := slog.LevelWarn - w := io.Discard - if testing.Verbose() { - level = slog.LevelDebug - w = os.Stderr + if !testing.Verbose() { + slog.SetDefault(slog.New(logutils.DiscardHandler{})) + return } - logger := logrus.StandardLogger() - logger.SetFormatter(logutils.NewTestJSONFormatter()) - logger.SetLevel(logutils.SlogLevelToLogrusLevel(level)) - - output := logutils.NewSharedWriter(w) - logger.SetOutput(output) - slog.SetDefault(slog.New(logutils.NewSlogJSONHandler(output, logutils.SlogJSONHandlerConfig{Level: level}))) + logutils.Initialize(logutils.Config{ + Severity: slog.LevelDebug.String(), + Format: LogFormatJSON, + }) }) } -// NewLoggerForTests creates a new logrus logger for test environments. -func NewLoggerForTests() *logrus.Logger { - InitLoggerForTests() - return logrus.StandardLogger() -} - // NewSlogLoggerForTests creates a new slog logger for test environments. func NewSlogLoggerForTests() *slog.Logger { InitLoggerForTests() return slog.Default() } -// WrapLogger wraps an existing logger entry and returns -// a value satisfying the Logger interface -func WrapLogger(logger *logrus.Entry) Logger { - return &logWrapper{Entry: logger} -} - -// NewLogger creates a new empty logrus logger. -func NewLogger() *logrus.Logger { - return logrus.StandardLogger() -} - -// Logger describes a logger value -type Logger interface { - logrus.FieldLogger - // GetLevel specifies the level at which this logger - // value is logging - GetLevel() logrus.Level - // SetLevel sets the logger's level to the specified value - SetLevel(level logrus.Level) -} - // FatalError is for CLI front-ends: it detects gravitational/trace debugging // information, sends it to the logger, strips it off and prints a clean message to stderr func FatalError(err error) { @@ -231,7 +156,7 @@ func GetIterations() int { if err != nil { panic(err) } - logrus.Debugf("Starting tests with %v iterations.", iter) + slog.DebugContext(context.Background(), "Running tests multiple times due to presence of ITERATIONS environment variable", "iterations", iter) return iter } @@ -484,47 +409,6 @@ func AllowWhitespace(s string) string { return sb.String() } -// NewStdlogger creates a new stdlib logger that uses the specified leveled logger -// for output and the given component as a logging prefix. -func NewStdlogger(logger LeveledOutputFunc, component string) *stdlog.Logger { - return stdlog.New(&stdlogAdapter{ - log: logger, - }, component, stdlog.LstdFlags) -} - -// Write writes the specified buffer p to the underlying leveled logger. -// Implements io.Writer -func (r *stdlogAdapter) Write(p []byte) (n int, err error) { - r.log(string(p)) - return len(p), nil -} - -// stdlogAdapter is an io.Writer that writes into an instance -// of logrus.Logger -type stdlogAdapter struct { - log LeveledOutputFunc -} - -// LeveledOutputFunc describes a function that emits given -// arguments at a specific level to an underlying logger -type LeveledOutputFunc func(args ...interface{}) - -// GetLevel returns the level of the underlying logger -func (r *logWrapper) GetLevel() logrus.Level { - return r.Entry.Logger.GetLevel() -} - -// SetLevel sets the logging level to the given value -func (r *logWrapper) SetLevel(level logrus.Level) { - r.Entry.Logger.SetLevel(level) -} - -// logWrapper wraps a log entry. -// Implements Logger -type logWrapper struct { - *logrus.Entry -} - // needsQuoting returns true if any non-printable characters are found. func needsQuoting(text string) bool { for _, r := range text { diff --git a/lib/utils/log/formatter_test.go b/lib/utils/log/formatter_test.go index 9abb0310ba0be..aff0ec8be3a74 100644 --- a/lib/utils/log/formatter_test.go +++ b/lib/utils/log/formatter_test.go @@ -22,7 +22,6 @@ import ( "bytes" "context" "encoding/json" - "errors" "fmt" "io" "log/slog" @@ -38,7 +37,6 @@ import ( "github.com/google/go-cmp/cmp/cmpopts" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" - "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -48,7 +46,7 @@ import ( const message = "Adding diagnostic debugging handlers.\t To connect with profiler, use go tool pprof diag_addr." var ( - logErr = errors.New("the quick brown fox jumped really high") + logErr = &trace.BadParameterError{Message: "the quick brown fox jumped really high"} addr = fakeAddr{addr: "127.0.0.1:1234"} fields = map[string]any{ @@ -72,6 +70,10 @@ func (a fakeAddr) String() string { return a.addr } +func (a fakeAddr) MarshalText() (text []byte, err error) { + return []byte(a.addr), nil +} + func TestOutput(t *testing.T) { loc, err := time.LoadLocation("Africa/Cairo") require.NoError(t, err, "failed getting timezone") @@ -89,58 +91,50 @@ func TestOutput(t *testing.T) { // 4) the caller outputRegex := regexp.MustCompile(`(\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{3}Z)(\s+.*)(".*diag_addr\.")(.*)(\slog/formatter_test.go:\d{3})`) + expectedFields := map[string]string{ + "local": addr.String(), + "remote": addr.String(), + "login": "llama", + "teleportUser": "user", + "id": "1234", + "test": "123", + "animal": `"llama\n"`, + "error": "[" + trace.DebugReport(logErr) + "]", + "diag_addr": addr.String(), + } + tests := []struct { - name string - logrusLevel logrus.Level - slogLevel slog.Level + name string + slogLevel slog.Level }{ { - name: "trace", - logrusLevel: logrus.TraceLevel, - slogLevel: TraceLevel, + name: "trace", + slogLevel: TraceLevel, }, { - name: "debug", - logrusLevel: logrus.DebugLevel, - slogLevel: slog.LevelDebug, + name: "debug", + slogLevel: slog.LevelDebug, }, { - name: "info", - logrusLevel: logrus.InfoLevel, - slogLevel: slog.LevelInfo, + name: "info", + slogLevel: slog.LevelInfo, }, { - name: "warn", - logrusLevel: logrus.WarnLevel, - slogLevel: slog.LevelWarn, + name: "warn", + slogLevel: slog.LevelWarn, }, { - name: "error", - logrusLevel: logrus.ErrorLevel, - slogLevel: slog.LevelError, + name: "error", + slogLevel: slog.LevelError, }, { - name: "fatal", - logrusLevel: logrus.FatalLevel, - slogLevel: slog.LevelError + 1, + name: "fatal", + slogLevel: slog.LevelError + 1, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - // Create a logrus logger using the custom formatter which outputs to a local buffer. - var logrusOutput bytes.Buffer - formatter := NewDefaultTextFormatter(true) - formatter.timestampEnabled = true - require.NoError(t, formatter.CheckAndSetDefaults()) - - logrusLogger := logrus.New() - logrusLogger.SetFormatter(formatter) - logrusLogger.SetOutput(&logrusOutput) - logrusLogger.ReplaceHooks(logrus.LevelHooks{}) - logrusLogger.SetLevel(test.logrusLevel) - entry := logrusLogger.WithField(teleport.ComponentKey, "test").WithTime(clock.Now().UTC()) - // Create a slog logger using the custom handler which outputs to a local buffer. var slogOutput bytes.Buffer slogConfig := SlogTextHandlerConfig{ @@ -155,13 +149,6 @@ func TestOutput(t *testing.T) { } slogLogger := slog.New(NewSlogTextHandler(&slogOutput, slogConfig)).With(teleport.ComponentKey, "test") - // Add some fields and output the message at the desired log level via logrus. - l := entry.WithField("test", 123).WithField("animal", "llama\n").WithField("error", logErr) - logrusTestLogLineNumber := func() int { - l.WithField("diag_addr", &addr).WithField(teleport.ComponentFields, fields).Log(test.logrusLevel, message) - return getCallerLineNumber() - 1 // Get the line number of this call, and assume the log call is right above it - }() - // Add some fields and output the message at the desired log level via slog. l2 := slogLogger.With("test", 123).With("animal", "llama\n").With("error", logErr) slogTestLogLineNumber := func() int { @@ -169,163 +156,144 @@ func TestOutput(t *testing.T) { return getCallerLineNumber() - 1 // Get the line number of this call, and assume the log call is right above it }() - // Validate that both loggers produces the same output. The added complexity comes from the fact that - // our custom slog handler does NOT sort the additional fields like our logrus formatter does. - logrusMatches := outputRegex.FindStringSubmatch(logrusOutput.String()) - require.NotEmpty(t, logrusMatches, "logrus output was in unexpected format: %s", logrusOutput.String()) + // Validate the logger output. The added complexity comes from the fact that + // our custom slog handler does NOT sort the additional fields. slogMatches := outputRegex.FindStringSubmatch(slogOutput.String()) require.NotEmpty(t, slogMatches, "slog output was in unexpected format: %s", slogOutput.String()) // The first match is the timestamp: 2023-10-31T10:09:06+02:00 - logrusTime, err := time.Parse(time.RFC3339, logrusMatches[1]) - assert.NoError(t, err, "invalid logrus timestamp found %s", logrusMatches[1]) - slogTime, err := time.Parse(time.RFC3339, slogMatches[1]) assert.NoError(t, err, "invalid slog timestamp found %s", slogMatches[1]) - - assert.InDelta(t, logrusTime.Unix(), slogTime.Unix(), 10) + assert.InDelta(t, clock.Now().Unix(), slogTime.Unix(), 10) // Match level, and component: DEBU [TEST] - assert.Empty(t, cmp.Diff(logrusMatches[2], slogMatches[2]), "level, and component to be identical") - // Match the log message: "Adding diagnostic debugging handlers.\t To connect with profiler, use go tool pprof diag_addr.\n" - assert.Empty(t, cmp.Diff(logrusMatches[3], slogMatches[3]), "expected output messages to be identical") + expectedLevel := formatLevel(test.slogLevel, true) + expectedComponent := formatComponent(slog.StringValue("test"), defaultComponentPadding) + expectedMatch := " " + expectedLevel + " " + expectedComponent + " " + assert.Equal(t, expectedMatch, slogMatches[2], "level, and component to be identical") + // Match the log message + assert.Equal(t, `"Adding diagnostic debugging handlers.\t To connect with profiler, use go tool pprof diag_addr."`, slogMatches[3], "expected output messages to be identical") // The last matches are the caller information - assert.Equal(t, fmt.Sprintf(" log/formatter_test.go:%d", logrusTestLogLineNumber), logrusMatches[5]) assert.Equal(t, fmt.Sprintf(" log/formatter_test.go:%d", slogTestLogLineNumber), slogMatches[5]) // The third matches are the fields which will be key value pairs(animal:llama) separated by a space. Since - // logrus sorts the fields and slog doesn't we can't just assert equality and instead build a map of the key + // slog doesn't sort the fields, we can't assert equality and instead build a map of the key // value pairs to ensure they are all present and accounted for. - logrusFieldMatches := fieldsRegex.FindAllStringSubmatch(logrusMatches[4], -1) slogFieldMatches := fieldsRegex.FindAllStringSubmatch(slogMatches[4], -1) // The first match is the key, the second match is the value - logrusFields := map[string]string{} - for _, match := range logrusFieldMatches { - logrusFields[strings.TrimSpace(match[1])] = strings.TrimSpace(match[2]) - } - slogFields := map[string]string{} for _, match := range slogFieldMatches { slogFields[strings.TrimSpace(match[1])] = strings.TrimSpace(match[2]) } - assert.Equal(t, slogFields, logrusFields) + require.Empty(t, + cmp.Diff( + expectedFields, + slogFields, + cmpopts.SortMaps(func(a, b string) bool { return a < b }), + ), + ) }) } }) t.Run("json", func(t *testing.T) { tests := []struct { - name string - logrusLevel logrus.Level - slogLevel slog.Level + name string + slogLevel slog.Level }{ { - name: "trace", - logrusLevel: logrus.TraceLevel, - slogLevel: TraceLevel, + name: "trace", + slogLevel: TraceLevel, }, { - name: "debug", - logrusLevel: logrus.DebugLevel, - slogLevel: slog.LevelDebug, + name: "debug", + slogLevel: slog.LevelDebug, }, { - name: "info", - logrusLevel: logrus.InfoLevel, - slogLevel: slog.LevelInfo, + name: "info", + slogLevel: slog.LevelInfo, }, { - name: "warn", - logrusLevel: logrus.WarnLevel, - slogLevel: slog.LevelWarn, + name: "warn", + slogLevel: slog.LevelWarn, }, { - name: "error", - logrusLevel: logrus.ErrorLevel, - slogLevel: slog.LevelError, + name: "error", + slogLevel: slog.LevelError, }, { - name: "fatal", - logrusLevel: logrus.FatalLevel, - slogLevel: slog.LevelError + 1, + name: "fatal", + slogLevel: slog.LevelError + 1, + }, + } + + expectedFields := map[string]any{ + "trace.fields": map[string]any{ + "teleportUser": "user", + "id": float64(1234), + "local": addr.String(), + "login": "llama", + "remote": addr.String(), }, + "test": float64(123), + "animal": `llama`, + "error": logErr.Error(), + "diag_addr": addr.String(), + "component": "test", + "message": message, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - // Create a logrus logger using the custom formatter which outputs to a local buffer. - var logrusOut bytes.Buffer - formatter := &JSONFormatter{ - ExtraFields: nil, - callerEnabled: true, - } - require.NoError(t, formatter.CheckAndSetDefaults()) - - logrusLogger := logrus.New() - logrusLogger.SetFormatter(formatter) - logrusLogger.SetOutput(&logrusOut) - logrusLogger.ReplaceHooks(logrus.LevelHooks{}) - logrusLogger.SetLevel(test.logrusLevel) - entry := logrusLogger.WithField(teleport.ComponentKey, "test") - // Create a slog logger using the custom formatter which outputs to a local buffer. var slogOutput bytes.Buffer slogLogger := slog.New(NewSlogJSONHandler(&slogOutput, SlogJSONHandlerConfig{Level: test.slogLevel})).With(teleport.ComponentKey, "test") - // Add some fields and output the message at the desired log level via logrus. - l := entry.WithField("test", 123).WithField("animal", "llama").WithField("error", trace.Wrap(logErr)) - logrusTestLogLineNumber := func() int { - l.WithField("diag_addr", addr.String()).Log(test.logrusLevel, message) - return getCallerLineNumber() - 1 // Get the line number of this call, and assume the log call is right above it - }() - // Add some fields and output the message at the desired log level via slog. l2 := slogLogger.With("test", 123).With("animal", "llama").With("error", trace.Wrap(logErr)) slogTestLogLineNumber := func() int { - l2.Log(context.Background(), test.slogLevel, message, "diag_addr", &addr) + l2.With(teleport.ComponentFields, fields).Log(context.Background(), test.slogLevel, message, "diag_addr", &addr) return getCallerLineNumber() - 1 // Get the line number of this call, and assume the log call is right above it }() - // The order of the fields emitted by the two loggers is different, so comparing the output directly - // for equality won't work. Instead, a map is built with all the key value pairs, excluding the caller - // and that map is compared to ensure all items are present and match. - var logrusData map[string]any - require.NoError(t, json.Unmarshal(logrusOut.Bytes(), &logrusData), "invalid logrus output format") - var slogData map[string]any require.NoError(t, json.Unmarshal(slogOutput.Bytes(), &slogData), "invalid slog output format") - logrusCaller, ok := logrusData["caller"].(string) - delete(logrusData, "caller") - assert.True(t, ok, "caller was missing from logrus output") - assert.Equal(t, fmt.Sprintf("log/formatter_test.go:%d", logrusTestLogLineNumber), logrusCaller) - slogCaller, ok := slogData["caller"].(string) delete(slogData, "caller") assert.True(t, ok, "caller was missing from slog output") assert.Equal(t, fmt.Sprintf("log/formatter_test.go:%d", slogTestLogLineNumber), slogCaller) - logrusTimestamp, ok := logrusData["timestamp"].(string) - delete(logrusData, "timestamp") - assert.True(t, ok, "time was missing from logrus output") + slogLevel, ok := slogData["level"].(string) + delete(slogData, "level") + assert.True(t, ok, "level was missing from slog output") + var expectedLevel string + switch test.slogLevel { + case TraceLevel: + expectedLevel = "trace" + case slog.LevelWarn: + expectedLevel = "warning" + case slog.LevelError + 1: + expectedLevel = "fatal" + default: + expectedLevel = test.slogLevel.String() + } + assert.Equal(t, strings.ToLower(expectedLevel), slogLevel) slogTimestamp, ok := slogData["timestamp"].(string) delete(slogData, "timestamp") assert.True(t, ok, "time was missing from slog output") - logrusTime, err := time.Parse(time.RFC3339, logrusTimestamp) - assert.NoError(t, err, "invalid logrus timestamp %s", logrusTimestamp) - slogTime, err := time.Parse(time.RFC3339, slogTimestamp) assert.NoError(t, err, "invalid slog timestamp %s", slogTimestamp) - assert.InDelta(t, logrusTime.Unix(), slogTime.Unix(), 10) + assert.InDelta(t, clock.Now().Unix(), slogTime.Unix(), 10) require.Empty(t, cmp.Diff( - logrusData, + expectedFields, slogData, cmpopts.SortMaps(func(a, b string) bool { return a < b }), ), @@ -347,38 +315,6 @@ func getCallerLineNumber() int { func BenchmarkFormatter(b *testing.B) { ctx := context.Background() b.ReportAllocs() - b.Run("logrus", func(b *testing.B) { - b.Run("text", func(b *testing.B) { - formatter := NewDefaultTextFormatter(true) - require.NoError(b, formatter.CheckAndSetDefaults()) - logger := logrus.New() - logger.SetFormatter(formatter) - logger.SetOutput(io.Discard) - b.ResetTimer() - - entry := logger.WithField(teleport.ComponentKey, "test") - for i := 0; i < b.N; i++ { - l := entry.WithField("test", 123).WithField("animal", "llama\n").WithField("error", logErr) - l.WithField("diag_addr", &addr).WithField(teleport.ComponentFields, fields).Info(message) - } - }) - - b.Run("json", func(b *testing.B) { - formatter := &JSONFormatter{} - require.NoError(b, formatter.CheckAndSetDefaults()) - logger := logrus.New() - logger.SetFormatter(formatter) - logger.SetOutput(io.Discard) - logger.ReplaceHooks(logrus.LevelHooks{}) - b.ResetTimer() - - entry := logger.WithField(teleport.ComponentKey, "test") - for i := 0; i < b.N; i++ { - l := entry.WithField("test", 123).WithField("animal", "llama\n").WithField("error", logErr) - l.WithField("diag_addr", &addr).WithField(teleport.ComponentFields, fields).Info(message) - } - }) - }) b.Run("slog", func(b *testing.B) { b.Run("default_text", func(b *testing.B) { @@ -430,47 +366,26 @@ func BenchmarkFormatter(b *testing.B) { } func TestConcurrentOutput(t *testing.T) { - t.Run("logrus", func(t *testing.T) { - debugFormatter := NewDefaultTextFormatter(true) - require.NoError(t, debugFormatter.CheckAndSetDefaults()) - logrus.SetFormatter(debugFormatter) - logrus.SetOutput(os.Stdout) - - logger := logrus.WithField(teleport.ComponentKey, "test") - - var wg sync.WaitGroup - for i := 0; i < 1000; i++ { - wg.Add(1) - go func(i int) { - defer wg.Done() - logger.Infof("Detected Teleport component %d is running in a degraded state.", i) - }(i) - } - wg.Wait() - }) + logger := slog.New(NewSlogTextHandler(os.Stdout, SlogTextHandlerConfig{ + EnableColors: true, + })).With(teleport.ComponentKey, "test") - t.Run("slog", func(t *testing.T) { - logger := slog.New(NewSlogTextHandler(os.Stdout, SlogTextHandlerConfig{ - EnableColors: true, - })).With(teleport.ComponentKey, "test") - - var wg sync.WaitGroup - ctx := context.Background() - for i := 0; i < 1000; i++ { - wg.Add(1) - go func(i int) { - defer wg.Done() - logger.InfoContext(ctx, "Teleport component entered degraded state", - slog.Int("component", i), - slog.Group("group", - slog.String("test", "123"), - slog.String("animal", "llama"), - ), - ) - }(i) - } - wg.Wait() - }) + var wg sync.WaitGroup + ctx := context.Background() + for i := 0; i < 1000; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + logger.InfoContext(ctx, "Teleport component entered degraded state", + slog.Int("component", i), + slog.Group("group", + slog.String("test", "123"), + slog.String("animal", "llama"), + ), + ) + }(i) + } + wg.Wait() } // allPossibleSubsets returns all combinations of subsets for the @@ -493,58 +408,34 @@ func allPossibleSubsets(in []string) [][]string { return subsets } -// TestExtraFields validates that the output is identical for the -// logrus formatter and slog handler based on the configured extra -// fields. +// TestExtraFields validates that the output is expected for the +// slog handler based on the configured extra fields. func TestExtraFields(t *testing.T) { // Capture a fake time that all output will use. now := clockwork.NewFakeClock().Now() // Capture the caller information to be injected into all messages. pc, _, _, _ := runtime.Caller(0) - fs := runtime.CallersFrames([]uintptr{pc}) - f, _ := fs.Next() - callerTrace := &trace.Trace{ - Func: f.Function, - Path: f.File, - Line: f.Line, - } const message = "testing 123" - // Test against every possible configured combination of allowed format fields. - fields := allPossibleSubsets(defaultFormatFields) - t.Run("text", func(t *testing.T) { - for _, configuredFields := range fields { + // Test against every possible configured combination of allowed format fields. + for _, configuredFields := range allPossibleSubsets(defaultFormatFields) { name := "not configured" if len(configuredFields) > 0 { name = strings.Join(configuredFields, " ") } t.Run(name, func(t *testing.T) { - logrusFormatter := TextFormatter{ - ExtraFields: configuredFields, - } - // Call CheckAndSetDefaults to exercise the extra fields logic. Since - // FormatCaller is always overridden within CheckAndSetDefaults, it is - // explicitly set afterward so the caller points to our fake call site. - require.NoError(t, logrusFormatter.CheckAndSetDefaults()) - logrusFormatter.FormatCaller = callerTrace.String - - var slogOutput bytes.Buffer - var slogHandler slog.Handler = NewSlogTextHandler(&slogOutput, SlogTextHandlerConfig{ConfiguredFields: configuredFields}) - - entry := &logrus.Entry{ - Data: logrus.Fields{"animal": "llama", "vegetable": "carrot", teleport.ComponentKey: "test"}, - Time: now, - Level: logrus.DebugLevel, - Caller: &f, - Message: message, - } - - logrusOut, err := logrusFormatter.Format(entry) - require.NoError(t, err) + replaced := map[string]struct{}{} + var slogHandler slog.Handler = NewSlogTextHandler(io.Discard, SlogTextHandlerConfig{ + ConfiguredFields: configuredFields, + ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr { + replaced[a.Key] = struct{}{} + return a + }, + }) record := slog.Record{ Time: now, @@ -557,42 +448,29 @@ func TestExtraFields(t *testing.T) { require.NoError(t, slogHandler.Handle(context.Background(), record)) - require.Equal(t, string(logrusOut), slogOutput.String()) + for k := range replaced { + delete(replaced, k) + } + + require.Empty(t, replaced, replaced) }) } }) t.Run("json", func(t *testing.T) { - for _, configuredFields := range fields { + // Test against every possible configured combination of allowed format fields. + // Note, the json handler limits the allowed fields to a subset of those allowed + // by the text handler. + for _, configuredFields := range allPossibleSubsets([]string{CallerField, ComponentField, TimestampField}) { name := "not configured" if len(configuredFields) > 0 { name = strings.Join(configuredFields, " ") } t.Run(name, func(t *testing.T) { - logrusFormatter := JSONFormatter{ - ExtraFields: configuredFields, - } - // Call CheckAndSetDefaults to exercise the extra fields logic. Since - // FormatCaller is always overridden within CheckAndSetDefaults, it is - // explicitly set afterward so the caller points to our fake call site. - require.NoError(t, logrusFormatter.CheckAndSetDefaults()) - logrusFormatter.FormatCaller = callerTrace.String - var slogOutput bytes.Buffer var slogHandler slog.Handler = NewSlogJSONHandler(&slogOutput, SlogJSONHandlerConfig{ConfiguredFields: configuredFields}) - entry := &logrus.Entry{ - Data: logrus.Fields{"animal": "llama", "vegetable": "carrot", teleport.ComponentKey: "test"}, - Time: now, - Level: logrus.DebugLevel, - Caller: &f, - Message: message, - } - - logrusOut, err := logrusFormatter.Format(entry) - require.NoError(t, err) - record := slog.Record{ Time: now, Message: message, @@ -604,11 +482,31 @@ func TestExtraFields(t *testing.T) { require.NoError(t, slogHandler.Handle(context.Background(), record)) - var slogData, logrusData map[string]any - require.NoError(t, json.Unmarshal(logrusOut, &logrusData)) + var slogData map[string]any require.NoError(t, json.Unmarshal(slogOutput.Bytes(), &slogData)) - require.Equal(t, slogData, logrusData) + delete(slogData, "animal") + delete(slogData, "vegetable") + delete(slogData, "message") + delete(slogData, "level") + + var expectedLen int + expectedFields := configuredFields + switch l := len(configuredFields); l { + case 0: + // The level field was removed above, but is included in the default fields + expectedLen = len(defaultFormatFields) - 1 + expectedFields = defaultFormatFields + default: + expectedLen = l + } + require.Len(t, slogData, expectedLen, slogData) + + for _, f := range expectedFields { + delete(slogData, f) + } + + require.Empty(t, slogData, slogData) }) } }) diff --git a/lib/utils/log/log.go b/lib/utils/log/log.go index 2f16b902e3df6..d8aadb75146bf 100644 --- a/lib/utils/log/log.go +++ b/lib/utils/log/log.go @@ -42,6 +42,8 @@ type Config struct { ExtraFields []string // EnableColors dictates if output should be colored. EnableColors bool + // Padding to use for various components. + Padding int } // Initialize configures the default global logger based on the @@ -112,6 +114,7 @@ func Initialize(loggerConfig Config) (*slog.Logger, *slog.LevelVar, error) { Level: level, EnableColors: loggerConfig.EnableColors, ConfiguredFields: configuredFields, + Padding: loggerConfig.Padding, })) slog.SetDefault(logger) case "json": diff --git a/lib/utils/log/logrus_formatter.go b/lib/utils/log/logrus_formatter.go deleted file mode 100644 index 14ad8441da7cc..0000000000000 --- a/lib/utils/log/logrus_formatter.go +++ /dev/null @@ -1,427 +0,0 @@ -/* - * Teleport - * Copyright (C) 2023 Gravitational, Inc. - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -package log - -import ( - "fmt" - "regexp" - "runtime" - "slices" - "strconv" - "strings" - - "github.com/gravitational/trace" - "github.com/sirupsen/logrus" - - "github.com/gravitational/teleport" -) - -// TextFormatter is a [logrus.Formatter] that outputs messages in -// a textual format. -type TextFormatter struct { - // ComponentPadding is a padding to pick when displaying - // and formatting component field, defaults to DefaultComponentPadding - ComponentPadding int - // EnableColors enables colored output - EnableColors bool - // FormatCaller is a function to return (part) of source file path for output. - // Defaults to filePathAndLine() if unspecified - FormatCaller func() (caller string) - // ExtraFields represent the extra fields that will be added to the log message - ExtraFields []string - // TimestampEnabled specifies if timestamp is enabled in logs - timestampEnabled bool - // CallerEnabled specifies if caller is enabled in logs - callerEnabled bool -} - -type writer struct { - b *buffer -} - -func newWriter() *writer { - return &writer{b: &buffer{}} -} - -func (w *writer) Len() int { - return len(*w.b) -} - -func (w *writer) WriteString(s string) (int, error) { - return w.b.WriteString(s) -} - -func (w *writer) WriteByte(c byte) error { - return w.b.WriteByte(c) -} - -func (w *writer) Bytes() []byte { - return *w.b -} - -// NewDefaultTextFormatter creates a TextFormatter with -// the default options set. -func NewDefaultTextFormatter(enableColors bool) *TextFormatter { - return &TextFormatter{ - ComponentPadding: defaultComponentPadding, - FormatCaller: formatCallerWithPathAndLine, - ExtraFields: defaultFormatFields, - EnableColors: enableColors, - callerEnabled: true, - timestampEnabled: false, - } -} - -// CheckAndSetDefaults checks and sets log format configuration. -func (tf *TextFormatter) CheckAndSetDefaults() error { - // set padding - if tf.ComponentPadding == 0 { - tf.ComponentPadding = defaultComponentPadding - } - // set caller - tf.FormatCaller = formatCallerWithPathAndLine - - // set log formatting - if tf.ExtraFields == nil { - tf.timestampEnabled = true - tf.callerEnabled = true - tf.ExtraFields = defaultFormatFields - return nil - } - - if slices.Contains(tf.ExtraFields, TimestampField) { - tf.timestampEnabled = true - } - - if slices.Contains(tf.ExtraFields, CallerField) { - tf.callerEnabled = true - } - - return nil -} - -// Format formats each log line as configured in teleport config file. -func (tf *TextFormatter) Format(e *logrus.Entry) ([]byte, error) { - caller := tf.FormatCaller() - w := newWriter() - - // write timestamp first if enabled - if tf.timestampEnabled { - *w.b = appendRFC3339Millis(*w.b, e.Time.Round(0)) - } - - for _, field := range tf.ExtraFields { - switch field { - case LevelField: - var color int - var level string - switch e.Level { - case logrus.TraceLevel: - level = "TRACE" - color = gray - case logrus.DebugLevel: - level = "DEBUG" - color = gray - case logrus.InfoLevel: - level = "INFO" - color = blue - case logrus.WarnLevel: - level = "WARN" - color = yellow - case logrus.ErrorLevel: - level = "ERROR" - color = red - case logrus.FatalLevel: - level = "FATAL" - color = red - default: - color = blue - level = strings.ToUpper(e.Level.String()) - } - - if !tf.EnableColors { - color = noColor - } - - w.writeField(padMax(level, defaultLevelPadding), color) - case ComponentField: - padding := defaultComponentPadding - if tf.ComponentPadding != 0 { - padding = tf.ComponentPadding - } - if w.Len() > 0 { - w.WriteByte(' ') - } - component, ok := e.Data[teleport.ComponentKey].(string) - if ok && component != "" { - component = fmt.Sprintf("[%v]", component) - } - component = strings.ToUpper(padMax(component, padding)) - if component[len(component)-1] != ' ' { - component = component[:len(component)-1] + "]" - } - - w.WriteString(component) - default: - if _, ok := knownFormatFields[field]; !ok { - return nil, trace.BadParameter("invalid log format key: %v", field) - } - } - } - - // always use message - if e.Message != "" { - w.writeField(e.Message, noColor) - } - - if len(e.Data) > 0 { - w.writeMap(e.Data) - } - - // write caller last if enabled - if tf.callerEnabled && caller != "" { - w.writeField(caller, noColor) - } - - w.WriteByte('\n') - return w.Bytes(), nil -} - -// JSONFormatter implements the [logrus.Formatter] interface and adds extra -// fields to log entries. -type JSONFormatter struct { - logrus.JSONFormatter - - ExtraFields []string - // FormatCaller is a function to return (part) of source file path for output. - // Defaults to filePathAndLine() if unspecified - FormatCaller func() (caller string) - - callerEnabled bool - componentEnabled bool -} - -// CheckAndSetDefaults checks and sets log format configuration. -func (j *JSONFormatter) CheckAndSetDefaults() error { - // set log formatting - if j.ExtraFields == nil { - j.ExtraFields = defaultFormatFields - } - // set caller - j.FormatCaller = formatCallerWithPathAndLine - - if slices.Contains(j.ExtraFields, CallerField) { - j.callerEnabled = true - } - - if slices.Contains(j.ExtraFields, ComponentField) { - j.componentEnabled = true - } - - // rename default fields - j.JSONFormatter = logrus.JSONFormatter{ - FieldMap: logrus.FieldMap{ - logrus.FieldKeyTime: TimestampField, - logrus.FieldKeyLevel: LevelField, - logrus.FieldKeyMsg: messageField, - }, - DisableTimestamp: !slices.Contains(j.ExtraFields, TimestampField), - } - - return nil -} - -// Format formats each log line as configured in teleport config file. -func (j *JSONFormatter) Format(e *logrus.Entry) ([]byte, error) { - if j.callerEnabled { - path := j.FormatCaller() - e.Data[CallerField] = path - } - - if j.componentEnabled { - e.Data[ComponentField] = e.Data[teleport.ComponentKey] - } - - delete(e.Data, teleport.ComponentKey) - - return j.JSONFormatter.Format(e) -} - -// NewTestJSONFormatter creates a JSONFormatter that is -// configured for output in tests. -func NewTestJSONFormatter() *JSONFormatter { - formatter := &JSONFormatter{} - if err := formatter.CheckAndSetDefaults(); err != nil { - panic(err) - } - return formatter -} - -func (w *writer) writeError(value interface{}) { - switch err := value.(type) { - case trace.Error: - *w.b = fmt.Appendf(*w.b, "[%v]", err.DebugReport()) - default: - *w.b = fmt.Appendf(*w.b, "[%v]", value) - } -} - -func (w *writer) writeField(value interface{}, color int) { - if w.Len() > 0 { - w.WriteByte(' ') - } - w.writeValue(value, color) -} - -func (w *writer) writeKeyValue(key string, value interface{}) { - if w.Len() > 0 { - w.WriteByte(' ') - } - w.WriteString(key) - w.WriteByte(':') - if key == logrus.ErrorKey { - w.writeError(value) - return - } - w.writeValue(value, noColor) -} - -func (w *writer) writeValue(value interface{}, color int) { - if s, ok := value.(string); ok { - if color != noColor { - *w.b = fmt.Appendf(*w.b, "\u001B[%dm", color) - } - - if needsQuoting(s) { - *w.b = strconv.AppendQuote(*w.b, s) - } else { - *w.b = fmt.Append(*w.b, s) - } - - if color != noColor { - *w.b = fmt.Append(*w.b, "\u001B[0m") - } - return - } - - if color != noColor { - *w.b = fmt.Appendf(*w.b, "\x1b[%dm%v\x1b[0m", color, value) - return - } - - *w.b = fmt.Appendf(*w.b, "%v", value) -} - -func (w *writer) writeMap(m map[string]any) { - if len(m) == 0 { - return - } - keys := make([]string, 0, len(m)) - for key := range m { - keys = append(keys, key) - } - slices.Sort(keys) - for _, key := range keys { - if key == teleport.ComponentKey { - continue - } - switch value := m[key].(type) { - case map[string]any: - w.writeMap(value) - case logrus.Fields: - w.writeMap(value) - default: - w.writeKeyValue(key, value) - } - } -} - -type frameCursor struct { - // current specifies the current stack frame. - // if omitted, rest contains the complete stack - current *runtime.Frame - // rest specifies the rest of stack frames to explore - rest *runtime.Frames - // n specifies the total number of stack frames - n int -} - -// formatCallerWithPathAndLine formats the caller in the form path/segment: -// for output in the log -func formatCallerWithPathAndLine() (path string) { - if cursor := findFrame(); cursor != nil { - t := newTraceFromFrames(*cursor, nil) - return t.Loc() - } - return "" -} - -var frameIgnorePattern = regexp.MustCompile(`github\.com/sirupsen/logrus`) - -// findFrames positions the stack pointer to the first -// function that does not match the frameIngorePattern -// and returns the rest of the stack frames -func findFrame() *frameCursor { - var buf [32]uintptr - // Skip enough frames to start at user code. - // This number is a mere hint to the following loop - // to start as close to user code as possible and getting it right is not mandatory. - // The skip count might need to get updated if the call to findFrame is - // moved up/down the call stack - n := runtime.Callers(4, buf[:]) - pcs := buf[:n] - frames := runtime.CallersFrames(pcs) - for i := 0; i < n; i++ { - frame, _ := frames.Next() - if !frameIgnorePattern.MatchString(frame.Function) { - return &frameCursor{ - current: &frame, - rest: frames, - n: n, - } - } - } - return nil -} - -func newTraceFromFrames(cursor frameCursor, err error) *trace.TraceErr { - traces := make(trace.Traces, 0, cursor.n) - if cursor.current != nil { - traces = append(traces, frameToTrace(*cursor.current)) - } - for { - frame, more := cursor.rest.Next() - traces = append(traces, frameToTrace(frame)) - if !more { - break - } - } - return &trace.TraceErr{ - Err: err, - Traces: traces, - } -} - -func frameToTrace(frame runtime.Frame) trace.Trace { - return trace.Trace{ - Func: frame.Function, - Path: frame.File, - Line: frame.Line, - } -} diff --git a/lib/utils/log/slog.go b/lib/utils/log/slog.go index 46f0e13627b3e..bfb34f4a94114 100644 --- a/lib/utils/log/slog.go +++ b/lib/utils/log/slog.go @@ -27,7 +27,6 @@ import ( "unicode" "github.com/gravitational/trace" - "github.com/sirupsen/logrus" oteltrace "go.opentelemetry.io/otel/trace" ) @@ -68,25 +67,6 @@ var SupportedLevelsText = []string{ slog.LevelError.String(), } -// SlogLevelToLogrusLevel converts a [slog.Level] to its equivalent -// [logrus.Level]. -func SlogLevelToLogrusLevel(level slog.Level) logrus.Level { - switch level { - case TraceLevel: - return logrus.TraceLevel - case slog.LevelDebug: - return logrus.DebugLevel - case slog.LevelInfo: - return logrus.InfoLevel - case slog.LevelWarn: - return logrus.WarnLevel - case slog.LevelError: - return logrus.ErrorLevel - default: - return logrus.FatalLevel - } -} - // DiscardHandler is a [slog.Handler] that discards all messages. It // is more efficient than a [slog.Handler] which outputs to [io.Discard] since // it performs zero formatting. diff --git a/lib/utils/log/slog_text_handler.go b/lib/utils/log/slog_text_handler.go index 7f93a388977bb..612615ba8582d 100644 --- a/lib/utils/log/slog_text_handler.go +++ b/lib/utils/log/slog_text_handler.go @@ -150,45 +150,12 @@ func (s *SlogTextHandler) Handle(ctx context.Context, r slog.Record) error { // Processing fields in this manner allows users to // configure the level and component position in the output. - // This matches the behavior of the original logrus. All other + // This matches the behavior of the original logrus formatter. All other // fields location in the output message are static. for _, field := range s.cfg.ConfiguredFields { switch field { case LevelField: - var color int - var level string - switch r.Level { - case TraceLevel: - level = "TRACE" - color = gray - case slog.LevelDebug: - level = "DEBUG" - color = gray - case slog.LevelInfo: - level = "INFO" - color = blue - case slog.LevelWarn: - level = "WARN" - color = yellow - case slog.LevelError: - level = "ERROR" - color = red - case slog.LevelError + 1: - level = "FATAL" - color = red - default: - color = blue - level = r.Level.String() - } - - if !s.cfg.EnableColors { - color = noColor - } - - level = padMax(level, defaultLevelPadding) - if color != noColor { - level = fmt.Sprintf("\u001B[%dm%s\u001B[0m", color, level) - } + level := formatLevel(r.Level, s.cfg.EnableColors) if rep == nil { state.appendKey(slog.LevelKey) @@ -211,12 +178,8 @@ func (s *SlogTextHandler) Handle(ctx context.Context, r slog.Record) error { if attr.Key != teleport.ComponentKey { return true } - component = fmt.Sprintf("[%v]", attr.Value) - component = strings.ToUpper(padMax(component, s.cfg.Padding)) - if component[len(component)-1] != ' ' { - component = component[:len(component)-1] + "]" - } + component = formatComponent(attr.Value, s.cfg.Padding) return false }) @@ -271,6 +234,55 @@ func (s *SlogTextHandler) Handle(ctx context.Context, r slog.Record) error { return err } +func formatLevel(value slog.Level, enableColors bool) string { + var color int + var level string + switch value { + case TraceLevel: + level = "TRACE" + color = gray + case slog.LevelDebug: + level = "DEBUG" + color = gray + case slog.LevelInfo: + level = "INFO" + color = blue + case slog.LevelWarn: + level = "WARN" + color = yellow + case slog.LevelError: + level = "ERROR" + color = red + case slog.LevelError + 1: + level = "FATAL" + color = red + default: + color = blue + level = value.String() + } + + if !enableColors { + color = noColor + } + + level = padMax(level, defaultLevelPadding) + if color != noColor { + level = fmt.Sprintf("\u001B[%dm%s\u001B[0m", color, level) + } + + return level +} + +func formatComponent(value slog.Value, padding int) string { + component := fmt.Sprintf("[%v]", value) + component = strings.ToUpper(padMax(component, padding)) + if component[len(component)-1] != ' ' { + component = component[:len(component)-1] + "]" + } + + return component +} + func (s *SlogTextHandler) clone() *SlogTextHandler { // We can't use assignment because we can't copy the mutex. return &SlogTextHandler{ diff --git a/lib/utils/log/writer.go b/lib/utils/log/writer.go deleted file mode 100644 index 77cf3037a8b66..0000000000000 --- a/lib/utils/log/writer.go +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Teleport - * Copyright (C) 2023 Gravitational, Inc. - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -package log - -import ( - "io" - "sync" -) - -// SharedWriter is an [io.Writer] implementation that protects -// writes with a mutex. This allows a single [io.Writer] to be shared -// by both logrus and slog without their output clobbering each other. -type SharedWriter struct { - mu sync.Mutex - io.Writer -} - -func (s *SharedWriter) Write(p []byte) (int, error) { - s.mu.Lock() - defer s.mu.Unlock() - - return s.Writer.Write(p) -} - -// NewSharedWriter wraps the provided [io.Writer] in a writer that -// is thread safe. -func NewSharedWriter(w io.Writer) *SharedWriter { - return &SharedWriter{Writer: w} -} diff --git a/lib/web/join_tokens.go b/lib/web/join_tokens.go index df9896f5e1532..d54269df7c381 100644 --- a/lib/web/join_tokens.go +++ b/lib/web/join_tokens.go @@ -631,6 +631,7 @@ func getJoinScript(ctx context.Context, settings scriptSettings, m nodeAPIGetter } var buf bytes.Buffer + var appServerResourceLabels []string // If app install mode is requested but parameters are blank for some reason, // we need to return an error. if settings.appInstallMode { @@ -640,6 +641,12 @@ func getJoinScript(ctx context.Context, settings scriptSettings, m nodeAPIGetter if !appURIPattern.MatchString(settings.appURI) { return "", trace.BadParameter("appURI %q contains invalid characters", settings.appURI) } + + suggestedLabels := token.GetSuggestedLabels() + appServerResourceLabels, err = scripts.MarshalLabelsYAML(suggestedLabels, 4) + if err != nil { + return "", trace.Wrap(err) + } } if settings.discoveryInstallMode { @@ -689,6 +696,7 @@ func getJoinScript(ctx context.Context, settings scriptSettings, m nodeAPIGetter "installUpdater": strconv.FormatBool(settings.installUpdater), "version": shsprintf.EscapeDefaultContext(version), "appInstallMode": strconv.FormatBool(settings.appInstallMode), + "appServerResourceLabels": appServerResourceLabels, "appName": shsprintf.EscapeDefaultContext(settings.appName), "appURI": shsprintf.EscapeDefaultContext(settings.appURI), "joinMethod": shsprintf.EscapeDefaultContext(settings.joinMethod), diff --git a/lib/web/join_tokens_test.go b/lib/web/join_tokens_test.go index ba0b0be4ff9b1..4e0062b333ef3 100644 --- a/lib/web/join_tokens_test.go +++ b/lib/web/join_tokens_test.go @@ -761,6 +761,17 @@ func TestGetNodeJoinScript(t *testing.T) { require.Contains(t, script, fmt.Sprintf("%s=%s", types.InternalResourceIDLabel, internalResourceID)) }, }, + { + desc: "app server labels", + settings: scriptSettings{token: validToken, appInstallMode: true, appName: "app-name", appURI: "app-uri"}, + errAssert: require.NoError, + extraAssertions: func(script string) { + require.Contains(t, script, `APP_NAME='app-name'`) + require.Contains(t, script, `APP_URI='app-uri'`) + require.Contains(t, script, `public_addr`) + require.Contains(t, script, fmt.Sprintf(" labels:\n %s: %s", types.InternalResourceIDLabel, internalResourceID)) + }, + }, } { t.Run(test.desc, func(t *testing.T) { script, err := getJoinScript(context.Background(), test.settings, m) diff --git a/lib/web/scripts/node-join/install.sh b/lib/web/scripts/node-join/install.sh index 3d8403c00787d..64c7cc6b6aab2 100755 --- a/lib/web/scripts/node-join/install.sh +++ b/lib/web/scripts/node-join/install.sh @@ -441,6 +441,11 @@ get_yaml_list() { install_teleport_app_config() { log "Writing Teleport app service config to ${TELEPORT_CONFIG_PATH}" CA_PINS_CONFIG=$(get_yaml_list "ca_pin" "${CA_PIN_HASHES}" " ") + # This file is processed by `shellschek` as part of the lint step + # It detects an issue because of un-set variables - $index and $line. This check is called SC2154. + # However, that's not an issue, because those variables are replaced when we run go's text/template engine over it. + # When executing the script, those are no long variables but actual values. + # shellcheck disable=SC2154 cat << EOF > ${TELEPORT_CONFIG_PATH} version: v3 teleport: @@ -463,6 +468,9 @@ app_service: - name: "${APP_NAME}" uri: "${APP_URI}" public_addr: ${APP_PUBLIC_ADDR} + labels:{{range $index, $line := .appServerResourceLabels}} + {{$line -}} +{{end}} EOF } # installs the provided teleport config (for database service) diff --git a/lib/web/usertasks_test.go b/lib/web/usertasks_test.go index 0bb2dbb9a9f9a..13e9723458090 100644 --- a/lib/web/usertasks_test.go +++ b/lib/web/usertasks_test.go @@ -31,6 +31,7 @@ import ( usertasksv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/usertasks/v1" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/types/usertasks" + "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/web/ui" ) @@ -53,6 +54,8 @@ func TestUserTask(t *testing.T) { }) require.NoError(t, err) pack := env.proxies[0].authPack(t, userWithRW, []types.Role{roleRWUserTask}) + adminClient, err := env.server.NewClient(auth.TestAdmin()) + require.NoError(t, err) getAllEndpoint := pack.clt.Endpoint("webapi", "sites", clusterName, "usertask") singleItemEndpoint := func(name string) string { @@ -90,7 +93,7 @@ func TestUserTask(t *testing.T) { }) require.NoError(t, err) - _, err = env.proxies[0].auth.Auth().CreateUserTask(ctx, userTask) + _, err = adminClient.UserTasksServiceClient().CreateUserTask(ctx, userTask) require.NoError(t, err) userTaskForTest = userTask }