diff --git a/cmd/nanomdm/main.go b/cmd/nanomdm/main.go index 5ba2c81..5ccae1a 100644 --- a/cmd/nanomdm/main.go +++ b/cmd/nanomdm/main.go @@ -15,6 +15,7 @@ import ( "github.com/micromdm/nanomdm/cli" mdmhttp "github.com/micromdm/nanomdm/http" httpapi "github.com/micromdm/nanomdm/http/api" + "github.com/micromdm/nanomdm/http/authproxy" httpmdm "github.com/micromdm/nanomdm/http/mdm" "github.com/micromdm/nanomdm/log/stdlogfmt" "github.com/micromdm/nanomdm/push/buford" @@ -34,6 +35,8 @@ const ( endpointMDM = "/mdm" endpointCheckin = "/checkin" + endpointAuthProxy = "/authproxy/" + endpointAPIPushCert = "/v1/pushcert" endpointAPIPush = "/v1/push/" endpointAPIEnqueue = "/v1/enqueue/" @@ -62,6 +65,7 @@ func main() { flMigration = flag.Bool("migration", false, "HTTP endpoint for enrollment migrations") flRetro = flag.Bool("retro", false, "Allow retroactive certificate-authorization association") flDMURLPfx = flag.String("dm", "", "URL to send Declarative Management requests to") + flAuthProxy = flag.String("auth-proxy", "", "Reverse proxy URL target for MDM-authenticated HTTP requests") ) flag.Parse() @@ -158,6 +162,22 @@ func main() { } mux.Handle(endpointCheckin, checkinHandler) } + + if *flAuthProxy != "" { + var authProxyHandler http.Handler + authProxyHandler, err = authproxy.New(*flAuthProxy, logger.With("handler", "authproxy")) + if err != nil { + stdlog.Fatal(err) + } + authProxyHandler = httpmdm.CertWithEnrollmentIDMiddleware(authProxyHandler, certauth.HashCert, mdmStorage, true, logger.With("handler", "with-enrollment-id")) + authProxyHandler = httpmdm.CertVerifyMiddleware(authProxyHandler, verifier, logger.With("handler", "cert-verify")) + if *flCertHeader != "" { + authProxyHandler = httpmdm.CertExtractPEMHeaderMiddleware(authProxyHandler, *flCertHeader, logger.With("handler", "cert-extract")) + } else { + authProxyHandler = httpmdm.CertExtractMdmSignatureMiddleware(authProxyHandler, logger.With("handler", "cert-extract")) + } + mux.Handle(endpointAuthProxy, authProxyHandler) + } } if *flAPIKey != "" { diff --git a/http/authproxy/authproxy.go b/http/authproxy/authproxy.go new file mode 100644 index 0000000..4bb8c5b --- /dev/null +++ b/http/authproxy/authproxy.go @@ -0,0 +1,51 @@ +// Package authproxy is a simple reverse proxy for Apple MDM clients. +package authproxy + +import ( + "net/http" + "net/http/httputil" + "net/url" + + mdmhttp "github.com/micromdm/nanomdm/http" + httpmdm "github.com/micromdm/nanomdm/http/mdm" + "github.com/micromdm/nanomdm/log" + "github.com/micromdm/nanomdm/log/ctxlog" +) + +const ( + EnrollmentIDHeader = "X-Enrollment-ID" + TraceIDHeader = "X-Trace-ID" +) + +// New creates a new NanoMDM enrollment authenticating reverse proxy. +// This reverse proxy is mostly the standard httputil proxy. It depends +// on middleware HTTP handlers to enforce authentication and set the +// context value for the enrollment ID. +func New(dest string, logger log.Logger) (*httputil.ReverseProxy, error) { + target, err := url.Parse(dest) + if err != nil { + return nil, err + } + proxy := httputil.NewSingleHostReverseProxy(target) + proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) { + ctxlog.Logger(r.Context(), logger).Info("err", err) + // use the same error as the standrad reverse proxy + w.WriteHeader(http.StatusBadGateway) + } + dir := proxy.Director + proxy.Director = func(req *http.Request) { + dir(req) + req.Host = target.Host + // save the effort of forwarding this huge header + req.Header.Del("Mdm-Signature") + if id := httpmdm.GetEnrollmentID(req.Context()); id != "" { + req.Header.Set(EnrollmentIDHeader, id) + } + // TODO: this couples us to our specific idea of trace logging + // Perhaps have an optional config for header specificaiton? + if id := mdmhttp.GetTraceID(req.Context()); id != "" { + req.Header.Set(TraceIDHeader, id) + } + } + return proxy, nil +} diff --git a/http/http.go b/http/http.go index 9065b0d..4233705 100644 --- a/http/http.go +++ b/http/http.go @@ -51,6 +51,12 @@ func VersionHandler(version string) http.HandlerFunc { type ctxKeyTraceID struct{} +// GetTraceID returns the trace ID from ctx. +func GetTraceID(ctx context.Context) string { + id, _ := ctx.Value(ctxKeyTraceID{}).(string) + return id +} + // TraceLoggingMiddleware sets up a trace ID in the request context and // logs HTTP requests. func TraceLoggingMiddleware(next http.Handler, logger log.Logger, traceID func(*http.Request) string) http.HandlerFunc { diff --git a/http/mdm/mdm_cert.go b/http/mdm/mdm_cert.go index 86e51fb..abfcd7e 100644 --- a/http/mdm/mdm_cert.go +++ b/http/mdm/mdm_cert.go @@ -10,10 +10,13 @@ import ( mdmhttp "github.com/micromdm/nanomdm/http" "github.com/micromdm/nanomdm/log" "github.com/micromdm/nanomdm/log/ctxlog" + "github.com/micromdm/nanomdm/storage" ) type contextKeyCert struct{} +var contextEnrollmentID struct{} + // CertExtractPEMHeaderMiddleware extracts the MDM enrollment identity // certificate from the request into the HTTP request context. It looks // at the request header which should be a URL-encoded PEM certificate. @@ -128,3 +131,68 @@ func CertVerifyMiddleware(next http.Handler, verifier CertVerifier, logger log.L next.ServeHTTP(w, r) } } + +// GetEnrollmentID retrieves the MDM enrollment ID from ctx. +func GetEnrollmentID(ctx context.Context) string { + id, _ := ctx.Value(contextEnrollmentID).(string) + return id +} + +type HashFn func(*x509.Certificate) string + +// WithEnrollmentIDMiddleware tries to associate the enrollment ID to the request context. +// It does this by looking up the certificate on the context, hashing it with +// hasher, looking up the hash in storage, and setting the ID on the context. +// +// The next handler will be called even if cert or ID is not found unless +// enforce is true. This way next is able to use the existence of the ID on +// the context to make its own decisions. +func WithEnrollmentIDMiddleware(next http.Handler, hasher HashFn, store storage.CertAuthRetriever, enforce bool, logger log.Logger) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + cert := GetCert(r.Context()) + if cert == nil { + if enforce { + ctxlog.Logger(r.Context(), logger).Info( + "err", "missing certificate", + ) + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusBadRequest) + return + } else { + ctxlog.Logger(r.Context(), logger).Debug( + "msg", "missing certificate", + ) + next.ServeHTTP(w, r) + return + } + } + if store == nil || hasher == nil { + panic("store nor hasher must not be nil") + } + mr, err := store.EnrollmentFromHash(r.Context(), hasher(cert)) + if err != nil { + ctxlog.Logger(r.Context(), logger).Info( + "msg", "retreiving enrollment from hash", + "err", err, + ) + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return + } + if mr == nil || mr.ID == "" { + if enforce { + ctxlog.Logger(r.Context(), logger).Info( + "err", "missing enrollment id", + ) + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusBadRequest) + return + } else { + ctxlog.Logger(r.Context(), logger).Debug( + "msg", "missing enrollment id", + ) + next.ServeHTTP(w, r) + return + } + } + ctx := context.WithValue(r.Context(), contextEnrollmentID, mr.ID) + next.ServeHTTP(w, r.WithContext(ctx)) + } +} diff --git a/service/certauth/certauth.go b/service/certauth/certauth.go index fda00b0..fbe3c84 100644 --- a/service/certauth/certauth.go +++ b/service/certauth/certauth.go @@ -97,7 +97,8 @@ func New(next service.CheckinAndCommandService, storage storage.CertAuthStore, o return certAuth } -func hashCert(cert *x509.Certificate) string { +// HashCert returns the string representation +func HashCert(cert *x509.Certificate) string { hashed := sha256.Sum256(cert.Raw) b := make([]byte, len(hashed)) copy(b, hashed[:]) @@ -112,7 +113,7 @@ func (s *CertAuth) associateNewEnrollment(r *mdm.Request) error { return err } logger := ctxlog.Logger(r.Context, s.logger) - hash := hashCert(r.Certificate) + hash := HashCert(r.Certificate) if hasHash, err := s.storage.HasCertHash(r, hash); err != nil { return err } else if hasHash { @@ -157,7 +158,7 @@ func (s *CertAuth) validateAssociateExistingEnrollment(r *mdm.Request) error { return err } logger := ctxlog.Logger(r.Context, s.logger) - hash := hashCert(r.Certificate) + hash := HashCert(r.Certificate) if isAssoc, err := s.storage.IsCertHashAssociated(r, hash); err != nil { return err } else if isAssoc { diff --git a/storage/all.go b/storage/all.go index bb69851..10b48c0 100644 --- a/storage/all.go +++ b/storage/all.go @@ -7,6 +7,7 @@ type AllStorage interface { PushCertStore CommandEnqueuer CertAuthStore + CertAuthRetriever StoreMigrator TokenUpdateTallyStore } diff --git a/storage/allmulti/certauth.go b/storage/allmulti/certauth.go index 3593c6c..4221226 100644 --- a/storage/allmulti/certauth.go +++ b/storage/allmulti/certauth.go @@ -1,6 +1,8 @@ package allmulti import ( + "context" + "github.com/micromdm/nanomdm/mdm" "github.com/micromdm/nanomdm/storage" ) @@ -32,3 +34,10 @@ func (ms *MultiAllStorage) AssociateCertHash(r *mdm.Request, hash string) error }) return err } + +func (ms *MultiAllStorage) EnrollmentFromHash(ctx context.Context, hash string) (*mdm.Request, error) { + val, err := ms.execStores(ctx, func(s storage.AllStorage) (interface{}, error) { + return s.EnrollmentFromHash(ctx, hash) + }) + return val.(*mdm.Request), err +} diff --git a/storage/file/certauth.go b/storage/file/certauth.go index 47ee0d6..8a166cc 100644 --- a/storage/file/certauth.go +++ b/storage/file/certauth.go @@ -2,6 +2,7 @@ package file import ( "bufio" + "context" "errors" "os" "path" @@ -68,3 +69,23 @@ func (s *FileStorage) AssociateCertHash(r *mdm.Request, hash string) error { e := s.newEnrollment(r.ID) return e.writeFile(CertAuthFilename, []byte(hash)) } + +func (s *FileStorage) EnrollmentFromHash(_ context.Context, hash string) (*mdm.Request, error) { + f, err := os.Open(path.Join(s.path, CertAuthAssociationsFilename)) + if err != nil { + return nil, err + } + defer f.Close() + scanner := bufio.NewScanner(f) + for scanner.Scan() { + text := scanner.Text() + if strings.Contains(text, hash) { + split := strings.Split(text, ",") + if len(split) < 2 { + return nil, errors.New("hash and enrollment id not present on line") + } + return &mdm.Request{EnrollID: &mdm.EnrollID{ID: split[0]}}, nil + } + } + return nil, nil +} diff --git a/storage/mysql/certauth.go b/storage/mysql/certauth.go index cc64c0e..610e766 100644 --- a/storage/mysql/certauth.go +++ b/storage/mysql/certauth.go @@ -2,6 +2,8 @@ package mysql import ( "context" + "database/sql" + "errors" "strings" "github.com/micromdm/nanomdm/mdm" @@ -49,3 +51,16 @@ UPDATE sha256 = new.sha256;`, ) return err } + +func (s *MySQLStorage) EnrollmentFromHash(ctx context.Context, hash string) (*mdm.Request, error) { + var id string + err := s.db.QueryRowContext( + ctx, + `SELECT id FROM cert_auth_associations WHERE sha256 = ? LIMIT 1;`, + hash, + ).Scan(&id) + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return &mdm.Request{EnrollID: &mdm.EnrollID{ID: id}}, err +} diff --git a/storage/pgsql/certauth.go b/storage/pgsql/certauth.go index 1ceec98..1c16fc7 100644 --- a/storage/pgsql/certauth.go +++ b/storage/pgsql/certauth.go @@ -2,6 +2,8 @@ package pgsql import ( "context" + "database/sql" + "errors" "strings" "github.com/micromdm/nanomdm/mdm" @@ -50,3 +52,16 @@ ON CONFLICT ON CONSTRAINT cert_auth_associations_pkey DO UPDATE SET updated_at=n ) return err } + +func (s *PgSQLStorage) EnrollmentFromHash(ctx context.Context, hash string) (*mdm.Request, error) { + var id string + err := s.db.QueryRowContext( + ctx, + `SELECT id FROM cert_auth_associations WHERE sha256 = $1 LIMIT 1;`, + hash, + ).Scan(&id) + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return &mdm.Request{EnrollID: &mdm.EnrollID{ID: id}}, err +} diff --git a/storage/storage.go b/storage/storage.go index 22a1976..16914f5 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -65,6 +65,13 @@ type CertAuthStore interface { AssociateCertHash(r *mdm.Request, hash string) error } +type CertAuthRetriever interface { + // EnrollmentFromHash retrieves an MDM request from a cert hash. + // Implementations should return a nil pointer if no result is found. + // The ID member ought to be populated when non-nil. + EnrollmentFromHash(ctx context.Context, hash string) (*mdm.Request, error) +} + // StoreMigrator retrieves MDM check-ins type StoreMigrator interface { // RetrieveMigrationCheckins sends the (decoded) forms of