Skip to content

Commit

Permalink
Clean up container tests
Browse files Browse the repository at this point in the history
Always use parallel in `complete` tests
Always use serial in `ci` tests
  • Loading branch information
dostuffthatmatters committed Jan 31, 2024
1 parent ceba67a commit 39e8093
Showing 1 changed file with 105 additions and 112 deletions.
217 changes: 105 additions & 112 deletions tests/retrieval/test_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,33 @@ def test_container_lifecycle_ci(
provide_container_factory: src.retrieval.dispatching.container_factory.
ContainerFactory,
) -> None:
_run(
provide_config_template,
provide_container_factory,
only_run_mock_retrieval=True,
)
config = provide_config_template
container_factory = provide_container_factory

_point_config_to_test_data(config)
assert config.retrieval is not None

pending_jobs = _generate_job_list()
print("Running in serial mode")
for i, j in enumerate(pending_jobs):
print(f"#{i}: Spinning up new session")
session = src.retrieval.session.create_session.run(
container_factory,
j[2],
retrieval_algorithm=j[0],
atmospheric_profile_model=j[1],
job_settings=src.types.config.RetrievalJobSettingsConfig(
# test this for all alg/atm combinations
# for one of the sensor data contexts
use_local_pressure_in_pcxs=(
j[2].from_datetime.date() == datetime.date(2017, 6, 9)
),
)
)
print(f"#{i}: Running session")
run_session(session, config, True)
print(f"#{i}: Finished session")
container_factory.remove_container(session.ctn.container_id)


# this test will run the actual retrieval algorithm
Expand All @@ -121,11 +143,79 @@ def test_container_lifecycle_complete(
provide_container_factory: src.retrieval.dispatching.container_factory.
ContainerFactory,
) -> None:
_run(
provide_config_template,
provide_container_factory,
only_run_mock_retrieval=False,
)
config = provide_config_template
container_factory = provide_container_factory

_point_config_to_test_data(config)
assert config.retrieval is not None

pending_jobs = _generate_job_list()
active_processes: list[multiprocessing.Process] = []
finished_processes: list[multiprocessing.Process] = []

cpu_count = multiprocessing.cpu_count()
print(f"Detected {cpu_count} CPU cores")
process_count = max(1, cpu_count - 1)
print(f"Running {process_count} processes in parallel")

# wait for all processes to finish
while True:
while ((len(active_processes) < process_count) and
(len(pending_jobs) > 0)):

j = pending_jobs.pop(0)
print(f"Spinning up new session")
session = src.retrieval.session.create_session.run(
container_factory,
j[2],
retrieval_algorithm=j[0],
atmospheric_profile_model=j[1],
job_settings=src.types.config.RetrievalJobSettingsConfig(
# test this for all alg/atm combinations
# for one of the sensor data contexts
use_local_pressure_in_pcxs=(
j[2].from_datetime.date() == datetime.date(2017, 6, 9)
),
)
)
print(f"Creating new process")
p = multiprocessing.Process(
target=run_session,
args=(session, config, False),
name=(
f"{session.ctn.container_id}:{j[0]}-{j[1]}-" +
f"{j[2].sensor_id}-{j[2].from_datetime.date()}"
),
)
print(f"Starting process {p.name}")
p.start()
active_processes.append(p)
print(f"Started process {p.name}")

time.sleep(0.5)

newly_finished_processes: list[multiprocessing.Process] = []
for p in active_processes:
if not p.is_alive():
newly_finished_processes.append(p)

for p in newly_finished_processes:
print(f"Joining process {p.name}")
p.join()
active_processes.remove(p)
container_factory.remove_container(p.name.split(":")[0])
finished_processes.append(p)
p.close()
print(f"Finished process {p.name}")

if len(active_processes) == 0 and len(pending_jobs) == 0:
break

time.sleep(2)
print(
f"Pending | Active | Finished: {len(pending_jobs)} |" +
f" {len(active_processes)} | {len(finished_processes)}"
)


