Skip to content

Commit

Permalink
fix: sftp retry on connection lost (#465)
Browse files Browse the repository at this point in the history
Co-authored-by: ItsSudip <sudip.paul1997@gmail.com>
Co-authored-by: Dilip Kola <kdilipkola@gmail.com>
  • Loading branch information
3 people authored May 20, 2024
1 parent 71f4ea7 commit 8383ee7
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 34 deletions.
44 changes: 34 additions & 10 deletions sftp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ func sshClientConfig(config *SSHConfig) (*ssh.ClientConfig, error) {
return sshConfig, nil
}

// NewSSHClient establishes an SSH connection and returns an SSH client
func NewSSHClient(config *SSHConfig) (*ssh.Client, error) {
// newSSHClient establishes an SSH connection and returns an SSH client
func newSSHClient(config *SSHConfig) (*ssh.Client, error) {
sshConfig, err := sshClientConfig(config)
if err != nil {
return nil, fmt.Errorf("cannot configure SSH client: %w", err)
Expand All @@ -80,34 +80,58 @@ func NewSSHClient(config *SSHConfig) (*ssh.Client, error) {
}

type clientImpl struct {
client *sftp.Client
sftpClient *sftp.Client
config *SSHConfig
}

type Client interface {
OpenFile(path string, f int) (io.ReadWriteCloser, error)
Remove(path string) error
MkdirAll(path string) error
Reset() error
}

// newSFTPClient creates an SFTP client with existing SSH client
func newSFTPClient(client *ssh.Client) (Client, error) {
sftpClient, err := sftp.NewClient(client)
func newSFTPClient(client *ssh.Client) (*sftp.Client, error) {
return sftp.NewClient(client)
}

func newSFTPClientFromConfig(config *SSHConfig) (*sftp.Client, error) {
sshClient, err := newSSHClient(config)
if err != nil {
return nil, fmt.Errorf("creating SSH client: %w", err)
}
return newSFTPClient(sshClient)
}

func newClient(config *SSHConfig) (Client, error) {
sftpClient, err := newSFTPClientFromConfig(config)
if err != nil {
return nil, fmt.Errorf("cannot create SFTP client: %w", err)
return nil, fmt.Errorf("creating SFTP client: %w", err)
}
return &clientImpl{
client: sftpClient,
sftpClient: sftpClient,
config: config,
}, nil
}

func (c *clientImpl) OpenFile(path string, f int) (io.ReadWriteCloser, error) {
return c.client.OpenFile(path, f)
return c.sftpClient.OpenFile(path, f)
}

func (c *clientImpl) Remove(path string) error {
return c.client.Remove(path)
return c.sftpClient.Remove(path)
}

func (c *clientImpl) MkdirAll(path string) error {
return c.client.MkdirAll(path)
return c.sftpClient.MkdirAll(path)
}

func (c *clientImpl) Reset() error {
newSFTPClient, err := newSFTPClientFromConfig(c.config)
if err != nil {
return err
}
c.sftpClient = newSFTPClient
return nil
}
14 changes: 14 additions & 0 deletions sftp/mock_sftp/mock_sftp_client.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

87 changes: 77 additions & 10 deletions sftp/sftp.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
package sftp

import (
"errors"
"fmt"
"io"
"os"
"path/filepath"

"golang.org/x/crypto/ssh"
"github.com/pkg/sftp"
)

const (
Expand All @@ -23,21 +24,33 @@ type FileManager interface {
Delete(remoteFilePath string) error
}

type Option func(impl *fileManagerImpl)

// WithRetryOnIdleConnection enables retrying the operation once in case of a "connection lost" error due to an idle connection.
func WithRetryOnIdleConnection() Option {
return func(impl *fileManagerImpl) {
impl.retryOnIdleConnection = true
}
}

// fileManagerImpl is a real implementation of FileManager
type fileManagerImpl struct {
client Client
client Client
retryOnIdleConnection bool
}

func NewFileManager(sshClient *ssh.Client) (FileManager, error) {
sftpClient, err := newSFTPClient(sshClient)
if err != nil {
return nil, fmt.Errorf("cannot create SFTP client: %w", err)
// Upload uploads a file to the remote server
func (fm *fileManagerImpl) Upload(localFilePath, remoteFilePath string) error {
if fm.retryOnIdleConnection {
return fm.retryOnConnectionLost(func() error {
return fm.upload(localFilePath, remoteFilePath)
})
}
return &fileManagerImpl{client: sftpClient}, nil

return fm.upload(localFilePath, remoteFilePath)
}

// Upload uploads a file to the remote server
func (fm *fileManagerImpl) Upload(localFilePath, remoteFilePath string) error {
func (fm *fileManagerImpl) upload(localFilePath, remoteFilePath string) error {
localFile, err := os.Open(localFilePath)
if err != nil {
return fmt.Errorf("cannot open local file: %w", err)
Expand Down Expand Up @@ -70,6 +83,16 @@ func (fm *fileManagerImpl) Upload(localFilePath, remoteFilePath string) error {

// Download downloads a file from the remote server
func (fm *fileManagerImpl) Download(remoteFilePath, localDir string) error {
if fm.retryOnIdleConnection {
return fm.retryOnConnectionLost(func() error {
return fm.download(remoteFilePath, localDir)
})
}

return fm.download(remoteFilePath, localDir)
}

func (fm *fileManagerImpl) download(remoteFilePath, localDir string) error {
remoteFile, err := fm.client.OpenFile(remoteFilePath, os.O_RDONLY)
if err != nil {
return fmt.Errorf("cannot open remote file: %w", err)
Expand Down Expand Up @@ -97,10 +120,54 @@ func (fm *fileManagerImpl) Download(remoteFilePath, localDir string) error {

// Delete deletes a file on the remote server
func (fm *fileManagerImpl) Delete(remoteFilePath string) error {
if fm.retryOnIdleConnection {
return fm.retryOnConnectionLost(func() error {
return fm.delete(remoteFilePath)
})
}

return fm.delete(remoteFilePath)
}

func (fm *fileManagerImpl) delete(remoteFilePath string) error {
err := fm.client.Remove(remoteFilePath)
if err != nil {
return fmt.Errorf("cannot delete file: %w", err)
}

return nil
}

func (fm *fileManagerImpl) reset() error {
return fm.client.Reset()
}

// NewFileManager is not concurrent safe. It should not be used from multiple goroutines concurrently without additional synchronization.
func NewFileManager(config *SSHConfig, opts ...Option) (FileManager, error) {
sftpClient, err := newClient(config)
if err != nil {
return nil, err
}
fm := &fileManagerImpl{client: sftpClient}
for _, opt := range opts {
opt(fm)
}
return fm, nil
}

func (fm *fileManagerImpl) retryOnConnectionLost(fileOperation func() error) error {
err := fileOperation()
if err == nil || !isConnectionLostError(err) {
return err // Operation successful or non-retryable error
}

if err := fm.reset(); err != nil {
return err
}

// Retry the operation
return fileOperation()
}

func isConnectionLostError(err error) bool {
return errors.Is(err, sftp.ErrSshFxConnectionLost)
}
57 changes: 43 additions & 14 deletions sftp/sftp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (

"github.com/golang/mock/gomock"
"github.com/ory/dockertest/v3"
"github.com/pkg/sftp"
"github.com/stretchr/testify/require"

"github.com/rudderlabs/rudder-go-kit/sftp/mock_sftp"
Expand Down Expand Up @@ -115,7 +116,7 @@ func TestSSHClientConfig(t *testing.T) {
}
}

func TestUpload(t *testing.T) {
func TestUploadWithRetry(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

Expand All @@ -138,16 +139,24 @@ func TestUpload(t *testing.T) {

mockSFTPClient := mock_sftp.NewMockClient(ctrl)
mockSFTPClient.EXPECT().OpenFile(gomock.Any(), gomock.Any()).Return(&nopReadWriteCloser{remoteBuf}, nil)
mockSFTPClient.EXPECT().MkdirAll(gomock.Any()).Return(nil)

fileManager := &fileManagerImpl{client: mockSFTPClient}
mockSFTPClient.EXPECT().Reset().Return(nil)
callCounter := 0
mockSFTPClient.EXPECT().MkdirAll(gomock.Any()).DoAndReturn(func(_ interface{}) error {
callCounter++
if callCounter == 1 {
return sftp.ErrSshFxConnectionLost
}
return nil
}).Times(2)

fileManager := &fileManagerImpl{client: mockSFTPClient, retryOnIdleConnection: true}

err = fileManager.Upload(localFilePath, "someRemotePath")
require.NoError(t, err)
require.Equal(t, data, remoteBuf.Bytes())
}

func TestDownload(t *testing.T) {
func TestDownloadWithRetry(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

Expand All @@ -162,9 +171,17 @@ func TestDownload(t *testing.T) {
remoteBuf := bytes.NewBuffer(data)

mockSFTPClient := mock_sftp.NewMockClient(ctrl)
mockSFTPClient.EXPECT().OpenFile(gomock.Any(), gomock.Any()).Return(&nopReadWriteCloser{remoteBuf}, nil)

fileManager := &fileManagerImpl{client: mockSFTPClient}
callCounter := 0
mockSFTPClient.EXPECT().OpenFile(gomock.Any(), gomock.Any()).DoAndReturn(func(_, _ interface{}) (io.ReadWriteCloser, error) {
callCounter++
if callCounter == 1 {
return nil, sftp.ErrSSHFxConnectionLost
}
return &nopReadWriteCloser{remoteBuf}, nil
}).Times(2)
mockSFTPClient.EXPECT().Reset().Return(nil)

fileManager := &fileManagerImpl{client: mockSFTPClient, retryOnIdleConnection: true}

err = fileManager.Download(filepath.Join("someRemoteDir", "test_file.json"), localDir)
require.NoError(t, err)
Expand All @@ -173,15 +190,24 @@ func TestDownload(t *testing.T) {
require.Equal(t, data, localFileContents)
}

func TestDelete(t *testing.T) {
func TestDeleteWithRetry(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

remoteFilePath := "someRemoteFilePath"
mockSFTPClient := mock_sftp.NewMockClient(ctrl)
mockSFTPClient.EXPECT().Remove(remoteFilePath).Return(nil)
callCounter := 0
mockSFTPClient.EXPECT().Remove(gomock.Any()).DoAndReturn(func(_ interface{}) error {
callCounter++
if callCounter == 1 {
return sftp.ErrSSHFxConnectionLost
}
return nil
}).Times(2)

fileManager := &fileManagerImpl{client: mockSFTPClient}
mockSFTPClient.EXPECT().Reset().Return(nil)

fileManager := &fileManagerImpl{client: mockSFTPClient, retryOnIdleConnection: true}

err := fileManager.Delete(remoteFilePath)
require.NoError(t, err)
Expand Down Expand Up @@ -211,14 +237,15 @@ func TestSFTP(t *testing.T) {
require.NoError(t, err)
port, err := strconv.Atoi(portStr)
require.NoError(t, err)
sshClient, err := NewSSHClient(&SSHConfig{
sshConfig := &SSHConfig{
User: "linuxserver.io",
HostName: hostname,
Port: port,
AuthMethod: "keyAuth",
PrivateKey: string(privateKey),
DialTimeout: 10 * time.Second,
})
}
sshClient, err := newSSHClient(sshConfig)
require.NoError(t, err)

// Create session
Expand All @@ -230,9 +257,11 @@ func TestSFTP(t *testing.T) {
err = session.Run(fmt.Sprintf("mkdir -p %s", remoteDir))
require.NoError(t, err)

sftpManger, err := NewFileManager(sshClient)
sftpClient, err := newSFTPClient(sshClient)
require.NoError(t, err)

sftpManger := &fileManagerImpl{client: &clientImpl{sftpClient: sftpClient}}

// Create local and remote directories within the temporary directory
baseDir := t.TempDir()
localDir := filepath.Join(baseDir, "local")
Expand Down

0 comments on commit 8383ee7

Please sign in to comment.