Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mk-armah committed Dec 3, 2024
1 parent 5cc7258 commit 7e5e81e
Showing 1 changed file with 38 additions and 19 deletions.
57 changes: 38 additions & 19 deletions integrations/aws/tests/utils/test_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from aws.aws_credentials import AwsCredentials
from aws.session_manager import SessionManager
from aioboto3 import Session
from port_ocean.context.event import event_context


class TestUpdateAvailableAccessCredentials(unittest.IsolatedAsyncioTestCase):
Expand All @@ -26,17 +27,19 @@ async def _create_iterator_tasks(func: Any, count: int) -> List[Any]:
async def test_multiple_task_execution(
self, mock_lock: AsyncMock, mock_reset: AsyncMock
) -> None:
tasks: List[Any] = await self._create_iterator_tasks(
self._run_update_access_iterator_result, 10
)
async for result in stream_async_iterators_tasks(*tasks):
self.assertTrue(result)

# Assert that the reset method was awaited exactly once (i.e., no thundering herd)
mock_reset.assert_awaited_once()
async with event_context("test_event"):
tasks: List[Any] = await self._create_iterator_tasks(
self._run_update_access_iterator_result, 10
)
async for result in stream_async_iterators_tasks(*tasks):
self.assertTrue(result)

# Assert that the reset method was awaited exactly once (i.e., no thundering herd)
mock_reset.assert_awaited_once()

mock_lock.__aenter__.assert_awaited_once()
mock_lock.__aexit__.assert_awaited_once()
mock_lock.__aenter__.assert_awaited_once()
mock_lock.__aexit__.assert_awaited_once()


class TestAwsSessions(unittest.IsolatedAsyncioTestCase):
Expand All @@ -53,7 +56,9 @@ def tearDown(self) -> None:

async def test_get_sessions_with_custom_account_id(self) -> None:
"""Test get_sessions with a custom account ID and region."""
self.credentials_mock.create_session = AsyncMock(return_value=self.session_mock)
self.credentials_mock.create_refreshable_session = AsyncMock(
return_value=self.session_mock
)

self.session_manager_mock.find_credentials_by_account_id.return_value = (
self.credentials_mock
Expand All @@ -66,12 +71,16 @@ async def test_get_sessions_with_custom_account_id(self) -> None:
)
]

self.credentials_mock.create_session.assert_called_once_with("us-west-2")
self.credentials_mock.create_refreshable_session.assert_called_once_with(
"us-west-2"
)
self.assertEqual(sessions[0], self.session_mock)

async def test_session_factory_with_custom_region(self) -> None:
"""Test session_factory with custom region."""
self.credentials_mock.create_session = AsyncMock(return_value=self.session_mock)
self.credentials_mock.create_refreshable_session = AsyncMock(
return_value=self.session_mock
)
sessions: List[Session] = [
s
async for s in session_factory(
Expand All @@ -81,20 +90,26 @@ async def test_session_factory_with_custom_region(self) -> None:
)
]

self.credentials_mock.create_session.assert_called_once_with("us-east-1")
self.credentials_mock.create_refreshable_session.assert_called_once_with(
"us-east-1"
)
self.assertEqual(sessions[0], self.session_mock)

async def test_get_sessions_with_default_region(self) -> None:
"""Test get_sessions with default region."""
self.credentials_mock.default_regions = ["us-west-1"]
self.credentials_mock.create_session = AsyncMock(return_value=self.session_mock)
self.credentials_mock.create_refreshable_session = AsyncMock(
return_value=self.session_mock
)
self.session_manager_mock._aws_credentials = [self.credentials_mock]

sessions: List[Session] = [
s async for s in get_sessions(use_default_region=True)
]

self.credentials_mock.create_session.assert_called_once_with("us-west-1")
self.credentials_mock.create_refreshable_session.assert_called_once_with(
"us-west-1"
)
self.assertEqual(len(sessions), 1)
self.assertEqual(sessions[0], self.session_mock)

Expand All @@ -106,10 +121,10 @@ async def test_get_sessions_with_multiple_credentials(self) -> None:
self.credentials_mock_1.default_regions = ["us-west-1"]
self.credentials_mock_2.default_regions = ["us-east-1"]

self.credentials_mock_1.create_session = AsyncMock(
self.credentials_mock_1.create_refreshable_session = AsyncMock(
return_value=self.session_mock
)
self.credentials_mock_2.create_session = AsyncMock(
self.credentials_mock_2.create_refreshable_session = AsyncMock(
return_value=self.session_mock
)

Expand All @@ -123,5 +138,9 @@ async def test_get_sessions_with_multiple_credentials(self) -> None:
]

self.assertEqual(len(sessions), 2)
self.credentials_mock_1.create_session.assert_called_once_with("us-west-1")
self.credentials_mock_2.create_session.assert_called_once_with("us-east-1")
self.credentials_mock_1.create_refreshable_session.assert_called_once_with(
"us-west-1"
)
self.credentials_mock_2.create_refreshable_session.assert_called_once_with(
"us-east-1"
)

0 comments on commit 7e5e81e

Please sign in to comment.