Skip to content

Commit

Permalink
pyinfra/connectors: Fix overwriting of users known_hosts file (pyinfr…
Browse files Browse the repository at this point in the history
…a-dev#1209)

This fixes issue pyinfra-dev#1209, making it so that we append new keys to the
users known_hosts file instead of overwriting it.

Additionally:

    Added a testcase that should discover this breaking in the future.
    Broke out the append functionality into a "append_hostkey" function,
making it so we don't needlessly reuse code for AskPolicy and
AcceptNewPolicy.
    Linting actually correct this time.

Previous behaviour when adding a new key:

    Create a paramiko.HostKeys object
    Read the users known_hosts file
    Add the new key to the object
    Save the object, overwriting the users host_keys file.

New behaviour:

    Create a paramiko.HostKeyEntry object using the new hostname and
corresponding key.
    Append this key to the existing known_hosts file.
  • Loading branch information
vo452 committed Dec 2, 2024
1 parent ef8acef commit 78c93ec
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 15 deletions.
40 changes: 26 additions & 14 deletions pyinfra/connectors/sshuserclient/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
SSHException,
)
from paramiko.agent import AgentRequestHandler
from paramiko.hostkeys import HostKeyEntry

from pyinfra import logger
from pyinfra.api.util import memoize
Expand All @@ -31,6 +32,28 @@ def missing_host_key(self, client, hostname, key):
)


def append_hostkey(client, hostname, key):
"""Append hostname to the clients host_keys_file"""

with HOST_KEYS_LOCK:
# The paramiko client saves host keys incorrectly whereas the host keys object does
# this correctly, so use that with the client filename variable.
# See: https://github.com/paramiko/paramiko/pull/1989
host_key_entry = HostKeyEntry([hostname], key)
if host_key_entry is None:
raise SSHException(
"Append Hostkey: Failed to parse host {0}, could not append to hostfile".format(
hostname
),
)
with open(client._host_keys_filename, "a") as host_keys_file:
hk_entry = host_key_entry.to_line()
if hk_entry is None:
raise SSHException(f"Append Hostkey: Failed to append hostkey ({host_key_entry})")

host_keys_file.write(hk_entry)


class AcceptNewPolicy(MissingHostKeyPolicy):
def missing_host_key(self, client, hostname, key):
logger.warning(
Expand All @@ -40,13 +63,8 @@ def missing_host_key(self, client, hostname, key):
),
)

with HOST_KEYS_LOCK:
host_keys = client.get_host_keys()
host_keys.add(hostname, key.get_name(), key)
# The paramiko client saves host keys incorrectly whereas the host keys object does
# this correctly, so use that with the client filename variable.
# See: https://github.com/paramiko/paramiko/pull/1989
host_keys.save(client._host_keys_filename)
append_hostkey(client, hostname, key)
logger.warning("Added host key for {0} to known_hosts".format(hostname))


class AskPolicy(MissingHostKeyPolicy):
Expand All @@ -60,13 +78,7 @@ def missing_host_key(self, client, hostname, key):
raise SSHException(
"AskPolicy: No host key for {0} found in known_hosts".format(hostname),
)
with HOST_KEYS_LOCK:
host_keys = client.get_host_keys()
host_keys.add(hostname, key.get_name(), key)
# The paramiko client saves host keys incorrectly whereas the host keys object does
# this correctly, so use that with the client filename variable.
# See: https://github.com/paramiko/paramiko/pull/1989
host_keys.save(client._host_keys_filename)
append_hostkey(client, hostname, key)
logger.warning("Added host key for {0} to known_hosts".format(hostname))
return

Expand Down
69 changes: 68 additions & 1 deletion tests/test_connectors/test_sshuserclient.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from base64 import b64decode
from unittest import TestCase
from unittest.mock import mock_open, patch

from paramiko import ProxyCommand
from paramiko import PKey, ProxyCommand, SSHException

