Skip to content

Commit

Permalink
Merge branch 'dev' into pythonicworkflow
Browse files Browse the repository at this point in the history
Signed-off-by: YunLiu <[email protected]>

Signed-off-by: YunLiu <[email protected]>
  • Loading branch information
KumoLiu committed Nov 25, 2024
2 parents ec202e0 + d94df3f commit 70dc9b5
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 11 deletions.
10 changes: 10 additions & 0 deletions monai/bundle/reference_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,16 @@ def get_resolved_content(self, id: str, **kwargs: Any) -> ConfigExpression | str
"""
return self._resolve_one_item(id=id, **kwargs)

def remove_resolved_content(self, id: str) -> Any | None:
"""
Remove the resolved ``ConfigItem`` by id.
Args:
id: id name of the expected item.
"""
return self.resolved_content.pop(id) if id in self.resolved_content else None

@classmethod
def normalize_id(cls, id: str | int) -> str:
"""
Expand Down
19 changes: 17 additions & 2 deletions monai/bundle/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,8 +544,23 @@ def check_properties(self) -> list[str] | None:
ret.extend(wrong_props)
return ret

def _run_expr(self, id: str, **kwargs: dict) -> Any:
return self.parser.get_parsed_content(id, **kwargs) if id in self.parser else None
def _run_expr(self, id: str, **kwargs: dict) -> list[Any]:
"""
Evaluate the expression or expression list given by `id`. The resolved values from the evaluations are not stored,
allowing this to be evaluated repeatedly (eg. in streaming applications) without restarting the hosting process.
"""
ret = []
if id in self.parser:
# suppose all the expressions are in a list, run and reset the expressions
if isinstance(self.parser[id], list):
for i in range(len(self.parser[id])):
sub_id = f"{id}{ID_SEP_KEY}{i}"
ret.append(self.parser.get_parsed_content(sub_id, **kwargs))
self.parser.ref_resolver.remove_resolved_content(sub_id)
else:
ret.append(self.parser.get_parsed_content(id, **kwargs))
self.parser.ref_resolver.remove_resolved_content(id)
return ret

def _get_prop_id(self, name: str, property: dict) -> Any:
prop_id = property[BundlePropertyConfig.ID]
Expand Down
45 changes: 38 additions & 7 deletions tests/test_bundle_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from monai.data import Dataset
from monai.inferers import SimpleInferer, SlidingWindowInferer
from monai.networks.nets import UNet
from monai.transforms import Compose, LoadImage, LoadImaged
from monai.transforms import Compose, LoadImage, LoadImaged, SaveImaged
from tests.nonconfig_workflow import NonConfigWorkflow, PythonicWorkflowImpl

TEST_CASE_1 = [os.path.join(os.path.dirname(__file__), "testing_data", "inference.json")]
Expand All @@ -36,6 +36,8 @@

TEST_CASE_3 = [os.path.join(os.path.dirname(__file__), "testing_data", "config_fl_train.json")]

TEST_CASE_4 = [os.path.join(os.path.dirname(__file__), "testing_data", "responsive_inference.json")]

TEST_CASE_NON_CONFIG_WRONG_LOG = [None, "logging.conf", "Cannot find the logging config file: logging.conf."]


Expand All @@ -46,9 +48,9 @@ def setUp(self):
self.expected_shape = (128, 128, 128)
test_image = np.random.rand(*self.expected_shape)
self.filename = os.path.join(self.data_dir, "image.nii")
self.filename2 = os.path.join(self.data_dir, "image2.nii")
self.filename1 = os.path.join(self.data_dir, "image1.nii")
nib.save(nib.Nifti1Image(test_image, np.eye(4)), self.filename)
nib.save(nib.Nifti1Image(test_image, np.eye(4)), self.filename2)
nib.save(nib.Nifti1Image(test_image, np.eye(4)), self.filename1)

def tearDown(self):
shutil.rmtree(self.data_dir)
Expand Down Expand Up @@ -119,6 +121,35 @@ def test_inference_config(self, config_file):
self._test_inferer(inferer)
self.assertEqual(inferer.workflow_type, "infer")

@parameterized.expand([TEST_CASE_4])
def test_responsive_inference_config(self, config_file):
input_loader = LoadImaged(keys="image")
output_saver = SaveImaged(keys="pred", output_dir=self.data_dir, output_postfix="seg")

# test standard MONAI model-zoo config workflow
inferer = ConfigWorkflow(
workflow_type="infer",
config_file=config_file,
logging_file=os.path.join(os.path.dirname(__file__), "testing_data", "logging.conf"),
)
# FIXME: temp add the property for test, we should add it to some formal realtime infer properties
inferer.add_property(name="dataflow", required=True, config_id="dataflow")