def run_session(
Expand All @@ -144,15 +234,11 @@ def run_session(
)


def _run(
config: src.types.Config,
container_factory: src.retrieval.dispatching.container_factory.
ContainerFactory,
only_run_mock_retrieval: bool,
) -> None:
_point_config_to_test_data(config)
def _generate_job_list(
) -> list[tuple[src.types.RetrievalAlgorithm, src.types.AtmosphericProfileModel,
em27_metadata.types.SensorDataContext]]:

src.retrieval.utils.retrieval_status.RetrievalStatusList.reset()
assert config.retrieval is not None

pending_jobs: list[tuple[src.types.RetrievalAlgorithm,
src.types.AtmosphericProfileModel,
Expand All @@ -177,100 +263,7 @@ def _run(
f" #{i}: {j[0]} | {j[1]} | {j[2].sensor_id} | {j[2].from_datetime.date()}"
)

ci_env_var = os.getenv("CI", "not set")
print(f'Environment variable "CI" = {ci_env_var}')
run_parallel = (ci_env_var not in ["true", "True", "1", True])

if run_parallel:
active_processes: list[multiprocessing.Process] = []
finished_processes: list[multiprocessing.Process] = []

cpu_count = multiprocessing.cpu_count()
print(f"Detected {cpu_count} CPU cores")
process_count = max(1, cpu_count - 1)
print(f"Running {process_count} processes in parallel")

# wait for all processes to finish
while True:
while ((len(active_processes) < process_count) and
(len(pending_jobs) > 0)):

j = pending_jobs.pop(0)
print(f"Spinning up new session")
session = src.retrieval.session.create_session.run(
container_factory,
j[2],
retrieval_algorithm=j[0],
atmospheric_profile_model=j[1],
job_settings=src.types.config.RetrievalJobSettingsConfig(
# test this for all alg/atm combinations
# for one of the sensor data contexts
use_local_pressure_in_pcxs=(
j[2].from_datetime.date() == datetime.date(
2017, 6, 9
)
),
)
)
print(f"Creating new process")
p = multiprocessing.Process(
target=run_session,
args=(session, config, only_run_mock_retrieval),
name=(
f"{session.ctn.container_id}:{alg}-{atm}-" +
f"{sdc.sensor_id}-{sdc.from_datetime.date()}"
),
)
print(f"Starting process {p.name}")
p.start()
active_processes.append(p)
print(f"Started process {p.name}")

time.sleep(0.5)

newly_finished_processes: list[multiprocessing.Process] = []
for p in active_processes:
if not p.is_alive():
newly_finished_processes.append(p)

for p in newly_finished_processes:
print(f"Joining process {p.name}")
p.join()
active_processes.remove(p)
container_factory.remove_container(p.name.split(":")[0])
finished_processes.append(p)
p.close()
print(f"Finished process {p.name}")

if len(active_processes) == 0 and len(pending_jobs) == 0:
break

time.sleep(2)
print(
f"Pending | Active | Finished: {len(pending_jobs)} |" +
f" {len(active_processes)} | {len(finished_processes)}"
)
else:
print("Running in serial mode")
for i, j in enumerate(pending_jobs):
print(f"#{i}: Spinning up new session")
session = src.retrieval.session.create_session.run(
container_factory,
j[2],
retrieval_algorithm=j[0],
atmospheric_profile_model=j[1],
job_settings=src.types.config.RetrievalJobSettingsConfig(
# test this for all alg/atm combinations
# for one of the sensor data contexts
use_local_pressure_in_pcxs=(
j[2].from_datetime.date() == datetime.date(2017, 6, 9)
),
)
)
print(f"#{i}: Running session")
run_session(session, config, only_run_mock_retrieval)
print(f"#{i}: Finished session")
container_factory.remove_container(session.ctn.container_id)
return pending_jobs


def _point_config_to_test_data(config: src.types.Config) -> None:
Expand Down

0 comments on commit 39e8093

Please sign in to comment.