From a26c2a94e85d7a1c53d80e3999f7274392e1fec8 Mon Sep 17 00:00:00 2001 From: Forrest <30576607+fspmarshall@users.noreply.github.com> Date: Thu, 9 Jan 2025 14:30:37 -0800 Subject: [PATCH 1/8] add ssh identity object (#50787) --- lib/auth/auth.go | 68 ++-- lib/auth/keygen/keygen.go | 176 +++------- lib/auth/keygen/keygen_test.go | 34 +- lib/auth/test/suite.go | 158 +++++---- lib/auth/testauthority/testauthority.go | 3 +- lib/client/client_store_test.go | 23 +- lib/client/cluster_client_test.go | 10 +- lib/client/identityfile/identity_test.go | 10 +- lib/client/keyagent_test.go | 23 +- lib/reversetunnel/srv_test.go | 13 +- lib/services/authority.go | 99 ------ lib/srv/authhandlers_test.go | 30 +- lib/sshca/identity.go | 392 +++++++++++++++++++++++ lib/sshca/identity_test.go | 97 ++++++ lib/sshca/sshca.go | 37 ++- 15 files changed, 770 insertions(+), 403 deletions(-) create mode 100644 lib/sshca/identity.go create mode 100644 lib/sshca/identity_test.go diff --git a/lib/auth/auth.go b/lib/auth/auth.go index 82bd49e68befb..aef1a77ed2564 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -3241,39 +3241,41 @@ func generateCert(ctx context.Context, a *Server, req certRequest, caType types. return nil, trace.Wrap(err) } - params := services.UserCertParams{ - CASigner: sshSigner, - PublicUserKey: req.sshPublicKey, - Username: req.user.GetName(), - Impersonator: req.impersonator, - AllowedLogins: allowedLogins, - TTL: sessionTTL, - Roles: req.checker.RoleNames(), - CertificateFormat: certificateFormat, - PermitPortForwarding: req.checker.CanPortForward(), - PermitAgentForwarding: req.checker.CanForwardAgents(), - PermitX11Forwarding: req.checker.PermitX11Forwarding(), - RouteToCluster: req.routeToCluster, - Traits: req.traits, - ActiveRequests: req.activeRequests, - MFAVerified: req.mfaVerified, - PreviousIdentityExpires: req.previousIdentityExpires, - LoginIP: req.loginIP, - PinnedIP: pinnedIP, - DisallowReissue: req.disallowReissue, - Renewable: req.renewable, - Generation: req.generation, - BotName: req.botName, - BotInstanceID: req.botInstanceID, - CertificateExtensions: req.checker.CertificateExtensions(), - AllowedResourceIDs: requestedResourcesStr, - ConnectionDiagnosticID: req.connectionDiagnosticID, - PrivateKeyPolicy: attestedKeyPolicy, - DeviceID: req.deviceExtensions.DeviceID, - DeviceAssetTag: req.deviceExtensions.AssetTag, - DeviceCredentialID: req.deviceExtensions.CredentialID, - GitHubUserID: githubUserID, - GitHubUsername: githubUsername, + params := sshca.UserCertificateRequest{ + CASigner: sshSigner, + PublicUserKey: req.sshPublicKey, + TTL: sessionTTL, + CertificateFormat: certificateFormat, + Identity: sshca.Identity{ + Username: req.user.GetName(), + Impersonator: req.impersonator, + AllowedLogins: allowedLogins, + Roles: req.checker.RoleNames(), + PermitPortForwarding: req.checker.CanPortForward(), + PermitAgentForwarding: req.checker.CanForwardAgents(), + PermitX11Forwarding: req.checker.PermitX11Forwarding(), + RouteToCluster: req.routeToCluster, + Traits: req.traits, + ActiveRequests: req.activeRequests, + MFAVerified: req.mfaVerified, + PreviousIdentityExpires: req.previousIdentityExpires, + LoginIP: req.loginIP, + PinnedIP: pinnedIP, + DisallowReissue: req.disallowReissue, + Renewable: req.renewable, + Generation: req.generation, + BotName: req.botName, + BotInstanceID: req.botInstanceID, + CertificateExtensions: req.checker.CertificateExtensions(), + AllowedResourceIDs: requestedResourcesStr, + ConnectionDiagnosticID: req.connectionDiagnosticID, + PrivateKeyPolicy: attestedKeyPolicy, + DeviceID: req.deviceExtensions.DeviceID, + DeviceAssetTag: req.deviceExtensions.AssetTag, + DeviceCredentialID: req.deviceExtensions.CredentialID, + GitHubUserID: githubUserID, + GitHubUsername: githubUsername, + }, } signedSSHCert, err = a.GenerateUserCert(params) if err != nil { diff --git a/lib/auth/keygen/keygen.go b/lib/auth/keygen/keygen.go index cd6bb0acb28ee..5f47b3a90ac16 100644 --- a/lib/auth/keygen/keygen.go +++ b/lib/auth/keygen/keygen.go @@ -23,7 +23,6 @@ import ( "crypto/rand" "fmt" "log/slog" - "strings" "time" "github.com/gravitational/trace" @@ -31,12 +30,11 @@ import ( "golang.org/x/crypto/ssh" "github.com/gravitational/teleport" - "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/api/types/wrappers" apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/lib/modules" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/sshca" "github.com/gravitational/teleport/lib/utils" ) @@ -129,164 +127,70 @@ func (k *Keygen) GenerateHostCertWithoutValidation(c services.HostCertParams) ([ // GenerateUserCert generates a user ssh certificate with the passed in parameters. // The private key of the CA to sign the certificate must be provided. -func (k *Keygen) GenerateUserCert(c services.UserCertParams) ([]byte, error) { - if err := c.CheckAndSetDefaults(); err != nil { - return nil, trace.Wrap(err, "error validating UserCertParams") +func (k *Keygen) GenerateUserCert(req sshca.UserCertificateRequest) ([]byte, error) { + if err := req.CheckAndSetDefaults(); err != nil { + return nil, trace.Wrap(err, "error validating user certificate request") } - return k.GenerateUserCertWithoutValidation(c) + return k.GenerateUserCertWithoutValidation(req) } // GenerateUserCertWithoutValidation generates a user ssh certificate with the // passed in parameters without validating them. -func (k *Keygen) GenerateUserCertWithoutValidation(c services.UserCertParams) ([]byte, error) { - pubKey, _, _, _, err := ssh.ParseAuthorizedKey(c.PublicUserKey) +func (k *Keygen) GenerateUserCertWithoutValidation(req sshca.UserCertificateRequest) ([]byte, error) { + pubKey, _, _, _, err := ssh.ParseAuthorizedKey(req.PublicUserKey) if err != nil { return nil, trace.Wrap(err) } - validBefore := uint64(ssh.CertTimeInfinity) - if c.TTL != 0 { - b := k.clock.Now().UTC().Add(c.TTL) - validBefore = uint64(b.Unix()) + + // create shallow copy of identity since we want to make some local changes + ident := req.Identity + + // since this method ignores the supplied values for ValidBefore/ValidAfter, avoid confusing by + // rejecting identities where they are set. + if ident.ValidBefore != 0 { + return nil, trace.BadParameter("ValidBefore should not be set in calls to GenerateUserCert") + } + if ident.ValidAfter != 0 { + return nil, trace.BadParameter("ValidAfter should not be set in calls to GenerateUserCert") + } + + // calculate ValidBefore based on the outer request TTL + ident.ValidBefore = uint64(ssh.CertTimeInfinity) + if req.TTL != 0 { + b := k.clock.Now().UTC().Add(req.TTL) + ident.ValidBefore = uint64(b.Unix()) slog.DebugContext( context.TODO(), "Generated user key with expiry.", - "allowed_logins", c.AllowedLogins, - "valid_before_unix_ts", validBefore, + "allowed_logins", ident.AllowedLogins, + "valid_before_unix_ts", ident.ValidBefore, "valid_before", b, ) } - cert := &ssh.Certificate{ - // we have to use key id to identify teleport user - KeyId: c.Username, - ValidPrincipals: c.AllowedLogins, - Key: pubKey, - ValidAfter: uint64(k.clock.Now().UTC().Add(-1 * time.Minute).Unix()), - ValidBefore: validBefore, - CertType: ssh.UserCert, - } - cert.Permissions.Extensions = map[string]string{ - teleport.CertExtensionPermitPTY: "", - } - if c.PermitX11Forwarding { - cert.Permissions.Extensions[teleport.CertExtensionPermitX11Forwarding] = "" - } - if c.PermitAgentForwarding { - cert.Permissions.Extensions[teleport.CertExtensionPermitAgentForwarding] = "" - } - if c.PermitPortForwarding { - cert.Permissions.Extensions[teleport.CertExtensionPermitPortForwarding] = "" - } - if c.MFAVerified != "" { - cert.Permissions.Extensions[teleport.CertExtensionMFAVerified] = c.MFAVerified - } - if !c.PreviousIdentityExpires.IsZero() { - cert.Permissions.Extensions[teleport.CertExtensionPreviousIdentityExpires] = c.PreviousIdentityExpires.Format(time.RFC3339) - } - if c.LoginIP != "" { - cert.Permissions.Extensions[teleport.CertExtensionLoginIP] = c.LoginIP - } - if c.Impersonator != "" { - cert.Permissions.Extensions[teleport.CertExtensionImpersonator] = c.Impersonator - } - if c.DisallowReissue { - cert.Permissions.Extensions[teleport.CertExtensionDisallowReissue] = "" - } - if c.Renewable { - cert.Permissions.Extensions[teleport.CertExtensionRenewable] = "" - } - if c.Generation > 0 { - cert.Permissions.Extensions[teleport.CertExtensionGeneration] = fmt.Sprint(c.Generation) - } - if c.BotName != "" { - cert.Permissions.Extensions[teleport.CertExtensionBotName] = c.BotName - } - if c.BotInstanceID != "" { - cert.Permissions.Extensions[teleport.CertExtensionBotInstanceID] = c.BotInstanceID - } - if c.AllowedResourceIDs != "" { - cert.Permissions.Extensions[teleport.CertExtensionAllowedResources] = c.AllowedResourceIDs - } - if c.ConnectionDiagnosticID != "" { - cert.Permissions.Extensions[teleport.CertExtensionConnectionDiagnosticID] = c.ConnectionDiagnosticID - } - if c.PrivateKeyPolicy != "" { - cert.Permissions.Extensions[teleport.CertExtensionPrivateKeyPolicy] = string(c.PrivateKeyPolicy) - } - if devID := c.DeviceID; devID != "" { - cert.Permissions.Extensions[teleport.CertExtensionDeviceID] = devID - } - if assetTag := c.DeviceAssetTag; assetTag != "" { - cert.Permissions.Extensions[teleport.CertExtensionDeviceAssetTag] = assetTag - } - if credID := c.DeviceCredentialID; credID != "" { - cert.Permissions.Extensions[teleport.CertExtensionDeviceCredentialID] = credID - } - if c.GitHubUserID != "" { - cert.Permissions.Extensions[teleport.CertExtensionGitHubUserID] = c.GitHubUserID - } - if c.GitHubUsername != "" { - cert.Permissions.Extensions[teleport.CertExtensionGitHubUsername] = c.GitHubUsername - } - if c.PinnedIP != "" { + // set ValidAfter to be 1 minute in the past + ident.ValidAfter = uint64(k.clock.Now().UTC().Add(-1 * time.Minute).Unix()) + + // if the provided identity is attempting to perform IP pinning, make sure modules are enforced + if ident.PinnedIP != "" { if modules.GetModules().BuildType() != modules.BuildEnterprise { return nil, trace.AccessDenied("source IP pinning is only supported in Teleport Enterprise") } - if cert.CriticalOptions == nil { - cert.CriticalOptions = make(map[string]string) - } - // IPv4, all bits matter - ip := c.PinnedIP + "/32" - if strings.Contains(c.PinnedIP, ":") { - // IPv6 - ip = c.PinnedIP + "/128" - } - cert.CriticalOptions[teleport.CertCriticalOptionSourceAddress] = ip } - for _, extension := range c.CertificateExtensions { - // TODO(lxea): update behavior when non ssh, non extensions are supported. - if extension.Mode != types.CertExtensionMode_EXTENSION || - extension.Type != types.CertExtensionType_SSH { - continue - } - cert.Extensions[extension.Name] = extension.Value + // encode the identity into a certificate + cert, err := ident.Encode(req.CertificateFormat) + if err != nil { + return nil, trace.Wrap(err) } - // Add roles, traits, and route to cluster in the certificate extensions if - // the standard format was requested. Certificate extensions are not included - // legacy SSH certificates due to a bug in OpenSSH <= OpenSSH 7.1: - // https://bugzilla.mindrot.org/show_bug.cgi?id=2387 - if c.CertificateFormat == constants.CertificateFormatStandard { - traits, err := wrappers.MarshalTraits(&c.Traits) - if err != nil { - return nil, trace.Wrap(err) - } - if len(traits) > 0 { - cert.Permissions.Extensions[teleport.CertExtensionTeleportTraits] = string(traits) - } - if len(c.Roles) != 0 { - roles, err := services.MarshalCertRoles(c.Roles) - if err != nil { - return nil, trace.Wrap(err) - } - cert.Permissions.Extensions[teleport.CertExtensionTeleportRoles] = roles - } - if c.RouteToCluster != "" { - cert.Permissions.Extensions[teleport.CertExtensionTeleportRouteToCluster] = c.RouteToCluster - } - if !c.ActiveRequests.IsEmpty() { - requests, err := c.ActiveRequests.Marshal() - if err != nil { - return nil, trace.Wrap(err) - } - cert.Permissions.Extensions[teleport.CertExtensionTeleportActiveRequests] = string(requests) - } - } + // set the public key of the certificate + cert.Key = pubKey - if err := cert.SignCert(rand.Reader, c.CASigner); err != nil { + if err := cert.SignCert(rand.Reader, req.CASigner); err != nil { return nil, trace.Wrap(err) } + return ssh.MarshalAuthorizedKey(cert), nil } diff --git a/lib/auth/keygen/keygen_test.go b/lib/auth/keygen/keygen_test.go index e2d68d91a923e..d6c243b3ee986 100644 --- a/lib/auth/keygen/keygen_test.go +++ b/lib/auth/keygen/keygen_test.go @@ -38,6 +38,7 @@ import ( "github.com/gravitational/teleport/lib/auth/test" "github.com/gravitational/teleport/lib/cryptosuites" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/sshca" ) type nativeContext struct { @@ -226,23 +227,24 @@ func TestUserCertCompatibility(t *testing.T) { for i, tc := range tests { comment := fmt.Sprintf("Test %v", i) - userCertificateBytes, err := tt.suite.A.GenerateUserCert(services.UserCertParams{ - CASigner: caSigner, - PublicUserKey: ssh.MarshalAuthorizedKey(caSigner.PublicKey()), - Username: "user", - AllowedLogins: []string{"centos", "root"}, - TTL: time.Hour, - Roles: []string{"foo"}, - CertificateExtensions: []*types.CertExtension{{ - Type: types.CertExtensionType_SSH, - Mode: types.CertExtensionMode_EXTENSION, - Name: "login@github.com", - Value: "hello", + userCertificateBytes, err := tt.suite.A.GenerateUserCert(sshca.UserCertificateRequest{ + CASigner: caSigner, + PublicUserKey: ssh.MarshalAuthorizedKey(caSigner.PublicKey()), + TTL: time.Hour, + CertificateFormat: tc.inCompatibility, + Identity: sshca.Identity{ + Username: "user", + AllowedLogins: []string{"centos", "root"}, + Roles: []string{"foo"}, + CertificateExtensions: []*types.CertExtension{{ + Type: types.CertExtensionType_SSH, + Mode: types.CertExtensionMode_EXTENSION, + Name: "login@github.com", + Value: "hello", + }}, + PermitAgentForwarding: true, + PermitPortForwarding: true, }, - }, - CertificateFormat: tc.inCompatibility, - PermitAgentForwarding: true, - PermitPortForwarding: true, }) require.NoError(t, err, comment) diff --git a/lib/auth/test/suite.go b/lib/auth/test/suite.go index 3e97874d8802e..14d22f8265647 100644 --- a/lib/auth/test/suite.go +++ b/lib/auth/test/suite.go @@ -95,15 +95,17 @@ func (s *AuthSuite) GenerateUserCert(t *testing.T) { caSigner, err := ssh.ParsePrivateKey(priv) require.NoError(t, err) - cert, err := s.A.GenerateUserCert(services.UserCertParams{ - CASigner: caSigner, - PublicUserKey: pub, - Username: "user", - AllowedLogins: []string{"centos", "root"}, - TTL: time.Hour, - PermitAgentForwarding: true, - PermitPortForwarding: true, - CertificateFormat: constants.CertificateFormatStandard, + cert, err := s.A.GenerateUserCert(sshca.UserCertificateRequest{ + CASigner: caSigner, + PublicUserKey: pub, + TTL: time.Hour, + CertificateFormat: constants.CertificateFormatStandard, + Identity: sshca.Identity{ + Username: "user", + AllowedLogins: []string{"centos", "root"}, + PermitAgentForwarding: true, + PermitPortForwarding: true, + }, }) require.NoError(t, err) @@ -112,59 +114,67 @@ func (s *AuthSuite) GenerateUserCert(t *testing.T) { err = checkCertExpiry(cert, s.Clock.Now().Add(-1*time.Minute), s.Clock.Now().Add(1*time.Hour)) require.NoError(t, err) - cert, err = s.A.GenerateUserCert(services.UserCertParams{ - CASigner: caSigner, - PublicUserKey: pub, - Username: "user", - AllowedLogins: []string{"root"}, - TTL: -20, - PermitAgentForwarding: true, - PermitPortForwarding: true, - CertificateFormat: constants.CertificateFormatStandard, + cert, err = s.A.GenerateUserCert(sshca.UserCertificateRequest{ + CASigner: caSigner, + PublicUserKey: pub, + TTL: -20, + CertificateFormat: constants.CertificateFormatStandard, + Identity: sshca.Identity{ + Username: "user", + AllowedLogins: []string{"root"}, + PermitAgentForwarding: true, + PermitPortForwarding: true, + }, }) require.NoError(t, err) err = checkCertExpiry(cert, s.Clock.Now().Add(-1*time.Minute), s.Clock.Now().Add(apidefaults.MinCertDuration)) require.NoError(t, err) - _, err = s.A.GenerateUserCert(services.UserCertParams{ - CASigner: caSigner, - PublicUserKey: pub, - Username: "user", - AllowedLogins: []string{"root"}, - TTL: 0, - PermitAgentForwarding: true, - PermitPortForwarding: true, - CertificateFormat: constants.CertificateFormatStandard, + _, err = s.A.GenerateUserCert(sshca.UserCertificateRequest{ + CASigner: caSigner, + PublicUserKey: pub, + TTL: 0, + CertificateFormat: constants.CertificateFormatStandard, + Identity: sshca.Identity{ + Username: "user", + AllowedLogins: []string{"root"}, + PermitAgentForwarding: true, + PermitPortForwarding: true, + }, }) require.NoError(t, err) err = checkCertExpiry(cert, s.Clock.Now().Add(-1*time.Minute), s.Clock.Now().Add(apidefaults.MinCertDuration)) require.NoError(t, err) - _, err = s.A.GenerateUserCert(services.UserCertParams{ - CASigner: caSigner, - PublicUserKey: pub, - Username: "user", - AllowedLogins: []string{"root"}, - TTL: time.Hour, - PermitAgentForwarding: true, - PermitPortForwarding: true, - CertificateFormat: constants.CertificateFormatStandard, + _, err = s.A.GenerateUserCert(sshca.UserCertificateRequest{ + CASigner: caSigner, + PublicUserKey: pub, + TTL: time.Hour, + CertificateFormat: constants.CertificateFormatStandard, + Identity: sshca.Identity{ + Username: "user", + AllowedLogins: []string{"root"}, + PermitAgentForwarding: true, + PermitPortForwarding: true, + }, }) require.NoError(t, err) inRoles := []string{"role-1", "role-2"} impersonator := "alice" - cert, err = s.A.GenerateUserCert(services.UserCertParams{ - CASigner: caSigner, - PublicUserKey: pub, - Username: "user", - Impersonator: impersonator, - AllowedLogins: []string{"root"}, - TTL: time.Hour, - PermitAgentForwarding: true, - PermitPortForwarding: true, - CertificateFormat: constants.CertificateFormatStandard, - Roles: inRoles, + cert, err = s.A.GenerateUserCert(sshca.UserCertificateRequest{ + CASigner: caSigner, + PublicUserKey: pub, + TTL: time.Hour, + CertificateFormat: constants.CertificateFormatStandard, + Identity: sshca.Identity{ + Username: "user", + Impersonator: impersonator, + AllowedLogins: []string{"root"}, + PermitAgentForwarding: true, + PermitPortForwarding: true, + Roles: inRoles, + }, }) require.NoError(t, err) parsedCert, err := sshutils.ParseCertificate(cert) @@ -178,15 +188,17 @@ func (s *AuthSuite) GenerateUserCert(t *testing.T) { // Check that MFAVerified and PreviousIdentityExpires are encoded into ssh cert clock := clockwork.NewFakeClock() - cert, err = s.A.GenerateUserCert(services.UserCertParams{ - CASigner: caSigner, - PublicUserKey: pub, - Username: "user", - AllowedLogins: []string{"root"}, - TTL: time.Minute, - CertificateFormat: constants.CertificateFormatStandard, - MFAVerified: "mfa-device-id", - PreviousIdentityExpires: clock.Now().Add(time.Hour), + cert, err = s.A.GenerateUserCert(sshca.UserCertificateRequest{ + CASigner: caSigner, + PublicUserKey: pub, + TTL: time.Minute, + CertificateFormat: constants.CertificateFormatStandard, + Identity: sshca.Identity{ + Username: "user", + AllowedLogins: []string{"root"}, + MFAVerified: "mfa-device-id", + PreviousIdentityExpires: clock.Now().Add(time.Hour), + }, }) require.NoError(t, err) parsedCert, err = sshutils.ParseCertificate(cert) @@ -202,14 +214,16 @@ func (s *AuthSuite) GenerateUserCert(t *testing.T) { const devID = "deviceid1" const devTag = "devicetag1" const devCred = "devicecred1" - certRaw, err := s.A.GenerateUserCert(services.UserCertParams{ - CASigner: caSigner, // Required. - PublicUserKey: pub, // Required. - Username: "llama", // Required. - AllowedLogins: []string{"llama"}, // Required. - DeviceID: devID, - DeviceAssetTag: devTag, - DeviceCredentialID: devCred, + certRaw, err := s.A.GenerateUserCert(sshca.UserCertificateRequest{ + CASigner: caSigner, // Required. + PublicUserKey: pub, // Required. + Identity: sshca.Identity{ + Username: "llama", // Required. + AllowedLogins: []string{"llama"}, // Required. + DeviceID: devID, + DeviceAssetTag: devTag, + DeviceCredentialID: devCred, + }, }) require.NoError(t, err, "GenerateUserCert failed") @@ -223,13 +237,15 @@ func (s *AuthSuite) GenerateUserCert(t *testing.T) { t.Run("github identity", func(t *testing.T) { githubUserID := "1234567" githubUsername := "github-user" - certRaw, err := s.A.GenerateUserCert(services.UserCertParams{ - CASigner: caSigner, // Required. - PublicUserKey: pub, // Required. - Username: "llama", // Required. - AllowedLogins: []string{"llama"}, // Required. - GitHubUserID: githubUserID, - GitHubUsername: githubUsername, + certRaw, err := s.A.GenerateUserCert(sshca.UserCertificateRequest{ + CASigner: caSigner, // Required. + PublicUserKey: pub, // Required. + Identity: sshca.Identity{ + Username: "llama", // Required. + AllowedLogins: []string{"llama"}, // Required. + GitHubUserID: githubUserID, + GitHubUsername: githubUsername, + }, }) require.NoError(t, err, "GenerateUserCert failed") diff --git a/lib/auth/testauthority/testauthority.go b/lib/auth/testauthority/testauthority.go index 8dae039d9c1f4..b58f9ac27493d 100644 --- a/lib/auth/testauthority/testauthority.go +++ b/lib/auth/testauthority/testauthority.go @@ -29,6 +29,7 @@ import ( "github.com/gravitational/teleport/lib/auth/keygen" "github.com/gravitational/teleport/lib/cryptosuites" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/sshca" ) type Keygen struct { @@ -60,7 +61,7 @@ func (n *Keygen) GenerateHostCert(c services.HostCertParams) ([]byte, error) { return n.GenerateHostCertWithoutValidation(c) } -func (n *Keygen) GenerateUserCert(c services.UserCertParams) ([]byte, error) { +func (n *Keygen) GenerateUserCert(c sshca.UserCertificateRequest) ([]byte, error) { return n.GenerateUserCertWithoutValidation(c) } diff --git a/lib/client/client_store_test.go b/lib/client/client_store_test.go index 8090c5e664851..71239884aaaba 100644 --- a/lib/client/client_store_test.go +++ b/lib/client/client_store_test.go @@ -45,6 +45,7 @@ import ( "github.com/gravitational/teleport/lib/cryptosuites" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/sshca" "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" @@ -104,16 +105,18 @@ func (s *testAuthority) makeSignedKeyRing(t *testing.T, idx KeyRingIndex, makeEx caSigner, err := ssh.ParsePrivateKey(CAPriv) require.NoError(t, err) - cert, err := s.keygen.GenerateUserCert(services.UserCertParams{ - CASigner: caSigner, - PublicUserKey: sshPriv.MarshalSSHPublicKey(), - Username: idx.Username, - AllowedLogins: allowedLogins, - TTL: ttl, - PermitAgentForwarding: false, - PermitPortForwarding: true, - GitHubUserID: "1234567", - GitHubUsername: "github-username", + cert, err := s.keygen.GenerateUserCert(sshca.UserCertificateRequest{ + CASigner: caSigner, + PublicUserKey: sshPriv.MarshalSSHPublicKey(), + TTL: ttl, + Identity: sshca.Identity{ + Username: idx.Username, + AllowedLogins: allowedLogins, + PermitAgentForwarding: false, + PermitPortForwarding: true, + GitHubUserID: "1234567", + GitHubUsername: "github-username", + }, }) require.NoError(t, err) diff --git a/lib/client/cluster_client_test.go b/lib/client/cluster_client_test.go index 7a90be3f30d80..e529b4737d1db 100644 --- a/lib/client/cluster_client_test.go +++ b/lib/client/cluster_client_test.go @@ -39,7 +39,7 @@ import ( libmfa "github.com/gravitational/teleport/lib/client/mfa" "github.com/gravitational/teleport/lib/fixtures" "github.com/gravitational/teleport/lib/observability/tracing" - "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/sshca" "github.com/gravitational/teleport/lib/tlsca" ) @@ -390,13 +390,15 @@ func TestIssueUserCertsWithMFA(t *testing.T) { var sshCert, tlsCert []byte var err error if req.SSHPublicKey != nil { - sshCert, err = ca.keygen.GenerateUserCert(services.UserCertParams{ + sshCert, err = ca.keygen.GenerateUserCert(sshca.UserCertificateRequest{ CASigner: caSigner, PublicUserKey: req.SSHPublicKey, TTL: req.Expires.Sub(clock.Now()), - Username: req.Username, CertificateFormat: req.Format, - RouteToCluster: req.RouteToCluster, + Identity: sshca.Identity{ + Username: req.Username, + RouteToCluster: req.RouteToCluster, + }, }) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/client/identityfile/identity_test.go b/lib/client/identityfile/identity_test.go index 3f52aefe162db..9d8eeb62a894d 100644 --- a/lib/client/identityfile/identity_test.go +++ b/lib/client/identityfile/identity_test.go @@ -46,7 +46,7 @@ import ( "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/fixtures" "github.com/gravitational/teleport/lib/kube/kubeconfig" - "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/sshca" "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/tlsca" ) @@ -108,11 +108,13 @@ func newClientKeyRing(t *testing.T, modifiers ...func(*tlsca.Identity)) *client. caSigner, err := ssh.NewSignerFromKey(signer) require.NoError(t, err) - certificate, err := keygen.GenerateUserCert(services.UserCertParams{ + certificate, err := keygen.GenerateUserCert(sshca.UserCertificateRequest{ CASigner: caSigner, PublicUserKey: ssh.MarshalAuthorizedKey(privateKey.SSHPublicKey()), - Username: "testuser", - AllowedLogins: []string{"testuser"}, + Identity: sshca.Identity{ + Username: "testuser", + AllowedLogins: []string{"testuser"}, + }, }) require.NoError(t, err) diff --git a/lib/client/keyagent_test.go b/lib/client/keyagent_test.go index 4c0c078e82293..a8dfdae28da95 100644 --- a/lib/client/keyagent_test.go +++ b/lib/client/keyagent_test.go @@ -50,6 +50,7 @@ import ( "github.com/gravitational/teleport/lib/cryptosuites" "github.com/gravitational/teleport/lib/fixtures" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/sshca" "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" ) @@ -751,16 +752,18 @@ func (s *KeyAgentTestSuite) makeKeyRing(t *testing.T, username, proxyHost string sshPub, err := ssh.NewPublicKey(sshKey.Public()) require.NoError(t, err) - certificate, err := testauthority.New().GenerateUserCert(services.UserCertParams{ - CertificateFormat: constants.CertificateFormatStandard, - CASigner: caSigner, - PublicUserKey: ssh.MarshalAuthorizedKey(sshPub), - Username: username, - AllowedLogins: []string{username}, - TTL: ttl, - PermitAgentForwarding: true, - PermitPortForwarding: true, - RouteToCluster: s.clusterName, + certificate, err := testauthority.New().GenerateUserCert(sshca.UserCertificateRequest{ + CertificateFormat: constants.CertificateFormatStandard, + CASigner: caSigner, + PublicUserKey: ssh.MarshalAuthorizedKey(sshPub), + TTL: ttl, + Identity: sshca.Identity{ + Username: username, + AllowedLogins: []string{username}, + PermitAgentForwarding: true, + PermitPortForwarding: true, + RouteToCluster: s.clusterName, + }, }) require.NoError(t, err) diff --git a/lib/reversetunnel/srv_test.go b/lib/reversetunnel/srv_test.go index 2477739df359a..8794a8323f0f1 100644 --- a/lib/reversetunnel/srv_test.go +++ b/lib/reversetunnel/srv_test.go @@ -39,6 +39,7 @@ import ( "github.com/gravitational/teleport/lib/auth/authclient" "github.com/gravitational/teleport/lib/auth/testauthority" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/sshca" "github.com/gravitational/teleport/lib/utils" ) @@ -103,15 +104,17 @@ func TestServerKeyAuth(t *testing.T) { { desc: "user cert", key: func() ssh.PublicKey { - rawCert, err := ta.GenerateUserCert(services.UserCertParams{ + rawCert, err := ta.GenerateUserCert(sshca.UserCertificateRequest{ CASigner: caSigner, PublicUserKey: pub, - Username: con.User(), - AllowedLogins: []string{con.User()}, - Roles: []string{"dev", "admin"}, - RouteToCluster: "user-cluster-name", CertificateFormat: constants.CertificateFormatStandard, TTL: time.Minute, + Identity: sshca.Identity{ + Username: con.User(), + AllowedLogins: []string{con.User()}, + Roles: []string{"dev", "admin"}, + RouteToCluster: "user-cluster-name", + }, }) require.NoError(t, err) key, _, _, _, err := ssh.ParseAuthorizedKey(rawCert) diff --git a/lib/services/authority.go b/lib/services/authority.go index fb6a3efe612e6..2345342b1195b 100644 --- a/lib/services/authority.go +++ b/lib/services/authority.go @@ -32,9 +32,7 @@ import ( "github.com/jonboulle/clockwork" "golang.org/x/crypto/ssh" - apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/api/types/wrappers" apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/api/utils/keys" "github.com/gravitational/teleport/lib/jwt" @@ -321,103 +319,6 @@ func (c HostCertParams) Check() error { return nil } -// UserCertParams defines OpenSSH user certificate parameters -type UserCertParams struct { - // CASigner is the signer that will sign the public key of the user with the CA private key - CASigner ssh.Signer - // PublicUserKey is the public key of the user in SSH authorized_keys format. - PublicUserKey []byte - // TTL defines how long a certificate is valid for - TTL time.Duration - // Username is teleport username - Username string - // Impersonator is set when a user requests certificate for another user - Impersonator string - // AllowedLogins is a list of SSH principals - AllowedLogins []string - // PermitX11Forwarding permits X11 forwarding for this cert - PermitX11Forwarding bool - // PermitAgentForwarding permits agent forwarding for this cert - PermitAgentForwarding bool - // PermitPortForwarding permits port forwarding. - PermitPortForwarding bool - // PermitFileCopying permits the use of SCP/SFTP. - PermitFileCopying bool - // Roles is a list of roles assigned to this user - Roles []string - // CertificateFormat is the format of the SSH certificate. - CertificateFormat string - // RouteToCluster specifies the target cluster - // if present in the certificate, will be used - // to route the requests to - RouteToCluster string - // Traits hold claim data used to populate a role at runtime. - Traits wrappers.Traits - // ActiveRequests tracks privilege escalation requests applied during - // certificate construction. - ActiveRequests RequestIDs - // MFAVerified is the UUID of an MFA device when this Identity was - // confirmed immediately after an MFA check. - MFAVerified string - // PreviousIdentityExpires is the expiry time of the identity/cert that this - // identity/cert was derived from. It is used to determine a session's hard - // deadline in cases where both require_session_mfa and disconnect_expired_cert - // are enabled. See https://github.com/gravitational/teleport/issues/18544. - PreviousIdentityExpires time.Time - // LoginIP is an observed IP of the client on the moment of certificate creation. - LoginIP string - // PinnedIP is an IP from which client must communicate with Teleport. - PinnedIP string - // DisallowReissue flags that any attempt to request new certificates while - // authenticated with this cert should be denied. - DisallowReissue bool - // CertificateExtensions are user configured ssh key extensions - CertificateExtensions []*types.CertExtension - // Renewable indicates this certificate is renewable. - Renewable bool - // Generation counts the number of times a certificate has been renewed. - Generation uint64 - // BotName is set to the name of the bot, if the user is a Machine ID bot user. - // Empty for human users. - BotName string - // BotInstanceID is the unique identifier for the bot instance, if this is a - // Machine ID bot. It is empty for human users. - BotInstanceID string - // AllowedResourceIDs lists the resources the user should be able to access. - AllowedResourceIDs string - // ConnectionDiagnosticID references the ConnectionDiagnostic that we should use to append traces when testing a Connection. - ConnectionDiagnosticID string - // PrivateKeyPolicy is the private key policy supported by this certificate. - PrivateKeyPolicy keys.PrivateKeyPolicy - // DeviceID is the trusted device identifier. - DeviceID string - // DeviceAssetTag is the device inventory identifier. - DeviceAssetTag string - // DeviceCredentialID is the identifier for the credential used by the device - // to authenticate itself. - DeviceCredentialID string - // GitHubUserID indicates the GitHub user ID identified by the GitHub - // connector. - GitHubUserID string - // GitHubUserID indicates the GitHub username identified by the GitHub - // connector. - GitHubUsername string -} - -// CheckAndSetDefaults checks the user certificate parameters -func (c *UserCertParams) CheckAndSetDefaults() error { - if c.CASigner == nil { - return trace.BadParameter("CASigner is required") - } - if c.TTL < apidefaults.MinCertDuration { - c.TTL = apidefaults.MinCertDuration - } - if len(c.AllowedLogins) == 0 { - return trace.BadParameter("AllowedLogins are required") - } - return nil -} - // CertPoolFromCertAuthorities returns a certificate pool from the TLS certificates // set up in the certificate authorities list, as well as the number of certificates // that were added to the pool. diff --git a/lib/srv/authhandlers_test.go b/lib/srv/authhandlers_test.go index 78856817654a9..907a3db97b786 100644 --- a/lib/srv/authhandlers_test.go +++ b/lib/srv/authhandlers_test.go @@ -35,7 +35,7 @@ import ( "github.com/gravitational/teleport/lib/auth/testauthority" "github.com/gravitational/teleport/lib/cryptosuites" "github.com/gravitational/teleport/lib/events/eventstest" - "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/sshca" ) type mockCAandAuthPrefGetter struct { @@ -213,11 +213,13 @@ func TestRBAC(t *testing.T) { privateKey, err := cryptosuites.GeneratePrivateKeyWithAlgorithm(cryptosuites.ECDSAP256) require.NoError(t, err) - c, err := keygen.GenerateUserCert(services.UserCertParams{ + c, err := keygen.GenerateUserCert(sshca.UserCertificateRequest{ CASigner: caSigner, PublicUserKey: ssh.MarshalAuthorizedKey(privateKey.SSHPublicKey()), - Username: "testuser", - AllowedLogins: []string{"testuser"}, + Identity: sshca.Identity{ + Username: "testuser", + AllowedLogins: []string{"testuser"}, + }, }) require.NoError(t, err) @@ -385,16 +387,18 @@ func TestRBACJoinMFA(t *testing.T) { require.NoError(t, err) keygen := testauthority.New() - c, err := keygen.GenerateUserCert(services.UserCertParams{ - CASigner: caSigner, - PublicUserKey: privateKey.MarshalSSHPublicKey(), - Username: username, - AllowedLogins: []string{username}, - Traits: wrappers.Traits{ - teleport.TraitInternalPrefix: []string{""}, - }, - Roles: []string{tt.role}, + c, err := keygen.GenerateUserCert(sshca.UserCertificateRequest{ + CASigner: caSigner, + PublicUserKey: privateKey.MarshalSSHPublicKey(), CertificateFormat: constants.CertificateFormatStandard, + Identity: sshca.Identity{ + Username: username, + AllowedLogins: []string{username}, + Traits: wrappers.Traits{ + teleport.TraitInternalPrefix: []string{""}, + }, + Roles: []string{tt.role}, + }, }) require.NoError(t, err) diff --git a/lib/sshca/identity.go b/lib/sshca/identity.go new file mode 100644 index 0000000000000..19f40bfdf336d --- /dev/null +++ b/lib/sshca/identity.go @@ -0,0 +1,392 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +// Package sshca specifies interfaces for SSH certificate authorities +package sshca + +import ( + "fmt" + "maps" + "strconv" + "strings" + "time" + + "github.com/gravitational/trace" + "golang.org/x/crypto/ssh" + + "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/constants" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/types/wrappers" + "github.com/gravitational/teleport/api/utils/keys" + "github.com/gravitational/teleport/lib/services" +) + +// Identity is a user identity. All identity fields map directly to an ssh certificate field. +type Identity struct { + // ValidAfter is the unix timestamp that marks the start time for when the certificate should + // be considered valid. + ValidAfter uint64 + // ValidBefore is the unix timestamp that marks the end time for when the certificate should + // be considered valid. + ValidBefore uint64 + // Username is teleport username + Username string + // Impersonator is set when a user requests certificate for another user + Impersonator string + // AllowedLogins is a list of SSH principals + AllowedLogins []string + // PermitX11Forwarding permits X11 forwarding for this cert + PermitX11Forwarding bool + // PermitAgentForwarding permits agent forwarding for this cert + PermitAgentForwarding bool + // PermitPortForwarding permits port forwarding. + PermitPortForwarding bool + // Roles is a list of roles assigned to this user + Roles []string + // RouteToCluster specifies the target cluster + // if present in the certificate, will be used + // to route the requests to + RouteToCluster string + // Traits hold claim data used to populate a role at runtime. + Traits wrappers.Traits + // ActiveRequests tracks privilege escalation requests applied during + // certificate construction. + ActiveRequests services.RequestIDs + // MFAVerified is the UUID of an MFA device when this Identity was + // confirmed immediately after an MFA check. + MFAVerified string + // PreviousIdentityExpires is the expiry time of the identity/cert that this + // identity/cert was derived from. It is used to determine a session's hard + // deadline in cases where both require_session_mfa and disconnect_expired_cert + // are enabled. See https://github.com/gravitational/teleport/issues/18544. + PreviousIdentityExpires time.Time + // LoginIP is an observed IP of the client on the moment of certificate creation. + LoginIP string + // PinnedIP is an IP from which client must communicate with Teleport. + PinnedIP string + // DisallowReissue flags that any attempt to request new certificates while + // authenticated with this cert should be denied. + DisallowReissue bool + // CertificateExtensions are user configured ssh key extensions (note: this field also + // ends up aggregating all *unknown* extensions during cert parsing, meaning that this + // can sometimes contain fields that were inserted by a newer version of teleport). + CertificateExtensions []*types.CertExtension + // Renewable indicates this certificate is renewable. + Renewable bool + // Generation counts the number of times a certificate has been renewed, with a generation of 1 + // meaning the cert has never been renewed. A generation of zero means the cert's generation is + // not being tracked. + Generation uint64 + // BotName is set to the name of the bot, if the user is a Machine ID bot user. + // Empty for human users. + BotName string + // BotInstanceID is the unique identifier for the bot instance, if this is a + // Machine ID bot. It is empty for human users. + BotInstanceID string + // AllowedResourceIDs lists the resources the user should be able to access. + AllowedResourceIDs string + // ConnectionDiagnosticID references the ConnectionDiagnostic that we should use to append traces when testing a Connection. + ConnectionDiagnosticID string + // PrivateKeyPolicy is the private key policy supported by this certificate. + PrivateKeyPolicy keys.PrivateKeyPolicy + // DeviceID is the trusted device identifier. + DeviceID string + // DeviceAssetTag is the device inventory identifier. + DeviceAssetTag string + // DeviceCredentialID is the identifier for the credential used by the device + // to authenticate itself. + DeviceCredentialID string + // GitHubUserID indicates the GitHub user ID identified by the GitHub + // connector. + GitHubUserID string + // GitHubUsername indicates the GitHub username identified by the GitHub + // connector. + GitHubUsername string +} + +// Check performs validation of certain fields in the identity. +func (i *Identity) Check() error { + if len(i.AllowedLogins) == 0 { + return trace.BadParameter("ssh user identity missing allowed logins") + } + + return nil +} + +// Encode encodes the identity into an ssh certificate. Note that the returned certificate is incomplete +// and must be have its public key set before signing. +func (i *Identity) Encode(certFormat string) (*ssh.Certificate, error) { + validBefore := i.ValidBefore + if validBefore == 0 { + validBefore = uint64(ssh.CertTimeInfinity) + } + validAfter := i.ValidAfter + if validAfter == 0 { + validAfter = uint64(time.Now().UTC().Add(-1 * time.Minute).Unix()) + } + cert := &ssh.Certificate{ + // we have to use key id to identify teleport user + KeyId: i.Username, + ValidPrincipals: i.AllowedLogins, + ValidAfter: validAfter, + ValidBefore: validBefore, + CertType: ssh.UserCert, + } + cert.Permissions.Extensions = map[string]string{ + teleport.CertExtensionPermitPTY: "", + } + + if i.PermitX11Forwarding { + cert.Permissions.Extensions[teleport.CertExtensionPermitX11Forwarding] = "" + } + if i.PermitAgentForwarding { + cert.Permissions.Extensions[teleport.CertExtensionPermitAgentForwarding] = "" + } + if i.PermitPortForwarding { + cert.Permissions.Extensions[teleport.CertExtensionPermitPortForwarding] = "" + } + if i.MFAVerified != "" { + cert.Permissions.Extensions[teleport.CertExtensionMFAVerified] = i.MFAVerified + } + if !i.PreviousIdentityExpires.IsZero() { + cert.Permissions.Extensions[teleport.CertExtensionPreviousIdentityExpires] = i.PreviousIdentityExpires.Format(time.RFC3339) + } + if i.LoginIP != "" { + cert.Permissions.Extensions[teleport.CertExtensionLoginIP] = i.LoginIP + } + if i.Impersonator != "" { + cert.Permissions.Extensions[teleport.CertExtensionImpersonator] = i.Impersonator + } + if i.DisallowReissue { + cert.Permissions.Extensions[teleport.CertExtensionDisallowReissue] = "" + } + if i.Renewable { + cert.Permissions.Extensions[teleport.CertExtensionRenewable] = "" + } + if i.Generation > 0 { + cert.Permissions.Extensions[teleport.CertExtensionGeneration] = fmt.Sprint(i.Generation) + } + if i.BotName != "" { + cert.Permissions.Extensions[teleport.CertExtensionBotName] = i.BotName + } + if i.BotInstanceID != "" { + cert.Permissions.Extensions[teleport.CertExtensionBotInstanceID] = i.BotInstanceID + } + if i.AllowedResourceIDs != "" { + cert.Permissions.Extensions[teleport.CertExtensionAllowedResources] = i.AllowedResourceIDs + } + if i.ConnectionDiagnosticID != "" { + cert.Permissions.Extensions[teleport.CertExtensionConnectionDiagnosticID] = i.ConnectionDiagnosticID + } + if i.PrivateKeyPolicy != "" { + cert.Permissions.Extensions[teleport.CertExtensionPrivateKeyPolicy] = string(i.PrivateKeyPolicy) + } + if devID := i.DeviceID; devID != "" { + cert.Permissions.Extensions[teleport.CertExtensionDeviceID] = devID + } + if assetTag := i.DeviceAssetTag; assetTag != "" { + cert.Permissions.Extensions[teleport.CertExtensionDeviceAssetTag] = assetTag + } + if credID := i.DeviceCredentialID; credID != "" { + cert.Permissions.Extensions[teleport.CertExtensionDeviceCredentialID] = credID + } + if i.GitHubUserID != "" { + cert.Permissions.Extensions[teleport.CertExtensionGitHubUserID] = i.GitHubUserID + } + if i.GitHubUsername != "" { + cert.Permissions.Extensions[teleport.CertExtensionGitHubUsername] = i.GitHubUsername + } + + if i.PinnedIP != "" { + if cert.CriticalOptions == nil { + cert.CriticalOptions = make(map[string]string) + } + // IPv4, all bits matter + ip := i.PinnedIP + "/32" + if strings.Contains(i.PinnedIP, ":") { + // IPv6 + ip = i.PinnedIP + "/128" + } + cert.CriticalOptions[teleport.CertCriticalOptionSourceAddress] = ip + } + + for _, extension := range i.CertificateExtensions { + // TODO(lxea): update behavior when non ssh, non extensions are supported. + if extension.Mode != types.CertExtensionMode_EXTENSION || + extension.Type != types.CertExtensionType_SSH { + continue + } + cert.Extensions[extension.Name] = extension.Value + } + + // Add roles, traits, and route to cluster in the certificate extensions if + // the standard format was requested. Certificate extensions are not included + // legacy SSH certificates due to a bug in OpenSSH <= OpenSSH 7.1: + // https://bugzilla.mindrot.org/show_bug.cgi?id=2387 + if certFormat == constants.CertificateFormatStandard { + traits, err := wrappers.MarshalTraits(&i.Traits) + if err != nil { + return nil, trace.Wrap(err) + } + if len(traits) > 0 { + cert.Permissions.Extensions[teleport.CertExtensionTeleportTraits] = string(traits) + } + if len(i.Roles) != 0 { + roles, err := services.MarshalCertRoles(i.Roles) + if err != nil { + return nil, trace.Wrap(err) + } + cert.Permissions.Extensions[teleport.CertExtensionTeleportRoles] = roles + } + if i.RouteToCluster != "" { + cert.Permissions.Extensions[teleport.CertExtensionTeleportRouteToCluster] = i.RouteToCluster + } + if !i.ActiveRequests.IsEmpty() { + requests, err := i.ActiveRequests.Marshal() + if err != nil { + return nil, trace.Wrap(err) + } + cert.Permissions.Extensions[teleport.CertExtensionTeleportActiveRequests] = string(requests) + } + } + + return cert, nil +} + +// DecodeIdentity decodes an ssh certificate into an identity. +func DecodeIdentity(cert *ssh.Certificate) (*Identity, error) { + if cert.CertType != ssh.UserCert { + return nil, trace.BadParameter("DecodeIdentity intended for use with user certs, got %v", cert.CertType) + } + ident := &Identity{ + Username: cert.KeyId, + AllowedLogins: cert.ValidPrincipals, + ValidAfter: cert.ValidAfter, + ValidBefore: cert.ValidBefore, + } + + // clone the extension map and remove entries from the clone as they are processed so + // that we can easily aggregate the remainder into the CertificateExtensions field. + extensions := maps.Clone(cert.Extensions) + + takeExtension := func(name string) (value string, ok bool) { + v, ok := extensions[name] + if !ok { + return "", false + } + delete(extensions, name) + return v, true + } + + takeValue := func(name string) string { + value, _ := takeExtension(name) + return value + } + + takeBool := func(name string) bool { + _, ok := takeExtension(name) + return ok + } + + // ignore the permit pty extension, it's always set + _, _ = takeExtension(teleport.CertExtensionPermitPTY) + + ident.PermitX11Forwarding = takeBool(teleport.CertExtensionPermitX11Forwarding) + ident.PermitAgentForwarding = takeBool(teleport.CertExtensionPermitAgentForwarding) + ident.PermitPortForwarding = takeBool(teleport.CertExtensionPermitPortForwarding) + ident.MFAVerified = takeValue(teleport.CertExtensionMFAVerified) + + if v, ok := takeExtension(teleport.CertExtensionPreviousIdentityExpires); ok { + t, err := time.Parse(time.RFC3339, v) + if err != nil { + return nil, trace.BadParameter("failed to parse value %q for extension %q as RFC3339 timestamp: %v", v, teleport.CertExtensionPreviousIdentityExpires, err) + } + ident.PreviousIdentityExpires = t + } + + ident.LoginIP = takeValue(teleport.CertExtensionLoginIP) + ident.Impersonator = takeValue(teleport.CertExtensionImpersonator) + ident.DisallowReissue = takeBool(teleport.CertExtensionDisallowReissue) + ident.Renewable = takeBool(teleport.CertExtensionRenewable) + + if v, ok := takeExtension(teleport.CertExtensionGeneration); ok { + i, err := strconv.ParseUint(v, 10, 64) + if err != nil { + return nil, trace.BadParameter("failed to parse value %q for extension %q as uint64: %v", v, teleport.CertExtensionGeneration, err) + } + ident.Generation = i + } + + ident.BotName = takeValue(teleport.CertExtensionBotName) + ident.BotInstanceID = takeValue(teleport.CertExtensionBotInstanceID) + ident.AllowedResourceIDs = takeValue(teleport.CertExtensionAllowedResources) + ident.ConnectionDiagnosticID = takeValue(teleport.CertExtensionConnectionDiagnosticID) + ident.PrivateKeyPolicy = keys.PrivateKeyPolicy(takeValue(teleport.CertExtensionPrivateKeyPolicy)) + ident.DeviceID = takeValue(teleport.CertExtensionDeviceID) + ident.DeviceAssetTag = takeValue(teleport.CertExtensionDeviceAssetTag) + ident.DeviceCredentialID = takeValue(teleport.CertExtensionDeviceCredentialID) + ident.GitHubUserID = takeValue(teleport.CertExtensionGitHubUserID) + ident.GitHubUsername = takeValue(teleport.CertExtensionGitHubUsername) + + if v, ok := cert.CriticalOptions[teleport.CertCriticalOptionSourceAddress]; ok { + parts := strings.Split(v, "/") + if len(parts) != 2 { + return nil, trace.BadParameter("failed to parse value %q for critical option %q as CIDR", v, teleport.CertCriticalOptionSourceAddress) + } + ident.PinnedIP = parts[0] + } + + if v, ok := takeExtension(teleport.CertExtensionTeleportTraits); ok { + var traits wrappers.Traits + if err := wrappers.UnmarshalTraits([]byte(v), &traits); err != nil { + return nil, trace.BadParameter("failed to unmarshal value %q for extension %q as traits: %v", v, teleport.CertExtensionTeleportTraits, err) + } + ident.Traits = traits + } + + if v, ok := takeExtension(teleport.CertExtensionTeleportRoles); ok { + roles, err := services.UnmarshalCertRoles(v) + if err != nil { + return nil, trace.BadParameter("failed to unmarshal value %q for extension %q as roles: %v", v, teleport.CertExtensionTeleportRoles, err) + } + ident.Roles = roles + } + + ident.RouteToCluster = takeValue(teleport.CertExtensionTeleportRouteToCluster) + + if v, ok := takeExtension(teleport.CertExtensionTeleportActiveRequests); ok { + var requests services.RequestIDs + if err := requests.Unmarshal([]byte(v)); err != nil { + return nil, trace.BadParameter("failed to unmarshal value %q for extension %q as active requests: %v", v, teleport.CertExtensionTeleportActiveRequests, err) + } + ident.ActiveRequests = requests + } + + // aggregate all remaining extensions into the CertificateExtensions field + for name, value := range extensions { + ident.CertificateExtensions = append(ident.CertificateExtensions, &types.CertExtension{ + Name: name, + Value: value, + Type: types.CertExtensionType_SSH, + Mode: types.CertExtensionMode_EXTENSION, + }) + } + + return ident, nil +} diff --git a/lib/sshca/identity_test.go b/lib/sshca/identity_test.go new file mode 100644 index 0000000000000..5c7c6db75b3e8 --- /dev/null +++ b/lib/sshca/identity_test.go @@ -0,0 +1,97 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +// Package sshca specifies interfaces for SSH certificate authorities +package sshca + +import ( + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/api/constants" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/types/wrappers" + "github.com/gravitational/teleport/api/utils/keys" + "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/utils/testutils" +) + +func TestIdentityConversion(t *testing.T) { + ident := &Identity{ + ValidAfter: 1, + ValidBefore: 2, + Username: "user", + Impersonator: "impersonator", + AllowedLogins: []string{"login1", "login2"}, + PermitX11Forwarding: true, + PermitAgentForwarding: true, + PermitPortForwarding: true, + Roles: []string{"role1", "role2"}, + RouteToCluster: "cluster", + Traits: wrappers.Traits{"trait1": []string{"value1"}, "trait2": []string{"value2"}}, + ActiveRequests: services.RequestIDs{ + AccessRequests: []string{uuid.NewString()}, + }, + MFAVerified: "mfa", + PreviousIdentityExpires: time.Unix(12345, 0), + LoginIP: "127.0.0.1", + PinnedIP: "127.0.0.1", + DisallowReissue: true, + CertificateExtensions: []*types.CertExtension{&types.CertExtension{ + Name: "extname", + Value: "extvalue", + Type: types.CertExtensionType_SSH, + Mode: types.CertExtensionMode_EXTENSION, + }}, + Renewable: true, + Generation: 3, + BotName: "bot", + BotInstanceID: "instance", + AllowedResourceIDs: "resource", + ConnectionDiagnosticID: "diag", + PrivateKeyPolicy: keys.PrivateKeyPolicy("policy"), + DeviceID: "device", + DeviceAssetTag: "asset", + DeviceCredentialID: "cred", + GitHubUserID: "github", + GitHubUsername: "ghuser", + } + + ignores := []string{ + "CertExtension.Type", // only currently defined enum variant is a zero value + "CertExtension.Mode", // only currently defined enum variant is a zero value + // TODO(fspmarshall): figure out a mechanism for making ignore of grpc fields more convenient + "CertExtension.XXX_NoUnkeyedLiteral", + "CertExtension.XXX_unrecognized", + "CertExtension.XXX_sizecache", + } + + require.True(t, testutils.ExhaustiveNonEmpty(ident, ignores...), "empty=%+v", testutils.FindAllEmpty(ident, ignores...)) + + cert, err := ident.Encode(constants.CertificateFormatStandard) + require.NoError(t, err) + + ident2, err := DecodeIdentity(cert) + require.NoError(t, err) + + require.Empty(t, cmp.Diff(ident, ident2)) +} diff --git a/lib/sshca/sshca.go b/lib/sshca/sshca.go index 5e9e3f548f853..15f5dcf6c1aeb 100644 --- a/lib/sshca/sshca.go +++ b/lib/sshca/sshca.go @@ -20,6 +20,12 @@ package sshca import ( + "time" + + "github.com/gravitational/trace" + "golang.org/x/crypto/ssh" + + apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/lib/services" ) @@ -33,5 +39,34 @@ type Authority interface { // GenerateUserCert generates user ssh certificate, it takes pkey as a signing // private key (user certificate authority) - GenerateUserCert(certParams services.UserCertParams) ([]byte, error) + GenerateUserCert(UserCertificateRequest) ([]byte, error) +} + +// UserCertificateRequest is a request to generate a new ssh user certificate. +type UserCertificateRequest struct { + // CASigner is the signer that will sign the public key of the user with the CA private key + CASigner ssh.Signer + // PublicUserKey is the public key of the user in SSH authorized_keys format. + PublicUserKey []byte + // TTL defines how long a certificate is valid for (if specified, ValidAfter/ValidBefore within the + // identity must not be set). + TTL time.Duration + // CertificateFormat is the format of the SSH certificate. + CertificateFormat string + // Identity is the user identity to be encoded in the certificate. + Identity Identity +} + +func (r *UserCertificateRequest) CheckAndSetDefaults() error { + if r.CASigner == nil { + return trace.BadParameter("ssh user certificate request missing ca signer") + } + if r.TTL < apidefaults.MinCertDuration { + r.TTL = apidefaults.MinCertDuration + } + if err := r.Identity.Check(); err != nil { + return trace.Wrap(err) + } + + return nil } From 4ee850ee088722ae9db80b5089f8ea74dbcdab86 Mon Sep 17 00:00:00 2001 From: Lisa Kim Date: Thu, 9 Jan 2025 15:14:45 -0800 Subject: [PATCH 2/8] Pass join token suggestedLabels to app server labels during install.sh (#50720) * Allow adding app server labels from join token for install.sh * Address CRs * Reduce label yaml space, improve test --- lib/web/join_tokens.go | 8 ++++++++ lib/web/join_tokens_test.go | 11 +++++++++++ lib/web/scripts/node-join/install.sh | 8 ++++++++ 3 files changed, 27 insertions(+) diff --git a/lib/web/join_tokens.go b/lib/web/join_tokens.go index df9896f5e1532..d54269df7c381 100644 --- a/lib/web/join_tokens.go +++ b/lib/web/join_tokens.go @@ -631,6 +631,7 @@ func getJoinScript(ctx context.Context, settings scriptSettings, m nodeAPIGetter } var buf bytes.Buffer + var appServerResourceLabels []string // If app install mode is requested but parameters are blank for some reason, // we need to return an error. if settings.appInstallMode { @@ -640,6 +641,12 @@ func getJoinScript(ctx context.Context, settings scriptSettings, m nodeAPIGetter if !appURIPattern.MatchString(settings.appURI) { return "", trace.BadParameter("appURI %q contains invalid characters", settings.appURI) } + + suggestedLabels := token.GetSuggestedLabels() + appServerResourceLabels, err = scripts.MarshalLabelsYAML(suggestedLabels, 4) + if err != nil { + return "", trace.Wrap(err) + } } if settings.discoveryInstallMode { @@ -689,6 +696,7 @@ func getJoinScript(ctx context.Context, settings scriptSettings, m nodeAPIGetter "installUpdater": strconv.FormatBool(settings.installUpdater), "version": shsprintf.EscapeDefaultContext(version), "appInstallMode": strconv.FormatBool(settings.appInstallMode), + "appServerResourceLabels": appServerResourceLabels, "appName": shsprintf.EscapeDefaultContext(settings.appName), "appURI": shsprintf.EscapeDefaultContext(settings.appURI), "joinMethod": shsprintf.EscapeDefaultContext(settings.joinMethod), diff --git a/lib/web/join_tokens_test.go b/lib/web/join_tokens_test.go index ba0b0be4ff9b1..4e0062b333ef3 100644 --- a/lib/web/join_tokens_test.go +++ b/lib/web/join_tokens_test.go @@ -761,6 +761,17 @@ func TestGetNodeJoinScript(t *testing.T) { require.Contains(t, script, fmt.Sprintf("%s=%s", types.InternalResourceIDLabel, internalResourceID)) }, }, + { + desc: "app server labels", + settings: scriptSettings{token: validToken, appInstallMode: true, appName: "app-name", appURI: "app-uri"}, + errAssert: require.NoError, + extraAssertions: func(script string) { + require.Contains(t, script, `APP_NAME='app-name'`) + require.Contains(t, script, `APP_URI='app-uri'`) + require.Contains(t, script, `public_addr`) + require.Contains(t, script, fmt.Sprintf(" labels:\n %s: %s", types.InternalResourceIDLabel, internalResourceID)) + }, + }, } { t.Run(test.desc, func(t *testing.T) { script, err := getJoinScript(context.Background(), test.settings, m) diff --git a/lib/web/scripts/node-join/install.sh b/lib/web/scripts/node-join/install.sh index 3d8403c00787d..64c7cc6b6aab2 100755 --- a/lib/web/scripts/node-join/install.sh +++ b/lib/web/scripts/node-join/install.sh @@ -441,6 +441,11 @@ get_yaml_list() { install_teleport_app_config() { log "Writing Teleport app service config to ${TELEPORT_CONFIG_PATH}" CA_PINS_CONFIG=$(get_yaml_list "ca_pin" "${CA_PIN_HASHES}" " ") + # This file is processed by `shellschek` as part of the lint step + # It detects an issue because of un-set variables - $index and $line. This check is called SC2154. + # However, that's not an issue, because those variables are replaced when we run go's text/template engine over it. + # When executing the script, those are no long variables but actual values. + # shellcheck disable=SC2154 cat << EOF > ${TELEPORT_CONFIG_PATH} version: v3 teleport: @@ -463,6 +468,9 @@ app_service: - name: "${APP_NAME}" uri: "${APP_URI}" public_addr: ${APP_PUBLIC_ADDR} + labels:{{range $index, $line := .appServerResourceLabels}} + {{$line -}} +{{end}} EOF } # installs the provided teleport config (for database service) From 66f82c2f31bfb2ee10075a182f95ac61deea7919 Mon Sep 17 00:00:00 2001 From: Marco Dinis Date: Fri, 10 Jan 2025 11:38:23 +0000 Subject: [PATCH 3/8] Fix UserTask Status not being updated (#50855) * Fix UserTask Status not being updated The Status field for UserTasks was not being correctly updated when the Spec.State was not changed. * copy the status field * use admin client instead of backend directly --- lib/auth/usertasks/usertasksv1/service.go | 25 +++++++++----- .../usertasks/usertasksv1/service_test.go | 33 ++++++++++++++++++- lib/web/usertasks_test.go | 5 ++- 3 files changed, 52 insertions(+), 11 deletions(-) diff --git a/lib/auth/usertasks/usertasksv1/service.go b/lib/auth/usertasks/usertasksv1/service.go index a383e55a70135..74223f258369c 100644 --- a/lib/auth/usertasks/usertasksv1/service.go +++ b/lib/auth/usertasks/usertasksv1/service.go @@ -19,6 +19,7 @@ package usertasksv1 import ( + "cmp" "context" "log/slog" "time" @@ -131,7 +132,7 @@ func (s *Service) CreateUserTask(ctx context.Context, req *usertasksv1.CreateUse return nil, trace.Wrap(err) } - s.updateStatus(req.UserTask) + s.updateStatus(req.UserTask, nil /* existing user task */) rsp, err := s.backend.CreateUserTask(ctx, req.UserTask) s.emitCreateAuditEvent(ctx, rsp, authCtx, err) @@ -264,10 +265,7 @@ func (s *Service) UpdateUserTask(ctx context.Context, req *usertasksv1.UpdateUse } stateChanged := existingUserTask.GetSpec().GetState() != req.GetUserTask().GetSpec().GetState() - - if stateChanged { - s.updateStatus(req.UserTask) - } + s.updateStatus(req.UserTask, existingUserTask) rsp, err := s.backend.UpdateUserTask(ctx, req.UserTask) s.emitUpdateAuditEvent(ctx, existingUserTask, req.GetUserTask(), authCtx, err) @@ -333,9 +331,7 @@ func (s *Service) UpsertUserTask(ctx context.Context, req *usertasksv1.UpsertUse stateChanged = existingUserTask.GetSpec().GetState() != req.GetUserTask().GetSpec().GetState() } - if stateChanged { - s.updateStatus(req.UserTask) - } + s.updateStatus(req.UserTask, existingUserTask) rsp, err := s.backend.UpsertUserTask(ctx, req.UserTask) s.emitUpsertAuditEvent(ctx, existingUserTask, req.GetUserTask(), authCtx, err) @@ -350,10 +346,21 @@ func (s *Service) UpsertUserTask(ctx context.Context, req *usertasksv1.UpsertUse return rsp, nil } -func (s *Service) updateStatus(ut *usertasksv1.UserTask) { +func (s *Service) updateStatus(ut *usertasksv1.UserTask, existing *usertasksv1.UserTask) { + // Default status for UserTask. ut.Status = &usertasksv1.UserTaskStatus{ LastStateChange: timestamppb.New(s.clock.Now()), } + + if existing != nil { + // Inherit everything from existing UserTask. + ut.Status.LastStateChange = cmp.Or(existing.GetStatus().GetLastStateChange(), ut.Status.LastStateChange) + + // Update specific values. + if existing.GetSpec().GetState() != ut.GetSpec().GetState() { + ut.Status.LastStateChange = timestamppb.New(s.clock.Now()) + } + } } func (s *Service) emitUpsertAuditEvent(ctx context.Context, old, new *usertasksv1.UserTask, authCtx *authz.Context, err error) { diff --git a/lib/auth/usertasks/usertasksv1/service_test.go b/lib/auth/usertasks/usertasksv1/service_test.go index d40b3740af591..1a909c278bdd8 100644 --- a/lib/auth/usertasks/usertasksv1/service_test.go +++ b/lib/auth/usertasks/usertasksv1/service_test.go @@ -153,6 +153,7 @@ func TestEvents(t *testing.T) { // LastStateChange is updated. require.Equal(t, timestamppb.New(fakeClock.Now()), createUserTaskResp.Status.LastStateChange) + expectedLastStateChange := createUserTaskResp.Status.LastStateChange ut1.Spec.DiscoverEc2.Instances["i-345"] = &usertasksv1.DiscoverEC2Instance{ InstanceId: "i-345", DiscoveryConfig: "dc01", @@ -165,7 +166,7 @@ func TestEvents(t *testing.T) { require.Len(t, testReporter.emittedEvents, 1) consumeAssertEvent(t, auditEventsSink.C(), auditEventFor(userTaskName, "update", "OPEN", "OPEN")) // LastStateChange is not updated. - require.Equal(t, createUserTaskResp.Status.LastStateChange, upsertUserTaskResp.Status.LastStateChange) + require.Equal(t, expectedLastStateChange.AsTime(), upsertUserTaskResp.Status.LastStateChange.AsTime()) ut1.Spec.State = "RESOLVED" fakeClock.Advance(1 * time.Minute) @@ -177,6 +178,36 @@ func TestEvents(t *testing.T) { // LastStateChange was updated because the state changed. require.Equal(t, timestamppb.New(fakeClock.Now()), updateUserTaskResp.Status.LastStateChange) + // Updating one of the instances. + expectedLastStateChange = updateUserTaskResp.Status.GetLastStateChange() + fakeClock.Advance(1 * time.Minute) + ut1.Spec.DiscoverEc2.Instances["i-345"] = &usertasksv1.DiscoverEC2Instance{ + InstanceId: "i-345", + DiscoveryConfig: "dc01", + DiscoveryGroup: "dg01", + SyncTime: timestamppb.New(fakeClock.Now()), + } + updateUserTaskResp, err = service.UpdateUserTask(ctx, &usertasksv1.UpdateUserTaskRequest{UserTask: ut1}) + require.NoError(t, err) + // Does not change the LastStateChange + require.Equal(t, expectedLastStateChange.AsTime(), updateUserTaskResp.Status.LastStateChange.AsTime()) + consumeAssertEvent(t, auditEventsSink.C(), auditEventFor(userTaskName, "update", "RESOLVED", "RESOLVED")) + + // Upserting one of the instances. + expectedLastStateChange = updateUserTaskResp.Status.GetLastStateChange() + fakeClock.Advance(1 * time.Minute) + ut1.Spec.DiscoverEc2.Instances["i-345"] = &usertasksv1.DiscoverEC2Instance{ + InstanceId: "i-345", + DiscoveryConfig: "dc01", + DiscoveryGroup: "dg01", + SyncTime: timestamppb.New(fakeClock.Now()), + } + upsertUserTaskResp, err = service.UpsertUserTask(ctx, &usertasksv1.UpsertUserTaskRequest{UserTask: ut1}) + require.NoError(t, err) + // Does not change the LastStateChange + require.Equal(t, expectedLastStateChange.AsTime(), upsertUserTaskResp.Status.LastStateChange.AsTime()) + consumeAssertEvent(t, auditEventsSink.C(), auditEventFor(userTaskName, "update", "RESOLVED", "RESOLVED")) + _, err = service.DeleteUserTask(ctx, &usertasksv1.DeleteUserTaskRequest{Name: userTaskName}) require.NoError(t, err) // No usage report for deleted resources. diff --git a/lib/web/usertasks_test.go b/lib/web/usertasks_test.go index 0bb2dbb9a9f9a..13e9723458090 100644 --- a/lib/web/usertasks_test.go +++ b/lib/web/usertasks_test.go @@ -31,6 +31,7 @@ import ( usertasksv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/usertasks/v1" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/types/usertasks" + "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/web/ui" ) @@ -53,6 +54,8 @@ func TestUserTask(t *testing.T) { }) require.NoError(t, err) pack := env.proxies[0].authPack(t, userWithRW, []types.Role{roleRWUserTask}) + adminClient, err := env.server.NewClient(auth.TestAdmin()) + require.NoError(t, err) getAllEndpoint := pack.clt.Endpoint("webapi", "sites", clusterName, "usertask") singleItemEndpoint := func(name string) string { @@ -90,7 +93,7 @@ func TestUserTask(t *testing.T) { }) require.NoError(t, err) - _, err = env.proxies[0].auth.Auth().CreateUserTask(ctx, userTask) + _, err = adminClient.UserTasksServiceClient().CreateUserTask(ctx, userTask) require.NoError(t, err) userTaskForTest = userTask } From 4a10f05a96d33001bb54e90a38e15905957b0117 Mon Sep 17 00:00:00 2001 From: Noah Stride Date: Fri, 10 Jan 2025 13:25:02 +0000 Subject: [PATCH 4/8] Update protoc-gen-terraform to v3.0.2 (#50943) --- .github/workflows/lint.yaml | 2 +- integrations/terraform/Makefile | 2 +- integrations/terraform/README.md | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index 0cb29695f968b..ff6cdd143e7e1 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -235,7 +235,7 @@ jobs: - name: Check if Terraform resources are up to date # We have to add the current directory as a safe directory or else git commands will not work as expected. # The protoc-gen-terraform version must match the version in integrations/terraform/Makefile - run: git config --global --add safe.directory $(realpath .) && go install github.com/gravitational/protoc-gen-terraform@c91cc3ef4d7d0046c36cb96b1cd337e466c61225 && make terraform-resources-up-to-date + run: git config --global --add safe.directory $(realpath .) && go install github.com/gravitational/protoc-gen-terraform/v3@v3.0.2 && make terraform-resources-up-to-date lint-rfd: name: Lint (RFD) diff --git a/integrations/terraform/Makefile b/integrations/terraform/Makefile index 572a07d4d45dc..149aef0ed5b4b 100644 --- a/integrations/terraform/Makefile +++ b/integrations/terraform/Makefile @@ -47,7 +47,7 @@ $(BUILDDIR)/terraform-provider-teleport_%: terraform-provider-teleport-v$(VERSIO CUSTOM_IMPORTS_TMP_DIR ?= /tmp/protoc-gen-terraform/custom-imports # This version must match the version installed by .github/workflows/lint.yaml -PROTOC_GEN_TERRAFORM_VERSION ?= v3.0.0 +PROTOC_GEN_TERRAFORM_VERSION ?= v3.0.2 PROTOC_GEN_TERRAFORM_EXISTS := $(shell $(PROTOC_GEN_TERRAFORM) version 2>&1 >/dev/null | grep 'protoc-gen-terraform $(PROTOC_GEN_TERRAFORM_VERSION)') .PHONY: gen-tfschema diff --git a/integrations/terraform/README.md b/integrations/terraform/README.md index 53e752f725d41..dde74bc7b793b 100644 --- a/integrations/terraform/README.md +++ b/integrations/terraform/README.md @@ -7,9 +7,9 @@ Please, refer to [official documentation](https://goteleport.com/docs/admin-guid ## Development 1. Install [`protobuf`](https://grpc.io/docs/protoc-installation/). -2. Install [`protoc-gen-terraform`](https://github.com/gravitational/protoc-gen-terraform) @v3.0.0. +2. Install [`protoc-gen-terraform`](https://github.com/gravitational/protoc-gen-terraform) @v3.0.2. - ```go install github.com/gravitational/protoc-gen-terraform@c91cc3ef4d7d0046c36cb96b1cd337e466c61225``` + ```go install github.com/gravitational/protoc-gen-terraform/v3@v3.0.2``` 3. Install [`Terraform`](https://learn.hashicorp.com/tutorials/terraform/install-cli) v1.1.0+. Alternatively, you can use [`tfenv`](https://github.com/tfutils/tfenv). Please note that on Mac M1 you need to specify `TFENV_ARCH` (ex: `TFENV_ARCH=arm64 tfenv install 1.1.6`). From 5b5bab980ea6c6fb21b8b552045f2a5612b8f867 Mon Sep 17 00:00:00 2001 From: Hugo Shaka Date: Fri, 10 Jan 2025 10:43:54 -0500 Subject: [PATCH 5/8] Use a non-global metrics registry in Teleport (#50913) * Support a non-global registry in Teleport * lint * Update lib/service/service.go Co-authored-by: rosstimothy <39066650+rosstimothy@users.noreply.github.com> --------- Co-authored-by: rosstimothy <39066650+rosstimothy@users.noreply.github.com> --- lib/service/service.go | 48 +++++++++++++++++- lib/service/service_test.go | 86 +++++++++++++++++++++++++++++++- lib/service/servicecfg/config.go | 11 ++++ 3 files changed, 142 insertions(+), 3 deletions(-) diff --git a/lib/service/service.go b/lib/service/service.go index 7638ee5e85caf..7fd997e7234f0 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -54,6 +54,7 @@ import ( "github.com/gravitational/roundtrip" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" + "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/quic-go/quic-go" "go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc" @@ -657,6 +658,15 @@ type TeleportProcess struct { // resolver is used to identify the reverse tunnel address when connecting via // the proxy. resolver reversetunnelclient.Resolver + + // metricRegistry is the prometheus metric registry for the process. + // Every teleport service that wants to register metrics should use this + // instead of the global prometheus.DefaultRegisterer to avoid registration + // conflicts. + // + // Both the metricsRegistry and the default global registry are gathered by + // Telepeort's metric service. + metricsRegistry *prometheus.Registry } // processIndex is an internal process index @@ -1179,6 +1189,7 @@ func NewTeleport(cfg *servicecfg.Config) (*TeleportProcess, error) { logger: cfg.Logger, cloudLabels: cloudLabels, TracingProvider: tracing.NoopProvider(), + metricsRegistry: cfg.MetricsRegistry, } process.registerExpectedServices(cfg) @@ -3405,11 +3416,46 @@ func (process *TeleportProcess) initUploaderService() error { return nil } +// promHTTPLogAdapter adapts a slog.Logger into a promhttp.Logger. +type promHTTPLogAdapter struct { + ctx context.Context + *slog.Logger +} + +// Println implements the promhttp.Logger interface. +func (l promHTTPLogAdapter) Println(v ...interface{}) { + //nolint:sloglint // msg cannot be constant + l.ErrorContext(l.ctx, fmt.Sprint(v...)) +} + // initMetricsService starts the metrics service currently serving metrics for // prometheus consumption func (process *TeleportProcess) initMetricsService() error { mux := http.NewServeMux() - mux.Handle("/metrics", promhttp.Handler()) + + // We gather metrics both from the in-process registry (preferred metrics registration method) + // and the global registry (used by some Teleport services and many dependencies). + gatherers := prometheus.Gatherers{ + process.metricsRegistry, + prometheus.DefaultGatherer, + } + + metricsHandler := promhttp.InstrumentMetricHandler( + process.metricsRegistry, promhttp.HandlerFor(gatherers, promhttp.HandlerOpts{ + // Errors can happen if metrics are registered with identical names in both the local and the global registry. + // In this case, we log the error but continue collecting metrics. The first collected metric will win + // (the one from the local metrics registry takes precedence). + // As we move more things to the local registry, especially in other tools like tbot, we will have less + // conflicts in tests. + ErrorHandling: promhttp.ContinueOnError, + ErrorLog: promHTTPLogAdapter{ + ctx: process.ExitContext(), + Logger: process.logger.With(teleport.ComponentKey, teleport.ComponentMetrics), + }, + }), + ) + + mux.Handle("/metrics", metricsHandler) logger := process.logger.With(teleport.ComponentKey, teleport.Component(teleport.ComponentMetrics, process.id)) diff --git a/lib/service/service_test.go b/lib/service/service_test.go index 52e59387ff580..4c08a87689145 100644 --- a/lib/service/service_test.go +++ b/lib/service/service_test.go @@ -23,9 +23,11 @@ import ( "crypto/tls" "errors" "fmt" + "io" "log/slog" "net" "net/http" + "net/url" "os" "path/filepath" "strings" @@ -39,6 +41,8 @@ import ( "github.com/google/uuid" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" + "github.com/prometheus/client_golang/prometheus" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/sync/errgroup" "google.golang.org/grpc" @@ -1887,7 +1891,7 @@ func TestAgentRolloutController(t *testing.T) { dataDir := makeTempDir(t) cfg := servicecfg.MakeDefaultConfig() - // We use a real clock because too many sevrices are using the clock and it's not possible to accurately wait for + // We use a real clock because too many services are using the clock and it's not possible to accurately wait for // each one of them to reach the point where they wait for the clock to advance. If we add a WaitUntil(X waiters) // check, this will break the next time we add a new waiter. cfg.Clock = clockwork.NewRealClock() @@ -1906,7 +1910,7 @@ func TestAgentRolloutController(t *testing.T) { process, err := NewTeleport(cfg) require.NoError(t, err) - // Test setup: start the Teleport auth and wait for it to beocme ready + // Test setup: start the Teleport auth and wait for it to become ready require.NoError(t, process.Start()) // Test setup: wait for every service to start @@ -1949,6 +1953,84 @@ func TestAgentRolloutController(t *testing.T) { }, 5*time.Second, 10*time.Millisecond) } +func TestMetricsService(t *testing.T) { + t.Parallel() + // Test setup: create a listener for the metrics server, get its file descriptor. + + // Note: this code is copied from integrations/helpers/NewListenerOn() to avoid including helpers in a production + // build and avoid a cyclic dependency. + metricsListener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + t.Cleanup(func() { + assert.NoError(t, metricsListener.Close()) + }) + require.IsType(t, &net.TCPListener{}, metricsListener) + metricsListenerFile, err := metricsListener.(*net.TCPListener).File() + require.NoError(t, err) + + // Test setup: create a new teleport process + dataDir := makeTempDir(t) + cfg := servicecfg.MakeDefaultConfig() + cfg.DataDir = dataDir + cfg.SetAuthServerAddress(utils.NetAddr{AddrNetwork: "tcp", Addr: "127.0.0.1:0"}) + cfg.Auth.Enabled = true + cfg.Proxy.Enabled = false + cfg.SSH.Enabled = false + cfg.DebugService.Enabled = false + cfg.Auth.StorageConfig.Params["path"] = dataDir + cfg.Auth.ListenAddr = utils.NetAddr{AddrNetwork: "tcp", Addr: "127.0.0.1:0"} + cfg.Metrics.Enabled = true + + // Configure the metrics server to use the listener we previously created. + cfg.Metrics.ListenAddr = &utils.NetAddr{AddrNetwork: "tcp", Addr: metricsListener.Addr().String()} + cfg.FileDescriptors = []*servicecfg.FileDescriptor{ + {Type: string(ListenerMetrics), Address: metricsListener.Addr().String(), File: metricsListenerFile}, + } + + // Create and start the Teleport service. + process, err := NewTeleport(cfg) + require.NoError(t, err) + require.NoError(t, process.Start()) + t.Cleanup(func() { + assert.NoError(t, process.Close()) + assert.NoError(t, process.Wait()) + }) + + // Test setup: create our test metrics. + nonce := strings.ReplaceAll(uuid.NewString(), "-", "") + localMetric := prometheus.NewGauge(prometheus.GaugeOpts{ + Namespace: "test", + Name: "local_metric_" + nonce, + }) + globalMetric := prometheus.NewGauge(prometheus.GaugeOpts{ + Namespace: "test", + Name: "global_metric_" + nonce, + }) + require.NoError(t, process.metricsRegistry.Register(localMetric)) + require.NoError(t, prometheus.Register(globalMetric)) + + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + t.Cleanup(cancel) + _, err = process.WaitForEvent(ctx, MetricsReady) + require.NoError(t, err) + + // Test execution: get metrics and check the tests metrics are here. + metricsURL, err := url.Parse("http://" + metricsListener.Addr().String()) + require.NoError(t, err) + metricsURL.Path = "/metrics" + resp, err := http.Get(metricsURL.String()) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.NoError(t, resp.Body.Close()) + + // Test validation: check that the metrics server served both the local and global registry. + require.Contains(t, string(body), "local_metric_"+nonce) + require.Contains(t, string(body), "global_metric_"+nonce) +} + // makeTempDir makes a temp dir with a shorter name than t.TempDir() in order to // avoid https://github.com/golang/go/issues/62614. func makeTempDir(t *testing.T) string { diff --git a/lib/service/servicecfg/config.go b/lib/service/servicecfg/config.go index a89e79a8f6302..a89e29a2c7b54 100644 --- a/lib/service/servicecfg/config.go +++ b/lib/service/servicecfg/config.go @@ -34,6 +34,7 @@ import ( "github.com/ghodss/yaml" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" + "github.com/prometheus/client_golang/prometheus" "golang.org/x/crypto/ssh" "github.com/gravitational/teleport" @@ -264,6 +265,12 @@ type Config struct { // protocol. DatabaseREPLRegistry dbrepl.REPLRegistry + // MetricsRegistry is the prometheus metrics registry used by the Teleport process to register its metrics. + // As of today, not every Teleport metric is registered against this registry. Some Teleport services + // and Teleport dependencies are using the global registry. + // Both the MetricsRegistry and the default global registry are gathered by Teleport's metric service. + MetricsRegistry *prometheus.Registry + // token is either the token needed to join the auth server, or a path pointing to a file // that contains the token // @@ -520,6 +527,10 @@ func ApplyDefaults(cfg *Config) { cfg.LoggerLevel = new(slog.LevelVar) } + if cfg.MetricsRegistry == nil { + cfg.MetricsRegistry = prometheus.NewRegistry() + } + // Remove insecure and (borderline insecure) cryptographic primitives from // default configuration. These can still be added back in file configuration by // users, but not supported by default by Teleport. See #1856 for more From cca83feb66e777965f64162eb3f7bd7db34e4c3a Mon Sep 17 00:00:00 2001 From: rosstimothy <39066650+rosstimothy@users.noreply.github.com> Date: Fri, 10 Jan 2025 11:08:42 -0500 Subject: [PATCH 6/8] Convert integrations to use slog (#50921) --- integrations/access/accesslist/app.go | 49 +++-- .../access_monitoring_rules.go | 14 +- integrations/access/accessrequest/app.go | 76 ++++--- integrations/access/common/app.go | 10 +- .../access/common/auth/token_provider.go | 23 ++- .../access/common/auth/token_provider_test.go | 7 +- integrations/access/datadog/bot.go | 2 +- integrations/access/datadog/client.go | 2 +- .../datadog/cmd/teleport-datadog/main.go | 13 +- .../access/datadog/testlib/fake_datadog.go | 3 +- integrations/access/discord/bot.go | 3 +- .../discord/cmd/teleport-discord/main.go | 13 +- .../access/discord/testlib/fake_discord.go | 3 +- integrations/access/email/app.go | 70 ++++--- integrations/access/email/client.go | 12 +- .../access/email/cmd/teleport-email/main.go | 15 +- integrations/access/email/mailers.go | 6 +- .../access/email/testlib/mock_mailgun.go | 4 +- integrations/access/jira/app.go | 109 +++++----- integrations/access/jira/client.go | 15 +- .../access/jira/cmd/teleport-jira/main.go | 13 +- integrations/access/jira/testlib/fake_jira.go | 3 +- integrations/access/jira/testlib/suite.go | 2 +- integrations/access/jira/webhook_server.go | 14 +- integrations/access/mattermost/bot.go | 20 +- .../cmd/teleport-mattermost/main.go | 13 +- .../mattermost/testlib/fake_mattermost.go | 3 +- integrations/access/msteams/app.go | 7 +- integrations/access/msteams/bot.go | 3 +- .../msteams/cmd/teleport-msteams/main.go | 3 +- .../access/msteams/testlib/fake_msteams.go | 3 +- integrations/access/msteams/uninstall.go | 17 +- integrations/access/msteams/validate.go | 7 +- integrations/access/opsgenie/app.go | 64 +++--- integrations/access/opsgenie/client.go | 10 +- .../access/opsgenie/testlib/fake_opsgenie.go | 3 +- integrations/access/pagerduty/app.go | 83 ++++---- integrations/access/pagerduty/client.go | 12 +- .../pagerduty/cmd/teleport-pagerduty/main.go | 13 +- .../pagerduty/testlib/fake_pagerduty.go | 3 +- integrations/access/servicenow/app.go | 78 ++++--- integrations/access/servicenow/client.go | 5 +- .../servicenow/testlib/fake_servicenow.go | 3 +- integrations/access/slack/bot.go | 6 +- .../access/slack/cmd/teleport-slack/main.go | 13 +- .../access/slack/testlib/fake_slack.go | 3 +- .../event-handler/fake_fluentd_test.go | 3 - integrations/event-handler/main.go | 25 +-- integrations/lib/bail.go | 8 +- integrations/lib/config.go | 10 +- integrations/lib/embeddedtbot/bot.go | 11 +- integrations/lib/http.go | 14 +- integrations/lib/logger/logger.go | 190 +++--------------- integrations/lib/signals.go | 7 +- integrations/lib/tctl/tctl.go | 24 ++- integrations/lib/testing/integration/suite.go | 4 +- integrations/lib/watcherjob/watcherjob.go | 16 +- .../crdgen/cmd/protoc-gen-crd-docs/debug.go | 27 ++- .../crdgen/cmd/protoc-gen-crd-docs/main.go | 14 +- .../crdgen/cmd/protoc-gen-crd/debug.go | 27 ++- .../crdgen/cmd/protoc-gen-crd/main.go | 14 +- integrations/terraform/go.mod | 2 +- integrations/terraform/provider/errors.go | 6 +- integrations/terraform/provider/provider.go | 43 ++-- lib/utils/log/log.go | 3 + 65 files changed, 651 insertions(+), 650 deletions(-) diff --git a/integrations/access/accesslist/app.go b/integrations/access/accesslist/app.go index 02f933baf5ecd..ba40de3abf575 100644 --- a/integrations/access/accesslist/app.go +++ b/integrations/access/accesslist/app.go @@ -33,6 +33,7 @@ import ( "github.com/gravitational/teleport/integrations/lib" "github.com/gravitational/teleport/integrations/lib/logger" pd "github.com/gravitational/teleport/integrations/lib/plugindata" + logutils "github.com/gravitational/teleport/lib/utils/log" ) const ( @@ -118,7 +119,7 @@ func (a *App) run(ctx context.Context) error { log := logger.Get(ctx) - log.Info("Access list monitor is running") + log.InfoContext(ctx, "Access list monitor is running") a.job.SetReady(true) @@ -134,7 +135,7 @@ func (a *App) run(ctx context.Context) error { } timer.Reset(jitter(reminderInterval)) case <-ctx.Done(): - log.Info("Access list monitor is finished") + log.InfoContext(ctx, "Access list monitor is finished") return nil } } @@ -146,7 +147,7 @@ func (a *App) run(ctx context.Context) error { func (a *App) remindIfNecessary(ctx context.Context) error { log := logger.Get(ctx) - log.Info("Looking for Access List Review reminders") + log.InfoContext(ctx, "Looking for Access List Review reminders") var nextToken string var err error @@ -156,13 +157,14 @@ func (a *App) remindIfNecessary(ctx context.Context) error { accessLists, nextToken, err = a.apiClient.ListAccessLists(ctx, 0 /* default page size */, nextToken) if err != nil { if trace.IsNotImplemented(err) { - log.Errorf("access list endpoint is not implemented on this auth server, so the access list app is ceasing to run.") + log.ErrorContext(ctx, "access list endpoint is not implemented on this auth server, so the access list app is ceasing to run") return trace.Wrap(err) } else if trace.IsAccessDenied(err) { - log.Warnf("Slack bot does not have permissions to list access lists. Please add access_list read and list permissions " + - "to the role associated with the Slack bot.") + const msg = "Slack bot does not have permissions to list access lists. Please add access_list read and list permissions " + + "to the role associated with the Slack bot." + log.WarnContext(ctx, msg) } else { - log.Errorf("error listing access lists: %v", err) + log.ErrorContext(ctx, "error listing access lists", "error", err) } break } @@ -170,7 +172,10 @@ func (a *App) remindIfNecessary(ctx context.Context) error { for _, accessList := range accessLists { recipients, err := a.getRecipientsRequiringReminders(ctx, accessList) if err != nil { - log.WithError(err).Warnf("Error getting recipients to notify for review due for access list %q", accessList.Spec.Title) + log.WarnContext(ctx, "Error getting recipients to notify for review due for access list", + "error", err, + "access_list", accessList.Spec.Title, + ) continue } @@ -195,7 +200,7 @@ func (a *App) remindIfNecessary(ctx context.Context) error { } if len(errs) > 0 { - log.WithError(trace.NewAggregate(errs...)).Warn("Error notifying for access list reviews") + log.WarnContext(ctx, "Error notifying for access list reviews", "error", trace.NewAggregate(errs...)) } return nil @@ -213,7 +218,10 @@ func (a *App) getRecipientsRequiringReminders(ctx context.Context, accessList *a // If the current time before the notification start time, skip notifications. if now.Before(notificationStart) { - log.Debugf("Access list %s is not ready for notifications, notifications start at %s", accessList.GetName(), notificationStart.Format(time.RFC3339)) + log.DebugContext(ctx, "Access list is not ready for notifications", + "access_list", accessList.GetName(), + "notification_start_time", notificationStart.Format(time.RFC3339), + ) return nil, nil } @@ -255,12 +263,17 @@ func (a *App) fetchRecipients(ctx context.Context, accessList *accesslist.Access if err != nil { // TODO(kiosion): Remove in v18; protecting against server not having `GetAccessListOwners` func. if trace.IsNotImplemented(err) { - log.WithError(err).Warnf("Error getting nested owners for access list '%v', continuing with only explicit owners", accessList.GetName()) + log.WarnContext(ctx, "Error getting nested owners for access list, continuing with only explicit owners", + "error", err, + "access_list", accessList.GetName(), + ) for _, owner := range accessList.Spec.Owners { allOwners = append(allOwners, &owner) } } else { - log.WithError(err).Errorf("Error getting owners for access list '%v'", accessList.GetName()) + log.ErrorContext(ctx, "Error getting owners for access list", + "error", err, + "access_list", accessList.GetName()) } } @@ -270,7 +283,7 @@ func (a *App) fetchRecipients(ctx context.Context, accessList *accesslist.Access for _, owner := range allOwners { recipient, err := a.bot.FetchRecipient(ctx, owner.Name) if err != nil { - log.Debugf("error getting recipient %s", owner.Name) + log.DebugContext(ctx, "error getting recipient", "recipient", owner.Name) continue } allRecipients[owner.Name] = *recipient @@ -293,7 +306,10 @@ func (a *App) updatePluginDataAndGetRecipientsRequiringReminders(ctx context.Con // Calculate days from start. daysFromStart := now.Sub(notificationStart) / oneDay windowStart = notificationStart.Add(daysFromStart * oneDay) - log.Infof("windowStart: %s, now: %s", windowStart.String(), now.String()) + log.InfoContext(ctx, "calculating window start", + "window_start", logutils.StringerAttr(windowStart), + "now", logutils.StringerAttr(now), + ) } recipients := []common.Recipient{} @@ -304,7 +320,10 @@ func (a *App) updatePluginDataAndGetRecipientsRequiringReminders(ctx context.Con // If the notification window is before the last notification date, then this user doesn't need a notification. if !windowStart.After(lastNotification) { - log.Debugf("User %s has already been notified for access list %s", recipient.Name, accessList.GetName()) + log.DebugContext(ctx, "User has already been notified for access list", + "user", recipient.Name, + "access_list", accessList.GetName(), + ) userNotifications[recipient.Name] = lastNotification continue } diff --git a/integrations/access/accessmonitoring/access_monitoring_rules.go b/integrations/access/accessmonitoring/access_monitoring_rules.go index 3dea9ea2bf543..82c91413bff96 100644 --- a/integrations/access/accessmonitoring/access_monitoring_rules.go +++ b/integrations/access/accessmonitoring/access_monitoring_rules.go @@ -151,8 +151,10 @@ func (amrh *RuleHandler) RecipientsFromAccessMonitoringRules(ctx context.Context for _, rule := range amrh.getAccessMonitoringRules() { match, err := MatchAccessRequest(rule.Spec.Condition, req) if err != nil { - log.WithError(err).WithField("rule", rule.Metadata.Name). - Warn("Failed to parse access monitoring notification rule") + log.WarnContext(ctx, "Failed to parse access monitoring notification rule", + "error", err, + "rule", rule.Metadata.Name, + ) } if !match { continue @@ -160,7 +162,7 @@ func (amrh *RuleHandler) RecipientsFromAccessMonitoringRules(ctx context.Context for _, recipient := range rule.Spec.Notification.Recipients { rec, err := amrh.fetchRecipientCallback(ctx, recipient) if err != nil { - log.WithError(err).Warn("Failed to fetch plugin recipients based on Access monitoring rule recipients") + log.WarnContext(ctx, "Failed to fetch plugin recipients based on Access monitoring rule recipients", "error", err) continue } recipientSet.Add(*rec) @@ -176,8 +178,10 @@ func (amrh *RuleHandler) RawRecipientsFromAccessMonitoringRules(ctx context.Cont for _, rule := range amrh.getAccessMonitoringRules() { match, err := MatchAccessRequest(rule.Spec.Condition, req) if err != nil { - log.WithError(err).WithField("rule", rule.Metadata.Name). - Warn("Failed to parse access monitoring notification rule") + log.WarnContext(ctx, "Failed to parse access monitoring notification rule", + "error", err, + "rule", rule.Metadata.Name, + ) } if !match { continue diff --git a/integrations/access/accessrequest/app.go b/integrations/access/accessrequest/app.go index 8a5effc73dabd..17182ec3dc8ee 100644 --- a/integrations/access/accessrequest/app.go +++ b/integrations/access/accessrequest/app.go @@ -21,6 +21,7 @@ package accessrequest import ( "context" "fmt" + "log/slog" "slices" "strings" "time" @@ -36,6 +37,7 @@ import ( "github.com/gravitational/teleport/integrations/lib/logger" pd "github.com/gravitational/teleport/integrations/lib/plugindata" "github.com/gravitational/teleport/integrations/lib/watcherjob" + logutils "github.com/gravitational/teleport/lib/utils/log" ) const ( @@ -189,16 +191,16 @@ func (a *App) onWatcherEvent(ctx context.Context, event types.Event) error { func (a *App) handleAccessRequest(ctx context.Context, event types.Event) error { op := event.Type reqID := event.Resource.GetName() - ctx, _ = logger.WithField(ctx, "request_id", reqID) + ctx, _ = logger.With(ctx, "request_id", reqID) switch op { case types.OpPut: - ctx, _ = logger.WithField(ctx, "request_op", "put") + ctx, _ = logger.With(ctx, "request_op", "put") req, ok := event.Resource.(types.AccessRequest) if !ok { return trace.BadParameter("unexpected resource type %T", event.Resource) } - ctx, log := logger.WithField(ctx, "request_state", req.GetState().String()) + ctx, log := logger.With(ctx, "request_state", req.GetState().String()) var err error switch { @@ -207,21 +209,29 @@ func (a *App) handleAccessRequest(ctx context.Context, event types.Event) error case req.GetState().IsResolved(): err = a.onResolvedRequest(ctx, req) default: - log.WithField("event", event).Warn("Unknown request state") + log.WarnContext(ctx, "Unknown request state", + slog.Group("event", + slog.Any("type", logutils.StringerAttr(event.Type)), + slog.Group("resource", + "kind", event.Resource.GetKind(), + "name", event.Resource.GetName(), + ), + ), + ) return nil } if err != nil { - log.WithError(err).Errorf("Failed to process request") + log.ErrorContext(ctx, "Failed to process request", "error", err) return trace.Wrap(err) } return nil case types.OpDelete: - ctx, log := logger.WithField(ctx, "request_op", "delete") + ctx, log := logger.With(ctx, "request_op", "delete") if err := a.onDeletedRequest(ctx, reqID); err != nil { - log.WithError(err).Errorf("Failed to process deleted request") + log.ErrorContext(ctx, "Failed to process deleted request", "error", err) return trace.Wrap(err) } return nil @@ -242,7 +252,7 @@ func (a *App) onPendingRequest(ctx context.Context, req types.AccessRequest) err loginsByRole, err := a.getLoginsByRole(ctx, req) if trace.IsAccessDenied(err) { - log.Warnf("Missing permissions to get logins by role. Please add role.read to the associated role. error: %s", err) + log.WarnContext(ctx, "Missing permissions to get logins by role, please add role.read to the associated role", "error", err) } else if err != nil { return trace.Wrap(err) } @@ -265,12 +275,12 @@ func (a *App) onPendingRequest(ctx context.Context, req types.AccessRequest) err return trace.Wrap(err) } } else { - log.Warning("No channel to post") + log.WarnContext(ctx, "No channel to post") } // Try to approve the request if user is currently on-call. if err := a.tryApproveRequest(ctx, reqID, req); err != nil { - log.Warningf("Failed to auto approve request: %v", err) + log.WarnContext(ctx, "Failed to auto approve request", "error", err) } case trace.IsAlreadyExists(err): // The messages were already sent, nothing to do, we can update the reviews @@ -311,7 +321,7 @@ func (a *App) onResolvedRequest(ctx context.Context, req types.AccessRequest) er case types.RequestState_PROMOTED: tag = pd.ResolvedPromoted default: - logger.Get(ctx).Warningf("Unknown state %v (%s)", state, state.String()) + logger.Get(ctx).WarnContext(ctx, "Unknown state", "state", logutils.StringerAttr(state)) return replyErr } err := trace.Wrap(a.updateMessages(ctx, req.GetName(), tag, reason, req.GetReviews())) @@ -330,13 +340,13 @@ func (a *App) broadcastAccessRequestMessages(ctx context.Context, recipients []c return trace.Wrap(err) } for _, data := range sentMessages { - logger.Get(ctx).WithFields(logger.Fields{ - "channel_id": data.ChannelID, - "message_id": data.MessageID, - }).Info("Successfully posted messages") + logger.Get(ctx).InfoContext(ctx, "Successfully posted messages", + "channel_id", data.ChannelID, + "message_id", data.MessageID, + ) } if err != nil { - logger.Get(ctx).WithError(err).Error("Failed to post one or more messages") + logger.Get(ctx).ErrorContext(ctx, "Failed to post one or more messages", "error", err) } _, err = a.pluginData.Update(ctx, reqID, func(existing PluginData) (PluginData, error) { @@ -369,7 +379,7 @@ func (a *App) postReviewReplies(ctx context.Context, reqID string, reqReviews [] return existing, nil }) if trace.IsAlreadyExists(err) { - logger.Get(ctx).Debug("Failed to post reply: replies are already sent") + logger.Get(ctx).DebugContext(ctx, "Failed to post reply: replies are already sent") return nil } if err != nil { @@ -383,7 +393,7 @@ func (a *App) postReviewReplies(ctx context.Context, reqID string, reqReviews [] errors := make([]error, 0, len(slice)) for _, data := range pd.SentMessages { - ctx, _ = logger.WithFields(ctx, logger.Fields{"channel_id": data.ChannelID, "message_id": data.MessageID}) + ctx, _ = logger.With(ctx, "channel_id", data.ChannelID, "message_id", data.MessageID) for _, review := range slice { if err := a.bot.PostReviewReply(ctx, data.ChannelID, data.MessageID, review); err != nil { errors = append(errors, err) @@ -425,7 +435,7 @@ func (a *App) getMessageRecipients(ctx context.Context, req types.AccessRequest) for _, recipient := range recipients { rec, err := a.bot.FetchRecipient(ctx, recipient) if err != nil { - log.Warningf("Failed to fetch Opsgenie recipient: %v", err) + log.WarnContext(ctx, "Failed to fetch Opsgenie recipient", "error", err) continue } recipientSet.Add(*rec) @@ -436,7 +446,7 @@ func (a *App) getMessageRecipients(ctx context.Context, req types.AccessRequest) validEmailSuggReviewers := []string{} for _, reviewer := range req.GetSuggestedReviewers() { if !lib.IsEmail(reviewer) { - log.Warningf("Failed to notify a suggested reviewer: %q does not look like a valid email", reviewer) + log.WarnContext(ctx, "Failed to notify a suggested reviewer with an invalid email address", "reviewer", reviewer) continue } @@ -446,7 +456,7 @@ func (a *App) getMessageRecipients(ctx context.Context, req types.AccessRequest) for _, rawRecipient := range rawRecipients { recipient, err := a.bot.FetchRecipient(ctx, rawRecipient) if err != nil { - log.WithError(err).Warn("Failure when fetching recipient, continuing anyway") + log.WarnContext(ctx, "Failure when fetching recipient, continuing anyway", "error", err) } else { recipientSet.Add(*recipient) } @@ -476,7 +486,7 @@ func (a *App) updateMessages(ctx context.Context, reqID string, tag pd.Resolutio return existing, nil }) if trace.IsNotFound(err) { - log.Debug("Failed to update messages: plugin data is missing") + log.DebugContext(ctx, "Failed to update messages: plugin data is missing") return nil } if trace.IsAlreadyExists(err) { @@ -485,7 +495,7 @@ func (a *App) updateMessages(ctx context.Context, reqID string, tag pd.Resolutio "cannot change the resolution tag of an already resolved request, existing: %s, event: %s", pluginData.ResolutionTag, tag) } - log.Debug("Request is already resolved, ignoring event") + log.DebugContext(ctx, "Request is already resolved, ignoring event") return nil } if err != nil { @@ -496,13 +506,17 @@ func (a *App) updateMessages(ctx context.Context, reqID string, tag pd.Resolutio if err := a.bot.UpdateMessages(ctx, reqID, reqData, sentMessages, reviews); err != nil { return trace.Wrap(err) } - log.Infof("Successfully marked request as %s in all messages", tag) + + log.InfoContext(ctx, "Marked request with resolution and sent emails!", "resolution", tag) if err := a.bot.NotifyUser(ctx, reqID, reqData); err != nil && !trace.IsNotImplemented(err) { return trace.Wrap(err) } - log.Infof("Successfully notified user %s request marked as %s", reqData.User, tag) + log.InfoContext(ctx, "Successfully notified user", + "user", reqData.User, + "resolution", tag, + ) return nil } @@ -545,13 +559,11 @@ func (a *App) getResourceNames(ctx context.Context, req types.AccessRequest) ([] // tryApproveRequest attempts to automatically approve the access request if the // user is on call for the configured service/team. func (a *App) tryApproveRequest(ctx context.Context, reqID string, req types.AccessRequest) error { - log := logger.Get(ctx). - WithField("req_id", reqID). - WithField("user", req.GetUser()) + log := logger.Get(ctx).With("req_id", reqID, "user", req.GetUser()) oncallUsers, err := a.bot.FetchOncallUsers(ctx, req) if trace.IsNotImplemented(err) { - log.Debugf("Skipping auto-approval because %q bot does not support automatic approvals.", a.pluginName) + log.DebugContext(ctx, "Skipping auto-approval because bot does not support automatic approvals", "bot", a.pluginName) return nil } if err != nil { @@ -559,7 +571,7 @@ func (a *App) tryApproveRequest(ctx context.Context, reqID string, req types.Acc } if !slices.Contains(oncallUsers, req.GetUser()) { - log.Debug("Skipping approval because user is not on-call.") + log.DebugContext(ctx, "Skipping approval because user is not on-call") return nil } @@ -573,12 +585,12 @@ func (a *App) tryApproveRequest(ctx context.Context, reqID string, req types.Acc }, }); err != nil { if strings.HasSuffix(err.Error(), "has already reviewed this request") { - log.Debug("Request has already been reviewed.") + log.DebugContext(ctx, "Request has already been reviewed") return nil } return trace.Wrap(err) } - log.Info("Successfully submitted a request approval.") + log.InfoContext(ctx, "Successfully submitted a request approval") return nil } diff --git a/integrations/access/common/app.go b/integrations/access/common/app.go index 805c0dde6ef8a..6c174e1422b75 100644 --- a/integrations/access/common/app.go +++ b/integrations/access/common/app.go @@ -88,7 +88,7 @@ func (a *BaseApp) WaitReady(ctx context.Context) (bool, error) { func (a *BaseApp) checkTeleportVersion(ctx context.Context) (proto.PingResponse, error) { log := logger.Get(ctx) - log.Debug("Checking Teleport server version") + log.DebugContext(ctx, "Checking Teleport server version") pong, err := a.APIClient.Ping(ctx) if err != nil { @@ -156,9 +156,9 @@ func (a *BaseApp) run(ctx context.Context) error { a.mainJob.SetReady(allOK) if allOK { - log.Info("Plugin is ready") + log.InfoContext(ctx, "Plugin is ready") } else { - log.Error("Plugin is not ready") + log.ErrorContext(ctx, "Plugin is not ready") } for _, app := range a.apps { @@ -203,11 +203,11 @@ func (a *BaseApp) init(ctx context.Context) error { } } - log.Debug("Starting API health check...") + log.DebugContext(ctx, "Starting API health check") if err = a.Bot.CheckHealth(ctx); err != nil { return trace.Wrap(err, "API health check failed") } - log.Debug("API health check finished ok") + log.DebugContext(ctx, "API health check finished ok") return nil } diff --git a/integrations/access/common/auth/token_provider.go b/integrations/access/common/auth/token_provider.go index f4ae33936a709..e0c23b0b36427 100644 --- a/integrations/access/common/auth/token_provider.go +++ b/integrations/access/common/auth/token_provider.go @@ -20,12 +20,12 @@ package auth import ( "context" + "log/slog" "sync" "time" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" - "github.com/sirupsen/logrus" "github.com/gravitational/teleport/integrations/access/common/auth/oauth" "github.com/gravitational/teleport/integrations/access/common/auth/storage" @@ -65,7 +65,7 @@ type RotatedAccessTokenProviderConfig struct { Refresher oauth.Refresher Clock clockwork.Clock - Log *logrus.Entry + Log *slog.Logger } // CheckAndSetDefaults validates a configuration and sets default values @@ -87,7 +87,7 @@ func (c *RotatedAccessTokenProviderConfig) CheckAndSetDefaults() error { c.Clock = clockwork.NewRealClock() } if c.Log == nil { - c.Log = logrus.NewEntry(logrus.StandardLogger()) + c.Log = slog.Default() } return nil } @@ -104,7 +104,7 @@ type RotatedAccessTokenProvider struct { refresher oauth.Refresher clock clockwork.Clock - log logrus.FieldLogger + log *slog.Logger lock sync.RWMutex // protects the below fields creds *storage.Credentials @@ -153,12 +153,12 @@ func (r *RotatedAccessTokenProvider) RefreshLoop(ctx context.Context) { timer := r.clock.NewTimer(interval) defer timer.Stop() - r.log.Infof("Will attempt token refresh in: %s", interval) + r.log.InfoContext(ctx, "Starting token refresh loop", "next_refresh", interval) for { select { case <-ctx.Done(): - r.log.Info("Shutting down") + r.log.InfoContext(ctx, "Shutting down") return case <-timer.Chan(): creds, _ := r.store.GetCredentials(ctx) @@ -174,18 +174,21 @@ func (r *RotatedAccessTokenProvider) RefreshLoop(ctx context.Context) { interval := r.getRefreshInterval(creds) timer.Reset(interval) - r.log.Infof("Next refresh in: %s", interval) + r.log.InfoContext(ctx, "Refreshed token", "next_refresh", interval) continue } creds, err := r.refresh(ctx) if err != nil { - r.log.Errorf("Error while refreshing: %s. Will retry after: %s", err, r.retryInterval) + r.log.ErrorContext(ctx, "Error while refreshing token", + "error", err, + "retry_interval", r.retryInterval, + ) timer.Reset(r.retryInterval) } else { err := r.store.PutCredentials(ctx, creds) if err != nil { - r.log.Errorf("Error while storing the refreshed credentials: %s", err) + r.log.ErrorContext(ctx, "Error while storing the refreshed credentials", "error", err) timer.Reset(r.retryInterval) continue } @@ -196,7 +199,7 @@ func (r *RotatedAccessTokenProvider) RefreshLoop(ctx context.Context) { interval := r.getRefreshInterval(creds) timer.Reset(interval) - r.log.Infof("Successfully refreshed credentials. Next refresh in: %s", interval) + r.log.InfoContext(ctx, "Successfully refreshed credentials", "next_refresh", interval) } } } diff --git a/integrations/access/common/auth/token_provider_test.go b/integrations/access/common/auth/token_provider_test.go index fca79776ba024..e4f02ec3d3ae5 100644 --- a/integrations/access/common/auth/token_provider_test.go +++ b/integrations/access/common/auth/token_provider_test.go @@ -20,12 +20,12 @@ package auth import ( "context" + "log/slog" "testing" "time" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" - "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" "github.com/gravitational/teleport/integrations/access/common/auth/oauth" @@ -57,9 +57,6 @@ func (s *mockStore) PutCredentials(ctx context.Context, creds *storage.Credentia } func TestRotatedAccessTokenProvider(t *testing.T) { - log := logrus.New() - log.Level = logrus.DebugLevel - newProvider := func(ctx context.Context, store storage.Store, refresher oauth.Refresher, clock clockwork.Clock, initialCreds *storage.Credentials) *RotatedAccessTokenProvider { return &RotatedAccessTokenProvider{ store: store, @@ -70,7 +67,7 @@ func TestRotatedAccessTokenProvider(t *testing.T) { tokenBufferInterval: 1 * time.Hour, creds: initialCreds, - log: log, + log: slog.Default(), } } diff --git a/integrations/access/datadog/bot.go b/integrations/access/datadog/bot.go index e92dbbb524a20..4e1f52a6c218d 100644 --- a/integrations/access/datadog/bot.go +++ b/integrations/access/datadog/bot.go @@ -162,7 +162,7 @@ func (b Bot) FetchOncallUsers(ctx context.Context, req types.AccessRequest) ([]s annotationKey := types.TeleportNamespace + types.ReqAnnotationApproveSchedulesLabel teamNames, err := common.GetNamesFromAnnotations(req, annotationKey) if err != nil { - log.Debug("Automatic approvals annotation is empty or unspecified.") + log.DebugContext(ctx, "Automatic approvals annotation is empty or unspecified") return nil, nil } diff --git a/integrations/access/datadog/client.go b/integrations/access/datadog/client.go index 489eb0c51a44d..2d4ebf79ea5f2 100644 --- a/integrations/access/datadog/client.go +++ b/integrations/access/datadog/client.go @@ -126,7 +126,7 @@ func onAfterDatadogResponse(sink common.StatusSink) resty.ResponseMiddleware { defer cancel() if err := sink.Emit(ctx, status); err != nil { - log.WithError(err).Errorf("Error while emitting Datadog Incident Management plugin status: %v", err) + log.ErrorContext(ctx, "Error while emitting Datadog Incident Management plugin status", "error", err) } } diff --git a/integrations/access/datadog/cmd/teleport-datadog/main.go b/integrations/access/datadog/cmd/teleport-datadog/main.go index cb9cbd1959771..84a6a14c0955f 100644 --- a/integrations/access/datadog/cmd/teleport-datadog/main.go +++ b/integrations/access/datadog/cmd/teleport-datadog/main.go @@ -22,6 +22,7 @@ import ( "context" _ "embed" "fmt" + "log/slog" "os" "github.com/alecthomas/kingpin/v2" @@ -67,12 +68,13 @@ func main() { if err := run(*path, *debug); err != nil { lib.Bail(err) } else { - logger.Standard().Info("Successfully shut down") + slog.InfoContext(context.Background(), "Successfully shut down") } } } func run(configPath string, debug bool) error { + ctx := context.Background() conf, err := datadog.LoadConfig(configPath) if err != nil { return trace.Wrap(err) @@ -86,14 +88,15 @@ func run(configPath string, debug bool) error { return err } if debug { - logger.Standard().Debugf("DEBUG logging enabled") + slog.DebugContext(ctx, "DEBUG logging enabled") } app := datadog.NewDatadogApp(conf) go lib.ServeSignals(app, common.PluginShutdownTimeout) - logger.Standard().Infof("Starting Teleport Access Datadog Incident Management Plugin %s:%s", teleport.Version, teleport.Gitref) - return trace.Wrap( - app.Run(context.Background()), + slog.InfoContext(ctx, "Starting Teleport Access Datadog Incident Management Plugin", + "version", teleport.Version, + "git_ref", teleport.Gitref, ) + return trace.Wrap(app.Run(ctx)) } diff --git a/integrations/access/datadog/testlib/fake_datadog.go b/integrations/access/datadog/testlib/fake_datadog.go index 64ef2e35b93b7..5cfe8b539f454 100644 --- a/integrations/access/datadog/testlib/fake_datadog.go +++ b/integrations/access/datadog/testlib/fake_datadog.go @@ -32,7 +32,6 @@ import ( "github.com/gravitational/trace" "github.com/julienschmidt/httprouter" - log "github.com/sirupsen/logrus" "github.com/gravitational/teleport/integrations/access/datadog" ) @@ -281,6 +280,6 @@ func (d *FakeDatadog) GetOncallTeams() (map[string][]string, bool) { func panicIf(err error) { if err != nil { - log.Panicf("%v at %v", err, string(debug.Stack())) + panic(fmt.Sprintf("%v at %v", err, string(debug.Stack()))) } } diff --git a/integrations/access/discord/bot.go b/integrations/access/discord/bot.go index ca231bdf83a93..576606998b23c 100644 --- a/integrations/access/discord/bot.go +++ b/integrations/access/discord/bot.go @@ -94,8 +94,7 @@ func emitStatusUpdate(resp *resty.Response, statusSink common.StatusSink) { if err := statusSink.Emit(ctx, status); err != nil { logger.Get(resp.Request.Context()). - WithError(err). - Errorf("Error while emitting Discord plugin status: %v", err) + ErrorContext(ctx, "Error while emitting Discord plugin status", "error", err) } } diff --git a/integrations/access/discord/cmd/teleport-discord/main.go b/integrations/access/discord/cmd/teleport-discord/main.go index cd19ce64591b6..f624b407742ba 100644 --- a/integrations/access/discord/cmd/teleport-discord/main.go +++ b/integrations/access/discord/cmd/teleport-discord/main.go @@ -20,6 +20,7 @@ import ( "context" _ "embed" "fmt" + "log/slog" "os" "github.com/alecthomas/kingpin/v2" @@ -65,12 +66,13 @@ func main() { if err := run(*path, *debug); err != nil { lib.Bail(err) } else { - logger.Standard().Info("Successfully shut down") + slog.InfoContext(context.Background(), "Successfully shut down") } } } func run(configPath string, debug bool) error { + ctx := context.Background() conf, err := discord.LoadDiscordConfig(configPath) if err != nil { return trace.Wrap(err) @@ -84,14 +86,15 @@ func run(configPath string, debug bool) error { return trace.Wrap(err) } if debug { - logger.Standard().Debugf("DEBUG logging enabled") + slog.DebugContext(ctx, "DEBUG logging enabled") } app := discord.NewApp(conf) go lib.ServeSignals(app, common.PluginShutdownTimeout) - logger.Standard().Infof("Starting Teleport Access Discord Plugin %s:%s", teleport.Version, teleport.Gitref) - return trace.Wrap( - app.Run(context.Background()), + slog.InfoContext(ctx, "Starting Teleport Access Discord Plugin", + "version", teleport.Version, + "git_ref", teleport.Gitref, ) + return trace.Wrap(app.Run(ctx)) } diff --git a/integrations/access/discord/testlib/fake_discord.go b/integrations/access/discord/testlib/fake_discord.go index c5a176446be5b..0a059d8ac81e2 100644 --- a/integrations/access/discord/testlib/fake_discord.go +++ b/integrations/access/discord/testlib/fake_discord.go @@ -32,7 +32,6 @@ import ( "github.com/gravitational/trace" "github.com/julienschmidt/httprouter" - log "github.com/sirupsen/logrus" "github.com/gravitational/teleport/integrations/access/discord" ) @@ -188,6 +187,6 @@ func (s *FakeDiscord) CheckMessageUpdateByResponding(ctx context.Context) (disco func panicIf(err error) { if err != nil { - log.Panicf("%v at %v", err, string(debug.Stack())) + panic(fmt.Sprintf("%v at %v", err, string(debug.Stack()))) } } diff --git a/integrations/access/email/app.go b/integrations/access/email/app.go index 07bb3b558080e..cae9c33ed5315 100644 --- a/integrations/access/email/app.go +++ b/integrations/access/email/app.go @@ -18,6 +18,7 @@ package email import ( "context" + "log/slog" "slices" "time" @@ -32,6 +33,7 @@ import ( "github.com/gravitational/teleport/integrations/lib/logger" "github.com/gravitational/teleport/integrations/lib/watcherjob" "github.com/gravitational/teleport/lib/utils" + logutils "github.com/gravitational/teleport/lib/utils/log" ) const ( @@ -90,7 +92,6 @@ func (a *App) run(ctx context.Context) error { var err error log := logger.Get(ctx) - log.Infof("Starting Teleport Access Email Plugin") if err = a.init(ctx); err != nil { return trace.Wrap(err) @@ -137,9 +138,9 @@ func (a *App) run(ctx context.Context) error { a.mainJob.SetReady(ok) if ok { - log.Info("Plugin is ready") + log.InfoContext(ctx, "Plugin is ready") } else { - log.Error("Plugin is not ready") + log.ErrorContext(ctx, "Plugin is not ready") } <-watcherJob.Done() @@ -186,24 +187,24 @@ func (a *App) init(ctx context.Context) error { }, }) - log.Debug("Starting client connection health check...") + log.DebugContext(ctx, "Starting client connection health check") if err = a.client.CheckHealth(ctx); err != nil { return trace.Wrap(err, "client connection health check failed") } - log.Debug("Client connection health check finished ok") + log.DebugContext(ctx, "Client connection health check finished ok") return nil } // checkTeleportVersion checks that Teleport version is not lower than required func (a *App) checkTeleportVersion(ctx context.Context) (proto.PingResponse, error) { log := logger.Get(ctx) - log.Debug("Checking Teleport server version") + log.DebugContext(ctx, "Checking Teleport server version") pong, err := a.apiClient.Ping(ctx) if err != nil { if trace.IsNotImplemented(err) { return pong, trace.Wrap(err, "server version must be at least %s", minServerVersion) } - log.Error("Unable to get Teleport server version") + log.ErrorContext(ctx, "Unable to get Teleport server version") return pong, trace.Wrap(err) } err = utils.CheckMinVersion(pong.ServerVersion, minServerVersion) @@ -229,16 +230,16 @@ func (a *App) handleAccessRequest(ctx context.Context, event types.Event) error } op := event.Type reqID := event.Resource.GetName() - ctx, _ = logger.WithField(ctx, "request_id", reqID) + ctx, _ = logger.With(ctx, "request_id", reqID) switch op { case types.OpPut: - ctx, _ = logger.WithField(ctx, "request_op", "put") + ctx, _ = logger.With(ctx, "request_op", "put") req, ok := event.Resource.(types.AccessRequest) if !ok { return trace.Errorf("unexpected resource type %T", event.Resource) } - ctx, log := logger.WithField(ctx, "request_state", req.GetState().String()) + ctx, log := logger.With(ctx, "request_state", req.GetState().String()) var err error switch { @@ -249,21 +250,31 @@ func (a *App) handleAccessRequest(ctx context.Context, event types.Event) error case req.GetState().IsDenied(): err = a.onResolvedRequest(ctx, req) default: - log.WithField("event", event).Warn("Unknown request state") + log.WarnContext(ctx, "Unknown request state", + slog.Group("event", + slog.Any("type", logutils.StringerAttr(event.Type)), + slog.Group("resource", + "kind", event.Resource.GetKind(), + "name", event.Resource.GetName(), + ), + ), + ) + + log.With("event", event).WarnContext(ctx, "Unknown request state") return nil } if err != nil { - log.WithError(err).Errorf("Failed to process request") + log.ErrorContext(ctx, "Failed to process request", "error", err) return trace.Wrap(err) } return nil case types.OpDelete: - ctx, log := logger.WithField(ctx, "request_op", "delete") + ctx, log := logger.With(ctx, "request_op", "delete") if err := a.onDeletedRequest(ctx, reqID); err != nil { - log.WithError(err).Errorf("Failed to process deleted request") + log.ErrorContext(ctx, "Failed to process deleted request", "error", err) return trace.Wrap(err) } return nil @@ -292,7 +303,7 @@ func (a *App) onPendingRequest(ctx context.Context, req types.AccessRequest) err if isNew { recipients := a.getRecipients(ctx, req) if len(recipients) == 0 { - log.Warning("No recipients to send") + log.WarnContext(ctx, "No recipients to send") return nil } @@ -329,7 +340,7 @@ func (a *App) onResolvedRequest(ctx context.Context, req types.AccessRequest) er case types.RequestState_DENIED: resolution.Tag = ResolvedDenied default: - logger.Get(ctx).Warningf("Unknown state %v (%s)", state, state.String()) + logger.Get(ctx).WarnContext(ctx, "Unknown state", "state", logutils.StringerAttr(state)) return replyErr } err := trace.Wrap(a.sendResolution(ctx, req.GetName(), resolution)) @@ -359,7 +370,7 @@ func (a *App) getRecipients(ctx context.Context, req types.AccessRequest) []comm rawRecipients := a.conf.RoleToRecipients.GetRawRecipientsFor(req.GetRoles(), req.GetSuggestedReviewers()) for _, rawRecipient := range rawRecipients { if !lib.IsEmail(rawRecipient) { - log.Warningf("Failed to notify a reviewer: %q does not look like a valid email", rawRecipient) + log.WarnContext(ctx, "Failed to notify a suggested reviewer with an invalid email address", "reviewer", rawRecipient) continue } recipientSet.Add(common.Recipient{ @@ -382,7 +393,7 @@ func (a *App) sendNewThreads(ctx context.Context, recipients []common.Recipient, logSentThreads(ctx, threadsSent, "new threads") if err != nil { - logger.Get(ctx).WithError(err).Error("Failed send one or more messages") + logger.Get(ctx).ErrorContext(ctx, "Failed send one or more messages", "error", err) } _, err = a.modifyPluginData(ctx, reqID, func(existing *PluginData) (PluginData, bool) { @@ -425,7 +436,7 @@ func (a *App) sendReviews(ctx context.Context, reqID string, reqData RequestData return trace.Wrap(err) } if !ok { - logger.Get(ctx).Debug("Failed to post reply: plugin data is missing") + logger.Get(ctx).DebugContext(ctx, "Failed to post reply: plugin data is missing") return nil } reviews := reqReviews[oldCount:] @@ -439,7 +450,11 @@ func (a *App) sendReviews(ctx context.Context, reqID string, reqData RequestData if err != nil { errors = append(errors, err) } - logger.Get(ctx).Infof("New review for request %v by %v is %v", reqID, review.Author, review.ProposedState.String()) + logger.Get(ctx).InfoContext(ctx, "New review for request", + "request_id", reqID, + "author", review.Author, + "state", logutils.StringerAttr(review.ProposedState), + ) logSentThreads(ctx, threadsSent, "new review") } @@ -473,7 +488,7 @@ func (a *App) sendResolution(ctx context.Context, reqID string, resolution Resol return trace.Wrap(err) } if !ok { - log.Debug("Failed to update messages: plugin data is missing") + log.DebugContext(ctx, "Failed to update messages: plugin data is missing") return nil } @@ -482,7 +497,7 @@ func (a *App) sendResolution(ctx context.Context, reqID string, resolution Resol threadsSent, err := a.client.SendResolution(ctx, threads, reqID, reqData) logSentThreads(ctx, threadsSent, "request resolved") - log.Infof("Marked request as %s and sent emails!", resolution.Tag) + log.InfoContext(ctx, "Marked request with resolution and sent emails", "resolution", resolution.Tag) if err != nil { return trace.Wrap(err) @@ -567,10 +582,11 @@ func (a *App) updatePluginData(ctx context.Context, reqID string, data PluginDat // logSentThreads logs successfully sent emails func logSentThreads(ctx context.Context, threads []EmailThread, kind string) { for _, thread := range threads { - logger.Get(ctx).WithFields(logger.Fields{ - "email": thread.Email, - "timestamp": thread.Timestamp, - "message_id": thread.MessageID, - }).Infof("Successfully sent %v!", kind) + logger.Get(ctx).InfoContext(ctx, "Successfully sent", + "email", thread.Email, + "timestamp", thread.Timestamp, + "message_id", thread.MessageID, + "kind", kind, + ) } } diff --git a/integrations/access/email/client.go b/integrations/access/email/client.go index 6ef1d2f04144e..f687f5deb0009 100644 --- a/integrations/access/email/client.go +++ b/integrations/access/email/client.go @@ -61,16 +61,16 @@ func NewClient(ctx context.Context, conf Config, clusterName, webProxyAddr strin if conf.Mailgun != nil { mailer = NewMailgunMailer(*conf.Mailgun, conf.StatusSink, conf.Delivery.Sender, clusterName, conf.RoleToRecipients[types.Wildcard]) - logger.Get(ctx).WithField("domain", conf.Mailgun.Domain).Info("Using Mailgun as email transport") + logger.Get(ctx).InfoContext(ctx, "Using Mailgun as email transport", "domain", conf.Mailgun.Domain) } if conf.SMTP != nil { mailer = NewSMTPMailer(*conf.SMTP, conf.StatusSink, conf.Delivery.Sender, clusterName) - logger.Get(ctx).WithFields(logger.Fields{ - "host": conf.SMTP.Host, - "port": conf.SMTP.Port, - "username": conf.SMTP.Username, - }).Info("Using SMTP as email transport") + logger.Get(ctx).InfoContext(ctx, "Using SMTP as email transport", + "host", conf.SMTP.Host, + "port", conf.SMTP.Port, + "username", conf.SMTP.Username, + ) } return Client{ diff --git a/integrations/access/email/cmd/teleport-email/main.go b/integrations/access/email/cmd/teleport-email/main.go index 840c80da76177..ccaec3acbed36 100644 --- a/integrations/access/email/cmd/teleport-email/main.go +++ b/integrations/access/email/cmd/teleport-email/main.go @@ -20,6 +20,7 @@ import ( "context" _ "embed" "fmt" + "log/slog" "os" "github.com/alecthomas/kingpin/v2" @@ -65,12 +66,13 @@ func main() { if err := run(*path, *debug); err != nil { lib.Bail(err) } else { - logger.Standard().Info("Successfully shut down") + slog.InfoContext(context.Background(), "Successfully shut down") } } } func run(configPath string, debug bool) error { + ctx := context.Background() conf, err := email.LoadConfig(configPath) if err != nil { return trace.Wrap(err) @@ -84,11 +86,11 @@ func run(configPath string, debug bool) error { return err } if debug { - logger.Standard().Debugf("DEBUG logging enabled") + slog.DebugContext(ctx, "DEBUG logging enabled") } if conf.Delivery.Recipients != nil { - logger.Standard().Warn("The delivery.recipients config option is deprecated, set role_to_recipients[\"*\"] instead for the same functionality") + slog.WarnContext(ctx, "The delivery.recipients config option is deprecated, set role_to_recipients[\"*\"] instead for the same functionality") } app, err := email.NewApp(*conf) @@ -98,8 +100,9 @@ func run(configPath string, debug bool) error { go lib.ServeSignals(app, common.PluginShutdownTimeout) - logger.Standard().Infof("Starting Teleport Access Email Plugin %s:%s", teleport.Version, teleport.Gitref) - return trace.Wrap( - app.Run(context.Background()), + slog.InfoContext(ctx, "Starting Teleport Access Email Plugin", + "version", teleport.Version, + "git_ref", teleport.Gitref, ) + return trace.Wrap(app.Run(ctx)) } diff --git a/integrations/access/email/mailers.go b/integrations/access/email/mailers.go index 60d5b4592449f..5cbd3d98bee02 100644 --- a/integrations/access/email/mailers.go +++ b/integrations/access/email/mailers.go @@ -114,7 +114,7 @@ func (m *SMTPMailer) CheckHealth(ctx context.Context) error { return trace.Wrap(err) } if err := client.Close(); err != nil { - log.Debug("Failed to close client connection after health check") + log.DebugContext(ctx, "Failed to close client connection after health check") } return nil } @@ -191,7 +191,7 @@ func (m *SMTPMailer) emitStatus(ctx context.Context, statusErr error) { code = http.StatusInternalServerError } if err := m.sink.Emit(ctx, common.StatusFromStatusCode(code)); err != nil { - log.WithError(err).Error("Error while emitting Email plugin status") + log.ErrorContext(ctx, "Error while emitting Email plugin status", "error", err) } } @@ -252,7 +252,7 @@ func (t *statusSinkTransport) RoundTrip(req *http.Request) (*http.Response, erro status := common.StatusFromStatusCode(resp.StatusCode) if err := t.sink.Emit(ctx, status); err != nil { - log.WithError(err).Error("Error while emitting Email plugin status") + log.ErrorContext(ctx, "Error while emitting Email plugin status", "error", err) } } return resp, nil diff --git a/integrations/access/email/testlib/mock_mailgun.go b/integrations/access/email/testlib/mock_mailgun.go index 58cbbc8ebb098..7895a5cdcaefe 100644 --- a/integrations/access/email/testlib/mock_mailgun.go +++ b/integrations/access/email/testlib/mock_mailgun.go @@ -24,7 +24,6 @@ import ( "github.com/google/uuid" "github.com/gravitational/trace" - log "github.com/sirupsen/logrus" ) const ( @@ -58,7 +57,8 @@ func newMockMailgunServer(concurrency int) *mockMailgunServer { s := httptest.NewUnstartedServer(func(mg *mockMailgunServer) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { if err := r.ParseMultipartForm(multipartFormBufSize); err != nil { - log.Error(err) + w.WriteHeader(http.StatusInternalServerError) + return } id := uuid.New().String() diff --git a/integrations/access/jira/app.go b/integrations/access/jira/app.go index 2aab94e887f0d..c8e6c8273ec02 100644 --- a/integrations/access/jira/app.go +++ b/integrations/access/jira/app.go @@ -21,6 +21,7 @@ package jira import ( "context" "fmt" + "log/slog" "net/url" "regexp" "strings" @@ -40,6 +41,7 @@ import ( "github.com/gravitational/teleport/integrations/lib/logger" "github.com/gravitational/teleport/integrations/lib/watcherjob" "github.com/gravitational/teleport/lib/utils" + logutils "github.com/gravitational/teleport/lib/utils/log" ) const ( @@ -125,7 +127,6 @@ func (a *App) run(ctx context.Context) error { var err error log := logger.Get(ctx) - log.Infof("Starting Teleport Jira Plugin") if err = a.init(ctx); err != nil { return trace.Wrap(err) @@ -164,9 +165,9 @@ func (a *App) run(ctx context.Context) error { ok := (a.webhookSrv == nil || httpOk) && watcherOk a.mainJob.SetReady(ok) if ok { - log.Info("Plugin is ready") + log.InfoContext(ctx, "Plugin is ready") } else { - log.Error("Plugin is not ready") + log.ErrorContext(ctx, "Plugin is not ready") } if httpJob != nil { @@ -205,11 +206,11 @@ func (a *App) init(ctx context.Context) error { return trace.Wrap(err) } - log.Debug("Starting Jira API health check...") + log.DebugContext(ctx, "Starting Jira API health check") if err = a.jira.HealthCheck(ctx); err != nil { return trace.Wrap(err, "api health check failed") } - log.Debug("Jira API health check finished ok") + log.DebugContext(ctx, "Jira API health check finished ok") if !a.conf.DisableWebhook { webhookSrv, err := NewWebhookServer(a.conf.HTTP, a.onJiraWebhook) @@ -227,13 +228,13 @@ func (a *App) init(ctx context.Context) error { func (a *App) checkTeleportVersion(ctx context.Context) (proto.PingResponse, error) { log := logger.Get(ctx) - log.Debug("Checking Teleport server version") + log.DebugContext(ctx, "Checking Teleport server version") pong, err := a.teleport.Ping(ctx) if err != nil { if trace.IsNotImplemented(err) { return pong, trace.Wrap(err, "server version must be at least %s", minServerVersion) } - log.Error("Unable to get Teleport server version") + log.ErrorContext(ctx, "Unable to get Teleport server version") return pong, trace.Wrap(err) } err = utils.CheckMinVersion(pong.ServerVersion, minServerVersion) @@ -246,17 +247,17 @@ func (a *App) onWatcherEvent(ctx context.Context, event types.Event) error { } op := event.Type reqID := event.Resource.GetName() - ctx, _ = logger.WithField(ctx, "request_id", reqID) + ctx, _ = logger.With(ctx, "request_id", reqID) switch op { case types.OpPut: - ctx, _ = logger.WithField(ctx, "request_op", "put") + ctx, _ = logger.With(ctx, "request_op", "put") req, ok := event.Resource.(types.AccessRequest) if !ok { return trace.Errorf("unexpected resource type %T", event.Resource) } - ctx, log := logger.WithField(ctx, "request_state", req.GetState().String()) - log.Debug("Processing watcher event") + ctx, log := logger.With(ctx, "request_state", req.GetState().String()) + log.DebugContext(ctx, "Processing watcher event") var err error switch { @@ -265,21 +266,29 @@ func (a *App) onWatcherEvent(ctx context.Context, event types.Event) error { case req.GetState().IsResolved(): err = a.onResolvedRequest(ctx, req) default: - log.WithField("event", event).Warn("Unknown request state") + log.WarnContext(ctx, "Unknown request state", + slog.Group("event", + slog.Any("type", logutils.StringerAttr(event.Type)), + slog.Group("resource", + "kind", event.Resource.GetKind(), + "name", event.Resource.GetName(), + ), + ), + ) return nil } if err != nil { - log.WithError(err).Error("Failed to process request") + log.ErrorContext(ctx, "Failed to process request", "error", err) return trace.Wrap(err) } return nil case types.OpDelete: - ctx, log := logger.WithField(ctx, "request_op", "delete") + ctx, log := logger.With(ctx, "request_op", "delete") if err := a.onDeletedRequest(ctx, reqID); err != nil { - log.WithError(err).Errorf("Failed to process deleted request") + log.ErrorContext(ctx, "Failed to process deleted request", "error", err) return trace.Wrap(err) } return nil @@ -299,10 +308,11 @@ func (a *App) onJiraWebhook(_ context.Context, webhook Webhook) error { return nil } - ctx, log := logger.WithFields(ctx, logger.Fields{ - "jira_issue_id": webhook.Issue.ID, - }) - log.Debugf("Processing incoming webhook event %q with type %q", webhookEvent, issueEventTypeName) + ctx, log := logger.With(ctx, "jira_issue_id", webhook.Issue.ID) + log.DebugContext(ctx, "Processing incoming webhook event", + "event", webhookEvent, + "event_type", issueEventTypeName, + ) if webhook.Issue == nil { return trace.Errorf("got webhook without issue info") @@ -333,20 +343,20 @@ func (a *App) onJiraWebhook(_ context.Context, webhook Webhook) error { if statusName == "" { return trace.Errorf("getting Jira issue status: %w", err) } - log.Warnf("Using most recent successful getIssue response: %v", err) + log.WarnContext(ctx, "Using most recent successful getIssue response", "error", err) } - ctx, log = logger.WithFields(ctx, logger.Fields{ - "jira_issue_id": issue.ID, - "jira_issue_key": issue.Key, - }) + ctx, log = logger.With(ctx, + "jira_issue_id", issue.ID, + "jira_issue_key", issue.Key, + ) switch { case statusName == "pending": - log.Debug("Issue has pending status, ignoring it") + log.DebugContext(ctx, "Issue has pending status, ignoring it") return nil case statusName == "expired": - log.Debug("Issue has expired status, ignoring it") + log.DebugContext(ctx, "Issue has expired status, ignoring it") return nil case statusName != "approved" && statusName != "denied": return trace.BadParameter("unknown Jira status %s", statusName) @@ -357,11 +367,11 @@ func (a *App) onJiraWebhook(_ context.Context, webhook Webhook) error { return trace.Wrap(err) } if reqID == "" { - log.Debugf("Missing %q issue property", RequestIDPropertyKey) + log.DebugContext(ctx, "Missing teleportAccessRequestId issue property") return nil } - ctx, log = logger.WithField(ctx, "request_id", reqID) + ctx, log = logger.With(ctx, "request_id", reqID) reqs, err := a.teleport.GetAccessRequests(ctx, types.AccessRequestFilter{ID: reqID}) if err != nil { @@ -382,8 +392,9 @@ func (a *App) onJiraWebhook(_ context.Context, webhook Webhook) error { return trace.Errorf("plugin data is blank") } if pluginData.IssueID != issue.ID { - log.WithField("plugin_data_issue_id", pluginData.IssueID). - Debug("plugin_data.issue_id does not match issue.id") + log.DebugContext(ctx, "plugin_data.issue_id does not match issue.id", + "plugin_data_issue_id", pluginData.IssueID, + ) return trace.Errorf("issue_id from request's plugin_data does not match") } @@ -406,17 +417,17 @@ func (a *App) onJiraWebhook(_ context.Context, webhook Webhook) error { author, reason, err := a.loadResolutionInfo(ctx, issue, statusName) if err != nil { - log.WithError(err).Error("Failed to load resolution info from the issue history") + log.ErrorContext(ctx, "Failed to load resolution info from the issue history", "error", err) } resolution.Reason = reason - ctx, _ = logger.WithFields(ctx, logger.Fields{ - "jira_user_email": author.EmailAddress, - "jira_user_name": author.DisplayName, - "request_user": req.GetUser(), - "request_roles": req.GetRoles(), - "reason": reason, - }) + ctx, _ = logger.With(ctx, + "jira_user_email", author.EmailAddress, + "jira_user_name", author.DisplayName, + "request_user", req.GetUser(), + "request_roles", req.GetRoles(), + "reason", reason, + ) if err := a.resolveRequest(ctx, reqID, author.EmailAddress, resolution); err != nil { return trace.Wrap(err) } @@ -498,11 +509,11 @@ func (a *App) createIssue(ctx context.Context, reqID string, reqData RequestData return trace.Wrap(err) } - ctx, log := logger.WithFields(ctx, logger.Fields{ - "jira_issue_id": data.IssueID, - "jira_issue_key": data.IssueKey, - }) - log.Info("Jira Issue created") + ctx, log := logger.With(ctx, + "jira_issue_id", data.IssueID, + "jira_issue_key", data.IssueKey, + ) + log.InfoContext(ctx, "Jira Issue created") // Save jira issue info in plugin data. _, err = a.modifyPluginData(ctx, reqID, func(existing *PluginData) (PluginData, bool) { @@ -551,11 +562,11 @@ func (a *App) addReviewComments(ctx context.Context, reqID string, reqReviews [] } if !ok { if issueID == "" { - logger.Get(ctx).Debug("Failed to add the comment: plugin data is blank") + logger.Get(ctx).DebugContext(ctx, "Failed to add the comment: plugin data is blank") } return nil } - ctx, _ = logger.WithField(ctx, "jira_issue_id", issueID) + ctx, _ = logger.With(ctx, "jira_issue_id", issueID) slice := reqReviews[oldCount:] if len(slice) == 0 { @@ -621,7 +632,7 @@ func (a *App) resolveRequest(ctx context.Context, reqID string, userEmail string return trace.Wrap(err) } - logger.Get(ctx).Infof("Jira user %s the request", resolution.Tag) + logger.Get(ctx).InfoContext(ctx, "Jira user processed the request", "resolution", resolution.Tag) return nil } @@ -658,18 +669,18 @@ func (a *App) resolveIssue(ctx context.Context, reqID string, resolution Resolut } if !ok { if issueID == "" { - logger.Get(ctx).Debug("Failed to resolve the issue: plugin data is blank") + logger.Get(ctx).DebugContext(ctx, "Failed to resolve the issue: plugin data is blank") } // Either plugin data is missing or issue is already resolved by us, just quit. return nil } - ctx, log := logger.WithField(ctx, "jira_issue_id", issueID) + ctx, log := logger.With(ctx, "jira_issue_id", issueID) if err := a.jira.ResolveIssue(ctx, issueID, resolution); err != nil { return trace.Wrap(err) } - log.Info("Successfully resolved the issue") + log.InfoContext(ctx, "Successfully resolved the issue") return nil } diff --git a/integrations/access/jira/client.go b/integrations/access/jira/client.go index 2877966af663b..a23381e4d2666 100644 --- a/integrations/access/jira/client.go +++ b/integrations/access/jira/client.go @@ -125,7 +125,7 @@ func NewJiraClient(conf JiraConfig, clusterName, teleportProxyAddr string, statu defer cancel() if err := statusSink.Emit(ctx, status); err != nil { - log.WithError(err).Errorf("Error while emitting Jira plugin status: %v", err) + log.ErrorContext(ctx, "Error while emitting Jira plugin status", "error", err) } } @@ -199,7 +199,7 @@ func (j *Jira) HealthCheck(ctx context.Context) error { } } - log.Debug("Checking out Jira project...") + log.DebugContext(ctx, "Checking out Jira project") var project Project _, err = j.client.NewRequest(). SetContext(ctx). @@ -209,9 +209,12 @@ func (j *Jira) HealthCheck(ctx context.Context) error { if err != nil { return trace.Wrap(err) } - log.Debugf("Found project %q named %q", project.Key, project.Name) + log.DebugContext(ctx, "Found Jira project", + "project", project.Key, + "project_name", project.Name, + ) - log.Debug("Checking out Jira project permissions...") + log.DebugContext(ctx, "Checking out Jira project permissions") queryOptions, err := query.Values(GetMyPermissionsQueryOptions{ ProjectKey: j.project, Permissions: jiraRequiredPermissions, @@ -433,7 +436,7 @@ func (j *Jira) ResolveIssue(ctx context.Context, issueID string, resolution Reso if err2 := trace.Wrap(j.TransitionIssue(ctx, issue.ID, transition.ID)); err2 != nil { return trace.NewAggregate(err1, err2) } - logger.Get(ctx).Debugf("Successfully moved the issue to the status %q", toStatus) + logger.Get(ctx).DebugContext(ctx, "Successfully moved the issue to the target status", "target_status", toStatus) return trace.Wrap(err1) } @@ -457,7 +460,7 @@ func (j *Jira) AddResolutionComment(ctx context.Context, id string, resolution R SetBody(CommentInput{Body: builder.String()}). Post("rest/api/2/issue/{issueID}/comment") if err == nil { - logger.Get(ctx).Debug("Successfully added a resolution comment to the issue") + logger.Get(ctx).DebugContext(ctx, "Successfully added a resolution comment to the issue") } return trace.Wrap(err) } diff --git a/integrations/access/jira/cmd/teleport-jira/main.go b/integrations/access/jira/cmd/teleport-jira/main.go index b2c2bb0672d06..851de27473296 100644 --- a/integrations/access/jira/cmd/teleport-jira/main.go +++ b/integrations/access/jira/cmd/teleport-jira/main.go @@ -20,6 +20,7 @@ import ( "context" _ "embed" "fmt" + "log/slog" "os" "github.com/alecthomas/kingpin/v2" @@ -72,12 +73,13 @@ func main() { if err := run(*path, *insecure, *debug); err != nil { lib.Bail(err) } else { - logger.Standard().Info("Successfully shut down") + slog.InfoContext(context.Background(), "Successfully shut down") } } } func run(configPath string, insecure bool, debug bool) error { + ctx := context.Background() conf, err := jira.LoadConfig(configPath) if err != nil { return trace.Wrap(err) @@ -91,7 +93,7 @@ func run(configPath string, insecure bool, debug bool) error { return err } if debug { - logger.Standard().Debugf("DEBUG logging enabled") + slog.DebugContext(ctx, "DEBUG logging enabled") } conf.HTTP.Insecure = insecure @@ -102,8 +104,9 @@ func run(configPath string, insecure bool, debug bool) error { go lib.ServeSignals(app, common.PluginShutdownTimeout) - logger.Standard().Infof("Starting Teleport Access Jira Plugin %s:%s", teleport.Version, teleport.Gitref) - return trace.Wrap( - app.Run(context.Background()), + slog.InfoContext(ctx, "Starting Teleport Access Jira Plugin", + "version", teleport.Version, + "git_ref", teleport.Gitref, ) + return trace.Wrap(app.Run(ctx)) } diff --git a/integrations/access/jira/testlib/fake_jira.go b/integrations/access/jira/testlib/fake_jira.go index 1da8c432ec3a9..9696500620aba 100644 --- a/integrations/access/jira/testlib/fake_jira.go +++ b/integrations/access/jira/testlib/fake_jira.go @@ -30,7 +30,6 @@ import ( "github.com/gravitational/trace" "github.com/julienschmidt/httprouter" - log "github.com/sirupsen/logrus" "github.com/gravitational/teleport/integrations/access/jira" ) @@ -304,6 +303,6 @@ func (s *FakeJira) CheckIssueTransition(ctx context.Context) (jira.Issue, error) func panicIf(err error) { if err != nil { - log.Panicf("%v at %v", err, string(debug.Stack())) + panic(fmt.Sprintf("%v at %v", err, string(debug.Stack()))) } } diff --git a/integrations/access/jira/testlib/suite.go b/integrations/access/jira/testlib/suite.go index 38341d589fa5d..c2a3d421f442c 100644 --- a/integrations/access/jira/testlib/suite.go +++ b/integrations/access/jira/testlib/suite.go @@ -721,7 +721,7 @@ func (s *JiraSuiteOSS) TestRace() { defer cancel() var lastErr error for { - logger.Get(ctx).Infof("Trying to approve issue %q", issue.Key) + logger.Get(ctx).InfoContext(ctx, "Trying to approve issue", "issue_key", issue.Key) resp, err := s.postWebhook(ctx, s.webhookURL.String(), issue.ID, "Approved") if err != nil { if lib.IsDeadline(err) { diff --git a/integrations/access/jira/webhook_server.go b/integrations/access/jira/webhook_server.go index b83e449b992c8..e9e409959b40a 100644 --- a/integrations/access/jira/webhook_server.go +++ b/integrations/access/jira/webhook_server.go @@ -105,29 +105,31 @@ func (s *WebhookServer) processWebhook(rw http.ResponseWriter, r *http.Request, defer cancel() httpRequestID := fmt.Sprintf("%v-%v", time.Now().Unix(), atomic.AddUint64(&s.counter, 1)) - ctx, log := logger.WithField(ctx, "jira_http_id", httpRequestID) + ctx, log := logger.With(ctx, "jira_http_id", httpRequestID) var webhook Webhook body, err := io.ReadAll(io.LimitReader(r.Body, jiraWebhookPayloadLimit+1)) if err != nil { - log.WithError(err).Error("Failed to read webhook payload") + log.ErrorContext(ctx, "Failed to read webhook payload", "error", err) http.Error(rw, "", http.StatusInternalServerError) return } if len(body) > jiraWebhookPayloadLimit { - log.Error("Received a webhook larger than %d bytes", jiraWebhookPayloadLimit) + log.ErrorContext(ctx, "Received a webhook with a payload that exceeded the limit", + "payload_size", len(body), + "payload_size_limit", jiraWebhookPayloadLimit, + ) http.Error(rw, "", http.StatusRequestEntityTooLarge) } if err = json.Unmarshal(body, &webhook); err != nil { - log.WithError(err).Error("Failed to parse webhook payload") + log.ErrorContext(ctx, "Failed to parse webhook payload", "error", err) http.Error(rw, "", http.StatusBadRequest) return } if err = s.onWebhook(ctx, webhook); err != nil { - log.WithError(err).Error("Failed to process webhook") - log.Debugf("%v", trace.DebugReport(err)) + log.ErrorContext(ctx, "Failed to process webhook", "error", err) var code int switch { case lib.IsCanceled(err) || lib.IsDeadline(err): diff --git a/integrations/access/mattermost/bot.go b/integrations/access/mattermost/bot.go index c7de9d0aaae44..edf0a7e73264d 100644 --- a/integrations/access/mattermost/bot.go +++ b/integrations/access/mattermost/bot.go @@ -150,7 +150,7 @@ func NewBot(conf Config, clusterName, webProxyAddr string) (Bot, error) { ctx, cancel := context.WithTimeout(context.Background(), mmStatusEmitTimeout) defer cancel() if err := sink.Emit(ctx, status); err != nil { - log.Errorf("Error while emitting plugin status: %v", err) + log.ErrorContext(ctx, "Error while emitting plugin status", "error", err) } }() @@ -463,14 +463,14 @@ func (b Bot) buildPostText(reqID string, reqData pd.AccessRequestData) (string, } func (b Bot) tryLookupDirectChannel(ctx context.Context, userEmail string) string { - log := logger.Get(ctx).WithField("mm_user_email", userEmail) + log := logger.Get(ctx).With("mm_user_email", userEmail) channel, err := b.LookupDirectChannel(ctx, userEmail) if err != nil { var errResult *ErrorResult if errors.As(trace.Unwrap(err), &errResult) { - log.Warningf("Failed to lookup direct channel info: %q", errResult.Message) + log.WarnContext(ctx, "Failed to lookup direct channel info", "error", errResult.Message) } else { - log.WithError(err).Error("Failed to lookup direct channel info") + log.ErrorContext(ctx, "Failed to lookup direct channel info", "error", err) } return "" } @@ -478,17 +478,17 @@ func (b Bot) tryLookupDirectChannel(ctx context.Context, userEmail string) strin } func (b Bot) tryLookupChannel(ctx context.Context, team, name string) string { - log := logger.Get(ctx).WithFields(logger.Fields{ - "mm_team": team, - "mm_channel": name, - }) + log := logger.Get(ctx).With( + "mm_team", team, + "mm_channel", name, + ) channel, err := b.LookupChannel(ctx, team, name) if err != nil { var errResult *ErrorResult if errors.As(trace.Unwrap(err), &errResult) { - log.Warningf("Failed to lookup channel info: %q", errResult.Message) + log.WarnContext(ctx, "Failed to lookup channel info", "error", errResult.Message) } else { - log.WithError(err).Error("Failed to lookup channel info") + log.ErrorContext(ctx, "Failed to lookup channel info", "error", err) } return "" } diff --git a/integrations/access/mattermost/cmd/teleport-mattermost/main.go b/integrations/access/mattermost/cmd/teleport-mattermost/main.go index 7c4777b26655b..0c67abb62ef86 100644 --- a/integrations/access/mattermost/cmd/teleport-mattermost/main.go +++ b/integrations/access/mattermost/cmd/teleport-mattermost/main.go @@ -20,6 +20,7 @@ import ( "context" _ "embed" "fmt" + "log/slog" "os" "github.com/alecthomas/kingpin/v2" @@ -65,12 +66,13 @@ func main() { if err := run(*path, *debug); err != nil { lib.Bail(err) } else { - logger.Standard().Info("Successfully shut down") + slog.InfoContext(context.Background(), "Successfully shut down") } } } func run(configPath string, debug bool) error { + ctx := context.Background() conf, err := mattermost.LoadConfig(configPath) if err != nil { return trace.Wrap(err) @@ -84,14 +86,15 @@ func run(configPath string, debug bool) error { return err } if debug { - logger.Standard().Debugf("DEBUG logging enabled") + slog.DebugContext(ctx, "DEBUG logging enabled") } app := mattermost.NewMattermostApp(conf) go lib.ServeSignals(app, common.PluginShutdownTimeout) - logger.Standard().Infof("Starting Teleport Access Mattermost Plugin %s:%s", teleport.Version, teleport.Gitref) - return trace.Wrap( - app.Run(context.Background()), + slog.InfoContext(ctx, "Starting Teleport Access Mattermost Plugin", + "version", teleport.Version, + "git_ref", teleport.Gitref, ) + return trace.Wrap(app.Run(ctx)) } diff --git a/integrations/access/mattermost/testlib/fake_mattermost.go b/integrations/access/mattermost/testlib/fake_mattermost.go index 10cc048e743bd..b2c28287c6153 100644 --- a/integrations/access/mattermost/testlib/fake_mattermost.go +++ b/integrations/access/mattermost/testlib/fake_mattermost.go @@ -31,7 +31,6 @@ import ( "github.com/gravitational/trace" "github.com/julienschmidt/httprouter" - log "github.com/sirupsen/logrus" "github.com/gravitational/teleport/integrations/access/mattermost" ) @@ -387,6 +386,6 @@ func (s *FakeMattermost) CheckPostUpdate(ctx context.Context) (mattermost.Post, func panicIf(err error) { if err != nil { - log.Panicf("%v at %v", err, string(debug.Stack())) + panic(fmt.Sprintf("%v at %v", err, string(debug.Stack()))) } } diff --git a/integrations/access/msteams/app.go b/integrations/access/msteams/app.go index 306be091ca8b0..b18c96ba3f4a3 100644 --- a/integrations/access/msteams/app.go +++ b/integrations/access/msteams/app.go @@ -62,14 +62,9 @@ type App struct { // NewApp initializes a new teleport-msteams app and returns it. func NewApp(conf Config) (*App, error) { - log, err := conf.Log.NewSLogLogger() - if err != nil { - return nil, trace.Wrap(err) - } - app := &App{ conf: conf, - log: log.With("plugin", pluginName), + log: slog.With("plugin", pluginName), } app.mainJob = lib.NewServiceJob(app.run) diff --git a/integrations/access/msteams/bot.go b/integrations/access/msteams/bot.go index c0598c1f4d24f..4292f856dba90 100644 --- a/integrations/access/msteams/bot.go +++ b/integrations/access/msteams/bot.go @@ -30,7 +30,6 @@ import ( "github.com/gravitational/teleport/integrations/access/common" "github.com/gravitational/teleport/integrations/access/msteams/msapi" "github.com/gravitational/teleport/integrations/lib" - "github.com/gravitational/teleport/integrations/lib/logger" "github.com/gravitational/teleport/integrations/lib/plugindata" ) @@ -469,7 +468,7 @@ func (b *Bot) CheckHealth(ctx context.Context) error { Code: status, ErrorMessage: message, }); err != nil { - logger.Get(ctx).Errorf("Error while emitting ms teams plugin status: %v", err) + b.log.ErrorContext(ctx, "Error while emitting ms teams plugin status", "error", err) } } return trace.Wrap(err) diff --git a/integrations/access/msteams/cmd/teleport-msteams/main.go b/integrations/access/msteams/cmd/teleport-msteams/main.go index 970df1ac98db4..75e66a46b7cf7 100644 --- a/integrations/access/msteams/cmd/teleport-msteams/main.go +++ b/integrations/access/msteams/cmd/teleport-msteams/main.go @@ -16,6 +16,7 @@ package main import ( "context" + "log/slog" "os" "time" @@ -99,7 +100,7 @@ func main() { if err := run(*startConfigPath, *debug); err != nil { lib.Bail(err) } else { - logger.Standard().Info("Successfully shut down") + slog.InfoContext(context.Background(), "Successfully shut down") } } diff --git a/integrations/access/msteams/testlib/fake_msteams.go b/integrations/access/msteams/testlib/fake_msteams.go index ceb1a3edc2d41..f3e4d4c5550c2 100644 --- a/integrations/access/msteams/testlib/fake_msteams.go +++ b/integrations/access/msteams/testlib/fake_msteams.go @@ -30,7 +30,6 @@ import ( "github.com/google/uuid" "github.com/gravitational/trace" "github.com/julienschmidt/httprouter" - log "github.com/sirupsen/logrus" "github.com/gravitational/teleport/integrations/access/msteams/msapi" ) @@ -326,6 +325,6 @@ func (s *FakeTeams) CheckMessageUpdate(ctx context.Context) (Msg, error) { func panicIf(err error) { if err != nil { - log.Panicf("%v at %v", err, string(debug.Stack())) + panic(fmt.Sprintf("%v at %v", err, string(debug.Stack()))) } } diff --git a/integrations/access/msteams/uninstall.go b/integrations/access/msteams/uninstall.go index e60a9ce0c8ddd..22aa9e6961ab1 100644 --- a/integrations/access/msteams/uninstall.go +++ b/integrations/access/msteams/uninstall.go @@ -18,7 +18,8 @@ import ( "context" "github.com/gravitational/trace" - log "github.com/sirupsen/logrus" + + "github.com/gravitational/teleport/integrations/lib/logger" ) func Uninstall(ctx context.Context, configPath string) error { @@ -26,11 +27,13 @@ func Uninstall(ctx context.Context, configPath string) error { if err != nil { return trace.Wrap(err) } - err = checkApp(ctx, b) - if err != nil { + + if err := checkApp(ctx, b); err != nil { return trace.Wrap(err) } + log := logger.Get(ctx) + var errs []error for _, recipient := range c.Recipients.GetAllRawRecipients() { _, isChannel := b.checkChannelURL(recipient) @@ -38,11 +41,11 @@ func Uninstall(ctx context.Context, configPath string) error { errs = append(errs, b.UninstallAppForUser(ctx, recipient)) } } - err = trace.NewAggregate(errs...) - if err != nil { - log.Errorln("The following error(s) happened when uninstalling the Teams App:") + + if trace.NewAggregate(errs...) != nil { + log.ErrorContext(ctx, "Encountered error(s) when uninstalling the Teams App", "error", err) return err } - log.Info("Successfully uninstalled app for all recipients") + log.InfoContext(ctx, "Successfully uninstalled app for all recipients") return nil } diff --git a/integrations/access/msteams/validate.go b/integrations/access/msteams/validate.go index 61d9d25f635e8..7969d7edebe0d 100644 --- a/integrations/access/msteams/validate.go +++ b/integrations/access/msteams/validate.go @@ -17,6 +17,7 @@ package msteams import ( "context" "fmt" + "log/slog" "time" cards "github.com/DanielTitkov/go-adaptive-cards" @@ -142,11 +143,7 @@ func loadConfig(configPath string) (*Bot, *Config, error) { fmt.Printf(" - Checking application %v status...\n", c.MSAPI.TeamsAppID) - log, err := c.Log.NewSLogLogger() - if err != nil { - return nil, nil, trace.Wrap(err) - } - b, err := NewBot(c, "local", "", log) + b, err := NewBot(c, "local", "", slog.Default()) if err != nil { return nil, nil, trace.Wrap(err) } diff --git a/integrations/access/opsgenie/app.go b/integrations/access/opsgenie/app.go index 132389ad5b5a3..60950f31fa4b1 100644 --- a/integrations/access/opsgenie/app.go +++ b/integrations/access/opsgenie/app.go @@ -22,6 +22,7 @@ import ( "context" "errors" "fmt" + "log/slog" "strings" "time" @@ -39,6 +40,7 @@ import ( "github.com/gravitational/teleport/integrations/lib/logger" "github.com/gravitational/teleport/integrations/lib/watcherjob" "github.com/gravitational/teleport/lib/utils" + logutils "github.com/gravitational/teleport/lib/utils/log" ) const ( @@ -115,7 +117,7 @@ func (a *App) run(ctx context.Context) error { var err error log := logger.Get(ctx) - log.Infof("Starting Teleport Access Opsgenie Plugin") + log.InfoContext(ctx, "Starting Teleport Access Opsgenie Plugin") if err = a.init(ctx); err != nil { return trace.Wrap(err) @@ -147,9 +149,9 @@ func (a *App) run(ctx context.Context) error { a.mainJob.SetReady(ok) if ok { - log.Info("Plugin is ready") + log.InfoContext(ctx, "Plugin is ready") } else { - log.Error("Plugin is not ready") + log.ErrorContext(ctx, "Plugin is not ready") } <-watcherJob.Done() @@ -177,24 +179,24 @@ func (a *App) init(ctx context.Context) error { } log := logger.Get(ctx) - log.Debug("Starting API health check...") + log.DebugContext(ctx, "Starting API health check") if err = a.opsgenie.CheckHealth(ctx); err != nil { return trace.Wrap(err, "API health check failed") } - log.Debug("API health check finished ok") + log.DebugContext(ctx, "API health check finished ok") return nil } func (a *App) checkTeleportVersion(ctx context.Context) (proto.PingResponse, error) { log := logger.Get(ctx) - log.Debug("Checking Teleport server version") + log.DebugContext(ctx, "Checking Teleport server version") pong, err := a.teleport.Ping(ctx) if err != nil { if trace.IsNotImplemented(err) { return pong, trace.Wrap(err, "server version must be at least %s", minServerVersion) } - log.Error("Unable to get Teleport server version") + log.ErrorContext(ctx, "Unable to get Teleport server version") return pong, trace.Wrap(err) } err = utils.CheckMinVersion(pong.ServerVersion, minServerVersion) @@ -219,16 +221,16 @@ func (a *App) handleAcessRequest(ctx context.Context, event types.Event) error { } op := event.Type reqID := event.Resource.GetName() - ctx, _ = logger.WithField(ctx, "request_id", reqID) + ctx, _ = logger.With(ctx, "request_id", reqID) switch op { case types.OpPut: - ctx, _ = logger.WithField(ctx, "request_op", "put") + ctx, _ = logger.With(ctx, "request_op", "put") req, ok := event.Resource.(types.AccessRequest) if !ok { return trace.Errorf("unexpected resource type %T", event.Resource) } - ctx, log := logger.WithField(ctx, "request_state", req.GetState().String()) + ctx, log := logger.With(ctx, "request_state", req.GetState().String()) var err error switch { @@ -237,21 +239,29 @@ func (a *App) handleAcessRequest(ctx context.Context, event types.Event) error { case req.GetState().IsResolved(): err = a.onResolvedRequest(ctx, req) default: - log.WithField("event", event).Warn("Unknown request state") + log.WarnContext(ctx, "Unknown request state", + slog.Group("event", + slog.Any("type", logutils.StringerAttr(event.Type)), + slog.Group("resource", + "kind", event.Resource.GetKind(), + "name", event.Resource.GetName(), + ), + ), + ) return nil } if err != nil { - log.WithError(err).Error("Failed to process request") + log.ErrorContext(ctx, "Failed to process request", "error", err) return trace.Wrap(err) } return nil case types.OpDelete: - ctx, log := logger.WithField(ctx, "request_op", "delete") + ctx, log := logger.With(ctx, "request_op", "delete") if err := a.onDeletedRequest(ctx, reqID); err != nil { - log.WithError(err).Error("Failed to process deleted request") + log.ErrorContext(ctx, "Failed to process deleted request", "error", err) return trace.Wrap(err) } return nil @@ -310,13 +320,13 @@ func (a *App) getNotifySchedulesAndTeams(ctx context.Context, req types.AccessRe scheduleAnnotationKey := types.TeleportNamespace + types.ReqAnnotationNotifySchedulesLabel schedules, err = common.GetNamesFromAnnotations(req, scheduleAnnotationKey) if err != nil { - log.Debugf("No schedules to notify in %s", scheduleAnnotationKey) + log.DebugContext(ctx, "No schedules to notify", "schedule", scheduleAnnotationKey) } teamAnnotationKey := types.TeleportNamespace + types.ReqAnnotationTeamsLabel teams, err = common.GetNamesFromAnnotations(req, teamAnnotationKey) if err != nil { - log.Debugf("No teams to notify in %s", teamAnnotationKey) + log.DebugContext(ctx, "No teams to notify", "teams", teamAnnotationKey) } if len(schedules) == 0 && len(teams) == 0 { @@ -336,7 +346,7 @@ func (a *App) tryNotifyService(ctx context.Context, req types.AccessRequest) (bo recipientSchedules, recipientTeams, err := a.getMessageRecipients(ctx, req) if err != nil { - log.Debugf("Skipping the notification: %s", err) + log.DebugContext(ctx, "Skipping notification", "error", err) return false, trace.Wrap(errMissingAnnotation) } @@ -434,8 +444,8 @@ func (a *App) createAlert(ctx context.Context, reqID string, reqData RequestData if err != nil { return trace.Wrap(err) } - ctx, log := logger.WithField(ctx, "opsgenie_alert_id", data.AlertID) - log.Info("Successfully created Opsgenie alert") + ctx, log := logger.With(ctx, "opsgenie_alert_id", data.AlertID) + log.InfoContext(ctx, "Successfully created Opsgenie alert") // Save opsgenie alert info in plugin data. _, err = a.modifyPluginData(ctx, reqID, func(existing *PluginData) (PluginData, bool) { @@ -479,10 +489,10 @@ func (a *App) postReviewNotes(ctx context.Context, reqID string, reqReviews []ty return trace.Wrap(err) } if !ok { - logger.Get(ctx).Debug("Failed to post the note: plugin data is missing") + logger.Get(ctx).DebugContext(ctx, "Failed to post the note: plugin data is missing") return nil } - ctx, _ = logger.WithField(ctx, "opsgenie_alert_id", data.AlertID) + ctx, _ = logger.With(ctx, "opsgenie_alert_id", data.AlertID) slice := reqReviews[oldCount:] if len(slice) == 0 { @@ -504,7 +514,7 @@ func (a *App) tryApproveRequest(ctx context.Context, req types.AccessRequest) er serviceNames, err := a.getOnCallServiceNames(req) if err != nil { - logger.Get(ctx).Debugf("Skipping the approval: %s", err) + logger.Get(ctx).DebugContext(ctx, "Skipping approval", "error", err) return nil } @@ -537,14 +547,14 @@ func (a *App) tryApproveRequest(ctx context.Context, req types.AccessRequest) er }, }); err != nil { if strings.HasSuffix(err.Error(), "has already reviewed this request") { - log.Debug("Already reviewed the request") + log.DebugContext(ctx, "Already reviewed the request") return nil } return trace.Wrap(err, "submitting access request") } } - log.Info("Successfully submitted a request approval") + log.InfoContext(ctx, "Successfully submitted a request approval") return nil } @@ -576,15 +586,15 @@ func (a *App) resolveAlert(ctx context.Context, reqID string, resolution Resolut return trace.Wrap(err) } if !ok { - logger.Get(ctx).Debug("Failed to resolve the alert: plugin data is missing") + logger.Get(ctx).DebugContext(ctx, "Failed to resolve the alert: plugin data is missing") return nil } - ctx, log := logger.WithField(ctx, "opsgenie_alert_id", alertID) + ctx, log := logger.With(ctx, "opsgenie_alert_id", alertID) if err := a.opsgenie.ResolveAlert(ctx, alertID, resolution); err != nil { return trace.Wrap(err) } - log.Info("Successfully resolved the alert") + log.InfoContext(ctx, "Successfully resolved the alert") return nil } diff --git a/integrations/access/opsgenie/client.go b/integrations/access/opsgenie/client.go index 2619c6ed6f7a9..2c8cdaec09a33 100644 --- a/integrations/access/opsgenie/client.go +++ b/integrations/access/opsgenie/client.go @@ -185,10 +185,10 @@ func (og Client) tryGetAlertRequestResult(ctx context.Context, reqID string) (Ge for { alertRequestResult, err := og.getAlertRequestResult(ctx, reqID) if err == nil { - logger.Get(ctx).Debugf("Got alert request result: %+v", alertRequestResult) + logger.Get(ctx).DebugContext(ctx, "Got alert request result", "alert_id", alertRequestResult.Data.AlertID) return alertRequestResult, nil } - logger.Get(ctx).Debug("Failed to get alert request result:", err) + logger.Get(ctx).DebugContext(ctx, "Failed to get alert request result", "error", err) if err := backoff.Do(ctx); err != nil { return GetAlertRequestResult{}, trace.Wrap(err) } @@ -344,8 +344,10 @@ func (og Client) CheckHealth(ctx context.Context) error { code = types.PluginStatusCode_OTHER_ERROR } if err := og.StatusSink.Emit(ctx, &types.PluginStatusV1{Code: code}); err != nil { - logger.Get(resp.Request.Context()).WithError(err). - WithField("code", resp.StatusCode()).Errorf("Error while emitting servicenow plugin status: %v", err) + logger.Get(resp.Request.Context()).ErrorContext(ctx, "Error while emitting servicenow plugin status", + "error", err, + "code", resp.StatusCode(), + ) } } diff --git a/integrations/access/opsgenie/testlib/fake_opsgenie.go b/integrations/access/opsgenie/testlib/fake_opsgenie.go index 9b5e6252119d1..1c124e19a75fc 100644 --- a/integrations/access/opsgenie/testlib/fake_opsgenie.go +++ b/integrations/access/opsgenie/testlib/fake_opsgenie.go @@ -32,7 +32,6 @@ import ( "github.com/gravitational/trace" "github.com/julienschmidt/httprouter" - log "github.com/sirupsen/logrus" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/integrations/access/opsgenie" @@ -314,7 +313,7 @@ func (s *FakeOpsgenie) GetSchedule(scheduleName string) ([]opsgenie.Responder, b func panicIf(err error) { if err != nil { - log.Panicf("%v at %v", err, string(debug.Stack())) + panic(fmt.Sprintf("%v at %v", err, string(debug.Stack()))) } } diff --git a/integrations/access/pagerduty/app.go b/integrations/access/pagerduty/app.go index 5eadcc5147cd0..2351c5d2d5f02 100644 --- a/integrations/access/pagerduty/app.go +++ b/integrations/access/pagerduty/app.go @@ -22,6 +22,7 @@ import ( "context" "errors" "fmt" + "log/slog" "strings" "time" @@ -38,6 +39,7 @@ import ( "github.com/gravitational/teleport/integrations/lib/logger" "github.com/gravitational/teleport/integrations/lib/watcherjob" "github.com/gravitational/teleport/lib/utils" + logutils "github.com/gravitational/teleport/lib/utils/log" ) const ( @@ -106,7 +108,6 @@ func (a *App) run(ctx context.Context) error { var err error log := logger.Get(ctx) - log.Infof("Starting Teleport Access PagerDuty Plugin") if err = a.init(ctx); err != nil { return trace.Wrap(err) @@ -146,9 +147,9 @@ func (a *App) run(ctx context.Context) error { a.mainJob.SetReady(ok) if ok { - log.Info("Plugin is ready") + log.InfoContext(ctx, "Plugin is ready") } else { - log.Error("Plugin is not ready") + log.ErrorContext(ctx, "Plugin is not ready") } <-watcherJob.Done() @@ -202,25 +203,25 @@ func (a *App) init(ctx context.Context) error { return trace.Wrap(err) } - log.Debug("Starting PagerDuty API health check...") + log.DebugContext(ctx, "Starting PagerDuty API health check") if err = a.pagerduty.HealthCheck(ctx); err != nil { return trace.Wrap(err, "api health check failed. check your credentials and service_id settings") } - log.Debug("PagerDuty API health check finished ok") + log.DebugContext(ctx, "PagerDuty API health check finished ok") return nil } func (a *App) checkTeleportVersion(ctx context.Context) (proto.PingResponse, error) { log := logger.Get(ctx) - log.Debug("Checking Teleport server version") + log.DebugContext(ctx, "Checking Teleport server version") pong, err := a.teleport.Ping(ctx) if err != nil { if trace.IsNotImplemented(err) { return pong, trace.Wrap(err, "server version must be at least %s", minServerVersion) } - log.Error("Unable to get Teleport server version") + log.ErrorContext(ctx, "Unable to get Teleport server version") return pong, trace.Wrap(err) } err = utils.CheckMinVersion(pong.ServerVersion, minServerVersion) @@ -245,16 +246,16 @@ func (a *App) handleAccessRequest(ctx context.Context, event types.Event) error } op := event.Type reqID := event.Resource.GetName() - ctx, _ = logger.WithField(ctx, "request_id", reqID) + ctx, _ = logger.With(ctx, "request_id", reqID) switch op { case types.OpPut: - ctx, _ = logger.WithField(ctx, "request_op", "put") + ctx, _ = logger.With(ctx, "request_op", "put") req, ok := event.Resource.(types.AccessRequest) if !ok { return trace.Errorf("unexpected resource type %T", event.Resource) } - ctx, log := logger.WithField(ctx, "request_state", req.GetState().String()) + ctx, log := logger.With(ctx, "request_state", req.GetState().String()) var err error switch { @@ -263,21 +264,29 @@ func (a *App) handleAccessRequest(ctx context.Context, event types.Event) error case req.GetState().IsResolved(): err = a.onResolvedRequest(ctx, req) default: - log.WithField("event", event).Warn("Unknown request state") + log.WarnContext(ctx, "Unknown request state", + slog.Group("event", + slog.Any("type", logutils.StringerAttr(event.Type)), + slog.Group("resource", + "kind", event.Resource.GetKind(), + "name", event.Resource.GetName(), + ), + ), + ) return nil } if err != nil { - log.WithError(err).Error("Failed to process request") + log.ErrorContext(ctx, "Failed to process request", "error", err) return trace.Wrap(err) } return nil case types.OpDelete: - ctx, log := logger.WithField(ctx, "request_op", "delete") + ctx, log := logger.With(ctx, "request_op", "delete") if err := a.onDeletedRequest(ctx, reqID); err != nil { - log.WithError(err).Error("Failed to process deleted request") + log.ErrorContext(ctx, "Failed to process deleted request", "error", err) return trace.Wrap(err) } return nil @@ -288,7 +297,7 @@ func (a *App) handleAccessRequest(ctx context.Context, event types.Event) error func (a *App) onPendingRequest(ctx context.Context, req types.AccessRequest) error { if len(req.GetSystemAnnotations()) == 0 { - logger.Get(ctx).Debug("Cannot proceed further. Request is missing any annotations") + logger.Get(ctx).DebugContext(ctx, "Cannot proceed further - request is missing any annotations") return nil } @@ -370,11 +379,11 @@ func (a *App) tryNotifyService(ctx context.Context, req types.AccessRequest) (bo serviceName, err := a.getNotifyServiceName(ctx, req) if err != nil { - log.Debugf("Skipping the notification: %s", err) + log.DebugContext(ctx, "Skipping the notification", "error", err) return false, trace.Wrap(errSkip) } - ctx, _ = logger.WithField(ctx, "pd_service_name", serviceName) + ctx, _ = logger.With(ctx, "pd_service_name", serviceName) service, err := a.pagerduty.FindServiceByName(ctx, serviceName) if err != nil { return false, trace.Wrap(err, "finding pagerduty service %s", serviceName) @@ -420,8 +429,8 @@ func (a *App) createIncident(ctx context.Context, serviceID, reqID string, reqDa if err != nil { return trace.Wrap(err) } - ctx, log := logger.WithField(ctx, "pd_incident_id", data.IncidentID) - log.Info("Successfully created PagerDuty incident") + ctx, log := logger.With(ctx, "pd_incident_id", data.IncidentID) + log.InfoContext(ctx, "Successfully created PagerDuty incident") // Save pagerduty incident info in plugin data. _, err = a.modifyPluginData(ctx, reqID, func(existing *PluginData) (PluginData, bool) { @@ -465,10 +474,10 @@ func (a *App) postReviewNotes(ctx context.Context, reqID string, reqReviews []ty return trace.Wrap(err) } if !ok { - logger.Get(ctx).Debug("Failed to post the note: plugin data is missing") + logger.Get(ctx).DebugContext(ctx, "Failed to post the note: plugin data is missing") return nil } - ctx, _ = logger.WithField(ctx, "pd_incident_id", data.IncidentID) + ctx, _ = logger.With(ctx, "pd_incident_id", data.IncidentID) slice := reqReviews[oldCount:] if len(slice) == 0 { @@ -490,36 +499,40 @@ func (a *App) tryApproveRequest(ctx context.Context, req types.AccessRequest) er serviceNames, err := a.getOnCallServiceNames(req) if err != nil { - logger.Get(ctx).Debugf("Skipping the approval: %s", err) + logger.Get(ctx).DebugContext(ctx, "Skipping approval", "error", err) return nil } userName := req.GetUser() if !lib.IsEmail(userName) { - logger.Get(ctx).Warningf("Skipping the approval: %q does not look like a valid email", userName) + logger.Get(ctx).WarnContext(ctx, "Skipping approval, found invalid email", "pd_user_email", userName) return nil } user, err := a.pagerduty.FindUserByEmail(ctx, userName) if err != nil { if trace.IsNotFound(err) { - log.WithError(err).WithField("pd_user_email", userName).Debug("Skipping the approval: email is not found") + log.DebugContext(ctx, "Skipping approval, email is not found", + "error", err, + "pd_user_email", userName) return nil } return trace.Wrap(err) } - ctx, log = logger.WithFields(ctx, logger.Fields{ - "pd_user_email": user.Email, - "pd_user_name": user.Name, - }) + ctx, log = logger.With(ctx, + "pd_user_email", user.Email, + "pd_user_name", user.Name, + ) services, err := a.pagerduty.FindServicesByNames(ctx, serviceNames) if err != nil { return trace.Wrap(err) } if len(services) == 0 { - log.WithField("pd_service_names", serviceNames).Warning("Failed to find any service") + log.WarnContext(ctx, "Failed to find any service", + "pd_service_names", serviceNames, + ) return nil } @@ -536,7 +549,7 @@ func (a *App) tryApproveRequest(ctx context.Context, req types.AccessRequest) er return trace.Wrap(err) } if len(escalationPolicyIDs) == 0 { - log.Debug("Skipping the approval: user is not on call") + log.DebugContext(ctx, "Skipping the approval: user is not on call") return nil } @@ -561,13 +574,13 @@ func (a *App) tryApproveRequest(ctx context.Context, req types.AccessRequest) er }, }); err != nil { if strings.HasSuffix(err.Error(), "has already reviewed this request") { - log.Debug("Already reviewed the request") + log.DebugContext(ctx, "Already reviewed the request") return nil } return trace.Wrap(err, "submitting access request") } - log.Info("Successfully submitted a request approval") + log.InfoContext(ctx, "Successfully submitted a request approval") return nil } @@ -599,15 +612,15 @@ func (a *App) resolveIncident(ctx context.Context, reqID string, resolution Reso return trace.Wrap(err) } if !ok { - logger.Get(ctx).Debug("Failed to resolve the incident: plugin data is missing") + logger.Get(ctx).DebugContext(ctx, "Failed to resolve the incident: plugin data is missing") return nil } - ctx, log := logger.WithField(ctx, "pd_incident_id", incidentID) + ctx, log := logger.With(ctx, "pd_incident_id", incidentID) if err := a.pagerduty.ResolveIncident(ctx, incidentID, resolution); err != nil { return trace.Wrap(err) } - log.Info("Successfully resolved the incident") + log.InfoContext(ctx, "Successfully resolved the incident") return nil } diff --git a/integrations/access/pagerduty/client.go b/integrations/access/pagerduty/client.go index 51adfb38f5aed..fd42876a154ca 100644 --- a/integrations/access/pagerduty/client.go +++ b/integrations/access/pagerduty/client.go @@ -122,7 +122,7 @@ func onAfterPagerDutyResponse(sink common.StatusSink) resty.ResponseMiddleware { defer cancel() if err := sink.Emit(ctx, status); err != nil { - log.WithError(err).Errorf("Error while emitting PagerDuty plugin status: %v", err) + log.ErrorContext(ctx, "Error while emitting PagerDuty plugin status", "error", err) } if resp.IsError() { @@ -288,7 +288,7 @@ func (p *Pagerduty) FindUserByEmail(ctx context.Context, userEmail string) (User } if len(result.Users) > 0 && result.More { - logger.Get(ctx).Warningf("PagerDuty returned too many results when querying by email %q", userEmail) + logger.Get(ctx).WarnContext(ctx, "PagerDuty returned too many results when querying user email", "email", userEmail) } return User{}, trace.NotFound("failed to find pagerduty user by email %s", userEmail) @@ -387,10 +387,10 @@ func (p *Pagerduty) FilterOnCallPolicies(ctx context.Context, userID string, esc if len(filteredIDSet) == 0 { if anyData { - logger.Get(ctx).WithFields(logger.Fields{ - "pd_user_id": userID, - "pd_escalation_policy_ids": escalationPolicyIDs, - }).Warningf("PagerDuty returned some oncalls array but none of them matched the query") + logger.Get(ctx).WarnContext(ctx, "PagerDuty returned some oncalls array but none of them matched the query", + "pd_user_id", userID, + "pd_escalation_policy_ids", escalationPolicyIDs, + ) } return nil, nil diff --git a/integrations/access/pagerduty/cmd/teleport-pagerduty/main.go b/integrations/access/pagerduty/cmd/teleport-pagerduty/main.go index aa4a8ba96eb32..58cfa27248d56 100644 --- a/integrations/access/pagerduty/cmd/teleport-pagerduty/main.go +++ b/integrations/access/pagerduty/cmd/teleport-pagerduty/main.go @@ -20,6 +20,7 @@ import ( "context" _ "embed" "fmt" + "log/slog" "os" "github.com/alecthomas/kingpin/v2" @@ -65,12 +66,13 @@ func main() { if err := run(*path, *debug); err != nil { lib.Bail(err) } else { - logger.Standard().Info("Successfully shut down") + slog.InfoContext(context.Background(), "Successfully shut down") } } } func run(configPath string, debug bool) error { + ctx := context.Background() conf, err := pagerduty.LoadConfig(configPath) if err != nil { return trace.Wrap(err) @@ -84,7 +86,7 @@ func run(configPath string, debug bool) error { return err } if debug { - logger.Standard().Debugf("DEBUG logging enabled") + slog.DebugContext(ctx, "DEBUG logging enabled") } app, err := pagerduty.NewApp(*conf) @@ -94,8 +96,9 @@ func run(configPath string, debug bool) error { go lib.ServeSignals(app, common.PluginShutdownTimeout) - logger.Standard().Infof("Starting Teleport Access PagerDuty Plugin %s:%s", teleport.Version, teleport.Gitref) - return trace.Wrap( - app.Run(context.Background()), + slog.InfoContext(ctx, "Starting Teleport Access PagerDuty Plugin", + "version", teleport.Version, + "git_ref", teleport.Gitref, ) + return trace.Wrap(app.Run(ctx)) } diff --git a/integrations/access/pagerduty/testlib/fake_pagerduty.go b/integrations/access/pagerduty/testlib/fake_pagerduty.go index 18a2a6ae24361..eee358f022458 100644 --- a/integrations/access/pagerduty/testlib/fake_pagerduty.go +++ b/integrations/access/pagerduty/testlib/fake_pagerduty.go @@ -32,7 +32,6 @@ import ( "github.com/gravitational/trace" "github.com/julienschmidt/httprouter" - log "github.com/sirupsen/logrus" "github.com/gravitational/teleport/integrations/access/pagerduty" "github.com/gravitational/teleport/integrations/lib/stringset" @@ -565,6 +564,6 @@ func (s *FakePagerduty) CheckNewIncidentNote(ctx context.Context) (FakeIncidentN func panicIf(err error) { if err != nil { - log.Panicf("%v at %v", err, string(debug.Stack())) + panic(fmt.Sprintf("%v at %v", err, string(debug.Stack()))) } } diff --git a/integrations/access/servicenow/app.go b/integrations/access/servicenow/app.go index 3d56f4fc97a8b..07248b488d872 100644 --- a/integrations/access/servicenow/app.go +++ b/integrations/access/servicenow/app.go @@ -21,6 +21,7 @@ package servicenow import ( "context" "fmt" + "log/slog" "net/url" "slices" "strings" @@ -41,6 +42,7 @@ import ( "github.com/gravitational/teleport/integrations/lib/logger" "github.com/gravitational/teleport/integrations/lib/watcherjob" "github.com/gravitational/teleport/lib/utils" + logutils "github.com/gravitational/teleport/lib/utils/log" ) const ( @@ -116,7 +118,7 @@ func (a *App) WaitReady(ctx context.Context) (bool, error) { func (a *App) run(ctx context.Context) error { log := logger.Get(ctx) - log.Infof("Starting Teleport Access Servicenow Plugin") + log.InfoContext(ctx, "Starting Teleport Access Servicenow Plugin") if err := a.init(ctx); err != nil { return trace.Wrap(err) @@ -153,9 +155,9 @@ func (a *App) run(ctx context.Context) error { } a.mainJob.SetReady(ok) if ok { - log.Info("ServiceNow plugin is ready") + log.InfoContext(ctx, "ServiceNow plugin is ready") } else { - log.Error("ServiceNow plugin is not ready") + log.ErrorContext(ctx, "ServiceNow plugin is not ready") } <-watcherJob.Done() @@ -190,25 +192,25 @@ func (a *App) init(ctx context.Context) error { return trace.Wrap(err) } - log.Debug("Starting API health check...") + log.DebugContext(ctx, "Starting API health check") if err = a.serviceNow.CheckHealth(ctx); err != nil { return trace.Wrap(err, "API health check failed") } - log.Debug("API health check finished ok") + log.DebugContext(ctx, "API health check finished ok") return nil } func (a *App) checkTeleportVersion(ctx context.Context) (proto.PingResponse, error) { log := logger.Get(ctx) - log.Debug("Checking Teleport server version") + log.DebugContext(ctx, "Checking Teleport server version") pong, err := a.teleport.Ping(ctx) if err != nil { if trace.IsNotImplemented(err) { return pong, trace.Wrap(err, "server version must be at least %s", minServerVersion) } - log.Error("Unable to get Teleport server version") + log.ErrorContext(ctx, "Unable to get Teleport server version") return pong, trace.Wrap(err) } err = utils.CheckMinVersion(pong.ServerVersion, minServerVersion) @@ -233,16 +235,16 @@ func (a *App) handleAccessRequest(ctx context.Context, event types.Event) error } op := event.Type reqID := event.Resource.GetName() - ctx, _ = logger.WithField(ctx, "request_id", reqID) + ctx, _ = logger.With(ctx, "request_id", reqID) switch op { case types.OpPut: - ctx, _ = logger.WithField(ctx, "request_op", "put") + ctx, _ = logger.With(ctx, "request_op", "put") req, ok := event.Resource.(types.AccessRequest) if !ok { return trace.Errorf("unexpected resource type %T", event.Resource) } - ctx, log := logger.WithField(ctx, "request_state", req.GetState().String()) + ctx, log := logger.With(ctx, "request_state", req.GetState().String()) var err error switch { @@ -251,21 +253,29 @@ func (a *App) handleAccessRequest(ctx context.Context, event types.Event) error case req.GetState().IsResolved(): err = a.onResolvedRequest(ctx, req) default: - log.WithField("event", event).Warnf("Unknown request state: %q", req.GetState()) + log.WarnContext(ctx, "Unknown request state", + slog.Group("event", + slog.Any("type", logutils.StringerAttr(event.Type)), + slog.Group("resource", + "kind", event.Resource.GetKind(), + "name", event.Resource.GetName(), + ), + ), + ) return nil } if err != nil { - log.WithError(err).Error("Failed to process request") + log.ErrorContext(ctx, "Failed to process request", "error", err) return trace.Wrap(err) } return nil case types.OpDelete: - ctx, log := logger.WithField(ctx, "request_op", "delete") + ctx, log := logger.With(ctx, "request_op", "delete") if err := a.onDeletedRequest(ctx, reqID); err != nil { - log.WithError(err).Error("Failed to process deleted request") + log.ErrorContext(ctx, "Failed to process deleted request", "error", err) return trace.Wrap(err) } return nil @@ -276,7 +286,7 @@ func (a *App) handleAccessRequest(ctx context.Context, event types.Event) error func (a *App) onPendingRequest(ctx context.Context, req types.AccessRequest) error { reqID := req.GetName() - log := logger.Get(ctx).WithField("reqId", reqID) + log := logger.Get(ctx).With("req_id", reqID) resourceNames, err := a.getResourceNames(ctx, req) if err != nil { @@ -303,7 +313,7 @@ func (a *App) onPendingRequest(ctx context.Context, req types.AccessRequest) err } if isNew { - log.Infof("Creating servicenow incident") + log.InfoContext(ctx, "Creating servicenow incident") recipientAssignee := a.accessMonitoringRules.RecipientsFromAccessMonitoringRules(ctx, req) assignees := []string{} recipientAssignee.ForEach(func(r common.Recipient) { @@ -375,8 +385,8 @@ func (a *App) createIncident(ctx context.Context, reqID string, reqData RequestD if err != nil { return trace.Wrap(err) } - ctx, log := logger.WithField(ctx, "servicenow_incident_id", data.IncidentID) - log.Info("Successfully created Servicenow incident") + ctx, log := logger.With(ctx, "servicenow_incident_id", data.IncidentID) + log.InfoContext(ctx, "Successfully created Servicenow incident") // Save servicenow incident info in plugin data. _, err = a.modifyPluginData(ctx, reqID, func(existing *PluginData) (PluginData, bool) { @@ -420,10 +430,10 @@ func (a *App) postReviewNotes(ctx context.Context, reqID string, reqReviews []ty return trace.Wrap(err) } if !ok { - logger.Get(ctx).Debug("Failed to post the note: plugin data is missing") + logger.Get(ctx).DebugContext(ctx, "Failed to post the note: plugin data is missing") return nil } - ctx, _ = logger.WithField(ctx, "servicenow_incident_id", data.IncidentID) + ctx, _ = logger.With(ctx, "servicenow_incident_id", data.IncidentID) slice := reqReviews[oldCount:] if len(slice) == 0 { @@ -445,22 +455,28 @@ func (a *App) tryApproveRequest(ctx context.Context, req types.AccessRequest) er serviceNames, err := a.getOnCallServiceNames(req) if err != nil { - logger.Get(ctx).Debugf("Skipping the approval: %s", err) + logger.Get(ctx).DebugContext(ctx, "Skipping the approval", "error", err) return nil } - log.Debugf("Checking the following shifts to see if the requester is on-call: %s", serviceNames) + log.DebugContext(ctx, "Checking the shifts to see if the requester is on-call", "shifts", serviceNames) onCallUsers, err := a.getOnCallUsers(ctx, serviceNames) if err != nil { return trace.Wrap(err) } - log.Debugf("Users on-call are: %s", onCallUsers) + log.DebugContext(ctx, "Users on-call are", "on_call_users", onCallUsers) if userIsOnCall := slices.Contains(onCallUsers, req.GetUser()); !userIsOnCall { - log.Debugf("User %q is not on-call, not approving the request %q.", req.GetUser(), req.GetName()) + log.DebugContext(ctx, "User is not on-call, not approving the request", + "user", req.GetUser(), + "request", req.GetName(), + ) return nil } - log.Debugf("User %q is on-call. Auto-approving the request %q.", req.GetUser(), req.GetName()) + log.DebugContext(ctx, "User is on-call, auto-approving the request", + "user", req.GetUser(), + "request", req.GetName(), + ) if _, err := a.teleport.SubmitAccessReview(ctx, types.AccessReviewSubmission{ RequestID: req.GetName(), Review: types.AccessReview{ @@ -474,12 +490,12 @@ func (a *App) tryApproveRequest(ctx context.Context, req types.AccessRequest) er }, }); err != nil { if strings.HasSuffix(err.Error(), "has already reviewed this request") { - log.Debug("Already reviewed the request") + log.DebugContext(ctx, "Already reviewed the request") return nil } return trace.Wrap(err, "submitting access request") } - log.Info("Successfully submitted a request approval") + log.InfoContext(ctx, "Successfully submitted a request approval") return nil } @@ -490,7 +506,7 @@ func (a *App) getOnCallUsers(ctx context.Context, serviceNames []string) ([]stri respondersResult, err := a.serviceNow.GetOnCall(ctx, scheduleName) if err != nil { if trace.IsNotFound(err) { - log.WithError(err).Error("Failed to retrieve responder from schedule") + log.ErrorContext(ctx, "Failed to retrieve responder from schedule", "error", err) continue } return nil, trace.Wrap(err) @@ -528,15 +544,15 @@ func (a *App) resolveIncident(ctx context.Context, reqID string, resolution Reso return trace.Wrap(err) } if !ok { - logger.Get(ctx).Debug("Failed to resolve the incident: plugin data is missing") + logger.Get(ctx).DebugContext(ctx, "Failed to resolve the incident: plugin data is missing") return nil } - ctx, log := logger.WithField(ctx, "servicenow_incident_id", incidentID) + ctx, log := logger.With(ctx, "servicenow_incident_id", incidentID) if err := a.serviceNow.ResolveIncident(ctx, incidentID, resolution); err != nil { return trace.Wrap(err) } - log.Info("Successfully resolved the incident") + log.InfoContext(ctx, "Successfully resolved the incident") return nil } diff --git a/integrations/access/servicenow/client.go b/integrations/access/servicenow/client.go index 8d0fb4f62b9de..8c306c1efa4ee 100644 --- a/integrations/access/servicenow/client.go +++ b/integrations/access/servicenow/client.go @@ -287,7 +287,10 @@ func (snc *Client) CheckHealth(ctx context.Context) error { } if err := snc.StatusSink.Emit(ctx, &types.PluginStatusV1{Code: code}); err != nil { log := logger.Get(resp.Request.Context()) - log.WithError(err).WithField("code", resp.StatusCode()).Errorf("Error while emitting servicenow plugin status: %v", err) + log.ErrorContext(ctx, "Error while emitting servicenow plugin status", + "error", err, + "code", resp.StatusCode(), + ) } } diff --git a/integrations/access/servicenow/testlib/fake_servicenow.go b/integrations/access/servicenow/testlib/fake_servicenow.go index 3b2d70e82a9b2..edf3fdced5fe7 100644 --- a/integrations/access/servicenow/testlib/fake_servicenow.go +++ b/integrations/access/servicenow/testlib/fake_servicenow.go @@ -32,7 +32,6 @@ import ( "github.com/google/uuid" "github.com/gravitational/trace" "github.com/julienschmidt/httprouter" - log "github.com/sirupsen/logrus" "github.com/gravitational/teleport/integrations/access/servicenow" "github.com/gravitational/teleport/integrations/lib/stringset" @@ -284,6 +283,6 @@ func (s *FakeServiceNow) getOnCall(rotationName string) []string { func panicIf(err error) { if err != nil { - log.Panicf("%v at %v", err, string(debug.Stack())) + panic(fmt.Sprintf("%v at %v", err, string(debug.Stack()))) } } diff --git a/integrations/access/slack/bot.go b/integrations/access/slack/bot.go index 9c58093cb9897..e7fefa0107163 100644 --- a/integrations/access/slack/bot.go +++ b/integrations/access/slack/bot.go @@ -29,7 +29,6 @@ import ( "github.com/go-resty/resty/v2" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" - log "github.com/sirupsen/logrus" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/types/accesslist" @@ -37,6 +36,7 @@ import ( "github.com/gravitational/teleport/integrations/access/accessrequest" "github.com/gravitational/teleport/integrations/access/common" "github.com/gravitational/teleport/integrations/lib" + "github.com/gravitational/teleport/integrations/lib/logger" pd "github.com/gravitational/teleport/integrations/lib/plugindata" ) @@ -68,7 +68,7 @@ func onAfterResponseSlack(sink common.StatusSink) func(_ *resty.Client, resp *re ctx, cancel := context.WithTimeout(context.Background(), statusEmitTimeout) defer cancel() if err := sink.Emit(ctx, status); err != nil { - log.Errorf("Error while emitting plugin status: %v", err) + logger.Get(ctx).ErrorContext(ctx, "Error while emitting plugin status", "error", err) } }() @@ -139,7 +139,7 @@ func (b Bot) BroadcastAccessRequestMessage(ctx context.Context, recipients []com // the case with most SSO setups. userRecipient, err := b.FetchRecipient(ctx, reqData.User) if err != nil { - log.Warningf("Unable to find user %s in Slack, will not be able to notify.", reqData.User) + logger.Get(ctx).WarnContext(ctx, "Unable to find user in Slack, will not be able to notify", "user", reqData.User) } // Include the user in the list of recipients if it exists. diff --git a/integrations/access/slack/cmd/teleport-slack/main.go b/integrations/access/slack/cmd/teleport-slack/main.go index 1f77db5f21492..ffa73144f540b 100644 --- a/integrations/access/slack/cmd/teleport-slack/main.go +++ b/integrations/access/slack/cmd/teleport-slack/main.go @@ -20,6 +20,7 @@ import ( "context" _ "embed" "fmt" + "log/slog" "os" "github.com/alecthomas/kingpin/v2" @@ -65,12 +66,13 @@ func main() { if err := run(*path, *debug); err != nil { lib.Bail(err) } else { - logger.Standard().Info("Successfully shut down") + slog.InfoContext(context.Background(), "Successfully shut down") } } } func run(configPath string, debug bool) error { + ctx := context.Background() conf, err := slack.LoadSlackConfig(configPath) if err != nil { return trace.Wrap(err) @@ -84,14 +86,15 @@ func run(configPath string, debug bool) error { return trace.Wrap(err) } if debug { - logger.Standard().Debugf("DEBUG logging enabled") + slog.DebugContext(ctx, "DEBUG logging enabled") } app := slack.NewSlackApp(conf) go lib.ServeSignals(app, common.PluginShutdownTimeout) - logger.Standard().Infof("Starting Teleport Access Slack Plugin %s:%s", teleport.Version, teleport.Gitref) - return trace.Wrap( - app.Run(context.Background()), + slog.InfoContext(ctx, "Starting Teleport Access Slack Plugin", + "version", teleport.Version, + "git_ref", teleport.Gitref, ) + return trace.Wrap(app.Run(ctx)) } diff --git a/integrations/access/slack/testlib/fake_slack.go b/integrations/access/slack/testlib/fake_slack.go index eef81460da7f1..d18a43230c744 100644 --- a/integrations/access/slack/testlib/fake_slack.go +++ b/integrations/access/slack/testlib/fake_slack.go @@ -31,7 +31,6 @@ import ( "github.com/gravitational/trace" "github.com/julienschmidt/httprouter" - log "github.com/sirupsen/logrus" "github.com/gravitational/teleport/integrations/access/slack" ) @@ -315,6 +314,6 @@ func (s *FakeSlack) CheckMessageUpdateByResponding(ctx context.Context) (slack.M func panicIf(err error) { if err != nil { - log.Panicf("%v at %v", err, string(debug.Stack())) + panic(fmt.Sprintf("%v at %v", err, string(debug.Stack()))) } } diff --git a/integrations/event-handler/fake_fluentd_test.go b/integrations/event-handler/fake_fluentd_test.go index ecf286569f12d..72a363468ba15 100644 --- a/integrations/event-handler/fake_fluentd_test.go +++ b/integrations/event-handler/fake_fluentd_test.go @@ -31,8 +31,6 @@ import ( "github.com/gravitational/trace" "github.com/stretchr/testify/require" - - "github.com/gravitational/teleport/integrations/lib/logger" ) type FakeFluentd struct { @@ -150,7 +148,6 @@ func (f *FakeFluentd) GetURL() string { func (f *FakeFluentd) Respond(w http.ResponseWriter, r *http.Request) { req, err := io.ReadAll(r.Body) if err != nil { - logger.Standard().WithError(err).Error("FakeFluentd Respond() failed to read body") fmt.Fprintln(w, "NOK") return } diff --git a/integrations/event-handler/main.go b/integrations/event-handler/main.go index 859f6544c1e06..693b5bb24e036 100644 --- a/integrations/event-handler/main.go +++ b/integrations/event-handler/main.go @@ -46,8 +46,6 @@ const ( ) func main() { - // This initializes the legacy logrus logger. This has been kept in place - // in case any of the dependencies are still using logrus. logger.Init() ctx := kong.Parse( @@ -64,17 +62,13 @@ func main() { Format: "text", } if cli.Debug { - enableLogDebug() logCfg.Severity = "debug" } - log, err := logCfg.NewSLogLogger() - if err != nil { - fmt.Println(trace.DebugReport(trace.Wrap(err, "initializing logger"))) + + if err := logger.Setup(logCfg); err != nil { + fmt.Println(trace.DebugReport(err)) os.Exit(-1) } - // Whilst this package mostly dependency injects slog, upstream dependencies - // may still use the default slog logger. - slog.SetDefault(log) switch { case ctx.Command() == "version": @@ -86,25 +80,16 @@ func main() { os.Exit(-1) } case ctx.Command() == "start": - err := start(log) + err := start(slog.Default()) if err != nil { lib.Bail(err) } else { - log.InfoContext(context.TODO(), "Successfully shut down") + slog.InfoContext(context.TODO(), "Successfully shut down") } } } -// turn on log debugging -func enableLogDebug() { - err := logger.Setup(logger.Config{Severity: "debug", Output: "stderr"}) - if err != nil { - fmt.Println(trace.DebugReport(err)) - os.Exit(-1) - } -} - // start spawns the main process func start(log *slog.Logger) error { app, err := NewApp(&cli.Start, log) diff --git a/integrations/lib/bail.go b/integrations/lib/bail.go index 72804cd0ac3c4..d1351bb05f7fe 100644 --- a/integrations/lib/bail.go +++ b/integrations/lib/bail.go @@ -19,22 +19,24 @@ package lib import ( + "context" "errors" + "log/slog" "os" "github.com/gravitational/trace" - log "github.com/sirupsen/logrus" ) // Bail exits with nonzero exit code and prints an error to a log. func Bail(err error) { + ctx := context.Background() var agg trace.Aggregate if errors.As(trace.Unwrap(err), &agg) { for i, err := range agg.Errors() { - log.WithError(err).Errorf("Terminating with fatal error [%d]...", i+1) + slog.ErrorContext(ctx, "Terminating with fatal error", "error_number", i+1, "error", err) } } else { - log.WithError(err).Error("Terminating with fatal error...") + slog.ErrorContext(ctx, "Terminating with fatal error", "error", err) } os.Exit(1) } diff --git a/integrations/lib/config.go b/integrations/lib/config.go index 24f6c981e6686..66285167e5e36 100644 --- a/integrations/lib/config.go +++ b/integrations/lib/config.go @@ -22,12 +22,12 @@ import ( "context" "errors" "io" + "log/slog" "os" "strings" "time" "github.com/gravitational/trace" - log "github.com/sirupsen/logrus" "google.golang.org/grpc" grpcbackoff "google.golang.org/grpc/backoff" @@ -137,7 +137,7 @@ func NewIdentityFileWatcher(ctx context.Context, path string, interval time.Dura } if err := dynamicCred.Reload(); err != nil { - log.WithError(err).Error("Failed to reload identity file from disk.") + slog.ErrorContext(ctx, "Failed to reload identity file from disk", "error", err) } timer.Reset(interval) } @@ -152,7 +152,7 @@ func (cfg TeleportConfig) NewClient(ctx context.Context) (*client.Client, error) case cfg.Addr != "": addr = cfg.Addr case cfg.AuthServer != "": - log.Warn("Configuration setting `auth_server` is deprecated, consider to change it to `addr`") + slog.WarnContext(ctx, "Configuration setting `auth_server` is deprecated, consider to change it to `addr`") addr = cfg.AuthServer } @@ -173,13 +173,13 @@ func (cfg TeleportConfig) NewClient(ctx context.Context) (*client.Client, error) } if validCred, err := credentials.CheckIfExpired(creds); err != nil { - log.Warn(err) + slog.WarnContext(ctx, "found expired credentials", "error", err) if !validCred { return nil, trace.BadParameter( "No valid credentials found, this likely means credentials are expired. In this case, please sign new credentials and increase their TTL if needed.", ) } - log.Info("At least one non-expired credential has been found, continuing startup") + slog.InfoContext(ctx, "At least one non-expired credential has been found, continuing startup") } bk := grpcbackoff.DefaultConfig diff --git a/integrations/lib/embeddedtbot/bot.go b/integrations/lib/embeddedtbot/bot.go index e693b40793fe5..b8ed026386114 100644 --- a/integrations/lib/embeddedtbot/bot.go +++ b/integrations/lib/embeddedtbot/bot.go @@ -26,7 +26,6 @@ import ( "time" "github.com/gravitational/trace" - log "github.com/sirupsen/logrus" "github.com/gravitational/teleport/api/client" "github.com/gravitational/teleport/api/client/proto" @@ -106,9 +105,9 @@ func (b *EmbeddedBot) start(ctx context.Context) { go func() { err := bot.Run(botCtx) if err != nil { - log.Errorf("bot exited with error: %s", err) + slog.ErrorContext(botCtx, "bot exited with error", "error", err) } else { - log.Infof("bot exited without error") + slog.InfoContext(botCtx, "bot exited without error") } b.errCh <- trace.Wrap(err) }() @@ -142,10 +141,10 @@ func (b *EmbeddedBot) waitForCredentials(ctx context.Context, deadline time.Dura select { case <-waitCtx.Done(): - log.Warn("context canceled while waiting for the bot client") + slog.WarnContext(ctx, "context canceled while waiting for the bot client") return nil, trace.Wrap(ctx.Err()) case <-b.credential.Ready(): - log.Infof("credential ready") + slog.InfoContext(ctx, "credential ready") } return b.credential, nil @@ -177,7 +176,7 @@ func (b *EmbeddedBot) StartAndWaitForCredentials(ctx context.Context, deadline t // buildClient reads tbot's memory disttination, retrieves the certificates // and builds a new Teleport client using those certs. func (b *EmbeddedBot) buildClient(ctx context.Context) (*client.Client, error) { - log.Infof("Building a new client to connect to %s", b.cfg.AuthServer) + slog.InfoContext(ctx, "Building a new client to connect to cluster", "auth_server_address", b.cfg.AuthServer) c, err := client.New(ctx, client.Config{ Addrs: []string{b.cfg.AuthServer}, Credentials: []client.Credentials{b.credential}, diff --git a/integrations/lib/http.go b/integrations/lib/http.go index dbb279913a5bd..6f98ad957a75c 100644 --- a/integrations/lib/http.go +++ b/integrations/lib/http.go @@ -24,6 +24,7 @@ import ( "crypto/x509" "errors" "fmt" + "log/slog" "net" "net/http" "net/url" @@ -33,7 +34,8 @@ import ( "github.com/gravitational/trace" "github.com/julienschmidt/httprouter" - log "github.com/sirupsen/logrus" + + logutils "github.com/gravitational/teleport/lib/utils/log" ) // TLSConfig stores TLS configuration for a http service @@ -178,7 +180,7 @@ func NewHTTP(config HTTPConfig) (*HTTP, error) { if verify := config.TLS.VerifyClientCertificateFunc; verify != nil { tlsConfig.VerifyPeerCertificate = func(_ [][]byte, chains [][]*x509.Certificate) error { if err := verify(chains); err != nil { - log.WithError(err).Error("HTTPS client certificate verification failed") + slog.ErrorContext(context.Background(), "HTTPS client certificate verification failed", "error", err) return err } return nil @@ -217,7 +219,7 @@ func BuildURLPath(args ...interface{}) string { // ListenAndServe runs a http(s) server on a provided port. func (h *HTTP) ListenAndServe(ctx context.Context) error { - defer log.Debug("HTTP server terminated") + defer slog.DebugContext(ctx, "HTTP server terminated") var err error h.server.BaseContext = func(_ net.Listener) context.Context { @@ -256,10 +258,10 @@ func (h *HTTP) ListenAndServe(ctx context.Context) error { } if h.Insecure { - log.Debugf("Starting insecure HTTP server on %s", addr) + slog.DebugContext(ctx, "Starting insecure HTTP server", "listen_addr", logutils.StringerAttr(addr)) err = h.server.Serve(listener) } else { - log.Debugf("Starting secure HTTPS server on %s", addr) + slog.DebugContext(ctx, "Starting secure HTTPS server", "listen_addr", logutils.StringerAttr(addr)) err = h.server.ServeTLS(listener, h.CertFile, h.KeyFile) } if errors.Is(err, http.ErrServerClosed) { @@ -288,7 +290,7 @@ func (h *HTTP) ServiceJob() ServiceJob { return NewServiceJob(func(ctx context.Context) error { MustGetProcess(ctx).OnTerminate(func(ctx context.Context) error { if err := h.ShutdownWithTimeout(ctx, time.Second*5); err != nil { - log.Error("HTTP server graceful shutdown failed") + slog.ErrorContext(ctx, "HTTP server graceful shutdown failed") return err } return nil diff --git a/integrations/lib/logger/logger.go b/integrations/lib/logger/logger.go index 7422f03ff906c..a1ce5bf7275ed 100644 --- a/integrations/lib/logger/logger.go +++ b/integrations/lib/logger/logger.go @@ -20,16 +20,11 @@ package logger import ( "context" - "io" - "io/fs" "log/slog" "os" - "strings" "github.com/gravitational/trace" - log "github.com/sirupsen/logrus" - "github.com/gravitational/teleport" "github.com/gravitational/teleport/lib/utils" logutils "github.com/gravitational/teleport/lib/utils/log" ) @@ -41,8 +36,6 @@ type Config struct { Format string `toml:"format"` } -type Fields = log.Fields - type contextKey struct{} var extraFields = []string{logutils.LevelField, logutils.ComponentField, logutils.CallerField} @@ -50,179 +43,50 @@ var extraFields = []string{logutils.LevelField, logutils.ComponentField, logutil // Init sets up logger for a typical daemon scenario until configuration // file is parsed func Init() { - formatter := &logutils.TextFormatter{ - EnableColors: utils.IsTerminal(os.Stderr), - ComponentPadding: 1, // We don't use components so strip the padding - ExtraFields: extraFields, - } - - log.SetOutput(os.Stderr) - if err := formatter.CheckAndSetDefaults(); err != nil { - log.WithError(err).Error("unable to create text log formatter") - return - } - - log.SetFormatter(formatter) + enableColors := utils.IsTerminal(os.Stderr) + logutils.Initialize(logutils.Config{ + Severity: slog.LevelInfo.String(), + Format: "text", + ExtraFields: extraFields, + EnableColors: enableColors, + Padding: 1, + }) } func Setup(conf Config) error { + var enableColors bool switch conf.Output { case "stderr", "error", "2": - log.SetOutput(os.Stderr) + enableColors = utils.IsTerminal(os.Stderr) case "", "stdout", "out", "1": - log.SetOutput(os.Stdout) + enableColors = utils.IsTerminal(os.Stdout) default: - // assume it's a file path: - logFile, err := os.Create(conf.Output) - if err != nil { - return trace.Wrap(err, "failed to create the log file") - } - log.SetOutput(logFile) } - switch strings.ToLower(conf.Severity) { - case "info": - log.SetLevel(log.InfoLevel) - case "err", "error": - log.SetLevel(log.ErrorLevel) - case "debug": - log.SetLevel(log.DebugLevel) - case "warn", "warning": - log.SetLevel(log.WarnLevel) - case "trace": - log.SetLevel(log.TraceLevel) - default: - return trace.BadParameter("unsupported logger severity: '%v'", conf.Severity) - } - - return nil + _, _, err := logutils.Initialize(logutils.Config{ + Output: conf.Output, + Severity: conf.Severity, + Format: conf.Format, + ExtraFields: extraFields, + EnableColors: enableColors, + Padding: 1, + }) + return trace.Wrap(err) } -// NewSLogLogger builds a slog.Logger from the logger.Config. -// TODO(tross): Defer logging initialization to logutils.Initialize and use the -// global slog loggers once integrations has been updated to use slog. -func (conf Config) NewSLogLogger() (*slog.Logger, error) { - const ( - // logFileDefaultMode is the preferred permissions mode for log file. - logFileDefaultMode fs.FileMode = 0o644 - // logFileDefaultFlag is the preferred flags set to log file. - logFileDefaultFlag = os.O_WRONLY | os.O_CREATE | os.O_APPEND - ) - - var w io.Writer - switch conf.Output { - case "": - w = logutils.NewSharedWriter(os.Stderr) - case "stderr", "error", "2": - w = logutils.NewSharedWriter(os.Stderr) - case "stdout", "out", "1": - w = logutils.NewSharedWriter(os.Stdout) - case teleport.Syslog: - w = os.Stderr - sw, err := logutils.NewSyslogWriter() - if err != nil { - slog.Default().ErrorContext(context.Background(), "Failed to switch logging to syslog", "error", err) - break - } - - // If syslog output has been configured and is supported by the operating system, - // then the shared writer is not needed because the syslog writer is already - // protected with a mutex. - w = sw - default: - // Assume this is a file path. - sharedWriter, err := logutils.NewFileSharedWriter(conf.Output, logFileDefaultFlag, logFileDefaultMode) - if err != nil { - return nil, trace.Wrap(err, "failed to init the log file shared writer") - } - w = logutils.NewWriterFinalizer[*logutils.FileSharedWriter](sharedWriter) - if err := sharedWriter.RunWatcherReopen(context.Background()); err != nil { - return nil, trace.Wrap(err) - } - } - - level := new(slog.LevelVar) - switch strings.ToLower(conf.Severity) { - case "", "info": - level.Set(slog.LevelInfo) - case "err", "error": - level.Set(slog.LevelError) - case teleport.DebugLevel: - level.Set(slog.LevelDebug) - case "warn", "warning": - level.Set(slog.LevelWarn) - case "trace": - level.Set(logutils.TraceLevel) - default: - return nil, trace.BadParameter("unsupported logger severity: %q", conf.Severity) - } - - configuredFields, err := logutils.ValidateFields(extraFields) - if err != nil { - return nil, trace.Wrap(err) - } - - var slogLogger *slog.Logger - switch strings.ToLower(conf.Format) { - case "": - fallthrough // not set. defaults to 'text' - case "text": - enableColors := utils.IsTerminal(os.Stderr) - slogLogger = slog.New(logutils.NewSlogTextHandler(w, logutils.SlogTextHandlerConfig{ - Level: level, - EnableColors: enableColors, - ConfiguredFields: configuredFields, - })) - slog.SetDefault(slogLogger) - case "json": - slogLogger = slog.New(logutils.NewSlogJSONHandler(w, logutils.SlogJSONHandlerConfig{ - Level: level, - ConfiguredFields: configuredFields, - })) - slog.SetDefault(slogLogger) - default: - return nil, trace.BadParameter("unsupported log output format : %q", conf.Format) - } - - return slogLogger, nil -} - -func WithLogger(ctx context.Context, logger log.FieldLogger) context.Context { - return withLogger(ctx, logger) -} - -func withLogger(ctx context.Context, logger log.FieldLogger) context.Context { +func WithLogger(ctx context.Context, logger *slog.Logger) context.Context { return context.WithValue(ctx, contextKey{}, logger) } -func WithField(ctx context.Context, key string, value interface{}) (context.Context, log.FieldLogger) { - logger := Get(ctx).WithField(key, value) - return withLogger(ctx, logger), logger +func With(ctx context.Context, args ...any) (context.Context, *slog.Logger) { + logger := Get(ctx).With(args...) + return WithLogger(ctx, logger), logger } -func WithFields(ctx context.Context, logFields Fields) (context.Context, log.FieldLogger) { - logger := Get(ctx).WithFields(logFields) - return withLogger(ctx, logger), logger -} - -func SetField(ctx context.Context, key string, value interface{}) context.Context { - ctx, _ = WithField(ctx, key, value) - return ctx -} - -func SetFields(ctx context.Context, logFields Fields) context.Context { - ctx, _ = WithFields(ctx, logFields) - return ctx -} - -func Get(ctx context.Context) log.FieldLogger { - if logger, ok := ctx.Value(contextKey{}).(log.FieldLogger); ok && logger != nil { +func Get(ctx context.Context) *slog.Logger { + if logger, ok := ctx.Value(contextKey{}).(*slog.Logger); ok && logger != nil { return logger } - return Standard() -} - -func Standard() log.FieldLogger { - return log.StandardLogger() + return slog.Default() } diff --git a/integrations/lib/signals.go b/integrations/lib/signals.go index 4774915a6271b..4702455dfc7ca 100644 --- a/integrations/lib/signals.go +++ b/integrations/lib/signals.go @@ -20,12 +20,11 @@ package lib import ( "context" + "log/slog" "os" "os/signal" "syscall" "time" - - log "github.com/sirupsen/logrus" ) type Terminable interface { @@ -48,9 +47,9 @@ func ServeSignals(app Terminable, shutdownTimeout time.Duration) { gracefulShutdown := func() { tctx, tcancel := context.WithTimeout(ctx, shutdownTimeout) defer tcancel() - log.Infof("Attempting graceful shutdown...") + slog.InfoContext(tctx, "Attempting graceful shutdown") if err := app.Shutdown(tctx); err != nil { - log.Infof("Graceful shutdown failed. Trying fast shutdown...") + slog.InfoContext(tctx, "Graceful shutdown failed, attempting fast shutdown") app.Close() } } diff --git a/integrations/lib/tctl/tctl.go b/integrations/lib/tctl/tctl.go index 25e7e5e95e0da..5fa0a3252b45b 100644 --- a/integrations/lib/tctl/tctl.go +++ b/integrations/lib/tctl/tctl.go @@ -27,6 +27,7 @@ import ( "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/integrations/lib/logger" + logutils "github.com/gravitational/teleport/lib/utils/log" ) var regexpStatusCAPin = regexp.MustCompile(`CA pin +(sha256:[a-zA-Z0-9]+)`) @@ -59,10 +60,14 @@ func (tctl Tctl) Sign(ctx context.Context, username, format, outPath string) err outPath, ) cmd := exec.CommandContext(ctx, tctl.cmd(), args...) - log.Debugf("Running %s", cmd) + log.DebugContext(ctx, "Running tctl auth sign", "command", logutils.StringerAttr(cmd)) output, err := cmd.CombinedOutput() if err != nil { - log.WithError(err).WithField("args", args).Debug("tctl auth sign failed:", string(output)) + log.DebugContext(ctx, "tctl auth sign failed", + "error", err, + "args", args, + "command_output", string(output), + ) return trace.Wrap(err, "tctl auth sign failed") } return nil @@ -73,7 +78,7 @@ func (tctl Tctl) Create(ctx context.Context, resources []types.Resource) error { log := logger.Get(ctx) args := append(tctl.baseArgs(), "create") cmd := exec.CommandContext(ctx, tctl.cmd(), args...) - log.Debugf("Running %s", cmd) + log.DebugContext(ctx, "Running tctl create", "command", logutils.StringerAttr(cmd)) stdinPipe, err := cmd.StdinPipe() if err != nil { return trace.Wrap(err, "failed to get stdin pipe") @@ -81,16 +86,19 @@ func (tctl Tctl) Create(ctx context.Context, resources []types.Resource) error { go func() { defer func() { if err := stdinPipe.Close(); err != nil { - log.WithError(trace.Wrap(err)).Error("Failed to close stdin pipe") + log.ErrorContext(ctx, "Failed to close stdin pipe", "error", err) } }() if err := writeResourcesYAML(stdinPipe, resources); err != nil { - log.WithError(trace.Wrap(err)).Error("Failed to serialize resources stdin") + log.ErrorContext(ctx, "Failed to serialize resources stdin", "error", err) } }() output, err := cmd.CombinedOutput() if err != nil { - log.WithError(err).Debug("tctl create failed:", string(output)) + log.DebugContext(ctx, "tctl create failed", + "error", err, + "command_output", string(output), + ) return trace.Wrap(err, "tctl create failed") } return nil @@ -102,7 +110,7 @@ func (tctl Tctl) GetAll(ctx context.Context, query string) ([]types.Resource, er args := append(tctl.baseArgs(), "get", query) cmd := exec.CommandContext(ctx, tctl.cmd(), args...) - log.Debugf("Running %s", cmd) + log.DebugContext(ctx, "Running tctl get", "command", logutils.StringerAttr(cmd)) stdoutPipe, err := cmd.StdoutPipe() if err != nil { return nil, trace.Wrap(err, "failed to get stdout") @@ -140,7 +148,7 @@ func (tctl Tctl) GetCAPin(ctx context.Context) (string, error) { args := append(tctl.baseArgs(), "status") cmd := exec.CommandContext(ctx, tctl.cmd(), args...) - log.Debugf("Running %s", cmd) + log.DebugContext(ctx, "Running tctl status", "command", logutils.StringerAttr(cmd)) output, err := cmd.Output() if err != nil { return "", trace.Wrap(err, "failed to get auth status") diff --git a/integrations/lib/testing/integration/suite.go b/integrations/lib/testing/integration/suite.go index 22c0754f66a3b..c0f03c647ef75 100644 --- a/integrations/lib/testing/integration/suite.go +++ b/integrations/lib/testing/integration/suite.go @@ -93,7 +93,7 @@ func (s *Suite) initContexts(oldT *testing.T, newT *testing.T) { } else { baseCtx = context.Background() } - baseCtx, _ = logger.WithField(baseCtx, "test", newT.Name()) + baseCtx, _ = logger.With(baseCtx, "test", newT.Name()) baseCtx, cancel := context.WithCancel(baseCtx) newT.Cleanup(cancel) @@ -163,7 +163,7 @@ func (s *Suite) StartApp(app AppI) { if err := app.Run(ctx); err != nil { // We're in a goroutine so we can't just require.NoError(t, err). // All we can do is to log an error. - logger.Get(ctx).WithError(err).Error("Application failed") + logger.Get(ctx).ErrorContext(ctx, "Application failed", "error", err) } }() diff --git a/integrations/lib/watcherjob/watcherjob.go b/integrations/lib/watcherjob/watcherjob.go index 2999b86aaad0b..a7d2d14482ae6 100644 --- a/integrations/lib/watcherjob/watcherjob.go +++ b/integrations/lib/watcherjob/watcherjob.go @@ -130,23 +130,23 @@ func newJobWithEvents(events types.Events, config Config, fn EventFunc, watchIni if config.FailFast { return trace.WrapWithMessage(err, "Connection problem detected. Exiting as fail fast is on.") } - log.WithError(err).Error("Connection problem detected. Attempting to reconnect.") + log.ErrorContext(ctx, "Connection problem detected, attempting to reconnect", "error", err) case errors.Is(err, io.EOF): if config.FailFast { return trace.WrapWithMessage(err, "Watcher stream closed. Exiting as fail fast is on.") } - log.WithError(err).Error("Watcher stream closed. Attempting to reconnect.") + log.ErrorContext(ctx, "Watcher stream closed attempting to reconnect", "error", err) case lib.IsCanceled(err): - log.Debug("Watcher context is canceled") + log.DebugContext(ctx, "Watcher context is canceled") return trace.Wrap(err) default: - log.WithError(err).Error("Watcher event loop failed") + log.ErrorContext(ctx, "Watcher event loop failed", "error", err) return trace.Wrap(err) } // To mitigate a potentially aggressive retry loop, we wait if err := bk.Do(ctx); err != nil { - log.Debug("Watcher context was canceled while waiting before a reconnection") + log.DebugContext(ctx, "Watcher context was canceled while waiting before a reconnection") return trace.Wrap(err) } } @@ -162,7 +162,7 @@ func (job job) watchEvents(ctx context.Context) error { } defer func() { if err := watcher.Close(); err != nil { - logger.Get(ctx).WithError(err).Error("Failed to close a watcher") + logger.Get(ctx).ErrorContext(ctx, "Failed to close a watcher", "error", err) } }() @@ -170,7 +170,7 @@ func (job job) watchEvents(ctx context.Context) error { return trace.Wrap(err) } - logger.Get(ctx).Debug("Watcher connected") + logger.Get(ctx).DebugContext(ctx, "Watcher connected") job.SetReady(true) for { @@ -253,7 +253,7 @@ func (job job) eventLoop(ctx context.Context) error { event := *eventPtr resource := event.Resource if resource == nil { - log.Error("received an event with empty resource field") + log.ErrorContext(ctx, "received an event with empty resource field") } key := eventKey{kind: resource.GetKind(), name: resource.GetName()} if queue, loaded := queues[key]; loaded { diff --git a/integrations/operator/crdgen/cmd/protoc-gen-crd-docs/debug.go b/integrations/operator/crdgen/cmd/protoc-gen-crd-docs/debug.go index b1c7c7339c4ba..585c82058d5fb 100644 --- a/integrations/operator/crdgen/cmd/protoc-gen-crd-docs/debug.go +++ b/integrations/operator/crdgen/cmd/protoc-gen-crd-docs/debug.go @@ -21,38 +21,37 @@ package main import ( + "context" + "log/slog" "os" - "github.com/gravitational/trace" - log "github.com/sirupsen/logrus" - crdgen "github.com/gravitational/teleport/integrations/operator/crdgen" + logutils "github.com/gravitational/teleport/lib/utils/log" ) func main() { - log.SetLevel(log.DebugLevel) - log.SetOutput(os.Stderr) + slog.SetDefault(slog.New(logutils.NewSlogTextHandler(os.Stderr, + logutils.SlogTextHandlerConfig{ + Level: slog.LevelDebug, + }, + ))) + ctx := context.Background() inputPath := os.Getenv(crdgen.PluginInputPathEnvironment) if inputPath == "" { - log.Error( - trace.BadParameter( - "When built with the 'debug' tag, the input path must be set through the environment variable: %s", - crdgen.PluginInputPathEnvironment, - ), - ) + slog.ErrorContext(ctx, "When built with the 'debug' tag, the input path must be set through the TELEPORT_PROTOC_READ_FILE environment variable") os.Exit(-1) } - log.Infof("This is a debug build, the protoc request is read from the file: '%s'", inputPath) + slog.InfoContext(ctx, "This is a debug build, the protoc request is read from the file", "input_path", inputPath) req, err := crdgen.ReadRequestFromFile(inputPath) if err != nil { - log.WithError(err).Error("error reading request from file") + slog.ErrorContext(ctx, "error reading request from file", "error", err) os.Exit(-1) } if err := crdgen.HandleDocsRequest(req); err != nil { - log.WithError(err).Error("Failed to generate docs") + slog.ErrorContext(ctx, "Failed to generate docs", "error", err) os.Exit(-1) } } diff --git a/integrations/operator/crdgen/cmd/protoc-gen-crd-docs/main.go b/integrations/operator/crdgen/cmd/protoc-gen-crd-docs/main.go index e091e5a8c1d0f..ac1be771b0bf0 100644 --- a/integrations/operator/crdgen/cmd/protoc-gen-crd-docs/main.go +++ b/integrations/operator/crdgen/cmd/protoc-gen-crd-docs/main.go @@ -21,20 +21,26 @@ package main import ( + "context" + "log/slog" "os" "github.com/gogo/protobuf/vanity/command" - log "github.com/sirupsen/logrus" crdgen "github.com/gravitational/teleport/integrations/operator/crdgen" + logutils "github.com/gravitational/teleport/lib/utils/log" ) func main() { - log.SetLevel(log.DebugLevel) - log.SetOutput(os.Stderr) + slog.SetDefault(slog.New(logutils.NewSlogTextHandler(os.Stderr, + logutils.SlogTextHandlerConfig{ + Level: slog.LevelDebug, + }, + ))) + req := command.Read() if err := crdgen.HandleDocsRequest(req); err != nil { - log.WithError(err).Error("Failed to generate schema") + slog.ErrorContext(context.Background(), "Failed to generate schema", "error", err) os.Exit(-1) } } diff --git a/integrations/operator/crdgen/cmd/protoc-gen-crd/debug.go b/integrations/operator/crdgen/cmd/protoc-gen-crd/debug.go index bf19cf7eaca87..2da3e47ab9ec8 100644 --- a/integrations/operator/crdgen/cmd/protoc-gen-crd/debug.go +++ b/integrations/operator/crdgen/cmd/protoc-gen-crd/debug.go @@ -21,38 +21,37 @@ package main import ( + "context" + "log/slog" "os" - "github.com/gravitational/trace" - log "github.com/sirupsen/logrus" - crdgen "github.com/gravitational/teleport/integrations/operator/crdgen" + logutils "github.com/gravitational/teleport/lib/utils/log" ) func main() { - log.SetLevel(log.DebugLevel) - log.SetOutput(os.Stderr) + slog.SetDefault(slog.New(logutils.NewSlogTextHandler(os.Stderr, + logutils.SlogTextHandlerConfig{ + Level: slog.LevelDebug, + }, + ))) + ctx := context.Background() inputPath := os.Getenv(crdgen.PluginInputPathEnvironment) if inputPath == "" { - log.Error( - trace.BadParameter( - "When built with the 'debug' tag, the input path must be set through the environment variable: %s", - crdgen.PluginInputPathEnvironment, - ), - ) + slog.ErrorContext(ctx, "When built with the 'debug' tag, the input path must be set through the TELEPORT_PROTOC_READ_FILE environment variable") os.Exit(-1) } - log.Infof("This is a debug build, the protoc request is read from the file: '%s'", inputPath) + slog.InfoContext(ctx, "This is a debug build, the protoc request is read from the file", "input_path", inputPath) req, err := crdgen.ReadRequestFromFile(inputPath) if err != nil { - log.WithError(err).Error("error reading request from file") + slog.ErrorContext(ctx, "error reading request from file", "error", err) os.Exit(-1) } if err := crdgen.HandleCRDRequest(req); err != nil { - log.WithError(err).Error("Failed to generate schema") + slog.ErrorContext(ctx, "Failed to generate schema", "error", err) os.Exit(-1) } } diff --git a/integrations/operator/crdgen/cmd/protoc-gen-crd/main.go b/integrations/operator/crdgen/cmd/protoc-gen-crd/main.go index 863af95862505..a557993626415 100644 --- a/integrations/operator/crdgen/cmd/protoc-gen-crd/main.go +++ b/integrations/operator/crdgen/cmd/protoc-gen-crd/main.go @@ -21,20 +21,26 @@ package main import ( + "context" + "log/slog" "os" "github.com/gogo/protobuf/vanity/command" - log "github.com/sirupsen/logrus" crdgen "github.com/gravitational/teleport/integrations/operator/crdgen" + logutils "github.com/gravitational/teleport/lib/utils/log" ) func main() { - log.SetLevel(log.DebugLevel) - log.SetOutput(os.Stderr) + slog.SetDefault(slog.New(logutils.NewSlogTextHandler(os.Stderr, + logutils.SlogTextHandlerConfig{ + Level: slog.LevelDebug, + }, + ))) + req := command.Read() if err := crdgen.HandleCRDRequest(req); err != nil { - log.WithError(err).Error("Failed to generate schema") + slog.ErrorContext(context.Background(), "Failed to generate schema", "error", err) os.Exit(-1) } } diff --git a/integrations/terraform/go.mod b/integrations/terraform/go.mod index d3240ffff8135..5222dc914a105 100644 --- a/integrations/terraform/go.mod +++ b/integrations/terraform/go.mod @@ -21,7 +21,6 @@ require ( github.com/hashicorp/terraform-plugin-log v0.9.0 github.com/hashicorp/terraform-plugin-sdk/v2 v2.10.1 github.com/jonboulle/clockwork v0.4.0 - github.com/sirupsen/logrus v1.9.3 github.com/stretchr/testify v1.10.0 google.golang.org/grpc v1.69.2 google.golang.org/protobuf v1.36.2 @@ -307,6 +306,7 @@ require ( github.com/shirou/gopsutil/v4 v4.24.12 // indirect github.com/shopspring/decimal v1.4.0 // indirect github.com/sijms/go-ora/v2 v2.8.22 // indirect + github.com/sirupsen/logrus v1.9.3 // indirect github.com/spf13/cast v1.7.0 // indirect github.com/spf13/cobra v1.8.1 // indirect github.com/spf13/pflag v1.0.5 // indirect diff --git a/integrations/terraform/provider/errors.go b/integrations/terraform/provider/errors.go index d31715366d192..6c0f838b474bf 100644 --- a/integrations/terraform/provider/errors.go +++ b/integrations/terraform/provider/errors.go @@ -17,9 +17,11 @@ limitations under the License. package provider import ( + "context" + "log/slog" + "github.com/gravitational/trace" "github.com/hashicorp/terraform-plugin-framework/diag" - log "github.com/sirupsen/logrus" ) // diagFromWrappedErr wraps error with additional information @@ -43,7 +45,7 @@ func diagFromWrappedErr(summary string, err error, kind string) diag.Diagnostic // diagFromErr converts error to diag.Diagnostics. If logging level is debug, provides trace.DebugReport instead of short text. func diagFromErr(summary string, err error) diag.Diagnostic { - if log.GetLevel() >= log.DebugLevel { + if slog.Default().Enabled(context.Background(), slog.LevelDebug) { return diag.NewErrorDiagnostic(err.Error(), trace.DebugReport(err)) } diff --git a/integrations/terraform/provider/provider.go b/integrations/terraform/provider/provider.go index 13b20d20c434f..99d460a49f806 100644 --- a/integrations/terraform/provider/provider.go +++ b/integrations/terraform/provider/provider.go @@ -19,6 +19,7 @@ package provider import ( "context" "fmt" + "log/slog" "net" "os" "strconv" @@ -29,13 +30,13 @@ import ( "github.com/hashicorp/terraform-plugin-framework/diag" "github.com/hashicorp/terraform-plugin-framework/tfsdk" "github.com/hashicorp/terraform-plugin-framework/types" - log "github.com/sirupsen/logrus" "google.golang.org/grpc" "google.golang.org/grpc/grpclog" "github.com/gravitational/teleport/api/client" "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/lib/utils" + logutils "github.com/gravitational/teleport/lib/utils/log" ) const ( @@ -305,7 +306,7 @@ func (p *Provider) Configure(ctx context.Context, req tfsdk.ConfigureProviderReq return } - log.WithFields(log.Fields{"addr": addr}).Debug("Using Teleport address") + slog.DebugContext(ctx, "Using Teleport address", "addr", addr) dialTimeoutDuration, err := time.ParseDuration(dialTimeoutDurationStr) if err != nil { @@ -393,7 +394,7 @@ func (p *Provider) Configure(ctx context.Context, req tfsdk.ConfigureProviderReq // checkTeleportVersion ensures that Teleport version is at least minServerVersion func (p *Provider) checkTeleportVersion(ctx context.Context, client *client.Client, resp *tfsdk.ConfigureProviderResponse) bool { - log.Debug("Checking Teleport server version") + slog.DebugContext(ctx, "Checking Teleport server version") pong, err := client.Ping(ctx) if err != nil { if trace.IsNotImplemented(err) { @@ -403,13 +404,13 @@ func (p *Provider) checkTeleportVersion(ctx context.Context, client *client.Clie ) return false } - log.WithError(err).Debug("Teleport version check error!") + slog.DebugContext(ctx, "Teleport version check error", "error", err) resp.Diagnostics.AddError("Unable to get Teleport server version!", "Unable to get Teleport server version!") return false } err = utils.CheckMinVersion(pong.ServerVersion, minServerVersion) if err != nil { - log.WithError(err).Debug("Teleport version check error!") + slog.DebugContext(ctx, "Teleport version check error", "error", err) resp.Diagnostics.AddError("Teleport version check error!", err.Error()) return false } @@ -447,7 +448,7 @@ func (p *Provider) validateAddr(addr string, resp *tfsdk.ConfigureProviderRespon _, _, err := net.SplitHostPort(addr) if err != nil { - log.WithField("addr", addr).WithError(err).Debug("Teleport address format error!") + slog.DebugContext(context.Background(), "Teleport address format error", "error", err, "addr", addr) resp.Diagnostics.AddError( "Invalid Teleport address format", fmt.Sprintf("Teleport address must be specified as host:port. Got %q", addr), @@ -461,20 +462,32 @@ func (p *Provider) validateAddr(addr string, resp *tfsdk.ConfigureProviderRespon // configureLog configures logging func (p *Provider) configureLog() { + level := slog.LevelError // Get Terraform log level - level, err := log.ParseLevel(os.Getenv("TF_LOG")) - if err != nil { - log.SetLevel(log.ErrorLevel) - } else { - log.SetLevel(level) + switch strings.ToLower(os.Getenv("TF_LOG")) { + case "panic", "fatal", "error": + level = slog.LevelError + case "warn", "warning": + level = slog.LevelWarn + case "info": + level = slog.LevelInfo + case "debug": + level = slog.LevelDebug + case "trace": + level = logutils.TraceLevel } - log.SetFormatter(&log.TextFormatter{}) + _, _, err := logutils.Initialize(logutils.Config{ + Severity: level.String(), + Format: "text", + }) + if err != nil { + return + } // Show GRPC debug logs only if TF_LOG=DEBUG - if log.GetLevel() >= log.DebugLevel { - l := grpclog.NewLoggerV2(log.StandardLogger().Out, log.StandardLogger().Out, log.StandardLogger().Out) - grpclog.SetLoggerV2(l) + if level <= slog.LevelDebug { + grpclog.SetLoggerV2(grpclog.NewLoggerV2(os.Stderr, os.Stderr, os.Stderr)) } } diff --git a/lib/utils/log/log.go b/lib/utils/log/log.go index 2f16b902e3df6..d8aadb75146bf 100644 --- a/lib/utils/log/log.go +++ b/lib/utils/log/log.go @@ -42,6 +42,8 @@ type Config struct { ExtraFields []string // EnableColors dictates if output should be colored. EnableColors bool + // Padding to use for various components. + Padding int } // Initialize configures the default global logger based on the @@ -112,6 +114,7 @@ func Initialize(loggerConfig Config) (*slog.Logger, *slog.LevelVar, error) { Level: level, EnableColors: loggerConfig.EnableColors, ConfiguredFields: configuredFields, + Padding: loggerConfig.Padding, })) slog.SetDefault(logger) case "json": From 883c53a53cadf5885fb8eafec3e0f0f89c2f65a7 Mon Sep 17 00:00:00 2001 From: rosstimothy <39066650+rosstimothy@users.noreply.github.com> Date: Fri, 10 Jan 2025 11:58:56 -0500 Subject: [PATCH 7/8] Remove logrus dependency (#50930) This is the _last_ step required to migrate from logrus to slog. All components in the repository have been migrated to use slog allowing the logrus formatter to be deleted. The slog handler tests that validate the output have been updated to assert the format directly instead of comparing it to the output from the logrus formatter. --- .golangci.yml | 8 - go.mod | 2 +- lib/client/api.go | 2 +- lib/srv/desktop/rdp/rdpclient/client.go | 2 +- lib/utils/cli.go | 152 +-------- lib/utils/log/formatter_test.go | 420 +++++++++-------------- lib/utils/log/logrus_formatter.go | 427 ------------------------ lib/utils/log/slog.go | 20 -- lib/utils/log/slog_text_handler.go | 92 ++--- lib/utils/log/writer.go | 45 --- 10 files changed, 232 insertions(+), 938 deletions(-) delete mode 100644 lib/utils/log/logrus_formatter.go delete mode 100644 lib/utils/log/writer.go diff --git a/.golangci.yml b/.golangci.yml index 98859bad6c7d9..ecc5e7c8e253f 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -115,14 +115,6 @@ linters-settings: desc: 'use "crypto" or "x/crypto" instead' # Prevent importing any additional logging libraries. logging: - files: - # Integrations are still allowed to use logrus becuase they haven't - # been converted to slog yet. Once they use slog, remove this exception. - - '!**/integrations/**' - # The log package still contains the logrus formatter consumed by the integrations. - # Remove this exception when said formatter is deleted. - - '!**/lib/utils/log/**' - - '!**/lib/utils/cli.go' deny: - pkg: github.com/sirupsen/logrus desc: 'use "log/slog" instead' diff --git a/go.mod b/go.mod index 3c35132910093..78f04732806b6 100644 --- a/go.mod +++ b/go.mod @@ -179,7 +179,6 @@ require ( github.com/sigstore/cosign/v2 v2.4.1 github.com/sigstore/sigstore v1.8.11 github.com/sijms/go-ora/v2 v2.8.22 - github.com/sirupsen/logrus v1.9.3 github.com/snowflakedb/gosnowflake v1.12.1 github.com/spf13/cobra v1.8.1 github.com/spiffe/go-spiffe/v2 v2.4.0 @@ -501,6 +500,7 @@ require ( github.com/sigstore/protobuf-specs v0.3.2 // indirect github.com/sigstore/rekor v1.3.6 // indirect github.com/sigstore/timestamp-authority v1.2.2 // indirect + github.com/sirupsen/logrus v1.9.3 // indirect github.com/sourcegraph/conc v0.3.0 // indirect github.com/spf13/afero v1.11.0 // indirect github.com/spf13/cast v1.7.0 // indirect diff --git a/lib/client/api.go b/lib/client/api.go index ed94462aa9c73..8b4c317265573 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -2853,7 +2853,7 @@ type execResult struct { // sharedWriter is an [io.Writer] implementation that protects // writes with a mutex. This allows a single [io.Writer] to be shared -// by both logrus and slog without their output clobbering each other. +// by multiple command runners. type sharedWriter struct { mu sync.Mutex io.Writer diff --git a/lib/srv/desktop/rdp/rdpclient/client.go b/lib/srv/desktop/rdp/rdpclient/client.go index 821408d2208fa..534644e6be1df 100644 --- a/lib/srv/desktop/rdp/rdpclient/client.go +++ b/lib/srv/desktop/rdp/rdpclient/client.go @@ -93,7 +93,7 @@ func init() { var rustLogLevel string // initialize the Rust logger by setting $RUST_LOG based - // on the logrus log level + // on the slog log level // (unless RUST_LOG is already explicitly set, then we // assume the user knows what they want) rl := os.Getenv("RUST_LOG") diff --git a/lib/utils/cli.go b/lib/utils/cli.go index e79c0bc2aa8f0..648cf7095352f 100644 --- a/lib/utils/cli.go +++ b/lib/utils/cli.go @@ -26,7 +26,6 @@ import ( "flag" "fmt" "io" - stdlog "log" "log/slog" "os" "runtime" @@ -38,7 +37,6 @@ import ( "github.com/alecthomas/kingpin/v2" "github.com/gravitational/trace" - "github.com/sirupsen/logrus" "golang.org/x/term" "github.com/gravitational/teleport" @@ -100,59 +98,18 @@ func InitLogger(purpose LoggingPurpose, level slog.Level, opts ...LoggerOption) opt(&o) } - logrus.StandardLogger().ReplaceHooks(make(logrus.LevelHooks)) - logrus.SetLevel(logutils.SlogLevelToLogrusLevel(level)) - - var ( - w io.Writer - enableColors bool - ) - switch purpose { - case LoggingForCLI: - // If debug logging was asked for on the CLI, then write logs to stderr. - // Otherwise, discard all logs. - if level == slog.LevelDebug { - enableColors = IsTerminal(os.Stderr) - w = logutils.NewSharedWriter(os.Stderr) - } else { - w = io.Discard - enableColors = false - } - case LoggingForDaemon: - enableColors = IsTerminal(os.Stderr) - w = logutils.NewSharedWriter(os.Stderr) - } - - var ( - formatter logrus.Formatter - handler slog.Handler - ) - switch o.format { - case LogFormatText, "": - textFormatter := logutils.NewDefaultTextFormatter(enableColors) - - // Calling CheckAndSetDefaults enables the timestamp field to - // be included in the output. The error returned is ignored - // because the default formatter cannot be invalid. - if purpose == LoggingForCLI && level == slog.LevelDebug { - _ = textFormatter.CheckAndSetDefaults() - } - - formatter = textFormatter - handler = logutils.NewSlogTextHandler(w, logutils.SlogTextHandlerConfig{ - Level: level, - EnableColors: enableColors, - }) - case LogFormatJSON: - formatter = &logutils.JSONFormatter{} - handler = logutils.NewSlogJSONHandler(w, logutils.SlogJSONHandlerConfig{ - Level: level, - }) + // If debug or trace logging is not enabled for CLIs, + // then discard all log output. + if purpose == LoggingForCLI && level > slog.LevelDebug { + slog.SetDefault(slog.New(logutils.DiscardHandler{})) + return } - logrus.SetFormatter(formatter) - logrus.SetOutput(w) - slog.SetDefault(slog.New(handler)) + logutils.Initialize(logutils.Config{ + Severity: level.String(), + Format: o.format, + EnableColors: IsTerminal(os.Stderr), + }) } var initTestLoggerOnce = sync.Once{} @@ -163,56 +120,24 @@ func InitLoggerForTests() { // Parse flags to check testing.Verbose(). flag.Parse() - level := slog.LevelWarn - w := io.Discard - if testing.Verbose() { - level = slog.LevelDebug - w = os.Stderr + if !testing.Verbose() { + slog.SetDefault(slog.New(logutils.DiscardHandler{})) + return } - logger := logrus.StandardLogger() - logger.SetFormatter(logutils.NewTestJSONFormatter()) - logger.SetLevel(logutils.SlogLevelToLogrusLevel(level)) - - output := logutils.NewSharedWriter(w) - logger.SetOutput(output) - slog.SetDefault(slog.New(logutils.NewSlogJSONHandler(output, logutils.SlogJSONHandlerConfig{Level: level}))) + logutils.Initialize(logutils.Config{ + Severity: slog.LevelDebug.String(), + Format: LogFormatJSON, + }) }) } -// NewLoggerForTests creates a new logrus logger for test environments. -func NewLoggerForTests() *logrus.Logger { - InitLoggerForTests() - return logrus.StandardLogger() -} - // NewSlogLoggerForTests creates a new slog logger for test environments. func NewSlogLoggerForTests() *slog.Logger { InitLoggerForTests() return slog.Default() } -// WrapLogger wraps an existing logger entry and returns -// a value satisfying the Logger interface -func WrapLogger(logger *logrus.Entry) Logger { - return &logWrapper{Entry: logger} -} - -// NewLogger creates a new empty logrus logger. -func NewLogger() *logrus.Logger { - return logrus.StandardLogger() -} - -// Logger describes a logger value -type Logger interface { - logrus.FieldLogger - // GetLevel specifies the level at which this logger - // value is logging - GetLevel() logrus.Level - // SetLevel sets the logger's level to the specified value - SetLevel(level logrus.Level) -} - // FatalError is for CLI front-ends: it detects gravitational/trace debugging // information, sends it to the logger, strips it off and prints a clean message to stderr func FatalError(err error) { @@ -231,7 +156,7 @@ func GetIterations() int { if err != nil { panic(err) } - logrus.Debugf("Starting tests with %v iterations.", iter) + slog.DebugContext(context.Background(), "Running tests multiple times due to presence of ITERATIONS environment variable", "iterations", iter) return iter } @@ -484,47 +409,6 @@ func AllowWhitespace(s string) string { return sb.String() } -// NewStdlogger creates a new stdlib logger that uses the specified leveled logger -// for output and the given component as a logging prefix. -func NewStdlogger(logger LeveledOutputFunc, component string) *stdlog.Logger { - return stdlog.New(&stdlogAdapter{ - log: logger, - }, component, stdlog.LstdFlags) -} - -// Write writes the specified buffer p to the underlying leveled logger. -// Implements io.Writer -func (r *stdlogAdapter) Write(p []byte) (n int, err error) { - r.log(string(p)) - return len(p), nil -} - -// stdlogAdapter is an io.Writer that writes into an instance -// of logrus.Logger -type stdlogAdapter struct { - log LeveledOutputFunc -} - -// LeveledOutputFunc describes a function that emits given -// arguments at a specific level to an underlying logger -type LeveledOutputFunc func(args ...interface{}) - -// GetLevel returns the level of the underlying logger -func (r *logWrapper) GetLevel() logrus.Level { - return r.Entry.Logger.GetLevel() -} - -// SetLevel sets the logging level to the given value -func (r *logWrapper) SetLevel(level logrus.Level) { - r.Entry.Logger.SetLevel(level) -} - -// logWrapper wraps a log entry. -// Implements Logger -type logWrapper struct { - *logrus.Entry -} - // needsQuoting returns true if any non-printable characters are found. func needsQuoting(text string) bool { for _, r := range text { diff --git a/lib/utils/log/formatter_test.go b/lib/utils/log/formatter_test.go index 9abb0310ba0be..aff0ec8be3a74 100644 --- a/lib/utils/log/formatter_test.go +++ b/lib/utils/log/formatter_test.go @@ -22,7 +22,6 @@ import ( "bytes" "context" "encoding/json" - "errors" "fmt" "io" "log/slog" @@ -38,7 +37,6 @@ import ( "github.com/google/go-cmp/cmp/cmpopts" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" - "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -48,7 +46,7 @@ import ( const message = "Adding diagnostic debugging handlers.\t To connect with profiler, use go tool pprof diag_addr." var ( - logErr = errors.New("the quick brown fox jumped really high") + logErr = &trace.BadParameterError{Message: "the quick brown fox jumped really high"} addr = fakeAddr{addr: "127.0.0.1:1234"} fields = map[string]any{ @@ -72,6 +70,10 @@ func (a fakeAddr) String() string { return a.addr } +func (a fakeAddr) MarshalText() (text []byte, err error) { + return []byte(a.addr), nil +} + func TestOutput(t *testing.T) { loc, err := time.LoadLocation("Africa/Cairo") require.NoError(t, err, "failed getting timezone") @@ -89,58 +91,50 @@ func TestOutput(t *testing.T) { // 4) the caller outputRegex := regexp.MustCompile(`(\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{3}Z)(\s+.*)(".*diag_addr\.")(.*)(\slog/formatter_test.go:\d{3})`) + expectedFields := map[string]string{ + "local": addr.String(), + "remote": addr.String(), + "login": "llama", + "teleportUser": "user", + "id": "1234", + "test": "123", + "animal": `"llama\n"`, + "error": "[" + trace.DebugReport(logErr) + "]", + "diag_addr": addr.String(), + } + tests := []struct { - name string - logrusLevel logrus.Level - slogLevel slog.Level + name string + slogLevel slog.Level }{ { - name: "trace", - logrusLevel: logrus.TraceLevel, - slogLevel: TraceLevel, + name: "trace", + slogLevel: TraceLevel, }, { - name: "debug", - logrusLevel: logrus.DebugLevel, - slogLevel: slog.LevelDebug, + name: "debug", + slogLevel: slog.LevelDebug, }, { - name: "info", - logrusLevel: logrus.InfoLevel, - slogLevel: slog.LevelInfo, + name: "info", + slogLevel: slog.LevelInfo, }, { - name: "warn", - logrusLevel: logrus.WarnLevel, - slogLevel: slog.LevelWarn, + name: "warn", + slogLevel: slog.LevelWarn, }, { - name: "error", - logrusLevel: logrus.ErrorLevel, - slogLevel: slog.LevelError, + name: "error", + slogLevel: slog.LevelError, }, { - name: "fatal", - logrusLevel: logrus.FatalLevel, - slogLevel: slog.LevelError + 1, + name: "fatal", + slogLevel: slog.LevelError + 1, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - // Create a logrus logger using the custom formatter which outputs to a local buffer. - var logrusOutput bytes.Buffer - formatter := NewDefaultTextFormatter(true) - formatter.timestampEnabled = true - require.NoError(t, formatter.CheckAndSetDefaults()) - - logrusLogger := logrus.New() - logrusLogger.SetFormatter(formatter) - logrusLogger.SetOutput(&logrusOutput) - logrusLogger.ReplaceHooks(logrus.LevelHooks{}) - logrusLogger.SetLevel(test.logrusLevel) - entry := logrusLogger.WithField(teleport.ComponentKey, "test").WithTime(clock.Now().UTC()) - // Create a slog logger using the custom handler which outputs to a local buffer. var slogOutput bytes.Buffer slogConfig := SlogTextHandlerConfig{ @@ -155,13 +149,6 @@ func TestOutput(t *testing.T) { } slogLogger := slog.New(NewSlogTextHandler(&slogOutput, slogConfig)).With(teleport.ComponentKey, "test") - // Add some fields and output the message at the desired log level via logrus. - l := entry.WithField("test", 123).WithField("animal", "llama\n").WithField("error", logErr) - logrusTestLogLineNumber := func() int { - l.WithField("diag_addr", &addr).WithField(teleport.ComponentFields, fields).Log(test.logrusLevel, message) - return getCallerLineNumber() - 1 // Get the line number of this call, and assume the log call is right above it - }() - // Add some fields and output the message at the desired log level via slog. l2 := slogLogger.With("test", 123).With("animal", "llama\n").With("error", logErr) slogTestLogLineNumber := func() int { @@ -169,163 +156,144 @@ func TestOutput(t *testing.T) { return getCallerLineNumber() - 1 // Get the line number of this call, and assume the log call is right above it }() - // Validate that both loggers produces the same output. The added complexity comes from the fact that - // our custom slog handler does NOT sort the additional fields like our logrus formatter does. - logrusMatches := outputRegex.FindStringSubmatch(logrusOutput.String()) - require.NotEmpty(t, logrusMatches, "logrus output was in unexpected format: %s", logrusOutput.String()) + // Validate the logger output. The added complexity comes from the fact that + // our custom slog handler does NOT sort the additional fields. slogMatches := outputRegex.FindStringSubmatch(slogOutput.String()) require.NotEmpty(t, slogMatches, "slog output was in unexpected format: %s", slogOutput.String()) // The first match is the timestamp: 2023-10-31T10:09:06+02:00 - logrusTime, err := time.Parse(time.RFC3339, logrusMatches[1]) - assert.NoError(t, err, "invalid logrus timestamp found %s", logrusMatches[1]) - slogTime, err := time.Parse(time.RFC3339, slogMatches[1]) assert.NoError(t, err, "invalid slog timestamp found %s", slogMatches[1]) - - assert.InDelta(t, logrusTime.Unix(), slogTime.Unix(), 10) + assert.InDelta(t, clock.Now().Unix(), slogTime.Unix(), 10) // Match level, and component: DEBU [TEST] - assert.Empty(t, cmp.Diff(logrusMatches[2], slogMatches[2]), "level, and component to be identical") - // Match the log message: "Adding diagnostic debugging handlers.\t To connect with profiler, use go tool pprof diag_addr.\n" - assert.Empty(t, cmp.Diff(logrusMatches[3], slogMatches[3]), "expected output messages to be identical") + expectedLevel := formatLevel(test.slogLevel, true) + expectedComponent := formatComponent(slog.StringValue("test"), defaultComponentPadding) + expectedMatch := " " + expectedLevel + " " + expectedComponent + " " + assert.Equal(t, expectedMatch, slogMatches[2], "level, and component to be identical") + // Match the log message + assert.Equal(t, `"Adding diagnostic debugging handlers.\t To connect with profiler, use go tool pprof diag_addr."`, slogMatches[3], "expected output messages to be identical") // The last matches are the caller information - assert.Equal(t, fmt.Sprintf(" log/formatter_test.go:%d", logrusTestLogLineNumber), logrusMatches[5]) assert.Equal(t, fmt.Sprintf(" log/formatter_test.go:%d", slogTestLogLineNumber), slogMatches[5]) // The third matches are the fields which will be key value pairs(animal:llama) separated by a space. Since - // logrus sorts the fields and slog doesn't we can't just assert equality and instead build a map of the key + // slog doesn't sort the fields, we can't assert equality and instead build a map of the key // value pairs to ensure they are all present and accounted for. - logrusFieldMatches := fieldsRegex.FindAllStringSubmatch(logrusMatches[4], -1) slogFieldMatches := fieldsRegex.FindAllStringSubmatch(slogMatches[4], -1) // The first match is the key, the second match is the value - logrusFields := map[string]string{} - for _, match := range logrusFieldMatches { - logrusFields[strings.TrimSpace(match[1])] = strings.TrimSpace(match[2]) - } - slogFields := map[string]string{} for _, match := range slogFieldMatches { slogFields[strings.TrimSpace(match[1])] = strings.TrimSpace(match[2]) } - assert.Equal(t, slogFields, logrusFields) + require.Empty(t, + cmp.Diff( + expectedFields, + slogFields, + cmpopts.SortMaps(func(a, b string) bool { return a < b }), + ), + ) }) } }) t.Run("json", func(t *testing.T) { tests := []struct { - name string - logrusLevel logrus.Level - slogLevel slog.Level + name string + slogLevel slog.Level }{ { - name: "trace", - logrusLevel: logrus.TraceLevel, - slogLevel: TraceLevel, + name: "trace", + slogLevel: TraceLevel, }, { - name: "debug", - logrusLevel: logrus.DebugLevel, - slogLevel: slog.LevelDebug, + name: "debug", + slogLevel: slog.LevelDebug, }, { - name: "info", - logrusLevel: logrus.InfoLevel, - slogLevel: slog.LevelInfo, + name: "info", + slogLevel: slog.LevelInfo, }, { - name: "warn", - logrusLevel: logrus.WarnLevel, - slogLevel: slog.LevelWarn, + name: "warn", + slogLevel: slog.LevelWarn, }, { - name: "error", - logrusLevel: logrus.ErrorLevel, - slogLevel: slog.LevelError, + name: "error", + slogLevel: slog.LevelError, }, { - name: "fatal", - logrusLevel: logrus.FatalLevel, - slogLevel: slog.LevelError + 1, + name: "fatal", + slogLevel: slog.LevelError + 1, + }, + } + + expectedFields := map[string]any{ + "trace.fields": map[string]any{ + "teleportUser": "user", + "id": float64(1234), + "local": addr.String(), + "login": "llama", + "remote": addr.String(), }, + "test": float64(123), + "animal": `llama`, + "error": logErr.Error(), + "diag_addr": addr.String(), + "component": "test", + "message": message, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - // Create a logrus logger using the custom formatter which outputs to a local buffer. - var logrusOut bytes.Buffer - formatter := &JSONFormatter{ - ExtraFields: nil, - callerEnabled: true, - } - require.NoError(t, formatter.CheckAndSetDefaults()) - - logrusLogger := logrus.New() - logrusLogger.SetFormatter(formatter) - logrusLogger.SetOutput(&logrusOut) - logrusLogger.ReplaceHooks(logrus.LevelHooks{}) - logrusLogger.SetLevel(test.logrusLevel) - entry := logrusLogger.WithField(teleport.ComponentKey, "test") - // Create a slog logger using the custom formatter which outputs to a local buffer. var slogOutput bytes.Buffer slogLogger := slog.New(NewSlogJSONHandler(&slogOutput, SlogJSONHandlerConfig{Level: test.slogLevel})).With(teleport.ComponentKey, "test") - // Add some fields and output the message at the desired log level via logrus. - l := entry.WithField("test", 123).WithField("animal", "llama").WithField("error", trace.Wrap(logErr)) - logrusTestLogLineNumber := func() int { - l.WithField("diag_addr", addr.String()).Log(test.logrusLevel, message) - return getCallerLineNumber() - 1 // Get the line number of this call, and assume the log call is right above it - }() - // Add some fields and output the message at the desired log level via slog. l2 := slogLogger.With("test", 123).With("animal", "llama").With("error", trace.Wrap(logErr)) slogTestLogLineNumber := func() int { - l2.Log(context.Background(), test.slogLevel, message, "diag_addr", &addr) + l2.With(teleport.ComponentFields, fields).Log(context.Background(), test.slogLevel, message, "diag_addr", &addr) return getCallerLineNumber() - 1 // Get the line number of this call, and assume the log call is right above it }() - // The order of the fields emitted by the two loggers is different, so comparing the output directly - // for equality won't work. Instead, a map is built with all the key value pairs, excluding the caller - // and that map is compared to ensure all items are present and match. - var logrusData map[string]any - require.NoError(t, json.Unmarshal(logrusOut.Bytes(), &logrusData), "invalid logrus output format") - var slogData map[string]any require.NoError(t, json.Unmarshal(slogOutput.Bytes(), &slogData), "invalid slog output format") - logrusCaller, ok := logrusData["caller"].(string) - delete(logrusData, "caller") - assert.True(t, ok, "caller was missing from logrus output") - assert.Equal(t, fmt.Sprintf("log/formatter_test.go:%d", logrusTestLogLineNumber), logrusCaller) - slogCaller, ok := slogData["caller"].(string) delete(slogData, "caller") assert.True(t, ok, "caller was missing from slog output") assert.Equal(t, fmt.Sprintf("log/formatter_test.go:%d", slogTestLogLineNumber), slogCaller) - logrusTimestamp, ok := logrusData["timestamp"].(string) - delete(logrusData, "timestamp") - assert.True(t, ok, "time was missing from logrus output") + slogLevel, ok := slogData["level"].(string) + delete(slogData, "level") + assert.True(t, ok, "level was missing from slog output") + var expectedLevel string + switch test.slogLevel { + case TraceLevel: + expectedLevel = "trace" + case slog.LevelWarn: + expectedLevel = "warning" + case slog.LevelError + 1: + expectedLevel = "fatal" + default: + expectedLevel = test.slogLevel.String() + } + assert.Equal(t, strings.ToLower(expectedLevel), slogLevel) slogTimestamp, ok := slogData["timestamp"].(string) delete(slogData, "timestamp") assert.True(t, ok, "time was missing from slog output") - logrusTime, err := time.Parse(time.RFC3339, logrusTimestamp) - assert.NoError(t, err, "invalid logrus timestamp %s", logrusTimestamp) - slogTime, err := time.Parse(time.RFC3339, slogTimestamp) assert.NoError(t, err, "invalid slog timestamp %s", slogTimestamp) - assert.InDelta(t, logrusTime.Unix(), slogTime.Unix(), 10) + assert.InDelta(t, clock.Now().Unix(), slogTime.Unix(), 10) require.Empty(t, cmp.Diff( - logrusData, + expectedFields, slogData, cmpopts.SortMaps(func(a, b string) bool { return a < b }), ), @@ -347,38 +315,6 @@ func getCallerLineNumber() int { func BenchmarkFormatter(b *testing.B) { ctx := context.Background() b.ReportAllocs() - b.Run("logrus", func(b *testing.B) { - b.Run("text", func(b *testing.B) { - formatter := NewDefaultTextFormatter(true) - require.NoError(b, formatter.CheckAndSetDefaults()) - logger := logrus.New() - logger.SetFormatter(formatter) - logger.SetOutput(io.Discard) - b.ResetTimer() - - entry := logger.WithField(teleport.ComponentKey, "test") - for i := 0; i < b.N; i++ { - l := entry.WithField("test", 123).WithField("animal", "llama\n").WithField("error", logErr) - l.WithField("diag_addr", &addr).WithField(teleport.ComponentFields, fields).Info(message) - } - }) - - b.Run("json", func(b *testing.B) { - formatter := &JSONFormatter{} - require.NoError(b, formatter.CheckAndSetDefaults()) - logger := logrus.New() - logger.SetFormatter(formatter) - logger.SetOutput(io.Discard) - logger.ReplaceHooks(logrus.LevelHooks{}) - b.ResetTimer() - - entry := logger.WithField(teleport.ComponentKey, "test") - for i := 0; i < b.N; i++ { - l := entry.WithField("test", 123).WithField("animal", "llama\n").WithField("error", logErr) - l.WithField("diag_addr", &addr).WithField(teleport.ComponentFields, fields).Info(message) - } - }) - }) b.Run("slog", func(b *testing.B) { b.Run("default_text", func(b *testing.B) { @@ -430,47 +366,26 @@ func BenchmarkFormatter(b *testing.B) { } func TestConcurrentOutput(t *testing.T) { - t.Run("logrus", func(t *testing.T) { - debugFormatter := NewDefaultTextFormatter(true) - require.NoError(t, debugFormatter.CheckAndSetDefaults()) - logrus.SetFormatter(debugFormatter) - logrus.SetOutput(os.Stdout) - - logger := logrus.WithField(teleport.ComponentKey, "test") - - var wg sync.WaitGroup - for i := 0; i < 1000; i++ { - wg.Add(1) - go func(i int) { - defer wg.Done() - logger.Infof("Detected Teleport component %d is running in a degraded state.", i) - }(i) - } - wg.Wait() - }) + logger := slog.New(NewSlogTextHandler(os.Stdout, SlogTextHandlerConfig{ + EnableColors: true, + })).With(teleport.ComponentKey, "test") - t.Run("slog", func(t *testing.T) { - logger := slog.New(NewSlogTextHandler(os.Stdout, SlogTextHandlerConfig{ - EnableColors: true, - })).With(teleport.ComponentKey, "test") - - var wg sync.WaitGroup - ctx := context.Background() - for i := 0; i < 1000; i++ { - wg.Add(1) - go func(i int) { - defer wg.Done() - logger.InfoContext(ctx, "Teleport component entered degraded state", - slog.Int("component", i), - slog.Group("group", - slog.String("test", "123"), - slog.String("animal", "llama"), - ), - ) - }(i) - } - wg.Wait() - }) + var wg sync.WaitGroup + ctx := context.Background() + for i := 0; i < 1000; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + logger.InfoContext(ctx, "Teleport component entered degraded state", + slog.Int("component", i), + slog.Group("group", + slog.String("test", "123"), + slog.String("animal", "llama"), + ), + ) + }(i) + } + wg.Wait() } // allPossibleSubsets returns all combinations of subsets for the @@ -493,58 +408,34 @@ func allPossibleSubsets(in []string) [][]string { return subsets } -// TestExtraFields validates that the output is identical for the -// logrus formatter and slog handler based on the configured extra -// fields. +// TestExtraFields validates that the output is expected for the +// slog handler based on the configured extra fields. func TestExtraFields(t *testing.T) { // Capture a fake time that all output will use. now := clockwork.NewFakeClock().Now() // Capture the caller information to be injected into all messages. pc, _, _, _ := runtime.Caller(0) - fs := runtime.CallersFrames([]uintptr{pc}) - f, _ := fs.Next() - callerTrace := &trace.Trace{ - Func: f.Function, - Path: f.File, - Line: f.Line, - } const message = "testing 123" - // Test against every possible configured combination of allowed format fields. - fields := allPossibleSubsets(defaultFormatFields) - t.Run("text", func(t *testing.T) { - for _, configuredFields := range fields { + // Test against every possible configured combination of allowed format fields. + for _, configuredFields := range allPossibleSubsets(defaultFormatFields) { name := "not configured" if len(configuredFields) > 0 { name = strings.Join(configuredFields, " ") } t.Run(name, func(t *testing.T) { - logrusFormatter := TextFormatter{ - ExtraFields: configuredFields, - } - // Call CheckAndSetDefaults to exercise the extra fields logic. Since - // FormatCaller is always overridden within CheckAndSetDefaults, it is - // explicitly set afterward so the caller points to our fake call site. - require.NoError(t, logrusFormatter.CheckAndSetDefaults()) - logrusFormatter.FormatCaller = callerTrace.String - - var slogOutput bytes.Buffer - var slogHandler slog.Handler = NewSlogTextHandler(&slogOutput, SlogTextHandlerConfig{ConfiguredFields: configuredFields}) - - entry := &logrus.Entry{ - Data: logrus.Fields{"animal": "llama", "vegetable": "carrot", teleport.ComponentKey: "test"}, - Time: now, - Level: logrus.DebugLevel, - Caller: &f, - Message: message, - } - - logrusOut, err := logrusFormatter.Format(entry) - require.NoError(t, err) + replaced := map[string]struct{}{} + var slogHandler slog.Handler = NewSlogTextHandler(io.Discard, SlogTextHandlerConfig{ + ConfiguredFields: configuredFields, + ReplaceAttr: func(groups []string, a slog.Attr) slog.Attr { + replaced[a.Key] = struct{}{} + return a + }, + }) record := slog.Record{ Time: now, @@ -557,42 +448,29 @@ func TestExtraFields(t *testing.T) { require.NoError(t, slogHandler.Handle(context.Background(), record)) - require.Equal(t, string(logrusOut), slogOutput.String()) + for k := range replaced { + delete(replaced, k) + } + + require.Empty(t, replaced, replaced) }) } }) t.Run("json", func(t *testing.T) { - for _, configuredFields := range fields { + // Test against every possible configured combination of allowed format fields. + // Note, the json handler limits the allowed fields to a subset of those allowed + // by the text handler. + for _, configuredFields := range allPossibleSubsets([]string{CallerField, ComponentField, TimestampField}) { name := "not configured" if len(configuredFields) > 0 { name = strings.Join(configuredFields, " ") } t.Run(name, func(t *testing.T) { - logrusFormatter := JSONFormatter{ - ExtraFields: configuredFields, - } - // Call CheckAndSetDefaults to exercise the extra fields logic. Since - // FormatCaller is always overridden within CheckAndSetDefaults, it is - // explicitly set afterward so the caller points to our fake call site. - require.NoError(t, logrusFormatter.CheckAndSetDefaults()) - logrusFormatter.FormatCaller = callerTrace.String - var slogOutput bytes.Buffer var slogHandler slog.Handler = NewSlogJSONHandler(&slogOutput, SlogJSONHandlerConfig{ConfiguredFields: configuredFields}) - entry := &logrus.Entry{ - Data: logrus.Fields{"animal": "llama", "vegetable": "carrot", teleport.ComponentKey: "test"}, - Time: now, - Level: logrus.DebugLevel, - Caller: &f, - Message: message, - } - - logrusOut, err := logrusFormatter.Format(entry) - require.NoError(t, err) - record := slog.Record{ Time: now, Message: message, @@ -604,11 +482,31 @@ func TestExtraFields(t *testing.T) { require.NoError(t, slogHandler.Handle(context.Background(), record)) - var slogData, logrusData map[string]any - require.NoError(t, json.Unmarshal(logrusOut, &logrusData)) + var slogData map[string]any require.NoError(t, json.Unmarshal(slogOutput.Bytes(), &slogData)) - require.Equal(t, slogData, logrusData) + delete(slogData, "animal") + delete(slogData, "vegetable") + delete(slogData, "message") + delete(slogData, "level") + + var expectedLen int + expectedFields := configuredFields + switch l := len(configuredFields); l { + case 0: + // The level field was removed above, but is included in the default fields + expectedLen = len(defaultFormatFields) - 1 + expectedFields = defaultFormatFields + default: + expectedLen = l + } + require.Len(t, slogData, expectedLen, slogData) + + for _, f := range expectedFields { + delete(slogData, f) + } + + require.Empty(t, slogData, slogData) }) } }) diff --git a/lib/utils/log/logrus_formatter.go b/lib/utils/log/logrus_formatter.go deleted file mode 100644 index 14ad8441da7cc..0000000000000 --- a/lib/utils/log/logrus_formatter.go +++ /dev/null @@ -1,427 +0,0 @@ -/* - * Teleport - * Copyright (C) 2023 Gravitational, Inc. - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -package log - -import ( - "fmt" - "regexp" - "runtime" - "slices" - "strconv" - "strings" - - "github.com/gravitational/trace" - "github.com/sirupsen/logrus" - - "github.com/gravitational/teleport" -) - -// TextFormatter is a [logrus.Formatter] that outputs messages in -// a textual format. -type TextFormatter struct { - // ComponentPadding is a padding to pick when displaying - // and formatting component field, defaults to DefaultComponentPadding - ComponentPadding int - // EnableColors enables colored output - EnableColors bool - // FormatCaller is a function to return (part) of source file path for output. - // Defaults to filePathAndLine() if unspecified - FormatCaller func() (caller string) - // ExtraFields represent the extra fields that will be added to the log message - ExtraFields []string - // TimestampEnabled specifies if timestamp is enabled in logs - timestampEnabled bool - // CallerEnabled specifies if caller is enabled in logs - callerEnabled bool -} - -type writer struct { - b *buffer -} - -func newWriter() *writer { - return &writer{b: &buffer{}} -} - -func (w *writer) Len() int { - return len(*w.b) -} - -func (w *writer) WriteString(s string) (int, error) { - return w.b.WriteString(s) -} - -func (w *writer) WriteByte(c byte) error { - return w.b.WriteByte(c) -} - -func (w *writer) Bytes() []byte { - return *w.b -} - -// NewDefaultTextFormatter creates a TextFormatter with -// the default options set. -func NewDefaultTextFormatter(enableColors bool) *TextFormatter { - return &TextFormatter{ - ComponentPadding: defaultComponentPadding, - FormatCaller: formatCallerWithPathAndLine, - ExtraFields: defaultFormatFields, - EnableColors: enableColors, - callerEnabled: true, - timestampEnabled: false, - } -} - -// CheckAndSetDefaults checks and sets log format configuration. -func (tf *TextFormatter) CheckAndSetDefaults() error { - // set padding - if tf.ComponentPadding == 0 { - tf.ComponentPadding = defaultComponentPadding - } - // set caller - tf.FormatCaller = formatCallerWithPathAndLine - - // set log formatting - if tf.ExtraFields == nil { - tf.timestampEnabled = true - tf.callerEnabled = true - tf.ExtraFields = defaultFormatFields - return nil - } - - if slices.Contains(tf.ExtraFields, TimestampField) { - tf.timestampEnabled = true - } - - if slices.Contains(tf.ExtraFields, CallerField) { - tf.callerEnabled = true - } - - return nil -} - -// Format formats each log line as configured in teleport config file. -func (tf *TextFormatter) Format(e *logrus.Entry) ([]byte, error) { - caller := tf.FormatCaller() - w := newWriter() - - // write timestamp first if enabled - if tf.timestampEnabled { - *w.b = appendRFC3339Millis(*w.b, e.Time.Round(0)) - } - - for _, field := range tf.ExtraFields { - switch field { - case LevelField: - var color int - var level string - switch e.Level { - case logrus.TraceLevel: - level = "TRACE" - color = gray - case logrus.DebugLevel: - level = "DEBUG" - color = gray - case logrus.InfoLevel: - level = "INFO" - color = blue - case logrus.WarnLevel: - level = "WARN" - color = yellow - case logrus.ErrorLevel: - level = "ERROR" - color = red - case logrus.FatalLevel: - level = "FATAL" - color = red - default: - color = blue - level = strings.ToUpper(e.Level.String()) - } - - if !tf.EnableColors { - color = noColor - } - - w.writeField(padMax(level, defaultLevelPadding), color) - case ComponentField: - padding := defaultComponentPadding - if tf.ComponentPadding != 0 { - padding = tf.ComponentPadding - } - if w.Len() > 0 { - w.WriteByte(' ') - } - component, ok := e.Data[teleport.ComponentKey].(string) - if ok && component != "" { - component = fmt.Sprintf("[%v]", component) - } - component = strings.ToUpper(padMax(component, padding)) - if component[len(component)-1] != ' ' { - component = component[:len(component)-1] + "]" - } - - w.WriteString(component) - default: - if _, ok := knownFormatFields[field]; !ok { - return nil, trace.BadParameter("invalid log format key: %v", field) - } - } - } - - // always use message - if e.Message != "" { - w.writeField(e.Message, noColor) - } - - if len(e.Data) > 0 { - w.writeMap(e.Data) - } - - // write caller last if enabled - if tf.callerEnabled && caller != "" { - w.writeField(caller, noColor) - } - - w.WriteByte('\n') - return w.Bytes(), nil -} - -// JSONFormatter implements the [logrus.Formatter] interface and adds extra -// fields to log entries. -type JSONFormatter struct { - logrus.JSONFormatter - - ExtraFields []string - // FormatCaller is a function to return (part) of source file path for output. - // Defaults to filePathAndLine() if unspecified - FormatCaller func() (caller string) - - callerEnabled bool - componentEnabled bool -} - -// CheckAndSetDefaults checks and sets log format configuration. -func (j *JSONFormatter) CheckAndSetDefaults() error { - // set log formatting - if j.ExtraFields == nil { - j.ExtraFields = defaultFormatFields - } - // set caller - j.FormatCaller = formatCallerWithPathAndLine - - if slices.Contains(j.ExtraFields, CallerField) { - j.callerEnabled = true - } - - if slices.Contains(j.ExtraFields, ComponentField) { - j.componentEnabled = true - } - - // rename default fields - j.JSONFormatter = logrus.JSONFormatter{ - FieldMap: logrus.FieldMap{ - logrus.FieldKeyTime: TimestampField, - logrus.FieldKeyLevel: LevelField, - logrus.FieldKeyMsg: messageField, - }, - DisableTimestamp: !slices.Contains(j.ExtraFields, TimestampField), - } - - return nil -} - -// Format formats each log line as configured in teleport config file. -func (j *JSONFormatter) Format(e *logrus.Entry) ([]byte, error) { - if j.callerEnabled { - path := j.FormatCaller() - e.Data[CallerField] = path - } - - if j.componentEnabled { - e.Data[ComponentField] = e.Data[teleport.ComponentKey] - } - - delete(e.Data, teleport.ComponentKey) - - return j.JSONFormatter.Format(e) -} - -// NewTestJSONFormatter creates a JSONFormatter that is -// configured for output in tests. -func NewTestJSONFormatter() *JSONFormatter { - formatter := &JSONFormatter{} - if err := formatter.CheckAndSetDefaults(); err != nil { - panic(err) - } - return formatter -} - -func (w *writer) writeError(value interface{}) { - switch err := value.(type) { - case trace.Error: - *w.b = fmt.Appendf(*w.b, "[%v]", err.DebugReport()) - default: - *w.b = fmt.Appendf(*w.b, "[%v]", value) - } -} - -func (w *writer) writeField(value interface{}, color int) { - if w.Len() > 0 { - w.WriteByte(' ') - } - w.writeValue(value, color) -} - -func (w *writer) writeKeyValue(key string, value interface{}) { - if w.Len() > 0 { - w.WriteByte(' ') - } - w.WriteString(key) - w.WriteByte(':') - if key == logrus.ErrorKey { - w.writeError(value) - return - } - w.writeValue(value, noColor) -} - -func (w *writer) writeValue(value interface{}, color int) { - if s, ok := value.(string); ok { - if color != noColor { - *w.b = fmt.Appendf(*w.b, "\u001B[%dm", color) - } - - if needsQuoting(s) { - *w.b = strconv.AppendQuote(*w.b, s) - } else { - *w.b = fmt.Append(*w.b, s) - } - - if color != noColor { - *w.b = fmt.Append(*w.b, "\u001B[0m") - } - return - } - - if color != noColor { - *w.b = fmt.Appendf(*w.b, "\x1b[%dm%v\x1b[0m", color, value) - return - } - - *w.b = fmt.Appendf(*w.b, "%v", value) -} - -func (w *writer) writeMap(m map[string]any) { - if len(m) == 0 { - return - } - keys := make([]string, 0, len(m)) - for key := range m { - keys = append(keys, key) - } - slices.Sort(keys) - for _, key := range keys { - if key == teleport.ComponentKey { - continue - } - switch value := m[key].(type) { - case map[string]any: - w.writeMap(value) - case logrus.Fields: - w.writeMap(value) - default: - w.writeKeyValue(key, value) - } - } -} - -type frameCursor struct { - // current specifies the current stack frame. - // if omitted, rest contains the complete stack - current *runtime.Frame - // rest specifies the rest of stack frames to explore - rest *runtime.Frames - // n specifies the total number of stack frames - n int -} - -// formatCallerWithPathAndLine formats the caller in the form path/segment: -// for output in the log -func formatCallerWithPathAndLine() (path string) { - if cursor := findFrame(); cursor != nil { - t := newTraceFromFrames(*cursor, nil) - return t.Loc() - } - return "" -} - -var frameIgnorePattern = regexp.MustCompile(`github\.com/sirupsen/logrus`) - -// findFrames positions the stack pointer to the first -// function that does not match the frameIngorePattern -// and returns the rest of the stack frames -func findFrame() *frameCursor { - var buf [32]uintptr - // Skip enough frames to start at user code. - // This number is a mere hint to the following loop - // to start as close to user code as possible and getting it right is not mandatory. - // The skip count might need to get updated if the call to findFrame is - // moved up/down the call stack - n := runtime.Callers(4, buf[:]) - pcs := buf[:n] - frames := runtime.CallersFrames(pcs) - for i := 0; i < n; i++ { - frame, _ := frames.Next() - if !frameIgnorePattern.MatchString(frame.Function) { - return &frameCursor{ - current: &frame, - rest: frames, - n: n, - } - } - } - return nil -} - -func newTraceFromFrames(cursor frameCursor, err error) *trace.TraceErr { - traces := make(trace.Traces, 0, cursor.n) - if cursor.current != nil { - traces = append(traces, frameToTrace(*cursor.current)) - } - for { - frame, more := cursor.rest.Next() - traces = append(traces, frameToTrace(frame)) - if !more { - break - } - } - return &trace.TraceErr{ - Err: err, - Traces: traces, - } -} - -func frameToTrace(frame runtime.Frame) trace.Trace { - return trace.Trace{ - Func: frame.Function, - Path: frame.File, - Line: frame.Line, - } -} diff --git a/lib/utils/log/slog.go b/lib/utils/log/slog.go index 46f0e13627b3e..bfb34f4a94114 100644 --- a/lib/utils/log/slog.go +++ b/lib/utils/log/slog.go @@ -27,7 +27,6 @@ import ( "unicode" "github.com/gravitational/trace" - "github.com/sirupsen/logrus" oteltrace "go.opentelemetry.io/otel/trace" ) @@ -68,25 +67,6 @@ var SupportedLevelsText = []string{ slog.LevelError.String(), } -// SlogLevelToLogrusLevel converts a [slog.Level] to its equivalent -// [logrus.Level]. -func SlogLevelToLogrusLevel(level slog.Level) logrus.Level { - switch level { - case TraceLevel: - return logrus.TraceLevel - case slog.LevelDebug: - return logrus.DebugLevel - case slog.LevelInfo: - return logrus.InfoLevel - case slog.LevelWarn: - return logrus.WarnLevel - case slog.LevelError: - return logrus.ErrorLevel - default: - return logrus.FatalLevel - } -} - // DiscardHandler is a [slog.Handler] that discards all messages. It // is more efficient than a [slog.Handler] which outputs to [io.Discard] since // it performs zero formatting. diff --git a/lib/utils/log/slog_text_handler.go b/lib/utils/log/slog_text_handler.go index 7f93a388977bb..612615ba8582d 100644 --- a/lib/utils/log/slog_text_handler.go +++ b/lib/utils/log/slog_text_handler.go @@ -150,45 +150,12 @@ func (s *SlogTextHandler) Handle(ctx context.Context, r slog.Record) error { // Processing fields in this manner allows users to // configure the level and component position in the output. - // This matches the behavior of the original logrus. All other + // This matches the behavior of the original logrus formatter. All other // fields location in the output message are static. for _, field := range s.cfg.ConfiguredFields { switch field { case LevelField: - var color int - var level string - switch r.Level { - case TraceLevel: - level = "TRACE" - color = gray - case slog.LevelDebug: - level = "DEBUG" - color = gray - case slog.LevelInfo: - level = "INFO" - color = blue - case slog.LevelWarn: - level = "WARN" - color = yellow - case slog.LevelError: - level = "ERROR" - color = red - case slog.LevelError + 1: - level = "FATAL" - color = red - default: - color = blue - level = r.Level.String() - } - - if !s.cfg.EnableColors { - color = noColor - } - - level = padMax(level, defaultLevelPadding) - if color != noColor { - level = fmt.Sprintf("\u001B[%dm%s\u001B[0m", color, level) - } + level := formatLevel(r.Level, s.cfg.EnableColors) if rep == nil { state.appendKey(slog.LevelKey) @@ -211,12 +178,8 @@ func (s *SlogTextHandler) Handle(ctx context.Context, r slog.Record) error { if attr.Key != teleport.ComponentKey { return true } - component = fmt.Sprintf("[%v]", attr.Value) - component = strings.ToUpper(padMax(component, s.cfg.Padding)) - if component[len(component)-1] != ' ' { - component = component[:len(component)-1] + "]" - } + component = formatComponent(attr.Value, s.cfg.Padding) return false }) @@ -271,6 +234,55 @@ func (s *SlogTextHandler) Handle(ctx context.Context, r slog.Record) error { return err } +func formatLevel(value slog.Level, enableColors bool) string { + var color int + var level string + switch value { + case TraceLevel: + level = "TRACE" + color = gray + case slog.LevelDebug: + level = "DEBUG" + color = gray + case slog.LevelInfo: + level = "INFO" + color = blue + case slog.LevelWarn: + level = "WARN" + color = yellow + case slog.LevelError: + level = "ERROR" + color = red + case slog.LevelError + 1: + level = "FATAL" + color = red + default: + color = blue + level = value.String() + } + + if !enableColors { + color = noColor + } + + level = padMax(level, defaultLevelPadding) + if color != noColor { + level = fmt.Sprintf("\u001B[%dm%s\u001B[0m", color, level) + } + + return level +} + +func formatComponent(value slog.Value, padding int) string { + component := fmt.Sprintf("[%v]", value) + component = strings.ToUpper(padMax(component, padding)) + if component[len(component)-1] != ' ' { + component = component[:len(component)-1] + "]" + } + + return component +} + func (s *SlogTextHandler) clone() *SlogTextHandler { // We can't use assignment because we can't copy the mutex. return &SlogTextHandler{ diff --git a/lib/utils/log/writer.go b/lib/utils/log/writer.go deleted file mode 100644 index 77cf3037a8b66..0000000000000 --- a/lib/utils/log/writer.go +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Teleport - * Copyright (C) 2023 Gravitational, Inc. - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -package log - -import ( - "io" - "sync" -) - -// SharedWriter is an [io.Writer] implementation that protects -// writes with a mutex. This allows a single [io.Writer] to be shared -// by both logrus and slog without their output clobbering each other. -type SharedWriter struct { - mu sync.Mutex - io.Writer -} - -func (s *SharedWriter) Write(p []byte) (int, error) { - s.mu.Lock() - defer s.mu.Unlock() - - return s.Writer.Write(p) -} - -// NewSharedWriter wraps the provided [io.Writer] in a writer that -// is thread safe. -func NewSharedWriter(w io.Writer) *SharedWriter { - return &SharedWriter{Writer: w} -} From f63a099ca797ac6c82b86190504c8e2fbe795ef4 Mon Sep 17 00:00:00 2001 From: Paul Gottschling Date: Fri, 10 Jan 2025 13:01:35 -0500 Subject: [PATCH 8/8] Add Access Monitoring compatibility docs warning (#50571) Closes #48745 Add a warning to the External Audit Storage page that this feature is not compatible with Access Monitoring on Teleport Enterprise (Cloud), complementing the warning on the Access Monitoring page. --- .../admin-guides/access-controls/access-monitoring.mdx | 2 +- .../admin-guides/management/external-audit-storage.mdx | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/docs/pages/admin-guides/access-controls/access-monitoring.mdx b/docs/pages/admin-guides/access-controls/access-monitoring.mdx index 7f5a7b2a0a864..25797cf3e89d3 100644 --- a/docs/pages/admin-guides/access-controls/access-monitoring.mdx +++ b/docs/pages/admin-guides/access-controls/access-monitoring.mdx @@ -17,7 +17,7 @@ Users are able to write their own custom access monitoring queries by querying t Access Monitoring is not currently supported with External Audit Storage - in Teleport Enterprise (cloud-hosted). This functionality will be + in Teleport Enterprise (Cloud). This functionality will be enabled in a future Teleport release. diff --git a/docs/pages/admin-guides/management/external-audit-storage.mdx b/docs/pages/admin-guides/management/external-audit-storage.mdx index 6aa2fcc0368b8..587bb7ffebe56 100644 --- a/docs/pages/admin-guides/management/external-audit-storage.mdx +++ b/docs/pages/admin-guides/management/external-audit-storage.mdx @@ -21,6 +21,12 @@ External Audit Storage is based on Teleport's available on Teleport Enterprise Cloud clusters running Teleport v14.2.1 or above. + +On Teleport Enterprise (Cloud), External Audit +Storage is not currently supported for users who have Access Monitoring enabled. +This functionality will be enabled in a future Teleport release. + + ## Prerequisites 1. A Teleport Enterprise Cloud account. If you do not have one, [sign