Skip to content

Commit

Permalink
fleetctl now runs saved queries (#15667)
Browse files Browse the repository at this point in the history
📺 Looom:
https://www.loom.com/share/1aec4616fa4449e7abac579084aef0ba?sid=0884f742-feb3-48bb-82dc-b7834bc9a6e1

Fixed fleetctl issue where it was creating a new query when running a
query by name, as opposed to using the existing saved query.
#15630

API change will be in a separate PR:
#15673

# Checklist for submitter

If some of the following don't apply, delete the relevant line.

<!-- Note that API documentation changes are now addressed by the
product design team. -->

- [x] Changes file added for user-visible changes in `changes/` or
`orbit/changes/`.
See [Changes
files](https://fleetdm.com/docs/contributing/committing-changes#changes-files)
for more information.
- [x] Input data is properly validated, `SELECT *` is avoided, SQL
injection is prevented (using placeholders for values in statements)
- [x] Added/updated tests
- [x] Manual QA for all new/changed functionality
  • Loading branch information
getvictor authored Dec 15, 2023
1 parent 5e3f501 commit 0e040cc
Show file tree
Hide file tree
Showing 12 changed files with 188 additions and 30 deletions.
1 change: 1 addition & 0 deletions changes/15630-fleetctl-saved-query
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed fleetctl issue where it was creating a new query when running a query by name, as opposed to using the existing saved query.
6 changes: 3 additions & 3 deletions cmd/fleetctl/get.go
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ func queryToTableRow(query fleet.Query, teamName string) []string {

func printInheritedQueriesMsg(client *service.Client, teamID *uint) error {
if teamID != nil {
globalQueries, err := client.GetQueries(nil)
globalQueries, err := client.GetQueries(nil, nil)
if err != nil {
return fmt.Errorf("could not list global queries: %w", err)
}
Expand Down Expand Up @@ -410,7 +410,7 @@ func getQueriesCommand() *cli.Command {

// if name wasn't provided, list either all global queries or all team queries...
if name == "" {
queries, err := client.GetQueries(teamID)
queries, err := client.GetQueries(teamID, nil)
if err != nil {
return fmt.Errorf("could not list queries: %w", err)
}
Expand Down Expand Up @@ -559,7 +559,7 @@ func getPacksCommand() *cli.Command {
}

// Get global queries (teamID==nil), because 2017 packs reference global queries.
queries, err := client.GetQueries(nil)
queries, err := client.GetQueries(nil, nil)
if err != nil {
return fmt.Errorf("could not list queries: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion cmd/fleetctl/goquery.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func (c *goqueryClient) ScheduleQuery(uuid, query string) (string, error) {
return "", errors.New("could not lookup host")
}

res, err := c.client.LiveQuery(query, []string{}, []string{hostname})
res, err := c.client.LiveQuery(query, nil, []string{}, []string{hostname})
if err != nil {
return "", err
}
Expand Down
27 changes: 19 additions & 8 deletions cmd/fleetctl/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,20 +97,31 @@ func queryCommand() *cli.Command {
return errors.New("--query and --query-name must not be provided together")
}

var queryID *uint
if flQueryName != "" {
var teamID *uint
if tid := c.Uint(teamFlagName); tid != 0 {
teamID = &tid
}
q, err := fleet.GetQuerySpec(teamID, flQueryName)
if err != nil {
queries, err := fleet.GetQueries(teamID, &flQueryName)
if err != nil || len(queries) == 0 {
return fmt.Errorf("Query '%s' not found", flQueryName)
}
flQuery = q.Query
}

if flQuery == "" {
return errors.New("Query must be specified with --query or --query-name")
// For backwards compatibility with older fleet server, we explicitly find the query in the result array
for _, query := range queries {
if query.Name == flQueryName {
id := query.ID // making an explicit copy of ID
queryID = &id
break
}
}
if queryID == nil {
return fmt.Errorf("Query '%s' not found", flQueryName)
}
} else {
if flQuery == "" {
return errors.New("Query must be specified with --query or --query-name")
}
}

var output outputWriter
Expand All @@ -123,7 +134,7 @@ func queryCommand() *cli.Command {
hosts := strings.Split(flHosts, ",")
labels := strings.Split(flLabels, ",")

res, err := fleet.LiveQuery(flQuery, labels, hosts)
res, err := fleet.LiveQuery(flQuery, queryID, labels, hosts)
if err != nil {
return err
}
Expand Down
143 changes: 135 additions & 8 deletions cmd/fleetctl/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import (
"github.com/stretchr/testify/require"
)

func TestLiveQuery(t *testing.T) {
func TestSavedLiveQuery(t *testing.T) {
rs := pubsub.NewInmemQueryResults()
lq := live_query_mock.New(t)

Expand All @@ -39,6 +39,15 @@ func TestLiveQuery(t *testing.T) {
}
}

const queryName = "saved-query"
const queryString = "select 42, * from time"
query := fleet.Query{
ID: 42,
Name: queryName,
Query: queryString,
Saved: true,
}

ds.HostIDsByNameFunc = func(ctx context.Context, filter fleet.TeamFilter, hostnames []string) ([]uint, error) {
return []uint{1234}, nil
}
Expand All @@ -48,9 +57,11 @@ func TestLiveQuery(t *testing.T) {
ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
return &fleet.AppConfig{}, nil
}
ds.NewQueryFunc = func(ctx context.Context, query *fleet.Query, opts ...fleet.OptionalArg) (*fleet.Query, error) {
query.ID = 42
return query, nil
ds.ListQueriesFunc = func(ctx context.Context, opt fleet.ListQueryOptions) ([]*fleet.Query, error) {
if opt.MatchQuery == queryName {
return []*fleet.Query{&query}, nil
}
return []*fleet.Query{}, nil
}
ds.NewDistributedQueryCampaignFunc = func(ctx context.Context, camp *fleet.DistributedQueryCampaign) (*fleet.DistributedQueryCampaign, error) {
camp.ID = 321
Expand All @@ -71,12 +82,12 @@ func TestLiveQuery(t *testing.T) {

lq.On("QueriesForHost", uint(1)).Return(
map[string]string{
"42": "select 42, * from time",
"42": queryString,
},
nil,
)
lq.On("QueryCompletedByHost", "42", 99).Return(nil)
lq.On("RunQuery", "321", "select 42, * from time", []uint{1}).Return(nil)
lq.On("RunQuery", "321", queryString, []uint{1}).Return(nil)

ds.DistributedQueryCampaignTargetIDsFunc = func(ctx context.Context, id uint) (targets *fleet.HostTargets, err error) {
return &fleet.HostTargets{HostIDs: []uint{99}}, nil
Expand All @@ -91,7 +102,7 @@ func TestLiveQuery(t *testing.T) {
return nil
}
ds.QueryFunc = func(ctx context.Context, id uint) (*fleet.Query, error) {
return &fleet.Query{}, nil
return &query, nil
}
ds.IsSavedQueryFunc = func(ctx context.Context, queryID uint) (bool, error) {
return true, nil
Expand Down Expand Up @@ -138,7 +149,7 @@ func TestLiveQuery(t *testing.T) {

expected := `{"host":"somehostname","rows":[{"bing":"fds","host_display_name":"somehostname","host_hostname":"somehostname"}]}
`
assert.Equal(t, expected, runAppForTest(t, []string{"query", "--hosts", "1234", "--query", "select 42, * from time"}))
assert.Equal(t, expected, runAppForTest(t, []string{"query", "--hosts", "1234", "--query-name", "saved-query"}))

// We need to use waitGroups to detect whether Database functions were called because this is an asynchronous test which will flag data races otherwise.
c := make(chan struct{})
Expand All @@ -157,3 +168,119 @@ func TestLiveQuery(t *testing.T) {
case <-c: // All good
}
}

func TestAdHocLiveQuery(t *testing.T) {
rs := pubsub.NewInmemQueryResults()
lq := live_query_mock.New(t)

logger := kitlog.NewJSONLogger(os.Stdout)
logger = level.NewFilter(logger, level.AllowDebug())

_, ds := runServerWithMockedDS(
t, &service.TestServerOpts{
Rs: rs,
Lq: lq,
Logger: logger,
},
)

users, err := ds.ListUsersFunc(context.Background(), fleet.UserListOptions{})
require.NoError(t, err)
var admin *fleet.User
for _, user := range users {
if user.GlobalRole != nil && *user.GlobalRole == fleet.RoleAdmin {
admin = user
}
}

ds.HostIDsByNameFunc = func(ctx context.Context, filter fleet.TeamFilter, hostnames []string) ([]uint, error) {
return []uint{1234}, nil
}
ds.LabelIDsByNameFunc = func(ctx context.Context, labels []string) ([]uint, error) {
return nil, nil
}
ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
return &fleet.AppConfig{}, nil
}
ds.NewQueryFunc = func(ctx context.Context, query *fleet.Query, opts ...fleet.OptionalArg) (*fleet.Query, error) {
query.ID = 42
return query, nil
}
ds.NewDistributedQueryCampaignFunc = func(ctx context.Context, camp *fleet.DistributedQueryCampaign) (
*fleet.DistributedQueryCampaign, error,
) {
camp.ID = 321
return camp, nil
}
ds.NewDistributedQueryCampaignTargetFunc = func(
ctx context.Context, target *fleet.DistributedQueryCampaignTarget,
) (*fleet.DistributedQueryCampaignTarget, error) {
return target, nil
}
ds.HostIDsInTargetsFunc = func(ctx context.Context, filter fleet.TeamFilter, targets fleet.HostTargets) ([]uint, error) {
return []uint{1}, nil
}
ds.CountHostsInTargetsFunc = func(
ctx context.Context, filter fleet.TeamFilter, targets fleet.HostTargets, now time.Time,
) (fleet.TargetMetrics, error) {
return fleet.TargetMetrics{TotalHosts: 1, OnlineHosts: 1}, nil
}
ds.NewActivityFunc = func(ctx context.Context, user *fleet.User, activity fleet.ActivityDetails) error {
return nil
}

lq.On("QueriesForHost", uint(1)).Return(
map[string]string{
"42": "select 42, * from time",
},
nil,
)
lq.On("QueryCompletedByHost", "42", 99).Return(nil)
lq.On("RunQuery", "321", "select 42, * from time", []uint{1}).Return(nil)

ds.DistributedQueryCampaignTargetIDsFunc = func(ctx context.Context, id uint) (targets *fleet.HostTargets, err error) {
return &fleet.HostTargets{HostIDs: []uint{99}}, nil
}
ds.DistributedQueryCampaignFunc = func(ctx context.Context, id uint) (*fleet.DistributedQueryCampaign, error) {
return &fleet.DistributedQueryCampaign{
ID: 321,
UserID: admin.ID,
}, nil
}
ds.SaveDistributedQueryCampaignFunc = func(ctx context.Context, camp *fleet.DistributedQueryCampaign) error {
return nil
}
ds.QueryFunc = func(ctx context.Context, id uint) (*fleet.Query, error) {
return &fleet.Query{}, nil
}
ds.IsSavedQueryFunc = func(ctx context.Context, queryID uint) (bool, error) {
return false, nil
}

go func() {
time.Sleep(2 * time.Second)
require.NoError(
t, rs.WriteResult(
fleet.DistributedQueryResult{
DistributedQueryCampaignID: 321,
Rows: []map[string]string{{"bing": "fds"}},
Host: fleet.ResultHostData{
ID: 99,
Hostname: "somehostname",
DisplayName: "somehostname",
},
Stats: &fleet.Stats{
WallTimeMs: 10,
UserTime: 20,
SystemTime: 30,
Memory: 40,
},
},
),
)
}()

expected := `{"host":"somehostname","rows":[{"bing":"fds","host_display_name":"somehostname","host_hostname":"somehostname"}]}
`
assert.Equal(t, expected, runAppForTest(t, []string{"query", "--hosts", "1234", "--query", "select 42, * from time"}))
}
2 changes: 1 addition & 1 deletion cmd/fleetctl/upgrade_packs.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func upgradePacksCommand() *cli.Command {
}

// get global queries (teamID==nil), because 2017 packs reference global queries.
queries, err := client.GetQueries(nil)
queries, err := client.GetQueries(nil, nil)
if err != nil {
return fmt.Errorf("could not list queries: %w", err)
}
Expand Down
5 changes: 5 additions & 0 deletions server/datastore/mysql/queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,11 @@ func (ds *Datastore) ListQueries(ctx context.Context, opt fleet.ListQueryOptions
}
}

if opt.MatchQuery != "" {
whereClauses += " AND q.name = ?"
args = append(args, opt.MatchQuery)
}

sql += whereClauses
sql, args = appendListOptionsWithCursorToSQL(sql, args, &opt.ListOptions)

Expand Down
11 changes: 7 additions & 4 deletions server/datastore/mysql/queries_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"math/rand"
"sort"
"testing"
"time"

"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/ptr"
Expand Down Expand Up @@ -193,8 +194,9 @@ func testQueriesDelete(t *testing.T, ds *Datastore) {
require.True(t, fleet.IsNotFound(err))

// Ensure stats were deleted.
// The actual delete occurs asynchronously, but enough time should have passed
// given the above DB access to ensure the original query completed.
// The actual delete occurs asynchronously, so enough time should have passed
// to ensure the original query completed.
time.Sleep(10 * time.Millisecond)
stats, err := ds.GetLiveQueryStats(context.Background(), query.ID, []uint{hostID})
require.NoError(t, err)
require.Equal(t, 0, len(stats))
Expand Down Expand Up @@ -278,8 +280,9 @@ func testQueriesDeleteMany(t *testing.T, ds *Datastore) {
require.Nil(t, err)
assert.Len(t, queries, 2)
// Ensure stats were deleted.
// The actual delete occurs asynchronously, but enough time should have passed
// given the above DB access to ensure the original query completed.
// The actual delete occurs asynchronously, so enough time should have passed
// to ensure the original query completed.
time.Sleep(10 * time.Millisecond)
stats, err := ds.GetLiveQueryStats(context.Background(), q1.ID, hostIDs)
require.NoError(t, err)
require.Equal(t, 0, len(stats))
Expand Down
9 changes: 6 additions & 3 deletions server/service/client_live_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,15 @@ func (h *LiveQueryResultsHandler) Status() *campaignStatus {
}

// LiveQuery creates a new live query and begins streaming results.
func (c *Client) LiveQuery(query string, labels []string, hosts []string) (*LiveQueryResultsHandler, error) {
return c.LiveQueryWithContext(context.Background(), query, labels, hosts)
func (c *Client) LiveQuery(query string, queryID *uint, labels []string, hosts []string) (*LiveQueryResultsHandler, error) {
return c.LiveQueryWithContext(context.Background(), query, queryID, labels, hosts)
}

func (c *Client) LiveQueryWithContext(ctx context.Context, query string, labels []string, hosts []string) (*LiveQueryResultsHandler, error) {
func (c *Client) LiveQueryWithContext(
ctx context.Context, query string, queryID *uint, labels []string, hosts []string,
) (*LiveQueryResultsHandler, error) {
req := createDistributedQueryCampaignByNamesRequest{
QueryID: queryID,
QuerySQL: query,
Selected: distributedQueryCampaignTargetsByNames{Labels: labels, Hosts: hosts},
}
Expand Down
2 changes: 1 addition & 1 deletion server/service/client_live_query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ func TestLiveQueryWithContext(t *testing.T) {
ctx, cancelFunc := context.WithTimeout(context.Background(), 10*time.Second)
defer cancelFunc()

res, err := client.LiveQueryWithContext(ctx, "select 1;", nil, []string{"host1"})
res, err := client.LiveQueryWithContext(ctx, "select 1;", nil, nil, []string{"host1"})
require.NoError(t, err)

gotResults := false
Expand Down
5 changes: 4 additions & 1 deletion server/service/client_queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,15 @@ func (c *Client) GetQuerySpec(teamID *uint, name string) (*fleet.QuerySpec, erro
}

// GetQueries retrieves the list of all Queries.
func (c *Client) GetQueries(teamID *uint) ([]fleet.Query, error) {
func (c *Client) GetQueries(teamID *uint, name *string) ([]fleet.Query, error) {
verb, path := "GET", "/api/latest/fleet/queries"
query := url.Values{}
if teamID != nil {
query.Set("team_id", fmt.Sprint(*teamID))
}
if name != nil {
query.Set("query", *name)
}
var responseBody listQueriesResponse
err := c.authenticatedRequestWithQuery(nil, verb, path, &responseBody, query.Encode())
if err != nil {
Expand Down
5 changes: 5 additions & 0 deletions server/service/integration_core_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2868,6 +2868,11 @@ func (s *integrationTestSuite) TestScheduledQueries() {
require.Len(t, listQryResp.Queries, 1)
assert.Equal(t, query.Name, listQryResp.Queries[0].Name)

// Return that query by name
s.DoJSON("GET", fmt.Sprintf("/api/latest/fleet/queries?query=%s", query.Name), nil, http.StatusOK, &listQryResp)
require.Len(t, listQryResp.Queries, 1)
assert.Equal(t, query.Name, listQryResp.Queries[0].Name)

// next page returns nothing
s.DoJSON("GET", "/api/latest/fleet/queries", nil, http.StatusOK, &listQryResp, "per_page", "2", "page", "1")
require.Len(t, listQryResp.Queries, 0)
Expand Down

0 comments on commit 0e040cc

Please sign in to comment.