Skip to content

Commit

Permalink
Fix for pyinfra overwriting hostfiles (pyinfra-dev#1209)
Browse files Browse the repository at this point in the history
Instead of relying on paramiko to correctly parse the full
hostfile, this fix makes it so that we only append new keys to
the hostfile rather than overwriting it completely.

The old behaviour results in unwanted alterations to the hostfile,
such as hostfile entries not parseable by paramiko disappearing
without users consent, and comments being removed.

I also split out the logic for appending hostkeys into a new function
instead of having the code be unnecessarily repeated in multiple places.
This makes the MissingHostKeyPolicy children a bit more readable IMO.
  • Loading branch information
evoldstad committed Oct 1, 2024
1 parent 2fb4fb0 commit 3553df5
Showing 1 changed file with 21 additions and 14 deletions.
35 changes: 21 additions & 14 deletions pyinfra/connectors/sshuserclient/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,24 @@ def missing_host_key(self, client, hostname, key):
)


def append_hostkey(client, hostname, key):
"""Append hostname to the clients host_keys_file"""

with HOST_KEYS_LOCK:
host_keys = client.get_host_keys()
host_keys.add(hostname, key.get_name(), key)
# The paramiko client saves host keys incorrectly whereas the host keys object does
# this correctly, so use that with the client filename variable.
# See: https://github.com/paramiko/paramiko/pull/1989
host_key_entry = host_keys.get(hostname).to_line()
if host_key_entry is None:
raise SSHException(
"Append Hostkey: Failed to parse host {0}, could not append to hostfile".format(hostname),
)
with open(client._host_keys_filename, "a") as host_keys_file:
host_keys_file.write(host_key_entry.to_line())


class AcceptNewPolicy(MissingHostKeyPolicy):
def missing_host_key(self, client, hostname, key):
logger.warning(
Expand All @@ -40,13 +58,8 @@ def missing_host_key(self, client, hostname, key):
),
)

with HOST_KEYS_LOCK:
host_keys = client.get_host_keys()
host_keys.add(hostname, key.get_name(), key)
# The paramiko client saves host keys incorrectly whereas the host keys object does
# this correctly, so use that with the client filename variable.
# See: https://github.com/paramiko/paramiko/pull/1989
host_keys.save(client._host_keys_filename)
append_hostkey(client, hostname, key)
logger.warning("Added host key for {0} to known_hosts".format(hostname))


class AskPolicy(MissingHostKeyPolicy):
Expand All @@ -60,13 +73,7 @@ def missing_host_key(self, client, hostname, key):
raise SSHException(
"AskPolicy: No host key for {0} found in known_hosts".format(hostname),
)
with HOST_KEYS_LOCK:
host_keys = client.get_host_keys()
host_keys.add(hostname, key.get_name(), key)
# The paramiko client saves host keys incorrectly whereas the host keys object does
# this correctly, so use that with the client filename variable.
# See: https://github.com/paramiko/paramiko/pull/1989
host_keys.save(client._host_keys_filename)
append_hostkey(client, hostname, key)
logger.warning("Added host key for {0} to known_hosts".format(hostname))
return

Expand Down

0 comments on commit 3553df5

Please sign in to comment.