diff --git a/internal/check/template.go b/internal/check/template.go index d92278da0..f0e5b35c0 100644 --- a/internal/check/template.go +++ b/internal/check/template.go @@ -8,7 +8,7 @@ import ( // A Func is a specific lint-check, which runs on a specific objects, and emits diagnostics if problems are found. // Checks have access to the entire LintContext, with all the objects in it, but must only report problems for the // object passed in the second argument. -type Func func(lintCtx *lintcontext.LintContext, object lintcontext.Object) []diagnostic.Diagnostic +type Func func(lintCtx lintcontext.LintContext, object lintcontext.Object) []diagnostic.Diagnostic // ObjectKindsDesc describes a list of supported object kinds for a check template. type ObjectKindsDesc struct { diff --git a/internal/lintcontext/context.go b/internal/lintcontext/context.go index c4f1b06a3..fb13d99c9 100644 --- a/internal/lintcontext/context.go +++ b/internal/lintcontext/context.go @@ -23,22 +23,37 @@ type InvalidObject struct { } // A LintContext represents the context for a lint run. -type LintContext struct { +type LintContext interface { + Objects() []Object + InvalidObjects() []InvalidObject +} + +type lintContextImpl struct { objects []Object invalidObjects []InvalidObject } // Objects returns the (valid) objects loaded from this LintContext. -func (l *LintContext) Objects() []Object { +func (l *lintContextImpl) Objects() []Object { return l.objects } +// addObject adds a valid object to this LintContext +func (l *lintContextImpl) addObjects(objs ...Object) { + l.objects = append(l.objects, objs...) +} + // InvalidObjects returns any objects that we attempted to load, but which were invalid. -func (l *LintContext) InvalidObjects() []InvalidObject { +func (l *lintContextImpl) InvalidObjects() []InvalidObject { return l.invalidObjects } -// New returns a ready-to-use, empty, lint context. -func New() *LintContext { - return &LintContext{} +// addInvalidObject adds an invalid object to this LintContext +func (l *lintContextImpl) addInvalidObjects(objs ...InvalidObject) { + l.invalidObjects = append(l.invalidObjects, objs...) +} + +// new returns a ready-to-use, empty, lintContextImpl. +func new() *lintContextImpl { + return &lintContextImpl{} } diff --git a/internal/lintcontext/create_contexts.go b/internal/lintcontext/create_contexts.go index bee109028..7d9355d64 100644 --- a/internal/lintcontext/create_contexts.go +++ b/internal/lintcontext/create_contexts.go @@ -20,15 +20,15 @@ var ( // Currently, each directory of Kube YAML files (or Helm charts) are treated as a separate context. // TODO: Figure out if it's useful to allow people to specify that files spanning different directories // should be treated as being in the same context. -func CreateContexts(filesOrDirs ...string) ([]*LintContext, error) { - contextsByDir := make(map[string]*LintContext) +func CreateContexts(filesOrDirs ...string) ([]LintContext, error) { + contextsByDir := make(map[string]*lintContextImpl) for _, fileOrDir := range filesOrDirs { // Stdin if fileOrDir == "-" { if _, alreadyExists := contextsByDir["-"]; alreadyExists { continue } - ctx := New() + ctx := new() if err := ctx.loadObjectsFromReader("", os.Stdin); err != nil { return nil, err } @@ -46,7 +46,7 @@ func CreateContexts(filesOrDirs ...string) ([]*LintContext, error) { if knownYAMLExtensions.Contains(strings.ToLower(filepath.Ext(currentPath))) || fileOrDir == currentPath { ctx := contextsByDir[dirName] if ctx == nil { - ctx = New() + ctx = new() contextsByDir[dirName] = ctx } if err := ctx.loadObjectsFromYAMLFile(currentPath, info); err != nil { @@ -60,7 +60,7 @@ func CreateContexts(filesOrDirs ...string) ([]*LintContext, error) { if _, alreadyExists := contextsByDir[currentPath]; alreadyExists { return nil } - ctx := New() + ctx := new() contextsByDir[currentPath] = ctx if err := ctx.loadObjectsFromHelmChart(currentPath); err != nil { return err @@ -78,7 +78,7 @@ func CreateContexts(filesOrDirs ...string) ([]*LintContext, error) { dirs = append(dirs, dir) } sort.Strings(dirs) - var contexts []*LintContext + var contexts []LintContext for _, dir := range dirs { contexts = append(contexts, contextsByDir[dir]) } diff --git a/internal/lintcontext/mocks/container.go b/internal/lintcontext/mocks/container.go new file mode 100644 index 000000000..27de38ee3 --- /dev/null +++ b/internal/lintcontext/mocks/container.go @@ -0,0 +1,29 @@ +package mocks + +import ( + "github.com/pkg/errors" + v1 "k8s.io/api/core/v1" +) + +// AddContainerToPod adds a mock container to the specified pod under context +func (l *MockLintContext) AddContainerToPod( + podName, containerName, image string, + ports []v1.ContainerPort, + env []v1.EnvVar, + sc *v1.SecurityContext, +) error { + pod, ok := l.pods[podName] + if !ok { + return errors.Errorf("pod with name %q is not found", podName) + } + // TODO: keep supporting other fields + pod.Spec.Containers = append(pod.Spec.Containers, v1.Container{ + Name: containerName, + Image: image, + Ports: ports, + Env: env, + Resources: v1.ResourceRequirements{}, + SecurityContext: sc, + }) + return nil +} diff --git a/internal/lintcontext/mocks/context.go b/internal/lintcontext/mocks/context.go new file mode 100644 index 000000000..9c0391e3c --- /dev/null +++ b/internal/lintcontext/mocks/context.go @@ -0,0 +1,30 @@ +package mocks + +import ( + "golang.stackrox.io/kube-linter/internal/lintcontext" + v1 "k8s.io/api/core/v1" +) + +// MockLintContext is mock implementation of the LintContext used in unit tests +type MockLintContext struct { + pods map[string]*v1.Pod +} + +// Objects returns all the objects under this MockLintContext +func (l *MockLintContext) Objects() []lintcontext.Object { + result := make([]lintcontext.Object, 0, len(l.pods)) + for _, p := range l.pods { + result = append(result, lintcontext.Object{Metadata: lintcontext.ObjectMetadata{}, K8sObject: p}) + } + return result +} + +// InvalidObjects is not implemented. For now we don't care about invalid objects for mock context. +func (l *MockLintContext) InvalidObjects() []lintcontext.InvalidObject { + return nil +} + +// NewMockContext returns an empty mockLintContext +func NewMockContext() *MockLintContext { + return &MockLintContext{pods: make(map[string]*v1.Pod)} +} diff --git a/internal/lintcontext/mocks/pod.go b/internal/lintcontext/mocks/pod.go new file mode 100644 index 000000000..1645a78eb --- /dev/null +++ b/internal/lintcontext/mocks/pod.go @@ -0,0 +1,53 @@ +package mocks + +import ( + "github.com/pkg/errors" + v1 "k8s.io/api/core/v1" + metaV1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +// AddMockPod adds a mock Pod to LintContext +func (l *MockLintContext) AddMockPod( + podName, namespace, clusterName string, + labels, annotations map[string]string, +) { + l.pods[podName] = + &v1.Pod{ + TypeMeta: metaV1.TypeMeta{}, + ObjectMeta: metaV1.ObjectMeta{ + Name: podName, + Namespace: namespace, + Labels: labels, + Annotations: annotations, + ClusterName: clusterName, + }, + Spec: v1.PodSpec{}, + Status: v1.PodStatus{}, + } +} + +// AddSecurityContextToPod adds a security context to the pod specified by name +func (l *MockLintContext) AddSecurityContextToPod( + podName string, + runAsUser *int64, + runAsNonRoot *bool, +) error { + pod, ok := l.pods[podName] + if !ok { + return errors.Errorf("pod with name %q is not found", podName) + } + // TODO: keep supporting other fields + pod.Spec.SecurityContext = &v1.PodSecurityContext{ + SELinuxOptions: nil, + WindowsOptions: nil, + RunAsUser: runAsUser, + RunAsGroup: nil, + RunAsNonRoot: runAsNonRoot, + SupplementalGroups: nil, + FSGroup: nil, + Sysctls: nil, + FSGroupChangePolicy: nil, + SeccompProfile: nil, + } + return nil +} diff --git a/internal/lintcontext/parse_yaml.go b/internal/lintcontext/parse_yaml.go index 6faf7e4cc..ff0adb99f 100644 --- a/internal/lintcontext/parse_yaml.go +++ b/internal/lintcontext/parse_yaml.go @@ -66,7 +66,7 @@ func (w nopWriter) Write(p []byte) (n int, err error) { return len(p), nil } -func (l *LintContext) renderHelmChart(dir string) (map[string]string, error) { +func (l *lintContextImpl) renderHelmChart(dir string) (map[string]string, error) { // Helm doesn't have great logging behaviour, and can spam stderr, so silence their logging. // TODO: capture these logs. log.SetOutput(nopWriter{}) @@ -95,11 +95,11 @@ func (l *LintContext) renderHelmChart(dir string) (map[string]string, error) { return rendered, nil } -func (l *LintContext) loadObjectsFromHelmChart(dir string) error { +func (l *lintContextImpl) loadObjectsFromHelmChart(dir string) error { metadata := ObjectMetadata{FilePath: dir} renderedFiles, err := l.renderHelmChart(dir) if err != nil { - l.invalidObjects = append(l.invalidObjects, InvalidObject{Metadata: metadata, LoadErr: err}) + l.addInvalidObjects(InvalidObject{Metadata: metadata, LoadErr: err}) return nil } for path, contents := range renderedFiles { @@ -113,7 +113,7 @@ func (l *LintContext) loadObjectsFromHelmChart(dir string) error { return nil } -func (l *LintContext) loadObjectFromYAMLReader(filePath string, r *yaml.YAMLReader) error { +func (l *lintContextImpl) loadObjectFromYAMLReader(filePath string, r *yaml.YAMLReader) error { doc, err := r.Read() if err != nil { return err @@ -130,14 +130,14 @@ func (l *LintContext) loadObjectFromYAMLReader(filePath string, r *yaml.YAMLRead objs, err := parseObjects(doc) if err != nil { - l.invalidObjects = append(l.invalidObjects, InvalidObject{ + l.addInvalidObjects(InvalidObject{ Metadata: metadata, LoadErr: err, }) return nil } for _, obj := range objs { - l.objects = append(l.objects, Object{ + l.addObjects(Object{ Metadata: metadata, K8sObject: obj, }) @@ -145,7 +145,7 @@ func (l *LintContext) loadObjectFromYAMLReader(filePath string, r *yaml.YAMLRead return nil } -func (l *LintContext) loadObjectsFromYAMLFile(filePath string, info os.FileInfo) error { +func (l *lintContextImpl) loadObjectsFromYAMLFile(filePath string, info os.FileInfo) error { if info.Size() > maxFileSizeBytes { return nil } @@ -160,7 +160,7 @@ func (l *LintContext) loadObjectsFromYAMLFile(filePath string, info os.FileInfo) return l.loadObjectsFromReader(filePath, file) } -func (l *LintContext) loadObjectsFromReader(filePath string, reader io.Reader) error { +func (l *lintContextImpl) loadObjectsFromReader(filePath string, reader io.Reader) error { yamlReader := yaml.NewYAMLReader(bufio.NewReader(reader)) for { if err := l.loadObjectFromYAMLReader(filePath, yamlReader); err != nil { diff --git a/internal/run/run.go b/internal/run/run.go index ab796ec54..827b8a2c7 100644 --- a/internal/run/run.go +++ b/internal/run/run.go @@ -15,7 +15,7 @@ type Result struct { } // Run runs the linter on the given context, with the given config. -func Run(lintCtxs []*lintcontext.LintContext, registry checkregistry.CheckRegistry, checks []string) (Result, error) { +func Run(lintCtxs []lintcontext.LintContext, registry checkregistry.CheckRegistry, checks []string) (Result, error) { instantiatedChecks := make([]*instantiatedcheck.InstantiatedCheck, 0, len(checks)) for _, checkName := range checks { instantiatedCheck := registry.Load(checkName) diff --git a/internal/templates/antiaffinity/template.go b/internal/templates/antiaffinity/template.go index 289081a64..8b5f18d4f 100644 --- a/internal/templates/antiaffinity/template.go +++ b/internal/templates/antiaffinity/template.go @@ -26,7 +26,7 @@ func init() { Parameters: params.ParamDescs, ParseAndValidateParams: params.ParseAndValidate, Instantiate: params.WrapInstantiateFunc(func(p params.Params) (check.Func, error) { - return func(_ *lintcontext.LintContext, object lintcontext.Object) []diagnostic.Diagnostic { + return func(_ lintcontext.LintContext, object lintcontext.Object) []diagnostic.Diagnostic { replicas, found := extract.Replicas(object.K8sObject) if !found { return nil diff --git a/internal/templates/containercapabilities/template.go b/internal/templates/containercapabilities/template.go index 784534f3c..8b22122aa 100644 --- a/internal/templates/containercapabilities/template.go +++ b/internal/templates/containercapabilities/template.go @@ -16,8 +16,9 @@ import ( ) const ( + templateKey = "verify-container-capabilities" reservedCapabilitiesAll = "all" - matchLiteralReservedCapabilitiesAll = "(?i)" + reservedCapabilitiesAll + matchLiteralReservedCapabilitiesAll = "^(?i)" + reservedCapabilitiesAll + "$" ) var ( @@ -26,6 +27,12 @@ var ( utils.Must(err) return m }() + + addListDiagMsgFmt = "container %q has ADD capability: %q, which matched with the forbidden capability for containers" + addListWithAllDiagMsgFmt = "container %q has ADD capability: %q, but no capabilities " + + "should be added at all and this capability is not included in the exceptions list" + dropListDiagMsgFmt = "container %q has DROP capabilities: %q, but does not drop capability %q which is required" + dropListWithAllDiagMsgFmt = "container %q has DROP capabilities: %q, but in fact all capabilities are required to be dropped" ) func checkCapabilityDropList( @@ -49,7 +56,7 @@ func checkCapabilityDropList( *result, diagnostic.Diagnostic{ Message: fmt.Sprintf( - "container %q has DROP capabilities: %q, but in fact all capabilities are required to be dropeed", + dropListWithAllDiagMsgFmt, containerName, scCaps.Drop), }) @@ -71,8 +78,8 @@ func checkCapabilityDropList( append( *result, diagnostic.Diagnostic{ - Message: fmt.Sprintf("container %q has DROP capabilities: %q, but does not drop "+ - "capability %q which is required", + Message: fmt.Sprintf( + dropListDiagMsgFmt, containerName, scCaps.Drop, paramCap), @@ -106,8 +113,7 @@ func checkCapabilityAddList( *result, diagnostic.Diagnostic{ Message: fmt.Sprintf( - "container %q has ADD capability: %q, but no capabilities should be added at all and"+ - " this capabilty is not included in the exceptions list", + addListWithAllDiagMsgFmt, containerName, scCap), }) @@ -128,7 +134,7 @@ func checkCapabilityAddList( *result, diagnostic.Diagnostic{ Message: fmt.Sprintf( - "container %q has ADD capability: %q, which matched with the forbidden capability for containers", + addListDiagMsgFmt, containerName, scCap), }) @@ -181,7 +187,7 @@ func validateExceptionsList(forbidAll bool, exceptions []string) error { func init() { templates.Register(check.Template{ HumanName: "Verify container capabilities", - Key: "verify-container-capabilities", + Key: templateKey, Description: "Flag containers that do not match capabilities requirements", SupportedObjectKinds: check.ObjectKindsDesc{ ObjectKinds: []string{objectkinds.DeploymentLike}, diff --git a/internal/templates/containercapabilities/template_test.go b/internal/templates/containercapabilities/template_test.go new file mode 100644 index 000000000..9e31dd264 --- /dev/null +++ b/internal/templates/containercapabilities/template_test.go @@ -0,0 +1,179 @@ +package containercapabilities + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/suite" + "golang.stackrox.io/kube-linter/internal/diagnostic" + "golang.stackrox.io/kube-linter/internal/lintcontext/mocks" + "golang.stackrox.io/kube-linter/internal/templates" + "golang.stackrox.io/kube-linter/internal/templates/containercapabilities/internal/params" + v1 "k8s.io/api/core/v1" +) + +var ( + podName = "test-pod" + containerName = "test-container" +) + +func TestContainerCapabilities(t *testing.T) { + suite.Run(t, new(ContainerCapabilitiesTestSuite)) +} + +type ContainerCapabilitiesTestSuite struct { + templates.TemplateTestSuite + + ctx *mocks.MockLintContext +} + +func (s *ContainerCapabilitiesTestSuite) SetupTest() { + s.Init(templateKey) + s.ctx = mocks.NewMockContext() +} + +func (s *ContainerCapabilitiesTestSuite) TestForbiddenCapabilities() { + addCaps := []v1.Capability{"FORBIDDEN_CAP", "ALLOWED_CAP"} + dropCaps := []v1.Capability{"DROPPED_CAP"} + + s.addPodAndAddContainerWithCaps(addCaps, dropCaps) + + s.Validate(s.ctx, []templates.TestCase{ + { + Param: params.Params{ + ForbiddenCapabilities: []string{"FORBIDDEN_CAP", "DROPPED_CAP"}, + Exceptions: nil, + }, + Diagnostics: []diagnostic.Diagnostic{ + {Message: fmt.Sprintf(addListDiagMsgFmt, containerName, "FORBIDDEN_CAP")}, + {Message: fmt.Sprintf(dropListDiagMsgFmt, containerName, dropCaps, "FORBIDDEN_CAP")}, + }, + ExpectInstantiationError: false, + }, + }) +} + +func (s *ContainerCapabilitiesTestSuite) TestForbiddenCapabilitiesWithAll() { + addCaps := []v1.Capability{"CAP_1", "CAP_2", "CAP_3"} + dropCaps := []v1.Capability{"DROPPED_CAP"} + + s.addPodAndAddContainerWithCaps(addCaps, dropCaps) + + s.Validate(s.ctx, []templates.TestCase{ + // Case 1: all are prohibited + { + Param: params.Params{ + ForbiddenCapabilities: []string{"all"}, + Exceptions: nil, + }, + Diagnostics: []diagnostic.Diagnostic{ + {Message: fmt.Sprintf(addListWithAllDiagMsgFmt, containerName, "CAP_1")}, + {Message: fmt.Sprintf(addListWithAllDiagMsgFmt, containerName, "CAP_2")}, + {Message: fmt.Sprintf(addListWithAllDiagMsgFmt, containerName, "CAP_3")}, + {Message: fmt.Sprintf(dropListWithAllDiagMsgFmt, containerName, dropCaps)}, + }, + ExpectInstantiationError: false, + }, + // Case 2: with some forgiven capabilities + { + Param: params.Params{ + // Also tests reserved word "all" should match irrespective of case + ForbiddenCapabilities: []string{"AlL"}, + Exceptions: []string{"CAP_1", "CAP_2"}, + }, + Diagnostics: []diagnostic.Diagnostic{ + {Message: fmt.Sprintf(addListWithAllDiagMsgFmt, containerName, "CAP_3")}, + {Message: fmt.Sprintf(dropListWithAllDiagMsgFmt, containerName, dropCaps)}, + }, + ExpectInstantiationError: false, + }, + }) +} + +func (s *ContainerCapabilitiesTestSuite) TestAddListHasAll() { + addCaps := []v1.Capability{"all", "REDUNDANT_CAP"} + dropCaps := make([]v1.Capability, 0) + + s.addPodAndAddContainerWithCaps(addCaps, dropCaps) + + s.Validate(s.ctx, []templates.TestCase{ + { + Param: params.Params{ + ForbiddenCapabilities: []string{"CAP_1"}, + Exceptions: nil, + }, + Diagnostics: []diagnostic.Diagnostic{ + {Message: fmt.Sprintf(addListDiagMsgFmt, containerName, "all")}, + {Message: fmt.Sprintf(dropListDiagMsgFmt, containerName, dropCaps, "CAP_1")}, + }, + ExpectInstantiationError: false, + }, + }) +} + +func (s *ContainerCapabilitiesTestSuite) TestDropListHasAll() { + addCaps := []v1.Capability{"FORGIVEN_CAP"} + dropCaps := []v1.Capability{"all"} + + s.addPodAndAddContainerWithCaps(addCaps, dropCaps) + + s.Validate(s.ctx, []templates.TestCase{ + // Case 1: caps are all dropped by "all" in drop list + { + Param: params.Params{ + ForbiddenCapabilities: []string{"CAP_1", "CAP_2"}, + Exceptions: nil, + }, + Diagnostics: []diagnostic.Diagnostic{}, + ExpectInstantiationError: false, + }, + // Case 2: forbidden caps include "all" + { + Param: params.Params{ + ForbiddenCapabilities: []string{"all"}, + Exceptions: nil, + }, + Diagnostics: []diagnostic.Diagnostic{ + {Message: fmt.Sprintf(addListWithAllDiagMsgFmt, containerName, "FORGIVEN_CAP")}, + }, + ExpectInstantiationError: false, + }, + // Case 3: now we forgive the FORGIVEN_CAP. Should see no error + { + Param: params.Params{ + ForbiddenCapabilities: []string{"all"}, + Exceptions: []string{"FORGIVEN_CAP"}, + }, + Diagnostics: []diagnostic.Diagnostic{}, + ExpectInstantiationError: false, + }, + }) +} + +func (s *ContainerCapabilitiesTestSuite) TestInvalidParams() { + addCaps := []v1.Capability{"CAP_1"} + dropCaps := []v1.Capability{"CAP_2"} + + s.addPodAndAddContainerWithCaps(addCaps, dropCaps) + + s.Validate(s.ctx, []templates.TestCase{ + { + Param: params.Params{ + ForbiddenCapabilities: []string{"THIS_IS_NOT_All_CAP"}, + Exceptions: []string{"BUT_WE_SPECIFY_EXCEPTIONS"}, + }, + Diagnostics: []diagnostic.Diagnostic{}, + ExpectInstantiationError: true, + }, + }) +} + +func (s *ContainerCapabilitiesTestSuite) addPodAndAddContainerWithCaps(addCaps, dropCaps []v1.Capability) { + s.ctx.AddMockPod(podName, "", "", nil, nil) + s.ctx.AddContainerToPod(podName, containerName, "", nil, nil, &v1.SecurityContext{ + Capabilities: &v1.Capabilities{ + Add: addCaps, + Drop: dropCaps, + }, + }) +} diff --git a/internal/templates/danglingservice/template.go b/internal/templates/danglingservice/template.go index c1358fca4..704886d78 100644 --- a/internal/templates/danglingservice/template.go +++ b/internal/templates/danglingservice/template.go @@ -26,7 +26,7 @@ func init() { Parameters: params.ParamDescs, ParseAndValidateParams: params.ParseAndValidate, Instantiate: params.WrapInstantiateFunc(func(_ params.Params) (check.Func, error) { - return func(lintCtx *lintcontext.LintContext, object lintcontext.Object) []diagnostic.Diagnostic { + return func(lintCtx lintcontext.LintContext, object lintcontext.Object) []diagnostic.Diagnostic { service, ok := object.K8sObject.(*v1.Service) if !ok { return nil diff --git a/internal/templates/deprecatedserviceaccount/template.go b/internal/templates/deprecatedserviceaccount/template.go index 1e4dd5713..ebe7f4556 100644 --- a/internal/templates/deprecatedserviceaccount/template.go +++ b/internal/templates/deprecatedserviceaccount/template.go @@ -23,7 +23,7 @@ func init() { Parameters: params.ParamDescs, ParseAndValidateParams: params.ParseAndValidate, Instantiate: params.WrapInstantiateFunc(func(_ params.Params) (check.Func, error) { - return func(_ *lintcontext.LintContext, object lintcontext.Object) []diagnostic.Diagnostic { + return func(_ lintcontext.LintContext, object lintcontext.Object) []diagnostic.Diagnostic { podSpec, found := extract.PodSpec(object.K8sObject) if !found { return nil diff --git a/internal/templates/disallowedgvk/template.go b/internal/templates/disallowedgvk/template.go index 5725e0b7a..59c2832e4 100644 --- a/internal/templates/disallowedgvk/template.go +++ b/internal/templates/disallowedgvk/template.go @@ -37,7 +37,7 @@ func init() { if err != nil { return nil, errors.Wrap(err, "invalid kind") } - return func(_ *lintcontext.LintContext, object lintcontext.Object) []diagnostic.Diagnostic { + return func(_ lintcontext.LintContext, object lintcontext.Object) []diagnostic.Diagnostic { gvk := extract.GVK(object.K8sObject) if groupMatcher(gvk.Group) && versionMatcher(gvk.Version) && kindMatcher(gvk.Kind) { return []diagnostic.Diagnostic{{Message: fmt.Sprintf("disallowed API object found: %s", gvk)}} diff --git a/internal/templates/mismatchingselector/template.go b/internal/templates/mismatchingselector/template.go index c7070b97f..ffd00aa9f 100644 --- a/internal/templates/mismatchingselector/template.go +++ b/internal/templates/mismatchingselector/template.go @@ -27,7 +27,7 @@ func init() { Parameters: params.ParamDescs, ParseAndValidateParams: params.ParseAndValidate, Instantiate: params.WrapInstantiateFunc(func(_ params.Params) (check.Func, error) { - return func(_ *lintcontext.LintContext, object lintcontext.Object) []diagnostic.Diagnostic { + return func(_ lintcontext.LintContext, object lintcontext.Object) []diagnostic.Diagnostic { selector, found := extract.Selector(object.K8sObject) if !found { return nil diff --git a/internal/templates/nonexistentserviceaccount/template.go b/internal/templates/nonexistentserviceaccount/template.go index d84df1d96..4f9554a5c 100644 --- a/internal/templates/nonexistentserviceaccount/template.go +++ b/internal/templates/nonexistentserviceaccount/template.go @@ -30,7 +30,7 @@ func init() { Parameters: params.ParamDescs, ParseAndValidateParams: params.ParseAndValidate, Instantiate: params.WrapInstantiateFunc(func(_ params.Params) (check.Func, error) { - return func(lintCtx *lintcontext.LintContext, object lintcontext.Object) []diagnostic.Diagnostic { + return func(lintCtx lintcontext.LintContext, object lintcontext.Object) []diagnostic.Diagnostic { podSpec, found := extract.PodSpec(object.K8sObject) if !found { return nil diff --git a/internal/templates/runasnonroot/template.go b/internal/templates/runasnonroot/template.go index 6c2b7dc60..c24700b0f 100644 --- a/internal/templates/runasnonroot/template.go +++ b/internal/templates/runasnonroot/template.go @@ -44,7 +44,7 @@ func init() { Parameters: params.ParamDescs, ParseAndValidateParams: params.ParseAndValidate, Instantiate: params.WrapInstantiateFunc(func(_ params.Params) (check.Func, error) { - return func(_ *lintcontext.LintContext, object lintcontext.Object) []diagnostic.Diagnostic { + return func(_ lintcontext.LintContext, object lintcontext.Object) []diagnostic.Diagnostic { podSpec, found := extract.PodSpec(object.K8sObject) if !found { return nil diff --git a/internal/templates/serviceaccount/template.go b/internal/templates/serviceaccount/template.go index a004121e3..e6fae0c30 100644 --- a/internal/templates/serviceaccount/template.go +++ b/internal/templates/serviceaccount/template.go @@ -29,7 +29,7 @@ func init() { if err != nil { return nil, err } - return func(_ *lintcontext.LintContext, object lintcontext.Object) []diagnostic.Diagnostic { + return func(_ lintcontext.LintContext, object lintcontext.Object) []diagnostic.Diagnostic { podSpec, found := extract.PodSpec(object.K8sObject) if !found { return nil diff --git a/internal/templates/templates_testutils.go b/internal/templates/templates_testutils.go new file mode 100644 index 000000000..b22c4359b --- /dev/null +++ b/internal/templates/templates_testutils.go @@ -0,0 +1,59 @@ +package templates + +import ( + "github.com/stretchr/testify/suite" + "golang.stackrox.io/kube-linter/internal/check" + "golang.stackrox.io/kube-linter/internal/diagnostic" + "golang.stackrox.io/kube-linter/internal/lintcontext" +) + +// TemplateTestSuite is a basic testing suite for all templates +// test with some generic helper functions +type TemplateTestSuite struct { + suite.Suite + + Template check.Template +} + +// TestCase represents a single test case which can be verified under a LintContext +type TestCase struct { + Param interface{} + Diagnostics []diagnostic.Diagnostic + ExpectInstantiationError bool +} + +// Init initializes the test suite with a template +func (s *TemplateTestSuite) Init(templateKey string) { + t, ok := Get(templateKey) + s.True(ok, "template with key %q not found", templateKey) + s.Template = t +} + +// Validate validates the given test cases against the LintContext passed in. +func (s *TemplateTestSuite) Validate( + ctx lintcontext.LintContext, + cases []TestCase, +) { + for _, c := range cases { + checkFunc, err := s.Template.Instantiate(c.Param) + if c.ExpectInstantiationError { + s.Error(err, "param should have caused error but did not raise one") + continue + } + for _, obj := range ctx.Objects() { + diagnostics := checkFunc(ctx, obj) + s.compareDiagnostics(c.Diagnostics, diagnostics) + } + } +} + +func (s *TemplateTestSuite) compareDiagnostics(expected, actual []diagnostic.Diagnostic) { + expectedMessages, actualMessages := make([]string, 0, len(expected)), make([]string, 0, len(actual)) + for _, diag := range expected { + expectedMessages = append(expectedMessages, diag.Message) + } + for _, diag := range actual { + actualMessages = append(actualMessages, diag.Message) + } + s.ElementsMatch(expectedMessages, actualMessages, "expected diagnostics and actual diagnostics do not match") +} diff --git a/internal/templates/util/per_container_check.go b/internal/templates/util/per_container_check.go index f3ebda582..514bc2241 100644 --- a/internal/templates/util/per_container_check.go +++ b/internal/templates/util/per_container_check.go @@ -12,7 +12,7 @@ import ( // that applies to containers. The given function is passed each container, and is allowed to return // diagnostics if an error is found. func PerContainerCheck(matchFunc func(container *v1.Container) []diagnostic.Diagnostic) check.Func { - return func(_ *lintcontext.LintContext, object lintcontext.Object) []diagnostic.Diagnostic { + return func(_ lintcontext.LintContext, object lintcontext.Object) []diagnostic.Diagnostic { podSpec, found := extract.PodSpec(object.K8sObject) if !found { return nil diff --git a/internal/templates/util/required_matcher.go b/internal/templates/util/required_matcher.go index 0dfb2c1f8..387f36297 100644 --- a/internal/templates/util/required_matcher.go +++ b/internal/templates/util/required_matcher.go @@ -34,7 +34,7 @@ func ConstructRequiredMapMatcher(key, value, fieldType string) (check.Func, erro return nil, errors.Errorf("unknown fieldType %q", fieldType) } - return func(_ *lintcontext.LintContext, object lintcontext.Object) []diagnostic.Diagnostic { + return func(_ lintcontext.LintContext, object lintcontext.Object) []diagnostic.Diagnostic { fields := extractFunc(object.K8sObject) for k, v := range fields { if keyMatcher(k) && valueMatcher(v) { diff --git a/internal/templates/writablehostmount/template.go b/internal/templates/writablehostmount/template.go index 9b4ef00db..5f4e14ae4 100644 --- a/internal/templates/writablehostmount/template.go +++ b/internal/templates/writablehostmount/template.go @@ -23,7 +23,7 @@ func init() { Parameters: params.ParamDescs, ParseAndValidateParams: params.ParseAndValidate, Instantiate: params.WrapInstantiateFunc(func(_ params.Params) (check.Func, error) { - return func(_ *lintcontext.LintContext, object lintcontext.Object) []diagnostic.Diagnostic { + return func(_ lintcontext.LintContext, object lintcontext.Object) []diagnostic.Diagnostic { podSpec, found := extract.PodSpec(object.K8sObject) if !found { return nil