diff --git a/environment.go b/environment.go index 846b846..e25dfbb 100644 --- a/environment.go +++ b/environment.go @@ -16,21 +16,28 @@ package environment import ( "sync" + "sync/atomic" ) var ( // Mutex when operating on the current runtime environment. - mutex sync.Mutex + mutex sync.RWMutex // The current environment. current = Development + // Is the current runtime environment locked? + locked = int32(0) + // List of supported environments. supported = []Env{Development, Testing, Prerelease, Production} ) // Get returns the current runtime environment. func Get() Env { + mutex.RLock() + defer mutex.RUnlock() + return current } @@ -46,12 +53,34 @@ func Register(env Env) { } } +// Lock locks the current runtime environment. +// After locking, the current runtime environment cannot be changed. +func Lock() { + mutex.Lock() + defer mutex.Unlock() + + atomic.StoreInt32(&locked, 1) +} + +// Locked returns whether the current runtime environment is locked. +func Locked() bool { + return atomic.LoadInt32(&locked) == 1 +} + // Set sets the current runtime environment. -// If the given environment is not supported, an ErrInvalidEnv error is returned. +// If the given runtime environment is not supported, ErrInvalidEnv error is returned. +// If the current runtime environment is locked, ErrLocked error is returned. func Set(env Env) error { mutex.Lock() defer mutex.Unlock() + return doSet(env) +} + +func doSet(env Env) error { + if Locked() { + return ErrLocked + } if !env.In(supported) { return ErrInvalidEnv } @@ -59,3 +88,17 @@ func Set(env Env) error { current = env return nil } + +// SetAndLock sets and locks the current runtime environment. +// If the runtime environment settings fail, they are not locked. +func SetAndLock(env Env) error { + mutex.Lock() + defer mutex.Unlock() + + if err := doSet(env); err != nil { + return err + } + + atomic.StoreInt32(&locked, 1) + return nil +} diff --git a/environment_test.go b/environment_test.go index 57e0558..1763b7a 100644 --- a/environment_test.go +++ b/environment_test.go @@ -15,36 +15,109 @@ package environment import ( + "sync/atomic" "testing" ) -func TestEnvironment(t *testing.T) { +func doTestEnvironment(f func()) { + // Oh, we still need to restore the scene! defer func() { current = Development + atomic.StoreInt32(&locked, 0) supported = []Env{Development, Testing, Prerelease, Production} }() - if got := Get(); got != Development { - t.Fatal(got) - } + f() +} + +func TestEnvironment(t *testing.T) { + // Copy from supported. + list := []Env{Development, Testing, Prerelease, Production} - envs := []Env{Development, Testing, Prerelease, Production} - for i, j := 0, len(envs); i < j; i++ { - if err := Set(envs[i]); err != nil { + doTestEnvironment(func() { + // The default runtime environment is Development! + if got := Get(); got != Development { + t.Fatal(got) + } + + for _, env := range list { + if err := Set(env); err != nil { + t.Fatal(err) + } + if got := Get(); got != env { + t.Fatal(got) + } + } + + if err := Set("foo"); err == nil { + t.Fatal("No error") + } else { + if err != ErrInvalidEnv { + t.Fatal(err) + } + } + + Register("foo") + + if err := Set("foo"); err != nil { t.Fatal(err) } - } - if err := Set("foo"); err == nil { - t.Fatal(`Set("foo")`) - } else { - if err != ErrInvalidEnv { + if Locked() { + t.Fatal("Locked") + } + + Lock() + + if !Locked() { + t.Fatal("Not Locked") + } + for _, env := range list { + if err := Set(env); err == nil { + t.Fatal("No error") + } else { + if err != ErrLocked { + t.Fatal(err) + } + } + } + if err := Set("foo"); err == nil { + t.Fatal("No error") + } else { + if err != ErrLocked { + t.Fatal(err) + } + } + }) +} + +func TestSetAndLock(t *testing.T) { + doTestEnvironment(func() { + if Locked() { + t.Fatal("Locked") + } + + // The default runtime environment is Development! + if got := Get(); got != Development { + t.Fatal(got) + } + + if err := SetAndLock(Testing); err != nil { t.Fatal(err) } - } + if !Locked() { + t.Fatal("Not Locked") + } + if got := Get(); got != Testing { + t.Fatal(got) + } - Register("foo") - if err := Set("foo"); err != nil { - t.Fatal(err) - } + if err := SetAndLock(Production); err == nil { + t.Fatal("No error") + } else { + if err != ErrLocked { + t.Fatal(err) + } + } + }) } diff --git a/errors.go b/errors.go index 8f2281a..51941b7 100644 --- a/errors.go +++ b/errors.go @@ -20,4 +20,8 @@ import ( // ErrInvalidEnv represents that the given runtime environment is not // registered or supported. -var ErrInvalidEnv = errors.New("invalid environment") +var ErrInvalidEnv = errors.New("invalid runtime environment") + +// ErrLocked indicates that the current runtime environment is locked +// and cannot be changed. +var ErrLocked = errors.New("locked runtime environment")