Skip to content

Commit

Permalink
[TEMP] Redesigning mila init
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <[email protected]>
  • Loading branch information
lebrice committed Apr 18, 2024
1 parent 1d8ea61 commit 66397c8
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 90 deletions.
36 changes: 1 addition & 35 deletions milatools/cli/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import rich.logging
from typing_extensions import TypedDict

from milatools.cli.init import init
from milatools.utils.vscode_utils import (
sync_vscode_extensions_with_hostnames,
)
Expand All @@ -32,22 +33,13 @@
from ..utils.remote_v1 import RemoteV1
from .code_command import add_mila_code_arguments
from .common import forward, standard_server
from .init import (
print_welcome_message,
setup_keys_on_login_node,
setup_passwordless_ssh_access,
setup_ssh_config,
setup_vscode_settings,
setup_windows_ssh_config_from_wsl,
)
from .utils import (
CLUSTERS,
MilatoolsUserError,
SortingHelpFormatter,
SSHConnectionError,
T,
get_fully_qualified_name,
running_inside_WSL,
)

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -424,32 +416,6 @@ def intranet(search: Sequence[str]) -> None:
webbrowser.open(url)


def init():
"""Set up your configuration and credentials."""

#############################
# Step 1: SSH Configuration #
#############################

print("Checking ssh config")

ssh_config = setup_ssh_config()

# if we're running on WSL, we actually just copy the id_rsa + id_rsa.pub and the
# ~/.ssh/config to the Windows ssh directory (taking care to remove the
# ControlMaster-related entries) so that the user doesn't need to install Python on
# the Windows side.
if running_inside_WSL():
setup_windows_ssh_config_from_wsl(linux_ssh_config=ssh_config)

success = setup_passwordless_ssh_access(ssh_config=ssh_config)
if not success:
exit()
setup_keys_on_login_node()
setup_vscode_settings()
print_welcome_message()


