From 0011b0fc413f8fae2a27d21e9b47c0c9a60498f7 Mon Sep 17 00:00:00 2001 From: Calvin Lee Date: Tue, 28 Feb 2023 09:38:31 -0800 Subject: [PATCH] Fix error on MySQL backend when bsToken null (#63) --- service/request.go | 10 ++-- storage/mysql/bstoken.go | 8 ++-- storage/mysql/bstoken_test.go | 60 +++++++++++++++++++++++ storage/mysql/common_test.go | 8 ++++ storage/mysql/device_test.go | 89 +++++++++++++++++++++++++++++++++++ storage/mysql/queue_test.go | 77 ++---------------------------- 6 files changed, 172 insertions(+), 80 deletions(-) create mode 100644 storage/mysql/bstoken_test.go create mode 100644 storage/mysql/common_test.go create mode 100644 storage/mysql/device_test.go diff --git a/service/request.go b/service/request.go index 332ba6f..6139ebc 100644 --- a/service/request.go +++ b/service/request.go @@ -59,15 +59,19 @@ func CheckinRequest(svc Checkin, r *mdm.Request, bodyBytes []byte) ([]byte, erro err = fmt.Errorf("setbootstraptoken service: %w", err) } case *mdm.GetBootstrapToken: + // https://developer.apple.com/documentation/devicemanagement/get_bootstrap_token var bsToken *mdm.BootstrapToken bsToken, err = svc.GetBootstrapToken(r, m) if err != nil { err = fmt.Errorf("getbootstraptoken service: %w", err) break } - respBytes, err = plist.Marshal(bsToken) - if err != nil { - err = fmt.Errorf("marshal bootstrap token: %w", err) + // If there is no bsToken, return an empty body + if bsToken != nil { + respBytes, err = plist.Marshal(bsToken) + if err != nil { + err = fmt.Errorf("marshal bootstrap token: %w", err) + } } case *mdm.DeclarativeManagement: respBytes, err = svc.DeclarativeManagement(r, m) diff --git a/storage/mysql/bstoken.go b/storage/mysql/bstoken.go index f21e150..0c2ed73 100644 --- a/storage/mysql/bstoken.go +++ b/storage/mysql/bstoken.go @@ -1,6 +1,8 @@ package mysql import ( + "database/sql" + "github.com/micromdm/nanomdm/mdm" ) @@ -18,17 +20,17 @@ func (s *MySQLStorage) StoreBootstrapToken(r *mdm.Request, msg *mdm.SetBootstrap } func (s *MySQLStorage) RetrieveBootstrapToken(r *mdm.Request, _ *mdm.GetBootstrapToken) (*mdm.BootstrapToken, error) { - var tokenB64 string + var tokenB64 sql.NullString err := s.db.QueryRowContext( r.Context, `SELECT bootstrap_token_b64 FROM devices WHERE id = ?;`, r.ID, ).Scan(&tokenB64) - if err != nil { + if err != nil || !tokenB64.Valid { return nil, err } bsToken := new(mdm.BootstrapToken) - err = bsToken.SetTokenString(tokenB64) + err = bsToken.SetTokenString(tokenB64.String) if err == nil { err = s.updateLastSeen(r) } diff --git a/storage/mysql/bstoken_test.go b/storage/mysql/bstoken_test.go new file mode 100644 index 0000000..4ccea1e --- /dev/null +++ b/storage/mysql/bstoken_test.go @@ -0,0 +1,60 @@ +//go:build integration +// +build integration + +package mysql + +import ( + "bytes" + "context" + "encoding/base64" + "testing" + + "github.com/micromdm/nanomdm/mdm" +) + +func TestBSToken(t *testing.T) { + if *flDSN == "" { + t.Fatal("MySQL DSN flag not provided to test") + } + + storage, err := New(WithDSN(*flDSN), WithDeleteCommands()) + if err != nil { + t.Fatal(err) + } + + var d Device + d, err = enrollTestDevice(storage) + if err != nil { + t.Fatal(err) + } + + ctx := context.Background() + + t.Run("BSToken nil", func(t *testing.T) { + tok, err := storage.RetrieveBootstrapToken(&mdm.Request{Context: ctx, EnrollID: d.EnrollID()}, nil) + if err != nil { + t.Fatal(err) + } + if tok != nil { + t.Fatal("Token for new device was nonnull") + } + }) + t.Run("BSToken set/get", func(t *testing.T) { + data := []byte("test token") + bsToken := mdm.BootstrapToken{BootstrapToken: make([]byte, base64.StdEncoding.EncodedLen(len(data)))} + base64.StdEncoding.Encode(bsToken.BootstrapToken, data) + testReq := &mdm.Request{Context: ctx, EnrollID: d.EnrollID()} + err := storage.StoreBootstrapToken(testReq, &mdm.SetBootstrapToken{BootstrapToken: bsToken}) + if err != nil { + t.Fatal(err) + } + + tok, err := storage.RetrieveBootstrapToken(testReq, nil) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(bsToken.BootstrapToken, tok.BootstrapToken) { + t.Fatalf("Bootstap tokens disequal after roundtrip: %v!=%v", bsToken, tok) + } + }) +} diff --git a/storage/mysql/common_test.go b/storage/mysql/common_test.go new file mode 100644 index 0000000..84e9a89 --- /dev/null +++ b/storage/mysql/common_test.go @@ -0,0 +1,8 @@ +//go:build integration +// +build integration + +package mysql + +import "flag" + +var flDSN = flag.String("dsn", "", "DSN of test MySQL instance") diff --git a/storage/mysql/device_test.go b/storage/mysql/device_test.go new file mode 100644 index 0000000..a431196 --- /dev/null +++ b/storage/mysql/device_test.go @@ -0,0 +1,89 @@ +//go:build integration +// +build integration + +package mysql + +import ( + "context" + "errors" + "io/ioutil" + + "github.com/micromdm/nanomdm/mdm" + "github.com/micromdm/nanomdm/storage" +) + +type DeviceInterfaces interface { + storage.CheckinStore +} + +type Device struct { + UDID string +} + +func (d *Device) EnrollID() *mdm.EnrollID { + return &mdm.EnrollID{Type: mdm.Device, ID: d.UDID} +} + +func loadAuthMsg() (*mdm.Authenticate, Device, error) { + var d Device + b, err := ioutil.ReadFile("../../mdm/testdata/Authenticate.2.plist") + if err != nil { + return nil, d, err + } + r, err := mdm.DecodeCheckin(b) + if err != nil { + return nil, d, err + } + a, ok := r.(*mdm.Authenticate) + if !ok { + return nil, d, errors.New("not an Authenticate message") + } + d = Device{UDID: a.UDID} + return a, d, nil +} + +func loadTokenMsg() (*mdm.TokenUpdate, error) { + b, err := ioutil.ReadFile("../../mdm/testdata/TokenUpdate.2.plist") + if err != nil { + return nil, err + } + r, err := mdm.DecodeCheckin(b) + if err != nil { + return nil, err + } + a, ok := r.(*mdm.TokenUpdate) + if !ok { + return nil, errors.New("not a TokenUpdate message") + } + return a, nil +} + +func (d *Device) newMdmReq() *mdm.Request { + return &mdm.Request{ + Context: context.Background(), + EnrollID: &mdm.EnrollID{ + Type: mdm.Device, + ID: d.UDID, + }, + } +} + +func enrollTestDevice(storage DeviceInterfaces) (Device, error) { + authMsg, d, err := loadAuthMsg() + if err != nil { + return d, err + } + err = storage.StoreAuthenticate(d.newMdmReq(), authMsg) + if err != nil { + return d, err + } + tokenMsg, err := loadTokenMsg() + if err != nil { + return d, err + } + err = storage.StoreTokenUpdate(d.newMdmReq(), tokenMsg) + if err != nil { + return d, err + } + return d, nil +} diff --git a/storage/mysql/queue_test.go b/storage/mysql/queue_test.go index 49ee203..08cf8d6 100644 --- a/storage/mysql/queue_test.go +++ b/storage/mysql/queue_test.go @@ -4,84 +4,13 @@ package mysql import ( - "context" - "errors" - "flag" - "io/ioutil" "testing" - "github.com/micromdm/nanomdm/mdm" "github.com/micromdm/nanomdm/storage/internal/test" _ "github.com/go-sql-driver/mysql" ) -var flDSN = flag.String("dsn", "", "DSN of test MySQL instance") - -func loadAuthMsg() (*mdm.Authenticate, error) { - b, err := ioutil.ReadFile("../../mdm/testdata/Authenticate.2.plist") - if err != nil { - return nil, err - } - r, err := mdm.DecodeCheckin(b) - if err != nil { - return nil, err - } - a, ok := r.(*mdm.Authenticate) - if !ok { - return nil, errors.New("not an Authenticate message") - } - return a, nil -} - -func loadTokenMsg() (*mdm.TokenUpdate, error) { - b, err := ioutil.ReadFile("../../mdm/testdata/TokenUpdate.2.plist") - if err != nil { - return nil, err - } - r, err := mdm.DecodeCheckin(b) - if err != nil { - return nil, err - } - a, ok := r.(*mdm.TokenUpdate) - if !ok { - return nil, errors.New("not a TokenUpdate message") - } - return a, nil -} - -const deviceUDID = "66ADE930-5FDF-5EC4-8429-15640684C489" - -func newMdmReq() *mdm.Request { - return &mdm.Request{ - Context: context.Background(), - EnrollID: &mdm.EnrollID{ - Type: mdm.Device, - ID: deviceUDID, - }, - } -} - -func enrollTestDevice(storage *MySQLStorage) error { - authMsg, err := loadAuthMsg() - if err != nil { - return err - } - err = storage.StoreAuthenticate(newMdmReq(), authMsg) - if err != nil { - return err - } - tokenMsg, err := loadTokenMsg() - if err != nil { - return err - } - err = storage.StoreTokenUpdate(newMdmReq(), tokenMsg) - if err != nil { - return err - } - return nil -} - func TestQueue(t *testing.T) { if *flDSN == "" { t.Fatal("MySQL DSN flag not provided to test") @@ -92,13 +21,13 @@ func TestQueue(t *testing.T) { t.Fatal(err) } - err = enrollTestDevice(storage) + d, err := enrollTestDevice(storage) if err != nil { t.Fatal(err) } t.Run("WithDeleteCommands()", func(t *testing.T) { - test.TestQueue(t, deviceUDID, storage) + test.TestQueue(t, d.UDID, storage) }) storage, err = New(WithDSN(*flDSN)) @@ -107,6 +36,6 @@ func TestQueue(t *testing.T) { } t.Run("normal", func(t *testing.T) { - test.TestQueue(t, deviceUDID, storage) + test.TestQueue(t, d.UDID, storage) }) }