diff --git a/README.md b/README.md index 22d0a16..1e92fe0 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,7 @@ Self-Update library for Github, Gitea and Gitlab hosted applications in Go * [SHA256](#sha256) * [ECDSA](#ecdsa) * [Using a single checksum file for all your assets](#using-a-single-checksum-file-for-all-your-assets) +* [macOS universal binaries](#macos-universal-binaries) * [Other providers than Github](#other-providers-than-github) * [GitLab](#gitlab) * [Example:](#example-1) @@ -35,7 +36,7 @@ Self-Update library for Github, Gitea and Gitlab hosted applications in Go # Introduction -go-selfupdate detects the information of the latest release via a source provider and +`go-selfupdate` detects the information of the latest release via a source provider and checks the current version. If a newer version than itself is detected, it downloads the released binary from the source provider and replaces itself. @@ -43,17 +44,19 @@ the source provider and replaces itself. - Retrieve the proper binary for the OS and arch where the binary is running - Update the binary with rollback support on failure - Tested on Linux, macOS and Windows -- Many archive and compression formats are supported (zip, tar, gzip, xzip, bzip2) +- Support for different versions of ARM architecture +- Support macOS universal binaries +- Many archive and compression formats are supported (zip, tar, gzip, xz, bzip2) - Support private repositories - Support hash, signature validation -Two source providers are available: +Three source providers are available: - GitHub - Gitea - Gitlab This library started as a fork of https://github.com/rhysd/go-github-selfupdate. A few things have changed from the original implementation: -- don't expose an external semver.Version type, but provide the same functionality through the API: LessThan, Equal and GreaterThan +- don't expose an external `semver.Version` type, but provide the same functionality through the API: `LessThan`, `Equal` and `GreaterThan` - use an interface to send logs (compatible with standard log.Logger) - able to detect different ARM CPU architectures (the original library wasn't working on my different versions of raspberry pi) - support for assets compressed with bzip2 (.bz2) @@ -80,7 +83,7 @@ func update(version string) error { return nil } - exe, err := os.Executable() + exe, err := selfupdate.ExecutablePath() if err != nil { return errors.New("could not locate executable path") } @@ -301,6 +304,18 @@ Tools like [goreleaser][] produce a single checksum file for all your assets. A updater, _ := NewUpdater(Config{Validator: &ChecksumValidator{UniqueFilename: "checksums.txt"}}) ``` +# macOS universal binaries + +You can ask the updater to choose a macOS universal binary as a fallback if the native architecture wasn't found. + +You need to provide the architecture name for the universal binary in the `Config` struct: + +```go +updater, _ := NewUpdater(Config{UniversalArch: "all"}) +``` + +Default is empty, which means no fallback. + # Other providers than Github This library can be easily extended by providing a new source and release implementation for any git provider @@ -353,7 +368,7 @@ func update() { } fmt.Printf("found release %s\n", release.Version()) - exe, err := os.Executable() + exe, err := selfupdate.ExecutablePath() if err != nil { return errors.New("could not locate executable path") } diff --git a/arch.go b/arch.go index 8fae8d9..e4b1cf8 100644 --- a/arch.go +++ b/arch.go @@ -9,17 +9,26 @@ const ( maxARM = 7 ) -// generateAdditionalArch we can use depending on the type of CPU -func generateAdditionalArch(arch string, goarm uint8) []string { +// getAdditionalArch we can use depending on the type of CPU +func getAdditionalArch(arch string, goarm uint8, universalArch string) []string { + const defaultArchCapacity = 3 + additionalArch := make([]string, 0, defaultArchCapacity) + if arch == "arm" && goarm >= minARM && goarm <= maxARM { - additionalArch := make([]string, 0, maxARM-minARM) + // more precise arch at the top of the list for v := goarm; v >= minARM; v-- { additionalArch = append(additionalArch, fmt.Sprintf("armv%d", v)) } + additionalArch = append(additionalArch, "arm") return additionalArch } + + additionalArch = append(additionalArch, arch) if arch == "amd64" { - return []string{"x86_64"} + additionalArch = append(additionalArch, "x86_64") + } + if universalArch != "" { + additionalArch = append(additionalArch, universalArch) } - return []string{} + return additionalArch } diff --git a/arch_test.go b/arch_test.go index 3fa5c8e..8013b42 100644 --- a/arch_test.go +++ b/arch_test.go @@ -9,23 +9,26 @@ import ( func TestAdditionalArch(t *testing.T) { testData := []struct { - arch string - goarm uint8 - expected []string + arch string + goarm uint8 + universalArch string + expected []string }{ - {"arm64", 8, []string{}}, - {"arm", 8, []string{}}, // armv8 is called arm64 - this shouldn't happen - {"arm", 7, []string{"armv7", "armv6", "armv5"}}, - {"arm", 6, []string{"armv6", "armv5"}}, - {"arm", 5, []string{"armv5"}}, - {"arm", 4, []string{}}, // go is not supporting below armv5 - {"amd64", 0, []string{"x86_64"}}, + {"arm64", 0, "", []string{"arm64"}}, + {"arm64", 0, "all", []string{"arm64", "all"}}, + {"arm", 8, "", []string{"arm"}}, // armv8 is called arm64 - this shouldn't happen + {"arm", 7, "", []string{"armv7", "armv6", "armv5", "arm"}}, + {"arm", 6, "", []string{"armv6", "armv5", "arm"}}, + {"arm", 5, "", []string{"armv5", "arm"}}, + {"arm", 4, "", []string{"arm"}}, // go is not supporting below armv5 + {"amd64", 0, "", []string{"amd64", "x86_64"}}, + {"amd64", 0, "all", []string{"amd64", "x86_64", "all"}}, } for _, testItem := range testData { t.Run(fmt.Sprintf("%s-%d", testItem.arch, testItem.goarm), func(t *testing.T) { - result := generateAdditionalArch(testItem.arch, testItem.goarm) - assert.ElementsMatch(t, testItem.expected, result) + result := getAdditionalArch(testItem.arch, testItem.goarm, testItem.universalArch) + assert.Equal(t, testItem.expected, result) }) } } diff --git a/arm.go b/arm.go index e1bb78f..8e5c257 100644 --- a/arm.go +++ b/arm.go @@ -2,17 +2,8 @@ package selfupdate import ( "debug/buildinfo" - "os" ) -var goarm uint8 - -//nolint:gochecknoinits -func init() { - // avoid using runtime.goarm directly - goarm = getGOARM(os.Args[0]) -} - func getGOARM(goBinary string) uint8 { build, err := buildinfo.ReadFile(goBinary) if err != nil { @@ -21,7 +12,7 @@ func getGOARM(goBinary string) uint8 { for _, setting := range build.Settings { if setting.Key == "GOARM" { // the value is coming from the linker, so it should be safe to convert - return uint8(setting.Value[0] - '0') + return setting.Value[0] - '0' } } return 0 diff --git a/cmd/detect-latest-release/update.go b/cmd/detect-latest-release/update.go index 533ea5d..185fb94 100644 --- a/cmd/detect-latest-release/update.go +++ b/cmd/detect-latest-release/update.go @@ -2,10 +2,8 @@ package main import ( "context" - "errors" "fmt" "log" - "os" "runtime" "github.com/creativeprojects/go-selfupdate" @@ -26,9 +24,9 @@ func update(version string) error { return nil } - exe, err := os.Executable() + exe, err := selfupdate.ExecutablePath() if err != nil { - return errors.New("could not locate executable path") + return fmt.Errorf("could not locate executable path: %w", err) } if err := selfupdate.UpdateTo(context.Background(), latest.AssetURL, latest.AssetName, exe); err != nil { return fmt.Errorf("error occurred while updating binary: %w", err) diff --git a/codecov.yml b/codecov.yml new file mode 100644 index 0000000..1203e7f --- /dev/null +++ b/codecov.yml @@ -0,0 +1,18 @@ +codecov: + notify: + after_n_builds: 6 + +comment: + after_n_builds: 6 + +coverage: + round: nearest + status: + project: + default: + target: auto + threshold: "2%" + patch: + default: + target: "70%" + threshold: "2%" diff --git a/config.go b/config.go index 1e364f6..2c9e37b 100644 --- a/config.go +++ b/config.go @@ -2,7 +2,7 @@ package selfupdate // Config represents the configuration of self-update. type Config struct { - // Source where to load the releases from (example: GitHubSource) + // Source where to load the releases from (example: GitHubSource). Source Source // Validator represents types which enable additional validation of downloaded release. Validator Validator @@ -10,14 +10,17 @@ type Config struct { // An asset is selected if it matches any of those, in addition to the regular tag, os, arch, extensions. // Please make sure that your filter(s) uniquely match an asset. Filters []string - // OS is set to the value of runtime.GOOS by default, but you can force another value here + // OS is set to the value of runtime.GOOS by default, but you can force another value here. OS string - // Arch is set to the value of runtime.GOARCH by default, but you can force another value here + // Arch is set to the value of runtime.GOARCH by default, but you can force another value here. Arch string - // Arm 32bits version. Valid values are 0 (unknown), 5, 6 or 7. Default is detected value (if any) + // Arm 32bits version. Valid values are 0 (unknown), 5, 6 or 7. Default is detected value (if available). Arm uint8 - // Draft permits an upgrade to a "draft" version (default to false) + // Arch name for macOS universal binary. Default to none. + // If set, the updater will only pick the universal binary if the Arch is not found. + UniversalArch string + // Draft permits an upgrade to a "draft" version (default to false). Draft bool - // Prerelease permits an upgrade to a "pre-release" version (default to false) + // Prerelease permits an upgrade to a "pre-release" version (default to false). Prerelease bool } diff --git a/detect.go b/detect.go index f5d5498..b2e5871 100644 --- a/detect.go +++ b/detect.go @@ -15,7 +15,7 @@ var reVersion = regexp.MustCompile(`\d+\.\d+\.\d+`) // It fetches releases information from the source provider and find out the latest release with matching the tag names and asset names. // Drafts and pre-releases are ignored. // Assets would be suffixed by the OS name and the arch name such as 'foo_linux_amd64' where 'foo' is a command name. -// '-' can also be used as a separator. File can be compressed with zip, gzip, zxip, bzip2, tar&gzip or tar&zxip. +// '-' can also be used as a separator. File can be compressed with zip, gzip, xz, bzip2, tar&gzip or tar&xz. // So the asset can have a file extension for the corresponding compression format such as '.zip'. // On Windows, '.exe' also can be contained such as 'foo_windows_amd64.exe.zip'. func (up *Updater) DetectLatest(ctx context.Context, repository Repository) (release *Release, found bool, err error) { @@ -131,7 +131,7 @@ func findValidationAsset(rel SourceRelease, validationName string) (SourceAsset, func (up *Updater) findReleaseAndAsset(rels []SourceRelease, targetVersion string) (SourceRelease, SourceAsset, *semver.Version, bool) { // we put the detected arch at the end of the list: that's fine for ARM so far, // as the additional arch are more accurate than the generic one - for _, arch := range append(generateAdditionalArch(up.arch, up.arm), up.arch) { + for _, arch := range getAdditionalArch(up.arch, up.arm, up.universalArch) { release, asset, version, found := up.findReleaseAndAssetForArch(arch, rels, targetVersion) if found { return release, asset, version, found diff --git a/detect_test.go b/detect_test.go index 4eec275..8624330 100644 --- a/detect_test.go +++ b/detect_test.go @@ -526,17 +526,23 @@ func TestFindReleaseAndAsset(t *testing.T) { rel2 := "rel2" assetLinux386 := "asset_linux_386.tgz" assetLinuxAMD64 := "asset_linux_amd64.tgz" - assetLinuxX86_64 := "asset_linux_x86_64.tgz" + assetLinuxX86 := "asset_linux_x86_64.tgz" assetLinuxARM := "asset_linux_arm.tgz" assetLinuxARMv5 := "asset_linux_armv5.tgz" assetLinuxARMv6 := "asset_linux_armv6.tgz" assetLinuxARMv7 := "asset_linux_armv7.tgz" assetLinuxARM64 := "asset_linux_arm64.tgz" + assetLinuxAll := "asset_linux_all.tgz" + assetDarwinAMD64 := "asset_darwin_amd64.tgz" + assetDarwinARM64 := "asset_darwin_arm64.tgz" + assetDarwinAll := "asset_darwin_all.tgz" + testData := []struct { name string os string arch string arm uint8 + universalArch string releases []SourceRelease version string filters []string @@ -765,7 +771,54 @@ func TestFindReleaseAndAsset(t *testing.T) { name: assetLinux386, }, &GitHubAsset{ - name: assetLinuxX86_64, + name: assetLinuxX86, + }, + }, + }, + }, + version: "v2.0.0", + filters: nil, + found: true, + expectedAssetName: assetLinuxX86, + }, + { + name: "universal binary ignored on linux", + os: "linux", // universal binary is for darwin only + arch: "amd64", + universalArch: "all", + releases: []SourceRelease{ + &GitHubRelease{ + name: rel2, + tagName: tag2, + assets: []SourceAsset{ + &GitHubAsset{ + name: assetLinuxAll, + }, + }, + }, + }, + version: "v2.0.0", + filters: nil, + found: false, + }, + { + name: "match amd64 instead of universal binary", + os: "darwin", // universal binary is for darwin only + arch: "amd64", + universalArch: "all", + releases: []SourceRelease{ + &GitHubRelease{ + name: rel2, + tagName: tag2, + assets: []SourceAsset{ + &GitHubAsset{ + name: assetDarwinAMD64, + }, + &GitHubAsset{ + name: assetDarwinARM64, + }, + &GitHubAsset{ + name: assetDarwinAll, }, }, }, @@ -773,17 +826,89 @@ func TestFindReleaseAndAsset(t *testing.T) { version: "v2.0.0", filters: nil, found: true, - expectedAssetName: assetLinuxX86_64, + expectedAssetName: assetDarwinAMD64, + }, + { + name: "match arm64 instead of universal binary", + os: "darwin", // universal binary is for darwin only + arch: "arm64", + universalArch: "all", + releases: []SourceRelease{ + &GitHubRelease{ + name: rel2, + tagName: tag2, + assets: []SourceAsset{ + &GitHubAsset{ + name: assetDarwinAMD64, + }, + &GitHubAsset{ + name: assetDarwinARM64, + }, + &GitHubAsset{ + name: assetDarwinAll, + }, + }, + }, + }, + version: "v2.0.0", + filters: nil, + found: true, + expectedAssetName: assetDarwinARM64, + }, + { + name: "match universal binary", + os: "darwin", // universal binary is for darwin only + arch: "arm64", + universalArch: "all", + releases: []SourceRelease{ + &GitHubRelease{ + name: rel2, + tagName: tag2, + assets: []SourceAsset{ + &GitHubAsset{ + name: assetDarwinAll, + }, + }, + }, + }, + version: "v2.0.0", + filters: nil, + found: true, + expectedAssetName: assetDarwinAll, + }, + { + name: "no match when universal binary not specified", + os: "darwin", + arch: "arm64", + universalArch: "", + releases: []SourceRelease{ + &GitHubRelease{ + name: rel2, + tagName: tag2, + assets: []SourceAsset{ + &GitHubAsset{ + name: assetDarwinAll, + }, + }, + }, + }, + version: "v2.0.0", + filters: nil, + found: false, }, } for _, testItem := range testData { + testItem := testItem t.Run(testItem.name, func(t *testing.T) { + t.Parallel() + updater, err := NewUpdater(Config{ - Filters: testItem.filters, - OS: testItem.os, - Arch: testItem.arch, - Arm: testItem.arm, + Filters: testItem.filters, + OS: testItem.os, + Arch: testItem.arch, + Arm: testItem.arm, + UniversalArch: testItem.universalArch, }) require.NoError(t, err) _, asset, _, found := updater.findReleaseAndAsset(testItem.releases, testItem.version) diff --git a/internal/path.go b/internal/path.go new file mode 100644 index 0000000..e33f8bb --- /dev/null +++ b/internal/path.go @@ -0,0 +1,21 @@ +package internal + +import ( + "os" + "path/filepath" +) + +// GetExecutablePath returns the path of the executable file with all symlinks resolved. +func GetExecutablePath() (string, error) { + exe, err := os.Executable() + if err != nil { + return "", err + } + + exe, err = filepath.EvalSymlinks(exe) + if err != nil { + return "", err + } + + return exe, nil +} diff --git a/internal/path_test.go b/internal/path_test.go new file mode 100644 index 0000000..0b75e00 --- /dev/null +++ b/internal/path_test.go @@ -0,0 +1,16 @@ +package internal_test + +import ( + "testing" + + "github.com/creativeprojects/go-selfupdate/internal" + "github.com/stretchr/testify/assert" +) + +func TestGetExecutablePath(t *testing.T) { + t.Parallel() + + exe, err := internal.GetExecutablePath() + assert.NoError(t, err) + assert.NotEmpty(t, exe) +} diff --git a/path.go b/path.go new file mode 100644 index 0000000..b7f0ae0 --- /dev/null +++ b/path.go @@ -0,0 +1,7 @@ +package selfupdate + +import "github.com/creativeprojects/go-selfupdate/internal" + +func ExecutablePath() (string, error) { + return internal.GetExecutablePath() +} diff --git a/path_test.go b/path_test.go new file mode 100644 index 0000000..dfcd578 --- /dev/null +++ b/path_test.go @@ -0,0 +1,15 @@ +package selfupdate + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestExecutablePath(t *testing.T) { + t.Parallel() + + exe, err := ExecutablePath() + assert.NoError(t, err) + assert.NotEmpty(t, exe) +} diff --git a/universal_binary.go b/universal_binary.go new file mode 100644 index 0000000..acf55c5 --- /dev/null +++ b/universal_binary.go @@ -0,0 +1,13 @@ +package selfupdate + +import "debug/macho" + +// IsDarwinUniversalBinary checks if the file is a universal binary (also called a fat binary). +func IsDarwinUniversalBinary(filename string) bool { + file, err := macho.OpenFat(filename) + if err == nil { + file.Close() + return true + } + return false +} diff --git a/update.go b/update.go index a8842c6..2b3447f 100644 --- a/update.go +++ b/update.go @@ -10,6 +10,7 @@ import ( "strings" "github.com/Masterminds/semver/v3" + "github.com/creativeprojects/go-selfupdate/internal" "github.com/creativeprojects/go-selfupdate/update" ) @@ -82,7 +83,7 @@ func (up *Updater) UpdateCommand(ctx context.Context, cmdPath string, current st // UpdateSelf updates the running executable itself to the latest version. // 'current' is used to check the latest version against the current version. func (up *Updater) UpdateSelf(ctx context.Context, current string, repository Repository) (*Release, error) { - cmdPath, err := os.Executable() + cmdPath, err := internal.GetExecutablePath() if err != nil { return nil, err } diff --git a/update/apply.go b/update/apply.go index a0800dc..6825e1d 100644 --- a/update/apply.go +++ b/update/apply.go @@ -3,13 +3,13 @@ package update import ( "bytes" "crypto" - "crypto/x509" - "encoding/pem" "errors" "fmt" "io" "os" "path/filepath" + + "github.com/creativeprojects/go-selfupdate/internal" ) var ( @@ -37,7 +37,7 @@ var ( // back to /path/to/target. // // If the roll back operation fails, the file system is left in an inconsistent state (between steps 5 and 6) where -// there is no new executable file and the old executable file could not be be moved to its original location. In this +// there is no new executable file and the old executable file could not be moved to its original location. In this // case you should notify the user of the bad news and ask them to recover manually. Applications can determine whether // the rollback failed by calling RollbackError, see the documentation on that function for additional detail. func Apply(update io.Reader, opts Options) error { @@ -61,14 +61,16 @@ func Apply(update io.Reader, opts Options) error { opts.Verifier = NewECDSAVerifier() } if opts.TargetMode == 0 { - opts.TargetMode = 0755 + opts.TargetMode = 0o755 } // get target path var err error - opts.TargetPath, err = opts.getPath() - if err != nil { - return err + if opts.TargetPath == "" { + opts.TargetPath, err = internal.GetExecutablePath() + if err != nil { + return err + } } var newBytes []byte @@ -180,95 +182,3 @@ type rollbackErr struct { error // original error rollbackErr error // error encountered while rolling back } - -// Options for Apply update -type Options struct { - // TargetPath defines the path to the file to update. - // The emptry string means 'the executable file of the running program'. - TargetPath string - - // Create TargetPath replacement with this file mode. If zero, defaults to 0755. - TargetMode os.FileMode - - // Checksum of the new binary to verify against. If nil, no checksum or signature verification is done. - Checksum []byte - - // Public key to use for signature verification. If nil, no signature verification is done. - PublicKey crypto.PublicKey - - // Signature to verify the updated file. If nil, no signature verification is done. - Signature []byte - - // Pluggable signature verification algorithm. If nil, ECDSA is used. - Verifier Verifier - - // Use this hash function to generate the checksum. If not set, SHA256 is used. - Hash crypto.Hash - - // Store the old executable file at this path after a successful update. - // The empty string means the old executable file will be removed after the update. - OldSavePath string -} - -// SetPublicKeyPEM is a convenience method to set the PublicKey property -// used for checking a completed update's signature by parsing a -// Public Key formatted as PEM data. -func (o *Options) SetPublicKeyPEM(pembytes []byte) error { - block, _ := pem.Decode(pembytes) - if block == nil { - return errors.New("couldn't parse PEM data") - } - - pub, err := x509.ParsePKIXPublicKey(block.Bytes) - if err != nil { - return err - } - o.PublicKey = pub - return nil -} - -func (o *Options) getPath() (string, error) { - if o.TargetPath != "" { - return o.TargetPath, nil - } - exe, err := os.Executable() - if err != nil { - return "", err - } - - exe, err = filepath.EvalSymlinks(exe) - if err != nil { - return "", err - } - - return exe, nil -} - -func (o *Options) verifyChecksum(updated []byte) error { - checksum, err := checksumFor(o.Hash, updated) - if err != nil { - return err - } - - if !bytes.Equal(o.Checksum, checksum) { - return fmt.Errorf("updated file has wrong checksum. Expected: %x, got: %x", o.Checksum, checksum) - } - return nil -} - -func (o *Options) verifySignature(updated []byte) error { - checksum, err := checksumFor(o.Hash, updated) - if err != nil { - return err - } - return o.Verifier.VerifySignature(checksum, o.Signature, o.Hash, o.PublicKey) -} - -func checksumFor(h crypto.Hash, payload []byte) ([]byte, error) { - if !h.Available() { - return nil, errors.New("requested hash function not available") - } - hash := h.New() - hash.Write(payload) // guaranteed not to error - return hash.Sum([]byte{}), nil -} diff --git a/update/apply_test.go b/update/apply_test.go index 2c2223a..a747093 100644 --- a/update/apply_test.go +++ b/update/apply_test.go @@ -10,6 +10,8 @@ import ( "fmt" "os" "testing" + + "github.com/stretchr/testify/assert" ) var ( @@ -24,13 +26,17 @@ func cleanup(path string) { } // we write with a separate name for each test so that we can run them in parallel -func writeOldFile(path string, t *testing.T) { - if err := os.WriteFile(path, oldFile, 0777); err != nil { +func writeOldFile(t *testing.T, path string) { + t.Helper() + + if err := os.WriteFile(path, oldFile, 0o600); err != nil { t.Fatalf("Failed to write file for testing preparation: %v", err) } } -func validateUpdate(path string, err error, t *testing.T) { +func validateUpdate(t *testing.T, path string, err error) { + t.Helper() + if err != nil { t.Fatalf("Failed to update: %v", err) } @@ -46,20 +52,24 @@ func validateUpdate(path string, err error, t *testing.T) { } func TestApplySimple(t *testing.T) { - fName := "TestApplySimple" + t.Parallel() + + fName := t.Name() defer cleanup(fName) - writeOldFile(fName, t) + writeOldFile(t, fName) err := Apply(bytes.NewReader(newFile), Options{ TargetPath: fName, }) - validateUpdate(fName, err, t) + validateUpdate(t, fName, err) } func TestApplyOldSavePath(t *testing.T) { - fName := "TestApplyOldSavePath" + t.Parallel() + + fName := t.Name() defer cleanup(fName) - writeOldFile(fName, t) + writeOldFile(t, fName) oldfName := "OldSavePath" @@ -67,7 +77,7 @@ func TestApplyOldSavePath(t *testing.T) { TargetPath: fName, OldSavePath: oldfName, }) - validateUpdate(fName, err, t) + validateUpdate(t, fName, err) if _, err := os.Stat(oldfName); os.IsNotExist(err) { t.Fatalf("Failed to find the old file: %v", err) @@ -77,21 +87,25 @@ func TestApplyOldSavePath(t *testing.T) { } func TestVerifyChecksum(t *testing.T) { - fName := "TestVerifyChecksum" + t.Parallel() + + fName := t.Name() defer cleanup(fName) - writeOldFile(fName, t) + writeOldFile(t, fName) err := Apply(bytes.NewReader(newFile), Options{ TargetPath: fName, Checksum: newFileChecksum[:], }) - validateUpdate(fName, err, t) + validateUpdate(t, fName, err) } func TestVerifyChecksumNegative(t *testing.T) { - fName := "TestVerifyChecksumNegative" + t.Parallel() + + fName := t.Name() defer cleanup(fName) - writeOldFile(fName, t) + writeOldFile(t, fName) badChecksum := []byte{0x0A, 0x0B, 0x0C, 0xFF} err := Apply(bytes.NewReader(newFile), Options{ @@ -190,10 +204,30 @@ func sign(parsePrivKey func([]byte) (crypto.Signer, error), privatePEM string, s return sig } +func TestSetInvalidPublicKeyPEM(t *testing.T) { + t.Parallel() + + const wrongPublicKey = ` +-----BEGIN PUBLIC KEY----- +== not valid base64 == +-----END PUBLIC KEY----- +` + + fName := t.Name() + defer cleanup(fName) + writeOldFile(t, fName) + + opts := Options{TargetPath: fName} + err := opts.SetPublicKeyPEM([]byte(wrongPublicKey)) + assert.Error(t, err, "Did not fail with invalid public key") +} + func TestVerifyECSignature(t *testing.T) { - fName := "TestVerifyECSignature" + t.Parallel() + + fName := t.Name() defer cleanup(fName) - writeOldFile(fName, t) + writeOldFile(t, fName) opts := Options{TargetPath: fName} err := opts.SetPublicKeyPEM([]byte(ecdsaPublicKey)) @@ -203,13 +237,15 @@ func TestVerifyECSignature(t *testing.T) { opts.Signature = signec(ecdsaPrivateKey, newFile, t) err = Apply(bytes.NewReader(newFile), opts) - validateUpdate(fName, err, t) + validateUpdate(t, fName, err) } func TestVerifyRSASignature(t *testing.T) { - fName := "TestVerifyRSASignature" + t.Parallel() + + fName := t.Name() defer cleanup(fName) - writeOldFile(fName, t) + writeOldFile(t, fName) opts := Options{ TargetPath: fName, @@ -222,13 +258,15 @@ func TestVerifyRSASignature(t *testing.T) { opts.Signature = signrsa(rsaPrivateKey, newFile, t) err = Apply(bytes.NewReader(newFile), opts) - validateUpdate(fName, err, t) + validateUpdate(t, fName, err) } func TestVerifyFailBadSignature(t *testing.T) { - fName := "TestVerifyFailBadSignature" + t.Parallel() + + fName := t.Name() defer cleanup(fName) - writeOldFile(fName, t) + writeOldFile(t, fName) opts := Options{ TargetPath: fName, @@ -246,9 +284,11 @@ func TestVerifyFailBadSignature(t *testing.T) { } func TestVerifyFailNoSignature(t *testing.T) { - fName := "TestVerifySignatureWithPEM" + t.Parallel() + + fName := t.Name() defer cleanup(fName) - writeOldFile(fName, t) + writeOldFile(t, fName) opts := Options{TargetPath: fName} err := opts.SetPublicKeyPEM([]byte(ecdsaPublicKey)) @@ -262,7 +302,10 @@ func TestVerifyFailNoSignature(t *testing.T) { } } -const wrongKey = ` +func TestVerifyFailWrongSignature(t *testing.T) { + t.Parallel() + + const wrongKey = ` -----BEGIN EC PRIVATE KEY----- MIGkAgEBBDBzqYp6N2s8YWYifBjS03/fFfmGeIPcxQEi+bbFeekIYt8NIKIkhD+r hpaIwSmot+qgBwYFK4EEACKhZANiAAR0EC8Usbkc4k30frfEB2ECmsIghu9DJSqE @@ -271,10 +314,9 @@ VBbP/Ff+05HOqwPC7rJMy1VAJLKg7Cw= -----END EC PRIVATE KEY----- ` -func TestVerifyFailWrongSignature(t *testing.T) { - fName := "TestVerifyFailWrongSignature" + fName := t.Name() defer cleanup(fName) - writeOldFile(fName, t) + writeOldFile(t, fName) opts := Options{TargetPath: fName} err := opts.SetPublicKeyPEM([]byte(ecdsaPublicKey)) @@ -290,9 +332,11 @@ func TestVerifyFailWrongSignature(t *testing.T) { } func TestSignatureButNoPublicKey(t *testing.T) { - fName := "TestSignatureButNoPublicKey" + t.Parallel() + + fName := t.Name() defer cleanup(fName) - writeOldFile(fName, t) + writeOldFile(t, fName) err := Apply(bytes.NewReader(newFile), Options{ TargetPath: fName, @@ -304,9 +348,11 @@ func TestSignatureButNoPublicKey(t *testing.T) { } func TestPublicKeyButNoSignature(t *testing.T) { - fName := "TestPublicKeyButNoSignature" + t.Parallel() + + fName := t.Name() defer cleanup(fName) - writeOldFile(fName, t) + writeOldFile(t, fName) opts := Options{TargetPath: fName} if err := opts.SetPublicKeyPEM([]byte(ecdsaPublicKey)); err != nil { @@ -319,9 +365,12 @@ func TestPublicKeyButNoSignature(t *testing.T) { } func TestWriteError(t *testing.T) { - fName := "TestWriteError" + // fix this test patching the global openFile variable + // t.Parallel() + + fName := t.Name() defer cleanup(fName) - writeOldFile(fName, t) + writeOldFile(t, fName) openFile = func(name string, flags int, perm os.FileMode) (*os.File, error) { f, err := os.OpenFile(name, flags, perm) diff --git a/update/hide_test.go b/update/hide_test.go new file mode 100644 index 0000000..94c0f80 --- /dev/null +++ b/update/hide_test.go @@ -0,0 +1,20 @@ +package update + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestHideFile(t *testing.T) { + t.Parallel() + + tempFile := filepath.Join(t.TempDir(), t.Name()) + err := os.WriteFile(tempFile, []byte("test"), 0o644) + assert.NoError(t, err) + + err = hideFile(tempFile) + assert.NoError(t, err) +} diff --git a/update/hide_windows.go b/update/hide_windows.go index c368b9c..3c2feb7 100644 --- a/update/hide_windows.go +++ b/update/hide_windows.go @@ -9,7 +9,11 @@ func hideFile(path string) error { kernel32 := syscall.NewLazyDLL("kernel32.dll") setFileAttributes := kernel32.NewProc("SetFileAttributesW") - r1, _, err := setFileAttributes.Call(uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr(path))), 2) + utf16Str, err := syscall.UTF16PtrFromString(path) + if err != nil { + return err + } + r1, _, err := setFileAttributes.Call(uintptr(unsafe.Pointer(utf16Str)), 2) if r1 == 0 { return err diff --git a/update/options.go b/update/options.go new file mode 100644 index 0000000..ba8edb2 --- /dev/null +++ b/update/options.go @@ -0,0 +1,86 @@ +package update + +import ( + "bytes" + "crypto" + "crypto/x509" + "encoding/pem" + "errors" + "fmt" + "os" +) + +// Options for Apply update +type Options struct { + // TargetPath defines the path to the file to update. + // The empty string means 'the executable file of the running program'. + TargetPath string + + // Create TargetPath replacement with this file mode. If zero, defaults to 0755. + TargetMode os.FileMode + + // Checksum of the new binary to verify against. If nil, no checksum or signature verification is done. + Checksum []byte + + // Public key to use for signature verification. If nil, no signature verification is done. + PublicKey crypto.PublicKey + + // Signature to verify the updated file. If nil, no signature verification is done. + Signature []byte + + // Pluggable signature verification algorithm. If nil, ECDSA is used. + Verifier Verifier + + // Use this hash function to generate the checksum. If not set, SHA256 is used. + Hash crypto.Hash + + // Store the old executable file at this path after a successful update. + // The empty string means the old executable file will be removed after the update. + OldSavePath string +} + +// SetPublicKeyPEM is a convenience method to set the PublicKey property +// used for checking a completed update's signature by parsing a +// Public Key formatted as PEM data. +func (o *Options) SetPublicKeyPEM(pembytes []byte) error { + block, _ := pem.Decode(pembytes) + if block == nil { + return errors.New("couldn't parse PEM data") + } + + pub, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + return err + } + o.PublicKey = pub + return nil +} + +func (o *Options) verifyChecksum(updated []byte) error { + checksum, err := checksumFor(o.Hash, updated) + if err != nil { + return err + } + + if !bytes.Equal(o.Checksum, checksum) { + return fmt.Errorf("updated file has wrong checksum. Expected: %x, got: %x", o.Checksum, checksum) + } + return nil +} + +func (o *Options) verifySignature(updated []byte) error { + checksum, err := checksumFor(o.Hash, updated) + if err != nil { + return err + } + return o.Verifier.VerifySignature(checksum, o.Signature, o.Hash, o.PublicKey) +} + +func checksumFor(h crypto.Hash, payload []byte) ([]byte, error) { + if !h.Available() { + return nil, errors.New("requested hash function not available") + } + hash := h.New() + _, _ = hash.Write(payload) + return hash.Sum([]byte{}), nil +} diff --git a/updater.go b/updater.go index 31dd35e..cdae795 100644 --- a/updater.go +++ b/updater.go @@ -4,18 +4,21 @@ import ( "fmt" "regexp" "runtime" + + "github.com/creativeprojects/go-selfupdate/internal" ) // Updater is responsible for managing the context of self-update. type Updater struct { - source Source - validator Validator - filters []*regexp.Regexp - os string - arch string - arm uint8 - prerelease bool - draft bool + source Source + validator Validator + filters []*regexp.Regexp + os string + arch string + arm uint8 + universalArch string // only filled in when needed + prerelease bool + draft bool } // keep the default updater instance in cache @@ -27,6 +30,7 @@ func NewUpdater(config Config) (*Updater, error) { source := config.Source if source == nil { // default source is GitHub + // an error can only be returned when using GitHub Enterprise URLs source, _ = NewGitHubSource(GitHubConfig{}) } @@ -40,27 +44,33 @@ func NewUpdater(config Config) (*Updater, error) { } os := config.OS - arch := config.Arch if os == "" { os = runtime.GOOS } + arch := config.Arch if arch == "" { arch = runtime.GOARCH } arm := config.Arm - if arm == 0 && goarm > 0 { - arm = goarm + if arm == 0 && arch == "arm" { + exe, _ := internal.GetExecutablePath() + arm = getGOARM(exe) + } + universalArch := "" + if os == "darwin" && config.UniversalArch != "" { + universalArch = config.UniversalArch } return &Updater{ - source: source, - validator: config.Validator, - filters: filtersRe, - os: os, - arch: arch, - arm: arm, - prerelease: config.Prerelease, - draft: config.Draft, + source: source, + validator: config.Validator, + filters: filtersRe, + os: os, + arch: arch, + arm: arm, + universalArch: universalArch, + prerelease: config.Prerelease, + draft: config.Draft, }, nil } @@ -73,14 +83,6 @@ func DefaultUpdater() *Updater { if defaultUpdater != nil { return defaultUpdater } - // an error can only be returned when using GitHub Enterprise URLs - // so we're safe here :) - source, _ := NewGitHubSource(GitHubConfig{}) - defaultUpdater = &Updater{ - source: source, - os: runtime.GOOS, - arch: runtime.GOARCH, - arm: goarm, - } + defaultUpdater, _ = NewUpdater(Config{}) return defaultUpdater }