Skip to content

Commit

Permalink
feat: add a lock for grid cache
Browse files Browse the repository at this point in the history
  • Loading branch information
scorix committed Dec 3, 2024
1 parent 0e3b451 commit 4f5d750
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 9 deletions.
5 changes: 5 additions & 0 deletions pkg/geo/grids/gaussian/regular.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"cmp"
"fmt"
"math"
"sync"

"github.com/scorix/walg/pkg/geo/distance"
"github.com/scorix/walg/pkg/geo/grids"
Expand All @@ -18,11 +19,15 @@ type regular struct {

var regularCache = make(map[int]*regular)
var regularCacheGroup singleflight.Group
var regularCacheLock sync.Mutex

func NewRegular(n int) *regular {
name := fmt.Sprintf("F%d", n)

r, _, _ := regularCacheGroup.Do(name, func() (any, error) {
regularCacheLock.Lock()
defer regularCacheLock.Unlock()

if cached, ok := regularCache[n]; ok {
return cached, nil
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/geo/grids/grid.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ func GridIndex(g Grid, lat, lon float64, mode ScanMode) int {
return GridIndexFromIndices(g, latIdx, lonIdx, mode)
}

func GridPoint(g Grid, index int, mode ScanMode) (lat, lon float64) {
func GridPoint(g Grid, index int, mode ScanMode) (lat, lon float64, ok bool) {
if index < 0 || index >= g.Size() {
return math.NaN(), math.NaN()
return math.NaN(), math.NaN(), false
}

latitudesSize := len(g.Latitudes())
Expand Down Expand Up @@ -56,7 +56,7 @@ func GridPoint(g Grid, index int, mode ScanMode) (lat, lon float64) {
lonIdx = longitudesSize - 1 - lonIdx
}

return g.Latitudes()[latIdx], g.Longitudes()[lonIdx]
return g.Latitudes()[latIdx], g.Longitudes()[lonIdx], true
}

func GridIndexFromIndices(g Grid, latIdx, lonIdx int, mode ScanMode) int {
Expand Down
15 changes: 11 additions & 4 deletions pkg/geo/grids/grid_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,27 @@ func runGridTests(t *testing.T, grid grids.Grid, tests []gridTestCase, mode grid
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
idx := grids.GridIndex(grid, tt.lat, tt.lon, mode)
recoveredLat, recoveredLon := grids.GridPoint(grid, idx, mode)
recoveredLat, recoveredLon, ok := grids.GridPoint(grid, idx, mode)
require.True(t, ok)

actualDist := distance.VincentyIterations(tt.lat, tt.lon, recoveredLat, recoveredLon, iterations)
t.Logf("Actual #%d index: (%.3f, %.3f), dist: %fkm from (%.3f, %.3f)",
idx, recoveredLat, recoveredLon, actualDist, tt.lat, tt.lon)

nearestIdxs := grids.NewNearestGrids(grid).NearestGrids(tt.lat, tt.lon, mode)
for i, nearIdx := range nearestIdxs {
nearLat, nearLon := grids.GridPoint(grid, nearIdx, mode)
nearLat, nearLon, ok := grids.GridPoint(grid, nearIdx, mode)
require.True(t, ok)

dist := distance.VincentyIterations(tt.lat, tt.lon, nearLat, nearLon, iterations)
t.Logf("Nearest %d index #%d: (%.3f, %.3f), dist: %fkm from (%.3f, %.3f)",
i, nearIdx, nearLat, nearLon, dist, tt.lat, tt.lon)
assert.GreaterOrEqual(t, dist, actualDist)
}

expectedLat, expectedLon := grids.GridPoint(grid, tt.expectedIdx, mode)
expectedLat, expectedLon, ok := grids.GridPoint(grid, tt.expectedIdx, mode)
require.True(t, ok)

expectedDist := distance.VincentyIterations(tt.lat, tt.lon, expectedLat, expectedLon, iterations)
t.Logf("Expected #%d index: (%.3f, %.3f), dist: %f from (%.3f, %.3f)",
tt.expectedIdx, expectedLat, expectedLon, expectedDist, tt.lat, tt.lon)
Expand All @@ -64,7 +69,9 @@ func runGridTests(t *testing.T, grid grids.Grid, tests []gridTestCase, mode grid

// guess
guessIdx := grids.GuessGridIndex(grid, tt.lat, tt.lon, mode)
guessLat, guessLon := grids.GridPoint(grid, guessIdx, mode)
guessLat, guessLon, ok := grids.GridPoint(grid, guessIdx, mode)
require.True(t, ok)

guessDist := distance.VincentyIterations(tt.lat, tt.lon, guessLat, guessLon, iterations)
t.Logf("Guess index #%d: (%.3f, %.3f), dist: %f from (%.3f, %.3f)",
guessIdx, guessLat, guessLon, guessDist, tt.lat, tt.lon)
Expand Down
5 changes: 5 additions & 0 deletions pkg/geo/grids/latlon/latlon.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"cmp"
"fmt"
"math"
"sync"

"github.com/scorix/walg/pkg/geo/distance"
"github.com/scorix/walg/pkg/geo/grids"
Expand All @@ -27,11 +28,15 @@ type latLon struct {

var latLonCache = make(map[string]*latLon)
var latLonCacheGroup singleflight.Group
var latLonCacheLock sync.Mutex

func NewLatLonGrid(minLat, maxLat, minLon, maxLon, latStep, lonStep float64) *latLon {
name := fmt.Sprintf("L%f,%f,%f,%f,%f,%f", minLat, maxLat, minLon, maxLon, latStep, lonStep)

ll, _, _ := latLonCacheGroup.Do(name, func() (any, error) {
latLonCacheLock.Lock()
defer latLonCacheLock.Unlock()

if cached, ok := latLonCache[name]; ok {
return cached, nil
}
Expand Down
8 changes: 6 additions & 2 deletions pkg/geo/grids/latlon/latlon_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/scorix/walg/pkg/geo/grids"
"github.com/scorix/walg/pkg/geo/grids/latlon"
Expand Down Expand Up @@ -149,7 +150,9 @@ func TestGridPoint(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
lat, lon := grids.GridPoint(grid, tt.index, 0)
lat, lon, ok := grids.GridPoint(grid, tt.index, 0)
require.True(t, ok)

assert.Equal(t, tt.expectedLat, lat, "Latitude mismatch for index %d", tt.index)
assert.Equal(t, tt.expectedLon, lon, "Longitude mismatch for index %d", tt.index)
})
Expand All @@ -162,7 +165,8 @@ func TestGridRoundTrip(t *testing.T) {
// Test conversion from lat/lon to index and back
originalLat, originalLon := 32.5, 112.5
index := grids.GridIndex(grid, originalLat, originalLon, 0)
recoveredLat, recoveredLon := grids.GridPoint(grid, index, 0)
recoveredLat, recoveredLon, ok := grids.GridPoint(grid, index, 0)
require.True(t, ok)

assert.Equal(t, originalLat, recoveredLat, "Latitude should remain unchanged")
assert.Equal(t, originalLon, recoveredLon, "Longitude should remain unchanged")
Expand Down

0 comments on commit 4f5d750

Please sign in to comment.