Skip to content

Commit

Permalink
fix: reuse http client for externaldata requests
Browse files Browse the repository at this point in the history
  • Loading branch information
mannbiher committed Apr 26, 2024
1 parent c2efb00 commit 00fe810
Show file tree
Hide file tree
Showing 6 changed files with 155 additions and 135 deletions.
63 changes: 62 additions & 1 deletion constraint/pkg/client/drivers/rego/builtin.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
package rego

import (
"crypto/tls"
"crypto/x509"
"encoding/base64"
"fmt"
"net/http"
"net/url"
"time"

"github.com/open-policy-agent/frameworks/constraint/pkg/apis/externaldata/unversioned"
"github.com/open-policy-agent/frameworks/constraint/pkg/externaldata"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/rego"
Expand All @@ -12,6 +18,9 @@ import (
const (
providerResponseAPIVersion = "externaldata.gatekeeper.sh/v1beta1"
providerResponseKind = "ProviderResponse"
HTTPSScheme = "https"
idleConnTimeout = 90 * time.Second
maxIdleConnsPerHost = 100
)

func externalDataBuiltin(d *Driver) func(bctx rego.BuiltinContext, regorequest *ast.Term) (*ast.Term, error) {
Expand All @@ -31,6 +40,12 @@ func externalDataBuiltin(d *Driver) func(bctx rego.BuiltinContext, regorequest *
return externaldata.HandleError(http.StatusBadRequest, err)
}

client, err := getClient(&provider, clientCert)
if err != nil {
return externaldata.HandleError(http.StatusInternalServerError,
fmt.Errorf("failed to get HTTP client: %w", err))
}

// check provider response cache
var providerRequestKeys []string
var providerResponseStatusCode int
Expand Down Expand Up @@ -71,7 +86,7 @@ func externalDataBuiltin(d *Driver) func(bctx rego.BuiltinContext, regorequest *
}

if len(providerRequestKeys) > 0 {
externaldataResponse, statusCode, err := d.sendRequestToProvider(bctx.Context, &provider, providerRequestKeys, clientCert)
externaldataResponse, statusCode, err := d.sendRequestToProvider(bctx.Context, &provider, providerRequestKeys, client)
if err != nil {
return externaldata.HandleError(statusCode, err)
}
Expand Down Expand Up @@ -115,3 +130,49 @@ func externalDataBuiltin(d *Driver) func(bctx rego.BuiltinContext, regorequest *
return externaldata.PrepareRegoResponse(regoResponse)
}
}

// getClient returns a new HTTP client, and set up its TLS configuration.
func getClient(provider *unversioned.Provider, clientCert *tls.Certificate) (*http.Client, error) {
u, err := url.Parse(provider.Spec.URL)
if err != nil {
return nil, fmt.Errorf("failed to parse provider URL %s: %w", provider.Spec.URL, err)
}

if u.Scheme != HTTPSScheme {
return nil, fmt.Errorf("only HTTPS scheme is supported")
}

client := &http.Client{
Timeout: time.Duration(provider.Spec.Timeout) * time.Second,
}

tlsConfig := &tls.Config{MinVersion: tls.VersionTLS13}

// present our client cert to the server
// in case provider wants to verify it
if clientCert != nil {
tlsConfig.Certificates = []tls.Certificate{*clientCert}
}

// if the provider presents its own CA bundle,
// we will use it to verify the server's certificate
caBundleData, err := base64.StdEncoding.DecodeString(provider.Spec.CABundle)
if err != nil {
return nil, fmt.Errorf("failed to decode CA bundle: %w", err)
}

providerCertPool := x509.NewCertPool()
if ok := providerCertPool.AppendCertsFromPEM(caBundleData); !ok {
return nil, fmt.Errorf("failed to append provider's CA bundle to certificate pool")
}

tlsConfig.RootCAs = providerCertPool

client.Transport = &http.Transport{
TLSClientConfig: tlsConfig,
IdleConnTimeout: idleConnTimeout,
MaxIdleConnsPerHost: maxIdleConnsPerHost,
}

return client, nil
}
85 changes: 85 additions & 0 deletions constraint/pkg/client/drivers/rego/builtin_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package rego

import (
"crypto/tls"
"testing"

"github.com/open-policy-agent/frameworks/constraint/pkg/apis/externaldata/unversioned"
)

const (
validCABundle = "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1JSUIwekNDQVgyZ0F3SUJBZ0lKQUkvTTdCWWp3Qit1TUEwR0NTcUdTSWIzRFFFQkJRVUFNRVV4Q3pBSkJnTlYKQkFZVEFrRlZNUk13RVFZRFZRUUlEQXBUYjIxbExWTjBZWFJsTVNFd0h3WURWUVFLREJoSmJuUmxjbTVsZENCWAphV1JuYVhSeklGQjBlU0JNZEdRd0hoY05NVEl3T1RFeU1qRTFNakF5V2hjTk1UVXdPVEV5TWpFMU1qQXlXakJGCk1Rc3dDUVlEVlFRR0V3SkJWVEVUTUJFR0ExVUVDQXdLVTI5dFpTMVRkR0YwWlRFaE1COEdBMVVFQ2d3WVNXNTAKWlhKdVpYUWdWMmxrWjJsMGN5QlFkSGtnVEhSa01Gd3dEUVlKS29aSWh2Y05BUUVCQlFBRFN3QXdTQUpCQU5MSgpoUEhoSVRxUWJQa2xHM2liQ1Z4d0dNUmZwL3Y0WHFoZmRRSGRjVmZIYXA2TlE1V29rLzR4SUErdWkzNS9NbU5hCnJ0TnVDK0JkWjF0TXVWQ1BGWmNDQXdFQUFhTlFNRTR3SFFZRFZSME9CQllFRkp2S3M4UmZKYVhUSDA4VytTR3YKelF5S24wSDhNQjhHQTFVZEl3UVlNQmFBRkp2S3M4UmZKYVhUSDA4VytTR3Z6UXlLbjBIOE1Bd0dBMVVkRXdRRgpNQU1CQWY4d0RRWUpLb1pJaHZjTkFRRUZCUUFEUVFCSmxmZkpIeWJqREd4Uk1xYVJtRGhYMCs2djAyVFVLWnNXCnI1UXVWYnBRaEg2dSswVWdjVzBqcDlRd3B4b1BUTFRXR1hFV0JCQnVyeEZ3aUNCaGtRK1YKLS0tLS1FTkQgQ0VSVElGSUNBVEUtLS0tLQo="
badCABundle = "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCmhlbGxvCi0tLS0tRU5EIENFUlRJRklDQVRFLS0tLS0K"
)

func Test_getClient(t *testing.T) {
type args struct {
provider *unversioned.Provider
clientCert *tls.Certificate
}
tests := []struct {
name string
args args
wantErr bool
}{
{
name: "invalid http url",
args: args{
provider: &unversioned.Provider{
Spec: unversioned.ProviderSpec{
URL: "http://foo",
},
},
clientCert: nil,
},
wantErr: true,
},
{
name: "no CA bundle",
args: args{
provider: &unversioned.Provider{
Spec: unversioned.ProviderSpec{
URL: "https://foo",
},
},
clientCert: nil,
},
wantErr: true,
},
{
name: "invalid CA bundle",
args: args{
provider: &unversioned.Provider{
Spec: unversioned.ProviderSpec{
URL: "https://foo",
CABundle: badCABundle,
},
},
clientCert: nil,
},
wantErr: true,
},
{
name: "valid CA bundle",
args: args{
provider: &unversioned.Provider{
Spec: unversioned.ProviderSpec{
URL: "https://foo",
CABundle: validCABundle,
},
},
clientCert: nil,
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := getClient(tt.args.provider, tt.args.clientCert)
if (err != nil) != tt.wantErr {
t.Errorf("getClient() error = %v, wantErr %v", err, tt.wantErr)
return
}
})
}
}
5 changes: 2 additions & 3 deletions constraint/pkg/client/drivers/rego/driver_unit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package rego

import (
"context"
"crypto/tls"
"errors"
"fmt"
"net/http"
Expand Down Expand Up @@ -676,7 +675,7 @@ func TestDriver_ExternalData(t *testing.T) {
},
clientCertContent: clientCert,
clientKeyContent: clientKey,
sendRequestToProvider: func(_ context.Context, _ *unversioned.Provider, _ []string, _ *tls.Certificate) (*externaldata.ProviderResponse, int, error) {
sendRequestToProvider: func(_ context.Context, _ *unversioned.Provider, _ []string, _ *http.Client) (*externaldata.ProviderResponse, int, error) {
return nil, http.StatusBadRequest, errors.New("error from SendRequestToProvider")
},
errorExpected: true,
Expand All @@ -695,7 +694,7 @@ func TestDriver_ExternalData(t *testing.T) {
},
clientCertContent: clientCert,
clientKeyContent: clientKey,
sendRequestToProvider: func(_ context.Context, _ *unversioned.Provider, _ []string, _ *tls.Certificate) (*externaldata.ProviderResponse, int, error) {
sendRequestToProvider: func(_ context.Context, _ *unversioned.Provider, _ []string, _ *http.Client) (*externaldata.ProviderResponse, int, error) {
return &externaldata.ProviderResponse{
APIVersion: "v1beta1",
Kind: "Provider",
Expand Down
4 changes: 4 additions & 0 deletions constraint/pkg/externaldata/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ import (
"k8s.io/apimachinery/pkg/util/wait"
)

const (
HTTPSScheme = "https"
)

type ProviderCache struct {
cache map[string]unversioned.Provider
mux sync.RWMutex
Expand Down
58 changes: 2 additions & 56 deletions constraint/pkg/externaldata/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,15 @@ package externaldata
import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"time"

"github.com/open-policy-agent/frameworks/constraint/pkg/apis/externaldata/unversioned"
)

const (
HTTPScheme = "http"
HTTPSScheme = "https"
)

// RegoRequest is the request for external_data rego function.
type RegoRequest struct {
// ProviderName is the name of the external data provider.
Expand Down Expand Up @@ -57,17 +48,16 @@ func NewProviderRequest(keys []string) *ProviderRequest {
}

// SendRequestToProvider is a function that sends a request to the external data provider.
type SendRequestToProvider func(ctx context.Context, provider *unversioned.Provider, keys []string, clientCert *tls.Certificate) (*ProviderResponse, int, error)
type SendRequestToProvider func(ctx context.Context, provider *unversioned.Provider, keys []string, client *http.Client) (*ProviderResponse, int, error)

// DefaultSendRequestToProvider is the default function to send the request to the external data provider.
func DefaultSendRequestToProvider(ctx context.Context, provider *unversioned.Provider, keys []string, clientCert *tls.Certificate) (*ProviderResponse, int, error) {
func DefaultSendRequestToProvider(ctx context.Context, provider *unversioned.Provider, keys []string, client *http.Client) (*ProviderResponse, int, error) {
externaldataRequest := NewProviderRequest(keys)
body, err := json.Marshal(externaldataRequest)
if err != nil {
return nil, http.StatusInternalServerError, fmt.Errorf("failed to marshal external data request: %w", err)
}

client, err := getClient(provider, clientCert)
if err != nil {
return nil, http.StatusInternalServerError, fmt.Errorf("failed to get HTTP client: %w", err)
}
Expand Down Expand Up @@ -100,50 +90,6 @@ func DefaultSendRequestToProvider(ctx context.Context, provider *unversioned.Pro
return &externaldataResponse, resp.StatusCode, nil
}

// getClient returns a new HTTP client, and set up its TLS configuration.
func getClient(provider *unversioned.Provider, clientCert *tls.Certificate) (*http.Client, error) {
u, err := url.Parse(provider.Spec.URL)
if err != nil {
return nil, fmt.Errorf("failed to parse provider URL %s: %w", provider.Spec.URL, err)
}

if u.Scheme != HTTPSScheme {
return nil, fmt.Errorf("only HTTPS scheme is supported")
}

client := &http.Client{
Timeout: time.Duration(provider.Spec.Timeout) * time.Second,
}

tlsConfig := &tls.Config{MinVersion: tls.VersionTLS13}

// present our client cert to the server
// in case provider wants to verify it
if clientCert != nil {
tlsConfig.Certificates = []tls.Certificate{*clientCert}
}

// if the provider presents its own CA bundle,
// we will use it to verify the server's certificate
caBundleData, err := base64.StdEncoding.DecodeString(provider.Spec.CABundle)
if err != nil {
return nil, fmt.Errorf("failed to decode CA bundle: %w", err)
}

providerCertPool := x509.NewCertPool()
if ok := providerCertPool.AppendCertsFromPEM(caBundleData); !ok {
return nil, fmt.Errorf("failed to append provider's CA bundle to certificate pool")
}

tlsConfig.RootCAs = providerCertPool

client.Transport = &http.Transport{
TLSClientConfig: tlsConfig,
}

return client, nil
}

// ProviderKind strings are special string constants for Providers.
// +kubebuilder:validation:Enum=ProviderRequestKind;ProviderResponseKind
type ProviderKind string
Expand Down
Loading

0 comments on commit 00fe810

Please sign in to comment.