Skip to content

Commit

Permalink
fix (shell) : Improve shell detection on windows (#3767)
Browse files Browse the repository at this point in the history
We should detect the usage of $SHELL environment variable when using
CRC from linux like environments on Windows. We should also convert the
CRC binary paths to unix path format whenever unix shells are detected.

Signed-off-by: Rohan Kumar <[email protected]>
  • Loading branch information
rohanKanojia committed Dec 26, 2024
1 parent c72a45f commit ae1fd94
Show file tree
Hide file tree
Showing 6 changed files with 421 additions and 25 deletions.
93 changes: 90 additions & 3 deletions pkg/os/shell/shell.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,17 @@ package shell

import (
"fmt"
"os"
"sort"
"strconv"
"strings"

crcos "github.com/crc-org/crc/v2/pkg/os"
)

var (
CommandRunner = crcos.NewLocalCommandRunner()
WindowsSubsystemLinuxKernelMetadataFile = "/proc/version"
)

type Config struct {
Expand Down Expand Up @@ -65,9 +75,9 @@ func GetEnvString(userShell string, envName string, envValue string) string {
case "cmd":
return fmt.Sprintf("SET %s=%s", envName, envValue)
case "fish":
return fmt.Sprintf("contains %s $fish_user_paths; or set -U fish_user_paths %s $fish_user_paths", envValue, envValue)
return fmt.Sprintf("contains %s $fish_user_paths; or set -U fish_user_paths %s $fish_user_paths", convertToLinuxStylePath(userShell, envValue), convertToLinuxStylePath(userShell, envValue))
default:
return fmt.Sprintf("export %s=\"%s\"", envName, envValue)
return fmt.Sprintf("export %s=\"%s\"", envName, convertToLinuxStylePath(userShell, envValue))
}
}

Expand All @@ -81,8 +91,85 @@ func GetPathEnvString(userShell string, prependedPath string) string {
case "cmd":
pathStr = fmt.Sprintf("%s;%%PATH%%", prependedPath)
default:
pathStr = fmt.Sprintf("%s:$PATH", prependedPath)
pathStr = fmt.Sprintf("%s:$PATH", convertToLinuxStylePath(userShell, prependedPath))
}

return GetEnvString(userShell, "PATH", pathStr)
}

func convertToLinuxStylePath(userShell string, path string) string {
if IsWindowsSubsystemLinux() {
return convertToWindowsSubsystemLinuxPath(path)
}
if strings.Contains(path, "\\") &&
(userShell == "bash" || userShell == "zsh" || userShell == "fish") {
path = strings.ReplaceAll(path, ":", "")
path = strings.ReplaceAll(path, "\\", "/")

return fmt.Sprintf("/%s", path)
}
return path
}

func convertToWindowsSubsystemLinuxPath(path string) string {
stdOut, _, err := CommandRunner.Run("wsl", "-e", "bash", "-c", fmt.Sprintf("wslpath -a '%s'", path))
if err != nil {
return path
}
return strings.TrimSpace(stdOut)
}

func IsWindowsSubsystemLinux() bool {
procVersionContent, err := os.ReadFile(WindowsSubsystemLinuxKernelMetadataFile)
if err != nil {
return false
}
if strings.Contains(string(procVersionContent), "Microsoft") ||
strings.Contains(string(procVersionContent), "WSL") {
return true
}
return false
}

func detectShellByInvokingCommand(defaultShell string, command string, args []string) string {
stdOut, _, err := CommandRunner.Run(command, args...)
if err != nil {
return defaultShell
}

detectedShell := inspectProcessOutputForRecentlyUsedShell(stdOut)
if detectedShell == "" {
return defaultShell
}
return detectedShell
}

func inspectProcessOutputForRecentlyUsedShell(psCommandOutput string) string {
type ProcessOutput struct {
processID int
output string
}
var processOutputs []ProcessOutput
lines := strings.Split(psCommandOutput, "\n")[1:]
for _, line := range lines {
lineParts := strings.Split(strings.TrimSpace(line), " ")
if len(lineParts) == 2 && (strings.Contains(lineParts[1], "zsh") ||
strings.Contains(lineParts[1], "bash") ||
strings.Contains(lineParts[1], "fish")) {
parsedProcessID, err := strconv.Atoi(lineParts[0])
if err == nil {
processOutputs = append(processOutputs, ProcessOutput{
processID: parsedProcessID,
output: lineParts[1],
})
}
}
}
sort.Slice(processOutputs, func(i, j int) bool {
return processOutputs[i].processID > processOutputs[j].processID
})
if len(processOutputs) > 0 {
return processOutputs[0].output
}
return ""
}
225 changes: 211 additions & 14 deletions pkg/os/shell/shell_test.go
Original file line number Diff line number Diff line change
@@ -1,31 +1,228 @@
//go:build !windows
// +build !windows

package shell

import (
"os"
"path/filepath"
"testing"

"github.com/stretchr/testify/assert"
)