def forward_command(
remote: str,
page: str | None,
Expand Down
158 changes: 103 additions & 55 deletions milatools/cli/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,23 @@
import sys
import warnings
from logging import getLogger as get_logger
from pathlib import Path
from pathlib import Path, PosixPath
from typing import Any

import questionary as qn
from invoke.exceptions import UnexpectedExit
from paramiko.config import SSHConfig as SSHConfigReader

from milatools.utils.remote_v2 import SSH_CONFIG_FILE

from ..utils.local_v1 import LocalV1, check_passwordless, display
from ..utils.remote_v1 import RemoteV1
from ..utils.vscode_utils import (
from milatools.cli import console
from milatools.cli.utils import SSHConfig as SSHConfigWriter
from milatools.cli.utils import T, running_inside_WSL, yn
from milatools.utils.local_v1 import check_passwordless, display
from milatools.utils.local_v2 import LocalV2
from milatools.utils.remote_v1 import RemoteV1
from milatools.utils.vscode_utils import (
get_expected_vscode_settings_json_path,
vscode_installed,
)
from .utils import SSHConfig as SSHConfigWriter
from .utils import T, running_inside_WSL, yn

logger = get_logger(__name__)

Expand Down Expand Up @@ -114,6 +114,34 @@
}


def init():
"""Set up your configuration and credentials."""

#############################
# Step 1: SSH Configuration #
#############################

print("Checking ssh config")
ssh_config_path = Path("~/.ssh/config").expanduser()

setup_ssh_config(ssh_config_path=ssh_config_path)

# if we're running on WSL, we actually just copy the id_rsa + id_rsa.pub and the
# ~/.ssh/config to the Windows ssh directory (taking care to remove the
# ControlMaster-related entries) so that the user doesn't need to install Python on
# the Windows side.
if running_inside_WSL():
assert isinstance(ssh_config_path, PosixPath) # we're running in linux (WSL).
setup_windows_ssh_config_from_wsl(linux_ssh_config_path=ssh_config_path)

success = setup_passwordless_ssh_access(ssh_config_path)
if not success:
exit()
setup_keys_on_login_node()
setup_vscode_settings()
print_welcome_message()


def setup_ssh_config(
ssh_config_path: str | Path = "~/.ssh/config",
) -> SSHConfigReader:
Expand Down Expand Up @@ -180,7 +208,7 @@ def setup_ssh_config(
return ssh_config


def setup_windows_ssh_config_from_wsl(linux_ssh_config: SSHConfigWriter):
def setup_windows_ssh_config_from_wsl(linux_ssh_config_path: PosixPath):
"""Setup the Windows SSH configuration and public key from within WSL.
This copies over the entries from the linux ssh configuration file, except for the
Expand All @@ -192,6 +220,8 @@ def setup_windows_ssh_config_from_wsl(linux_ssh_config: SSHConfigWriter):
This makes it so the user doesn't need to install Python/Anaconda on the Windows
side in order to use `mila code` from within WSL.
"""
linux_ssh_config = SSHConfigWriter(linux_ssh_config_path)

assert running_inside_WSL()
# NOTE: This also assumes that a public/private key pair has already been generated
# at ~/.ssh/id_rsa.pub and ~/.ssh/id_rsa.
Expand Down Expand Up @@ -234,45 +264,34 @@ def setup_windows_ssh_config_from_wsl(linux_ssh_config: SSHConfigWriter):
_copy_if_needed(linux_key_file, windows_key_file)


def setup_passwordless_ssh_access(ssh_config: SSHConfigWriter) -> bool:
def setup_passwordless_ssh_access(ssh_config_path: Path) -> bool:
"""Sets up passwordless ssh access to the Mila and optionally also to DRAC.
Sets up ssh connection to the DRAC clusters if they are present in the SSH config
file.
Returns whether the operation completed successfully or not.
"""
print("Checking passwordless authentication")
print("Setting up passwordless SSH access.")

here = LocalV1()
sshdir = Path.home() / ".ssh"
ssh_config = SSHConfigReader.from_path(str(ssh_config_path))

# Check if there is a public key file in ~/.ssh
if not list(sshdir.glob("id*.pub")):
if yn("You have no public keys. Generate one?"):
# Run ssh-keygen with the given location and no passphrase.
ssh_private_key_path = Path.home() / ".ssh" / "id_rsa"
create_ssh_keypair(ssh_private_key_path, here)
else:
print("No public keys.")
return False
# TODO: Generate SSH keys with ssh-keygen (not setting the passphrase so users can choose to use a passphrase or not).
setup_passwordless_ssh_access_to_cluster("mila", ssh_config_path)

# TODO: This uses the public key set in the SSH config file, which may (or may not)
# be the random id*.pub file that was just checked for above.
success = setup_passwordless_ssh_access_to_cluster("mila")
if not success:
return False
setup_keys_on_login_node("mila")
hosts_in_ssh_config = [
hostname
for hostname in ssh_config.get_hostnames()
if not any(c in hostname for c in "!*?")
]

drac_clusters_in_ssh_config: list[str] = []
hosts_in_config = ssh_config.hosts()
for cluster in DRAC_CLUSTERS:
if any(cluster in hostname for hostname in hosts_in_config):
drac_clusters_in_ssh_config.append(cluster)
drac_clusters_in_ssh_config: list[str] = list(
set(DRAC_CLUSTERS).intersection(hosts_in_ssh_config)
)

if not drac_clusters_in_ssh_config:
logger.debug(
f"There are no DRAC clusters in the SSH config at {ssh_config.path}."
f"There are no DRAC clusters in the SSH config at {ssh_config_path}."
)
return True

Expand All @@ -285,35 +304,65 @@ def setup_passwordless_ssh_access(ssh_config: SSHConfigWriter) -> bool:
"See https://docs.alliancecan.ca/wiki/SSH_Keys#Using_CCDB for more info."
)
for drac_cluster in drac_clusters_in_ssh_config:
success = setup_passwordless_ssh_access_to_cluster(drac_cluster)
success = setup_passwordless_ssh_access_to_cluster(
drac_cluster, ssh_config_path
)
if not success:
return False
setup_keys_on_login_node(drac_cluster)
return True


def setup_passwordless_ssh_access_to_cluster(cluster: str) -> bool:
def _get_private_key_path_for_hostname(
hostname: str, ssh_config_path: Path
) -> Path | None:
config = SSHConfigReader.from_path(str(ssh_config_path))
identity_file = config.lookup(hostname).get("identityfile")
if not identity_file:
return None
# Seems to be a list for some reason?
if isinstance(identity_file, list):
assert identity_file
identity_file = identity_file[0]
return Path(identity_file).expanduser()


def setup_passwordless_ssh_access_to_cluster(
cluster: str, ssh_config_path: Path
) -> bool:
"""Sets up passwordless SSH access to the given hostname.
On Mac/Linux, uses `ssh-copy-id`. Performs the steps of ssh-copy-id manually on
Windows.
Returns whether the operation completed successfully or not.
"""
here = LocalV1()
here = LocalV2()
# Check that it is possible to connect without using a password.
print(f"Checking if passwordless SSH access is setup for the {cluster} cluster.")
# TODO: Potentially use a custom key like `~/.ssh/id_milatools.pub` instead of
# the default.
from paramiko.config import SSHConfig
ssh_private_key_path = _get_private_key_path_for_hostname(cluster, ssh_config_path)
# TODO: Simplify the code here by assuming that users just accepted the changes to
# their SSH config proposed by the first part of `mila init`.
# - Instead of making the code complicated with lots of corner cases, just raise an
# error if the SSH config doesn't match what we expect to see after `mila init`.
raise NotImplementedError()
if ssh_private_key_path is None:
# TODO: What to do if there isn't a private key set in the SSH config, but there
# is already a private key in the SSH dir? (it would be used by ssh).
console.log(
f"There is no private key set to be used for the {cluster} cluster."
)
ssh_private_key_path = Path("~/.ssh/id_rsa").expanduser()

if not ssh_private_key_path.exists():
console.log(
f"The ssh key to use for host {cluster} does not exist at {ssh_private_key_path}. Creating it now."
)
create_ssh_keypair(ssh_private_key_path)
config_writer = SSHConfigWriter(ssh_config_path)

config_writer.set(cluster, IdentityFile=str(ssh_private_key_path))

config = SSHConfig.from_path(str(SSH_CONFIG_FILE))
identity_file = config.lookup(cluster).get("identityfile", "~/.ssh/id_rsa")
# Seems to be a list for some reason?
if isinstance(identity_file, list):
assert identity_file
identity_file = identity_file[0]
ssh_private_key_path = Path(identity_file).expanduser()
ssh_public_key_path = ssh_private_key_path.with_suffix(".pub")
assert ssh_public_key_path.exists()

Expand Down Expand Up @@ -451,7 +500,7 @@ def get_windows_home_path_in_wsl() -> Path:

def create_ssh_keypair(
ssh_private_key_path: Path,
local: LocalV1 | None = None,
local: LocalV2 | None = None,
passphrase: str | None = "",
) -> None:
"""Creates a public/private key pair at the given path using ssh-keygen.
Expand All @@ -460,18 +509,17 @@ def create_ssh_keypair(
Otherwise, if passphrase is an empty string, no passphrase will be used (default).
If a string is passed, it is passed to ssh-keygen and used as the passphrase.
"""
local = local or LocalV1()
command = [
local = local or LocalV2()
command = (
"ssh-keygen",
"-f",
str(ssh_private_key_path.expanduser()),
"-t",
"rsa",
]
"rsa", # note: Could also let the user choose the type of encryption..
)
if passphrase is not None:
command.extend(["-N", passphrase])
display(command)
subprocess.run(command, check=True)
command += ("-N", passphrase)
local.run(command, display=True)


def has_passphrase(ssh_private_key_path: Path) -> bool:
Expand Down

0 comments on commit 66397c8

Please sign in to comment.