Skip to content

Commit

Permalink
feat: add server side sessions
Browse files Browse the repository at this point in the history
  • Loading branch information
andygeiss committed Jan 5, 2025
1 parent 8e39644 commit 257285d
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 17 deletions.
14 changes: 10 additions & 4 deletions security/mux.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,21 @@ import (

// Mux creates a new mux with the liveness check endpoint (/liveness)
// and the readiness check endpoint (/readiness).
func Mux(ctx context.Context, efs embed.FS) *http.ServeMux {
// It also adds an authentication endpoint (/auth/callback) and a login endpoint (/auth/login).
// The mux is returned along with a new ServerSessions instance.
func Mux(ctx context.Context, efs embed.FS) (mux *http.ServeMux, serverSessions *ServerSessions) {
// Create a new mux with liveness and readyness endpoint.
mux := http.NewServeMux()
mux = http.NewServeMux()

// Create an in-memory store for the server sessions.
serverSessions = NewServerSessions()

// Add a file server to the mux.
mux.Handle("GET /", http.FileServerFS(efs))

// Add authentication to the mux.
mux.HandleFunc("GET /auth/callback", OAuthCallback(os.Getenv("HOME_PATH")))
homePath := os.Getenv("HOME_PATH")
mux.HandleFunc("GET /auth/callback", OAuthCallback(homePath, serverSessions))
mux.HandleFunc("GET /auth/login", OAuthLogin)

// Add a liveness check endpoint to the mux.
Expand All @@ -36,5 +42,5 @@ func Mux(ctx context.Context, efs embed.FS) *http.ServeMux {
}
})

return mux
return mux, serverSessions
}
12 changes: 6 additions & 6 deletions security/mux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ var efs embed.FS

func TestServeMux_Is_Not_Nil(t *testing.T) {
ctx := context.Background()
mux := security.Mux(ctx, efs)
mux, _ := security.Mux(ctx, efs)
assert.That(t, "mux must not be nil", mux != nil, true)
}

func TestServeMux_Has_Health_Check(t *testing.T) {
ctx := context.Background()
mux := security.Mux(ctx, efs)
mux, _ := security.Mux(ctx, efs)
req := httptest.NewRequest("GET", "/liveness", nil)
w := httptest.NewRecorder()
mux.ServeHTTP(w, req)
Expand All @@ -31,7 +31,7 @@ func TestServeMux_Has_Health_Check(t *testing.T) {

func TestServeMux_Has_Readiness_Check_When_Context_Active(t *testing.T) {
ctx := context.Background()
mux := security.Mux(ctx, efs)
mux, _ := security.Mux(ctx, efs)
req := httptest.NewRequest("GET", "/readiness", nil)
w := httptest.NewRecorder()
mux.ServeHTTP(w, req)
Expand All @@ -42,7 +42,7 @@ func TestServeMux_Has_Readiness_Check_When_Context_Active(t *testing.T) {
func TestServeMux_Has_Readiness_Check_When_Context_Canceled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
cancel() // Immediately cancel the context.
mux := security.Mux(ctx, efs)
mux, _ := security.Mux(ctx, efs)
req := httptest.NewRequest("GET", "/readiness", nil)
w := httptest.NewRecorder()
mux.ServeHTTP(w, req)
Expand All @@ -51,7 +51,7 @@ func TestServeMux_Has_Readiness_Check_When_Context_Canceled(t *testing.T) {

func TestServeMux_Unknown_Route(t *testing.T) {
ctx := context.Background()
mux := security.Mux(ctx, efs)
mux, _ := security.Mux(ctx, efs)
req := httptest.NewRequest("GET", "/unknown", nil)
w := httptest.NewRecorder()
mux.ServeHTTP(w, req)
Expand All @@ -60,7 +60,7 @@ func TestServeMux_Unknown_Route(t *testing.T) {

func TestServeMux_Has_Static_Assets(t *testing.T) {
ctx := context.Background()
mux := security.Mux(ctx, efs)
mux, _ := security.Mux(ctx, efs)
req := httptest.NewRequest("GET", "/testdata/server.crt", nil)
w := httptest.NewRecorder()
mux.ServeHTTP(w, req)
Expand Down
19 changes: 12 additions & 7 deletions security/oauth_callback.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
)

// OAuthLogin is the handler for the /github/login route.
func OAuthCallback(homePath string) http.HandlerFunc {
func OAuthCallback(homePath string, sessions *ServerSessions) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
code := r.URL.Query().Get("code")
// state := r.URL.Query().Get("state")
Expand All @@ -30,14 +30,19 @@ func OAuthCallback(homePath string) http.HandlerFunc {
return
}

// Set the user's information in the request headers.
r.Header.Set("X-User-Avatar-URL", userInfo.AvatarURL)
r.Header.Set("X-User-Email", userInfo.EMail)
r.Header.Set("X-User-Login", userInfo.Login)
r.Header.Set("X-User-Name", userInfo.Name)
// Update the user's session.
sessionID := sessions.Update(ServerSession{
AvatarURL: userInfo.AvatarURL,
EMail: userInfo.EMail,
Login: userInfo.Login,
Name: userInfo.Name,
})

params := url.Values{}
params.Add("s", sessionID)

// Redirect the user to the home page.
http.Redirect(w, r, homePath, http.StatusSeeOther)
http.Redirect(w, r, homePath+"?"+params.Encode(), http.StatusSeeOther)
}
}

Expand Down
45 changes: 45 additions & 0 deletions security/server_sessions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package security

import (
"encoding/hex"
"sync"
)

// ServerSession is a session for a user.
type ServerSession struct {
AvatarURL string `json:"avatar_url"`
EMail string `json:"email"`
Login string `json:"login"`
Name string `json:"name"`
}

// ServerSessions is a thread-safe map of email addresses to tokens.
type ServerSessions struct {
sessions map[string]ServerSession
mutex sync.RWMutex
}

// NewServerSessions creates a new serverSessions.
func NewServerSessions() *ServerSessions {
return &ServerSessions{
sessions: make(map[string]ServerSession),
}
}

// Update adds a new session to the serverSessions.
func (a *ServerSessions) Update(info ServerSession) (sessionID string) {
a.mutex.Lock()
defer a.mutex.Unlock()
bytes := GenerateKey()
sessionID = hex.EncodeToString(bytes[:])
a.sessions[sessionID] = info
return sessionID
}

// Get returns the session for the given sessionID.
func (a *ServerSessions) Get(sessionID string) (ServerSession, bool) {
a.mutex.RLock()
defer a.mutex.RUnlock()
info, ok := a.sessions[sessionID]
return info, ok
}
23 changes: 23 additions & 0 deletions security/server_sessions_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package security_test

import (
"testing"

"github.com/andygeiss/cloud-native-utils/assert"
"github.com/andygeiss/cloud-native-utils/security"
)

func TestServerSessions_Update(t *testing.T) {
sessions := security.NewServerSessions()
token := sessions.Update(security.ServerSession{AvatarURL: "avatar_url", EMail: "email", Login: "login", Name: "name"})
assert.That(t, "token is correct", len(token), 64)
}

func TestServerSessions_Get(t *testing.T) {
sessions := security.NewServerSessions()
session := security.ServerSession{AvatarURL: "avatar_url", EMail: "email", Login: "login", Name: "name"}
token := sessions.Update(session)
current, found := sessions.Get(token)
assert.That(t, "session must be found", found, true)
assert.That(t, "session is correct", current, session)
}

0 comments on commit 257285d

Please sign in to comment.