Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: re-enroll devices that are removed from ABM and then added back (#23757) #23835

Merged
merged 1 commit into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions changes/23200-ade-enroll
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
- Fixes a bug where a device that was removed from ABM and then added back wouldn't properly
re-enroll in Fleet MDM
36 changes: 36 additions & 0 deletions server/datastore/mysql/hosts.go
Original file line number Diff line number Diff line change
Expand Up @@ -5155,6 +5155,42 @@ func (ds *Datastore) GetMatchingHostSerials(ctx context.Context, serials []strin
return result, nil
}

func (ds *Datastore) GetMatchingHostSerialsMarkedDeleted(ctx context.Context, serials []string) (map[string]struct{}, error) {
result := map[string]struct{}{}
if len(serials) == 0 {
return result, nil
}

stmt := `
SELECT
hardware_serial
FROM
hosts h
JOIN host_dep_assignments hdep ON hdep.host_id = h.id
WHERE
h.hardware_serial IN (?) AND hdep.deleted_at IS NOT NULL;
`

var args []interface{}
for _, serial := range serials {
args = append(args, serial)
}
stmt, args, err := sqlx.In(stmt, args)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "building IN statement for matching hosts")
}
var matchingSerials []string
if err := sqlx.SelectContext(ctx, ds.reader(ctx), &matchingSerials, stmt, args...); err != nil {
return nil, err
}

for _, serial := range matchingSerials {
result[serial] = struct{}{}
}

return result, nil
}

