From ae1fd941bdc6567e59050320003eabad7342cf6e Mon Sep 17 00:00:00 2001 From: Rohan Kumar Date: Thu, 19 Dec 2024 12:24:11 +0530 Subject: [PATCH] fix (shell) : Improve shell detection on windows (#3767) We should detect the usage of $SHELL environment variable when using CRC from linux like environments on Windows. We should also convert the CRC binary paths to unix path format whenever unix shells are detected. Signed-off-by: Rohan Kumar --- pkg/os/shell/shell.go | 93 +++++++++++- pkg/os/shell/shell_test.go | 225 +++++++++++++++++++++++++++-- pkg/os/shell/shell_unix.go | 8 +- pkg/os/shell/shell_unix_test.go | 69 ++++++++- pkg/os/shell/shell_windows.go | 7 +- pkg/os/shell/shell_windows_test.go | 44 ++++++ 6 files changed, 421 insertions(+), 25 deletions(-) diff --git a/pkg/os/shell/shell.go b/pkg/os/shell/shell.go index ccd776edc6..7bc796abfc 100644 --- a/pkg/os/shell/shell.go +++ b/pkg/os/shell/shell.go @@ -2,7 +2,17 @@ package shell import ( "fmt" + "os" + "sort" + "strconv" "strings" + + crcos "github.com/crc-org/crc/v2/pkg/os" +) + +var ( + CommandRunner = crcos.NewLocalCommandRunner() + WindowsSubsystemLinuxKernelMetadataFile = "/proc/version" ) type Config struct { @@ -65,9 +75,9 @@ func GetEnvString(userShell string, envName string, envValue string) string { case "cmd": return fmt.Sprintf("SET %s=%s", envName, envValue) case "fish": - return fmt.Sprintf("contains %s $fish_user_paths; or set -U fish_user_paths %s $fish_user_paths", envValue, envValue) + return fmt.Sprintf("contains %s $fish_user_paths; or set -U fish_user_paths %s $fish_user_paths", convertToLinuxStylePath(userShell, envValue), convertToLinuxStylePath(userShell, envValue)) default: - return fmt.Sprintf("export %s=\"%s\"", envName, envValue) + return fmt.Sprintf("export %s=\"%s\"", envName, convertToLinuxStylePath(userShell, envValue)) } } @@ -81,8 +91,85 @@ func GetPathEnvString(userShell string, prependedPath string) string { case "cmd": pathStr = fmt.Sprintf("%s;%%PATH%%", prependedPath) default: - pathStr = fmt.Sprintf("%s:$PATH", prependedPath) + pathStr = fmt.Sprintf("%s:$PATH", convertToLinuxStylePath(userShell, prependedPath)) } return GetEnvString(userShell, "PATH", pathStr) } + +func convertToLinuxStylePath(userShell string, path string) string { + if IsWindowsSubsystemLinux() { + return convertToWindowsSubsystemLinuxPath(path) + } + if strings.Contains(path, "\\") && + (userShell == "bash" || userShell == "zsh" || userShell == "fish") { + path = strings.ReplaceAll(path, ":", "") + path = strings.ReplaceAll(path, "\\", "/") + + return fmt.Sprintf("/%s", path) + } + return path +} + +func convertToWindowsSubsystemLinuxPath(path string) string { + stdOut, _, err := CommandRunner.Run("wsl", "-e", "bash", "-c", fmt.Sprintf("wslpath -a '%s'", path)) + if err != nil { + return path + } + return strings.TrimSpace(stdOut) +} + +func IsWindowsSubsystemLinux() bool { + procVersionContent, err := os.ReadFile(WindowsSubsystemLinuxKernelMetadataFile) + if err != nil { + return false + } + if strings.Contains(string(procVersionContent), "Microsoft") || + strings.Contains(string(procVersionContent), "WSL") { + return true + } + return false +} + +func detectShellByInvokingCommand(defaultShell string, command string, args []string) string { + stdOut, _, err := CommandRunner.Run(command, args...) + if err != nil { + return defaultShell + } + + detectedShell := inspectProcessOutputForRecentlyUsedShell(stdOut) + if detectedShell == "" { + return defaultShell + } + return detectedShell +} + +func inspectProcessOutputForRecentlyUsedShell(psCommandOutput string) string { + type ProcessOutput struct { + processID int + output string + } + var processOutputs []ProcessOutput + lines := strings.Split(psCommandOutput, "\n")[1:] + for _, line := range lines { + lineParts := strings.Split(strings.TrimSpace(line), " ") + if len(lineParts) == 2 && (strings.Contains(lineParts[1], "zsh") || + strings.Contains(lineParts[1], "bash") || + strings.Contains(lineParts[1], "fish")) { + parsedProcessID, err := strconv.Atoi(lineParts[0]) + if err == nil { + processOutputs = append(processOutputs, ProcessOutput{ + processID: parsedProcessID, + output: lineParts[1], + }) + } + } + } + sort.Slice(processOutputs, func(i, j int) bool { + return processOutputs[i].processID > processOutputs[j].processID + }) + if len(processOutputs) > 0 { + return processOutputs[0].output + } + return "" +} diff --git a/pkg/os/shell/shell_test.go b/pkg/os/shell/shell_test.go index e9e7a002b8..d1b5ce6c1e 100644 --- a/pkg/os/shell/shell_test.go +++ b/pkg/os/shell/shell_test.go @@ -1,31 +1,228 @@ -//go:build !windows -// +build !windows - package shell import ( "os" + "path/filepath" "testing" "github.com/stretchr/testify/assert" ) -func TestDetectBash(t *testing.T) { - defer func(shell string) { os.Setenv("SHELL", shell) }(os.Getenv("SHELL")) - os.Setenv("SHELL", "/bin/bash") +type MockCommandRunner struct { + commandName string + commandArgs []string + expectedOutputToReturn string + expectedErrMessageToReturn string + expectedErrToReturn error +} + +func NewMockCommandRunner() *MockCommandRunner { + return NewMockCommandRunnerWithOutputErr("", "", nil) +} - shell, err := detect() +func NewMockCommandRunnerWithOutputErr(output string, errorMsg string, err error) *MockCommandRunner { + return &MockCommandRunner{ + commandName: "", + commandArgs: []string{}, + expectedOutputToReturn: output, + expectedErrMessageToReturn: errorMsg, + expectedErrToReturn: err, + } +} - assert.Equal(t, "bash", shell) - assert.NoError(t, err) +func (e *MockCommandRunner) Run(command string, args ...string) (string, string, error) { + e.commandName = command + e.commandArgs = args + return e.expectedOutputToReturn, e.expectedErrMessageToReturn, e.expectedErrToReturn } -func TestDetectFish(t *testing.T) { - defer func(shell string) { os.Setenv("SHELL", shell) }(os.Getenv("SHELL")) - os.Setenv("SHELL", "/bin/fish") +func (e *MockCommandRunner) RunPrivate(command string, args ...string) (string, string, error) { + e.commandName = command + e.commandArgs = args + return e.expectedOutputToReturn, e.expectedErrMessageToReturn, e.expectedErrToReturn +} + +func (e *MockCommandRunner) RunPrivileged(_ string, cmdAndArgs ...string) (string, string, error) { + e.commandArgs = cmdAndArgs + return e.expectedOutputToReturn, e.expectedErrMessageToReturn, e.expectedErrToReturn +} - shell, err := detect() +func TestGetPathEnvString(t *testing.T) { + tests := []struct { + name string + userShell string + path string + expectedStr string + }{ + {"fish shell", "fish", "C:\\Users\\foo\\.crc\\bin\\oc", "contains /C/Users/foo/.crc/bin/oc $fish_user_paths; or set -U fish_user_paths /C/Users/foo/.crc/bin/oc $fish_user_paths"}, + {"powershell shell", "powershell", "C:\\Users\\foo\\oc.exe", "$Env:PATH = \"C:\\Users\\foo\\oc.exe;$Env:PATH\""}, + {"cmd shell", "cmd", "C:\\Users\\foo\\oc.exe", "SET PATH=C:\\Users\\foo\\oc.exe;%PATH%"}, + {"bash with windows path", "bash", "C:\\Users\\foo.exe", "export PATH=\"/C/Users/foo.exe:$PATH\""}, + {"unknown with windows path", "unknown", "C:\\Users\\foo.exe", "export PATH=\"C:\\Users\\foo.exe:$PATH\""}, + {"unknown shell with unix path", "unknown", "/home/foo/.crc/bin/oc", "export PATH=\"/home/foo/.crc/bin/oc:$PATH\""}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := GetPathEnvString(tt.userShell, tt.path) + if result != tt.expectedStr { + t.Errorf("GetPathEnvString(%s, %s) = %s; want %s", tt.userShell, tt.path, result, tt.expectedStr) + } + }) + } +} + +func TestConvertToLinuxStylePath(t *testing.T) { + tests := []struct { + name string + userShell string + path string + expectedPath string + }{ + {"bash on windows, should convert", "bash", "C:\\Users\\foo\\.crc\\bin\\oc", "/C/Users/foo/.crc/bin/oc"}, + {"zsh on windows, should convert", "zsh", "C:\\Users\\foo\\.crc\\bin\\oc", "/C/Users/foo/.crc/bin/oc"}, + {"fish on windows, should convert", "fish", "C:\\Users\\foo\\.crc\\bin\\oc", "/C/Users/foo/.crc/bin/oc"}, + {"powershell on windows, should NOT convert", "powershell", "C:\\Users\\foo\\.crc\\bin\\oc", "C:\\Users\\foo\\.crc\\bin\\oc"}, + {"cmd on windows, should NOT convert", "cmd", "C:\\Users\\foo\\.crc\\bin\\oc", "C:\\Users\\foo\\.crc\\bin\\oc"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := convertToLinuxStylePath(tt.userShell, tt.path) + if result != tt.expectedPath { + t.Errorf("convertToLinuxStylePath(%s, %s) = %s; want %s", tt.userShell, tt.path, result, tt.expectedPath) + } + }) + } +} - assert.Equal(t, "fish", shell) +func TestConvertToLinuxStylePath_WhenRunOnWSL_ThenExecuteWslPathBinary(t *testing.T) { + // Given + dir := t.TempDir() + wslVersionFilePath := filepath.Join(dir, "version") + wslVersionFile, err := os.Create(wslVersionFilePath) assert.NoError(t, err) + defer func(wslVersionFile *os.File) { + err := wslVersionFile.Close() + assert.NoError(t, err) + }(wslVersionFile) + numberOfBytesWritten, err := wslVersionFile.WriteString("Linux version 5.15.167.4-microsoft-standard-WSL2 (root@f9c826d3017f) (gcc (GCC) 11.2.0, GNU ld (GNU Binutils) 2.37) #1 SMP Tue Nov 5 00:21:55 UTC 2024") + assert.NoError(t, err) + assert.Greater(t, numberOfBytesWritten, 0) + WindowsSubsystemLinuxKernelMetadataFile = wslVersionFilePath + mockCommandExecutor := NewMockCommandRunner() + CommandRunner = mockCommandExecutor + // When + convertToLinuxStylePath("wsl", "C:\\Users\\foo\\.crc\\bin\\oc") + // Then + assert.Equal(t, "wsl", mockCommandExecutor.commandName) + assert.Equal(t, []string{"-e", "bash", "-c", "wslpath -a 'C:\\Users\\foo\\.crc\\bin\\oc'"}, mockCommandExecutor.commandArgs) +} + +func TestIsWindowsSubsystemLinux_whenInvalidKernelInfoFile_thenReturnFalse(t *testing.T) { + // Given + When + WindowsSubsystemLinuxKernelMetadataFile = "/i/dont/exist" + // Then + assert.Equal(t, false, IsWindowsSubsystemLinux()) +} + +func TestIsWindowsSubsystemLinux_whenValidKernelInfoFile_thenReturnTrue(t *testing.T) { + tests := []struct { + name string + versionFileContent string + expectedResult bool + }{ + { + "version file contains WSL and Microsoft keywords, then return true", + "Linux version 5.15.167.4-microsoft-standard-WSL2 (root@f9c826d3017f) (gcc (GCC) 11.2.0, GNU ld (GNU Binutils) 2.37) #1 SMP Tue Nov 5 00:21:55 UTC 2024", + true, + }, + { + "version file does NOT contain WSL and Microsoft keywords, then return false", + "invalid", + false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Given + dir := t.TempDir() + wslVersionFilePath := filepath.Join(dir, "version") + wslVersionFile, err := os.Create(wslVersionFilePath) + assert.NoError(t, err) + defer func(wslVersionFile *os.File) { + err := wslVersionFile.Close() + assert.NoError(t, err) + err = os.Remove(wslVersionFile.Name()) + assert.NoError(t, err) + }(wslVersionFile) + numberOfBytesWritten, err := wslVersionFile.WriteString(tt.versionFileContent) + assert.NoError(t, err) + assert.Greater(t, numberOfBytesWritten, 0) + WindowsSubsystemLinuxKernelMetadataFile = wslVersionFilePath + // When + result := IsWindowsSubsystemLinux() + + // Then + assert.Equal(t, tt.expectedResult, result) + }) + } +} + +func TestConvertToWindowsSubsystemLinuxPath(t *testing.T) { + // Given + mockCommandExecutor := NewMockCommandRunner() + CommandRunner = mockCommandExecutor + + // When + convertToWindowsSubsystemLinuxPath("C:\\Users\\foo\\.crc\\bin\\oc") + + // Then + assert.Equal(t, "wsl", mockCommandExecutor.commandName) + assert.Equal(t, []string{"-e", "bash", "-c", "wslpath -a 'C:\\Users\\foo\\.crc\\bin\\oc'"}, mockCommandExecutor.commandArgs) +} + +func TestInspectProcessForRecentlyUsedShell(t *testing.T) { + tests := []struct { + name string + psCommandOutput string + expectedShellType string + }{ + { + "nothing provided, then return empty string", + "", + "", + }, + { + "bash shell, then detect bash shell", + " PID COMMAND\n 435 bash\n 31162 ps", + "bash", + }, + { + "zsh shell, then detect zsh shell", + " PID COMMAND\n 31253 zsh\n 31259 ps", + "zsh", + }, + { + "fish shell, then detect fish shell", + " PID COMMAND\n 31352 fish\n 31372 ps", + "fish", + }, + {"bash and zsh shell, then detect zsh with more recent process id", + " PID COMMAND\n 435 bash\n 31253 zsh\n 31259 ps", + "zsh", + }, + {"bash and fish shell, then detect fish shell with more recent process id", + " PID COMMAND\n 435 bash\n 31352 fish\n 31372 ps", + "fish", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Given + When + result := inspectProcessOutputForRecentlyUsedShell(tt.psCommandOutput) + // Then + if result != tt.expectedShellType { + t.Errorf("%s inspectProcessOutputForRecentlyUsedShell() = %s; want %s", tt.name, result, tt.expectedShellType) + } + }) + } } diff --git a/pkg/os/shell/shell_unix.go b/pkg/os/shell/shell_unix.go index e017d38df7..b935ba6f87 100644 --- a/pkg/os/shell/shell_unix.go +++ b/pkg/os/shell/shell_unix.go @@ -6,7 +6,6 @@ package shell import ( "errors" "fmt" - "os" "path/filepath" ) @@ -16,12 +15,11 @@ var ( // detect detects user's current shell. func detect() (string, error) { - shell := os.Getenv("SHELL") - - if shell == "" { + detectedShell := detectShellByInvokingCommand("", "ps", []string{"-o", "pid,comm"}) + if detectedShell == "" { fmt.Printf("The default lines below are for a sh/bash shell, you can specify the shell you're using, with the --shell flag.\n\n") return "", ErrUnknownShell } - return filepath.Base(shell), nil + return filepath.Base(detectedShell), nil } diff --git a/pkg/os/shell/shell_unix_test.go b/pkg/os/shell/shell_unix_test.go index 18fb907f65..6c59cdc4f2 100644 --- a/pkg/os/shell/shell_unix_test.go +++ b/pkg/os/shell/shell_unix_test.go @@ -4,6 +4,7 @@ package shell import ( + "bytes" "os" "testing" @@ -11,11 +12,75 @@ import ( ) func TestUnknownShell(t *testing.T) { - defer func(shell string) { os.Setenv("SHELL", shell) }(os.Getenv("SHELL")) - os.Setenv("SHELL", "") + // Given + mockCommandExecutor := NewMockCommandRunnerWithOutputErr("", "", nil) + CommandRunner = mockCommandExecutor + originalStdout := os.Stdout + r, w, _ := os.Pipe() + os.Stdout = w + // When shell, err := detect() + // Then assert.Equal(t, err, ErrUnknownShell) + err = w.Close() + assert.NoError(t, err) + os.Stdout = originalStdout + var buf bytes.Buffer + nBytesRead, err := buf.ReadFrom(r) + assert.NoError(t, err) + assert.Greater(t, nBytesRead, int64(0)) + assert.Equal(t, "The default lines below are for a sh/bash shell, you can specify the shell you're using, with the --shell flag.\n\n", buf.String()) + assert.Equal(t, "ps", mockCommandExecutor.commandName) + assert.Equal(t, []string{"-o", "pid,comm"}, mockCommandExecutor.commandArgs) assert.Empty(t, shell) } + +func TestDetect_GivenPsOutputContainsShell_ThenReturnShellProcessWithMostRecentPid(t *testing.T) { + tests := []struct { + name string + psCommandOutput string + expectedShellType string + }{ + { + "bash shell, then detect bash shell", + " PID COMMAND\n 435 bash\n 31162 ps", + "bash", + }, + { + "zsh shell, then detect zsh shell", + " PID COMMAND\n 31253 zsh\n 31259 ps", + "zsh", + }, + { + "fish shell, then detect fish shell", + " PID COMMAND\n 31352 fish\n 31372 ps", + "fish", + }, + {"bash and zsh shell, then detect zsh with more recent process id", + " PID COMMAND\n 435 bash\n 31253 zsh\n 31259 ps", + "zsh", + }, + {"bash and fish shell, then detect fish shell with more recent process id", + " PID COMMAND\n 435 bash\n 31352 fish\n 31372 ps", + "fish", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Given + mockCommandExecutor := NewMockCommandRunnerWithOutputErr(tt.psCommandOutput, "", nil) + CommandRunner = mockCommandExecutor + + // When + shell, err := detect() + + // Then + assert.Equal(t, "ps", mockCommandExecutor.commandName) + assert.Equal(t, []string{"-o", "pid,comm"}, mockCommandExecutor.commandArgs) + assert.Equal(t, tt.expectedShellType, shell) + assert.NoError(t, err) + }) + } +} diff --git a/pkg/os/shell/shell_windows.go b/pkg/os/shell/shell_windows.go index 5075bb7402..041357ad02 100644 --- a/pkg/os/shell/shell_windows.go +++ b/pkg/os/shell/shell_windows.go @@ -4,13 +4,14 @@ import ( "fmt" "math" "os" + "path/filepath" "strings" "syscall" "unsafe" ) var ( - supportedShell = []string{"cmd", "powershell"} + supportedShell = []string{"cmd", "powershell", "bash", "zsh", "fish"} ) // re-implementation of private function in https://github.com/golang/go/blob/master/src/syscall/syscall_windows.go @@ -62,6 +63,10 @@ func shellType(shell string, defaultShell string) string { return "powershell" case strings.Contains(strings.ToLower(shell), "cmd"): return "cmd" + case strings.Contains(strings.ToLower(shell), "wsl"): + return detectShellByInvokingCommand("bash", "wsl", []string{"-e", "bash", "-c", "ps -ao pid,comm"}) + case filepath.IsAbs(shell) && strings.Contains(strings.ToLower(shell), "bash"): + return "bash" default: return defaultShell } diff --git a/pkg/os/shell/shell_windows_test.go b/pkg/os/shell/shell_windows_test.go index 381b2947c9..f6eab50c8d 100644 --- a/pkg/os/shell/shell_windows_test.go +++ b/pkg/os/shell/shell_windows_test.go @@ -43,3 +43,47 @@ func TestGetNameAndItsPpidOfParent(t *testing.T) { assert.Equal(t, "go.exe", shell) assert.NoError(t, err) } + +func TestSupportedShells(t *testing.T) { + assert.Equal(t, []string{"cmd", "powershell", "bash", "zsh", "fish"}, supportedShell) +} + +func TestShellType(t *testing.T) { + tests := []struct { + name string + userShell string + expectedShellType string + }{ + {"git bash", "C:\\Program Files\\Git\\usr\\bin\\bash.exe", "bash"}, + {"windows subsystem for linux", "wsl.exe", "bash"}, + {"powershell", "powershell", "powershell"}, + {"cmd.exe", "cmd.exe", "cmd"}, + {"pwsh", "pwsh.exe", "powershell"}, + {"empty value", "", "cmd"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Given + CommandRunner = NewMockCommandRunner() + // When + result := shellType(tt.userShell, "cmd") + // Then + if result != tt.expectedShellType { + t.Errorf("shellType(%s) = %s; want %s", tt.userShell, result, tt.expectedShellType) + } + }) + } +} + +func TestDetectShellInWindowsSubsystemLinux(t *testing.T) { + // Given + mockCommandExecutor := NewMockCommandRunner() + CommandRunner = mockCommandExecutor + + // When + shellType("wsl.exe", "cmd") + + // Then + assert.Equal(t, "wsl", mockCommandExecutor.commandName) + assert.Equal(t, []string{"-e", "bash", "-c", "ps -ao pid,comm"}, mockCommandExecutor.commandArgs) +}