Skip to content

Commit

Permalink
Added reset to CLI tests
Browse files Browse the repository at this point in the history
Added `viper.Reset()` to tests, because the wrong string representation
was sticking around between tests
  • Loading branch information
arcward committed Sep 22, 2024
1 parent 3c50d6b commit f24f9a4
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 13 deletions.
8 changes: 8 additions & 0 deletions cmd/init_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"fmt"
"github.com/arcward/disconcierge/disconcierge"
"github.com/spf13/viper"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/driver/sqlite"
Expand All @@ -14,6 +15,13 @@ import (
)

func TestInitCommand(t *testing.T) {
t.Cleanup(
func() {
viper.Reset()
cfg = disconcierge.DefaultConfig()
},
)

tempDir := t.TempDir()
dbPath := filepath.Join(tempDir, "test.db")

Expand Down
23 changes: 13 additions & 10 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,8 @@ func initConfig() {
viper.SetDefault("runtime_config_ttl", disconcierge.DefaultRuntimeConfigTTL)
viper.SetDefault("user_cache_ttl", disconcierge.DefaultUserCacheTTL)

viper.SetDefault("log_level", disconcierge.DefaultLogLevel)
viper.SetDefault("api.log_level", disconcierge.DefaultAPILogLevel)
viper.SetDefault("log_level", disconcierge.DefaultLogLevel.String())
viper.SetDefault("api.log_level", disconcierge.DefaultAPILogLevel.String())

viper.SetDefault("startup_timeout", disconcierge.DefaultStartupTimeout)
viper.SetDefault("shutdown_timeout", disconcierge.DefaultShutdownTimeout)
Expand Down Expand Up @@ -209,7 +209,7 @@ func initConfig() {
)
viper.SetDefault(
"discord.webhook_server.log_level",
disconcierge.DefaultDiscordWebhookLogLevel,
disconcierge.DefaultDiscordWebhookLogLevel.String(),
)

fatalErr := func(err error) {
Expand Down Expand Up @@ -301,45 +301,48 @@ func initConfig() {
viper.GetStringSlice("api.cors.expose_headers"),
)

for k, v := range viper.AllSettings() {
log.Printf("config: %s: %v", k, v)
}
logLevelVar, err := levelStringToLevelVar(viper.GetString("log_level"))
if err != nil {
log.Fatalf("error parsing log level: %v", err)
log.Fatalf("error parsing log_level: %v", err)
}
viper.Set("log_level", logLevelVar)

logLevelVar, err = levelStringToLevelVar(viper.GetString("discord.log_level"))
if err != nil {
log.Fatalf("error parsing log level: %v", err)
log.Fatalf("error parsing discord log level: %v", err)
}
viper.Set("discord.log_level", logLevelVar)

logLevelVar, err = levelStringToLevelVar(viper.GetString("openai.log_level"))
if err != nil {
log.Fatalf("error parsing log level: %v", err)
log.Fatalf("error parsing openai log level: %v", err)
}
viper.Set("openai.log_level", logLevelVar)

logLevelVar, err = levelStringToLevelVar(viper.GetString("discord.discordgo_log_level"))
if err != nil {
log.Fatalf("error parsing log level: %v", err)
log.Fatalf("error parsing discordgo log level: %v", err)
}
viper.Set("discord.discordgo_log_level", logLevelVar)

logLevelVar, err = levelStringToLevelVar(viper.GetString("database_log_level"))
if err != nil {
log.Fatalf("error parsing log level: %v", err)
log.Fatalf("error parsing database log level: %v", err)
}
viper.Set("database_log_level", logLevelVar)

logLevelVar, err = levelStringToLevelVar(viper.GetString("api.log_level"))
if err != nil {
log.Fatalf("error parsing log level: %v", err)
log.Fatalf("error parsing api log level: %v", err)
}
viper.Set("api.log_level", logLevelVar)

logLevelVar, err = levelStringToLevelVar(viper.GetString("discord.webhook_server.log_level"))
if err != nil {
log.Fatalf("error parsing log level: %v", err)
log.Fatalf("error parsing webhook server log level: %v", err)
}
viper.Set("discord.webhook_server.log_level", logLevelVar)

Expand Down
10 changes: 7 additions & 3 deletions cmd/root_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (

func TestLoadConfigFromEnvFile(t *testing.T) {
// Save the original environment

originalEnv := os.Environ()
t.Cleanup(
func() {
Expand All @@ -28,9 +29,12 @@ func TestLoadConfigFromEnvFile(t *testing.T) {
}
},
)

// Clear the environment before the test
os.Clearenv()
t.Cleanup(
func() {
viper.Reset()
cfg = disconcierge.DefaultConfig()
},
)

tmpdir := t.TempDir()

Expand Down

0 comments on commit f24f9a4

Please sign in to comment.