diff --git a/changes/15630-fleetctl-saved-query b/changes/15630-fleetctl-saved-query new file mode 100644 index 000000000000..c9ca92f4dfff --- /dev/null +++ b/changes/15630-fleetctl-saved-query @@ -0,0 +1 @@ +Fixed fleetctl issue where it was creating a new query when running a query by name, as opposed to using the existing saved query. \ No newline at end of file diff --git a/cmd/fleetctl/get.go b/cmd/fleetctl/get.go index 1a748aacb56e..fef9c827a32f 100644 --- a/cmd/fleetctl/get.go +++ b/cmd/fleetctl/get.go @@ -343,7 +343,7 @@ func queryToTableRow(query fleet.Query, teamName string) []string { func printInheritedQueriesMsg(client *service.Client, teamID *uint) error { if teamID != nil { - globalQueries, err := client.GetQueries(nil) + globalQueries, err := client.GetQueries(nil, nil) if err != nil { return fmt.Errorf("could not list global queries: %w", err) } @@ -410,7 +410,7 @@ func getQueriesCommand() *cli.Command { // if name wasn't provided, list either all global queries or all team queries... if name == "" { - queries, err := client.GetQueries(teamID) + queries, err := client.GetQueries(teamID, nil) if err != nil { return fmt.Errorf("could not list queries: %w", err) } @@ -559,7 +559,7 @@ func getPacksCommand() *cli.Command { } // Get global queries (teamID==nil), because 2017 packs reference global queries. - queries, err := client.GetQueries(nil) + queries, err := client.GetQueries(nil, nil) if err != nil { return fmt.Errorf("could not list queries: %w", err) } diff --git a/cmd/fleetctl/goquery.go b/cmd/fleetctl/goquery.go index 4a5e4747ae9b..4c8eee406edd 100644 --- a/cmd/fleetctl/goquery.go +++ b/cmd/fleetctl/goquery.go @@ -76,7 +76,7 @@ func (c *goqueryClient) ScheduleQuery(uuid, query string) (string, error) { return "", errors.New("could not lookup host") } - res, err := c.client.LiveQuery(query, []string{}, []string{hostname}) + res, err := c.client.LiveQuery(query, nil, []string{}, []string{hostname}) if err != nil { return "", err } diff --git a/cmd/fleetctl/query.go b/cmd/fleetctl/query.go index 912b9f1661f2..5ea0b31baa4e 100644 --- a/cmd/fleetctl/query.go +++ b/cmd/fleetctl/query.go @@ -97,20 +97,31 @@ func queryCommand() *cli.Command { return errors.New("--query and --query-name must not be provided together") } + var queryID *uint if flQueryName != "" { var teamID *uint if tid := c.Uint(teamFlagName); tid != 0 { teamID = &tid } - q, err := fleet.GetQuerySpec(teamID, flQueryName) - if err != nil { + queries, err := fleet.GetQueries(teamID, &flQueryName) + if err != nil || len(queries) == 0 { return fmt.Errorf("Query '%s' not found", flQueryName) } - flQuery = q.Query - } - - if flQuery == "" { - return errors.New("Query must be specified with --query or --query-name") + // For backwards compatibility with older fleet server, we explicitly find the query in the result array + for _, query := range queries { + if query.Name == flQueryName { + id := query.ID // making an explicit copy of ID + queryID = &id + break + } + } + if queryID == nil { + return fmt.Errorf("Query '%s' not found", flQueryName) + } + } else { + if flQuery == "" { + return errors.New("Query must be specified with --query or --query-name") + } } var output outputWriter @@ -123,7 +134,7 @@ func queryCommand() *cli.Command { hosts := strings.Split(flHosts, ",") labels := strings.Split(flLabels, ",") - res, err := fleet.LiveQuery(flQuery, labels, hosts) + res, err := fleet.LiveQuery(flQuery, queryID, labels, hosts) if err != nil { return err } diff --git a/cmd/fleetctl/query_test.go b/cmd/fleetctl/query_test.go index d57564e39ecc..ce2188ee721c 100644 --- a/cmd/fleetctl/query_test.go +++ b/cmd/fleetctl/query_test.go @@ -17,7 +17,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestLiveQuery(t *testing.T) { +func TestSavedLiveQuery(t *testing.T) { rs := pubsub.NewInmemQueryResults() lq := live_query_mock.New(t) @@ -39,6 +39,15 @@ func TestLiveQuery(t *testing.T) { } } + const queryName = "saved-query" + const queryString = "select 42, * from time" + query := fleet.Query{ + ID: 42, + Name: queryName, + Query: queryString, + Saved: true, + } + ds.HostIDsByNameFunc = func(ctx context.Context, filter fleet.TeamFilter, hostnames []string) ([]uint, error) { return []uint{1234}, nil } @@ -48,9 +57,11 @@ func TestLiveQuery(t *testing.T) { ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { return &fleet.AppConfig{}, nil } - ds.NewQueryFunc = func(ctx context.Context, query *fleet.Query, opts ...fleet.OptionalArg) (*fleet.Query, error) { - query.ID = 42 - return query, nil + ds.ListQueriesFunc = func(ctx context.Context, opt fleet.ListQueryOptions) ([]*fleet.Query, error) { + if opt.MatchQuery == queryName { + return []*fleet.Query{&query}, nil + } + return []*fleet.Query{}, nil } ds.NewDistributedQueryCampaignFunc = func(ctx context.Context, camp *fleet.DistributedQueryCampaign) (*fleet.DistributedQueryCampaign, error) { camp.ID = 321 @@ -71,12 +82,12 @@ func TestLiveQuery(t *testing.T) { lq.On("QueriesForHost", uint(1)).Return( map[string]string{ - "42": "select 42, * from time", + "42": queryString, }, nil, ) lq.On("QueryCompletedByHost", "42", 99).Return(nil) - lq.On("RunQuery", "321", "select 42, * from time", []uint{1}).Return(nil) + lq.On("RunQuery", "321", queryString, []uint{1}).Return(nil) ds.DistributedQueryCampaignTargetIDsFunc = func(ctx context.Context, id uint) (targets *fleet.HostTargets, err error) { return &fleet.HostTargets{HostIDs: []uint{99}}, nil @@ -91,7 +102,7 @@ func TestLiveQuery(t *testing.T) { return nil } ds.QueryFunc = func(ctx context.Context, id uint) (*fleet.Query, error) { - return &fleet.Query{}, nil + return &query, nil } ds.IsSavedQueryFunc = func(ctx context.Context, queryID uint) (bool, error) { return true, nil @@ -138,7 +149,7 @@ func TestLiveQuery(t *testing.T) { expected := `{"host":"somehostname","rows":[{"bing":"fds","host_display_name":"somehostname","host_hostname":"somehostname"}]} ` - assert.Equal(t, expected, runAppForTest(t, []string{"query", "--hosts", "1234", "--query", "select 42, * from time"})) + assert.Equal(t, expected, runAppForTest(t, []string{"query", "--hosts", "1234", "--query-name", "saved-query"})) // We need to use waitGroups to detect whether Database functions were called because this is an asynchronous test which will flag data races otherwise. c := make(chan struct{}) @@ -157,3 +168,119 @@ func TestLiveQuery(t *testing.T) { case <-c: // All good } } + +func TestAdHocLiveQuery(t *testing.T) { + rs := pubsub.NewInmemQueryResults() + lq := live_query_mock.New(t) + + logger := kitlog.NewJSONLogger(os.Stdout) + logger = level.NewFilter(logger, level.AllowDebug()) + + _, ds := runServerWithMockedDS( + t, &service.TestServerOpts{ + Rs: rs, + Lq: lq, + Logger: logger, + }, + ) + + users, err := ds.ListUsersFunc(context.Background(), fleet.UserListOptions{}) + require.NoError(t, err) + var admin *fleet.User + for _, user := range users { + if user.GlobalRole != nil && *user.GlobalRole == fleet.RoleAdmin { + admin = user + } + } + + ds.HostIDsByNameFunc = func(ctx context.Context, filter fleet.TeamFilter, hostnames []string) ([]uint, error) { + return []uint{1234}, nil + } + ds.LabelIDsByNameFunc = func(ctx context.Context, labels []string) ([]uint, error) { + return nil, nil + } + ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { + return &fleet.AppConfig{}, nil + } + ds.NewQueryFunc = func(ctx context.Context, query *fleet.Query, opts ...fleet.OptionalArg) (*fleet.Query, error) { + query.ID = 42 + return query, nil + } + ds.NewDistributedQueryCampaignFunc = func(ctx context.Context, camp *fleet.DistributedQueryCampaign) ( + *fleet.DistributedQueryCampaign, error, + ) { + camp.ID = 321 + return camp, nil + } + ds.NewDistributedQueryCampaignTargetFunc = func( + ctx context.Context, target *fleet.DistributedQueryCampaignTarget, + ) (*fleet.DistributedQueryCampaignTarget, error) { + return target, nil + } + ds.HostIDsInTargetsFunc = func(ctx context.Context, filter fleet.TeamFilter, targets fleet.HostTargets) ([]uint, error) { + return []uint{1}, nil + } + ds.CountHostsInTargetsFunc = func( + ctx context.Context, filter fleet.TeamFilter, targets fleet.HostTargets, now time.Time, + ) (fleet.TargetMetrics, error) { + return fleet.TargetMetrics{TotalHosts: 1, OnlineHosts: 1}, nil + } + ds.NewActivityFunc = func(ctx context.Context, user *fleet.User, activity fleet.ActivityDetails) error { + return nil + } + + lq.On("QueriesForHost", uint(1)).Return( + map[string]string{ + "42": "select 42, * from time", + }, + nil, + ) + lq.On("QueryCompletedByHost", "42", 99).Return(nil) + lq.On("RunQuery", "321", "select 42, * from time", []uint{1}).Return(nil) + + ds.DistributedQueryCampaignTargetIDsFunc = func(ctx context.Context, id uint) (targets *fleet.HostTargets, err error) { + return &fleet.HostTargets{HostIDs: []uint{99}}, nil + } + ds.DistributedQueryCampaignFunc = func(ctx context.Context, id uint) (*fleet.DistributedQueryCampaign, error) { + return &fleet.DistributedQueryCampaign{ + ID: 321, + UserID: admin.ID, + }, nil + } + ds.SaveDistributedQueryCampaignFunc = func(ctx context.Context, camp *fleet.DistributedQueryCampaign) error { + return nil + } + ds.QueryFunc = func(ctx context.Context, id uint) (*fleet.Query, error) { + return &fleet.Query{}, nil + } + ds.IsSavedQueryFunc = func(ctx context.Context, queryID uint) (bool, error) { + return false, nil + } + + go func() { + time.Sleep(2 * time.Second) + require.NoError( + t, rs.WriteResult( + fleet.DistributedQueryResult{ + DistributedQueryCampaignID: 321, + Rows: []map[string]string{{"bing": "fds"}}, + Host: fleet.ResultHostData{ + ID: 99, + Hostname: "somehostname", + DisplayName: "somehostname", + }, + Stats: &fleet.Stats{ + WallTimeMs: 10, + UserTime: 20, + SystemTime: 30, + Memory: 40, + }, + }, + ), + ) + }() + + expected := `{"host":"somehostname","rows":[{"bing":"fds","host_display_name":"somehostname","host_hostname":"somehostname"}]} +` + assert.Equal(t, expected, runAppForTest(t, []string{"query", "--hosts", "1234", "--query", "select 42, * from time"})) +} diff --git a/cmd/fleetctl/upgrade_packs.go b/cmd/fleetctl/upgrade_packs.go index b3a62f858834..0533960d8f2f 100644 --- a/cmd/fleetctl/upgrade_packs.go +++ b/cmd/fleetctl/upgrade_packs.go @@ -80,7 +80,7 @@ func upgradePacksCommand() *cli.Command { } // get global queries (teamID==nil), because 2017 packs reference global queries. - queries, err := client.GetQueries(nil) + queries, err := client.GetQueries(nil, nil) if err != nil { return fmt.Errorf("could not list queries: %w", err) } diff --git a/server/datastore/mysql/queries.go b/server/datastore/mysql/queries.go index a399cffdbbe6..e4a279cd82a3 100644 --- a/server/datastore/mysql/queries.go +++ b/server/datastore/mysql/queries.go @@ -504,6 +504,11 @@ func (ds *Datastore) ListQueries(ctx context.Context, opt fleet.ListQueryOptions } } + if opt.MatchQuery != "" { + whereClauses += " AND q.name = ?" + args = append(args, opt.MatchQuery) + } + sql += whereClauses sql, args = appendListOptionsWithCursorToSQL(sql, args, &opt.ListOptions) diff --git a/server/datastore/mysql/queries_test.go b/server/datastore/mysql/queries_test.go index ed748ca84987..383752b87d58 100644 --- a/server/datastore/mysql/queries_test.go +++ b/server/datastore/mysql/queries_test.go @@ -7,6 +7,7 @@ import ( "math/rand" "sort" "testing" + "time" "github.com/fleetdm/fleet/v4/server/fleet" "github.com/fleetdm/fleet/v4/server/ptr" @@ -193,8 +194,9 @@ func testQueriesDelete(t *testing.T, ds *Datastore) { require.True(t, fleet.IsNotFound(err)) // Ensure stats were deleted. - // The actual delete occurs asynchronously, but enough time should have passed - // given the above DB access to ensure the original query completed. + // The actual delete occurs asynchronously, so enough time should have passed + // to ensure the original query completed. + time.Sleep(10 * time.Millisecond) stats, err := ds.GetLiveQueryStats(context.Background(), query.ID, []uint{hostID}) require.NoError(t, err) require.Equal(t, 0, len(stats)) @@ -278,8 +280,9 @@ func testQueriesDeleteMany(t *testing.T, ds *Datastore) { require.Nil(t, err) assert.Len(t, queries, 2) // Ensure stats were deleted. - // The actual delete occurs asynchronously, but enough time should have passed - // given the above DB access to ensure the original query completed. + // The actual delete occurs asynchronously, so enough time should have passed + // to ensure the original query completed. + time.Sleep(10 * time.Millisecond) stats, err := ds.GetLiveQueryStats(context.Background(), q1.ID, hostIDs) require.NoError(t, err) require.Equal(t, 0, len(stats)) diff --git a/server/service/client_live_query.go b/server/service/client_live_query.go index 38efa75455fa..eb0a69197ea9 100644 --- a/server/service/client_live_query.go +++ b/server/service/client_live_query.go @@ -62,12 +62,15 @@ func (h *LiveQueryResultsHandler) Status() *campaignStatus { } // LiveQuery creates a new live query and begins streaming results. -func (c *Client) LiveQuery(query string, labels []string, hosts []string) (*LiveQueryResultsHandler, error) { - return c.LiveQueryWithContext(context.Background(), query, labels, hosts) +func (c *Client) LiveQuery(query string, queryID *uint, labels []string, hosts []string) (*LiveQueryResultsHandler, error) { + return c.LiveQueryWithContext(context.Background(), query, queryID, labels, hosts) } -func (c *Client) LiveQueryWithContext(ctx context.Context, query string, labels []string, hosts []string) (*LiveQueryResultsHandler, error) { +func (c *Client) LiveQueryWithContext( + ctx context.Context, query string, queryID *uint, labels []string, hosts []string, +) (*LiveQueryResultsHandler, error) { req := createDistributedQueryCampaignByNamesRequest{ + QueryID: queryID, QuerySQL: query, Selected: distributedQueryCampaignTargetsByNames{Labels: labels, Hosts: hosts}, } diff --git a/server/service/client_live_query_test.go b/server/service/client_live_query_test.go index 3d3e79430e42..f7933fe1cdc3 100644 --- a/server/service/client_live_query_test.go +++ b/server/service/client_live_query_test.go @@ -99,7 +99,7 @@ func TestLiveQueryWithContext(t *testing.T) { ctx, cancelFunc := context.WithTimeout(context.Background(), 10*time.Second) defer cancelFunc() - res, err := client.LiveQueryWithContext(ctx, "select 1;", nil, []string{"host1"}) + res, err := client.LiveQueryWithContext(ctx, "select 1;", nil, nil, []string{"host1"}) require.NoError(t, err) gotResults := false diff --git a/server/service/client_queries.go b/server/service/client_queries.go index 00c27ee80d57..0d0f8ce2f9a2 100644 --- a/server/service/client_queries.go +++ b/server/service/client_queries.go @@ -29,12 +29,15 @@ func (c *Client) GetQuerySpec(teamID *uint, name string) (*fleet.QuerySpec, erro } // GetQueries retrieves the list of all Queries. -func (c *Client) GetQueries(teamID *uint) ([]fleet.Query, error) { +func (c *Client) GetQueries(teamID *uint, name *string) ([]fleet.Query, error) { verb, path := "GET", "/api/latest/fleet/queries" query := url.Values{} if teamID != nil { query.Set("team_id", fmt.Sprint(*teamID)) } + if name != nil { + query.Set("query", *name) + } var responseBody listQueriesResponse err := c.authenticatedRequestWithQuery(nil, verb, path, &responseBody, query.Encode()) if err != nil { diff --git a/server/service/integration_core_test.go b/server/service/integration_core_test.go index 8e63bdcd85a0..ea25d60f075e 100644 --- a/server/service/integration_core_test.go +++ b/server/service/integration_core_test.go @@ -2868,6 +2868,11 @@ func (s *integrationTestSuite) TestScheduledQueries() { require.Len(t, listQryResp.Queries, 1) assert.Equal(t, query.Name, listQryResp.Queries[0].Name) + // Return that query by name + s.DoJSON("GET", fmt.Sprintf("/api/latest/fleet/queries?query=%s", query.Name), nil, http.StatusOK, &listQryResp) + require.Len(t, listQryResp.Queries, 1) + assert.Equal(t, query.Name, listQryResp.Queries[0].Name) + // next page returns nothing s.DoJSON("GET", "/api/latest/fleet/queries", nil, http.StatusOK, &listQryResp, "per_page", "2", "page", "1") require.Len(t, listQryResp.Queries, 0)