Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

connector: add CAS connector #3836

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 129 additions & 0 deletions connector/cas/cas.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
// Package cas provides authentication strategies using CAS.
package cas

import (
"fmt"
"log/slog"
"net/http"
"net/url"

"github.com/dexidp/dex/connector"
"github.com/pkg/errors"
"gopkg.in/cas.v2"
)

// Config holds configuration options for CAS logins.
type Config struct {
Portal string `json:"portal"`
Mapping map[string]string `json:"mapping"`
}

// Open returns a strategy for logging in through CAS.
func (c *Config) Open(id string, logger *slog.Logger) (connector.Connector, error) {
casURL, err := url.Parse(c.Portal)
if err != nil {
return "", fmt.Errorf("failed to parse casURL %q: %v", c.Portal, err)
}
return &casConnector{
client: http.DefaultClient,
portal: casURL,
mapping: c.Mapping,
logger: logger.With(slog.Group("connector", "type", "cas", "id", id)),
pathSuffix: "/" + id,
}, nil
}

var _ connector.CallbackConnector = (*casConnector)(nil)

type casConnector struct {
client *http.Client
portal *url.URL
mapping map[string]string
logger *slog.Logger
pathSuffix string
}

// LoginURL returns the URL to redirect the user to login with.
func (m *casConnector) LoginURL(s connector.Scopes, callbackURL, state string) (string, error) {
u, err := url.Parse(callbackURL)
if err != nil {
return "", fmt.Errorf("failed to parse callbackURL %q: %v", callbackURL, err)
}
u.Path += m.pathSuffix
// context = $callbackURL + $m.pathSuffix
v := u.Query()
v.Set("context", u.String()) // without query params
v.Set("state", state)
u.RawQuery = v.Encode()

loginURL := *m.portal
loginURL.Path += "/login"
// encode service url to context, which used in `HandleCallback`
// service = $callbackURL + $m.pathSuffix ? state=$state & context=$callbackURL + $m.pathSuffix
q := loginURL.Query()
q.Set("service", u.String()) // service = ...?state=...&context=...
loginURL.RawQuery = q.Encode()
return loginURL.String(), nil
}

// HandleCallback parses the request and returns the user's identity
func (m *casConnector) HandleCallback(s connector.Scopes, r *http.Request) (connector.Identity, error) {
state := r.URL.Query().Get("state")
ticket := r.URL.Query().Get("ticket")
// service=context = $callbackURL + $m.pathSuffix
serviceURL, err := url.Parse(r.URL.Query().Get("context"))
if err != nil {
return connector.Identity{}, fmt.Errorf("failed to parse serviceURL %q: %v", r.URL.Query().Get("context"), err)
}
// service = $callbackURL + $m.pathSuffix ? state=$state & context=$callbackURL + $m.pathSuffix
q := serviceURL.Query()
q.Set("context", serviceURL.String())
q.Set("state", state)
serviceURL.RawQuery = q.Encode()

user, err := m.getCasUserByTicket(ticket, serviceURL)
if err != nil {
return connector.Identity{}, err
}
m.logger.Info("cas user", "user", user)
return user, nil
}

func (m *casConnector) getCasUserByTicket(ticket string, serviceURL *url.URL) (connector.Identity, error) {
id := connector.Identity{}
// validate ticket
validator := cas.NewServiceTicketValidator(m.client, m.portal)
resp, err := validator.ValidateTicket(serviceURL, ticket)
if err != nil {
return id, errors.Wrapf(err, "failed to validate ticket via %q with ticket %q", serviceURL, ticket)
}
// fill identity
id.UserID = resp.User
id.Groups = resp.MemberOf
if len(m.mapping) == 0 {
return id, nil
}
if username, ok := m.mapping["username"]; ok {
id.Username = resp.Attributes.Get(username)
if id.Username == "" && username == "userid" {
id.Username = resp.User
}
}
if preferredUsername, ok := m.mapping["preferred_username"]; ok {
id.PreferredUsername = resp.Attributes.Get(preferredUsername)
if id.PreferredUsername == "" && preferredUsername == "userid" {
id.PreferredUsername = resp.User
}
}
if email, ok := m.mapping["email"]; ok {
id.Email = resp.Attributes.Get(email)
if id.Email != "" {
id.EmailVerified = true
}
}
// override memberOf
if groups, ok := m.mapping["groups"]; ok {
id.Groups = resp.Attributes[groups]
}
return id, nil
}
192 changes: 192 additions & 0 deletions connector/cas/cas_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
package cas

