Skip to content

Commit

Permalink
Merge pull request #68 from k-capehart/jwt-flow
Browse files Browse the repository at this point in the history
jwt flow
  • Loading branch information
k-capehart authored Nov 4, 2024
2 parents 3496b09 + 5bdcded commit 1b0a3c3
Show file tree
Hide file tree
Showing 8 changed files with 241 additions and 7 deletions.
24 changes: 21 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ A REST API wrapper for interacting with Salesforce using the Go programming lang
- Read the [Golang documentation](https://go.dev/doc/)

## Table of Contents

- [Installation](#installation)
- [Types](#types)
- [Authentication](#authentication)
Expand Down Expand Up @@ -44,6 +45,8 @@ type Creds struct {
SecurityToken string
ConsumerKey string
ConsumerSecret string
ConsumerRSAPem string
AccessToken string
}

type SalesforceResults struct {
Expand Down Expand Up @@ -118,6 +121,20 @@ if err != nil {
}
```

[JWT Bearer Flow](https://help.salesforce.com/s/articleView?id=sf.remoteaccess_oauth_jwt_flow.htm&type=5)

```go
sf, err := salesforce.Init(salesforce.Creds{
Domain: DOMAIN,
Username: USERNAME,
ConsumerKey: CONSUMER_KEY,
ConsumerRSAPem: CONSUMER_RSA_PEM,
})
if err != nil {
panic(err)
}
```

Authenticate with an Access Token

- Implement your own OAuth flow and use the resulting `access_token` from the response to initialize go-salesforce
Expand Down Expand Up @@ -225,6 +242,7 @@ SELECT Id, Account.Name FROM Contact
```

#### Corresponding Go Structs

To effectively handle the data returned by this query, define your Go structs as follows:

```go
Expand Down Expand Up @@ -362,8 +380,8 @@ Insert, Update, Upsert, or Delete collections of records
- Perform operations in batches of up to 200 records at a time
- Consider making a Bulk request for very large operations
- Partial successes are enabled
- If a record fails then successes are still committed to the database
- Will return an instance of `SalesforceResults` which contains information on each affected record and whether DML errors were encountered
- If a record fails then successes are still committed to the database
- Will return an instance of `SalesforceResults` which contains information on each affected record and whether DML errors were encountered

### InsertCollection

Expand Down Expand Up @@ -509,7 +527,7 @@ Make numerous 'subrequests' contained within a single 'composite request', reduc
- So if batch size is 1, then max number of records to be included in request is 25
- If batch size is 200, then max is 5000
- Can optionally allow partial successes by setting allOrNone parameter
- If true, then successes are still committed to the database even if a record fails
- If true, then successes are still committed to the database even if a record fails
- Will return an instance of SalesforceResults which contains information on each affected record and whether DML errors were encountered

### InsertComposite
Expand Down
54 changes: 54 additions & 0 deletions auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,15 @@ package salesforce
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"time"

"github.com/golang-jwt/jwt/v5"
)

type authentication struct {
Expand All @@ -28,13 +33,17 @@ type Creds struct {
SecurityToken string
ConsumerKey string
ConsumerSecret string
ConsumerRSAPem string
AccessToken string
}

const JwtExpirationTime = 5 * time.Minute

const (
grantTypeUsernamePassword = "password"
grantTypeClientCredentials = "client_credentials"
grantTypeAccessToken = "access_token"
grantTypeJWT = "urn:ietf:params:oauth:grant-type:jwt-bearer"
)

func validateAuth(sf Salesforce) error {
Expand Down Expand Up @@ -80,6 +89,14 @@ func refreshSession(auth *authentication) error {
auth.creds.ConsumerKey,
auth.creds.ConsumerSecret,
)
case grantTypeJWT:
refreshedAuth, err = jwtFlow(
auth.InstanceUrl,
auth.creds.Username,
auth.creds.ConsumerKey,
auth.creds.ConsumerRSAPem,
JwtExpirationTime,
)
default:
return errors.New("invalid session, unable to refresh session")
}
Expand Down Expand Up @@ -166,3 +183,40 @@ func setAccessToken(domain string, accessToken string) (*authentication, error)
auth.grantType = grantTypeAccessToken
return auth, nil
}

func jwtFlow(domain string, username string, consumerKey string, consumerRSAPem string, expirationTime time.Duration) (*authentication, error) {
audience := domain
if(strings.Contains(audience, "sandbox")) {
audience = "https://test.salesforce.com"
} else {
audience = "https://login.salesforce.com"
}
claims := &jwt.MapClaims{
"exp": strconv.Itoa(int(time.Now().Unix() + int64(expirationTime.Seconds()))),
"aud": audience,
"iss": consumerKey,
"sub": username,
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
signKey, err := jwt.ParseRSAPrivateKeyFromPEM([]byte(consumerRSAPem))
if err != nil {
return nil, fmt.Errorf("ParseRSAPrivateKeyFromPEM: %w", err)
}
tokenString, err := token.SignedString(signKey)
if err != nil {
return nil, fmt.Errorf("jwt.SignedString: %w", err)
}

payload := url.Values{
"grant_type": {grantTypeJWT},
"assertion": {tokenString},
}
endpoint := "/services/oauth2/token"
body := strings.NewReader(payload.Encode())
auth, err := doAuth(domain+endpoint, body)
if err != nil {
return nil, err
}
auth.grantType = grantTypeJWT
return auth, nil
}
97 changes: 97 additions & 0 deletions auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package salesforce

import (
"net/http"
"os"
"reflect"
"testing"
"time"
)

func Test_validateAuth(t *testing.T) {
Expand Down Expand Up @@ -245,13 +247,37 @@ func Test_refreshSession(t *testing.T) {
Signature: "signed",
}
serverClientCredentials, sfAuthClientCredentials := setupTestServer(refreshedAuth, http.StatusOK)
sfAuthClientCredentials.creds = Creds{
Domain: serverClientCredentials.URL,
ConsumerKey: "key",
ConsumerSecret: "secret",
}
defer serverClientCredentials.Close()
sfAuthClientCredentials.grantType = grantTypeClientCredentials

serverUserNamePassword, sfAuthUserNamePassword := setupTestServer(refreshedAuth, http.StatusOK)
sfAuthUserNamePassword.creds = Creds{
Domain: serverUserNamePassword.URL,
Username: "u",
Password: "p",
SecurityToken: "t",
ConsumerKey: "key",
ConsumerSecret: "secret",
}
defer serverUserNamePassword.Close()
sfAuthUserNamePassword.grantType = grantTypeUsernamePassword

serverJwt, sfAuthJwt := setupTestServer(refreshedAuth, http.StatusOK)
sampleKey, _ := os.ReadFile("test/sample_key.pem")
sfAuthJwt.creds = Creds{
Domain: serverJwt.URL,
Username: "u",
ConsumerKey: "key",
ConsumerRSAPem: string(sampleKey),
}
defer serverJwt.Close()
sfAuthJwt.grantType = grantTypeJWT

serverNoGrantType, sfAuthNoGrantType := setupTestServer(refreshedAuth, http.StatusOK)
defer serverNoGrantType.Close()

Expand Down Expand Up @@ -281,6 +307,11 @@ func Test_refreshSession(t *testing.T) {
args: args{auth: &sfAuthUserNamePassword},
wantErr: false,
},
{
name: "refresh_jwt",
args: args{auth: &sfAuthJwt},
wantErr: false,
},
{
name: "error_no_grant_type",
args: args{auth: &sfAuthNoGrantType},
Expand All @@ -305,3 +336,69 @@ func Test_refreshSession(t *testing.T) {
})
}
}

func Test_jwtFlow(t *testing.T) {
auth := authentication{
AccessToken: "1234",
InstanceUrl: "example.com",
Id: "123abc",
IssuedAt: "01/01/1970",
Signature: "signed",
grantType: grantTypeJWT,
}
server, _ := setupTestServer(auth, http.StatusOK)
defer server.Close()

badServer, _ := setupTestServer(auth, http.StatusForbidden)
defer badServer.Close()

sampleKey, _ := os.ReadFile("test/sample_key.pem")

type args struct {
domain string
username string
consumerKey string
consumerRSAPem string
}
tests := []struct {
name string
args args
want *authentication
wantErr bool
}{
{
name: "authentication_success",
args: args{
domain: server.URL,
username: "user",
consumerKey: "key",
consumerRSAPem: string(sampleKey),
},
want: &auth,
wantErr: false,
},
{
name: "authentication_fail",
args: args{
domain: badServer.URL,
username: "user",
consumerKey: "key",
consumerRSAPem: string(sampleKey),
},
want: nil,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := jwtFlow(tt.args.domain, tt.args.username, tt.args.consumerKey, tt.args.consumerRSAPem, 1*time.Minute)
if (err != nil) != tt.wantErr {
t.Errorf("jwtFlow() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("jwtFlow() = %v, want %v", got, tt.want)
}
})
}
}
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ require github.com/mitchellh/mapstructure v1.5.0
require github.com/forcedotcom/go-soql v0.0.0-20220705175410-00f698360bee

require (
github.com/golang-jwt/jwt/v5 v5.2.1
github.com/spf13/afero v1.11.0
k8s.io/apimachinery v0.31.1
)
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ github.com/forcedotcom/go-soql v0.0.0-20220705175410-00f698360bee/go.mod h1:bON1
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk=
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
Expand Down
18 changes: 14 additions & 4 deletions salesforce.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,11 @@ func processSalesforceError(resp http.Response, auth *authentication, payload re
func Init(creds Creds) (*Salesforce, error) {
var auth *authentication
var err error
if creds != (Creds{}) && creds.Domain != "" && creds.ConsumerKey != "" && creds.ConsumerSecret != "" &&
if creds == (Creds{}) {
return nil, errors.New("creds is empty")
}
if creds.Domain != "" && creds.ConsumerKey != "" && creds.ConsumerSecret != "" &&
creds.Username != "" && creds.Password != "" && creds.SecurityToken != "" {

auth, err = usernamePasswordFlow(
creds.Domain,
creds.Username,
Expand All @@ -213,17 +215,25 @@ func Init(creds Creds) (*Salesforce, error) {
creds.ConsumerKey,
creds.ConsumerSecret,
)
} else if creds != (Creds{}) && creds.Domain != "" && creds.ConsumerKey != "" && creds.ConsumerSecret != "" {
} else if creds.Domain != "" && creds.ConsumerKey != "" && creds.ConsumerSecret != "" {
auth, err = clientCredentialsFlow(
creds.Domain,
creds.ConsumerKey,
creds.ConsumerSecret,
)
} else if creds != (Creds{}) && creds.AccessToken != "" {
} else if creds.AccessToken != "" {
auth, err = setAccessToken(
creds.Domain,
creds.AccessToken,
)
} else if creds.Domain != "" && creds.Username != "" && creds.ConsumerKey != "" && creds.ConsumerRSAPem != "" {
auth, err = jwtFlow(
creds.Domain,
creds.Username,
creds.ConsumerKey,
creds.ConsumerRSAPem,
JwtExpirationTime,
)
}

if err != nil {
Expand Down
25 changes: 25 additions & 0 deletions salesforce_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"io"
"net/http"
"net/http/httptest"
"os"
"reflect"
"strconv"
"strings"
Expand Down Expand Up @@ -464,6 +465,25 @@ func TestInit(t *testing.T) {
}
sfAuthAccessToken.creds = credsAccessToken

sfAuthJwt := authentication{
AccessToken: "1234",
InstanceUrl: "example.com",
Id: "123abc",
IssuedAt: "01/01/1970",
Signature: "signed",
grantType: grantTypeJWT,
}
serverJwt, _ := setupTestServer(sfAuthJwt, http.StatusOK)
defer serverJwt.Close()
sampleKey, _ := os.ReadFile("test/sample_key.pem")
credsJwt := Creds{
Domain: serverAccessToken.URL,
Username: "u",
ConsumerKey: "key",
ConsumerRSAPem: string(sampleKey),
}
sfAuthAccessToken.creds = credsAccessToken

type args struct {
creds Creds
}
Expand Down Expand Up @@ -496,6 +516,11 @@ func TestInit(t *testing.T) {
args: args{creds: credsAccessToken},
wantErr: false,
},
{
name: "authentication_jwt",
args: args{creds: credsJwt},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down
Loading

0 comments on commit 1b0a3c3

Please sign in to comment.