Skip to content

Commit

Permalink
add authproxy support
Browse files Browse the repository at this point in the history
  • Loading branch information
jessepeterson committed Jul 14, 2023
1 parent 65daeb8 commit 269a37b
Show file tree
Hide file tree
Showing 11 changed files with 217 additions and 3 deletions.
20 changes: 20 additions & 0 deletions cmd/nanomdm/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/micromdm/nanomdm/cli"
mdmhttp "github.com/micromdm/nanomdm/http"
httpapi "github.com/micromdm/nanomdm/http/api"
"github.com/micromdm/nanomdm/http/authproxy"
httpmdm "github.com/micromdm/nanomdm/http/mdm"
"github.com/micromdm/nanomdm/log/stdlogfmt"
"github.com/micromdm/nanomdm/push/buford"
Expand All @@ -34,6 +35,8 @@ const (
endpointMDM = "/mdm"
endpointCheckin = "/checkin"

endpointAuthProxy = "/authproxy/"

endpointAPIPushCert = "/v1/pushcert"
endpointAPIPush = "/v1/push/"
endpointAPIEnqueue = "/v1/enqueue/"
Expand Down Expand Up @@ -62,6 +65,7 @@ func main() {
flMigration = flag.Bool("migration", false, "HTTP endpoint for enrollment migrations")
flRetro = flag.Bool("retro", false, "Allow retroactive certificate-authorization association")
flDMURLPfx = flag.String("dm", "", "URL to send Declarative Management requests to")
flAuthProxy = flag.String("auth-proxy", "", "Reverse proxy URL target for MDM-authenticated HTTP requests")
)
flag.Parse()

Expand Down Expand Up @@ -158,6 +162,22 @@ func main() {
}
mux.Handle(endpointCheckin, checkinHandler)
}

if *flAuthProxy != "" {
var authProxyHandler http.Handler
authProxyHandler, err = authproxy.New(*flAuthProxy, logger.With("handler", "authproxy"))
if err != nil {
stdlog.Fatal(err)
}
authProxyHandler = httpmdm.CertWithEnrollmentIDMiddleware(authProxyHandler, certauth.HashCert, mdmStorage, true, logger.With("handler", "with-enrollment-id"))

Check failure on line 172 in cmd/nanomdm/main.go

View workflow job for this annotation

GitHub Actions / Build, test, and format (1.17.x, ubuntu-latest)

undefined: "github.com/micromdm/nanomdm/http/mdm".CertWithEnrollmentIDMiddleware
authProxyHandler = httpmdm.CertVerifyMiddleware(authProxyHandler, verifier, logger.With("handler", "cert-verify"))
if *flCertHeader != "" {
authProxyHandler = httpmdm.CertExtractPEMHeaderMiddleware(authProxyHandler, *flCertHeader, logger.With("handler", "cert-extract"))
} else {
authProxyHandler = httpmdm.CertExtractMdmSignatureMiddleware(authProxyHandler, logger.With("handler", "cert-extract"))
}
mux.Handle(endpointAuthProxy, authProxyHandler)
}
}

