Skip to content

Commit

Permalink
Merge pull request easybuilders#4453 from boegel/run_shell_cmd_qa
Browse files Browse the repository at this point in the history
implement support for running interactive commands with `run_shell_cmd`
  • Loading branch information
branfosj authored Apr 6, 2024
2 parents 5dac63e + 204f328 commit 471bbcc
Show file tree
Hide file tree
Showing 2 changed files with 344 additions and 21 deletions.
144 changes: 127 additions & 17 deletions easybuild/tools/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
* Ward Poelmans (Ghent University)
"""
import contextlib
import fcntl
import functools
import inspect
import locale
Expand Down Expand Up @@ -202,11 +203,66 @@ def fileprefix_from_cmd(cmd, allowed_chars=False):
return ''.join([c for c in cmd if c in allowed_chars])


def _answer_question(stdout, proc, qa_patterns, qa_wait_patterns):
"""
Private helper function to try and answer questions raised in interactive shell commands.
"""
match_found = False

for question, answers in qa_patterns:
# allow extra whitespace at the end
question += r'[\s\n]*$'
regex = re.compile(question.encode())
res = regex.search(stdout)
if res:
_log.debug(f"Found match for question pattern '{question}' at end of stdout: {stdout[:1000]}")
# if answer is specified as a list, we take the first item as current answer,
# and add it to the back of the list (so we cycle through answers)
if isinstance(answers, list):
answer = answers.pop(0)
answers.append(answer)
elif isinstance(answers, str):
answer = answers
else:
raise EasyBuildError(f"Unknown type of answers encountered for question ({question}): {answers}")

# answer may need to be completed via pattern extracted from question
_log.debug(f"Raw answer for question pattern '{question}': {answer}")
answer = answer % {k: v.decode() for (k, v) in res.groupdict().items()}
answer += '\n'
_log.info(f"Found match for question pattern '{question}', replying with: {answer}")

try:
os.write(proc.stdin.fileno(), answer.encode())
except OSError as err:
raise EasyBuildError("Failed to answer question raised by interactive command: %s", err)

match_found = True
break
else:
_log.info("No match found for question patterns, considering question wait patterns")
# if no match was found among question patterns,
# take into account patterns for non-questions (qa_wait_patterns)
for pattern in qa_wait_patterns:
# allow extra whitespace at the end
pattern += r'[\s\n]*$'
regex = re.compile(pattern.encode())
if regex.search(stdout):
_log.info(f"Found match for wait pattern '{pattern}'")
_log.debug(f"Found match for wait pattern '{pattern}' at end of stdout: {stdout[:1000]}")
match_found = True
break
else:
_log.info("No match found for question wait patterns")

return match_found


@run_shell_cmd_cache
def run_shell_cmd(cmd, fail_on_error=True, split_stderr=False, stdin=None, env=None,
hidden=False, in_dry_run=False, verbose_dry_run=False, work_dir=None, use_bash=True,
output_file=True, stream_output=None, asynchronous=False, task_id=None, with_hooks=True,
qa_patterns=None, qa_wait_patterns=None):
qa_patterns=None, qa_wait_patterns=None, qa_timeout=100):
"""
Run specified (interactive) shell command, and capture output + exit code.
Expand All @@ -225,8 +281,9 @@ def run_shell_cmd(cmd, fail_on_error=True, split_stderr=False, stdin=None, env=N
:param task_id: task ID for specified shell command (included in return value)
:param with_hooks: trigger pre/post run_shell_cmd hooks (if defined)
:param qa_patterns: list of 2-tuples with patterns for questions + corresponding answers
:param qa_wait_patterns: list of 2-tuples with patterns for non-questions
and number of iterations to allow these patterns to match with end out command output
:param qa_wait_patterns: list of strings with patterns for non-questions
:param qa_timeout: amount of seconds to wait until more output is produced when there is no matching question
:return: Named tuple with:
- output: command output, stdout+stderr combined if split_stderr is disabled, only stdout otherwise
- exit_code: exit code of command (integer)
Expand All @@ -245,9 +302,13 @@ def to_cmd_str(cmd):

return cmd_str

# temporarily raise a NotImplementedError until all options are implemented
if qa_patterns or qa_wait_patterns:
raise NotImplementedError
# make sure that qa_patterns is a list of 2-tuples (not a dict, or something else)
if qa_patterns:
if not isinstance(qa_patterns, list) or any(not isinstance(x, tuple) or len(x) != 2 for x in qa_patterns):
raise EasyBuildError("qa_patterns passed to run_shell_cmd should be a list of 2-tuples!")

if qa_wait_patterns is None:
qa_wait_patterns = []

if work_dir is None:
work_dir = os.getcwd()
Expand Down Expand Up @@ -280,11 +341,14 @@ def to_cmd_str(cmd):
else:
cmd_out_fp, cmd_err_fp = None, None

interactive = bool(qa_patterns)
interactive_msg = 'interactive ' if interactive else ''

# early exit in 'dry run' mode, after printing the command that would be run (unless 'hidden' is enabled)
if not in_dry_run and build_option('extended_dry_run'):
if not hidden or verbose_dry_run:
silent = build_option('silent')
msg = f" running shell command \"{cmd_str}\"\n"
msg = f" running {interactive_msg}shell command \"{cmd_str}\"\n"
msg += f" (in {work_dir})"
dry_run_msg(msg, silent=silent)

Expand All @@ -293,7 +357,7 @@ def to_cmd_str(cmd):

start_time = datetime.now()
if not hidden:
_cmd_trace_msg(cmd_str, start_time, work_dir, stdin, cmd_out_fp, cmd_err_fp, thread_id)
_cmd_trace_msg(cmd_str, start_time, work_dir, stdin, cmd_out_fp, cmd_err_fp, thread_id, interactive=interactive)

if stream_output:
print_msg(f"(streaming) output for command '{cmd_str}':")
Expand All @@ -310,15 +374,19 @@ def to_cmd_str(cmd):

if with_hooks:
hooks = load_hooks(build_option('hooks'))
hook_res = run_hook(RUN_SHELL_CMD, hooks, pre_step_hook=True, args=[cmd], kwargs={'work_dir': work_dir})
kwargs = {
'interactive': interactive,
'work_dir': work_dir,
}
hook_res = run_hook(RUN_SHELL_CMD, hooks, pre_step_hook=True, args=[cmd], kwargs=kwargs)
if hook_res:
cmd, old_cmd = hook_res, cmd
cmd_str = to_cmd_str(cmd)
_log.info("Command to run was changed by pre-%s hook: '%s' (was: '%s')", RUN_SHELL_CMD, cmd, old_cmd)

stderr = subprocess.PIPE if split_stderr else subprocess.STDOUT

log_msg = f"Running shell command '{cmd_str}' in {work_dir}"
log_msg = f"Running {interactive_msg}shell command '{cmd_str}' in {work_dir}"
if thread_id:
log_msg += f" (via thread with ID {thread_id})"
_log.info(log_msg)
Expand All @@ -330,23 +398,62 @@ def to_cmd_str(cmd):
if stdin:
stdin = stdin.encode()

if stream_output:
if stream_output or qa_patterns:

if qa_patterns:
# make stdout, stderr, stdin non-blocking files
channels = [proc.stdout, proc.stdin]
if split_stderr:
channels += proc.stderr
for channel in channels:
fd = channel.fileno()
flags = fcntl.fcntl(fd, fcntl.F_GETFL)
fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)

