diff --git a/pyinfra/connectors/sshuserclient/client.py b/pyinfra/connectors/sshuserclient/client.py index 362c99cb9..4dd88a055 100644 --- a/pyinfra/connectors/sshuserclient/client.py +++ b/pyinfra/connectors/sshuserclient/client.py @@ -14,6 +14,7 @@ SSHException, ) from paramiko.agent import AgentRequestHandler +from paramiko.hostkeys import HostKeyEntry from pyinfra import logger from pyinfra.api.util import memoize @@ -31,6 +32,28 @@ def missing_host_key(self, client, hostname, key): ) +def append_hostkey(client, hostname, key): + """Append hostname to the clients host_keys_file""" + + with HOST_KEYS_LOCK: + # The paramiko client saves host keys incorrectly whereas the host keys object does + # this correctly, so use that with the client filename variable. + # See: https://github.com/paramiko/paramiko/pull/1989 + host_key_entry = HostKeyEntry([hostname], key) + if host_key_entry is None: + raise SSHException( + "Append Hostkey: Failed to parse host {0}, could not append to hostfile".format( + hostname + ), + ) + with open(client._host_keys_filename, "a") as host_keys_file: + hk_entry = host_key_entry.to_line() + if hk_entry is None: + raise SSHException(f"Append Hostkey: Failed to append hostkey ({host_key_entry})") + + host_keys_file.write(hk_entry) + + class AcceptNewPolicy(MissingHostKeyPolicy): def missing_host_key(self, client, hostname, key): logger.warning( @@ -40,13 +63,8 @@ def missing_host_key(self, client, hostname, key): ), ) - with HOST_KEYS_LOCK: - host_keys = client.get_host_keys() - host_keys.add(hostname, key.get_name(), key) - # The paramiko client saves host keys incorrectly whereas the host keys object does - # this correctly, so use that with the client filename variable. - # See: https://github.com/paramiko/paramiko/pull/1989 - host_keys.save(client._host_keys_filename) + append_hostkey(client, hostname, key) + logger.warning("Added host key for {0} to known_hosts".format(hostname)) class AskPolicy(MissingHostKeyPolicy): @@ -60,13 +78,7 @@ def missing_host_key(self, client, hostname, key): raise SSHException( "AskPolicy: No host key for {0} found in known_hosts".format(hostname), ) - with HOST_KEYS_LOCK: - host_keys = client.get_host_keys() - host_keys.add(hostname, key.get_name(), key) - # The paramiko client saves host keys incorrectly whereas the host keys object does - # this correctly, so use that with the client filename variable. - # See: https://github.com/paramiko/paramiko/pull/1989 - host_keys.save(client._host_keys_filename) + append_hostkey(client, hostname, key) logger.warning("Added host key for {0} to known_hosts".format(hostname)) return diff --git a/tests/test_connectors/test_sshuserclient.py b/tests/test_connectors/test_sshuserclient.py index 3c3c17450..28f1468a3 100644 --- a/tests/test_connectors/test_sshuserclient.py +++ b/tests/test_connectors/test_sshuserclient.py @@ -1,7 +1,8 @@ +from base64 import b64decode from unittest import TestCase from unittest.mock import mock_open, patch -from paramiko import ProxyCommand +from paramiko import PKey, ProxyCommand, SSHException from pyinfra.connectors.sshuserclient import SSHClient from pyinfra.connectors.sshuserclient.client import AskPolicy, get_ssh_config @@ -41,6 +42,30 @@ Include other_file """ +# To ensure that we don't remove things from users hostfiles +# we should test that all modifications only append to the +# hostfile, and don't delete any data or comments. +EXAMPLE_KEY_1 = ( + "AAAAB3NzaC1yc2EAAAADAQABAAABgQCj7ndNxQowgcQnjshcLrqPEiiphnt+" + "VTTvDP6mHBL9j1aNUkY4Ue1gvwnGLVlOhGeYrnZaMgRK6+PKCUXaDbC7qtbW8gIkhL7aGCsOr/" + "C56SJMy/BCZfxd1nWzAOxSDPgVsmerOBYfNqltV9/hWCqBywINIR+5dIg6JTJ72pcEpEjcYgXk" + "E2YEFXV1JHnsKgbLWNlhScqb2UmyRkQyytRLtL+38TGxkxCflmO+5Z8CSSNY7GidjMIZ7Q4zMj" + "A2n1nGrlTDkzwDCsw+wqFPGQA179cnfGWOWRVruj16z6XyvxvjJwbz0wQZ75XK5tKSb7FNyeIE" + "s4TT4jk+S4dhPeAUC5y+bDYirYgM4GC7uEnztnZyaVWQ7B381AK4Qdrwt51ZqExKbQpTUNn+Ej" + "qoTwvqNj4kqx5QUCI0ThS/YkOxJCXmPUWZbhjpCg56i+2aB6CmK2JGhn57K5mj0MNdBXA4/Wnw" + "H6XoPWJzK5Nyu2zB3nAZp+S5hpQs+p1vN1/wsjk=" +) + +KNOWN_HOSTS_EXAMPLE_DATA = f""" +# this is an important comment + +# another comment after the newline + +@cert-authority example-domain.lan ssh-rsa {EXAMPLE_KEY_1} + +192.168.1.222 ssh-rsa {EXAMPLE_KEY_1} +""" + class TestSSHUserConfigMissing(TestCase): def setUp(self): @@ -199,3 +224,45 @@ def test_test_paramiko_connect_kwargs(self, fake_paramiko_connect): port=22, test="kwarg", ) + + def test_missing_hostkey(self): + client = SSHClient() + policy = AskPolicy() + example_hostname = "new_host" + example_keytype = "ecdsa-sha2-nistp256" + example_key = ( + "AAAAE2VjZHNhLXNoYTItbmlzdHAyNT" + "YAAAAIbmlzdHAyNTYAAABBBHNp1NM" + "ZjxPBuuKwIPfkVJqWaH3oUtW137kIW" + "P4PlCyACt8zVIIimFhIpwRUidcf7jw" + "VWPAJvfBjEPqewDApnZQ=" + ) + + key = PKey.from_type_string( + example_keytype, + b64decode(example_key), + ) + + # Check if AskPolicy respects not importing and properly raises SSHException + with self.subTest("Check user 'no'"): + with patch("builtins.input", return_value="n"): + self.assertRaises( + SSHException, lambda: policy.missing_host_key(client, example_hostname, key) + ) + + # Check if AskPolicy properly appends to hostfile + with self.subTest("Check user 'yes'"): + mock_data = mock_open(read_data=KNOWN_HOSTS_EXAMPLE_DATA) + # Read mock hostfile + with patch("pyinfra.connectors.sshuserclient.client.open", mock_data): + with patch("paramiko.hostkeys.open", mock_data): + with patch("builtins.input", return_value="y"): + policy.missing_host_key(client, "new_host", key) + + # Assert that we appended correctly to the file + write_call_args = mock_data.return_value.write.call_args + # Ensure we only wrote once and then closed the handle. + assert len(write_call_args) == 2 + # Ensure we wrote the correct content + correct_output = f"{example_hostname} {example_keytype} {example_key}\n" + assert write_call_args[0][0] == correct_output