Skip to content

Commit

Permalink
Add TLS support.
Browse files Browse the repository at this point in the history
Signed-off-by: Alex Welch <me@alexwelch.com>
  • Loading branch information
bdehamer authored and alexwelch committed Jun 26, 2015
1 parent 25f32bc commit f6d6939
Show file tree
Hide file tree
Showing 14 changed files with 152 additions and 32 deletions.
37 changes: 35 additions & 2 deletions actions/actions.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package actions
import (
"encoding/json"
"fmt"
"os/user"
"strings"

"github.com/CenturyLinkLabs/prettycli"
Expand All @@ -27,8 +28,40 @@ func init() {
}

type Options struct {
Args []string
Flags map[string]string
Args []string
Flags map[string]string
EndpointOptions EndpointOptions
}

type EndpointOptions struct {
Host string
TLS bool
TLSVerify bool
TLSCaCert string
TLSCert string
TLSKey string
}

func (eo EndpointOptions) tlsCaCert() string {
return resolveHomeDirectory(eo.TLSCaCert)
}

func (eo EndpointOptions) tlsCert() string {
return resolveHomeDirectory(eo.TLSCert)
}

func (eo EndpointOptions) tlsKey() string {
return resolveHomeDirectory(eo.TLSKey)
}

func resolveHomeDirectory(path string) string {
if strings.Contains(path, "~") {
usr, _ := user.Current()
dir := usr.HomeDir
return strings.Replace(path, "~", dir, 1)
}

return path
}

type Zodiaction func(Options) (prettycli.Output, error)
Expand Down
2 changes: 1 addition & 1 deletion actions/deploy.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
func Deploy(options Options) (prettycli.Output, error) {
fmt.Println("Deploying your application...")

endpoint, err := endpointFactory(options.Flags["endpoint"])
endpoint, err := endpointFactory(options.EndpointOptions)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion actions/deploy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func TestDeploy_Success(t *testing.T) {
},
}

endpointFactory = func(string) (Endpoint, error) {
endpointFactory = func(EndpointOptions) (Endpoint, error) {
return e, nil
}

Expand Down
54 changes: 49 additions & 5 deletions actions/endpoint.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
package actions

import (
"crypto/tls"
"crypto/x509"
"encoding/json"
"io/ioutil"
"net/url"

log "github.com/Sirupsen/logrus"
Expand All @@ -13,17 +16,57 @@ var (
)

func init() {
endpointFactory = func(dockerHost string) (Endpoint, error) {
c, err := dockerclient.NewDockerClient(dockerHost, nil)
endpointFactory = func(endpointOpts EndpointOptions) (Endpoint, error) {

tlsConfig, err := getTlsConfig(endpointOpts)
if err != nil {
return nil, err
}

c, err := dockerclient.NewDockerClient(endpointOpts.Host, tlsConfig)
if err != nil {
return nil, err
}

return &DockerEndpoint{url: dockerHost, client: c}, nil
return &DockerEndpoint{url: endpointOpts.Host, client: c}, nil
}
}

type EndpointFactory func(string) (Endpoint, error)
func getTlsConfig(endpointOpts EndpointOptions) (*tls.Config, error) {
var tlsConfig *tls.Config

if endpointOpts.TLS {

tlsConfig = &tls.Config{
InsecureSkipVerify: !endpointOpts.TLSVerify,
}

if endpointOpts.tlsCert() != "" && endpointOpts.tlsKey() != "" {

cert, err := tls.LoadX509KeyPair(endpointOpts.tlsCert(), endpointOpts.tlsKey())
if err != nil {
return nil, err
}
tlsConfig.Certificates = []tls.Certificate{cert}
}

// Load CA cert
if endpointOpts.tlsCaCert() != "" {

caCert, err := ioutil.ReadFile(endpointOpts.tlsCaCert())
if err != nil {
return nil, err
}
caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM(caCert)

tlsConfig.RootCAs = caCertPool
}
}
return tlsConfig, nil
}

type EndpointFactory func(EndpointOptions) (Endpoint, error)

type Endpoint interface {
Version() (string, error)
Expand All @@ -40,6 +83,7 @@ type DockerEndpoint struct {
client dockerclient.Client
}

// TODO: can we ditch this? Should always have it on the client
func (e *DockerEndpoint) Host() string {
url, _ := url.Parse(e.url)
return url.Host
Expand All @@ -54,13 +98,13 @@ func (e *DockerEndpoint) Version() (string, error) {
return v.Version, nil
}

// TODO: can we ditch this? Should always have it on the client
func (e *DockerEndpoint) Name() string {
return e.url
}

func (e *DockerEndpoint) StartContainer(name string, cc ContainerConfig) error {
dcc, _ := translateContainerConfig(cc)

id, err := e.client.CreateContainer(&dcc, name)
if err != nil {
log.Fatalf("Problem creating container: ", err)
Expand Down
2 changes: 1 addition & 1 deletion actions/list.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (

func List(options Options) (prettycli.Output, error) {

endpoint, err := endpointFactory(options.Flags["endpoint"])
endpoint, err := endpointFactory(options.EndpointOptions)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion actions/list_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func TestList_Success(t *testing.T) {
},
}

endpointFactory = func(string) (Endpoint, error) {
endpointFactory = func(EndpointOptions) (Endpoint, error) {
return e, nil
}

Expand Down
2 changes: 1 addition & 1 deletion actions/rollback.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
func Rollback(options Options) (prettycli.Output, error) {
fmt.Println("Rolling back your application...")

endpoint, err := endpointFactory(options.Flags["endpoint"])
endpoint, err := endpointFactory(options.EndpointOptions)
if err != nil {
return nil, err
}
Expand Down
8 changes: 4 additions & 4 deletions actions/rollback_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func TestRollback_Success(t *testing.T) {
},
}

endpointFactory = func(string) (Endpoint, error) {
endpointFactory = func(EndpointOptions) (Endpoint, error) {
return e, nil
}

Expand Down Expand Up @@ -168,7 +168,7 @@ func TestRollbackWithID_Success(t *testing.T) {
},
}

endpointFactory = func(string) (Endpoint, error) {
endpointFactory = func(EndpointOptions) (Endpoint, error) {
return e, nil
}

Expand Down Expand Up @@ -234,7 +234,7 @@ func TestRollbackWithNoPreviousDeployment_Error(t *testing.T) {
},
}

endpointFactory = func(string) (Endpoint, error) {
endpointFactory = func(EndpointOptions) (Endpoint, error) {
return e, nil
}

Expand Down Expand Up @@ -304,7 +304,7 @@ func TestRollbackWithNonexistingID_Error(t *testing.T) {
},
}

endpointFactory = func(string) (Endpoint, error) {
endpointFactory = func(EndpointOptions) (Endpoint, error) {
return e, nil
}

Expand Down
2 changes: 1 addition & 1 deletion actions/teardown.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (

func Teardown(options Options) (prettycli.Output, error) {

endpoint, err := endpointFactory(options.Flags["endpoint"])
endpoint, err := endpointFactory(options.EndpointOptions)
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion actions/teardown_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func TestTeardown_Success(t *testing.T) {
},
}

endpointFactory = func(string) (Endpoint, error) {
endpointFactory = func(EndpointOptions) (Endpoint, error) {
return e, nil
}

Expand Down
2 changes: 1 addition & 1 deletion actions/verify.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ var RequredAPIVersion = semver.MustParse("1.6.0")

func Verify(options Options) (prettycli.Output, error) {

endpoint, err := endpointFactory(options.Flags["endpoint"])
endpoint, err := endpointFactory(options.EndpointOptions)
if err != nil {
return nil, err
}
Expand Down
8 changes: 4 additions & 4 deletions actions/verify_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ func (e mockVerifyEndpoint) Name() string {

func TestVerify_Success(t *testing.T) {
e := mockVerifyEndpoint{version: "1.6.1", url: "http://foo.bar"}
endpointFactory = func(string) (Endpoint, error) {
endpointFactory = func(EndpointOptions) (Endpoint, error) {
return e, nil
}
o, err := Verify(Options{})
Expand All @@ -38,7 +38,7 @@ func TestVerify_Success(t *testing.T) {

func TestVerify_ErroredOldVersion(t *testing.T) {
e := mockVerifyEndpoint{version: "1.5.0"}
endpointFactory = func(string) (Endpoint, error) {
endpointFactory = func(EndpointOptions) (Endpoint, error) {
return e, nil
}
o, err := Verify(Options{})
Expand All @@ -49,7 +49,7 @@ func TestVerify_ErroredOldVersion(t *testing.T) {

func TestVerify_ErroredCrazyVersion(t *testing.T) {
e := mockVerifyEndpoint{version: "eleventy-billion"}
endpointFactory = func(string) (Endpoint, error) {
endpointFactory = func(EndpointOptions) (Endpoint, error) {
return e, nil
}
o, err := Verify(Options{})
Expand All @@ -60,7 +60,7 @@ func TestVerify_ErroredCrazyVersion(t *testing.T) {

func TestVerify_ErroredAPIError(t *testing.T) {
e := mockVerifyEndpoint{ErrorForVersion: errors.New("test error")}
endpointFactory = func(string) (Endpoint, error) {
endpointFactory = func(EndpointOptions) (Endpoint, error) {
return e, nil
}
o, err := Verify(Options{})
Expand Down
12 changes: 7 additions & 5 deletions integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ import (

var b *clitest.BuildTester

const TLSFlag = "--tls=false"

func setup(t *testing.T) {
if testing.Short() {
t.SkipNow()
Expand Down Expand Up @@ -49,7 +51,7 @@ func TestVerify_Successful(t *testing.T) {
s, endpointFlag := newFakeServerAndFlag()
defer s.Close()

r := b.Run(t, endpointFlag, "verify")
r := b.Run(t, TLSFlag, endpointFlag, "verify")
r.AssertSuccessful()
assert.Contains(t, r.Stdout(), "Successfully verified endpoint:")
assert.Empty(t, r.Stderr())
Expand All @@ -71,10 +73,10 @@ func TestVerify_EndpointEnvVar(t *testing.T) {

parts := strings.Split(endpointFlag, "=")

os.Setenv("ZODIAC_DOCKER_ENDPOINT", parts[1])
defer os.Unsetenv("ZODIAC_DOCKER_ENDPOINT")
os.Setenv("DOCKER_HOST", parts[1])
defer os.Unsetenv("DOCKER_HOST")

r := b.Run(t, "verify")
r := b.Run(t, TLSFlag, "verify")
r.AssertSuccessful()
assert.Contains(t, r.Stdout(), "Successfully verified endpoint:")
assert.Empty(t, r.Stderr())
Expand All @@ -85,7 +87,7 @@ func TestDeploy_Successful(t *testing.T) {
s, endpointFlag := newFakeServerAndFlag()
defer s.Close()

r := b.Run(t, endpointFlag, "deploy", "-f", "fixtures/webapp.yml")
r := b.Run(t, TLSFlag, endpointFlag, "deploy", "-f", "fixtures/webapp.yml")
fmt.Println(r.Stderr())
fmt.Println(r.Stdout())
r.AssertSuccessful()
Expand Down
Loading

0 comments on commit f6d6939

Please sign in to comment.