From ba8f8d09f30fb2e277eae069335135a5e565c1ba Mon Sep 17 00:00:00 2001 From: Tim Wang Date: Mon, 23 Oct 2023 15:24:56 -0700 Subject: [PATCH] Refactor: add MatcherName to Matcher interface --- matching/matcher.go | 13 ++++++++++ planner/solver.go | 54 ++++++++++++++++-------------------------- planner/solver_test.go | 29 ++++++++--------------- 3 files changed, 43 insertions(+), 53 deletions(-) diff --git a/matching/matcher.go b/matching/matcher.go index 602b65375..61bc75212 100644 --- a/matching/matcher.go +++ b/matching/matcher.go @@ -10,6 +10,7 @@ import ( type Matcher interface { Match(req *FilterRequest) (places []Place, err error) + MatcherName() string } // FilterCriteria is an enum for various points of interest filtering criteria @@ -38,6 +39,10 @@ type FilterRequest struct { type MatcherForPriceRange struct { } +func (matcher MatcherForPriceRange) MatcherName() string { + return "Matcher for Price Range" +} + func (matcher MatcherForPriceRange) Match(req *FilterRequest) ([]Place, error) { filterParams := req.Params[FilterByPriceRange] @@ -90,6 +95,10 @@ func NearbySearchForCategory(ctx context.Context, searcher iowrappers.SearchClie return results, nil } +func (m MatcherForTime) MatcherName() string { + return "Matcher for Time" +} + func (m MatcherForTime) Match(req *FilterRequest) ([]Place, error) { var results []Place filterParams := req.Params[FilterByTimePeriod] @@ -134,6 +143,10 @@ type UserRatingFilterParams struct { type MatcherForUserRatings struct { } +func (m MatcherForUserRatings) MatcherName() string { + return "Matcher for User Ratings" +} + func (m MatcherForUserRatings) Match(req *FilterRequest) ([]Place, error) { var results []Place filterParams := req.Params[FilterByUserRating] diff --git a/planner/solver.go b/planner/solver.go index ae7ccc41f..2c27bae30 100644 --- a/planner/solver.go +++ b/planner/solver.go @@ -48,9 +48,7 @@ func (ps PlanningSolution) Key() float64 { type Solver struct { Searcher *iowrappers.PoiSearcher placeMatcher *PlaceMatcher - timeMatcher *matching.MatcherForTime - priceMatcher *matching.MatcherForPriceRange - userRatingsMatcher *matching.MatcherForUserRatings + concreteMatchers []matching.Matcher placeDedupeCountLimit int nearbyCitiesCountLimit int } @@ -111,10 +109,11 @@ func (s *Solver) Init(poiSearcher *iowrappers.PoiSearcher, placeDedupeCountLimit s.Searcher = poiSearcher s.placeDedupeCountLimit = placeDedupeCountLimit s.nearbyCitiesCountLimit = nearbyCitiesCountLimit - s.timeMatcher = &matching.MatcherForTime{} - s.priceMatcher = &matching.MatcherForPriceRange{} - s.userRatingsMatcher = &matching.MatcherForUserRatings{} - s.placeMatcher = NewMatcher(s.timeMatcher) + s.concreteMatchers = make([]matching.Matcher, 0) + s.concreteMatchers = append(s.concreteMatchers, &matching.MatcherForUserRatings{}) + s.concreteMatchers = append(s.concreteMatchers, &matching.MatcherForTime{}) + s.concreteMatchers = append(s.concreteMatchers, &matching.MatcherForPriceRange{}) + s.placeMatcher = NewPlaceMatcher() } func (s *Solver) ValidateLocation(ctx context.Context, location *POI.Location) bool { @@ -486,8 +485,8 @@ type PlaceMatcher struct { m matching.Matcher } -func NewMatcher(matcher matching.Matcher) *PlaceMatcher { - return &PlaceMatcher{m: matcher} +func NewPlaceMatcher() *PlaceMatcher { + return &PlaceMatcher{} } func (pm *PlaceMatcher) setMatcher(matcher matching.Matcher) { @@ -549,33 +548,20 @@ func (s *Solver) generatePlacesForSlots(ctx context.Context, req *PlanningReques func (s *Solver) filterPlaces(places []matching.Place, params map[matching.FilterCriteria]interface{}, c POI.PlaceCategory) ([]matching.Place, error) { logger := iowrappers.Logger - placesByRating, err := s.placeMatcher.MatchPlaces(&matching.FilterRequest{ - Places: places, - Params: params, - }, s.userRatingsMatcher) - if err != nil { - return nil, err - } - logger.Debugf("Filtered by zero user rating count, the number of places for category %s is %d", c, len(placesByRating)) - - placesByTime, err := s.placeMatcher.MatchPlaces(&matching.FilterRequest{ - Places: placesByRating, - Params: params, - }, s.timeMatcher) - if err != nil { - return nil, err + var res = places + var err error + for _, m := range s.concreteMatchers { + res, err = s.placeMatcher.MatchPlaces(&matching.FilterRequest{ + Places: res, + Params: params, + }, m) + if err != nil { + return nil, err + } + logger.Debugf("Filtered by %s, the number of places for category %s is %d", m.MatcherName(), c, len(res)) } - logger.Debugf("Filtered by time, the number of places for category %s is %d", c, len(placesByTime)) - placesByPrice, err := s.placeMatcher.MatchPlaces(&matching.FilterRequest{ - Places: placesByTime, - Params: params, - }, s.priceMatcher) - if err != nil { - return nil, err - } - logger.Debugf("Filtered by price range, the number of places for category %s is %d", c, len(placesByPrice)) - return placesByPrice, nil + return res, nil } func (s *Solver) generateSolutions(ctx context.Context, req *PlanningRequest) (resp *PlanningResp) { diff --git a/planner/solver_test.go b/planner/solver_test.go index 74788a950..ce02a9e55 100644 --- a/planner/solver_test.go +++ b/planner/solver_test.go @@ -11,18 +11,17 @@ import ( "testing" ) -func TestSolver_filterPlaces(t *testing.T) { +func TestSolver_filterPlaces1(t *testing.T) { type fields struct { - Searcher *iowrappers.PoiSearcher - placeMatcher *PlaceMatcher - timeMatcher *matching.MatcherForTime - priceMatcher *matching.MatcherForPriceRange - userRatingsMatcher *matching.MatcherForUserRatings + Searcher *iowrappers.PoiSearcher + placeMatcher *PlaceMatcher + placeDedupeCountLimit int + nearbyCitiesCountLimit int } type args struct { - c POI.PlaceCategory places []matching.Place params map[matching.FilterCriteria]interface{} + c POI.PlaceCategory } tests := []struct { name string @@ -34,11 +33,8 @@ func TestSolver_filterPlaces(t *testing.T) { { name: "test filter places should return correct results", fields: fields{ - Searcher: &iowrappers.PoiSearcher{}, - placeMatcher: &PlaceMatcher{}, - timeMatcher: &matching.MatcherForTime{}, - priceMatcher: &matching.MatcherForPriceRange{}, - userRatingsMatcher: &matching.MatcherForUserRatings{}, + Searcher: &iowrappers.PoiSearcher{}, + placeMatcher: &PlaceMatcher{}, }, args: args{ c: POI.PlaceCategoryEatery, @@ -101,13 +97,8 @@ func TestSolver_filterPlaces(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - s := &Solver{ - Searcher: tt.fields.Searcher, - placeMatcher: tt.fields.placeMatcher, - timeMatcher: tt.fields.timeMatcher, - priceMatcher: tt.fields.priceMatcher, - userRatingsMatcher: tt.fields.userRatingsMatcher, - } + s := &Solver{} + s.Init(tt.fields.Searcher, tt.fields.placeDedupeCountLimit, tt.fields.nearbyCitiesCountLimit) got, err := s.filterPlaces(tt.args.places, tt.args.params, tt.args.c) if (err != nil) != tt.wantErr { t.Errorf("filterPlaces() error = %v, wantErr %v", err, tt.wantErr)