if *flAPIKey != "" {
Expand Down
51 changes: 51 additions & 0 deletions http/authproxy/authproxy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Package authproxy is a simple reverse proxy for Apple MDM clients.
package authproxy

import (
"net/http"
"net/http/httputil"
"net/url"

mdmhttp "github.com/micromdm/nanomdm/http"
httpmdm "github.com/micromdm/nanomdm/http/mdm"
"github.com/micromdm/nanomdm/log"
"github.com/micromdm/nanomdm/log/ctxlog"
)

const (
EnrollmentIDHeader = "X-Enrollment-ID"
TraceIDHeader = "X-Trace-ID"
)

// New creates a new NanoMDM enrollment authenticating reverse proxy.
// This reverse proxy is mostly the standard httputil proxy. It depends
// on middleware HTTP handlers to enforce authentication and set the
// context value for the enrollment ID.
func New(dest string, logger log.Logger) (*httputil.ReverseProxy, error) {
target, err := url.Parse(dest)
if err != nil {
return nil, err
}
proxy := httputil.NewSingleHostReverseProxy(target)
proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
ctxlog.Logger(r.Context(), logger).Info("err", err)
// use the same error as the standrad reverse proxy
w.WriteHeader(http.StatusBadGateway)
}
dir := proxy.Director
proxy.Director = func(req *http.Request) {
dir(req)
req.Host = target.Host
// save the effort of forwarding this huge header
req.Header.Del("Mdm-Signature")
if id := httpmdm.GetEnrollmentID(req.Context()); id != "" {
req.Header.Set(EnrollmentIDHeader, id)
}
// TODO: this couples us to our specific idea of trace logging
// Perhaps have an optional config for header specificaiton?
if id := mdmhttp.GetTraceID(req.Context()); id != "" {
req.Header.Set(TraceIDHeader, id)
}
}
return proxy, nil
}
6 changes: 6 additions & 0 deletions http/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ func VersionHandler(version string) http.HandlerFunc {

type ctxKeyTraceID struct{}

// GetTraceID returns the trace ID from ctx.
func GetTraceID(ctx context.Context) string {
id, _ := ctx.Value(ctxKeyTraceID{}).(string)
return id
}

// TraceLoggingMiddleware sets up a trace ID in the request context and
// logs HTTP requests.
func TraceLoggingMiddleware(next http.Handler, logger log.Logger, traceID func(*http.Request) string) http.HandlerFunc {
Expand Down
68 changes: 68 additions & 0 deletions http/mdm/mdm_cert.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,13 @@ import (
mdmhttp "github.com/micromdm/nanomdm/http"
"github.com/micromdm/nanomdm/log"
"github.com/micromdm/nanomdm/log/ctxlog"
"github.com/micromdm/nanomdm/storage"
)

type contextKeyCert struct{}

var contextEnrollmentID struct{}

// CertExtractPEMHeaderMiddleware extracts the MDM enrollment identity
// certificate from the request into the HTTP request context. It looks
// at the request header which should be a URL-encoded PEM certificate.
Expand Down Expand Up @@ -128,3 +131,68 @@ func CertVerifyMiddleware(next http.Handler, verifier CertVerifier, logger log.L
next.ServeHTTP(w, r)
}
}

// GetEnrollmentID retrieves the MDM enrollment ID from ctx.
func GetEnrollmentID(ctx context.Context) string {
id, _ := ctx.Value(contextEnrollmentID).(string)
return id
}

type HashFn func(*x509.Certificate) string

// WithEnrollmentIDMiddleware tries to associate the enrollment ID to the request context.
// It does this by looking up the certificate on the context, hashing it with
// hasher, looking up the hash in storage, and setting the ID on the context.
//
// The next handler will be called even if cert or ID is not found unless
// enforce is true. This way next is able to use the existence of the ID on
// the context to make its own decisions.
func WithEnrollmentIDMiddleware(next http.Handler, hasher HashFn, store storage.CertAuthRetriever, enforce bool, logger log.Logger) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
cert := GetCert(r.Context())
if cert == nil {
if enforce {
ctxlog.Logger(r.Context(), logger).Info(
"err", "missing certificate",
)
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusBadRequest)
return
} else {
ctxlog.Logger(r.Context(), logger).Debug(
"msg", "missing certificate",
)
next.ServeHTTP(w, r)
return
}
}
if store == nil || hasher == nil {
panic("store nor hasher must not be nil")
}
mr, err := store.EnrollmentFromHash(r.Context(), hasher(cert))
if err != nil {
ctxlog.Logger(r.Context(), logger).Info(
"msg", "retreiving enrollment from hash",
"err", err,
)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
if mr == nil || mr.ID == "" {
if enforce {
ctxlog.Logger(r.Context(), logger).Info(
"err", "missing enrollment id",
)
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusBadRequest)
return
} else {
ctxlog.Logger(r.Context(), logger).Debug(
"msg", "missing enrollment id",
)
next.ServeHTTP(w, r)
return
}
}
ctx := context.WithValue(r.Context(), contextEnrollmentID, mr.ID)
next.ServeHTTP(w, r.WithContext(ctx))
}
}
7 changes: 4 additions & 3 deletions service/certauth/certauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ func New(next service.CheckinAndCommandService, storage storage.CertAuthStore, o
return certAuth
}