from pyinfra.connectors.sshuserclient import SSHClient
from pyinfra.connectors.sshuserclient.client import AskPolicy, get_ssh_config
Expand Down Expand Up @@ -41,6 +42,30 @@
Include other_file
"""

# To ensure that we don't remove things from users hostfiles
# we should test that all modifications only append to the
# hostfile, and don't delete any data or comments.
EXAMPLE_KEY_1 = (
"AAAAB3NzaC1yc2EAAAADAQABAAABgQCj7ndNxQowgcQnjshcLrqPEiiphnt+"
"VTTvDP6mHBL9j1aNUkY4Ue1gvwnGLVlOhGeYrnZaMgRK6+PKCUXaDbC7qtbW8gIkhL7aGCsOr/"
"C56SJMy/BCZfxd1nWzAOxSDPgVsmerOBYfNqltV9/hWCqBywINIR+5dIg6JTJ72pcEpEjcYgXk"
"E2YEFXV1JHnsKgbLWNlhScqb2UmyRkQyytRLtL+38TGxkxCflmO+5Z8CSSNY7GidjMIZ7Q4zMj"
"A2n1nGrlTDkzwDCsw+wqFPGQA179cnfGWOWRVruj16z6XyvxvjJwbz0wQZ75XK5tKSb7FNyeIE"
"s4TT4jk+S4dhPeAUC5y+bDYirYgM4GC7uEnztnZyaVWQ7B381AK4Qdrwt51ZqExKbQpTUNn+Ej"
"qoTwvqNj4kqx5QUCI0ThS/YkOxJCXmPUWZbhjpCg56i+2aB6CmK2JGhn57K5mj0MNdBXA4/Wnw"
"H6XoPWJzK5Nyu2zB3nAZp+S5hpQs+p1vN1/wsjk="
)

KNOWN_HOSTS_EXAMPLE_DATA = f"""
# this is an important comment
# another comment after the newline
@cert-authority example-domain.lan ssh-rsa {EXAMPLE_KEY_1}
192.168.1.222 ssh-rsa {EXAMPLE_KEY_1}
"""


class TestSSHUserConfigMissing(TestCase):
def setUp(self):
Expand Down Expand Up @@ -199,3 +224,45 @@ def test_test_paramiko_connect_kwargs(self, fake_paramiko_connect):
port=22,
test="kwarg",
)

def test_missing_hostkey(self):
client = SSHClient()
policy = AskPolicy()
example_hostname = "new_host"
example_keytype = "ecdsa-sha2-nistp256"
example_key = (
"AAAAE2VjZHNhLXNoYTItbmlzdHAyNT"
"YAAAAIbmlzdHAyNTYAAABBBHNp1NM"
"ZjxPBuuKwIPfkVJqWaH3oUtW137kIW"
"P4PlCyACt8zVIIimFhIpwRUidcf7jw"
"VWPAJvfBjEPqewDApnZQ="
)

key = PKey.from_type_string(
example_keytype,
b64decode(example_key),
)

# Check if AskPolicy respects not importing and properly raises SSHException
with self.subTest("Check user 'no'"):
with patch("builtins.input", return_value="n"):
self.assertRaises(
SSHException, lambda: policy.missing_host_key(client, example_hostname, key)
)

# Check if AskPolicy properly appends to hostfile
with self.subTest("Check user 'yes'"):
mock_data = mock_open(read_data=KNOWN_HOSTS_EXAMPLE_DATA)
# Read mock hostfile
with patch("pyinfra.connectors.sshuserclient.client.open", mock_data):
with patch("paramiko.hostkeys.open", mock_data):
with patch("builtins.input", return_value="y"):
policy.missing_host_key(client, "new_host", key)

# Assert that we appended correctly to the file
write_call_args = mock_data.return_value.write.call_args
# Ensure we only wrote once and then closed the handle.
assert len(write_call_args) == 2
# Ensure we wrote the correct content
correct_output = f"{example_hostname} {example_keytype} {example_key}\n"
assert write_call_args[0][0] == correct_output

0 comments on commit 78c93ec

Please sign in to comment.