diff --git a/cmd/file.go b/cmd/file.go index 3f69e5c..9b038fe 100644 --- a/cmd/file.go +++ b/cmd/file.go @@ -19,6 +19,7 @@ package cmd import ( "fmt" "path" + "time" "gopkg.in/ini.v1" @@ -34,6 +35,7 @@ func init() { fileCmd.PersistentFlags().StringVarP(&destination, "output", "o", getDefaultCredentialsFile(), "output file for credentials") fileCmd.PersistentFlags().StringVarP(&profileName, "profile", "p", "default", "profile name") fileCmd.PersistentFlags().BoolVarP(&force, "force", "f", false, "overwrite existing profile without prompting") + fileCmd.PersistentFlags().BoolVarP(&autoRefresh, "refresh", "R", false, "automatically refresh credentials in file") rootCmd.AddCommand(fileCmd) } @@ -46,17 +48,54 @@ var fileCmd = &cobra.Command{ func runFile(cmd *cobra.Command, args []string) error { role = args[0] + err := updateCredentialsFile(role, profileName, destination, noIpRestrict, assumeRole) + if err != nil { + return err + } + if autoRefresh { + log.Infof("starting automatic file refresh for %s", role) + go fileRefresher(role, profileName, destination, noIpRestrict, assumeRole) + <-shutdown + } + return nil +} + +func updateCredentialsFile(role, profile, filename string, noIpRestrict bool, assumeRole []string) error { credentials, err := creds.GetCredentials(role, noIpRestrict, assumeRole) if err != nil { return err } - err = writeCredentialsFile(credentials) + err = writeCredentialsFile(credentials, profile, filename) if err != nil { return err } return nil } +func fileRefresher(role, profile, filename string, noIpRestrict bool, assumeRole []string) { + ticker := time.NewTicker(time.Minute) + + for { + select { + case _ = <-ticker.C: + log.Debug("checking credentials") + expiring, err := isExpiring(filename, profile, 10) + if err != nil { + log.Errorf("error checking credential expiration: %v", err) + } + if expiring { + log.Info("credentials are expiring soon, refreshing...") + err = updateCredentialsFile(role, profile, filename, noIpRestrict, assumeRole) + if err != nil { + log.Errorf("error updating credentials: %v", err) + } else { + log.Info("credentials refreshed!") + } + } + } + } +} + func getDefaultCredentialsFile() string { home, err := homedir.Dir() if err != nil { @@ -74,7 +113,7 @@ func getDefaultAwsConfigFile() string { } func shouldOverwriteCredentials() bool { - if force { + if force || autoRefresh { return true } userForce, err := util.PromptBool(fmt.Sprintf("Overwrite %s profile?", profileName)) @@ -84,7 +123,35 @@ func shouldOverwriteCredentials() bool { return userForce } -func writeCredentialsFile(credentials *creds.AwsCredentials) error { +func isExpiring(filename, profile string, thresholdMinutes int) (bool, error) { + fileContents, err := ini.Load(filename) + if err != nil { + return false, err + } + section, err := fileContents.GetSection(profile) + if err != nil { + return true, err + } + expiration, err := section.GetKey("expiration") + if err != nil { + return true, err + } + expirationTime, err := expiration.Time() + if err != nil { + return true, err + } + diff := time.Duration(thresholdMinutes) * time.Minute + timeUntilExpiration := expirationTime.Sub(time.Now()).Round(0) + log.Debugf("%s until expiration, refresh threshold is %s", timeUntilExpiration, diff) + if timeUntilExpiration < diff { + log.Debug("will refresh") + return true, nil + } + log.Debug("will not refresh") + return false, nil +} + +func writeCredentialsFile(credentials *creds.AwsCredentials, profile, filename string) error { var credentialsINI *ini.File var err error @@ -92,8 +159,8 @@ func writeCredentialsFile(credentials *creds.AwsCredentials) error { ini.PrettyFormat = false ini.PrettyEqual = true - if util.FileExists(destination) { - credentialsINI, err = ini.Load(destination) + if util.FileExists(filename) { + credentialsINI, err = ini.Load(filename) if err != nil { return err } @@ -101,18 +168,19 @@ func writeCredentialsFile(credentials *creds.AwsCredentials) error { credentialsINI = ini.Empty() } - if _, err := credentialsINI.GetSection(profileName); err == nil { + if _, err := credentialsINI.GetSection(profile); err == nil { // section already exists, should we overwrite? if !shouldOverwriteCredentials() { // user says no, so we'll just bail out - return fmt.Errorf("not overwriting %s profile", profileName) + return fmt.Errorf("not overwriting %s profile", profile) } } - credentialsINI.Section(profileName).Key("aws_access_key_id").SetValue(credentials.AccessKeyId) - credentialsINI.Section(profileName).Key("aws_secret_access_key").SetValue(credentials.SecretAccessKey) - credentialsINI.Section(profileName).Key("aws_session_token").SetValue(credentials.SessionToken) - err = credentialsINI.SaveTo(destination) + credentialsINI.Section(profile).Key("aws_access_key_id").SetValue(credentials.AccessKeyId) + credentialsINI.Section(profile).Key("aws_secret_access_key").SetValue(credentials.SecretAccessKey) + credentialsINI.Section(profile).Key("aws_session_token").SetValue(credentials.SessionToken) + credentialsINI.Section(profile).Key("expiration").SetValue(credentials.Expiration.Format("2006-01-02T15:04:05Z07:00")) + err = credentialsINI.SaveTo(filename) if err != nil { return err } diff --git a/cmd/vars.go b/cmd/vars.go index 4b3d84e..63aa8d2 100644 --- a/cmd/vars.go +++ b/cmd/vars.go @@ -25,6 +25,7 @@ var ( destination string destinationConfig string force bool + autoRefresh bool noIpRestrict bool metadataRegion string metadataListenAddr string