Skip to content

Commit

Permalink
fix treatment of sync scripts + prevent running expired scripts on fl…
Browse files Browse the repository at this point in the history
…eet upgrade (#16567)

for #16547

# Checklist for submitter

If some of the following don't apply, delete the relevant line.

<!-- Note that API documentation changes are now addressed by the
product design team. -->

- [x] Input data is properly validated, `SELECT *` is avoided, SQL
injection is prevented (using placeholders for values in statements)
- [x] Added/updated tests
- For database migrations:
- [x] Checked schema for all modified table for columns that will
auto-update timestamps during migration.
- [x] Confirmed that updating the timestamps is acceptable, and will not
cause unwanted side effects.
- [x] Manual QA for all new/changed functionality
  • Loading branch information
Roberto Dip authored Feb 2, 2024
1 parent ffa929b commit 7ddf275
Show file tree
Hide file tree
Showing 8 changed files with 146 additions and 28 deletions.
8 changes: 7 additions & 1 deletion server/datastore/mysql/activities.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"strings"

"github.com/fleetdm/fleet/v4/pkg/scripts"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/jmoiron/sqlx"
Expand Down Expand Up @@ -212,9 +213,14 @@ func (ds *Datastore) ListHostUpcomingActivities(ctx context.Context, hostID uint
WHERE
hsr.host_id = ? AND
hsr.exit_code IS NULL
AND (
hsr.sync_request = 0
OR hsr.created_at >= DATE_SUB(NOW(), INTERVAL ? SECOND)
)
`

args := []any{fleet.ActivityTypeRanScript{}.ActivityName(), hostID}
seconds := int(scripts.MaxServerWaitTime.Seconds())
args := []any{fleet.ActivityTypeRanScript{}.ActivityName(), hostID, seconds}
stmt, args := appendListOptionsWithCursorToSQL(listStmt, args, &opt)

var activities []*fleet.Activity
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,25 @@ func Up_20240126020643(tx *sql.Tx) error {
return errors.Wrap(err, "create host_activities table")
}

// Prior to this update, the database didn't differentiate between
// "async" and "sync" requests. With Fleet v4.44.0, all async requests
// will execute regardless of their pending duration. To avoid
// unintended execution of old requests upon server upgrade, these are
// now marked as "sync", reflecting their original 5-minute execution
// limit.
const setOldScriptsAsSyncStmt = `
UPDATE host_script_results hsr
SET
sync_request = 1,
updated_at = hsr.updated_at
WHERE
user_id IS NULL
AND created_at < CURRENT_TIMESTAMP
`
if _, err := tx.Exec(setOldScriptsAsSyncStmt); err != nil {
return errors.Wrap(err, "set sync_request = 1 for old scripts")
}

return nil
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,37 @@ func TestUp_20240126020643(t *testing.T) {
// create a host execution request in the past
minutesAgo := time.Now().UTC().Add(-5 * time.Minute).Truncate(time.Second)
hsr1 := execNoErrLastID(t, db, `INSERT INTO host_script_results (host_id, execution_id, script_contents, output, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?)`, 1, uuid.NewString(), "echo 'hello'", "", minutesAgo, minutesAgo)
hsr2 := execNoErrLastID(t, db, `INSERT INTO host_script_results (host_id, execution_id, script_contents, output, created_at, updated_at, exit_code) VALUES (?, ?, ?, ?, ?, ?, ?)`, 1, uuid.NewString(), "echo 'hello'", "", minutesAgo, minutesAgo, 1)

// Apply current migration.
applyNext(t, db)

// existing host execution request's timestamp hasn't changed (despite added column)
type timestamps struct {
CreatedAt time.Time `db:"created_at"`
UpdatedAt time.Time `db:"updated_at"`
// async request is set to `true` for existing results
// existing host execution request's timestamp hasn't changed (despite
// added column, and modified sync_request)
type scriptResults struct {
CreatedAt time.Time `db:"created_at"`
UpdatedAt time.Time `db:"updated_at"`
SyncRequest bool `db:"sync_request"`
}

var ts timestamps
err := db.Get(&ts, `SELECT created_at, updated_at FROM host_script_results WHERE id = ?`, hsr1)
var sr scriptResults
err := db.Get(&sr, `SELECT created_at, updated_at, sync_request FROM host_script_results WHERE id = ?`, hsr1)
require.NoError(t, err)
assert.Equal(t, minutesAgo, ts.CreatedAt)
assert.Equal(t, minutesAgo, ts.UpdatedAt)
assert.Equal(t, minutesAgo, sr.CreatedAt)
assert.Equal(t, minutesAgo, sr.UpdatedAt)
assert.True(t, sr.SyncRequest)

sr = scriptResults{}
err = db.Get(&sr, `SELECT created_at, updated_at, sync_request FROM host_script_results WHERE id = ?`, hsr2)
require.NoError(t, err)
assert.Equal(t, minutesAgo, sr.CreatedAt)
assert.Equal(t, minutesAgo, sr.UpdatedAt)
assert.True(t, sr.SyncRequest)

// create a new host execution request with user u1 and one with u2
hsr2 := execNoErrLastID(t, db, `INSERT INTO host_script_results (host_id, execution_id, script_contents, output, user_id) VALUES (?, ?, ?, ?, ?)`, 1, uuid.NewString(), "echo 'hello'", "", u1)
hsr3 := execNoErrLastID(t, db, `INSERT INTO host_script_results (host_id, execution_id, script_contents, output, user_id) VALUES (?, ?, ?, ?, ?)`, 1, uuid.NewString(), "echo 'hello'", "", u2)
hsr3 := execNoErrLastID(t, db, `INSERT INTO host_script_results (host_id, execution_id, script_contents, output, user_id) VALUES (?, ?, ?, ?, ?)`, 1, uuid.NewString(), "echo 'hello'", "", u1)
hsr4 := execNoErrLastID(t, db, `INSERT INTO host_script_results (host_id, execution_id, script_contents, output, user_id) VALUES (?, ?, ?, ?, ?)`, 1, uuid.NewString(), "echo 'hello'", "", u2)

// create a host activity entry for act1
execNoErr(t, db, `INSERT INTO host_activities (host_id, activity_id) VALUES (?, ?)`, 1, act1)
Expand All @@ -51,10 +63,10 @@ func TestUp_20240126020643(t *testing.T) {

var userID sql.NullInt64
// hsr2 now has a NULL user id, but hsr3 still has user id u2
err = db.Get(&userID, `SELECT user_id FROM host_script_results WHERE id = ?`, hsr2)
err = db.Get(&userID, `SELECT user_id FROM host_script_results WHERE id = ?`, hsr3)
require.NoError(t, err)
assert.False(t, userID.Valid)
err = db.Get(&userID, `SELECT user_id FROM host_script_results WHERE id = ?`, hsr3)
err = db.Get(&userID, `SELECT user_id FROM host_script_results WHERE id = ?`, hsr4)
require.NoError(t, err)
assert.True(t, userID.Valid)
assert.Equal(t, u2, userID.Int64)
Expand Down
9 changes: 8 additions & 1 deletion server/datastore/mysql/scripts.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"time"
"unicode/utf8"

"github.com/fleetdm/fleet/v4/pkg/scripts"
"github.com/fleetdm/fleet/v4/server/contexts/ctxerr"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/google/uuid"
Expand Down Expand Up @@ -97,11 +98,17 @@ func (ds *Datastore) ListPendingHostScriptExecutions(ctx context.Context, hostID
WHERE
host_id = ? AND
exit_code IS NULL
-- async requests + sync requests created within the given interval
AND (
sync_request = 0
OR created_at >= DATE_SUB(NOW(), INTERVAL ? SECOND)
)
ORDER BY
created_at ASC`

var results []*fleet.HostScriptResult
if err := sqlx.SelectContext(ctx, ds.reader(ctx), &results, listStmt, hostID); err != nil {
seconds := int(scripts.MaxServerWaitTime.Seconds())
if err := sqlx.SelectContext(ctx, ds.reader(ctx), &results, listStmt, hostID, seconds); err != nil {
return nil, ctxerr.Wrap(ctx, err, "list pending host script executions")
}
return results, nil
Expand Down
50 changes: 50 additions & 0 deletions server/datastore/mysql/scripts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,56 @@ func testHostScriptResult(t *testing.T, ds *Datastore) {
script, err = ds.GetHostScriptExecutionResult(ctx, createdScript.ExecutionID)
require.NoError(t, err)
require.Equal(t, expectedOutput, script.Output)

// create an async execution request
createdScript, err = ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{
HostID: 1,
ScriptContents: "echo 3",
UserID: &u.ID,
SyncRequest: false,
})
require.NoError(t, err)
require.NotZero(t, createdScript.ID)
require.NotEmpty(t, createdScript.ExecutionID)
require.Equal(t, uint(1), createdScript.HostID)
require.NotEmpty(t, createdScript.ExecutionID)
require.Equal(t, "echo 3", createdScript.ScriptContents)
require.Nil(t, createdScript.ExitCode)
require.Empty(t, createdScript.Output)
require.NotNil(t, createdScript.UserID)
require.Equal(t, u.ID, *createdScript.UserID)
require.False(t, createdScript.SyncRequest)

// the script execution is now listed as pending for this host
pending, err = ds.ListPendingHostScriptExecutions(ctx, 1)
require.NoError(t, err)
require.Len(t, pending, 1)
require.Equal(t, createdScript.ID, pending[0].ID)

// modify the timestamp of the script to simulate an script that has
// been pending for a long time
ExecAdhocSQL(t, ds, func(tx sqlx.ExtContext) error {
_, err := tx.ExecContext(ctx, "UPDATE host_script_results SET created_at = ? WHERE id = ?", time.Now().Add(-24*time.Hour), createdScript.ID)
return err
})

// the script execution still shows as pending
pending, err = ds.ListPendingHostScriptExecutions(ctx, 1)
require.NoError(t, err)
require.Len(t, pending, 1)
require.Equal(t, createdScript.ID, pending[0].ID)

// modify the script to be a sync script that has
// been pending for a long time
ExecAdhocSQL(t, ds, func(tx sqlx.ExtContext) error {
_, err := tx.ExecContext(ctx, "UPDATE host_script_results SET sync_request = 1 WHERE id = ?", createdScript.ID)
return err
})

// the script is not pending anymore
pending, err = ds.ListPendingHostScriptExecutions(ctx, 1)
require.NoError(t, err)
require.Empty(t, pending, 0)
}

func testScripts(t *testing.T, ds *Datastore) {
Expand Down
25 changes: 16 additions & 9 deletions server/service/integration_core_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9784,7 +9784,7 @@ func (s *integrationTestSuite) TestListHostUpcomingActivities() {
})
require.NoError(t, err)

hsr, err := s.ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{HostID: host1.ID, ScriptContents: "A"})
hsr, err := s.ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{HostID: host1.ID, ScriptContents: "A", SyncRequest: true})
require.NoError(t, err)
h1A := hsr.ExecutionID
hsr, err = s.ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{HostID: host1.ID, ScriptContents: "B"})
Expand All @@ -9793,45 +9793,52 @@ func (s *integrationTestSuite) TestListHostUpcomingActivities() {
hsr, err = s.ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{HostID: host1.ID, ScriptContents: "C"})
require.NoError(t, err)
h1C := hsr.ExecutionID
hsr, err = s.ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{HostID: host1.ID, ScriptContents: "D"})
hsr, err = s.ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{HostID: host1.ID, ScriptContents: "D", SyncRequest: true})
require.NoError(t, err)
h1D := hsr.ExecutionID
hsr, err = s.ds.NewHostScriptExecutionRequest(ctx, &fleet.HostScriptRequestPayload{HostID: host1.ID, ScriptContents: "E"})
require.NoError(t, err)
h1E := hsr.ExecutionID

