Skip to content

Commit

Permalink
Merge pull request #80 from datadius/preserve-protocol
Browse files Browse the repository at this point in the history
Handle T message and prepare for adding -p option
  • Loading branch information
bramvdbogaerde authored Apr 28, 2024
2 parents bd16750 + b4cd115 commit db7cf4f
Show file tree
Hide file tree
Showing 2 changed files with 179 additions and 77 deletions.
86 changes: 59 additions & 27 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ package scp
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"io/ioutil"
Expand Down Expand Up @@ -85,13 +84,24 @@ func (a *Client) SSHClient() *ssh.Client {
}

// CopyFromFile copies the contents of an os.File to a remote location, it will get the length of the file by looking it up from the filesystem.
func (a *Client) CopyFromFile(ctx context.Context, file os.File, remotePath string, permissions string) error {
func (a *Client) CopyFromFile(
ctx context.Context,
file os.File,
remotePath string,
permissions string,
) error {
return a.CopyFromFilePassThru(ctx, file, remotePath, permissions, nil)
}

// CopyFromFilePassThru copies the contents of an os.File to a remote location, it will get the length of the file by looking it up from the filesystem.
// Access copied bytes by providing a PassThru reader factory.
func (a *Client) CopyFromFilePassThru(ctx context.Context, file os.File, remotePath string, permissions string, passThru PassThru) error {
func (a *Client) CopyFromFilePassThru(
ctx context.Context,
file os.File,
remotePath string,
permissions string,
passThru PassThru,
) error {
stat, err := file.Stat()
if err != nil {
return fmt.Errorf("failed to stat file: %w", err)
Expand All @@ -101,21 +111,39 @@ func (a *Client) CopyFromFilePassThru(ctx context.Context, file os.File, remoteP

// CopyFile copies the contents of an io.Reader to a remote location, the length is determined by reading the io.Reader until EOF
// if the file length in know in advance please use "Copy" instead.
func (a *Client) CopyFile(ctx context.Context, fileReader io.Reader, remotePath string, permissions string) error {
func (a *Client) CopyFile(
ctx context.Context,
fileReader io.Reader,
remotePath string,
permissions string,
) error {
return a.CopyFilePassThru(ctx, fileReader, remotePath, permissions, nil)
}

// CopyFilePassThru copies the contents of an io.Reader to a remote location, the length is determined by reading the io.Reader until EOF
// if the file length in know in advance please use "Copy" instead.
// Access copied bytes by providing a PassThru reader factory.
func (a *Client) CopyFilePassThru(ctx context.Context, fileReader io.Reader, remotePath string, permissions string, passThru PassThru) error {
func (a *Client) CopyFilePassThru(
ctx context.Context,
fileReader io.Reader,
remotePath string,
permissions string,
passThru PassThru,
) error {
contentsBytes, err := ioutil.ReadAll(fileReader)
if err != nil {
return fmt.Errorf("failed to read all data from reader: %w", err)
}
bytesReader := bytes.NewReader(contentsBytes)

return a.CopyPassThru(ctx, bytesReader, remotePath, permissions, int64(len(contentsBytes)), passThru)
return a.CopyPassThru(
ctx,
bytesReader,
remotePath,
permissions,
int64(len(contentsBytes)),
passThru,
)
}

// wait waits for the waitgroup for the specified max timeout.
Expand All @@ -139,27 +167,36 @@ func wait(wg *sync.WaitGroup, ctx context.Context) error {
// checkResponse checks the response it reads from the remote, and will return a single error in case
// of failure.
func checkResponse(r io.Reader) error {
response, err := ParseResponse(r)
_, err := ParseResponse(r, nil)
if err != nil {
return err
}

if response.IsFailure() {
return errors.New(response.GetMessage())
}

return nil

}

// Copy copies the contents of an io.Reader to a remote location.
func (a *Client) Copy(ctx context.Context, r io.Reader, remotePath string, permissions string, size int64) error {
func (a *Client) Copy(
ctx context.Context,
r io.Reader,
remotePath string,
permissions string,
size int64,
) error {
return a.CopyPassThru(ctx, r, remotePath, permissions, size, nil)
}

// CopyPassThru copies the contents of an io.Reader to a remote location.
// Access copied bytes by providing a PassThru reader factory
func (a *Client) CopyPassThru(ctx context.Context, r io.Reader, remotePath string, permissions string, size int64, passThru PassThru) error {
func (a *Client) CopyPassThru(
ctx context.Context,
r io.Reader,
remotePath string,
permissions string,
size int64,
passThru PassThru,
) error {
session, err := a.sshClient.NewSession()
if err != nil {
return fmt.Errorf("Error creating ssh session in copy to remote: %v", err)
Expand Down Expand Up @@ -272,7 +309,12 @@ func (a *Client) CopyFromRemote(ctx context.Context, file *os.File, remotePath s
// CopyFromRemotePassThru copies a file from the remote to the given writer. The passThru parameter can be used
// to keep track of progress and how many bytes that were download from the remote.
// `passThru` can be set to nil to disable this behaviour.
func (a *Client) CopyFromRemotePassThru(ctx context.Context, w io.Writer, remotePath string, passThru PassThru) error {
func (a *Client) CopyFromRemotePassThru(
ctx context.Context,
w io.Writer,
remotePath string,
passThru PassThru,
) error {
session, err := a.sshClient.NewSession()
if err != nil {
return fmt.Errorf("Error creating ssh session in copy from remote: %v", err)
Expand Down Expand Up @@ -319,17 +361,7 @@ func (a *Client) CopyFromRemotePassThru(ctx context.Context, w io.Writer, remote
return
}

res, err := ParseResponse(r)
if err != nil {
errCh <- err
return
}
if res.IsFailure() {
errCh <- errors.New(res.GetMessage())
return
}

infos, err := res.ParseFileInfos()
fileInfo, err := ParseResponse(r, in)
if err != nil {
errCh <- err
return
Expand All @@ -342,10 +374,10 @@ func (a *Client) CopyFromRemotePassThru(ctx context.Context, w io.Writer, remote
}

if passThru != nil {
r = passThru(r, infos.Size)
r = passThru(r, fileInfo.Size)
}

_, err = CopyN(w, r, infos.Size)
_, err = CopyN(w, r, fileInfo.Size)
if err != nil {
errCh <- err
return
Expand Down
170 changes: 120 additions & 50 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,42 +9,30 @@ package scp
import (
"bufio"
"errors"
"fmt"
"io"
"strconv"
"strings"
)

type ResponseType = uint8
type ResponseType = byte

const (
Ok ResponseType = 0
Warning ResponseType = 1
Error ResponseType = 2
Create ResponseType = 'C'
Time ResponseType = 'T'
)

// Response represent a response from the SCP command.
// There are tree types of responses that the remote can send back:
// ok, warning and error
//
// The difference between warning and error is that the connection is not closed by the remote,
// however, a warning can indicate a file transfer failure (such as invalid destination directory)
// and such be handled as such.
//
// All responses except for the `Ok` type always have a message (although these can be empty)
//
// The remote sends a confirmation after every SCP command, because a failure can occur after every
// command, the response should be read and checked after sending them.
type Response struct {
Type ResponseType
Message string
}

// ParseResponse reads from the given reader (assuming it is the output of the remote) and parses it into a Response structure.
func ParseResponse(reader io.Reader) (Response, error) {
func ParseResponse(reader io.Reader, writer io.Writer) (*FileInfos, error) {
fileInfos := NewFileInfos()

buffer := make([]uint8, 1)
_, err := reader.Read(buffer)
if err != nil {
return Response{}, err
return fileInfos, err
}

responseType := buffer[0]
Expand All @@ -53,61 +41,143 @@ func ParseResponse(reader io.Reader) (Response, error) {
bufferedReader := bufio.NewReader(reader)
message, err = bufferedReader.ReadString('\n')
if err != nil {
return Response{}, err
return fileInfos, err
}
}

return Response{responseType, message}, nil
}
if responseType == Warning || responseType == Error {
return fileInfos, errors.New(message)
}

func (r *Response) IsOk() bool {
return r.Type == Ok
}
// Exit early because we're only interested in the ok response
if responseType == Ok {
return fileInfos, nil
}

func (r *Response) IsWarning() bool {
return r.Type == Warning
}
if !(responseType == Create || responseType == Time) {
return fileInfos, errors.New(
fmt.Sprintf(
"Message does not follow scp protocol: %s\n Cmmmm <length> <filename> or T<mtime> 0 <atime> 0",
message,
),
)
}

// IsError returns true when the remote responded with an error.
func (r *Response) IsError() bool {
return r.Type == Error
}
if responseType == Time {
err = ParseFileTime(message, fileInfos)
if err != nil {
return nil, err
}

message, err = bufferedReader.ReadString('\n')
if err == io.EOF {
err = Ack(writer)
if err != nil {
return fileInfos, err
}
message, err = bufferedReader.ReadString('\n')

if err != nil {
return fileInfos, err
}
}

if err != nil && err != io.EOF {
return fileInfos, err
}

responseType = message[0]
}

// IsFailure returns true when the remote answered with a warning or an error.
func (r *Response) IsFailure() bool {
return r.IsWarning() || r.IsError()
}
if responseType == Create {
err = ParseFileInfos(message, fileInfos)
if err != nil {
return nil, err
}
}
}

// GetMessage returns the message the remote sent back.
func (r *Response) GetMessage() string {
return r.Message
return fileInfos, nil
}

type FileInfos struct {
Message string
Filename string
Permissions string
Size int64
Atime int64
Mtime int64
}

func NewFileInfos() *FileInfos {
return &FileInfos{}
}

func (r *Response) ParseFileInfos() (*FileInfos, error) {
message := strings.ReplaceAll(r.Message, "\n", "")
parts := strings.Split(message, " ")
func (fileInfos *FileInfos) Update(new *FileInfos) {
if new == nil {
return
}
if new.Filename != "" {
fileInfos.Filename = new.Filename
}
if new.Permissions != "" {
fileInfos.Permissions = new.Permissions
}
if new.Size != 0 {
fileInfos.Size = new.Size
}
if new.Atime != 0 {
fileInfos.Atime = new.Atime
}
if new.Mtime != 0 {
fileInfos.Mtime = new.Mtime
}
}

func ParseFileInfos(message string, fileInfos *FileInfos) error {
processMessage := strings.ReplaceAll(message, "\n", "")
parts := strings.Split(processMessage, " ")
if len(parts) < 3 {
return nil, errors.New("unable to parse message as file infos")
return errors.New("unable to parse Chmod protocol")
}

size, err := strconv.Atoi(parts[1])
if err != nil {
return nil, err
return err
}

return &FileInfos{
Message: r.Message,
fileInfos.Update(&FileInfos{
Filename: parts[2],
Permissions: parts[0],
Size: int64(size),
Filename: parts[2],
}, nil
})

return nil
}

func ParseFileTime(
message string,
fileInfos *FileInfos,
) error {
processMessage := strings.ReplaceAll(message, "\n", "")
parts := strings.Split(processMessage, " ")
if len(parts) < 3 {
return errors.New("unable to parse Time protocol")
}

aTime, err := strconv.Atoi(string(parts[0][0:10]))
if err != nil {
return errors.New("unable to parse ATime component of message")
}
mTime, err := strconv.Atoi(string(parts[2][0:10]))
if err != nil {
return errors.New("unable to parse MTime component of message")
}

fileInfos.Update(&FileInfos{
Atime: int64(aTime),
Mtime: int64(mTime),
})
return nil
}

// Ack writes an `Ack` message to the remote, does not await its response, a seperate call to ParseResponse is
Expand Down

0 comments on commit db7cf4f

Please sign in to comment.