Skip to content

Commit

Permalink
Fix error on MySQL backend when bsToken null (#63)
Browse files Browse the repository at this point in the history
  • Loading branch information
Calvin Lee authored Feb 28, 2023
1 parent c771e92 commit 0011b0f
Show file tree
Hide file tree
Showing 6 changed files with 172 additions and 80 deletions.
10 changes: 7 additions & 3 deletions service/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,19 @@ func CheckinRequest(svc Checkin, r *mdm.Request, bodyBytes []byte) ([]byte, erro
err = fmt.Errorf("setbootstraptoken service: %w", err)
}
case *mdm.GetBootstrapToken:
// https://developer.apple.com/documentation/devicemanagement/get_bootstrap_token
var bsToken *mdm.BootstrapToken
bsToken, err = svc.GetBootstrapToken(r, m)
if err != nil {
err = fmt.Errorf("getbootstraptoken service: %w", err)
break
}
respBytes, err = plist.Marshal(bsToken)
if err != nil {
err = fmt.Errorf("marshal bootstrap token: %w", err)
// If there is no bsToken, return an empty body
if bsToken != nil {
respBytes, err = plist.Marshal(bsToken)
if err != nil {
err = fmt.Errorf("marshal bootstrap token: %w", err)
}
}
case *mdm.DeclarativeManagement:
respBytes, err = svc.DeclarativeManagement(r, m)
Expand Down
8 changes: 5 additions & 3 deletions storage/mysql/bstoken.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package mysql

import (
"database/sql"

"github.com/micromdm/nanomdm/mdm"
)

Expand All @@ -18,17 +20,17 @@ func (s *MySQLStorage) StoreBootstrapToken(r *mdm.Request, msg *mdm.SetBootstrap
}

func (s *MySQLStorage) RetrieveBootstrapToken(r *mdm.Request, _ *mdm.GetBootstrapToken) (*mdm.BootstrapToken, error) {
var tokenB64 string
var tokenB64 sql.NullString
err := s.db.QueryRowContext(
r.Context,
`SELECT bootstrap_token_b64 FROM devices WHERE id = ?;`,
r.ID,
).Scan(&tokenB64)
if err != nil {
if err != nil || !tokenB64.Valid {
return nil, err
}
bsToken := new(mdm.BootstrapToken)
err = bsToken.SetTokenString(tokenB64)
err = bsToken.SetTokenString(tokenB64.String)
if err == nil {
err = s.updateLastSeen(r)
}
Expand Down
60 changes: 60 additions & 0 deletions storage/mysql/bstoken_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
//go:build integration
// +build integration

package mysql

import (
"bytes"
"context"
"encoding/base64"
"testing"

"github.com/micromdm/nanomdm/mdm"
)

func TestBSToken(t *testing.T) {
if *flDSN == "" {
t.Fatal("MySQL DSN flag not provided to test")
}

storage, err := New(WithDSN(*flDSN), WithDeleteCommands())
if err != nil {
t.Fatal(err)
}

var d Device
d, err = enrollTestDevice(storage)
if err != nil {
t.Fatal(err)
}

ctx := context.Background()

t.Run("BSToken nil", func(t *testing.T) {
tok, err := storage.RetrieveBootstrapToken(&mdm.Request{Context: ctx, EnrollID: d.EnrollID()}, nil)
if err != nil {
t.Fatal(err)
}
if tok != nil {
t.Fatal("Token for new device was nonnull")
}
})
t.Run("BSToken set/get", func(t *testing.T) {
data := []byte("test token")
bsToken := mdm.BootstrapToken{BootstrapToken: make([]byte, base64.StdEncoding.EncodedLen(len(data)))}
base64.StdEncoding.Encode(bsToken.BootstrapToken, data)
testReq := &mdm.Request{Context: ctx, EnrollID: d.EnrollID()}
err := storage.StoreBootstrapToken(testReq, &mdm.SetBootstrapToken{BootstrapToken: bsToken})
if err != nil {
t.Fatal(err)
}

tok, err := storage.RetrieveBootstrapToken(testReq, nil)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(bsToken.BootstrapToken, tok.BootstrapToken) {
t.Fatalf("Bootstap tokens disequal after roundtrip: %v!=%v", bsToken, tok)
}
})
}
8 changes: 8 additions & 0 deletions storage/mysql/common_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
//go:build integration
// +build integration

package mysql

import "flag"

var flDSN = flag.String("dsn", "", "DSN of test MySQL instance")
89 changes: 89 additions & 0 deletions storage/mysql/device_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
//go:build integration
// +build integration

package mysql

import (
"context"
"errors"
"io/ioutil"

"github.com/micromdm/nanomdm/mdm"
"github.com/micromdm/nanomdm/storage"
)

type DeviceInterfaces interface {
storage.CheckinStore
}

type Device struct {
UDID string
}

func (d *Device) EnrollID() *mdm.EnrollID {
return &mdm.EnrollID{Type: mdm.Device, ID: d.UDID}
}

func loadAuthMsg() (*mdm.Authenticate, Device, error) {
var d Device
b, err := ioutil.ReadFile("../../mdm/testdata/Authenticate.2.plist")
if err != nil {
return nil, d, err
}
r, err := mdm.DecodeCheckin(b)
if err != nil {
return nil, d, err
}
a, ok := r.(*mdm.Authenticate)
if !ok {
return nil, d, errors.New("not an Authenticate message")
}
d = Device{UDID: a.UDID}
return a, d, nil
}

func loadTokenMsg() (*mdm.TokenUpdate, error) {
b, err := ioutil.ReadFile("../../mdm/testdata/TokenUpdate.2.plist")
if err != nil {
return nil, err
}
r, err := mdm.DecodeCheckin(b)
if err != nil {
return nil, err
}
a, ok := r.(*mdm.TokenUpdate)
if !ok {
return nil, errors.New("not a TokenUpdate message")
}
return a, nil
}

func (d *Device) newMdmReq() *mdm.Request {
return &mdm.Request{
Context: context.Background(),
EnrollID: &mdm.EnrollID{
Type: mdm.Device,
ID: d.UDID,
},
}
}

func enrollTestDevice(storage DeviceInterfaces) (Device, error) {
authMsg, d, err := loadAuthMsg()
if err != nil {
return d, err
}
err = storage.StoreAuthenticate(d.newMdmReq(), authMsg)
if err != nil {
return d, err
}
tokenMsg, err := loadTokenMsg()
if err != nil {
return d, err
}
err = storage.StoreTokenUpdate(d.newMdmReq(), tokenMsg)
if err != nil {
return d, err
}
return d, nil
}
77 changes: 3 additions & 74 deletions storage/mysql/queue_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,84 +4,13 @@
package mysql

import (
"context"
"errors"
"flag"
"io/ioutil"
"testing"

"github.com/micromdm/nanomdm/mdm"
"github.com/micromdm/nanomdm/storage/internal/test"

_ "github.com/go-sql-driver/mysql"
)

var flDSN = flag.String("dsn", "", "DSN of test MySQL instance")

func loadAuthMsg() (*mdm.Authenticate, error) {
b, err := ioutil.ReadFile("../../mdm/testdata/Authenticate.2.plist")
if err != nil {
return nil, err
}
r, err := mdm.DecodeCheckin(b)
if err != nil {
return nil, err
}
a, ok := r.(*mdm.Authenticate)
if !ok {
return nil, errors.New("not an Authenticate message")
}
return a, nil
}

func loadTokenMsg() (*mdm.TokenUpdate, error) {
b, err := ioutil.ReadFile("../../mdm/testdata/TokenUpdate.2.plist")
if err != nil {
return nil, err
}
r, err := mdm.DecodeCheckin(b)
if err != nil {
return nil, err
}
a, ok := r.(*mdm.TokenUpdate)
if !ok {
return nil, errors.New("not a TokenUpdate message")
}
return a, nil
}

const deviceUDID = "66ADE930-5FDF-5EC4-8429-15640684C489"

func newMdmReq() *mdm.Request {
return &mdm.Request{
Context: context.Background(),
EnrollID: &mdm.EnrollID{
Type: mdm.Device,
ID: deviceUDID,
},
}
}

func enrollTestDevice(storage *MySQLStorage) error {
authMsg, err := loadAuthMsg()
if err != nil {
return err
}
err = storage.StoreAuthenticate(newMdmReq(), authMsg)
if err != nil {
return err
}
tokenMsg, err := loadTokenMsg()
if err != nil {
return err
}
err = storage.StoreTokenUpdate(newMdmReq(), tokenMsg)
if err != nil {
return err
}
return nil
}

func TestQueue(t *testing.T) {
if *flDSN == "" {
t.Fatal("MySQL DSN flag not provided to test")
Expand All @@ -92,13 +21,13 @@ func TestQueue(t *testing.T) {
t.Fatal(err)
}

err = enrollTestDevice(storage)
d, err := enrollTestDevice(storage)
if err != nil {
t.Fatal(err)
}

t.Run("WithDeleteCommands()", func(t *testing.T) {
test.TestQueue(t, deviceUDID, storage)
test.TestQueue(t, d.UDID, storage)
})

storage, err = New(WithDSN(*flDSN))
Expand All @@ -107,6 +36,6 @@ func TestQueue(t *testing.T) {
}

t.Run("normal", func(t *testing.T) {
test.TestQueue(t, deviceUDID, storage)
test.TestQueue(t, d.UDID, storage)
})
}

0 comments on commit 0011b0f

Please sign in to comment.