diff --git a/future.go b/future.go index 0bd38d9..83b1240 100644 --- a/future.go +++ b/future.go @@ -46,11 +46,11 @@ type Future[T any] struct { state *state[T] } -func (s *state[T]) set(val T, err error) { +func (s *state[T]) set(val T, err error) bool { for { st := atomic.LoadUint64(&s.state) if !isFree(st) { - panic("promise already satisfied") + return false } if atomic.CompareAndSwapUint64(&s.state, st, st+stateDelta) { s.val = val @@ -69,7 +69,7 @@ func (s *state[T]) set(val T, err error) { head.next = nil } } - return + return true } } } @@ -135,7 +135,13 @@ func NewPromise[T any]() *Promise[T] { } func (p *Promise[T]) Set(val T, err error) { - p.state.set(val, err) + if !p.state.set(val, err) { + panic("promise already satisfied") + } +} + +func (p *Promise[T]) SetSafety(val T, err error) bool { + return p.state.set(val, err) } func (p *Promise[T]) Future() *Future[T] { diff --git a/future_test.go b/future_test.go index 2990c91..7cee5d2 100644 --- a/future_test.go +++ b/future_test.go @@ -58,6 +58,16 @@ func TestPromiseSetTwice(t *testing.T) { }) } +func TestPromiseSetSafetyTwice(t *testing.T) { + p := NewPromise[int]() + f := p.Future() + p.SetSafety(1, nil) + p.SetSafety(2, nil) + val, err := f.Get() + assert.Equal(t, 1, val) + assert.NoError(t, err) +} + func TestFutureSubscribe(t *testing.T) { p := NewPromise[int]() f := p.Future()