diff --git a/matching/score.go b/matching/score.go index f728c5c2f..8eea1a1bd 100644 --- a/matching/score.go +++ b/matching/score.go @@ -17,7 +17,7 @@ const ( // TODO, RW: remove in the future func ScoreOld(places []Place) float64 { if len(places) == 1 { - return singlePlaceScore(places[0]) + return PlaceScore(places[0]) } distances := calDistances(places) // Haversine distances maxDist := math.Max(0.001, calMaxDistance(distances)) // protect against maximum distance being zero @@ -30,7 +30,7 @@ func ScoreOld(places []Place) float64 { // Score uses constant distance normalisation factor func Score(places []Place, distNorm int) float64 { if len(places) == 1 { - return singlePlaceScore(places[0]) + return PlaceScore(places[0]) } distances := calDistances(places) // Haversine distances avgDistance := stat.Mean(distances, nil) / float64(distNorm) // normalized average distance @@ -39,7 +39,7 @@ func Score(places []Place, distNorm int) float64 { return avgScore - avgDistance } -func singlePlaceScore(place Place) float64 { +func PlaceScore(place Place) float64 { var boostFactor float64 if place.PlacePrice() == 0 { boostFactor = float64(place.Rating()) / MaxPlaceRating @@ -70,7 +70,7 @@ func avgPlacesScore(places []Place) float64 { numPlaces := len(places) placeScores := make([]float64, numPlaces) for k, place := range places { - placeScores[k] = singlePlaceScore(place) + placeScores[k] = PlaceScore(place) } return stat.Mean(placeScores, nil) } diff --git a/matching/score_test.go b/matching/score_test.go new file mode 100644 index 000000000..88fbb9349 --- /dev/null +++ b/matching/score_test.go @@ -0,0 +1,43 @@ +package matching + +import ( + "github.com/weihesdlegend/Vacation-planner/POI" + "testing" +) + +func TestPlaceScore(t *testing.T) { + type args struct { + place Place + } + tests := []struct { + name string + args args + want float64 + }{ + { + name: "test compute place score with zero price should return correct result", + args: args{place: Place{ + Place: &POI.Place{Name: "Local Park", UserRatingsTotal: 99, Rating: 4.0}, + Category: POI.PlaceCategoryVisit, + Price: 0, + }}, + want: 1.6, + }, + { + name: "test compute place score with price at level two should return correct result", + args: args{place: Place{ + Place: &POI.Place{Name: "Uncle Pizza", UserRatingsTotal: 999, Rating: 4.0}, + Category: POI.PlaceCategoryEatery, + Price: 2, + }}, + want: 6.0, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := PlaceScore(tt.args.place); got != tt.want { + t.Errorf("PlaceScore() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/planner/solver.go b/planner/solver.go index 4c116514e..013135e66 100644 --- a/planner/solver.go +++ b/planner/solver.go @@ -1,6 +1,7 @@ package planner import ( + "cmp" "container/heap" "context" "errors" @@ -537,6 +538,8 @@ func (s *Solver) generatePlacesForSlots(ctx context.Context, req *PlanningReques if len(placesByPrice) == 0 { return nil, fmt.Errorf("failed to find any place for category %s at slot %s for location %+v", slot.Category, slot.TimeSlot.ToString(), req.Location) } + // sort places by score descending so the solver checks places with higher score first + slices.SortFunc(placesByPrice, func(a, b matching.Place) int { return cmp.Compare(matching.PlaceScore(b), matching.PlaceScore(a)) }) placeClusters = append(placeClusters, placesByPrice) } return placeClusters, nil