inferer.initialize()
inferer.dataflow.update(input_loader({"image": self.filename}))
inferer.run()
output_saver(inferer.dataflow)
self.assertTrue(os.path.exists(os.path.join(self.data_dir, "image", "image_seg.nii.gz")))

# bundle is instantiated and idle, just change the input for next inference
inferer.dataflow.clear()
inferer.dataflow.update(input_loader({"image": self.filename1}))
inferer.run()
output_saver(inferer.dataflow)
self.assertTrue(os.path.exists(os.path.join(self.data_dir, "image1", "image1_seg.nii.gz")))

inferer.finalize()

@parameterized.expand([TEST_CASE_3])
def test_train_config(self, config_file):
# test standard MONAI model-zoo config workflow
Expand Down Expand Up @@ -187,11 +218,11 @@ def test_pythonic_workflow(self):
self.assertEqual(workflow.inferer.roi_size, (64, 64, 32))
workflow.run()
# update input data and run again
workflow.dataflow.update(input_loader({"image": self.filename2}))
workflow.dataflow.update(input_loader({"image": self.filename1}))
workflow.run()
pred = workflow.dataflow["pred"]
self.assertEqual(pred.shape[2:], self.expected_shape)
self.assertEqual(pred.meta["filename_or_obj"], self.filename2)
self.assertEqual(pred.meta["filename_or_obj"], self.filename1)
workflow.finalize()

def test_create_pythonic_workflow(self):
Expand Down Expand Up @@ -223,11 +254,11 @@ def test_create_pythonic_workflow(self):

workflow.run()
# update input data and run again
workflow.dataflow.update(input_loader({"image": self.filename2}))
workflow.dataflow.update(input_loader({"image": self.filename1}))
workflow.run()
pred = workflow.dataflow["pred"]
self.assertEqual(pred.shape[2:], self.expected_shape)
self.assertEqual(pred.meta["filename_or_obj"], self.filename2)
self.assertEqual(pred.meta["filename_or_obj"], self.filename1)

# test add properties
workflow.add_property(name="net", required=True, desc="network for the training.")
Expand Down
8 changes: 6 additions & 2 deletions tests/test_module_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,17 @@ def test_transform_api(self):
continue
with self.subTest(n=n):
basename = n[:-1] # Transformd basename is Transform

# remove aliases to check, do this before the assert below so that a failed assert does skip this
for postfix in ("D", "d", "Dict"):
remained.remove(f"{basename}{postfix}")

for docname in (f"{basename}", f"{basename}d"):
if docname in to_exclude_docs:
continue
if (contents is not None) and f"`{docname}`" not in f"{contents}":
self.assertTrue(False, f"please add `{docname}` to docs/source/transforms.rst")
for postfix in ("D", "d", "Dict"):
remained.remove(f"{basename}{postfix}")

self.assertFalse(remained)


Expand Down
101 changes: 101 additions & 0 deletions tests/testing_data/responsive_inference.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
{
"imports": [
"$from collections import defaultdict"
],
"bundle_root": "will override",
"device": "$torch.device('cpu')",
"network_def": {
"_target_": "UNet",
"spatial_dims": 3,
"in_channels": 1,
"out_channels": 2,
"channels": [
2,
2,
4,
8,
4
],
"strides": [
2,
2,
2,
2
],
"num_res_units": 2,
"norm": "batch"
},
"network": "$@network_def.to(@device)",
"dataflow": "$defaultdict()",
"preprocessing": {
"_target_": "Compose",
"transforms": [
{
"_target_": "EnsureChannelFirstd",
"keys": "image"
},
{
"_target_": "ScaleIntensityd",
"keys": "image"
},
{
"_target_": "RandRotated",
"_disabled_": true,
"keys": "image"
}
]
},
"dataset": {
"_target_": "Dataset",
"data": [
"@dataflow"
],
"transform": "@preprocessing"
},
"dataloader": {
"_target_": "DataLoader",
"dataset": "@dataset",
"batch_size": 1,
"shuffle": false,
"num_workers": 0
},
"inferer": {
"_target_": "SlidingWindowInferer",
"roi_size": [
64,
64,
32
],
"sw_batch_size": 4,
"overlap": 0.25
},
"postprocessing": {
"_target_": "Compose",
"transforms": [
{
"_target_": "Activationsd",
"keys": "pred",
"softmax": true
},
{
"_target_": "AsDiscreted",
"keys": "pred",
"argmax": true
}
]
},
"evaluator": {
"_target_": "SupervisedEvaluator",
"device": "@device",
"val_data_loader": "@dataloader",
"network": "@network",
"inferer": "@inferer",
"postprocessing": "@postprocessing",
"amp": false,
"epoch_length": 1
},
"run": [
"[email protected]()",
"[email protected](@evaluator.state.output[0])"
]
}

0 comments on commit 70dc9b5

Please sign in to comment.