import (
"fmt"
"log/slog"
"math/rand"
"net/http"
"net/url"
"os"
"reflect"
"testing"
"time"

"github.com/dexidp/dex/connector"
"github.com/pkg/errors"
"gopkg.in/yaml.v3"
)

type tcase struct {
xml string
mapping map[string]string
id connector.Identity
err string
}

func TestOpen(t *testing.T) {
configSection := `
portal: https://example.org/cas
mapping:
username: name
preferred_username: username
email: email
groups: affiliation
`

var config Config
if err := yaml.Unmarshal([]byte(configSection), &config); err != nil {
t.Errorf("parse config: %v", err)
return
}

conn, err := config.Open("cas", slog.Default())
if err != nil {
t.Errorf("open connector: %v", err)
return
}

casConnector, _ := conn.(*casConnector)
if casConnector.portal.String() != config.Portal {
t.Errorf("expected portal %q, got %q", config.Portal, casConnector.portal.String())
return
}
if !reflect.DeepEqual(casConnector.mapping, config.Mapping) {
t.Errorf("expected mapping %v, got %v", config.Mapping, casConnector.mapping)
return
}
}

func TestCAS(t *testing.T) {
callback := "https://dex.example.org/dex/callback"
casURL, _ := url.Parse("https://example.org/cas")
scope := connector.Scopes{Groups: true}

cases := []tcase{{
xml: "testdata/cas_success.xml",
mapping: map[string]string{
"username": "name",
"preferred_username": "username",
"email": "email",
},
id: connector.Identity{
UserID: "123456",
Username: "jdoe",
PreferredUsername: "jdoe",
Email: "jdoe@example.org",
EmailVerified: true,
Groups: []string{"A", "B"},
ConnectorData: nil,
},
err: "",
}, {
xml: "testdata/cas_success.xml",
mapping: map[string]string{
"username": "name",
"preferred_username": "username",
"email": "email",
"groups": "affiliation",
},
id: connector.Identity{
UserID: "123456",
Username: "jdoe",
PreferredUsername: "jdoe",
Email: "jdoe@example.org",
EmailVerified: true,
Groups: []string{"staff", "faculty"},
ConnectorData: nil,
},
err: "",
}, {
xml: "testdata/cas_failure.xml",
mapping: map[string]string{},
id: connector.Identity{},
err: "INVALID_TICKET: Ticket ST-1856339-aA5Yuvrxzpv8Tau1cYQ7 not recognized",
}}

seed := rand.NewSource(time.Now().UnixNano())
for _, tc := range cases {
ticket := fmt.Sprintf("ST-%d", seed.Int63())
state := fmt.Sprintf("%d", seed.Int63())

conn := &casConnector{
portal: casURL,
mapping: tc.mapping,
logger: slog.Default(),
pathSuffix: "/cas",
client: &http.Client{
Transport: &mockTransport{
ticket: ticket,
file: tc.xml,
},
},
}

// login
login, err := conn.LoginURL(scope, callback, state)
if err != nil {
t.Errorf("get login url: %v", err)
return
}
loginURL, err := url.Parse(login)
if err != nil {
t.Errorf("parse login url: %v", err)
return
}

// cas server
queryService := loginURL.Query().Get("service")
serviceURL, err := url.Parse(queryService)
if err != nil {
t.Errorf("parse service url: %v", err)
return
}
serviceQueryState := serviceURL.Query().Get("state")
if serviceQueryState != state {
t.Errorf("state: expected %#v, got %#v", state, serviceQueryState)
return
}
req, _ := http.NewRequest(http.MethodGet, queryService, nil)
q := req.URL.Query()
q.Set("ticket", ticket)
req.URL.RawQuery = q.Encode()

// validate
id, err := conn.HandleCallback(scope, req)
if err != nil {
if c := errors.Cause(err); c != nil && tc.err != "" && c.Error() == tc.err {
continue
}
t.Errorf("handle callback: %v", err)
return
}
if !reflect.DeepEqual(id, tc.id) {
t.Errorf("identity: expected %#v, got %#v", tc.id, id)
return
}
}
}

