Skip to content

Commit

Permalink
fix windows vss agent
Browse files Browse the repository at this point in the history
  • Loading branch information
Son Roy Almerol committed Nov 4, 2024
1 parent 48e821a commit 8bfa817
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 113 deletions.
37 changes: 7 additions & 30 deletions cmd/pbs_windows_agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"fmt"
"net/url"
"os"
"os/exec"
"strings"
"sync"
"syscall"
Expand All @@ -35,7 +34,7 @@ func main() {
serverUrl, ok := os.LookupEnv("PBS_AGENT_SERVER")
if !ok {
for {
serverUrl = promptInput("PBS Agent", "Server URL")
serverUrl = utils.PromptInput("PBS Agent", "Server URL")

_, err := url.ParseRequestURI(serverUrl)
if err == nil {
Expand All @@ -49,7 +48,7 @@ func main() {

_, err := url.ParseRequestURI(serverUrl)
if err != nil {
showMessageBox("Error", fmt.Sprintf("Invalid server URL: %s", err))
utils.ShowMessageBox("Error", fmt.Sprintf("Invalid server URL: %s", err))
os.Exit(1)
}

Expand All @@ -65,18 +64,18 @@ func main() {

err = sftpConfig.PopulateKeys()
if err != nil {
showMessageBox("Error", fmt.Sprintf("Unable to populate SFTP keys: %s", err))
utils.ShowMessageBox("Error", fmt.Sprintf("Unable to populate SFTP keys: %s", err))
os.Exit(1)
}

port, err := utils.DriveLetterPort(rune)
if err != nil {
showMessageBox("Error", fmt.Sprintf("Unable to map letter to port: %s", err))
utils.ShowMessageBox("Error", fmt.Sprintf("Unable to map letter to port: %s", err))
os.Exit(1)
}

wg.Add(1)
go sftp.Serve(ctx, &wg, sftpConfig, "0.0.0.0", port, fmt.Sprintf("%s:\\", driveLetter))
go sftp.Serve(ctx, &wg, sftpConfig, "0.0.0.0", port, driveLetter)
}

defer snapshots.CloseAllSnapshots()
Expand All @@ -87,28 +86,6 @@ func main() {
wg.Wait()
}

func showMessageBox(title, message string) {
windows.MessageBox(0,
windows.StringToUTF16Ptr(message),
windows.StringToUTF16Ptr(title),
windows.MB_OK|windows.MB_ICONERROR)
}

func promptInput(title, prompt string) string {
cmd := exec.Command("powershell", "-Command", fmt.Sprintf(`
[void][Reflection.Assembly]::LoadWithPartialName('Microsoft.VisualBasic');
$input = [Microsoft.VisualBasic.Interaction]::InputBox('%s', '%s');
$input`, prompt, title))

output, err := cmd.Output()
if err != nil {
fmt.Println("Failed to get input:", err)
return ""
}

return strings.TrimSpace(string(output))
}

func onReady(serverUrl string) func() {
return func() {
systray.SetIcon(icon)
Expand All @@ -117,7 +94,7 @@ func onReady(serverUrl string) func() {

url, err := url.Parse(serverUrl)
if err != nil {
showMessageBox("Error", fmt.Sprintf("Failed to parse server URL: %s", err))
utils.ShowMessageBox("Error", fmt.Sprintf("Failed to parse server URL: %s", err))
os.Exit(1)
}

Expand Down Expand Up @@ -152,6 +129,6 @@ func runAsAdmin() {

err := windows.ShellExecute(0, verbPtr, exePtr, argPtr, cwdPtr, showCmd)
if err != nil {
showMessageBox("Error", fmt.Sprintf("Failed to run as administrator: %s", err))
utils.ShowMessageBox("Error", fmt.Sprintf("Failed to run as administrator: %s", err))
}
}
4 changes: 2 additions & 2 deletions internal/agent/sftp/filelister.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func (h *SftpHandler) FileStat(filename string) (*FileLister, error) {
var stat fs.FileInfo
var err error

isRoot := strings.TrimPrefix(filename, h.SnapshotPath) == ""
isRoot := strings.TrimPrefix(filename, h.Snapshot.SnapshotPath) == ""

if isRoot {
stat, err = os.Stat(filename)
Expand All @@ -80,7 +80,7 @@ func (h *SftpHandler) FileStat(filename string) (*FileLister, error) {
}

func (h *SftpHandler) setFilePath(r *sftp.Request) {
r.Filepath = filepath.Join(h.SnapshotPath, filepath.Clean(r.Filepath))
r.Filepath = filepath.Join(h.Snapshot.SnapshotPath, filepath.Clean(r.Filepath))
}

func (h *SftpHandler) fetch(path string, mode int) (*os.File, error) {
Expand Down
13 changes: 8 additions & 5 deletions internal/agent/sftp/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,20 @@ import (
"log"
"os"
"sync"
"time"

"github.com/pkg/sftp"
"github.com/sonroyaalmerol/pbs-d2d-backup/internal/agent/snapshots"
)

type SftpHandler struct {
mu sync.Mutex
BasePath string
SnapshotPath string
mu sync.Mutex
DriveLetter string
Snapshot *snapshots.WinVSSSnapshot
}

func NewSftpHandler(ctx context.Context, basePath string, snapshot *snapshots.WinVSSSnapshot) (*sftp.Handlers, error) {
handler := &SftpHandler{BasePath: basePath, SnapshotPath: snapshot.SnapshotPath}
func NewSftpHandler(ctx context.Context, driveLetter string, snapshot *snapshots.WinVSSSnapshot) (*sftp.Handlers, error) {
handler := &SftpHandler{DriveLetter: driveLetter, Snapshot: snapshot}

return &sftp.Handlers{
FileGet: handler,
Expand All @@ -36,6 +37,7 @@ func (h *SftpHandler) Fileread(r *sftp.Request) (io.ReaderAt, error) {
h.mu.Lock()
defer h.mu.Unlock()

h.Snapshot.LastAccessed = time.Now()
h.setFilePath(r)

file, err := h.fetch(r.Filepath, os.O_RDONLY)
Expand Down Expand Up @@ -64,6 +66,7 @@ func (h *SftpHandler) Filelist(r *sftp.Request) (sftp.ListerAt, error) {
h.mu.Lock()
defer h.mu.Unlock()

h.Snapshot.LastAccessed = time.Now()
h.setFilePath(r)

switch r.Method {
Expand Down
44 changes: 21 additions & 23 deletions internal/agent/sftp/sftp.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,23 @@ import (
"log"
"net"
"net/url"
"os"
"strings"
"sync"
"time"

"github.com/pkg/sftp"
"github.com/sonroyaalmerol/pbs-d2d-backup/internal/agent/snapshots"
"github.com/sonroyaalmerol/pbs-d2d-backup/internal/utils"
"golang.org/x/crypto/ssh"
)

func Serve(ctx context.Context, wg *sync.WaitGroup, sftpConfig *SFTPConfig, address, port string, baseDir string) {
func Serve(ctx context.Context, wg *sync.WaitGroup, sftpConfig *SFTPConfig, address, port string, driveLetter string) {
defer wg.Done()
listenAt := fmt.Sprintf("%s:%s", address, port)
listener, err := net.Listen("tcp", listenAt)
if err != nil {
log.Fatalf("failed to listen on %s: %v", listenAt, err)
utils.ShowMessageBox("Fatal Error", fmt.Sprintf("Port is already in use! Failed to listen on %s: %v", listenAt, err))
os.Exit(1)
}
defer listener.Close()

Expand All @@ -37,32 +39,32 @@ func Serve(ctx context.Context, wg *sync.WaitGroup, sftpConfig *SFTPConfig, addr
default:
conn, err := listener.Accept()
if err != nil {
log.Printf("failed to accept connection: %v", err)
utils.ShowMessageBox("Error", fmt.Sprintf("failed to accept connection: %v", err))
continue
}

go handleConnection(conn, sftpConfig, baseDir)
go handleConnection(conn, sftpConfig, driveLetter)
}
}
}

func handleConnection(conn net.Conn, sftpConfig *SFTPConfig, baseDir string) {
func handleConnection(conn net.Conn, sftpConfig *SFTPConfig, driveLetter string) {
defer conn.Close()

server, err := url.Parse(sftpConfig.Server)
if err != nil {
log.Printf("failed to parse server IP: %v", err)
utils.ShowMessageBox("Error", fmt.Sprintf("failed to parse server IP: %v", err))
return
}

if !strings.Contains(conn.RemoteAddr().String(), server.Hostname()) {
log.Printf("WARNING: an unregistered client has attempted to connect: %s", conn.RemoteAddr().String())
utils.ShowMessageBox("Error", fmt.Sprintf("WARNING: an unregistered client has attempted to connect: %s", conn.RemoteAddr().String()))
return
}

sconn, chans, reqs, err := ssh.NewServerConn(conn, sftpConfig.ServerConfig)
if err != nil {
log.Printf("failed to perform SSH handshake: %v", err)
utils.ShowMessageBox("Error", fmt.Sprintf("failed to perform SSH handshake: %v", err))
return
}

Expand All @@ -82,7 +84,7 @@ func handleConnection(conn net.Conn, sftpConfig *SFTPConfig, baseDir string) {
}

go handleRequests(requests)
go handleSFTP(channel, baseDir)
go handleSFTP(channel, driveLetter)
}
}

Expand All @@ -95,7 +97,7 @@ func handlePingPong(reqs <-chan *ssh.Request) {
log.Println("Failed to reply to ping:", err)
}
} else {
log.Printf("Received unknown request type: %s", req.Type)
log.Printf("Received unknown request type: %s\n", req.Type)
}
}
}
Expand All @@ -110,29 +112,25 @@ func handleRequests(requests <-chan *ssh.Request) {
}
}

func handleSFTP(channel ssh.Channel, baseDir string) {
func handleSFTP(channel ssh.Channel, driveLetter string) {
defer channel.Close()

snapshot, err := snapshots.Snapshot(baseDir)
snapshot, err := snapshots.Snapshot(driveLetter)
if err != nil {
log.Fatalf("failed to initialize snapshot: %s", err)
utils.ShowMessageBox("Fatal Error", fmt.Sprintf("failed to initialize snapshot: %s", err))
os.Exit(1)
}

ctx := context.Background()
sftpHandler, err := NewSftpHandler(ctx, baseDir, snapshot)
sftpHandler, err := NewSftpHandler(ctx, driveLetter, snapshot)
if err != nil {
_ = snapshot.Close()
log.Fatalf("failed to initialize handler: %s", err)
utils.ShowMessageBox("Fatal Error", fmt.Sprintf("failed to initialize handler: %s", err))
os.Exit(1)
}

snapshot.Used = true
snapshot.LastUsedUpdate = time.Now()

server := sftp.NewRequestServer(channel, *sftpHandler)
if err := server.Serve(); err != nil {
log.Printf("sftp server completed with error: %s", err)
log.Printf("sftp server completed with error: %s\n", err)
}

snapshot.Used = false
snapshot.LastUsedUpdate = time.Now()
}
Loading

0 comments on commit 8bfa817

Please sign in to comment.