diff --git a/probe/ssh/ssh.go b/probe/ssh/ssh.go index 382b2146..899a8dd5 100644 --- a/probe/ssh/ssh.go +++ b/probe/ssh/ssh.go @@ -20,6 +20,7 @@ package ssh import ( "bytes" + "context" "fmt" "net" @@ -140,10 +141,10 @@ func (s *Server) DoProbe() (bool, string) { message := "SSH Command has been Run Successfully!" if err != nil { - if _, ok := err.(*ssh.ExitMissingError); ok { - s.exitCode = UnknownExitCode // Error: remote server does not send an exit status - } else if e, ok := err.(*ssh.ExitError); ok { + if e, ok := err.(*ssh.ExitError); ok { s.exitCode = e.ExitStatus() + } else { + s.exitCode = UnknownExitCode } log.Errorf("[%s / %s] %v", s.ProbeKind, s.ProbeName, err) status = false @@ -259,11 +260,21 @@ func (s *Server) RunSSHCmd() (string, error) { var stdoutBuf, stderrBuf bytes.Buffer session.Stdout = &stdoutBuf session.Stderr = &stderrBuf - if err := session.Run(env + global.CommandLine(s.Command, s.Args)); err != nil { - return stderrBuf.String(), err - } - return stdoutBuf.String(), nil + errCh := make(chan error, 1) + go func() { + errCh <- session.Run(env + global.CommandLine(s.Command, s.Args)) + }() + + ctx, cancel := context.WithTimeout(context.Background(), s.Timeout()) + defer cancel() + select { + case <-ctx.Done(): + session.Signal(ssh.SIGINT) + return fmt.Sprintf("timeout after %s", s.Timeout()), ctx.Err() + case err := <-errCh: + return stdoutBuf.String(), err + } } // ExportMetrics export shell metrics