diff --git a/cmd/windows_agent/service.go b/cmd/windows_agent/service.go index d3d4e71..4ad4e32 100644 --- a/cmd/windows_agent/service.go +++ b/cmd/windows_agent/service.go @@ -14,7 +14,6 @@ import ( "github.com/sonroyaalmerol/pbs-plus/internal/agent" "github.com/sonroyaalmerol/pbs-plus/internal/agent/snapshots" "github.com/sonroyaalmerol/pbs-plus/internal/syslog" - "github.com/sonroyaalmerol/pbs-plus/internal/utils" "golang.org/x/sys/windows/registry" ) @@ -57,12 +56,11 @@ func (p *agentService) startPing() { ping() for { - retryWait := utils.WaitChan(time.Second * 5) select { case <-p.ctx.Done(): agent.SetStatus("Agent service is not running") return - case <-retryWait: + case <-time.After(time.Second * 5): ping() } } @@ -91,11 +89,10 @@ func (p *agentService) run() { if !urlExists() { for !urlExists() { - retryWait := utils.WaitChan(time.Second * 5) select { case <-p.ctx.Done(): return - case <-retryWait: + case <-time.After(time.Second * 5): } } } @@ -106,11 +103,10 @@ func (p *agentService) run() { err = drive.serveSFTP(p) for err != nil { logger.Errorf("Drive SFTP error: %v", err) - retryWait := utils.WaitChan(time.Second * 5) select { case <-p.ctx.Done(): return - case <-retryWait: + case <-time.After(time.Second * 5): err = drive.serveSFTP(p) } } diff --git a/cmd/windows_agent/systray.go b/cmd/windows_agent/systray.go index 7a8307a..b6fb898 100644 --- a/cmd/windows_agent/systray.go +++ b/cmd/windows_agent/systray.go @@ -13,7 +13,6 @@ import ( "github.com/getlantern/systray" "github.com/kardianos/service" "github.com/sonroyaalmerol/pbs-plus/internal/agent" - "github.com/sonroyaalmerol/pbs-plus/internal/utils" "golang.org/x/sys/windows/registry" ) @@ -59,11 +58,10 @@ func (p *agentTray) onReady(url string) func() { setIP() for { - retryWait := utils.WaitChan(time.Second * 2) select { case <-ctx.Done(): return - case <-retryWait: + case <-time.After(time.Second * 2): setIP() } } @@ -88,11 +86,10 @@ func (p *agentTray) onReady(url string) func() { setStatus() for { - retryWait := utils.WaitChan(time.Second * 2) select { case <-ctx.Done(): return - case <-retryWait: + case <-time.After(time.Second * 2): setStatus() } } diff --git a/internal/agent/nfs/fs.go b/internal/agent/nfs/handler.go similarity index 98% rename from internal/agent/nfs/fs.go rename to internal/agent/nfs/handler.go index eb03c0a..f567787 100644 --- a/internal/agent/nfs/fs.go +++ b/internal/agent/nfs/handler.go @@ -80,7 +80,7 @@ func (ro *ReadOnlyFS) Chroot(path string) (billy.Filesystem, error) { if err != nil { return nil, err } - return New(fs), nil + return NewROFS(fs), nil } func (ro *ReadOnlyFS) Root() string { diff --git a/internal/agent/nfs/listener.go b/internal/agent/nfs/listener.go new file mode 100644 index 0000000..eddc753 --- /dev/null +++ b/internal/agent/nfs/listener.go @@ -0,0 +1,26 @@ +package nfs + +import ( + "net" + "strings" +) + +type FilteredListener struct { + net.Listener + allowedIP string +} + +func (fl *FilteredListener) Accept() (net.Conn, error) { + for { + conn, err := fl.Listener.Accept() + if err != nil { + return nil, err + } + + if strings.Contains(conn.RemoteAddr().String(), fl.allowedIP) { + return conn, nil + } + + conn.Close() + } +} diff --git a/internal/agent/nfs/nfs.go b/internal/agent/nfs/nfs.go index 0c589b4..a988152 100644 --- a/internal/agent/nfs/nfs.go +++ b/internal/agent/nfs/nfs.go @@ -7,16 +7,37 @@ import ( "context" "fmt" "net" + "net/url" "time" "github.com/go-git/go-billy/v5/osfs" "github.com/sonroyaalmerol/pbs-plus/internal/agent/snapshots" - "github.com/sonroyaalmerol/pbs-plus/internal/utils" "github.com/willscott/go-nfs" "github.com/willscott/go-nfs/helpers" + "golang.org/x/sys/windows/registry" ) func Serve(ctx context.Context, errChan chan string, address, port string, driveLetter string) { + baseKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, "Software\\PBSPlus\\Config", registry.QUERY_VALUE) + if err != nil { + errChan <- fmt.Sprintf("Unable to create registry key -> %v", err) + return + } + + defer baseKey.Close() + + var server string + if server, _, err = baseKey.GetStringValue("ServerURL"); err != nil { + errChan <- fmt.Sprintf("Unable to get server url -> %v", err) + return + } + + serverUrl, err := url.Parse(server) + if err != nil { + errChan <- fmt.Sprintf("failed to parse server IP: %v", err) + return + } + var listener net.Listener listening := false @@ -28,17 +49,18 @@ func Serve(ctx context.Context, errChan chan string, address, port string, drive errChan <- fmt.Sprintf("Port is already in use! Failed to listen on %s: %v", listenAt, err) return } + + listener = &FilteredListener{Listener: listener, allowedIP: serverUrl.Hostname()} listening = true } listen() for !listening { - retryWait := utils.WaitChan(time.Second * 5) select { case <-ctx.Done(): return - case <-retryWait: + case <-time.After(time.Second * 5): listen() } } @@ -56,21 +78,21 @@ func Serve(ctx context.Context, errChan chan string, address, port string, drive readOnlyFs := NewROFS(fs) nfsHandler := helpers.NewNullAuthHandler(readOnlyFs) - go func() { - for { - go func() { - err := nfs.Serve(listener, nfsHandler) - if err != nil { - errChan <- fmt.Sprintf("NFS server error: %v", err) - } - }() - - select { - case <-ctx.Done(): - listener.Close() - return - case <-errChan: + for { + done := make(chan struct{}) + go func() { + err := nfs.Serve(listener, nfsHandler) + if err != nil { + errChan <- fmt.Sprintf("NFS server error: %v", err) } + close(done) + }() + + select { + case <-ctx.Done(): + listener.Close() + return + case <-done: } - }() + } } diff --git a/internal/agent/sftp/sftp.go b/internal/agent/sftp/sftp.go index 0c688aa..c501c28 100644 --- a/internal/agent/sftp/sftp.go +++ b/internal/agent/sftp/sftp.go @@ -13,7 +13,6 @@ import ( "github.com/pkg/sftp" "github.com/sonroyaalmerol/pbs-plus/internal/agent/snapshots" - "github.com/sonroyaalmerol/pbs-plus/internal/utils" "golang.org/x/crypto/ssh" ) @@ -37,11 +36,10 @@ func Serve(ctx context.Context, errChan chan string, sftpConfig *SFTPConfig, add listen() for !listening { - retryWait := utils.WaitChan(time.Second * 5) select { case <-ctx.Done(): return - case <-retryWait: + case <-time.After(time.Second * 5): listen() } } diff --git a/internal/agent/systray_comm.go b/internal/agent/systray_comm.go index a5964fb..1ff1924 100644 --- a/internal/agent/systray_comm.go +++ b/internal/agent/systray_comm.go @@ -1,3 +1,5 @@ +//go:build windows + package agent import ( diff --git a/internal/utils/wait.go b/internal/utils/wait.go deleted file mode 100644 index 0286bcc..0000000 --- a/internal/utils/wait.go +++ /dev/null @@ -1,14 +0,0 @@ -package utils - -import ( - "time" -) - -func WaitChan(duration time.Duration) <-chan struct{} { - done := make(chan struct{}) - go func() { - time.Sleep(duration) - close(done) - }() - return done -}