func TestDetectBash(t *testing.T) {
defer func(shell string) { os.Setenv("SHELL", shell) }(os.Getenv("SHELL"))
os.Setenv("SHELL", "/bin/bash")
type MockCommandRunner struct {
commandName string
commandArgs []string
expectedOutputToReturn string
expectedErrMessageToReturn string
expectedErrToReturn error
}

func NewMockCommandRunner() *MockCommandRunner {
return NewMockCommandRunnerWithOutputErr("", "", nil)
}

shell, err := detect()
func NewMockCommandRunnerWithOutputErr(output string, errorMsg string, err error) *MockCommandRunner {
return &MockCommandRunner{
commandName: "",
commandArgs: []string{},
expectedOutputToReturn: output,
expectedErrMessageToReturn: errorMsg,
expectedErrToReturn: err,
}
}

assert.Equal(t, "bash", shell)
assert.NoError(t, err)
func (e *MockCommandRunner) Run(command string, args ...string) (string, string, error) {
e.commandName = command
e.commandArgs = args
return e.expectedOutputToReturn, e.expectedErrMessageToReturn, e.expectedErrToReturn
}

func TestDetectFish(t *testing.T) {
defer func(shell string) { os.Setenv("SHELL", shell) }(os.Getenv("SHELL"))
os.Setenv("SHELL", "/bin/fish")
func (e *MockCommandRunner) RunPrivate(command string, args ...string) (string, string, error) {
e.commandName = command
e.commandArgs = args
return e.expectedOutputToReturn, e.expectedErrMessageToReturn, e.expectedErrToReturn
}

func (e *MockCommandRunner) RunPrivileged(_ string, cmdAndArgs ...string) (string, string, error) {
e.commandArgs = cmdAndArgs
return e.expectedOutputToReturn, e.expectedErrMessageToReturn, e.expectedErrToReturn
}

shell, err := detect()
func TestGetPathEnvString(t *testing.T) {
tests := []struct {
name string
userShell string
path string
expectedStr string
}{
{"fish shell", "fish", "C:\\Users\\foo\\.crc\\bin\\oc", "contains /C/Users/foo/.crc/bin/oc $fish_user_paths; or set -U fish_user_paths /C/Users/foo/.crc/bin/oc $fish_user_paths"},
{"powershell shell", "powershell", "C:\\Users\\foo\\oc.exe", "$Env:PATH = \"C:\\Users\\foo\\oc.exe;$Env:PATH\""},
{"cmd shell", "cmd", "C:\\Users\\foo\\oc.exe", "SET PATH=C:\\Users\\foo\\oc.exe;%PATH%"},
{"bash with windows path", "bash", "C:\\Users\\foo.exe", "export PATH=\"/C/Users/foo.exe:$PATH\""},
{"unknown with windows path", "unknown", "C:\\Users\\foo.exe", "export PATH=\"C:\\Users\\foo.exe:$PATH\""},
{"unknown shell with unix path", "unknown", "/home/foo/.crc/bin/oc", "export PATH=\"/home/foo/.crc/bin/oc:$PATH\""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetPathEnvString(tt.userShell, tt.path)
if result != tt.expectedStr {
t.Errorf("GetPathEnvString(%s, %s) = %s; want %s", tt.userShell, tt.path, result, tt.expectedStr)
}
})
}
}

func TestConvertToLinuxStylePath(t *testing.T) {
tests := []struct {
name string
userShell string
path string
expectedPath string
}{
{"bash on windows, should convert", "bash", "C:\\Users\\foo\\.crc\\bin\\oc", "/C/Users/foo/.crc/bin/oc"},
{"zsh on windows, should convert", "zsh", "C:\\Users\\foo\\.crc\\bin\\oc", "/C/Users/foo/.crc/bin/oc"},
{"fish on windows, should convert", "fish", "C:\\Users\\foo\\.crc\\bin\\oc", "/C/Users/foo/.crc/bin/oc"},
{"powershell on windows, should NOT convert", "powershell", "C:\\Users\\foo\\.crc\\bin\\oc", "C:\\Users\\foo\\.crc\\bin\\oc"},
{"cmd on windows, should NOT convert", "cmd", "C:\\Users\\foo\\.crc\\bin\\oc", "C:\\Users\\foo\\.crc\\bin\\oc"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := convertToLinuxStylePath(tt.userShell, tt.path)
if result != tt.expectedPath {
t.Errorf("convertToLinuxStylePath(%s, %s) = %s; want %s", tt.userShell, tt.path, result, tt.expectedPath)
}
})
}
}

