diff --git a/http/api/api.go b/http/api/api.go index 8be45a7..8ebbcde 100644 --- a/http/api/api.go +++ b/http/api/api.go @@ -141,6 +141,12 @@ func PushHandler(pusher push.Pusher, logger log.Logger) http.HandlerFunc { // using. Also note we expose Go errors to the output as this is meant // for "API" users. func RawCommandEnqueueHandler(enqueuer storage.CommandEnqueuer, pusher push.Pusher, logger log.Logger) http.HandlerFunc { + if enqueuer == nil { + panic("nil enqueuer") + } + if logger == nil { + panic("nil logger") + } return func(w http.ResponseWriter, r *http.Request) { ids := strings.Split(r.URL.Path, ",") ctx, logger := setupCtxLog(r.Context(), ids, logger) @@ -201,14 +207,14 @@ func RawCommandEnqueueHandler(enqueuer storage.CommandEnqueuer, pusher push.Push // optionally send pushes pushResp := make(map[string]*push.Response) var pushErr error - if !nopush { + if !nopush && pusher != nil { pushResp, pushErr = pusher.Push(ctx, ids) if err != nil { logger.Info("msg", "push", "err", err) output.PushError = err.Error() } - } else { - pushErr = nil + } else if !nopush && pusher == nil { + pushErr = errors.New("nil pusher") } // loop through our push errors, if any, and add to output var pushCt, pushErrCt int diff --git a/http/mdm/mdm.go b/http/mdm/mdm.go index 9cb4830..5cc235b 100644 --- a/http/mdm/mdm.go +++ b/http/mdm/mdm.go @@ -29,6 +29,12 @@ func mdmReqFromHTTPReq(r *http.Request) *mdm.Request { // CheckinHandler decodes an MDM check-in request and adapts it to service. func CheckinHandler(svc service.Checkin, logger log.Logger) http.HandlerFunc { + if svc == nil { + panic("nil service") + } + if logger == nil { + panic("nil logger") + } return func(w http.ResponseWriter, r *http.Request) { logger := ctxlog.Logger(r.Context(), logger) bodyBytes, err := mdmhttp.ReadAllAndReplaceBody(r) diff --git a/mdm/checkin.go b/mdm/checkin.go index 27387b6..56de939 100644 --- a/mdm/checkin.go +++ b/mdm/checkin.go @@ -25,7 +25,7 @@ type Authenticate struct { // Fields that may be present but are not strictly required for the // operation of the MDM protocol. Nice-to-haves. - SerialNumber string + SerialNumber string `plist:",omitempty"` } type b64Data []byte diff --git a/mdm/command.go b/mdm/command.go index 3be9998..7b2e023 100644 --- a/mdm/command.go +++ b/mdm/command.go @@ -23,11 +23,10 @@ type ErrorChain struct { // See https://developer.apple.com/documentation/devicemanagement/implementing_device_management/sending_mdm_commands_to_a_device type CommandResults struct { Enrollment - CommandUUID string + CommandUUID string `plist:",omitempty"` Status string - ErrorChain []ErrorChain - RequestType string - Raw []byte `plist:"-"` // Original command result XML plist + ErrorChain []ErrorChain `plist:",omitempty"` + Raw []byte `plist:"-"` // Original command result XML plist } // DecodeCheckin unmarshals rawMessage into results diff --git a/mdm/command_test.go b/mdm/command_test.go index 2a423b4..d0d2a53 100644 --- a/mdm/command_test.go +++ b/mdm/command_test.go @@ -10,14 +10,12 @@ func TestCommandAndReportResults(t *testing.T) { for _, test := range []struct { filename string UDID string - RequestType string Status string CommandUUID string }{ { "testdata/DeviceInformation.1.plist", "66ADE930-5FDF-5EC4-8429-15640684C489", - "DeviceInformation", "Acknowledged", "76eda240-5488-4989-8339-f2ae160113c4", }, @@ -36,9 +34,6 @@ func TestCommandAndReportResults(t *testing.T) { if msg, have, want := "incorrect UDID", a.UDID, test.UDID; have != want { t.Errorf("%s: %q, want: %q", msg, have, want) } - if msg, have, want := "incorrect RequestType", a.RequestType, test.RequestType; have != want { - t.Errorf("%s: %q, want: %q", msg, have, want) - } if msg, have, want := "incorrect Status", a.Status, test.Status; have != want { t.Errorf("%s: %q, want: %q", msg, have, want) } diff --git a/service/certauth/certauth_test.go b/service/certauth/certauth_test.go index 1a81eac..dd0edb4 100644 --- a/service/certauth/certauth_test.go +++ b/service/certauth/certauth_test.go @@ -8,6 +8,7 @@ import ( "github.com/micromdm/nanomdm/mdm" "github.com/micromdm/nanomdm/storage/file" + "github.com/micromdm/nanomdm/test" ) func loadAuthMsg() (*mdm.Authenticate, error) { @@ -61,7 +62,7 @@ func TestNilCertAuth(t *testing.T) { } func TestCertAuth(t *testing.T) { - _, crt, err := SimpleSelfSignedRSAKeypair("TESTDEVICE", 1) + _, crt, err := test.SimpleSelfSignedRSAKeypair("TESTDEVICE", 1) if err != nil { t.Fatal(err) } @@ -69,7 +70,7 @@ func TestCertAuth(t *testing.T) { if err != nil { t.Fatal(err) } - certAuth := New(&NopService{}, storage) + certAuth := New(&test.NopService{}, storage) if certAuth == nil { t.Fatal("New returned nil") } @@ -109,7 +110,7 @@ func TestCertAuth(t *testing.T) { if err != nil { t.Fatal(err) } - _, crt2, err := SimpleSelfSignedRSAKeypair("TESTDEVICE", 2) + _, crt2, err := test.SimpleSelfSignedRSAKeypair("TESTDEVICE", 2) if err != nil { t.Fatal(err) } @@ -125,7 +126,7 @@ func TestCertAuth(t *testing.T) { } func TestCertAuthRetro(t *testing.T) { - _, crt, err := SimpleSelfSignedRSAKeypair("TESTDEVICE", 1) + _, crt, err := test.SimpleSelfSignedRSAKeypair("TESTDEVICE", 1) if err != nil { t.Fatal(err) } @@ -133,7 +134,7 @@ func TestCertAuthRetro(t *testing.T) { if err != nil { t.Fatal(err) } - certAuth := New(&NopService{}, storage, WithAllowRetroactive()) + certAuth := New(&test.NopService{}, storage, WithAllowRetroactive()) if certAuth == nil { t.Fatal("New returned nil") } @@ -153,7 +154,7 @@ func TestCertAuthRetro(t *testing.T) { if err != nil { t.Fatal(err) } - _, crt2, err := SimpleSelfSignedRSAKeypair("TESTDEVICE", 2) + _, crt2, err := test.SimpleSelfSignedRSAKeypair("TESTDEVICE", 2) if err != nil { t.Fatal(err) } diff --git a/storage/file/bstoken.go b/storage/file/bstoken.go index 9072bec..b497802 100644 --- a/storage/file/bstoken.go +++ b/storage/file/bstoken.go @@ -19,10 +19,15 @@ func (s *FileStorage) StoreBootstrapToken(r *mdm.Request, msg *mdm.SetBootstrapT return nil } +// RetrieveBootstrapToken reads the BootstrapToken from disk and returns it. +// If no token yet exists a nil token and no error are returned. func (s *FileStorage) RetrieveBootstrapToken(r *mdm.Request, _ *mdm.GetBootstrapToken) (*mdm.BootstrapToken, error) { e := s.newEnrollment(r.ID) bsTokenRaw, err := e.readFile(BootstrapTokenFile) - if err != nil { + if errors.Is(err, os.ErrNotExist) { + // mute the error if we haven't escrowed a token yet. + return nil, nil + } else if err != nil { return nil, err } bsToken := &mdm.BootstrapToken{ diff --git a/storage/file/file.go b/storage/file/file.go index 31c1833..ba096aa 100644 --- a/storage/file/file.go +++ b/storage/file/file.go @@ -166,6 +166,14 @@ func (s *FileStorage) StoreAuthenticate(r *mdm.Request, msg *mdm.Authenticate) e return err } } + if err := e.resetNumericFile(TokenUpdateTallyFilename); err != nil { + return err + } + // remove the BootstrapToken when we receive an Authenticate message + // BS tokens are only valid when a new one is escrowed after enrollment. + if err := os.Remove(e.dirPrefix(BootstrapTokenFile)); err != nil && !errors.Is(err, os.ErrNotExist) { + return err + } return e.writeFile(AuthenticateFilename, []byte(msg.Raw)) } @@ -229,9 +237,6 @@ func (s *FileStorage) Disable(r *mdm.Request) error { if err := e.writeFile(DisabledFilename, nil); err != nil { return err } - if err := e.resetNumericFile(TokenUpdateTallyFilename); err != nil { - return err - } } return e.removeSubEnrollments() } diff --git a/storage/file/file_test.go b/storage/file/file_test.go index 6f8c6f6..5a30ccc 100644 --- a/storage/file/file_test.go +++ b/storage/file/file_test.go @@ -4,7 +4,7 @@ import ( "context" "testing" - "github.com/micromdm/nanomdm/storage/test" + "github.com/micromdm/nanomdm/test/e2e" ) func TestFileStorage(t *testing.T) { @@ -13,6 +13,5 @@ func TestFileStorage(t *testing.T) { t.Fatal(err) } - test.TestQueue(t, "EA4E19F1-7F8B-493D-BEAB-264B33BCF4E6", s) - test.TestRetrievePushInfo(t, context.Background(), s) + t.Run("e2e", func(t *testing.T) { e2e.TestE2E(t, context.Background(), s) }) } diff --git a/storage/mysql/bstoken_test.go b/storage/mysql/bstoken_test.go deleted file mode 100644 index 9fbc801..0000000 --- a/storage/mysql/bstoken_test.go +++ /dev/null @@ -1,59 +0,0 @@ -package mysql - -import ( - "bytes" - "context" - "encoding/base64" - "os" - "testing" - - "github.com/micromdm/nanomdm/mdm" -) - -func TestBSToken(t *testing.T) { - testDSN := os.Getenv("NANOMDM_MYSQL_STORAGE_TEST_DSN") - if testDSN == "" { - t.Skip("NANOMDM_MYSQL_STORAGE_TEST_DSN not set") - } - - storage, err := New(WithDSN(testDSN), WithDeleteCommands()) - if err != nil { - t.Fatal(err) - } - - var d Device - d, err = enrollTestDevice(storage) - if err != nil { - t.Fatal(err) - } - - ctx := context.Background() - - t.Run("BSToken nil", func(t *testing.T) { - tok, err := storage.RetrieveBootstrapToken(&mdm.Request{Context: ctx, EnrollID: d.EnrollID()}, nil) - if err != nil { - t.Fatal(err) - } - if tok != nil { - t.Fatal("Token for new device was nonnull") - } - }) - t.Run("BSToken set/get", func(t *testing.T) { - data := []byte("test token") - bsToken := mdm.BootstrapToken{BootstrapToken: make([]byte, base64.StdEncoding.EncodedLen(len(data)))} - base64.StdEncoding.Encode(bsToken.BootstrapToken, data) - testReq := &mdm.Request{Context: ctx, EnrollID: d.EnrollID()} - err := storage.StoreBootstrapToken(testReq, &mdm.SetBootstrapToken{BootstrapToken: bsToken}) - if err != nil { - t.Fatal(err) - } - - tok, err := storage.RetrieveBootstrapToken(testReq, nil) - if err != nil { - t.Fatal(err) - } - if !bytes.Equal(bsToken.BootstrapToken, tok.BootstrapToken) { - t.Fatalf("Bootstap tokens disequal after roundtrip: %v!=%v", bsToken, tok) - } - }) -} diff --git a/storage/mysql/device_test.go b/storage/mysql/device_test.go deleted file mode 100644 index 12ac9a4..0000000 --- a/storage/mysql/device_test.go +++ /dev/null @@ -1,86 +0,0 @@ -package mysql - -import ( - "context" - "errors" - "io/ioutil" - - "github.com/micromdm/nanomdm/mdm" - "github.com/micromdm/nanomdm/storage" -) - -type DeviceInterfaces interface { - storage.CheckinStore -} - -type Device struct { - UDID string -} - -func (d *Device) EnrollID() *mdm.EnrollID { - return &mdm.EnrollID{Type: mdm.Device, ID: d.UDID} -} - -func loadAuthMsg() (*mdm.Authenticate, Device, error) { - var d Device - b, err := ioutil.ReadFile("../../mdm/testdata/Authenticate.2.plist") - if err != nil { - return nil, d, err - } - r, err := mdm.DecodeCheckin(b) - if err != nil { - return nil, d, err - } - a, ok := r.(*mdm.Authenticate) - if !ok { - return nil, d, errors.New("not an Authenticate message") - } - d = Device{UDID: a.UDID} - return a, d, nil -} - -func loadTokenMsg() (*mdm.TokenUpdate, error) { - b, err := ioutil.ReadFile("../../mdm/testdata/TokenUpdate.2.plist") - if err != nil { - return nil, err - } - r, err := mdm.DecodeCheckin(b) - if err != nil { - return nil, err - } - a, ok := r.(*mdm.TokenUpdate) - if !ok { - return nil, errors.New("not a TokenUpdate message") - } - return a, nil -} - -func (d *Device) newMdmReq() *mdm.Request { - return &mdm.Request{ - Context: context.Background(), - EnrollID: &mdm.EnrollID{ - Type: mdm.Device, - ID: d.UDID, - }, - } -} - -func enrollTestDevice(storage DeviceInterfaces) (Device, error) { - authMsg, d, err := loadAuthMsg() - if err != nil { - return d, err - } - err = storage.StoreAuthenticate(d.newMdmReq(), authMsg) - if err != nil { - return d, err - } - tokenMsg, err := loadTokenMsg() - if err != nil { - return d, err - } - err = storage.StoreTokenUpdate(d.newMdmReq(), tokenMsg) - if err != nil { - return d, err - } - return d, nil -} diff --git a/storage/mysql/mysql.go b/storage/mysql/mysql.go index 2af0e83..e5c991f 100644 --- a/storage/mysql/mysql.go +++ b/storage/mysql/mysql.go @@ -109,6 +109,8 @@ ON DUPLICATE KEY UPDATE identity_cert = new.identity_cert, serial_number = new.serial_number, + bootstrap_token_b64 = NULL, + bootstrap_token_at = NULL, authenticate = new.authenticate, authenticate_at = CURRENT_TIMESTAMP;`, r.ID, pemCert, nullEmptyString(msg.SerialNumber), msg.Raw, diff --git a/storage/mysql/mysql_test.go b/storage/mysql/mysql_test.go new file mode 100644 index 0000000..0f8d8a5 --- /dev/null +++ b/storage/mysql/mysql_test.go @@ -0,0 +1,31 @@ +package mysql + +import ( + "context" + "os" + "testing" + + _ "github.com/go-sql-driver/mysql" + "github.com/micromdm/nanomdm/test/e2e" +) + +func TestMySQL(t *testing.T) { + testDSN := os.Getenv("NANOMDM_MYSQL_STORAGE_TEST_DSN") + if testDSN == "" { + t.Skip("NANOMDM_MYSQL_STORAGE_TEST_DSN not set") + } + + s, err := New(WithDSN(testDSN), WithDeleteCommands()) + if err != nil { + t.Fatal(err) + } + + t.Run("e2e-WithDeleteCommands()", func(t *testing.T) { e2e.TestE2E(t, context.Background(), s) }) + + s, err = New(WithDSN(testDSN)) + if err != nil { + t.Fatal(err) + } + + t.Run("e2e", func(t *testing.T) { e2e.TestE2E(t, context.Background(), s) }) +} diff --git a/storage/mysql/push_test.go b/storage/mysql/push_test.go deleted file mode 100644 index 64c48a2..0000000 --- a/storage/mysql/push_test.go +++ /dev/null @@ -1,23 +0,0 @@ -package mysql - -import ( - "context" - "os" - "testing" - - "github.com/micromdm/nanomdm/storage/test" -) - -func TestRetrievePushInfo(t *testing.T) { - testDSN := os.Getenv("NANOMDM_MYSQL_STORAGE_TEST_DSN") - if testDSN == "" { - t.Skip("NANOMDM_MYSQL_STORAGE_TEST_DSN not set") - } - - storage, err := New(WithDSN(testDSN), WithDeleteCommands()) - if err != nil { - t.Fatal(err) - } - - test.TestRetrievePushInfo(t, context.Background(), storage) -} diff --git a/storage/mysql/queue_test.go b/storage/mysql/queue_test.go deleted file mode 100644 index 382bfd1..0000000 --- a/storage/mysql/queue_test.go +++ /dev/null @@ -1,40 +0,0 @@ -package mysql - -import ( - "os" - "testing" - - "github.com/micromdm/nanomdm/storage/test" - - _ "github.com/go-sql-driver/mysql" -) - -func TestQueue(t *testing.T) { - testDSN := os.Getenv("NANOMDM_MYSQL_STORAGE_TEST_DSN") - if testDSN == "" { - t.Skip("NANOMDM_MYSQL_STORAGE_TEST_DSN not set") - } - - storage, err := New(WithDSN(testDSN), WithDeleteCommands()) - if err != nil { - t.Fatal(err) - } - - d, err := enrollTestDevice(storage) - if err != nil { - t.Fatal(err) - } - - t.Run("WithDeleteCommands()", func(t *testing.T) { - test.TestQueue(t, d.UDID, storage) - }) - - storage, err = New(WithDSN(testDSN)) - if err != nil { - t.Fatal(err) - } - - t.Run("normal", func(t *testing.T) { - test.TestQueue(t, d.UDID, storage) - }) -} diff --git a/storage/pgsql/bstoken.go b/storage/pgsql/bstoken.go index 4189d96..d70b934 100644 --- a/storage/pgsql/bstoken.go +++ b/storage/pgsql/bstoken.go @@ -1,6 +1,8 @@ package pgsql import ( + "database/sql" + "github.com/micromdm/nanomdm/mdm" ) @@ -18,17 +20,17 @@ func (s *PgSQLStorage) StoreBootstrapToken(r *mdm.Request, msg *mdm.SetBootstrap } func (s *PgSQLStorage) RetrieveBootstrapToken(r *mdm.Request, _ *mdm.GetBootstrapToken) (*mdm.BootstrapToken, error) { - var tokenB64 string + var tokenB64 sql.NullString err := s.db.QueryRowContext( r.Context, `SELECT bootstrap_token_b64 FROM devices WHERE id = $1;`, r.ID, ).Scan(&tokenB64) - if err != nil { + if err != nil || !tokenB64.Valid { return nil, err } bsToken := new(mdm.BootstrapToken) - err = bsToken.SetTokenString(tokenB64) + err = bsToken.SetTokenString(tokenB64.String) if err == nil { err = s.updateLastSeen(r) } diff --git a/storage/pgsql/pgsql_test.go b/storage/pgsql/pgsql_test.go new file mode 100644 index 0000000..a507dbc --- /dev/null +++ b/storage/pgsql/pgsql_test.go @@ -0,0 +1,31 @@ +package pgsql + +import ( + "context" + "os" + "testing" + + _ "github.com/lib/pq" + "github.com/micromdm/nanomdm/test/e2e" +) + +func TestMySQL(t *testing.T) { + testDSN := os.Getenv("NANOMDM_PGSQL_STORAGE_TEST_DSN") + if testDSN == "" { + t.Skip("NANOMDM_PGSQL_STORAGE_TEST_DSN not set") + } + + s, err := New(WithDSN(testDSN), WithDeleteCommands()) + if err != nil { + t.Fatal(err) + } + + t.Run("e2e-WithDeleteCommands()", func(t *testing.T) { e2e.TestE2E(t, context.Background(), s) }) + + s, err = New(WithDSN(testDSN)) + if err != nil { + t.Fatal(err) + } + + t.Run("e2e", func(t *testing.T) { e2e.TestE2E(t, context.Background(), s) }) +} diff --git a/storage/pgsql/postgresql.go b/storage/pgsql/postgresql.go index 9f3c4dc..2c3680a 100644 --- a/storage/pgsql/postgresql.go +++ b/storage/pgsql/postgresql.go @@ -109,6 +109,8 @@ ON CONFLICT ON CONSTRAINT devices_pkey DO UPDATE SET identity_cert = EXCLUDED.identity_cert, serial_number = EXCLUDED.serial_number, + bootstrap_token_b64 = NULL, + bootstrap_token_at = NULL, authenticate = EXCLUDED.authenticate, authenticate_at = CURRENT_TIMESTAMP;`, r.ID, nullEmptyString(string(pemCert)), nullEmptyString(msg.SerialNumber), msg.Raw, diff --git a/storage/pgsql/queue_test.go b/storage/pgsql/queue_test.go deleted file mode 100644 index 4d8bb89..0000000 --- a/storage/pgsql/queue_test.go +++ /dev/null @@ -1,111 +0,0 @@ -//go:build integration -// +build integration - -package pgsql - -import ( - "context" - "errors" - "flag" - "io/ioutil" - "testing" - - _ "github.com/lib/pq" - "github.com/micromdm/nanomdm/mdm" - "github.com/micromdm/nanomdm/storage/test" -) - -var flDSN = flag.String("dsn", "", "DSN of test PostgreSQL instance") - -func loadAuthMsg() (*mdm.Authenticate, error) { - b, err := ioutil.ReadFile("../../mdm/testdata/Authenticate.2.plist") - if err != nil { - return nil, err - } - r, err := mdm.DecodeCheckin(b) - if err != nil { - return nil, err - } - a, ok := r.(*mdm.Authenticate) - if !ok { - return nil, errors.New("not an Authenticate message") - } - return a, nil -} - -func loadTokenMsg() (*mdm.TokenUpdate, error) { - b, err := ioutil.ReadFile("../../mdm/testdata/TokenUpdate.2.plist") - if err != nil { - return nil, err - } - r, err := mdm.DecodeCheckin(b) - if err != nil { - return nil, err - } - a, ok := r.(*mdm.TokenUpdate) - if !ok { - return nil, errors.New("not a TokenUpdate message") - } - return a, nil -} - -const deviceUDID = "66ADE930-5FDF-5EC4-8429-15640684C489" - -func newMdmReq() *mdm.Request { - return &mdm.Request{ - Context: context.Background(), - EnrollID: &mdm.EnrollID{ - Type: mdm.Device, - ID: deviceUDID, - }, - } -} - -func enrollTestDevice(storage *PgSQLStorage) error { - authMsg, err := loadAuthMsg() - if err != nil { - return err - } - err = storage.StoreAuthenticate(newMdmReq(), authMsg) - if err != nil { - return err - } - tokenMsg, err := loadTokenMsg() - if err != nil { - return err - } - err = storage.StoreTokenUpdate(newMdmReq(), tokenMsg) - if err != nil { - return err - } - return nil -} - -func TestQueue(t *testing.T) { - if *flDSN == "" { - t.Fatal("PostgreSQL DSN flag not provided to test") - } - - storage, err := New(WithDSN(*flDSN), WithDeleteCommands()) - if err != nil { - t.Fatal(err) - } - - err = enrollTestDevice(storage) - if err != nil { - t.Fatal(err) - } - - t.Run("WithDeleteCommands()", func(t *testing.T) { - test.TestQueue(t, deviceUDID, storage) - }) - - storage, err = New(WithDSN(*flDSN)) - if err != nil { - t.Fatal(err) - } - - t.Run("normal", func(t *testing.T) { - test.TestQueue(t, deviceUDID, storage) - }) -} diff --git a/storage/storage.go b/storage/storage.go index 1885fe0..b7adde9 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -30,6 +30,9 @@ type CommandAndReportResultsStore interface { type BootstrapTokenStore interface { StoreBootstrapToken(r *mdm.Request, msg *mdm.SetBootstrapToken) error + + // RetrieveBootstrapToken retrieves the previously-escrowed Bootstrap Token. + // If a token has not yet been escrowed then a nil token and no error should be returned. RetrieveBootstrapToken(r *mdm.Request, msg *mdm.GetBootstrapToken) (*mdm.BootstrapToken, error) } diff --git a/storage/test/push.go b/storage/test/push.go deleted file mode 100644 index 9492a4b..0000000 --- a/storage/test/push.go +++ /dev/null @@ -1,19 +0,0 @@ -package test - -import ( - "context" - "testing" - - "github.com/micromdm/nanomdm/storage" -) - -func TestRetrievePushInfo(t *testing.T, ctx context.Context, s storage.PushStore) { - t.Run("TestRetrievePushInfo", func(t *testing.T) { - _, err := s.RetrievePushInfo(ctx, []string{"INVALID"}) - if err != nil { - // should NOT recieve a "global" error for an enrollment that - // is merely invalid (or not enrolled yet, or not fully enrolled) - t.Errorf("should NOT have errored: %v", err) - } - }) -} diff --git a/storage/test/queue.go b/storage/test/queue.go deleted file mode 100644 index aa564c3..0000000 --- a/storage/test/queue.go +++ /dev/null @@ -1,155 +0,0 @@ -package test - -import ( - "context" - "testing" - - "github.com/micromdm/nanomdm/mdm" - "github.com/micromdm/nanomdm/storage" - - "github.com/groob/plist" -) - -// QueueInterfaces are the storage interfaces needed for testing queue operations. -type QueueInterfaces interface { - storage.CommandEnqueuer - storage.CommandAndReportResultsStore -} - -// newCommand assembles a fake command including the plist raw value -func newCommand(cmd string) (*mdm.Command, error) { - // assemble a fake struct just for marshalling to plist - fCmd := &struct { - CommandUUID string - Command struct { - RequestType string - } - }{ - CommandUUID: cmd, - Command: struct{ RequestType string }{cmd}, - } - // marshal it to plist - rawBytes, err := plist.Marshal(fCmd) - if err != nil { - return nil, err - } - // return a real *mdm.Command which includes the marshalled JSON - return &mdm.Command{ - CommandUUID: fCmd.CommandUUID, - Command: fCmd.Command, - Raw: rawBytes, - }, nil -} - -// enqueue queues a new command -func enqueue(t *testing.T, q QueueInterfaces, ctx context.Context, id, cmdStr string) { - cmd, err := newCommand(cmdStr) - if err != nil { - t.Fatal(err) - } - res, err := q.EnqueueCommand(ctx, []string{id}, cmd) - if err != nil { - t.Fatal(err) - } - for k, v := range res { - t.Fatalf("enqueuing to ID %s: %v", k, v) - } -} - -// compareCommand compares makes sure cmd looks similar to newCommand(cmdStr) -func compareCommand(t *testing.T, cmdStr string, cmd *mdm.Command) { - if cmdStr != "" && cmd == nil { - t.Errorf("expected next command, but got empty response. wanted: %q", cmdStr) - return - } - if cmdStr == "" && cmd != nil { - t.Errorf("expected empty next command, but got: %q", cmd.CommandUUID) - } - if cmd == nil { - return - } - if cmd.CommandUUID != cmdStr { - t.Errorf("mismatched command UUID. want: %q, have: %q", cmdStr, cmd.CommandUUID) - } - if cmd.Command.RequestType != cmdStr { - t.Errorf("mismatched command RequestType. want: %q, have: %q", cmdStr, cmd.Command.RequestType) - } -} - -// retrieve retrieves the next command from the backend -func retrieve(t *testing.T, q QueueInterfaces, r *mdm.Request, cmdStr string, skipNotNow bool) { - retCmd, err := q.RetrieveNextCommand(r, skipNotNow) - if err != nil { - t.Fatal(err) - } - compareCommand(t, cmdStr, retCmd) -} - -// report fakes a command result and reports it to the backend -func report(t *testing.T, q QueueInterfaces, r *mdm.Request, cmdStr, status string) { - fReport := &struct { - CommandUUID string `plist:",omitempty"` - Status string - RequestType string `plist:",omitempty"` - }{CommandUUID: cmdStr, Status: status, RequestType: cmdStr} - rawBytes, err := plist.Marshal(fReport) - if err != nil { - t.Fatal(err) - } - results := &mdm.CommandResults{ - CommandUUID: fReport.CommandUUID, - Status: fReport.Status, - RequestType: fReport.RequestType, - Raw: rawBytes, - } - err = q.StoreCommandReport(r, results) - if err != nil { - t.Error(err) - } -} - -// reportRetrieve behaves similarly to an MDM client: it first reports -// the results and then retrieves the next command. -func reportRetrieve(t *testing.T, q QueueInterfaces, r *mdm.Request, reportCmd, reportStatus, expectedCmd string) { - report(t, q, r, reportCmd, reportStatus) - skipNotNow := false - if reportStatus == "NotNow" { - skipNotNow = true - } - retrieve(t, q, r, expectedCmd, skipNotNow) -} - -// TestQueue performs basic testing of the storage queue -func TestQueue(t *testing.T, id string, q QueueInterfaces) { - ctx := context.Background() - - // build a fake MDM request object - r := &mdm.Request{ - EnrollID: &mdm.EnrollID{ - Type: mdm.Device, - ID: id, - ParentID: "", - }, - Context: ctx, - } - - t.Run("basic", func(t *testing.T) { - reportRetrieve(t, q, r, "", "Idle", "") - enqueue(t, q, ctx, id, "CMD1") - enqueue(t, q, ctx, id, "CMD2") - reportRetrieve(t, q, r, "", "Idle", "CMD1") - reportRetrieve(t, q, r, "CMD1", "Acknowledged", "CMD2") - reportRetrieve(t, q, r, "CMD2", "Acknowledged", "") - reportRetrieve(t, q, r, "", "Idle", "") - }) - - t.Run("notnow", func(t *testing.T) { - reportRetrieve(t, q, r, "", "Idle", "") - enqueue(t, q, ctx, id, "CMD3") - reportRetrieve(t, q, r, "", "Idle", "CMD3") - reportRetrieve(t, q, r, "CMD3", "NotNow", "") - reportRetrieve(t, q, r, "", "Idle", "CMD3") - reportRetrieve(t, q, r, "CMD3", "Acknowledged", "") - reportRetrieve(t, q, r, "", "Idle", "") - }) -} diff --git a/test/e2e/api.go b/test/e2e/api.go new file mode 100644 index 0000000..772a151 --- /dev/null +++ b/test/e2e/api.go @@ -0,0 +1,51 @@ +package e2e + +import ( + "context" + "errors" + "net/http" + "strings" + + "github.com/micromdm/nanomdm/mdm" + "github.com/micromdm/nanomdm/test" + "github.com/micromdm/nanomdm/test/enrollment" +) + +// Doer executes an HTTP request. +type Doer interface { + Do(*http.Request) (*http.Response, error) +} + +type api struct { + doer Doer +} + +func (a *api) RawCommandEnqueue(ctx context.Context, ids []string, cmd *mdm.Command, nopush bool) error { + r, err := test.PlistReader(cmd) + if err != nil { + return err + } + + if !strings.HasSuffix(enqueueURL, "/") { + return errors.New("missing trailing slash of enqueue URL") + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, enqueueURL+strings.Join(ids, ","), r) + if err != nil { + return err + } + + v := req.URL.Query() + if nopush { + v.Set("nopush", "1") + } + req.URL.RawQuery = v.Encode() + + resp, err := a.doer.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + return enrollment.HTTPErrors(resp) +} diff --git a/test/e2e/bstoken.go b/test/e2e/bstoken.go new file mode 100644 index 0000000..ee7e504 --- /dev/null +++ b/test/e2e/bstoken.go @@ -0,0 +1,53 @@ +package e2e + +import ( + "bytes" + "context" + "encoding/base64" + "fmt" + "testing" + + "github.com/micromdm/nanomdm/mdm" +) + +type bstokenDevice interface { + IDer + DoGetBootstrapToken(ctx context.Context) (*mdm.BootstrapToken, error) + DoEscrowBootstrapToken(ctx context.Context, token []byte) error +} + +// bstoken assumes d is a new enrollment and has had no BootstrapToken stored yet. +func bstoken(t *testing.T, ctx context.Context, d bstokenDevice) { + tok, err := d.DoGetBootstrapToken(ctx) + if err != nil { + // should not error. newly enrolled devices should not error + // if their BS token is requested. + t.Fatal(fmt.Errorf("error retrieving not-yet-escrowed bootstrap token: %w", err)) + } + + if tok != nil { + t.Errorf("token for supposedly freshly enrolled device %s was not nil", d.ID()) + } + + input := []byte("hello world") + + err = d.DoEscrowBootstrapToken(ctx, input) + if err != nil { + t.Fatal(err) + } + + tok, err = d.DoGetBootstrapToken(ctx) + if err != nil { + t.Fatal(err) + } + + x, err := base64.StdEncoding.DecodeString(string(tok.BootstrapToken)) + if err != nil { + t.Fatal(err) + } + + if have, want := x, input; !bytes.Equal(have, want) { + t.Errorf("bootstrap token: have: %v, want: %v", string(have), string(want)) + } + +} diff --git a/test/e2e/client.go b/test/e2e/client.go new file mode 100644 index 0000000..ece209d --- /dev/null +++ b/test/e2e/client.go @@ -0,0 +1,22 @@ +package e2e + +import ( + "net/http" + "net/http/httptest" +) + +// HandlerClient behaves like an HTTP client but merely routes to an http.Handler. +type HandlerClient struct { + handler http.Handler +} + +func NewHandlerClient(handler http.Handler) *HandlerClient { + return &HandlerClient{handler: handler} +} + +// Do routes HTTP requests to an http.Handler using an httptest.NewRecorder. +func (c *HandlerClient) Do(r *http.Request) (*http.Response, error) { + rec := httptest.NewRecorder() + c.handler.ServeHTTP(rec, r) + return rec.Result(), nil +} diff --git a/test/e2e/device.go b/test/e2e/device.go new file mode 100644 index 0000000..00f723f --- /dev/null +++ b/test/e2e/device.go @@ -0,0 +1,92 @@ +package e2e + +import ( + "context" + "fmt" + "io" + + "github.com/groob/plist" + "github.com/micromdm/nanomdm/mdm" + "github.com/micromdm/nanomdm/test" + "github.com/micromdm/nanomdm/test/enrollment" +) + +// device is a wrapper around our enrollment for ease of use. +type device struct { + *enrollment.Enrollment +} + +func newDeviceFromCheckins(doer Doer, serverURL, authPath, tokUpdPath string) (*device, error) { + e, err := enrollment.NewFromCheckins(doer, serverURL, "", authPath, tokUpdPath) + if err != nil { + return nil, err + } + return &device{Enrollment: e}, nil +} + +func newDevice(doer Doer, serverURL string) (*device, error) { + const topic = "com.example.apns.topic" + + e, err := enrollment.NewRandomDeviceEnrollment(doer, topic, serverURL, "") + if err != nil { + return nil, err + } + + return &device{Enrollment: e}, nil +} + +func newCommand(uuid, requestType string) *mdm.Command { + if uuid == "" && requestType == "" { + return nil + } + return &mdm.Command{ + CommandUUID: uuid, + Command: struct{ RequestType string }{ + RequestType: requestType, + }, + } +} + +func (d *device) NewCommandReport(uuid, status string, errors []mdm.ErrorChain) *mdm.CommandResults { + return &mdm.CommandResults{ + Enrollment: *d.GetEnrollment(), + CommandUUID: uuid, + Status: status, + ErrorChain: errors, + } +} + +const Limit1MiB = 1024 * 1024 + +func (d *device) CMDDoReportAndFetch(ctx context.Context, report *mdm.CommandResults) (*mdm.Command, error) { + reportReader, err := test.PlistReader(report) + if err != nil { + return nil, err + } + + resp, err := d.DoReportAndFetch(ctx, reportReader) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + body, err := io.ReadAll(io.LimitReader(resp.Body, Limit1MiB)) + if err != nil { + return nil, err + } + + if resp.StatusCode != 200 { + return nil, enrollment.NewHTTPError(resp, body) + } + + var cmd *mdm.Command + + if len(body) > 0 { + cmd = new(mdm.Command) + if err = plist.Unmarshal(body, cmd); err != nil { + return nil, fmt.Errorf("decoding command body: %w", err) + } + } + + return cmd, nil +} diff --git a/test/e2e/e2e.go b/test/e2e/e2e.go new file mode 100644 index 0000000..acbb8ba --- /dev/null +++ b/test/e2e/e2e.go @@ -0,0 +1,115 @@ +package e2e + +import ( + "context" + "fmt" + "net/http" + "testing" + + "github.com/micromdm/nanolib/log" + "github.com/micromdm/nanomdm/cryptoutil" + httpapi "github.com/micromdm/nanomdm/http/api" + httpmdm "github.com/micromdm/nanomdm/http/mdm" + "github.com/micromdm/nanomdm/mdm" + "github.com/micromdm/nanomdm/service" + "github.com/micromdm/nanomdm/service/certauth" + "github.com/micromdm/nanomdm/service/nanomdm" + "github.com/micromdm/nanomdm/storage" +) + +const ( + serverURL = "/mdm" + enqueueURL = "/api/enq/" +) + +// setupNanoMDM configures normal-ish NanoMDM HTTP server handlers for testing. +func setupNanoMDM(logger log.Logger, store storage.AllStorage) (http.Handler, error) { + // begin with the primary NanoMDM service + var svc service.CheckinAndCommandService = nanomdm.New(store, nanomdm.WithLogger(logger)) + + // chain the certificate auth middleware + svc = certauth.New(svc, store) + + // setup MDM (check-in and command) handlers + var mdmHandler http.Handler = httpmdm.CheckinAndCommandHandler(svc, logger.With("handler", "mdm")) + // mdmHandler = httpmdm.CertVerifyMiddleware(mdmHandler, , logger.With("handler", "verify")) + mdmHandler = httpmdm.CertExtractMdmSignatureMiddleware(mdmHandler, httpmdm.MdmSignatureVerifierFunc(cryptoutil.VerifyMdmSignature)) + + // setup API handlers + var enqueueHandler http.Handler = httpapi.RawCommandEnqueueHandler(store, nil, logger.With("handler", enqueueURL)) + enqueueHandler = http.StripPrefix(enqueueURL, enqueueHandler) + + // create a mux for them + mux := http.NewServeMux() + mux.Handle(serverURL, mdmHandler) + mux.Handle(enqueueURL, enqueueHandler) + + return mux, nil +} + +type NanoMDMAPI interface { + // RawCommandEnqueue enqueues cmd to ids. An APNs push is omitted if nopush is true. + RawCommandEnqueue(ctx context.Context, ids []string, cmd *mdm.Command, nopush bool) error +} + +type IDer interface { + ID() string +} + +func TestE2E(t *testing.T, ctx context.Context, store storage.AllStorage) { + var logger log.Logger = log.NopLogger // stdlogfmt.New(stdlogfmt.WithDebugFlag(true)) + + mux, err := setupNanoMDM(logger, store) + if err != nil { + t.Fatal(err) + } + + // create a fake HTTP client that dispatches to our raw handlers + c := NewHandlerClient(mux) + + // create our new device for testing + d, err := newDeviceFromCheckins( + c, + serverURL, + "../../mdm/testdata/Authenticate.2.plist", + "../../mdm/testdata/TokenUpdate.2.plist", + ) + if err != nil { + t.Fatal(err) + } + + // regression test for retrieving push info of missing devices. + t.Run("invalid-pushinfo", func(t *testing.T) { + _, err := store.RetrievePushInfo(ctx, []string{"INVALID"}) + if err != nil { + // should NOT recieve a "global" error for an enrollment that + // is merely invalid (or not enrolled yet, or not fully enrolled) + t.Errorf("should NOT have errored: %v", err) + } + }) + + t.Run("enroll", func(t *testing.T) { enroll(t, ctx, d, store) }) + + t.Run("tally", func(t *testing.T) { tally(t, ctx, d, store, 1) }) + + t.Run("bstoken", func(t *testing.T) { bstoken(t, ctx, d.Enrollment) }) + + // re-enroll device + // this is to try and catch any leftover crud that a storage backend didn't + // clean up (like the tally count, BS token, etc.) + err = d.DoEnroll(ctx) + if err != nil { + t.Fatal(fmt.Errorf("re-enrolling device %s: %w", d.ID(), err)) + } + + t.Run("tally-after-reenroll", func(t *testing.T) { tally(t, ctx, d, store, 1) }) + + t.Run("bstoken-after-reenroll", func(t *testing.T) { bstoken(t, ctx, d.Enrollment) }) + + err = store.ClearQueue(d.NewMDMRequest(ctx)) + if err != nil { + t.Fatal() + } + + t.Run("queue", func(t *testing.T) { queue(t, ctx, d, &api{doer: c}) }) +} diff --git a/test/e2e/enroll.go b/test/e2e/enroll.go new file mode 100644 index 0000000..5097f4c --- /dev/null +++ b/test/e2e/enroll.go @@ -0,0 +1,39 @@ +package e2e + +import ( + "context" + "reflect" + "testing" + + "github.com/micromdm/nanomdm/mdm" + "github.com/micromdm/nanomdm/storage" +) + +type enrollDevice interface { + IDer + DoEnroll(context.Context) error + GetPush() *mdm.Push +} + +func enroll(t *testing.T, ctx context.Context, d enrollDevice, store storage.PushStore) { + // enroll it + err := d.DoEnroll(ctx) + if err != nil { + t.Fatal(err) + } + + // extract the push info for the given id + pushInfos, err := store.RetrievePushInfo(ctx, []string{d.ID()}) + if err != nil { + t.Fatal(err) + } + + // test that we got the right push data data back + if want, have := 1, len(pushInfos); want != have { + t.Fatalf("len(pushInfos): want: %v, have: %v", want, have) + } + push := d.GetPush() + if !reflect.DeepEqual(pushInfos[d.ID()], push) { + t.Errorf("pushInfo have: %v, want: %v", pushInfos[d.ID()], push) + } +} diff --git a/test/e2e/queue.go b/test/e2e/queue.go new file mode 100644 index 0000000..a42aacc --- /dev/null +++ b/test/e2e/queue.go @@ -0,0 +1,96 @@ +package e2e + +import ( + "context" + "fmt" + "reflect" + "testing" + + "github.com/micromdm/nanomdm/mdm" +) + +type queueDevice interface { + CMDDoReportAndFetch(ctx context.Context, cmd *mdm.CommandResults) (*mdm.Command, error) + NewCommandReport(uuid, status string, errors []mdm.ErrorChain) *mdm.CommandResults + IDer +} + +// enqueue enqueues cmd to id using a. +func enqueue(t *testing.T, ctx context.Context, a NanoMDMAPI, id string, cmd *mdm.Command) { + err := a.RawCommandEnqueue(ctx, []string{id}, cmd, true) + if err != nil { + t.Fatal(err) + } +} + +// simpleCmd makes a command with a CommandUUID and RequestType the same string. +func simpleCmd(cmdID string) *mdm.Command { + return newCommand(cmdID, cmdID) +} + +// sendReportExpectCommandReply send a command report and expect a certain command reply. +func sendReportExpectCommandReply(t *testing.T, ctx context.Context, d queueDevice, reportCmd, reportStatus, expectedCmd string) { + cr := d.NewCommandReport(reportCmd, reportStatus, nil) + cmd, err := d.CMDDoReportAndFetch(ctx, cr) + if err != nil { + t.Fatal(fmt.Errorf("reporting cmd=%s status=%s: %w", reportCmd, reportStatus, err)) + } + + // make sure the command we expect was received + if have, want := cmd, simpleCmd(expectedCmd); !reflect.DeepEqual(have, want) { + t.Errorf("command: have: %v, want: %v", have, want) + } +} + +// enqueueSimple enqueues cmd to a for d. +func enqueueSimple(t *testing.T, ctx context.Context, d queueDevice, a NanoMDMAPI, cmd string) { + // we're assuming the UDID is all we need here. + enqueue(t, ctx, a, d.ID(), simpleCmd(cmd)) +} + +func queue(t *testing.T, ctx context.Context, d queueDevice, a NanoMDMAPI) { + t.Run("basic", func(t *testing.T) { + // report Idle. + // expect no command (empty queue for this id). + sendReportExpectCommandReply(t, ctx, d, "", "Idle", "") + // enqueue a couple commands. + enqueueSimple(t, ctx, d, a, "CMD1") + enqueueSimple(t, ctx, d, a, "CMD2") + // report Idle. + // but now expect the CMD1 result (first on the queue). + sendReportExpectCommandReply(t, ctx, d, "", "Idle", "CMD1") + // ack CMD1. + // expect CMD2. + sendReportExpectCommandReply(t, ctx, d, "CMD1", "Acknowledged", "CMD2") + // ack CMD2 (effectively clearning the queue). + // expect no command (only two commands queued). + sendReportExpectCommandReply(t, ctx, d, "CMD2", "Acknowledged", "") + // report Idle. + // expect no command (empty queue). + sendReportExpectCommandReply(t, ctx, d, "", "Idle", "") + }) + t.Run("notnow", func(t *testing.T) { + // report Idle. + // expect no command (empty queue). + sendReportExpectCommandReply(t, ctx, d, "", "Idle", "") + // enqueue CMD3. + enqueueSimple(t, ctx, d, a, "CMD3") + // report Idle. + // expect CMD3. + sendReportExpectCommandReply(t, ctx, d, "", "Idle", "CMD3") + // report NotNow for CMD3. + // expect no command (only NotNow commands in queue). + sendReportExpectCommandReply(t, ctx, d, "CMD3", "NotNow", "") + // report Idle. + // this could be considered as "resetting" NotNow for CMD3. + // expect CMD3 (the NotNow'd command). + sendReportExpectCommandReply(t, ctx, d, "", "Idle", "CMD3") + // ack CMD3. + // expect no command (empty queue). + sendReportExpectCommandReply(t, ctx, d, "CMD3", "Acknowledged", "") + // report Idle. + // expect no command (empty queue). + sendReportExpectCommandReply(t, ctx, d, "", "Idle", "") + }) + +} diff --git a/test/e2e/tally.go b/test/e2e/tally.go new file mode 100644 index 0000000..845c684 --- /dev/null +++ b/test/e2e/tally.go @@ -0,0 +1,44 @@ +package e2e + +import ( + "context" + "testing" + + "github.com/micromdm/nanomdm/storage" +) + +type tokenTallyDevice interface { + DoTokenUpdate(context.Context) error + IDer +} + +// tally tests to make sure the TokenUpdate tally functions nominally. +func tally(t *testing.T, ctx context.Context, d tokenTallyDevice, store storage.TokenUpdateTallyStore, initial int) { + // retrieve the tally + tally, err := store.RetrieveTokenUpdateTally(ctx, d.ID()) + if err != nil { + t.Fatal() + } + + // make sure it's what we want + if have, want := tally, initial; have != want { + t.Errorf("token update tally: have: %v, want: %v", have, want) + } + + // perform a TokenUpdate (should increase the tally) + err = d.DoTokenUpdate(ctx) + if err != nil { + t.Fatal() + } + + // retrieve the tally again + tally, err = store.RetrieveTokenUpdateTally(ctx, d.ID()) + if err != nil { + t.Fatal() + } + + // make sure it's what we want (+1) + if have, want := tally, initial+1; have != want { + t.Errorf("token update tally (2nd): have: %v, want: %v", have, want) + } +} diff --git a/test/enrollment/enrollment.go b/test/enrollment/enrollment.go new file mode 100644 index 0000000..841ef7c --- /dev/null +++ b/test/enrollment/enrollment.go @@ -0,0 +1,343 @@ +package enrollment + +import ( + "context" + "crypto" + "crypto/rand" + "crypto/x509" + "encoding/base64" + "errors" + "fmt" + "io" + "net/http" + "os" + "sync" + + "github.com/micromdm/nanomdm/mdm" + "github.com/micromdm/nanomdm/test" + "github.com/micromdm/nanomdm/test/protocol" + + "github.com/groob/plist" +) + +var ErrAlreadyEnrolled = errors.New("already enrolled") + +type Transport interface { + // DoCheckIn performs an HTTP MDM check-in to the CheckInURL (or ServerURL). + // The caller is responsible for closing the response body. + DoCheckIn(context.Context, io.Reader) (*http.Response, error) + + // DoReportResultsAndFetchNext sends an HTTP MDM report-results-and-retrieve-next-command request to the ServerURL. + // The caller is responsible for closing the response body. + DoReportResultsAndFetchNext(ctx context.Context, report io.Reader) (*http.Response, error) +} + +// Enrollment emulates an MDM enrollment. +// Currently it mostly emulates device channel enrollments. +type Enrollment struct { + enrollID mdm.EnrollID + enrollment mdm.Enrollment + push mdm.Push + + cert *x509.Certificate + key crypto.PrivateKey + + serialNumber string + unlockToken []byte + + transport Transport + + enrolled bool + enrollM sync.Mutex +} + +func loadAuthTokUpd(authPath, tokUpdPath string) (*mdm.Authenticate, *mdm.TokenUpdate, error) { + authBytes, err := os.ReadFile(authPath) + if err != nil { + return nil, nil, err + } + msg, err := mdm.DecodeCheckin(authBytes) + if err != nil { + return nil, nil, err + } + auth, ok := msg.(*mdm.Authenticate) + if !ok { + return auth, nil, errors.New("not an Authenticate message") + } + tokUpdBytes, err := os.ReadFile(tokUpdPath) + if err != nil { + return auth, nil, err + } + msg, err = mdm.DecodeCheckin(tokUpdBytes) + if err != nil { + return auth, nil, err + } + tokUpd, ok := msg.(*mdm.TokenUpdate) + if !ok { + return auth, tokUpd, errors.New("not a TokenUpdate message") + } + return auth, tokUpd, nil +} + +// NewFromCheckins loads device information from authenticate and tokenupdate files on disk. +func NewFromCheckins(doer protocol.Doer, serverURL, checkInURL, authenticatePath, tokenUpdatePath string) (*Enrollment, error) { + auth, tokUpd, err := loadAuthTokUpd(authenticatePath, tokenUpdatePath) + if err != nil { + return nil, err + } + + e := &Enrollment{ + enrollment: auth.Enrollment, + push: tokUpd.Push, + serialNumber: auth.SerialNumber, + + // we're assuming the IDs here are devices + enrollID: mdm.EnrollID{Type: mdm.Device, ID: auth.UDID}, + } + e.key, e.cert, err = test.SimpleSelfSignedRSAKeypair("TESTDEVICE", 2) + + e.transport = protocol.NewTransport( + protocol.WithSignMessage(), + protocol.WithIdentityProvider(e.GetIdentity), + protocol.WithMDMURLs(serverURL, checkInURL), + protocol.WithClient(doer), + ) + + return e, err +} + +// NewRandomDeviceEnrollment creates a new randomly identified MDM enrollment. +func NewRandomDeviceEnrollment(doer protocol.Doer, topic, serverURL, checkInURL string) (*Enrollment, error) { + udid := randString(32) + e := &Enrollment{ + enrollment: mdm.Enrollment{UDID: udid}, + push: mdm.Push{ + Topic: topic, + PushMagic: randString(32), + // Token: []byte(randString(32)), // Token is populated in DoTokenUpdate() + }, + serialNumber: randString(8), + // unlockToken: , + enrollID: mdm.EnrollID{Type: mdm.Device, ID: udid}, + } + var err error + e.key, e.cert, err = test.SimpleSelfSignedRSAKeypair("TESTDEVICE", 2) + + e.transport = protocol.NewTransport( + protocol.WithSignMessage(), + protocol.WithIdentityProvider(e.GetIdentity), + protocol.WithMDMURLs(serverURL, checkInURL), + protocol.WithClient(doer), + ) + + return e, err +} + +// GetIdentity supplies the identity certificate and key of this enrollment. +func (c *Enrollment) GetIdentity(context.Context) (*x509.Certificate, crypto.PrivateKey, error) { + return c.cert, c.key, nil +} + +// genAuthenticate creates an XML Plist Authenticate check-in message. +func (e *Enrollment) genAuthenticate() (io.Reader, error) { + a := &mdm.Authenticate{ + Enrollment: e.enrollment, + MessageType: mdm.MessageType{MessageType: "Authenticate"}, + Topic: e.push.Topic, + SerialNumber: e.serialNumber, + } + return test.PlistReader(a) +} + +// genTokenUpdate creates an XML Plist TokenUpdate check-in message. +func (e *Enrollment) genTokenUpdate() (io.Reader, error) { + t := &mdm.TokenUpdate{ + Enrollment: e.enrollment, + MessageType: mdm.MessageType{MessageType: "TokenUpdate"}, + Push: e.push, + UnlockToken: e.unlockToken, + } + return test.PlistReader(t) +} + +// DoTokenUpdate sends a TokenUpdate to the MDM server. +func (e *Enrollment) DoTokenUpdate(ctx context.Context) error { + e.enrollM.Lock() + defer e.enrollM.Unlock() + return e.doTokenUpdate(ctx) +} + +// doTokenUpdate sends a TokenUpdate to the MDM server. +func (e *Enrollment) doTokenUpdate(ctx context.Context) error { + // generate new random push token. + // the token comes from Apple's APNs service. so we'll simulate this + // by re-generating the token every time we do a TokenUpdate. + e.push.Token = []byte(randString(32)) + + // generate TokenUpdate check-in message + msg, err := e.genTokenUpdate() + if err != nil { + return err + } + + // send it to the MDM server + resp, err := e.transport.DoCheckIn(ctx, msg) + if err != nil { + return err + } + defer resp.Body.Close() + + // check for errors + return HTTPErrors(resp) +} + +// DoEnroll enrolls (or re-enrolls) this enrollment into MDM. +// Authenticate and TokenUpdate check-in messages are sent via the +// transport to the MDM server. +func (e *Enrollment) DoEnroll(ctx context.Context) error { + e.enrollM.Lock() + defer e.enrollM.Unlock() + + if e.enrolled { + e.enrolled = false + } + + // generate Authenticate check-in message + auth, err := e.genAuthenticate() + if err != nil { + return err + } + + // send it to the MDM server + authResp, err := e.transport.DoCheckIn(ctx, auth) + if err != nil { + return err + } + + // check for any errors + if err = HTTPErrors(authResp); err != nil { + authResp.Body.Close() + return fmt.Errorf("enrollment authenticate check-in: %w", err) + } + authResp.Body.Close() + + err = e.doTokenUpdate(ctx) + if err != nil { + return err + } + + e.enrolled = true + + return nil +} + +// GetEnrollment returns the enrollment identifier data. +func (e *Enrollment) GetEnrollment() *mdm.Enrollment { + return &e.enrollment +} + +// ID returns the NanoMDM "normalized" enrollment ID. +func (e *Enrollment) ID() string { + // we know we're only dealing with device IDs at this point. + // make that assumption of the UDID for the normalized ID. + return e.enrollment.UDID +} + +// EnrollID returns the NanoMDM enroll ID. +func (e *Enrollment) EnrollID() *mdm.EnrollID { + return &e.enrollID +} + +func (e *Enrollment) NewMDMRequest(ctx context.Context) *mdm.Request { + return &mdm.Request{ + Context: ctx, + EnrollID: e.EnrollID(), + } +} + +// GetPush returns the enrollment push info data. +func (e *Enrollment) GetPush() *mdm.Push { + return &e.push +} + +// DoReportAndFetch sends report to the MDM server. +// Any new command delivered will be in the response. +// The caller is responsible for closing the response body. +func (e *Enrollment) DoReportAndFetch(ctx context.Context, report io.Reader) (*http.Response, error) { + return e.transport.DoReportResultsAndFetchNext(ctx, report) +} + +// genSetBootstrapToken creates an XML Plist SetBootstrapToken check-in message. +func (e *Enrollment) genSetBootstrapToken(token []byte) (io.Reader, error) { + msg := &mdm.SetBootstrapToken{ + Enrollment: e.enrollment, + MessageType: mdm.MessageType{MessageType: "SetBootstrapToken"}, + BootstrapToken: mdm.BootstrapToken{BootstrapToken: make([]byte, base64.StdEncoding.EncodedLen(len(token)))}, + } + base64.StdEncoding.Encode(msg.BootstrapToken.BootstrapToken, token) + return test.PlistReader(msg) +} + +// DoEscrowBootstrapToken sends the Bootstrap Token to the MDM server. +func (e *Enrollment) DoEscrowBootstrapToken(ctx context.Context, token []byte) error { + r, err := e.genSetBootstrapToken(token) + if err != nil { + return err + } + + // send it to the MDM server + resp, err := e.transport.DoCheckIn(ctx, r) + if err != nil { + return err + } + defer resp.Body.Close() + + // check for errors + return HTTPErrors(resp) +} + +// genGetBootstrapToken creates an XML Plist GetBootstrapToken check-in message. +func (e *Enrollment) genGetBootstrapToken() (io.Reader, error) { + msg := &mdm.GetBootstrapToken{ + Enrollment: e.enrollment, + MessageType: mdm.MessageType{MessageType: "GetBootstrapToken"}, + } + return test.PlistReader(msg) +} + +// DoGetBootstrapToken retrieves the Bootstrap Token from the MDM erver. +func (e *Enrollment) DoGetBootstrapToken(ctx context.Context) (*mdm.BootstrapToken, error) { + r, err := e.genGetBootstrapToken() + if err != nil { + return nil, err + } + + // send it to the MDM server + resp, err := e.transport.DoCheckIn(ctx, r) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + body, err := io.ReadAll(io.LimitReader(resp.Body, Limit10KiB)) + if err != nil { + return nil, err + } + + if resp.StatusCode != 200 { + return nil, NewHTTPError(resp, body) + } + + var tok *mdm.BootstrapToken + if len(body) > 0 { + tok = new(mdm.BootstrapToken) + err = plist.Unmarshal(body, tok) + } + return tok, err +} + +func randString(n int) string { + b := make([]byte, n) + rand.Read(b) + return fmt.Sprintf("%x", b) +} diff --git a/test/enrollment/utils.go b/test/enrollment/utils.go new file mode 100644 index 0000000..5b487cd --- /dev/null +++ b/test/enrollment/utils.go @@ -0,0 +1,66 @@ +package enrollment + +import ( + "errors" + "fmt" + "io" + "net/http" + "strconv" +) + +// HTTPError contains the body and status details. +type HTTPError struct { + Body []byte + Status string + StatusCode int +} + +func NewHTTPError(response *http.Response, body []byte) *HTTPError { + if response == nil { + response = &http.Response{} + } + return &HTTPError{ + Body: body, + Status: response.Status, + StatusCode: response.StatusCode, + } +} + +// Error returns strings for HTTP errors that may include body and status. +func (e *HTTPError) Error() (err string) { + err = "HTTP error" + if e == nil { + return + } + if e.Status != "" { + err += ": " + e.Status + } else { + err += ": " + strconv.Itoa(e.StatusCode) + } + if len(e.Body) > 0 { + err += ": " + string(e.Body) + } + return +} + +const Limit10KiB = 10 * 1024 + +// HTTPErrors reports an HTTP error for a non-200 HTTP response. +// The first 10KiB of the body is read for non-200 response. +// For a 200 response nil is returned. +// Caller is responsible for closing response body. +func HTTPErrors(r *http.Response) error { + if r == nil { + return errors.New("nil response") + } + + if r.StatusCode != 200 { + body, err := io.ReadAll(io.LimitReader(r.Body, Limit10KiB)) + if err != nil { + return fmt.Errorf("error reading body of non-200 response: %w", err) + } + return NewHTTPError(r, body) + } + + return nil +} diff --git a/service/certauth/helpers_test.go b/test/helpers.go similarity index 99% rename from service/certauth/helpers_test.go rename to test/helpers.go index 27f38a9..e53d0d4 100644 --- a/service/certauth/helpers_test.go +++ b/test/helpers.go @@ -1,4 +1,4 @@ -package certauth +package test import ( "crypto/rand" diff --git a/test/plist.go b/test/plist.go new file mode 100644 index 0000000..8280f6a --- /dev/null +++ b/test/plist.go @@ -0,0 +1,16 @@ +package test + +import ( + "bytes" + "io" + + "github.com/groob/plist" +) + +// PlistReader encodes v to XML Plist. +func PlistReader(v interface{}) (io.Reader, error) { + buf := new(bytes.Buffer) + enc := plist.NewEncoder(buf) + enc.Indent("\t") + return buf, enc.Encode(v) +} diff --git a/test/protocol/transport.go b/test/protocol/transport.go new file mode 100644 index 0000000..91aa100 --- /dev/null +++ b/test/protocol/transport.go @@ -0,0 +1,164 @@ +// Package protocol implements primitives and interfaces of the base Apple MDM protocol. +package protocol + +import ( + "bytes" + "context" + "crypto" + "crypto/x509" + "encoding/base64" + "errors" + "fmt" + "io" + "net/http" + + "github.com/smallstep/pkcs7" +) + +const ( + // CheckInMIMEType is the HTTP MIME type of Apple MDM check-in messages. + CheckInMIMEType = "application/x-apple-aspen-mdm-checkin" + + // MDMSignatureHeader is the HTTP header name for the in-message + // signature checking. + MDMSignatureHeader = "Mdm-Signature" +) + +var ( + ErrMissingDeviceIdentity = errors.New("missing device identity") + ErrNilTransport = errors.New("nil transport") +) + +// Doer executes an HTTP request. +type Doer interface { + Do(*http.Request) (*http.Response, error) +} + +type IdentityProvider func(context.Context) (*x509.Certificate, crypto.PrivateKey, error) + +// Transport encapsulates the MDM enrollment underlying MDM transport. +// The MDM channels utilize this transport to communicate with the host. +type Transport struct { + checkInURL string + serverURL string + signMessage bool + provider IdentityProvider + doer Doer +} + +type TransportOption func(*Transport) + +// WithClient configures the HTTP client for this transport. +func WithClient(doer Doer) TransportOption { + return func(t *Transport) { + t.doer = doer + } +} + +// WithIdentityProvider configures the certificate and private key provider for this transport. +func WithIdentityProvider(f IdentityProvider) TransportOption { + return func(t *Transport) { + t.provider = f + } +} + +// WithMDMURLs supplies the ServerURL and CheckInURLs to the transport. +// Per MDM spec checkInURL is optional. +func WithMDMURLs(serverURL, checkInURL string) TransportOption { + return func(t *Transport) { + t.serverURL = serverURL + t.checkInURL = checkInURL + } +} + +// WithSignMessage include the signed message header. +func WithSignMessage() TransportOption { + return func(t *Transport) { + t.signMessage = true + } +} + +func NewTransport(opts ...TransportOption) *Transport { + t := &Transport{ + doer: http.DefaultClient, + } + for _, opt := range opts { + opt(t) + } + return t +} + +// SignMessage generates the CMS detached signature encoded as Base64. +func (t *Transport) SignMessage(ctx context.Context, body []byte) (string, error) { + if t.provider == nil { + return "", ErrMissingDeviceIdentity + } + cert, key, err := t.provider(ctx) + if err != nil { + return "", err + } + if cert == nil || key == nil { + return "", ErrMissingDeviceIdentity + } + sd, err := pkcs7.NewSignedData(body) + if err != nil { + return "", err + } + err = sd.AddSigner(cert, key, pkcs7.SignerInfoConfig{}) + if err != nil { + return "", err + } + sd.Detach() + sig, err := sd.Finish() + return base64.StdEncoding.EncodeToString(sig), err +} + +func (t *Transport) doRequest(ctx context.Context, body io.Reader, checkin bool) (*http.Response, error) { + if t == nil { + return nil, ErrNilTransport + } + var bodyBuf *bytes.Buffer + if t.signMessage { + bodyBuf = new(bytes.Buffer) + if _, err := bodyBuf.ReadFrom(body); err != nil { + return nil, fmt.Errorf("reading body into buffer: %w", err) + } + body = bodyBuf + } + + url := t.serverURL + if checkin && t.checkInURL != "" { + url = t.checkInURL + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPut, url, body) + if err != nil { + return nil, fmt.Errorf("creating request: %w", err) + } + + if checkin { + req.Header.Set("Content-Type", CheckInMIMEType) + } + + if t.signMessage { + sig, err := t.SignMessage(ctx, bodyBuf.Bytes()) + if err != nil { + return nil, fmt.Errorf("generating mdm-signature: %w", err) + } + req.Header.Set(MDMSignatureHeader, sig) + } + + return t.doer.Do(req) +} + +// DoCheckIn executes a check-in request with body. +// The caller is responsible for closing the response body. +func (t *Transport) DoCheckIn(ctx context.Context, body io.Reader) (*http.Response, error) { + return t.doRequest(ctx, body, true) +} + +// DoReportResultsAndFetchNext executes a report and fetch request with body. +// The caller is responsible for closing the response body. +func (t *Transport) DoReportResultsAndFetchNext(ctx context.Context, body io.Reader) (*http.Response, error) { + return t.doRequest(ctx, body, false) +}