// modify the timestamp h1D to simulate an script that has
// been pending for a long time
mysql.ExecAdhocSQL(t, s.ds, func(tx sqlx.ExtContext) error {
_, err := tx.ExecContext(ctx, "UPDATE host_script_results SET created_at = ? WHERE execution_id IN (?, ?)", time.Now().Add(-24*time.Hour), h1A, h1B)
return err
})

cases := []struct {
queries []string // alternate query name and value
wantExecs []string
wantMeta *fleet.PaginationMetadata
}{
{
wantExecs: []string{h1A, h1B, h1C, h1D, h1E},
wantExecs: []string{h1B, h1C, h1D, h1E},
wantMeta: &fleet.PaginationMetadata{HasNextResults: false, HasPreviousResults: false},
},
{
queries: []string{"per_page", "2"},
wantExecs: []string{h1A, h1B},
wantExecs: []string{h1B, h1C},
wantMeta: &fleet.PaginationMetadata{HasNextResults: true, HasPreviousResults: false},
},
{
queries: []string{"per_page", "2", "page", "1"},
wantExecs: []string{h1C, h1D},
wantMeta: &fleet.PaginationMetadata{HasNextResults: true, HasPreviousResults: true},
wantExecs: []string{h1D, h1E},
wantMeta: &fleet.PaginationMetadata{HasNextResults: false, HasPreviousResults: true},
},
{
queries: []string{"per_page", "2", "page", "2"},
wantExecs: []string{h1E},
wantExecs: nil,
wantMeta: &fleet.PaginationMetadata{HasNextResults: false, HasPreviousResults: true},
},
{
queries: []string{"per_page", "3"},
wantExecs: []string{h1A, h1B, h1C},
wantExecs: []string{h1B, h1C, h1D},
wantMeta: &fleet.PaginationMetadata{HasNextResults: true, HasPreviousResults: false},
},
{
queries: []string{"per_page", "3", "page", "1"},
wantExecs: []string{h1D, h1E},
wantExecs: []string{h1E},
wantMeta: &fleet.PaginationMetadata{HasNextResults: false, HasPreviousResults: true},
},
{
Expand Down
24 changes: 22 additions & 2 deletions server/service/integration_enterprise_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4653,8 +4653,28 @@ func (s *integrationEnterpriseTestSuite) TestRunHostScript() {
require.False(t, runSyncResp.HostTimeout)
require.Contains(t, runSyncResp.Message, "Scripts are disabled")

// create a sync execution request.
runSyncResp = runScriptSyncResponse{}
s.DoJSON("POST", "/api/latest/fleet/scripts/run/sync", fleet.HostScriptRequestPayload{HostID: host.ID, ScriptContents: "echo"}, http.StatusRequestTimeout, &runSyncResp)

// modify the timestamp of the script to simulate an script that has
// been pending for a long time
mysql.ExecAdhocSQL(t, s.ds, func(tx sqlx.ExtContext) error {
_, err := tx.ExecContext(context.Background(), "UPDATE host_script_results SET created_at = ? WHERE execution_id = ?", time.Now().Add(-24*time.Hour), runSyncResp.ExecutionID)
return err
})

// fetch the results for the timed-out script
scriptResultResp = getScriptResultResponse{}
s.DoJSON("GET", "/api/latest/fleet/scripts/results/"+runSyncResp.ExecutionID, nil, http.StatusOK, &scriptResultResp)
require.Equal(t, host.ID, scriptResultResp.HostID)
require.Equal(t, "echo", scriptResultResp.ScriptContents)
require.Nil(t, scriptResultResp.ExitCode)
require.True(t, scriptResultResp.HostTimeout)
require.Contains(t, scriptResultResp.Message, fleet.RunScriptHostTimeoutErrMsg)

// make the host "offline"
err = s.ds.MarkHostsSeen(ctx, []uint{host.ID}, time.Now().Add(-time.Hour))
err = s.ds.MarkHostsSeen(context.Background(), []uint{host.ID}, time.Now().Add(-time.Hour))
require.NoError(t, err)

// attempt to create a sync script execution request, fails because the host
Expand All @@ -4664,7 +4684,7 @@ func (s *integrationEnterpriseTestSuite) TestRunHostScript() {
require.Contains(t, errMsg, fleet.RunScriptHostOfflineErrMsg)

// attempt to create an async script execution request, succeeds because script is added to queue.
res = s.Do("POST", "/api/latest/fleet/scripts/run", fleet.HostScriptRequestPayload{HostID: host.ID, ScriptContents: "echo"}, http.StatusAccepted)
s.Do("POST", "/api/latest/fleet/scripts/run", fleet.HostScriptRequestPayload{HostID: host.ID, ScriptContents: "echo"}, http.StatusAccepted)
}

func (s *integrationEnterpriseTestSuite) TestRunHostSavedScript() {
Expand Down
3 changes: 0 additions & 3 deletions server/service/orbit.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,9 +227,6 @@ func (svc *Service) GetOrbitConfig(ctx context.Context) (fleet.OrbitConfig, erro

// load the pending script executions for that host
if !appConfig.ServerSettings.ScriptsDisabled {
// it is important that the "ignoreOlder" parameter in this call is the
// same everywhere (which is here and in RunScript to check if there is
// already a pending script).
pending, err := svc.ds.ListPendingHostScriptExecutions(ctx, host.ID)
if err != nil {
return fleet.OrbitConfig{}, err
Expand Down

0 comments on commit 7ddf275

Please sign in to comment.