diff --git a/execpath/dir.go b/execpath/dir.go new file mode 100644 index 0000000..82f341e --- /dev/null +++ b/execpath/dir.go @@ -0,0 +1,17 @@ +package execpath + +import ( + "os" + "path/filepath" +) + +// Dir returns the absolute path of the directory that contains the executable +// that started the current process. The returned path does not include a +// trailing slash. +func Dir() (string, error) { + exe, err := os.Executable() + if err != nil { + return "", err + } + return filepath.Dir(exe), nil +} diff --git a/execpath/dir_test.go b/execpath/dir_test.go new file mode 100644 index 0000000..c0ec637 --- /dev/null +++ b/execpath/dir_test.go @@ -0,0 +1,26 @@ +package execpath_test + +import ( + "os" + "path/filepath" + "testing" + + "github.com/samherrmann/serveit/execpath" +) + +func TestDir(t *testing.T) { + exe, err := os.Executable() + if err != nil { + t.Error(err) + } + want := filepath.Dir(filepath.ToSlash(exe)) + + got, err := execpath.Dir() + if err != nil { + t.Error(err) + } + + if got != want { + t.Errorf("Wrong executable directory: Got %v, want %v", got, want) + } +} diff --git a/execpath/doc.go b/execpath/doc.go new file mode 100644 index 0000000..8e516e6 --- /dev/null +++ b/execpath/doc.go @@ -0,0 +1,3 @@ +// Package execpath provides utilities to manipulate paths relative to the +// executable that started the current process. +package execpath diff --git a/main.go b/main.go index 6149c98..2cbba35 100644 --- a/main.go +++ b/main.go @@ -4,8 +4,10 @@ import ( "log" "net/http" "os" + "path/filepath" "strconv" + "github.com/samherrmann/serveit/execpath" "github.com/samherrmann/serveit/flag" "github.com/samherrmann/serveit/handlers" "github.com/samherrmann/serveit/security" @@ -29,8 +31,8 @@ func parseFlags() *flag.Config { return config } -func ensureSecrets(hostnames []string) { - if err := security.EnsureKeyPairs(hostnames); err != nil { +func ensureSecrets(dir string, hostnames []string) { + if err := security.EnsureKeyPairs(dir, hostnames); err != nil { log.Fatalln(err) } } @@ -38,14 +40,16 @@ func ensureSecrets(hostnames []string) { func listenAndServe(port int, tls bool, hostnames []string) { addr := ":" + strconv.Itoa(port) log.Println("Serving current directory on port " + addr) - var err error if tls { - ensureSecrets(hostnames) - err = http.ListenAndServeTLS(addr, security.CertFilename, security.KeyFilename, nil) + dir, err := execpath.Dir() + if err != nil { + log.Fatalln(err) + } + ensureSecrets(dir, hostnames) + keyPath := filepath.Join(dir, security.KeyFilename) + certPath := filepath.Join(dir, security.CertFilename) + log.Fatalln(http.ListenAndServeTLS(addr, certPath, keyPath, nil)) } else { - err = http.ListenAndServe(addr, nil) - } - if err != nil { - log.Fatalln(err) + log.Fatalln(http.ListenAndServe(addr, nil)) } } diff --git a/security/cert.go b/security/cert.go index 47daf0a..b1ce67b 100644 --- a/security/cert.go +++ b/security/cert.go @@ -4,8 +4,8 @@ import ( "fmt" "io/ioutil" "net" - "os" "os/exec" + "path/filepath" "strings" ) @@ -19,22 +19,25 @@ var CSRFilename = "serveit.csr" var ExtFilename = "serveit.ext" // EnsureCert creates a server X.509 certificate if it doesn't already exist. -func EnsureCert(hostnames []string) error { - _, err := os.Stat(CertFilename) - if os.IsNotExist(err) { - if err := createCSR(); err != nil { +func EnsureCert(dir string, hostnames []string) error { + exists, err := fileExists(dir, CertFilename) + if err != nil { + return err + } + if !exists { + if err := createCSR(dir); err != nil { return err } - if err = createExtFile(hostnames); err != nil { + if err = createExtFile(dir, hostnames); err != nil { return err } - return CreateCert() + return CreateCert(dir) } - return err + return nil } // CreateCert creates a server X.509 certificate. -func CreateCert() error { +func CreateCert(dir string) error { cmd := exec.Command( "openssl", "x509", "-req", @@ -48,12 +51,13 @@ func CreateCert() error { "-sha256", "-extfile", ExtFilename, ) + cmd.Dir = dir _, err := cmd.CombinedOutput() return err } // createCSR creates a certificate signing request. -func createCSR() error { +func createCSR(dir string) error { cmd := exec.Command( "openssl", "req", "-new", @@ -61,12 +65,13 @@ func createCSR() error { "-subj", "/C=CA/ST=Ontario/L=Ottawa/O=samherrmann/CN=serveit", "-out", CSRFilename, ) + cmd.Dir = dir _, err := cmd.CombinedOutput() return err } // createExtFile creates a certificate extensions file. -func createExtFile(hostnames []string) error { +func createExtFile(dir string, hostnames []string) error { dns := []string{} ips := []string{} @@ -92,5 +97,5 @@ func createExtFile(hostnames []string) error { content += (strings.Join(dns, "\n") + "\n") content += (strings.Join(ips, "\n") + "\n") - return ioutil.WriteFile(ExtFilename, []byte(content), 0644) + return ioutil.WriteFile(filepath.Join(dir, ExtFilename), []byte(content), 0644) } diff --git a/security/file_exists.go b/security/file_exists.go new file mode 100644 index 0000000..1810e3e --- /dev/null +++ b/security/file_exists.go @@ -0,0 +1,17 @@ +package security + +import ( + "os" + "path/filepath" +) + +func fileExists(dir string, filename string) (bool, error) { + _, err := os.Stat(filepath.Join(dir, filename)) + if err == nil { + return true, nil + } + if os.IsNotExist(err) { + return false, nil + } + return false, err +} diff --git a/security/key.go b/security/key.go index 6be8189..8faee93 100644 --- a/security/key.go +++ b/security/key.go @@ -1,7 +1,6 @@ package security import ( - "os" "os/exec" ) @@ -9,21 +8,25 @@ import ( var KeyFilename = "serveit.key" // EnsureKey creates an RSA key if it doesn't already exist. -func EnsureKey() error { - _, err := os.Stat(KeyFilename) - if os.IsNotExist(err) { - return CreateKey() +func EnsureKey(dir string) error { + exists, err := fileExists(dir, KeyFilename) + if err != nil { + return err } - return err + if !exists { + return CreateKey(dir) + } + return nil } // CreateKey creates an RSA key. -func CreateKey() error { +func CreateKey(dir string) error { cmd := exec.Command( "openssl", "genrsa", "-out", KeyFilename, "2048", ) + cmd.Dir = dir _, err := cmd.CombinedOutput() return err } diff --git a/security/keypairs.go b/security/keypairs.go index 09494bb..9dd034d 100644 --- a/security/keypairs.go +++ b/security/keypairs.go @@ -6,20 +6,20 @@ import ( // EnsureKeyPairs creates an RSA key and a X.509 certificate for both the // certificate authority (CA) and the server if they don't already exist. -func EnsureKeyPairs(hostnames []string) error { - err := EnsureRootCAKey() +func EnsureKeyPairs(dir string, hostnames []string) error { + err := EnsureRootCAKey(dir) if err != nil { return fmt.Errorf("Error creating %v: %w", RootCAKeyFilename, err) } - err = EnsureRootCACert() + err = EnsureRootCACert(dir) if err != nil { return fmt.Errorf("Error creating %v: %w", RootCACertFilename, err) } - err = EnsureKey() + err = EnsureKey(dir) if err != nil { return fmt.Errorf("Error creating %v: %w", KeyFilename, err) } - err = EnsureCert(hostnames) + err = EnsureCert(dir, hostnames) if err != nil { return fmt.Errorf("Error creating %v: %w", CertFilename, err) } diff --git a/security/keypairs_test.go b/security/keypairs_test.go index 9de156d..afffc19 100644 --- a/security/keypairs_test.go +++ b/security/keypairs_test.go @@ -5,6 +5,7 @@ import ( "fmt" "io/ioutil" "os" + "path/filepath" "strings" "testing" @@ -12,56 +13,60 @@ import ( ) func TestEnsureKeyPairs(t *testing.T) { + dir := "testdata" + // Start with a clean slate. - if err := removeAllFiles(); err != nil { + if err := removeAllFiles(dir); err != nil { t.Error(err) } - if err := security.EnsureKeyPairs([]string{"localhost"}); err != nil { + if err := security.EnsureKeyPairs(dir, []string{"localhost"}); err != nil { t.Error(err) } - if err := verifyKey(security.RootCAKeyFilename); err != nil { + if err := verifyKey(dir, security.RootCAKeyFilename); err != nil { t.Error(err) } - if err := verifyCert(security.RootCACertFilename); err != nil { + if err := verifyCert(dir, security.RootCACertFilename); err != nil { t.Error(err) } - if err := verifyKey(security.KeyFilename); err != nil { + if err := verifyKey(dir, security.KeyFilename); err != nil { t.Error(err) } - if err := verifyCert(security.CertFilename); err != nil { + if err := verifyCert(dir, security.CertFilename); err != nil { t.Error(err) } // Clean up. - if err := removeAllFiles(); err != nil { + if err := removeAllFiles(dir); err != nil { t.Error(err) } } -func verifyKey(filename string) error { +func verifyKey(dir, filename string) error { return verifyFileContent( + dir, filename, "-----BEGIN RSA PRIVATE KEY-----", "-----END RSA PRIVATE KEY-----", ) } -func verifyCert(filename string) error { +func verifyCert(dir, filename string) error { return verifyFileContent( + dir, filename, "-----BEGIN CERTIFICATE-----", "-----END CERTIFICATE-----", ) } -func verifyFileContent(filename string, prefix string, suffix string) error { +func verifyFileContent(dir, filename, prefix, suffix string) error { // Read content from file. - content, err := ioutil.ReadFile(filename) + content, err := ioutil.ReadFile(filepath.Join(dir, filename)) if err != nil { return fmt.Errorf("Error reading %v: %v", filename, err) } @@ -75,7 +80,7 @@ func verifyFileContent(filename string, prefix string, suffix string) error { return nil } -func removeAllFiles() error { +func removeAllFiles(dir string) error { files := []string{ security.RootCAKeyFilename, security.RootCACertFilename, @@ -87,7 +92,7 @@ func removeAllFiles() error { } for _, file := range files { - if err := os.Remove(file); err != nil && !errors.Is(err, os.ErrNotExist) { + if err := os.Remove(filepath.Join(dir, file)); err != nil && !errors.Is(err, os.ErrNotExist) { return fmt.Errorf("Error removing %v: %w", file, err) } } diff --git a/security/root_ca_cert.go b/security/root_ca_cert.go index 2a3d637..b54ff87 100644 --- a/security/root_ca_cert.go +++ b/security/root_ca_cert.go @@ -1,7 +1,6 @@ package security import ( - "os" "os/exec" ) @@ -15,16 +14,19 @@ var RootCACertSerialFilename = fileRootname + ".srl" // EnsureRootCACert creates a certificate authority (CA) X.509 certificate if it // doesn't already exist. -func EnsureRootCACert() error { - _, err := os.Stat(RootCACertFilename) - if os.IsNotExist(err) { - return CreateRootCACert() +func EnsureRootCACert(dir string) error { + exists, err := fileExists(dir, RootCACertFilename) + if err != nil { + return err } - return err + if !exists { + return CreateRootCACert(dir) + } + return nil } // CreateRootCACert creates a certificate authority (CA) X.509 certificate. -func CreateRootCACert() error { +func CreateRootCACert(dir string) error { cmd := exec.Command( "openssl", "req", "-x509", @@ -37,6 +39,7 @@ func CreateRootCACert() error { "-subj", `/C=CA/ST=Ontario/L=Ottawa/O=samherrmann/CN=Serveit Root Certificate Authority`, "-out", RootCACertFilename, ) + cmd.Dir = dir _, err := cmd.CombinedOutput() return err } diff --git a/security/root_ca_key.go b/security/root_ca_key.go index 03de2e0..0646ca6 100644 --- a/security/root_ca_key.go +++ b/security/root_ca_key.go @@ -1,7 +1,6 @@ package security import ( - "os" "os/exec" ) @@ -10,16 +9,19 @@ var RootCAKeyFilename = "serveit_root_ca.key" // EnsureRootCAKey creates a certificate authority (CA) RSA key if it doesn't // already exist. -func EnsureRootCAKey() error { - _, err := os.Stat(RootCAKeyFilename) - if os.IsNotExist(err) { - return CreateRootCAKey() +func EnsureRootCAKey(dir string) error { + exists, err := fileExists(dir, RootCAKeyFilename) + if err != nil { + return err } - return err + if !exists { + return CreateRootCAKey(dir) + } + return nil } // CreateRootCAKey creates a certificate authority (CA) RSA key. -func CreateRootCAKey() error { +func CreateRootCAKey(dir string) error { cmd := exec.Command( "openssl", "genrsa", "-aes256", @@ -27,6 +29,7 @@ func CreateRootCAKey() error { "-out", RootCAKeyFilename, "2048", ) + cmd.Dir = dir _, err := cmd.CombinedOutput() return err } diff --git a/security/testdata/.gitignore b/security/testdata/.gitignore new file mode 100644 index 0000000..c96a04f --- /dev/null +++ b/security/testdata/.gitignore @@ -0,0 +1,2 @@ +* +!.gitignore \ No newline at end of file