Skip to content

Commit

Permalink
Fix a race condition when browsers would send two requests
Browse files Browse the repository at this point in the history
Some browsers would send two requests, with one landing reliably between
when the OAuth2Listener was closed (and thus, its channel was closed)
and when the http server would be closed.

This change solves this problem by closing the channel after the
request is received and only ever processing a single request.

Other requests will receive responses, but will be silently ignored
  • Loading branch information
punmechanic committed Mar 26, 2024
1 parent 7555b79 commit ad169ef
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 28 deletions.
78 changes: 52 additions & 26 deletions cli/oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"net/http"
"net/url"
"strings"
"sync"

"github.com/RobotsAndPencils/go-saml"
"github.com/coreos/go-oidc"
Expand Down Expand Up @@ -63,11 +64,6 @@ type OAuth2CallbackInfo struct {
ErrorDescription string
}

type OAuth2Listener struct {
Socket net.Listener
callbackCh chan OAuth2CallbackInfo
}

func ParseCallbackRequest(r *http.Request) (OAuth2CallbackInfo, error) {
info := OAuth2CallbackInfo{
Error: r.FormValue("error"),
Expand All @@ -79,44 +75,74 @@ func ParseCallbackRequest(r *http.Request) (OAuth2CallbackInfo, error) {
return info, nil
}

// OAuth2Listener will listen for a single callback request from a web server and return the code if it matched, or an error otherwise.
type OAuth2Listener struct {
socket net.Listener
once sync.Once
callbackCh chan OAuth2CallbackInfo
}

func NewOAuth2Listener(socket net.Listener) OAuth2Listener {
return OAuth2Listener{
Socket: socket,
socket: socket,
// This channel is only ever closed if a successful request is received.
// If the caller closes the socket, then that channel will leak resources.
//
// This probably indicates a problem with the way this struct is constructed:
// The channel should be 'bound' to the lifetime of the socket.
//
// Still, it's a minor resource waste, so we don't care that much.
//
// We can't have a Close() function on this struct, because the caller could call Close() before a request is received,
// which would result on a send on a closed channel - which will cause a panic.
//
// The correct thing to do is probably modify this constructor to instead:
// * Accept a context
// * Return a channel
// * Close the channel when the context expires or it receives a request (whichever is first).
//
// Unfortunately, this is challenging to do while also ensuring that this struct adheres to the http.Handler interface.
//
// The correct solution probably means we change this function signature to
//
// func(context.Context, state string) (http.Handler, <-chan string)
//
// or the less re-usable/testable
//
// func(context.Context, socket net.Listener, state string) <-chan string
//
// The real problem we have with the current layout is that the struct can be put into invalid states, and the easiest way to avoid that
// is to simply not allow state manipulation at all using a closure.
callbackCh: make(chan OAuth2CallbackInfo),
}
}

func (o OAuth2Listener) Close() error {
if o.callbackCh != nil {
func (o *OAuth2Listener) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// This can sometimes be called multiple times, depending on the browser.
// We will simply ignore any other requests and only serve the first.
o.once.Do(func() {
info, err := ParseCallbackRequest(r)
if err == nil {
// The only errors that might occur would be incorrectly formatted requests, which we will silently drop.
o.callbackCh <- info
}
close(o.callbackCh)
}
if o.Socket != nil {
return o.Socket.Close()
}
return nil
}

func (o OAuth2Listener) ServeHTTP(w http.ResponseWriter, r *http.Request) {
info, err := ParseCallbackRequest(r)
if err == nil {
// The only errors that might occur would be incorrectly formatted requests, which we will silently drop.
o.callbackCh <- info
}
})

// This is displayed to the end user in their browser.
// We still want to provide feedback to the end-user.
fmt.Fprintln(w, "You may close this window now.")
}

func (o OAuth2Listener) Listen() error {
err := http.Serve(o.Socket, o)
func (o *OAuth2Listener) Listen() error {
err := http.Serve(o.socket, o)
if errors.Is(err, http.ErrServerClosed) {
return nil
}

return err
}

func (o OAuth2Listener) WaitForAuthorizationCode(ctx context.Context, state string) (string, error) {
func (o *OAuth2Listener) WaitForAuthorizationCode(ctx context.Context, state string) (string, error) {
select {
case info := <-o.callbackCh:
if info.Error != "" {
Expand Down Expand Up @@ -218,6 +244,7 @@ func (r RedirectionFlowHandler) HandlePendingSession(ctx context.Context, challe
if err != nil {
return nil, err
}
defer sock.Close()

_, port, err := net.SplitHostPort(sock.Addr().String())
if err != nil {
Expand All @@ -232,7 +259,6 @@ func (r RedirectionFlowHandler) HandlePendingSession(ctx context.Context, challe
)

listener := NewOAuth2Listener(sock)
defer listener.Close()
// This error can be ignored.
go listener.Listen()

Expand Down
2 changes: 0 additions & 2 deletions cli/oauth2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,13 @@ func Test_OAuth2Listener_WaitForAuthorizationCodeWorksCorrectly(t *testing.T) {
cancel()

assert.Equal(t, expectedCode, code)
assert.NoError(t, listener.Close())
}

func Test_OAuth2Listener_ZeroValueNeverPanics(t *testing.T) {
var listener OAuth2Listener
deadline, _ := context.WithTimeout(context.Background(), 500*time.Millisecond)
_, err := listener.WaitForAuthorizationCode(deadline, "")
assert.ErrorIs(t, context.DeadlineExceeded, err)
assert.NoError(t, listener.Close())
}

// This test is going to be flaky because processes may open ports outside of our control.
Expand Down

0 comments on commit ad169ef

Please sign in to comment.