Skip to content

Commit

Permalink
Merge pull request #147 from NASA-AMMOS/aerie-2.18.0
Browse files Browse the repository at this point in the history
Aerie 2.18.0
  • Loading branch information
cartermak authored Dec 4, 2024
2 parents ff4b2ec + a313db9 commit 6b413ba
Show file tree
Hide file tree
Showing 9 changed files with 157 additions and 11 deletions.
2 changes: 1 addition & 1 deletion .env
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
DOCKER_TAG=v2.11.0
DOCKER_TAG=v2.18.0
REPOSITORY_DOCKER_URL=ghcr.io/nasa-ammos

AERIE_USERNAME=aerie
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ jobs:
strategy:
matrix:
python-version: ["3.6.15", "3.11"]
aerie-version: ["2.11.0"]
aerie-version: ["2.18.0"]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
Expand Down
2 changes: 1 addition & 1 deletion src/aerie_cli/aerie_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1862,7 +1862,7 @@ def __expand_activity_arguments(self, plan: ActivityPlanRead, full_args: str = N
for activity in plan.activities:
if expand_all or activity.type in expand_types:
query = """
query ($args: ActivityArguments!, $act_type: String!, $model_id: ID!) {
query ($args: ActivityArguments!, $act_type: String!, $model_id: Int!) {
getActivityEffectiveArguments(
activityArguments: $args,
activityTypeName: $act_type,
Expand Down
43 changes: 40 additions & 3 deletions src/aerie_cli/aerie_host.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@

from attrs import define, field

COMPATIBLE_AERIE_VERSIONS = [
"2.18.0"
]

class AerieHostVersionError(RuntimeError):
pass


def process_gateway_response(resp: requests.Response) -> dict:
"""Throw a RuntimeError if the Gateway response is malformed or contains errors
Expand All @@ -18,12 +25,12 @@ def process_gateway_response(resp: requests.Response) -> dict:
dict: Contents of response JSON
"""
if not resp.ok:
raise RuntimeError(f"Bad response from Aerie Gateway.")
raise RuntimeError("Bad response from Aerie Gateway")

try:
resp_json = resp.json()
except requests.exceptions.JSONDecodeError:
raise RuntimeError(f"Failed to get response JSON")
raise RuntimeError("Bad response from Aerie Gateway")

if "success" in resp_json.keys() and not resp_json["success"]:
raise RuntimeError(f"Aerie Gateway request was not successful")
Expand Down Expand Up @@ -260,7 +267,15 @@ def is_auth_enabled(self) -> bool:

return True

def authenticate(self, username: str, password: str = None):
def authenticate(self, username: str, password: str = None, force: bool = False):

try:
self.check_aerie_version()
except AerieHostVersionError as e:
if force:
print("Warning: " + e.args[0])
else:
raise

resp = self.session.post(
self.gateway_url + "/auth/login",
Expand All @@ -278,6 +293,28 @@ def authenticate(self, username: str, password: str = None):
if not self.check_auth():
raise RuntimeError(f"Failed to open session")

def check_aerie_version(self) -> None:
"""Assert that the Aerie host is a compatible version
Raises a `RuntimeError` if the host appears to be incompatible.
"""

resp = self.session.get(self.gateway_url + "/version")

try:
resp_json = process_gateway_response(resp)
host_version = resp_json["version"]
except (RuntimeError, KeyError):
# If the Gateway responded, the route doesn't exist
if resp.text and "Aerie Gateway" in resp.text:
raise AerieHostVersionError("Incompatible Aerie version: host version unknown")

# Otherwise, it could just be a failed connection
raise

if host_version not in COMPATIBLE_AERIE_VERSIONS:
raise AerieHostVersionError(f"Incompatible Aerie version: {host_version}")


@define
class ExternalAuthConfiguration:
Expand Down
5 changes: 3 additions & 2 deletions src/aerie_cli/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ def activate_session(
),
role: str = typer.Option(
None, "--role", "-r", help="Specify a non-default role", metavar="ROLE"
)
),
force: bool = typer.Option(False, "--force", help="Force connection to Aerie host and ignore version compatibility")
):
"""
Activate a session with an Aerie host using a given configuration
Expand All @@ -102,7 +103,7 @@ def activate_session(

conf = PersistentConfigurationManager.get_configuration_by_name(name)

session = start_session_from_configuration(conf, username)
session = start_session_from_configuration(conf, username, force=force)

if role is not None:
if role in session.aerie_jwt.allowed_roles:
Expand Down
6 changes: 4 additions & 2 deletions src/aerie_cli/utils/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ def start_session_from_configuration(
configuration: AerieHostConfiguration,
username: str = None,
password: str = None,
secret_post_vars: Dict[str, str] = None
secret_post_vars: Dict[str, str] = None,
force: bool = False
):
"""Start and authenticate an Aerie Host session, with prompts if necessary
Expand All @@ -136,6 +137,7 @@ def start_session_from_configuration(
username (str, optional): Aerie username.
password (str, optional): Aerie password.
secret_post_vars (Dict[str, str], optional): Optionally provide values for some or all secret post request variable values. Defaults to None.
force (bool, optional): Force connection to Aerie host and ignore version compatibility. Defaults to False.
Returns:
AerieHost:
Expand All @@ -162,6 +164,6 @@ def start_session_from_configuration(
if password is None and hs.is_auth_enabled():
password = typer.prompt("Aerie Password", hide_input=True)

hs.authenticate(username, password)
hs.authenticate(username, password, force)

return hs
2 changes: 1 addition & 1 deletion tests/integration_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
CONFIGURATIONS_PATH = os.path.join(FILES_PATH, "configuration")
CONFIGURATION_PATH = os.path.join(CONFIGURATIONS_PATH, "localhost_config.json")
MODELS_PATH = os.path.join(FILES_PATH, "models")
MODEL_VERSION = os.environ.get("AERIE_VERSION", "2.11.0")
MODEL_VERSION = os.environ.get("AERIE_VERSION", "2.18.0")
MODEL_JAR = os.path.join(MODELS_PATH, f"banananation-{MODEL_VERSION}.jar")
MODEL_NAME = "banananation"
MODEL_VERSION = "0.0.1"
Expand Down
Binary file not shown.
106 changes: 106 additions & 0 deletions tests/unit_tests/test_aerie_host.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from typing import Dict
import pytest
import requests

from aerie_cli.aerie_host import AerieHost, COMPATIBLE_AERIE_VERSIONS, AerieJWT


class MockJWT:
def __init__(self, *args, **kwargs):
self.default_role = 'viewer'

class MockResponse:
def __init__(self, json: Dict, text: str = None, ok: bool = True) -> None:
self.json_data = json
self.text = text
self.ok = ok

def json(self) -> Dict:
if self.json_data is None:
raise requests.exceptions.JSONDecodeError("", "", 0)
return self.json_data


class MockSession:

def __init__(self, mock_response: MockResponse) -> None:
self.mock_response = mock_response

def get(self, *args, **kwargs) -> MockResponse:
return self.mock_response

def post(self, *args, **kwargs) -> MockResponse:
return self.mock_response


def get_mock_aerie_host(json: Dict = None, text: str = None, ok: bool = True) -> AerieHost:
mock_response = MockResponse(json, text, ok)
mock_session = MockSession(mock_response)
return AerieHost("", "", mock_session)


def test_check_aerie_version():
aerie_host = get_mock_aerie_host(
json={"version": COMPATIBLE_AERIE_VERSIONS[0]})

aerie_host.check_aerie_version()


def test_authenticate_invalid_version(capsys, monkeypatch):
ah = AerieHost("", "")

def mock_get(*_, **__):
return MockResponse({"version": "1.0.0"})
def mock_post(*_, **__):
return MockResponse({"token": ""})
def mock_check_auth(*_, **__):
return True

monkeypatch.setattr(requests.Session, "get", mock_get)
monkeypatch.setattr(requests.Session, "post", mock_post)
monkeypatch.setattr(AerieHost, "check_auth", mock_check_auth)
monkeypatch.setattr(AerieJWT, "__init__", MockJWT.__init__)

with pytest.raises(RuntimeError) as e:
ah.authenticate("")

assert "Incompatible Aerie version: 1.0.0" in str(e.value)


def test_authenticate_invalid_version_force(capsys, monkeypatch):
ah = AerieHost("", "")

def mock_get(*_, **__):
return MockResponse({"version": "1.0.0"})
def mock_post(*_, **__):
return MockResponse({"token": ""})
def mock_check_auth(*_, **__):
return True

monkeypatch.setattr(requests.Session, "get", mock_get)
monkeypatch.setattr(requests.Session, "post", mock_post)
monkeypatch.setattr(AerieHost, "check_auth", mock_check_auth)
monkeypatch.setattr(AerieJWT, "__init__", MockJWT.__init__)

ah.authenticate("", force=True)

assert capsys.readouterr().out == "Warning: Incompatible Aerie version: 1.0.0\n"


def test_no_version_endpoint():
aerie_host = get_mock_aerie_host(text="blah Aerie Gateway blah", ok=True)

with pytest.raises(RuntimeError) as e:
aerie_host.check_aerie_version()

assert "Incompatible Aerie version: host version unknown" in str(e.value)


def test_version_broken_gateway():
aerie_host = get_mock_aerie_host(
text="502 Bad Gateway or something", ok=True)

with pytest.raises(RuntimeError) as e:
aerie_host.check_aerie_version()

assert "Bad response from Aerie Gateway" in str(e.value)

0 comments on commit 6b413ba

Please sign in to comment.