diff --git a/fake/fake_test.go b/fake/fake_test.go index 8f61bf0..08b7641 100644 --- a/fake/fake_test.go +++ b/fake/fake_test.go @@ -228,9 +228,9 @@ func TestPipeline__Do_Finish(t *testing.T) { sess := pipe.LowerSession() calls := 0 - sess.AddNextCall(func() { + sess.AddNextCall(memproxy.NewEmptyCallback(func() { calls++ - }) + })) sess.Execute() assert.Equal(t, 1, calls) diff --git a/item/item.go b/item/item.go index 8049603..c4391c5 100644 --- a/item/item.go +++ b/item/item.go @@ -5,6 +5,7 @@ import ( "errors" "log" "time" + "unsafe" "github.com/QuangTung97/go-memcache/memcache" @@ -223,6 +224,20 @@ type Item[T Value, K Key] struct { stats Stats } +func (i *Item[T, K]) addNextCall(fn func(obj unsafe.Pointer)) { + i.sess.AddNextCall(memproxy.CallbackFunc{ + Object: nil, + Func: fn, + }) +} + +func (i *Item[T, K]) addDelayedCall(d time.Duration, fn func(obj unsafe.Pointer)) { + i.sess.AddDelayedCall(d, memproxy.CallbackFunc{ + Object: nil, + Func: fn, + }) +} + type getResultType[T any] struct { resp T err error @@ -230,7 +245,7 @@ type getResultType[T any] struct { func (s *getState[T, K]) handleLeaseGranted(cas uint64) { fillFn := s.it.filler(s.ctx, s.key) - s.it.sess.AddNextCall(func() { + s.it.addNextCall(func(_ unsafe.Pointer) { fillResp, err := fillFn() if err == ErrNotFound { @@ -253,7 +268,9 @@ func (s *getState[T, K]) handleLeaseGranted(cas uint64) { if cas > 0 { _ = s.it.pipeline.LeaseSet(s.keyStr, data, cas, memproxy.LeaseSetOptions{}) - s.it.sess.AddNextCall(s.it.pipeline.Execute) + s.it.addNextCall(func(obj unsafe.Pointer) { + s.it.pipeline.Execute() + }) } }) } @@ -298,7 +315,7 @@ func (s *getState[T, K]) handleCacheError(err error) { } } -func (s *getState[T, K]) nextFunc() { +func (s *getState[T, K]) nextFunc(_ unsafe.Pointer) { leaseGetResp, err := s.leaseGetResult.Result() s.leaseGetResult = nil @@ -333,11 +350,11 @@ func (s *getState[T, K]) nextFunc() { s.it.increaseRejectedCount(s.retryCount) if s.retryCount < len(s.it.options.sleepDurations) { - s.it.sess.AddDelayedCall(s.it.options.sleepDurations[s.retryCount], func() { + s.it.addDelayedCall(s.it.options.sleepDurations[s.retryCount], func(_ unsafe.Pointer) { s.retryCount++ s.leaseGetResult = s.it.pipeline.LeaseGet(s.keyStr, memproxy.LeaseGetOptions{}) - s.it.sess.AddNextCall(s.nextFunc) + s.it.addNextCall(s.nextFunc) }) return } @@ -383,7 +400,7 @@ func (i *Item[T, K]) Get(ctx context.Context, key K) func() (T, error) { state.leaseGetResult = i.pipeline.LeaseGet(keyStr, memproxy.LeaseGetOptions{}) - i.sess.AddNextCall(state.nextFunc) + i.addNextCall(state.nextFunc) return state.returnFunc } diff --git a/item/item_test.go b/item/item_test.go index e97d95c..2e3dfd7 100644 --- a/item/item_test.go +++ b/item/item_test.go @@ -120,12 +120,12 @@ func newItemTestWithSleepDurations( }, } - var calls []func() + var calls []memproxy.CallbackFunc - sess.AddNextCallFunc = func(fn func()) { + sess.AddNextCallFunc = func(fn memproxy.CallbackFunc) { calls = append(calls, fn) } - sess.AddDelayedCallFunc = func(d time.Duration, fn func()) { + sess.AddDelayedCallFunc = func(d time.Duration, fn memproxy.CallbackFunc) { i.delayCalls = append(i.delayCalls, d) calls = append(calls, fn) } @@ -134,7 +134,7 @@ func newItemTestWithSleepDurations( nextCalls := calls calls = nil for _, fn := range nextCalls { - fn() + fn.Call() } } } diff --git a/memproxy.go b/memproxy.go index 006875c..0ae1e76 100644 --- a/memproxy.go +++ b/memproxy.go @@ -3,6 +3,7 @@ package memproxy import ( "context" "time" + "unsafe" ) // Memcache represents a generic Memcache interface @@ -62,10 +63,31 @@ type SessionProvider interface { New() Session } +// CallbackFunc for session +type CallbackFunc struct { + Object unsafe.Pointer + Func func(obj unsafe.Pointer) +} + +// Call ... +func (f CallbackFunc) Call() { + f.Func(f.Object) +} + +// NewEmptyCallback creates CallbackFunc from empty args function +func NewEmptyCallback(fn func()) CallbackFunc { + return CallbackFunc{ + Object: nil, + Func: func(_ unsafe.Pointer) { + fn() + }, + } +} + // Session controlling session values & delayed tasks, this object is NOT Thread Safe type Session interface { - AddNextCall(fn func()) - AddDelayedCall(d time.Duration, fn func()) + AddNextCall(fn CallbackFunc) + AddDelayedCall(d time.Duration, fn CallbackFunc) Execute() GetLower() Session diff --git a/mhash/mhash.go b/mhash/mhash.go index e341c10..eb3100c 100644 --- a/mhash/mhash.go +++ b/mhash/mhash.go @@ -5,9 +5,10 @@ import ( "encoding/binary" "encoding/hex" "errors" + "math" + "github.com/QuangTung97/memproxy" "github.com/QuangTung97/memproxy/item" - "math" ) // ErrHashTooDeep when too many levels to go to @@ -196,7 +197,7 @@ func (h *Hash[T, R, K]) Get(ctx context.Context, rootKey R, key K) func() (Null[ Level: callCtx.level, Hash: computeHashAtLevel(keyHash, callCtx.level), }) - h.sess.AddNextCall(nextCallFn) + h.sess.AddNextCall(memproxy.NewEmptyCallback(nextCallFn)) } callCtx.doComputeFn = doGetFn diff --git a/mhash/updater.go b/mhash/updater.go index 8dc7800..62ffb9f 100644 --- a/mhash/updater.go +++ b/mhash/updater.go @@ -2,6 +2,7 @@ package mhash import ( "context" + "github.com/QuangTung97/memproxy" "github.com/QuangTung97/memproxy/item" ) @@ -285,9 +286,9 @@ func (u *HashUpdater[T, R, K]) UpsertBucket( if withUpdate { nextCallFn(true) } else { - u.sess.AddNextCall(func() { + u.sess.AddNextCall(memproxy.NewEmptyCallback(func() { nextCallFn(false) - }) + })) } } @@ -367,13 +368,13 @@ func (u *HashUpdater[T, R, K]) UpsertBucket( callCtx.doComputeFn() - u.lowerSession.AddNextCall(func() { + u.lowerSession.AddNextCall(memproxy.NewEmptyCallback(func() { callCtx = callContext{} callCtx.doComputeFn = func() { doComputeWithUpdate(true) } callCtx.doComputeFn() - }) + })) return func() error { u.execute() @@ -404,9 +405,9 @@ func (u *HashUpdater[T, R, K]) DeleteBucket( if withUpdate { nextCallFn(true) } else { - u.sess.AddNextCall(func() { + u.sess.AddNextCall(memproxy.NewEmptyCallback(func() { nextCallFn(false) - }) + })) } } @@ -448,7 +449,7 @@ func (u *HashUpdater[T, R, K]) DeleteBucket( callCtx.doComputeFn() - u.lowerSession.AddNextCall(func() { + u.lowerSession.AddNextCall(memproxy.NewEmptyCallback(func() { // clear state callCtx = callContext{} scannedBuckets = scannedBuckets[:0] @@ -457,7 +458,7 @@ func (u *HashUpdater[T, R, K]) DeleteBucket( doComputeWithUpdate(true) } callCtx.doComputeFn() - }) + })) return func() error { u.execute() diff --git a/mocks/memproxy_mocks.go b/mocks/memproxy_mocks.go index 5525993..bd47a2e 100644 --- a/mocks/memproxy_mocks.go +++ b/mocks/memproxy_mocks.go @@ -483,10 +483,10 @@ var _ Session = &SessionMock{} // // // make and configure a mocked Session // mockedSession := &SessionMock{ -// AddDelayedCallFunc: func(d time.Duration, fn func()) { +// AddDelayedCallFunc: func(d time.Duration, fn memproxy.CallbackFunc) { // panic("mock out the AddDelayedCall method") // }, -// AddNextCallFunc: func(fn func()) { +// AddNextCallFunc: func(fn memproxy.CallbackFunc) { // panic("mock out the AddNextCall method") // }, // ExecuteFunc: func() { @@ -503,10 +503,10 @@ var _ Session = &SessionMock{} // } type SessionMock struct { // AddDelayedCallFunc mocks the AddDelayedCall method. - AddDelayedCallFunc func(d time.Duration, fn func()) + AddDelayedCallFunc func(d time.Duration, fn memproxy.CallbackFunc) // AddNextCallFunc mocks the AddNextCall method. - AddNextCallFunc func(fn func()) + AddNextCallFunc func(fn memproxy.CallbackFunc) // ExecuteFunc mocks the Execute method. ExecuteFunc func() @@ -521,12 +521,12 @@ type SessionMock struct { // D is the d argument value. D time.Duration // Fn is the fn argument value. - Fn func() + Fn memproxy.CallbackFunc } // AddNextCall holds details about calls to the AddNextCall method. AddNextCall []struct { // Fn is the fn argument value. - Fn func() + Fn memproxy.CallbackFunc } // Execute holds details about calls to the Execute method. Execute []struct { @@ -542,13 +542,13 @@ type SessionMock struct { } // AddDelayedCall calls AddDelayedCallFunc. -func (mock *SessionMock) AddDelayedCall(d time.Duration, fn func()) { +func (mock *SessionMock) AddDelayedCall(d time.Duration, fn memproxy.CallbackFunc) { if mock.AddDelayedCallFunc == nil { panic("SessionMock.AddDelayedCallFunc: method is nil but Session.AddDelayedCall was just called") } callInfo := struct { D time.Duration - Fn func() + Fn memproxy.CallbackFunc }{ D: d, Fn: fn, @@ -565,11 +565,11 @@ func (mock *SessionMock) AddDelayedCall(d time.Duration, fn func()) { // len(mockedSession.AddDelayedCallCalls()) func (mock *SessionMock) AddDelayedCallCalls() []struct { D time.Duration - Fn func() + Fn memproxy.CallbackFunc } { var calls []struct { D time.Duration - Fn func() + Fn memproxy.CallbackFunc } mock.lockAddDelayedCall.RLock() calls = mock.calls.AddDelayedCall @@ -578,12 +578,12 @@ func (mock *SessionMock) AddDelayedCallCalls() []struct { } // AddNextCall calls AddNextCallFunc. -func (mock *SessionMock) AddNextCall(fn func()) { +func (mock *SessionMock) AddNextCall(fn memproxy.CallbackFunc) { if mock.AddNextCallFunc == nil { panic("SessionMock.AddNextCallFunc: method is nil but Session.AddNextCall was just called") } callInfo := struct { - Fn func() + Fn memproxy.CallbackFunc }{ Fn: fn, } @@ -598,10 +598,10 @@ func (mock *SessionMock) AddNextCall(fn func()) { // // len(mockedSession.AddNextCallCalls()) func (mock *SessionMock) AddNextCallCalls() []struct { - Fn func() + Fn memproxy.CallbackFunc } { var calls []struct { - Fn func() + Fn memproxy.CallbackFunc } mock.lockAddNextCall.RLock() calls = mock.calls.AddNextCall diff --git a/proxy/proxy.go b/proxy/proxy.go index f60d8b4..3b88d7b 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "unsafe" "github.com/QuangTung97/go-memcache/memcache" @@ -201,6 +202,11 @@ type leaseGetState struct { err error } +func retryOnOtherNodeCallback(obj unsafe.Pointer) { + s := (*leaseGetState)(obj) + s.retryOnOtherNode() +} + func (s *leaseGetState) retryOnOtherNode() { s.pipe.doExecuteForAllServers() @@ -212,6 +218,11 @@ func (s *leaseGetState) retryOnOtherNode() { } } +func leaseGetStateNextFuncCallback(obj unsafe.Pointer) { + s := (*leaseGetState)(obj) + s.nextFunc() +} + func (s *leaseGetState) nextFunc() { s.pipe.doExecuteForAllServers() @@ -229,7 +240,11 @@ func (s *leaseGetState) nextFunc() { pipe := s.pipe.getRoutePipeline(s.serverID) s.fn = pipe.LeaseGet(s.key, s.options) - s.pipe.sess.AddNextCall(s.retryOnOtherNode) + s.pipe.sess.AddNextCall(memproxy.CallbackFunc{ + Object: unsafe.Pointer(s), + Func: retryOnOtherNodeCallback, + }) + return } @@ -266,7 +281,10 @@ func (p *Pipeline) LeaseGet( fn: fn, } - p.sess.AddNextCall(state.nextFunc) + p.sess.AddNextCall(memproxy.CallbackFunc{ + Object: unsafe.Pointer(state), + Func: leaseGetStateNextFuncCallback, + }) return state } diff --git a/session.go b/session.go index d14b32a..7872ac2 100644 --- a/session.go +++ b/session.go @@ -70,7 +70,7 @@ func newSession( type sessionImpl struct { provider *sessionProviderImpl - nextCalls []func() + nextCalls []CallbackFunc heap delayedCallHeap isDirty bool // an optimization @@ -81,7 +81,7 @@ type sessionImpl struct { type delayedCall struct { startedAt time.Time - call func() + call CallbackFunc } var _ Session = &sessionImpl{} @@ -97,16 +97,16 @@ func setDirtyRecursive(s *sessionImpl) { } // AddNextCall ... -func (s *sessionImpl) AddNextCall(fn func()) { +func (s *sessionImpl) AddNextCall(fn CallbackFunc) { setDirtyRecursive(s) if s.nextCalls == nil { - s.nextCalls = make([]func(), 0, 32) + s.nextCalls = make([]CallbackFunc, 0, 32) } s.nextCalls = append(s.nextCalls, fn) } // AddDelayedCall ... -func (s *sessionImpl) AddDelayedCall(d time.Duration, fn func()) { +func (s *sessionImpl) AddDelayedCall(d time.Duration, fn CallbackFunc) { setDirtyRecursive(s) s.heap.push(delayedCall{ startedAt: s.provider.nowFn().Add(d), @@ -149,7 +149,7 @@ func (s *sessionImpl) executeNextCalls() { nextCalls := s.nextCalls s.nextCalls = nil for _, call := range nextCalls { - call() + call.Call() } } } @@ -170,7 +170,7 @@ MainLoop: continue MainLoop } s.heap.pop() - top.call() + top.call.Call() } } } diff --git a/session_test.go b/session_test.go index 0d5f0da..24affc6 100644 --- a/session_test.go +++ b/session_test.go @@ -1,9 +1,10 @@ package memproxy import ( - "github.com/stretchr/testify/assert" "testing" "time" + + "github.com/stretchr/testify/assert" ) type sessionTest struct { @@ -57,11 +58,11 @@ func newCallMock() *callMock { } } -func (m *callMock) get() func() { - return func() { +func (m *callMock) get() CallbackFunc { + return NewEmptyCallback(func() { m.count++ m.fn() - } + }) } func TestSessionAddNextCall(t *testing.T) {