diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml deleted file mode 100644 index 1512e3ec..00000000 --- a/.github/workflows/go.yml +++ /dev/null @@ -1,24 +0,0 @@ -name: Presubmit -on: - push: - branches: [main] - pull_request: - branches: [main] -jobs: - check: - name: Presubmit checks - runs-on: ubuntu-latest - steps: - - name: Set up Go 1.x - uses: actions/setup-go@v2 - with: - go-version: ^1.13 - - name: Check out code into the Go module directory - uses: actions/checkout@v2 - - name: Get dependencies - run: | - curl -sfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh| sh -s -- -b $(go env GOPATH)/bin v1.24.0 - - name: Lint - run: golangci-lint run - - name: Test - run: go test -v ./... diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 00000000..f1640505 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,18 @@ +name: lint + +on: + push: + branches: [ 'main' ] + pull_request: + branches: [ 'main' ] + +jobs: + golangci: + name: Run golangci-lint + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - name: golangci-lint + uses: golangci/golangci-lint-action@v2 + with: + version: v1.52.2 diff --git a/.github/workflows/maint.yml b/.github/workflows/maint.yml deleted file mode 100644 index 1081e7a4..00000000 --- a/.github/workflows/maint.yml +++ /dev/null @@ -1,34 +0,0 @@ -name: Maintainer -on: - workflow_dispatch: - schedule: - - cron: "0 12 * * 0" -jobs: - upgrade_go: - name: Upgrade go.mod - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v2 - - uses: actions/setup-go@v2 - with: - go-version: "^1.15.6" - - name: Install goupdate - run: | - ( - cd $(mktemp -d) - go get github.com/crewjam/goupdate - ) - git config --global user.email noreply@github.com - git config --global user.name "Github Actions" - - name: Update go.mod - run: | - go version - go env - $(go env GOPATH)/bin/goupdate -test 'go test ./...' --commit -v - - name: Create Pull Request - uses: peter-evans/create-pull-request@v3 - with: - commit-message: "Update go.mod" - branch: auto/update-go - title: "Update go.mod" - body: "" diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 00000000..60f3f8f0 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,26 @@ +name: test + +on: + push: + branches: [ 'main' ] + pull_request: + branches: [ 'main' ] +jobs: + tests: + name: Run tests + runs-on: ubuntu-latest + strategy: + matrix: + go: [ '1.17.x', '1.18.x', '1.19.x'] + steps: + - name: Check out code into the Go module directory + uses: actions/checkout@v2 + - name: Set up Go ${{ matrix.go }} + uses: actions/setup-go@v2 + with: + go-version: ${{ matrix.go }} + - name: Go version + run: go version + - name: Run Go tests + run: | + go test -v ./... diff --git a/.golangci.yml b/.golangci.yml index 1392c902..f93ef23b 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -7,41 +7,35 @@ linters: enable: - - gofmt # Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification [fast: true, auto-fix: true] - - goimports # Goimports does everything that gofmt does. Additionally it checks unused imports [fast: true, auto-fix: true] - - gosec # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases [fast: true, auto-fix: false] - - misspell # Finds commonly misspelled English words in comments [fast: true, auto-fix: true] - - deadcode # Finds unused code [fast: true, auto-fix: false] - - golint # Golint differs from gofmt. Gofmt reformats Go source code, whereas golint prints out style mistakes [fast: true, auto-fix: false] - - unconvert # Remove unnecessary type conversions [fast: true, auto-fix: false] - - disable: - # TODO(ross): fix errors reported by these checkers and enable them - bodyclose # checks whether HTTP response body is closed successfully [fast: false, auto-fix: false] - depguard # Go linter that checks if package imports are in a list of acceptable packages [fast: true, auto-fix: false] - - dupl # Tool for code clone detection [fast: true, auto-fix: false] - errcheck # Inspects source code for security problems [fast: true, auto-fix: false] - - gochecknoglobals # Checks that no globals are present in Go code [fast: true, auto-fix: false] - - gochecknoinits # Checks that no init functions are present in Go code [fast: true, auto-fix: false] - - goconst # Finds repeated strings that could be replaced by a constant [fast: true, auto-fix: false] - gocritic # The most opinionated Go source code linter [fast: true, auto-fix: false] - gocyclo # Computes and checks the cyclomatic complexity of functions [fast: true, auto-fix: false] + - gofmt # Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification [fast: true, auto-fix: true] + - goimports # Goimports does everything that gofmt does. Additionally it checks unused imports [fast: true, auto-fix: true] + - gosec # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases [fast: true, auto-fix: false] - gosimple # Linter for Go source code that specializes in simplifying a code [fast: false, auto-fix: false] - govet # Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string [fast: false, auto-fix: false] - ineffassign # Detects when assignments to existing variables are not used [fast: true, auto-fix: false] - - interfacer # Linter that suggests narrower interface types [fast: false, auto-fix: false] - - lll # Reports long lines [fast: true, auto-fix: false] - - maligned # Tool to detect Go structs that would take less memory if their fields were sorted [fast: true, auto-fix: false] + - misspell # Finds commonly misspelled English words in comments [fast: true, auto-fix: true] - nakedret # Finds naked returns in functions greater than a specified function length [fast: true, auto-fix: false] - prealloc # Finds slice declarations that could potentially be preallocated [fast: true, auto-fix: false] - - scopelint # Scopelint checks for unpinned variables in go programs [fast: true, auto-fix: false] + - revive # Golint differs from gofmt. Gofmt reformats Go source code, whereas golint prints out style mistakes [fast: true, auto-fix: false] - staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks [fast: false, auto-fix: false] - - structcheck # Finds unused struct fields [fast: true, auto-fix: false] - stylecheck # Stylecheck is a replacement for golint [fast: false, auto-fix: false] - typecheck # Like the front-end of a Go compiler, parses and type-checks Go code [fast: true, auto-fix: false] + - unconvert # Remove unnecessary type conversions [fast: true, auto-fix: false] - unparam # Reports unused function parameters [fast: false, auto-fix: false] - unused # Checks Go code for unused constants, variables, functions and types [fast: false, auto-fix: false] - - varcheck # Finds unused global variables and constants [fast: true, auto-fix: false] + + disable: + # TODO(ross): fix errors reported by these checkers and enable them + - dupl # Tool for code clone detection [fast: true, auto-fix: false] + - gochecknoglobals # Checks that no globals are present in Go code [fast: true, auto-fix: false] + - gochecknoinits # Checks that no init functions are present in Go code [fast: true, auto-fix: false] + - goconst # Finds repeated strings that could be replaced by a constant [fast: true, auto-fix: false] + - lll # Reports long lines [fast: true, auto-fix: false] linters-settings: goimports: local-prefixes: github.com/crewjam/saml diff --git a/README.md b/README.md index 71f24786..c0b98058 100644 --- a/README.md +++ b/README.md @@ -58,7 +58,7 @@ import ( ) func hello(w http.ResponseWriter, r *http.Request) { - fmt.Fprintf(w, "Hello, %s!", samlsp.AttributeFromContext(r.Context(), "cn")) + fmt.Fprintf(w, "Hello, %s!", samlsp.AttributeFromContext(r.Context(), "displayName")) } func main() { diff --git a/example/idp/idp.go b/example/idp/idp.go index 6069d379..4e47a56a 100644 --- a/example/idp/idp.go +++ b/example/idp/idp.go @@ -1,3 +1,4 @@ +// Package main contains an example identity provider implementation. package main import ( diff --git a/example/service.go b/example/service.go index c153b65f..5b6ddb27 100644 --- a/example/service.go +++ b/example/service.go @@ -32,7 +32,7 @@ type Link struct { } // CreateLink handles requests to create links -func CreateLink(c web.C, w http.ResponseWriter, r *http.Request) { +func CreateLink(_ web.C, w http.ResponseWriter, r *http.Request) { account := r.Header.Get("X-Remote-User") l := Link{ ShortLink: uniuri.New(), @@ -42,22 +42,20 @@ func CreateLink(c web.C, w http.ResponseWriter, r *http.Request) { links[l.ShortLink] = l fmt.Fprintf(w, "%s\n", l.ShortLink) - return } // ServeLink handles requests to redirect to a link -func ServeLink(c web.C, w http.ResponseWriter, r *http.Request) { +func ServeLink(_ web.C, w http.ResponseWriter, r *http.Request) { l, ok := links[strings.TrimPrefix(r.URL.Path, "/")] if !ok { http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) return } http.Redirect(w, r, l.Target, http.StatusFound) - return } // ListLinks returns a list of the current user's links -func ListLinks(c web.C, w http.ResponseWriter, r *http.Request) { +func ListLinks(_ web.C, w http.ResponseWriter, r *http.Request) { account := r.Header.Get("X-Remote-User") for _, l := range links { if l.Owner == account { @@ -145,14 +143,24 @@ func main() { spURL := *idpMetadataURL spURL.Path = "/services/sp" - http.Post(spURL.String(), "text/xml", bytes.NewReader(spMetadataBuf)) + resp, err := http.Post(spURL.String(), "text/xml", bytes.NewReader(spMetadataBuf)) + + if err != nil { + panic(err) + } + + if err := resp.Body.Close(); err != nil { + panic(err) + } goji.Handle("/saml/*", samlSP) authMux := web.New() authMux.Use(samlSP.RequireAccount) authMux.Get("/whoami", func(w http.ResponseWriter, r *http.Request) { - pretty.Fprintf(w, "%# v", r) + if _, err := pretty.Fprintf(w, "%# v", r); err != nil { + panic(err) + } }) authMux.Post("/", CreateLink) authMux.Get("/", ListLinks) diff --git a/example/trivial/trivial.go b/example/trivial/trivial.go index e8be7cb9..45f46080 100644 --- a/example/trivial/trivial.go +++ b/example/trivial/trivial.go @@ -1,3 +1,4 @@ +// Package main contains an example service provider implementation. package main import ( @@ -6,14 +7,34 @@ import ( "crypto/tls" "crypto/x509" "fmt" + "log" "net/http" "net/url" + "time" "github.com/crewjam/saml/samlsp" ) +var samlMiddleware *samlsp.Middleware + func hello(w http.ResponseWriter, r *http.Request) { - fmt.Fprintf(w, "Hello, %s!", samlsp.AttributeFromContext(r.Context(), "cn")) + fmt.Fprintf(w, "Hello, %s!", samlsp.AttributeFromContext(r.Context(), "displayName")) +} + +func logout(w http.ResponseWriter, r *http.Request) { + nameID := samlsp.AttributeFromContext(r.Context(), "urn:oasis:names:tc:SAML:attribute:subject-id") + url, err := samlMiddleware.ServiceProvider.MakeRedirectLogoutRequest(nameID, "") + if err != nil { + panic(err) // TODO handle error + } + + err = samlMiddleware.Session.DeleteSession(w, r) + if err != nil { + panic(err) // TODO handle error + } + + w.Header().Add("Location", url.String()) + w.WriteHeader(http.StatusFound) } func main() { @@ -26,30 +47,38 @@ func main() { panic(err) // TODO handle error } - rootURL, _ := url.Parse("http://localhost:8000") - idpMetadataURL, _ := url.Parse("https://samltest.id/saml/idp") - - idpMetadata, err := samlsp.FetchMetadata( - context.Background(), - http.DefaultClient, + idpMetadataURL, err := url.Parse("https://samltest.id/saml/idp") + if err != nil { + panic(err) // TODO handle error + } + idpMetadata, err := samlsp.FetchMetadata(context.Background(), http.DefaultClient, *idpMetadataURL) if err != nil { panic(err) // TODO handle error } - samlSP, err := samlsp.New(samlsp.Options{ - URL: *rootURL, - IDPMetadata: idpMetadata, - Key: keyPair.PrivateKey.(*rsa.PrivateKey), - Certificate: keyPair.Leaf, - SignRequest: true, - }) + rootURL, err := url.Parse("http://localhost:8000") if err != nil { panic(err) // TODO handle error } + samlMiddleware, _ = samlsp.New(samlsp.Options{ + URL: *rootURL, + Key: keyPair.PrivateKey.(*rsa.PrivateKey), + Certificate: keyPair.Leaf, + IDPMetadata: idpMetadata, + SignRequest: true, // some IdP require the SLO request to be signed + }) app := http.HandlerFunc(hello) - http.Handle("/hello", samlSP.RequireAccount(app)) - http.Handle("/saml/", samlSP) - http.ListenAndServe(":8000", nil) + slo := http.HandlerFunc(logout) + + http.Handle("/hello", samlMiddleware.RequireAccount(app)) + http.Handle("/saml/", samlMiddleware) + http.Handle("/logout", slo) + + server := &http.Server{ + Addr: ":8080", + ReadHeaderTimeout: 5 * time.Second, + } + log.Fatal(server.ListenAndServe()) } diff --git a/flate.go b/flate.go new file mode 100644 index 00000000..4d14e780 --- /dev/null +++ b/flate.go @@ -0,0 +1,31 @@ +package saml + +import ( + "compress/flate" + "fmt" + "io" +) + +const flateUncompressLimit = 10 * 1024 * 1024 // 10MB + +func newSaferFlateReader(r io.Reader) io.ReadCloser { + return &saferFlateReader{r: flate.NewReader(r)} +} + +type saferFlateReader struct { + r io.ReadCloser + count int +} + +func (r *saferFlateReader) Read(p []byte) (n int, err error) { + if r.count+len(p) > flateUncompressLimit { + return 0, fmt.Errorf("flate: uncompress limit exceeded (%d bytes)", flateUncompressLimit) + } + n, err = r.r.Read(p) + r.count += n + return n, err +} + +func (r *saferFlateReader) Close() error { + return r.r.Close() +} diff --git a/go.mod b/go.mod index 687d29a5..745c5c2c 100644 --- a/go.mod +++ b/go.mod @@ -1,22 +1,19 @@ module github.com/crewjam/saml -go 1.13 +go 1.16 require ( github.com/beevik/etree v1.1.0 github.com/crewjam/httperr v0.2.0 - github.com/davecgh/go-spew v1.1.1 // indirect - github.com/dchest/uniuri v0.0.0-20200228104902-7aecb25e1fe5 - github.com/form3tech-oss/jwt-go v3.2.2+incompatible - github.com/google/go-cmp v0.5.5 - github.com/jonboulle/clockwork v0.2.2 // indirect - github.com/kr/pretty v0.3.0 - github.com/kr/text v0.2.0 // indirect + github.com/dchest/uniuri v1.2.0 + github.com/golang-jwt/jwt/v4 v4.4.3 + github.com/google/go-cmp v0.5.9 + github.com/kr/pretty v0.3.1 github.com/mattermost/xml-roundtrip-validator v0.1.0 github.com/pkg/errors v0.9.1 // indirect - github.com/russellhaering/goxmldsig v1.1.1 + github.com/russellhaering/goxmldsig v1.3.0 + github.com/stretchr/testify v1.8.1 github.com/zenazn/goji v1.0.1 - golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2 - golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect + golang.org/x/crypto v0.0.0-20220128200615-198e4374d7ed gotest.tools v2.2.0+incompatible ) diff --git a/go.sum b/go.sum index 80810704..7ab71ea2 100644 --- a/go.sum +++ b/go.sum @@ -6,20 +6,19 @@ github.com/crewjam/httperr v0.2.0/go.mod h1:Jlz+Sg/XqBQhyMjdDiC+GNNRzZTD7x39Gu3p github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dchest/uniuri v0.0.0-20200228104902-7aecb25e1fe5 h1:RAV05c0xOkJ3dZGS0JFybxFKZ2WMLabgx3uXnd7rpGs= -github.com/dchest/uniuri v0.0.0-20200228104902-7aecb25e1fe5/go.mod h1:GgB8SF9nRG+GqaDtLcwJZsQFhcogVCJ79j4EdT0c2V4= -github.com/form3tech-oss/jwt-go v3.2.2+incompatible h1:TcekIExNqud5crz4xD2pavyTgWiPvpYe4Xau31I0PRk= -github.com/form3tech-oss/jwt-go v3.2.2+incompatible/go.mod h1:pbq4aXjuKjdthFRnoDwaVPLA+WlJuPGy+QneDUgJi2k= -github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= -github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/jonboulle/clockwork v0.2.0/go.mod h1:Pkfl5aHPm1nk2H9h0bjmnJD/BcgbGXUBGnn1kMkgxc8= +github.com/dchest/uniuri v1.2.0 h1:koIcOUdrTIivZgSLhHQvKgqdWZq5d7KdMEWF1Ud6+5g= +github.com/dchest/uniuri v1.2.0/go.mod h1:fSzm4SLHzNZvWLvWJew423PhAzkpNQYq+uNLq4kxhkY= +github.com/golang-jwt/jwt/v4 v4.4.3 h1:Hxl6lhQFj4AnOX6MLrsCb/+7tCj7DxP7VA+2rDIq5AU= +github.com/golang-jwt/jwt/v4 v4.4.3/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/jonboulle/clockwork v0.2.2 h1:UOGuzwb1PwsrDAObMuhUnj0p5ULPj8V/xJ7Kx9qUBdQ= github.com/jonboulle/clockwork v0.2.2/go.mod h1:Pkfl5aHPm1nk2H9h0bjmnJD/BcgbGXUBGnn1kMkgxc8= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= -github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= -github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= @@ -33,37 +32,40 @@ github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINE github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= -github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8= github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= -github.com/russellhaering/goxmldsig v1.1.0 h1:lK/zeJie2sqG52ZAlPNn1oBBqsIsEKypUUBGpYYF6lk= -github.com/russellhaering/goxmldsig v1.1.0/go.mod h1:QK8GhXPB3+AfuCrfo0oRISa9NfzeCpWmxeGnqEpDF9o= -github.com/russellhaering/goxmldsig v1.1.1 h1:vI0r2osGF1A9PLvsGdPUAGwEIrKa4Pj5sesSBsebIxM= -github.com/russellhaering/goxmldsig v1.1.1/go.mod h1:gM4MDENBQf7M+V824SGfyIUVFWydB7n0KkEubVJl+Tw= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/russellhaering/goxmldsig v1.3.0 h1:DllIWUgMy0cRUMfGiASiYEa35nsieyD3cigIwLonTPM= +github.com/russellhaering/goxmldsig v1.3.0/go.mod h1:gM4MDENBQf7M+V824SGfyIUVFWydB7n0KkEubVJl+Tw= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/zenazn/goji v1.0.1 h1:4lbD8Mx2h7IvloP7r2C0D6ltZP6Ufip8Hn0wmSK5LR8= github.com/zenazn/goji v1.0.1/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= -golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2 h1:It14KIkyBFYkHkwZ7k45minvA9aorojkyjGk9KJ5B/w= -golang.org/x/crypto v0.0.0-20210322153248-0c34fe9e7dc2/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/crypto v0.0.0-20220128200615-198e4374d7ed h1:YoWVYYAfvQ4ddHv3OKmIvX7NCAhFGTj62VP2l2kfBbA= +golang.org/x/crypto v0.0.0-20220128200615-198e4374d7ed/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= -golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= -gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gotest.tools v2.2.0+incompatible h1:VsBPFP1AI068pPrMxtb/S8Zkgf9xEmTLJjfM+P5UIEo= gotest.tools v2.2.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81HFBacw= diff --git a/identity_provider.go b/identity_provider.go index 20258daf..b03dc89b 100644 --- a/identity_provider.go +++ b/identity_provider.go @@ -2,7 +2,6 @@ package saml import ( "bytes" - "compress/flate" "crypto" "crypto/tls" "crypto/x509" @@ -36,7 +35,10 @@ type Session struct { ExpireTime time.Time Index string - NameID string + NameID string + NameIDFormat string + SubjectID string + Groups []string UserName string UserEmail string @@ -94,6 +96,7 @@ type AssertionMaker interface { // and password). type IdentityProvider struct { Key crypto.PrivateKey + Signer crypto.Signer Logger logger.Interface Certificate *x509.Certificate Intermediates []*x509.Certificate @@ -131,13 +134,21 @@ func (idp *IdentityProvider) Metadata() *EntityDescriptor { { Use: "signing", KeyInfo: KeyInfo{ - Certificate: certStr, + X509Data: X509Data{ + X509Certificates: []X509Certificate{ + {Data: certStr}, + }, + }, }, }, { Use: "encryption", KeyInfo: KeyInfo{ - Certificate: certStr, + X509Data: X509Data{ + X509Certificates: []X509Certificate{ + {Data: certStr}, + }, + }, }, EncryptionMethods: []EncryptionMethod{ {Algorithm: "http://www.w3.org/2001/04/xmlenc#aes128-cbc"}, @@ -186,10 +197,13 @@ func (idp *IdentityProvider) Handler() http.Handler { } // ServeMetadata is an http.HandlerFunc that serves the IDP metadata -func (idp *IdentityProvider) ServeMetadata(w http.ResponseWriter, r *http.Request) { +func (idp *IdentityProvider) ServeMetadata(w http.ResponseWriter, _ *http.Request) { buf, _ := xml.MarshalIndent(idp.Metadata(), "", " ") w.Header().Set("Content-Type", "application/samlmetadata+xml") - w.Write(buf) + if _, err := w.Write(buf); err != nil { + idp.Logger.Printf("ERROR: %s", err) + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + } } // ServeSSO handles SAML auth requests. @@ -352,7 +366,7 @@ func NewIdpAuthnRequest(idp *IdentityProvider, r *http.Request) (*IdpAuthnReques if err != nil { return nil, fmt.Errorf("cannot decode request: %s", err) } - req.RequestBuffer, err = ioutil.ReadAll(flate.NewReader(bytes.NewReader(compressedRequest))) + req.RequestBuffer, err = ioutil.ReadAll(newSaferFlateReader(bytes.NewReader(compressedRequest))) if err != nil { return nil, fmt.Errorf("cannot decompress request: %s", err) } @@ -706,9 +720,7 @@ func (DefaultAssertionMaker) MakeAssertion(req *IdpAuthnRequest, session *Sessio }) } - for _, ca := range session.CustomAttributes { - attributes = append(attributes, ca) - } + attributes = append(attributes, session.CustomAttributes...) if len(session.Groups) != 0 { groupMemberAttributeValues := []AttributeValue{} @@ -726,6 +738,19 @@ func (DefaultAssertionMaker) MakeAssertion(req *IdpAuthnRequest, session *Sessio }) } + if session.SubjectID != "" { + attributes = append(attributes, Attribute{ + Name: "urn:oasis:names:tc:SAML:attribute:subject-id", + NameFormat: "urn:oasis:names:tc:SAML:2.0:attrname-format:uri", + Values: []AttributeValue{ + { + Type: "xs:string", + Value: session.SubjectID, + }, + }, + }) + } + // allow for some clock skew in the validity period using the // issuer's apparent clock. notBefore := req.Now.Add(-1 * MaxClockSkew) @@ -735,6 +760,12 @@ func (DefaultAssertionMaker) MakeAssertion(req *IdpAuthnRequest, session *Sessio notOnOrAfterAfter = notBefore.Add(MaxIssueDelay) } + nameIDFormat := "urn:oasis:names:tc:SAML:2.0:nameid-format:transient" + + if session.NameIDFormat != "" { + nameIDFormat = session.NameIDFormat + } + req.Assertion = &Assertion{ ID: fmt.Sprintf("id-%x", randomBytes(20)), IssueInstant: TimeNow(), @@ -745,7 +776,7 @@ func (DefaultAssertionMaker) MakeAssertion(req *IdpAuthnRequest, session *Sessio }, Subject: &Subject{ NameID: &NameID{ - Format: "urn:oasis:names:tc:SAML:2.0:nameid-format:transient", + Format: nameIDFormat, NameQualifier: req.IDP.Metadata().EntityID, SPNameQualifier: req.ServiceProviderMetadata.EntityID, Value: session.NameID, @@ -801,24 +832,8 @@ const canonicalizerPrefixList = "" // MakeAssertionEl sets `AssertionEl` to a signed, possibly encrypted, version of `Assertion`. func (req *IdpAuthnRequest) MakeAssertionEl() error { - keyPair := tls.Certificate{ - Certificate: [][]byte{req.IDP.Certificate.Raw}, - PrivateKey: req.IDP.Key, - Leaf: req.IDP.Certificate, - } - for _, cert := range req.IDP.Intermediates { - keyPair.Certificate = append(keyPair.Certificate, cert.Raw) - } - keyStore := dsig.TLSCertKeyStore(keyPair) - - signatureMethod := req.IDP.SignatureMethod - if signatureMethod == "" { - signatureMethod = dsig.RSASHA1SignatureMethod - } - - signingContext := dsig.NewDefaultSigningContext(keyStore) - signingContext.Canonicalizer = dsig.MakeC14N10ExclusiveCanonicalizerWithPrefixList(canonicalizerPrefixList) - if err := signingContext.SetSignatureMethod(signatureMethod); err != nil { + signingContext, err := req.signingContext() + if err != nil { return err } @@ -854,7 +869,7 @@ func (req *IdpAuthnRequest) MakeAssertionEl() error { encryptor := xmlenc.OAEP() encryptor.BlockCipher = xmlenc.AES128CBC encryptor.DigestMethod = &xmlenc.SHA1 - encryptedDataEl, err := encryptor.Encrypt(certBuf, signedAssertionBuf) + encryptedDataEl, err := encryptor.Encrypt(certBuf, signedAssertionBuf, nil) if err != nil { return err } @@ -867,12 +882,23 @@ func (req *IdpAuthnRequest) MakeAssertionEl() error { return nil } -// WriteResponse writes the `Response` to the http.ResponseWriter. If -// `Response` is not already set, it calls MakeResponse to produce it. -func (req *IdpAuthnRequest) WriteResponse(w http.ResponseWriter) error { +// IdpAuthnRequestForm contans HTML form information to be submitted to the +// SAML HTTP POST binding ACS. +type IdpAuthnRequestForm struct { + URL string + SAMLResponse string + RelayState string +} + +// PostBinding creates the HTTP POST form information for this +// `IdpAuthnRequest`. If `Response` is not already set, it calls MakeResponse +// to produce it. +func (req *IdpAuthnRequest) PostBinding() (IdpAuthnRequestForm, error) { + var form IdpAuthnRequestForm + if req.ResponseEl == nil { if err := req.MakeResponse(); err != nil { - return err + return form, err } } @@ -880,45 +906,48 @@ func (req *IdpAuthnRequest) WriteResponse(w http.ResponseWriter) error { doc.SetRoot(req.ResponseEl) responseBuf, err := doc.WriteToBytes() if err != nil { - return err + return form, err } - // the only supported binding is the HTTP-POST binding - switch req.ACSEndpoint.Binding { - case HTTPPostBinding: - tmpl := template.Must(template.New("saml-post-form").Parse(`` + - `
` + - `` + - `` + - ``)) - data := struct { - URL string - SAMLResponse string - RelayState string - }{ - URL: req.ACSEndpoint.Location, - SAMLResponse: base64.StdEncoding.EncodeToString(responseBuf), - RelayState: req.RelayState, - } - - buf := bytes.NewBuffer(nil) - if err := tmpl.Execute(buf, data); err != nil { - return err - } - if _, err := io.Copy(w, buf); err != nil { - return err - } - return nil - - default: - return fmt.Errorf("%s: unsupported binding %s", + if req.ACSEndpoint.Binding != HTTPPostBinding { + return form, fmt.Errorf("%s: unsupported binding %s", req.ServiceProviderMetadata.EntityID, req.ACSEndpoint.Binding) } + + form.URL = req.ACSEndpoint.Location + form.SAMLResponse = base64.StdEncoding.EncodeToString(responseBuf) + form.RelayState = req.RelayState + + return form, nil +} + +// WriteResponse writes the `Response` to the http.ResponseWriter. If +// `Response` is not already set, it calls MakeResponse to produce it. +func (req *IdpAuthnRequest) WriteResponse(w http.ResponseWriter) error { + form, err := req.PostBinding() + if err != nil { + return err + } + + tmpl := template.Must(template.New("saml-post-form").Parse(`` + + `` + + `` + + `` + + ``)) + + buf := bytes.NewBuffer(nil) + if err := tmpl.Execute(buf, form); err != nil { + return err + } + if _, err := io.Copy(w, buf); err != nil { + return err + } + return nil } // getSPEncryptionCert returns the certificate which we can use to encrypt things @@ -927,7 +956,7 @@ func (req *IdpAuthnRequest) getSPEncryptionCert() (*x509.Certificate, error) { certStr := "" for _, keyDescriptor := range req.SPSSODescriptor.KeyDescriptors { if keyDescriptor.Use == "encryption" { - certStr = keyDescriptor.KeyInfo.Certificate + certStr = keyDescriptor.KeyInfo.X509Data.X509Certificates[0].Data break } } @@ -936,8 +965,8 @@ func (req *IdpAuthnRequest) getSPEncryptionCert() (*x509.Certificate, error) { // non-empty cert we find. if certStr == "" { for _, keyDescriptor := range req.SPSSODescriptor.KeyDescriptors { - if keyDescriptor.Use == "" && keyDescriptor.KeyInfo.Certificate != "" { - certStr = keyDescriptor.KeyInfo.Certificate + if keyDescriptor.Use == "" && len(keyDescriptor.KeyInfo.X509Data.X509Certificates) != 0 && keyDescriptor.KeyInfo.X509Data.X509Certificates[0].Data != "" { + certStr = keyDescriptor.KeyInfo.X509Data.X509Certificates[0].Data break } } @@ -1005,24 +1034,8 @@ func (req *IdpAuthnRequest) MakeResponse() error { // Sign the response element (we've already signed the Assertion element) { - keyPair := tls.Certificate{ - Certificate: [][]byte{req.IDP.Certificate.Raw}, - PrivateKey: req.IDP.Key, - Leaf: req.IDP.Certificate, - } - for _, cert := range req.IDP.Intermediates { - keyPair.Certificate = append(keyPair.Certificate, cert.Raw) - } - keyStore := dsig.TLSCertKeyStore(keyPair) - - signatureMethod := req.IDP.SignatureMethod - if signatureMethod == "" { - signatureMethod = dsig.RSASHA1SignatureMethod - } - - signingContext := dsig.NewDefaultSigningContext(keyStore) - signingContext.Canonicalizer = dsig.MakeC14N10ExclusiveCanonicalizerWithPrefixList(canonicalizerPrefixList) - if err := signingContext.SetSignatureMethod(signatureMethod); err != nil { + signingContext, err := req.signingContext() + if err != nil { return err } @@ -1040,3 +1053,44 @@ func (req *IdpAuthnRequest) MakeResponse() error { req.ResponseEl = responseEl return nil } + +// signingContext will create a signing context for the request. +func (req *IdpAuthnRequest) signingContext() (*dsig.SigningContext, error) { + // Create a cert chain based off of the IDP cert and its intermediates. + certificates := [][]byte{req.IDP.Certificate.Raw} + for _, cert := range req.IDP.Intermediates { + certificates = append(certificates, cert.Raw) + } + + var signingContext *dsig.SigningContext + var err error + // If signer is set, use it instead of the private key. + if req.IDP.Signer != nil { + signingContext, err = dsig.NewSigningContext(req.IDP.Signer, certificates) + if err != nil { + return nil, err + } + } else { + keyPair := tls.Certificate{ + Certificate: certificates, + PrivateKey: req.IDP.Key, + Leaf: req.IDP.Certificate, + } + keyStore := dsig.TLSCertKeyStore(keyPair) + + signingContext = dsig.NewDefaultSigningContext(keyStore) + } + + // Default to using SHA1 if the signature method isn't set. + signatureMethod := req.IDP.SignatureMethod + if signatureMethod == "" { + signatureMethod = dsig.RSASHA1SignatureMethod + } + + signingContext.Canonicalizer = dsig.MakeC14N10ExclusiveCanonicalizerWithPrefixList(canonicalizerPrefixList) + if err := signingContext.SetSignatureMethod(signatureMethod); err != nil { + return nil, err + } + + return signingContext, nil +} diff --git a/identity_provider_go116_test.go b/identity_provider_go116_test.go new file mode 100644 index 00000000..6d4a0a53 --- /dev/null +++ b/identity_provider_go116_test.go @@ -0,0 +1,57 @@ +//go:build !go1.17 +// +build !go1.17 + +package saml + +import ( + "bytes" + "compress/flate" + "encoding/base64" + "io/ioutil" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "gotest.tools/assert" + is "gotest.tools/assert/cmp" +) + +func TestIDPHTTPCanHandleSSORequest(t *testing.T) { + test := NewIdentityProviderTest(t, applyKey) + w := httptest.NewRecorder() + + const validRequest = `lJJBayoxFIX%2FypC9JhnU5wszAz7lgWCLaNtFd5fMbQ1MkmnunVb%2FfUfbUqEgdhs%2BTr5zkmLW8S5s8KVD4mzvm0Cl6FIwEciRCeCRDFuznd2sTD5Upk2Ro42NyGZEmNjFMI%2BBOo9pi%2BnVWbzfrEqxY27JSEntEPfg2waHNnpJ4JtcgiWRLfoLXYBjwDfu6p%2B8JIoiWy5K4eqBUipXIzVRUwXKKtRK53qkJ3qqQVuNPUjU4TIQQ%2BBS5EqPBzofKH2ntBn%2FMervo8jWnyX%2BuVC78FwKkT1gopNKX1JUxSklXTMIfM0gsv8xeeDL%2BPGk7%2FF0Qg0GdnwQ1cW5PDLUwFDID6uquO1Dlot1bJw9%2FPLRmia%2BzRMCYyk4dSiq6205QSDXOxfy3KAq5Pkvqt4DAAD%2F%2Fw%3D%3D` + + r, _ := http.NewRequest("GET", "https://idp.example.com/saml/sso?RelayState=ThisIsTheRelayState&"+ + "SAMLRequest="+validRequest, nil) + test.IDP.Handler().ServeHTTP(w, r) + assert.Check(t, is.Equal(http.StatusOK, w.Code)) + + // rejects requests that are invalid + w = httptest.NewRecorder() + r, _ = http.NewRequest("GET", "https://idp.example.com/saml/sso?RelayState=ThisIsTheRelayState&"+ + "SAMLRequest=PEF1dGhuUmVxdWVzdA%3D%3D", nil) + test.IDP.Handler().ServeHTTP(w, r) + assert.Check(t, is.Equal(http.StatusBadRequest, w.Code)) + + // rejects requests that contain malformed XML + { + a, _ := url.QueryUnescape(validRequest) + b, _ := base64.StdEncoding.DecodeString(a) + c, _ := ioutil.ReadAll(flate.NewReader(bytes.NewReader(b))) + d := bytes.Replace(c, []byte("