diff --git a/internal/core/algorithm/ngt/ngt_test.go b/internal/core/algorithm/ngt/ngt_test.go index 0186f2abed..82cffb9464 100644 --- a/internal/core/algorithm/ngt/ngt_test.go +++ b/internal/core/algorithm/ngt/ngt_test.go @@ -4765,3 +4765,121 @@ func Test_ngt_Close(t *testing.T) { }) } } + +func Test_ngt_Property(t *testing.T) { + type fields struct { + dimension int + objectType objectType + distanceType distanceType + } + type want struct { + want *Property + err error + } + type test struct { + name string + fields fields + want want + createFunc func(t *testing.T, fields fields) (NGT, error) + checkFunc func(want, *Property, error) error + beforeFunc func() + afterFunc func(*testing.T, NGT) error + } + defaultCheckFunc := func(w want, prop *Property, err error) error { + if !errors.Is(err, w.err) { + return errors.Errorf("got_error: \"%#v\",\n\t\t\t\twant: \"%#v\"", err, w.err) + } + + return nil + } + defaultCreateFunc := func(t *testing.T, fields fields) (NGT, error) { + t.Helper() + + return New( + WithObjectType(fields.objectType), + WithDimension(fields.dimension), + WithDistanceType(fields.distanceType), + ) + } + tests := []test{ + { + name: "get ngt property", + fields: fields{ + dimension: 9, + objectType: Float, + distanceType: L2, + }, + want: want{ + want: &Property{ + Dimension: 9, + ObjectType: Float, + DistanceType: L2, + IndexType: GraphAndTree, + DatabaseType: Memory, + GraphType: ANNG, + }, + err: nil, + }, + checkFunc: func(w want, p *Property, err error) error { + if err := defaultCheckFunc(w, p, err); err != nil { + return err + } + if p.Dimension != w.want.Dimension { + return errors.Errorf("got_dimension: \"%d\", want_dimension: \"%d\"", p.Dimension, w.want.Dimension) + } + if p.ObjectType != w.want.ObjectType { + return errors.Errorf("got_object_type: \"%v\", want_object_type: \"%v\"", p.ObjectType, w.want.ObjectType) + } + if p.DistanceType != w.want.DistanceType { + return errors.Errorf("got_distance_type: \"%v\", want_distance_type: \"%v\"", p.DistanceType, w.want.DistanceType) + } + if p.IndexType != w.want.IndexType { + return errors.Errorf("got_index_type: \"%v\", want_index_type: \"%v\"", p.IndexType, w.want.IndexType) + } + if p.DatabaseType != w.want.DatabaseType { + return errors.Errorf("got_database_type: \"%v\", want_database_type: \"%v\"", p.DatabaseType, w.want.DatabaseType) + } + if p.GraphType != w.want.GraphType { + return errors.Errorf("got_graph_type: \"%v\", want_graph_type: \"%v\"", p.GraphType, w.want.GraphType) + } + return nil + }, + }, + } + + for _, tc := range tests { + test := tc + t.Run(test.name, func(tt *testing.T) { + tt.Parallel() + defer goleak.VerifyNone(tt, goleak.IgnoreCurrent()) + if test.beforeFunc != nil { + test.beforeFunc() + } + if test.afterFunc == nil { + test.afterFunc = defaultAfterFunc + } + checkFunc := test.checkFunc + if test.checkFunc == nil { + checkFunc = defaultCheckFunc + } + if test.createFunc == nil { + test.createFunc = defaultCreateFunc + } + + n, err := test.createFunc(tt, test.fields) + if err != nil { + tt.Fatal(err) + } + defer func() { + if err := test.afterFunc(tt, n); err != nil { + tt.Error(err) + } + }() + + prop, err := n.GetProperty() + if err := checkFunc(test.want, prop, err); err != nil { + tt.Errorf("error = %v", err) + } + }) + } +}