diff --git a/mount/mount_unix.go b/mount/mount_unix.go index a250bfc8..0ccbc7d5 100644 --- a/mount/mount_unix.go +++ b/mount/mount_unix.go @@ -4,7 +4,9 @@ package mount import ( "fmt" + "path" "sort" + "strings" "github.com/moby/sys/mountinfo" "golang.org/x/sys/unix" @@ -37,35 +39,56 @@ func Unmount(target string) error { } } -// RecursiveUnmount unmounts the target and all mounts underneath, starting -// with the deepest mount first. The argument does not have to be a mount -// point itself. -func RecursiveUnmount(target string) error { - // Fast path, works if target is a mount point that can be unmounted. +// UnmountAll unmounts all mounts and submounts underneath parent, +func UnmountAll(parent string) error { + // Get all mounts in "parent" + mounts, err := mountinfo.GetMounts(mountinfo.PrefixFilter(parent)) + if err != nil { + return err + } + + // Fast path: try to unmount top-level mounts first. This works if target is + // a mount point that can be unmounted. // On Linux, mntDetach flag ensures a recursive unmount. For other // platforms, if there are submounts, we'll get EBUSY (and fall back - // to the slow path). NOTE we do not ignore EINVAL here as target might - // not be a mount point itself (but there can be mounts underneath). - if err := unix.Unmount(target, mntDetach); err == nil { - return nil + // to the slow path). We're not using RecursiveUnmount() here, to avoid + // repeatedly calling mountinfo.GetMounts() + + var skipParents []string + for _, m := range mounts { + // Skip parent itself, and skip non-top-level mounts + if m.Mountpoint == parent || path.Dir(m.Mountpoint) != parent { + continue + } + if err := unix.Unmount(m.Mountpoint, mntDetach); err == nil { + skipParents = append(skipParents, m.Mountpoint) + } } - // Slow path: get all submounts, sort, unmount one by one. - mounts, err := mountinfo.GetMounts(mountinfo.PrefixFilter(target)) - if err != nil { - return err + // Remove all sub-mounts of paths that were successfully unmounted from the list + subMounts := mounts[:0] + for _, m := range mounts { + for _, p := range skipParents { + if m.Mountpoint == parent || m.Mountpoint == p { + // Skip parent itself, and mounts that already were unmounted + continue + } + if !strings.HasPrefix(m.Mountpoint, p) { + subMounts = append(subMounts, m) + } + } } // Make the deepest mount be first - sort.Slice(mounts, func(i, j int) bool { - return len(mounts[i].Mountpoint) > len(mounts[j].Mountpoint) + sort.Slice(subMounts, func(i, j int) bool { + return len(subMounts[i].Mountpoint) > len(subMounts[j].Mountpoint) }) var ( suberr error lastMount = len(mounts) - 1 ) - for i, m := range mounts { + for i, m := range subMounts { err = Unmount(m.Mountpoint) if err != nil { if i == lastMount { @@ -85,3 +108,20 @@ func RecursiveUnmount(target string) error { } return nil } + +// RecursiveUnmount unmounts the target and all mounts underneath, starting +// with the deepest mount first. The argument does not have to be a mount +// point itself. +func RecursiveUnmount(target string) error { + // Fast path, works if target is a mount point that can be unmounted. + // On Linux, mntDetach flag ensures a recursive unmount. For other + // platforms, if there are submounts, we'll get EBUSY (and fall back + // to the slow path). NOTE we do not ignore EINVAL here as target might + // not be a mount point itself (but there can be mounts underneath). + if err := unix.Unmount(target, mntDetach); err == nil { + return nil + } + + // Slow path: unmount all mounts inside target one by one. + return UnmountAll(target) +}