func hashCert(cert *x509.Certificate) string {
// HashCert returns the string representation
func HashCert(cert *x509.Certificate) string {
hashed := sha256.Sum256(cert.Raw)
b := make([]byte, len(hashed))
copy(b, hashed[:])
Expand All @@ -112,7 +113,7 @@ func (s *CertAuth) associateNewEnrollment(r *mdm.Request) error {
return err
}
logger := ctxlog.Logger(r.Context, s.logger)
hash := hashCert(r.Certificate)
hash := HashCert(r.Certificate)
if hasHash, err := s.storage.HasCertHash(r, hash); err != nil {
return err
} else if hasHash {
Expand Down Expand Up @@ -157,7 +158,7 @@ func (s *CertAuth) validateAssociateExistingEnrollment(r *mdm.Request) error {
return err
}
logger := ctxlog.Logger(r.Context, s.logger)
hash := hashCert(r.Certificate)
hash := HashCert(r.Certificate)
if isAssoc, err := s.storage.IsCertHashAssociated(r, hash); err != nil {
return err
} else if isAssoc {
Expand Down
1 change: 1 addition & 0 deletions storage/all.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ type AllStorage interface {
PushCertStore
CommandEnqueuer
CertAuthStore
CertAuthRetriever
StoreMigrator
TokenUpdateTallyStore
}
9 changes: 9 additions & 0 deletions storage/allmulti/certauth.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package allmulti

import (
"context"

"github.com/micromdm/nanomdm/mdm"
"github.com/micromdm/nanomdm/storage"
)
Expand Down Expand Up @@ -32,3 +34,10 @@ func (ms *MultiAllStorage) AssociateCertHash(r *mdm.Request, hash string) error
})
return err
}

func (ms *MultiAllStorage) EnrollmentFromHash(ctx context.Context, hash string) (*mdm.Request, error) {
val, err := ms.execStores(ctx, func(s storage.AllStorage) (interface{}, error) {
return s.EnrollmentFromHash(ctx, hash)
})
return val.(*mdm.Request), err
}
21 changes: 21 additions & 0 deletions storage/file/certauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package file

import (
"bufio"
"context"
"errors"
"os"
"path"
Expand Down Expand Up @@ -68,3 +69,23 @@ func (s *FileStorage) AssociateCertHash(r *mdm.Request, hash string) error {
e := s.newEnrollment(r.ID)
return e.writeFile(CertAuthFilename, []byte(hash))
}

func (s *FileStorage) EnrollmentFromHash(_ context.Context, hash string) (*mdm.Request, error) {
f, err := os.Open(path.Join(s.path, CertAuthAssociationsFilename))
if err != nil {
return nil, err
}
defer f.Close()
scanner := bufio.NewScanner(f)
for scanner.Scan() {
text := scanner.Text()
if strings.Contains(text, hash) {
split := strings.Split(text, ",")
if len(split) < 2 {
return nil, errors.New("hash and enrollment id not present on line")
}
return &mdm.Request{EnrollID: &mdm.EnrollID{ID: split[0]}}, nil
}
}
return nil, nil
}
15 changes: 15 additions & 0 deletions storage/mysql/certauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package mysql

import (
"context"
"database/sql"
"errors"
"strings"

"github.com/micromdm/nanomdm/mdm"
Expand Down Expand Up @@ -49,3 +51,16 @@ UPDATE sha256 = new.sha256;`,
)
return err
}

func (s *MySQLStorage) EnrollmentFromHash(ctx context.Context, hash string) (*mdm.Request, error) {
var id string
err := s.db.QueryRowContext(
ctx,
`SELECT id FROM cert_auth_associations WHERE sha256 = ? LIMIT 1;`,
hash,
).Scan(&id)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
return &mdm.Request{EnrollID: &mdm.EnrollID{ID: id}}, err
}
15 changes: 15 additions & 0 deletions storage/pgsql/certauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package pgsql

import (
"context"
"database/sql"
"errors"
"strings"

"github.com/micromdm/nanomdm/mdm"
Expand Down Expand Up @@ -50,3 +52,16 @@ ON CONFLICT ON CONSTRAINT cert_auth_associations_pkey DO UPDATE SET updated_at=n
)
return err
}

func (s *PgSQLStorage) EnrollmentFromHash(ctx context.Context, hash string) (*mdm.Request, error) {
var id string
err := s.db.QueryRowContext(
ctx,
`SELECT id FROM cert_auth_associations WHERE sha256 = $1 LIMIT 1;`,
hash,
).Scan(&id)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
return &mdm.Request{EnrollID: &mdm.EnrollID{ID: id}}, err
}
7 changes: 7 additions & 0 deletions storage/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ type CertAuthStore interface {
AssociateCertHash(r *mdm.Request, hash string) error
}

type CertAuthRetriever interface {
// EnrollmentFromHash retrieves an MDM request from a cert hash.
// Implementations should return a nil pointer if no result is found.
// The ID member ought to be populated when non-nil.
EnrollmentFromHash(ctx context.Context, hash string) (*mdm.Request, error)
}

// StoreMigrator retrieves MDM check-ins
type StoreMigrator interface {
// RetrieveMigrationCheckins sends the (decoded) forms of
Expand Down

0 comments on commit 269a37b

Please sign in to comment.