diff --git a/remotefs/download.go b/remotefs/download.go new file mode 100644 index 00000000..35b1b75e --- /dev/null +++ b/remotefs/download.go @@ -0,0 +1,47 @@ +package remotefs + +import ( + "bytes" + "crypto/sha256" + "fmt" + "io" + "os" +) + +// Download a file from the remote host. +func Download(fs FS, src, dst string) error { + remote, err := fs.Open(src) + if err != nil { + return fmt.Errorf("open remote file for download: %w", err) + } + defer remote.Close() + + remoteStat, err := remote.Stat() + if err != nil { + return fmt.Errorf("stat remote file for download: %w", err) + } + + remoteSum := sha256.New() + localSum := sha256.New() + + local, err := os.OpenFile(dst, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, remoteStat.Mode()) + if err != nil { + return fmt.Errorf("open local file for download: %w", err) + } + defer local.Close() + + remoteReader := io.TeeReader(remote, remoteSum) + if _, err := io.Copy(io.MultiWriter(local, localSum), remoteReader); err != nil { + _ = local.Close() + return fmt.Errorf("copy file from remote host: %w", err) + } + if err := local.Close(); err != nil { + return fmt.Errorf("close local file after download: %w", err) + } + + if !bytes.Equal(localSum.Sum(nil), remoteSum.Sum(nil)) { + return fmt.Errorf("downloading %s failed: %w", src, ErrChecksumMismatch) + } + + return nil +} diff --git a/remotefs/downloaddirectory.go b/remotefs/downloaddirectory.go new file mode 100644 index 00000000..de4c77d1 --- /dev/null +++ b/remotefs/downloaddirectory.go @@ -0,0 +1,43 @@ +package remotefs + +import ( + "fmt" + "io/fs" + "os" + "path/filepath" +) + +// DownloadDirectory downloads all files and directories recursively from the remote system to local directory. +func DownloadDirectory(fsys FS, src, dst string) error { + walkErr := fs.WalkDir(fsys, src, func(path string, dir fs.DirEntry, err error) error { + if err != nil { + return fmt.Errorf("walk remote directory: %w", err) + } + + relPath, err := filepath.Rel(src, path) + if err != nil { + return fmt.Errorf("calculate relative path: %w", err) + } + targetPath := filepath.Join(dst, relPath) + + if dir.IsDir() { + dirInfo, err := dir.Info() + if err != nil { + return fmt.Errorf("get dir info: %w", err) + } + if err := os.MkdirAll(targetPath, dirInfo.Mode()&os.ModePerm); err != nil { + return fmt.Errorf("create local directory: %w", err) + } + } else { + if err := Download(fsys, path, targetPath); err != nil { + return fmt.Errorf("download directory: %w", err) + } + } + return nil + }) + + if walkErr != nil { + return fmt.Errorf("walk remote directory tree: %w", walkErr) + } + return nil +} diff --git a/remotefs/uploaddirectory.go b/remotefs/uploaddirectory.go new file mode 100644 index 00000000..bcd74697 --- /dev/null +++ b/remotefs/uploaddirectory.go @@ -0,0 +1,43 @@ +package remotefs + +import ( + "fmt" + "io/fs" + "os" + "path/filepath" +) + +// UploadDirectory uploads all files and directories recursively to the remote system. +func UploadDirectory(fsys FS, src, dst string) error { + walkErr := filepath.WalkDir(src, func(path string, dir fs.DirEntry, err error) error { + if err != nil { + return fmt.Errorf("walk local directory: %w", err) + } + + relPath, err := filepath.Rel(src, path) + if err != nil { + return fmt.Errorf("calculate relative path: %w", err) + } + targetPath := filepath.Join(dst, relPath) + + if dir.IsDir() { + dirInfo, err := dir.Info() + if err != nil { + return fmt.Errorf("get dir info: %w", err) + } + if err := fsys.MkdirAll(targetPath, dirInfo.Mode()&os.ModePerm); err != nil { + return fmt.Errorf("create remote directory: %w", err) + } + } else { + if err := Upload(fsys, path, targetPath); err != nil { + return fmt.Errorf("upload directory: %w", err) + } + } + return nil + }) + + if walkErr != nil { + return fmt.Errorf("walk remote directory tree: %w", walkErr) + } + return nil +}