type mockTransport struct {
ticket string
file string
}

func (f *mockTransport) RoundTrip(req *http.Request) (*http.Response, error) {
file, err := os.Open(f.file)
if err != nil {
return nil, err
}

if ticket := req.URL.Query().Get("ticket"); ticket != f.ticket {
return nil, fmt.Errorf("ticket: expected %#v, got %#v", f.ticket, ticket)
}

return &http.Response{
StatusCode: http.StatusOK,
Body: file,
Header: http.Header{
"Content-Type": []string{"text/xml"},
},
Request: req,
}, nil
}
5 changes: 5 additions & 0 deletions connector/cas/testdata/cas_failure.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
<cas:serviceResponse xmlns:cas="http://www.yale.edu/tp/cas">
<cas:authenticationFailure code="INVALID_TICKET">
Ticket ST-1856339-aA5Yuvrxzpv8Tau1cYQ7 not recognized
</cas:authenticationFailure>
</cas:serviceResponse>
15 changes: 15 additions & 0 deletions connector/cas/testdata/cas_success.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
<cas:serviceResponse xmlns:cas="http://www.yale.edu/tp/cas">
<cas:authenticationSuccess>
<cas:user>123456</cas:user>
<cas:attributes>
<cas:name>jdoe</cas:name>
<cas:username>jdoe</cas:username>
<cas:email>jdoe@example.org</cas:email>
<cas:affiliation>staff</cas:affiliation>
<cas:affiliation>faculty</cas:affiliation>
<cas:memberOf>A</cas:memberOf>
<cas:memberOf>B</cas:memberOf>
</cas:attributes>
<cas:proxyGrantingTicket>PGTIOU-84678-8a9d...</cas:proxyGrantingTicket>
</cas:authenticationSuccess>
</cas:serviceResponse>
4 changes: 3 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ require (
google.golang.org/api v0.203.0
google.golang.org/grpc v1.67.1
google.golang.org/protobuf v1.35.1
gopkg.in/cas.v2 v2.2.2
gopkg.in/yaml.v3 v3.0.1
)

require (
Expand All @@ -63,6 +65,7 @@ require (
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-openapi/inflect v0.19.0 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/golang/glog v1.2.2 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/golang/protobuf v1.5.4 // indirect
github.com/google/go-cmp v0.6.0 // indirect
Expand Down Expand Up @@ -101,7 +104,6 @@ require (
google.golang.org/genproto/googleapis/api v0.0.0-20241007155032-5fefd90f89a9 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20241015192408-796eee8c2d53 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

replace github.com/dexidp/dex/api/v2 => ./api/v2
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5x
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/glog v1.2.2 h1:1+mZ9upx1Dh6FmUTFR1naJ77miKiXgALjWOZ3NVFPmY=
github.com/golang/glog v1.2.2/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w=
github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE=
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
Expand Down Expand Up @@ -394,6 +396,8 @@ google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpAD
google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c=
google.golang.org/protobuf v1.35.1 h1:m3LfL6/Ca+fqnjnlqQXNpFPABW1UD7mjh8KO2mKFytA=
google.golang.org/protobuf v1.35.1/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
gopkg.in/cas.v2 v2.2.2 h1:teLr/JI7VDEQu6qkXKndYac9w5tfy57sWlV+eNYHH+o=
gopkg.in/cas.v2 v2.2.2/go.mod h1:mlmjh4qM/Jm3eSDD0QVr5GaaSW3nOonSUSWkLLvNYnI=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
Expand Down
Loading