Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: reuse http client for externaldata requests #424

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 62 additions & 1 deletion constraint/pkg/client/drivers/rego/builtin.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
package rego

import (
"crypto/tls"
"crypto/x509"
"encoding/base64"
"fmt"
"net/http"
"net/url"
"time"

"github.com/open-policy-agent/frameworks/constraint/pkg/apis/externaldata/unversioned"
"github.com/open-policy-agent/frameworks/constraint/pkg/externaldata"
"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/rego"
Expand All @@ -12,6 +18,9 @@ import (
const (
providerResponseAPIVersion = "externaldata.gatekeeper.sh/v1beta1"
providerResponseKind = "ProviderResponse"
HTTPSScheme = "https"
idleConnTimeout = 90 * time.Second
maxIdleConnsPerHost = 100
)

func externalDataBuiltin(d *Driver) func(bctx rego.BuiltinContext, regorequest *ast.Term) (*ast.Term, error) {
Expand All @@ -31,6 +40,12 @@ func externalDataBuiltin(d *Driver) func(bctx rego.BuiltinContext, regorequest *
return externaldata.HandleError(http.StatusBadRequest, err)
}

client, err := getClient(&provider, clientCert)
if err != nil {
return externaldata.HandleError(http.StatusInternalServerError,
fmt.Errorf("failed to get HTTP client: %w", err))
}

// check provider response cache
var providerRequestKeys []string
var providerResponseStatusCode int
Expand Down Expand Up @@ -71,7 +86,7 @@ func externalDataBuiltin(d *Driver) func(bctx rego.BuiltinContext, regorequest *
}

if len(providerRequestKeys) > 0 {
externaldataResponse, statusCode, err := d.sendRequestToProvider(bctx.Context, &provider, providerRequestKeys, clientCert)
externaldataResponse, statusCode, err := d.sendRequestToProvider(bctx.Context, &provider, providerRequestKeys, client)
if err != nil {
return externaldata.HandleError(statusCode, err)
}
Expand Down Expand Up @@ -115,3 +130,49 @@ func externalDataBuiltin(d *Driver) func(bctx rego.BuiltinContext, regorequest *
return externaldata.PrepareRegoResponse(regoResponse)
}
}

// getClient returns a new HTTP client, and set up its TLS configuration.
func getClient(provider *unversioned.Provider, clientCert *tls.Certificate) (*http.Client, error) {
u, err := url.Parse(provider.Spec.URL)
if err != nil {
return nil, fmt.Errorf("failed to parse provider URL %s: %w", provider.Spec.URL, err)
}

if u.Scheme != HTTPSScheme {
return nil, fmt.Errorf("only HTTPS scheme is supported")
}

client := &http.Client{
Timeout: time.Duration(provider.Spec.Timeout) * time.Second,
}

tlsConfig := &tls.Config{MinVersion: tls.VersionTLS13}

// present our client cert to the server
// in case provider wants to verify it
if clientCert != nil {
tlsConfig.Certificates = []tls.Certificate{*clientCert}
}

// if the provider presents its own CA bundle,
// we will use it to verify the server's certificate
caBundleData, err := base64.StdEncoding.DecodeString(provider.Spec.CABundle)
if err != nil {
return nil, fmt.Errorf("failed to decode CA bundle: %w", err)
}

providerCertPool := x509.NewCertPool()
if ok := providerCertPool.AppendCertsFromPEM(caBundleData); !ok {
return nil, fmt.Errorf("failed to append provider's CA bundle to certificate pool")
}

tlsConfig.RootCAs = providerCertPool

client.Transport = &http.Transport{
TLSClientConfig: tlsConfig,
IdleConnTimeout: idleConnTimeout,
MaxIdleConnsPerHost: maxIdleConnsPerHost,
}

return client, nil
}
85 changes: 85 additions & 0 deletions constraint/pkg/client/drivers/rego/builtin_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package rego

import (
"crypto/tls"
"testing"

"github.com/open-policy-agent/frameworks/constraint/pkg/apis/externaldata/unversioned"
)

const (
validCABundle = "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1JSUIwekNDQVgyZ0F3SUJBZ0lKQUkvTTdCWWp3Qit1TUEwR0NTcUdTSWIzRFFFQkJRVUFNRVV4Q3pBSkJnTlYKQkFZVEFrRlZNUk13RVFZRFZRUUlEQXBUYjIxbExWTjBZWFJsTVNFd0h3WURWUVFLREJoSmJuUmxjbTVsZENCWAphV1JuYVhSeklGQjBlU0JNZEdRd0hoY05NVEl3T1RFeU1qRTFNakF5V2hjTk1UVXdPVEV5TWpFMU1qQXlXakJGCk1Rc3dDUVlEVlFRR0V3SkJWVEVUTUJFR0ExVUVDQXdLVTI5dFpTMVRkR0YwWlRFaE1COEdBMVVFQ2d3WVNXNTAKWlhKdVpYUWdWMmxrWjJsMGN5QlFkSGtnVEhSa01Gd3dEUVlKS29aSWh2Y05BUUVCQlFBRFN3QXdTQUpCQU5MSgpoUEhoSVRxUWJQa2xHM2liQ1Z4d0dNUmZwL3Y0WHFoZmRRSGRjVmZIYXA2TlE1V29rLzR4SUErdWkzNS9NbU5hCnJ0TnVDK0JkWjF0TXVWQ1BGWmNDQXdFQUFhTlFNRTR3SFFZRFZSME9CQllFRkp2S3M4UmZKYVhUSDA4VytTR3YKelF5S24wSDhNQjhHQTFVZEl3UVlNQmFBRkp2S3M4UmZKYVhUSDA4VytTR3Z6UXlLbjBIOE1Bd0dBMVVkRXdRRgpNQU1CQWY4d0RRWUpLb1pJaHZjTkFRRUZCUUFEUVFCSmxmZkpIeWJqREd4Uk1xYVJtRGhYMCs2djAyVFVLWnNXCnI1UXVWYnBRaEg2dSswVWdjVzBqcDlRd3B4b1BUTFRXR1hFV0JCQnVyeEZ3aUNCaGtRK1YKLS0tLS1FTkQgQ0VSVElGSUNBVEUtLS0tLQo="
badCABundle = "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCmhlbGxvCi0tLS0tRU5EIENFUlRJRklDQVRFLS0tLS0K"
)

func Test_getClient(t *testing.T) {
type args struct {
provider *unversioned.Provider
clientCert *tls.Certificate
}
tests := []struct {
name string
args args
wantErr bool
}{
{
name: "invalid http url",
args: args{
provider: &unversioned.Provider{
Spec: unversioned.ProviderSpec{
URL: "http://foo",
},
},
clientCert: nil,
},
wantErr: true,
},
{
name: "no CA bundle",
args: args{
provider: &unversioned.Provider{
Spec: unversioned.ProviderSpec{
URL: "https://foo",
},
},
clientCert: nil,
},
wantErr: true,
},
{
name: "invalid CA bundle",
args: args{
provider: &unversioned.Provider{
Spec: unversioned.ProviderSpec{
URL: "https://foo",
CABundle: badCABundle,
},
},
clientCert: nil,
},
wantErr: true,
},
{
name: "valid CA bundle",
args: args{
provider: &unversioned.Provider{
Spec: unversioned.ProviderSpec{
URL: "https://foo",
CABundle: validCABundle,
},
},
clientCert: nil,
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := getClient(tt.args.provider, tt.args.clientCert)
if (err != nil) != tt.wantErr {
t.Errorf("getClient() error = %v, wantErr %v", err, tt.wantErr)
return
}
})
}
}
5 changes: 2 additions & 3 deletions constraint/pkg/client/drivers/rego/driver_unit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package rego

import (
"context"
"crypto/tls"
"errors"
"fmt"
"net/http"
Expand Down Expand Up @@ -676,7 +675,7 @@ func TestDriver_ExternalData(t *testing.T) {
},
clientCertContent: clientCert,
clientKeyContent: clientKey,
sendRequestToProvider: func(_ context.Context, _ *unversioned.Provider, _ []string, _ *tls.Certificate) (*externaldata.ProviderResponse, int, error) {
sendRequestToProvider: func(_ context.Context, _ *unversioned.Provider, _ []string, _ *http.Client) (*externaldata.ProviderResponse, int, error) {
return nil, http.StatusBadRequest, errors.New("error from SendRequestToProvider")
},
errorExpected: true,
Expand All @@ -695,7 +694,7 @@ func TestDriver_ExternalData(t *testing.T) {
},
clientCertContent: clientCert,
clientKeyContent: clientKey,
sendRequestToProvider: func(_ context.Context, _ *unversioned.Provider, _ []string, _ *tls.Certificate) (*externaldata.ProviderResponse, int, error) {
sendRequestToProvider: func(_ context.Context, _ *unversioned.Provider, _ []string, _ *http.Client) (*externaldata.ProviderResponse, int, error) {
return &externaldata.ProviderResponse{
APIVersion: "v1beta1",
Kind: "Provider",
Expand Down
4 changes: 4 additions & 0 deletions constraint/pkg/externaldata/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ import (
"k8s.io/apimachinery/pkg/util/wait"
)

const (
HTTPSScheme = "https"
)

type ProviderCache struct {
cache map[string]unversioned.Provider
mux sync.RWMutex
Expand Down
58 changes: 2 additions & 56 deletions constraint/pkg/externaldata/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,15 @@ package externaldata
import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"time"

"github.com/open-policy-agent/frameworks/constraint/pkg/apis/externaldata/unversioned"
)

const (
HTTPScheme = "http"
HTTPSScheme = "https"
)

// RegoRequest is the request for external_data rego function.
type RegoRequest struct {
// ProviderName is the name of the external data provider.
Expand Down Expand Up @@ -57,17 +48,16 @@ func NewProviderRequest(keys []string) *ProviderRequest {
}

// SendRequestToProvider is a function that sends a request to the external data provider.
type SendRequestToProvider func(ctx context.Context, provider *unversioned.Provider, keys []string, clientCert *tls.Certificate) (*ProviderResponse, int, error)
type SendRequestToProvider func(ctx context.Context, provider *unversioned.Provider, keys []string, client *http.Client) (*ProviderResponse, int, error)

// DefaultSendRequestToProvider is the default function to send the request to the external data provider.
func DefaultSendRequestToProvider(ctx context.Context, provider *unversioned.Provider, keys []string, clientCert *tls.Certificate) (*ProviderResponse, int, error) {
func DefaultSendRequestToProvider(ctx context.Context, provider *unversioned.Provider, keys []string, client *http.Client) (*ProviderResponse, int, error) {
externaldataRequest := NewProviderRequest(keys)
body, err := json.Marshal(externaldataRequest)
if err != nil {
return nil, http.StatusInternalServerError, fmt.Errorf("failed to marshal external data request: %w", err)
}

client, err := getClient(provider, clientCert)
if err != nil {
return nil, http.StatusInternalServerError, fmt.Errorf("failed to get HTTP client: %w", err)
}
Expand Down Expand Up @@ -100,50 +90,6 @@ func DefaultSendRequestToProvider(ctx context.Context, provider *unversioned.Pro
return &externaldataResponse, resp.StatusCode, nil
}

// getClient returns a new HTTP client, and set up its TLS configuration.
func getClient(provider *unversioned.Provider, clientCert *tls.Certificate) (*http.Client, error) {
u, err := url.Parse(provider.Spec.URL)
if err != nil {
return nil, fmt.Errorf("failed to parse provider URL %s: %w", provider.Spec.URL, err)
}

if u.Scheme != HTTPSScheme {
return nil, fmt.Errorf("only HTTPS scheme is supported")
}

client := &http.Client{
Timeout: time.Duration(provider.Spec.Timeout) * time.Second,
}

tlsConfig := &tls.Config{MinVersion: tls.VersionTLS13}

// present our client cert to the server
// in case provider wants to verify it
if clientCert != nil {
tlsConfig.Certificates = []tls.Certificate{*clientCert}
}

// if the provider presents its own CA bundle,
// we will use it to verify the server's certificate
caBundleData, err := base64.StdEncoding.DecodeString(provider.Spec.CABundle)
if err != nil {
return nil, fmt.Errorf("failed to decode CA bundle: %w", err)
}

providerCertPool := x509.NewCertPool()
if ok := providerCertPool.AppendCertsFromPEM(caBundleData); !ok {
return nil, fmt.Errorf("failed to append provider's CA bundle to certificate pool")
}

tlsConfig.RootCAs = providerCertPool

client.Transport = &http.Transport{
TLSClientConfig: tlsConfig,
}

return client, nil
}

// ProviderKind strings are special string constants for Providers.
// +kubebuilder:validation:Enum=ProviderRequestKind;ProviderResponseKind
type ProviderKind string
Expand Down
Loading
Loading