func (ds *Datastore) GetHostHealth(ctx context.Context, id uint) (*fleet.HostHealth, error) {
sqlStmt := `
SELECT h.os_version, h.updated_at, h.platform, h.team_id, hd.encrypted as disk_encryption_enabled FROM hosts h
Expand Down
81 changes: 81 additions & 0 deletions server/datastore/mysql/hosts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ func TestHosts(t *testing.T) {
{"UpdateHostIssues", testUpdateHostIssues},
{"ListUpcomingHostMaintenanceWindows", testListUpcomingHostMaintenanceWindows},
{"GetHostEmails", testGetHostEmails},
{"TestGetMatchingHostSerialsMarkedDeleted", testGetMatchingHostSerialsMarkedDeleted},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
Expand Down Expand Up @@ -9751,3 +9752,83 @@ func testGetHostEmails(t *testing.T, ds *Datastore) {
require.NoError(t, err)
assert.ElementsMatch(t, []string{"foo@example.com", "bar@example.com"}, emails)
}

func testGetMatchingHostSerialsMarkedDeleted(t *testing.T, ds *Datastore) {
ctx := context.Background()
serials := []string{"foo", "bar", "baz"}
team, err := ds.NewTeam(context.Background(), &fleet.Team{
Name: "team1",
})
require.NoError(t, err)
abmTok, err := ds.InsertABMToken(ctx, &fleet.ABMToken{OrganizationName: t.Name(), EncryptedToken: []byte("token")})
require.NoError(t, err)
var hosts []fleet.Host
for i, serial := range serials {
var tmID *uint
if serial == "bar" {
tmID = &team.ID
}
h, err := ds.NewHost(ctx, &fleet.Host{
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
PolicyUpdatedAt: time.Now(),
SeenTime: time.Now(),
NodeKey: ptr.String(fmt.Sprint(i)),
UUID: fmt.Sprint(i),
OsqueryHostID: ptr.String(fmt.Sprint(i)),
Hostname: "foo.local",
PrimaryIP: "192.168.1.1",
PrimaryMac: "30-65-EC-6F-C4-58",
HardwareSerial: serial,
TeamID: tmID,
ID: uint(i),
})
require.NoError(t, err)
require.NotNil(t, h)

// Only "foo" and "baz" are
if i%2 == 0 {
hosts = append(hosts, *h)
}
}

require.NoError(t, ds.UpsertMDMAppleHostDEPAssignments(ctx, hosts, abmTok.ID))
require.NoError(t, ds.DeleteHostDEPAssignments(ctx, abmTok.ID, serials))

cases := []struct {
name string
in []string
want map[string]struct{}
err string
}{
{"no serials provided", []string{}, map[string]struct{}{}, ""},
{"no matching serials", []string{"oof", "rab", "bar"}, map[string]struct{}{}, ""},
{
"partial matches",
[]string{"foo", "rab", "bar"},
map[string]struct{}{"foo": {}},
"",
},
{
"all matching",
[]string{"foo", "baz"},
map[string]struct{}{
"foo": {},
"baz": {},
},
"",
},
}

for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
got, err := ds.GetMatchingHostSerialsMarkedDeleted(ctx, tt.in)
if tt.err == "" {
require.NoError(t, err)
} else {
require.ErrorContains(t, err, tt.err)
}
require.Equal(t, tt.want, got)
})
}
}
5 changes: 5 additions & 0 deletions server/fleet/datastore.go
Original file line number Diff line number Diff line change
Expand Up @@ -1298,6 +1298,11 @@ type Datastore interface {
// a map that only contains the serials that have a matching row in the `hosts` table.
GetMatchingHostSerials(ctx context.Context, serials []string) (map[string]*Host, error)

// GetMatchingHostSerialsMarkedDeleted takes a list of device serial numbers and returns a map
// of only the ones that were found in the `hosts` table AND have a row in
// `host_dep_assignments` that is marked as deleted.
GetMatchingHostSerialsMarkedDeleted(ctx context.Context, serials []string) (map[string]struct{}, error)

// DeleteHostDEPAssignmentsFromAnotherABM makes as deleted any DEP entry that matches one of the provided serials only if the entry is NOT associated to the provided ABM token.
DeleteHostDEPAssignmentsFromAnotherABM(ctx context.Context, abmTokenID uint, serials []string) error

Expand Down
17 changes: 15 additions & 2 deletions server/mdm/apple/apple_mdm.go
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,18 @@ func (d *DEPService) processDeviceResponse(
for _, device := range addedDevicesSlice {
addedSerials = append(addedSerials, device.SerialNumber)
}

// Check if any of the "added" or "modified" hosts are hosts that we've recently removed from
// Fleet in ABM. A host in this state will have a row in `host_dep_assignments` where the
// `deleted_at ` col is NOT NULL. Down below we skip assigning the profile to devices that we
// think are still enrolled; doing this check here allows us to avoid skipping devices that
// _seem_ like they're still enrolled but were actually removed and should get the profile.
// See https://github.com/fleetdm/fleet/issues/23200 for more context.
existingDeletedSerials, err := d.ds.GetMatchingHostSerialsMarkedDeleted(ctx, addedSerials)
if err != nil {
return ctxerr.Wrap(ctx, err, "get matching deleted host serials")
}

err = d.ds.DeleteHostDEPAssignmentsFromAnotherABM(ctx, abmTokenID, addedSerials)
if err != nil {
return ctxerr.Wrap(ctx, err, "deleting dep assignments from another abm")
Expand All @@ -682,7 +694,7 @@ func (d *DEPService) processDeviceResponse(
}

level.Debug(kitlog.With(d.logger)).Log("msg", "devices to assign DEP profiles", "to_add", len(addedDevicesSlice), "to_remove",
deletedSerials, "to_modify", modifiedSerials)
strings.Join(deletedSerials, ", "), "to_modify", strings.Join(modifiedSerials, ", "))

// at this point, the hosts rows are created for the devices, with the
// correct team_id, so we know what team-specific profile needs to be applied.
Expand Down Expand Up @@ -754,7 +766,8 @@ func (d *DEPService) processDeviceResponse(
for profUUID, devices := range profileToDevices {
var serials []string
for _, device := range devices {
if device.ProfileUUID == profUUID {
_, ok := existingDeletedSerials[device.SerialNumber]
if device.ProfileUUID == profUUID && !ok {
skippedSerials = append(skippedSerials, device.SerialNumber)
continue
}
Expand Down
12 changes: 12 additions & 0 deletions server/mock/datastore_mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -859,6 +859,8 @@ type GetMDMAppleDefaultSetupAssistantFunc func(ctx context.Context, teamID *uint

type GetMatchingHostSerialsFunc func(ctx context.Context, serials []string) (map[string]*fleet.Host, error)

type GetMatchingHostSerialsMarkedDeletedFunc func(ctx context.Context, serials []string) (map[string]struct{}, error)

type DeleteHostDEPAssignmentsFromAnotherABMFunc func(ctx context.Context, abmTokenID uint, serials []string) error

type DeleteHostDEPAssignmentsFunc func(ctx context.Context, abmTokenID uint, serials []string) error
Expand Down Expand Up @@ -2405,6 +2407,9 @@ type DataStore struct {
GetMatchingHostSerialsFunc GetMatchingHostSerialsFunc
GetMatchingHostSerialsFuncInvoked bool

GetMatchingHostSerialsMarkedDeletedFunc GetMatchingHostSerialsMarkedDeletedFunc
GetMatchingHostSerialsMarkedDeletedFuncInvoked bool

DeleteHostDEPAssignmentsFromAnotherABMFunc DeleteHostDEPAssignmentsFromAnotherABMFunc
DeleteHostDEPAssignmentsFromAnotherABMFuncInvoked bool

Expand Down Expand Up @@ -5773,6 +5778,13 @@ func (s *DataStore) GetMatchingHostSerials(ctx context.Context, serials []string
return s.GetMatchingHostSerialsFunc(ctx, serials)
}

func (s *DataStore) GetMatchingHostSerialsMarkedDeleted(ctx context.Context, serials []string) (map[string]struct{}, error) {
s.mu.Lock()
s.GetMatchingHostSerialsMarkedDeletedFuncInvoked = true
s.mu.Unlock()
return s.GetMatchingHostSerialsMarkedDeletedFunc(ctx, serials)
}

func (s *DataStore) DeleteHostDEPAssignmentsFromAnotherABM(ctx context.Context, abmTokenID uint, serials []string) error {
s.mu.Lock()
s.DeleteHostDEPAssignmentsFromAnotherABMFuncInvoked = true
Expand Down
Loading
Loading