if stdin:
proc.stdin.write(stdin)

exit_code = None
stdout, stderr = b'', b''
check_interval_secs = 0.1
time_no_match = 0

# collect output piece-wise, while checking for questions to answer (if qa_patterns is provided)
while exit_code is None:
exit_code = proc.poll()

# use small read size (128 bytes) when streaming output, to make it stream more fluently
# -1 means reading until EOF
read_size = 128 if exit_code is None else -1

stdout += proc.stdout.read(read_size)
more_stdout = proc.stdout.read1(read_size) or b''
stdout += more_stdout

# note: we assume that there won't be any questions in stderr output
if split_stderr:
stderr += proc.stderr.read(read_size)
stderr += proc.stderr.read1(read_size) or b''

if qa_patterns:
if _answer_question(stdout, proc, qa_patterns, qa_wait_patterns):
time_no_match = 0
else:
_log.debug(f"No match found in question/wait patterns at end of stdout: {stdout[:1000]}")
# this will only run if the for loop above was *not* stopped by the break statement
time_no_match += check_interval_secs
if time_no_match > qa_timeout:
error_msg = "No matching questions found for current command output, "
error_msg += f"giving up after {qa_timeout} seconds!"
raise EasyBuildError(error_msg)
else:
_log.debug(f"{time_no_match:0.1f} seconds without match in output of interactive shell command")

time.sleep(check_interval_secs)

exit_code = proc.poll()

# collect last bit of output once processed has exited
stdout += proc.stdout.read()
if split_stderr:
stderr += proc.stderr.read()
else:
(stdout, stderr) = proc.communicate(input=stdin)

Expand Down Expand Up @@ -389,6 +496,7 @@ def to_cmd_str(cmd):
if with_hooks:
run_hook_kwargs = {
'exit_code': res.exit_code,
'interactive': interactive,
'output': res.output,
'stderr': res.stderr,
'work_dir': res.work_dir,
Expand All @@ -402,7 +510,7 @@ def to_cmd_str(cmd):
return res


def _cmd_trace_msg(cmd, start_time, work_dir, stdin, cmd_out_fp, cmd_err_fp, thread_id):
def _cmd_trace_msg(cmd, start_time, work_dir, stdin, cmd_out_fp, cmd_err_fp, thread_id, interactive=False):
"""
Helper function to construct and print trace message for command being run
Expand All @@ -413,13 +521,15 @@ def _cmd_trace_msg(cmd, start_time, work_dir, stdin, cmd_out_fp, cmd_err_fp, thr
:param cmd_out_fp: path to output file for command
:param cmd_err_fp: path to errors/warnings output file for command
:param thread_id: thread ID (None when not running shell command asynchronously)
:param interactive: boolean indicating whether it is an interactive command, or not
"""
start_time = start_time.strftime('%Y-%m-%d %H:%M:%S')

interactive = 'interactive ' if interactive else ''
if thread_id:
run_cmd_msg = f"running shell command (asynchronously, thread ID: {thread_id}):"
run_cmd_msg = f"running {interactive}shell command (asynchronously, thread ID: {thread_id}):"
else:
run_cmd_msg = "running shell command:"
run_cmd_msg = f"running {interactive}shell command:"

lines = [
run_cmd_msg,
Expand Down
Loading

0 comments on commit 471bbcc

Please sign in to comment.