Skip to content

Commit

Permalink
Use proper TokenSource for IAP auth
Browse files Browse the repository at this point in the history
  • Loading branch information
cedws committed Apr 24, 2023
1 parent cb028d5 commit d8baa48
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 36 deletions.
32 changes: 18 additions & 14 deletions iap/dialopts.go
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
package iap

import (
"golang.org/x/oauth2"
)

type DialOption func(*dialOptions)

type dialOptions struct {
Zone string
Token string
Region string
Project string
Port string
Network string
Interface string
Instance string
Host string
Group string
Compress bool
Zone string
TokenSource *oauth2.TokenSource
Region string
Project string
Port string
Network string
Interface string
Instance string
Host string
Group string
Compress bool
}

func (d *dialOptions) collectOpts(opts []DialOption) {
Expand All @@ -22,10 +26,10 @@ func (d *dialOptions) collectOpts(opts []DialOption) {
}
}

// WithToken is a functional option that sets the authorization token.
func WithToken(token string) func(*dialOptions) {
// WithTokenSource is a functional option that sets the authorization toke source.
func WithTokenSource(tokenSource *oauth2.TokenSource) func(*dialOptions) {
return func(d *dialOptions) {
d.Token = token
d.TokenSource = tokenSource
}
}

Expand Down
17 changes: 13 additions & 4 deletions iap/iap.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,20 @@ func Dial(ctx context.Context, opts ...DialOption) (*Conn, error) {
dopts := &dialOptions{}
dopts.collectOpts(opts)

header := make(http.Header)
header.Set("Origin", proxyOrigin)

if dopts.TokenSource != nil {
token, err := (*dopts.TokenSource).Token()
if err != nil {
return nil, err
}

header.Set("Authorization", fmt.Sprintf("%v %v", token.Type(), token.AccessToken))
}

wsOptions := websocket.DialOptions{
HTTPHeader: http.Header{
"Authorization": []string{fmt.Sprintf("Bearer %v", dopts.Token)},
"Origin": []string{proxyOrigin},
},
HTTPHeader: header,
Subprotocols: []string{proxySubproto},
CompressionMode: websocket.CompressionDisabled,
}
Expand Down
12 changes: 12 additions & 0 deletions internal/cmd/root.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
package cmd

import (
"context"

"github.com/charmbracelet/log"
"github.com/spf13/cobra"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
)

var (
Expand All @@ -17,6 +21,14 @@ var rootCmd = &cobra.Command{
Long: "Utility for Google Cloud's Identity-Aware Proxy",
}

func getTokenSource() *oauth2.TokenSource {
tokenSource, err := google.DefaultTokenSource(context.Background())
if err != nil {
log.Fatal(err)
}
return &tokenSource
}

func init() {
rootCmd.PersistentFlags().BoolVarP(&compress, "compress", "c", false, "Enable WebSocket compression")
rootCmd.PersistentFlags().StringVarP(&listen, "listen", "l", "127.0.0.1:0", "Listen address and port")
Expand Down
3 changes: 2 additions & 1 deletion internal/cmd/to_host.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@ var hostCmd = &cobra.Command{
Long: "Create a tunnel to a remote private IP or FQDN (requires BeyondCorp Enterprise)",
Args: cobra.ExactArgs(1),
PreRun: func(cmd *cobra.Command, args []string) {
log.Info("Starting proxy", "listen", listen, "dest", fmt.Sprintf("%v:%v", args[0], port), "project", project)
log.Info("Started proxy", "listen", listen, "dest", fmt.Sprintf("%v:%v", args[0], port), "project", project)
},
Run: func(cmd *cobra.Command, args []string) {
opts := []iap.DialOption{
iap.WithProject(project),
iap.WithHost(args[0], region, network, destGroup),
iap.WithPort(fmt.Sprint(port)),
iap.WithTokenSource(getTokenSource()),
}
if compress {
opts = append(opts, iap.WithCompression())
Expand Down
1 change: 1 addition & 0 deletions internal/cmd/to_instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ var instanceCmd = &cobra.Command{
iap.WithProject(project),
iap.WithInstance(args[0], zone, ninterface),
iap.WithPort(fmt.Sprint(port)),
iap.WithTokenSource(getTokenSource()),
}
if compress {
opts = append(opts, iap.WithCompression())
Expand Down
20 changes: 3 additions & 17 deletions internal/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,12 @@ import (

"github.com/cedws/iapc/iap"
"github.com/charmbracelet/log"
"golang.org/x/oauth2/google"
)

// Start starts a proxy server that listens on the given address and port.
func Start(listen string, opts []iap.DialOption) {
opts = append(opts, iap.WithToken(getToken()))
if err := testConn(opts); err != nil {
log.Fatal(err)
log.Fatalf("Error testing connection: %v", err)
}

listener, err := net.Listen("tcp", listen)
Expand Down Expand Up @@ -45,7 +43,7 @@ func handleClient(opts []iap.DialOption, conn net.Conn) {

tun, err := iap.Dial(context.Background(), opts...)
if err != nil {
log.Error(err)
log.Errorf("Error dialing IAP: %v", err)
return
}
defer tun.Close()
Expand All @@ -54,17 +52,5 @@ func handleClient(opts []iap.DialOption, conn net.Conn) {
go io.Copy(conn, tun)
io.Copy(tun, conn)

log.Info("Client disconnected", "client", conn.RemoteAddr())
}

func getToken() string {
credentials, err := google.FindDefaultCredentials(context.Background())
if err != nil {
log.Fatal(err)
}
tok, err := credentials.TokenSource.Token()
if err != nil {
log.Fatal(err)
}
return tok.AccessToken
log.Info("Client disconnected", "client", conn.RemoteAddr(), "sentbytes", tun.Sent(), "recvbytes", tun.Received())
}

0 comments on commit d8baa48

Please sign in to comment.