diff --git a/api/config_server.go b/api/config_server.go index 86a298b..ba56dce 100644 --- a/api/config_server.go +++ b/api/config_server.go @@ -64,6 +64,10 @@ func (s *server) Delete(_ context.Context, sel *pb.Selector) (*pb.DeleteRecordsR } } + if err := s.config.save(s.ConfigProvider); err != nil { + return nil, status.Error(codes.Internal, err.Error()) + } + return &pb.DeleteRecordsResponse{}, nil } diff --git a/api/config_server_test.go b/api/config_server_test.go index d685e6e..c73203d 100644 --- a/api/config_server_test.go +++ b/api/config_server_test.go @@ -13,56 +13,75 @@ import ( pb "github.com/pomerium/cli/proto" ) -func TestLoadSave(t *testing.T) { +func TestUpsertLoadSave(t *testing.T) { ctx := context.Background() - - opt := api.WithConfigProvider(new(api.MemCP)) - cfg, err := api.NewServer(ctx, opt) - require.NoError(t, err, "load empty config") + provider := api.WithConfigProvider(new(api.MemCP)) var ids []string - for _, r := range []*pb.Record{ - { - Tags: []string{"one"}, - Conn: &pb.Connection{ - Name: proto.String("test one"), - RemoteAddr: "test1.another.domain.com", - ListenAddr: proto.String(":9993"), + t.Run("upsert", func(t *testing.T) { + cfg, err := api.NewServer(ctx, provider) + require.NoError(t, err, "load empty config") + for _, r := range []*pb.Record{ + { + Tags: []string{"one"}, + Conn: &pb.Connection{ + Name: proto.String("test one"), + RemoteAddr: "test1.another.domain.com", + ListenAddr: proto.String(":9993"), + }, }, - }, - { - Tags: []string{"one", "two"}, - Conn: &pb.Connection{ - Name: proto.String("test two"), - RemoteAddr: "test2.some.domain.com", - ListenAddr: proto.String(":9991"), + { + Tags: []string{"one", "two"}, + Conn: &pb.Connection{ + Name: proto.String("test two"), + RemoteAddr: "test2.some.domain.com", + ListenAddr: proto.String(":9991"), + }, }, - }, - } { - r, err := cfg.Upsert(ctx, r) - if assert.NoError(t, err) { - assert.NotNil(t, r.Id) - ids = append(ids, r.GetId()) + } { + r, err := cfg.Upsert(ctx, r) + if assert.NoError(t, err) { + assert.NotNil(t, r.Id) + ids = append(ids, r.GetId()) + } } - } + }) - cfg, err = api.NewServer(ctx, opt) - require.NoError(t, err, "load config") + t.Run("load", func(t *testing.T) { + cfg, err := api.NewServer(ctx, provider) + require.NoError(t, err, "load config") - selectors := map[string]*pb.Selector{ - "all": { - All: true, - }, "ids": { - Ids: ids, - }, "tags": { - Tags: []string{"one"}, - }} - for label, s := range selectors { - recs, err := cfg.List(ctx, s) - if assert.NoError(t, err, label) && assert.NotNil(t, recs, label) { - assert.Len(t, recs.Records, len(ids), label) + selectors := map[string]*pb.Selector{ + "all": { + All: true, + }, "ids": { + Ids: ids, + }, "tags": { + Tags: []string{"one"}, + }} + for label, s := range selectors { + recs, err := cfg.List(ctx, s) + if assert.NoError(t, err, label) && assert.NotNil(t, recs, label) { + assert.Len(t, recs.Records, len(ids), label) + } } - } + }) + + t.Run("delete", func(t *testing.T) { + cfg, err := api.NewServer(ctx, provider) + require.NoError(t, err, "load config") + + _, err = cfg.Delete(ctx, &pb.Selector{All: true}) + assert.NoError(t, err) + }) + + t.Run("load", func(t *testing.T) { + cfg, err := api.NewServer(ctx, provider) + require.NoError(t, err, "load config") + recs, err := cfg.List(ctx, &pb.Selector{All: true}) + require.NoError(t, err) + assert.Empty(t, recs.Records) + }) } func TestCertInfo(t *testing.T) {