assert.Equal(t, "fish", shell)
func TestConvertToLinuxStylePath_WhenRunOnWSL_ThenExecuteWslPathBinary(t *testing.T) {
// Given
dir := t.TempDir()
wslVersionFilePath := filepath.Join(dir, "version")
wslVersionFile, err := os.Create(wslVersionFilePath)
assert.NoError(t, err)
defer func(wslVersionFile *os.File) {
err := wslVersionFile.Close()
assert.NoError(t, err)
}(wslVersionFile)
numberOfBytesWritten, err := wslVersionFile.WriteString("Linux version 5.15.167.4-microsoft-standard-WSL2 (root@f9c826d3017f) (gcc (GCC) 11.2.0, GNU ld (GNU Binutils) 2.37) #1 SMP Tue Nov 5 00:21:55 UTC 2024")
assert.NoError(t, err)
assert.Greater(t, numberOfBytesWritten, 0)
WindowsSubsystemLinuxKernelMetadataFile = wslVersionFilePath
mockCommandExecutor := NewMockCommandRunner()
CommandRunner = mockCommandExecutor
// When
convertToLinuxStylePath("wsl", "C:\\Users\\foo\\.crc\\bin\\oc")
// Then
assert.Equal(t, "wsl", mockCommandExecutor.commandName)
assert.Equal(t, []string{"-e", "bash", "-c", "wslpath -a 'C:\\Users\\foo\\.crc\\bin\\oc'"}, mockCommandExecutor.commandArgs)
}

func TestIsWindowsSubsystemLinux_whenInvalidKernelInfoFile_thenReturnFalse(t *testing.T) {
// Given + When
WindowsSubsystemLinuxKernelMetadataFile = "/i/dont/exist"
// Then
assert.Equal(t, false, IsWindowsSubsystemLinux())
}

func TestIsWindowsSubsystemLinux_whenValidKernelInfoFile_thenReturnTrue(t *testing.T) {
tests := []struct {
name string
versionFileContent string
expectedResult bool
}{
{
"version file contains WSL and Microsoft keywords, then return true",
"Linux version 5.15.167.4-microsoft-standard-WSL2 (root@f9c826d3017f) (gcc (GCC) 11.2.0, GNU ld (GNU Binutils) 2.37) #1 SMP Tue Nov 5 00:21:55 UTC 2024",
true,
},
{
"version file does NOT contain WSL and Microsoft keywords, then return false",
"invalid",
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Given
dir := t.TempDir()
wslVersionFilePath := filepath.Join(dir, "version")
wslVersionFile, err := os.Create(wslVersionFilePath)
assert.NoError(t, err)
defer func(wslVersionFile *os.File) {
err := wslVersionFile.Close()
assert.NoError(t, err)
err = os.Remove(wslVersionFile.Name())
assert.NoError(t, err)
}(wslVersionFile)
numberOfBytesWritten, err := wslVersionFile.WriteString(tt.versionFileContent)
assert.NoError(t, err)
assert.Greater(t, numberOfBytesWritten, 0)
WindowsSubsystemLinuxKernelMetadataFile = wslVersionFilePath
// When
result := IsWindowsSubsystemLinux()

// Then
assert.Equal(t, tt.expectedResult, result)
})
}
}

func TestConvertToWindowsSubsystemLinuxPath(t *testing.T) {
// Given
mockCommandExecutor := NewMockCommandRunner()
CommandRunner = mockCommandExecutor

// When
convertToWindowsSubsystemLinuxPath("C:\\Users\\foo\\.crc\\bin\\oc")

// Then
assert.Equal(t, "wsl", mockCommandExecutor.commandName)
assert.Equal(t, []string{"-e", "bash", "-c", "wslpath -a 'C:\\Users\\foo\\.crc\\bin\\oc'"}, mockCommandExecutor.commandArgs)
}

func TestInspectProcessForRecentlyUsedShell(t *testing.T) {
tests := []struct {
name string
psCommandOutput string
expectedShellType string
}{
{
"nothing provided, then return empty string",
"",
"",
},
{
"bash shell, then detect bash shell",
" PID COMMAND\n 435 bash\n 31162 ps",
"bash",
},
{
"zsh shell, then detect zsh shell",
" PID COMMAND\n 31253 zsh\n 31259 ps",
"zsh",
},
{
"fish shell, then detect fish shell",
" PID COMMAND\n 31352 fish\n 31372 ps",
"fish",
},
{"bash and zsh shell, then detect zsh with more recent process id",
" PID COMMAND\n 435 bash\n 31253 zsh\n 31259 ps",
"zsh",
},
{"bash and fish shell, then detect fish shell with more recent process id",
" PID COMMAND\n 435 bash\n 31352 fish\n 31372 ps",
"fish",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Given + When
result := inspectProcessOutputForRecentlyUsedShell(tt.psCommandOutput)
// Then
if result != tt.expectedShellType {
t.Errorf("%s inspectProcessOutputForRecentlyUsedShell() = %s; want %s", tt.name, result, tt.expectedShellType)
}
})
}
}
8 changes: 3 additions & 5 deletions pkg/os/shell/shell_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ package shell
import (
"errors"
"fmt"
"os"
"path/filepath"
)

Expand All @@ -16,12 +15,11 @@ var (

// detect detects user's current shell.
func detect() (string, error) {
shell := os.Getenv("SHELL")

if shell == "" {
detectedShell := detectShellByInvokingCommand("", "ps", []string{"-o", "pid,comm"})
if detectedShell == "" {
fmt.Printf("The default lines below are for a sh/bash shell, you can specify the shell you're using, with the --shell flag.\n\n")
return "", ErrUnknownShell
}

return filepath.Base(shell), nil
return filepath.Base(detectedShell), nil
}
Loading

0 comments on commit ae1fd94

Please sign in to comment.