diff --git a/event/bind_test.go b/event/bind_test.go index 073bf065..4fe1d942 100644 --- a/event/bind_test.go +++ b/event/bind_test.go @@ -56,10 +56,10 @@ func TestBus_Unbind(t *testing.T) { func TestBind(t *testing.T) { t.Run("SuperficialCoverage", func(t *testing.T) { b := event.NewBus(event.NewCallerMap()) - cid := randomCallerID() - b.GetCallerMap().Register(cid) + var cid event.CallerID + id := b.GetCallerMap().Register(cid) var calls int32 - b1 := event.Bind(b, event.Enter, cid, func(*event.CallerID, event.EnterPayload) event.Response { + b1 := event.Bind(b, event.Enter, id, func(event.CallerID, event.EnterPayload) event.Response { atomic.AddInt32(&calls, 1) return 0 }) @@ -90,31 +90,30 @@ func TestGlobalBind(t *testing.T) { func TestBus_UnbindAllFrom(t *testing.T) { t.Run("Basic", func(t *testing.T) { b := event.NewBus(event.NewCallerMap()) - var cid = new(event.CallerID) - b.GetCallerMap().Register(cid) + var cid event.CallerID + id := b.GetCallerMap().Register(cid) var calls int32 for i := 0; i < 5; i++ { - b1 := event.Bind(b, event.Enter, cid, func(*event.CallerID, event.EnterPayload) event.Response { + b1 := event.Bind(b, event.Enter, id, func(event.CallerID, event.EnterPayload) event.Response { atomic.AddInt32(&calls, 1) return 0 }) <-b1.Bound } - oldID := *cid - b.GetCallerMap().Register(cid) - b1 := event.Bind(b, event.Enter, cid, func(*event.CallerID, event.EnterPayload) event.Response { + id2 := b.GetCallerMap().Register(cid) + b1 := event.Bind(b, event.Enter, id2, func(event.CallerID, event.EnterPayload) event.Response { atomic.AddInt32(&calls, 1) return 0 }) <-b1.Bound <-event.TriggerOn(b, event.Enter, event.EnterPayload{}) if calls != 6 { - t.Fatal(expectedError("calls", 6, calls)) + t.Fatal(expectedError("calls", 1, calls)) } - <-b.UnbindAllFrom(oldID) + <-b.UnbindAllFrom(id) <-event.TriggerOn(b, event.Enter, event.EnterPayload{}) if calls != 7 { - t.Fatal(expectedError("calls", 7, calls)) + t.Fatal(expectedError("calls", 1, calls)) } }) } diff --git a/event/bus_test.go b/event/bus_test.go index 7993ba69..8b227ddb 100644 --- a/event/bus_test.go +++ b/event/bus_test.go @@ -1,6 +1,7 @@ package event_test import ( + "math/rand" "sync/atomic" "testing" "time" @@ -27,11 +28,11 @@ func TestBus_SetCallerMap(t *testing.T) { t.Run("Basic", func(t *testing.T) { cm1 := event.NewCallerMap() b := event.NewBus(cm1) - c1 := randomCallerID() + c1 := event.CallerID(rand.Intn(10000)) b.GetCallerMap().Register(c1) cm2 := event.NewCallerMap() b.SetCallerMap(cm2) - if b.GetCallerMap().HasEntity(c1.CID()) { + if b.GetCallerMap().HasEntity(c1) { t.Fatal("event had old entity after changed caller map") } }) diff --git a/event/caller.go b/event/caller.go index 1ab09bcd..880fd1a7 100644 --- a/event/caller.go +++ b/event/caller.go @@ -12,10 +12,6 @@ func (c CallerID) CID() CallerID { return c } -func (c *CallerID) SetCID(c2 CallerID) { - *c = c2 -} - // Global is the CallerID associated with global bindings. A caller must not be assigned // this ID. Global may be used to manually create bindings scoped to no callers, but the GlobalBind function // should be preferred when possible for type safety. @@ -23,7 +19,6 @@ const Global CallerID = 0 type Caller interface { CID() CallerID - SetCID(CallerID) } // A CallerMap tracks CallerID mappings to Entities. @@ -46,7 +41,7 @@ func NewCallerMap() *CallerMap { // NextID finds the next available caller id // and returns it, after adding the given entity to // the caller map. -func (cm *CallerMap) Register(e Caller) { +func (cm *CallerMap) Register(e Caller) CallerID { cm.callersLock.Lock() defer cm.callersLock.Unlock() // Q: Why not use atomic? @@ -63,7 +58,7 @@ func (cm *CallerMap) Register(e Caller) { // Increment before assigning to preserve Global == caller 0 cm.highestID++ cm.callers[cm.highestID] = e - e.SetCID(cm.highestID) + return cm.highestID } // Get returns the entity corresponding to the given ID within diff --git a/event/caller_test.go b/event/caller_test.go index 6ca3b8e6..10707660 100644 --- a/event/caller_test.go +++ b/event/caller_test.go @@ -25,47 +25,42 @@ func TestNewCallerMap(t *testing.T) { }) } -func randomCallerID() *event.CallerID { - c1 := event.CallerID(rand.Intn(10000)) - return &c1 -} - func TestCallerMap_Register(t *testing.T) { t.Run("Basic", func(t *testing.T) { m := event.NewCallerMap() - c1 := randomCallerID() - m.Register(c1) - c2 := m.GetEntity(c1.CID()) + c1 := event.CallerID(rand.Intn(10000)) + id := m.Register(c1) + c2 := m.GetEntity(id) if c2 != c1 { t.Fatalf("unable to retrieve registered caller") } - if !m.HasEntity(c1.CID()) { + if !m.HasEntity(id) { t.Fatalf("caller map does not have registered caller") } }) t.Run("Remove", func(t *testing.T) { m := event.NewCallerMap() - c1 := randomCallerID() - m.Register(c1) - m.RemoveEntity(c1.CID()) - c3 := m.GetEntity(c1.CID()) + c1 := event.CallerID(rand.Intn(10000)) + id := m.Register(c1) + m.RemoveEntity(id) + c3 := m.GetEntity(id) if c3 != nil { t.Fatalf("get entity had registered caller after remove") } - if m.HasEntity(c1.CID()) { + if m.HasEntity(id) { t.Fatalf("caller map has registered caller after remove") } }) t.Run("Clear", func(t *testing.T) { m := event.NewCallerMap() - c1 := randomCallerID() - m.Register(c1) + c1 := event.CallerID(rand.Intn(10000)) + id := m.Register(c1) m.Clear() - c3 := m.GetEntity(c1.CID()) + c3 := m.GetEntity(id) if c3 != nil { t.Fatalf("get entity had registered caller after clear") } - if m.HasEntity(c1.CID()) { + if m.HasEntity(id) { t.Fatalf("caller map has registered caller after clear") } }) diff --git a/event/trigger_test.go b/event/trigger_test.go index 4b52c2ff..2dae1215 100644 --- a/event/trigger_test.go +++ b/event/trigger_test.go @@ -87,14 +87,14 @@ func TestBus_TriggerForCaller(t *testing.T) { }) t.Run("WithValidCallerID", func(t *testing.T) { b := event.NewBus(event.NewCallerMap()) - var cid = new(event.CallerID) - b.GetCallerMap().Register(cid) + var cid event.CallerID + callerID := b.GetCallerMap().Register(cid) id := event.UnsafeEventID(rand.Intn(100000)) errs := make(chan error) - binding := b.UnsafeBind(id, *cid, func(ci event.CallerID, h event.Handler, i interface{}) event.Response { + binding := b.UnsafeBind(id, callerID, func(ci event.CallerID, h event.Handler, i interface{}) event.Response { defer close(errs) - if ci != *cid { - errs <- expectedError("callerID", *cid, ci) + if ci != callerID { + errs <- expectedError("callerID", callerID, ci) } if h != b { errs <- expectedError("bus", b, h) @@ -110,7 +110,7 @@ func TestBus_TriggerForCaller(t *testing.T) { t.Fatal("timeout waiting for bind to close channel") case <-binding.Bound: } - ch := b.TriggerForCaller(*cid, id, nil) + ch := b.TriggerForCaller(callerID, id, nil) select { case <-time.After(50 * time.Millisecond): t.Fatal("timeout waiting for trigger to close channel") @@ -195,14 +195,14 @@ func TestBus_Trigger(t *testing.T) { }) t.Run("WithValidCallerID", func(t *testing.T) { b := event.NewBus(event.NewCallerMap()) - var cid = new(event.CallerID) - b.GetCallerMap().Register(cid) + var cid event.CallerID + callerID := b.GetCallerMap().Register(cid) id := event.UnsafeEventID(rand.Intn(100000)) errs := make(chan error) - binding := b.UnsafeBind(id, *cid, func(ci event.CallerID, h event.Handler, i interface{}) event.Response { + binding := b.UnsafeBind(id, event.CallerID(callerID), func(ci event.CallerID, h event.Handler, i interface{}) event.Response { defer close(errs) - if ci != *cid { - errs <- expectedError("callerID", *cid, ci) + if ci != callerID { + errs <- expectedError("callerID", callerID, ci) } if h != b { errs <- expectedError("bus", b, h)