diff --git a/internal/pkg/conflictless/config.go b/internal/pkg/conflictless/config.go index 4456dc2..c3f5209 100644 --- a/internal/pkg/conflictless/config.go +++ b/internal/pkg/conflictless/config.go @@ -1,6 +1,9 @@ package conflictless -import "fmt" +import ( + "fmt" + "strings" +) // FlagCollection is a collection of flags. type FlagCollection struct { @@ -99,6 +102,7 @@ func (cfg *Config) SetChangeFileFormatFromFlags() error { } formatFlag := *cfg.Flags.ChangeFileFormat + formatFlag = strings.ToLower(formatFlag) switch formatFlag { case "yaml": diff --git a/internal/pkg/conflictless/config_test.go b/internal/pkg/conflictless/config_test.go index ab684fc..5cc83d6 100644 --- a/internal/pkg/conflictless/config_test.go +++ b/internal/pkg/conflictless/config_test.go @@ -34,6 +34,57 @@ func TestSetBumpFromFlags(t *testing.T) { } } +func TestSetChangeFileFormatFromFlags(t *testing.T) { + t.Parallel() + + for _, testCase := range []struct { + description string + format string + expected string + }{ + {"yml", "yml", "yml"}, + {"yaml", "yaml", "yaml"}, + {"json", "json", "json"}, + {"upper_case_json", "JSON", "json"}, + {"mixed_case_yaml", "yAmL", "yaml"}, + } { + t.Run(testCase.description, func(t *testing.T) { + t.Parallel() + + cfg := new(conflictless.Config) + cfg.Flags.ChangeFileFormat = &testCase.format + + err := cfg.SetChangeFileFormatFromFlags() + assert.NoError(t, err) + assert.Equal(t, testCase.expected, cfg.ChangeFileFormat) + }) + } +} + +func TestChangeFileFormatFromFlagsWithNil(t *testing.T) { + t.Parallel() + + cfg := new(conflictless.Config) + cfg.ChangeFileFormat = "yml" + cfg.Flags.ChangeFileFormat = nil + + err := cfg.SetChangeFileFormatFromFlags() + assert.NoError(t, err) + assert.Equal(t, "yml", cfg.ChangeFileFormat) +} + +func TestChangeFileFormatFromFlagsWithInvalid(t *testing.T) { + t.Parallel() + + invalidFormat := "foo" + + cfg := new(conflictless.Config) + cfg.Flags.ChangeFileFormat = &invalidFormat + + err := cfg.SetChangeFileFormatFromFlags() + assert.Error(t, err) +} + func TestSetBumpFromFlagsWhenInputIsInvalid(t *testing.T) { t.Parallel()