Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix XSS via missing Binding syntax validation #34

Closed
wants to merge 11 commits into from
Closed
2 changes: 1 addition & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ jobs:
- name: golangci-lint
uses: golangci/golangci-lint-action@v2
with:
version: v1.46.2
version: v1.54.2
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ import (
)

func hello(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "Hello, %s!", samlsp.AttributeFromContext(r.Context(), "cn"))
fmt.Fprintf(w, "Hello, %s!", samlsp.AttributeFromContext(r.Context(), "displayName"))
}

func main() {
Expand Down
7 changes: 6 additions & 1 deletion example/trivial/trivial.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"log"
"net/http"
"net/url"
"time"

"github.com/crewjam/saml/samlsp"
)
Expand Down Expand Up @@ -72,5 +73,9 @@ func main() {
http.Handle("/hello", samlMiddleware.RequireAccount(app))
http.Handle("/saml/", samlMiddleware)
http.Handle("/logout", slo)
log.Fatal(http.ListenAndServe(":8000", nil))
server := &http.Server{
Addr: ":8000",
ReadHeaderTimeout: 3 * time.Second,
}
log.Fatal(server.ListenAndServe())
}
31 changes: 31 additions & 0 deletions flate.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package saml

import (
"compress/flate"
"fmt"
"io"
)

const flateUncompressLimit = 10 * 1024 * 1024 // 10MB

func newSaferFlateReader(r io.Reader) io.ReadCloser {
return &saferFlateReader{r: flate.NewReader(r)}
}

type saferFlateReader struct {
r io.ReadCloser
count int
}

func (r *saferFlateReader) Read(p []byte) (n int, err error) {
if r.count+len(p) > flateUncompressLimit {
return 0, fmt.Errorf("flate: uncompress limit exceeded (%d bytes)", flateUncompressLimit)
}
n, err = r.r.Read(p)
r.count += n
return n, err
}

func (r *saferFlateReader) Close() error {
return r.r.Close()
}
5 changes: 3 additions & 2 deletions identity_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package saml

import (
"bytes"
"compress/flate"
"crypto"
"crypto/tls"
"crypto/x509"
Expand Down Expand Up @@ -130,6 +129,8 @@ func (idp *IdentityProvider) Metadata() *EntityDescriptor {
SSODescriptor: SSODescriptor{
RoleDescriptor: RoleDescriptor{
ProtocolSupportEnumeration: "urn:oasis:names:tc:SAML:2.0:protocol",
CacheDuration: validDuration,
ValidUntil: TimeNow().Add(validDuration),
KeyDescriptors: []KeyDescriptor{
{
Use: "signing",
Expand Down Expand Up @@ -363,7 +364,7 @@ func NewIdpAuthnRequest(idp *IdentityProvider, r *http.Request) (*IdpAuthnReques
if err != nil {
return nil, fmt.Errorf("cannot decode request: %s", err)
}
req.RequestBuffer, err = ioutil.ReadAll(flate.NewReader(bytes.NewReader(compressedRequest)))
req.RequestBuffer, err = ioutil.ReadAll(newSaferFlateReader(bytes.NewReader(compressedRequest)))
if err != nil {
return nil, fmt.Errorf("cannot decompress request: %s", err)
}
Expand Down
33 changes: 32 additions & 1 deletion identity_provider_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package saml

import (
"bytes"
"compress/flate"
"crypto"
"crypto/rsa"
"crypto/x509"
Expand Down Expand Up @@ -147,6 +149,8 @@ func TestIDPCanProduceMetadata(t *testing.T) {
{
SSODescriptor: SSODescriptor{
RoleDescriptor: RoleDescriptor{
ValidUntil: TimeNow().Add(DefaultValidDuration),
CacheDuration: DefaultValidDuration,
ProtocolSupportEnumeration: "urn:oasis:names:tc:SAML:2.0:protocol",
KeyDescriptors: []KeyDescriptor{
{
Expand Down Expand Up @@ -205,7 +209,8 @@ func TestIDPHTTPCanHandleMetadataRequest(t *testing.T) {
test.IDP.Handler().ServeHTTP(w, r)
assert.Check(t, is.Equal(http.StatusOK, w.Code))
assert.Check(t, is.Equal("application/samlmetadata+xml", w.Header().Get("Content-type")))
assert.Check(t, strings.HasPrefix(string(w.Body.Bytes()), "<EntityDescriptor"),
body := string(w.Body.Bytes())
assert.Check(t, strings.HasPrefix(body, "<EntityDescriptor"),
string(w.Body.Bytes()))
}

Expand Down Expand Up @@ -1013,3 +1018,29 @@ func TestIDPNoDestination(t *testing.T) {
err = req.MakeResponse()
assert.Check(t, err)
}

func TestIDPRejectDecompressionBomb(t *testing.T) {
test := NewIdentifyProviderTest(t)
test.IDP.SessionProvider = &mockSessionProvider{
GetSessionFunc: func(w http.ResponseWriter, r *http.Request, req *IdpAuthnRequest) *Session {
fmt.Fprintf(w, "RelayState: %s\nSAMLRequest: %s",
req.RelayState, req.RequestBuffer)
return nil
},
}

//w := httptest.NewRecorder()

data := bytes.Repeat([]byte("a"), 768*1024*1024)
var compressed bytes.Buffer
w, _ := flate.NewWriter(&compressed, flate.BestCompression)
w.Write(data)
w.Close()
encoded := base64.StdEncoding.EncodeToString(compressed.Bytes())

r, _ := http.NewRequest("GET", "/dontcare?"+url.Values{
"SAMLRequest": {encoded},
}.Encode(), nil)
_, err := NewIdpAuthnRequest(&test.IDP, r)
assert.Error(t, err, "cannot decompress request: flate: uncompress limit exceeded (10485760 bytes)")
}
Loading
Loading