diff --git a/.github/workflows/blossom-ci.yml b/.github/workflows/blossom-ci.yml
index 1d6ee8a46c..63ac5536d8 100644
--- a/.github/workflows/blossom-ci.yml
+++ b/.github/workflows/blossom-ci.yml
@@ -29,13 +29,14 @@ jobs:
args: ${{ env.args }}
# This job only runs for pull request comments
- if: contains('\
- Nic-Ma,\
- wyli,\
- pxLi,\
- YanxuanLiu,\
- KumoLiu,\
- ', format('{0},', github.actor)) && github.event.comment.body == '/build'
+ if: |
+ github.event.comment.body == '/build' &&
+ (
+ github.actor == 'Nic-Ma' ||
+ github.actor == 'wyli' ||
+ github.actor == 'wendell-hom' ||
+ github.actor == 'KumoLiu'
+ )
steps:
- name: Check if comment is issued by authorized person
run: blossom-ci
diff --git a/.github/workflows/conda.yml b/.github/workflows/conda.yml
index 367a24cbde..394685acd3 100644
--- a/.github/workflows/conda.yml
+++ b/.github/workflows/conda.yml
@@ -32,6 +32,10 @@ jobs:
maximum-size: 16GB
disk-root: "D:"
- uses: actions/checkout@v4
+ - name: Clean up disk space
+ run: |
+ find /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \;
+ rm -rf /usr/share/dotnet/
- uses: conda-incubator/setup-miniconda@v3
with:
auto-update-conda: true
@@ -56,6 +60,10 @@ jobs:
conda deactivate
- name: Test env (CPU ${{ runner.os }})
shell: bash -el {0}
+ env:
+ NGC_API_KEY: ${{ secrets.NGC_API_KEY }}
+ NGC_ORG: ${{ secrets.NGC_ORG }}
+ NGC_TEAM: ${{ secrets.NGC_TEAM }}
run: |
conda activate monai
$(pwd)/runtests.sh --build --unittests
diff --git a/.github/workflows/cron-ngc-bundle.yml b/.github/workflows/cron-ngc-bundle.yml
index bd45bc8d1e..d4b45e1d55 100644
--- a/.github/workflows/cron-ngc-bundle.yml
+++ b/.github/workflows/cron-ngc-bundle.yml
@@ -18,10 +18,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- - name: Set up Python 3.8
+ - name: Set up Python 3.9
uses: actions/setup-python@v5
with:
- python-version: '3.8'
+ python-version: '3.9'
- name: cache weekly timestamp
id: pip-cache
run: echo "datew=$(date '+%Y-%V')" >> $GITHUB_OUTPUT
diff --git a/.github/workflows/cron.yml b/.github/workflows/cron.yml
index 792fda5279..516e2d4743 100644
--- a/.github/workflows/cron.yml
+++ b/.github/workflows/cron.yml
@@ -13,24 +13,28 @@ jobs:
strategy:
matrix:
environment:
- - "PT191+CUDA113"
- "PT110+CUDA113"
- - "PT113+CUDA113"
- - "PTLATEST+CUDA121"
+ - "PT113+CUDA118"
+ - "PT210+CUDA121"
+ - "PT240+CUDA126"
+ - "PTLATEST+CUDA126"
include:
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes
- - environment: PT191+CUDA113
- pytorch: "torch==1.9.1 torchvision==0.10.1 --extra-index-url https://download.pytorch.org/whl/cu113"
- base: "nvcr.io/nvidia/pytorch:21.06-py3" # CUDA 11.3
- environment: PT110+CUDA113
pytorch: "torch==1.10.2 torchvision==0.11.3 --extra-index-url https://download.pytorch.org/whl/cu113"
base: "nvcr.io/nvidia/pytorch:21.06-py3" # CUDA 11.3
- - environment: PT113+CUDA113
- pytorch: "torch==1.13.1 torchvision==0.14.1 --extra-index-url https://download.pytorch.org/whl/cu113"
- base: "nvcr.io/nvidia/pytorch:21.06-py3" # CUDA 11.3
- - environment: PTLATEST+CUDA121
- pytorch: "-U torch torchvision --extra-index-url https://download.pytorch.org/whl/cu118"
- base: "nvcr.io/nvidia/pytorch:23.08-py3" # CUDA 12.2
+ - environment: PT113+CUDA118
+ pytorch: "torch==1.13.1 torchvision==0.14.1 --extra-index-url https://download.pytorch.org/whl/cu121"
+ base: "nvcr.io/nvidia/pytorch:22.10-py3" # CUDA 11.8
+ - environment: PT210+CUDA121
+ pytorch: "pytorch==2.1.0 torchvision==0.16.0 --extra-index-url https://download.pytorch.org/whl/cu121"
+ base: "nvcr.io/nvidia/pytorch:23.08-py3" # CUDA 12.1
+ - environment: PT240+CUDA126
+ pytorch: "pytorch==2.4.0 torchvision==0.19.0 --extra-index-url https://download.pytorch.org/whl/cu121"
+ base: "nvcr.io/nvidia/pytorch:24.08-py3" # CUDA 12.6
+ - environment: PTLATEST+CUDA126
+ pytorch: "-U torch torchvision --extra-index-url https://download.pytorch.org/whl/cu121"
+ base: "nvcr.io/nvidia/pytorch:24.10-py3" # CUDA 12.6
container:
image: ${{ matrix.base }}
options: "--gpus all"
@@ -50,6 +54,10 @@ jobs:
python -m pip install -r requirements-dev.txt
python -m pip list
- name: Run tests report coverage
+ env:
+ NGC_API_KEY: ${{ secrets.NGC_API_KEY }}
+ NGC_ORG: ${{ secrets.NGC_ORG }}
+ NGC_TEAM: ${{ secrets.NGC_TEAM }}
run: |
export LAUNCH_DELAY=$[ $RANDOM % 16 * 60 ]
echo "Sleep $LAUNCH_DELAY"
@@ -76,7 +84,7 @@ jobs:
if: github.repository == 'Project-MONAI/MONAI'
strategy:
matrix:
- container: ["pytorch:22.10", "pytorch:23.08"]
+ container: ["pytorch:23.08", "pytorch:24.08", "pytorch:24.10"]
container:
image: nvcr.io/nvidia/${{ matrix.container }}-py3 # testing with the latest pytorch base image
options: "--gpus all"
@@ -94,6 +102,10 @@ jobs:
python -m pip install -r requirements-dev.txt
python -m pip list
- name: Run tests report coverage
+ env:
+ NGC_API_KEY: ${{ secrets.NGC_API_KEY }}
+ NGC_ORG: ${{ secrets.NGC_ORG }}
+ NGC_TEAM: ${{ secrets.NGC_TEAM }}
run: |
export LAUNCH_DELAY=$[ $RANDOM % 16 * 60 ]
echo "Sleep $LAUNCH_DELAY"
@@ -121,7 +133,7 @@ jobs:
if: github.repository == 'Project-MONAI/MONAI'
strategy:
matrix:
- container: ["pytorch:23.08"]
+ container: ["pytorch:24.10"]
container:
image: nvcr.io/nvidia/${{ matrix.container }}-py3 # testing with the latest pytorch base image
options: "--gpus all"
@@ -196,6 +208,10 @@ jobs:
- name: Run tests report coverage
# The docker image process has done the compilation.
# BUILD_MONAI=1 is necessary for triggering the USE_COMPILED flag.
+ env:
+ NGC_API_KEY: ${{ secrets.NGC_API_KEY }}
+ NGC_ORG: ${{ secrets.NGC_ORG }}
+ NGC_TEAM: ${{ secrets.NGC_TEAM }}
run: |
cd /opt/monai
nvidia-smi
@@ -221,7 +237,7 @@ jobs:
if: github.repository == 'Project-MONAI/MONAI'
needs: cron-gpu # so that monai itself is verified first
container:
- image: nvcr.io/nvidia/pytorch:23.08-py3 # testing with the latest pytorch base image
+ image: nvcr.io/nvidia/pytorch:24.10-py3 # testing with the latest pytorch base image
options: "--gpus all --ipc=host"
runs-on: [self-hosted, linux, x64, integration]
steps:
diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml
index 65716f86f9..17ffe4cf90 100644
--- a/.github/workflows/docker.yml
+++ b/.github/workflows/docker.yml
@@ -100,3 +100,6 @@ jobs:
shell: bash
env:
QUICKTEST: True
+ NGC_API_KEY: ${{ secrets.NGC_API_KEY }}
+ NGC_ORG: ${{ secrets.NGC_ORG }}
+ NGC_TEAM: ${{ secrets.NGC_TEAM }}
diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml
index c82530a551..5be2ebb86c 100644
--- a/.github/workflows/integration.yml
+++ b/.github/workflows/integration.yml
@@ -68,6 +68,9 @@ jobs:
shell: bash
env:
BUILD_MONAI: 1
+ NGC_API_KEY: ${{ secrets.NGC_API_KEY }}
+ NGC_ORG: ${{ secrets.NGC_ORG }}
+ NGC_TEAM: ${{ secrets.NGC_TEAM }}
run: ./runtests.sh --build --net
- name: Add reaction
diff --git a/.github/workflows/pythonapp-gpu.yml b/.github/workflows/pythonapp-gpu.yml
index a6d7981814..70c3153076 100644
--- a/.github/workflows/pythonapp-gpu.yml
+++ b/.github/workflows/pythonapp-gpu.yml
@@ -29,10 +29,6 @@ jobs:
- "PT210+CUDA121DOCKER"
include:
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes
- - environment: PT19+CUDA114DOCKER
- # 21.10: 1.10.0a0+0aef44c
- pytorch: "-h" # we explicitly set pytorch to -h to avoid pip install error
- base: "nvcr.io/nvidia/pytorch:21.10-py3"
- environment: PT110+CUDA111
pytorch: "torch==1.10.2 torchvision==0.11.3 --extra-index-url https://download.pytorch.org/whl/cu111"
base: "nvcr.io/nvidia/cuda:11.1.1-devel-ubuntu18.04"
@@ -47,6 +43,10 @@ jobs:
# 23.08: 2.1.0a0+29c30b1
pytorch: "-h" # we explicitly set pytorch to -h to avoid pip install error
base: "nvcr.io/nvidia/pytorch:23.08-py3"
+ - environment: PT210+CUDA121DOCKER
+ # 24.08: 2.3.0a0+40ec155e58.nv24.3
+ pytorch: "-h" # we explicitly set pytorch to -h to avoid pip install error
+ base: "nvcr.io/nvidia/pytorch:24.08-py3"
container:
image: ${{ matrix.base }}
options: --gpus all --env NVIDIA_DISABLE_REQUIRE=true # workaround for unsatisfied condition: cuda>=11.6
@@ -62,7 +62,7 @@ jobs:
if [ ${{ matrix.environment }} = "PT110+CUDA111" ] || \
[ ${{ matrix.environment }} = "PT113+CUDA116" ]
then
- PYVER=3.8 PYSFX=3 DISTUTILS=python3-distutils && \
+ PYVER=3.9 PYSFX=3 DISTUTILS=python3-distutils && \
apt-get update && apt-get install -y --no-install-recommends \
curl \
pkg-config \
diff --git a/.github/workflows/pythonapp-min.yml b/.github/workflows/pythonapp-min.yml
index bbe7579774..b0d37937e9 100644
--- a/.github/workflows/pythonapp-min.yml
+++ b/.github/workflows/pythonapp-min.yml
@@ -9,6 +9,8 @@ on:
- main
- releasing/*
pull_request:
+ head_ref-ignore:
+ - dev
concurrency:
# automatically cancel the previously triggered workflows when there's a newer version
@@ -29,10 +31,10 @@ jobs:
timeout-minutes: 40
steps:
- uses: actions/checkout@v4
- - name: Set up Python 3.8
+ - name: Set up Python 3.9
uses: actions/setup-python@v5
with:
- python-version: '3.8'
+ python-version: '3.9'
- name: Prepare pip wheel
run: |
which python
@@ -65,13 +67,16 @@ jobs:
shell: bash
env:
QUICKTEST: True
+ NGC_API_KEY: ${{ secrets.NGC_API_KEY }}
+ NGC_ORG: ${{ secrets.NGC_ORG }}
+ NGC_TEAM: ${{ secrets.NGC_TEAM }}
min-dep-py3: # min dependencies installed tests for different python
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
- python-version: ['3.8', '3.9', '3.10', '3.11']
+ python-version: ['3.9', '3.10', '3.11', '3.12']
timeout-minutes: 40
steps:
- uses: actions/checkout@v4
@@ -110,20 +115,23 @@ jobs:
./runtests.sh --min
env:
QUICKTEST: True
+ NGC_API_KEY: ${{ secrets.NGC_API_KEY }}
+ NGC_ORG: ${{ secrets.NGC_ORG }}
+ NGC_TEAM: ${{ secrets.NGC_TEAM }}
min-dep-pytorch: # min dependencies installed tests for different pytorch
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
- pytorch-version: ['1.9.1', '1.10.2', '1.11.0', '1.12.1', '1.13', 'latest']
+ pytorch-version: ['1.10.2', '1.11.0', '1.12.1', '1.13', '2.0.1', 'latest']
timeout-minutes: 40
steps:
- uses: actions/checkout@v4
- - name: Set up Python 3.8
+ - name: Set up Python 3.9
uses: actions/setup-python@v5
with:
- python-version: '3.8'
+ python-version: '3.9'
- name: Prepare pip wheel
run: |
which python
@@ -159,3 +167,6 @@ jobs:
./runtests.sh --min
env:
QUICKTEST: True
+ NGC_API_KEY: ${{ secrets.NGC_API_KEY }}
+ NGC_ORG: ${{ secrets.NGC_ORG }}
+ NGC_TEAM: ${{ secrets.NGC_TEAM }}
diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml
index b7f2cfb9db..3c39166c1e 100644
--- a/.github/workflows/pythonapp.yml
+++ b/.github/workflows/pythonapp.yml
@@ -9,6 +9,8 @@ on:
- main
- releasing/*
pull_request:
+ head_ref-ignore:
+ - dev
concurrency:
# automatically cancel the previously triggered workflows when there's a newer version
@@ -27,10 +29,10 @@ jobs:
opt: ["codeformat", "pytype", "mypy"]
steps:
- uses: actions/checkout@v4
- - name: Set up Python 3.8
+ - name: Set up Python 3.9
uses: actions/setup-python@v5
with:
- python-version: '3.8'
+ python-version: '3.9'
- name: cache weekly timestamp
id: pip-cache
run: |
@@ -43,6 +45,7 @@ jobs:
key: ${{ runner.os }}-pip-${{ steps.pip-cache.outputs.datew }}
- name: Install dependencies
run: |
+ find /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \;
python -m pip install --upgrade pip wheel
python -m pip install -r requirements-dev.txt
- name: Lint and type check
@@ -68,10 +71,10 @@ jobs:
maximum-size: 16GB
disk-root: "D:"
- uses: actions/checkout@v4
- - name: Set up Python 3.8
+ - name: Set up Python 3.9
uses: actions/setup-python@v5
with:
- python-version: '3.8'
+ python-version: '3.9'
- name: Prepare pip wheel
run: |
which python
@@ -96,8 +99,10 @@ jobs:
name: Install itk pre-release (Linux only)
run: |
python -m pip install --pre -U itk
+ find /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \;
- name: Install the dependencies
run: |
+ python -m pip install --user --upgrade pip wheel
python -m pip install torch==1.13.1 torchvision==0.14.1
cat "requirements-dev.txt"
python -m pip install -r requirements-dev.txt
@@ -127,10 +132,10 @@ jobs:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- - name: Set up Python 3.8
+ - name: Set up Python 3.9
uses: actions/setup-python@v5
with:
- python-version: '3.8'
+ python-version: '3.9'
- name: cache weekly timestamp
id: pip-cache
run: |
@@ -145,7 +150,8 @@ jobs:
key: ${{ runner.os }}-pip-${{ steps.pip-cache.outputs.datew }}
- name: Install dependencies
run: |
- python -m pip install --user --upgrade pip setuptools wheel twine
+ find /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \;
+ python -m pip install --user --upgrade pip setuptools wheel twine packaging
# install the latest pytorch for testing
# however, "pip install monai*.tar.gz" will build cpp/cuda with an isolated
# fresh torch installation according to pyproject.toml
@@ -208,10 +214,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- - name: Set up Python 3.8
+ - name: Set up Python 3.9
uses: actions/setup-python@v5
with:
- python-version: '3.8'
+ python-version: '3.9'
- name: cache weekly timestamp
id: pip-cache
run: |
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index c134724665..cb0e109bb7 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -24,7 +24,7 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install setuptools
run: |
- python -m pip install --user --upgrade setuptools wheel
+ python -m pip install --user --upgrade setuptools wheel packaging
- name: Build and test source archive and wheel file
run: |
find /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \;
@@ -78,12 +78,13 @@ jobs:
rm dist/monai*.tar.gz
ls -al dist/
- - if: matrix.python-version == '3.9' && startsWith(github.ref, 'refs/tags/')
- name: Publish to Test PyPI
- uses: pypa/gh-action-pypi-publish@release/v1
- with:
- password: ${{ secrets.TEST_PYPI }}
- repository-url: https://test.pypi.org/legacy/
+ # remove publishing to Test PyPI as it is moved to blossom
+ # - if: matrix.python-version == '3.9' && startsWith(github.ref, 'refs/tags/')
+ # name: Publish to Test PyPI
+ # uses: pypa/gh-action-pypi-publish@release/v1
+ # with:
+ # password: ${{ secrets.TEST_PYPI }}
+ # repository-url: https://test.pypi.org/legacy/
versioning:
# compute versioning file from python setup.py
@@ -104,7 +105,7 @@ jobs:
run: |
find /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \;
git describe
- python -m pip install --user --upgrade setuptools wheel
+ python -m pip install --user --upgrade setuptools wheel packaging
python setup.py build
cat build/lib/monai/_version.py
- name: Upload version
@@ -119,7 +120,8 @@ jobs:
rm -rf {*,.[^.]*}
release_tag_docker:
- if: github.repository == 'Project-MONAI/MONAI'
+ # if: github.repository == 'Project-MONAI/MONAI'
+ if: ${{ false }}
needs: versioning
runs-on: ubuntu-latest
steps:
diff --git a/.github/workflows/setupapp.yml b/.github/workflows/setupapp.yml
index c6ad243b81..7e01f55cd9 100644
--- a/.github/workflows/setupapp.yml
+++ b/.github/workflows/setupapp.yml
@@ -49,6 +49,10 @@ jobs:
python -m pip install --upgrade torch torchvision
python -m pip install -r requirements-dev.txt
- name: Run unit tests report coverage
+ env:
+ NGC_API_KEY: ${{ secrets.NGC_API_KEY }}
+ NGC_ORG: ${{ secrets.NGC_ORG }}
+ NGC_TEAM: ${{ secrets.NGC_TEAM }}
run: |
python -m pip list
git config --global --add safe.directory /__w/MONAI/MONAI
@@ -77,7 +81,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
- python-version: ['3.8', '3.9', '3.10']
+ python-version: ['3.9', '3.10', '3.11']
steps:
- uses: actions/checkout@v4
with:
@@ -104,6 +108,10 @@ jobs:
python -m pip install --upgrade pip wheel
python -m pip install -r requirements-dev.txt
- name: Run quick tests CPU ubuntu
+ env:
+ NGC_API_KEY: ${{ secrets.NGC_API_KEY }}
+ NGC_ORG: ${{ secrets.NGC_ORG }}
+ NGC_TEAM: ${{ secrets.NGC_TEAM }}
run: |
python -m pip list
python -c 'import torch; print(torch.__version__); print(torch.rand(5,3))'
@@ -119,10 +127,10 @@ jobs:
install: # pip install from github url, the default branch is dev
runs-on: ubuntu-latest
steps:
- - name: Set up Python 3.8
+ - name: Set up Python 3.9
uses: actions/setup-python@v5
with:
- python-version: '3.8'
+ python-version: '3.9'
- name: cache weekly timestamp
id: pip-cache
run: |
diff --git a/.github/workflows/weekly-preview.yml b/.github/workflows/weekly-preview.yml
index e94e1dac5a..f89e0a11c4 100644
--- a/.github/workflows/weekly-preview.yml
+++ b/.github/workflows/weekly-preview.yml
@@ -5,6 +5,39 @@ on:
- cron: "0 2 * * 0" # 02:00 of every Sunday
jobs:
+ flake8-py3:
+ runs-on: ubuntu-latest
+ strategy:
+ matrix:
+ opt: ["codeformat", "pytype", "mypy"]
+ steps:
+ - uses: actions/checkout@v4
+ - name: Set up Python 3.9
+ uses: actions/setup-python@v5
+ with:
+ python-version: '3.9'
+ - name: cache weekly timestamp
+ id: pip-cache
+ run: |
+ echo "datew=$(date '+%Y-%V')" >> $GITHUB_OUTPUT
+ - name: cache for pip
+ uses: actions/cache@v4
+ id: cache
+ with:
+ path: ~/.cache/pip
+ key: ${{ runner.os }}-pip-${{ steps.pip-cache.outputs.datew }}
+ - name: Install dependencies
+ run: |
+ find /opt/hostedtoolcache/* -maxdepth 0 ! -name 'Python' -exec rm -rf {} \;
+ python -m pip install --upgrade pip wheel
+ python -m pip install -r requirements-dev.txt
+ - name: Lint and type check
+ run: |
+ # clean up temporary files
+ $(pwd)/runtests.sh --build --clean
+ # Github actions have 2 cores, so parallelize pytype
+ $(pwd)/runtests.sh --build --${{ matrix.opt }} -j 2
+
packaging:
if: github.repository == 'Project-MONAI/MONAI'
runs-on: ubuntu-latest
@@ -19,7 +52,7 @@ jobs:
python-version: '3.9'
- name: Install setuptools
run: |
- python -m pip install --user --upgrade setuptools wheel
+ python -m pip install --user --upgrade setuptools wheel packaging
- name: Build distribution
run: |
export HEAD_COMMIT_ID=$(git rev-parse HEAD)
@@ -33,7 +66,7 @@ jobs:
export YEAR_WEEK=$(date +'%y%U')
echo "Year week for tag is ${YEAR_WEEK}"
if ! [[ $YEAR_WEEK =~ ^[0-9]{4}$ ]] ; then echo "Wrong 'year week' format. Should be 4 digits."; exit 1 ; fi
- git tag "1.4.dev${YEAR_WEEK}"
+ git tag "1.5.dev${YEAR_WEEK}"
git log -1
git tag --list
python setup.py sdist bdist_wheel
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 14b41bbeb8..2a57fbf31a 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -9,7 +9,7 @@ ci:
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
- rev: v4.4.0
+ rev: v5.0.0
hooks:
- id: end-of-file-fixer
- id: trailing-whitespace
@@ -26,39 +26,36 @@ repos:
args: ['--autofix', '--no-sort-keys', '--indent=4']
- id: end-of-file-fixer
- id: mixed-line-ending
- - repo: https://github.com/charliermarsh/ruff-pre-commit
- rev: v0.0.261
+ - repo: https://github.com/astral-sh/ruff-pre-commit
+ rev: v0.7.0
hooks:
- id: ruff
args:
- --fix
- repo: https://github.com/asottile/pyupgrade
- rev: v3.3.1
+ rev: v3.19.0
hooks:
- id: pyupgrade
- args: [--py37-plus]
- name: Upgrade code excluding monai networks
+ args: [--py39-plus, --keep-runtime-typing]
+ name: Upgrade code with exceptions
exclude: |
(?x)(
^versioneer.py|
^monai/_version.py|
- ^monai/networks/| # no PEP 604 for torchscript tensorrt
- ^monai/losses/ # no PEP 604 for torchscript tensorrt
+ ^monai/networks/| # avoid typing rewrites
+ ^monai/apps/detection/utils/anchor_utils.py| # avoid typing rewrites
+ ^tests/test_compute_panoptic_quality.py # avoid typing rewrites
)
- - id: pyupgrade
- args: [--py37-plus, --keep-runtime-typing]
- name: Upgrade monai networks
- files: (?x)(^monai/networks/)
- repo: https://github.com/asottile/yesqa
- rev: v1.4.0
+ rev: v1.5.0
hooks:
- id: yesqa
name: Unused noqa
additional_dependencies:
- flake8>=3.8.1
- - flake8-bugbear
+ - flake8-bugbear<=24.2.6
- flake8-comprehensions
- pep8-naming
exclude: |
@@ -69,7 +66,7 @@ repos:
)$
- repo: https://github.com/hadialqattan/pycln
- rev: v2.1.3
+ rev: v2.4.0
hooks:
- id: pycln
args: [--config=pyproject.toml]
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 61be8f07c1..ffd773727f 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -5,6 +5,195 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
## [Unreleased]
+## [1.4.0] - 2024-10-17
+## What's Changed
+### Added
+* Implemented Conjugate Gradient Solver to generate confidence maps. (#7876)
+* Added norm parameter to `ResNet` (#7752, #7805)
+* Introduced alpha parameter to `DiceFocalLoss` for improved flexibility (#7841)
+* Integrated Tailored ControlNet Implementations (#7875)
+* Integrated Tailored Auto-Encoder Model (#7861)
+* Integrated Tailored Diffusion U-Net Model (7867)
+* Added Maisi morphological functions (#7893)
+* Added support for downloading bundles from NGC private registry (#7907, #7929, #8076)
+* Integrated generative refactor into the core (#7886, #7962)
+* Made `ViT` and `UNETR` models compatible with TorchScript (#7937)
+* Implemented post-download checks for MONAI bundles and compatibility warnings (#7938)
+* Added NGC prefix argument when downloading bundles (#7974)
+* Added flash attention support in the attention block for improved performance (#7977)
+* Enhanced `MLPBlock` for compatibility with VISTA-3D (#7995)
+* Added support for Neighbor-Aware Calibration Loss (NACL) for calibrated models in segmentation tasks (#7819)
+* Added label_smoothing parameter to `DiceCELoss` for enhanced model calibration (#8000)
+* Add `include_fc` and `use_combined_linear` argument in the `SABlock` (#7996)
+* Added utilities, networks, and an inferer specific to VISTA-3D (#7999, #7987, #8047, #8059, #8021)
+* Integrated a new network, `CellSamWrapper`, for cell-based applications (#7981)
+* Introduced `WriteFileMapping` transform to map between input image paths and their corresponding output paths (#7769)
+* Added `TrtHandler` to accelerate models using TensorRT (#7990, #8064)
+* Added box and points conversion transforms for more flexible spatial manipulation (#8053)
+* Enhanced `RandSimulateLowResolutiond` transform with deterministic support (#8057)
+* Added a contiguous argument to the `Fourier` class to facilitate contiguous tensor outputs (#7969)
+* Allowed `ApplyTransformToPointsd` to receive a sequence of reference keys for more versatile point manipulation (#8063)
+* Made `MetaTensor` an optional print in `DataStats` and `DataStatsd` for more concise logging (#7814)
+#### misc.
+* Refactored Dataset to utilize Compose for handling transforms. (#7784)
+* Combined `map_classes_to_indices` and `generate_label_classes_crop_centers` into a unified function (#7712)
+* Introduced metadata schema directly into the codebase for improved structure and validation (#7409)
+* Renamed `optional_packages_version` to `required_packages_version` for clearer package dependency management. (#7253)
+* Replaced `pkg_resources` with the more modern packaging module for package handling (#7953)
+* Refactored MAISI-related networks to align with the new generative components (#7989, #7993, #8005)
+* Added a badge displaying monthly download statistics to enhance project visibility (#7891)
+### Fixed
+#### transforms
+* Ensured deterministic behavior in `MixUp`, `CutMix`, and `CutOut` transforms (#7813)
+* Applied a minor correction to `AsDiscrete` transform (#7984)
+* Fixed handling of integer weightmaps in `RandomWeightedCrop` (#8097)
+* Resolved data type bug in `ScaleIntensityRangePercentile` (#8109)
+#### data
+* Fixed negative strides issue in the `NrrdReader` (#7809)
+* Addressed wsireader issue with retrieving MPP (7921)
+* Ensured location is returned as a tuple in wsireader (#8007)
+* Corrected interpretation of space directions in nrrd reader (#8091)
+#### metrics and losses
+* Improved memory management for `NACLLoss` (#8020)
+* Fixed reduction logic in `GeneralizedDiceScore` (#7970)
+#### networks
+* Resolved issue with loading pre-trained weights in `ResNet` (#7924)
+* Fixed error where `torch.device` object had no attribute gpu_id during TensorRT export (#8019)
+* Corrected function for loading older weights in `DiffusionModelUNet` (#8031)
+* Switched to `torch_tensorrt.Device` instead of `torch.device` during TensorRT compilation (#8051)
+#### engines and handlers
+* Attempted to resolve the "experiment already exists" issue in `MLFlowHandler` (#7916)
+* Refactored the model export process for conversion and saving (#7934)
+#### misc.
+* Adjusted requirements to exclude Numpy version 2.0 (#7859)
+* Updated deprecated `scipy.ndimage` namespaces in optional imports (#7847, #7897)
+* Resolved `load_module()` deprecation in Python 3.12 (#7881)
+* Fixed Ruff type check issues (#7885)
+* Cleaned disk space in the conda test pipeline (#7902)
+* Replaced deprecated `pkgutil.find_loader` usage (#7906)
+* Enhanced docstrings in various modules (#7913, #8055)
+* Test cases fixing (#7905, #7794, #7808)
+* Fix mypy issue introduced in 1.11.0 (#7941)
+* Cleaned up warnings during test collection (#7914)
+* Fix incompatible types in assignment issue (#7950)
+* Fix outdated link in the docs (#7971)
+* Addressed CI issues (#7983, #8013)
+* Fix module can not import correctly issue (#8015)
+* Fix AttributeError when using torch.min and max (#8041)
+* Ensure synchronization by adding `cuda.synchronize` (#8058)
+* Ignore warning from nptyping as workaround (#8062)
+* Suppress deprecated warning when importing monai (#8067)
+* Fix link in test bundle under MONAI-extra-test-data (#8092)
+### Changed
+* Base Docker image upgraded to `nvcr.io/nvidia/pytorch:24.08-py3` from `nvcr.io/nvidia/pytorch:23.08-py3`
+* Change blossom-ci to ACL security format (#7843)
+* Move PyType test to weekly test (#8025)
+* Adjusted to meet Numpy 2.0 requirements (#7857)
+### Deprecated
+* Dropped support for Python 3.8 (#7909)
+* Remove deprecated arguments and class for v1.4 (#8079)
+### Removed
+* Remove use of deprecated python 3.12 strtobool (#7900)
+* Removed the pipeline for publishing to testpypi (#8086)
+* Cleaning up some very old and now obsolete infrastructure (#8113, #8118, #8121)
+
+## [1.3.2] - 2024-06-25
+### Fixed
+#### misc.
+* Updated Numpy version constraint to < 2.0 (#7859)
+
+## [1.3.1] - 2024-05-17
+### Added
+* Support for `by_measure` argument in `RemoveSmallObjects` (#7137)
+* Support for `pretrained` flag in `ResNet` (#7095)
+* Support for uploading and downloading bundles to and from the Hugging Face Hub (#6454)
+* Added weight parameter in DiceLoss to apply weight to voxels of each class (#7158)
+* Support for returning dice for each class in `DiceMetric` (#7163)
+* Introduced `ComponentStore` for storage purposes (#7159)
+* Added utilities used in MONAI Generative (#7134)
+* Enabled Python 3.11 support for `convert_to_torchscript` and `convert_to_onnx` (#7182)
+* Support for MLflow in `AutoRunner` (#7176)
+* `fname_regex` option in PydicomReader (#7181)
+* Allowed setting AutoRunner parameters from config (#7175)
+* `VoxelMorphUNet` and `VoxelMorph` (#7178)
+* Enabled `cache` option in `GridPatchDataset` (#7180)
+* Introduced `class_labels` option in `write_metrics_reports` for improved readability (#7249)
+* `DiffusionLoss` for image registration task (#7272)
+* Supported specifying `filename` in `Saveimage` (#7318)
+* Compile support in `SupervisedTrainer` and `SupervisedEvaluator` (#7375)
+* `mlflow_experiment_name` support in `Auto3DSeg` (#7442)
+* Arm support (#7500)
+* `BarlowTwinsLoss` for representation learning (#7530)
+* `SURELoss` and `ConjugateGradient` for diffusion models (#7308)
+* Support for `CutMix`, `CutOut`, and `MixUp` augmentation techniques (#7198)
+* `meta_file` and `logging_file` options to `BundleWorkflow` (#7549)
+* `properties_path` option to `BundleWorkflow` for customized properties (#7542)
+* Support for both soft and hard clipping in `ClipIntensityPercentiles` (#7535)
+* Support for not saving artifacts in `MLFlowHandler` (#7604)
+* Support for multi-channel images in `PerceptualLoss` (#7568)
+* Added ResNet backbone for `FlexibleUNet` (#7571)
+* Introduced `dim_head` option in `SABlock` to set dimensions for each head (#7664)
+* Direct links to github source code to docs (#7738, #7779)
+#### misc.
+* Refactored `list_data_collate` and `collate_meta_tensor` to utilize the latest PyTorch API (#7165)
+* Added __str__ method in `Metric` base class (#7487)
+* Made enhancements for testing files (#7662, #7670, #7663, #7671, #7672)
+* Improved documentation for bundles (#7116)
+### Fixed
+#### transforms
+* Addressed issue where lazy mode was ignored in `SpatialPadd` (#7316)
+* Tracked applied operations in `ImageFilter` (#7395)
+* Warnings are now given only if missing class is not set to 0 in `generate_label_classes_crop_centers` (#7602)
+* Input is now always converted to C-order in `distance_transform_edt` to ensure consistent behavior (#7675)
+#### data
+* Modified .npz file behavior to use keys in `NumpyReader` (#7148)
+* Handled corrupted cached files in `PersistentDataset` (#7244)
+* Corrected affine update in `NrrdReader` (#7415)
+#### metrics and losses
+* Addressed precision issue in `get_confusion_matrix` (#7187)
+* Harmonized and clarified documentation and tests for dice losses variants (#7587)
+#### networks
+* Removed hard-coded `spatial_dims` in `SwinTransformer` (#7302)
+* Fixed learnable `position_embeddings` in `PatchEmbeddingBlock` (#7564, #7605)
+* Removed `memory_pool_limit` in TRT config (#7647)
+* Propagated `kernel_size` to `ConvBlocks` within `AttentionUnet` (#7734)
+* Addressed hard-coded activation layer in `ResNet` (#7749)
+#### bundle
+* Resolved bundle download issue (#7280)
+* Updated `bundle_root` directory for `NNIGen` (#7586)
+* Checked for `num_fold` and failed early if incorrect (#7634)
+* Enhanced logging logic in `ConfigWorkflow` (#7745)
+#### misc.
+* Enabled chaining in `Auto3DSeg` CLI (#7168)
+* Addressed useless error message in `nnUNetV2Runner` (#7217)
+* Resolved typing and deprecation issues in Mypy (#7231)
+* Quoted `$PY_EXE` variable to handle Python path that contains spaces in Bash (#7268)
+* Improved documentation, code examples, and warning messages in various modules (#7234, #7213, #7271, #7326, #7569, #7584)
+* Fixed typos in various modules (#7321, #7322, #7458, #7595, #7612)
+* Enhanced docstrings in various modules (#7245, #7381, #7746)
+* Handled error when data is on CPU in `DataAnalyzer` (#7310)
+* Updated version requirements for third-party packages (#7343, #7344, #7384, #7448, #7659, #7704, #7744, #7742, #7780)
+* Addressed incorrect slice compute in `ImageStats` (#7374)
+* Avoided editing a loop's mutable iterable to address B308 (#7397)
+* Fixed issue with `CUDA_VISIBLE_DEVICES` setting being ignored (#7408, #7581)
+* Avoided changing Python version in CICD (#7424)
+* Renamed partial to callable in instantiate mode (#7413)
+* Imported AttributeError for Python 3.12 compatibility (#7482)
+* Updated `nnUNetV2Runner` to support nnunetv2 2.2 (#7483)
+* Used uint8 instead of int8 in `LabelStats` (#7489)
+* Utilized subprocess for nnUNet training (#7576)
+* Addressed deprecated warning in ruff (#7625)
+* Fixed downloading failure on FIPS machine (#7698)
+* Updated `torch_tensorrt` compile parameters to avoid warning (#7714)
+* Restrict `Auto3DSeg` fold input based on datalist (#7778)
+### Changed
+* Base Docker image upgraded to `nvcr.io/nvidia/pytorch:24.03-py3` from `nvcr.io/nvidia/pytorch:23.08-py3`
+### Removed
+* Removed unrecommended star-arg unpacking after a keyword argument, addressed B026 (#7262)
+* Skipped old PyTorch version test for `SwinUNETR` (#7266)
+* Dropped docker build workflow and migrated to Nvidia Blossom system (#7450)
+* Dropped Python 3.8 test on quick-py3 workflow (#7719)
+
## [1.3.0] - 2023-10-12
### Added
* Intensity transforms `ScaleIntensityFixedMean` and `RandScaleIntensityFixedMean` (#6542)
@@ -943,7 +1132,10 @@ the postprocessing steps should be used before calling the metrics methods
[highlights]: https://github.com/Project-MONAI/MONAI/blob/master/docs/source/highlights.md
-[Unreleased]: https://github.com/Project-MONAI/MONAI/compare/1.3.0...HEAD
+[Unreleased]: https://github.com/Project-MONAI/MONAI/compare/1.4.0...HEAD
+[1.4.0]: https://github.com/Project-MONAI/MONAI/compare/1.3.2...1.4.0
+[1.3.2]: https://github.com/Project-MONAI/MONAI/compare/1.3.1...1.3.2
+[1.3.1]: https://github.com/Project-MONAI/MONAI/compare/1.3.0...1.3.1
[1.3.0]: https://github.com/Project-MONAI/MONAI/compare/1.2.0...1.3.0
[1.2.0]: https://github.com/Project-MONAI/MONAI/compare/1.1.0...1.2.0
[1.1.0]: https://github.com/Project-MONAI/MONAI/compare/1.0.1...1.1.0
diff --git a/CITATION.cff b/CITATION.cff
index cac47faae4..86b147ce84 100644
--- a/CITATION.cff
+++ b/CITATION.cff
@@ -6,8 +6,8 @@ title: "MONAI: Medical Open Network for AI"
abstract: "AI Toolkit for Healthcare Imaging"
authors:
- name: "MONAI Consortium"
-date-released: 2023-10-12
-version: "1.3.0"
+date-released: 2024-10-17
+version: "1.4.0"
identifiers:
- description: "This DOI represents all versions of MONAI, and will always resolve to the latest one."
type: doi
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 5c886cff30..e780f26420 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -123,7 +123,7 @@ or (for new features that would not break existing functionality):
```
It is recommended that the new test `test_[module_name].py` is constructed by using only
-python 3.8+ build-in functions, `torch`, `numpy`, `coverage` (for reporting code coverages) and `parameterized` (for organising test cases) packages.
+python 3.9+ build-in functions, `torch`, `numpy`, `coverage` (for reporting code coverages) and `parameterized` (for organising test cases) packages.
If it requires any other external packages, please make sure:
- the packages are listed in [`requirements-dev.txt`](requirements-dev.txt)
- the new test `test_[module_name].py` is added to the `exclude_cases` in [`./tests/min_tests.py`](./tests/min_tests.py) so that
@@ -383,3 +383,8 @@ then make PRs to the `releasing/[version number]` to fix the bugs via the regula
If any error occurs after the release process, first check out a new hotfix branch from the `main` branch,
make a patch version release following the semantic versioning, for example, `releasing/0.1.1`.
Make sure the `releasing/0.1.1` is merged back into both `dev` and `main` and all the test pipelines succeed.
+
+
+
+ ⬆️ Back to Top
+
diff --git a/Dockerfile b/Dockerfile
index 7383837585..5fcfcf274d 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -11,14 +11,18 @@
# To build with a different base image
# please run `docker build` using the `--build-arg PYTORCH_IMAGE=...` flag.
-ARG PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:23.08-py3
+ARG PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:24.10-py3
FROM ${PYTORCH_IMAGE}
LABEL maintainer="monai.contact@gmail.com"
# TODO: remark for issue [revise the dockerfile](https://github.com/zarr-developers/numcodecs/issues/431)
-WORKDIR /opt
-RUN git clone --recursive https://github.com/zarr-developers/numcodecs.git && pip wheel numcodecs
+RUN if [[ $(uname -m) =~ "aarch64" ]]; then \
+ export CFLAGS="-O3" && \
+ export DISABLE_NUMCODECS_SSE2=true && \
+ export DISABLE_NUMCODECS_AVX2=true && \
+ pip install numcodecs; \
+ fi
WORKDIR /opt/monai
@@ -52,4 +56,5 @@ RUN apt-get update \
&& rm -rf /var/lib/apt/lists/*
# append /opt/tools to runtime path for NGC CLI to be accessible from all file system locations
ENV PATH=${PATH}:/opt/tools
+ENV POLYGRAPHY_AUTOINSTALL_DEPS=1
WORKDIR /opt/monai
diff --git a/README.md b/README.md
index 7565fea1b7..e5607ccb02 100644
--- a/README.md
+++ b/README.md
@@ -12,15 +12,15 @@
[![premerge](https://github.com/Project-MONAI/MONAI/actions/workflows/pythonapp.yml/badge.svg?branch=dev)](https://github.com/Project-MONAI/MONAI/actions/workflows/pythonapp.yml)
[![postmerge](https://img.shields.io/github/checks-status/project-monai/monai/dev?label=postmerge)](https://github.com/Project-MONAI/MONAI/actions?query=branch%3Adev)
-[![docker](https://github.com/Project-MONAI/MONAI/actions/workflows/docker.yml/badge.svg?branch=dev)](https://github.com/Project-MONAI/MONAI/actions/workflows/docker.yml)
[![Documentation Status](https://readthedocs.org/projects/monai/badge/?version=latest)](https://docs.monai.io/en/latest/)
[![codecov](https://codecov.io/gh/Project-MONAI/MONAI/branch/dev/graph/badge.svg?token=6FTC7U1JJ4)](https://codecov.io/gh/Project-MONAI/MONAI)
+[![monai Downloads Last Month](https://assets.piptrends.com/get-last-month-downloads-badge/monai.svg 'monai Downloads Last Month by pip Trends')](https://piptrends.com/package/monai)
-MONAI is a [PyTorch](https://pytorch.org/)-based, [open-source](https://github.com/Project-MONAI/MONAI/blob/dev/LICENSE) framework for deep learning in healthcare imaging, part of [PyTorch Ecosystem](https://pytorch.org/ecosystem/).
-Its ambitions are:
-- developing a community of academic, industrial and clinical researchers collaborating on a common foundation;
-- creating state-of-the-art, end-to-end training workflows for healthcare imaging;
-- providing researchers with the optimized and standardized way to create and evaluate deep learning models.
+MONAI is a [PyTorch](https://pytorch.org/)-based, [open-source](https://github.com/Project-MONAI/MONAI/blob/dev/LICENSE) framework for deep learning in healthcare imaging, part of the [PyTorch Ecosystem](https://pytorch.org/ecosystem/).
+Its ambitions are as follows:
+- Developing a community of academic, industrial and clinical researchers collaborating on a common foundation;
+- Creating state-of-the-art, end-to-end training workflows for healthcare imaging;
+- Providing researchers with the optimized and standardized way to create and evaluate deep learning models.
## Features
diff --git a/docs/images/maisi_train.png b/docs/images/maisi_train.png
new file mode 100644
index 0000000000..8c4936456d
Binary files /dev/null and b/docs/images/maisi_train.png differ
diff --git a/docs/images/python.svg b/docs/images/python.svg
index b7aa7c59bd..8ef6b61c03 100644
--- a/docs/images/python.svg
+++ b/docs/images/python.svg
@@ -1 +1 @@
-
+
diff --git a/docs/images/vista2d.png b/docs/images/vista2d.png
new file mode 100644
index 0000000000..5d09c1a275
Binary files /dev/null and b/docs/images/vista2d.png differ
diff --git a/docs/images/vista3d.png b/docs/images/vista3d.png
new file mode 100644
index 0000000000..c8a94fbecd
Binary files /dev/null and b/docs/images/vista3d.png differ
diff --git a/docs/requirements.txt b/docs/requirements.txt
index e5bedf8552..7307d8e5f9 100644
--- a/docs/requirements.txt
+++ b/docs/requirements.txt
@@ -6,7 +6,7 @@ itk>=5.2
nibabel
parameterized
scikit-image>=0.19.0
-scipy>=1.7.1
+scipy>=1.12.0; python_version >= '3.9'
tensorboard
commonmark==0.9.1
recommonmark==0.6.0
@@ -21,8 +21,8 @@ sphinxcontrib-serializinghtml
sphinx-autodoc-typehints==1.11.1
pandas
einops
-transformers<4.22; python_version <= '3.10' # https://github.com/Project-MONAI/MONAI/issues/5157
-mlflow>=1.28.0
+transformers>=4.36.0, <4.41.0; python_version <= '3.10'
+mlflow>=2.12.2
clearml>=1.10.0rc0
tensorboardX
imagecodecs; platform_system == "Linux" or platform_system == "Darwin"
@@ -40,3 +40,6 @@ onnx>=1.13.0
onnxruntime; python_version <= '3.10'
zarr
huggingface_hub
+pyamg>=5.0.0
+packaging
+polygraphy
diff --git a/docs/source/apps.rst b/docs/source/apps.rst
index 7fa7b9e9ff..cc4cea8c1e 100644
--- a/docs/source/apps.rst
+++ b/docs/source/apps.rst
@@ -248,6 +248,22 @@ FastMRIReader
~~~~~~~~~~~~~
.. autofunction:: monai.apps.reconstruction.complex_utils.complex_conj
+`Vista3d`
+---------
+.. automodule:: monai.apps.vista3d.inferer
+.. autofunction:: point_based_window_inferer
+
+.. automodule:: monai.apps.vista3d.transforms
+.. autoclass:: VistaPreTransformd
+ :members:
+.. autoclass:: VistaPostTransformd
+ :members:
+.. autoclass:: Relabeld
+ :members:
+
+.. automodule:: monai.apps.vista3d.sampler
+.. autofunction:: sample_prompt_pairs
+
`Auto3DSeg`
-----------
.. automodule:: monai.apps.auto3dseg
diff --git a/docs/source/conf.py b/docs/source/conf.py
index fdb10fbe03..a91f38081f 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -13,6 +13,8 @@
import os
import subprocess
import sys
+import importlib
+import inspect
sys.path.insert(0, os.path.abspath(".."))
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
@@ -95,7 +97,7 @@ def generate_apidocs(*args):
"sphinx.ext.mathjax",
"sphinx.ext.napoleon",
"sphinx.ext.autodoc",
- "sphinx.ext.viewcode",
+ "sphinx.ext.linkcode",
"sphinx.ext.autosectionlabel",
"sphinx.ext.autosummary",
"sphinx_autodoc_typehints",
@@ -137,8 +139,8 @@ def generate_apidocs(*args):
"github_user": "Project-MONAI",
"github_repo": "MONAI",
"github_version": "dev",
- "doc_path": "docs/",
- "conf_py_path": "/docs/",
+ "doc_path": "docs/source",
+ "conf_py_path": "/docs/source",
"VERSION": version,
}
html_scaled_image_link = False
@@ -162,3 +164,60 @@ def setup(app):
# Hook to allow for automatic generation of API docs
# before doc deployment begins.
app.connect("builder-inited", generate_apidocs)
+
+
+# -- Linkcode configuration --------------------------------------------------
+DEFAULT_REF = "dev"
+read_the_docs_ref = os.environ.get("READTHEDOCS_GIT_IDENTIFIER", None)
+if read_the_docs_ref:
+ # When building on ReadTheDocs, link to the specific commit
+ # https://docs.readthedocs.io/en/stable/reference/environment-variables.html#envvar-READTHEDOCS_GIT_IDENTIFIER
+ git_ref = read_the_docs_ref
+elif os.environ.get("GITHUB_REF_TYPE", "branch") == "tag":
+ # When building a tag, link to the tag itself
+ git_ref = os.environ.get("GITHUB_REF", DEFAULT_REF)
+else:
+ git_ref = os.environ.get("GITHUB_SHA", DEFAULT_REF)
+
+DEFAULT_REPOSITORY = "Project-MONAI/MONAI"
+repository = os.environ.get("GITHUB_REPOSITORY", DEFAULT_REPOSITORY)
+
+base_code_url = f"https://github.com/{repository}/blob/{git_ref}"
+MODULE_ROOT_FOLDER = "monai"
+repo_root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
+
+
+# Adjusted from https://github.com/python-websockets/websockets/blob/main/docs/conf.py
+def linkcode_resolve(domain, info):
+ if domain != "py":
+ raise ValueError(
+ f"expected domain to be 'py', got {domain}."
+ "Please adjust linkcode_resolve to either handle this domain or ignore it."
+ )
+
+ mod = importlib.import_module(info["module"])
+ if "." in info["fullname"]:
+ objname, attrname = info["fullname"].split(".")
+ obj = getattr(mod, objname)
+ try:
+ # object is a method of a class
+ obj = getattr(obj, attrname)
+ except AttributeError:
+ # object is an attribute of a class
+ return None
+ else:
+ obj = getattr(mod, info["fullname"])
+
+ try:
+ file = inspect.getsourcefile(obj)
+ source, lineno = inspect.getsourcelines(obj)
+ except TypeError:
+ # e.g. object is a typing.Union
+ return None
+ file = os.path.relpath(file, repo_root_path)
+ if not file.startswith(MODULE_ROOT_FOLDER):
+ # e.g. object is a typing.NewType
+ return None
+ start, end = lineno, lineno + len(source) - 1
+ url = f"{base_code_url}/{file}#L{start}-L{end}"
+ return url
diff --git a/docs/source/config_syntax.md b/docs/source/config_syntax.md
index c932879b5a..742841acca 100644
--- a/docs/source/config_syntax.md
+++ b/docs/source/config_syntax.md
@@ -16,6 +16,7 @@ Content:
- [`$` to evaluate as Python expressions](#to-evaluate-as-python-expressions)
- [`%` to textually replace configuration elements](#to-textually-replace-configuration-elements)
- [`_target_` (`_disabled_`, `_desc_`, `_requires_`, `_mode_`) to instantiate a Python object](#instantiate-a-python-object)
+ - [`+` to alter semantics of merging config keys from multiple configuration files](#multiple-config-files)
- [The command line interface](#the-command-line-interface)
- [Recommendations](#recommendations)
@@ -175,6 +176,47 @@ _Description:_ `_requires_`, `_disabled_`, `_desc_`, and `_mode_` are optional k
- `"debug"` -- execute with debug prompt and return the return value of ``pdb.runcall(_target_, **kwargs)``,
see also [`pdb.runcall`](https://docs.python.org/3/library/pdb.html#pdb.runcall).
+## Multiple config files
+
+_Description:_ Multiple config files may be specified on the command line.
+The content of those config files is being merged. When same keys are specifiled in more than one config file,
+the value associated with the key is being overridden, in the order config files are specified.
+If the desired behaviour is to merge values from both files, the key in second config file should be prefixed with `+`.
+The value types for the merged contents must match and be both of `dict` or both of `list` type.
+`dict` values will be merged via update(), `list` values - concatenated via extend().
+Here's an example. In this case, "amp" value will be overridden by extra_config.json.
+`imports` and `preprocessing#transforms` lists will be merged. An error would be thrown if the value type in `"+imports"` is not `list`:
+
+config.json:
+```json
+{
+ "amp": "$True"
+ "imports": [
+ "$import torch"
+ ],
+ "preprocessing": {
+ "_target_": "Compose",
+ "transforms": [
+ "$@t1",
+ "$@t2"
+ ]
+ },
+}
+```
+
+extra_config.json:
+```json
+{
+ "amp": "$False"
+ "+imports": [
+ "$from monai.networks import trt_compile"
+ ],
+ "+preprocessing#transforms": [
+ "$@t3"
+ ]
+}
+```
+
## The command line interface
In addition to the Pythonic APIs, a few command line interfaces (CLI) are provided to interact with the bundle.
diff --git a/docs/source/engines.rst b/docs/source/engines.rst
index afb2682822..a015c7b2a3 100644
--- a/docs/source/engines.rst
+++ b/docs/source/engines.rst
@@ -30,6 +30,11 @@ Workflows
.. autoclass:: GanTrainer
:members:
+`AdversarialTrainer`
+~~~~~~~~~~~~~~~~~~~~
+.. autoclass:: AdversarialTrainer
+ :members:
+
`Evaluator`
~~~~~~~~~~~
.. autoclass:: Evaluator
diff --git a/docs/source/index.rst b/docs/source/index.rst
index b6c8c22f98..85adee7e44 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -37,7 +37,7 @@ Features
Getting started
---------------
-`MedNIST demo `_ and `MONAI for PyTorch Users `_ are available on Colab.
+`MedNIST demo `_ and `MONAI for PyTorch Users `_ are available on Colab.
Examples and notebook tutorials are located at `Project-MONAI/tutorials `_.
diff --git a/docs/source/inferers.rst b/docs/source/inferers.rst
index 33f9e14d83..326f56e96c 100644
--- a/docs/source/inferers.rst
+++ b/docs/source/inferers.rst
@@ -49,6 +49,29 @@ Inferers
:members:
:special-members: __call__
+`DiffusionInferer`
+~~~~~~~~~~~~~~~~~~
+.. autoclass:: DiffusionInferer
+ :members:
+ :special-members: __call__
+
+`LatentDiffusionInferer`
+~~~~~~~~~~~~~~~~~~~~~~~~
+.. autoclass:: LatentDiffusionInferer
+ :members:
+ :special-members: __call__
+
+`ControlNetDiffusionInferer`
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+.. autoclass:: ControlNetDiffusionInferer
+ :members:
+ :special-members: __call__
+
+`ControlNetLatentDiffusionInferer`
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+.. autoclass:: ControlNetLatentDiffusionInferer
+ :members:
+ :special-members: __call__
Splitters
---------
diff --git a/docs/source/installation.md b/docs/source/installation.md
index d77253f0f9..4308a07647 100644
--- a/docs/source/installation.md
+++ b/docs/source/installation.md
@@ -19,7 +19,7 @@
---
-MONAI's core functionality is written in Python 3 (>= 3.8) and only requires [Numpy](https://numpy.org/) and [Pytorch](https://pytorch.org/).
+MONAI's core functionality is written in Python 3 (>= 3.9) and only requires [Numpy](https://numpy.org/) and [Pytorch](https://pytorch.org/).
The package is currently distributed via Github as the primary source code repository,
and the Python package index (PyPI). The pre-built Docker images are made available on DockerHub.
@@ -258,6 +258,6 @@ Since MONAI v0.2.0, the extras syntax such as `pip install 'monai[nibabel]'` is
```
which correspond to `nibabel`, `scikit-image`,`scipy`, `pillow`, `tensorboard`,
-`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, `onnx`, `onnxruntime`, `zarr`, `lpips`, `nvidia-ml-py`, and `huggingface_hub` respectively.
+`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `clearml`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, `pyyaml`, `fire`, `jsonschema`, `ninja`, `pynrrd`, `pydicom`, `h5py`, `nni`, `optuna`, `onnx`, `onnxruntime`, `zarr`, `lpips`, `nvidia-ml-py`, `huggingface_hub` and `pyamg` respectively.
- `pip install 'monai[all]'` installs all the optional dependencies.
diff --git a/docs/source/losses.rst b/docs/source/losses.rst
index ba794af3eb..528ccd1173 100644
--- a/docs/source/losses.rst
+++ b/docs/source/losses.rst
@@ -93,6 +93,11 @@ Segmentation Losses
.. autoclass:: SoftDiceclDiceLoss
:members:
+`NACLLoss`
+~~~~~~~~~~
+.. autoclass:: NACLLoss
+ :members:
+
Registration Losses
-------------------
diff --git a/docs/source/mb_specification.rst b/docs/source/mb_specification.rst
index cedafa0d23..56d660e35c 100644
--- a/docs/source/mb_specification.rst
+++ b/docs/source/mb_specification.rst
@@ -63,12 +63,12 @@ This file contains the metadata information relating to the model, including wha
* **monai_version**: version of MONAI the bundle was generated on, later versions expected to work.
* **pytorch_version**: version of Pytorch the bundle was generated on, later versions expected to work.
* **numpy_version**: version of Numpy the bundle was generated on, later versions expected to work.
-* **optional_packages_version**: dictionary relating optional package names to their versions, these packages are not needed but are recommended to be installed with this stated minimum version.
+* **required_packages_version**: dictionary relating required package names to their versions. These are packages in addition to the base requirements of MONAI which this bundle absolutely needs. For example, if the bundle must load Nifti files the Nibabel package will be required.
* **task**: plain-language description of what the model is meant to do.
* **description**: longer form plain-language description of what the model is, what it does, etc.
* **authors**: state author(s) of the model.
* **copyright**: state model copyright.
-* **network_data_format**: defines the format, shape, and meaning of inputs and outputs to the model, contains keys "inputs" and "outputs" relating named inputs/outputs to their format specifiers (defined below).
+* **network_data_format**: defines the format, shape, and meaning of inputs and outputs to the (primary) model, contains keys "inputs" and "outputs" relating named inputs/outputs to their format specifiers (defined below). There is also an optional "post_processed_outputs" key stating the format of "outputs" after postprocessing transforms are applied, this is used to describe the final output from the bundle if it varies from the raw network output. These keys can also relate to primitive values (number, string, boolean), instead of the tensor format specified below.
Tensor format specifiers are used to define input and output tensors and their meanings, and must be a dictionary containing at least these keys:
@@ -89,6 +89,8 @@ Optional keys:
* **data_source**: description of where training/validation can be sourced.
* **data_type**: type of source data used for training/validation.
* **references**: list of published referenced relating to the model.
+* **supported_apps**: list of supported applications which use bundles, eg. 'monai-label' would be present if the bundle is compatible with MONAI Label applications.
+* **\*_data_format**: defines the format, shape, and meaning of inputs and outputs to additional models which are secondary to the main model. This contains the same sort of information as **network_data_format** which describes networks providing secondary functionality, eg. a localisation network used to identify ROI in an image for cropping before data is sent to the primary network of this bundle.
The format for tensors used as inputs and outputs can be used to specify semantic meaning of these values, and later is used by software handling bundles to determine how to process and interpret this data. There are various types of image data that MONAI is uses, and other data types such as point clouds, dictionary sequences, time signals, and others. The following list is provided as a set of supported definitions of what a tensor "format" is but is not exhaustive and users can provide their own which would be left up to the model users to interpret:
@@ -124,7 +126,7 @@ An example JSON metadata file:
"monai_version": "0.9.0",
"pytorch_version": "1.10.0",
"numpy_version": "1.21.2",
- "optional_packages_version": {"nibabel": "3.2.1"},
+ "required_packages_version": {"nibabel": "3.2.1"},
"task": "Decathlon spleen segmentation",
"description": "A pre-trained model for volumetric (3D) segmentation of the spleen from CT image",
"authors": "MONAI team",
diff --git a/docs/source/networks.rst b/docs/source/networks.rst
index b59c8af5fc..64a3a4c9d1 100644
--- a/docs/source/networks.rst
+++ b/docs/source/networks.rst
@@ -481,6 +481,11 @@ Nets
.. autoclass:: SegResNetDS
:members:
+`SegResNetDS2`
+~~~~~~~~~~~~~~
+.. autoclass:: SegResNetDS2
+ :members:
+
`SegResNetVAE`
~~~~~~~~~~~~~~
.. autoclass:: SegResNetVAE
@@ -491,6 +496,11 @@ Nets
.. autoclass:: ResNet
:members:
+`ResNetFeatures`
+~~~~~~~~~~~~~~~~
+.. autoclass:: ResNetFeatures
+ :members:
+
`SENet`
~~~~~~~
.. autoclass:: SENet
@@ -551,6 +561,11 @@ Nets
.. autoclass:: UNETR
:members:
+`VISTA3D`
+~~~~~~~~~
+.. autoclass:: VISTA3D
+ :members:
+
`SwinUNETR`
~~~~~~~~~~~
.. autoclass:: SwinUNETR
@@ -720,14 +735,9 @@ Nets
.. autoclass:: VoxelMorphUNet
:members:
-.. autoclass:: voxelmorphunet
- :members:
-
.. autoclass:: VoxelMorph
:members:
-.. autoclass:: voxelmorph
-
Utilities
---------
.. automodule:: monai.networks.utils
diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst
index 21a7e5e44e..59a5ed9a26 100644
--- a/docs/source/transforms.rst
+++ b/docs/source/transforms.rst
@@ -309,6 +309,12 @@ Intensity
:members:
:special-members: __call__
+`ClipIntensityPercentiles`
+""""""""""""""""""""""""""
+.. autoclass:: ClipIntensityPercentiles
+ :members:
+ :special-members: __call__
+
`RandScaleIntensity`
""""""""""""""""""""
.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandScaleIntensity.png
@@ -548,6 +554,12 @@ IO
:members:
:special-members: __call__
+`WriteFileMapping`
+""""""""""""""""""
+.. autoclass:: WriteFileMapping
+ :members:
+ :special-members: __call__
+
NVIDIA Tool Extension (NVTX)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -661,6 +673,12 @@ Post-processing
:members:
:special-members: __call__
+`Invert`
+"""""""""
+.. autoclass:: Invert
+ :members:
+ :special-members: __call__
+
Regularization
^^^^^^^^^^^^^^
@@ -958,6 +976,18 @@ Spatial
:members:
:special-members: __call__
+`ConvertBoxToPoints`
+""""""""""""""""""""
+.. autoclass:: ConvertBoxToPoints
+ :members:
+ :special-members: __call__
+
+`ConvertPointsToBoxes`
+""""""""""""""""""""""
+.. autoclass:: ConvertPointsToBoxes
+ :members:
+ :special-members: __call__
+
Smooth Field
^^^^^^^^^^^^
@@ -1211,6 +1241,12 @@ Utility
:members:
:special-members: __call__
+`ApplyTransformToPoints`
+""""""""""""""""""""""""
+.. autoclass:: ApplyTransformToPoints
+ :members:
+ :special-members: __call__
+
Dictionary Transforms
---------------------
@@ -1412,6 +1448,12 @@ Intensity (Dict)
:members:
:special-members: __call__
+`ClipIntensityPercentilesd`
+"""""""""""""""""""""""""""
+.. autoclass:: ClipIntensityPercentilesd
+ :members:
+ :special-members: __call__
+
`RandScaleIntensityd`
"""""""""""""""""""""
.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RandScaleIntensityd.png
@@ -1631,6 +1673,12 @@ IO (Dict)
:members:
:special-members: __call__
+`WriteFileMappingd`
+"""""""""""""""""""
+.. autoclass:: WriteFileMappingd
+ :members:
+ :special-members: __call__
+
Post-processing (Dict)
^^^^^^^^^^^^^^^^^^^^^^
@@ -1950,6 +1998,18 @@ Spatial (Dict)
:members:
:special-members: __call__
+`ConvertBoxToPointsd`
+"""""""""""""""""""""
+.. autoclass:: ConvertBoxToPointsd
+ :members:
+ :special-members: __call__
+
+`ConvertPointsToBoxesd`
+"""""""""""""""""""""""
+.. autoclass:: ConvertPointsToBoxesd
+ :members:
+ :special-members: __call__
+
Smooth Field (Dict)
^^^^^^^^^^^^^^^^^^^
@@ -2260,6 +2320,12 @@ Utility (Dict)
:members:
:special-members: __call__
+`ApplyTransformToPointsd`
+"""""""""""""""""""""""""
+.. autoclass:: ApplyTransformToPointsd
+ :members:
+ :special-members: __call__
+
MetaTensor
^^^^^^^^^^
@@ -2305,6 +2371,9 @@ Utilities
.. automodule:: monai.transforms.utils_pytorch_numpy_unification
:members:
+.. automodule:: monai.transforms.utils_morphological_ops
+ :members:
+
By Categories
-------------
.. toctree::
diff --git a/docs/source/utils.rst b/docs/source/utils.rst
index 527247799f..ae3b476c3e 100644
--- a/docs/source/utils.rst
+++ b/docs/source/utils.rst
@@ -17,12 +17,6 @@ Module utils
:members:
-Aliases
--------
-.. automodule:: monai.utils.aliases
- :members:
-
-
Misc
----
.. automodule:: monai.utils.misc
@@ -81,3 +75,8 @@ Component store
---------------
.. autoclass:: monai.utils.component_store.ComponentStore
:members:
+
+Ordering
+--------
+.. automodule:: monai.utils.ordering
+ :members:
diff --git a/docs/source/whatsnew.rst b/docs/source/whatsnew.rst
index a12dbe6959..b1f6b2dac7 100644
--- a/docs/source/whatsnew.rst
+++ b/docs/source/whatsnew.rst
@@ -6,6 +6,7 @@ What's New
.. toctree::
:maxdepth: 1
+ whatsnew_1_4.md
whatsnew_1_3.md
whatsnew_1_2.md
whatsnew_1_1.md
diff --git a/docs/source/whatsnew_1_3.md b/docs/source/whatsnew_1_3.md
index c4b14810b5..6480547eec 100644
--- a/docs/source/whatsnew_1_3.md
+++ b/docs/source/whatsnew_1_3.md
@@ -1,4 +1,4 @@
-# What's new in 1.3 🎉🎉
+# What's new in 1.3
- Bundle usability enhancements
- Integrating MONAI Generative into MONAI core
diff --git a/docs/source/whatsnew_1_4.md b/docs/source/whatsnew_1_4.md
new file mode 100644
index 0000000000..0fc82ff820
--- /dev/null
+++ b/docs/source/whatsnew_1_4.md
@@ -0,0 +1,68 @@
+# What's new in 1.4 🎉🎉
+
+- MAISI: state-of-the-art 3D Latent Diffusion Model
+- VISTA-3D: interactive foundation model for segmenting and anotating human anatomies
+- VISTA-2D: cell segmentation pipeline
+- Integrating MONAI Generative into MONAI core
+- Lazy TensorRT export via `trt_compile`
+- Geometric Data Support
+
+
+## MAISI: state-of-the-art 3D Latent Diffusion Model
+
+![maisi](../images/maisi_train.png)
+
+MAISI (Medical AI for Synthetic Imaging) is a state-of-the-art three-dimensional (3D) Latent Diffusion Model designed for generating high-quality synthetic CT images with or without anatomical annotations. This AI model excels in data augmentation and creating realistic medical imaging data to supplement limited datasets due to privacy concerns or rare conditions. It can also significantly enhance the performance of other medical imaging AI models by generating diverse and realistic training data.
+
+A tutorial for generating large CT images accompanied by corresponding segmentation masks using MAISI is provided within
+[`project-monai/tutorials`](https://github.com/Project-MONAI/tutorials/blob/main/generation/maisi).
+It contains the following features:
+- A foundation Variational Auto-Encoder (VAE) model for latent feature compression that works for both CT and MRI with flexible volume size and voxel size
+- A foundation Diffusion model that can generate large CT volumes up to 512 × 512 × 768 size, with flexible volume size and voxel size
+- A ControlNet to generate image/mask pairs that can improve downstream tasks, with controllable organ/tumor size
+
+## VISTA-3D: state-of-the-art 3D Latent Diffusion Model
+
+![vista-3d](../images/vista3d.png)
+
+VISTA-3D is a specialized interactive foundation model for 3D medical imaging. It excels in providing accurate and adaptable segmentation analysis across anatomies and modalities. Utilizing a multi-head architecture, VISTA-3D adapts to varying conditions and anatomical areas, helping guide users' annotation workflow.
+
+A tutorial showing how to finetune VISTA-3D on spleen dataset is provided within
+[`project-monai/tutorials`](https://github.com/Project-MONAI/tutorials/blob/main/vista_3d).
+It supports three core workflows:
+- Segment everything: Enables whole body exploration, crucial for understanding complex diseases affecting multiple organs and for holistic treatment planning.
+- Segment using class: Provides detailed sectional views based on specific classes, essential for targeted disease analysis or organ mapping, such as tumor identification in critical organs.
+- Segment point prompts: Enhances segmentation precision through user-directed, click-based selection. This interactive approach accelerates the creation of accurate ground-truth data, essential in medical imaging analysis.
+
+## VISTA-2D: cell segmentation pipeline
+
+![vista-2d](../images/vista2d.png)
+
+VISTA-2D is a comprehensive training and inference pipeline for cell segmentation in imaging applications. For more information, refer to this [Blog](https://developer.nvidia.com/blog/advancing-cell-segmentation-and-morphology-analysis-with-nvidia-ai-foundation-model-vista-2d/)
+
+Key features of the model include:
+- A robust deep learning algorithm utilizing transformers
+- Foundational model as compared to specialist models
+- Supports a wide variety of datasets and file formats
+- Capable of handling multiple imaging modalities
+- Multi-GPU and multinode training support
+
+A tutorial demonstrating how to train a cell segmentation model using the MONAI framework on the Cellpose dataset can be found in [`project-monai/tutorials`](https://github.com/Project-MONAI/tutorials/blob/main/vista_2d).
+
+## Integrating MONAI Generative into MONAI Core
+
+Key modules originally developed in the [MONAI GenerativeModels](https://github.com/Project-MONAI/GenerativeModels) repository have been integrated into the core MONAI codebase. This integration ensures consistent maintenance and streamlined release of essential components for generative AI. In this version, all utilities, networks, diffusion schedulers, inferers, and engines have been migrated into the core codebase. Special care has been taken to ensure saved weights from models trained using GenerativeModels can be loaded into those now integrated into core.
+
+Additionally, several tutorials have been ported and are available within [`project-monai/tutorials`](https://github.com/Project-MONAI/tutorials/blob/main/generation)
+
+## Lazy TensorRT export via `trt_compile`
+This release expands TensorRT optimization options for MONAI bundles with `trt_compile` API.
+The existing `trt_export` API requires the user to run a separate export script to prepare a TensorRT engine-based TorchScript model.
+`trt_compile` builds and saves a TensorRT engine the first time a bundle is run and provides limited dependency support.
+It also allows partial TensorRT export where only a certain submodule is being optimized, which improves usability.
+A few bundles in the MONAI model zoo, like the new [VISTA-3D](https://github.com/Project-MONAI/model-zoo/tree/dev/models/vista3d)
+and [VISTA-2D](https://github.com/Project-MONAI/model-zoo/tree/dev/models/vista2d) bundles, already come with `trt_inference.json` config files which use `trt_compile`.
+
+## Geometric Data Support
+
+MONAI introduces support for geometric data transformations as a key feature. As a starting point, ApplyTransformToPoints transform is added to facilitate matrix operations on points, enabling flexible and efficient handling of geometric transformations. Alongside this, the framework now supports conversions between boxes and points, providing seamless interoperability within detection pipelines. These updates have been integrated into existing pipelines, such as the [detection tutorial](https://github.com/Project-MONAI/tutorials/blob/main/detection) and the [3D registration workflow](https://github.com/Project-MONAI/tutorials/blob/main/3d_registration/learn2reg_nlst_paired_lung_ct.ipynb), leveraging the latest APIs for improved functionality.
diff --git a/environment-dev.yml b/environment-dev.yml
index 20427d5d5c..4a1723e8a5 100644
--- a/environment-dev.yml
+++ b/environment-dev.yml
@@ -5,11 +5,11 @@ channels:
- nvidia
- conda-forge
dependencies:
- - numpy>=1.20
+ - numpy>=1.24,<2.0
- pytorch>=1.9
- torchio
- torchvision
- - pytorch-cuda=11.6
+ - pytorch-cuda>=11.6
- pip
- pip:
- -r requirements-dev.txt
diff --git a/monai/__init__.py b/monai/__init__.py
index eb05ac993d..d92557a8e1 100644
--- a/monai/__init__.py
+++ b/monai/__init__.py
@@ -11,13 +11,55 @@
from __future__ import annotations
+import logging
import os
import sys
+import warnings
from ._version import get_versions
+old_showwarning = warnings.showwarning
+
+
+def custom_warning_handler(message, category, filename, lineno, file=None, line=None):
+ ignore_files = ["ignite/handlers/checkpoint", "modelopt/torch/quantization/tensor_quant"]
+ if any(ignore in filename for ignore in ignore_files):
+ return
+ old_showwarning(message, category, filename, lineno, file, line)
+
+
+class DeprecatedTypesWarningFilter(logging.Filter):
+ def filter(self, record):
+ message_bodies_to_ignore = [
+ "np.bool8",
+ "np.object0",
+ "np.int0",
+ "np.uint0",
+ "np.void0",
+ "np.str0",
+ "np.bytes0",
+ "@validator",
+ "@root_validator",
+ "class-based `config`",
+ "pkg_resources",
+ "Implicitly cleaning up",
+ ]
+ for message in message_bodies_to_ignore:
+ if message in record.getMessage():
+ return False
+ return True
+
+
+# workaround for https://github.com/Project-MONAI/MONAI/issues/8060
+# TODO: remove this workaround after upstream fixed the warning
+# Set the custom warning handler to filter warning
+warnings.showwarning = custom_warning_handler
+# Get the logger for warnings and add the filter to the logger
+logging.getLogger("py.warnings").addFilter(DeprecatedTypesWarningFilter())
+
+
PY_REQUIRED_MAJOR = 3
-PY_REQUIRED_MINOR = 8
+PY_REQUIRED_MINOR = 9
version_dict = get_versions()
__version__: str = version_dict.get("version", "0+unknown")
@@ -37,6 +79,7 @@
category=RuntimeWarning,
)
+
from .utils.module import load_submodules # noqa: E402
# handlers_* have some external decorators the users may not have installed
diff --git a/monai/apps/auto3dseg/auto_runner.py b/monai/apps/auto3dseg/auto_runner.py
index 52a0824227..99bf92c481 100644
--- a/monai/apps/auto3dseg/auto_runner.py
+++ b/monai/apps/auto3dseg/auto_runner.py
@@ -298,9 +298,13 @@ def __init__(
pass
# inspect and update folds
- num_fold = self.inspect_datalist_folds(datalist_filename=datalist_filename)
+ self.max_fold = self.inspect_datalist_folds(datalist_filename=datalist_filename)
if "num_fold" in self.data_src_cfg:
num_fold = int(self.data_src_cfg["num_fold"]) # override from config
+ logger.info(f"Setting num_fold {num_fold} based on the input config.")
+ else:
+ num_fold = self.max_fold
+ logger.info(f"Setting num_fold {num_fold} based on the input datalist {datalist_filename}.")
self.data_src_cfg["datalist"] = datalist_filename # update path to a version in work_dir and save user input
ConfigParser.export_config_file(
@@ -398,7 +402,10 @@ def inspect_datalist_folds(self, datalist_filename: str) -> int:
if len(fold_list) > 0:
num_fold = max(fold_list) + 1
- logger.info(f"Setting num_fold {num_fold} based on the input datalist {datalist_filename}.")
+ logger.info(f"Found num_fold {num_fold} based on the input datalist {datalist_filename}.")
+ # check if every fold is present
+ if len(set(fold_list)) != num_fold:
+ raise ValueError(f"Fold numbers are not continuous from 0 to {num_fold - 1}")
elif "validation" in datalist and len(datalist["validation"]) > 0:
logger.info("No fold numbers provided, attempting to use a single fold based on the validation key")
# update the datalist file
@@ -492,6 +499,11 @@ def set_num_fold(self, num_fold: int = 5) -> AutoRunner:
if num_fold <= 0:
raise ValueError(f"num_fold is expected to be an integer greater than zero. Now it gets {num_fold}")
+ if num_fold > self.max_fold:
+ # Auto3DSeg must contain validation set, so the maximum fold number is max_fold.
+ raise ValueError(
+ f"num_fold is greater than the maximum fold number {self.max_fold} in {self.datalist_filename}."
+ )
self.num_fold = num_fold
return self
@@ -539,7 +551,7 @@ def set_device_info(
cmd_prefix: command line prefix for subprocess running in BundleAlgo and EnsembleRunner.
Default using env "CMD_PREFIX" or None, examples are:
- - single GPU/CPU or multinode bcprun: "python " or "/opt/conda/bin/python3.8 ",
+ - single GPU/CPU or multinode bcprun: "python " or "/opt/conda/bin/python3.9 ",
- single node multi-GPU running "torchrun --nnodes=1 --nproc_per_node=2 "
If user define this prefix, please make sure --nproc_per_node matches cuda_visible_device or
diff --git a/monai/apps/auto3dseg/hpo_gen.py b/monai/apps/auto3dseg/hpo_gen.py
index b755b99feb..ed6d903897 100644
--- a/monai/apps/auto3dseg/hpo_gen.py
+++ b/monai/apps/auto3dseg/hpo_gen.py
@@ -53,7 +53,7 @@ def update_params(self, *args, **kwargs):
raise NotImplementedError
@abstractmethod
- def set_score(self):
+ def set_score(self, *args, **kwargs):
"""Report the evaluated results to HPO."""
raise NotImplementedError
diff --git a/monai/apps/deepedit/transforms.py b/monai/apps/deepedit/transforms.py
index 6d0825f54a..5af082e2b0 100644
--- a/monai/apps/deepedit/transforms.py
+++ b/monai/apps/deepedit/transforms.py
@@ -30,7 +30,7 @@
logger = logging.getLogger(__name__)
-distance_transform_cdt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_cdt")
+distance_transform_cdt, _ = optional_import("scipy.ndimage", name="distance_transform_cdt")
class DiscardAddGuidanced(MapTransform):
diff --git a/monai/apps/deepgrow/transforms.py b/monai/apps/deepgrow/transforms.py
index 9aca77a36c..721c0db489 100644
--- a/monai/apps/deepgrow/transforms.py
+++ b/monai/apps/deepgrow/transforms.py
@@ -27,7 +27,7 @@
from monai.utils.enums import PostFix
measure, _ = optional_import("skimage.measure", "0.14.2", min_version)
-distance_transform_cdt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_cdt")
+distance_transform_cdt, _ = optional_import("scipy.ndimage", name="distance_transform_cdt")
DEFAULT_POST_FIX = PostFix.meta()
@@ -803,6 +803,14 @@ class RestoreLabeld(MapTransform):
original_shape_key: key that records original shape for foreground.
cropped_shape_key: key that records cropped shape for foreground.
allow_missing_keys: don't raise exception if key is missing.
+ restore_resizing: used to enable or disable resizing restoration, default is True.
+ If True, the transform will resize the items back to its original shape.
+ restore_cropping: used to enable or disable cropping restoration, default is True.
+ If True, the transform will restore the items to its uncropped size.
+ restore_spacing: used to enable or disable spacing restoration, default is True.
+ If True, the transform will resample the items back to the spacing it had before being altered.
+ restore_slicing: used to enable or disable slicing restoration, default is True.
+ If True, the transform will reassemble the full volume by restoring the slices to their original positions.
"""
def __init__(
@@ -819,6 +827,10 @@ def __init__(
original_shape_key: str = "foreground_original_shape",
cropped_shape_key: str = "foreground_cropped_shape",
allow_missing_keys: bool = False,
+ restore_resizing: bool = True,
+ restore_cropping: bool = True,
+ restore_spacing: bool = True,
+ restore_slicing: bool = True,
) -> None:
super().__init__(keys, allow_missing_keys)
self.ref_image = ref_image
@@ -833,6 +845,10 @@ def __init__(
self.end_coord_key = end_coord_key
self.original_shape_key = original_shape_key
self.cropped_shape_key = cropped_shape_key
+ self.restore_resizing = restore_resizing
+ self.restore_cropping = restore_cropping
+ self.restore_spacing = restore_spacing
+ self.restore_slicing = restore_slicing
def __call__(self, data: Any) -> dict:
d = dict(data)
@@ -842,38 +858,45 @@ def __call__(self, data: Any) -> dict:
image = d[key]
# Undo Resize
- current_shape = image.shape
- cropped_shape = meta_dict[self.cropped_shape_key]
- if np.any(np.not_equal(current_shape, cropped_shape)):
- resizer = Resize(spatial_size=cropped_shape[1:], mode=mode)
- image = resizer(image, mode=mode, align_corners=align_corners)
+ if self.restore_resizing:
+ current_shape = image.shape
+ cropped_shape = meta_dict[self.cropped_shape_key]
+ if np.any(np.not_equal(current_shape, cropped_shape)):
+ resizer = Resize(spatial_size=cropped_shape[1:], mode=mode)
+ image = resizer(image, mode=mode, align_corners=align_corners)
# Undo Crop
- original_shape = meta_dict[self.original_shape_key]
- result = np.zeros(original_shape, dtype=np.float32)
- box_start = meta_dict[self.start_coord_key]
- box_end = meta_dict[self.end_coord_key]
-
- spatial_dims = min(len(box_start), len(image.shape[1:]))
- slices = tuple(
- [slice(None)] + [slice(s, e) for s, e in zip(box_start[:spatial_dims], box_end[:spatial_dims])]
- )
- result[slices] = image
+ if self.restore_cropping:
+ original_shape = meta_dict[self.original_shape_key]
+ result = np.zeros(original_shape, dtype=np.float32)
+ box_start = meta_dict[self.start_coord_key]
+ box_end = meta_dict[self.end_coord_key]
+
+ spatial_dims = min(len(box_start), len(image.shape[1:]))
+ slices = tuple(
+ [slice(None)] + [slice(s, e) for s, e in zip(box_start[:spatial_dims], box_end[:spatial_dims])]
+ )
+ result[slices] = image
+ else:
+ result = image
# Undo Spacing
- current_size = result.shape[1:]
- # change spatial_shape from HWD to DHW
- spatial_shape = list(np.roll(meta_dict["spatial_shape"], 1))
- spatial_size = spatial_shape[-len(current_size) :]
+ if self.restore_spacing:
+ current_size = result.shape[1:]
+ # change spatial_shape from HWD to DHW
+ spatial_shape = list(np.roll(meta_dict["spatial_shape"], 1))
+ spatial_size = spatial_shape[-len(current_size) :]
- if np.any(np.not_equal(current_size, spatial_size)):
- resizer = Resize(spatial_size=spatial_size, mode=mode)
- result = resizer(result, mode=mode, align_corners=align_corners) # type: ignore
+ if np.any(np.not_equal(current_size, spatial_size)):
+ resizer = Resize(spatial_size=spatial_size, mode=mode)
+ result = resizer(result, mode=mode, align_corners=align_corners) # type: ignore
# Undo Slicing
slice_idx = meta_dict.get("slice_idx")
final_result: NdarrayOrTensor
- if slice_idx is None or self.slice_only:
+ if not self.restore_slicing: # do nothing if restore slicing isn't requested
+ final_result = result
+ elif slice_idx is None or self.slice_only:
final_result = result if len(result.shape) <= 3 else result[0]
else:
slice_idx = meta_dict["slice_idx"][0]
diff --git a/monai/apps/detection/networks/retinanet_network.py b/monai/apps/detection/networks/retinanet_network.py
index ec86c3b0e9..ca6a8f5c19 100644
--- a/monai/apps/detection/networks/retinanet_network.py
+++ b/monai/apps/detection/networks/retinanet_network.py
@@ -42,7 +42,7 @@
import math
import warnings
from collections.abc import Callable, Sequence
-from typing import Any, Dict
+from typing import Any
import torch
from torch import Tensor, nn
@@ -330,7 +330,7 @@ def forward(self, images: Tensor) -> Any:
features = self.feature_extractor(images)
if isinstance(features, Tensor):
feature_maps = [features]
- elif torch.jit.isinstance(features, Dict[str, Tensor]):
+ elif torch.jit.isinstance(features, dict[str, Tensor]):
feature_maps = list(features.values())
else:
feature_maps = list(features)
diff --git a/monai/apps/detection/transforms/array.py b/monai/apps/detection/transforms/array.py
index d8ffce4584..301a636b6c 100644
--- a/monai/apps/detection/transforms/array.py
+++ b/monai/apps/detection/transforms/array.py
@@ -15,7 +15,8 @@
from __future__ import annotations
-from typing import Any, Sequence
+from collections.abc import Sequence
+from typing import Any
import numpy as np
import torch
diff --git a/monai/apps/detection/utils/anchor_utils.py b/monai/apps/detection/utils/anchor_utils.py
index 283169b653..cbde3ebae9 100644
--- a/monai/apps/detection/utils/anchor_utils.py
+++ b/monai/apps/detection/utils/anchor_utils.py
@@ -189,7 +189,7 @@ def generate_anchors(
w_ratios = 1 / area_scale
h_ratios = area_scale
# if 3d, w:h:d = 1:aspect_ratios[:,0]:aspect_ratios[:,1]
- elif self.spatial_dims == 3:
+ else:
area_scale = torch.pow(aspect_ratios_t[:, 0] * aspect_ratios_t[:, 1], 1 / 3.0)
w_ratios = 1 / area_scale
h_ratios = aspect_ratios_t[:, 0] / area_scale
@@ -199,7 +199,7 @@ def generate_anchors(
hs = (h_ratios[:, None] * scales_t[None, :]).view(-1)
if self.spatial_dims == 2:
base_anchors = torch.stack([-ws, -hs, ws, hs], dim=1) / 2.0
- elif self.spatial_dims == 3:
+ else: # elif self.spatial_dims == 3:
ds = (d_ratios[:, None] * scales_t[None, :]).view(-1)
base_anchors = torch.stack([-ws, -hs, -ds, ws, hs, ds], dim=1) / 2.0
diff --git a/monai/apps/generation/__init__.py b/monai/apps/generation/__init__.py
new file mode 100644
index 0000000000..1e97f89407
--- /dev/null
+++ b/monai/apps/generation/__init__.py
@@ -0,0 +1,10 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/monai/apps/generation/maisi/__init__.py b/monai/apps/generation/maisi/__init__.py
new file mode 100644
index 0000000000..1e97f89407
--- /dev/null
+++ b/monai/apps/generation/maisi/__init__.py
@@ -0,0 +1,10 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/monai/apps/generation/maisi/networks/__init__.py b/monai/apps/generation/maisi/networks/__init__.py
new file mode 100644
index 0000000000..1e97f89407
--- /dev/null
+++ b/monai/apps/generation/maisi/networks/__init__.py
@@ -0,0 +1,10 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py
new file mode 100644
index 0000000000..6251ea8e83
--- /dev/null
+++ b/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py
@@ -0,0 +1,991 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import gc
+import logging
+from collections.abc import Sequence
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from monai.networks.blocks import Convolution
+from monai.networks.blocks.spatialattention import SpatialAttentionBlock
+from monai.networks.nets.autoencoderkl import AEKLResBlock, AutoencoderKL
+from monai.utils.type_conversion import convert_to_tensor
+
+# Set up logging configuration
+logger = logging.getLogger(__name__)
+
+
+def _empty_cuda_cache(save_mem: bool) -> None:
+ if torch.cuda.is_available() and save_mem:
+ torch.cuda.empty_cache()
+ return
+
+
+class MaisiGroupNorm3D(nn.GroupNorm):
+ """
+ Custom 3D Group Normalization with optional print_info output.
+
+ Args:
+ num_groups: Number of groups for the group norm.
+ num_channels: Number of channels for the group norm.
+ eps: Epsilon value for numerical stability.
+ affine: Whether to use learnable affine parameters, default to `True`.
+ norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`.
+ print_info: Whether to print information, default to `False`.
+ save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
+ """
+
+ def __init__(
+ self,
+ num_groups: int,
+ num_channels: int,
+ eps: float = 1e-5,
+ affine: bool = True,
+ norm_float16: bool = False,
+ print_info: bool = False,
+ save_mem: bool = True,
+ ):
+ super().__init__(num_groups, num_channels, eps, affine)
+ self.norm_float16 = norm_float16
+ self.print_info = print_info
+ self.save_mem = save_mem
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ if self.print_info:
+ logger.info(f"MaisiGroupNorm3D with input size: {input.size()}")
+
+ if len(input.shape) != 5:
+ raise ValueError("Expected a 5D tensor")
+
+ param_n, param_c, param_d, param_h, param_w = input.shape
+ input = input.view(param_n, self.num_groups, param_c // self.num_groups, param_d, param_h, param_w)
+
+ inputs = []
+ for i in range(input.size(1)):
+ array = input[:, i : i + 1, ...].to(dtype=torch.float32)
+ mean = array.mean([2, 3, 4, 5], keepdim=True)
+ std = array.var([2, 3, 4, 5], unbiased=False, keepdim=True).add_(self.eps).sqrt_()
+ if self.norm_float16:
+ inputs.append(((array - mean) / std).to(dtype=torch.float16))
+ else:
+ inputs.append((array - mean) / std)
+
+ del input
+ _empty_cuda_cache(self.save_mem)
+
+ input = torch.cat(inputs, dim=1) if max(inputs[0].size()) < 500 else self._cat_inputs(inputs)
+
+ input = input.view(param_n, param_c, param_d, param_h, param_w)
+ if self.affine:
+ input.mul_(self.weight.view(1, param_c, 1, 1, 1)).add_(self.bias.view(1, param_c, 1, 1, 1))
+
+ if self.print_info:
+ logger.info(f"MaisiGroupNorm3D with output size: {input.size()}")
+
+ return input
+
+ def _cat_inputs(self, inputs):
+ input_type = inputs[0].device.type
+ input = inputs[0].clone().to("cpu", non_blocking=True) if input_type == "cuda" else inputs[0].clone()
+ inputs[0] = 0
+ _empty_cuda_cache(self.save_mem)
+
+ for k in range(len(inputs) - 1):
+ input = torch.cat((input, inputs[k + 1].cpu()), dim=1)
+ inputs[k + 1] = 0
+ _empty_cuda_cache(self.save_mem)
+ gc.collect()
+
+ if self.print_info:
+ logger.info(f"MaisiGroupNorm3D concat progress: {k + 1}/{len(inputs) - 1}.")
+
+ return input.to("cuda", non_blocking=True) if input_type == "cuda" else input
+
+
+class MaisiConvolution(nn.Module):
+ """
+ Convolutional layer with optional print_info output and custom splitting mechanism.
+
+ Args:
+ spatial_dims: Number of spatial dimensions (1D, 2D, 3D).
+ in_channels: Number of input channels.
+ out_channels: Number of output channels.
+ num_splits: Number of splits for the input tensor.
+ dim_split: Dimension of splitting for the input tensor.
+ print_info: Whether to print information.
+ save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
+ Additional arguments for the convolution operation.
+ https://docs.monai.io/en/stable/networks.html#convolution
+ """
+
+ def __init__(
+ self,
+ spatial_dims: int,
+ in_channels: int,
+ out_channels: int,
+ num_splits: int,
+ dim_split: int,
+ print_info: bool,
+ save_mem: bool = True,
+ strides: Sequence[int] | int = 1,
+ kernel_size: Sequence[int] | int = 3,
+ adn_ordering: str = "NDA",
+ act: tuple | str | None = "PRELU",
+ norm: tuple | str | None = "INSTANCE",
+ dropout: tuple | str | float | None = None,
+ dropout_dim: int = 1,
+ dilation: Sequence[int] | int = 1,
+ groups: int = 1,
+ bias: bool = True,
+ conv_only: bool = False,
+ is_transposed: bool = False,
+ padding: Sequence[int] | int | None = None,
+ output_padding: Sequence[int] | int | None = None,
+ ) -> None:
+ super().__init__()
+ self.conv = Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ strides=strides,
+ kernel_size=kernel_size,
+ adn_ordering=adn_ordering,
+ act=act,
+ norm=norm,
+ dropout=dropout,
+ dropout_dim=dropout_dim,
+ dilation=dilation,
+ groups=groups,
+ bias=bias,
+ conv_only=conv_only,
+ is_transposed=is_transposed,
+ padding=padding,
+ output_padding=output_padding,
+ )
+
+ self.dim_split = dim_split
+ self.stride = strides[self.dim_split] if isinstance(strides, list) else strides
+ self.num_splits = num_splits
+ self.print_info = print_info
+ self.save_mem = save_mem
+
+ def _split_tensor(self, x: torch.Tensor, split_size: int, padding: int) -> list[torch.Tensor]:
+ overlaps = [0] + [padding] * (self.num_splits - 1)
+ last_padding = x.size(self.dim_split + 2) % split_size
+
+ slices = [slice(None)] * 5
+ splits: list[torch.Tensor] = []
+ for i in range(self.num_splits):
+ slices[self.dim_split + 2] = slice(
+ i * split_size - overlaps[i],
+ (i + 1) * split_size + (padding if i != self.num_splits - 1 else last_padding),
+ )
+ splits.append(x[tuple(slices)])
+
+ if self.print_info:
+ for j in range(len(splits)):
+ logger.info(f"Split {j + 1}/{len(splits)} size: {splits[j].size()}")
+
+ return splits
+
+ def _concatenate_tensors(self, outputs: list[torch.Tensor], split_size: int, padding: int) -> torch.Tensor:
+ slices = [slice(None)] * 5
+ for i in range(self.num_splits):
+ slices[self.dim_split + 2] = slice(None, split_size) if i == 0 else slice(padding, padding + split_size)
+ outputs[i] = outputs[i][tuple(slices)]
+
+ if self.print_info:
+ for i in range(self.num_splits):
+ logger.info(f"Output {i + 1}/{len(outputs)} size after: {outputs[i].size()}")
+
+ if max(outputs[0].size()) < 500:
+ x = torch.cat(outputs, dim=self.dim_split + 2)
+ else:
+ x = outputs[0].clone().to("cpu", non_blocking=True)
+ outputs[0] = torch.Tensor(0)
+ _empty_cuda_cache(self.save_mem)
+ for k in range(len(outputs) - 1):
+ x = torch.cat((x, outputs[k + 1].cpu()), dim=self.dim_split + 2)
+ outputs[k + 1] = torch.Tensor(0)
+ _empty_cuda_cache(self.save_mem)
+ gc.collect()
+ if self.print_info:
+ logger.info(f"MaisiConvolution concat progress: {k + 1}/{len(outputs) - 1}.")
+
+ x = x.to("cuda", non_blocking=True)
+ return x
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ if self.print_info:
+ logger.info(f"Number of splits: {self.num_splits}")
+
+ # compute size of splits
+ l = x.size(self.dim_split + 2)
+ split_size = l // self.num_splits
+
+ # update padding length if necessary
+ padding = 3
+ if padding % self.stride > 0:
+ padding = (padding // self.stride + 1) * self.stride
+ if self.print_info:
+ logger.info(f"Padding size: {padding}")
+
+ # split tensor into a list of tensors
+ splits = self._split_tensor(x, split_size, padding)
+
+ del x
+ _empty_cuda_cache(self.save_mem)
+
+ # convolution
+ outputs = [self.conv(split) for split in splits]
+ if self.print_info:
+ for j in range(len(outputs)):
+ logger.info(f"Output {j + 1}/{len(outputs)} size before: {outputs[j].size()}")
+
+ # update size of splits and padding length for output
+ split_size_out = split_size
+ padding_s = padding
+ non_dim_split = self.dim_split + 1 if self.dim_split < 2 else 0
+ if outputs[0].size(non_dim_split + 2) // splits[0].size(non_dim_split + 2) == 2:
+ split_size_out *= 2
+ padding_s *= 2
+ elif splits[0].size(non_dim_split + 2) // outputs[0].size(non_dim_split + 2) == 2:
+ split_size_out //= 2
+ padding_s //= 2
+
+ # concatenate list of tensors
+ x = self._concatenate_tensors(outputs, split_size_out, padding_s)
+
+ del outputs
+ _empty_cuda_cache(self.save_mem)
+
+ return x
+
+
+class MaisiUpsample(nn.Module):
+ """
+ Convolution-based upsampling layer.
+
+ Args:
+ spatial_dims: Number of spatial dimensions (1D, 2D, 3D).
+ in_channels: Number of input channels to the layer.
+ use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder.
+ num_splits: Number of splits for the input tensor.
+ dim_split: Dimension of splitting for the input tensor.
+ print_info: Whether to print information.
+ save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
+ """
+
+ def __init__(
+ self,
+ spatial_dims: int,
+ in_channels: int,
+ use_convtranspose: bool,
+ num_splits: int,
+ dim_split: int,
+ print_info: bool,
+ save_mem: bool = True,
+ ) -> None:
+ super().__init__()
+ self.conv = MaisiConvolution(
+ spatial_dims=spatial_dims,
+ in_channels=in_channels,
+ out_channels=in_channels,
+ strides=2 if use_convtranspose else 1,
+ kernel_size=3,
+ padding=1,
+ conv_only=True,
+ is_transposed=use_convtranspose,
+ num_splits=num_splits,
+ dim_split=dim_split,
+ print_info=print_info,
+ save_mem=save_mem,
+ )
+ self.use_convtranspose = use_convtranspose
+ self.save_mem = save_mem
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ if self.use_convtranspose:
+ x = self.conv(x)
+ x_tensor: torch.Tensor = convert_to_tensor(x)
+ return x_tensor
+
+ x = F.interpolate(x, scale_factor=2.0, mode="trilinear")
+ _empty_cuda_cache(self.save_mem)
+ x = self.conv(x)
+ _empty_cuda_cache(self.save_mem)
+
+ out_tensor: torch.Tensor = convert_to_tensor(x)
+ return out_tensor
+
+
+class MaisiDownsample(nn.Module):
+ """
+ Convolution-based downsampling layer.
+
+ Args:
+ spatial_dims: Number of spatial dimensions (1D, 2D, 3D).
+ in_channels: Number of input channels.
+ num_splits: Number of splits for the input tensor.
+ dim_split: Dimension of splitting for the input tensor.
+ print_info: Whether to print information.
+ save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
+ """
+
+ def __init__(
+ self,
+ spatial_dims: int,
+ in_channels: int,
+ num_splits: int,
+ dim_split: int,
+ print_info: bool,
+ save_mem: bool = True,
+ ) -> None:
+ super().__init__()
+ self.pad = (0, 1) * spatial_dims
+ self.conv = MaisiConvolution(
+ spatial_dims=spatial_dims,
+ in_channels=in_channels,
+ out_channels=in_channels,
+ strides=2,
+ kernel_size=3,
+ padding=0,
+ conv_only=True,
+ num_splits=num_splits,
+ dim_split=dim_split,
+ print_info=print_info,
+ save_mem=save_mem,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = F.pad(x, self.pad, mode="constant", value=0.0)
+ x = self.conv(x)
+ return x
+
+
+class MaisiResBlock(nn.Module):
+ """
+ Residual block consisting of a cascade of 2 convolutions + activation + normalisation block, and a
+ residual connection between input and output.
+
+ Args:
+ spatial_dims: Number of spatial dimensions (1D, 2D, 3D).
+ in_channels: Input channels to the layer.
+ norm_num_groups: Number of groups for the group norm layer.
+ norm_eps: Epsilon for the normalization.
+ out_channels: Number of output channels.
+ num_splits: Number of splits for the input tensor.
+ dim_split: Dimension of splitting for the input tensor.
+ norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`.
+ print_info: Whether to print information, default to `False`.
+ save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
+ """
+
+ def __init__(
+ self,
+ spatial_dims: int,
+ in_channels: int,
+ norm_num_groups: int,
+ norm_eps: float,
+ out_channels: int,
+ num_splits: int,
+ dim_split: int,
+ norm_float16: bool = False,
+ print_info: bool = False,
+ save_mem: bool = True,
+ ) -> None:
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = in_channels if out_channels is None else out_channels
+ self.save_mem = save_mem
+
+ self.norm1 = MaisiGroupNorm3D(
+ num_groups=norm_num_groups,
+ num_channels=in_channels,
+ eps=norm_eps,
+ affine=True,
+ norm_float16=norm_float16,
+ print_info=print_info,
+ save_mem=save_mem,
+ )
+ self.conv1 = MaisiConvolution(
+ spatial_dims=spatial_dims,
+ in_channels=self.in_channels,
+ out_channels=self.out_channels,
+ strides=1,
+ kernel_size=3,
+ padding=1,
+ conv_only=True,
+ num_splits=num_splits,
+ dim_split=dim_split,
+ print_info=print_info,
+ save_mem=save_mem,
+ )
+ self.norm2 = MaisiGroupNorm3D(
+ num_groups=norm_num_groups,
+ num_channels=out_channels,
+ eps=norm_eps,
+ affine=True,
+ norm_float16=norm_float16,
+ print_info=print_info,
+ save_mem=save_mem,
+ )
+ self.conv2 = MaisiConvolution(
+ spatial_dims=spatial_dims,
+ in_channels=self.out_channels,
+ out_channels=self.out_channels,
+ strides=1,
+ kernel_size=3,
+ padding=1,
+ conv_only=True,
+ num_splits=num_splits,
+ dim_split=dim_split,
+ print_info=print_info,
+ save_mem=save_mem,
+ )
+
+ self.nin_shortcut = (
+ MaisiConvolution(
+ spatial_dims=spatial_dims,
+ in_channels=self.in_channels,
+ out_channels=self.out_channels,
+ strides=1,
+ kernel_size=1,
+ padding=0,
+ conv_only=True,
+ num_splits=num_splits,
+ dim_split=dim_split,
+ print_info=print_info,
+ save_mem=save_mem,
+ )
+ if self.in_channels != self.out_channels
+ else nn.Identity()
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ h = self.norm1(x)
+ _empty_cuda_cache(self.save_mem)
+
+ h = F.silu(h)
+ _empty_cuda_cache(self.save_mem)
+ h = self.conv1(h)
+ _empty_cuda_cache(self.save_mem)
+
+ h = self.norm2(h)
+ _empty_cuda_cache(self.save_mem)
+
+ h = F.silu(h)
+ _empty_cuda_cache(self.save_mem)
+ h = self.conv2(h)
+ _empty_cuda_cache(self.save_mem)
+
+ if self.in_channels != self.out_channels:
+ x = self.nin_shortcut(x)
+ _empty_cuda_cache(self.save_mem)
+
+ out = x + h
+ out_tensor: torch.Tensor = convert_to_tensor(out)
+ return out_tensor
+
+
+class MaisiEncoder(nn.Module):
+ """
+ Convolutional cascade that downsamples the image into a spatial latent space.
+
+ Args:
+ spatial_dims: Number of spatial dimensions (1D, 2D, 3D).
+ in_channels: Number of input channels.
+ num_channels: Sequence of block output channels.
+ out_channels: Number of channels in the bottom layer (latent space) of the autoencoder.
+ num_res_blocks: Number of residual blocks (see AEKLResBlock) per level.
+ norm_num_groups: Number of groups for the group norm layers.
+ norm_eps: Epsilon for the normalization.
+ attention_levels: Indicate which level from num_channels contain an attention block.
+ with_nonlocal_attn: If True, use non-local attention block.
+ include_fc: whether to include the final linear layer in the attention block. Default to False.
+ use_combined_linear: whether to use a single linear layer for qkv projection in the attention block, default to False.
+ use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
+ num_splits: Number of splits for the input tensor.
+ dim_split: Dimension of splitting for the input tensor.
+ norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`.
+ print_info: Whether to print information, default to `False`.
+ save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
+ """
+
+ def __init__(
+ self,
+ spatial_dims: int,
+ in_channels: int,
+ num_channels: Sequence[int],
+ out_channels: int,
+ num_res_blocks: Sequence[int],
+ norm_num_groups: int,
+ norm_eps: float,
+ attention_levels: Sequence[bool],
+ num_splits: int,
+ dim_split: int,
+ norm_float16: bool = False,
+ print_info: bool = False,
+ save_mem: bool = True,
+ with_nonlocal_attn: bool = True,
+ include_fc: bool = False,
+ use_combined_linear: bool = False,
+ use_flash_attention: bool = False,
+ ) -> None:
+ super().__init__()
+
+ # Check if attention_levels and num_channels have the same size
+ if len(attention_levels) != len(num_channels):
+ raise ValueError("attention_levels and num_channels must have the same size")
+
+ # Check if num_res_blocks and num_channels have the same size
+ if len(num_res_blocks) != len(num_channels):
+ raise ValueError("num_res_blocks and num_channels must have the same size")
+
+ self.save_mem = save_mem
+
+ blocks: list[nn.Module] = []
+
+ blocks.append(
+ MaisiConvolution(
+ spatial_dims=spatial_dims,
+ in_channels=in_channels,
+ out_channels=num_channels[0],
+ strides=1,
+ kernel_size=3,
+ padding=1,
+ conv_only=True,
+ num_splits=num_splits,
+ dim_split=dim_split,
+ print_info=print_info,
+ save_mem=save_mem,
+ )
+ )
+
+ output_channel = num_channels[0]
+ for i in range(len(num_channels)):
+ input_channel = output_channel
+ output_channel = num_channels[i]
+ is_final_block = i == len(num_channels) - 1
+
+ for _ in range(num_res_blocks[i]):
+ blocks.append(
+ MaisiResBlock(
+ spatial_dims=spatial_dims,
+ in_channels=input_channel,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ out_channels=output_channel,
+ num_splits=num_splits,
+ dim_split=dim_split,
+ norm_float16=norm_float16,
+ print_info=print_info,
+ save_mem=save_mem,
+ )
+ )
+ input_channel = output_channel
+ if attention_levels[i]:
+ blocks.append(
+ SpatialAttentionBlock(
+ spatial_dims=spatial_dims,
+ num_channels=input_channel,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ include_fc=include_fc,
+ use_combined_linear=use_combined_linear,
+ use_flash_attention=use_flash_attention,
+ )
+ )
+
+ if not is_final_block:
+ blocks.append(
+ MaisiDownsample(
+ spatial_dims=spatial_dims,
+ in_channels=input_channel,
+ num_splits=num_splits,
+ dim_split=dim_split,
+ print_info=print_info,
+ save_mem=save_mem,
+ )
+ )
+
+ if with_nonlocal_attn:
+ blocks.append(
+ AEKLResBlock(
+ spatial_dims=spatial_dims,
+ in_channels=num_channels[-1],
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ out_channels=num_channels[-1],
+ )
+ )
+
+ blocks.append(
+ SpatialAttentionBlock(
+ spatial_dims=spatial_dims,
+ num_channels=num_channels[-1],
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ include_fc=include_fc,
+ use_combined_linear=use_combined_linear,
+ use_flash_attention=use_flash_attention,
+ )
+ )
+ blocks.append(
+ AEKLResBlock(
+ spatial_dims=spatial_dims,
+ in_channels=num_channels[-1],
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ out_channels=num_channels[-1],
+ )
+ )
+
+ blocks.append(
+ MaisiGroupNorm3D(
+ num_groups=norm_num_groups,
+ num_channels=num_channels[-1],
+ eps=norm_eps,
+ affine=True,
+ norm_float16=norm_float16,
+ print_info=print_info,
+ save_mem=save_mem,
+ )
+ )
+ blocks.append(
+ MaisiConvolution(
+ spatial_dims=spatial_dims,
+ in_channels=num_channels[-1],
+ out_channels=out_channels,
+ strides=1,
+ kernel_size=3,
+ padding=1,
+ conv_only=True,
+ num_splits=num_splits,
+ dim_split=dim_split,
+ print_info=print_info,
+ save_mem=save_mem,
+ )
+ )
+
+ self.blocks = nn.ModuleList(blocks)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ for block in self.blocks:
+ x = block(x)
+ _empty_cuda_cache(self.save_mem)
+ return x
+
+
+class MaisiDecoder(nn.Module):
+ """
+ Convolutional cascade upsampling from a spatial latent space into an image space.
+
+ Args:
+ spatial_dims: Number of spatial dimensions (1D, 2D, 3D).
+ num_channels: Sequence of block output channels.
+ in_channels: Number of channels in the bottom layer (latent space) of the autoencoder.
+ out_channels: Number of output channels.
+ num_res_blocks: Number of residual blocks (see AEKLResBlock) per level.
+ norm_num_groups: Number of groups for the group norm layers.
+ norm_eps: Epsilon for the normalization.
+ attention_levels: Indicate which level from num_channels contain an attention block.
+ with_nonlocal_attn: If True, use non-local attention block.
+ include_fc: whether to include the final linear layer in the attention block. Default to False.
+ use_combined_linear: whether to use a single linear layer for qkv projection in the attention block, default to False.
+ use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
+ use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder.
+ num_splits: Number of splits for the input tensor.
+ dim_split: Dimension of splitting for the input tensor.
+ norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`.
+ print_info: Whether to print information, default to `False`.
+ save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
+ """
+
+ def __init__(
+ self,
+ spatial_dims: int,
+ num_channels: Sequence[int],
+ in_channels: int,
+ out_channels: int,
+ num_res_blocks: Sequence[int],
+ norm_num_groups: int,
+ norm_eps: float,
+ attention_levels: Sequence[bool],
+ num_splits: int,
+ dim_split: int,
+ norm_float16: bool = False,
+ print_info: bool = False,
+ save_mem: bool = True,
+ with_nonlocal_attn: bool = True,
+ include_fc: bool = False,
+ use_combined_linear: bool = False,
+ use_flash_attention: bool = False,
+ use_convtranspose: bool = False,
+ ) -> None:
+ super().__init__()
+ self.print_info = print_info
+ self.save_mem = save_mem
+
+ reversed_block_out_channels = list(reversed(num_channels))
+
+ blocks: list[nn.Module] = []
+
+ blocks.append(
+ MaisiConvolution(
+ spatial_dims=spatial_dims,
+ in_channels=in_channels,
+ out_channels=reversed_block_out_channels[0],
+ strides=1,
+ kernel_size=3,
+ padding=1,
+ conv_only=True,
+ num_splits=num_splits,
+ dim_split=dim_split,
+ print_info=print_info,
+ save_mem=save_mem,
+ )
+ )
+
+ if with_nonlocal_attn:
+ blocks.append(
+ AEKLResBlock(
+ spatial_dims=spatial_dims,
+ in_channels=reversed_block_out_channels[0],
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ out_channels=reversed_block_out_channels[0],
+ )
+ )
+ blocks.append(
+ SpatialAttentionBlock(
+ spatial_dims=spatial_dims,
+ num_channels=reversed_block_out_channels[0],
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ include_fc=include_fc,
+ use_combined_linear=use_combined_linear,
+ use_flash_attention=use_flash_attention,
+ )
+ )
+ blocks.append(
+ AEKLResBlock(
+ spatial_dims=spatial_dims,
+ in_channels=reversed_block_out_channels[0],
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ out_channels=reversed_block_out_channels[0],
+ )
+ )
+
+ reversed_attention_levels = list(reversed(attention_levels))
+ reversed_num_res_blocks = list(reversed(num_res_blocks))
+ block_out_ch = reversed_block_out_channels[0]
+ for i in range(len(reversed_block_out_channels)):
+ block_in_ch = block_out_ch
+ block_out_ch = reversed_block_out_channels[i]
+ is_final_block = i == len(num_channels) - 1
+
+ for _ in range(reversed_num_res_blocks[i]):
+ blocks.append(
+ MaisiResBlock(
+ spatial_dims=spatial_dims,
+ in_channels=block_in_ch,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ out_channels=block_out_ch,
+ num_splits=num_splits,
+ dim_split=dim_split,
+ norm_float16=norm_float16,
+ print_info=print_info,
+ save_mem=save_mem,
+ )
+ )
+ block_in_ch = block_out_ch
+
+ if reversed_attention_levels[i]:
+ blocks.append(
+ SpatialAttentionBlock(
+ spatial_dims=spatial_dims,
+ num_channels=block_in_ch,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ include_fc=include_fc,
+ use_combined_linear=use_combined_linear,
+ use_flash_attention=use_flash_attention,
+ )
+ )
+
+ if not is_final_block:
+ blocks.append(
+ MaisiUpsample(
+ spatial_dims=spatial_dims,
+ in_channels=block_in_ch,
+ use_convtranspose=use_convtranspose,
+ num_splits=num_splits,
+ dim_split=dim_split,
+ print_info=print_info,
+ save_mem=save_mem,
+ )
+ )
+
+ blocks.append(
+ MaisiGroupNorm3D(
+ num_groups=norm_num_groups,
+ num_channels=block_in_ch,
+ eps=norm_eps,
+ affine=True,
+ norm_float16=norm_float16,
+ print_info=print_info,
+ save_mem=save_mem,
+ )
+ )
+ blocks.append(
+ MaisiConvolution(
+ spatial_dims=spatial_dims,
+ in_channels=block_in_ch,
+ out_channels=out_channels,
+ strides=1,
+ kernel_size=3,
+ padding=1,
+ conv_only=True,
+ num_splits=num_splits,
+ dim_split=dim_split,
+ print_info=print_info,
+ save_mem=save_mem,
+ )
+ )
+
+ self.blocks = nn.ModuleList(blocks)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ for block in self.blocks:
+ x = block(x)
+ _empty_cuda_cache(self.save_mem)
+ return x
+
+
+class AutoencoderKlMaisi(AutoencoderKL):
+ """
+ AutoencoderKL with custom MaisiEncoder and MaisiDecoder.
+
+ Args:
+ spatial_dims: Number of spatial dimensions (1D, 2D, 3D).
+ in_channels: Number of input channels.
+ out_channels: Number of output channels.
+ num_res_blocks: Number of residual blocks per level.
+ num_channels: Sequence of block output channels.
+ attention_levels: Indicate which level from num_channels contain an attention block.
+ latent_channels: Number of channels in the latent space.
+ norm_num_groups: Number of groups for the group norm layers.
+ norm_eps: Epsilon for the normalization.
+ with_encoder_nonlocal_attn: If True, use non-local attention block in the encoder.
+ with_decoder_nonlocal_attn: If True, use non-local attention block in the decoder.
+ include_fc: whether to include the final linear layer. Default to False.
+ use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
+ use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
+ use_checkpointing: If True, use activation checkpointing.
+ use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder.
+ num_splits: Number of splits for the input tensor.
+ dim_split: Dimension of splitting for the input tensor.
+ norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`.
+ print_info: Whether to print information, default to `False`.
+ save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
+ """
+
+ def __init__(
+ self,
+ spatial_dims: int,
+ in_channels: int,
+ out_channels: int,
+ num_res_blocks: Sequence[int],
+ num_channels: Sequence[int],
+ attention_levels: Sequence[bool],
+ latent_channels: int = 3,
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-6,
+ with_encoder_nonlocal_attn: bool = False,
+ with_decoder_nonlocal_attn: bool = False,
+ include_fc: bool = False,
+ use_combined_linear: bool = False,
+ use_flash_attention: bool = False,
+ use_checkpointing: bool = False,
+ use_convtranspose: bool = False,
+ num_splits: int = 16,
+ dim_split: int = 0,
+ norm_float16: bool = False,
+ print_info: bool = False,
+ save_mem: bool = True,
+ ) -> None:
+ super().__init__(
+ spatial_dims,
+ in_channels,
+ out_channels,
+ num_res_blocks,
+ num_channels,
+ attention_levels,
+ latent_channels,
+ norm_num_groups,
+ norm_eps,
+ with_encoder_nonlocal_attn,
+ with_decoder_nonlocal_attn,
+ use_checkpointing,
+ use_convtranspose,
+ include_fc,
+ use_combined_linear,
+ use_flash_attention,
+ )
+
+ self.encoder: nn.Module = MaisiEncoder(
+ spatial_dims=spatial_dims,
+ in_channels=in_channels,
+ num_channels=num_channels,
+ out_channels=latent_channels,
+ num_res_blocks=num_res_blocks,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ attention_levels=attention_levels,
+ with_nonlocal_attn=with_encoder_nonlocal_attn,
+ include_fc=include_fc,
+ use_combined_linear=use_combined_linear,
+ use_flash_attention=use_flash_attention,
+ num_splits=num_splits,
+ dim_split=dim_split,
+ norm_float16=norm_float16,
+ print_info=print_info,
+ save_mem=save_mem,
+ )
+
+ self.decoder: nn.Module = MaisiDecoder(
+ spatial_dims=spatial_dims,
+ num_channels=num_channels,
+ in_channels=latent_channels,
+ out_channels=out_channels,
+ num_res_blocks=num_res_blocks,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ attention_levels=attention_levels,
+ with_nonlocal_attn=with_decoder_nonlocal_attn,
+ include_fc=include_fc,
+ use_combined_linear=use_combined_linear,
+ use_flash_attention=use_flash_attention,
+ use_convtranspose=use_convtranspose,
+ num_splits=num_splits,
+ dim_split=dim_split,
+ norm_float16=norm_float16,
+ print_info=print_info,
+ save_mem=save_mem,
+ )
diff --git a/monai/apps/generation/maisi/networks/controlnet_maisi.py b/monai/apps/generation/maisi/networks/controlnet_maisi.py
new file mode 100644
index 0000000000..7c13fd7bc6
--- /dev/null
+++ b/monai/apps/generation/maisi/networks/controlnet_maisi.py
@@ -0,0 +1,175 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+from collections.abc import Sequence
+
+import torch
+
+from monai.networks.nets.controlnet import ControlNet
+from monai.networks.nets.diffusion_model_unet import get_timestep_embedding
+
+
+class ControlNetMaisi(ControlNet):
+ """
+ Control network for diffusion models based on Zhang and Agrawala "Adding Conditional Control to Text-to-Image
+ Diffusion Models" (https://arxiv.org/abs/2302.05543)
+
+ Args:
+ spatial_dims: number of spatial dimensions.
+ in_channels: number of input channels.
+ num_res_blocks: number of residual blocks (see ResnetBlock) per level.
+ num_channels: tuple of block output channels.
+ attention_levels: list of levels to add attention.
+ norm_num_groups: number of groups for the normalization.
+ norm_eps: epsilon for the normalization.
+ resblock_updown: if True use residual blocks for up/downsampling.
+ num_head_channels: number of channels in each attention head.
+ with_conditioning: if True add spatial transformers to perform conditioning.
+ transformer_num_layers: number of layers of Transformer blocks to use.
+ cross_attention_dim: number of context dimensions to use.
+ num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds`
+ classes.
+ upcast_attention: if True, upcast attention operations to full precision.
+ conditioning_embedding_in_channels: number of input channels for the conditioning embedding.
+ conditioning_embedding_num_channels: number of channels for the blocks in the conditioning embedding.
+ use_checkpointing: if True, use activation checkpointing to save memory.
+ include_fc: whether to include the final linear layer. Default to False.
+ use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
+ use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
+ """
+
+ def __init__(
+ self,
+ spatial_dims: int,
+ in_channels: int,
+ num_res_blocks: Sequence[int] | int = (2, 2, 2, 2),
+ num_channels: Sequence[int] = (32, 64, 64, 64),
+ attention_levels: Sequence[bool] = (False, False, True, True),
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-6,
+ resblock_updown: bool = False,
+ num_head_channels: int | Sequence[int] = 8,
+ with_conditioning: bool = False,
+ transformer_num_layers: int = 1,
+ cross_attention_dim: int | None = None,
+ num_class_embeds: int | None = None,
+ upcast_attention: bool = False,
+ conditioning_embedding_in_channels: int = 1,
+ conditioning_embedding_num_channels: Sequence[int] = (16, 32, 96, 256),
+ use_checkpointing: bool = True,
+ include_fc: bool = False,
+ use_combined_linear: bool = False,
+ use_flash_attention: bool = False,
+ ) -> None:
+ super().__init__(
+ spatial_dims,
+ in_channels,
+ num_res_blocks,
+ num_channels,
+ attention_levels,
+ norm_num_groups,
+ norm_eps,
+ resblock_updown,
+ num_head_channels,
+ with_conditioning,
+ transformer_num_layers,
+ cross_attention_dim,
+ num_class_embeds,
+ upcast_attention,
+ conditioning_embedding_in_channels,
+ conditioning_embedding_num_channels,
+ include_fc,
+ use_combined_linear,
+ use_flash_attention,
+ )
+ self.use_checkpointing = use_checkpointing
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ timesteps: torch.Tensor,
+ controlnet_cond: torch.Tensor,
+ conditioning_scale: float = 1.0,
+ context: torch.Tensor | None = None,
+ class_labels: torch.Tensor | None = None,
+ ) -> tuple[list[torch.Tensor], torch.Tensor]:
+ emb = self._prepare_time_and_class_embedding(x, timesteps, class_labels)
+ h = self._apply_initial_convolution(x)
+ if self.use_checkpointing:
+ controlnet_cond = torch.utils.checkpoint.checkpoint(
+ self.controlnet_cond_embedding, controlnet_cond, use_reentrant=False
+ )
+ else:
+ controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
+ h += controlnet_cond
+ down_block_res_samples, h = self._apply_down_blocks(emb, context, h)
+ h = self._apply_mid_block(emb, context, h)
+ down_block_res_samples, mid_block_res_sample = self._apply_controlnet_blocks(h, down_block_res_samples)
+ # scaling
+ down_block_res_samples = [h * conditioning_scale for h in down_block_res_samples]
+ mid_block_res_sample *= conditioning_scale
+
+ return down_block_res_samples, mid_block_res_sample
+
+ def _prepare_time_and_class_embedding(self, x, timesteps, class_labels):
+ # 1. time
+ t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0])
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=x.dtype)
+ emb = self.time_embed(t_emb)
+
+ # 2. class
+ if self.num_class_embeds is not None:
+ if class_labels is None:
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
+ class_emb = self.class_embedding(class_labels)
+ class_emb = class_emb.to(dtype=x.dtype)
+ emb = emb + class_emb
+
+ return emb
+
+ def _apply_initial_convolution(self, x):
+ # 3. initial convolution
+ h = self.conv_in(x)
+ return h
+
+ def _apply_down_blocks(self, emb, context, h):
+ # 4. down
+ if context is not None and self.with_conditioning is False:
+ raise ValueError("model should have with_conditioning = True if context is provided")
+ down_block_res_samples: list[torch.Tensor] = [h]
+ for downsample_block in self.down_blocks:
+ h, res_samples = downsample_block(hidden_states=h, temb=emb, context=context)
+ for residual in res_samples:
+ down_block_res_samples.append(residual)
+
+ return down_block_res_samples, h
+
+ def _apply_mid_block(self, emb, context, h):
+ # 5. mid
+ h = self.middle_block(hidden_states=h, temb=emb, context=context)
+ return h
+
+ def _apply_controlnet_blocks(self, h, down_block_res_samples):
+ # 6. Control net blocks
+ controlnet_down_block_res_samples = []
+ for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
+ down_block_res_sample = controlnet_block(down_block_res_sample)
+ controlnet_down_block_res_samples.append(down_block_res_sample)
+
+ mid_block_res_sample = self.controlnet_mid_block(h)
+
+ return controlnet_down_block_res_samples, mid_block_res_sample
diff --git a/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py
new file mode 100644
index 0000000000..e990b5fc98
--- /dev/null
+++ b/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py
@@ -0,0 +1,410 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# =========================================================================
+# Adapted from https://github.com/huggingface/diffusers
+# which has the following license:
+# https://github.com/huggingface/diffusers/blob/main/LICENSE
+#
+# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =========================================================================
+
+from __future__ import annotations
+
+from collections.abc import Sequence
+
+import torch
+from torch import nn
+
+from monai.networks.blocks import Convolution
+from monai.networks.nets.diffusion_model_unet import (
+ get_down_block,
+ get_mid_block,
+ get_timestep_embedding,
+ get_up_block,
+ zero_module,
+)
+from monai.utils import ensure_tuple_rep
+from monai.utils.type_conversion import convert_to_tensor
+
+__all__ = ["DiffusionModelUNetMaisi"]
+
+
+class DiffusionModelUNetMaisi(nn.Module):
+ """
+ U-Net network with timestep embedding and attention mechanisms for conditioning based on
+ Rombach et al. "High-Resolution Image Synthesis with Latent Diffusion Models" https://arxiv.org/abs/2112.10752
+ and Pinaya et al. "Brain Imaging Generation with Latent Diffusion Models" https://arxiv.org/abs/2209.07162
+
+ Args:
+ spatial_dims: Number of spatial dimensions.
+ in_channels: Number of input channels.
+ out_channels: Number of output channels.
+ num_res_blocks: Number of residual blocks (see ResnetBlock) per level. Can be a single integer or a sequence of integers.
+ num_channels: Tuple of block output channels.
+ attention_levels: List of levels to add attention.
+ norm_num_groups: Number of groups for the normalization.
+ norm_eps: Epsilon for the normalization.
+ resblock_updown: If True, use residual blocks for up/downsampling.
+ num_head_channels: Number of channels in each attention head. Can be a single integer or a sequence of integers.
+ with_conditioning: If True, add spatial transformers to perform conditioning.
+ transformer_num_layers: Number of layers of Transformer blocks to use.
+ cross_attention_dim: Number of context dimensions to use.
+ num_class_embeds: If specified (as an int), then this model will be class-conditional with `num_class_embeds` classes.
+ upcast_attention: If True, upcast attention operations to full precision.
+ include_fc: whether to include the final linear layer. Default to False.
+ use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
+ use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
+ dropout_cattn: If different from zero, this will be the dropout value for the cross-attention layers.
+ include_top_region_index_input: If True, use top region index input.
+ include_bottom_region_index_input: If True, use bottom region index input.
+ include_spacing_input: If True, use spacing input.
+ """
+
+ def __init__(
+ self,
+ spatial_dims: int,
+ in_channels: int,
+ out_channels: int,
+ num_res_blocks: Sequence[int] | int = (2, 2, 2, 2),
+ num_channels: Sequence[int] = (32, 64, 64, 64),
+ attention_levels: Sequence[bool] = (False, False, True, True),
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-6,
+ resblock_updown: bool = False,
+ num_head_channels: int | Sequence[int] = 8,
+ with_conditioning: bool = False,
+ transformer_num_layers: int = 1,
+ cross_attention_dim: int | None = None,
+ num_class_embeds: int | None = None,
+ upcast_attention: bool = False,
+ include_fc: bool = False,
+ use_combined_linear: bool = False,
+ use_flash_attention: bool = False,
+ dropout_cattn: float = 0.0,
+ include_top_region_index_input: bool = False,
+ include_bottom_region_index_input: bool = False,
+ include_spacing_input: bool = False,
+ ) -> None:
+ super().__init__()
+ if with_conditioning is True and cross_attention_dim is None:
+ raise ValueError(
+ "DiffusionModelUNetMaisi expects dimension of the cross-attention conditioning (cross_attention_dim) "
+ "when using with_conditioning."
+ )
+ if cross_attention_dim is not None and with_conditioning is False:
+ raise ValueError(
+ "DiffusionModelUNetMaisi expects with_conditioning=True when specifying the cross_attention_dim."
+ )
+ if dropout_cattn > 1.0 or dropout_cattn < 0.0:
+ raise ValueError("Dropout cannot be negative or >1.0!")
+
+ # All number of channels should be multiple of num_groups
+ if any((out_channel % norm_num_groups) != 0 for out_channel in num_channels):
+ raise ValueError(
+ f"DiffusionModelUNetMaisi expects all num_channels being multiple of norm_num_groups, "
+ f"but get num_channels: {num_channels} and norm_num_groups: {norm_num_groups}"
+ )
+
+ if len(num_channels) != len(attention_levels):
+ raise ValueError(
+ f"DiffusionModelUNetMaisi expects num_channels being same size of attention_levels, "
+ f"but get num_channels: {len(num_channels)} and attention_levels: {len(attention_levels)}"
+ )
+
+ if isinstance(num_head_channels, int):
+ num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels))
+
+ if len(num_head_channels) != len(attention_levels):
+ raise ValueError(
+ "num_head_channels should have the same length as attention_levels. For the i levels without attention,"
+ " i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored."
+ )
+
+ if isinstance(num_res_blocks, int):
+ num_res_blocks = ensure_tuple_rep(num_res_blocks, len(num_channels))
+
+ if len(num_res_blocks) != len(num_channels):
+ raise ValueError(
+ "`num_res_blocks` should be a single integer or a tuple of integers with the same length as "
+ "`num_channels`."
+ )
+
+ if use_flash_attention is True and not torch.cuda.is_available():
+ raise ValueError(
+ "torch.cuda.is_available() should be True but is False. Flash attention is only available for GPU."
+ )
+
+ self.in_channels = in_channels
+ self.block_out_channels = num_channels
+ self.out_channels = out_channels
+ self.num_res_blocks = num_res_blocks
+ self.attention_levels = attention_levels
+ self.num_head_channels = num_head_channels
+ self.with_conditioning = with_conditioning
+
+ # input
+ self.conv_in = Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=in_channels,
+ out_channels=num_channels[0],
+ strides=1,
+ kernel_size=3,
+ padding=1,
+ conv_only=True,
+ )
+
+ # time
+ time_embed_dim = num_channels[0] * 4
+ self.time_embed = self._create_embedding_module(num_channels[0], time_embed_dim)
+
+ # class embedding
+ self.num_class_embeds = num_class_embeds
+ if num_class_embeds is not None:
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
+
+ self.include_top_region_index_input = include_top_region_index_input
+ self.include_bottom_region_index_input = include_bottom_region_index_input
+ self.include_spacing_input = include_spacing_input
+
+ new_time_embed_dim = time_embed_dim
+ if self.include_top_region_index_input:
+ self.top_region_index_layer = self._create_embedding_module(4, time_embed_dim)
+ new_time_embed_dim += time_embed_dim
+ if self.include_bottom_region_index_input:
+ self.bottom_region_index_layer = self._create_embedding_module(4, time_embed_dim)
+ new_time_embed_dim += time_embed_dim
+ if self.include_spacing_input:
+ self.spacing_layer = self._create_embedding_module(3, time_embed_dim)
+ new_time_embed_dim += time_embed_dim
+
+ # down
+ self.down_blocks = nn.ModuleList([])
+ output_channel = num_channels[0]
+ for i in range(len(num_channels)):
+ input_channel = output_channel
+ output_channel = num_channels[i]
+ is_final_block = i == len(num_channels) - 1
+ down_block = get_down_block(
+ spatial_dims=spatial_dims,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=new_time_embed_dim,
+ num_res_blocks=num_res_blocks[i],
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ add_downsample=not is_final_block,
+ resblock_updown=resblock_updown,
+ with_attn=(attention_levels[i] and not with_conditioning),
+ with_cross_attn=(attention_levels[i] and with_conditioning),
+ num_head_channels=num_head_channels[i],
+ transformer_num_layers=transformer_num_layers,
+ cross_attention_dim=cross_attention_dim,
+ upcast_attention=upcast_attention,
+ include_fc=include_fc,
+ use_combined_linear=use_combined_linear,
+ use_flash_attention=use_flash_attention,
+ dropout_cattn=dropout_cattn,
+ )
+
+ self.down_blocks.append(down_block)
+
+ # mid
+ self.middle_block = get_mid_block(
+ spatial_dims=spatial_dims,
+ in_channels=num_channels[-1],
+ temb_channels=new_time_embed_dim,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ with_conditioning=with_conditioning,
+ num_head_channels=num_head_channels[-1],
+ transformer_num_layers=transformer_num_layers,
+ cross_attention_dim=cross_attention_dim,
+ upcast_attention=upcast_attention,
+ include_fc=include_fc,
+ use_combined_linear=use_combined_linear,
+ use_flash_attention=use_flash_attention,
+ dropout_cattn=dropout_cattn,
+ )
+
+ # up
+ self.up_blocks = nn.ModuleList([])
+ reversed_block_out_channels = list(reversed(num_channels))
+ reversed_num_res_blocks = list(reversed(num_res_blocks))
+ reversed_attention_levels = list(reversed(attention_levels))
+ reversed_num_head_channels = list(reversed(num_head_channels))
+ output_channel = reversed_block_out_channels[0]
+ for i in range(len(reversed_block_out_channels)):
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ input_channel = reversed_block_out_channels[min(i + 1, len(num_channels) - 1)]
+
+ is_final_block = i == len(num_channels) - 1
+
+ up_block = get_up_block(
+ spatial_dims=spatial_dims,
+ in_channels=input_channel,
+ prev_output_channel=prev_output_channel,
+ out_channels=output_channel,
+ temb_channels=new_time_embed_dim,
+ num_res_blocks=reversed_num_res_blocks[i] + 1,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ add_upsample=not is_final_block,
+ resblock_updown=resblock_updown,
+ with_attn=(reversed_attention_levels[i] and not with_conditioning),
+ with_cross_attn=(reversed_attention_levels[i] and with_conditioning),
+ num_head_channels=reversed_num_head_channels[i],
+ transformer_num_layers=transformer_num_layers,
+ cross_attention_dim=cross_attention_dim,
+ upcast_attention=upcast_attention,
+ include_fc=include_fc,
+ use_combined_linear=use_combined_linear,
+ use_flash_attention=use_flash_attention,
+ dropout_cattn=dropout_cattn,
+ )
+
+ self.up_blocks.append(up_block)
+
+ # out
+ self.out = nn.Sequential(
+ nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels[0], eps=norm_eps, affine=True),
+ nn.SiLU(),
+ zero_module(
+ Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=num_channels[0],
+ out_channels=out_channels,
+ strides=1,
+ kernel_size=3,
+ padding=1,
+ conv_only=True,
+ )
+ ),
+ )
+
+ def _create_embedding_module(self, input_dim, embed_dim):
+ model = nn.Sequential(nn.Linear(input_dim, embed_dim), nn.SiLU(), nn.Linear(embed_dim, embed_dim))
+ return model
+
+ def _get_time_and_class_embedding(self, x, timesteps, class_labels):
+ t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0])
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=x.dtype)
+ emb = self.time_embed(t_emb)
+
+ if self.num_class_embeds is not None:
+ if class_labels is None:
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
+ class_emb = self.class_embedding(class_labels)
+ class_emb = class_emb.to(dtype=x.dtype)
+ emb += class_emb
+ return emb
+
+ def _get_input_embeddings(self, emb, top_index, bottom_index, spacing):
+ if self.include_top_region_index_input:
+ _emb = self.top_region_index_layer(top_index)
+ emb = torch.cat((emb, _emb), dim=1)
+ if self.include_bottom_region_index_input:
+ _emb = self.bottom_region_index_layer(bottom_index)
+ emb = torch.cat((emb, _emb), dim=1)
+ if self.include_spacing_input:
+ _emb = self.spacing_layer(spacing)
+ emb = torch.cat((emb, _emb), dim=1)
+ return emb
+
+ def _apply_down_blocks(self, h, emb, context, down_block_additional_residuals):
+ if context is not None and self.with_conditioning is False:
+ raise ValueError("model should have with_conditioning = True if context is provided")
+ down_block_res_samples: list[torch.Tensor] = [h]
+ for downsample_block in self.down_blocks:
+ h, res_samples = downsample_block(hidden_states=h, temb=emb, context=context)
+ down_block_res_samples.extend(res_samples)
+
+ # Additional residual conections for Controlnets
+ if down_block_additional_residuals is not None:
+ new_down_block_res_samples: list[torch.Tensor] = []
+ for down_block_res_sample, down_block_additional_residual in zip(
+ down_block_res_samples, down_block_additional_residuals
+ ):
+ down_block_res_sample += down_block_additional_residual
+ new_down_block_res_samples.append(down_block_res_sample)
+
+ down_block_res_samples = new_down_block_res_samples
+ return h, down_block_res_samples
+
+ def _apply_up_blocks(self, h, emb, context, down_block_res_samples):
+ for upsample_block in self.up_blocks:
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+ h = upsample_block(hidden_states=h, res_hidden_states_list=res_samples, temb=emb, context=context)
+
+ return h
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ timesteps: torch.Tensor,
+ context: torch.Tensor | None = None,
+ class_labels: torch.Tensor | None = None,
+ down_block_additional_residuals: tuple[torch.Tensor] | None = None,
+ mid_block_additional_residual: torch.Tensor | None = None,
+ top_region_index_tensor: torch.Tensor | None = None,
+ bottom_region_index_tensor: torch.Tensor | None = None,
+ spacing_tensor: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ """
+ Forward pass through the UNet model.
+
+ Args:
+ x: Input tensor of shape (N, C, SpatialDims).
+ timesteps: Timestep tensor of shape (N,).
+ context: Context tensor of shape (N, 1, ContextDim).
+ class_labels: Class labels tensor of shape (N,).
+ down_block_additional_residuals: Additional residual tensors for down blocks of shape (N, C, FeatureMapsDims).
+ mid_block_additional_residual: Additional residual tensor for mid block of shape (N, C, FeatureMapsDims).
+ top_region_index_tensor: Tensor representing top region index of shape (N, 4).
+ bottom_region_index_tensor: Tensor representing bottom region index of shape (N, 4).
+ spacing_tensor: Tensor representing spacing of shape (N, 3).
+
+ Returns:
+ A tensor representing the output of the UNet model.
+ """
+
+ emb = self._get_time_and_class_embedding(x, timesteps, class_labels)
+ emb = self._get_input_embeddings(emb, top_region_index_tensor, bottom_region_index_tensor, spacing_tensor)
+ h = self.conv_in(x)
+ h, _updated_down_block_res_samples = self._apply_down_blocks(h, emb, context, down_block_additional_residuals)
+ h = self.middle_block(h, emb, context)
+
+ # Additional residual conections for Controlnets
+ if mid_block_additional_residual is not None:
+ h += mid_block_additional_residual
+
+ h = self._apply_up_blocks(h, emb, context, _updated_down_block_res_samples)
+ h = self.out(h)
+ h_tensor: torch.Tensor = convert_to_tensor(h)
+ return h_tensor
diff --git a/monai/apps/nnunet/nnunetv2_runner.py b/monai/apps/nnunet/nnunetv2_runner.py
index 44b3c24256..8a10849904 100644
--- a/monai/apps/nnunet/nnunetv2_runner.py
+++ b/monai/apps/nnunet/nnunetv2_runner.py
@@ -22,6 +22,7 @@
from monai.apps.nnunet.utils import analyze_data, create_new_data_copy, create_new_dataset_json
from monai.bundle import ConfigParser
from monai.utils import ensure_tuple, optional_import
+from monai.utils.misc import run_cmd
load_pickle, _ = optional_import("batchgenerators.utilities.file_and_folder_operations", name="load_pickle")
join, _ = optional_import("batchgenerators.utilities.file_and_folder_operations", name="join")
@@ -495,65 +496,64 @@ def train_single_model(self, config: Any, fold: int, gpu_id: tuple | list | int
fold: fold of the 5-fold cross-validation. Should be an int between 0 and 4.
gpu_id: an integer to select the device to use, or a tuple/list of GPU device indices used for multi-GPU
training (e.g., (0,1)). Default: 0.
- from nnunetv2.run.run_training import run_training
kwargs: this optional parameter allows you to specify additional arguments in
- ``nnunetv2.run.run_training.run_training``. Currently supported args are
- - plans_identifier: custom plans identifier. Default: "nnUNetPlans".
- - pretrained_weights: path to nnU-Net checkpoint file to be used as pretrained model. Will only be
- used when actually training. Beta. Use with caution. Default: False.
- - use_compressed_data: True to use compressed data for training. Reading compressed data is much
- more CPU and (potentially) RAM intensive and should only be used if you know what you are
- doing. Default: False.
- - continue_training: continue training from latest checkpoint. Default: False.
- - only_run_validation: True to run the validation only. Requires training to have finished.
- Default: False.
- - disable_checkpointing: True to disable checkpointing. Ideal for testing things out and you
- don't want to flood your hard drive with checkpoints. Default: False.
+ ``nnunetv2.run.run_training.run_training_entry``.
+
+ Currently supported args are:
+
+ - p: custom plans identifier. Default: "nnUNetPlans".
+ - pretrained_weights: path to nnU-Net checkpoint file to be used as pretrained model. Will only be
+ used when actually training. Beta. Use with caution. Default: False.
+ - use_compressed: True to use compressed data for training. Reading compressed data is much
+ more CPU and (potentially) RAM intensive and should only be used if you know what you are
+ doing. Default: False.
+ - c: continue training from latest checkpoint. Default: False.
+ - val: True to run the validation only. Requires training to have finished.
+ Default: False.
+ - disable_checkpointing: True to disable checkpointing. Ideal for testing things out and you
+ don't want to flood your hard drive with checkpoints. Default: False.
"""
if "num_gpus" in kwargs:
kwargs.pop("num_gpus")
logger.warning("please use gpu_id to set the GPUs to use")
- if "trainer_class_name" in kwargs:
- kwargs.pop("trainer_class_name")
+ if "tr" in kwargs:
+ kwargs.pop("tr")
logger.warning("please specify the `trainer_class_name` in the __init__ of `nnUNetV2Runner`.")
- if "export_validation_probabilities" in kwargs:
- kwargs.pop("export_validation_probabilities")
+ if "npz" in kwargs:
+ kwargs.pop("npz")
logger.warning("please specify the `export_validation_probabilities` in the __init__ of `nnUNetV2Runner`.")
+ cmd = self.train_single_model_command(config, fold, gpu_id, kwargs)
+ run_cmd(cmd, shell=True)
+
+ def train_single_model_command(self, config, fold, gpu_id, kwargs):
if isinstance(gpu_id, (tuple, list)):
if len(gpu_id) > 1:
gpu_ids_str = ""
for _i in range(len(gpu_id)):
gpu_ids_str += f"{gpu_id[_i]},"
- os.environ["CUDA_VISIBLE_DEVICES"] = gpu_ids_str[:-1]
+ device_setting = f"CUDA_VISIBLE_DEVICES={gpu_ids_str[:-1]}"
else:
- os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_id[0]}"
- else:
- os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_id}"
-
- from nnunetv2.run.run_training import run_training
-
- if isinstance(gpu_id, int) or len(gpu_id) == 1:
- run_training(
- dataset_name_or_id=self.dataset_name_or_id,
- configuration=config,
- fold=fold,
- trainer_class_name=self.trainer_class_name,
- export_validation_probabilities=self.export_validation_probabilities,
- **kwargs,
- )
+ device_setting = f"CUDA_VISIBLE_DEVICES={gpu_id[0]}"
else:
- run_training(
- dataset_name_or_id=self.dataset_name_or_id,
- configuration=config,
- fold=fold,
- num_gpus=len(gpu_id),
- trainer_class_name=self.trainer_class_name,
- export_validation_probabilities=self.export_validation_probabilities,
- **kwargs,
- )
+ device_setting = f"CUDA_VISIBLE_DEVICES={gpu_id}"
+ num_gpus = 1 if isinstance(gpu_id, int) or len(gpu_id) == 1 else len(gpu_id)
+
+ cmd = (
+ f"{device_setting} nnUNetv2_train "
+ + f"{self.dataset_name_or_id} {config} {fold} "
+ + f"-tr {self.trainer_class_name} -num_gpus {num_gpus}"
+ )
+ if self.export_validation_probabilities:
+ cmd += " --npz"
+ for _key, _value in kwargs.items():
+ if _key == "p" or _key == "pretrained_weights":
+ cmd += f" -{_key} {_value}"
+ else:
+ cmd += f" --{_key} {_value}"
+ return cmd
def train(
self,
@@ -637,15 +637,7 @@ def train_parallel_cmd(
if _config in ensure_tuple(configs):
for _i in range(self.num_folds):
the_device = gpu_id_for_all[_index % n_devices] # type: ignore
- cmd = (
- "python -m monai.apps.nnunet nnUNetV2Runner train_single_model "
- + f"--input_config '{self.input_config_or_dict}' --work_dir '{self.work_dir}' "
- + f"--config '{_config}' --fold {_i} --gpu_id {the_device} "
- + f"--trainer_class_name {self.trainer_class_name} "
- + f"--export_validation_probabilities {self.export_validation_probabilities}"
- )
- for _key, _value in kwargs.items():
- cmd += f" --{_key} {_value}"
+ cmd = self.train_single_model_command(_config, _i, the_device, kwargs)
all_cmds[-1][the_device].append(cmd)
_index += 1
return all_cmds
diff --git a/monai/apps/nuclick/transforms.py b/monai/apps/nuclick/transforms.py
index f22ea764be..4828bd2e5a 100644
--- a/monai/apps/nuclick/transforms.py
+++ b/monai/apps/nuclick/transforms.py
@@ -24,7 +24,7 @@
measure, _ = optional_import("skimage.measure")
morphology, _ = optional_import("skimage.morphology")
-distance_transform_cdt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_cdt")
+distance_transform_cdt, _ = optional_import("scipy.ndimage", name="distance_transform_cdt")
class NuclickKeys(StrEnum):
diff --git a/monai/apps/pathology/engines/utils.py b/monai/apps/pathology/engines/utils.py
index 02249ed640..87ca0f8e76 100644
--- a/monai/apps/pathology/engines/utils.py
+++ b/monai/apps/pathology/engines/utils.py
@@ -11,7 +11,8 @@
from __future__ import annotations
-from typing import Any, Sequence
+from collections.abc import Sequence
+from typing import Any
import torch
diff --git a/monai/apps/pathology/inferers/inferer.py b/monai/apps/pathology/inferers/inferer.py
index 71259ca7df..392cba221f 100644
--- a/monai/apps/pathology/inferers/inferer.py
+++ b/monai/apps/pathology/inferers/inferer.py
@@ -11,7 +11,8 @@
from __future__ import annotations
-from typing import Any, Callable, Sequence
+from collections.abc import Sequence
+from typing import Any, Callable
import numpy as np
import torch
diff --git a/monai/apps/pathology/metrics/lesion_froc.py b/monai/apps/pathology/metrics/lesion_froc.py
index f4bf51ab28..138488348a 100644
--- a/monai/apps/pathology/metrics/lesion_froc.py
+++ b/monai/apps/pathology/metrics/lesion_froc.py
@@ -11,7 +11,8 @@
from __future__ import annotations
-from typing import TYPE_CHECKING, Any, Iterable
+from collections.abc import Iterable
+from typing import TYPE_CHECKING, Any
import numpy as np
diff --git a/monai/apps/pathology/transforms/post/array.py b/monai/apps/pathology/transforms/post/array.py
index 99e94f89c0..561ed3ae20 100644
--- a/monai/apps/pathology/transforms/post/array.py
+++ b/monai/apps/pathology/transforms/post/array.py
@@ -12,7 +12,8 @@
from __future__ import annotations
import warnings
-from typing import Callable, Sequence
+from collections.abc import Sequence
+from typing import Callable
import numpy as np
import torch
@@ -28,12 +29,12 @@
SobelGradients,
)
from monai.transforms.transform import Transform
-from monai.transforms.utils_pytorch_numpy_unification import max, maximum, min, sum, unique
+from monai.transforms.utils_pytorch_numpy_unification import max, maximum, min, sum, unique, where
from monai.utils import TransformBackends, convert_to_numpy, optional_import
from monai.utils.misc import ensure_tuple_rep
from monai.utils.type_conversion import convert_to_dst_type, convert_to_tensor
-label, _ = optional_import("scipy.ndimage.measurements", name="label")
+label, _ = optional_import("scipy.ndimage", name="label")
disk, _ = optional_import("skimage.morphology", name="disk")
opening, _ = optional_import("skimage.morphology", name="opening")
watershed, _ = optional_import("skimage.segmentation", name="watershed")
@@ -162,7 +163,8 @@ def __call__(self, prob_map: NdarrayOrTensor) -> NdarrayOrTensor:
pred = label(pred)[0]
if self.remove_small_objects is not None:
pred = self.remove_small_objects(pred)
- pred[pred > 0] = 1
+ pred_indices = np.where(pred > 0)
+ pred[pred_indices] = 1
return convert_to_dst_type(pred, prob_map, dtype=self.dtype)[0]
@@ -338,7 +340,8 @@ def __call__(self, mask: NdarrayOrTensor, instance_border: NdarrayOrTensor) -> N
instance_border = instance_border >= self.threshold # uncertain area
marker = mask - convert_to_dst_type(instance_border, mask)[0] # certain foreground
- marker[marker < 0] = 0
+ marker_indices = where(marker < 0)
+ marker[marker_indices] = 0 # type: ignore[index]
marker = self.postprocess_fn(marker)
marker = convert_to_numpy(marker)
@@ -379,6 +382,7 @@ def _generate_contour_coord(self, current: np.ndarray, previous: np.ndarray) ->
"""
p_delta = (current[0] - previous[0], current[1] - previous[1])
+ row, col = -1, -1
if p_delta in ((0.0, 1.0), (0.5, 0.5), (1.0, 0.0)):
row = int(current[0] + 0.5)
@@ -634,7 +638,7 @@ def __call__( # type: ignore
seg_map_crop = convert_to_dst_type(seg_map_crop == instance_id, type_map_crop, dtype=bool)[0]
- inst_type = type_map_crop[seg_map_crop]
+ inst_type = type_map_crop[seg_map_crop] # type: ignore[index]
type_list, type_pixels = unique(inst_type, return_counts=True)
type_list = list(zip(type_list, type_pixels))
type_list = sorted(type_list, key=lambda x: x[1], reverse=True)
diff --git a/monai/apps/pathology/utils.py b/monai/apps/pathology/utils.py
index d3ebe0a7a6..3aa0bfab86 100644
--- a/monai/apps/pathology/utils.py
+++ b/monai/apps/pathology/utils.py
@@ -33,10 +33,10 @@ def compute_multi_instance_mask(mask: np.ndarray, threshold: float) -> Any:
"""
neg = 255 - mask * 255
- distance = ndimage.morphology.distance_transform_edt(neg)
+ distance = ndimage.distance_transform_edt(neg)
binary = distance < threshold
- filled_image = ndimage.morphology.binary_fill_holes(binary)
+ filled_image = ndimage.binary_fill_holes(binary)
multi_instance_mask = measure.label(filled_image, connectivity=2)
return multi_instance_mask
diff --git a/monai/apps/tcia/utils.py b/monai/apps/tcia/utils.py
index 5524b488e9..f023cdbc87 100644
--- a/monai/apps/tcia/utils.py
+++ b/monai/apps/tcia/utils.py
@@ -12,7 +12,7 @@
from __future__ import annotations
import os
-from typing import Iterable
+from collections.abc import Iterable
import monai
from monai.config.type_definitions import PathLike
diff --git a/monai/apps/utils.py b/monai/apps/utils.py
index db541923b5..c2e17d3247 100644
--- a/monai/apps/utils.py
+++ b/monai/apps/utils.py
@@ -135,7 +135,9 @@ def check_hash(filepath: PathLike, val: str | None = None, hash_type: str = "md5
logger.info(f"Expected {hash_type} is None, skip {hash_type} check for file {filepath}.")
return True
actual_hash_func = look_up_option(hash_type.lower(), SUPPORTED_HASH_TYPES)
- actual_hash = actual_hash_func()
+
+ actual_hash = actual_hash_func(usedforsecurity=False) # allows checks on FIPS enabled machines
+
try:
with open(filepath, "rb") as f:
for chunk in iter(lambda: f.read(1024 * 1024), b""):
diff --git a/monai/apps/vista3d/__init__.py b/monai/apps/vista3d/__init__.py
new file mode 100644
index 0000000000..1e97f89407
--- /dev/null
+++ b/monai/apps/vista3d/__init__.py
@@ -0,0 +1,10 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/monai/apps/vista3d/inferer.py b/monai/apps/vista3d/inferer.py
new file mode 100644
index 0000000000..8f622ef6cd
--- /dev/null
+++ b/monai/apps/vista3d/inferer.py
@@ -0,0 +1,177 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import copy
+from collections.abc import Sequence
+from typing import Any
+
+import torch
+
+from monai.data.meta_tensor import MetaTensor
+from monai.utils import optional_import
+
+tqdm, _ = optional_import("tqdm", name="tqdm")
+
+__all__ = ["point_based_window_inferer"]
+
+
+def point_based_window_inferer(
+ inputs: torch.Tensor | MetaTensor,
+ roi_size: Sequence[int],
+ predictor: torch.nn.Module,
+ point_coords: torch.Tensor,
+ point_labels: torch.Tensor,
+ class_vector: torch.Tensor | None = None,
+ prompt_class: torch.Tensor | None = None,
+ prev_mask: torch.Tensor | MetaTensor | None = None,
+ point_start: int = 0,
+ center_only: bool = True,
+ margin: int = 5,
+ **kwargs: Any,
+) -> torch.Tensor:
+ """
+ Point-based window inferer that takes an input image, a set of points, and a model, and returns a segmented image.
+ The inferer algorithm crops the input image into patches that centered at the point sets, which is followed by
+ patch inference and average output stitching, and finally returns the segmented mask.
+
+ Args:
+ inputs: [1CHWD], input image to be processed.
+ roi_size: the spatial window size for inferences.
+ When its components have None or non-positives, the corresponding inputs dimension will be used.
+ if the components of the `roi_size` are non-positive values, the transform will use the
+ corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted
+ to `(32, 64)` if the second spatial dimension size of img is `64`.
+ sw_batch_size: the batch size to run window slices.
+ predictor: the model. For vista3D, the output is [B, 1, H, W, D] which needs to be transposed to [1, B, H, W, D].
+ Add transpose=True in kwargs for vista3d.
+ point_coords: [B, N, 3]. Point coordinates for B foreground objects, each has N points.
+ point_labels: [B, N]. Point labels. 0/1 means negative/positive points for regular supported or zero-shot classes.
+ 2/3 means negative/positive points for special supported classes (e.g. tumor, vessel).
+ class_vector: [B]. Used for class-head automatic segmentation. Can be None value.
+ prompt_class: [B]. The same as class_vector representing the point class and inform point head about
+ supported class or zeroshot, not used for automatic segmentation. If None, point head is default
+ to supported class segmentation.
+ prev_mask: [1, B, H, W, D]. The value is before sigmoid. An optional tensor of previously segmented masks.
+ point_start: only use points starting from this number. All points before this number is used to generate
+ prev_mask. This is used to avoid re-calculating the points in previous iterations if given prev_mask.
+ center_only: for each point, only crop the patch centered at this point. If false, crop 3 patches for each point.
+ margin: if center_only is false, this value is the distance between point to the patch boundary.
+ Returns:
+ stitched_output: [1, B, H, W, D]. The value is before sigmoid.
+ Notice: The function only supports SINGLE OBJECT INFERENCE with B=1.
+ """
+ if not point_coords.shape[0] == 1:
+ raise ValueError("Only supports single object point click.")
+ if not len(inputs.shape) == 5:
+ raise ValueError("Input image should be 5D.")
+ image, pad = _pad_previous_mask(copy.deepcopy(inputs), roi_size)
+ point_coords = point_coords + torch.tensor([pad[-2], pad[-4], pad[-6]]).to(point_coords.device)
+ prev_mask = _pad_previous_mask(copy.deepcopy(prev_mask), roi_size)[0] if prev_mask is not None else None
+ stitched_output = None
+ for p in point_coords[0][point_start:]:
+ lx_, rx_ = _get_window_idx(p[0], roi_size[0], image.shape[-3], center_only=center_only, margin=margin)
+ ly_, ry_ = _get_window_idx(p[1], roi_size[1], image.shape[-2], center_only=center_only, margin=margin)
+ lz_, rz_ = _get_window_idx(p[2], roi_size[2], image.shape[-1], center_only=center_only, margin=margin)
+ for i in range(len(lx_)):
+ for j in range(len(ly_)):
+ for k in range(len(lz_)):
+ lx, rx, ly, ry, lz, rz = (lx_[i], rx_[i], ly_[j], ry_[j], lz_[k], rz_[k])
+ unravel_slice = [
+ slice(None),
+ slice(None),
+ slice(int(lx), int(rx)),
+ slice(int(ly), int(ry)),
+ slice(int(lz), int(rz)),
+ ]
+ batch_image = image[unravel_slice]
+ output = predictor(
+ batch_image,
+ point_coords=point_coords,
+ point_labels=point_labels,
+ class_vector=class_vector,
+ prompt_class=prompt_class,
+ patch_coords=[unravel_slice],
+ prev_mask=prev_mask,
+ **kwargs,
+ )
+ if stitched_output is None:
+ stitched_output = torch.zeros(
+ [1, output.shape[1], image.shape[-3], image.shape[-2], image.shape[-1]], device="cpu"
+ )
+ stitched_mask = torch.zeros(
+ [1, output.shape[1], image.shape[-3], image.shape[-2], image.shape[-1]], device="cpu"
+ )
+ stitched_output[unravel_slice] += output.to("cpu")
+ stitched_mask[unravel_slice] = 1
+ # if stitched_mask is 0, then NaN value
+ stitched_output = stitched_output / stitched_mask
+ # revert padding
+ stitched_output = stitched_output[
+ :, :, pad[4] : image.shape[-3] - pad[5], pad[2] : image.shape[-2] - pad[3], pad[0] : image.shape[-1] - pad[1]
+ ]
+ stitched_mask = stitched_mask[
+ :, :, pad[4] : image.shape[-3] - pad[5], pad[2] : image.shape[-2] - pad[3], pad[0] : image.shape[-1] - pad[1]
+ ]
+ if prev_mask is not None:
+ prev_mask = prev_mask[
+ :,
+ :,
+ pad[4] : image.shape[-3] - pad[5],
+ pad[2] : image.shape[-2] - pad[3],
+ pad[0] : image.shape[-1] - pad[1],
+ ]
+ prev_mask = prev_mask.to("cpu") # type: ignore
+ # for un-calculated place, use previous mask
+ stitched_output[stitched_mask < 1] = prev_mask[stitched_mask < 1]
+ if isinstance(inputs, torch.Tensor):
+ inputs = MetaTensor(inputs)
+ if not hasattr(stitched_output, "meta"):
+ stitched_output = MetaTensor(stitched_output, affine=inputs.meta["affine"], meta=inputs.meta)
+ return stitched_output
+
+
+def _get_window_idx_c(p: int, roi: int, s: int) -> tuple[int, int]:
+ """Helper function to get the window index."""
+ if p - roi // 2 < 0:
+ left, right = 0, roi
+ elif p + roi // 2 > s:
+ left, right = s - roi, s
+ else:
+ left, right = int(p) - roi // 2, int(p) + roi // 2
+ return left, right
+
+
+def _get_window_idx(p: int, roi: int, s: int, center_only: bool = True, margin: int = 5) -> tuple[list[int], list[int]]:
+ """Get the window index."""
+ left, right = _get_window_idx_c(p, roi, s)
+ if center_only:
+ return [left], [right]
+ left_most = max(0, p - roi + margin)
+ right_most = min(s, p + roi - margin)
+ left_list = [left_most, right_most - roi, left]
+ right_list = [left_most + roi, right_most, right]
+ return left_list, right_list
+
+
+def _pad_previous_mask(
+ inputs: torch.Tensor | MetaTensor, roi_size: Sequence[int], padvalue: int = 0
+) -> tuple[torch.Tensor | MetaTensor, list[int]]:
+ """Helper function to pad inputs."""
+ pad_size = []
+ for k in range(len(inputs.shape) - 1, 1, -1):
+ diff = max(roi_size[k - 2] - inputs.shape[k], 0)
+ half = diff // 2
+ pad_size.extend([half, diff - half])
+ if any(pad_size):
+ inputs = torch.nn.functional.pad(inputs, pad=pad_size, mode="constant", value=padvalue) # type: ignore
+ return inputs, pad_size
diff --git a/monai/apps/vista3d/sampler.py b/monai/apps/vista3d/sampler.py
new file mode 100644
index 0000000000..17b2d34911
--- /dev/null
+++ b/monai/apps/vista3d/sampler.py
@@ -0,0 +1,179 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import copy
+import random
+from collections.abc import Callable, Sequence
+from typing import Any
+
+import numpy as np
+import torch
+from torch import Tensor
+
+ENABLE_SPECIAL = True
+SPECIAL_INDEX = (23, 24, 25, 26, 27, 57, 128)
+MERGE_LIST = {
+ 1: [25, 26], # hepatic tumor and vessel merge into liver
+ 4: [24], # pancreatic tumor merge into pancreas
+ 132: [57], # overlap with trachea merge into airway
+}
+
+__all__ = ["sample_prompt_pairs"]
+
+
+def _get_point_label(id: int) -> tuple[int, int]:
+ if id in SPECIAL_INDEX and ENABLE_SPECIAL:
+ return 2, 3
+ else:
+ return 0, 1
+
+
+def sample_prompt_pairs(
+ labels: Tensor,
+ label_set: Sequence[int],
+ max_prompt: int | None = None,
+ max_foreprompt: int | None = None,
+ max_backprompt: int = 1,
+ max_point: int = 20,
+ include_background: bool = False,
+ drop_label_prob: float = 0.2,
+ drop_point_prob: float = 0.2,
+ point_sampler: Callable | None = None,
+ **point_sampler_kwargs: Any,
+) -> tuple[Tensor | None, Tensor | None, Tensor | None, Tensor | None]:
+ """
+ Sample training pairs for VISTA3D training.
+
+ Args:
+ labels: [1, 1, H, W, D], ground truth labels.
+ label_set: the label list for the specific dataset. Note if 0 is included in label_set,
+ it will be added into automatic branch training. Recommend removing 0 from label_set
+ for multi-partially-labeled-dataset training, and adding 0 for finetuning specific dataset.
+ The reason is region with 0 in one partially labeled dataset may contain foregrounds in
+ another dataset.
+ max_prompt: int, max number of total prompt, including foreground and background.
+ max_foreprompt: int, max number of prompt from foreground.
+ max_backprompt: int, max number of prompt from background.
+ max_point: maximum number of points for each object.
+ include_background: if include 0 into training prompt. If included, background 0 is treated
+ the same as foreground and points will be sampled. Can be true only if user want to segment
+ background 0 with point clicks, otherwise always be false.
+ drop_label_prob: probability to drop label prompt.
+ drop_point_prob: probability to drop point prompt.
+ point_sampler: sampler to augment masks with supervoxel.
+ point_sampler_kwargs: arguments for point_sampler.
+
+ Returns:
+ tuple:
+ - label_prompt (Tensor | None): Tensor of shape [B, 1] containing the classes used for
+ training automatic segmentation.
+ - point (Tensor | None): Tensor of shape [B, N, 3] representing the corresponding points
+ for each class. Note that background label prompts require matching points as well
+ (e.g., [0, 0, 0] is used).
+ - point_label (Tensor | None): Tensor of shape [B, N] representing the corresponding point
+ labels for each point (negative or positive). -1 is used for padding the background
+ label prompt and will be ignored.
+ - prompt_class (Tensor | None): Tensor of shape [B, 1], exactly the same as label_prompt
+ for label indexing during training. If label_prompt is None, prompt_class is used to
+ identify point classes.
+
+ """
+
+ # class label number
+ if not labels.shape[0] == 1:
+ raise ValueError("only support batch size 1")
+ labels = labels[0, 0]
+ device = labels.device
+ unique_labels = labels.unique().cpu().numpy().tolist()
+ if include_background:
+ unique_labels = list(set(unique_labels) - (set(unique_labels) - set(label_set)))
+ else:
+ unique_labels = list(set(unique_labels) - (set(unique_labels) - set(label_set)) - {0})
+ background_labels = list(set(label_set) - set(unique_labels))
+ # during training, balance background and foreground prompts
+ if max_backprompt is not None:
+ if len(background_labels) > max_backprompt:
+ random.shuffle(background_labels)
+ background_labels = background_labels[:max_backprompt]
+
+ if max_foreprompt is not None:
+ if len(unique_labels) > max_foreprompt:
+ random.shuffle(unique_labels)
+ unique_labels = unique_labels[:max_foreprompt]
+
+ if max_prompt is not None:
+ if len(unique_labels) + len(background_labels) > max_prompt:
+ if len(unique_labels) > max_prompt:
+ unique_labels = random.sample(unique_labels, max_prompt)
+ background_labels = []
+ else:
+ background_labels = random.sample(background_labels, max_prompt - len(unique_labels))
+ _point = []
+ _point_label = []
+ # if use regular sampling
+ if point_sampler is None:
+ num_p = min(max_point, int(np.abs(random.gauss(mu=0, sigma=max_point // 2))) + 1)
+ num_n = min(max_point, int(np.abs(random.gauss(mu=0, sigma=max_point // 2))))
+ for id in unique_labels:
+ neg_id, pos_id = _get_point_label(id)
+ plabels = labels == int(id)
+ nlabels = ~plabels
+ plabelpoints = torch.nonzero(plabels)
+ nlabelpoints = torch.nonzero(nlabels)
+ # final sampled positive points
+ num_pa = min(len(plabelpoints), num_p)
+ # final sampled negative points
+ num_na = min(len(nlabelpoints), num_n)
+ _point.append(
+ torch.stack(
+ random.choices(plabelpoints, k=num_pa)
+ + random.choices(nlabelpoints, k=num_na)
+ + [torch.tensor([0, 0, 0], device=device)] * (num_p + num_n - num_pa - num_na)
+ )
+ )
+ _point_label.append(
+ torch.tensor([pos_id] * num_pa + [neg_id] * num_na + [-1] * (num_p + num_n - num_pa - num_na)).to(
+ device
+ )
+ )
+ for _ in background_labels:
+ # pad the background labels
+ _point.append(torch.zeros(num_p + num_n, 3).to(device)) # all 0
+ _point_label.append(torch.zeros(num_p + num_n).to(device) - 1) # -1 not a point
+ else:
+ _point, _point_label = point_sampler(unique_labels, **point_sampler_kwargs)
+ for _ in background_labels:
+ # pad the background labels
+ _point.append(torch.zeros(len(_point_label[0]), 3).to(device)) # all 0
+ _point_label.append(torch.zeros(len(_point_label[0])).to(device) - 1) # -1 not a point
+ if len(unique_labels) == 0 and len(background_labels) == 0:
+ # if max_backprompt is 0 and len(unique_labels), there is no effective prompt and the iteration must
+ # be skipped. Handle this in trainer.
+ label_prompt, point, point_label, prompt_class = None, None, None, None
+ else:
+ label_prompt = torch.tensor(unique_labels + background_labels).unsqueeze(-1).to(device).long()
+ point = torch.stack(_point)
+ point_label = torch.stack(_point_label)
+ prompt_class = copy.deepcopy(label_prompt)
+ if random.uniform(0, 1) < drop_label_prob and len(unique_labels) > 0:
+ label_prompt = None
+ # If label prompt is dropped, there is no need to pad with points with label -1.
+ pad = len(background_labels)
+ point = point[: len(point) - pad] # type: ignore
+ point_label = point_label[: len(point_label) - pad]
+ prompt_class = prompt_class[: len(prompt_class) - pad]
+ else:
+ if random.uniform(0, 1) < drop_point_prob:
+ point = None
+ point_label = None
+ return label_prompt, point, point_label, prompt_class
diff --git a/monai/apps/vista3d/transforms.py b/monai/apps/vista3d/transforms.py
new file mode 100644
index 0000000000..bd7fb19493
--- /dev/null
+++ b/monai/apps/vista3d/transforms.py
@@ -0,0 +1,224 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import warnings
+from collections.abc import Sequence
+
+import numpy as np
+import torch
+
+from monai.config import DtypeLike, KeysCollection
+from monai.transforms import MapLabelValue
+from monai.transforms.transform import MapTransform
+from monai.transforms.utils import keep_components_with_positive_points
+from monai.utils import look_up_option
+
+__all__ = ["VistaPreTransformd", "VistaPostTransformd", "Relabeld"]
+
+
+def _get_name_to_index_mapping(labels_dict: dict | None) -> dict:
+ """get the label name to index mapping"""
+ name_to_index_mapping = {}
+ if labels_dict is not None:
+ name_to_index_mapping = {v.lower(): int(k) for k, v in labels_dict.items()}
+ return name_to_index_mapping
+
+
+def _convert_name_to_index(name_to_index_mapping: dict, label_prompt: list | None) -> list | None:
+ """convert the label name to index"""
+ if label_prompt is not None and isinstance(label_prompt, list):
+ converted_label_prompt = []
+ # for new class, add to the mapping
+ for l in label_prompt:
+ if isinstance(l, str) and not l.isdigit():
+ if l.lower() not in name_to_index_mapping:
+ name_to_index_mapping[l.lower()] = len(name_to_index_mapping)
+ for l in label_prompt:
+ if isinstance(l, (int, str)):
+ converted_label_prompt.append(
+ name_to_index_mapping.get(l.lower(), int(l) if l.isdigit() else 0) if isinstance(l, str) else int(l)
+ )
+ else:
+ converted_label_prompt.append(l)
+ return converted_label_prompt
+ return label_prompt
+
+
+class VistaPreTransformd(MapTransform):
+ def __init__(
+ self,
+ keys: KeysCollection,
+ allow_missing_keys: bool = False,
+ special_index: Sequence[int] = (25, 26, 27, 28, 29, 117),
+ labels_dict: dict | None = None,
+ subclass: dict | None = None,
+ ) -> None:
+ """
+ Pre-transform for Vista3d.
+
+ It performs two functionalities:
+
+ 1. If label prompt shows the points belong to special class (defined by special index, e.g. tumors, vessels),
+ convert point labels from 0 (negative), 1 (positive) to special 2 (negative), 3 (positive).
+
+ 2. If label prompt is within the keys in subclass, convert the label prompt to its subclasses defined by subclass[key].
+ e.g. "lung" label is converted to ["left lung", "right lung"].
+
+ The `label_prompt` is a list of int values of length [B] and `point_labels` is a list of length B,
+ where each element is an int value of length [B, N].
+
+ Args:
+ keys: keys of the corresponding items to be transformed.
+ special_index: the index that defines the special class.
+ subclass: a dictionary that maps a label prompt to its subclasses.
+ allow_missing_keys: don't raise exception if key is missing.
+ """
+ super().__init__(keys, allow_missing_keys)
+ self.special_index = special_index
+ self.subclass = subclass
+ self.name_to_index_mapping = _get_name_to_index_mapping(labels_dict)
+
+ def __call__(self, data):
+ label_prompt = data.get("label_prompt", None)
+ point_labels = data.get("point_labels", None)
+ # convert the label name to index if needed
+ label_prompt = _convert_name_to_index(self.name_to_index_mapping, label_prompt)
+ try:
+ # The evaluator will check prompt. The invalid prompt will be skipped here and captured by evaluator.
+ if self.subclass is not None and label_prompt is not None:
+ _label_prompt = []
+ subclass_keys = list(map(int, self.subclass.keys()))
+ for i in range(len(label_prompt)):
+ if label_prompt[i] in subclass_keys:
+ _label_prompt.extend(self.subclass[str(label_prompt[i])])
+ else:
+ _label_prompt.append(label_prompt[i])
+ data["label_prompt"] = _label_prompt
+ if label_prompt is not None and point_labels is not None:
+ if label_prompt[0] in self.special_index:
+ point_labels = np.array(point_labels)
+ point_labels[point_labels == 0] = 2
+ point_labels[point_labels == 1] = 3
+ point_labels = point_labels.tolist()
+ data["point_labels"] = point_labels
+ except Exception:
+ # There is specific requirements for `label_prompt` and `point_labels`.
+ # If B > 1 or `label_prompt` is in subclass_keys, `point_labels` must be None.
+ # Those formatting errors should be captured later.
+ warnings.warn("VistaPreTransformd failed to transform label prompt or point labels.")
+
+ return data
+
+
+class VistaPostTransformd(MapTransform):
+ def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None:
+ """
+ Post-transform for Vista3d. It converts the model output logits into final segmentation masks.
+ If `label_prompt` is None, the output will be thresholded to be sequential indexes [0,1,2,...],
+ else the indexes will be [0, label_prompt[0], label_prompt[1], ...].
+ If `label_prompt` is None while `points` are provided, the model will perform postprocess to remove
+ regions that does not contain positive points.
+
+ Args:
+ keys: keys of the corresponding items to be transformed.
+ dataset_transforms: a dictionary specifies the transform for corresponding dataset:
+ key: dataset name, value: list of data transforms.
+ dataset_key: key to get the dataset name from the data dictionary, default to "dataset_name".
+ allow_missing_keys: don't raise exception if key is missing.
+
+ """
+ super().__init__(keys, allow_missing_keys)
+
+ def __call__(self, data):
+ """data["label_prompt"] should not contain 0"""
+ for keys in self.keys:
+ if keys in data:
+ pred = data[keys]
+ object_num = pred.shape[0]
+ device = pred.device
+ if data.get("label_prompt", None) is None and data.get("points", None) is not None:
+ pred = keep_components_with_positive_points(
+ pred.unsqueeze(0),
+ point_coords=data.get("points").to(device),
+ point_labels=data.get("point_labels").to(device),
+ )[0]
+ pred[pred < 0] = 0.0
+ # if it's multichannel, perform argmax
+ if object_num > 1:
+ # concate background channel. Make sure user did not provide 0 as prompt.
+ is_bk = torch.all(pred <= 0, dim=0, keepdim=True)
+ pred = pred.argmax(0).unsqueeze(0).float() + 1.0
+ pred[is_bk] = 0.0
+ else:
+ # AsDiscrete will remove NaN
+ # pred = monai.transforms.AsDiscrete(threshold=0.5)(pred)
+ pred[pred > 0] = 1.0
+ if "label_prompt" in data and data["label_prompt"] is not None:
+ pred += 0.5 # inplace mapping to avoid cloning pred
+ label_prompt = data["label_prompt"].to(device) # Ensure label_prompt is on the same device
+ for i in range(1, object_num + 1):
+ frac = i + 0.5
+ pred[pred == frac] = label_prompt[i - 1].to(pred.dtype)
+ pred[pred == 0.5] = 0.0
+ data[keys] = pred
+ return data
+
+
+class Relabeld(MapTransform):
+ def __init__(
+ self,
+ keys: KeysCollection,
+ label_mappings: dict[str, list[tuple[int, int]]],
+ dtype: DtypeLike = np.int16,
+ dataset_key: str = "dataset_name",
+ allow_missing_keys: bool = False,
+ ) -> None:
+ """
+ Remap the voxel labels in the input data dictionary based on the specified mapping.
+
+ This list of local -> global label mappings will be applied to each input `data[keys]`.
+ if `data[dataset_key]` is not in `label_mappings`, label_mappings['default']` will be used.
+ if `label_mappings[data[dataset_key]]` is None, no relabeling will be performed.
+
+ Args:
+ keys: keys of the corresponding items to be transformed.
+ label_mappings: a dictionary specifies how local dataset class indices are mapped to the
+ global class indices. The dictionary keys are dataset names and the values are lists of
+ list of (local label, global label) pairs. This list of local -> global label mappings
+ will be applied to each input `data[keys]`. If `data[dataset_key]` is not in `label_mappings`,
+ label_mappings['default']` will be used. if `label_mappings[data[dataset_key]]` is None,
+ no relabeling will be performed. Please set `label_mappings={}` to completely skip this transform.
+ dtype: convert the output data to dtype, default to float32.
+ dataset_key: key to get the dataset name from the data dictionary, default to "dataset_name".
+ allow_missing_keys: don't raise exception if key is missing.
+
+ """
+ super().__init__(keys, allow_missing_keys)
+ self.mappers = {}
+ self.dataset_key = dataset_key
+ for name, mapping in label_mappings.items():
+ self.mappers[name] = MapLabelValue(
+ orig_labels=[int(pair[0]) for pair in mapping],
+ target_labels=[int(pair[1]) for pair in mapping],
+ dtype=dtype,
+ )
+
+ def __call__(self, data):
+ d = dict(data)
+ dataset_name = d.get(self.dataset_key, "default")
+ _m = look_up_option(dataset_name, self.mappers, default=None)
+ if _m is None:
+ return d
+ for key in self.key_iterator(d):
+ d[key] = _m(d[key])
+ return d
diff --git a/monai/auto3dseg/analyzer.py b/monai/auto3dseg/analyzer.py
index 37f3faea21..e60327b551 100644
--- a/monai/auto3dseg/analyzer.py
+++ b/monai/auto3dseg/analyzer.py
@@ -470,7 +470,7 @@ def __call__(self, data: Mapping[Hashable, MetaTensor]) -> dict[Hashable, MetaTe
unique_label = unique(ndas_label)
if isinstance(ndas_label, (MetaTensor, torch.Tensor)):
- unique_label = unique_label.data.cpu().numpy()
+ unique_label = unique_label.data.cpu().numpy() # type: ignore[assignment]
unique_label = unique_label.astype(np.int16).tolist()
diff --git a/monai/auto3dseg/utils.py b/monai/auto3dseg/utils.py
index 58b900d410..211f23c415 100644
--- a/monai/auto3dseg/utils.py
+++ b/monai/auto3dseg/utils.py
@@ -407,7 +407,7 @@ def _prepare_cmd_default(cmd: str, cmd_prefix: str | None = None, **kwargs: Any)
Args:
cmd: the command or script to run in the distributed job.
- cmd_prefix: the command prefix to run the script, e.g., "python", "python -m", "python3", "/opt/conda/bin/python3.8 ".
+ cmd_prefix: the command prefix to run the script, e.g., "python", "python -m", "python3", "/opt/conda/bin/python3.9 ".
kwargs: the keyword arguments to be passed to the script.
Returns:
diff --git a/monai/bundle/config_parser.py b/monai/bundle/config_parser.py
index 829036af6f..1d9920a230 100644
--- a/monai/bundle/config_parser.py
+++ b/monai/bundle/config_parser.py
@@ -20,7 +20,7 @@
from monai.bundle.config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem
from monai.bundle.reference_resolver import ReferenceResolver
-from monai.bundle.utils import ID_REF_KEY, ID_SEP_KEY, MACRO_KEY
+from monai.bundle.utils import ID_REF_KEY, ID_SEP_KEY, MACRO_KEY, merge_kv
from monai.config import PathLike
from monai.utils import ensure_tuple, look_up_option, optional_import
from monai.utils.misc import CheckKeyDuplicatesYamlLoader, check_key_duplicates
@@ -118,7 +118,7 @@ def __init__(
self.ref_resolver = ReferenceResolver()
if config is None:
config = {self.meta_key: {}}
- self.set(config=config)
+ self.set(config=self.ref_resolver.normalize_meta_id(config))
def __repr__(self):
return f"{self.config}"
@@ -221,7 +221,7 @@ def set(self, config: Any, id: str = "", recursive: bool = True) -> None:
if isinstance(conf_, dict) and k not in conf_:
conf_[k] = {}
conf_ = conf_[k if isinstance(conf_, dict) else int(k)]
- self[ReferenceResolver.normalize_id(id)] = config
+ self[ReferenceResolver.normalize_id(id)] = self.ref_resolver.normalize_meta_id(config)
def update(self, pairs: dict[str, Any]) -> None:
"""
@@ -423,8 +423,10 @@ def load_config_files(cls, files: PathLike | Sequence[PathLike] | dict, **kwargs
if isinstance(files, str) and not Path(files).is_file() and "," in files:
files = files.split(",")
for i in ensure_tuple(files):
- for k, v in (cls.load_config_file(i, **kwargs)).items():
- parser[k] = v
+ config_dict = cls.load_config_file(i, **kwargs)
+ for k, v in config_dict.items():
+ merge_kv(parser, k, v)
+
return parser.get() # type: ignore
@classmethod
diff --git a/monai/bundle/reference_resolver.py b/monai/bundle/reference_resolver.py
index b36f2cc4a5..df69b021e1 100644
--- a/monai/bundle/reference_resolver.py
+++ b/monai/bundle/reference_resolver.py
@@ -13,11 +13,11 @@
import re
import warnings
-from collections.abc import Sequence
-from typing import Any, Iterator
+from collections.abc import Iterator, Sequence
+from typing import Any
from monai.bundle.config_item import ConfigComponent, ConfigExpression, ConfigItem
-from monai.bundle.utils import ID_REF_KEY, ID_SEP_KEY
+from monai.bundle.utils import DEPRECATED_ID_MAPPING, ID_REF_KEY, ID_SEP_KEY
from monai.utils import allow_missing_reference, look_up_option
__all__ = ["ReferenceResolver"]
@@ -202,6 +202,23 @@ def normalize_id(cls, id: str | int) -> str:
"""
return str(id).replace("#", cls.sep) # backward compatibility `#` is the old separator
+ def normalize_meta_id(self, config: Any) -> Any:
+ """
+ Update deprecated identifiers in `config` using `DEPRECATED_ID_MAPPING`.
+ This will replace names that are marked as deprecated with their replacement.
+
+ Args:
+ config: input config to be updated.
+ """
+ if isinstance(config, dict):
+ for _id, _new_id in DEPRECATED_ID_MAPPING.items():
+ if _id in config.keys():
+ warnings.warn(
+ f"Detected deprecated name '{_id}' in configuration file, replacing with '{_new_id}'."
+ )
+ config[_new_id] = config.pop(_id)
+ return config
+
@classmethod
def split_id(cls, id: str | int, last: bool = False) -> list[str]:
"""
diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py
index 2565a3cf64..131c78008b 100644
--- a/monai/bundle/scripts.py
+++ b/monai/bundle/scripts.py
@@ -16,7 +16,9 @@
import os
import re
import warnings
+import zipfile
from collections.abc import Mapping, Sequence
+from functools import partial
from pathlib import Path
from pydoc import locate
from shutil import copyfile
@@ -26,13 +28,13 @@
import torch
from torch.cuda import is_available
-from monai.apps.mmars.mmars import _get_all_ngc_models
+from monai._version import get_versions
from monai.apps.utils import _basename, download_url, extractall, get_logger
from monai.bundle.config_item import ConfigComponent
from monai.bundle.config_parser import ConfigParser
-from monai.bundle.utils import DEFAULT_INFERENCE, DEFAULT_METADATA
+from monai.bundle.utils import DEFAULT_INFERENCE, DEFAULT_METADATA, merge_kv
from monai.bundle.workflows import BundleWorkflow, ConfigWorkflow
-from monai.config import IgniteInfo, PathLike
+from monai.config import PathLike
from monai.data import load_net_with_metadata, save_net_with_metadata
from monai.networks import (
convert_to_onnx,
@@ -43,6 +45,7 @@
save_state,
)
from monai.utils import (
+ IgniteInfo,
check_parent_dir,
deprecated_arg,
ensure_tuple,
@@ -66,6 +69,9 @@
DEFAULT_DOWNLOAD_SOURCE = os.environ.get("BUNDLE_DOWNLOAD_SRC", "monaihosting")
PPRINT_CONFIG_N = 5
+MONAI_HOSTING_BASE_URL = "https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting"
+NGC_BASE_URL = "https://api.ngc.nvidia.com/v2/models/nvidia/monaitoolkit"
+
def update_kwargs(args: str | dict | None = None, ignore_none: bool = True, **kwargs: Any) -> dict:
"""
@@ -100,7 +106,7 @@ def update_kwargs(args: str | dict | None = None, ignore_none: bool = True, **kw
if isinstance(v, dict) and isinstance(args_.get(k), dict):
args_[k] = update_kwargs(args_[k], ignore_none, **v)
else:
- args_[k] = v
+ merge_kv(args_, k, v)
return args_
@@ -168,12 +174,19 @@ def _get_git_release_url(repo_owner: str, repo_name: str, tag_name: str, filenam
def _get_ngc_bundle_url(model_name: str, version: str) -> str:
- return f"https://api.ngc.nvidia.com/v2/models/nvidia/monaitoolkit/{model_name.lower()}/versions/{version}/zip"
+ return f"{NGC_BASE_URL}/{model_name.lower()}/versions/{version}/zip"
+
+
+def _get_ngc_private_base_url(repo: str) -> str:
+ return f"https://api.ngc.nvidia.com/v2/{repo}/models"
+
+
+def _get_ngc_private_bundle_url(model_name: str, version: str, repo: str) -> str:
+ return f"{_get_ngc_private_base_url(repo)}/{model_name.lower()}/versions/{version}/zip"
def _get_monaihosting_bundle_url(model_name: str, version: str) -> str:
- monaihosting_root_path = "https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting"
- return f"{monaihosting_root_path}/{model_name.lower()}/versions/{version}/files/{model_name}_v{version}.zip"
+ return f"{MONAI_HOSTING_BASE_URL}/{model_name.lower()}/versions/{version}/files/{model_name}_v{version}.zip"
def _download_from_github(repo: str, download_path: Path, filename: str, progress: bool = True) -> None:
@@ -206,10 +219,15 @@ def _remove_ngc_prefix(name: str, prefix: str = "monai_") -> str:
def _download_from_ngc(
- download_path: Path, filename: str, version: str, remove_prefix: str | None, progress: bool
+ download_path: Path,
+ filename: str,
+ version: str,
+ prefix: str = "monai_",
+ remove_prefix: str | None = "monai_",
+ progress: bool = True,
) -> None:
# ensure prefix is contained
- filename = _add_ngc_prefix(filename)
+ filename = _add_ngc_prefix(filename, prefix=prefix)
url = _get_ngc_bundle_url(model_name=filename, version=version)
filepath = download_path / f"{filename}_v{version}.zip"
if remove_prefix:
@@ -219,29 +237,175 @@ def _download_from_ngc(
extractall(filepath=filepath, output_dir=extract_path, has_base=True)
+def _download_from_ngc_private(
+ download_path: Path,
+ filename: str,
+ version: str,
+ repo: str,
+ prefix: str = "monai_",
+ remove_prefix: str | None = "monai_",
+ headers: dict | None = None,
+) -> None:
+ # ensure prefix is contained
+ filename = _add_ngc_prefix(filename, prefix=prefix)
+ request_url = _get_ngc_private_bundle_url(model_name=filename, version=version, repo=repo)
+ if has_requests:
+ headers = {} if headers is None else headers
+ response = requests_get(request_url, headers=headers)
+ response.raise_for_status()
+ else:
+ raise ValueError("NGC API requires requests package. Please install it.")
+
+ os.makedirs(download_path, exist_ok=True)
+ zip_path = download_path / f"{filename}_v{version}.zip"
+ with open(zip_path, "wb") as f:
+ f.write(response.content)
+ logger.info(f"Downloading: {zip_path}.")
+ if remove_prefix:
+ filename = _remove_ngc_prefix(filename, prefix=remove_prefix)
+ extract_path = download_path / f"{filename}"
+ with zipfile.ZipFile(zip_path, "r") as z:
+ z.extractall(extract_path)
+ logger.info(f"Writing into directory: {extract_path}.")
+
+
+def _get_ngc_token(api_key, retry=0):
+ """Try to connect to NGC."""
+ url = "https://authn.nvidia.com/token?service=ngc"
+ headers = {"Accept": "application/json", "Authorization": "ApiKey " + api_key}
+ if has_requests:
+ response = requests_get(url, headers=headers)
+ if not response.ok:
+ # retry 3 times, if failed, raise an error.
+ if retry < 3:
+ logger.info(f"Retrying {retry} time(s) to GET {url}.")
+ return _get_ngc_token(url, retry + 1)
+ raise RuntimeError("NGC API response is not ok. Failed to get token.")
+ else:
+ token = response.json()["token"]
+ return token
+
+
def _get_latest_bundle_version_monaihosting(name):
- url = "https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting"
- full_url = f"{url}/{name.lower()}"
+ full_url = f"{MONAI_HOSTING_BASE_URL}/{name.lower()}"
requests_get, has_requests = optional_import("requests", name="get")
if has_requests:
resp = requests_get(full_url)
resp.raise_for_status()
else:
- raise ValueError("NGC API requires requests package. Please install it.")
+ raise ValueError("NGC API requires requests package. Please install it.")
model_info = json.loads(resp.text)
return model_info["model"]["latestVersionIdStr"]
-def _get_latest_bundle_version(source: str, name: str, repo: str) -> dict[str, list[str] | str] | Any | None:
+def _examine_monai_version(monai_version: str) -> tuple[bool, str]:
+ """Examine if the package version is compatible with the MONAI version in the metadata."""
+ version_dict = get_versions()
+ package_version = version_dict.get("version", "0+unknown")
+ if package_version == "0+unknown":
+ return False, "Package version is not available. Skipping version check."
+ if monai_version == "0+unknown":
+ return False, "MONAI version is not specified in the bundle. Skipping version check."
+ # treat rc versions as the same as the release version
+ package_version = re.sub(r"rc\d.*", "", package_version)
+ monai_version = re.sub(r"rc\d.*", "", monai_version)
+ if package_version < monai_version:
+ return (
+ False,
+ f"Your MONAI version is {package_version}, but the bundle is built on MONAI version {monai_version}.",
+ )
+ return True, ""
+
+
+def _check_monai_version(bundle_dir: PathLike, name: str) -> None:
+ """Get the `monai_version` from the metadata.json and compare if it is smaller than the installed `monai` package version"""
+ metadata_file = Path(bundle_dir) / name / "configs" / "metadata.json"
+ if not metadata_file.exists():
+ logger.warning(f"metadata file not found in {metadata_file}.")
+ return
+ with open(metadata_file) as f:
+ metadata = json.load(f)
+ is_compatible, msg = _examine_monai_version(metadata.get("monai_version", "0+unknown"))
+ if not is_compatible:
+ logger.warning(msg)
+
+
+def _list_latest_versions(data: dict, max_versions: int = 3) -> list[str]:
+ """
+ Extract the latest versions from the data dictionary.
+
+ Args:
+ data: the data dictionary.
+ max_versions: the maximum number of versions to return.
+
+ Returns:
+ versions of the latest models in the reverse order of creation date, e.g. ['1.0.0', '0.9.0', '0.8.0'].
+ """
+ # Check if the data is a dictionary and it has the key 'modelVersions'
+ if not isinstance(data, dict) or "modelVersions" not in data:
+ raise ValueError("The data is not a dictionary or it does not have the key 'modelVersions'.")
+
+ # Extract the list of model versions
+ model_versions = data["modelVersions"]
+
+ if (
+ not isinstance(model_versions, list)
+ or len(model_versions) == 0
+ or "createdDate" not in model_versions[0]
+ or "versionId" not in model_versions[0]
+ ):
+ raise ValueError(
+ "The model versions are not a list or it is empty or it does not have the keys 'createdDate' and 'versionId'."
+ )
+
+ # Sort the versions by the 'createdDate' in descending order
+ sorted_versions = sorted(model_versions, key=lambda x: x["createdDate"], reverse=True)
+ return [v["versionId"] for v in sorted_versions[:max_versions]]
+
+
+def _get_latest_bundle_version_ngc(name: str, repo: str | None = None, headers: dict | None = None) -> str:
+ base_url = _get_ngc_private_base_url(repo) if repo else NGC_BASE_URL
+ version_endpoint = base_url + f"/{name.lower()}/versions/"
+
+ if not has_requests:
+ raise ValueError("requests package is required, please install it.")
+
+ version_header = {"Accept-Encoding": "gzip, deflate"} # Excluding 'zstd' to fit NGC requirements
+ if headers:
+ version_header.update(headers)
+ resp = requests_get(version_endpoint, headers=version_header)
+ resp.raise_for_status()
+ model_info = json.loads(resp.text)
+ latest_versions = _list_latest_versions(model_info)
+
+ for version in latest_versions:
+ file_endpoint = base_url + f"/{name.lower()}/versions/{version}/files/configs/metadata.json"
+ resp = requests_get(file_endpoint, headers=headers)
+ metadata = json.loads(resp.text)
+ resp.raise_for_status()
+ # if the package version is not available or the model is compatible with the package version
+ is_compatible, _ = _examine_monai_version(metadata["monai_version"])
+ if is_compatible:
+ if version != latest_versions[0]:
+ logger.info(f"Latest version is {latest_versions[0]}, but the compatible version is {version}.")
+ return version
+
+ # if no compatible version is found, return the latest version
+ return latest_versions[0]
+
+
+def _get_latest_bundle_version(
+ source: str, name: str, repo: str, **kwargs: Any
+) -> dict[str, list[str] | str] | Any | None:
if source == "ngc":
name = _add_ngc_prefix(name)
- model_dict = _get_all_ngc_models(name)
- for v in model_dict.values():
- if v["name"] == name:
- return v["latest"]
- return None
+ return _get_latest_bundle_version_ngc(name)
elif source == "monaihosting":
return _get_latest_bundle_version_monaihosting(name)
+ elif source == "ngc_private":
+ headers = kwargs.pop("headers", {})
+ name = _add_ngc_prefix(name)
+ return _get_latest_bundle_version_ngc(name, repo=repo, headers=headers)
elif source == "github":
repo_owner, repo_name, tag_name = repo.split("/")
return get_bundle_versions(name, repo=f"{repo_owner}/{repo_name}", tag=tag_name)["latest_version"]
@@ -308,6 +472,9 @@ def download(
# Execute this module as a CLI entry, and download bundle via URL:
python -m monai.bundle download --name --url
+ # Execute this module as a CLI entry, and download bundle from ngc_private with latest version:
+ python -m monai.bundle download --name --source "ngc_private" --bundle_dir "./" --repo "org/org_name"
+
# Set default args of `run` in a JSON / YAML file, help to record and simplify the command line.
# Other args still can override the default args at runtime.
# The content of the JSON / YAML file is a dictionary. For example:
@@ -328,14 +495,17 @@ def download(
Default is `bundle` subfolder under `torch.hub.get_dir()`.
source: storage location name. This argument is used when `url` is `None`.
In default, the value is achieved from the environment variable BUNDLE_DOWNLOAD_SRC, and
- it should be "ngc", "monaihosting", "github", or "huggingface_hub".
+ it should be "ngc", "monaihosting", "github", "ngc_private", or "huggingface_hub".
+ If source is "ngc_private", you need specify the NGC_API_KEY in the environment variable.
repo: repo name. This argument is used when `url` is `None` and `source` is "github" or "huggingface_hub".
If `source` is "github", it should be in the form of "repo_owner/repo_name/release_tag".
If `source` is "huggingface_hub", it should be in the form of "repo_owner/repo_name".
+ If `source` is "ngc_private", it should be in the form of "org/org_name" or "org/org_name/team/team_name",
+ or you can specify the environment variable NGC_ORG and NGC_TEAM.
url: url to download the data. If not `None`, data will be downloaded directly
and `source` will not be checked.
If `name` is `None`, filename is determined by `monai.apps.utils._basename(url)`.
- remove_prefix: This argument is used when `source` is "ngc". Currently, all ngc bundles
+ remove_prefix: This argument is used when `source` is "ngc" or "ngc_private". Currently, all ngc bundles
have the ``monai_`` prefix, which is not existing in their model zoo contrasts. In order to
maintain the consistency between these two sources, remove prefix is necessary.
Therefore, if specified, downloaded folder name will remove the prefix.
@@ -363,11 +533,18 @@ def download(
bundle_dir_ = _process_bundle_dir(bundle_dir_)
if repo_ is None:
- repo_ = "Project-MONAI/model-zoo/hosting_storage_v1"
- if len(repo_.split("/")) != 3 and source_ != "huggingface_hub":
- raise ValueError("repo should be in the form of `repo_owner/repo_name/release_tag`.")
+ org_ = os.getenv("NGC_ORG", None)
+ team_ = os.getenv("NGC_TEAM", None)
+ if org_ is not None and source_ == "ngc_private":
+ repo_ = f"org/{org_}/team/{team_}" if team_ is not None else f"org/{org_}"
+ else:
+ repo_ = "Project-MONAI/model-zoo/hosting_storage_v1"
+ if len(repo_.split("/")) not in (2, 4) and source_ == "ngc_private":
+ raise ValueError(f"repo should be in the form of `org/org_name/team/team_name` or `org/org_name`, got {repo_}.")
+ if len(repo_.split("/")) != 3 and source_ == "github":
+ raise ValueError(f"repo should be in the form of `repo_owner/repo_name/release_tag`, got {repo_}.")
elif len(repo_.split("/")) != 2 and source_ == "huggingface_hub":
- raise ValueError("Hugging Face Hub repo should be in the form of `repo_owner/repo_name`")
+ raise ValueError(f"Hugging Face Hub repo should be in the form of `repo_owner/repo_name`, got {repo_}.")
if url_ is not None:
if name_ is not None:
filepath = bundle_dir_ / f"{name_}.zip"
@@ -376,14 +553,22 @@ def download(
download_url(url=url_, filepath=filepath, hash_val=None, progress=progress_)
extractall(filepath=filepath, output_dir=bundle_dir_, has_base=True)
else:
+ headers = {}
if name_ is None:
raise ValueError(f"To download from source: {source_}, `name` must be provided.")
+ if source == "ngc_private":
+ api_key = os.getenv("NGC_API_KEY", None)
+ if api_key is None:
+ raise ValueError("API key is required for ngc_private source.")
+ else:
+ token = _get_ngc_token(api_key)
+ headers = {"Authorization": f"Bearer {token}"}
+
if version_ is None:
- version_ = _get_latest_bundle_version(source=source_, name=name_, repo=repo_)
+ version_ = _get_latest_bundle_version(source=source_, name=name_, repo=repo_, headers=headers)
if source_ == "github":
- if version_ is not None:
- name_ = "_v".join([name_, version_])
- _download_from_github(repo=repo_, download_path=bundle_dir_, filename=name_, progress=progress_)
+ name_ver = "_v".join([name_, version_]) if version_ is not None else name_
+ _download_from_github(repo=repo_, download_path=bundle_dir_, filename=name_ver, progress=progress_)
elif source_ == "monaihosting":
_download_from_monaihosting(download_path=bundle_dir_, filename=name_, version=version_, progress=progress_)
elif source_ == "ngc":
@@ -394,6 +579,15 @@ def download(
remove_prefix=remove_prefix_,
progress=progress_,
)
+ elif source_ == "ngc_private":
+ _download_from_ngc_private(
+ download_path=bundle_dir_,
+ filename=name_,
+ version=version_,
+ remove_prefix=remove_prefix_,
+ repo=repo_,
+ headers=headers,
+ )
elif source_ == "huggingface_hub":
extract_path = os.path.join(bundle_dir_, name_)
huggingface_hub.snapshot_download(repo_id=repo_, revision=version_, local_dir=extract_path)
@@ -403,6 +597,8 @@ def download(
f"got source: {source_}."
)
+ _check_monai_version(bundle_dir_, name_)
+
@deprecated_arg("net_name", since="1.2", removed="1.5", msg_suffix="please use ``model`` instead.")
@deprecated_arg("net_kwargs", since="1.2", removed="1.5", msg_suffix="please use ``model`` instead.")
@@ -778,10 +974,19 @@ def run(
https://docs.python.org/3/library/logging.config.html#logging.config.fileConfig.
Default to None.
tracking: if not None, enable the experiment tracking at runtime with optionally configurable and extensible.
- if "mlflow", will add `MLFlowHandler` to the parsed bundle with default tracking settings,
- if other string, treat it as file path to load the tracking settings.
- if `dict`, treat it as tracking settings.
- will patch the target config content with `tracking handlers` and the top-level items of `configs`.
+ If "mlflow", will add `MLFlowHandler` to the parsed bundle with default tracking settings where a set of
+ common parameters shown below will be added and can be passed through the `override` parameter of this method.
+
+ - ``"output_dir"``: the path to save mlflow tracking outputs locally, default to "/eval".
+ - ``"tracking_uri"``: uri to save mlflow tracking outputs, default to "/output_dir/mlruns".
+ - ``"experiment_name"``: experiment name for this run, default to "monai_experiment".
+ - ``"run_name"``: the name of current run.
+ - ``"save_execute_config"``: whether to save the executed config files. It can be `False`, `/path/to/artifacts`
+ or `True`. If set to `True`, will save to the default path "/eval". Default to `True`.
+
+ If other string, treat it as file path to load the tracking settings.
+ If `dict`, treat it as tracking settings.
+ Will patch the target config content with `tracking handlers` and the top-level items of `configs`.
for detailed usage examples, please check the tutorial:
https://github.com/Project-MONAI/tutorials/blob/main/experiment_management/bundle_integrate_mlflow.ipynb.
args_file: a JSON or YAML file to provide default values for `run_id`, `meta_file`,
@@ -1052,6 +1257,7 @@ def verify_net_in_out(
def _export(
converter: Callable,
+ saver: Callable,
parser: ConfigParser,
net_id: str,
filepath: str,
@@ -1066,6 +1272,8 @@ def _export(
Args:
converter: a callable object that takes a torch.nn.module and kwargs as input and
converts the module to another type.
+ saver: a callable object that accepts the converted model to save, a filepath to save to, meta values
+ (extracted from the parser), and a dictionary of extra JSON files (name -> contents) as input.
parser: a ConfigParser of the bundle to be converted.
net_id: ID name of the network component in the parser, it must be `torch.nn.Module`.
filepath: filepath to export, if filename has no extension, it becomes `.ts`.
@@ -1105,14 +1313,9 @@ def _export(
# add .json extension to all extra files which are always encoded as JSON
extra_files = {k + ".json": v for k, v in extra_files.items()}
- save_net_with_metadata(
- jit_obj=net,
- filename_prefix_or_stream=filepath,
- include_config_vals=False,
- append_timestamp=False,
- meta_values=parser.get().pop("_meta_", None),
- more_extra_files=extra_files,
- )
+ meta_values = parser.get().pop("_meta_", None)
+ saver(net, filepath, meta_values=meta_values, more_extra_files=extra_files)
+
logger.info(f"exported to file: {filepath}.")
@@ -1211,17 +1414,23 @@ def onnx_export(
input_shape_ = _get_fake_input_shape(parser=parser)
inputs_ = [torch.rand(input_shape_)]
- net = parser.get_parsed_content(net_id_)
- if has_ignite:
- # here we use ignite Checkpoint to support nested weights and be compatible with MONAI CheckpointSaver
- Checkpoint.load_objects(to_load={key_in_ckpt_: net}, checkpoint=ckpt_file_)
- else:
- ckpt = torch.load(ckpt_file_)
- copy_model_state(dst=net, src=ckpt if key_in_ckpt_ == "" else ckpt[key_in_ckpt_])
converter_kwargs_.update({"inputs": inputs_, "use_trace": use_trace_})
- onnx_model = convert_to_onnx(model=net, **converter_kwargs_)
- onnx.save(onnx_model, filepath_)
+
+ def save_onnx(onnx_obj: Any, filename_prefix_or_stream: str, **kwargs: Any) -> None:
+ onnx.save(onnx_obj, filename_prefix_or_stream)
+
+ _export(
+ convert_to_onnx,
+ save_onnx,
+ parser,
+ net_id=net_id_,
+ filepath=filepath_,
+ ckpt_file=ckpt_file_,
+ config_file=config_file_,
+ key_in_ckpt=key_in_ckpt_,
+ **converter_kwargs_,
+ )
def ckpt_export(
@@ -1342,8 +1551,12 @@ def ckpt_export(
converter_kwargs_.update({"inputs": inputs_, "use_trace": use_trace_})
# Use the given converter to convert a model and save with metadata, config content
+
+ save_ts = partial(save_net_with_metadata, include_config_vals=False, append_timestamp=False)
+
_export(
convert_to_torchscript,
+ save_ts,
parser,
net_id=net_id_,
filepath=filepath_,
@@ -1376,6 +1589,8 @@ def trt_export(
"""
Export the model checkpoint to the given filepath as a TensorRT engine-based TorchScript.
Currently, this API only supports converting models whose inputs are all tensors.
+ Note: NVIDIA Volta support (GPUs with compute capability 7.0) has been removed starting with TensorRT 10.5.
+ Review the TensorRT Support Matrix for which GPUs are supported.
There are two ways to export a model:
1, Torch-TensorRT way: PyTorch module ---> TorchScript module ---> TensorRT engine-based TorchScript.
@@ -1513,8 +1728,11 @@ def trt_export(
}
converter_kwargs_.update(trt_api_parameters)
+ save_ts = partial(save_net_with_metadata, include_config_vals=False, append_timestamp=False)
+
_export(
convert_to_trt,
+ save_ts,
parser,
net_id=net_id_,
filepath=filepath_,
@@ -1729,7 +1947,6 @@ def create_workflow(
"""
_args = update_kwargs(args=args_file, workflow_name=workflow_name, config_file=config_file, **kwargs)
- _log_input_summary(tag="run", args=_args)
(workflow_name, config_file) = _pop_args(
_args, workflow_name=ConfigWorkflow, config_file=None
) # the default workflow name is "ConfigWorkflow"
@@ -1753,7 +1970,7 @@ def create_workflow(
workflow_ = workflow_class(**_args)
workflow_.initialize()
-
+ _log_input_summary(tag="run", args=_args)
return workflow_
diff --git a/monai/bundle/utils.py b/monai/bundle/utils.py
index b187159c89..53d619f234 100644
--- a/monai/bundle/utils.py
+++ b/monai/bundle/utils.py
@@ -13,6 +13,7 @@
import json
import os
+import warnings
import zipfile
from typing import Any
@@ -21,12 +22,21 @@
yaml, _ = optional_import("yaml")
-__all__ = ["ID_REF_KEY", "ID_SEP_KEY", "EXPR_KEY", "MACRO_KEY", "DEFAULT_MLFLOW_SETTINGS", "DEFAULT_EXP_MGMT_SETTINGS"]
+__all__ = [
+ "ID_REF_KEY",
+ "ID_SEP_KEY",
+ "EXPR_KEY",
+ "MACRO_KEY",
+ "MERGE_KEY",
+ "DEFAULT_MLFLOW_SETTINGS",
+ "DEFAULT_EXP_MGMT_SETTINGS",
+]
ID_REF_KEY = "@" # start of a reference to a ConfigItem
ID_SEP_KEY = "::" # separator for the ID of a ConfigItem
EXPR_KEY = "$" # start of a ConfigExpression
MACRO_KEY = "%" # start of a macro of a config
+MERGE_KEY = "+" # prefix indicating merge instead of override in case of multiple configs.
_conf_values = get_config_values()
@@ -36,7 +46,7 @@
"monai_version": _conf_values["MONAI"],
"pytorch_version": str(_conf_values["Pytorch"]).split("+")[0].split("a")[0], # 1.9.0a0+df837d0 or 1.13.0+cu117
"numpy_version": _conf_values["Numpy"],
- "optional_packages_version": {},
+ "required_packages_version": {},
"task": "Describe what the network predicts",
"description": "A longer description of what the network does, use context, inputs, outputs, etc.",
"authors": "Your Name Here",
@@ -113,7 +123,7 @@
"experiment_name": "monai_experiment",
"run_name": None,
# may fill it at runtime
- "execute_config": None,
+ "save_execute_config": True,
"is_not_rank0": (
"$torch.distributed.is_available() \
and torch.distributed.is_initialized() and torch.distributed.get_rank() > 0"
@@ -125,7 +135,7 @@
"tracking_uri": "@tracking_uri",
"experiment_name": "@experiment_name",
"run_name": "@run_name",
- "artifacts": "@execute_config",
+ "artifacts": "@save_execute_config",
"iteration_log": True,
"epoch_log": True,
"tag_name": "train_loss",
@@ -148,7 +158,7 @@
"tracking_uri": "@tracking_uri",
"experiment_name": "@experiment_name",
"run_name": "@run_name",
- "artifacts": "@execute_config",
+ "artifacts": "@save_execute_config",
"iteration_log": False,
"close_on_complete": True,
},
@@ -157,6 +167,8 @@
DEFAULT_EXP_MGMT_SETTINGS = {"mlflow": DEFAULT_MLFLOW_SETTINGS} # default experiment management settings
+DEPRECATED_ID_MAPPING = {"optional_packages_version": "required_packages_version"}
+
def load_bundle_config(bundle_path: str, *config_names: str, **load_kw_args: Any) -> Any:
"""
@@ -221,6 +233,7 @@ def load_bundle_config(bundle_path: str, *config_names: str, **load_kw_args: Any
raise ValueError(f"Cannot find config file '{full_cname}'")
ardata = archive.read(full_cname)
+ cdata = {}
if full_cname.lower().endswith("json"):
cdata = json.loads(ardata, **load_kw_args)
@@ -230,3 +243,27 @@ def load_bundle_config(bundle_path: str, *config_names: str, **load_kw_args: Any
parser.read_config(f=cdata)
return parser
+
+
+def merge_kv(args: dict | Any, k: str, v: Any) -> None:
+ """
+ Update the `args` dict-like object with the key/value pair `k` and `v`.
+ """
+ if k.startswith(MERGE_KEY):
+ """
+ Both values associated with `+`-prefixed key pair must be of `dict` or `list` type.
+ `dict` values will be merged, `list` values - concatenated.
+ """
+ id = k[1:]
+ if id in args:
+ if isinstance(v, dict) and isinstance(args[id], dict):
+ args[id].update(v)
+ elif isinstance(v, list) and isinstance(args[id], list):
+ args[id].extend(v)
+ else:
+ raise ValueError(ValueError(f"config must be dict or list for key `{k}`, but got {type(v)}: {v}."))
+ else:
+ warnings.warn(f"Can't merge entry ['{k}'], '{id}' is not in target dict - copying instead.")
+ args[id] = v
+ else:
+ args[k] = v
diff --git a/monai/bundle/workflows.py b/monai/bundle/workflows.py
index da3aa30141..3ecd5dfbc5 100644
--- a/monai/bundle/workflows.py
+++ b/monai/bundle/workflows.py
@@ -11,20 +11,23 @@
from __future__ import annotations
+import json
import os
import sys
import time
from abc import ABC, abstractmethod
+from collections.abc import Sequence
from copy import copy
from logging.config import fileConfig
from pathlib import Path
-from typing import Any, Sequence
+from typing import Any
from monai.apps.utils import get_logger
from monai.bundle.config_parser import ConfigParser
from monai.bundle.properties import InferProperties, MetaProperties, TrainProperties
from monai.bundle.utils import DEFAULT_EXP_MGMT_SETTINGS, EXPR_KEY, ID_REF_KEY, ID_SEP_KEY
-from monai.utils import BundleProperty, BundlePropertyConfig, deprecated_arg, deprecated_arg_default, ensure_tuple
+from monai.config import PathLike
+from monai.utils import BundleProperty, BundlePropertyConfig, deprecated_arg, ensure_tuple
__all__ = ["BundleWorkflow", "ConfigWorkflow"]
@@ -41,11 +44,15 @@ class BundleWorkflow(ABC):
workflow_type: specifies the workflow type: "train" or "training" for a training workflow,
or "infer", "inference", "eval", "evaluation" for a inference workflow,
other unsupported string will raise a ValueError.
- default to `None` for common workflow.
+ default to `train` for train workflow.
workflow: specifies the workflow type: "train" or "training" for a training workflow,
or "infer", "inference", "eval", "evaluation" for a inference workflow,
other unsupported string will raise a ValueError.
default to `None` for common workflow.
+ properties_path: the path to the JSON file of properties.
+ meta_file: filepath of the metadata file, if this is a list of file paths, their contents will be merged in order.
+ logging_file: config file for `logging` module in the program. for more details:
+ https://docs.python.org/3/library/logging.config.html#logging.config.fileConfig.
"""
@@ -59,21 +66,62 @@ class BundleWorkflow(ABC):
new_name="workflow_type",
msg_suffix="please use `workflow_type` instead.",
)
- def __init__(self, workflow_type: str | None = None, workflow: str | None = None):
+ def __init__(
+ self,
+ workflow_type: str | None = None,
+ workflow: str | None = None,
+ properties_path: PathLike | None = None,
+ meta_file: str | Sequence[str] | None = None,
+ logging_file: str | None = None,
+ ):
+ if logging_file is not None:
+ if not os.path.isfile(logging_file):
+ raise FileNotFoundError(f"Cannot find the logging config file: {logging_file}.")
+ logger.info(f"Setting logging properties based on config: {logging_file}.")
+ fileConfig(logging_file, disable_existing_loggers=False)
+
+ if meta_file is not None:
+ if isinstance(meta_file, str) and not os.path.isfile(meta_file):
+ logger.error(
+ f"Cannot find the metadata config file: {meta_file}. "
+ "Please see: https://docs.monai.io/en/stable/mb_specification.html"
+ )
+ meta_file = None
+ if isinstance(meta_file, list):
+ for f in meta_file:
+ if not os.path.isfile(f):
+ logger.error(
+ f"Cannot find the metadata config file: {f}. "
+ "Please see: https://docs.monai.io/en/stable/mb_specification.html"
+ )
+ meta_file = None
+
workflow_type = workflow if workflow is not None else workflow_type
- if workflow_type is None:
+ if workflow_type is None and properties_path is None:
self.properties = copy(MetaProperties)
self.workflow_type = None
+ self.meta_file = meta_file
+ return
+ if properties_path is not None:
+ properties_path = Path(properties_path)
+ if not properties_path.is_file():
+ raise ValueError(f"Property file {properties_path} does not exist.")
+ with open(properties_path) as json_file:
+ self.properties = json.load(json_file)
+ self.workflow_type = None
+ self.meta_file = meta_file
return
- if workflow_type.lower() in self.supported_train_type:
+ if workflow_type.lower() in self.supported_train_type: # type: ignore[union-attr]
self.properties = {**TrainProperties, **MetaProperties}
self.workflow_type = "train"
- elif workflow_type.lower() in self.supported_infer_type:
+ elif workflow_type.lower() in self.supported_infer_type: # type: ignore[union-attr]
self.properties = {**InferProperties, **MetaProperties}
self.workflow_type = "infer"
else:
raise ValueError(f"Unsupported workflow type: '{workflow_type}'.")
+ self.meta_file = meta_file
+
@abstractmethod
def initialize(self, *args: Any, **kwargs: Any) -> Any:
"""
@@ -142,6 +190,13 @@ def get_workflow_type(self):
"""
return self.workflow_type
+ def get_meta_file(self):
+ """
+ Get the meta file.
+
+ """
+ return self.meta_file
+
def add_property(self, name: str, required: str, desc: str | None = None) -> None:
"""
Besides the default predefined properties, some 3rd party applications may need the bundle
@@ -185,6 +240,7 @@ class ConfigWorkflow(BundleWorkflow):
logging_file: config file for `logging` module in the program. for more details:
https://docs.python.org/3/library/logging.config.html#logging.config.fileConfig.
If None, default to "configs/logging.conf", which is commonly used for bundles in MONAI model zoo.
+ If False, the logging logic for the bundle will not be modified.
init_id: ID name of the expected config expression to initialize before running, default to "initialize".
allow a config to have no `initialize` logic and the ID.
run_id: ID name of the expected config expression to run, default to "run".
@@ -206,6 +262,7 @@ class ConfigWorkflow(BundleWorkflow):
or "infer", "inference", "eval", "evaluation" for a inference workflow,
other unsupported string will raise a ValueError.
default to `None` for common workflow.
+ properties_path: the path to the JSON file of properties.
override: id-value pairs to override or add the corresponding config content.
e.g. ``--net#input_chns 42``, ``--net %/data/other.json#net_arg``
@@ -218,58 +275,55 @@ class ConfigWorkflow(BundleWorkflow):
new_name="workflow_type",
msg_suffix="please use `workflow_type` instead.",
)
- @deprecated_arg_default("workflow_type", None, "train", since="1.2", replaced="1.4")
def __init__(
self,
config_file: str | Sequence[str],
meta_file: str | Sequence[str] | None = None,
- logging_file: str | None = None,
+ logging_file: str | bool | None = None,
init_id: str = "initialize",
run_id: str = "run",
final_id: str = "finalize",
tracking: str | dict | None = None,
- workflow_type: str | None = None,
+ workflow_type: str | None = "train",
workflow: str | None = None,
+ properties_path: PathLike | None = None,
**override: Any,
) -> None:
workflow_type = workflow if workflow is not None else workflow_type
- super().__init__(workflow_type=workflow_type)
if config_file is not None:
_config_files = ensure_tuple(config_file)
- self.config_root_path = Path(_config_files[0]).parent
+ config_root_path = Path(_config_files[0]).parent
for _config_file in _config_files:
_config_file = Path(_config_file)
- if _config_file.parent != self.config_root_path:
+ if _config_file.parent != config_root_path:
logger.warn(
- f"Not all config files are in {self.config_root_path}. If logging_file and meta_file are"
- f"not specified, {self.config_root_path} will be used as the default config root directory."
+ f"Not all config files are in {config_root_path}. If logging_file and meta_file are"
+ f"not specified, {config_root_path} will be used as the default config root directory."
)
if not _config_file.is_file():
raise FileNotFoundError(f"Cannot find the config file: {_config_file}.")
else:
- self.config_root_path = Path("configs")
-
+ config_root_path = Path("configs")
+ meta_file = str(config_root_path / "metadata.json") if meta_file is None else meta_file
+ super().__init__(workflow_type=workflow_type, meta_file=meta_file, properties_path=properties_path)
+ self.config_root_path = config_root_path
logging_file = str(self.config_root_path / "logging.conf") if logging_file is None else logging_file
- if logging_file is not None:
- if not os.path.exists(logging_file):
+ if logging_file is False:
+ logger.warn(f"Logging file is set to {logging_file}, skipping logging.")
+ else:
+ if not os.path.isfile(logging_file):
if logging_file == str(self.config_root_path / "logging.conf"):
logger.warn(f"Default logging file in {logging_file} does not exist, skipping logging.")
else:
raise FileNotFoundError(f"Cannot find the logging config file: {logging_file}.")
else:
+ fileConfig(str(logging_file), disable_existing_loggers=False)
logger.info(f"Setting logging properties based on config: {logging_file}.")
- fileConfig(logging_file, disable_existing_loggers=False)
self.parser = ConfigParser()
self.parser.read_config(f=config_file)
- meta_file = str(self.config_root_path / "metadata.json") if meta_file is None else meta_file
- if isinstance(meta_file, str) and not os.path.exists(meta_file):
- logger.error(
- f"Cannot find the metadata config file: {meta_file}. "
- "Please see: https://docs.monai.io/en/stable/mb_specification.html"
- )
- else:
- self.parser.read_meta(f=meta_file)
+ if self.meta_file is not None:
+ self.parser.read_meta(f=self.meta_file)
# the rest key-values in the _args are to override config content
self.parser.update(pairs=override)
@@ -455,13 +509,19 @@ def patch_bundle_tracking(parser: ConfigParser, settings: dict) -> None:
parser[k] = v
# save the executed config into file
default_name = f"config_{time.strftime('%Y%m%d_%H%M%S')}.json"
- filepath = parser.get("execute_config", None)
- if filepath is None:
- if "output_dir" not in parser:
- # if no "output_dir" in the bundle config, default to "/eval"
- parser["output_dir"] = f"{EXPR_KEY}{ID_REF_KEY}bundle_root + '/eval'"
- # experiment management tools can refer to this config item to track the config info
- parser["execute_config"] = parser["output_dir"] + f" + '/{default_name}'"
- filepath = os.path.join(parser.get_parsed_content("output_dir"), default_name)
- Path(filepath).parent.mkdir(parents=True, exist_ok=True)
- parser.export_config_file(parser.get(), filepath)
+ # Users can set the `save_execute_config` to `False`, `/path/to/artifacts` or `True`.
+ # If set to False, nothing will be recorded. If set to True, the default path will be logged.
+ # If set to a file path, the given path will be logged.
+ filepath = parser.get("save_execute_config", True)
+ if filepath:
+ if isinstance(filepath, bool):
+ if "output_dir" not in parser:
+ # if no "output_dir" in the bundle config, default to "/eval"
+ parser["output_dir"] = f"{EXPR_KEY}{ID_REF_KEY}bundle_root + '/eval'"
+ # experiment management tools can refer to this config item to track the config info
+ parser["save_execute_config"] = parser["output_dir"] + f" + '/{default_name}'"
+ filepath = os.path.join(parser.get_parsed_content("output_dir"), default_name)
+ Path(filepath).parent.mkdir(parents=True, exist_ok=True)
+ parser.export_config_file(parser.get(), filepath)
+ else:
+ parser["save_execute_config"] = None
diff --git a/monai/config/deviceconfig.py b/monai/config/deviceconfig.py
index a4580c741b..05842245ce 100644
--- a/monai/config/deviceconfig.py
+++ b/monai/config/deviceconfig.py
@@ -23,6 +23,8 @@
import torch
import monai
+from monai.utils.deprecate_utils import deprecated
+from monai.utils.enums import IgniteInfo as _IgniteInfo
from monai.utils.module import OptionalImportError, get_package_version, optional_import
try:
@@ -261,13 +263,11 @@ def print_debug_info(file: TextIO = sys.stdout) -> None:
print_gpu_info(file)
+@deprecated(since="1.4.0", removed="1.6.0", msg_suffix="Please use `monai.utils.enums.IgniteInfo` instead.")
class IgniteInfo:
- """
- Config information of the PyTorch ignite package.
-
- """
+ """Deprecated Import of IgniteInfo enum, which was moved to `monai.utils.enums.IgniteInfo`."""
- OPT_IMPORT_VERSION = "0.4.4"
+ OPT_IMPORT_VERSION = _IgniteInfo.OPT_IMPORT_VERSION
if __name__ == "__main__":
diff --git a/monai/config/type_definitions.py b/monai/config/type_definitions.py
index 57454a94e1..48c0547e31 100644
--- a/monai/config/type_definitions.py
+++ b/monai/config/type_definitions.py
@@ -12,7 +12,8 @@
from __future__ import annotations
import os
-from typing import Collection, Hashable, Iterable, Sequence, TypeVar, Union
+from collections.abc import Collection, Hashable, Iterable, Sequence
+from typing import TypeVar, Union
import numpy as np
import torch
diff --git a/monai/data/dataset.py b/monai/data/dataset.py
index 531893d768..8c53338d66 100644
--- a/monai/data/dataset.py
+++ b/monai/data/dataset.py
@@ -22,6 +22,7 @@
import warnings
from collections.abc import Callable, Sequence
from copy import copy, deepcopy
+from inspect import signature
from multiprocessing.managers import ListProxy
from multiprocessing.pool import ThreadPool
from pathlib import Path
@@ -36,15 +37,7 @@
from monai.data.meta_tensor import MetaTensor
from monai.data.utils import SUPPORTED_PICKLE_MOD, convert_tables_to_dicts, pickle_hashing
-from monai.transforms import (
- Compose,
- Randomizable,
- RandomizableTrait,
- Transform,
- apply_transform,
- convert_to_contiguous,
- reset_ops_id,
-)
+from monai.transforms import Compose, Randomizable, RandomizableTrait, Transform, convert_to_contiguous, reset_ops_id
from monai.utils import MAX_SEED, convert_to_tensor, get_seed, look_up_option, min_version, optional_import
from monai.utils.misc import first
@@ -77,15 +70,19 @@ class Dataset(_TorchDataset):
}, }, }]
"""
- def __init__(self, data: Sequence, transform: Callable | None = None) -> None:
+ def __init__(self, data: Sequence, transform: Sequence[Callable] | Callable | None = None) -> None:
"""
Args:
data: input data to load and transform to generate dataset for model.
- transform: a callable data transform on input data.
-
+ transform: a callable, sequence of callables or None. If transform is not
+ a `Compose` instance, it will be wrapped in a `Compose` instance. Sequences
+ of callables are applied in order and if `None` is passed, the data is returned as is.
"""
self.data = data
- self.transform: Any = transform
+ try:
+ self.transform = Compose(transform) if not isinstance(transform, Compose) else transform
+ except Exception as e:
+ raise ValueError("`transform` must be a callable or a list of callables that is Composable") from e
def __len__(self) -> int:
return len(self.data)
@@ -95,7 +92,7 @@ def _transform(self, index: int):
Fetch single data item from `self.data`.
"""
data_i = self.data[index]
- return apply_transform(self.transform, data_i) if self.transform is not None else data_i
+ return self.transform(data_i)
def __getitem__(self, index: int | slice | Sequence[int]):
"""
@@ -264,8 +261,6 @@ def __init__(
using the cached content and with re-created transform instances.
"""
- if not isinstance(transform, Compose):
- transform = Compose(transform)
super().__init__(data=data, transform=transform)
self.cache_dir = Path(cache_dir) if cache_dir is not None else None
self.hash_func = hash_func
@@ -323,9 +318,6 @@ def _pre_transform(self, item_transformed):
random transform object
"""
- if not isinstance(self.transform, Compose):
- raise ValueError("transform must be an instance of monai.transforms.Compose.")
-
first_random = self.transform.get_index_of_first(
lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform)
)
@@ -346,9 +338,6 @@ def _post_transform(self, item_transformed):
the transformed element through the random transforms
"""
- if not isinstance(self.transform, Compose):
- raise ValueError("transform must be an instance of monai.transforms.Compose.")
-
first_random = self.transform.get_index_of_first(
lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform)
)
@@ -383,7 +372,10 @@ def _cachecheck(self, item_transformed):
if hashfile is not None and hashfile.is_file(): # cache hit
try:
- return torch.load(hashfile)
+ if "weights_only" in signature(torch.load).parameters:
+ return torch.load(hashfile, weights_only=False)
+ else:
+ return torch.load(hashfile)
except PermissionError as e:
if sys.platform != "win32":
raise e
@@ -427,7 +419,7 @@ def _transform(self, index: int):
class CacheNTransDataset(PersistentDataset):
"""
- Extension of `PersistentDataset`, tt can also cache the result of first N transforms, no matter it's random or not.
+ Extension of `PersistentDataset`, it can also cache the result of first N transforms, no matter it's random or not.
"""
@@ -501,9 +493,6 @@ def _pre_transform(self, item_transformed):
Returns:
the transformed element up to the N transform object
"""
- if not isinstance(self.transform, Compose):
- raise ValueError("transform must be an instance of monai.transforms.Compose.")
-
item_transformed = self.transform(item_transformed, end=self.cache_n_trans, threading=True)
reset_ops_id(item_transformed)
@@ -519,9 +508,6 @@ def _post_transform(self, item_transformed):
Returns:
the final transformed result
"""
- if not isinstance(self.transform, Compose):
- raise ValueError("transform must be an instance of monai.transforms.Compose.")
-
return self.transform(item_transformed, start=self.cache_n_trans)
@@ -809,8 +795,6 @@ def __init__(
Not following these recommendations may lead to runtime errors or duplicated cache across processes.
"""
- if not isinstance(transform, Compose):
- transform = Compose(transform)
super().__init__(data=data, transform=transform)
self.set_num = cache_num # tracking the user-provided `cache_num` option
self.set_rate = cache_rate # tracking the user-provided `cache_rate` option
@@ -1282,8 +1266,10 @@ def to_list(x):
data = []
for dataset in self.data:
data.extend(to_list(dataset[index]))
+
if self.transform is not None:
- data = apply_transform(self.transform, data, map_items=False) # transform the list data
+ self.transform.map_items = False # Compose object map_items to false so transform is applied to list
+ data = self.transform(data)
# use tuple instead of list as the default collate_fn callback of MONAI DataLoader flattens nested lists
return tuple(data)
@@ -1432,15 +1418,11 @@ def __len__(self):
def _transform(self, index: int):
data = {k: v[index] for k, v in self.arrays.items()}
-
- if not self.transform:
- return data
-
- result = apply_transform(self.transform, data)
+ result = self.transform(data) if self.transform is not None else data
if isinstance(result, dict) or (isinstance(result, list) and isinstance(result[0], dict)):
return result
- raise AssertionError("With a dict supplied to apply_transform, should return a dict or a list of dicts.")
+ raise AssertionError("With a dict supplied to Compose, should return a dict or a list of dicts.")
class CSVDataset(Dataset):
@@ -1692,4 +1674,7 @@ def _load_meta_cache(self, meta_hash_file_name):
if meta_hash_file_name in self._meta_cache:
return self._meta_cache[meta_hash_file_name]
else:
- return torch.load(self.cache_dir / meta_hash_file_name)
+ if "weights_only" in signature(torch.load).parameters:
+ return torch.load(self.cache_dir / meta_hash_file_name, weights_only=False)
+ else:
+ return torch.load(self.cache_dir / meta_hash_file_name)
diff --git a/monai/data/dataset_summary.py b/monai/data/dataset_summary.py
index 769ae33b46..5b9e32afca 100644
--- a/monai/data/dataset_summary.py
+++ b/monai/data/dataset_summary.py
@@ -84,6 +84,7 @@ def collect_meta_data(self):
"""
for data in self.data_loader:
+ meta_dict = {}
if isinstance(data[self.image_key], MetaTensor):
meta_dict = data[self.image_key].meta
elif self.meta_key in data:
diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py
index 2361bb63a7..b4ae562911 100644
--- a/monai/data/image_reader.py
+++ b/monai/data/image_reader.py
@@ -1331,7 +1331,7 @@ def get_data(self, img: NrrdImage | list[NrrdImage]) -> tuple[np.ndarray, dict]:
header[MetaKeys.SPACE] = SpaceKeys.LPS # assuming LPS if not specified
header[MetaKeys.AFFINE] = header[MetaKeys.ORIGINAL_AFFINE].copy()
- header[MetaKeys.SPATIAL_SHAPE] = header["sizes"]
+ header[MetaKeys.SPATIAL_SHAPE] = header["sizes"].copy()
[header.pop(k) for k in ("sizes", "space origin", "space directions")] # rm duplicated data in header
if self.channel_dim is None: # default to "no_channel" or -1
@@ -1359,7 +1359,7 @@ def _get_affine(self, header: dict) -> np.ndarray:
x, y = direction.shape
affine_diam = min(x, y) + 1
affine: np.ndarray = np.eye(affine_diam)
- affine[:x, :y] = direction
+ affine[:x, :y] = direction.T
affine[: (affine_diam - 1), -1] = origin # len origin is always affine_diam - 1
return affine
diff --git a/monai/data/meta_obj.py b/monai/data/meta_obj.py
index 0dccaa9e1c..15e6e8be15 100644
--- a/monai/data/meta_obj.py
+++ b/monai/data/meta_obj.py
@@ -13,8 +13,9 @@
import itertools
import pprint
+from collections.abc import Iterable
from copy import deepcopy
-from typing import Any, Iterable
+from typing import Any
import numpy as np
import torch
diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py
index cad0851a8e..ac171e8508 100644
--- a/monai/data/meta_tensor.py
+++ b/monai/data/meta_tensor.py
@@ -13,8 +13,9 @@
import functools
import warnings
+from collections.abc import Sequence
from copy import deepcopy
-from typing import Any, Sequence
+from typing import Any
import numpy as np
import torch
@@ -505,7 +506,7 @@ def peek_pending_rank(self):
a = self.pending_operations[-1].get(LazyAttr.AFFINE, None) if self.pending_operations else self.affine
return 1 if a is None else int(max(1, len(a) - 1))
- def new_empty(self, size, dtype=None, device=None, requires_grad=False):
+ def new_empty(self, size, dtype=None, device=None, requires_grad=False): # type: ignore[override]
"""
must be defined for deepcopy to work
@@ -580,7 +581,7 @@ def ensure_torch_and_prune_meta(
img.affine = MetaTensor.get_default_affine()
return img
- def __repr__(self):
+ def __repr__(self): # type: ignore[override]
"""
Prints a representation of the tensor.
Prepends "meta" to ``torch.Tensor.__repr__``.
diff --git a/monai/data/test_time_augmentation.py b/monai/data/test_time_augmentation.py
index 23572dcef4..bcd5ea91a9 100644
--- a/monai/data/test_time_augmentation.py
+++ b/monai/data/test_time_augmentation.py
@@ -106,6 +106,8 @@ class TestTimeAugmentation:
mode, mean, std, vvc = tt_aug(test_data)
"""
+ __test__ = False # indicate to pytest that this class is not intended for collection
+
def __init__(
self,
transform: InvertibleTransform,
diff --git a/monai/data/torchscript_utils.py b/monai/data/torchscript_utils.py
index cabf06ce89..507cf411d6 100644
--- a/monai/data/torchscript_utils.py
+++ b/monai/data/torchscript_utils.py
@@ -116,7 +116,7 @@ def load_net_with_metadata(
Returns:
Triple containing loaded object, metadata dict, and extra files dict containing other file data if present
"""
- extra_files = {f: "" for f in more_extra_files}
+ extra_files = dict.fromkeys(more_extra_files, "")
extra_files[METADATA_FILENAME] = ""
jit_obj = torch.jit.load(filename_prefix_or_stream, map_location, extra_files)
diff --git a/monai/data/ultrasound_confidence_map.py b/monai/data/ultrasound_confidence_map.py
index 03813e7559..865e4a0a0f 100644
--- a/monai/data/ultrasound_confidence_map.py
+++ b/monai/data/ultrasound_confidence_map.py
@@ -19,9 +19,11 @@
__all__ = ["UltrasoundConfidenceMap"]
cv2, _ = optional_import("cv2")
-csc_matrix, _ = optional_import("scipy.sparse", "1.7.1", min_version, "csc_matrix")
-spsolve, _ = optional_import("scipy.sparse.linalg", "1.7.1", min_version, "spsolve")
-hilbert, _ = optional_import("scipy.signal", "1.7.1", min_version, "hilbert")
+csc_matrix, _ = optional_import("scipy.sparse", "1.12.0", min_version, "csc_matrix")
+spsolve, _ = optional_import("scipy.sparse.linalg", "1.12.0", min_version, "spsolve")
+cg, _ = optional_import("scipy.sparse.linalg", "1.12.0", min_version, "cg")
+hilbert, _ = optional_import("scipy.signal", "1.12.0", min_version, "hilbert")
+ruge_stuben_solver, _ = optional_import("pyamg", "5.0.0", min_version, "ruge_stuben_solver")
class UltrasoundConfidenceMap:
@@ -30,6 +32,9 @@ class UltrasoundConfidenceMap:
It generates a confidence map by setting source and sink points in the image and computing the probability
for random walks to reach the source for each pixel.
+ The official code is available at:
+ https://campar.in.tum.de/Main/AthanasiosKaramalisCode
+
Args:
alpha (float, optional): Alpha parameter. Defaults to 2.0.
beta (float, optional): Beta parameter. Defaults to 90.0.
@@ -37,15 +42,33 @@ class UltrasoundConfidenceMap:
mode (str, optional): 'RF' or 'B' mode data. Defaults to 'B'.
sink_mode (str, optional): Sink mode. Defaults to 'all'. If 'mask' is selected, a mask must be when calling
the transform. Can be 'all', 'mid', 'min', or 'mask'.
+ use_cg (bool, optional): Use Conjugate Gradient method for solving the linear system. Defaults to False.
+ cg_tol (float, optional): Tolerance for the Conjugate Gradient method. Defaults to 1e-6.
+ Will be used only if `use_cg` is True.
+ cg_maxiter (int, optional): Maximum number of iterations for the Conjugate Gradient method. Defaults to 200.
+ Will be used only if `use_cg` is True.
"""
- def __init__(self, alpha: float = 2.0, beta: float = 90.0, gamma: float = 0.05, mode="B", sink_mode="all"):
+ def __init__(
+ self,
+ alpha: float = 2.0,
+ beta: float = 90.0,
+ gamma: float = 0.05,
+ mode="B",
+ sink_mode="all",
+ use_cg=False,
+ cg_tol=1e-6,
+ cg_maxiter=200,
+ ):
# The hyperparameters for confidence map estimation
self.alpha = alpha
self.beta = beta
self.gamma = gamma
self.mode = mode
self.sink_mode = sink_mode
+ self.use_cg = use_cg
+ self.cg_tol = cg_tol
+ self.cg_maxiter = cg_maxiter
# The precision to use for all computations
self.eps = np.finfo("float64").eps
@@ -228,17 +251,18 @@ def confidence_laplacian(self, padded_index: NDArray, padded_image: NDArray, bet
s = self.normalize(s)
# Horizontal penalty
- s[:vertical_end] += gamma
- # s[vertical_end:diagonal_end] += gamma * np.sqrt(2) # --> In the paper it is sqrt(2)
- # since the diagonal edges are longer yet does not exist in the original code
+ s[vertical_end:] += gamma
+ # Here there is a difference between the official MATLAB code and the paper
+ # on the edge penalty. We directly implement what the official code does.
# Normalize differences
s = self.normalize(s)
# Gaussian weighting function
s = -(
- (np.exp(-beta * s, dtype="float64")) + 1.0e-6
- ) # --> This epsilon changes results drastically default: 1.e-6
+ (np.exp(-beta * s, dtype="float64")) + 1e-5
+ ) # --> This epsilon changes results drastically default: 10e-6
+ # Please notice that it is not 1e-6, it is 10e-6 which is actually different.
# Create Laplacian, diagonal missing
lap = csc_matrix((s, (i, j)))
@@ -256,7 +280,14 @@ def confidence_laplacian(self, padded_index: NDArray, padded_image: NDArray, bet
return lap
def _solve_linear_system(self, lap, rhs):
- x = spsolve(lap, rhs)
+
+ if self.use_cg:
+ lap_sparse = lap.tocsr()
+ ml = ruge_stuben_solver(lap_sparse, coarse_solver="pinv")
+ m = ml.aspreconditioner(cycle="V")
+ x, _ = cg(lap, rhs, rtol=self.cg_tol, maxiter=self.cg_maxiter, M=m)
+ else:
+ x = spsolve(lap, rhs)
return x
diff --git a/monai/data/utils.py b/monai/data/utils.py
index 585f02ec9e..f35c5124d8 100644
--- a/monai/data/utils.py
+++ b/monai/data/utils.py
@@ -53,10 +53,6 @@
pytorch_after,
)
-if pytorch_after(1, 13):
- # import private code for reuse purposes, comment in case things break in the future
- from torch.utils.data._utils.collate import collate_tensor_fn, default_collate_fn_map
-
pd, _ = optional_import("pandas")
DataFrame, _ = optional_import("pandas", name="DataFrame")
nib, _ = optional_import("nibabel")
@@ -454,8 +450,13 @@ def collate_meta_tensor_fn(batch, *, collate_fn_map=None):
Collate a sequence of meta tensor into a single batched metatensor. This is called by `collage_meta_tensor`
and so should not be used as a collate function directly in dataloaders.
"""
- collate_fn = collate_tensor_fn if pytorch_after(1, 13) else default_collate
- collated = collate_fn(batch) # type: ignore
+ if pytorch_after(1, 13):
+ from torch.utils.data._utils.collate import collate_tensor_fn # imported here for pylint/mypy issues
+
+ collated = collate_tensor_fn(batch)
+ else:
+ collated = default_collate(batch)
+
meta_dicts = [i.meta or TraceKeys.NONE for i in batch]
common_ = set.intersection(*[set(d.keys()) for d in meta_dicts if isinstance(d, dict)])
if common_:
@@ -496,6 +497,8 @@ def list_data_collate(batch: Sequence):
if pytorch_after(1, 13):
# needs to go here to avoid circular import
+ from torch.utils.data._utils.collate import default_collate_fn_map
+
from monai.data.meta_tensor import MetaTensor
default_collate_fn_map.update({MetaTensor: collate_meta_tensor_fn})
@@ -924,7 +927,7 @@ def compute_shape_offset(
corners = in_affine_ @ corners
all_dist = corners_out[:-1].copy()
corners_out = corners_out[:-1] / corners_out[-1]
- out_shape = np.round(corners_out.ptp(axis=1)) if scale_extent else np.round(corners_out.ptp(axis=1) + 1.0)
+ out_shape = np.round(np.ptp(corners_out, axis=1)) if scale_extent else np.round(np.ptp(corners_out, axis=1) + 1.0)
offset = None
for i in range(corners.shape[1]):
min_corner = np.min(all_dist - all_dist[:, i : i + 1], 1)
diff --git a/monai/data/video_dataset.py b/monai/data/video_dataset.py
index be3bcf5bd5..031e85db26 100644
--- a/monai/data/video_dataset.py
+++ b/monai/data/video_dataset.py
@@ -173,15 +173,15 @@ def get_available_codecs() -> dict[str, str]:
all_codecs = {"mp4v": ".mp4", "X264": ".avi", "H264": ".mp4", "MP42": ".mp4", "MJPG": ".mjpeg", "DIVX": ".avi"}
codecs = {}
with SuppressStderr():
- writer = cv2.VideoWriter()
with tempfile.TemporaryDirectory() as tmp_dir:
for codec, ext in all_codecs.items():
+ writer = cv2.VideoWriter()
fname = os.path.join(tmp_dir, f"test{ext}")
fourcc = cv2.VideoWriter_fourcc(*codec)
noviderr = writer.open(fname, fourcc, 1, (10, 10))
if noviderr:
codecs[codec] = ext
- writer.release()
+ writer.release()
return codecs
def get_num_frames(self) -> int:
diff --git a/monai/data/wsi_datasets.py b/monai/data/wsi_datasets.py
index 3488029a7a..2ee8c9d363 100644
--- a/monai/data/wsi_datasets.py
+++ b/monai/data/wsi_datasets.py
@@ -23,7 +23,7 @@
from monai.data.utils import iter_patch_position
from monai.data.wsi_reader import BaseWSIReader, WSIReader
from monai.transforms import ForegroundMask, Randomizable, apply_transform
-from monai.utils import convert_to_dst_type, ensure_tuple_rep
+from monai.utils import convert_to_dst_type, ensure_tuple, ensure_tuple_rep
from monai.utils.enums import CommonKeys, ProbMapKeys, WSIPatchKeys
__all__ = ["PatchWSIDataset", "SlidingPatchWSIDataset", "MaskedPatchWSIDataset"]
@@ -123,9 +123,9 @@ def _get_label(self, sample: dict):
def _get_location(self, sample: dict):
if self.center_location:
size = self._get_size(sample)
- return [sample[WSIPatchKeys.LOCATION][i] - size[i] // 2 for i in range(len(size))]
+ return ensure_tuple(sample[WSIPatchKeys.LOCATION][i] - size[i] // 2 for i in range(len(size)))
else:
- return sample[WSIPatchKeys.LOCATION]
+ return ensure_tuple(sample[WSIPatchKeys.LOCATION])
def _get_level(self, sample: dict):
if self.patch_level is None:
diff --git a/monai/data/wsi_reader.py b/monai/data/wsi_reader.py
index b31d4d9c3a..2a4fe9f7a8 100644
--- a/monai/data/wsi_reader.py
+++ b/monai/data/wsi_reader.py
@@ -1097,8 +1097,8 @@ def get_mpp(self, wsi, level: int) -> tuple[float, float]:
):
unit = wsi.pages[level].tags.get("ResolutionUnit")
if unit is not None:
- unit = str(unit.value)[8:]
- else:
+ unit = str(unit.value.name)
+ if unit is None or len(unit) == 0:
warnings.warn("The resolution unit is missing. `micrometer` will be used as default.")
unit = "micrometer"
diff --git a/monai/engines/__init__.py b/monai/engines/__init__.py
index d8dc51f620..93cc40e292 100644
--- a/monai/engines/__init__.py
+++ b/monai/engines/__init__.py
@@ -12,12 +12,14 @@
from __future__ import annotations
from .evaluator import EnsembleEvaluator, Evaluator, SupervisedEvaluator
-from .trainer import GanTrainer, SupervisedTrainer, Trainer
+from .trainer import AdversarialTrainer, GanTrainer, SupervisedTrainer, Trainer
from .utils import (
+ DiffusionPrepareBatch,
IterationEvents,
PrepareBatch,
PrepareBatchDefault,
PrepareBatchExtraInput,
+ VPredictionPrepareBatch,
default_make_latent,
default_metric_cmp_fn,
default_prepare_batch,
diff --git a/monai/engines/evaluator.py b/monai/engines/evaluator.py
index 2c8dfe6b85..d70a39726b 100644
--- a/monai/engines/evaluator.py
+++ b/monai/engines/evaluator.py
@@ -12,19 +12,20 @@
from __future__ import annotations
import warnings
-from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence
+from collections.abc import Iterable, Sequence
+from typing import TYPE_CHECKING, Any, Callable
import torch
from torch.utils.data import DataLoader
-from monai.config import IgniteInfo, KeysCollection
+from monai.config import KeysCollection
from monai.data import MetaTensor
from monai.engines.utils import IterationEvents, default_metric_cmp_fn, default_prepare_batch
from monai.engines.workflow import Workflow
from monai.inferers import Inferer, SimpleInferer
from monai.networks.utils import eval_mode, train_mode
from monai.transforms import Transform
-from monai.utils import ForwardMode, ensure_tuple, min_version, optional_import
+from monai.utils import ForwardMode, IgniteInfo, ensure_tuple, min_version, optional_import
from monai.utils.enums import CommonKeys as Keys
from monai.utils.enums import EngineStatsKeys as ESKeys
from monai.utils.module import look_up_option, pytorch_after
diff --git a/monai/engines/trainer.py b/monai/engines/trainer.py
index f1513ea73b..a0be86bae5 100644
--- a/monai/engines/trainer.py
+++ b/monai/engines/trainer.py
@@ -12,19 +12,19 @@
from __future__ import annotations
import warnings
-from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence
+from collections.abc import Iterable, Sequence
+from typing import TYPE_CHECKING, Any, Callable
import torch
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader
-from monai.config import IgniteInfo
from monai.data import MetaTensor
from monai.engines.utils import IterationEvents, default_make_latent, default_metric_cmp_fn, default_prepare_batch
from monai.engines.workflow import Workflow
from monai.inferers import Inferer, SimpleInferer
from monai.transforms import Transform
-from monai.utils import GanKeys, min_version, optional_import
+from monai.utils import AdversarialIterationEvents, AdversarialKeys, GanKeys, IgniteInfo, min_version, optional_import
from monai.utils.enums import CommonKeys as Keys
from monai.utils.enums import EngineStatsKeys as ESKeys
from monai.utils.module import pytorch_after
@@ -37,7 +37,7 @@
Metric, _ = optional_import("ignite.metrics", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Metric")
EventEnum, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EventEnum")
-__all__ = ["Trainer", "SupervisedTrainer", "GanTrainer"]
+__all__ = ["Trainer", "SupervisedTrainer", "GanTrainer", "AdversarialTrainer"]
class Trainer(Workflow):
@@ -471,3 +471,282 @@ def _iteration(
GanKeys.GLOSS: g_loss.item(),
GanKeys.DLOSS: d_total_loss.item(),
}
+
+
+class AdversarialTrainer(Trainer):
+ """
+ Standard supervised training workflow for adversarial loss enabled neural networks.
+
+ Args:
+ device: an object representing the device on which to run.
+ max_epochs: the total epoch number for engine to run.
+ train_data_loader: Core ignite engines uses `DataLoader` for training loop batchdata.
+ g_network: ''generator'' (G) network architecture.
+ g_optimizer: G optimizer function.
+ g_loss_function: G loss function for adversarial training.
+ recon_loss_function: G loss function for reconstructions.
+ d_network: discriminator (D) network architecture.
+ d_optimizer: D optimizer function.
+ d_loss_function: D loss function for adversarial training..
+ epoch_length: number of iterations for one epoch, default to `len(train_data_loader)`.
+ non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously with respect to
+ the host. For other cases, this argument has no effect.
+ prepare_batch: function to parse image and label for current iteration.
+ iteration_update: the callable function for every iteration, expect to accept `engine` and `batchdata` as input
+ parameters. if not provided, use `self._iteration()` instead.
+ g_inferer: inference method to execute G model forward. Defaults to ``SimpleInferer()``.
+ d_inferer: inference method to execute D model forward. Defaults to ``SimpleInferer()``.
+ postprocessing: execute additional transformation for the model output data. Typically, several Tensor based
+ transforms composed by `Compose`. Defaults to None
+ key_train_metric: compute metric when every iteration completed, and save average value to engine.state.metrics
+ when epoch completed. key_train_metric is the main metric to compare and save the checkpoint into files.
+ additional_metrics: more Ignite metrics that also attach to Ignite Engine.
+ metric_cmp_fn: function to compare current key metric with previous best key metric value, it must accept 2 args
+ (current_metric, previous_best) and return a bool result: if `True`, will update 'best_metric` and
+ `best_metric_epoch` with current metric and epoch, default to `greater than`.
+ train_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like:
+ CheckpointHandler, StatsHandler, etc.
+ amp: whether to enable auto-mixed-precision training, default is False.
+ event_names: additional custom ignite events that will register to the engine.
+ new events can be a list of str or `ignite.engine.events.EventEnum`.
+ event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`.
+ for more details, check: https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html
+ #ignite.engine.engine.Engine.register_events.
+ decollate: whether to decollate the batch-first data to a list of data after model computation, recommend
+ `decollate=True` when `postprocessing` uses components from `monai.transforms`. default to `True`.
+ optim_set_to_none: when calling `optimizer.zero_grad()`, instead of setting to zero, set the grads to None.
+ more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html.
+ to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
+ `device`, `non_blocking`.
+ amp_kwargs: dict of the args for `torch.cuda.amp.autocast()` API, for more details:
+ https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.autocast.
+ """
+
+ def __init__(
+ self,
+ device: torch.device | str,
+ max_epochs: int,
+ train_data_loader: Iterable | DataLoader,
+ g_network: torch.nn.Module,
+ g_optimizer: Optimizer,
+ g_loss_function: Callable,
+ recon_loss_function: Callable,
+ d_network: torch.nn.Module,
+ d_optimizer: Optimizer,
+ d_loss_function: Callable,
+ epoch_length: int | None = None,
+ non_blocking: bool = False,
+ prepare_batch: Callable = default_prepare_batch,
+ iteration_update: Callable | None = None,
+ g_inferer: Inferer | None = None,
+ d_inferer: Inferer | None = None,
+ postprocessing: Transform | None = None,
+ key_train_metric: dict[str, Metric] | None = None,
+ additional_metrics: dict[str, Metric] | None = None,
+ metric_cmp_fn: Callable = default_metric_cmp_fn,
+ train_handlers: Sequence | None = None,
+ amp: bool = False,
+ event_names: list[str | EventEnum | type[EventEnum]] | None = None,
+ event_to_attr: dict | None = None,
+ decollate: bool = True,
+ optim_set_to_none: bool = False,
+ to_kwargs: dict | None = None,
+ amp_kwargs: dict | None = None,
+ ):
+ super().__init__(
+ device=device,
+ max_epochs=max_epochs,
+ data_loader=train_data_loader,
+ epoch_length=epoch_length,
+ non_blocking=non_blocking,
+ prepare_batch=prepare_batch,
+ iteration_update=iteration_update,
+ postprocessing=postprocessing,
+ key_metric=key_train_metric,
+ additional_metrics=additional_metrics,
+ metric_cmp_fn=metric_cmp_fn,
+ handlers=train_handlers,
+ amp=amp,
+ event_names=event_names,
+ event_to_attr=event_to_attr,
+ decollate=decollate,
+ to_kwargs=to_kwargs,
+ amp_kwargs=amp_kwargs,
+ )
+
+ self.register_events(*AdversarialIterationEvents)
+
+ self.state.g_network = g_network
+ self.state.g_optimizer = g_optimizer
+ self.state.g_loss_function = g_loss_function
+ self.state.recon_loss_function = recon_loss_function
+
+ self.state.d_network = d_network
+ self.state.d_optimizer = d_optimizer
+ self.state.d_loss_function = d_loss_function
+
+ self.g_inferer = SimpleInferer() if g_inferer is None else g_inferer
+ self.d_inferer = SimpleInferer() if d_inferer is None else d_inferer
+
+ self.state.g_scaler = torch.cuda.amp.GradScaler() if self.amp else None
+ self.state.d_scaler = torch.cuda.amp.GradScaler() if self.amp else None
+
+ self.optim_set_to_none = optim_set_to_none
+ self._complete_state_dict_user_keys()
+
+ def _complete_state_dict_user_keys(self) -> None:
+ """
+ This method appends to the _state_dict_user_keys AdversarialTrainer's elements that are required for
+ checkpoint saving.
+
+ Follows the example found at:
+ https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html#ignite.engine.engine.Engine.state_dict
+ """
+ self._state_dict_user_keys.extend(
+ ["g_network", "g_optimizer", "d_network", "d_optimizer", "g_scaler", "d_scaler"]
+ )
+
+ g_loss_state_dict = getattr(self.state.g_loss_function, "state_dict", None)
+ if callable(g_loss_state_dict):
+ self._state_dict_user_keys.append("g_loss_function")
+
+ d_loss_state_dict = getattr(self.state.d_loss_function, "state_dict", None)
+ if callable(d_loss_state_dict):
+ self._state_dict_user_keys.append("d_loss_function")
+
+ recon_loss_state_dict = getattr(self.state.recon_loss_function, "state_dict", None)
+ if callable(recon_loss_state_dict):
+ self._state_dict_user_keys.append("recon_loss_function")
+
+ def _iteration(
+ self, engine: AdversarialTrainer, batchdata: dict[str, torch.Tensor]
+ ) -> dict[str, torch.Tensor | int | float | bool]:
+ """
+ Callback function for the Adversarial Training processing logic of 1 iteration in Ignite Engine.
+ Return below items in a dictionary:
+ - IMAGE: image Tensor data for model input, already moved to device.
+ - LABEL: label Tensor data corresponding to the image, already moved to device. In case of Unsupervised
+ Learning this is equal to IMAGE.
+ - PRED: prediction result of model.
+ - LOSS: loss value computed by loss functions of the generator (reconstruction and adversarial summed up).
+ - AdversarialKeys.REALS: real images from the batch. Are the same as IMAGE.
+ - AdversarialKeys.FAKES: fake images generated by the generator. Are the same as PRED.
+ - AdversarialKeys.REAL_LOGITS: logits of the discriminator for the real images.
+ - AdversarialKeys.FAKE_LOGITS: logits of the discriminator for the fake images.
+ - AdversarialKeys.RECONSTRUCTION_LOSS: loss value computed by the reconstruction loss function.
+ - AdversarialKeys.GENERATOR_LOSS: loss value computed by the generator loss function. It is the
+ discriminator loss for the fake images. That is backpropagated through the generator only.
+ - AdversarialKeys.DISCRIMINATOR_LOSS: loss value computed by the discriminator loss function. It is the
+ discriminator loss for the real images and the fake images. That is backpropagated through the
+ discriminator only.
+
+ Args:
+ engine: `AdversarialTrainer` to execute operation for an iteration.
+ batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data.
+
+ Raises:
+ ValueError: must provide batch data for current iteration.
+
+ """
+
+ if batchdata is None:
+ raise ValueError("Must provide batch data for current iteration.")
+ batch = engine.prepare_batch(batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs)
+
+ if len(batch) == 2:
+ inputs, targets = batch
+ args: tuple = ()
+ kwargs: dict = {}
+ else:
+ inputs, targets, args, kwargs = batch
+
+ engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets, AdversarialKeys.REALS: inputs}
+
+ def _compute_generator_loss() -> None:
+ engine.state.output[AdversarialKeys.FAKES] = engine.g_inferer(
+ inputs, engine.state.g_network, *args, **kwargs
+ )
+ engine.state.output[Keys.PRED] = engine.state.output[AdversarialKeys.FAKES]
+ engine.fire_event(AdversarialIterationEvents.GENERATOR_FORWARD_COMPLETED)
+
+ engine.state.output[AdversarialKeys.FAKE_LOGITS] = engine.d_inferer(
+ engine.state.output[AdversarialKeys.FAKES].float().contiguous(), engine.state.d_network, *args, **kwargs
+ )
+ engine.fire_event(AdversarialIterationEvents.GENERATOR_DISCRIMINATOR_FORWARD_COMPLETED)
+
+ engine.state.output[AdversarialKeys.RECONSTRUCTION_LOSS] = engine.state.recon_loss_function(
+ engine.state.output[AdversarialKeys.FAKES], targets
+ ).mean()
+ engine.fire_event(AdversarialIterationEvents.RECONSTRUCTION_LOSS_COMPLETED)
+
+ engine.state.output[AdversarialKeys.GENERATOR_LOSS] = engine.state.g_loss_function(
+ engine.state.output[AdversarialKeys.FAKE_LOGITS]
+ ).mean()
+ engine.fire_event(AdversarialIterationEvents.GENERATOR_LOSS_COMPLETED)
+
+ # Train Generator
+ engine.state.g_network.train()
+ engine.state.g_optimizer.zero_grad(set_to_none=engine.optim_set_to_none)
+
+ if engine.amp and engine.state.g_scaler is not None:
+ with torch.cuda.amp.autocast(**engine.amp_kwargs):
+ _compute_generator_loss()
+
+ engine.state.output[Keys.LOSS] = (
+ engine.state.output[AdversarialKeys.RECONSTRUCTION_LOSS]
+ + engine.state.output[AdversarialKeys.GENERATOR_LOSS]
+ )
+ engine.state.g_scaler.scale(engine.state.output[Keys.LOSS]).backward()
+ engine.fire_event(AdversarialIterationEvents.GENERATOR_BACKWARD_COMPLETED)
+ engine.state.g_scaler.step(engine.state.g_optimizer)
+ engine.state.g_scaler.update()
+ else:
+ _compute_generator_loss()
+ (
+ engine.state.output[AdversarialKeys.RECONSTRUCTION_LOSS]
+ + engine.state.output[AdversarialKeys.GENERATOR_LOSS]
+ ).backward()
+ engine.fire_event(AdversarialIterationEvents.GENERATOR_BACKWARD_COMPLETED)
+ engine.state.g_optimizer.step()
+ engine.fire_event(AdversarialIterationEvents.GENERATOR_MODEL_COMPLETED)
+
+ def _compute_discriminator_loss() -> None:
+ engine.state.output[AdversarialKeys.REAL_LOGITS] = engine.d_inferer(
+ engine.state.output[AdversarialKeys.REALS].contiguous().detach(),
+ engine.state.d_network,
+ *args,
+ **kwargs,
+ )
+ engine.fire_event(AdversarialIterationEvents.DISCRIMINATOR_REALS_FORWARD_COMPLETED)
+
+ engine.state.output[AdversarialKeys.FAKE_LOGITS] = engine.d_inferer(
+ engine.state.output[AdversarialKeys.FAKES].contiguous().detach(),
+ engine.state.d_network,
+ *args,
+ **kwargs,
+ )
+ engine.fire_event(AdversarialIterationEvents.DISCRIMINATOR_FAKES_FORWARD_COMPLETED)
+
+ engine.state.output[AdversarialKeys.DISCRIMINATOR_LOSS] = engine.state.d_loss_function(
+ engine.state.output[AdversarialKeys.REAL_LOGITS], engine.state.output[AdversarialKeys.FAKE_LOGITS]
+ ).mean()
+ engine.fire_event(AdversarialIterationEvents.DISCRIMINATOR_LOSS_COMPLETED)
+
+ # Train Discriminator
+ engine.state.d_network.train()
+ engine.state.d_network.zero_grad(set_to_none=engine.optim_set_to_none)
+
+ if engine.amp and engine.state.d_scaler is not None:
+ with torch.cuda.amp.autocast(**engine.amp_kwargs):
+ _compute_discriminator_loss()
+
+ engine.state.d_scaler.scale(engine.state.output[AdversarialKeys.DISCRIMINATOR_LOSS]).backward()
+ engine.fire_event(AdversarialIterationEvents.DISCRIMINATOR_BACKWARD_COMPLETED)
+ engine.state.d_scaler.step(engine.state.d_optimizer)
+ engine.state.d_scaler.update()
+ else:
+ _compute_discriminator_loss()
+ engine.state.output[AdversarialKeys.DISCRIMINATOR_LOSS].backward()
+ engine.state.d_optimizer.step()
+
+ return engine.state.output
diff --git a/monai/engines/utils.py b/monai/engines/utils.py
index 02c718cd14..8e19a18601 100644
--- a/monai/engines/utils.py
+++ b/monai/engines/utils.py
@@ -12,14 +12,14 @@
from __future__ import annotations
from abc import ABC, abstractmethod
-from collections.abc import Callable, Sequence
+from collections.abc import Callable, Mapping, Sequence
from typing import TYPE_CHECKING, Any, cast
import torch
+import torch.nn as nn
-from monai.config import IgniteInfo
from monai.transforms import apply_transform
-from monai.utils import ensure_tuple, min_version, optional_import
+from monai.utils import IgniteInfo, ensure_tuple, min_version, optional_import
from monai.utils.enums import CommonKeys, GanKeys
if TYPE_CHECKING:
@@ -36,6 +36,8 @@
"PrepareBatch",
"PrepareBatchDefault",
"PrepareBatchExtraInput",
+ "DiffusionPrepareBatch",
+ "VPredictionPrepareBatch",
"default_make_latent",
"engine_apply_transform",
"default_metric_cmp_fn",
@@ -238,6 +240,78 @@ def _get_data(key: str) -> torch.Tensor:
return cast(torch.Tensor, image), cast(torch.Tensor, label), tuple(args_), kwargs_
+class DiffusionPrepareBatch(PrepareBatch):
+ """
+ This class is used as a callable for the `prepare_batch` parameter of engine classes for diffusion training.
+
+ Assuming a supervised training process, it will generate a noise field using `get_noise` for an input image, and
+ return the image and noise field as the image/target pair plus the noise field the kwargs under the key "noise".
+ This assumes the inferer being used in conjunction with this class expects a "noise" parameter to be provided.
+
+ If the `condition_name` is provided, this must refer to a key in the input dictionary containing the condition
+ field to be passed to the inferer. This will appear in the keyword arguments under the key "condition".
+
+ """
+
+ def __init__(self, num_train_timesteps: int, condition_name: str | None = None) -> None:
+ self.condition_name = condition_name
+ self.num_train_timesteps = num_train_timesteps
+
+ def get_noise(self, images: torch.Tensor) -> torch.Tensor:
+ """Returns the noise tensor for input tensor `images`, override this for different noise distributions."""
+ return torch.randn_like(images)
+
+ def get_timesteps(self, images: torch.Tensor) -> torch.Tensor:
+ """Get a timestep, by default this is a random integer between 0 and `self.num_train_timesteps`."""
+ return torch.randint(0, self.num_train_timesteps, (images.shape[0],), device=images.device).long()
+
+ def get_target(self, images: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
+ """Return the target for the loss function, this is the `noise` value by default."""
+ return noise
+
+ def __call__(
+ self,
+ batchdata: dict[str, torch.Tensor],
+ device: str | torch.device | None = None,
+ non_blocking: bool = False,
+ **kwargs: Any,
+ ) -> tuple[torch.Tensor, torch.Tensor, tuple, dict]:
+ images, _ = default_prepare_batch(batchdata, device, non_blocking, **kwargs)
+ noise = self.get_noise(images).to(device, non_blocking=non_blocking, **kwargs)
+ timesteps = self.get_timesteps(images).to(device, non_blocking=non_blocking, **kwargs)
+
+ target = self.get_target(images, noise, timesteps).to(device, non_blocking=non_blocking, **kwargs)
+ infer_kwargs = {"noise": noise, "timesteps": timesteps}
+
+ if self.condition_name is not None and isinstance(batchdata, Mapping):
+ infer_kwargs["condition"] = batchdata[self.condition_name].to(device, non_blocking=non_blocking, **kwargs)
+
+ # return input, target, arguments, and keyword arguments where noise is the target and also a keyword value
+ return images, target, (), infer_kwargs
+
+
+class VPredictionPrepareBatch(DiffusionPrepareBatch):
+ """
+ This class is used as a callable for the `prepare_batch` parameter of engine classes for diffusion training.
+
+ Assuming a supervised training process, it will generate a noise field using `get_noise` for an input image, and
+ from this compute the velocity using the provided scheduler. This value is used as the target in place of the
+ noise field itself although the noise is field is in the kwargs under the key "noise". This assumes the inferer
+ being used in conjunction with this class expects a "noise" parameter to be provided.
+
+ If the `condition_name` is provided, this must refer to a key in the input dictionary containing the condition
+ field to be passed to the inferer. This will appear in the keyword arguments under the key "condition".
+
+ """
+
+ def __init__(self, scheduler: nn.Module, num_train_timesteps: int, condition_name: str | None = None) -> None:
+ super().__init__(num_train_timesteps=num_train_timesteps, condition_name=condition_name)
+ self.scheduler = scheduler
+
+ def get_target(self, images, noise, timesteps):
+ return self.scheduler.get_velocity(images, noise, timesteps)
+
+
def default_make_latent(
num_latents: int,
latent_size: int,
diff --git a/monai/engines/workflow.py b/monai/engines/workflow.py
index 30622c2b93..3629659db1 100644
--- a/monai/engines/workflow.py
+++ b/monai/engines/workflow.py
@@ -20,10 +20,9 @@
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
-from monai.config import IgniteInfo
from monai.engines.utils import IterationEvents, default_metric_cmp_fn, default_prepare_batch
from monai.transforms import Decollated
-from monai.utils import ensure_tuple, is_scalar, min_version, optional_import
+from monai.utils import IgniteInfo, ensure_tuple, is_scalar, min_version, optional_import
from .utils import engine_apply_transform
diff --git a/monai/fl/client/monai_algo.py b/monai/fl/client/monai_algo.py
index 9acf131bd9..a3ac58c221 100644
--- a/monai/fl/client/monai_algo.py
+++ b/monai/fl/client/monai_algo.py
@@ -134,12 +134,14 @@ def initialize(self, extra=None):
Args:
extra: Dict with additional information that should be provided by FL system,
- i.e., `ExtraItems.CLIENT_NAME` and `ExtraItems.APP_ROOT`.
+ i.e., `ExtraItems.CLIENT_NAME`, `ExtraItems.APP_ROOT` and `ExtraItems.LOGGING_FILE`.
+ You can diable the logging logic in the monai bundle by setting {ExtraItems.LOGGING_FILE} to False.
"""
if extra is None:
extra = {}
self.client_name = extra.get(ExtraItems.CLIENT_NAME, "noname")
+ logging_file = extra.get(ExtraItems.LOGGING_FILE, None)
self.logger.info(f"Initializing {self.client_name} ...")
# FL platform needs to provide filepath to configuration files
@@ -149,7 +151,7 @@ def initialize(self, extra=None):
if self.workflow is None:
config_train_files = self._add_config_files(self.config_train_filename)
self.workflow = ConfigWorkflow(
- config_file=config_train_files, meta_file=None, logging_file=None, workflow_type="train"
+ config_file=config_train_files, meta_file=None, logging_file=logging_file, workflow_type="train"
)
self.workflow.initialize()
self.workflow.bundle_root = self.bundle_root
@@ -412,13 +414,15 @@ def initialize(self, extra=None):
Args:
extra: Dict with additional information that should be provided by FL system,
- i.e., `ExtraItems.CLIENT_NAME` and `ExtraItems.APP_ROOT`.
+ i.e., `ExtraItems.CLIENT_NAME`, `ExtraItems.APP_ROOT` and `ExtraItems.LOGGING_FILE`.
+ You can diable the logging logic in the monai bundle by setting {ExtraItems.LOGGING_FILE} to False.
"""
self._set_cuda_device()
if extra is None:
extra = {}
self.client_name = extra.get(ExtraItems.CLIENT_NAME, "noname")
+ logging_file = extra.get(ExtraItems.LOGGING_FILE, None)
timestamp = time.strftime("%Y%m%d_%H%M%S")
self.logger.info(f"Initializing {self.client_name} ...")
# FL platform needs to provide filepath to configuration files
@@ -434,7 +438,7 @@ def initialize(self, extra=None):
self.train_workflow = ConfigWorkflow(
config_file=config_train_files,
meta_file=None,
- logging_file=None,
+ logging_file=logging_file,
workflow_type="train",
**self.train_kwargs,
)
@@ -459,7 +463,7 @@ def initialize(self, extra=None):
self.eval_workflow = ConfigWorkflow(
config_file=config_eval_files,
meta_file=None,
- logging_file=None,
+ logging_file=logging_file,
workflow_type=self.eval_workflow_name,
**self.eval_kwargs,
)
diff --git a/monai/fl/utils/constants.py b/monai/fl/utils/constants.py
index eda1a6b4f9..18beceeaee 100644
--- a/monai/fl/utils/constants.py
+++ b/monai/fl/utils/constants.py
@@ -30,6 +30,7 @@ class ExtraItems(StrEnum):
CLIENT_NAME = "fl_client_name"
APP_ROOT = "fl_app_root"
STATS_SENDER = "fl_stats_sender"
+ LOGGING_FILE = "logging_file"
class FlPhase(StrEnum):
diff --git a/monai/handlers/__init__.py b/monai/handlers/__init__.py
index 641f9aae7d..c1fa448f25 100644
--- a/monai/handlers/__init__.py
+++ b/monai/handlers/__init__.py
@@ -20,7 +20,7 @@
from .earlystop_handler import EarlyStopHandler
from .garbage_collector import GarbageCollector
from .hausdorff_distance import HausdorffDistance
-from .ignite_metric import IgniteMetric, IgniteMetricHandler
+from .ignite_metric import IgniteMetricHandler
from .logfile_handler import LogfileHandler
from .lr_schedule_handler import LrScheduleHandler
from .mean_dice import MeanDice
@@ -40,5 +40,6 @@
from .stats_handler import StatsHandler
from .surface_distance import SurfaceDistance
from .tensorboard_handlers import TensorBoardHandler, TensorBoardImageHandler, TensorBoardStatsHandler
+from .trt_handler import TrtHandler
from .utils import from_engine, ignore_data, stopping_fn_from_loss, stopping_fn_from_metric, write_metrics_reports
from .validation_handler import ValidationHandler
diff --git a/monai/handlers/checkpoint_loader.py b/monai/handlers/checkpoint_loader.py
index 9a867534a3..f48968ecfd 100644
--- a/monai/handlers/checkpoint_loader.py
+++ b/monai/handlers/checkpoint_loader.py
@@ -17,9 +17,8 @@
import torch
-from monai.config import IgniteInfo
from monai.networks.utils import copy_model_state
-from monai.utils import min_version, optional_import
+from monai.utils import IgniteInfo, min_version, optional_import
Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
Checkpoint, _ = optional_import("ignite.handlers", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Checkpoint")
diff --git a/monai/handlers/checkpoint_saver.py b/monai/handlers/checkpoint_saver.py
index 0651c6ff33..2a3a467570 100644
--- a/monai/handlers/checkpoint_saver.py
+++ b/monai/handlers/checkpoint_saver.py
@@ -17,8 +17,7 @@
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any
-from monai.config import IgniteInfo
-from monai.utils import is_scalar, min_version, optional_import
+from monai.utils import IgniteInfo, is_scalar, min_version, optional_import
Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
diff --git a/monai/handlers/classification_saver.py b/monai/handlers/classification_saver.py
index 831808f4fb..ffcfe3c1fb 100644
--- a/monai/handlers/classification_saver.py
+++ b/monai/handlers/classification_saver.py
@@ -18,8 +18,8 @@
import torch
-from monai.config import IgniteInfo
from monai.data import CSVSaver, decollate_batch
+from monai.utils import IgniteInfo
from monai.utils import ImageMetaKey as Key
from monai.utils import evenly_divisible_all_gather, min_version, optional_import, string_list_all_gather
diff --git a/monai/handlers/clearml_handlers.py b/monai/handlers/clearml_handlers.py
index 1cfd6a33fb..0aa2a5cc08 100644
--- a/monai/handlers/clearml_handlers.py
+++ b/monai/handlers/clearml_handlers.py
@@ -11,7 +11,8 @@
from __future__ import annotations
-from typing import TYPE_CHECKING, Any, Mapping, Sequence
+from collections.abc import Mapping, Sequence
+from typing import TYPE_CHECKING, Any
from monai.utils import optional_import
diff --git a/monai/handlers/decollate_batch.py b/monai/handlers/decollate_batch.py
index ac3aa94145..81415bd56e 100644
--- a/monai/handlers/decollate_batch.py
+++ b/monai/handlers/decollate_batch.py
@@ -13,10 +13,10 @@
from typing import TYPE_CHECKING
-from monai.config import IgniteInfo, KeysCollection
+from monai.config import KeysCollection
from monai.engines.utils import IterationEvents
from monai.transforms import Decollated
-from monai.utils import min_version, optional_import
+from monai.utils import IgniteInfo, min_version, optional_import
Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
if TYPE_CHECKING:
diff --git a/monai/handlers/earlystop_handler.py b/monai/handlers/earlystop_handler.py
index 93334bf5c0..0562335192 100644
--- a/monai/handlers/earlystop_handler.py
+++ b/monai/handlers/earlystop_handler.py
@@ -14,8 +14,7 @@
from collections.abc import Callable
from typing import TYPE_CHECKING
-from monai.config import IgniteInfo
-from monai.utils import min_version, optional_import
+from monai.utils import IgniteInfo, min_version, optional_import
Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
EarlyStopping, _ = optional_import("ignite.handlers", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EarlyStopping")
diff --git a/monai/handlers/garbage_collector.py b/monai/handlers/garbage_collector.py
index 3d7e948364..586fa10d33 100644
--- a/monai/handlers/garbage_collector.py
+++ b/monai/handlers/garbage_collector.py
@@ -14,8 +14,7 @@
import gc
from typing import TYPE_CHECKING
-from monai.config import IgniteInfo
-from monai.utils import min_version, optional_import
+from monai.utils import IgniteInfo, min_version, optional_import
if TYPE_CHECKING:
from ignite.engine import Engine, Events
diff --git a/monai/handlers/ignite_metric.py b/monai/handlers/ignite_metric.py
index 021154d705..44a5634c42 100644
--- a/monai/handlers/ignite_metric.py
+++ b/monai/handlers/ignite_metric.py
@@ -18,9 +18,8 @@
import torch
from torch.nn.modules.loss import _Loss
-from monai.config import IgniteInfo
from monai.metrics import CumulativeIterationMetric, LossMetric
-from monai.utils import MetricReduction, deprecated, min_version, optional_import
+from monai.utils import IgniteInfo, MetricReduction, min_version, optional_import
idist, _ = optional_import("ignite", IgniteInfo.OPT_IMPORT_VERSION, min_version, "distributed")
@@ -153,25 +152,3 @@ def attach(self, engine: Engine, name: str) -> None: # type: ignore[override]
self._name = name
if self.save_details and not hasattr(engine.state, "metric_details"):
engine.state.metric_details = {} # type: ignore
-
-
-@deprecated(since="1.2", removed="1.4", msg_suffix="Use IgniteMetricHandler instead of IgniteMetric.")
-class IgniteMetric(IgniteMetricHandler):
-
- def __init__(
- self,
- metric_fn: CumulativeIterationMetric | None = None,
- loss_fn: _Loss | None = None,
- output_transform: Callable = lambda x: x,
- save_details: bool = True,
- reduction: MetricReduction | str = MetricReduction.MEAN,
- get_not_nans: bool = False,
- ) -> None:
- super().__init__(
- metric_fn=metric_fn,
- loss_fn=loss_fn,
- output_transform=output_transform,
- save_details=save_details,
- reduction=reduction,
- get_not_nans=get_not_nans,
- )
diff --git a/monai/handlers/logfile_handler.py b/monai/handlers/logfile_handler.py
index df6ebd34a7..0c44ae47f4 100644
--- a/monai/handlers/logfile_handler.py
+++ b/monai/handlers/logfile_handler.py
@@ -15,8 +15,7 @@
import os
from typing import TYPE_CHECKING
-from monai.config import IgniteInfo
-from monai.utils import min_version, optional_import
+from monai.utils import IgniteInfo, min_version, optional_import
Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
if TYPE_CHECKING:
diff --git a/monai/handlers/lr_schedule_handler.py b/monai/handlers/lr_schedule_handler.py
index a79722517d..8d90992a84 100644
--- a/monai/handlers/lr_schedule_handler.py
+++ b/monai/handlers/lr_schedule_handler.py
@@ -17,8 +17,7 @@
from torch.optim.lr_scheduler import ReduceLROnPlateau, _LRScheduler
-from monai.config import IgniteInfo
-from monai.utils import ensure_tuple, min_version, optional_import
+from monai.utils import IgniteInfo, ensure_tuple, min_version, optional_import
Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
if TYPE_CHECKING:
diff --git a/monai/handlers/metric_logger.py b/monai/handlers/metric_logger.py
index d59205a021..62cdee6509 100644
--- a/monai/handlers/metric_logger.py
+++ b/monai/handlers/metric_logger.py
@@ -17,8 +17,7 @@
from threading import RLock
from typing import TYPE_CHECKING, Any
-from monai.config import IgniteInfo
-from monai.utils import min_version, optional_import
+from monai.utils import IgniteInfo, min_version, optional_import
from monai.utils.enums import CommonKeys
Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
diff --git a/monai/handlers/metrics_saver.py b/monai/handlers/metrics_saver.py
index 88a0926b91..6175b1242a 100644
--- a/monai/handlers/metrics_saver.py
+++ b/monai/handlers/metrics_saver.py
@@ -14,9 +14,9 @@
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING
-from monai.config import IgniteInfo
from monai.data import decollate_batch
from monai.handlers.utils import write_metrics_reports
+from monai.utils import IgniteInfo
from monai.utils import ImageMetaKey as Key
from monai.utils import ensure_tuple, min_version, optional_import, string_list_all_gather
diff --git a/monai/handlers/mlflow_handler.py b/monai/handlers/mlflow_handler.py
index df209c1c8b..3078d89f97 100644
--- a/monai/handlers/mlflow_handler.py
+++ b/monai/handlers/mlflow_handler.py
@@ -21,14 +21,17 @@
import torch
from torch.utils.data import Dataset
-from monai.config import IgniteInfo
-from monai.utils import CommonKeys, ensure_tuple, min_version, optional_import
+from monai.apps.utils import get_logger
+from monai.utils import CommonKeys, IgniteInfo, ensure_tuple, flatten_dict, min_version, optional_import
Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
mlflow, _ = optional_import("mlflow", descriptor="Please install mlflow before using MLFlowHandler.")
mlflow.entities, _ = optional_import(
"mlflow.entities", descriptor="Please install mlflow.entities before using MLFlowHandler."
)
+MlflowException, _ = optional_import(
+ "mlflow.exceptions", name="MlflowException", descriptor="Please install mlflow before using MLFlowHandler."
+)
pandas, _ = optional_import("pandas", descriptor="Please install pandas for recording the dataset.")
tqdm, _ = optional_import("tqdm", "4.47.0", min_version, "tqdm")
@@ -41,6 +44,8 @@
DEFAULT_TAG = "Loss"
+logger = get_logger(module_name=__name__)
+
class MLFlowHandler:
"""
@@ -236,10 +241,21 @@ def start(self, engine: Engine) -> None:
def _set_experiment(self):
experiment = self.experiment
if not experiment:
- experiment = self.client.get_experiment_by_name(self.experiment_name)
- if not experiment:
- experiment_id = self.client.create_experiment(self.experiment_name)
- experiment = self.client.get_experiment(experiment_id)
+ for _retry_time in range(3):
+ try:
+ experiment = self.client.get_experiment_by_name(self.experiment_name)
+ if not experiment:
+ experiment_id = self.client.create_experiment(self.experiment_name)
+ experiment = self.client.get_experiment(experiment_id)
+ break
+ except MlflowException as e:
+ if "RESOURCE_ALREADY_EXISTS" in str(e):
+ logger.warning("Experiment already exists; delaying before retrying.")
+ time.sleep(1)
+ if _retry_time == 2:
+ raise e
+ else:
+ raise e
if experiment.lifecycle_stage != mlflow.entities.LifecycleStage.ACTIVE:
raise ValueError(f"Cannot set a deleted experiment '{self.experiment_name}' as the active experiment")
@@ -287,7 +303,9 @@ def _log_metrics(self, metrics: dict[str, Any], step: int | None = None) -> None
run_id = self.cur_run.info.run_id
timestamp = int(time.time() * 1000)
- metrics_arr = [mlflow.entities.Metric(key, value, timestamp, step or 0) for key, value in metrics.items()]
+ metrics_arr = [
+ mlflow.entities.Metric(key, value, timestamp, step or 0) for key, value in flatten_dict(metrics).items()
+ ]
self.client.log_batch(run_id=run_id, metrics=metrics_arr, params=[], tags=[])
def _parse_artifacts(self):
diff --git a/monai/handlers/nvtx_handlers.py b/monai/handlers/nvtx_handlers.py
index 38eef6f05b..bd22af0db8 100644
--- a/monai/handlers/nvtx_handlers.py
+++ b/monai/handlers/nvtx_handlers.py
@@ -16,8 +16,7 @@
from typing import TYPE_CHECKING
-from monai.config import IgniteInfo
-from monai.utils import ensure_tuple, min_version, optional_import
+from monai.utils import IgniteInfo, ensure_tuple, min_version, optional_import
_nvtx, _ = optional_import("torch._C._nvtx", descriptor="NVTX is not installed. Are you sure you have a CUDA build?")
if TYPE_CHECKING:
diff --git a/monai/handlers/parameter_scheduler.py b/monai/handlers/parameter_scheduler.py
index d12e6e072c..1ce6193b6d 100644
--- a/monai/handlers/parameter_scheduler.py
+++ b/monai/handlers/parameter_scheduler.py
@@ -16,8 +16,7 @@
from collections.abc import Callable
from typing import TYPE_CHECKING
-from monai.config import IgniteInfo
-from monai.utils import min_version, optional_import
+from monai.utils import IgniteInfo, min_version, optional_import
if TYPE_CHECKING:
from ignite.engine import Engine, Events
diff --git a/monai/handlers/postprocessing.py b/monai/handlers/postprocessing.py
index c698c84338..541b5924d1 100644
--- a/monai/handlers/postprocessing.py
+++ b/monai/handlers/postprocessing.py
@@ -14,9 +14,8 @@
from collections.abc import Callable
from typing import TYPE_CHECKING
-from monai.config import IgniteInfo
from monai.engines.utils import IterationEvents, engine_apply_transform
-from monai.utils import min_version, optional_import
+from monai.utils import IgniteInfo, min_version, optional_import
Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
if TYPE_CHECKING:
diff --git a/monai/handlers/probability_maps.py b/monai/handlers/probability_maps.py
index 8a60fcc983..e21bd199f8 100644
--- a/monai/handlers/probability_maps.py
+++ b/monai/handlers/probability_maps.py
@@ -17,10 +17,10 @@
import numpy as np
-from monai.config import DtypeLike, IgniteInfo
+from monai.config import DtypeLike
from monai.data.folder_layout import FolderLayout
from monai.utils import ProbMapKeys, min_version, optional_import
-from monai.utils.enums import CommonKeys
+from monai.utils.enums import CommonKeys, IgniteInfo
Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
if TYPE_CHECKING:
diff --git a/monai/handlers/smartcache_handler.py b/monai/handlers/smartcache_handler.py
index ee043635db..e07e98e541 100644
--- a/monai/handlers/smartcache_handler.py
+++ b/monai/handlers/smartcache_handler.py
@@ -13,9 +13,8 @@
from typing import TYPE_CHECKING
-from monai.config import IgniteInfo
from monai.data import SmartCacheDataset
-from monai.utils import min_version, optional_import
+from monai.utils import IgniteInfo, min_version, optional_import
Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
if TYPE_CHECKING:
diff --git a/monai/handlers/stats_handler.py b/monai/handlers/stats_handler.py
index c49fcda819..214872fef4 100644
--- a/monai/handlers/stats_handler.py
+++ b/monai/handlers/stats_handler.py
@@ -19,8 +19,7 @@
import torch
from monai.apps import get_logger
-from monai.config import IgniteInfo
-from monai.utils import is_scalar, min_version, optional_import
+from monai.utils import IgniteInfo, flatten_dict, is_scalar, min_version, optional_import
Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
if TYPE_CHECKING:
@@ -75,7 +74,7 @@ def __init__(
output_transform: Callable = lambda x: x[0],
global_epoch_transform: Callable = lambda x: x,
state_attributes: Sequence[str] | None = None,
- name: str | None = "StatsHandler",
+ name: str | None = "monai.handlers.StatsHandler",
tag_name: str = DEFAULT_TAG,
key_var_format: str = DEFAULT_KEY_VAL_FORMAT,
) -> None:
@@ -212,8 +211,7 @@ def _default_epoch_print(self, engine: Engine) -> None:
"""
current_epoch = self.global_epoch_transform(engine.state.epoch)
-
- prints_dict = engine.state.metrics
+ prints_dict = flatten_dict(engine.state.metrics)
if prints_dict is not None and len(prints_dict) > 0:
out_str = f"Epoch[{current_epoch}] Metrics -- "
for name in sorted(prints_dict):
diff --git a/monai/handlers/tensorboard_handlers.py b/monai/handlers/tensorboard_handlers.py
index 7b7e3968fb..44a03710de 100644
--- a/monai/handlers/tensorboard_handlers.py
+++ b/monai/handlers/tensorboard_handlers.py
@@ -18,8 +18,7 @@
import numpy as np
import torch
-from monai.config import IgniteInfo
-from monai.utils import is_scalar, min_version, optional_import
+from monai.utils import IgniteInfo, is_scalar, min_version, optional_import
from monai.visualize import plot_2d_or_3d_image
Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
diff --git a/monai/handlers/trt_handler.py b/monai/handlers/trt_handler.py
new file mode 100644
index 0000000000..45e2669f70
--- /dev/null
+++ b/monai/handlers/trt_handler.py
@@ -0,0 +1,60 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from monai.networks import trt_compile
+from monai.utils import IgniteInfo, min_version, optional_import
+
+Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
+if TYPE_CHECKING:
+ from ignite.engine import Engine
+else:
+ Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine")
+
+
+class TrtHandler:
+ """
+ TrtHandler acts as an Ignite handler to apply TRT acceleration to the model.
+ Usage example::
+ handler = TrtHandler(model=model, base_path="/test/checkpoint.pt", args={"precision": "fp16"})
+ handler.attach(engine)
+ engine.run()
+ """
+
+ def __init__(self, model, base_path, args=None, submodule=None):
+ """
+ Args:
+ base_path: TRT path basename. TRT plan(s) saved to "base_path[.submodule].plan"
+ args: passed to trt_compile(). See trt_compile() for details.
+ submodule : Hierarchical ids of submodules to convert, e.g. 'image_decoder.decoder'
+ """
+ self.model = model
+ self.base_path = base_path
+ self.args = args
+ self.submodule = submodule
+
+ def attach(self, engine: Engine) -> None:
+ """
+ Args:
+ engine: Ignite Engine, it can be a trainer, validator or evaluator.
+ """
+ self.logger = engine.logger
+ engine.add_event_handler(Events.STARTED, self)
+
+ def __call__(self, engine: Engine) -> None:
+ """
+ Args:
+ engine: Ignite Engine, it can be a trainer, validator or evaluator.
+ """
+ trt_compile(self.model, self.base_path, args=self.args, submodule=self.submodule, logger=self.logger)
diff --git a/monai/handlers/utils.py b/monai/handlers/utils.py
index 0cd31b89c2..b6771f2dcc 100644
--- a/monai/handlers/utils.py
+++ b/monai/handlers/utils.py
@@ -19,8 +19,8 @@
import numpy as np
import torch
-from monai.config import IgniteInfo, KeysCollection, PathLike
-from monai.utils import ensure_tuple, look_up_option, min_version, optional_import
+from monai.config import KeysCollection, PathLike
+from monai.utils import IgniteInfo, ensure_tuple, look_up_option, min_version, optional_import
idist, _ = optional_import("ignite", IgniteInfo.OPT_IMPORT_VERSION, min_version, "distributed")
if TYPE_CHECKING:
diff --git a/monai/handlers/validation_handler.py b/monai/handlers/validation_handler.py
index 89c7715f42..38dd511aa4 100644
--- a/monai/handlers/validation_handler.py
+++ b/monai/handlers/validation_handler.py
@@ -13,9 +13,8 @@
from typing import TYPE_CHECKING
-from monai.config import IgniteInfo
from monai.engines.evaluator import Evaluator
-from monai.utils import min_version, optional_import
+from monai.utils import IgniteInfo, min_version, optional_import
Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
if TYPE_CHECKING:
diff --git a/monai/inferers/__init__.py b/monai/inferers/__init__.py
index 960380bfb8..fc78b9f7c4 100644
--- a/monai/inferers/__init__.py
+++ b/monai/inferers/__init__.py
@@ -12,13 +12,18 @@
from __future__ import annotations
from .inferer import (
+ ControlNetDiffusionInferer,
+ ControlNetLatentDiffusionInferer,
+ DiffusionInferer,
Inferer,
+ LatentDiffusionInferer,
PatchInferer,
SaliencyInferer,
SimpleInferer,
SliceInferer,
SlidingWindowInferer,
SlidingWindowInfererAdapt,
+ VQVAETransformerInferer,
)
from .merger import AvgMerger, Merger, ZarrAvgMerger
from .splitter import SlidingWindowSplitter, Splitter, WSISlidingWindowSplitter
diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py
index 0b4199938d..769b6cc0e7 100644
--- a/monai/inferers/inferer.py
+++ b/monai/inferers/inferer.py
@@ -11,24 +11,41 @@
from __future__ import annotations
+import math
import warnings
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
+from functools import partial
from pydoc import locate
from typing import Any
import torch
import torch.nn as nn
+import torch.nn.functional as F
from monai.apps.utils import get_logger
+from monai.data import decollate_batch
from monai.data.meta_tensor import MetaTensor
from monai.data.thread_buffer import ThreadBuffer
from monai.inferers.merger import AvgMerger, Merger
from monai.inferers.splitter import Splitter
from monai.inferers.utils import compute_importance_map, sliding_window_inference
-from monai.utils import BlendMode, PatchKeys, PytorchPadMode, ensure_tuple, optional_import
+from monai.networks.nets import (
+ VQVAE,
+ AutoencoderKL,
+ ControlNet,
+ DecoderOnlyTransformer,
+ DiffusionModelUNet,
+ SPADEAutoencoderKL,
+ SPADEDiffusionModelUNet,
+)
+from monai.networks.schedulers import Scheduler
+from monai.transforms import CenterSpatialCrop, SpatialPad
+from monai.utils import BlendMode, Ordering, PatchKeys, PytorchPadMode, ensure_tuple, optional_import
from monai.visualize import CAM, GradCAM, GradCAMpp
+tqdm, has_tqdm = optional_import("tqdm", name="tqdm")
+
logger = get_logger(__name__)
__all__ = [
@@ -752,3 +769,1264 @@ def network_wrapper(
return out
return tuple(out_i.unsqueeze(dim=self.spatial_dim + 2) for out_i in out)
+
+
+class DiffusionInferer(Inferer):
+ """
+ DiffusionInferer takes a trained diffusion model and a scheduler and can be used to perform a signal forward pass
+ for a training iteration, and sample from the model.
+
+ Args:
+ scheduler: diffusion scheduler.
+ """
+
+ def __init__(self, scheduler: Scheduler) -> None: # type: ignore[override]
+ super().__init__()
+
+ self.scheduler = scheduler
+
+ def __call__( # type: ignore[override]
+ self,
+ inputs: torch.Tensor,
+ diffusion_model: DiffusionModelUNet,
+ noise: torch.Tensor,
+ timesteps: torch.Tensor,
+ condition: torch.Tensor | None = None,
+ mode: str = "crossattn",
+ seg: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ """
+ Implements the forward pass for a supervised training iteration.
+
+ Args:
+ inputs: Input image to which noise is added.
+ diffusion_model: diffusion model.
+ noise: random noise, of the same shape as the input.
+ timesteps: random timesteps.
+ condition: Conditioning for network input.
+ mode: Conditioning mode for the network.
+ seg: if model is instance of SPADEDiffusionModelUnet, segmentation must be
+ provided on the forward (for SPADE-like AE or SPADE-like DM)
+ """
+ if mode not in ["crossattn", "concat"]:
+ raise NotImplementedError(f"{mode} condition is not supported")
+
+ noisy_image: torch.Tensor = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
+ if mode == "concat":
+ if condition is None:
+ raise ValueError("Conditioning is required for concat condition")
+ else:
+ noisy_image = torch.cat([noisy_image, condition], dim=1)
+ condition = None
+ diffusion_model = (
+ partial(diffusion_model, seg=seg)
+ if isinstance(diffusion_model, SPADEDiffusionModelUNet)
+ else diffusion_model
+ )
+ prediction: torch.Tensor = diffusion_model(x=noisy_image, timesteps=timesteps, context=condition)
+
+ return prediction
+
+ @torch.no_grad()
+ def sample(
+ self,
+ input_noise: torch.Tensor,
+ diffusion_model: DiffusionModelUNet,
+ scheduler: Scheduler | None = None,
+ save_intermediates: bool | None = False,
+ intermediate_steps: int | None = 100,
+ conditioning: torch.Tensor | None = None,
+ mode: str = "crossattn",
+ verbose: bool = True,
+ seg: torch.Tensor | None = None,
+ ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
+ """
+ Args:
+ input_noise: random noise, of the same shape as the desired sample.
+ diffusion_model: model to sample from.
+ scheduler: diffusion scheduler. If none provided will use the class attribute scheduler
+ save_intermediates: whether to return intermediates along the sampling change
+ intermediate_steps: if save_intermediates is True, saves every n steps
+ conditioning: Conditioning for network input.
+ mode: Conditioning mode for the network.
+ verbose: if true, prints the progression bar of the sampling process.
+ seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided.
+ """
+ if mode not in ["crossattn", "concat"]:
+ raise NotImplementedError(f"{mode} condition is not supported")
+ if mode == "concat" and conditioning is None:
+ raise ValueError("Conditioning must be supplied for if condition mode is concat.")
+ if not scheduler:
+ scheduler = self.scheduler
+ image = input_noise
+ if verbose and has_tqdm:
+ progress_bar = tqdm(scheduler.timesteps)
+ else:
+ progress_bar = iter(scheduler.timesteps)
+ intermediates = []
+ for t in progress_bar:
+ # 1. predict noise model_output
+ diffusion_model = (
+ partial(diffusion_model, seg=seg)
+ if isinstance(diffusion_model, SPADEDiffusionModelUNet)
+ else diffusion_model
+ )
+ if mode == "concat" and conditioning is not None:
+ model_input = torch.cat([image, conditioning], dim=1)
+ model_output = diffusion_model(
+ model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=None
+ )
+ else:
+ model_output = diffusion_model(
+ image, timesteps=torch.Tensor((t,)).to(input_noise.device), context=conditioning
+ )
+
+ # 2. compute previous image: x_t -> x_t-1
+ image, _ = scheduler.step(model_output, t, image)
+ if save_intermediates and t % intermediate_steps == 0:
+ intermediates.append(image)
+ if save_intermediates:
+ return image, intermediates
+ else:
+ return image
+
+ @torch.no_grad()
+ def get_likelihood(
+ self,
+ inputs: torch.Tensor,
+ diffusion_model: DiffusionModelUNet,
+ scheduler: Scheduler | None = None,
+ save_intermediates: bool | None = False,
+ conditioning: torch.Tensor | None = None,
+ mode: str = "crossattn",
+ original_input_range: tuple = (0, 255),
+ scaled_input_range: tuple = (0, 1),
+ verbose: bool = True,
+ seg: torch.Tensor | None = None,
+ ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
+ """
+ Computes the log-likelihoods for an input.
+
+ Args:
+ inputs: input images, NxCxHxW[xD]
+ diffusion_model: model to compute likelihood from
+ scheduler: diffusion scheduler. If none provided will use the class attribute scheduler.
+ save_intermediates: save the intermediate spatial KL maps
+ conditioning: Conditioning for network input.
+ mode: Conditioning mode for the network.
+ original_input_range: the [min,max] intensity range of the input data before any scaling was applied.
+ scaled_input_range: the [min,max] intensity range of the input data after scaling.
+ verbose: if true, prints the progression bar of the sampling process.
+ seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided.
+ """
+
+ if not scheduler:
+ scheduler = self.scheduler
+ if scheduler._get_name() != "DDPMScheduler":
+ raise NotImplementedError(
+ f"Likelihood computation is only compatible with DDPMScheduler,"
+ f" you are using {scheduler._get_name()}"
+ )
+ if mode not in ["crossattn", "concat"]:
+ raise NotImplementedError(f"{mode} condition is not supported")
+ if mode == "concat" and conditioning is None:
+ raise ValueError("Conditioning must be supplied for if condition mode is concat.")
+ if verbose and has_tqdm:
+ progress_bar = tqdm(scheduler.timesteps)
+ else:
+ progress_bar = iter(scheduler.timesteps)
+ intermediates = []
+ noise = torch.randn_like(inputs).to(inputs.device)
+ total_kl = torch.zeros(inputs.shape[0]).to(inputs.device)
+ for t in progress_bar:
+ timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long()
+ noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
+ diffusion_model = (
+ partial(diffusion_model, seg=seg)
+ if isinstance(diffusion_model, SPADEDiffusionModelUNet)
+ else diffusion_model
+ )
+ if mode == "concat" and conditioning is not None:
+ noisy_image = torch.cat([noisy_image, conditioning], dim=1)
+ model_output = diffusion_model(noisy_image, timesteps=timesteps, context=None)
+ else:
+ model_output = diffusion_model(x=noisy_image, timesteps=timesteps, context=conditioning)
+
+ # get the model's predicted mean, and variance if it is predicted
+ if model_output.shape[1] == inputs.shape[1] * 2 and scheduler.variance_type in ["learned", "learned_range"]:
+ model_output, predicted_variance = torch.split(model_output, inputs.shape[1], dim=1)
+ else:
+ predicted_variance = None
+
+ # 1. compute alphas, betas
+ alpha_prod_t = scheduler.alphas_cumprod[t]
+ alpha_prod_t_prev = scheduler.alphas_cumprod[t - 1] if t > 0 else scheduler.one
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+
+ # 2. compute predicted original sample from predicted noise also called
+ # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
+ if scheduler.prediction_type == "epsilon":
+ pred_original_sample = (noisy_image - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
+ elif scheduler.prediction_type == "sample":
+ pred_original_sample = model_output
+ elif scheduler.prediction_type == "v_prediction":
+ pred_original_sample = (alpha_prod_t**0.5) * noisy_image - (beta_prod_t**0.5) * model_output
+ # 3. Clip "predicted x_0"
+ if scheduler.clip_sample:
+ pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
+
+ # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
+ pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * scheduler.betas[t]) / beta_prod_t
+ current_sample_coeff = scheduler.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t
+
+ # 5. Compute predicted previous sample µ_t
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
+ predicted_mean = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * noisy_image
+
+ # get the posterior mean and variance
+ posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image)
+ posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance)
+
+ log_posterior_variance = torch.log(posterior_variance)
+ log_predicted_variance = torch.log(predicted_variance) if predicted_variance else log_posterior_variance
+
+ if t == 0:
+ # compute -log p(x_0|x_1)
+ kl = -self._get_decoder_log_likelihood(
+ inputs=inputs,
+ means=predicted_mean,
+ log_scales=0.5 * log_predicted_variance,
+ original_input_range=original_input_range,
+ scaled_input_range=scaled_input_range,
+ )
+ else:
+ # compute kl between two normals
+ kl = 0.5 * (
+ -1.0
+ + log_predicted_variance
+ - log_posterior_variance
+ + torch.exp(log_posterior_variance - log_predicted_variance)
+ + ((posterior_mean - predicted_mean) ** 2) * torch.exp(-log_predicted_variance)
+ )
+ total_kl += kl.view(kl.shape[0], -1).mean(dim=1)
+ if save_intermediates:
+ intermediates.append(kl.cpu())
+
+ if save_intermediates:
+ return total_kl, intermediates
+ else:
+ return total_kl
+
+ def _approx_standard_normal_cdf(self, x):
+ """
+ A fast approximation of the cumulative distribution function of the
+ standard normal. Code adapted from https://github.com/openai/improved-diffusion.
+ """
+
+ return 0.5 * (
+ 1.0 + torch.tanh(torch.sqrt(torch.Tensor([2.0 / math.pi]).to(x.device)) * (x + 0.044715 * torch.pow(x, 3)))
+ )
+
+ def _get_decoder_log_likelihood(
+ self,
+ inputs: torch.Tensor,
+ means: torch.Tensor,
+ log_scales: torch.Tensor,
+ original_input_range: tuple = (0, 255),
+ scaled_input_range: tuple = (0, 1),
+ ) -> torch.Tensor:
+ """
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
+ given image. Code adapted from https://github.com/openai/improved-diffusion.
+
+ Args:
+ input: the target images. It is assumed that this was uint8 values,
+ rescaled to the range [-1, 1].
+ means: the Gaussian mean Tensor.
+ log_scales: the Gaussian log stddev Tensor.
+ original_input_range: the [min,max] intensity range of the input data before any scaling was applied.
+ scaled_input_range: the [min,max] intensity range of the input data after scaling.
+ """
+ if inputs.shape != means.shape:
+ raise ValueError(f"Inputs and means must have the same shape, got {inputs.shape} and {means.shape}")
+ bin_width = (scaled_input_range[1] - scaled_input_range[0]) / (
+ original_input_range[1] - original_input_range[0]
+ )
+ centered_x = inputs - means
+ inv_stdv = torch.exp(-log_scales)
+ plus_in = inv_stdv * (centered_x + bin_width / 2)
+ cdf_plus = self._approx_standard_normal_cdf(plus_in)
+ min_in = inv_stdv * (centered_x - bin_width / 2)
+ cdf_min = self._approx_standard_normal_cdf(min_in)
+ log_cdf_plus = torch.log(cdf_plus.clamp(min=1e-12))
+ log_one_minus_cdf_min = torch.log((1.0 - cdf_min).clamp(min=1e-12))
+ cdf_delta = cdf_plus - cdf_min
+ log_probs = torch.where(
+ inputs < -0.999,
+ log_cdf_plus,
+ torch.where(inputs > 0.999, log_one_minus_cdf_min, torch.log(cdf_delta.clamp(min=1e-12))),
+ )
+ return log_probs
+
+
+class LatentDiffusionInferer(DiffusionInferer):
+ """
+ LatentDiffusionInferer takes a stage 1 model (VQVAE or AutoencoderKL), diffusion model, and a scheduler, and can
+ be used to perform a signal forward pass for a training iteration, and sample from the model.
+
+ Args:
+ scheduler: a scheduler to be used in combination with `unet` to denoise the encoded image latents.
+ scale_factor: scale factor to multiply the values of the latent representation before processing it by the
+ second stage.
+ ldm_latent_shape: desired spatial latent space shape. Used if there is a difference in the autoencoder model's latent shape.
+ autoencoder_latent_shape: autoencoder_latent_shape: autoencoder spatial latent space shape. Used if there is a
+ difference between the autoencoder's latent shape and the DM shape.
+ """
+
+ def __init__(
+ self,
+ scheduler: Scheduler,
+ scale_factor: float = 1.0,
+ ldm_latent_shape: list | None = None,
+ autoencoder_latent_shape: list | None = None,
+ ) -> None:
+ super().__init__(scheduler=scheduler)
+ self.scale_factor = scale_factor
+ if (ldm_latent_shape is None) ^ (autoencoder_latent_shape is None):
+ raise ValueError("If ldm_latent_shape is None, autoencoder_latent_shape must be None, and vice versa.")
+ self.ldm_latent_shape = ldm_latent_shape
+ self.autoencoder_latent_shape = autoencoder_latent_shape
+ if self.ldm_latent_shape is not None and self.autoencoder_latent_shape is not None:
+ self.ldm_resizer = SpatialPad(spatial_size=self.ldm_latent_shape)
+ self.autoencoder_resizer = CenterSpatialCrop(roi_size=self.autoencoder_latent_shape)
+
+ def __call__( # type: ignore[override]
+ self,
+ inputs: torch.Tensor,
+ autoencoder_model: AutoencoderKL | VQVAE,
+ diffusion_model: DiffusionModelUNet,
+ noise: torch.Tensor,
+ timesteps: torch.Tensor,
+ condition: torch.Tensor | None = None,
+ mode: str = "crossattn",
+ seg: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ """
+ Implements the forward pass for a supervised training iteration.
+
+ Args:
+ inputs: input image to which the latent representation will be extracted and noise is added.
+ autoencoder_model: first stage model.
+ diffusion_model: diffusion model.
+ noise: random noise, of the same shape as the latent representation.
+ timesteps: random timesteps.
+ condition: conditioning for network input.
+ mode: Conditioning mode for the network.
+ seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided.
+ """
+ with torch.no_grad():
+ latent = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor
+
+ if self.ldm_latent_shape is not None:
+ latent = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latent)], 0)
+
+ prediction: torch.Tensor = super().__call__(
+ inputs=latent,
+ diffusion_model=diffusion_model,
+ noise=noise,
+ timesteps=timesteps,
+ condition=condition,
+ mode=mode,
+ seg=seg,
+ )
+ return prediction
+
+ @torch.no_grad()
+ def sample( # type: ignore[override]
+ self,
+ input_noise: torch.Tensor,
+ autoencoder_model: AutoencoderKL | VQVAE,
+ diffusion_model: DiffusionModelUNet,
+ scheduler: Scheduler | None = None,
+ save_intermediates: bool | None = False,
+ intermediate_steps: int | None = 100,
+ conditioning: torch.Tensor | None = None,
+ mode: str = "crossattn",
+ verbose: bool = True,
+ seg: torch.Tensor | None = None,
+ ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
+ """
+ Args:
+ input_noise: random noise, of the same shape as the desired latent representation.
+ autoencoder_model: first stage model.
+ diffusion_model: model to sample from.
+ scheduler: diffusion scheduler. If none provided will use the class attribute scheduler.
+ save_intermediates: whether to return intermediates along the sampling change
+ intermediate_steps: if save_intermediates is True, saves every n steps
+ conditioning: Conditioning for network input.
+ mode: Conditioning mode for the network.
+ verbose: if true, prints the progression bar of the sampling process.
+ seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model
+ is instance of SPADEAutoencoderKL, segmentation must be provided.
+ """
+
+ if (
+ isinstance(autoencoder_model, SPADEAutoencoderKL)
+ and isinstance(diffusion_model, SPADEDiffusionModelUNet)
+ and autoencoder_model.decoder.label_nc != diffusion_model.label_nc
+ ):
+ raise ValueError(
+ f"If both autoencoder_model and diffusion_model implement SPADE, the number of semantic"
+ f"labels for each must be compatible, but got {autoencoder_model.decoder.label_nc} and"
+ f"{diffusion_model.label_nc}"
+ )
+
+ outputs = super().sample(
+ input_noise=input_noise,
+ diffusion_model=diffusion_model,
+ scheduler=scheduler,
+ save_intermediates=save_intermediates,
+ intermediate_steps=intermediate_steps,
+ conditioning=conditioning,
+ mode=mode,
+ verbose=verbose,
+ seg=seg,
+ )
+
+ if save_intermediates:
+ latent, latent_intermediates = outputs
+ else:
+ latent = outputs
+
+ if self.autoencoder_latent_shape is not None:
+ latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0)
+ latent_intermediates = [
+ torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates
+ ]
+
+ decode = autoencoder_model.decode_stage_2_outputs
+ if isinstance(autoencoder_model, SPADEAutoencoderKL):
+ decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg)
+ image = decode(latent / self.scale_factor)
+
+ if save_intermediates:
+ intermediates = []
+ for latent_intermediate in latent_intermediates:
+ decode = autoencoder_model.decode_stage_2_outputs
+ if isinstance(autoencoder_model, SPADEAutoencoderKL):
+ decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg)
+ intermediates.append(decode(latent_intermediate / self.scale_factor))
+ return image, intermediates
+
+ else:
+ return image
+
+ @torch.no_grad()
+ def get_likelihood( # type: ignore[override]
+ self,
+ inputs: torch.Tensor,
+ autoencoder_model: AutoencoderKL | VQVAE,
+ diffusion_model: DiffusionModelUNet,
+ scheduler: Scheduler | None = None,
+ save_intermediates: bool | None = False,
+ conditioning: torch.Tensor | None = None,
+ mode: str = "crossattn",
+ original_input_range: tuple | None = (0, 255),
+ scaled_input_range: tuple | None = (0, 1),
+ verbose: bool = True,
+ resample_latent_likelihoods: bool = False,
+ resample_interpolation_mode: str = "nearest",
+ seg: torch.Tensor | None = None,
+ ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
+ """
+ Computes the log-likelihoods of the latent representations of the input.
+
+ Args:
+ inputs: input images, NxCxHxW[xD]
+ autoencoder_model: first stage model.
+ diffusion_model: model to compute likelihood from
+ scheduler: diffusion scheduler. If none provided will use the class attribute scheduler
+ save_intermediates: save the intermediate spatial KL maps
+ conditioning: Conditioning for network input.
+ mode: Conditioning mode for the network.
+ original_input_range: the [min,max] intensity range of the input data before any scaling was applied.
+ scaled_input_range: the [min,max] intensity range of the input data after scaling.
+ verbose: if true, prints the progression bar of the sampling process.
+ resample_latent_likelihoods: if true, resamples the intermediate likelihood maps to have the same spatial
+ dimension as the input images.
+ resample_interpolation_mode: if use resample_latent_likelihoods, select interpolation 'nearest', 'bilinear',
+ or 'trilinear;
+ seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model
+ is instance of SPADEAutoencoderKL, segmentation must be provided.
+ """
+ if resample_latent_likelihoods and resample_interpolation_mode not in ("nearest", "bilinear", "trilinear"):
+ raise ValueError(
+ f"resample_interpolation mode should be either nearest, bilinear, or trilinear, got {resample_interpolation_mode}"
+ )
+ latents = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor
+
+ if self.ldm_latent_shape is not None:
+ latents = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latents)], 0)
+
+ outputs = super().get_likelihood(
+ inputs=latents,
+ diffusion_model=diffusion_model,
+ scheduler=scheduler,
+ save_intermediates=save_intermediates,
+ conditioning=conditioning,
+ mode=mode,
+ verbose=verbose,
+ seg=seg,
+ )
+
+ if save_intermediates and resample_latent_likelihoods:
+ intermediates = outputs[1]
+ resizer = nn.Upsample(size=inputs.shape[2:], mode=resample_interpolation_mode)
+ intermediates = [resizer(x) for x in intermediates]
+ outputs = (outputs[0], intermediates)
+ return outputs
+
+
+class ControlNetDiffusionInferer(DiffusionInferer):
+ """
+ ControlNetDiffusionInferer takes a trained diffusion model and a scheduler and can be used to perform a signal
+ forward pass for a training iteration, and sample from the model, supporting ControlNet-based conditioning.
+
+ Args:
+ scheduler: diffusion scheduler.
+ """
+
+ def __init__(self, scheduler: Scheduler) -> None:
+ Inferer.__init__(self)
+ self.scheduler = scheduler
+
+ def __call__( # type: ignore[override]
+ self,
+ inputs: torch.Tensor,
+ diffusion_model: DiffusionModelUNet,
+ controlnet: ControlNet,
+ noise: torch.Tensor,
+ timesteps: torch.Tensor,
+ cn_cond: torch.Tensor,
+ condition: torch.Tensor | None = None,
+ mode: str = "crossattn",
+ seg: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ """
+ Implements the forward pass for a supervised training iteration.
+
+ Args:
+ inputs: Input image to which noise is added.
+ diffusion_model: diffusion model.
+ controlnet: controlnet sub-network.
+ noise: random noise, of the same shape as the input.
+ timesteps: random timesteps.
+ cn_cond: conditioning image for the ControlNet.
+ condition: Conditioning for network input.
+ mode: Conditioning mode for the network.
+ seg: if model is instance of SPADEDiffusionModelUnet, segmentation must be
+ provided on the forward (for SPADE-like AE or SPADE-like DM)
+ """
+ if mode not in ["crossattn", "concat"]:
+ raise NotImplementedError(f"{mode} condition is not supported")
+
+ noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
+ down_block_res_samples, mid_block_res_sample = controlnet(
+ x=noisy_image, timesteps=timesteps, controlnet_cond=cn_cond
+ )
+ if mode == "concat" and condition is not None:
+ noisy_image = torch.cat([noisy_image, condition], dim=1)
+ condition = None
+
+ diffuse = diffusion_model
+ if isinstance(diffusion_model, SPADEDiffusionModelUNet):
+ diffuse = partial(diffusion_model, seg=seg)
+
+ prediction: torch.Tensor = diffuse(
+ x=noisy_image,
+ timesteps=timesteps,
+ context=condition,
+ down_block_additional_residuals=down_block_res_samples,
+ mid_block_additional_residual=mid_block_res_sample,
+ )
+
+ return prediction
+
+ @torch.no_grad()
+ def sample( # type: ignore[override]
+ self,
+ input_noise: torch.Tensor,
+ diffusion_model: DiffusionModelUNet,
+ controlnet: ControlNet,
+ cn_cond: torch.Tensor,
+ scheduler: Scheduler | None = None,
+ save_intermediates: bool | None = False,
+ intermediate_steps: int | None = 100,
+ conditioning: torch.Tensor | None = None,
+ mode: str = "crossattn",
+ verbose: bool = True,
+ seg: torch.Tensor | None = None,
+ ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
+ """
+ Args:
+ input_noise: random noise, of the same shape as the desired sample.
+ diffusion_model: model to sample from.
+ controlnet: controlnet sub-network.
+ cn_cond: conditioning image for the ControlNet.
+ scheduler: diffusion scheduler. If none provided will use the class attribute scheduler
+ save_intermediates: whether to return intermediates along the sampling change
+ intermediate_steps: if save_intermediates is True, saves every n steps
+ conditioning: Conditioning for network input.
+ mode: Conditioning mode for the network.
+ verbose: if true, prints the progression bar of the sampling process.
+ seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided.
+ """
+ if mode not in ["crossattn", "concat"]:
+ raise NotImplementedError(f"{mode} condition is not supported")
+
+ if not scheduler:
+ scheduler = self.scheduler
+ image = input_noise
+ if verbose and has_tqdm:
+ progress_bar = tqdm(scheduler.timesteps)
+ else:
+ progress_bar = iter(scheduler.timesteps)
+ intermediates = []
+ for t in progress_bar:
+ # 1. ControlNet forward
+ down_block_res_samples, mid_block_res_sample = controlnet(
+ x=image, timesteps=torch.Tensor((t,)).to(input_noise.device), controlnet_cond=cn_cond
+ )
+ # 2. predict noise model_output
+ diffuse = diffusion_model
+ if isinstance(diffusion_model, SPADEDiffusionModelUNet):
+ diffuse = partial(diffusion_model, seg=seg)
+
+ if mode == "concat" and conditioning is not None:
+ model_input = torch.cat([image, conditioning], dim=1)
+ model_output = diffuse(
+ model_input,
+ timesteps=torch.Tensor((t,)).to(input_noise.device),
+ context=None,
+ down_block_additional_residuals=down_block_res_samples,
+ mid_block_additional_residual=mid_block_res_sample,
+ )
+ else:
+ model_output = diffuse(
+ image,
+ timesteps=torch.Tensor((t,)).to(input_noise.device),
+ context=conditioning,
+ down_block_additional_residuals=down_block_res_samples,
+ mid_block_additional_residual=mid_block_res_sample,
+ )
+
+ # 3. compute previous image: x_t -> x_t-1
+ image, _ = scheduler.step(model_output, t, image)
+ if save_intermediates and t % intermediate_steps == 0:
+ intermediates.append(image)
+ if save_intermediates:
+ return image, intermediates
+ else:
+ return image
+
+ @torch.no_grad()
+ def get_likelihood( # type: ignore[override]
+ self,
+ inputs: torch.Tensor,
+ diffusion_model: DiffusionModelUNet,
+ controlnet: ControlNet,
+ cn_cond: torch.Tensor,
+ scheduler: Scheduler | None = None,
+ save_intermediates: bool | None = False,
+ conditioning: torch.Tensor | None = None,
+ mode: str = "crossattn",
+ original_input_range: tuple = (0, 255),
+ scaled_input_range: tuple = (0, 1),
+ verbose: bool = True,
+ seg: torch.Tensor | None = None,
+ ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
+ """
+ Computes the log-likelihoods for an input.
+
+ Args:
+ inputs: input images, NxCxHxW[xD]
+ diffusion_model: model to compute likelihood from
+ controlnet: controlnet sub-network.
+ cn_cond: conditioning image for the ControlNet.
+ scheduler: diffusion scheduler. If none provided will use the class attribute scheduler.
+ save_intermediates: save the intermediate spatial KL maps
+ conditioning: Conditioning for network input.
+ mode: Conditioning mode for the network.
+ original_input_range: the [min,max] intensity range of the input data before any scaling was applied.
+ scaled_input_range: the [min,max] intensity range of the input data after scaling.
+ verbose: if true, prints the progression bar of the sampling process.
+ seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided.
+ """
+
+ if not scheduler:
+ scheduler = self.scheduler
+ if scheduler._get_name() != "DDPMScheduler":
+ raise NotImplementedError(
+ f"Likelihood computation is only compatible with DDPMScheduler,"
+ f" you are using {scheduler._get_name()}"
+ )
+ if mode not in ["crossattn", "concat"]:
+ raise NotImplementedError(f"{mode} condition is not supported")
+ if verbose and has_tqdm:
+ progress_bar = tqdm(scheduler.timesteps)
+ else:
+ progress_bar = iter(scheduler.timesteps)
+ intermediates = []
+ noise = torch.randn_like(inputs).to(inputs.device)
+ total_kl = torch.zeros(inputs.shape[0]).to(inputs.device)
+ for t in progress_bar:
+ timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long()
+ noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
+ down_block_res_samples, mid_block_res_sample = controlnet(
+ x=noisy_image, timesteps=torch.Tensor((t,)).to(inputs.device), controlnet_cond=cn_cond
+ )
+
+ diffuse = diffusion_model
+ if isinstance(diffusion_model, SPADEDiffusionModelUNet):
+ diffuse = partial(diffusion_model, seg=seg)
+
+ if mode == "concat" and conditioning is not None:
+ noisy_image = torch.cat([noisy_image, conditioning], dim=1)
+ model_output = diffuse(
+ noisy_image,
+ timesteps=timesteps,
+ context=None,
+ down_block_additional_residuals=down_block_res_samples,
+ mid_block_additional_residual=mid_block_res_sample,
+ )
+ else:
+ model_output = diffuse(
+ x=noisy_image,
+ timesteps=timesteps,
+ context=conditioning,
+ down_block_additional_residuals=down_block_res_samples,
+ mid_block_additional_residual=mid_block_res_sample,
+ )
+ # get the model's predicted mean, and variance if it is predicted
+ if model_output.shape[1] == inputs.shape[1] * 2 and scheduler.variance_type in ["learned", "learned_range"]:
+ model_output, predicted_variance = torch.split(model_output, inputs.shape[1], dim=1)
+ else:
+ predicted_variance = None
+
+ # 1. compute alphas, betas
+ alpha_prod_t = scheduler.alphas_cumprod[t]
+ alpha_prod_t_prev = scheduler.alphas_cumprod[t - 1] if t > 0 else scheduler.one
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+
+ # 2. compute predicted original sample from predicted noise also called
+ # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
+ if scheduler.prediction_type == "epsilon":
+ pred_original_sample = (noisy_image - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
+ elif scheduler.prediction_type == "sample":
+ pred_original_sample = model_output
+ elif scheduler.prediction_type == "v_prediction":
+ pred_original_sample = (alpha_prod_t**0.5) * noisy_image - (beta_prod_t**0.5) * model_output
+ # 3. Clip "predicted x_0"
+ if scheduler.clip_sample:
+ pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
+
+ # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
+ pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * scheduler.betas[t]) / beta_prod_t
+ current_sample_coeff = scheduler.alphas[t] ** (0.5) * beta_prod_t_prev / beta_prod_t
+
+ # 5. Compute predicted previous sample µ_t
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
+ predicted_mean = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * noisy_image
+
+ # get the posterior mean and variance
+ posterior_mean = scheduler._get_mean(timestep=t, x_0=inputs, x_t=noisy_image)
+ posterior_variance = scheduler._get_variance(timestep=t, predicted_variance=predicted_variance)
+
+ log_posterior_variance = torch.log(posterior_variance)
+ log_predicted_variance = torch.log(predicted_variance) if predicted_variance else log_posterior_variance
+
+ if t == 0:
+ # compute -log p(x_0|x_1)
+ kl = -super()._get_decoder_log_likelihood(
+ inputs=inputs,
+ means=predicted_mean,
+ log_scales=0.5 * log_predicted_variance,
+ original_input_range=original_input_range,
+ scaled_input_range=scaled_input_range,
+ )
+ else:
+ # compute kl between two normals
+ kl = 0.5 * (
+ -1.0
+ + log_predicted_variance
+ - log_posterior_variance
+ + torch.exp(log_posterior_variance - log_predicted_variance)
+ + ((posterior_mean - predicted_mean) ** 2) * torch.exp(-log_predicted_variance)
+ )
+ total_kl += kl.view(kl.shape[0], -1).mean(dim=1)
+ if save_intermediates:
+ intermediates.append(kl.cpu())
+
+ if save_intermediates:
+ return total_kl, intermediates
+ else:
+ return total_kl
+
+
+class ControlNetLatentDiffusionInferer(ControlNetDiffusionInferer):
+ """
+ ControlNetLatentDiffusionInferer takes a stage 1 model (VQVAE or AutoencoderKL), diffusion model, controlnet,
+ and a scheduler, and can be used to perform a signal forward pass for a training iteration, and sample from
+ the model.
+
+ Args:
+ scheduler: a scheduler to be used in combination with `unet` to denoise the encoded image latents.
+ scale_factor: scale factor to multiply the values of the latent representation before processing it by the
+ second stage.
+ ldm_latent_shape: desired spatial latent space shape. Used if there is a difference in the autoencoder model's latent shape.
+ autoencoder_latent_shape: autoencoder_latent_shape: autoencoder spatial latent space shape. Used if there is a
+ difference between the autoencoder's latent shape and the DM shape.
+ """
+
+ def __init__(
+ self,
+ scheduler: Scheduler,
+ scale_factor: float = 1.0,
+ ldm_latent_shape: list | None = None,
+ autoencoder_latent_shape: list | None = None,
+ ) -> None:
+ super().__init__(scheduler=scheduler)
+ self.scale_factor = scale_factor
+ if (ldm_latent_shape is None) ^ (autoencoder_latent_shape is None):
+ raise ValueError("If ldm_latent_shape is None, autoencoder_latent_shape must be None" "and vice versa.")
+ self.ldm_latent_shape = ldm_latent_shape
+ self.autoencoder_latent_shape = autoencoder_latent_shape
+ if self.ldm_latent_shape is not None and self.autoencoder_latent_shape is not None:
+ self.ldm_resizer = SpatialPad(spatial_size=self.ldm_latent_shape)
+ self.autoencoder_resizer = CenterSpatialCrop(roi_size=self.autoencoder_latent_shape)
+
+ def __call__( # type: ignore[override]
+ self,
+ inputs: torch.Tensor,
+ autoencoder_model: AutoencoderKL | VQVAE,
+ diffusion_model: DiffusionModelUNet,
+ controlnet: ControlNet,
+ noise: torch.Tensor,
+ timesteps: torch.Tensor,
+ cn_cond: torch.Tensor,
+ condition: torch.Tensor | None = None,
+ mode: str = "crossattn",
+ seg: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ """
+ Implements the forward pass for a supervised training iteration.
+
+ Args:
+ inputs: input image to which the latent representation will be extracted and noise is added.
+ autoencoder_model: first stage model.
+ diffusion_model: diffusion model.
+ controlnet: instance of ControlNet model
+ noise: random noise, of the same shape as the latent representation.
+ timesteps: random timesteps.
+ cn_cond: conditioning tensor for the ControlNet network
+ condition: conditioning for network input.
+ mode: Conditioning mode for the network.
+ seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided.
+ """
+ with torch.no_grad():
+ latent = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor
+
+ if self.ldm_latent_shape is not None:
+ latent = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latent)], 0)
+
+ if cn_cond.shape[2:] != latent.shape[2:]:
+ cn_cond = F.interpolate(cn_cond, latent.shape[2:])
+
+ prediction = super().__call__(
+ inputs=latent,
+ diffusion_model=diffusion_model,
+ controlnet=controlnet,
+ noise=noise,
+ timesteps=timesteps,
+ cn_cond=cn_cond,
+ condition=condition,
+ mode=mode,
+ seg=seg,
+ )
+
+ return prediction
+
+ @torch.no_grad()
+ def sample( # type: ignore[override]
+ self,
+ input_noise: torch.Tensor,
+ autoencoder_model: AutoencoderKL | VQVAE,
+ diffusion_model: DiffusionModelUNet,
+ controlnet: ControlNet,
+ cn_cond: torch.Tensor,
+ scheduler: Scheduler | None = None,
+ save_intermediates: bool | None = False,
+ intermediate_steps: int | None = 100,
+ conditioning: torch.Tensor | None = None,
+ mode: str = "crossattn",
+ verbose: bool = True,
+ seg: torch.Tensor | None = None,
+ ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
+ """
+ Args:
+ input_noise: random noise, of the same shape as the desired latent representation.
+ autoencoder_model: first stage model.
+ diffusion_model: model to sample from.
+ controlnet: instance of ControlNet model.
+ cn_cond: conditioning tensor for the ControlNet network.
+ scheduler: diffusion scheduler. If none provided will use the class attribute scheduler.
+ save_intermediates: whether to return intermediates along the sampling change
+ intermediate_steps: if save_intermediates is True, saves every n steps
+ conditioning: Conditioning for network input.
+ mode: Conditioning mode for the network.
+ verbose: if true, prints the progression bar of the sampling process.
+ seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model
+ is instance of SPADEAutoencoderKL, segmentation must be provided.
+ """
+
+ if (
+ isinstance(autoencoder_model, SPADEAutoencoderKL)
+ and isinstance(diffusion_model, SPADEDiffusionModelUNet)
+ and autoencoder_model.decoder.label_nc != diffusion_model.label_nc
+ ):
+ raise ValueError(
+ "If both autoencoder_model and diffusion_model implement SPADE, the number of semantic"
+ "labels for each must be compatible. Got {autoencoder_model.decoder.label_nc} and {diffusion_model.label_nc}"
+ )
+
+ if cn_cond.shape[2:] != input_noise.shape[2:]:
+ cn_cond = F.interpolate(cn_cond, input_noise.shape[2:])
+
+ outputs = super().sample(
+ input_noise=input_noise,
+ diffusion_model=diffusion_model,
+ controlnet=controlnet,
+ cn_cond=cn_cond,
+ scheduler=scheduler,
+ save_intermediates=save_intermediates,
+ intermediate_steps=intermediate_steps,
+ conditioning=conditioning,
+ mode=mode,
+ verbose=verbose,
+ seg=seg,
+ )
+
+ if save_intermediates:
+ latent, latent_intermediates = outputs
+ else:
+ latent = outputs
+
+ if self.autoencoder_latent_shape is not None:
+ latent = torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(latent)], 0)
+ latent_intermediates = [
+ torch.stack([self.autoencoder_resizer(i) for i in decollate_batch(l)], 0) for l in latent_intermediates
+ ]
+
+ decode = autoencoder_model.decode_stage_2_outputs
+ if isinstance(autoencoder_model, SPADEAutoencoderKL):
+ decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg)
+
+ image = decode(latent / self.scale_factor)
+
+ if save_intermediates:
+ intermediates = []
+ for latent_intermediate in latent_intermediates:
+ decode = autoencoder_model.decode_stage_2_outputs
+ if isinstance(autoencoder_model, SPADEAutoencoderKL):
+ decode = partial(autoencoder_model.decode_stage_2_outputs, seg=seg)
+ intermediates.append(decode(latent_intermediate / self.scale_factor))
+ return image, intermediates
+
+ else:
+ return image
+
+ @torch.no_grad()
+ def get_likelihood( # type: ignore[override]
+ self,
+ inputs: torch.Tensor,
+ autoencoder_model: AutoencoderKL | VQVAE,
+ diffusion_model: DiffusionModelUNet,
+ controlnet: ControlNet,
+ cn_cond: torch.Tensor,
+ scheduler: Scheduler | None = None,
+ save_intermediates: bool | None = False,
+ conditioning: torch.Tensor | None = None,
+ mode: str = "crossattn",
+ original_input_range: tuple | None = (0, 255),
+ scaled_input_range: tuple | None = (0, 1),
+ verbose: bool = True,
+ resample_latent_likelihoods: bool = False,
+ resample_interpolation_mode: str = "nearest",
+ seg: torch.Tensor | None = None,
+ ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
+ """
+ Computes the log-likelihoods of the latent representations of the input.
+
+ Args:
+ inputs: input images, NxCxHxW[xD]
+ autoencoder_model: first stage model.
+ diffusion_model: model to compute likelihood from
+ controlnet: instance of ControlNet model.
+ cn_cond: conditioning tensor for the ControlNet network.
+ scheduler: diffusion scheduler. If none provided will use the class attribute scheduler
+ save_intermediates: save the intermediate spatial KL maps
+ conditioning: Conditioning for network input.
+ mode: Conditioning mode for the network.
+ original_input_range: the [min,max] intensity range of the input data before any scaling was applied.
+ scaled_input_range: the [min,max] intensity range of the input data after scaling.
+ verbose: if true, prints the progression bar of the sampling process.
+ resample_latent_likelihoods: if true, resamples the intermediate likelihood maps to have the same spatial
+ dimension as the input images.
+ resample_interpolation_mode: if use resample_latent_likelihoods, select interpolation 'nearest', 'bilinear',
+ or 'trilinear;
+ seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model
+ is instance of SPADEAutoencoderKL, segmentation must be provided.
+ """
+ if resample_latent_likelihoods and resample_interpolation_mode not in ("nearest", "bilinear", "trilinear"):
+ raise ValueError(
+ f"resample_interpolation mode should be either nearest, bilinear, or trilinear, got {resample_interpolation_mode}"
+ )
+
+ latents = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor
+
+ if cn_cond.shape[2:] != latents.shape[2:]:
+ cn_cond = F.interpolate(cn_cond, latents.shape[2:])
+
+ if self.ldm_latent_shape is not None:
+ latents = torch.stack([self.ldm_resizer(i) for i in decollate_batch(latents)], 0)
+
+ outputs = super().get_likelihood(
+ inputs=latents,
+ diffusion_model=diffusion_model,
+ controlnet=controlnet,
+ cn_cond=cn_cond,
+ scheduler=scheduler,
+ save_intermediates=save_intermediates,
+ conditioning=conditioning,
+ mode=mode,
+ verbose=verbose,
+ seg=seg,
+ )
+
+ if save_intermediates and resample_latent_likelihoods:
+ intermediates = outputs[1]
+ resizer = nn.Upsample(size=inputs.shape[2:], mode=resample_interpolation_mode)
+ intermediates = [resizer(x) for x in intermediates]
+ outputs = (outputs[0], intermediates)
+ return outputs
+
+
+class VQVAETransformerInferer(nn.Module):
+ """
+ Class to perform inference with a VQVAE + Transformer model.
+ """
+
+ def __init__(self) -> None:
+ Inferer.__init__(self)
+
+ def __call__(
+ self,
+ inputs: torch.Tensor,
+ vqvae_model: VQVAE,
+ transformer_model: DecoderOnlyTransformer,
+ ordering: Ordering,
+ condition: torch.Tensor | None = None,
+ return_latent: bool = False,
+ ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor, tuple]:
+ """
+ Implements the forward pass for a supervised training iteration.
+
+ Args:
+ inputs: input image to which the latent representation will be extracted.
+ vqvae_model: first stage model.
+ transformer_model: autoregressive transformer model.
+ ordering: ordering of the quantised latent representation.
+ return_latent: also return latent sequence and spatial dim of the latent.
+ condition: conditioning for network input.
+ """
+ with torch.no_grad():
+ latent = vqvae_model.index_quantize(inputs)
+
+ latent_spatial_dim = tuple(latent.shape[1:])
+ latent = latent.reshape(latent.shape[0], -1)
+ latent = latent[:, ordering.get_sequence_ordering()]
+
+ # get the targets for the loss
+ target = latent.clone()
+ # Use the value from vqvae_model's num_embeddings as the starting token, the "Begin Of Sentence" (BOS) token.
+ # Note the transformer_model must have vqvae_model.num_embeddings + 1 defined as num_tokens.
+ latent = F.pad(latent, (1, 0), "constant", vqvae_model.num_embeddings)
+ # crop the last token as we do not need the probability of the token that follows it
+ latent = latent[:, :-1]
+ latent = latent.long()
+
+ # train on a part of the sequence if it is longer than max_seq_length
+ seq_len = latent.shape[1]
+ max_seq_len = transformer_model.max_seq_len
+ if max_seq_len < seq_len:
+ start = int(torch.randint(low=0, high=seq_len + 1 - max_seq_len, size=(1,)).item())
+ else:
+ start = 0
+ prediction: torch.Tensor = transformer_model(x=latent[:, start : start + max_seq_len], context=condition)
+ if return_latent:
+ return prediction, target[:, start : start + max_seq_len], latent_spatial_dim
+ else:
+ return prediction
+
+ @torch.no_grad()
+ def sample(
+ self,
+ latent_spatial_dim: tuple[int, int, int] | tuple[int, int],
+ starting_tokens: torch.Tensor,
+ vqvae_model: VQVAE,
+ transformer_model: DecoderOnlyTransformer,
+ ordering: Ordering,
+ conditioning: torch.Tensor | None = None,
+ temperature: float = 1.0,
+ top_k: int | None = None,
+ verbose: bool = True,
+ ) -> torch.Tensor:
+ """
+ Sampling function for the VQVAE + Transformer model.
+
+ Args:
+ latent_spatial_dim: shape of the sampled image.
+ starting_tokens: starting tokens for the sampling. It must be vqvae_model.num_embeddings value.
+ vqvae_model: first stage model.
+ transformer_model: model to sample from.
+ conditioning: Conditioning for network input.
+ temperature: temperature for sampling.
+ top_k: top k sampling.
+ verbose: if true, prints the progression bar of the sampling process.
+ """
+ seq_len = math.prod(latent_spatial_dim)
+
+ if verbose and has_tqdm:
+ progress_bar = tqdm(range(seq_len))
+ else:
+ progress_bar = iter(range(seq_len))
+
+ latent_seq = starting_tokens.long()
+ for _ in progress_bar:
+ # if the sequence context is growing too long we must crop it at block_size
+ if latent_seq.size(1) <= transformer_model.max_seq_len:
+ idx_cond = latent_seq
+ else:
+ idx_cond = latent_seq[:, -transformer_model.max_seq_len :]
+
+ # forward the model to get the logits for the index in the sequence
+ logits = transformer_model(x=idx_cond, context=conditioning)
+ # pluck the logits at the final step and scale by desired temperature
+ logits = logits[:, -1, :] / temperature
+ # optionally crop the logits to only the top k options
+ if top_k is not None:
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
+ logits[logits < v[:, [-1]]] = -float("Inf")
+ # apply softmax to convert logits to (normalized) probabilities
+ probs = F.softmax(logits, dim=-1)
+ # remove the chance to be sampled the BOS token
+ probs[:, vqvae_model.num_embeddings] = 0
+ # sample from the distribution
+ idx_next = torch.multinomial(probs, num_samples=1)
+ # append sampled index to the running sequence and continue
+ latent_seq = torch.cat((latent_seq, idx_next), dim=1)
+
+ latent_seq = latent_seq[:, 1:]
+ latent_seq = latent_seq[:, ordering.get_revert_sequence_ordering()]
+ latent = latent_seq.reshape((starting_tokens.shape[0],) + latent_spatial_dim)
+
+ return vqvae_model.decode_samples(latent)
+
+ @torch.no_grad()
+ def get_likelihood(
+ self,
+ inputs: torch.Tensor,
+ vqvae_model: VQVAE,
+ transformer_model: DecoderOnlyTransformer,
+ ordering: Ordering,
+ condition: torch.Tensor | None = None,
+ resample_latent_likelihoods: bool = False,
+ resample_interpolation_mode: str = "nearest",
+ verbose: bool = False,
+ ) -> torch.Tensor:
+ """
+ Computes the log-likelihoods of the latent representations of the input.
+
+ Args:
+ inputs: input images, NxCxHxW[xD]
+ vqvae_model: first stage model.
+ transformer_model: autoregressive transformer model.
+ ordering: ordering of the quantised latent representation.
+ condition: conditioning for network input.
+ resample_latent_likelihoods: if true, resamples the intermediate likelihood maps to have the same spatial
+ dimension as the input images.
+ resample_interpolation_mode: if use resample_latent_likelihoods, select interpolation 'nearest', 'bilinear',
+ or 'trilinear;
+ verbose: if true, prints the progression bar of the sampling process.
+
+ """
+ if resample_latent_likelihoods and resample_interpolation_mode not in ("nearest", "bilinear", "trilinear"):
+ raise ValueError(
+ f"resample_interpolation mode should be either nearest, bilinear, or trilinear, got {resample_interpolation_mode}"
+ )
+
+ with torch.no_grad():
+ latent = vqvae_model.index_quantize(inputs)
+
+ latent_spatial_dim = tuple(latent.shape[1:])
+ latent = latent.reshape(latent.shape[0], -1)
+ latent = latent[:, ordering.get_sequence_ordering()]
+ seq_len = math.prod(latent_spatial_dim)
+
+ # Use the value from vqvae_model's num_embeddings as the starting token, the "Begin Of Sentence" (BOS) token.
+ # Note the transformer_model must have vqvae_model.num_embeddings + 1 defined as num_tokens.
+ latent = F.pad(latent, (1, 0), "constant", vqvae_model.num_embeddings)
+ latent = latent.long()
+
+ # get the first batch, up to max_seq_length, efficiently
+ logits = transformer_model(x=latent[:, : transformer_model.max_seq_len], context=condition)
+ probs = F.softmax(logits, dim=-1)
+ # target token for each set of logits is the next token along
+ target = latent[:, 1:]
+ probs = torch.gather(probs, 2, target[:, : transformer_model.max_seq_len].unsqueeze(2)).squeeze(2)
+
+ # if we have not covered the full sequence we continue with inefficient looping
+ if probs.shape[1] < target.shape[1]:
+ if verbose and has_tqdm:
+ progress_bar = tqdm(range(transformer_model.max_seq_len, seq_len))
+ else:
+ progress_bar = iter(range(transformer_model.max_seq_len, seq_len))
+
+ for i in progress_bar:
+ idx_cond = latent[:, i + 1 - transformer_model.max_seq_len : i + 1]
+ # forward the model to get the logits for the index in the sequence
+ logits = transformer_model(x=idx_cond, context=condition)
+ # pluck the logits at the final step
+ logits = logits[:, -1, :]
+ # apply softmax to convert logits to (normalized) probabilities
+ p = F.softmax(logits, dim=-1)
+ # select correct values and append
+ p = torch.gather(p, 1, target[:, i].unsqueeze(1))
+
+ probs = torch.cat((probs, p), dim=1)
+
+ # convert to log-likelihood
+ probs = torch.log(probs)
+
+ # reshape
+ probs = probs[:, ordering.get_revert_sequence_ordering()]
+ probs_reshaped = probs.reshape((inputs.shape[0],) + latent_spatial_dim)
+ if resample_latent_likelihoods:
+ resizer = nn.Upsample(size=inputs.shape[2:], mode=resample_interpolation_mode)
+ probs_reshaped = resizer(probs_reshaped[:, None, ...])
+
+ return probs_reshaped
diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py
index a080284e7c..edaf736091 100644
--- a/monai/inferers/utils.py
+++ b/monai/inferers/utils.py
@@ -12,8 +12,8 @@
from __future__ import annotations
import itertools
-from collections.abc import Callable, Mapping, Sequence
-from typing import Any, Iterable
+from collections.abc import Callable, Iterable, Mapping, Sequence
+from typing import Any
import numpy as np
import torch
@@ -300,6 +300,7 @@ def sliding_window_inference(
# remove padding if image_size smaller than roi_size
if any(pad_size):
+ kwargs.update({"pad_size": pad_size})
for ss, output_i in enumerate(output_image_list):
zoom_scale = [_shape_d / _roi_size_d for _shape_d, _roi_size_d in zip(output_i.shape[2:], roi_size)]
final_slicing: list[slice] = []
diff --git a/monai/losses/__init__.py b/monai/losses/__init__.py
index e937b53fa4..41935be204 100644
--- a/monai/losses/__init__.py
+++ b/monai/losses/__init__.py
@@ -37,6 +37,7 @@
from .hausdorff_loss import HausdorffDTLoss, LogHausdorffDTLoss
from .image_dissimilarity import GlobalMutualInformationLoss, LocalNormalizedCrossCorrelationLoss
from .multi_scale import MultiScaleLoss
+from .nacl_loss import NACLLoss
from .perceptual import PerceptualLoss
from .spatial_mask import MaskedLoss
from .spectral_loss import JukeboxLoss
diff --git a/monai/losses/dice.py b/monai/losses/dice.py
index b3c0f57c6e..3f02fae6b8 100644
--- a/monai/losses/dice.py
+++ b/monai/losses/dice.py
@@ -24,7 +24,7 @@
from monai.losses.focal_loss import FocalLoss
from monai.losses.spatial_mask import MaskedLoss
from monai.networks import one_hot
-from monai.utils import DiceCEReduction, LossReduction, Weight, deprecated_arg, look_up_option, pytorch_after
+from monai.utils import DiceCEReduction, LossReduction, Weight, look_up_option, pytorch_after
class DiceLoss(_Loss):
@@ -646,9 +646,6 @@ class DiceCELoss(_Loss):
"""
- @deprecated_arg(
- "ce_weight", since="1.2", removed="1.4", new_name="weight", msg_suffix="please use `weight` instead."
- )
def __init__(
self,
include_background: bool = True,
@@ -662,10 +659,10 @@ def __init__(
smooth_nr: float = 1e-5,
smooth_dr: float = 1e-5,
batch: bool = False,
- ce_weight: torch.Tensor | None = None,
weight: torch.Tensor | None = None,
lambda_dice: float = 1.0,
lambda_ce: float = 1.0,
+ label_smoothing: float = 0.0,
) -> None:
"""
Args:
@@ -704,11 +701,13 @@ def __init__(
Defaults to 1.0.
lambda_ce: the trade-off weight value for cross entropy loss. The value should be no less than 0.0.
Defaults to 1.0.
+ label_smoothing: a value in [0, 1] range. If > 0, the labels are smoothed
+ by the given factor to reduce overfitting.
+ Defaults to 0.0.
"""
super().__init__()
reduction = look_up_option(reduction, DiceCEReduction).value
- weight = ce_weight if ce_weight is not None else weight
dice_weight: torch.Tensor | None
if weight is not None and not include_background:
dice_weight = weight[1:]
@@ -728,7 +727,12 @@ def __init__(
batch=batch,
weight=dice_weight,
)
- self.cross_entropy = nn.CrossEntropyLoss(weight=weight, reduction=reduction)
+ if pytorch_after(1, 10):
+ self.cross_entropy = nn.CrossEntropyLoss(
+ weight=weight, reduction=reduction, label_smoothing=label_smoothing
+ )
+ else:
+ self.cross_entropy = nn.CrossEntropyLoss(weight=weight, reduction=reduction)
self.binary_cross_entropy = nn.BCEWithLogitsLoss(pos_weight=weight, reduction=reduction)
if lambda_dice < 0.0:
raise ValueError("lambda_dice should be no less than 0.0.")
@@ -778,12 +782,22 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
Raises:
ValueError: When number of dimensions for input and target are different.
- ValueError: When number of channels for target is neither 1 nor the same as input.
+ ValueError: When number of channels for target is neither 1 (without one-hot encoding) nor the same as input.
+
+ Returns:
+ torch.Tensor: value of the loss.
"""
- if len(input.shape) != len(target.shape):
+ if input.dim() != target.dim():
raise ValueError(
"the number of dimensions for input and target should be the same, "
+ f"got shape {input.shape} (nb dims: {len(input.shape)}) and {target.shape} (nb dims: {len(target.shape)}). "
+ "if target is not one-hot encoded, please provide a tensor with shape B1H[WD]."
+ )
+
+ if target.shape[1] != 1 and target.shape[1] != input.shape[1]:
+ raise ValueError(
+ "number of channels for target is neither 1 (without one-hot encoding) nor the same as input, "
f"got shape {input.shape} and {target.shape}."
)
@@ -801,14 +815,11 @@ class DiceFocalLoss(_Loss):
The details of Focal Loss is shown in ``monai.losses.FocalLoss``.
``gamma`` and ``lambda_focal`` are only used for the focal loss.
- ``include_background``, ``weight`` and ``reduction`` are used for both losses
+ ``include_background``, ``weight``, ``reduction``, and ``alpha`` are used for both losses,
and other parameters are only used for dice loss.
"""
- @deprecated_arg(
- "focal_weight", since="1.2", removed="1.4", new_name="weight", msg_suffix="please use `weight` instead."
- )
def __init__(
self,
include_background: bool = True,
@@ -823,10 +834,10 @@ def __init__(
smooth_dr: float = 1e-5,
batch: bool = False,
gamma: float = 2.0,
- focal_weight: Sequence[float] | float | int | torch.Tensor | None = None,
weight: Sequence[float] | float | int | torch.Tensor | None = None,
lambda_dice: float = 1.0,
lambda_focal: float = 1.0,
+ alpha: float | None = None,
) -> None:
"""
Args:
@@ -861,10 +872,10 @@ def __init__(
Defaults to 1.0.
lambda_focal: the trade-off weight value for focal loss. The value should be no less than 0.0.
Defaults to 1.0.
-
+ alpha: value of the alpha in the definition of the alpha-balanced Focal loss. The value should be in
+ [0, 1]. Defaults to None.
"""
super().__init__()
- weight = focal_weight if focal_weight is not None else weight
self.dice = DiceLoss(
include_background=include_background,
to_onehot_y=False,
@@ -880,7 +891,12 @@ def __init__(
weight=weight,
)
self.focal = FocalLoss(
- include_background=include_background, to_onehot_y=False, gamma=gamma, weight=weight, reduction=reduction
+ include_background=include_background,
+ to_onehot_y=False,
+ gamma=gamma,
+ weight=weight,
+ alpha=alpha,
+ reduction=reduction,
)
if lambda_dice < 0.0:
raise ValueError("lambda_dice should be no less than 0.0.")
@@ -899,14 +915,24 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
Raises:
ValueError: When number of dimensions for input and target are different.
- ValueError: When number of channels for target is neither 1 nor the same as input.
+ ValueError: When number of channels for target is neither 1 (without one-hot encoding) nor the same as input.
+ Returns:
+ torch.Tensor: value of the loss.
"""
- if len(input.shape) != len(target.shape):
+ if input.dim() != target.dim():
raise ValueError(
"the number of dimensions for input and target should be the same, "
+ f"got shape {input.shape} (nb dims: {len(input.shape)}) and {target.shape} (nb dims: {len(target.shape)}). "
+ "if target is not one-hot encoded, please provide a tensor with shape B1H[WD]."
+ )
+
+ if target.shape[1] != 1 and target.shape[1] != input.shape[1]:
+ raise ValueError(
+ "number of channels for target is neither 1 (without one-hot encoding) nor the same as input, "
f"got shape {input.shape} and {target.shape}."
)
+
if self.to_onehot_y:
n_pred_ch = input.shape[1]
if n_pred_ch == 1:
@@ -958,9 +984,6 @@ class GeneralizedDiceFocalLoss(_Loss):
ValueError: if either `lambda_gdl` or `lambda_focal` is less than 0.
"""
- @deprecated_arg(
- "focal_weight", since="1.2", removed="1.4", new_name="weight", msg_suffix="please use `weight` instead."
- )
def __init__(
self,
include_background: bool = True,
@@ -974,7 +997,6 @@ def __init__(
smooth_dr: float = 1e-5,
batch: bool = False,
gamma: float = 2.0,
- focal_weight: Sequence[float] | float | int | torch.Tensor | None = None,
weight: Sequence[float] | float | int | torch.Tensor | None = None,
lambda_gdl: float = 1.0,
lambda_focal: float = 1.0,
@@ -992,7 +1014,6 @@ def __init__(
smooth_dr=smooth_dr,
batch=batch,
)
- weight = focal_weight if focal_weight is not None else weight
self.focal = FocalLoss(
include_background=include_background,
to_onehot_y=to_onehot_y,
@@ -1015,15 +1036,23 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
target (torch.Tensor): the shape should be BNH[WD] or B1H[WD].
Raises:
- ValueError: When the input and target tensors have different numbers of dimensions, or the target
- channel isn't either one-hot encoded or categorical with the same shape of the input.
+ ValueError: When number of dimensions for input and target are different.
+ ValueError: When number of channels for target is neither 1 (without one-hot encoding) nor the same as input.
Returns:
torch.Tensor: value of the loss.
"""
if input.dim() != target.dim():
raise ValueError(
- f"Input - {input.shape} - and target - {target.shape} - must have the same number of dimensions."
+ "the number of dimensions for input and target should be the same, "
+ f"got shape {input.shape} (nb dims: {len(input.shape)}) and {target.shape} (nb dims: {len(target.shape)}). "
+ "if target is not one-hot encoded, please provide a tensor with shape B1H[WD]."
+ )
+
+ if target.shape[1] != 1 and target.shape[1] != input.shape[1]:
+ raise ValueError(
+ "number of channels for target is neither 1 (without one-hot encoding) nor the same as input, "
+ f"got shape {input.shape} and {target.shape}."
)
gdl_loss = self.generalized_dice(input, target)
diff --git a/monai/losses/ds_loss.py b/monai/losses/ds_loss.py
index 57fcff6b87..aacc16874d 100644
--- a/monai/losses/ds_loss.py
+++ b/monai/losses/ds_loss.py
@@ -33,7 +33,7 @@ def __init__(self, loss: _Loss, weight_mode: str = "exp", weights: list[float] |
weight_mode: {``"same"``, ``"exp"``, ``"two"``}
Specifies the weights calculation for each image level. Defaults to ``"exp"``.
- ``"same"``: all weights are equal to 1.
- - ``"exp"``: exponentially decreasing weights by a power of 2: 0, 0.5, 0.25, 0.125, etc .
+ - ``"exp"``: exponentially decreasing weights by a power of 2: 1, 0.5, 0.25, 0.125, etc .
- ``"two"``: equal smaller weights for lower levels: 1, 0.5, 0.5, 0.5, 0.5, etc
weights: a list of weights to apply to each deeply supervised sub-loss, if provided, this will be used
regardless of the weight_mode
diff --git a/monai/losses/hausdorff_loss.py b/monai/losses/hausdorff_loss.py
index 6117f27741..c58be2d253 100644
--- a/monai/losses/hausdorff_loss.py
+++ b/monai/losses/hausdorff_loss.py
@@ -79,7 +79,7 @@ def __init__(
Incompatible values.
"""
- super(HausdorffDTLoss, self).__init__(reduction=LossReduction(reduction).value)
+ super().__init__(reduction=LossReduction(reduction).value)
if other_act is not None and not callable(other_act):
raise TypeError(f"other_act must be None or callable but is {type(other_act).__name__}.")
if int(sigmoid) + int(softmax) > 1:
diff --git a/monai/losses/nacl_loss.py b/monai/losses/nacl_loss.py
new file mode 100644
index 0000000000..27a712d308
--- /dev/null
+++ b/monai/losses/nacl_loss.py
@@ -0,0 +1,139 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+from typing import Any
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn.modules.loss import _Loss
+
+from monai.networks.layers import GaussianFilter, MeanFilter
+
+
+class NACLLoss(_Loss):
+ """
+ Neighbor-Aware Calibration Loss (NACL) is primarily developed for developing calibrated models in image segmentation.
+ NACL computes standard cross-entropy loss with a linear penalty that enforces the logit distributions
+ to match a soft class proportion of surrounding pixel.
+
+ Murugesan, Balamurali, et al.
+ "Trust your neighbours: Penalty-based constraints for model calibration."
+ International Conference on Medical Image Computing and Computer-Assisted Intervention, MICCAI 2023.
+ https://arxiv.org/abs/2303.06268
+
+ Murugesan, Balamurali, et al.
+ "Neighbor-Aware Calibration of Segmentation Networks with Penalty-Based Constraints."
+ https://arxiv.org/abs/2401.14487
+ """
+
+ def __init__(
+ self,
+ classes: int,
+ dim: int,
+ kernel_size: int = 3,
+ kernel_ops: str = "mean",
+ distance_type: str = "l1",
+ alpha: float = 0.1,
+ sigma: float = 1.0,
+ ) -> None:
+ """
+ Args:
+ classes: number of classes
+ dim: dimension of data (supports 2d and 3d)
+ kernel_size: size of the spatial kernel
+ distance_type: l1/l2 distance between spatial kernel and predicted logits
+ alpha: weightage between cross entropy and logit constraint
+ sigma: sigma of gaussian
+ """
+
+ super().__init__()
+
+ if kernel_ops not in ["mean", "gaussian"]:
+ raise ValueError("Kernel ops must be either mean or gaussian")
+
+ if dim not in [2, 3]:
+ raise ValueError(f"Support 2d and 3d, got dim={dim}.")
+
+ if distance_type not in ["l1", "l2"]:
+ raise ValueError(f"Distance type must be either L1 or L2, got {distance_type}")
+
+ self.nc = classes
+ self.dim = dim
+ self.cross_entropy = nn.CrossEntropyLoss()
+ self.distance_type = distance_type
+ self.alpha = alpha
+ self.ks = kernel_size
+ self.svls_layer: Any
+
+ if kernel_ops == "mean":
+ self.svls_layer = MeanFilter(spatial_dims=dim, size=kernel_size)
+ self.svls_layer.filter = self.svls_layer.filter / (kernel_size**dim)
+ if kernel_ops == "gaussian":
+ self.svls_layer = GaussianFilter(spatial_dims=dim, sigma=sigma)
+
+ def get_constr_target(self, mask: torch.Tensor) -> torch.Tensor:
+ """
+ Converts the mask to one hot represenation and is smoothened with the selected spatial filter.
+
+ Args:
+ mask: the shape should be BH[WD].
+
+ Returns:
+ torch.Tensor: the shape would be BNH[WD], N being number of classes.
+ """
+ rmask: torch.Tensor
+
+ if self.dim == 2:
+ oh_labels = F.one_hot(mask.to(torch.int64), num_classes=self.nc).permute(0, 3, 1, 2).contiguous().float()
+ rmask = self.svls_layer(oh_labels)
+
+ if self.dim == 3:
+ oh_labels = F.one_hot(mask.to(torch.int64), num_classes=self.nc).permute(0, 4, 1, 2, 3).contiguous().float()
+ rmask = self.svls_layer(oh_labels)
+
+ return rmask
+
+ def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
+ """
+ Computes standard cross-entropy loss and constraints it neighbor aware logit penalty.
+
+ Args:
+ inputs: the shape should be BNH[WD], where N is the number of classes.
+ targets: the shape should be BH[WD].
+
+ Returns:
+ torch.Tensor: value of the loss.
+
+ Example:
+ >>> import torch
+ >>> from monai.losses import NACLLoss
+ >>> B, N, H, W = 8, 3, 64, 64
+ >>> input = torch.rand(B, N, H, W)
+ >>> target = torch.randint(0, N, (B, H, W))
+ >>> criterion = NACLLoss(classes = N, dim = 2)
+ >>> loss = criterion(input, target)
+ """
+
+ loss_ce = self.cross_entropy(inputs, targets)
+
+ utargets = self.get_constr_target(targets)
+
+ if self.distance_type == "l1":
+ loss_conf = utargets.sub(inputs).abs_().mean()
+ elif self.distance_type == "l2":
+ loss_conf = utargets.sub(inputs).pow_(2).abs_().mean()
+
+ loss: torch.Tensor = loss_ce + self.alpha * loss_conf
+
+ return loss
diff --git a/monai/losses/perceptual.py b/monai/losses/perceptual.py
index fd61603b03..a8ae90993a 100644
--- a/monai/losses/perceptual.py
+++ b/monai/losses/perceptual.py
@@ -45,6 +45,7 @@ class PerceptualLoss(nn.Module):
The fake 3D implementation is based on a 2.5D approach where we calculate the 2D perceptual loss on slices from all
three axes and average. The full 3D approach uses a 3D network to calculate the perceptual loss.
+ MedicalNet networks are only compatible with 3D inputs and support channel-wise loss.
Args:
spatial_dims: number of spatial dimensions.
@@ -62,6 +63,8 @@ class PerceptualLoss(nn.Module):
pretrained_state_dict_key: if `pretrained_path` is not `None`, this argument is used to
extract the expected state dict. This argument only works when ``"network_type"`` is "resnet50".
Defaults to `None`.
+ channel_wise: if True, the loss is returned per channel. Otherwise the loss is averaged over the channels.
+ Defaults to ``False``.
"""
def __init__(
@@ -74,6 +77,7 @@ def __init__(
pretrained: bool = True,
pretrained_path: str | None = None,
pretrained_state_dict_key: str | None = None,
+ channel_wise: bool = False,
):
super().__init__()
@@ -86,6 +90,9 @@ def __init__(
"Argument is_fake_3d must be set to False."
)
+ if channel_wise and "medicalnet_" not in network_type:
+ raise ValueError("Channel-wise loss is only compatible with MedicalNet networks.")
+
if network_type.lower() not in list(PercetualNetworkType):
raise ValueError(
"Unrecognised criterion entered for Adversarial Loss. Must be one in: %s"
@@ -102,7 +109,9 @@ def __init__(
self.spatial_dims = spatial_dims
self.perceptual_function: nn.Module
if spatial_dims == 3 and is_fake_3d is False:
- self.perceptual_function = MedicalNetPerceptualSimilarity(net=network_type, verbose=False)
+ self.perceptual_function = MedicalNetPerceptualSimilarity(
+ net=network_type, verbose=False, channel_wise=channel_wise
+ )
elif "radimagenet_" in network_type:
self.perceptual_function = RadImageNetPerceptualSimilarity(net=network_type, verbose=False)
elif network_type == "resnet50":
@@ -116,6 +125,7 @@ def __init__(
self.perceptual_function = LPIPS(pretrained=pretrained, net=network_type, verbose=False)
self.is_fake_3d = is_fake_3d
self.fake_3d_ratio = fake_3d_ratio
+ self.channel_wise = channel_wise
def _calculate_axis_loss(self, input: torch.Tensor, target: torch.Tensor, spatial_axis: int) -> torch.Tensor:
"""
@@ -172,7 +182,12 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
# 2D and real 3D cases
loss = self.perceptual_function(input, target)
- return torch.mean(loss)
+ if self.channel_wise:
+ loss = torch.mean(loss.squeeze(), dim=0)
+ else:
+ loss = torch.mean(loss)
+
+ return loss
class MedicalNetPerceptualSimilarity(nn.Module):
@@ -185,14 +200,20 @@ class MedicalNetPerceptualSimilarity(nn.Module):
net: {``"medicalnet_resnet10_23datasets"``, ``"medicalnet_resnet50_23datasets"``}
Specifies the network architecture to use. Defaults to ``"medicalnet_resnet10_23datasets"``.
verbose: if false, mute messages from torch Hub load function.
+ channel_wise: if True, the loss is returned per channel. Otherwise the loss is averaged over the channels.
+ Defaults to ``False``.
"""
- def __init__(self, net: str = "medicalnet_resnet10_23datasets", verbose: bool = False) -> None:
+ def __init__(
+ self, net: str = "medicalnet_resnet10_23datasets", verbose: bool = False, channel_wise: bool = False
+ ) -> None:
super().__init__()
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
self.model = torch.hub.load("warvito/MedicalNet-models", model=net, verbose=verbose)
self.eval()
+ self.channel_wise = channel_wise
+
for param in self.parameters():
param.requires_grad = False
@@ -206,20 +227,42 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
Args:
input: 3D input tensor with shape BCDHW.
target: 3D target tensor with shape BCDHW.
+
"""
input = medicalnet_intensity_normalisation(input)
target = medicalnet_intensity_normalisation(target)
# Get model outputs
- outs_input = self.model.forward(input)
- outs_target = self.model.forward(target)
+ feats_per_ch = 0
+ for ch_idx in range(input.shape[1]):
+ input_channel = input[:, ch_idx, ...].unsqueeze(1)
+ target_channel = target[:, ch_idx, ...].unsqueeze(1)
+
+ if ch_idx == 0:
+ outs_input = self.model.forward(input_channel)
+ outs_target = self.model.forward(target_channel)
+ feats_per_ch = outs_input.shape[1]
+ else:
+ outs_input = torch.cat([outs_input, self.model.forward(input_channel)], dim=1)
+ outs_target = torch.cat([outs_target, self.model.forward(target_channel)], dim=1)
# Normalise through the channels
feats_input = normalize_tensor(outs_input)
feats_target = normalize_tensor(outs_target)
- results: torch.Tensor = (feats_input - feats_target) ** 2
- results = spatial_average_3d(results.sum(dim=1, keepdim=True), keepdim=True)
+ feats_diff: torch.Tensor = (feats_input - feats_target) ** 2
+ if self.channel_wise:
+ results = torch.zeros(
+ feats_diff.shape[0], input.shape[1], feats_diff.shape[2], feats_diff.shape[3], feats_diff.shape[4]
+ )
+ for i in range(input.shape[1]):
+ l_idx = i * feats_per_ch
+ r_idx = (i + 1) * feats_per_ch
+ results[:, i, ...] = feats_diff[:, l_idx : i + r_idx, ...].sum(dim=1)
+ else:
+ results = feats_diff.sum(dim=1, keepdim=True)
+
+ results = spatial_average_3d(results, keepdim=True)
return results
diff --git a/monai/metrics/cumulative_average.py b/monai/metrics/cumulative_average.py
index e55e7b8576..dccf7b094b 100644
--- a/monai/metrics/cumulative_average.py
+++ b/monai/metrics/cumulative_average.py
@@ -65,6 +65,7 @@ def get_current(self, to_numpy: bool = True) -> NdarrayOrTensor:
if self.val is None:
return 0
+ val: NdarrayOrTensor
val = self.val.clone()
val[~torch.isfinite(val)] = 0
@@ -96,6 +97,7 @@ def aggregate(self, to_numpy: bool = True) -> NdarrayOrTensor:
dist.all_reduce(sum)
dist.all_reduce(count)
+ val: NdarrayOrTensor
val = torch.where(count > 0, sum / count, sum)
if to_numpy:
diff --git a/monai/metrics/fid.py b/monai/metrics/fid.py
index d655ac1bee..596f9aef7c 100644
--- a/monai/metrics/fid.py
+++ b/monai/metrics/fid.py
@@ -82,7 +82,7 @@ def _cov(input_data: torch.Tensor, rowvar: bool = True) -> torch.Tensor:
def _sqrtm(input_data: torch.Tensor) -> torch.Tensor:
"""Compute the square root of a matrix."""
- scipy_res, _ = scipy.linalg.sqrtm(input_data.detach().cpu().numpy().astype(np.float_), disp=False)
+ scipy_res, _ = scipy.linalg.sqrtm(input_data.detach().cpu().numpy().astype(np.float64), disp=False)
return torch.from_numpy(scipy_res)
diff --git a/monai/metrics/generalized_dice.py b/monai/metrics/generalized_dice.py
index e56bd46592..516021949b 100644
--- a/monai/metrics/generalized_dice.py
+++ b/monai/metrics/generalized_dice.py
@@ -14,34 +14,47 @@
import torch
from monai.metrics.utils import do_metric_reduction, ignore_background
-from monai.utils import MetricReduction, Weight, look_up_option
+from monai.utils import MetricReduction, Weight, deprecated_arg, deprecated_arg_default, look_up_option
from .metric import CumulativeIterationMetric
class GeneralizedDiceScore(CumulativeIterationMetric):
- """Compute the Generalized Dice Score metric between tensors, as the complement of the Generalized Dice Loss defined in:
+ """
+ Compute the Generalized Dice Score metric between tensors.
+ This metric is the complement of the Generalized Dice Loss defined in:
Sudre, C. et. al. (2017) Generalised Dice overlap as a deep learning
- loss function for highly unbalanced segmentations. DLMIA 2017.
+ loss function for highly unbalanced segmentations. DLMIA 2017.
- The inputs `y_pred` and `y` are expected to be one-hot, binarized channel-first
- or batch-first tensors, i.e., CHW[D] or BCHW[D].
+ The inputs `y_pred` and `y` are expected to be one-hot, binarized batch-first tensors, i.e., NCHW[D].
Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.
Args:
- include_background (bool, optional): whether to include the background class (assumed to be in channel 0), in the
+ include_background: Whether to include the background class (assumed to be in channel 0) in the
score computation. Defaults to True.
- reduction (str, optional): define mode of reduction to the metrics. Available reduction modes:
- {``"none"``, ``"mean_batch"``, ``"sum_batch"``}. Default to ``"mean_batch"``. If "none", will not do reduction.
- weight_type (Union[Weight, str], optional): {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to transform
+ reduction: Define mode of reduction to the metrics. Available reduction modes:
+ {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
+ ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction.
+ weight_type: {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to transform
ground truth volume into a weight factor. Defaults to ``"square"``.
Raises:
- ValueError: when the `weight_type` is not one of {``"none"``, ``"mean"``, ``"sum"``}.
+ ValueError: When the `reduction` is not one of MetricReduction enum.
"""
+ @deprecated_arg_default(
+ "reduction",
+ old_default=MetricReduction.MEAN_BATCH,
+ new_default=MetricReduction.MEAN,
+ since="1.4.0",
+ replaced="1.5.0",
+ msg_suffix=(
+ "Old versions computed `mean` when `mean_batch` was provided due to bug in reduction, "
+ "If you want to retain the old behavior (calculating the mean), please explicitly set the parameter to 'mean'."
+ ),
+ )
def __init__(
self,
include_background: bool = True,
@@ -50,79 +63,90 @@ def __init__(
) -> None:
super().__init__()
self.include_background = include_background
- reduction_options = [
- "none",
- "mean_batch",
- "sum_batch",
- MetricReduction.NONE,
- MetricReduction.MEAN_BATCH,
- MetricReduction.SUM_BATCH,
- ]
- self.reduction = reduction
- if self.reduction not in reduction_options:
- raise ValueError(f"reduction must be one of {reduction_options}")
+ self.reduction = look_up_option(reduction, MetricReduction)
self.weight_type = look_up_option(weight_type, Weight)
+ self.sum_over_classes = self.reduction in {
+ MetricReduction.SUM,
+ MetricReduction.MEAN,
+ MetricReduction.MEAN_CHANNEL,
+ MetricReduction.SUM_CHANNEL,
+ }
def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override]
- """Computes the Generalized Dice Score and returns a tensor with its per image values.
+ """
+ Computes the Generalized Dice Score and returns a tensor with its per image values.
Args:
- y_pred (torch.Tensor): binarized segmentation model output. It must be in one-hot format and in the NCHW[D] format,
+ y_pred (torch.Tensor): Binarized segmentation model output. It must be in one-hot format and in the NCHW[D] format,
where N is the batch dimension, C is the channel dimension, and the remaining are the spatial dimensions.
- y (torch.Tensor): binarized ground-truth. It must be in one-hot format and have the same shape as `y_pred`.
+ y (torch.Tensor): Binarized ground-truth. It must be in one-hot format and have the same shape as `y_pred`.
+
+ Returns:
+ torch.Tensor: Generalized Dice Score averaged across batch and class
Raises:
- ValueError: if `y_pred` and `y` have less than 3 dimensions, or `y_pred` and `y` don't have the same shape.
+ ValueError: If `y_pred` and `y` have less than 3 dimensions, or `y_pred` and `y` don't have the same shape.
"""
return compute_generalized_dice(
- y_pred=y_pred, y=y, include_background=self.include_background, weight_type=self.weight_type
+ y_pred=y_pred,
+ y=y,
+ include_background=self.include_background,
+ weight_type=self.weight_type,
+ sum_over_classes=self.sum_over_classes,
)
+ @deprecated_arg(
+ "reduction",
+ since="1.3.3",
+ removed="1.7.0",
+ msg_suffix="Reduction will be ignored. Set reduction during init. as gen.dice needs it during compute",
+ )
def aggregate(self, reduction: MetricReduction | str | None = None) -> torch.Tensor:
"""
Execute reduction logic for the output of `compute_generalized_dice`.
- Args:
- reduction (Union[MetricReduction, str, None], optional): define mode of reduction to the metrics.
- Available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``}.
- Defaults to ``"mean"``. If "none", will not do reduction.
+ Returns:
+ torch.Tensor: Aggregated metric value.
+
+ Raises:
+ ValueError: If the data to aggregate is not a PyTorch Tensor.
"""
data = self.get_buffer()
if not isinstance(data, torch.Tensor):
raise ValueError("The data to aggregate must be a PyTorch Tensor.")
- # Validate reduction argument if specified
- if reduction is not None:
- reduction_options = ["none", "mean", "sum", "mean_batch", "sum_batch"]
- if reduction not in reduction_options:
- raise ValueError(f"reduction must be one of {reduction_options}")
-
# Do metric reduction and return
- f, _ = do_metric_reduction(data, reduction or self.reduction)
+ f, _ = do_metric_reduction(data, self.reduction)
return f
def compute_generalized_dice(
- y_pred: torch.Tensor, y: torch.Tensor, include_background: bool = True, weight_type: Weight | str = Weight.SQUARE
+ y_pred: torch.Tensor,
+ y: torch.Tensor,
+ include_background: bool = True,
+ weight_type: Weight | str = Weight.SQUARE,
+ sum_over_classes: bool = False,
) -> torch.Tensor:
- """Computes the Generalized Dice Score and returns a tensor with its per image values.
+ """
+ Computes the Generalized Dice Score and returns a tensor with its per image values.
Args:
- y_pred (torch.Tensor): binarized segmentation model output. It should be binarized, in one-hot format
+ y_pred (torch.Tensor): Binarized segmentation model output. It should be binarized, in one-hot format
and in the NCHW[D] format, where N is the batch dimension, C is the channel dimension, and the
remaining are the spatial dimensions.
- y (torch.Tensor): binarized ground-truth. It should be binarized, in one-hot format and have the same shape as `y_pred`.
- include_background (bool, optional): whether to include score computation on the first channel of the
+ y (torch.Tensor): Binarized ground-truth. It should be binarized, in one-hot format and have the same shape as `y_pred`.
+ include_background: Whether to include score computation on the first channel of the
predicted output. Defaults to True.
weight_type (Union[Weight, str], optional): {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to
transform ground truth volume into a weight factor. Defaults to ``"square"``.
+ sum_over_labels (bool): Whether to sum the numerator and denominator across all labels before the final computation.
Returns:
- torch.Tensor: per batch and per class Generalized Dice Score, i.e., with the shape [batch_size, num_classes].
+ torch.Tensor: Per batch and per class Generalized Dice Score, i.e., with the shape [batch_size, num_classes].
Raises:
- ValueError: if `y_pred` or `y` are not PyTorch tensors, if `y_pred` and `y` have less than three dimensions,
+ ValueError: If `y_pred` or `y` are not PyTorch tensors, if `y_pred` and `y` have less than three dimensions,
or `y_pred` and `y` don't have the same shape.
"""
# Ensure tensors have at least 3 dimensions and have the same shape
@@ -158,16 +182,21 @@ def compute_generalized_dice(
b[infs] = 0
b[infs] = torch.max(b)
- # Compute the weighted numerator and denominator, summing along the class axis
- numer = 2.0 * (intersection * w).sum(dim=1)
- denom = (denominator * w).sum(dim=1)
+ # Compute the weighted numerator and denominator, summing along the class axis when sum_over_classes is True
+ if sum_over_classes:
+ numer = 2.0 * (intersection * w).sum(dim=1, keepdim=True)
+ denom = (denominator * w).sum(dim=1, keepdim=True)
+ y_pred_o = y_pred_o.sum(dim=-1, keepdim=True)
+ else:
+ numer = 2.0 * (intersection * w)
+ denom = denominator * w
+ y_pred_o = y_pred_o
# Compute the score
generalized_dice_score = numer / denom
# Handle zero division. Where denom == 0 and the prediction volume is 0, score is 1.
# Where denom == 0 but the prediction volume is not 0, score is 0
- y_pred_o = y_pred_o.sum(dim=-1)
denom_zeros = denom == 0
generalized_dice_score[denom_zeros] = torch.where(
(y_pred_o == 0)[denom_zeros],
diff --git a/monai/metrics/panoptic_quality.py b/monai/metrics/panoptic_quality.py
index 05175ba0fb..7c9d59c264 100644
--- a/monai/metrics/panoptic_quality.py
+++ b/monai/metrics/panoptic_quality.py
@@ -274,7 +274,7 @@ def _get_paired_iou(
return paired_iou, paired_true, paired_pred
- pairwise_iou = pairwise_iou.cpu().numpy()
+ pairwise_iou = pairwise_iou.cpu().numpy() # type: ignore[assignment]
paired_true, paired_pred = linear_sum_assignment(-pairwise_iou)
paired_iou = pairwise_iou[paired_true, paired_pred]
paired_true = torch.as_tensor(list(paired_true[paired_iou > match_iou_threshold] + 1), device=device)
diff --git a/monai/metrics/regression.py b/monai/metrics/regression.py
index 9d29654ee3..4c8b8aa71b 100644
--- a/monai/metrics/regression.py
+++ b/monai/metrics/regression.py
@@ -303,7 +303,7 @@ def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor
if self.spatial_dims == 3 and dims != 5:
raise ValueError(
- f"y_pred should have 4 dimensions (batch, channel, height, width, depth) when using {self.spatial_dims}"
+ f"y_pred should have 5 dimensions (batch, channel, height, width, depth) when using {self.spatial_dims}"
f" spatial dimensions, got {dims}."
)
diff --git a/monai/metrics/rocauc.py b/monai/metrics/rocauc.py
index 56d9faa9dd..57a8a072b4 100644
--- a/monai/metrics/rocauc.py
+++ b/monai/metrics/rocauc.py
@@ -88,8 +88,8 @@ def _calculate(y_pred: torch.Tensor, y: torch.Tensor) -> float:
n = len(y)
indices = y_pred.argsort()
- y = y[indices].cpu().numpy()
- y_pred = y_pred[indices].cpu().numpy()
+ y = y[indices].cpu().numpy() # type: ignore[assignment]
+ y_pred = y_pred[indices].cpu().numpy() # type: ignore[assignment]
nneg = auc = tmp_pos = tmp_neg = 0.0
for i in range(n):
diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py
index e7057256fb..96d60c9098 100644
--- a/monai/metrics/utils.py
+++ b/monai/metrics/utils.py
@@ -12,9 +12,10 @@
from __future__ import annotations
import warnings
+from collections.abc import Iterable, Sequence
from functools import lru_cache, partial
from types import ModuleType
-from typing import Any, Iterable, Sequence
+from typing import Any
import numpy as np
import torch
@@ -35,9 +36,9 @@
optional_import,
)
-binary_erosion, _ = optional_import("scipy.ndimage.morphology", name="binary_erosion")
-distance_transform_edt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_edt")
-distance_transform_cdt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_cdt")
+binary_erosion, _ = optional_import("scipy.ndimage", name="binary_erosion")
+distance_transform_edt, _ = optional_import("scipy.ndimage", name="distance_transform_edt")
+distance_transform_cdt, _ = optional_import("scipy.ndimage", name="distance_transform_cdt")
__all__ = [
"ignore_background",
diff --git a/monai/networks/__init__.py b/monai/networks/__init__.py
index 4c429ae813..5a240021d6 100644
--- a/monai/networks/__init__.py
+++ b/monai/networks/__init__.py
@@ -11,7 +11,9 @@
from __future__ import annotations
+from .trt_compiler import trt_compile
from .utils import (
+ add_casts_around_norms,
convert_to_onnx,
convert_to_torchscript,
convert_to_trt,
diff --git a/monai/networks/blocks/__init__.py b/monai/networks/blocks/__init__.py
index e67cb3376f..499caf2e0f 100644
--- a/monai/networks/blocks/__init__.py
+++ b/monai/networks/blocks/__init__.py
@@ -17,6 +17,7 @@
from .backbone_fpn_utils import BackboneWithFPN
from .convolutions import Convolution, ResidualUnit
from .crf import CRF
+from .crossattention import CrossAttentionBlock
from .denseblock import ConvDenseBlock, DenseBlock
from .dints_block import ActiConvNormBlock, FactorizedIncreaseBlock, FactorizedReduceBlock, P3DActiConvNormBlock
from .downsample import MaxAvgPool
@@ -25,11 +26,14 @@
from .fcn import FCN, GCN, MCFCN, Refine
from .feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork, LastLevelMaxPool, LastLevelP6P7
from .localnet_block import LocalNetDownSampleBlock, LocalNetFeatureExtractorBlock, LocalNetUpSampleBlock
+from .mednext_block import MedNeXtBlock, MedNeXtDownBlock, MedNeXtOutBlock, MedNeXtUpBlock
from .mlp import MLPBlock
from .patchembedding import PatchEmbed, PatchEmbeddingBlock
from .regunet_block import RegistrationDownSampleBlock, RegistrationExtractionBlock, RegistrationResidualConvBlock
from .segresnet_block import ResBlock
from .selfattention import SABlock
+from .spade_norm import SPADE
+from .spatialattention import SpatialAttentionBlock
from .squeeze_and_excitation import (
ChannelSELayer,
ResidualSELayer,
diff --git a/monai/networks/blocks/attention_utils.py b/monai/networks/blocks/attention_utils.py
new file mode 100644
index 0000000000..8c9002a16e
--- /dev/null
+++ b/monai/networks/blocks/attention_utils.py
@@ -0,0 +1,128 @@
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+from typing import Tuple
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
+ """
+ Get relative positional embeddings according to the relative positions of
+ query and key sizes.
+
+ Args:
+ q_size (int): size of query q.
+ k_size (int): size of key k.
+ rel_pos (Tensor): relative position embeddings (L, C).
+
+ Returns:
+ Extracted positional embeddings according to relative positions.
+ """
+ rel_pos_resized: torch.Tensor = torch.Tensor()
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
+ # Interpolate rel pos if needed.
+ if rel_pos.shape[0] != max_rel_dist:
+ # Interpolate rel pos.
+ rel_pos_resized = F.interpolate(
+ rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), size=max_rel_dist, mode="linear"
+ )
+ rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
+ else:
+ rel_pos_resized = rel_pos
+
+ # Scale the coords with short length if shapes for q and k are different.
+ q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
+ k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
+ relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
+
+ return rel_pos_resized[relative_coords.long()]
+
+
+def add_decomposed_rel_pos(
+ attn: torch.Tensor, q: torch.Tensor, rel_pos_lst: nn.ParameterList, q_size: Tuple, k_size: Tuple
+) -> torch.Tensor:
+ r"""
+ Calculate decomposed Relative Positional Embeddings from mvitv2 implementation:
+ https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
+
+ Only 2D and 3D are supported.
+
+ Encoding the relative position of tokens in the attention matrix: tokens spaced a distance
+ `d` apart will have the same embedding value (unlike absolute positional embedding).
+
+ .. math::
+ Attn_{logits}(Q, K) = (QK^{T} + E_{rel})*scale
+
+ where
+
+ .. math::
+ E_{ij}^{(rel)} = Q_{i}.R_{p(i), p(j)}
+
+ with :math:`R_{p(i), p(j)} \in R^{dim}` and :math:`p(i), p(j)`,
+ respectively spatial positions of element :math:`i` and :math:`j`
+
+ When using "decomposed" relative positional embedding, positional embedding is defined ("decomposed") as follow:
+
+ .. math::
+ R_{p(i), p(j)} = R^{d1}_{d1(i), d1(j)} + ... + R^{dn}_{dn(i), dn(j)}
+
+ with :math:`n = 1...dim`
+
+ Decomposed relative positional embedding reduces the complexity from :math:`\mathcal{O}(d1*...*dn)` to
+ :math:`\mathcal{O}(d1+...+dn)` compared with classical relative positional embedding.
+
+ Args:
+ attn (Tensor): attention map.
+ q (Tensor): query q in the attention layer with shape (B, s_dim_1 * ... * s_dim_n, C).
+ rel_pos_lst (ParameterList): relative position embeddings for each axis: rel_pos_lst[n] for nth axis.
+ q_size (Tuple): spatial sequence size of query q with (q_dim_1, ..., q_dim_n).
+ k_size (Tuple): spatial sequence size of key k with (k_dim_1, ..., k_dim_n).
+
+ Returns:
+ attn (Tensor): attention logits with added relative positional embeddings.
+ """
+ rh = get_rel_pos(q_size[0], k_size[0], rel_pos_lst[0])
+ rw = get_rel_pos(q_size[1], k_size[1], rel_pos_lst[1])
+
+ batch, _, dim = q.shape
+
+ if len(rel_pos_lst) == 2:
+ q_h, q_w = q_size[:2]
+ k_h, k_w = k_size[:2]
+ r_q = q.reshape(batch, q_h, q_w, dim)
+ rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, rh)
+ rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, rw)
+
+ attn = (attn.view(batch, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(
+ batch, q_h * q_w, k_h * k_w
+ )
+ elif len(rel_pos_lst) == 3:
+ q_h, q_w, q_d = q_size[:3]
+ k_h, k_w, k_d = k_size[:3]
+
+ rd = get_rel_pos(q_d, k_d, rel_pos_lst[2])
+
+ r_q = q.reshape(batch, q_h, q_w, q_d, dim)
+ rel_h = torch.einsum("bhwdc,hkc->bhwdk", r_q, rh)
+ rel_w = torch.einsum("bhwdc,wkc->bhwdk", r_q, rw)
+ rel_d = torch.einsum("bhwdc,wkc->bhwdk", r_q, rd)
+
+ attn = (
+ attn.view(batch, q_h, q_w, q_d, k_h, k_w, k_d)
+ + rel_h[:, :, :, :, None, None]
+ + rel_w[:, :, :, None, :, None]
+ + rel_d[:, :, :, None, None, :]
+ ).view(batch, q_h * q_w * q_d, k_h * k_w * k_d)
+
+ return attn
diff --git a/monai/networks/blocks/crossattention.py b/monai/networks/blocks/crossattention.py
new file mode 100644
index 0000000000..bdecf63168
--- /dev/null
+++ b/monai/networks/blocks/crossattention.py
@@ -0,0 +1,190 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+from typing import Optional, Tuple
+
+import torch
+import torch.nn as nn
+
+from monai.networks.layers.utils import get_rel_pos_embedding_layer
+from monai.utils import optional_import, pytorch_after
+
+Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")
+
+
+class CrossAttentionBlock(nn.Module):
+ """
+ A cross-attention block, based on: "Dosovitskiy et al.,
+ An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale "
+ One can setup relative positional embedding as described in
+ """
+
+ def __init__(
+ self,
+ hidden_size: int,
+ num_heads: int,
+ dropout_rate: float = 0.0,
+ hidden_input_size: int | None = None,
+ context_input_size: int | None = None,
+ dim_head: int | None = None,
+ qkv_bias: bool = False,
+ save_attn: bool = False,
+ causal: bool = False,
+ sequence_length: int | None = None,
+ rel_pos_embedding: Optional[str] = None,
+ input_size: Optional[Tuple] = None,
+ attention_dtype: Optional[torch.dtype] = None,
+ use_flash_attention: bool = False,
+ ) -> None:
+ """
+ Args:
+ hidden_size (int): dimension of hidden layer.
+ num_heads (int): number of attention heads.
+ dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0.
+ hidden_input_size (int, optional): dimension of the input tensor. Defaults to hidden_size.
+ context_input_size (int, optional): dimension of the context tensor. Defaults to hidden_size.
+ dim_head (int, optional): dimension of each head. Defaults to hidden_size // num_heads.
+ qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False.
+ save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.
+ causal (bool, optional): whether to use causal attention.
+ sequence_length (int, optional): if causal is True, it is necessary to specify the sequence length.
+ rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map. For now only
+ "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported.
+ input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative positional
+ parameter size.
+ attention_dtype: cast attention operations to this dtype.
+ use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
+ (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
+ """
+
+ super().__init__()
+
+ if not (0 <= dropout_rate <= 1):
+ raise ValueError("dropout_rate should be between 0 and 1.")
+
+ if dim_head:
+ inner_size = num_heads * dim_head
+ self.head_dim = dim_head
+ else:
+ if hidden_size % num_heads != 0:
+ raise ValueError("hidden size should be divisible by num_heads.")
+ inner_size = hidden_size
+ self.head_dim = hidden_size // num_heads
+
+ if causal and sequence_length is None:
+ raise ValueError("sequence_length is necessary for causal attention.")
+
+ if use_flash_attention and not pytorch_after(minor=13, major=1, patch=0):
+ raise ValueError(
+ "use_flash_attention is only supported for PyTorch versions >= 2.0."
+ "Upgrade your PyTorch or set the flag to False."
+ )
+ if use_flash_attention and save_attn:
+ raise ValueError(
+ "save_attn has been set to True, but use_flash_attention is also set"
+ "to True. save_attn can only be used if use_flash_attention is False"
+ )
+
+ if use_flash_attention and rel_pos_embedding is not None:
+ raise ValueError("rel_pos_embedding must be None if you are using flash_attention.")
+
+ self.num_heads = num_heads
+ self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size
+ self.context_input_size = context_input_size if context_input_size else hidden_size
+ self.out_proj = nn.Linear(inner_size, self.hidden_input_size)
+ # key, query, value projections
+ self.to_q = nn.Linear(self.hidden_input_size, inner_size, bias=qkv_bias)
+ self.to_k = nn.Linear(self.context_input_size, inner_size, bias=qkv_bias)
+ self.to_v = nn.Linear(self.context_input_size, inner_size, bias=qkv_bias)
+ self.input_rearrange = Rearrange("b h (l d) -> b l h d", l=num_heads)
+
+ self.out_rearrange = Rearrange("b l h d -> b h (l d)")
+ self.drop_output = nn.Dropout(dropout_rate)
+ self.drop_weights = nn.Dropout(dropout_rate)
+ self.dropout_rate = dropout_rate
+
+ self.scale = self.head_dim**-0.5
+ self.save_attn = save_attn
+ self.attention_dtype = attention_dtype
+
+ self.causal = causal
+ self.sequence_length = sequence_length
+ self.use_flash_attention = use_flash_attention
+
+ if causal and sequence_length is not None:
+ # causal mask to ensure that attention is only applied to the left in the input sequence
+ self.register_buffer(
+ "causal_mask",
+ torch.tril(torch.ones(sequence_length, sequence_length)).view(1, 1, sequence_length, sequence_length),
+ )
+ self.causal_mask: torch.Tensor
+ else:
+ self.causal_mask = torch.Tensor()
+
+ self.att_mat = torch.Tensor()
+ self.rel_positional_embedding = (
+ get_rel_pos_embedding_layer(rel_pos_embedding, input_size, self.head_dim, self.num_heads)
+ if rel_pos_embedding is not None
+ else None
+ )
+ self.input_size = input_size
+
+ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None):
+ """
+ Args:
+ x (torch.Tensor): input tensor. B x (s_dim_1 * ... * s_dim_n) x C
+ context (torch.Tensor, optional): context tensor. B x (s_dim_1 * ... * s_dim_n) x C
+
+ Return:
+ torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C
+ """
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
+ b, t, c = x.size() # batch size, sequence length, embedding dimensionality (hidden_size)
+
+ q = self.input_rearrange(self.to_q(x))
+ kv = context if context is not None else x
+ _, kv_t, _ = kv.size()
+ k = self.input_rearrange(self.to_k(kv))
+ v = self.input_rearrange(self.to_v(kv))
+
+ if self.attention_dtype is not None:
+ q = q.to(self.attention_dtype)
+ k = k.to(self.attention_dtype)
+
+ if self.use_flash_attention:
+ x = torch.nn.functional.scaled_dot_product_attention(
+ query=q, key=k, value=v, scale=self.scale, dropout_p=self.dropout_rate, is_causal=self.causal
+ )
+ else:
+ att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale
+ # apply relative positional embedding if defined
+ if self.rel_positional_embedding is not None:
+ att_mat = self.rel_positional_embedding(x, att_mat, q)
+
+ if self.causal:
+ att_mat = att_mat.masked_fill(self.causal_mask[:, :, :t, :kv_t] == 0, float("-inf"))
+
+ att_mat = att_mat.softmax(dim=-1)
+
+ if self.save_attn:
+ # no gradients and new tensor;
+ # https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
+ self.att_mat = att_mat.detach()
+
+ att_mat = self.drop_weights(att_mat)
+ x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
+
+ x = self.out_rearrange(x)
+ x = self.out_proj(x)
+ x = self.drop_output(x)
+ return x
diff --git a/monai/networks/blocks/mednext_block.py b/monai/networks/blocks/mednext_block.py
new file mode 100644
index 0000000000..0aa2bb6b58
--- /dev/null
+++ b/monai/networks/blocks/mednext_block.py
@@ -0,0 +1,309 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Portions of this code are derived from the original repository at:
+# https://github.com/MIC-DKFZ/MedNeXt
+# and are used under the terms of the Apache License, Version 2.0.
+
+from __future__ import annotations
+
+import torch
+import torch.nn as nn
+
+all = ["MedNeXtBlock", "MedNeXtDownBlock", "MedNeXtUpBlock", "MedNeXtOutBlock"]
+
+
+def get_conv_layer(spatial_dim: int = 3, transpose: bool = False):
+ if spatial_dim == 2:
+ return nn.ConvTranspose2d if transpose else nn.Conv2d
+ else: # spatial_dim == 3
+ return nn.ConvTranspose3d if transpose else nn.Conv3d
+
+
+class MedNeXtBlock(nn.Module):
+ """
+ MedNeXtBlock class for the MedNeXt model.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ expansion_ratio (int): Expansion ratio for the block. Defaults to 4.
+ kernel_size (int): Kernel size for convolutions. Defaults to 7.
+ use_residual_connection (int): Whether to use residual connection. Defaults to True.
+ norm_type (str): Type of normalization to use. Defaults to "group".
+ dim (str): Dimension of the input. Can be "2d" or "3d". Defaults to "3d".
+ global_resp_norm (bool): Whether to use global response normalization. Defaults to False.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ expansion_ratio: int = 4,
+ kernel_size: int = 7,
+ use_residual_connection: int = True,
+ norm_type: str = "group",
+ dim="3d",
+ global_resp_norm=False,
+ ):
+
+ super().__init__()
+
+ self.do_res = use_residual_connection
+
+ self.dim = dim
+ conv = get_conv_layer(spatial_dim=2 if dim == "2d" else 3)
+ global_resp_norm_param_shape = (1,) * (2 if dim == "2d" else 3)
+ # First convolution layer with DepthWise Convolutions
+ self.conv1 = conv(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ kernel_size=kernel_size,
+ stride=1,
+ padding=kernel_size // 2,
+ groups=in_channels,
+ )
+
+ # Normalization Layer. GroupNorm is used by default.
+ if norm_type == "group":
+ self.norm = nn.GroupNorm(num_groups=in_channels, num_channels=in_channels) # type: ignore
+ elif norm_type == "layer":
+ self.norm = nn.LayerNorm(
+ normalized_shape=[in_channels] + [kernel_size] * (2 if dim == "2d" else 3) # type: ignore
+ )
+ # Second convolution (Expansion) layer with Conv3D 1x1x1
+ self.conv2 = conv(
+ in_channels=in_channels, out_channels=expansion_ratio * in_channels, kernel_size=1, stride=1, padding=0
+ )
+
+ # GeLU activations
+ self.act = nn.GELU()
+
+ # Third convolution (Compression) layer with Conv3D 1x1x1
+ self.conv3 = conv(
+ in_channels=expansion_ratio * in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0
+ )
+
+ self.global_resp_norm = global_resp_norm
+ if self.global_resp_norm:
+ global_resp_norm_param_shape = (1, expansion_ratio * in_channels) + global_resp_norm_param_shape
+ self.global_resp_beta = nn.Parameter(torch.zeros(global_resp_norm_param_shape), requires_grad=True)
+ self.global_resp_gamma = nn.Parameter(torch.zeros(global_resp_norm_param_shape), requires_grad=True)
+
+ def forward(self, x):
+ """
+ Forward pass of the MedNeXtBlock.
+
+ Args:
+ x (torch.Tensor): Input tensor.
+
+ Returns:
+ torch.Tensor: Output tensor.
+ """
+ x1 = x
+ x1 = self.conv1(x1)
+ x1 = self.act(self.conv2(self.norm(x1)))
+
+ if self.global_resp_norm:
+ # gamma, beta: learnable affine transform parameters
+ # X: input of shape (N,C,H,W,D)
+ if self.dim == "2d":
+ gx = torch.norm(x1, p=2, dim=(-2, -1), keepdim=True)
+ else:
+ gx = torch.norm(x1, p=2, dim=(-3, -2, -1), keepdim=True)
+ nx = gx / (gx.mean(dim=1, keepdim=True) + 1e-6)
+ x1 = self.global_resp_gamma * (x1 * nx) + self.global_resp_beta + x1
+ x1 = self.conv3(x1)
+ if self.do_res:
+ x1 = x + x1
+ return x1
+
+
+class MedNeXtDownBlock(MedNeXtBlock):
+ """
+ MedNeXtDownBlock class for downsampling in the MedNeXt model.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ expansion_ratio (int): Expansion ratio for the block. Defaults to 4.
+ kernel_size (int): Kernel size for convolutions. Defaults to 7.
+ use_residual_connection (bool): Whether to use residual connection. Defaults to False.
+ norm_type (str): Type of normalization to use. Defaults to "group".
+ dim (str): Dimension of the input. Can be "2d" or "3d". Defaults to "3d".
+ global_resp_norm (bool): Whether to use global response normalization. Defaults to False.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ expansion_ratio: int = 4,
+ kernel_size: int = 7,
+ use_residual_connection: bool = False,
+ norm_type: str = "group",
+ dim: str = "3d",
+ global_resp_norm: bool = False,
+ ):
+
+ super().__init__(
+ in_channels,
+ out_channels,
+ expansion_ratio,
+ kernel_size,
+ use_residual_connection=False,
+ norm_type=norm_type,
+ dim=dim,
+ global_resp_norm=global_resp_norm,
+ )
+
+ conv = get_conv_layer(spatial_dim=2 if dim == "2d" else 3)
+ self.resample_do_res = use_residual_connection
+ if use_residual_connection:
+ self.res_conv = conv(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=2)
+
+ self.conv1 = conv(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ kernel_size=kernel_size,
+ stride=2,
+ padding=kernel_size // 2,
+ groups=in_channels,
+ )
+
+ def forward(self, x):
+ """
+ Forward pass of the MedNeXtDownBlock.
+
+ Args:
+ x (torch.Tensor): Input tensor.
+
+ Returns:
+ torch.Tensor: Output tensor.
+ """
+ x1 = super().forward(x)
+
+ if self.resample_do_res:
+ res = self.res_conv(x)
+ x1 = x1 + res
+
+ return x1
+
+
+class MedNeXtUpBlock(MedNeXtBlock):
+ """
+ MedNeXtUpBlock class for upsampling in the MedNeXt model.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ expansion_ratio (int): Expansion ratio for the block. Defaults to 4.
+ kernel_size (int): Kernel size for convolutions. Defaults to 7.
+ use_residual_connection (bool): Whether to use residual connection. Defaults to False.
+ norm_type (str): Type of normalization to use. Defaults to "group".
+ dim (str): Dimension of the input. Can be "2d" or "3d". Defaults to "3d".
+ global_resp_norm (bool): Whether to use global response normalization. Defaults to False.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ expansion_ratio: int = 4,
+ kernel_size: int = 7,
+ use_residual_connection: bool = False,
+ norm_type: str = "group",
+ dim: str = "3d",
+ global_resp_norm: bool = False,
+ ):
+ super().__init__(
+ in_channels,
+ out_channels,
+ expansion_ratio,
+ kernel_size,
+ use_residual_connection=False,
+ norm_type=norm_type,
+ dim=dim,
+ global_resp_norm=global_resp_norm,
+ )
+
+ self.resample_do_res = use_residual_connection
+
+ self.dim = dim
+ conv = get_conv_layer(spatial_dim=2 if dim == "2d" else 3, transpose=True)
+ if use_residual_connection:
+ self.res_conv = conv(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=2)
+
+ self.conv1 = conv(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ kernel_size=kernel_size,
+ stride=2,
+ padding=kernel_size // 2,
+ groups=in_channels,
+ )
+
+ def forward(self, x):
+ """
+ Forward pass of the MedNeXtUpBlock.
+
+ Args:
+ x (torch.Tensor): Input tensor.
+
+ Returns:
+ torch.Tensor: Output tensor.
+ """
+ x1 = super().forward(x)
+ # Asymmetry but necessary to match shape
+
+ if self.dim == "2d":
+ x1 = torch.nn.functional.pad(x1, (1, 0, 1, 0))
+ else:
+ x1 = torch.nn.functional.pad(x1, (1, 0, 1, 0, 1, 0))
+
+ if self.resample_do_res:
+ res = self.res_conv(x)
+ if self.dim == "2d":
+ res = torch.nn.functional.pad(res, (1, 0, 1, 0))
+ else:
+ res = torch.nn.functional.pad(res, (1, 0, 1, 0, 1, 0))
+ x1 = x1 + res
+
+ return x1
+
+
+class MedNeXtOutBlock(nn.Module):
+ """
+ MedNeXtOutBlock class for the output block in the MedNeXt model.
+
+ Args:
+ in_channels (int): Number of input channels.
+ n_classes (int): Number of output classes.
+ dim (str): Dimension of the input. Can be "2d" or "3d".
+ """
+
+ def __init__(self, in_channels, n_classes, dim):
+ super().__init__()
+
+ conv = get_conv_layer(spatial_dim=2 if dim == "2d" else 3, transpose=True)
+ self.conv_out = conv(in_channels, n_classes, kernel_size=1)
+
+ def forward(self, x):
+ """
+ Forward pass of the MedNeXtOutBlock.
+
+ Args:
+ x (torch.Tensor): Input tensor.
+
+ Returns:
+ torch.Tensor: Output tensor.
+ """
+ return self.conv_out(x)
diff --git a/monai/networks/blocks/mlp.py b/monai/networks/blocks/mlp.py
index d3510b64d3..8771711d25 100644
--- a/monai/networks/blocks/mlp.py
+++ b/monai/networks/blocks/mlp.py
@@ -11,12 +11,15 @@
from __future__ import annotations
+from typing import Union
+
import torch.nn as nn
from monai.networks.layers import get_act_layer
+from monai.networks.layers.factories import split_args
from monai.utils import look_up_option
-SUPPORTED_DROPOUT_MODE = {"vit", "swin"}
+SUPPORTED_DROPOUT_MODE = {"vit", "swin", "vista3d"}
class MLPBlock(nn.Module):
@@ -39,7 +42,7 @@ def __init__(
https://github.com/google-research/vision_transformer/blob/main/vit_jax/models.py#L87
"swin" corresponds to one instance as implemented in
https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_mlp.py#L23
-
+ "vista3d" mode does not use dropout.
"""
@@ -48,15 +51,24 @@ def __init__(
if not (0 <= dropout_rate <= 1):
raise ValueError("dropout_rate should be between 0 and 1.")
mlp_dim = mlp_dim or hidden_size
- self.linear1 = nn.Linear(hidden_size, mlp_dim) if act != "GEGLU" else nn.Linear(hidden_size, mlp_dim * 2)
+ act_name, _ = split_args(act)
+ self.linear1 = nn.Linear(hidden_size, mlp_dim) if act_name != "GEGLU" else nn.Linear(hidden_size, mlp_dim * 2)
self.linear2 = nn.Linear(mlp_dim, hidden_size)
self.fn = get_act_layer(act)
- self.drop1 = nn.Dropout(dropout_rate)
+ # Use Union[nn.Dropout, nn.Identity] for type annotations
+ self.drop1: Union[nn.Dropout, nn.Identity]
+ self.drop2: Union[nn.Dropout, nn.Identity]
+
dropout_opt = look_up_option(dropout_mode, SUPPORTED_DROPOUT_MODE)
if dropout_opt == "vit":
+ self.drop1 = nn.Dropout(dropout_rate)
self.drop2 = nn.Dropout(dropout_rate)
elif dropout_opt == "swin":
+ self.drop1 = nn.Dropout(dropout_rate)
self.drop2 = self.drop1
+ elif dropout_opt == "vista3d":
+ self.drop1 = nn.Identity()
+ self.drop2 = nn.Identity()
else:
raise ValueError(f"dropout_mode should be one of {SUPPORTED_DROPOUT_MODE}")
diff --git a/monai/networks/blocks/patchembedding.py b/monai/networks/blocks/patchembedding.py
index 44774ce5da..fca566591a 100644
--- a/monai/networks/blocks/patchembedding.py
+++ b/monai/networks/blocks/patchembedding.py
@@ -21,7 +21,7 @@
from monai.networks.blocks.pos_embed_utils import build_sincos_position_embedding
from monai.networks.layers import Conv, trunc_normal_
-from monai.utils import deprecated_arg, ensure_tuple_rep, optional_import
+from monai.utils import ensure_tuple_rep, optional_import
from monai.utils.module import look_up_option
Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")
@@ -42,9 +42,6 @@ class PatchEmbeddingBlock(nn.Module):
"""
- @deprecated_arg(
- name="pos_embed", since="1.2", removed="1.4", new_name="proj_type", msg_suffix="please use `proj_type` instead."
- )
def __init__(
self,
in_channels: int,
@@ -52,7 +49,6 @@ def __init__(
patch_size: Sequence[int] | int,
hidden_size: int,
num_heads: int,
- pos_embed: str = "conv",
proj_type: str = "conv",
pos_embed_type: str = "learnable",
dropout_rate: float = 0.0,
@@ -69,8 +65,6 @@ def __init__(
pos_embed_type: position embedding layer type.
dropout_rate: fraction of the input units to drop.
spatial_dims: number of spatial dimensions.
- .. deprecated:: 1.4
- ``pos_embed`` is deprecated in favor of ``proj_type``.
"""
super().__init__()
@@ -120,10 +114,7 @@ def __init__(
for in_size, pa_size in zip(img_size, patch_size):
grid_size.append(in_size // pa_size)
- with torch.no_grad():
- pos_embeddings = build_sincos_position_embedding(grid_size, hidden_size, spatial_dims)
- self.position_embeddings.data.copy_(pos_embeddings.float())
- self.position_embeddings.requires_grad = False
+ self.position_embeddings = build_sincos_position_embedding(grid_size, hidden_size, spatial_dims)
else:
raise ValueError(f"pos_embed_type {self.pos_embed_type} not supported.")
diff --git a/monai/networks/blocks/rel_pos_embedding.py b/monai/networks/blocks/rel_pos_embedding.py
new file mode 100644
index 0000000000..e53e5841b0
--- /dev/null
+++ b/monai/networks/blocks/rel_pos_embedding.py
@@ -0,0 +1,56 @@
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+from typing import Iterable, Tuple
+
+import torch
+from torch import nn
+
+from monai.networks.blocks.attention_utils import add_decomposed_rel_pos
+from monai.utils.misc import ensure_tuple_size
+
+
+class DecomposedRelativePosEmbedding(nn.Module):
+ def __init__(self, s_input_dims: Tuple[int, int] | Tuple[int, int, int], c_dim: int, num_heads: int) -> None:
+ """
+ Args:
+ s_input_dims (Tuple): input spatial dimension. (H, W) or (H, W, D)
+ c_dim (int): channel dimension
+ num_heads(int): number of attention heads
+ """
+ super().__init__()
+
+ # validate inputs
+ if not isinstance(s_input_dims, Iterable) or len(s_input_dims) not in [2, 3]:
+ raise ValueError("s_input_dims must be set as follows: (H, W) or (H, W, D)")
+
+ self.s_input_dims = s_input_dims
+ self.c_dim = c_dim
+ self.num_heads = num_heads
+ self.rel_pos_arr = nn.ParameterList(
+ [nn.Parameter(torch.zeros(2 * dim_input_size - 1, c_dim)) for dim_input_size in s_input_dims]
+ )
+
+ def forward(self, x: torch.Tensor, att_mat: torch.Tensor, q: torch.Tensor) -> torch.Tensor:
+ """"""
+ batch = x.shape[0]
+ h, w, d = ensure_tuple_size(self.s_input_dims, 3, 1)
+
+ att_mat = add_decomposed_rel_pos(
+ att_mat.contiguous().view(batch * self.num_heads, h * w * d, h * w * d),
+ q.contiguous().view(batch * self.num_heads, h * w * d, -1),
+ self.rel_pos_arr,
+ (h, w) if d == 1 else (h, w, d),
+ (h, w) if d == 1 else (h, w, d),
+ )
+
+ att_mat = att_mat.reshape(batch, self.num_heads, h * w * d, h * w * d)
+ return att_mat
diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py
index 7c81c1704f..ac96b077bd 100644
--- a/monai/networks/blocks/selfattention.py
+++ b/monai/networks/blocks/selfattention.py
@@ -11,10 +11,14 @@
from __future__ import annotations
+from typing import Tuple, Union
+
import torch
import torch.nn as nn
+import torch.nn.functional as F
-from monai.utils import optional_import
+from monai.networks.layers.utils import get_rel_pos_embedding_layer
+from monai.utils import optional_import, pytorch_after
Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")
@@ -32,6 +36,16 @@ def __init__(
dropout_rate: float = 0.0,
qkv_bias: bool = False,
save_attn: bool = False,
+ dim_head: int | None = None,
+ hidden_input_size: int | None = None,
+ causal: bool = False,
+ sequence_length: int | None = None,
+ rel_pos_embedding: str | None = None,
+ input_size: Tuple | None = None,
+ attention_dtype: torch.dtype | None = None,
+ include_fc: bool = True,
+ use_combined_linear: bool = True,
+ use_flash_attention: bool = False,
) -> None:
"""
Args:
@@ -40,6 +54,19 @@ def __init__(
dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0.
qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False.
save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.
+ dim_head (int, optional): dimension of each head. Defaults to hidden_size // num_heads.
+ hidden_input_size (int, optional): dimension of the input tensor. Defaults to hidden_size.
+ causal: whether to use causal attention (see https://arxiv.org/abs/1706.03762).
+ sequence_length: if causal is True, it is necessary to specify the sequence length.
+ rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map.
+ For now only "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported.
+ input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative
+ positional parameter size.
+ attention_dtype: cast attention operations to this dtype.
+ include_fc: whether to include the final linear layer. Default to True.
+ use_combined_linear: whether to use a single linear layer for qkv projection, default to True.
+ use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
+ (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
"""
@@ -51,30 +78,128 @@ def __init__(
if hidden_size % num_heads != 0:
raise ValueError("hidden size should be divisible by num_heads.")
+ if dim_head:
+ self.inner_dim = num_heads * dim_head
+ self.dim_head = dim_head
+ else:
+ if hidden_size % num_heads != 0:
+ raise ValueError("hidden size should be divisible by num_heads.")
+ self.inner_dim = hidden_size
+ self.dim_head = hidden_size // num_heads
+
+ if causal and sequence_length is None:
+ raise ValueError("sequence_length is necessary for causal attention.")
+
+ if use_flash_attention and not pytorch_after(minor=13, major=1, patch=0):
+ raise ValueError(
+ "use_flash_attention is only supported for PyTorch versions >= 2.0."
+ "Upgrade your PyTorch or set the flag to False."
+ )
+ if use_flash_attention and save_attn:
+ raise ValueError(
+ "save_attn has been set to True, but use_flash_attention is also set"
+ "to True. save_attn can only be used if use_flash_attention is False."
+ )
+
+ if use_flash_attention and rel_pos_embedding is not None:
+ raise ValueError("rel_pos_embedding must be None if you are using flash_attention.")
+
self.num_heads = num_heads
- self.out_proj = nn.Linear(hidden_size, hidden_size)
- self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias)
- self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads)
- self.out_rearrange = Rearrange("b h l d -> b l (h d)")
+ self.hidden_input_size = hidden_input_size if hidden_input_size else hidden_size
+ self.out_proj = nn.Linear(self.inner_dim, self.hidden_input_size)
+
+ self.qkv: Union[nn.Linear, nn.Identity]
+ self.to_q: Union[nn.Linear, nn.Identity]
+ self.to_k: Union[nn.Linear, nn.Identity]
+ self.to_v: Union[nn.Linear, nn.Identity]
+
+ if use_combined_linear:
+ self.qkv = nn.Linear(self.hidden_input_size, self.inner_dim * 3, bias=qkv_bias)
+ self.to_q = self.to_k = self.to_v = nn.Identity() # add to enable torchscript
+ self.input_rearrange = Rearrange("b h (qkv l d) -> qkv b l h d", qkv=3, l=num_heads)
+ else:
+ self.to_q = nn.Linear(self.hidden_input_size, self.inner_dim, bias=qkv_bias)
+ self.to_k = nn.Linear(self.hidden_input_size, self.inner_dim, bias=qkv_bias)
+ self.to_v = nn.Linear(self.hidden_input_size, self.inner_dim, bias=qkv_bias)
+ self.qkv = nn.Identity() # add to enable torchscript
+ self.input_rearrange = Rearrange("b h (l d) -> b l h d", l=num_heads)
+ self.out_rearrange = Rearrange("b l h d -> b h (l d)")
self.drop_output = nn.Dropout(dropout_rate)
self.drop_weights = nn.Dropout(dropout_rate)
- self.head_dim = hidden_size // num_heads
- self.scale = self.head_dim**-0.5
+ self.dropout_rate = dropout_rate
+ self.scale = self.dim_head**-0.5
self.save_attn = save_attn
self.att_mat = torch.Tensor()
+ self.attention_dtype = attention_dtype
+ self.causal = causal
+ self.sequence_length = sequence_length
+ self.include_fc = include_fc
+ self.use_combined_linear = use_combined_linear
+ self.use_flash_attention = use_flash_attention
+
+ if causal and sequence_length is not None:
+ # causal mask to ensure that attention is only applied to the left in the input sequence
+ self.register_buffer(
+ "causal_mask",
+ torch.tril(torch.ones(sequence_length, sequence_length)).view(1, 1, sequence_length, sequence_length),
+ )
+ self.causal_mask: torch.Tensor
+ else:
+ self.causal_mask = torch.Tensor()
+
+ self.rel_positional_embedding = (
+ get_rel_pos_embedding_layer(rel_pos_embedding, input_size, self.dim_head, self.num_heads)
+ if rel_pos_embedding is not None
+ else None
+ )
+ self.input_size = input_size
def forward(self, x):
- output = self.input_rearrange(self.qkv(x))
- q, k, v = output[0], output[1], output[2]
- att_mat = (torch.einsum("blxd,blyd->blxy", q, k) * self.scale).softmax(dim=-1)
- if self.save_attn:
- # no gradients and new tensor;
- # https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
- self.att_mat = att_mat.detach()
-
- att_mat = self.drop_weights(att_mat)
- x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
+ """
+ Args:
+ x (torch.Tensor): input tensor. B x (s_dim_1 * ... * s_dim_n) x C
+
+ Return:
+ torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C
+ """
+ if self.use_combined_linear:
+ output = self.input_rearrange(self.qkv(x))
+ q, k, v = output[0], output[1], output[2]
+ else:
+ q = self.input_rearrange(self.to_q(x))
+ k = self.input_rearrange(self.to_k(x))
+ v = self.input_rearrange(self.to_v(x))
+
+ if self.attention_dtype is not None:
+ q = q.to(self.attention_dtype)
+ k = k.to(self.attention_dtype)
+
+ if self.use_flash_attention:
+ x = F.scaled_dot_product_attention(
+ query=q, key=k, value=v, scale=self.scale, dropout_p=self.dropout_rate, is_causal=self.causal
+ )
+ else:
+ att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale
+
+ # apply relative positional embedding if defined
+ if self.rel_positional_embedding is not None:
+ att_mat = self.rel_positional_embedding(x, att_mat, q)
+
+ if self.causal:
+ att_mat = att_mat.masked_fill(self.causal_mask[:, :, : x.shape[-2], : x.shape[-2]] == 0, float("-inf"))
+
+ att_mat = att_mat.softmax(dim=-1)
+
+ if self.save_attn:
+ # no gradients and new tensor;
+ # https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
+ self.att_mat = att_mat.detach()
+
+ att_mat = self.drop_weights(att_mat)
+ x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
+
x = self.out_rearrange(x)
- x = self.out_proj(x)
+ if self.include_fc:
+ x = self.out_proj(x)
x = self.drop_output(x)
return x
diff --git a/monai/networks/blocks/spade_norm.py b/monai/networks/blocks/spade_norm.py
new file mode 100644
index 0000000000..343dfa9ec0
--- /dev/null
+++ b/monai/networks/blocks/spade_norm.py
@@ -0,0 +1,95 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from monai.networks.blocks import Convolution
+from monai.networks.layers.utils import get_norm_layer
+
+
+class SPADE(nn.Module):
+ """
+ Spatially Adaptive Normalization (SPADE) block, allowing for normalization of activations conditioned on a
+ semantic map. This block is used in SPADE-based image-to-image translation models, as described in
+ Semantic Image Synthesis with Spatially-Adaptive Normalization (https://arxiv.org/abs/1903.07291).
+
+ Args:
+ label_nc: number of semantic labels
+ norm_nc: number of output channels
+ kernel_size: kernel size
+ spatial_dims: number of spatial dimensions
+ hidden_channels: number of channels in the intermediate gamma and beta layers
+ norm: type of base normalisation used before applying the SPADE normalisation
+ norm_params: parameters for the base normalisation
+ """
+
+ def __init__(
+ self,
+ label_nc: int,
+ norm_nc: int,
+ kernel_size: int = 3,
+ spatial_dims: int = 2,
+ hidden_channels: int = 64,
+ norm: str | tuple = "INSTANCE",
+ norm_params: dict | None = None,
+ ) -> None:
+ super().__init__()
+
+ if norm_params is None:
+ norm_params = {}
+ if len(norm_params) != 0:
+ norm = (norm, norm_params)
+ self.param_free_norm = get_norm_layer(norm, spatial_dims=spatial_dims, channels=norm_nc)
+ self.mlp_shared = Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=label_nc,
+ out_channels=hidden_channels,
+ kernel_size=kernel_size,
+ norm=None,
+ act="LEAKYRELU",
+ )
+ self.mlp_gamma = Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=hidden_channels,
+ out_channels=norm_nc,
+ kernel_size=kernel_size,
+ act=None,
+ )
+ self.mlp_beta = Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=hidden_channels,
+ out_channels=norm_nc,
+ kernel_size=kernel_size,
+ act=None,
+ )
+
+ def forward(self, x: torch.Tensor, segmap: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ x: input tensor with shape (B, C, [spatial-dimensions]) where C is the number of semantic channels.
+ segmap: input segmentation map (B, C, [spatial-dimensions]) where C is the number of semantic channels.
+ The map will be interpolated to the dimension of x internally.
+ """
+
+ # Part 1. generate parameter-free normalized activations
+ normalized = self.param_free_norm(x.contiguous())
+
+ # Part 2. produce scaling and bias conditioned on semantic map
+ segmap = F.interpolate(segmap, size=x.size()[2:], mode="nearest")
+ actv = self.mlp_shared(segmap)
+ gamma = self.mlp_gamma(actv)
+ beta = self.mlp_beta(actv)
+ out: torch.Tensor = normalized * (1 + gamma) + beta
+ return out
diff --git a/monai/networks/blocks/spatialattention.py b/monai/networks/blocks/spatialattention.py
new file mode 100644
index 0000000000..60a89a7840
--- /dev/null
+++ b/monai/networks/blocks/spatialattention.py
@@ -0,0 +1,80 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+from typing import Optional
+
+import torch
+import torch.nn as nn
+
+from monai.networks.blocks import SABlock
+
+
+class SpatialAttentionBlock(nn.Module):
+ """Perform spatial self-attention on the input tensor.
+
+ The input tensor is reshaped to B x (x_dim * y_dim [ * z_dim]) x C, where C is the number of channels, and then
+ self-attention is performed on the reshaped tensor. The output tensor is reshaped back to the original shape.
+
+ Args:
+ spatial_dims: number of spatial dimensions, could be 1, 2, or 3.
+ num_channels: number of input channels. Must be divisible by num_head_channels.
+ num_head_channels: number of channels per head.
+ norm_num_groups: Number of groups for the group norm layer.
+ norm_eps: Epsilon for the normalization.
+ attention_dtype: cast attention operations to this dtype.
+ include_fc: whether to include the final linear layer. Default to True.
+ use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
+ use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
+ (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
+
+ """
+
+ def __init__(
+ self,
+ spatial_dims: int,
+ num_channels: int,
+ num_head_channels: int | None = None,
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-6,
+ attention_dtype: Optional[torch.dtype] = None,
+ include_fc: bool = True,
+ use_combined_linear: bool = False,
+ use_flash_attention: bool = False,
+ ) -> None:
+ super().__init__()
+
+ self.spatial_dims = spatial_dims
+ self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels, eps=norm_eps, affine=True)
+ # check num_head_channels is divisible by num_channels
+ if num_head_channels is not None and num_channels % num_head_channels != 0:
+ raise ValueError("num_channels must be divisible by num_head_channels")
+ num_heads = num_channels // num_head_channels if num_head_channels is not None else 1
+ self.attn = SABlock(
+ hidden_size=num_channels,
+ num_heads=num_heads,
+ qkv_bias=True,
+ attention_dtype=attention_dtype,
+ include_fc=include_fc,
+ use_combined_linear=use_combined_linear,
+ use_flash_attention=use_flash_attention,
+ )
+
+ def forward(self, x: torch.Tensor):
+ residual = x
+ shape = x.shape
+ x = self.norm(x)
+ x = x.reshape(*shape[:2], -1).transpose(1, 2) # "b c h w d -> b (h w d) c"
+ x = self.attn(x)
+ x = x.transpose(1, 2).reshape(shape) # "b (h w d) c -> b c h w d"
+ x = x + residual
+ return x
diff --git a/monai/networks/blocks/transformerblock.py b/monai/networks/blocks/transformerblock.py
index ddf959dad2..05eb3b07ab 100644
--- a/monai/networks/blocks/transformerblock.py
+++ b/monai/networks/blocks/transformerblock.py
@@ -11,10 +11,12 @@
from __future__ import annotations
+from typing import Optional
+
+import torch
import torch.nn as nn
-from monai.networks.blocks.mlp import MLPBlock
-from monai.networks.blocks.selfattention import SABlock
+from monai.networks.blocks import CrossAttentionBlock, MLPBlock, SABlock
class TransformerBlock(nn.Module):
@@ -31,6 +33,12 @@ def __init__(
dropout_rate: float = 0.0,
qkv_bias: bool = False,
save_attn: bool = False,
+ causal: bool = False,
+ sequence_length: int | None = None,
+ with_cross_attention: bool = False,
+ use_flash_attention: bool = False,
+ include_fc: bool = True,
+ use_combined_linear: bool = True,
) -> None:
"""
Args:
@@ -38,8 +46,12 @@ def __init__(
mlp_dim (int): dimension of feedforward layer.
num_heads (int): number of attention heads.
dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0.
- qkv_bias (bool, optional): apply bias term for the qkv linear layer. Defaults to False.
+ qkv_bias(bool, optional): apply bias term for the qkv linear layer. Defaults to False.
save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.
+ use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
+ (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
+ include_fc: whether to include the final linear layer. Default to True.
+ use_combined_linear: whether to use a single linear layer for qkv projection, default to True.
"""
@@ -53,10 +65,34 @@ def __init__(
self.mlp = MLPBlock(hidden_size, mlp_dim, dropout_rate)
self.norm1 = nn.LayerNorm(hidden_size)
- self.attn = SABlock(hidden_size, num_heads, dropout_rate, qkv_bias, save_attn)
+ self.attn = SABlock(
+ hidden_size,
+ num_heads,
+ dropout_rate,
+ qkv_bias=qkv_bias,
+ save_attn=save_attn,
+ causal=causal,
+ sequence_length=sequence_length,
+ include_fc=include_fc,
+ use_combined_linear=use_combined_linear,
+ use_flash_attention=use_flash_attention,
+ )
self.norm2 = nn.LayerNorm(hidden_size)
+ self.with_cross_attention = with_cross_attention
+
+ self.norm_cross_attn = nn.LayerNorm(hidden_size)
+ self.cross_attn = CrossAttentionBlock(
+ hidden_size=hidden_size,
+ num_heads=num_heads,
+ dropout_rate=dropout_rate,
+ qkv_bias=qkv_bias,
+ causal=False,
+ use_flash_attention=use_flash_attention,
+ )
- def forward(self, x):
+ def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor:
x = x + self.attn(self.norm1(x))
+ if self.with_cross_attention:
+ x = x + self.cross_attn(self.norm_cross_attn(x), context=context)
x = x + self.mlp(self.norm2(x))
return x
diff --git a/monai/networks/blocks/upsample.py b/monai/networks/blocks/upsample.py
index dee9966919..50fd39a70b 100644
--- a/monai/networks/blocks/upsample.py
+++ b/monai/networks/blocks/upsample.py
@@ -17,8 +17,8 @@
import torch.nn as nn
from monai.networks.layers.factories import Conv, Pad, Pool
-from monai.networks.utils import icnr_init, pixelshuffle
-from monai.utils import InterpolateMode, UpsampleMode, ensure_tuple_rep, look_up_option
+from monai.networks.utils import CastTempType, icnr_init, pixelshuffle
+from monai.utils import InterpolateMode, UpsampleMode, ensure_tuple_rep, look_up_option, pytorch_after
__all__ = ["Upsample", "UpSample", "SubpixelUpsample", "Subpixelupsample", "SubpixelUpSample"]
@@ -50,6 +50,7 @@ def __init__(
size: tuple[int] | int | None = None,
mode: UpsampleMode | str = UpsampleMode.DECONV,
pre_conv: nn.Module | str | None = "default",
+ post_conv: nn.Module | None = None,
interp_mode: str = InterpolateMode.LINEAR,
align_corners: bool | None = True,
bias: bool = True,
@@ -71,6 +72,7 @@ def __init__(
pre_conv: a conv block applied before upsampling. Defaults to "default".
When ``conv_block`` is ``"default"``, one reserved conv layer will be utilized when
Only used in the "nontrainable" or "pixelshuffle" mode.
+ post_conv: a conv block applied after upsampling. Defaults to None. Only used in the "nontrainable" mode.
interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``}
Only used in the "nontrainable" mode.
If ends with ``"linear"`` will use ``spatial dims`` to determine the correct interpolation.
@@ -154,15 +156,25 @@ def __init__(
linear_mode = [InterpolateMode.LINEAR, InterpolateMode.BILINEAR, InterpolateMode.TRILINEAR]
if interp_mode in linear_mode: # choose mode based on dimensions
interp_mode = linear_mode[spatial_dims - 1]
- self.add_module(
- "upsample_non_trainable",
- nn.Upsample(
- size=size,
- scale_factor=None if size else scale_factor_,
- mode=interp_mode.value,
- align_corners=align_corners,
- ),
+
+ upsample = nn.Upsample(
+ size=size,
+ scale_factor=None if size else scale_factor_,
+ mode=interp_mode.value,
+ align_corners=align_corners,
)
+
+ # Cast to float32 as 'upsample_nearest2d_out_frame' op does not support bfloat16
+ # https://github.com/pytorch/pytorch/issues/86679. This issue is solved in PyTorch 2.1
+ if pytorch_after(major=2, minor=1):
+ self.add_module("upsample_non_trainable", upsample)
+ else:
+ self.add_module(
+ "upsample_non_trainable",
+ CastTempType(initial_type=torch.bfloat16, temporary_type=torch.float32, submodule=upsample),
+ )
+ if post_conv:
+ self.add_module("postconv", post_conv)
elif up_mode == UpsampleMode.PIXELSHUFFLE:
self.add_module(
"pixelshuffle",
diff --git a/monai/networks/layers/__init__.py b/monai/networks/layers/__init__.py
index 3a6e4aa554..48c10270b1 100644
--- a/monai/networks/layers/__init__.py
+++ b/monai/networks/layers/__init__.py
@@ -14,7 +14,7 @@
from .conjugate_gradient import ConjugateGradient
from .convutils import calculate_out_shape, gaussian_1d, polyval, same_padding, stride_minus_kernel_padding
from .drop_path import DropPath
-from .factories import Act, Conv, Dropout, LayerFactory, Norm, Pad, Pool, split_args
+from .factories import Act, Conv, Dropout, LayerFactory, Norm, Pad, Pool, RelPosEmbedding, split_args
from .filtering import BilateralFilter, PHLFilter, TrainableBilateralFilter, TrainableJointBilateralFilter
from .gmm import GaussianMixtureModel
from .simplelayers import (
@@ -38,4 +38,5 @@
)
from .spatial_transforms import AffineTransform, grid_count, grid_grad, grid_pull, grid_push
from .utils import get_act_layer, get_dropout_layer, get_norm_layer, get_pool_layer
+from .vector_quantizer import EMAQuantizer, VectorQuantizer
from .weight_init import _no_grad_trunc_normal_, trunc_normal_
diff --git a/monai/networks/layers/factories.py b/monai/networks/layers/factories.py
index 4fc2c16f73..29b72a4f37 100644
--- a/monai/networks/layers/factories.py
+++ b/monai/networks/layers/factories.py
@@ -70,7 +70,7 @@ def use_factory(fact_args):
from monai.networks.utils import has_nvfuser_instance_norm
from monai.utils import ComponentStore, look_up_option, optional_import
-__all__ = ["LayerFactory", "Dropout", "Norm", "Act", "Conv", "Pool", "Pad", "split_args"]
+__all__ = ["LayerFactory", "Dropout", "Norm", "Act", "Conv", "Pool", "Pad", "RelPosEmbedding", "split_args"]
class LayerFactory(ComponentStore):
@@ -201,6 +201,10 @@ def split_args(args):
Conv = LayerFactory(name="Convolution layers", description="Factory for creating convolution layers.")
Pool = LayerFactory(name="Pooling layers", description="Factory for creating pooling layers.")
Pad = LayerFactory(name="Padding layers", description="Factory for creating padding layers.")
+RelPosEmbedding = LayerFactory(
+ name="Relative positional embedding layers",
+ description="Factory for creating relative positional embedding factory",
+)
@Dropout.factory_function("dropout")
@@ -468,3 +472,10 @@ def constant_pad_factory(dim: int) -> type[nn.ConstantPad1d | nn.ConstantPad2d |
"""
types = (nn.ConstantPad1d, nn.ConstantPad2d, nn.ConstantPad3d)
return types[dim - 1]
+
+
+@RelPosEmbedding.factory_function("decomposed")
+def decomposed_rel_pos_embedding() -> type[nn.Module]:
+ from monai.networks.blocks.rel_pos_embedding import DecomposedRelativePosEmbedding
+
+ return DecomposedRelativePosEmbedding
diff --git a/monai/networks/layers/filtering.py b/monai/networks/layers/filtering.py
index 0ff1187dcc..c48c77cf98 100644
--- a/monai/networks/layers/filtering.py
+++ b/monai/networks/layers/filtering.py
@@ -51,6 +51,8 @@ def forward(ctx, input, spatial_sigma=5, color_sigma=0.5, fast_approx=True):
ctx.cs = color_sigma
ctx.fa = fast_approx
output_data = _C.bilateral_filter(input, spatial_sigma, color_sigma, fast_approx)
+ if torch.cuda.is_available():
+ torch.cuda.synchronize()
return output_data
@staticmethod
@@ -139,7 +141,8 @@ def forward(ctx, input_img, sigma_x, sigma_y, sigma_z, color_sigma):
do_dsig_y,
do_dsig_z,
)
-
+ if torch.cuda.is_available():
+ torch.cuda.synchronize()
return output_tensor
@staticmethod
@@ -301,7 +304,8 @@ def forward(ctx, input_img, guidance_img, sigma_x, sigma_y, sigma_z, color_sigma
do_dsig_z,
guidance_img,
)
-
+ if torch.cuda.is_available():
+ torch.cuda.synchronize()
return output_tensor
@staticmethod
diff --git a/monai/networks/layers/simplelayers.py b/monai/networks/layers/simplelayers.py
index 4ac621967f..4acd4a3622 100644
--- a/monai/networks/layers/simplelayers.py
+++ b/monai/networks/layers/simplelayers.py
@@ -452,7 +452,7 @@ def get_binary_kernel(window_size: Sequence[int], dtype=torch.float, device=None
def median_filter(
in_tensor: torch.Tensor,
- kernel_size: Sequence[int] = (3, 3, 3),
+ kernel_size: Sequence[int] | int = (3, 3, 3),
spatial_dims: int = 3,
kernel: torch.Tensor | None = None,
**kwargs,
diff --git a/monai/networks/layers/utils.py b/monai/networks/layers/utils.py
index ace1af27b6..8676f74638 100644
--- a/monai/networks/layers/utils.py
+++ b/monai/networks/layers/utils.py
@@ -11,9 +11,11 @@
from __future__ import annotations
+from typing import Optional
+
import torch.nn
-from monai.networks.layers.factories import Act, Dropout, Norm, Pool, split_args
+from monai.networks.layers.factories import Act, Dropout, Norm, Pool, RelPosEmbedding, split_args
from monai.utils import has_option
__all__ = ["get_norm_layer", "get_act_layer", "get_dropout_layer", "get_pool_layer"]
@@ -124,3 +126,14 @@ def get_pool_layer(name: tuple | str, spatial_dims: int | None = 1):
pool_name, pool_args = split_args(name)
pool_type = Pool[pool_name, spatial_dims]
return pool_type(**pool_args)
+
+
+def get_rel_pos_embedding_layer(name: tuple | str, s_input_dims: Optional[tuple], c_dim: int, num_heads: int):
+ embedding_name, embedding_args = split_args(name)
+ embedding_type = RelPosEmbedding[embedding_name]
+ # create a dictionary with the default values which can be overridden by embedding_args
+ kw_args = {"s_input_dims": s_input_dims, "c_dim": c_dim, "num_heads": num_heads, **embedding_args}
+ # filter out unused argument names
+ kw_args = {k: v for k, v in kw_args.items() if has_option(embedding_type, k)}
+
+ return embedding_type(**kw_args)
diff --git a/monai/networks/layers/vector_quantizer.py b/monai/networks/layers/vector_quantizer.py
new file mode 100644
index 0000000000..9c354e1009
--- /dev/null
+++ b/monai/networks/layers/vector_quantizer.py
@@ -0,0 +1,233 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+from typing import Sequence, Tuple
+
+import torch
+from torch import nn
+
+__all__ = ["VectorQuantizer", "EMAQuantizer"]
+
+
+class EMAQuantizer(nn.Module):
+ """
+ Vector Quantization module using Exponential Moving Average (EMA) to learn the codebook parameters based on Neural
+ Discrete Representation Learning by Oord et al. (https://arxiv.org/abs/1711.00937) and the official implementation
+ that can be found at https://github.com/deepmind/sonnet/blob/v2/sonnet/src/nets/vqvae.py#L148 and commit
+ 58d9a2746493717a7c9252938da7efa6006f3739.
+
+ This module is not compatible with TorchScript while working in a Distributed Data Parallelism Module. This is due
+ to lack of TorchScript support for torch.distributed module as per https://github.com/pytorch/pytorch/issues/41353
+ on 22/10/2022. If you want to TorchScript your model, please turn set `ddp_sync` to False.
+
+ Args:
+ spatial_dims: number of spatial dimensions of the input.
+ num_embeddings: number of atomic elements in the codebook.
+ embedding_dim: number of channels of the input and atomic elements.
+ commitment_cost: scaling factor of the MSE loss between input and its quantized version. Defaults to 0.25.
+ decay: EMA decay. Defaults to 0.99.
+ epsilon: epsilon value. Defaults to 1e-5.
+ embedding_init: initialization method for the codebook. Defaults to "normal".
+ ddp_sync: whether to synchronize the codebook across processes. Defaults to True.
+ """
+
+ def __init__(
+ self,
+ spatial_dims: int,
+ num_embeddings: int,
+ embedding_dim: int,
+ commitment_cost: float = 0.25,
+ decay: float = 0.99,
+ epsilon: float = 1e-5,
+ embedding_init: str = "normal",
+ ddp_sync: bool = True,
+ ):
+ super().__init__()
+ self.spatial_dims: int = spatial_dims
+ self.embedding_dim: int = embedding_dim
+ self.num_embeddings: int = num_embeddings
+
+ assert self.spatial_dims in [2, 3], ValueError(
+ f"EMAQuantizer only supports 4D and 5D tensor inputs but received spatial dims {spatial_dims}."
+ )
+
+ self.embedding: torch.nn.Embedding = torch.nn.Embedding(self.num_embeddings, self.embedding_dim)
+ if embedding_init == "normal":
+ # Initialization is passed since the default one is normal inside the nn.Embedding
+ pass
+ elif embedding_init == "kaiming_uniform":
+ torch.nn.init.kaiming_uniform_(self.embedding.weight.data, mode="fan_in", nonlinearity="linear")
+ self.embedding.weight.requires_grad = False
+
+ self.commitment_cost: float = commitment_cost
+
+ self.register_buffer("ema_cluster_size", torch.zeros(self.num_embeddings))
+ self.register_buffer("ema_w", self.embedding.weight.data.clone())
+ # declare types for mypy
+ self.ema_cluster_size: torch.Tensor
+ self.ema_w: torch.Tensor
+ self.decay: float = decay
+ self.epsilon: float = epsilon
+
+ self.ddp_sync: bool = ddp_sync
+
+ # Precalculating required permutation shapes
+ self.flatten_permutation = [0] + list(range(2, self.spatial_dims + 2)) + [1]
+ self.quantization_permutation: Sequence[int] = [0, self.spatial_dims + 1] + list(
+ range(1, self.spatial_dims + 1)
+ )
+
+ def quantize(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Given an input it projects it to the quantized space and returns additional tensors needed for EMA loss.
+
+ Args:
+ inputs: Encoding space tensors of shape [B, C, H, W, D].
+
+ Returns:
+ torch.Tensor: Flatten version of the input of shape [B*H*W*D, C].
+ torch.Tensor: One-hot representation of the quantization indices of shape [B*H*W*D, self.num_embeddings].
+ torch.Tensor: Quantization indices of shape [B,H,W,D,1]
+
+ """
+ with torch.cuda.amp.autocast(enabled=False):
+ encoding_indices_view = list(inputs.shape)
+ del encoding_indices_view[1]
+
+ inputs = inputs.float()
+
+ # Converting to channel last format
+ flat_input = inputs.permute(self.flatten_permutation).contiguous().view(-1, self.embedding_dim)
+
+ # Calculate Euclidean distances
+ distances = (
+ (flat_input**2).sum(dim=1, keepdim=True)
+ + (self.embedding.weight.t() ** 2).sum(dim=0, keepdim=True)
+ - 2 * torch.mm(flat_input, self.embedding.weight.t())
+ )
+
+ # Mapping distances to indexes
+ encoding_indices = torch.max(-distances, dim=1)[1]
+ encodings = torch.nn.functional.one_hot(encoding_indices, self.num_embeddings).float()
+
+ # Quantize and reshape
+ encoding_indices = encoding_indices.view(encoding_indices_view)
+
+ return flat_input, encodings, encoding_indices
+
+ def embed(self, embedding_indices: torch.Tensor) -> torch.Tensor:
+ """
+ Given encoding indices of shape [B,D,H,W,1] embeds them in the quantized space
+ [B, D, H, W, self.embedding_dim] and reshapes them to [B, self.embedding_dim, D, H, W] to be fed to the
+ decoder.
+
+ Args:
+ embedding_indices: Tensor in channel last format which holds indices referencing atomic
+ elements from self.embedding
+
+ Returns:
+ torch.Tensor: Quantize space representation of encoding_indices in channel first format.
+ """
+ with torch.cuda.amp.autocast(enabled=False):
+ embedding: torch.Tensor = (
+ self.embedding(embedding_indices).permute(self.quantization_permutation).contiguous()
+ )
+ return embedding
+
+ def distributed_synchronization(self, encodings_sum: torch.Tensor, dw: torch.Tensor) -> None:
+ """
+ TorchScript does not support torch.distributed.all_reduce. This function is a bypassing trick based on the
+ example: https://pytorch.org/docs/stable/generated/torch.jit.unused.html#torch.jit.unused
+
+ Args:
+ encodings_sum: The summation of one hot representation of what encoding was used for each
+ position.
+ dw: The multiplication of the one hot representation of what encoding was used for each
+ position with the flattened input.
+
+ Returns:
+ None
+ """
+ if self.ddp_sync and torch.distributed.is_initialized():
+ torch.distributed.all_reduce(tensor=encodings_sum, op=torch.distributed.ReduceOp.SUM)
+ torch.distributed.all_reduce(tensor=dw, op=torch.distributed.ReduceOp.SUM)
+ else:
+ pass
+
+ def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ flat_input, encodings, encoding_indices = self.quantize(inputs)
+ quantized = self.embed(encoding_indices)
+
+ # Use EMA to update the embedding vectors
+ if self.training:
+ with torch.no_grad():
+ encodings_sum = encodings.sum(0)
+ dw = torch.mm(encodings.t(), flat_input)
+
+ if self.ddp_sync:
+ self.distributed_synchronization(encodings_sum, dw)
+
+ self.ema_cluster_size.data.mul_(self.decay).add_(torch.mul(encodings_sum, 1 - self.decay))
+
+ # Laplace smoothing of the cluster size
+ n = self.ema_cluster_size.sum()
+ weights = (self.ema_cluster_size + self.epsilon) / (n + self.num_embeddings * self.epsilon) * n
+ self.ema_w.data.mul_(self.decay).add_(torch.mul(dw, 1 - self.decay))
+ self.embedding.weight.data.copy_(self.ema_w / weights.unsqueeze(1))
+
+ # Encoding Loss
+ loss = self.commitment_cost * torch.nn.functional.mse_loss(quantized.detach(), inputs)
+
+ # Straight Through Estimator
+ quantized = inputs + (quantized - inputs).detach()
+
+ return quantized, loss, encoding_indices
+
+
+class VectorQuantizer(torch.nn.Module):
+ """
+ Vector Quantization wrapper that is needed as a workaround for the AMP to isolate the non fp16 compatible parts of
+ the quantization in their own class.
+
+ Args:
+ quantizer (torch.nn.Module): Quantizer module that needs to return its quantized representation, loss and index
+ based quantized representation.
+ """
+
+ def __init__(self, quantizer: EMAQuantizer):
+ super().__init__()
+
+ self.quantizer: EMAQuantizer = quantizer
+
+ self.perplexity: torch.Tensor = torch.rand(1)
+
+ def forward(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ quantized, loss, encoding_indices = self.quantizer(inputs)
+ # Perplexity calculations
+ avg_probs = (
+ torch.histc(encoding_indices.float(), bins=self.quantizer.num_embeddings, max=self.quantizer.num_embeddings)
+ .float()
+ .div(encoding_indices.numel())
+ )
+
+ self.perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
+
+ return loss, quantized
+
+ def embed(self, embedding_indices: torch.Tensor) -> torch.Tensor:
+ return self.quantizer.embed(embedding_indices=embedding_indices)
+
+ def quantize(self, encodings: torch.Tensor) -> torch.Tensor:
+ output = self.quantizer(encodings)
+ encoding_indices: torch.Tensor = output[2]
+ return encoding_indices
diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py
index 9247aaee85..b876e6a3fc 100644
--- a/monai/networks/nets/__init__.py
+++ b/monai/networks/nets/__init__.py
@@ -14,9 +14,11 @@
from .ahnet import AHnet, Ahnet, AHNet
from .attentionunet import AttentionUnet
from .autoencoder import AutoEncoder
+from .autoencoderkl import AutoencoderKL
from .basic_unet import BasicUNet, BasicUnet, Basicunet, basicunet
from .basic_unetplusplus import BasicUNetPlusPlus, BasicUnetPlusPlus, BasicunetPlusPlus, basicunetplusplus
from .classifier import Classifier, Critic, Discriminator
+from .controlnet import ControlNet
from .daf3d import DAF3D
from .densenet import (
DenseNet,
@@ -34,6 +36,7 @@
densenet201,
densenet264,
)
+from .diffusion_model_unet import DiffusionModelUNet
from .dints import DiNTS, TopologyConstruction, TopologyInstance, TopologySearch
from .dynunet import DynUNet, DynUnet, Dynunet
from .efficientnet import (
@@ -50,8 +53,28 @@
from .generator import Generator
from .highresnet import HighResBlock, HighResNet
from .hovernet import Hovernet, HoVernet, HoVerNet, HoverNet
+from .mednext import (
+ MedNeXt,
+ MedNext,
+ MedNextB,
+ MedNeXtB,
+ MedNextBase,
+ MedNextL,
+ MedNeXtL,
+ MedNeXtLarge,
+ MedNextLarge,
+ MedNextM,
+ MedNeXtM,
+ MedNeXtMedium,
+ MedNextMedium,
+ MedNextS,
+ MedNeXtS,
+ MedNeXtSmall,
+ MedNextSmall,
+)
from .milmodel import MILModel
from .netadapter import NetAdapter
+from .patchgan_discriminator import MultiScalePatchDiscriminator, PatchDiscriminator
from .quicknat import Quicknat
from .regressor import Regressor
from .regunet import GlobalNet, LocalNet, RegUNet
@@ -59,6 +82,8 @@
ResNet,
ResNetBlock,
ResNetBottleneck,
+ ResNetEncoder,
+ ResNetFeatures,
get_medicalnet_pretrained_resnet_args,
get_pretrained_resnet_medicalnet,
resnet10,
@@ -70,7 +95,7 @@
resnet200,
)
from .segresnet import SegResNet, SegResNetVAE
-from .segresnet_ds import SegResNetDS
+from .segresnet_ds import SegResNetDS, SegResNetDS2
from .senet import (
SENet,
SEnet,
@@ -102,13 +127,19 @@
seresnext50,
seresnext101,
)
+from .spade_autoencoderkl import SPADEAutoencoderKL
+from .spade_diffusion_model_unet import SPADEDiffusionModelUNet
+from .spade_network import SPADENet
from .swin_unetr import PatchMerging, PatchMergingV2, SwinUNETR
from .torchvision_fc import TorchVisionFCModel
from .transchex import BertAttention, BertMixedLayer, BertOutput, BertPreTrainedModel, MultiModal, Pooler, Transchex
+from .transformer import DecoderOnlyTransformer
from .unet import UNet, Unet
from .unetr import UNETR
from .varautoencoder import VarAutoEncoder
+from .vista3d import VISTA3D, vista3d132
from .vit import ViT
from .vitautoenc import ViTAutoEnc
from .vnet import VNet
from .voxelmorph import VoxelMorph, VoxelMorphUNet
+from .vqvae import VQVAE
diff --git a/monai/networks/nets/attentionunet.py b/monai/networks/nets/attentionunet.py
index 5689cf1071..fdf31d9701 100644
--- a/monai/networks/nets/attentionunet.py
+++ b/monai/networks/nets/attentionunet.py
@@ -29,7 +29,7 @@ def __init__(
spatial_dims: int,
in_channels: int,
out_channels: int,
- kernel_size: int = 3,
+ kernel_size: Sequence[int] | int = 3,
strides: int = 1,
dropout=0.0,
):
@@ -219,7 +219,13 @@ def __init__(
self.kernel_size = kernel_size
self.dropout = dropout
- head = ConvBlock(spatial_dims=spatial_dims, in_channels=in_channels, out_channels=channels[0], dropout=dropout)
+ head = ConvBlock(
+ spatial_dims=spatial_dims,
+ in_channels=in_channels,
+ out_channels=channels[0],
+ dropout=dropout,
+ kernel_size=self.kernel_size,
+ )
reduce_channels = Convolution(
spatial_dims=spatial_dims,
in_channels=channels[0],
@@ -245,6 +251,7 @@ def _create_block(channels: Sequence[int], strides: Sequence[int]) -> nn.Module:
out_channels=channels[1],
strides=strides[0],
dropout=self.dropout,
+ kernel_size=self.kernel_size,
),
subblock,
),
@@ -271,6 +278,7 @@ def _get_bottom_layer(self, in_channels: int, out_channels: int, strides: int) -
out_channels=out_channels,
strides=strides,
dropout=self.dropout,
+ kernel_size=self.kernel_size,
),
up_kernel_size=self.up_kernel_size,
strides=strides,
diff --git a/monai/networks/nets/autoencoderkl.py b/monai/networks/nets/autoencoderkl.py
new file mode 100644
index 0000000000..af191e748b
--- /dev/null
+++ b/monai/networks/nets/autoencoderkl.py
@@ -0,0 +1,735 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+from collections.abc import Sequence
+from typing import List
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from monai.networks.blocks import Convolution, SpatialAttentionBlock, Upsample
+from monai.utils import ensure_tuple_rep, optional_import
+
+Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")
+
+__all__ = ["AutoencoderKL"]
+
+
+class AsymmetricPad(nn.Module):
+ """
+ Pad the input tensor asymmetrically along every spatial dimension.
+
+ Args:
+ spatial_dims: number of spatial dimensions, could be 1, 2, or 3.
+ """
+
+ def __init__(self, spatial_dims: int) -> None:
+ super().__init__()
+ self.pad = (0, 1) * spatial_dims
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = nn.functional.pad(x, self.pad, mode="constant", value=0.0)
+ return x
+
+
+class AEKLDownsample(nn.Module):
+ """
+ Convolution-based downsampling layer.
+
+ Args:
+ spatial_dims: number of spatial dimensions (1D, 2D, 3D).
+ in_channels: number of input channels.
+ """
+
+ def __init__(self, spatial_dims: int, in_channels: int) -> None:
+ super().__init__()
+ self.pad = AsymmetricPad(spatial_dims=spatial_dims)
+
+ self.conv = Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=in_channels,
+ out_channels=in_channels,
+ strides=2,
+ kernel_size=3,
+ padding=0,
+ conv_only=True,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.pad(x)
+ x = self.conv(x)
+ return x
+
+
+class AEKLResBlock(nn.Module):
+ """
+ Residual block consisting of a cascade of 2 convolutions + activation + normalisation block, and a
+ residual connection between input and output.
+
+ Args:
+ spatial_dims: number of spatial dimensions, could be 1, 2, or 3.
+ in_channels: input channels to the layer.
+ norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of
+ channels is divisible by this number.
+ norm_eps: epsilon for the normalisation.
+ out_channels: number of output channels.
+ """
+
+ def __init__(
+ self, spatial_dims: int, in_channels: int, norm_num_groups: int, norm_eps: float, out_channels: int
+ ) -> None:
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = in_channels if out_channels is None else out_channels
+
+ self.norm1 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True)
+ self.conv1 = Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=self.in_channels,
+ out_channels=self.out_channels,
+ strides=1,
+ kernel_size=3,
+ padding=1,
+ conv_only=True,
+ )
+ self.norm2 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=out_channels, eps=norm_eps, affine=True)
+ self.conv2 = Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=self.out_channels,
+ out_channels=self.out_channels,
+ strides=1,
+ kernel_size=3,
+ padding=1,
+ conv_only=True,
+ )
+
+ self.nin_shortcut: nn.Module
+ if self.in_channels != self.out_channels:
+ self.nin_shortcut = Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=self.in_channels,
+ out_channels=self.out_channels,
+ strides=1,
+ kernel_size=1,
+ padding=0,
+ conv_only=True,
+ )
+ else:
+ self.nin_shortcut = nn.Identity()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ h = x
+ h = self.norm1(h)
+ h = F.silu(h)
+ h = self.conv1(h)
+
+ h = self.norm2(h)
+ h = F.silu(h)
+ h = self.conv2(h)
+
+ x = self.nin_shortcut(x)
+
+ return x + h
+
+
+class Encoder(nn.Module):
+ """
+ Convolutional cascade that downsamples the image into a spatial latent space.
+
+ Args:
+ spatial_dims: number of spatial dimensions, could be 1, 2, or 3.
+ in_channels: number of input channels.
+ channels: sequence of block output channels.
+ out_channels: number of channels in the bottom layer (latent space) of the autoencoder.
+ num_res_blocks: number of residual blocks (see _ResBlock) per level.
+ norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number.
+ norm_eps: epsilon for the normalization.
+ attention_levels: indicate which level from num_channels contain an attention block.
+ with_nonlocal_attn: if True use non-local attention block.
+ include_fc: whether to include the final linear layer. Default to True.
+ use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
+ use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
+ (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
+ """
+
+ def __init__(
+ self,
+ spatial_dims: int,
+ in_channels: int,
+ channels: Sequence[int],
+ out_channels: int,
+ num_res_blocks: Sequence[int],
+ norm_num_groups: int,
+ norm_eps: float,
+ attention_levels: Sequence[bool],
+ with_nonlocal_attn: bool = True,
+ include_fc: bool = True,
+ use_combined_linear: bool = False,
+ use_flash_attention: bool = False,
+ ) -> None:
+ super().__init__()
+ self.spatial_dims = spatial_dims
+ self.in_channels = in_channels
+ self.channels = channels
+ self.out_channels = out_channels
+ self.num_res_blocks = num_res_blocks
+ self.norm_num_groups = norm_num_groups
+ self.norm_eps = norm_eps
+ self.attention_levels = attention_levels
+
+ blocks: List[nn.Module] = []
+ # Initial convolution
+ blocks.append(
+ Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=in_channels,
+ out_channels=channels[0],
+ strides=1,
+ kernel_size=3,
+ padding=1,
+ conv_only=True,
+ )
+ )
+
+ # Residual and downsampling blocks
+ output_channel = channels[0]
+ for i in range(len(channels)):
+ input_channel = output_channel
+ output_channel = channels[i]
+ is_final_block = i == len(channels) - 1
+
+ for _ in range(self.num_res_blocks[i]):
+ blocks.append(
+ AEKLResBlock(
+ spatial_dims=spatial_dims,
+ in_channels=input_channel,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ out_channels=output_channel,
+ )
+ )
+ input_channel = output_channel
+ if attention_levels[i]:
+ blocks.append(
+ SpatialAttentionBlock(
+ spatial_dims=spatial_dims,
+ num_channels=input_channel,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ include_fc=include_fc,
+ use_combined_linear=use_combined_linear,
+ use_flash_attention=use_flash_attention,
+ )
+ )
+
+ if not is_final_block:
+ blocks.append(AEKLDownsample(spatial_dims=spatial_dims, in_channels=input_channel))
+ # Non-local attention block
+ if with_nonlocal_attn is True:
+ blocks.append(
+ AEKLResBlock(
+ spatial_dims=spatial_dims,
+ in_channels=channels[-1],
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ out_channels=channels[-1],
+ )
+ )
+
+ blocks.append(
+ SpatialAttentionBlock(
+ spatial_dims=spatial_dims,
+ num_channels=channels[-1],
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ include_fc=include_fc,
+ use_combined_linear=use_combined_linear,
+ use_flash_attention=use_flash_attention,
+ )
+ )
+ blocks.append(
+ AEKLResBlock(
+ spatial_dims=spatial_dims,
+ in_channels=channels[-1],
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ out_channels=channels[-1],
+ )
+ )
+ # Normalise and convert to latent size
+ blocks.append(nn.GroupNorm(num_groups=norm_num_groups, num_channels=channels[-1], eps=norm_eps, affine=True))
+ blocks.append(
+ Convolution(
+ spatial_dims=self.spatial_dims,
+ in_channels=channels[-1],
+ out_channels=out_channels,
+ strides=1,
+ kernel_size=3,
+ padding=1,
+ conv_only=True,
+ )
+ )
+
+ self.blocks = nn.ModuleList(blocks)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ for block in self.blocks:
+ x = block(x)
+ return x
+
+
+class Decoder(nn.Module):
+ """
+ Convolutional cascade upsampling from a spatial latent space into an image space.
+
+ Args:
+ spatial_dims: number of spatial dimensions, could be 1, 2, or 3.
+ channels: sequence of block output channels.
+ in_channels: number of channels in the bottom layer (latent space) of the autoencoder.
+ out_channels: number of output channels.
+ num_res_blocks: number of residual blocks (see _ResBlock) per level.
+ norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number.
+ norm_eps: epsilon for the normalization.
+ attention_levels: indicate which level from num_channels contain an attention block.
+ with_nonlocal_attn: if True use non-local attention block.
+ use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder.
+ include_fc: whether to include the final linear layer. Default to True.
+ use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
+ use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
+ (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
+ """
+
+ def __init__(
+ self,
+ spatial_dims: int,
+ channels: Sequence[int],
+ in_channels: int,
+ out_channels: int,
+ num_res_blocks: Sequence[int],
+ norm_num_groups: int,
+ norm_eps: float,
+ attention_levels: Sequence[bool],
+ with_nonlocal_attn: bool = True,
+ use_convtranspose: bool = False,
+ include_fc: bool = True,
+ use_combined_linear: bool = False,
+ use_flash_attention: bool = False,
+ ) -> None:
+ super().__init__()
+ self.spatial_dims = spatial_dims
+ self.channels = channels
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.num_res_blocks = num_res_blocks
+ self.norm_num_groups = norm_num_groups
+ self.norm_eps = norm_eps
+ self.attention_levels = attention_levels
+
+ reversed_block_out_channels = list(reversed(channels))
+
+ blocks: List[nn.Module] = []
+
+ # Initial convolution
+ blocks.append(
+ Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=in_channels,
+ out_channels=reversed_block_out_channels[0],
+ strides=1,
+ kernel_size=3,
+ padding=1,
+ conv_only=True,
+ )
+ )
+
+ # Non-local attention block
+ if with_nonlocal_attn is True:
+ blocks.append(
+ AEKLResBlock(
+ spatial_dims=spatial_dims,
+ in_channels=reversed_block_out_channels[0],
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ out_channels=reversed_block_out_channels[0],
+ )
+ )
+ blocks.append(
+ SpatialAttentionBlock(
+ spatial_dims=spatial_dims,
+ num_channels=reversed_block_out_channels[0],
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ include_fc=include_fc,
+ use_combined_linear=use_combined_linear,
+ use_flash_attention=use_flash_attention,
+ )
+ )
+ blocks.append(
+ AEKLResBlock(
+ spatial_dims=spatial_dims,
+ in_channels=reversed_block_out_channels[0],
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ out_channels=reversed_block_out_channels[0],
+ )
+ )
+
+ reversed_attention_levels = list(reversed(attention_levels))
+ reversed_num_res_blocks = list(reversed(num_res_blocks))
+ block_out_ch = reversed_block_out_channels[0]
+ for i in range(len(reversed_block_out_channels)):
+ block_in_ch = block_out_ch
+ block_out_ch = reversed_block_out_channels[i]
+ is_final_block = i == len(channels) - 1
+
+ for _ in range(reversed_num_res_blocks[i]):
+ blocks.append(
+ AEKLResBlock(
+ spatial_dims=spatial_dims,
+ in_channels=block_in_ch,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ out_channels=block_out_ch,
+ )
+ )
+ block_in_ch = block_out_ch
+
+ if reversed_attention_levels[i]:
+ blocks.append(
+ SpatialAttentionBlock(
+ spatial_dims=spatial_dims,
+ num_channels=block_in_ch,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ include_fc=include_fc,
+ use_combined_linear=use_combined_linear,
+ use_flash_attention=use_flash_attention,
+ )
+ )
+
+ if not is_final_block:
+ if use_convtranspose:
+ blocks.append(
+ Upsample(
+ spatial_dims=spatial_dims, mode="deconv", in_channels=block_in_ch, out_channels=block_in_ch
+ )
+ )
+ else:
+ post_conv = Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=block_in_ch,
+ out_channels=block_in_ch,
+ strides=1,
+ kernel_size=3,
+ padding=1,
+ conv_only=True,
+ )
+ blocks.append(
+ Upsample(
+ spatial_dims=spatial_dims,
+ mode="nontrainable",
+ in_channels=block_in_ch,
+ out_channels=block_in_ch,
+ interp_mode="nearest",
+ scale_factor=2.0,
+ post_conv=post_conv,
+ align_corners=None,
+ )
+ )
+
+ blocks.append(nn.GroupNorm(num_groups=norm_num_groups, num_channels=block_in_ch, eps=norm_eps, affine=True))
+ blocks.append(
+ Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=block_in_ch,
+ out_channels=out_channels,
+ strides=1,
+ kernel_size=3,
+ padding=1,
+ conv_only=True,
+ )
+ )
+
+ self.blocks = nn.ModuleList(blocks)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ for block in self.blocks:
+ x = block(x)
+ return x
+
+
+class AutoencoderKL(nn.Module):
+ """
+ Autoencoder model with KL-regularized latent space based on
+ Rombach et al. "High-Resolution Image Synthesis with Latent Diffusion Models" https://arxiv.org/abs/2112.10752
+ and Pinaya et al. "Brain Imaging Generation with Latent Diffusion Models" https://arxiv.org/abs/2209.07162
+
+ Args:
+ spatial_dims: number of spatial dimensions, could be 1, 2, or 3.
+ in_channels: number of input channels.
+ out_channels: number of output channels.
+ num_res_blocks: number of residual blocks (see _ResBlock) per level.
+ channels: number of output channels for each block.
+ attention_levels: sequence of levels to add attention.
+ latent_channels: latent embedding dimension.
+ norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number.
+ norm_eps: epsilon for the normalization.
+ with_encoder_nonlocal_attn: if True use non-local attention block in the encoder.
+ with_decoder_nonlocal_attn: if True use non-local attention block in the decoder.
+ use_checkpoint: if True, use activation checkpoint to save memory.
+ use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder.
+ include_fc: whether to include the final linear layer in the attention block. Default to True.
+ use_combined_linear: whether to use a single linear layer for qkv projection in the attention block, default to False.
+ use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
+ (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
+ """
+
+ def __init__(
+ self,
+ spatial_dims: int,
+ in_channels: int = 1,
+ out_channels: int = 1,
+ num_res_blocks: Sequence[int] | int = (2, 2, 2, 2),
+ channels: Sequence[int] = (32, 64, 64, 64),
+ attention_levels: Sequence[bool] = (False, False, True, True),
+ latent_channels: int = 3,
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-6,
+ with_encoder_nonlocal_attn: bool = True,
+ with_decoder_nonlocal_attn: bool = True,
+ use_checkpoint: bool = False,
+ use_convtranspose: bool = False,
+ include_fc: bool = True,
+ use_combined_linear: bool = False,
+ use_flash_attention: bool = False,
+ ) -> None:
+ super().__init__()
+
+ # All number of channels should be multiple of num_groups
+ if any((out_channel % norm_num_groups) != 0 for out_channel in channels):
+ raise ValueError("AutoencoderKL expects all num_channels being multiple of norm_num_groups")
+
+ if len(channels) != len(attention_levels):
+ raise ValueError("AutoencoderKL expects num_channels being same size of attention_levels")
+
+ if isinstance(num_res_blocks, int):
+ num_res_blocks = ensure_tuple_rep(num_res_blocks, len(channels))
+
+ if len(num_res_blocks) != len(channels):
+ raise ValueError(
+ "`num_res_blocks` should be a single integer or a tuple of integers with the same length as "
+ "`num_channels`."
+ )
+
+ self.encoder: nn.Module = Encoder(
+ spatial_dims=spatial_dims,
+ in_channels=in_channels,
+ channels=channels,
+ out_channels=latent_channels,
+ num_res_blocks=num_res_blocks,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ attention_levels=attention_levels,
+ with_nonlocal_attn=with_encoder_nonlocal_attn,
+ include_fc=include_fc,
+ use_combined_linear=use_combined_linear,
+ use_flash_attention=use_flash_attention,
+ )
+ self.decoder: nn.Module = Decoder(
+ spatial_dims=spatial_dims,
+ channels=channels,
+ in_channels=latent_channels,
+ out_channels=out_channels,
+ num_res_blocks=num_res_blocks,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ attention_levels=attention_levels,
+ with_nonlocal_attn=with_decoder_nonlocal_attn,
+ use_convtranspose=use_convtranspose,
+ include_fc=include_fc,
+ use_combined_linear=use_combined_linear,
+ use_flash_attention=use_flash_attention,
+ )
+ self.quant_conv_mu = Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=latent_channels,
+ out_channels=latent_channels,
+ strides=1,
+ kernel_size=1,
+ padding=0,
+ conv_only=True,
+ )
+ self.quant_conv_log_sigma = Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=latent_channels,
+ out_channels=latent_channels,
+ strides=1,
+ kernel_size=1,
+ padding=0,
+ conv_only=True,
+ )
+ self.post_quant_conv = Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=latent_channels,
+ out_channels=latent_channels,
+ strides=1,
+ kernel_size=1,
+ padding=0,
+ conv_only=True,
+ )
+ self.latent_channels = latent_channels
+ self.use_checkpoint = use_checkpoint
+
+ def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Forwards an image through the spatial encoder, obtaining the latent mean and sigma representations.
+
+ Args:
+ x: BxCx[SPATIAL DIMS] tensor
+
+ """
+ if self.use_checkpoint:
+ h = torch.utils.checkpoint.checkpoint(self.encoder, x, use_reentrant=False)
+ else:
+ h = self.encoder(x)
+
+ z_mu = self.quant_conv_mu(h)
+ z_log_var = self.quant_conv_log_sigma(h)
+ z_log_var = torch.clamp(z_log_var, -30.0, 20.0)
+ z_sigma = torch.exp(z_log_var / 2)
+
+ return z_mu, z_sigma
+
+ def sampling(self, z_mu: torch.Tensor, z_sigma: torch.Tensor) -> torch.Tensor:
+ """
+ From the mean and sigma representations resulting of encoding an image through the latent space,
+ obtains a noise sample resulting from sampling gaussian noise, multiplying by the variance (sigma) and
+ adding the mean.
+
+ Args:
+ z_mu: Bx[Z_CHANNELS]x[LATENT SPACE SIZE] mean vector obtained by the encoder when you encode an image
+ z_sigma: Bx[Z_CHANNELS]x[LATENT SPACE SIZE] variance vector obtained by the encoder when you encode an image
+
+ Returns:
+ sample of shape Bx[Z_CHANNELS]x[LATENT SPACE SIZE]
+ """
+ eps = torch.randn_like(z_sigma)
+ z_vae = z_mu + eps * z_sigma
+ return z_vae
+
+ def reconstruct(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Encodes and decodes an input image.
+
+ Args:
+ x: BxCx[SPATIAL DIMENSIONS] tensor.
+
+ Returns:
+ reconstructed image, of the same shape as input
+ """
+ z_mu, _ = self.encode(x)
+ reconstruction = self.decode(z_mu)
+ return reconstruction
+
+ def decode(self, z: torch.Tensor) -> torch.Tensor:
+ """
+ Based on a latent space sample, forwards it through the Decoder.
+
+ Args:
+ z: Bx[Z_CHANNELS]x[LATENT SPACE SHAPE]
+
+ Returns:
+ decoded image tensor
+ """
+ z = self.post_quant_conv(z)
+ dec: torch.Tensor
+ if self.use_checkpoint:
+ dec = torch.utils.checkpoint.checkpoint(self.decoder, z, use_reentrant=False)
+ else:
+ dec = self.decoder(z)
+ return dec
+
+ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ z_mu, z_sigma = self.encode(x)
+ z = self.sampling(z_mu, z_sigma)
+ reconstruction = self.decode(z)
+ return reconstruction, z_mu, z_sigma
+
+ def encode_stage_2_inputs(self, x: torch.Tensor) -> torch.Tensor:
+ z_mu, z_sigma = self.encode(x)
+ z = self.sampling(z_mu, z_sigma)
+ return z
+
+ def decode_stage_2_outputs(self, z: torch.Tensor) -> torch.Tensor:
+ image = self.decode(z)
+ return image
+
+ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None:
+ """
+ Load a state dict from an AutoencoderKL trained with [MONAI Generative](https://github.com/Project-MONAI/GenerativeModels).
+
+ Args:
+ old_state_dict: state dict from the old AutoencoderKL model.
+ """
+
+ new_state_dict = self.state_dict()
+ # if all keys match, just load the state dict
+ if all(k in new_state_dict for k in old_state_dict):
+ print("All keys match, loading state dict.")
+ self.load_state_dict(old_state_dict)
+ return
+
+ if verbose:
+ # print all new_state_dict keys that are not in old_state_dict
+ for k in new_state_dict:
+ if k not in old_state_dict:
+ print(f"key {k} not found in old state dict")
+ # and vice versa
+ print("----------------------------------------------")
+ for k in old_state_dict:
+ if k not in new_state_dict:
+ print(f"key {k} not found in new state dict")
+
+ # copy over all matching keys
+ for k in new_state_dict:
+ if k in old_state_dict:
+ new_state_dict[k] = old_state_dict.pop(k)
+
+ # fix the attention blocks
+ attention_blocks = [k.replace(".attn.to_q.weight", "") for k in new_state_dict if "attn.to_q.weight" in k]
+ for block in attention_blocks:
+ new_state_dict[f"{block}.attn.to_q.weight"] = old_state_dict.pop(f"{block}.to_q.weight")
+ new_state_dict[f"{block}.attn.to_k.weight"] = old_state_dict.pop(f"{block}.to_k.weight")
+ new_state_dict[f"{block}.attn.to_v.weight"] = old_state_dict.pop(f"{block}.to_v.weight")
+ new_state_dict[f"{block}.attn.to_q.bias"] = old_state_dict.pop(f"{block}.to_q.bias")
+ new_state_dict[f"{block}.attn.to_k.bias"] = old_state_dict.pop(f"{block}.to_k.bias")
+ new_state_dict[f"{block}.attn.to_v.bias"] = old_state_dict.pop(f"{block}.to_v.bias")
+
+ # old version did not have a projection so set these to the identity
+ new_state_dict[f"{block}.attn.out_proj.weight"] = torch.eye(
+ new_state_dict[f"{block}.attn.out_proj.weight"].shape[0]
+ )
+ new_state_dict[f"{block}.attn.out_proj.bias"] = torch.zeros(
+ new_state_dict[f"{block}.attn.out_proj.bias"].shape
+ )
+
+ # fix the upsample conv blocks which were renamed postconv
+ for k in new_state_dict:
+ if "postconv" in k:
+ old_name = k.replace("postconv", "conv")
+ new_state_dict[k] = old_state_dict.pop(old_name)
+ if verbose:
+ # print all remaining keys in old_state_dict
+ print("remaining keys in old_state_dict:", old_state_dict.keys())
+ self.load_state_dict(new_state_dict, strict=True)
diff --git a/monai/networks/nets/cell_sam_wrapper.py b/monai/networks/nets/cell_sam_wrapper.py
new file mode 100644
index 0000000000..308c3a6bcb
--- /dev/null
+++ b/monai/networks/nets/cell_sam_wrapper.py
@@ -0,0 +1,92 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from monai.utils import optional_import
+
+build_sam_vit_b, has_sam = optional_import("segment_anything.build_sam", name="build_sam_vit_b")
+
+_all__ = ["CellSamWrapper"]
+
+
+class CellSamWrapper(torch.nn.Module):
+ """
+ CellSamWrapper is thin wrapper around SAM model https://github.com/facebookresearch/segment-anything
+ with an image only decoder, that can be used for segmentation tasks.
+
+
+ Args:
+ auto_resize_inputs: whether to resize inputs before passing to the network.
+ (usually they need be resized, unless they are already at the expected size)
+ network_resize_roi: expected input size for the network.
+ (currently SAM expects 1024x1024)
+ checkpoint: checkpoint file to load the SAM weights from.
+ (this can be downloaded from SAM repo https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth)
+ return_features: whether to return features from SAM encoder
+ (without using decoder/upsampling to the original input size)
+
+ """
+
+ def __init__(
+ self,
+ auto_resize_inputs=True,
+ network_resize_roi=(1024, 1024),
+ checkpoint="sam_vit_b_01ec64.pth",
+ return_features=False,
+ *args,
+ **kwargs,
+ ) -> None:
+ super().__init__(*args, **kwargs)
+
+ self.network_resize_roi = network_resize_roi
+ self.auto_resize_inputs = auto_resize_inputs
+ self.return_features = return_features
+
+ if not has_sam:
+ raise ValueError(
+ "SAM is not installed, please run: pip install git+https://github.com/facebookresearch/segment-anything.git"
+ )
+
+ model = build_sam_vit_b(checkpoint=checkpoint)
+
+ model.prompt_encoder = None
+ model.mask_decoder = None
+
+ model.mask_decoder = nn.Sequential(
+ nn.BatchNorm2d(num_features=256),
+ nn.ReLU(inplace=True),
+ nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False),
+ nn.BatchNorm2d(num_features=128),
+ nn.ReLU(inplace=True),
+ nn.ConvTranspose2d(128, 3, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True),
+ )
+
+ self.model = model
+
+ def forward(self, x):
+ sh = x.shape[2:]
+
+ if self.auto_resize_inputs:
+ x = F.interpolate(x, size=self.network_resize_roi, mode="bilinear")
+
+ x = self.model.image_encoder(x)
+
+ if not self.return_features:
+ x = self.model.mask_decoder(x)
+ if self.auto_resize_inputs:
+ x = F.interpolate(x, size=sh, mode="bilinear")
+
+ return x
diff --git a/monai/networks/nets/controlnet.py b/monai/networks/nets/controlnet.py
new file mode 100644
index 0000000000..8b8813597f
--- /dev/null
+++ b/monai/networks/nets/controlnet.py
@@ -0,0 +1,467 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# =========================================================================
+# Adapted from https://github.com/huggingface/diffusers
+# which has the following license:
+# https://github.com/huggingface/diffusers/blob/main/LICENSE
+#
+# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =========================================================================
+
+from __future__ import annotations
+
+from collections.abc import Sequence
+
+import torch
+from torch import nn
+
+from monai.networks.blocks import Convolution
+from monai.networks.nets.diffusion_model_unet import get_down_block, get_mid_block, get_timestep_embedding
+from monai.utils import ensure_tuple_rep
+
+
+class ControlNetConditioningEmbedding(nn.Module):
+ """
+ Network to encode the conditioning into a latent space.
+ """
+
+ def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, channels: Sequence[int]):
+ super().__init__()
+
+ self.conv_in = Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=in_channels,
+ out_channels=channels[0],
+ strides=1,
+ kernel_size=3,
+ padding=1,
+ adn_ordering="A",
+ act="SWISH",
+ )
+
+ self.blocks = nn.ModuleList([])
+
+ for i in range(len(channels) - 1):
+ channel_in = channels[i]
+ channel_out = channels[i + 1]
+ self.blocks.append(
+ Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=channel_in,
+ out_channels=channel_in,
+ strides=1,
+ kernel_size=3,
+ padding=1,
+ adn_ordering="A",
+ act="SWISH",
+ )
+ )
+
+ self.blocks.append(
+ Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=channel_in,
+ out_channels=channel_out,
+ strides=2,
+ kernel_size=3,
+ padding=1,
+ adn_ordering="A",
+ act="SWISH",
+ )
+ )
+
+ self.conv_out = zero_module(
+ Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=channels[-1],
+ out_channels=out_channels,
+ strides=1,
+ kernel_size=3,
+ padding=1,
+ conv_only=True,
+ )
+ )
+
+ def forward(self, conditioning):
+ embedding = self.conv_in(conditioning)
+
+ for block in self.blocks:
+ embedding = block(embedding)
+
+ embedding = self.conv_out(embedding)
+
+ return embedding
+
+
+def zero_module(module):
+ for p in module.parameters():
+ nn.init.zeros_(p)
+ return module
+
+
+class ControlNet(nn.Module):
+ """
+ Control network for diffusion models based on Zhang and Agrawala "Adding Conditional Control to Text-to-Image
+ Diffusion Models" (https://arxiv.org/abs/2302.05543)
+
+ Args:
+ spatial_dims: number of spatial dimensions.
+ in_channels: number of input channels.
+ num_res_blocks: number of residual blocks (see ResnetBlock) per level.
+ channels: tuple of block output channels.
+ attention_levels: list of levels to add attention.
+ norm_num_groups: number of groups for the normalization.
+ norm_eps: epsilon for the normalization.
+ resblock_updown: if True use residual blocks for up/downsampling.
+ num_head_channels: number of channels in each attention head.
+ with_conditioning: if True add spatial transformers to perform conditioning.
+ transformer_num_layers: number of layers of Transformer blocks to use.
+ cross_attention_dim: number of context dimensions to use.
+ num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds`
+ classes.
+ upcast_attention: if True, upcast attention operations to full precision.
+ conditioning_embedding_in_channels: number of input channels for the conditioning embedding.
+ conditioning_embedding_num_channels: number of channels for the blocks in the conditioning embedding.
+ include_fc: whether to include the final linear layer. Default to True.
+ use_combined_linear: whether to use a single linear layer for qkv projection, default to True.
+ use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
+ (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
+ """
+
+ def __init__(
+ self,
+ spatial_dims: int,
+ in_channels: int,
+ num_res_blocks: Sequence[int] | int = (2, 2, 2, 2),
+ channels: Sequence[int] = (32, 64, 64, 64),
+ attention_levels: Sequence[bool] = (False, False, True, True),
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-6,
+ resblock_updown: bool = False,
+ num_head_channels: int | Sequence[int] = 8,
+ with_conditioning: bool = False,
+ transformer_num_layers: int = 1,
+ cross_attention_dim: int | None = None,
+ num_class_embeds: int | None = None,
+ upcast_attention: bool = False,
+ conditioning_embedding_in_channels: int = 1,
+ conditioning_embedding_num_channels: Sequence[int] = (16, 32, 96, 256),
+ include_fc: bool = True,
+ use_combined_linear: bool = False,
+ use_flash_attention: bool = False,
+ ) -> None:
+ super().__init__()
+ if with_conditioning is True and cross_attention_dim is None:
+ raise ValueError(
+ "ControlNet expects dimension of the cross-attention conditioning (cross_attention_dim) "
+ "to be specified when with_conditioning=True."
+ )
+ if cross_attention_dim is not None and with_conditioning is False:
+ raise ValueError("ControlNet expects with_conditioning=True when specifying the cross_attention_dim.")
+
+ # All number of channels should be multiple of num_groups
+ if any((out_channel % norm_num_groups) != 0 for out_channel in channels):
+ raise ValueError(
+ f"ControlNet expects all channels to be a multiple of norm_num_groups, but got"
+ f" channels={channels} and norm_num_groups={norm_num_groups}"
+ )
+
+ if len(channels) != len(attention_levels):
+ raise ValueError(
+ f"ControlNet expects channels to have the same length as attention_levels, but got "
+ f"channels={channels} and attention_levels={attention_levels}"
+ )
+
+ if isinstance(num_head_channels, int):
+ num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels))
+
+ if len(num_head_channels) != len(attention_levels):
+ raise ValueError(
+ f"num_head_channels should have the same length as attention_levels, but got channels={channels} and "
+ f"attention_levels={attention_levels} . For the i levels without attention,"
+ " i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored."
+ )
+
+ if isinstance(num_res_blocks, int):
+ num_res_blocks = ensure_tuple_rep(num_res_blocks, len(channels))
+
+ if len(num_res_blocks) != len(channels):
+ raise ValueError(
+ f"`num_res_blocks` should be a single integer or a tuple of integers with the same length as "
+ f"`num_channels`, but got num_res_blocks={num_res_blocks} and channels={channels}."
+ )
+
+ self.in_channels = in_channels
+ self.block_out_channels = channels
+ self.num_res_blocks = num_res_blocks
+ self.attention_levels = attention_levels
+ self.num_head_channels = num_head_channels
+ self.with_conditioning = with_conditioning
+
+ # input
+ self.conv_in = Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=in_channels,
+ out_channels=channels[0],
+ strides=1,
+ kernel_size=3,
+ padding=1,
+ conv_only=True,
+ )
+
+ # time
+ time_embed_dim = channels[0] * 4
+ self.time_embed = nn.Sequential(
+ nn.Linear(channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim)
+ )
+
+ # class embedding
+ self.num_class_embeds = num_class_embeds
+ if num_class_embeds is not None:
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
+
+ # control net conditioning embedding
+ self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
+ spatial_dims=spatial_dims,
+ in_channels=conditioning_embedding_in_channels,
+ channels=conditioning_embedding_num_channels,
+ out_channels=channels[0],
+ )
+
+ # down
+ self.down_blocks = nn.ModuleList([])
+ self.controlnet_down_blocks = nn.ModuleList([])
+ output_channel = channels[0]
+
+ controlnet_block = Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=output_channel,
+ out_channels=output_channel,
+ strides=1,
+ kernel_size=1,
+ padding=0,
+ conv_only=True,
+ )
+ controlnet_block = zero_module(controlnet_block.conv)
+ self.controlnet_down_blocks.append(controlnet_block)
+
+ for i in range(len(channels)):
+ input_channel = output_channel
+ output_channel = channels[i]
+ is_final_block = i == len(channels) - 1
+
+ down_block = get_down_block(
+ spatial_dims=spatial_dims,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=time_embed_dim,
+ num_res_blocks=num_res_blocks[i],
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ add_downsample=not is_final_block,
+ resblock_updown=resblock_updown,
+ with_attn=(attention_levels[i] and not with_conditioning),
+ with_cross_attn=(attention_levels[i] and with_conditioning),
+ num_head_channels=num_head_channels[i],
+ transformer_num_layers=transformer_num_layers,
+ cross_attention_dim=cross_attention_dim,
+ upcast_attention=upcast_attention,
+ include_fc=include_fc,
+ use_combined_linear=use_combined_linear,
+ use_flash_attention=use_flash_attention,
+ )
+
+ self.down_blocks.append(down_block)
+
+ for _ in range(num_res_blocks[i]):
+ controlnet_block = Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=output_channel,
+ out_channels=output_channel,
+ strides=1,
+ kernel_size=1,
+ padding=0,
+ conv_only=True,
+ )
+ controlnet_block = zero_module(controlnet_block)
+ self.controlnet_down_blocks.append(controlnet_block)
+ #
+ if not is_final_block:
+ controlnet_block = Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=output_channel,
+ out_channels=output_channel,
+ strides=1,
+ kernel_size=1,
+ padding=0,
+ conv_only=True,
+ )
+ controlnet_block = zero_module(controlnet_block)
+ self.controlnet_down_blocks.append(controlnet_block)
+
+ # mid
+ mid_block_channel = channels[-1]
+
+ self.middle_block = get_mid_block(
+ spatial_dims=spatial_dims,
+ in_channels=mid_block_channel,
+ temb_channels=time_embed_dim,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ with_conditioning=with_conditioning,
+ num_head_channels=num_head_channels[-1],
+ transformer_num_layers=transformer_num_layers,
+ cross_attention_dim=cross_attention_dim,
+ upcast_attention=upcast_attention,
+ include_fc=include_fc,
+ use_combined_linear=use_combined_linear,
+ use_flash_attention=use_flash_attention,
+ )
+
+ controlnet_block = Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=output_channel,
+ out_channels=output_channel,
+ strides=1,
+ kernel_size=1,
+ padding=0,
+ conv_only=True,
+ )
+ controlnet_block = zero_module(controlnet_block)
+ self.controlnet_mid_block = controlnet_block
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ timesteps: torch.Tensor,
+ controlnet_cond: torch.Tensor,
+ conditioning_scale: float = 1.0,
+ context: torch.Tensor | None = None,
+ class_labels: torch.Tensor | None = None,
+ ) -> tuple[list[torch.Tensor], torch.Tensor]:
+ """
+ Args:
+ x: input tensor (N, C, H, W, [D]).
+ timesteps: timestep tensor (N,).
+ controlnet_cond: controlnet conditioning tensor (N, C, H, W, [D])
+ conditioning_scale: conditioning scale.
+ context: context tensor (N, 1, cross_attention_dim), where cross_attention_dim is specified in the model init.
+ class_labels: context tensor (N, ).
+ """
+ # 1. time
+ t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0])
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=x.dtype)
+ emb = self.time_embed(t_emb)
+
+ # 2. class
+ if self.num_class_embeds is not None:
+ if class_labels is None:
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
+ class_emb = self.class_embedding(class_labels)
+ class_emb = class_emb.to(dtype=x.dtype)
+ emb = emb + class_emb
+
+ # 3. initial convolution
+ h = self.conv_in(x)
+
+ controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
+
+ h += controlnet_cond
+
+ # 4. down
+ if context is not None and self.with_conditioning is False:
+ raise ValueError("model should have with_conditioning = True if context is provided")
+ down_block_res_samples: list[torch.Tensor] = [h]
+ for downsample_block in self.down_blocks:
+ h, res_samples = downsample_block(hidden_states=h, temb=emb, context=context)
+ for residual in res_samples:
+ down_block_res_samples.append(residual)
+
+ # 5. mid
+ h = self.middle_block(hidden_states=h, temb=emb, context=context)
+
+ # 6. Control net blocks
+ controlnet_down_block_res_samples = []
+
+ for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
+ down_block_res_sample = controlnet_block(down_block_res_sample)
+ controlnet_down_block_res_samples.append(down_block_res_sample)
+
+ down_block_res_samples = controlnet_down_block_res_samples
+
+ mid_block_res_sample: torch.Tensor = self.controlnet_mid_block(h)
+
+ # 6. scaling
+ down_block_res_samples = [h * conditioning_scale for h in down_block_res_samples]
+ mid_block_res_sample *= conditioning_scale
+
+ return down_block_res_samples, mid_block_res_sample
+
+ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None:
+ """
+ Load a state dict from a ControlNet trained with
+ [MONAI Generative](https://github.com/Project-MONAI/GenerativeModels).
+
+ Args:
+ old_state_dict: state dict from the old ControlNet model.
+ """
+
+ new_state_dict = self.state_dict()
+ # if all keys match, just load the state dict
+ if all(k in new_state_dict for k in old_state_dict):
+ print("All keys match, loading state dict.")
+ self.load_state_dict(old_state_dict)
+ return
+
+ if verbose:
+ # print all new_state_dict keys that are not in old_state_dict
+ for k in new_state_dict:
+ if k not in old_state_dict:
+ print(f"key {k} not found in old state dict")
+ # and vice versa
+ print("----------------------------------------------")
+ for k in old_state_dict:
+ if k not in new_state_dict:
+ print(f"key {k} not found in new state dict")
+
+ # copy over all matching keys
+ for k in new_state_dict:
+ if k in old_state_dict:
+ new_state_dict[k] = old_state_dict.pop(k)
+
+ # fix the attention blocks
+ attention_blocks = [k.replace(".out_proj.weight", "") for k in new_state_dict if "out_proj.weight" in k]
+ for block in attention_blocks:
+ # projection
+ new_state_dict[f"{block}.out_proj.weight"] = old_state_dict.pop(f"{block}.to_out.0.weight")
+ new_state_dict[f"{block}.out_proj.bias"] = old_state_dict.pop(f"{block}.to_out.0.bias")
+
+ if verbose:
+ # print all remaining keys in old_state_dict
+ print("remaining keys in old_state_dict:", old_state_dict.keys())
+ self.load_state_dict(new_state_dict)
diff --git a/monai/networks/nets/daf3d.py b/monai/networks/nets/daf3d.py
index c9a18c746a..02e5bb022a 100644
--- a/monai/networks/nets/daf3d.py
+++ b/monai/networks/nets/daf3d.py
@@ -13,6 +13,7 @@
from collections import OrderedDict
from collections.abc import Callable, Sequence
+from functools import partial
import torch
import torch.nn as nn
@@ -25,6 +26,7 @@
from monai.networks.blocks.convolutions import Convolution
from monai.networks.blocks.feature_pyramid_network import ExtraFPNBlock, FeaturePyramidNetwork
from monai.networks.layers.factories import Conv, Norm
+from monai.networks.layers.utils import get_norm_layer
from monai.networks.nets.resnet import ResNet, ResNetBottleneck
__all__ = [
@@ -170,33 +172,37 @@ class Daf3dResNetBottleneck(ResNetBottleneck):
spatial_dims: number of spatial dimensions of the input image.
stride: stride to use for second conv layer.
downsample: which downsample layer to use.
+ norm: which normalization layer to use. Defaults to group.
"""
expansion = 2
- def __init__(self, in_planes, planes, spatial_dims=3, stride=1, downsample=None):
- norm_type: Callable = Norm[Norm.GROUP, spatial_dims]
+ def __init__(
+ self, in_planes, planes, spatial_dims=3, stride=1, downsample=None, norm=("group", {"num_groups": 32})
+ ):
conv_type: Callable = Conv[Conv.CONV, spatial_dims]
+ norm_layer = partial(get_norm_layer, name=norm, spatial_dims=spatial_dims)
+
# in case downsample uses batch norm, change to group norm
if isinstance(downsample, nn.Sequential):
downsample = nn.Sequential(
conv_type(in_planes, planes * self.expansion, kernel_size=1, stride=stride, bias=False),
- norm_type(num_groups=32, num_channels=planes * self.expansion),
+ norm_layer(channels=planes * self.expansion),
)
super().__init__(in_planes, planes, spatial_dims, stride, downsample)
# change norm from batch to group norm
- self.bn1 = norm_type(num_groups=32, num_channels=planes)
- self.bn2 = norm_type(num_groups=32, num_channels=planes)
- self.bn3 = norm_type(num_groups=32, num_channels=planes * self.expansion)
+ self.bn1 = norm_layer(channels=planes)
+ self.bn2 = norm_layer(channels=planes)
+ self.bn3 = norm_layer(channels=planes * self.expansion)
# adapt second convolution to work with groups
self.conv2 = conv_type(planes, planes, kernel_size=3, padding=1, stride=stride, groups=32, bias=False)
# adapt activation function
- self.relu = nn.PReLU() # type: ignore
+ self.relu = nn.PReLU()
class Daf3dResNetDilatedBottleneck(Daf3dResNetBottleneck):
@@ -212,8 +218,10 @@ class Daf3dResNetDilatedBottleneck(Daf3dResNetBottleneck):
downsample: which downsample layer to use.
"""
- def __init__(self, in_planes, planes, spatial_dims=3, stride=1, downsample=None):
- super().__init__(in_planes, planes, spatial_dims, stride, downsample)
+ def __init__(
+ self, in_planes, planes, spatial_dims=3, stride=1, downsample=None, norm=("group", {"num_groups": 32})
+ ):
+ super().__init__(in_planes, planes, spatial_dims, stride, downsample, norm)
# add dilation in second convolution
conv_type: Callable = Conv[Conv.CONV, spatial_dims]
@@ -287,7 +295,7 @@ def __init__(
n_input_channels, self.in_planes, kernel_size=7, stride=(1, 2, 2), padding=(3, 3, 3), bias=False
)
self.bn1 = norm_type(32, 64)
- self.relu = nn.PReLU() # type: ignore
+ self.relu = nn.PReLU()
# adapt layers to our needs
self.layer1 = self._make_layer(Daf3dResNetBottleneck, block_inplanes[0], layers[0], spatial_dims, shortcut_type)
diff --git a/monai/networks/nets/diffusion_model_unet.py b/monai/networks/nets/diffusion_model_unet.py
new file mode 100644
index 0000000000..65d6053acc
--- /dev/null
+++ b/monai/networks/nets/diffusion_model_unet.py
@@ -0,0 +1,2053 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# =========================================================================
+# Adapted from https://github.com/huggingface/diffusers
+# which has the following license:
+# https://github.com/huggingface/diffusers/blob/main/LICENSE
+#
+# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =========================================================================
+
+from __future__ import annotations
+
+import math
+from collections.abc import Sequence
+
+import torch
+from torch import nn
+
+from monai.networks.blocks import Convolution, CrossAttentionBlock, MLPBlock, SABlock, SpatialAttentionBlock, Upsample
+from monai.networks.layers.factories import Pool
+from monai.utils import ensure_tuple_rep, optional_import
+
+Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")
+
+__all__ = ["DiffusionModelUNet"]
+
+
+def zero_module(module: nn.Module) -> nn.Module:
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+class DiffusionUNetTransformerBlock(nn.Module):
+ """
+ A Transformer block that allows for the input dimension to differ from the hidden dimension.
+
+ Args:
+ num_channels: number of channels in the input and output.
+ num_attention_heads: number of heads to use for multi-head attention.
+ num_head_channels: number of channels in each attention head.
+ dropout: dropout probability to use.
+ cross_attention_dim: size of the context vector for cross attention.
+ upcast_attention: if True, upcast attention operations to full precision.
+ use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
+ (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
+ include_fc: whether to include the final linear layer. Default to True.
+ use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
+
+ """
+
+ def __init__(
+ self,
+ num_channels: int,
+ num_attention_heads: int,
+ num_head_channels: int,
+ dropout: float = 0.0,
+ cross_attention_dim: int | None = None,
+ upcast_attention: bool = False,
+ use_flash_attention: bool = False,
+ include_fc: bool = True,
+ use_combined_linear: bool = False,
+ ) -> None:
+ super().__init__()
+ self.attn1 = SABlock(
+ hidden_size=num_attention_heads * num_head_channels,
+ hidden_input_size=num_channels,
+ num_heads=num_attention_heads,
+ dim_head=num_head_channels,
+ dropout_rate=dropout,
+ attention_dtype=torch.float if upcast_attention else None,
+ include_fc=include_fc,
+ use_combined_linear=use_combined_linear,
+ use_flash_attention=use_flash_attention,
+ )
+ self.ff = MLPBlock(hidden_size=num_channels, mlp_dim=num_channels * 4, act="GEGLU", dropout_rate=dropout)
+ self.attn2 = CrossAttentionBlock(
+ hidden_size=num_attention_heads * num_head_channels,
+ num_heads=num_attention_heads,
+ hidden_input_size=num_channels,
+ context_input_size=cross_attention_dim,
+ dim_head=num_head_channels,
+ dropout_rate=dropout,
+ attention_dtype=torch.float if upcast_attention else None,
+ use_flash_attention=use_flash_attention,
+ )
+ self.norm1 = nn.LayerNorm(num_channels)
+ self.norm2 = nn.LayerNorm(num_channels)
+ self.norm3 = nn.LayerNorm(num_channels)
+
+ def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor:
+ # 1. Self-Attention
+ x = self.attn1(self.norm1(x)) + x
+
+ # 2. Cross-Attention
+ x = self.attn2(self.norm2(x), context=context) + x
+
+ # 3. Feed-forward
+ x = self.ff(self.norm3(x)) + x
+ return x
+
+
+class SpatialTransformer(nn.Module):
+ """
+ Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
+ standard transformer action. Finally, reshape to image.
+
+ Args:
+ spatial_dims: number of spatial dimensions.
+ in_channels: number of channels in the input and output.
+ num_attention_heads: number of heads to use for multi-head attention.
+ num_head_channels: number of channels in each attention head.
+ num_layers: number of layers of Transformer blocks to use.
+ dropout: dropout probability to use.
+ norm_num_groups: number of groups for the normalization.
+ norm_eps: epsilon for the normalization.
+ cross_attention_dim: number of context dimensions to use.
+ upcast_attention: if True, upcast attention operations to full precision.
+ include_fc: whether to include the final linear layer. Default to True.
+ use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
+ use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
+ (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
+
+ """
+
+ def __init__(
+ self,
+ spatial_dims: int,
+ in_channels: int,
+ num_attention_heads: int,
+ num_head_channels: int,
+ num_layers: int = 1,
+ dropout: float = 0.0,
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-6,
+ cross_attention_dim: int | None = None,
+ upcast_attention: bool = False,
+ include_fc: bool = True,
+ use_combined_linear: bool = False,
+ use_flash_attention: bool = False,
+ ) -> None:
+ super().__init__()
+ self.spatial_dims = spatial_dims
+ self.in_channels = in_channels
+ inner_dim = num_attention_heads * num_head_channels
+
+ self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True)
+
+ self.proj_in = Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=in_channels,
+ out_channels=inner_dim,
+ strides=1,
+ kernel_size=1,
+ padding=0,
+ conv_only=True,
+ )
+
+ self.transformer_blocks = nn.ModuleList(
+ [
+ DiffusionUNetTransformerBlock(
+ num_channels=inner_dim,
+ num_attention_heads=num_attention_heads,
+ num_head_channels=num_head_channels,
+ dropout=dropout,
+ cross_attention_dim=cross_attention_dim,
+ upcast_attention=upcast_attention,
+ include_fc=include_fc,
+ use_combined_linear=use_combined_linear,
+ use_flash_attention=use_flash_attention,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ self.proj_out = zero_module(
+ Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=inner_dim,
+ out_channels=in_channels,
+ strides=1,
+ kernel_size=1,
+ padding=0,
+ conv_only=True,
+ )
+ )
+
+ def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor:
+ # note: if no context is given, cross-attention defaults to self-attention
+ batch = channel = height = width = depth = -1
+ if self.spatial_dims == 2:
+ batch, channel, height, width = x.shape
+ if self.spatial_dims == 3:
+ batch, channel, height, width, depth = x.shape
+
+ residual = x
+ x = self.norm(x)
+ x = self.proj_in(x)
+
+ inner_dim = x.shape[1]
+
+ if self.spatial_dims == 2:
+ x = x.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
+ if self.spatial_dims == 3:
+ x = x.permute(0, 2, 3, 4, 1).reshape(batch, height * width * depth, inner_dim)
+
+ for block in self.transformer_blocks:
+ x = block(x, context=context)
+
+ if self.spatial_dims == 2:
+ x = x.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
+ if self.spatial_dims == 3:
+ x = x.reshape(batch, height, width, depth, inner_dim).permute(0, 4, 1, 2, 3).contiguous()
+
+ x = self.proj_out(x)
+ return x + residual
+
+
+def get_timestep_embedding(timesteps: torch.Tensor, embedding_dim: int, max_period: int = 10000) -> torch.Tensor:
+ """
+ Create sinusoidal timestep embeddings following the implementation in Ho et al. "Denoising Diffusion Probabilistic
+ Models" https://arxiv.org/abs/2006.11239.
+
+ Args:
+ timesteps: a 1-D Tensor of N indices, one per batch element.
+ embedding_dim: the dimension of the output.
+ max_period: controls the minimum frequency of the embeddings.
+ """
+ if timesteps.ndim != 1:
+ raise ValueError("Timesteps should be a 1d-array")
+
+ half_dim = embedding_dim // 2
+ exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
+ freqs = torch.exp(exponent / half_dim)
+
+ args = timesteps[:, None].float() * freqs[None, :]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+
+ # zero pad
+ if embedding_dim % 2 == 1:
+ embedding = torch.nn.functional.pad(embedding, (0, 1, 0, 0))
+
+ return embedding
+
+
+class DiffusionUnetDownsample(nn.Module):
+ """
+ Downsampling layer.
+
+ Args:
+ spatial_dims: number of spatial dimensions.
+ num_channels: number of input channels.
+ use_conv: if True uses Convolution instead of Pool average to perform downsampling. In case that use_conv is
+ False, the number of output channels must be the same as the number of input channels.
+ out_channels: number of output channels.
+ padding: controls the amount of implicit zero-paddings on both sides for padding number of points
+ for each dimension.
+ """
+
+ def __init__(
+ self, spatial_dims: int, num_channels: int, use_conv: bool, out_channels: int | None = None, padding: int = 1
+ ) -> None:
+ super().__init__()
+ self.num_channels = num_channels
+ self.out_channels = out_channels or num_channels
+ self.use_conv = use_conv
+ if use_conv:
+ self.op = Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=self.num_channels,
+ out_channels=self.out_channels,
+ strides=2,
+ kernel_size=3,
+ padding=padding,
+ conv_only=True,
+ )
+ else:
+ if self.num_channels != self.out_channels:
+ raise ValueError("num_channels and out_channels must be equal when use_conv=False")
+ self.op = Pool[Pool.AVG, spatial_dims](kernel_size=2, stride=2)
+
+ def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor:
+ del emb
+ if x.shape[1] != self.num_channels:
+ raise ValueError(
+ f"Input number of channels ({x.shape[1]}) is not equal to expected number of channels "
+ f"({self.num_channels})"
+ )
+ output: torch.Tensor = self.op(x)
+ return output
+
+
+class WrappedUpsample(Upsample):
+ """
+ Wraps MONAI upsample block to allow for calling with timestep embeddings.
+ """
+
+ def forward(self, x: torch.Tensor, emb: torch.Tensor | None = None) -> torch.Tensor:
+ del emb
+ upsampled: torch.Tensor = super().forward(x)
+ return upsampled
+
+
+class DiffusionUNetResnetBlock(nn.Module):
+ """
+ Residual block with timestep conditioning.
+
+ Args:
+ spatial_dims: The number of spatial dimensions.
+ in_channels: number of input channels.
+ temb_channels: number of timestep embedding channels.
+ out_channels: number of output channels.
+ up: if True, performs upsampling.
+ down: if True, performs downsampling.
+ norm_num_groups: number of groups for the group normalization.
+ norm_eps: epsilon for the group normalization.
+ """
+
+ def __init__(
+ self,
+ spatial_dims: int,
+ in_channels: int,
+ temb_channels: int,
+ out_channels: int | None = None,
+ up: bool = False,
+ down: bool = False,
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-6,
+ ) -> None:
+ super().__init__()
+ self.spatial_dims = spatial_dims
+ self.channels = in_channels
+ self.emb_channels = temb_channels
+ self.out_channels = out_channels or in_channels
+ self.up = up
+ self.down = down
+
+ self.norm1 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=norm_eps, affine=True)
+ self.nonlinearity = nn.SiLU()
+ self.conv1 = Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=in_channels,
+ out_channels=self.out_channels,
+ strides=1,
+ kernel_size=3,
+ padding=1,
+ conv_only=True,
+ )
+
+ self.upsample = self.downsample = None
+ if self.up:
+ self.upsample = WrappedUpsample(
+ spatial_dims=spatial_dims,
+ mode="nontrainable",
+ in_channels=in_channels,
+ out_channels=in_channels,
+ interp_mode="nearest",
+ scale_factor=2.0,
+ align_corners=None,
+ )
+ elif down:
+ self.downsample = DiffusionUnetDownsample(spatial_dims, in_channels, use_conv=False)
+
+ self.time_emb_proj = nn.Linear(temb_channels, self.out_channels)
+
+ self.norm2 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=self.out_channels, eps=norm_eps, affine=True)
+ self.conv2 = zero_module(
+ Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=self.out_channels,
+ out_channels=self.out_channels,
+ strides=1,
+ kernel_size=3,
+ padding=1,
+ conv_only=True,
+ )
+ )
+ self.skip_connection: nn.Module
+ if self.out_channels == in_channels:
+ self.skip_connection = nn.Identity()
+ else:
+ self.skip_connection = Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=in_channels,
+ out_channels=self.out_channels,
+ strides=1,
+ kernel_size=1,
+ padding=0,
+ conv_only=True,
+ )
+
+ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
+ h = x
+ h = self.norm1(h)
+ h = self.nonlinearity(h)
+
+ if self.upsample is not None:
+ x = self.upsample(x)
+ h = self.upsample(h)
+ elif self.downsample is not None:
+ x = self.downsample(x)
+ h = self.downsample(h)
+
+ h = self.conv1(h)
+
+ if self.spatial_dims == 2:
+ temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None]
+ else:
+ temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None, None]
+ h = h + temb
+
+ h = self.norm2(h)
+ h = self.nonlinearity(h)
+ h = self.conv2(h)
+ output: torch.Tensor = self.skip_connection(x) + h
+ return output
+
+
+class DownBlock(nn.Module):
+ """
+ Unet's down block containing resnet and downsamplers blocks.
+
+ Args:
+ spatial_dims: The number of spatial dimensions.
+ in_channels: number of input channels.
+ out_channels: number of output channels.
+ temb_channels: number of timestep embedding channels.
+ num_res_blocks: number of residual blocks.
+ norm_num_groups: number of groups for the group normalization.
+ norm_eps: epsilon for the group normalization.
+ add_downsample: if True add downsample block.
+ resblock_updown: if True use residual blocks for downsampling.
+ downsample_padding: padding used in the downsampling block.
+ """
+
+ def __init__(
+ self,
+ spatial_dims: int,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ num_res_blocks: int = 1,
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-6,
+ add_downsample: bool = True,
+ resblock_updown: bool = False,
+ downsample_padding: int = 1,
+ ) -> None:
+ super().__init__()
+ self.resblock_updown = resblock_updown
+
+ resnets = []
+
+ for i in range(num_res_blocks):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ DiffusionUNetResnetBlock(
+ spatial_dims=spatial_dims,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsampler: nn.Module | None
+ if resblock_updown:
+ self.downsampler = DiffusionUNetResnetBlock(
+ spatial_dims=spatial_dims,
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ down=True,
+ )
+ else:
+ self.downsampler = DiffusionUnetDownsample(
+ spatial_dims=spatial_dims,
+ num_channels=out_channels,
+ use_conv=True,
+ out_channels=out_channels,
+ padding=downsample_padding,
+ )
+ else:
+ self.downsampler = None
+
+ def forward(
+ self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None
+ ) -> tuple[torch.Tensor, list[torch.Tensor]]:
+ del context
+ output_states = []
+
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states, temb)
+ output_states.append(hidden_states)
+
+ if self.downsampler is not None:
+ hidden_states = self.downsampler(hidden_states, temb)
+ output_states.append(hidden_states)
+
+ return hidden_states, output_states
+
+
+class AttnDownBlock(nn.Module):
+ """
+ Unet's down block containing resnet, downsamplers and self-attention blocks.
+
+ Args:
+ spatial_dims: The number of spatial dimensions.
+ in_channels: number of input channels.
+ out_channels: number of output channels.
+ temb_channels: number of timestep embedding channels.
+ num_res_blocks: number of residual blocks.
+ norm_num_groups: number of groups for the group normalization.
+ norm_eps: epsilon for the group normalization.
+ add_downsample: if True add downsample block.
+ resblock_updown: if True use residual blocks for downsampling.
+ downsample_padding: padding used in the downsampling block.
+ num_head_channels: number of channels in each attention head.
+ include_fc: whether to include the final linear layer. Default to True.
+ use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
+ use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
+ (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
+ """
+
+ def __init__(
+ self,
+ spatial_dims: int,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ num_res_blocks: int = 1,
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-6,
+ add_downsample: bool = True,
+ resblock_updown: bool = False,
+ downsample_padding: int = 1,
+ num_head_channels: int = 1,
+ include_fc: bool = True,
+ use_combined_linear: bool = False,
+ use_flash_attention: bool = False,
+ ) -> None:
+ super().__init__()
+ self.resblock_updown = resblock_updown
+
+ resnets = []
+ attentions = []
+
+ for i in range(num_res_blocks):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ DiffusionUNetResnetBlock(
+ spatial_dims=spatial_dims,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ )
+ )
+ attentions.append(
+ SpatialAttentionBlock(
+ spatial_dims=spatial_dims,
+ num_channels=out_channels,
+ num_head_channels=num_head_channels,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ include_fc=include_fc,
+ use_combined_linear=use_combined_linear,
+ use_flash_attention=use_flash_attention,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ self.downsampler: nn.Module | None
+ if add_downsample:
+ if resblock_updown:
+ self.downsampler = DiffusionUNetResnetBlock(
+ spatial_dims=spatial_dims,
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ down=True,
+ )
+ else:
+ self.downsampler = DiffusionUnetDownsample(
+ spatial_dims=spatial_dims,
+ num_channels=out_channels,
+ use_conv=True,
+ out_channels=out_channels,
+ padding=downsample_padding,
+ )
+ else:
+ self.downsampler = None
+
+ def forward(
+ self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None
+ ) -> tuple[torch.Tensor, list[torch.Tensor]]:
+ del context
+ output_states = []
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states).contiguous()
+ output_states.append(hidden_states)
+
+ if self.downsampler is not None:
+ hidden_states = self.downsampler(hidden_states, temb)
+ output_states.append(hidden_states)
+
+ return hidden_states, output_states
+
+
+class CrossAttnDownBlock(nn.Module):
+ """
+ Unet's down block containing resnet, downsamplers and cross-attention blocks.
+
+ Args:
+ spatial_dims: number of spatial dimensions.
+ in_channels: number of input channels.
+ out_channels: number of output channels.
+ temb_channels: number of timestep embedding channels.
+ num_res_blocks: number of residual blocks.
+ norm_num_groups: number of groups for the group normalization.
+ norm_eps: epsilon for the group normalization.
+ add_downsample: if True add downsample block.
+ resblock_updown: if True use residual blocks for downsampling.
+ downsample_padding: padding used in the downsampling block.
+ num_head_channels: number of channels in each attention head.
+ transformer_num_layers: number of layers of Transformer blocks to use.
+ cross_attention_dim: number of context dimensions to use.
+ upcast_attention: if True, upcast attention operations to full precision.
+ dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers.
+ include_fc: whether to include the final linear layer. Default to True.
+ use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
+ use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
+ (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
+ """
+
+ def __init__(
+ self,
+ spatial_dims: int,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ num_res_blocks: int = 1,
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-6,
+ add_downsample: bool = True,
+ resblock_updown: bool = False,
+ downsample_padding: int = 1,
+ num_head_channels: int = 1,
+ transformer_num_layers: int = 1,
+ cross_attention_dim: int | None = None,
+ upcast_attention: bool = False,
+ dropout_cattn: float = 0.0,
+ include_fc: bool = True,
+ use_combined_linear: bool = False,
+ use_flash_attention: bool = False,
+ ) -> None:
+ super().__init__()
+ self.resblock_updown = resblock_updown
+
+ resnets = []
+ attentions = []
+
+ for i in range(num_res_blocks):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ DiffusionUNetResnetBlock(
+ spatial_dims=spatial_dims,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ )
+ )
+
+ attentions.append(
+ SpatialTransformer(
+ spatial_dims=spatial_dims,
+ in_channels=out_channels,
+ num_attention_heads=out_channels // num_head_channels,
+ num_head_channels=num_head_channels,
+ num_layers=transformer_num_layers,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ cross_attention_dim=cross_attention_dim,
+ upcast_attention=upcast_attention,
+ dropout=dropout_cattn,
+ include_fc=include_fc,
+ use_combined_linear=use_combined_linear,
+ use_flash_attention=use_flash_attention,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ self.downsampler: nn.Module | None
+ if add_downsample:
+ if resblock_updown:
+ self.downsampler = DiffusionUNetResnetBlock(
+ spatial_dims=spatial_dims,
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ down=True,
+ )
+ else:
+ self.downsampler = DiffusionUnetDownsample(
+ spatial_dims=spatial_dims,
+ num_channels=out_channels,
+ use_conv=True,
+ out_channels=out_channels,
+ padding=downsample_padding,
+ )
+ else:
+ self.downsampler = None
+
+ def forward(
+ self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None
+ ) -> tuple[torch.Tensor, list[torch.Tensor]]:
+ output_states = []
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states, context=context).contiguous()
+ output_states.append(hidden_states)
+
+ if self.downsampler is not None:
+ hidden_states = self.downsampler(hidden_states, temb)
+ output_states.append(hidden_states)
+
+ return hidden_states, output_states
+
+
+class AttnMidBlock(nn.Module):
+ """
+ Unet's mid block containing resnet and self-attention blocks.
+
+ Args:
+ spatial_dims: The number of spatial dimensions.
+ in_channels: number of input channels.
+ temb_channels: number of timestep embedding channels.
+ norm_num_groups: number of groups for the group normalization.
+ norm_eps: epsilon for the group normalization.
+ num_head_channels: number of channels in each attention head.
+ include_fc: whether to include the final linear layer. Default to True.
+ use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
+ use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
+ (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
+ """
+
+ def __init__(
+ self,
+ spatial_dims: int,
+ in_channels: int,
+ temb_channels: int,
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-6,
+ num_head_channels: int = 1,
+ include_fc: bool = True,
+ use_combined_linear: bool = False,
+ use_flash_attention: bool = False,
+ ) -> None:
+ super().__init__()
+
+ self.resnet_1 = DiffusionUNetResnetBlock(
+ spatial_dims=spatial_dims,
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ )
+ self.attention = SpatialAttentionBlock(
+ spatial_dims=spatial_dims,
+ num_channels=in_channels,
+ num_head_channels=num_head_channels,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ include_fc=include_fc,
+ use_combined_linear=use_combined_linear,
+ use_flash_attention=use_flash_attention,
+ )
+
+ self.resnet_2 = DiffusionUNetResnetBlock(
+ spatial_dims=spatial_dims,
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ )
+
+ def forward(
+ self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None
+ ) -> torch.Tensor:
+ del context
+ hidden_states = self.resnet_1(hidden_states, temb)
+ hidden_states = self.attention(hidden_states).contiguous()
+ hidden_states = self.resnet_2(hidden_states, temb)
+
+ return hidden_states
+
+
+class CrossAttnMidBlock(nn.Module):
+ """
+ Unet's mid block containing resnet and cross-attention blocks.
+
+ Args:
+ spatial_dims: The number of spatial dimensions.
+ in_channels: number of input channels.
+ temb_channels: number of timestep embedding channels
+ norm_num_groups: number of groups for the group normalization.
+ norm_eps: epsilon for the group normalization.
+ num_head_channels: number of channels in each attention head.
+ transformer_num_layers: number of layers of Transformer blocks to use.
+ cross_attention_dim: number of context dimensions to use.
+ upcast_attention: if True, upcast attention operations to full precision.
+ include_fc: whether to include the final linear layer. Default to True.
+ use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
+ use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
+ (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
+ """
+
+ def __init__(
+ self,
+ spatial_dims: int,
+ in_channels: int,
+ temb_channels: int,
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-6,
+ num_head_channels: int = 1,
+ transformer_num_layers: int = 1,
+ cross_attention_dim: int | None = None,
+ upcast_attention: bool = False,
+ dropout_cattn: float = 0.0,
+ include_fc: bool = True,
+ use_combined_linear: bool = False,
+ use_flash_attention: bool = False,
+ ) -> None:
+ super().__init__()
+
+ self.resnet_1 = DiffusionUNetResnetBlock(
+ spatial_dims=spatial_dims,
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ )
+ self.attention = SpatialTransformer(
+ spatial_dims=spatial_dims,
+ in_channels=in_channels,
+ num_attention_heads=in_channels // num_head_channels,
+ num_head_channels=num_head_channels,
+ num_layers=transformer_num_layers,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ cross_attention_dim=cross_attention_dim,
+ upcast_attention=upcast_attention,
+ dropout=dropout_cattn,
+ include_fc=include_fc,
+ use_combined_linear=use_combined_linear,
+ use_flash_attention=use_flash_attention,
+ )
+ self.resnet_2 = DiffusionUNetResnetBlock(
+ spatial_dims=spatial_dims,
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ )
+
+ def forward(
+ self, hidden_states: torch.Tensor, temb: torch.Tensor, context: torch.Tensor | None = None
+ ) -> torch.Tensor:
+ hidden_states = self.resnet_1(hidden_states, temb)
+ hidden_states = self.attention(hidden_states, context=context)
+ hidden_states = self.resnet_2(hidden_states, temb)
+
+ return hidden_states
+
+
+class UpBlock(nn.Module):
+ """
+ Unet's up block containing resnet and upsamplers blocks.
+
+ Args:
+ spatial_dims: The number of spatial dimensions.
+ in_channels: number of input channels.
+ prev_output_channel: number of channels from residual connection.
+ out_channels: number of output channels.
+ temb_channels: number of timestep embedding channels.
+ num_res_blocks: number of residual blocks.
+ norm_num_groups: number of groups for the group normalization.
+ norm_eps: epsilon for the group normalization.
+ add_upsample: if True add downsample block.
+ resblock_updown: if True use residual blocks for upsampling.
+ """
+
+ def __init__(
+ self,
+ spatial_dims: int,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ num_res_blocks: int = 1,
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-6,
+ add_upsample: bool = True,
+ resblock_updown: bool = False,
+ ) -> None:
+ super().__init__()
+ self.resblock_updown = resblock_updown
+ resnets = []
+
+ for i in range(num_res_blocks):
+ res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ DiffusionUNetResnetBlock(
+ spatial_dims=spatial_dims,
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ self.upsampler: nn.Module | None
+ if add_upsample:
+ if resblock_updown:
+ self.upsampler = DiffusionUNetResnetBlock(
+ spatial_dims=spatial_dims,
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ up=True,
+ )
+ else:
+ post_conv = Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=out_channels,
+ out_channels=out_channels,
+ strides=1,
+ kernel_size=3,
+ padding=1,
+ conv_only=True,
+ )
+ self.upsampler = WrappedUpsample(
+ spatial_dims=spatial_dims,
+ mode="nontrainable",
+ in_channels=out_channels,
+ out_channels=out_channels,
+ interp_mode="nearest",
+ scale_factor=2.0,
+ post_conv=post_conv,
+ align_corners=None,
+ )
+
+ else:
+ self.upsampler = None
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ res_hidden_states_list: list[torch.Tensor],
+ temb: torch.Tensor,
+ context: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ del context
+ for resnet in self.resnets:
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_list[-1]
+ res_hidden_states_list = res_hidden_states_list[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ hidden_states = resnet(hidden_states, temb)
+
+ if self.upsampler is not None:
+ hidden_states = self.upsampler(hidden_states, temb)
+
+ return hidden_states
+
+
+class AttnUpBlock(nn.Module):
+ """
+ Unet's up block containing resnet, upsamplers, and self-attention blocks.
+
+ Args:
+ spatial_dims: The number of spatial dimensions.
+ in_channels: number of input channels.
+ prev_output_channel: number of channels from residual connection.
+ out_channels: number of output channels.
+ temb_channels: number of timestep embedding channels.
+ num_res_blocks: number of residual blocks.
+ norm_num_groups: number of groups for the group normalization.
+ norm_eps: epsilon for the group normalization.
+ add_upsample: if True add downsample block.
+ resblock_updown: if True use residual blocks for upsampling.
+ num_head_channels: number of channels in each attention head.
+ include_fc: whether to include the final linear layer. Default to True.
+ use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
+ use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
+ (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
+ """
+
+ def __init__(
+ self,
+ spatial_dims: int,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ num_res_blocks: int = 1,
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-6,
+ add_upsample: bool = True,
+ resblock_updown: bool = False,
+ num_head_channels: int = 1,
+ include_fc: bool = True,
+ use_combined_linear: bool = False,
+ use_flash_attention: bool = False,
+ ) -> None:
+ super().__init__()
+ self.resblock_updown = resblock_updown
+
+ resnets = []
+ attentions = []
+
+ for i in range(num_res_blocks):
+ res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ DiffusionUNetResnetBlock(
+ spatial_dims=spatial_dims,
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ )
+ )
+ attentions.append(
+ SpatialAttentionBlock(
+ spatial_dims=spatial_dims,
+ num_channels=out_channels,
+ num_head_channels=num_head_channels,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ include_fc=include_fc,
+ use_combined_linear=use_combined_linear,
+ use_flash_attention=use_flash_attention,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+ self.attentions = nn.ModuleList(attentions)
+
+ self.upsampler: nn.Module | None
+ if add_upsample:
+ if resblock_updown:
+ self.upsampler = DiffusionUNetResnetBlock(
+ spatial_dims=spatial_dims,
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ up=True,
+ )
+ else:
+
+ post_conv = Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=out_channels,
+ out_channels=out_channels,
+ strides=1,
+ kernel_size=3,
+ padding=1,
+ conv_only=True,
+ )
+ self.upsampler = WrappedUpsample(
+ spatial_dims=spatial_dims,
+ mode="nontrainable",
+ in_channels=out_channels,
+ out_channels=out_channels,
+ interp_mode="nearest",
+ scale_factor=2.0,
+ post_conv=post_conv,
+ align_corners=None,
+ )
+ else:
+ self.upsampler = None
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ res_hidden_states_list: list[torch.Tensor],
+ temb: torch.Tensor,
+ context: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ del context
+ for resnet, attn in zip(self.resnets, self.attentions):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_list[-1]
+ res_hidden_states_list = res_hidden_states_list[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states).contiguous()
+
+ if self.upsampler is not None:
+ hidden_states = self.upsampler(hidden_states, temb)
+
+ return hidden_states
+
+
+class CrossAttnUpBlock(nn.Module):
+ """
+ Unet's up block containing resnet, upsamplers, and self-attention blocks.
+
+ Args:
+ spatial_dims: The number of spatial dimensions.
+ in_channels: number of input channels.
+ prev_output_channel: number of channels from residual connection.
+ out_channels: number of output channels.
+ temb_channels: number of timestep embedding channels.
+ num_res_blocks: number of residual blocks.
+ norm_num_groups: number of groups for the group normalization.
+ norm_eps: epsilon for the group normalization.
+ add_upsample: if True add downsample block.
+ resblock_updown: if True use residual blocks for upsampling.
+ num_head_channels: number of channels in each attention head.
+ transformer_num_layers: number of layers of Transformer blocks to use.
+ cross_attention_dim: number of context dimensions to use.
+ upcast_attention: if True, upcast attention operations to full precision.
+ dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers.
+ include_fc: whether to include the final linear layer. Default to True.
+ use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
+ use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
+ (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
+ """
+
+ def __init__(
+ self,
+ spatial_dims: int,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ num_res_blocks: int = 1,
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-6,
+ add_upsample: bool = True,
+ resblock_updown: bool = False,
+ num_head_channels: int = 1,
+ transformer_num_layers: int = 1,
+ cross_attention_dim: int | None = None,
+ upcast_attention: bool = False,
+ dropout_cattn: float = 0.0,
+ include_fc: bool = True,
+ use_combined_linear: bool = False,
+ use_flash_attention: bool = False,
+ ) -> None:
+ super().__init__()
+ self.resblock_updown = resblock_updown
+
+ resnets = []
+ attentions = []
+
+ for i in range(num_res_blocks):
+ res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ DiffusionUNetResnetBlock(
+ spatial_dims=spatial_dims,
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ )
+ )
+ attentions.append(
+ SpatialTransformer(
+ spatial_dims=spatial_dims,
+ in_channels=out_channels,
+ num_attention_heads=out_channels // num_head_channels,
+ num_head_channels=num_head_channels,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ num_layers=transformer_num_layers,
+ cross_attention_dim=cross_attention_dim,
+ upcast_attention=upcast_attention,
+ dropout=dropout_cattn,
+ include_fc=include_fc,
+ use_combined_linear=use_combined_linear,
+ use_flash_attention=use_flash_attention,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ self.upsampler: nn.Module | None
+ if add_upsample:
+ if resblock_updown:
+ self.upsampler = DiffusionUNetResnetBlock(
+ spatial_dims=spatial_dims,
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ up=True,
+ )
+ else:
+
+ post_conv = Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=out_channels,
+ out_channels=out_channels,
+ strides=1,
+ kernel_size=3,
+ padding=1,
+ conv_only=True,
+ )
+ self.upsampler = WrappedUpsample(
+ spatial_dims=spatial_dims,
+ mode="nontrainable",
+ in_channels=out_channels,
+ out_channels=out_channels,
+ interp_mode="nearest",
+ scale_factor=2.0,
+ post_conv=post_conv,
+ align_corners=None,
+ )
+ else:
+ self.upsampler = None
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ res_hidden_states_list: list[torch.Tensor],
+ temb: torch.Tensor,
+ context: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ for resnet, attn in zip(self.resnets, self.attentions):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_list[-1]
+ res_hidden_states_list = res_hidden_states_list[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states, context=context)
+
+ if self.upsampler is not None:
+ hidden_states = self.upsampler(hidden_states, temb)
+
+ return hidden_states
+
+
+def get_down_block(
+ spatial_dims: int,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ num_res_blocks: int,
+ norm_num_groups: int,
+ norm_eps: float,
+ add_downsample: bool,
+ resblock_updown: bool,
+ with_attn: bool,
+ with_cross_attn: bool,
+ num_head_channels: int,
+ transformer_num_layers: int,
+ cross_attention_dim: int | None,
+ upcast_attention: bool = False,
+ dropout_cattn: float = 0.0,
+ include_fc: bool = True,
+ use_combined_linear: bool = False,
+ use_flash_attention: bool = False,
+) -> nn.Module:
+ if with_attn:
+ return AttnDownBlock(
+ spatial_dims=spatial_dims,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ num_res_blocks=num_res_blocks,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ add_downsample=add_downsample,
+ resblock_updown=resblock_updown,
+ num_head_channels=num_head_channels,
+ include_fc=include_fc,
+ use_combined_linear=use_combined_linear,
+ use_flash_attention=use_flash_attention,
+ )
+ elif with_cross_attn:
+ return CrossAttnDownBlock(
+ spatial_dims=spatial_dims,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ num_res_blocks=num_res_blocks,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ add_downsample=add_downsample,
+ resblock_updown=resblock_updown,
+ num_head_channels=num_head_channels,
+ transformer_num_layers=transformer_num_layers,
+ cross_attention_dim=cross_attention_dim,
+ upcast_attention=upcast_attention,
+ dropout_cattn=dropout_cattn,
+ include_fc=include_fc,
+ use_combined_linear=use_combined_linear,
+ use_flash_attention=use_flash_attention,
+ )
+ else:
+ return DownBlock(
+ spatial_dims=spatial_dims,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ num_res_blocks=num_res_blocks,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ add_downsample=add_downsample,
+ resblock_updown=resblock_updown,
+ )
+
+
+def get_mid_block(
+ spatial_dims: int,
+ in_channels: int,
+ temb_channels: int,
+ norm_num_groups: int,
+ norm_eps: float,
+ with_conditioning: bool,
+ num_head_channels: int,
+ transformer_num_layers: int,
+ cross_attention_dim: int | None,
+ upcast_attention: bool = False,
+ dropout_cattn: float = 0.0,
+ include_fc: bool = True,
+ use_combined_linear: bool = False,
+ use_flash_attention: bool = False,
+) -> nn.Module:
+ if with_conditioning:
+ return CrossAttnMidBlock(
+ spatial_dims=spatial_dims,
+ in_channels=in_channels,
+ temb_channels=temb_channels,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ num_head_channels=num_head_channels,
+ transformer_num_layers=transformer_num_layers,
+ cross_attention_dim=cross_attention_dim,
+ upcast_attention=upcast_attention,
+ dropout_cattn=dropout_cattn,
+ include_fc=include_fc,
+ use_combined_linear=use_combined_linear,
+ use_flash_attention=use_flash_attention,
+ )
+ else:
+ return AttnMidBlock(
+ spatial_dims=spatial_dims,
+ in_channels=in_channels,
+ temb_channels=temb_channels,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ num_head_channels=num_head_channels,
+ include_fc=include_fc,
+ use_combined_linear=use_combined_linear,
+ use_flash_attention=use_flash_attention,
+ )
+
+
+def get_up_block(
+ spatial_dims: int,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ num_res_blocks: int,
+ norm_num_groups: int,
+ norm_eps: float,
+ add_upsample: bool,
+ resblock_updown: bool,
+ with_attn: bool,
+ with_cross_attn: bool,
+ num_head_channels: int,
+ transformer_num_layers: int,
+ cross_attention_dim: int | None,
+ upcast_attention: bool = False,
+ dropout_cattn: float = 0.0,
+ include_fc: bool = True,
+ use_combined_linear: bool = False,
+ use_flash_attention: bool = False,
+) -> nn.Module:
+ if with_attn:
+ return AttnUpBlock(
+ spatial_dims=spatial_dims,
+ in_channels=in_channels,
+ prev_output_channel=prev_output_channel,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ num_res_blocks=num_res_blocks,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ add_upsample=add_upsample,
+ resblock_updown=resblock_updown,
+ num_head_channels=num_head_channels,
+ include_fc=include_fc,
+ use_combined_linear=use_combined_linear,
+ use_flash_attention=use_flash_attention,
+ )
+ elif with_cross_attn:
+ return CrossAttnUpBlock(
+ spatial_dims=spatial_dims,
+ in_channels=in_channels,
+ prev_output_channel=prev_output_channel,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ num_res_blocks=num_res_blocks,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ add_upsample=add_upsample,
+ resblock_updown=resblock_updown,
+ num_head_channels=num_head_channels,
+ transformer_num_layers=transformer_num_layers,
+ cross_attention_dim=cross_attention_dim,
+ upcast_attention=upcast_attention,
+ dropout_cattn=dropout_cattn,
+ include_fc=include_fc,
+ use_combined_linear=use_combined_linear,
+ use_flash_attention=use_flash_attention,
+ )
+ else:
+ return UpBlock(
+ spatial_dims=spatial_dims,
+ in_channels=in_channels,
+ prev_output_channel=prev_output_channel,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ num_res_blocks=num_res_blocks,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ add_upsample=add_upsample,
+ resblock_updown=resblock_updown,
+ )
+
+
+class DiffusionModelUNet(nn.Module):
+ """
+ Unet network with timestep embedding and attention mechanisms for conditioning based on
+ Rombach et al. "High-Resolution Image Synthesis with Latent Diffusion Models" https://arxiv.org/abs/2112.10752
+ and Pinaya et al. "Brain Imaging Generation with Latent Diffusion Models" https://arxiv.org/abs/2209.07162
+
+ Args:
+ spatial_dims: number of spatial dimensions.
+ in_channels: number of input channels.
+ out_channels: number of output channels.
+ num_res_blocks: number of residual blocks (see _ResnetBlock) per level.
+ channels: tuple of block output channels.
+ attention_levels: list of levels to add attention.
+ norm_num_groups: number of groups for the normalization.
+ norm_eps: epsilon for the normalization.
+ resblock_updown: if True use residual blocks for up/downsampling.
+ num_head_channels: number of channels in each attention head.
+ with_conditioning: if True add spatial transformers to perform conditioning.
+ transformer_num_layers: number of layers of Transformer blocks to use.
+ cross_attention_dim: number of context dimensions to use.
+ num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds`
+ classes.
+ upcast_attention: if True, upcast attention operations to full precision.
+ dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers.
+ include_fc: whether to include the final linear layer. Default to True.
+ use_combined_linear: whether to use a single linear layer for qkv projection, default to True.
+ use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
+ (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
+ """
+
+ def __init__(
+ self,
+ spatial_dims: int,
+ in_channels: int,
+ out_channels: int,
+ num_res_blocks: Sequence[int] | int = (2, 2, 2, 2),
+ channels: Sequence[int] = (32, 64, 64, 64),
+ attention_levels: Sequence[bool] = (False, False, True, True),
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-6,
+ resblock_updown: bool = False,
+ num_head_channels: int | Sequence[int] = 8,
+ with_conditioning: bool = False,
+ transformer_num_layers: int = 1,
+ cross_attention_dim: int | None = None,
+ num_class_embeds: int | None = None,
+ upcast_attention: bool = False,
+ dropout_cattn: float = 0.0,
+ include_fc: bool = True,
+ use_combined_linear: bool = False,
+ use_flash_attention: bool = False,
+ ) -> None:
+ super().__init__()
+ if with_conditioning is True and cross_attention_dim is None:
+ raise ValueError(
+ "DiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) "
+ "when using with_conditioning."
+ )
+ if cross_attention_dim is not None and with_conditioning is False:
+ raise ValueError(
+ "DiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim."
+ )
+ if dropout_cattn > 1.0 or dropout_cattn < 0.0:
+ raise ValueError("Dropout cannot be negative or >1.0!")
+
+ # All number of channels should be multiple of num_groups
+ if any((out_channel % norm_num_groups) != 0 for out_channel in channels):
+ raise ValueError("DiffusionModelUNet expects all num_channels being multiple of norm_num_groups")
+
+ if len(channels) != len(attention_levels):
+ raise ValueError("DiffusionModelUNet expects num_channels being same size of attention_levels")
+
+ if isinstance(num_head_channels, int):
+ num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels))
+
+ if len(num_head_channels) != len(attention_levels):
+ raise ValueError(
+ "num_head_channels should have the same length as attention_levels. For the i levels without attention,"
+ " i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored."
+ )
+
+ if isinstance(num_res_blocks, int):
+ num_res_blocks = ensure_tuple_rep(num_res_blocks, len(channels))
+
+ if len(num_res_blocks) != len(channels):
+ raise ValueError(
+ "`num_res_blocks` should be a single integer or a tuple of integers with the same length as "
+ "`num_channels`."
+ )
+
+ self.in_channels = in_channels
+ self.block_out_channels = channels
+ self.out_channels = out_channels
+ self.num_res_blocks = num_res_blocks
+ self.attention_levels = attention_levels
+ self.num_head_channels = num_head_channels
+ self.with_conditioning = with_conditioning
+
+ # input
+ self.conv_in = Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=in_channels,
+ out_channels=channels[0],
+ strides=1,
+ kernel_size=3,
+ padding=1,
+ conv_only=True,
+ )
+
+ # time
+ time_embed_dim = channels[0] * 4
+ self.time_embed = nn.Sequential(
+ nn.Linear(channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim)
+ )
+
+ # class embedding
+ self.num_class_embeds = num_class_embeds
+ if num_class_embeds is not None:
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
+
+ # down
+ self.down_blocks = nn.ModuleList([])
+ output_channel = channels[0]
+ for i in range(len(channels)):
+ input_channel = output_channel
+ output_channel = channels[i]
+ is_final_block = i == len(channels) - 1
+
+ down_block = get_down_block(
+ spatial_dims=spatial_dims,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=time_embed_dim,
+ num_res_blocks=num_res_blocks[i],
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ add_downsample=not is_final_block,
+ resblock_updown=resblock_updown,
+ with_attn=(attention_levels[i] and not with_conditioning),
+ with_cross_attn=(attention_levels[i] and with_conditioning),
+ num_head_channels=num_head_channels[i],
+ transformer_num_layers=transformer_num_layers,
+ cross_attention_dim=cross_attention_dim,
+ upcast_attention=upcast_attention,
+ dropout_cattn=dropout_cattn,
+ include_fc=include_fc,
+ use_combined_linear=use_combined_linear,
+ use_flash_attention=use_flash_attention,
+ )
+
+ self.down_blocks.append(down_block)
+
+ # mid
+ self.middle_block = get_mid_block(
+ spatial_dims=spatial_dims,
+ in_channels=channels[-1],
+ temb_channels=time_embed_dim,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ with_conditioning=with_conditioning,
+ num_head_channels=num_head_channels[-1],
+ transformer_num_layers=transformer_num_layers,
+ cross_attention_dim=cross_attention_dim,
+ upcast_attention=upcast_attention,
+ dropout_cattn=dropout_cattn,
+ include_fc=include_fc,
+ use_combined_linear=use_combined_linear,
+ use_flash_attention=use_flash_attention,
+ )
+
+ # up
+ self.up_blocks = nn.ModuleList([])
+ reversed_block_out_channels = list(reversed(channels))
+ reversed_num_res_blocks = list(reversed(num_res_blocks))
+ reversed_attention_levels = list(reversed(attention_levels))
+ reversed_num_head_channels = list(reversed(num_head_channels))
+ output_channel = reversed_block_out_channels[0]
+ for i in range(len(reversed_block_out_channels)):
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ input_channel = reversed_block_out_channels[min(i + 1, len(channels) - 1)]
+
+ is_final_block = i == len(channels) - 1
+
+ up_block = get_up_block(
+ spatial_dims=spatial_dims,
+ in_channels=input_channel,
+ prev_output_channel=prev_output_channel,
+ out_channels=output_channel,
+ temb_channels=time_embed_dim,
+ num_res_blocks=reversed_num_res_blocks[i] + 1,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ add_upsample=not is_final_block,
+ resblock_updown=resblock_updown,
+ with_attn=(reversed_attention_levels[i] and not with_conditioning),
+ with_cross_attn=(reversed_attention_levels[i] and with_conditioning),
+ num_head_channels=reversed_num_head_channels[i],
+ transformer_num_layers=transformer_num_layers,
+ cross_attention_dim=cross_attention_dim,
+ upcast_attention=upcast_attention,
+ dropout_cattn=dropout_cattn,
+ include_fc=include_fc,
+ use_combined_linear=use_combined_linear,
+ use_flash_attention=use_flash_attention,
+ )
+
+ self.up_blocks.append(up_block)
+
+ # out
+ self.out = nn.Sequential(
+ nn.GroupNorm(num_groups=norm_num_groups, num_channels=channels[0], eps=norm_eps, affine=True),
+ nn.SiLU(),
+ zero_module(
+ Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=channels[0],
+ out_channels=out_channels,
+ strides=1,
+ kernel_size=3,
+ padding=1,
+ conv_only=True,
+ )
+ ),
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ timesteps: torch.Tensor,
+ context: torch.Tensor | None = None,
+ class_labels: torch.Tensor | None = None,
+ down_block_additional_residuals: tuple[torch.Tensor] | None = None,
+ mid_block_additional_residual: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ """
+ Args:
+ x: input tensor (N, C, SpatialDims).
+ timesteps: timestep tensor (N,).
+ context: context tensor (N, 1, ContextDim).
+ class_labels: context tensor (N, ).
+ down_block_additional_residuals: additional residual tensors for down blocks (N, C, FeatureMapsDims).
+ mid_block_additional_residual: additional residual tensor for mid block (N, C, FeatureMapsDims).
+ """
+ # 1. time
+ t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0])
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=x.dtype)
+ emb = self.time_embed(t_emb)
+
+ # 2. class
+ if self.num_class_embeds is not None:
+ if class_labels is None:
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
+ class_emb = self.class_embedding(class_labels)
+ class_emb = class_emb.to(dtype=x.dtype)
+ emb = emb + class_emb
+
+ # 3. initial convolution
+ h = self.conv_in(x)
+
+ # 4. down
+ if context is not None and self.with_conditioning is False:
+ raise ValueError("model should have with_conditioning = True if context is provided")
+ down_block_res_samples: list[torch.Tensor] = [h]
+ for downsample_block in self.down_blocks:
+ h, res_samples = downsample_block(hidden_states=h, temb=emb, context=context)
+ for residual in res_samples:
+ down_block_res_samples.append(residual)
+
+ # Additional residual conections for Controlnets
+ if down_block_additional_residuals is not None:
+ new_down_block_res_samples: list[torch.Tensor] = []
+ for down_block_res_sample, down_block_additional_residual in zip(
+ down_block_res_samples, down_block_additional_residuals
+ ):
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
+ new_down_block_res_samples += [down_block_res_sample]
+
+ down_block_res_samples = new_down_block_res_samples
+
+ # 5. mid
+ h = self.middle_block(hidden_states=h, temb=emb, context=context)
+
+ # Additional residual conections for Controlnets
+ if mid_block_additional_residual is not None:
+ h = h + mid_block_additional_residual
+
+ # 6. up
+ for upsample_block in self.up_blocks:
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+ h = upsample_block(hidden_states=h, res_hidden_states_list=res_samples, temb=emb, context=context)
+
+ # 7. output block
+ output: torch.Tensor = self.out(h)
+
+ return output
+
+ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None:
+ """
+ Load a state dict from a DiffusionModelUNet trained with
+ [MONAI Generative](https://github.com/Project-MONAI/GenerativeModels).
+
+ Args:
+ old_state_dict: state dict from the old DecoderOnlyTransformer model.
+ """
+
+ new_state_dict = self.state_dict()
+ # if all keys match, just load the state dict
+ if all(k in new_state_dict for k in old_state_dict):
+ print("All keys match, loading state dict.")
+ self.load_state_dict(old_state_dict)
+ return
+
+ if verbose:
+ # print all new_state_dict keys that are not in old_state_dict
+ for k in new_state_dict:
+ if k not in old_state_dict:
+ print(f"key {k} not found in old state dict")
+ # and vice versa
+ print("----------------------------------------------")
+ for k in old_state_dict:
+ if k not in new_state_dict:
+ print(f"key {k} not found in new state dict")
+
+ # copy over all matching keys
+ for k in new_state_dict:
+ if k in old_state_dict:
+ new_state_dict[k] = old_state_dict.pop(k)
+
+ # fix the attention blocks
+ attention_blocks = [k.replace(".attn.to_k.weight", "") for k in new_state_dict if "attn.to_k.weight" in k]
+ for block in attention_blocks:
+ new_state_dict[f"{block}.attn.to_q.weight"] = old_state_dict.pop(f"{block}.to_q.weight")
+ new_state_dict[f"{block}.attn.to_k.weight"] = old_state_dict.pop(f"{block}.to_k.weight")
+ new_state_dict[f"{block}.attn.to_v.weight"] = old_state_dict.pop(f"{block}.to_v.weight")
+ new_state_dict[f"{block}.attn.to_q.bias"] = old_state_dict.pop(f"{block}.to_q.bias")
+ new_state_dict[f"{block}.attn.to_k.bias"] = old_state_dict.pop(f"{block}.to_k.bias")
+ new_state_dict[f"{block}.attn.to_v.bias"] = old_state_dict.pop(f"{block}.to_v.bias")
+
+ # projection
+ new_state_dict[f"{block}.attn.out_proj.weight"] = old_state_dict.pop(f"{block}.proj_attn.weight")
+ new_state_dict[f"{block}.attn.out_proj.bias"] = old_state_dict.pop(f"{block}.proj_attn.bias")
+
+ # fix the cross attention blocks
+ cross_attention_blocks = [
+ k.replace(".out_proj.weight", "")
+ for k in new_state_dict
+ if "out_proj.weight" in k and "transformer_blocks" in k
+ ]
+ for block in cross_attention_blocks:
+ new_state_dict[f"{block}.out_proj.weight"] = old_state_dict.pop(f"{block}.to_out.0.weight")
+ new_state_dict[f"{block}.out_proj.bias"] = old_state_dict.pop(f"{block}.to_out.0.bias")
+
+ # fix the upsample conv blocks which were renamed postconv
+ for k in new_state_dict:
+ if "postconv" in k:
+ old_name = k.replace("postconv", "conv")
+ new_state_dict[k] = old_state_dict.pop(old_name)
+ if verbose:
+ # print all remaining keys in old_state_dict
+ print("remaining keys in old_state_dict:", old_state_dict.keys())
+ self.load_state_dict(new_state_dict)
+
+
+class DiffusionModelEncoder(nn.Module):
+ """
+ Classification Network based on the Encoder of the Diffusion Model, followed by fully connected layers. This network is based on
+ Wolleb et al. "Diffusion Models for Medical Anomaly Detection" (https://arxiv.org/abs/2203.04306).
+
+ Args:
+ spatial_dims: number of spatial dimensions.
+ in_channels: number of input channels.
+ out_channels: number of output channels.
+ num_res_blocks: number of residual blocks (see _ResnetBlock) per level.
+ channels: tuple of block output channels.
+ attention_levels: list of levels to add attention.
+ norm_num_groups: number of groups for the normalization.
+ norm_eps: epsilon for the normalization.
+ resblock_updown: if True use residual blocks for downsampling.
+ num_head_channels: number of channels in each attention head.
+ with_conditioning: if True add spatial transformers to perform conditioning.
+ transformer_num_layers: number of layers of Transformer blocks to use.
+ cross_attention_dim: number of context dimensions to use.
+ num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` classes.
+ upcast_attention: if True, upcast attention operations to full precision.
+ """
+
+ def __init__(
+ self,
+ spatial_dims: int,
+ in_channels: int,
+ out_channels: int,
+ num_res_blocks: Sequence[int] | int = (2, 2, 2, 2),
+ channels: Sequence[int] = (32, 64, 64, 64),
+ attention_levels: Sequence[bool] = (False, False, True, True),
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-6,
+ resblock_updown: bool = False,
+ num_head_channels: int | Sequence[int] = 8,
+ with_conditioning: bool = False,
+ transformer_num_layers: int = 1,
+ cross_attention_dim: int | None = None,
+ num_class_embeds: int | None = None,
+ upcast_attention: bool = False,
+ include_fc: bool = True,
+ use_combined_linear: bool = False,
+ use_flash_attention: bool = False,
+ ) -> None:
+ super().__init__()
+ if with_conditioning is True and cross_attention_dim is None:
+ raise ValueError(
+ "DiffusionModelEncoder expects dimension of the cross-attention conditioning (cross_attention_dim) "
+ "when using with_conditioning."
+ )
+ if cross_attention_dim is not None and with_conditioning is False:
+ raise ValueError(
+ "DiffusionModelEncoder expects with_conditioning=True when specifying the cross_attention_dim."
+ )
+
+ # All number of channels should be multiple of num_groups
+ if any((out_channel % norm_num_groups) != 0 for out_channel in channels):
+ raise ValueError("DiffusionModelEncoder expects all num_channels being multiple of norm_num_groups")
+ if len(channels) != len(attention_levels):
+ raise ValueError("DiffusionModelEncoder expects num_channels being same size of attention_levels")
+
+ if isinstance(num_head_channels, int):
+ num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels))
+
+ if isinstance(num_res_blocks, int):
+ num_res_blocks = ensure_tuple_rep(num_res_blocks, len(channels))
+
+ if len(num_head_channels) != len(attention_levels):
+ raise ValueError(
+ "num_head_channels should have the same length as attention_levels. For the i levels without attention,"
+ " i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored."
+ )
+
+ self.in_channels = in_channels
+ self.block_out_channels = channels
+ self.out_channels = out_channels
+ self.num_res_blocks = num_res_blocks
+ self.attention_levels = attention_levels
+ self.num_head_channels = num_head_channels
+ self.with_conditioning = with_conditioning
+
+ # input
+ self.conv_in = Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=in_channels,
+ out_channels=channels[0],
+ strides=1,
+ kernel_size=3,
+ padding=1,
+ conv_only=True,
+ )
+
+ # time
+ time_embed_dim = channels[0] * 4
+ self.time_embed = nn.Sequential(
+ nn.Linear(channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim)
+ )
+
+ # class embedding
+ self.num_class_embeds = num_class_embeds
+ if num_class_embeds is not None:
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
+
+ # down
+ self.down_blocks = nn.ModuleList([])
+ output_channel = channels[0]
+ for i in range(len(channels)):
+ input_channel = output_channel
+ output_channel = channels[i]
+ is_final_block = i == len(channels) # - 1
+
+ down_block = get_down_block(
+ spatial_dims=spatial_dims,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=time_embed_dim,
+ num_res_blocks=num_res_blocks[i],
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ add_downsample=not is_final_block,
+ resblock_updown=resblock_updown,
+ with_attn=(attention_levels[i] and not with_conditioning),
+ with_cross_attn=(attention_levels[i] and with_conditioning),
+ num_head_channels=num_head_channels[i],
+ transformer_num_layers=transformer_num_layers,
+ cross_attention_dim=cross_attention_dim,
+ upcast_attention=upcast_attention,
+ include_fc=include_fc,
+ use_combined_linear=use_combined_linear,
+ use_flash_attention=use_flash_attention,
+ )
+
+ self.down_blocks.append(down_block)
+
+ self.out = nn.Sequential(nn.Linear(4096, 512), nn.ReLU(), nn.Dropout(0.1), nn.Linear(512, self.out_channels))
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ timesteps: torch.Tensor,
+ context: torch.Tensor | None = None,
+ class_labels: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ """
+ Args:
+ x: input tensor (N, C, SpatialDims).
+ timesteps: timestep tensor (N,).
+ context: context tensor (N, 1, ContextDim).
+ class_labels: context tensor (N, ).
+ """
+ # 1. time
+ t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0])
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=x.dtype)
+ emb = self.time_embed(t_emb)
+
+ # 2. class
+ if self.num_class_embeds is not None:
+ if class_labels is None:
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
+ class_emb = self.class_embedding(class_labels)
+ class_emb = class_emb.to(dtype=x.dtype)
+ emb = emb + class_emb
+
+ # 3. initial convolution
+ h = self.conv_in(x)
+
+ # 4. down
+ if context is not None and self.with_conditioning is False:
+ raise ValueError("model should have with_conditioning = True if context is provided")
+ for downsample_block in self.down_blocks:
+ h, _ = downsample_block(hidden_states=h, temb=emb, context=context)
+
+ h = h.reshape(h.shape[0], -1)
+ output: torch.Tensor = self.out(h)
+
+ return output
diff --git a/monai/networks/nets/flexible_unet.py b/monai/networks/nets/flexible_unet.py
index ac2124b5f9..c27b0fc17b 100644
--- a/monai/networks/nets/flexible_unet.py
+++ b/monai/networks/nets/flexible_unet.py
@@ -24,6 +24,7 @@
from monai.networks.layers.utils import get_act_layer
from monai.networks.nets import EfficientNetEncoder
from monai.networks.nets.basic_unet import UpCat
+from monai.networks.nets.resnet import ResNetEncoder
from monai.utils import InterpolateMode, optional_import
__all__ = ["FlexibleUNet", "FlexUNet", "FLEXUNET_BACKBONE", "FlexUNetEncoderRegister"]
@@ -78,6 +79,7 @@ def register_class(self, name: type[Any] | str):
FLEXUNET_BACKBONE = FlexUNetEncoderRegister()
FLEXUNET_BACKBONE.register_class(EfficientNetEncoder)
+FLEXUNET_BACKBONE.register_class(ResNetEncoder)
class UNetDecoder(nn.Module):
@@ -238,7 +240,7 @@ def __init__(
) -> None:
"""
A flexible implement of UNet, in which the backbone/encoder can be replaced with
- any efficient network. Currently the input must have a 2 or 3 spatial dimension
+ any efficient or residual network. Currently the input must have a 2 or 3 spatial dimension
and the spatial size of each dimension must be a multiple of 32 if is_pad parameter
is False.
Please notice each output of backbone must be 2x downsample in spatial dimension
@@ -248,10 +250,11 @@ def __init__(
Args:
in_channels: number of input channels.
out_channels: number of output channels.
- backbone: name of backbones to initialize, only support efficientnet right now,
- can be from [efficientnet-b0,..., efficientnet-b8, efficientnet-l2].
- pretrained: whether to initialize pretrained ImageNet weights, only available
- for spatial_dims=2 and batch norm is used, default to False.
+ backbone: name of backbones to initialize, only support efficientnet and resnet right now,
+ can be from [efficientnet-b0, ..., efficientnet-b8, efficientnet-l2, resnet10, ..., resnet200].
+ pretrained: whether to initialize pretrained weights. ImageNet weights are available for efficient networks
+ if spatial_dims=2 and batch norm is used. MedicalNet weights are available for residual networks
+ if spatial_dims=3 and in_channels=1. Default to False.
decoder_channels: number of output channels for all feature maps in decoder.
`len(decoder_channels)` should equal to `len(encoder_channels) - 1`,default
to (256, 128, 64, 32, 16).
diff --git a/monai/networks/nets/hovernet.py b/monai/networks/nets/hovernet.py
index 5f340c9be6..3745b66bb5 100644
--- a/monai/networks/nets/hovernet.py
+++ b/monai/networks/nets/hovernet.py
@@ -43,7 +43,7 @@
from monai.networks.layers.factories import Conv, Dropout
from monai.networks.layers.utils import get_act_layer, get_norm_layer
from monai.utils.enums import HoVerNetBranch, HoVerNetMode, InterpolateMode, UpsampleMode
-from monai.utils.module import export, look_up_option
+from monai.utils.module import look_up_option
__all__ = ["HoVerNet", "Hovernet", "HoVernet", "HoVerNet"]
@@ -409,7 +409,6 @@ def forward(self, xin: torch.Tensor, short_cuts: list[torch.Tensor]) -> torch.Te
return x
-@export("monai.networks.nets")
class HoVerNet(nn.Module):
"""HoVerNet model
diff --git a/monai/networks/nets/mednext.py b/monai/networks/nets/mednext.py
new file mode 100644
index 0000000000..427572ba60
--- /dev/null
+++ b/monai/networks/nets/mednext.py
@@ -0,0 +1,354 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Portions of this code are derived from the original repository at:
+# https://github.com/MIC-DKFZ/MedNeXt
+# and are used under the terms of the Apache License, Version 2.0.
+
+from __future__ import annotations
+
+from collections.abc import Sequence
+
+import torch
+import torch.nn as nn
+
+from monai.networks.blocks.mednext_block import MedNeXtBlock, MedNeXtDownBlock, MedNeXtOutBlock, MedNeXtUpBlock
+
+__all__ = [
+ "MedNeXt",
+ "MedNeXtSmall",
+ "MedNeXtBase",
+ "MedNeXtMedium",
+ "MedNeXtLarge",
+ "MedNext",
+ "MedNextS",
+ "MedNeXtS",
+ "MedNextSmall",
+ "MedNextB",
+ "MedNeXtB",
+ "MedNextBase",
+ "MedNextM",
+ "MedNeXtM",
+ "MedNextMedium",
+ "MedNextL",
+ "MedNeXtL",
+ "MedNextLarge",
+]
+
+
+class MedNeXt(nn.Module):
+ """
+ MedNeXt model class from paper: https://arxiv.org/pdf/2303.09975
+
+ Args:
+ spatial_dims: spatial dimension of the input data. Defaults to 3.
+ init_filters: number of output channels for initial convolution layer. Defaults to 32.
+ in_channels: number of input channels for the network. Defaults to 1.
+ out_channels: number of output channels for the network. Defaults to 2.
+ encoder_expansion_ratio: expansion ratio for encoder blocks. Defaults to 2.
+ decoder_expansion_ratio: expansion ratio for decoder blocks. Defaults to 2.
+ bottleneck_expansion_ratio: expansion ratio for bottleneck blocks. Defaults to 2.
+ kernel_size: kernel size for convolutions. Defaults to 7.
+ deep_supervision: whether to use deep supervision. Defaults to False.
+ use_residual_connection: whether to use residual connections in standard, down and up blocks. Defaults to False.
+ blocks_down: number of blocks in each encoder stage. Defaults to [2, 2, 2, 2].
+ blocks_bottleneck: number of blocks in bottleneck stage. Defaults to 2.
+ blocks_up: number of blocks in each decoder stage. Defaults to [2, 2, 2, 2].
+ norm_type: type of normalization layer. Defaults to 'group'.
+ global_resp_norm: whether to use Global Response Normalization. Defaults to False. Refer: https://arxiv.org/abs/2301.00808
+ """
+
+ def __init__(
+ self,
+ spatial_dims: int = 3,
+ init_filters: int = 32,
+ in_channels: int = 1,
+ out_channels: int = 2,
+ encoder_expansion_ratio: Sequence[int] | int = 2,
+ decoder_expansion_ratio: Sequence[int] | int = 2,
+ bottleneck_expansion_ratio: int = 2,
+ kernel_size: int = 7,
+ deep_supervision: bool = False,
+ use_residual_connection: bool = False,
+ blocks_down: Sequence[int] = (2, 2, 2, 2),
+ blocks_bottleneck: int = 2,
+ blocks_up: Sequence[int] = (2, 2, 2, 2),
+ norm_type: str = "group",
+ global_resp_norm: bool = False,
+ ):
+ """
+ Initialize the MedNeXt model.
+
+ This method sets up the architecture of the model, including:
+ - Stem convolution
+ - Encoder stages and downsampling blocks
+ - Bottleneck blocks
+ - Decoder stages and upsampling blocks
+ - Output blocks for deep supervision (if enabled)
+ """
+ super().__init__()
+
+ self.do_ds = deep_supervision
+ assert spatial_dims in [2, 3], "`spatial_dims` can only be 2 or 3."
+ spatial_dims_str = f"{spatial_dims}d"
+ enc_kernel_size = dec_kernel_size = kernel_size
+
+ if isinstance(encoder_expansion_ratio, int):
+ encoder_expansion_ratio = [encoder_expansion_ratio] * len(blocks_down)
+
+ if isinstance(decoder_expansion_ratio, int):
+ decoder_expansion_ratio = [decoder_expansion_ratio] * len(blocks_up)
+
+ conv = nn.Conv2d if spatial_dims_str == "2d" else nn.Conv3d
+
+ self.stem = conv(in_channels, init_filters, kernel_size=1)
+
+ enc_stages = []
+ down_blocks = []
+
+ for i, num_blocks in enumerate(blocks_down):
+ enc_stages.append(
+ nn.Sequential(
+ *[
+ MedNeXtBlock(
+ in_channels=init_filters * (2**i),
+ out_channels=init_filters * (2**i),
+ expansion_ratio=encoder_expansion_ratio[i],
+ kernel_size=enc_kernel_size,
+ use_residual_connection=use_residual_connection,
+ norm_type=norm_type,
+ dim=spatial_dims_str,
+ global_resp_norm=global_resp_norm,
+ )
+ for _ in range(num_blocks)
+ ]
+ )
+ )
+
+ down_blocks.append(
+ MedNeXtDownBlock(
+ in_channels=init_filters * (2**i),
+ out_channels=init_filters * (2 ** (i + 1)),
+ expansion_ratio=encoder_expansion_ratio[i],
+ kernel_size=enc_kernel_size,
+ use_residual_connection=use_residual_connection,
+ norm_type=norm_type,
+ dim=spatial_dims_str,
+ )
+ )
+
+ self.enc_stages = nn.ModuleList(enc_stages)
+ self.down_blocks = nn.ModuleList(down_blocks)
+
+ self.bottleneck = nn.Sequential(
+ *[
+ MedNeXtBlock(
+ in_channels=init_filters * (2 ** len(blocks_down)),
+ out_channels=init_filters * (2 ** len(blocks_down)),
+ expansion_ratio=bottleneck_expansion_ratio,
+ kernel_size=dec_kernel_size,
+ use_residual_connection=use_residual_connection,
+ norm_type=norm_type,
+ dim=spatial_dims_str,
+ global_resp_norm=global_resp_norm,
+ )
+ for _ in range(blocks_bottleneck)
+ ]
+ )
+
+ up_blocks = []
+ dec_stages = []
+ for i, num_blocks in enumerate(blocks_up):
+ up_blocks.append(
+ MedNeXtUpBlock(
+ in_channels=init_filters * (2 ** (len(blocks_up) - i)),
+ out_channels=init_filters * (2 ** (len(blocks_up) - i - 1)),
+ expansion_ratio=decoder_expansion_ratio[i],
+ kernel_size=dec_kernel_size,
+ use_residual_connection=use_residual_connection,
+ norm_type=norm_type,
+ dim=spatial_dims_str,
+ global_resp_norm=global_resp_norm,
+ )
+ )
+
+ dec_stages.append(
+ nn.Sequential(
+ *[
+ MedNeXtBlock(
+ in_channels=init_filters * (2 ** (len(blocks_up) - i - 1)),
+ out_channels=init_filters * (2 ** (len(blocks_up) - i - 1)),
+ expansion_ratio=decoder_expansion_ratio[i],
+ kernel_size=dec_kernel_size,
+ use_residual_connection=use_residual_connection,
+ norm_type=norm_type,
+ dim=spatial_dims_str,
+ global_resp_norm=global_resp_norm,
+ )
+ for _ in range(num_blocks)
+ ]
+ )
+ )
+
+ self.up_blocks = nn.ModuleList(up_blocks)
+ self.dec_stages = nn.ModuleList(dec_stages)
+
+ self.out_0 = MedNeXtOutBlock(in_channels=init_filters, n_classes=out_channels, dim=spatial_dims_str)
+
+ if deep_supervision:
+ out_blocks = [
+ MedNeXtOutBlock(in_channels=init_filters * (2**i), n_classes=out_channels, dim=spatial_dims_str)
+ for i in range(1, len(blocks_up) + 1)
+ ]
+
+ out_blocks.reverse()
+ self.out_blocks = nn.ModuleList(out_blocks)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor | Sequence[torch.Tensor]:
+ """
+ Forward pass of the MedNeXt model.
+
+ This method performs the forward pass through the model, including:
+ - Stem convolution
+ - Encoder stages and downsampling
+ - Bottleneck blocks
+ - Decoder stages and upsampling with skip connections
+ - Output blocks for deep supervision (if enabled)
+
+ Args:
+ x (torch.Tensor): Input tensor.
+
+ Returns:
+ torch.Tensor or Sequence[torch.Tensor]: Output tensor(s).
+ """
+ # Apply stem convolution
+ x = self.stem(x)
+
+ # Encoder forward pass
+ enc_outputs = []
+ for enc_stage, down_block in zip(self.enc_stages, self.down_blocks):
+ x = enc_stage(x)
+ enc_outputs.append(x)
+ x = down_block(x)
+
+ # Bottleneck forward pass
+ x = self.bottleneck(x)
+
+ # Initialize deep supervision outputs if enabled
+ if self.do_ds:
+ ds_outputs = []
+
+ # Decoder forward pass with skip connections
+ for i, (up_block, dec_stage) in enumerate(zip(self.up_blocks, self.dec_stages)):
+ if self.do_ds and i < len(self.out_blocks):
+ ds_outputs.append(self.out_blocks[i](x))
+
+ x = up_block(x)
+ x = x + enc_outputs[-(i + 1)]
+ x = dec_stage(x)
+
+ # Final output block
+ x = self.out_0(x)
+
+ # Return output(s)
+ if self.do_ds and self.training:
+ return (x, *ds_outputs[::-1])
+ else:
+ return x
+
+
+# Define the MedNeXt variants as reported in 10.48550/arXiv.2303.09975
+def create_mednext(
+ variant: str,
+ spatial_dims: int = 3,
+ in_channels: int = 1,
+ out_channels: int = 2,
+ kernel_size: int = 3,
+ deep_supervision: bool = False,
+) -> MedNeXt:
+ """
+ Factory method to create MedNeXt variants.
+
+ Args:
+ variant (str): The MedNeXt variant to create ('S', 'B', 'M', or 'L').
+ spatial_dims (int): Number of spatial dimensions. Defaults to 3.
+ in_channels (int): Number of input channels. Defaults to 1.
+ out_channels (int): Number of output channels. Defaults to 2.
+ kernel_size (int): Kernel size for convolutions. Defaults to 3.
+ deep_supervision (bool): Whether to use deep supervision. Defaults to False.
+
+ Returns:
+ MedNeXt: The specified MedNeXt variant.
+
+ Raises:
+ ValueError: If an invalid variant is specified.
+ """
+ common_args = {
+ "spatial_dims": spatial_dims,
+ "in_channels": in_channels,
+ "out_channels": out_channels,
+ "kernel_size": kernel_size,
+ "deep_supervision": deep_supervision,
+ "use_residual_connection": True,
+ "norm_type": "group",
+ "global_resp_norm": False,
+ "init_filters": 32,
+ }
+
+ if variant.upper() == "S":
+ return MedNeXt(
+ encoder_expansion_ratio=2,
+ decoder_expansion_ratio=2,
+ bottleneck_expansion_ratio=2,
+ blocks_down=(2, 2, 2, 2),
+ blocks_bottleneck=2,
+ blocks_up=(2, 2, 2, 2),
+ **common_args, # type: ignore
+ )
+ elif variant.upper() == "B":
+ return MedNeXt(
+ encoder_expansion_ratio=(2, 3, 4, 4),
+ decoder_expansion_ratio=(4, 4, 3, 2),
+ bottleneck_expansion_ratio=4,
+ blocks_down=(2, 2, 2, 2),
+ blocks_bottleneck=2,
+ blocks_up=(2, 2, 2, 2),
+ **common_args, # type: ignore
+ )
+ elif variant.upper() == "M":
+ return MedNeXt(
+ encoder_expansion_ratio=(2, 3, 4, 4),
+ decoder_expansion_ratio=(4, 4, 3, 2),
+ bottleneck_expansion_ratio=4,
+ blocks_down=(3, 4, 4, 4),
+ blocks_bottleneck=4,
+ blocks_up=(4, 4, 4, 3),
+ **common_args, # type: ignore
+ )
+ elif variant.upper() == "L":
+ return MedNeXt(
+ encoder_expansion_ratio=(3, 4, 8, 8),
+ decoder_expansion_ratio=(8, 8, 4, 3),
+ bottleneck_expansion_ratio=8,
+ blocks_down=(3, 4, 8, 8),
+ blocks_bottleneck=8,
+ blocks_up=(8, 8, 4, 3),
+ **common_args, # type: ignore
+ )
+ else:
+ raise ValueError(f"Invalid MedNeXt variant: {variant}")
+
+
+MedNext = MedNeXt
+MedNextS = MedNeXtS = MedNextSmall = MedNeXtSmall = lambda **kwargs: create_mednext("S", **kwargs)
+MedNextB = MedNeXtB = MedNextBase = MedNeXtBase = lambda **kwargs: create_mednext("B", **kwargs)
+MedNextM = MedNeXtM = MedNextMedium = MedNeXtMedium = lambda **kwargs: create_mednext("M", **kwargs)
+MedNextL = MedNeXtL = MedNextLarge = MedNeXtLarge = lambda **kwargs: create_mednext("L", **kwargs)
diff --git a/monai/networks/nets/patchgan_discriminator.py b/monai/networks/nets/patchgan_discriminator.py
new file mode 100644
index 0000000000..74da917694
--- /dev/null
+++ b/monai/networks/nets/patchgan_discriminator.py
@@ -0,0 +1,230 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+from collections.abc import Sequence
+
+import torch
+import torch.nn as nn
+
+from monai.networks.blocks import Convolution
+from monai.networks.layers import Act
+from monai.networks.utils import normal_init
+
+
+class MultiScalePatchDiscriminator(nn.Sequential):
+ """
+ Multi-scale Patch-GAN discriminator based on Pix2PixHD:
+ High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs (https://arxiv.org/abs/1711.11585)
+
+ The Multi-scale discriminator made up of several PatchGAN discriminators, that process the images
+ at different spatial scales.
+
+ Args:
+ num_d: number of discriminators
+ num_layers_d: number of Convolution layers (Conv + activation + normalisation + [dropout]) in the first
+ discriminator. Each subsequent discriminator has one additional layer, meaning the output size is halved.
+ spatial_dims: number of spatial dimensions (1D, 2D etc.)
+ channels: number of filters in the first convolutional layer (doubled for each subsequent layer)
+ in_channels: number of input channels
+ out_channels: number of output channels in each discriminator
+ kernel_size: kernel size of the convolution layers
+ activation: activation layer type
+ norm: normalisation type
+ bias: introduction of layer bias
+ dropout: probability of dropout applied, defaults to 0.
+ minimum_size_im: minimum spatial size of the input image. Introduced to make sure the architecture
+ requested isn't going to downsample the input image beyond value of 1.
+ last_conv_kernel_size: kernel size of the last convolutional layer.
+ """
+
+ def __init__(
+ self,
+ num_d: int,
+ num_layers_d: int,
+ spatial_dims: int,
+ channels: int,
+ in_channels: int,
+ out_channels: int = 1,
+ kernel_size: int = 4,
+ activation: str | tuple = (Act.LEAKYRELU, {"negative_slope": 0.2}),
+ norm: str | tuple = "BATCH",
+ bias: bool = False,
+ dropout: float | tuple = 0.0,
+ minimum_size_im: int = 256,
+ last_conv_kernel_size: int = 1,
+ ) -> None:
+ super().__init__()
+ self.num_d = num_d
+ self.num_layers_d = num_layers_d
+ self.num_channels = channels
+ self.padding = tuple([int((kernel_size - 1) / 2)] * spatial_dims)
+ for i_ in range(self.num_d):
+ num_layers_d_i = self.num_layers_d * (i_ + 1)
+ output_size = float(minimum_size_im) / (2**num_layers_d_i)
+ if output_size < 1:
+ raise AssertionError(
+ f"Your image size is too small to take in up to {i_} discriminators with num_layers = {num_layers_d_i}."
+ "Please reduce num_layers, reduce num_D or enter bigger images."
+ )
+ subnet_d = PatchDiscriminator(
+ spatial_dims=spatial_dims,
+ channels=self.num_channels,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ num_layers_d=num_layers_d_i,
+ kernel_size=kernel_size,
+ activation=activation,
+ norm=norm,
+ bias=bias,
+ padding=self.padding,
+ dropout=dropout,
+ last_conv_kernel_size=last_conv_kernel_size,
+ )
+
+ self.add_module("discriminator_%d" % i_, subnet_d)
+
+ def forward(self, i: torch.Tensor) -> tuple[list[torch.Tensor], list[list[torch.Tensor]]]:
+ """
+ Args:
+ i: Input tensor
+
+ Returns:
+ list of outputs and another list of lists with the intermediate features
+ of each discriminator.
+ """
+
+ out: list[torch.Tensor] = []
+ intermediate_features: list[list[torch.Tensor]] = []
+ for disc in self.children():
+ out_d: list[torch.Tensor] = disc(i)
+ out.append(out_d[-1])
+ intermediate_features.append(out_d[:-1])
+
+ return out, intermediate_features
+
+
+class PatchDiscriminator(nn.Sequential):
+ """
+ Patch-GAN discriminator based on Pix2PixHD:
+ High-Resolution Image Synthesis and Semantic Manipulation with Conditional GANs (https://arxiv.org/abs/1711.11585)
+
+
+ Args:
+ spatial_dims: number of spatial dimensions (1D, 2D etc.)
+ channels: number of filters in the first convolutional layer (doubled for each subsequent layer)
+ in_channels: number of input channels
+ out_channels: number of output channels
+ num_layers_d: number of Convolution layers (Conv + activation + normalisation + [dropout]) in the discriminator.
+ kernel_size: kernel size of the convolution layers
+ act: activation type and arguments. Defaults to LeakyReLU.
+ norm: feature normalization type and arguments. Defaults to batch norm.
+ bias: whether to have a bias term in convolution blocks. Defaults to False.
+ padding: padding to be applied to the convolutional layers
+ dropout: proportion of dropout applied, defaults to 0.
+ last_conv_kernel_size: kernel size of the last convolutional layer.
+ """
+
+ def __init__(
+ self,
+ spatial_dims: int,
+ channels: int,
+ in_channels: int,
+ out_channels: int = 1,
+ num_layers_d: int = 3,
+ kernel_size: int = 4,
+ activation: str | tuple = (Act.LEAKYRELU, {"negative_slope": 0.2}),
+ norm: str | tuple = "BATCH",
+ bias: bool = False,
+ padding: int | Sequence[int] = 1,
+ dropout: float | tuple = 0.0,
+ last_conv_kernel_size: int | None = None,
+ ) -> None:
+ super().__init__()
+ self.num_layers_d = num_layers_d
+ self.num_channels = channels
+ if last_conv_kernel_size is None:
+ last_conv_kernel_size = kernel_size
+
+ self.add_module(
+ "initial_conv",
+ Convolution(
+ spatial_dims=spatial_dims,
+ kernel_size=kernel_size,
+ in_channels=in_channels,
+ out_channels=channels,
+ act=activation,
+ bias=True,
+ norm=None,
+ dropout=dropout,
+ padding=padding,
+ strides=2,
+ ),
+ )
+
+ input_channels = channels
+ output_channels = channels * 2
+
+ # Initial Layer
+ for l_ in range(self.num_layers_d):
+ if l_ == self.num_layers_d - 1:
+ stride = 1
+ else:
+ stride = 2
+ layer = Convolution(
+ spatial_dims=spatial_dims,
+ kernel_size=kernel_size,
+ in_channels=input_channels,
+ out_channels=output_channels,
+ act=activation,
+ bias=bias,
+ norm=norm,
+ dropout=dropout,
+ padding=padding,
+ strides=stride,
+ )
+ self.add_module("%d" % l_, layer)
+ input_channels = output_channels
+ output_channels = output_channels * 2
+
+ # Final layer
+ self.add_module(
+ "final_conv",
+ Convolution(
+ spatial_dims=spatial_dims,
+ kernel_size=last_conv_kernel_size,
+ in_channels=input_channels,
+ out_channels=out_channels,
+ bias=True,
+ conv_only=True,
+ padding=int((last_conv_kernel_size - 1) / 2),
+ dropout=0.0,
+ strides=1,
+ ),
+ )
+
+ self.apply(normal_init)
+
+ def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
+ """
+ Args:
+ x: input tensor
+
+ Returns:
+ list of intermediate features, with the last element being the output.
+ """
+ out = [x]
+ for submodel in self.children():
+ intermediate_output = submodel(out[-1])
+ out.append(intermediate_output)
+
+ return out[1:]
diff --git a/monai/networks/nets/quicknat.py b/monai/networks/nets/quicknat.py
index cbcccf24d7..7e0f9c6b38 100644
--- a/monai/networks/nets/quicknat.py
+++ b/monai/networks/nets/quicknat.py
@@ -42,7 +42,7 @@ class SkipConnectionWithIdx(SkipConnection):
Inherits from SkipConnection but provides the indizes with each forward pass.
"""
- def forward(self, input, indices):
+ def forward(self, input, indices): # type: ignore[override]
return super().forward(input), indices
@@ -57,7 +57,7 @@ class SequentialWithIdx(nn.Sequential):
def __init__(self, *args):
super().__init__(*args)
- def forward(self, input, indices):
+ def forward(self, input, indices): # type: ignore[override]
for module in self:
input, indices = module(input, indices)
return input, indices
@@ -165,9 +165,11 @@ def _get_layer(self, in_channels, out_channels, dilation):
)
return nn.Sequential(conv.get_submodule("adn"), conv.get_submodule("conv"))
- def forward(self, input, _):
+ def forward(self, input, _): # type: ignore[override]
i = 0
result = input
+ result1 = input # this will not stay this value, needed here for pylint/mypy
+
for l in self.children():
# ignoring the max (un-)pool and droupout already added in the initial initialization step
if isinstance(l, (nn.MaxPool2d, nn.MaxUnpool2d, nn.Dropout2d)):
@@ -213,7 +215,7 @@ def __init__(self, in_channels: int, max_pool, se_layer, dropout, kernel_size, n
super().__init__(in_channels, se_layer, dropout, kernel_size, num_filters)
self.max_pool = max_pool
- def forward(self, input, indices=None):
+ def forward(self, input, indices=None): # type: ignore[override]
input, indices = self.max_pool(input)
out_block, _ = super().forward(input, None)
@@ -241,7 +243,7 @@ def __init__(self, in_channels: int, un_pool, se_layer, dropout, kernel_size, nu
super().__init__(in_channels, se_layer, dropout, kernel_size, num_filters)
self.un_pool = un_pool
- def forward(self, input, indices):
+ def forward(self, input, indices): # type: ignore[override]
out_block, _ = super().forward(input, None)
out_block = self.un_pool(out_block, indices)
return out_block, None
@@ -268,7 +270,7 @@ def __init__(self, in_channels: int, se_layer, dropout, max_pool, un_pool, kerne
self.max_pool = max_pool
self.un_pool = un_pool
- def forward(self, input, indices):
+ def forward(self, input, indices): # type: ignore[override]
out_block, indices = self.max_pool(input)
out_block, _ = super().forward(out_block, None)
out_block = self.un_pool(out_block, indices)
diff --git a/monai/networks/nets/resnet.py b/monai/networks/nets/resnet.py
index 34a4b7057e..d62722478e 100644
--- a/monai/networks/nets/resnet.py
+++ b/monai/networks/nets/resnet.py
@@ -21,8 +21,9 @@
import torch
import torch.nn as nn
-from monai.networks.layers.factories import Conv, Norm, Pool
-from monai.networks.layers.utils import get_pool_layer
+from monai.networks.blocks.encoder import BaseEncoder
+from monai.networks.layers.factories import Conv, Pool
+from monai.networks.layers.utils import get_act_layer, get_norm_layer, get_pool_layer
from monai.utils import ensure_tuple_rep
from monai.utils.module import look_up_option, optional_import
@@ -45,6 +46,17 @@
"resnet200",
]
+resnet_params = {
+ # model_name: (block, layers, shortcut_type, bias_downsample, datasets23)
+ "resnet10": ("basic", [1, 1, 1, 1], "B", False, True),
+ "resnet18": ("basic", [2, 2, 2, 2], "A", True, True),
+ "resnet34": ("basic", [3, 4, 6, 3], "A", True, True),
+ "resnet50": ("bottleneck", [3, 4, 6, 3], "B", False, True),
+ "resnet101": ("bottleneck", [3, 4, 23, 3], "B", False, False),
+ "resnet152": ("bottleneck", [3, 8, 36, 3], "B", False, False),
+ "resnet200": ("bottleneck", [3, 24, 36, 3], "B", False, False),
+}
+
logger = logging.getLogger(__name__)
@@ -66,6 +78,8 @@ def __init__(
spatial_dims: int = 3,
stride: int = 1,
downsample: nn.Module | partial | None = None,
+ act: str | tuple = ("relu", {"inplace": True}),
+ norm: str | tuple = "batch",
) -> None:
"""
Args:
@@ -74,17 +88,18 @@ def __init__(
spatial_dims: number of spatial dimensions of the input image.
stride: stride to use for first conv layer.
downsample: which downsample layer to use.
+ act: activation type and arguments. Defaults to relu.
+ norm: feature normalization type and arguments. Defaults to batch norm.
"""
super().__init__()
conv_type: Callable = Conv[Conv.CONV, spatial_dims]
- norm_type: Callable = Norm[Norm.BATCH, spatial_dims]
self.conv1 = conv_type(in_planes, planes, kernel_size=3, padding=1, stride=stride, bias=False)
- self.bn1 = norm_type(planes)
- self.relu = nn.ReLU(inplace=True)
+ self.bn1 = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=planes)
+ self.act = get_act_layer(name=act)
self.conv2 = conv_type(planes, planes, kernel_size=3, padding=1, bias=False)
- self.bn2 = norm_type(planes)
+ self.bn2 = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=planes)
self.downsample = downsample
self.stride = stride
@@ -93,7 +108,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
out: torch.Tensor = self.conv1(x)
out = self.bn1(out)
- out = self.relu(out)
+ out = self.act(out)
out = self.conv2(out)
out = self.bn2(out)
@@ -102,7 +117,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = self.downsample(x)
out += residual
- out = self.relu(out)
+ out = self.act(out)
return out
@@ -117,6 +132,8 @@ def __init__(
spatial_dims: int = 3,
stride: int = 1,
downsample: nn.Module | partial | None = None,
+ act: str | tuple = ("relu", {"inplace": True}),
+ norm: str | tuple = "batch",
) -> None:
"""
Args:
@@ -125,20 +142,22 @@ def __init__(
spatial_dims: number of spatial dimensions of the input image.
stride: stride to use for second conv layer.
downsample: which downsample layer to use.
+ act: activation type and arguments. Defaults to relu.
+ norm: feature normalization type and arguments. Defaults to batch norm.
"""
super().__init__()
conv_type: Callable = Conv[Conv.CONV, spatial_dims]
- norm_type: Callable = Norm[Norm.BATCH, spatial_dims]
+ norm_layer = partial(get_norm_layer, name=norm, spatial_dims=spatial_dims)
self.conv1 = conv_type(in_planes, planes, kernel_size=1, bias=False)
- self.bn1 = norm_type(planes)
+ self.bn1 = norm_layer(channels=planes)
self.conv2 = conv_type(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
- self.bn2 = norm_type(planes)
+ self.bn2 = norm_layer(channels=planes)
self.conv3 = conv_type(planes, planes * self.expansion, kernel_size=1, bias=False)
- self.bn3 = norm_type(planes * self.expansion)
- self.relu = nn.ReLU(inplace=True)
+ self.bn3 = norm_layer(channels=planes * self.expansion)
+ self.act = get_act_layer(name=act)
self.downsample = downsample
self.stride = stride
@@ -147,11 +166,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
out: torch.Tensor = self.conv1(x)
out = self.bn1(out)
- out = self.relu(out)
+ out = self.act(out)
out = self.conv2(out)
out = self.bn2(out)
- out = self.relu(out)
+ out = self.act(out)
out = self.conv3(out)
out = self.bn3(out)
@@ -160,7 +179,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = self.downsample(x)
out += residual
- out = self.relu(out)
+ out = self.act(out)
return out
@@ -190,6 +209,8 @@ class ResNet(nn.Module):
num_classes: number of output (classifications).
feed_forward: whether to add the FC layer for the output, default to `True`.
bias_downsample: whether to use bias term in the downsampling block when `shortcut_type` is 'B', default to `True`.
+ act: activation type and arguments. Defaults to relu.
+ norm: feature normalization type and arguments. Defaults to batch norm.
"""
@@ -208,6 +229,8 @@ def __init__(
num_classes: int = 400,
feed_forward: bool = True,
bias_downsample: bool = True, # for backwards compatibility (also see PR #5477)
+ act: str | tuple = ("relu", {"inplace": True}),
+ norm: str | tuple = "batch",
) -> None:
super().__init__()
@@ -220,7 +243,6 @@ def __init__(
raise ValueError("Unknown block '%s', use basic or bottleneck" % block)
conv_type: type[nn.Conv1d | nn.Conv2d | nn.Conv3d] = Conv[Conv.CONV, spatial_dims]
- norm_type: type[nn.BatchNorm1d | nn.BatchNorm2d | nn.BatchNorm3d] = Norm[Norm.BATCH, spatial_dims]
pool_type: type[nn.MaxPool1d | nn.MaxPool2d | nn.MaxPool3d] = Pool[Pool.MAX, spatial_dims]
avgp_type: type[nn.AdaptiveAvgPool1d | nn.AdaptiveAvgPool2d | nn.AdaptiveAvgPool3d] = Pool[
Pool.ADAPTIVEAVG, spatial_dims
@@ -244,8 +266,10 @@ def __init__(
padding=tuple(k // 2 for k in conv1_kernel_size),
bias=False,
)
- self.bn1 = norm_type(self.in_planes)
- self.relu = nn.ReLU(inplace=True)
+
+ norm_layer = get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=self.in_planes)
+ self.bn1 = norm_layer
+ self.act = get_act_layer(name=act)
self.maxpool = pool_type(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, block_inplanes[0], layers[0], spatial_dims, shortcut_type)
self.layer2 = self._make_layer(block, block_inplanes[1], layers[1], spatial_dims, shortcut_type, stride=2)
@@ -257,7 +281,7 @@ def __init__(
for m in self.modules():
if isinstance(m, conv_type):
nn.init.kaiming_normal_(torch.as_tensor(m.weight), mode="fan_out", nonlinearity="relu")
- elif isinstance(m, norm_type):
+ elif isinstance(m, type(norm_layer)):
nn.init.constant_(torch.as_tensor(m.weight), 1)
nn.init.constant_(torch.as_tensor(m.bias), 0)
elif isinstance(m, nn.Linear):
@@ -277,9 +301,9 @@ def _make_layer(
spatial_dims: int,
shortcut_type: str,
stride: int = 1,
+ norm: str | tuple = "batch",
) -> nn.Sequential:
conv_type: Callable = Conv[Conv.CONV, spatial_dims]
- norm_type: Callable = Norm[Norm.BATCH, spatial_dims]
downsample: nn.Module | partial | None = None
if stride != 1 or self.in_planes != planes * block.expansion:
@@ -299,25 +323,30 @@ def _make_layer(
stride=stride,
bias=self.bias_downsample,
),
- norm_type(planes * block.expansion),
+ get_norm_layer(name=norm, spatial_dims=spatial_dims, channels=planes * block.expansion),
)
layers = [
block(
- in_planes=self.in_planes, planes=planes, spatial_dims=spatial_dims, stride=stride, downsample=downsample
+ in_planes=self.in_planes,
+ planes=planes,
+ spatial_dims=spatial_dims,
+ stride=stride,
+ downsample=downsample,
+ norm=norm,
)
]
self.in_planes = planes * block.expansion
for _i in range(1, blocks):
- layers.append(block(self.in_planes, planes, spatial_dims=spatial_dims))
+ layers.append(block(self.in_planes, planes, spatial_dims=spatial_dims, norm=norm))
return nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv1(x)
x = self.bn1(x)
- x = self.relu(x)
+ x = self.act(x)
if not self.no_max_pool:
x = self.maxpool(x)
@@ -335,6 +364,120 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x
+class ResNetFeatures(ResNet):
+
+ def __init__(self, model_name: str, pretrained: bool = True, spatial_dims: int = 3, in_channels: int = 1) -> None:
+ """Initialize resnet18 to resnet200 models as a backbone, the backbone can be used as an encoder for
+ segmentation and objection models.
+
+ Compared with the class `ResNet`, the only different place is the forward function.
+
+ Args:
+ model_name: name of model to initialize, can be from [resnet10, ..., resnet200].
+ pretrained: whether to initialize pretrained MedicalNet weights,
+ only available for spatial_dims=3 and in_channels=1.
+ spatial_dims: number of spatial dimensions of the input image.
+ in_channels: number of input channels for first convolutional layer.
+ """
+ if model_name not in resnet_params:
+ model_name_string = ", ".join(resnet_params.keys())
+ raise ValueError(f"invalid model_name {model_name} found, must be one of {model_name_string} ")
+
+ block, layers, shortcut_type, bias_downsample, datasets23 = resnet_params[model_name]
+
+ super().__init__(
+ block=block,
+ layers=layers,
+ block_inplanes=get_inplanes(),
+ spatial_dims=spatial_dims,
+ n_input_channels=in_channels,
+ conv1_t_stride=2,
+ shortcut_type=shortcut_type,
+ feed_forward=False,
+ bias_downsample=bias_downsample,
+ )
+ if pretrained:
+ if spatial_dims == 3 and in_channels == 1:
+ _load_state_dict(self, model_name, datasets23=datasets23)
+ else:
+ raise ValueError("Pretrained resnet models are only available for in_channels=1 and spatial_dims=3.")
+
+ def forward(self, inputs: torch.Tensor):
+ """
+ Args:
+ inputs: input should have spatially N dimensions
+ ``(Batch, in_channels, dim_0[, dim_1, ..., dim_N])``, N is defined by `dimensions`.
+
+ Returns:
+ a list of torch Tensors.
+ """
+ x = self.conv1(inputs)
+ x = self.bn1(x)
+ x = self.act(x)
+
+ features = []
+ features.append(x)
+
+ if not self.no_max_pool:
+ x = self.maxpool(x)
+
+ x = self.layer1(x)
+ features.append(x)
+
+ x = self.layer2(x)
+ features.append(x)
+
+ x = self.layer3(x)
+ features.append(x)
+
+ x = self.layer4(x)
+ features.append(x)
+
+ return features
+
+
+class ResNetEncoder(ResNetFeatures, BaseEncoder):
+ """Wrap the original resnet to an encoder for flexible-unet."""
+
+ backbone_names = ["resnet10", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnet200"]
+
+ @classmethod
+ def get_encoder_parameters(cls) -> list[dict]:
+ """Get the initialization parameter for resnet backbones."""
+ parameter_list = []
+ for backbone_name in cls.backbone_names:
+ parameter_list.append(
+ {"model_name": backbone_name, "pretrained": True, "spatial_dims": 3, "in_channels": 1}
+ )
+ return parameter_list
+
+ @classmethod
+ def num_channels_per_output(cls) -> list[tuple[int, ...]]:
+ """Get number of resnet backbone output feature maps channel."""
+ return [
+ (64, 64, 128, 256, 512),
+ (64, 64, 128, 256, 512),
+ (64, 64, 128, 256, 512),
+ (64, 256, 512, 1024, 2048),
+ (64, 256, 512, 1024, 2048),
+ (64, 256, 512, 1024, 2048),
+ (64, 256, 512, 1024, 2048),
+ ]
+
+ @classmethod
+ def num_outputs(cls) -> list[int]:
+ """Get number of resnet backbone output feature maps.
+
+ Since every backbone contains the same 5 output feature maps, the number list should be `[5] * 7`.
+ """
+ return [5] * 7
+
+ @classmethod
+ def get_encoder_names(cls) -> list[str]:
+ """Get names of resnet backbones."""
+ return cls.backbone_names
+
+
def _resnet(
arch: str,
block: type[ResNetBlock | ResNetBottleneck],
@@ -367,7 +510,7 @@ def _resnet(
# Check model bias_downsample and shortcut_type
bias_downsample, shortcut_type = get_medicalnet_pretrained_resnet_args(resnet_depth)
if shortcut_type == kwargs.get("shortcut_type", "B") and (
- bool(bias_downsample) == kwargs.get("bias_downsample", False) if bias_downsample != -1 else True
+ bias_downsample == kwargs.get("bias_downsample", True)
):
# Download the MedicalNet pretrained model
model_state_dict = get_pretrained_resnet_medicalnet(
@@ -375,8 +518,7 @@ def _resnet(
)
else:
raise NotImplementedError(
- f"Please set shortcut_type to {shortcut_type} and bias_downsample to"
- f"{bool(bias_downsample) if bias_downsample!=-1 else 'True or False'}"
+ f"Please set shortcut_type to {shortcut_type} and bias_downsample to {bias_downsample} "
f"when using pretrained MedicalNet resnet{resnet_depth}"
)
else:
@@ -477,7 +619,7 @@ def resnet200(pretrained: bool = False, progress: bool = True, **kwargs: Any) ->
def get_pretrained_resnet_medicalnet(resnet_depth: int, device: str = "cpu", datasets23: bool = True):
"""
- Donwlad resnet pretrained weights from https://huggingface.co/TencentMedicalNet
+ Download resnet pretrained weights from https://huggingface.co/TencentMedicalNet
Args:
resnet_depth: depth of the pretrained model. Supported values are 10, 18, 34, 50, 101, 152 and 200
@@ -533,11 +675,24 @@ def get_pretrained_resnet_medicalnet(resnet_depth: int, device: str = "cpu", dat
def get_medicalnet_pretrained_resnet_args(resnet_depth: int):
"""
Return correct shortcut_type and bias_downsample
- for pretrained MedicalNet weights according to resnet depth
+ for pretrained MedicalNet weights according to resnet depth.
"""
# After testing
# False: 10, 50, 101, 152, 200
# Any: 18, 34
- bias_downsample = -1 if resnet_depth in [18, 34] else 0 # 18, 10, 34
+ bias_downsample = resnet_depth in (18, 34)
shortcut_type = "A" if resnet_depth in [18, 34] else "B"
return bias_downsample, shortcut_type
+
+
+def _load_state_dict(model: nn.Module, model_name: str, datasets23: bool = True) -> None:
+ search_res = re.search(r"resnet(\d+)", model_name)
+ if search_res:
+ resnet_depth = int(search_res.group(1))
+ datasets23 = model_name.endswith("_23datasets")
+ else:
+ raise ValueError("model_name argument should contain resnet depth. Example: resnet18 or resnet18_23datasets.")
+
+ model_state_dict = get_pretrained_resnet_medicalnet(resnet_depth, device="cpu", datasets23=datasets23)
+ model_state_dict = {key.replace("module.", ""): value for key, value in model_state_dict.items()}
+ model.load_state_dict(model_state_dict)
diff --git a/monai/networks/nets/segresnet_ds.py b/monai/networks/nets/segresnet_ds.py
index 6430f5fdc9..098e490511 100644
--- a/monai/networks/nets/segresnet_ds.py
+++ b/monai/networks/nets/segresnet_ds.py
@@ -11,6 +11,7 @@
from __future__ import annotations
+import copy
from collections.abc import Callable
from typing import Union
@@ -23,7 +24,7 @@
from monai.networks.layers.utils import get_act_layer, get_norm_layer
from monai.utils import UpsampleMode, has_option
-__all__ = ["SegResNetDS"]
+__all__ = ["SegResNetDS", "SegResNetDS2"]
def scales_for_resolution(resolution: tuple | list, n_stages: int | None = None):
@@ -425,3 +426,130 @@ def _forward(self, x: torch.Tensor) -> Union[None, torch.Tensor, list[torch.Tens
def forward(self, x: torch.Tensor) -> Union[None, torch.Tensor, list[torch.Tensor]]:
return self._forward(x)
+
+
+class SegResNetDS2(SegResNetDS):
+ """
+ SegResNetDS2 adds an additional decorder branch to SegResNetDS and is the image encoder of VISTA3D
+ `_.
+
+ Args:
+ spatial_dims: spatial dimension of the input data. Defaults to 3.
+ init_filters: number of output channels for initial convolution layer. Defaults to 32.
+ in_channels: number of input channels for the network. Defaults to 1.
+ out_channels: number of output channels for the network. Defaults to 2.
+ act: activation type and arguments. Defaults to ``RELU``.
+ norm: feature normalization type and arguments. Defaults to ``BATCH``.
+ blocks_down: number of downsample blocks in each layer. Defaults to ``[1,2,2,4]``.
+ blocks_up: number of upsample blocks (optional).
+ dsdepth: number of levels for deep supervision. This will be the length of the list of outputs at each scale level.
+ At dsdepth==1,only a single output is returned.
+ preprocess: optional callable function to apply before the model's forward pass
+ resolution: optional input image resolution. When provided, the network will first use non-isotropic kernels to bring
+ image spacing into an approximately isotropic space.
+ Otherwise, by default, the kernel size and downsampling is always isotropic.
+
+ """
+
+ def __init__(
+ self,
+ spatial_dims: int = 3,
+ init_filters: int = 32,
+ in_channels: int = 1,
+ out_channels: int = 2,
+ act: tuple | str = "relu",
+ norm: tuple | str = "batch",
+ blocks_down: tuple = (1, 2, 2, 4),
+ blocks_up: tuple | None = None,
+ dsdepth: int = 1,
+ preprocess: nn.Module | Callable | None = None,
+ upsample_mode: UpsampleMode | str = "deconv",
+ resolution: tuple | None = None,
+ ):
+ super().__init__(
+ spatial_dims=spatial_dims,
+ init_filters=init_filters,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ act=act,
+ norm=norm,
+ blocks_down=blocks_down,
+ blocks_up=blocks_up,
+ dsdepth=dsdepth,
+ preprocess=preprocess,
+ upsample_mode=upsample_mode,
+ resolution=resolution,
+ )
+
+ self.up_layers_auto = nn.ModuleList([copy.deepcopy(layer) for layer in self.up_layers])
+
+ def forward( # type: ignore
+ self, x: torch.Tensor, with_point: bool = True, with_label: bool = True
+ ) -> tuple[Union[None, torch.Tensor, list[torch.Tensor]], Union[None, torch.Tensor, list[torch.Tensor]]]:
+ """
+ Args:
+ x: input tensor.
+ with_point: if true, return the point branch output.
+ with_label: if true, return the label branch output.
+ """
+ if self.preprocess is not None:
+ x = self.preprocess(x)
+
+ if not self.is_valid_shape(x):
+ raise ValueError(f"Input spatial dims {x.shape} must be divisible by {self.shape_factor()}")
+
+ x_down = self.encoder(x)
+
+ x_down.reverse()
+ x = x_down.pop(0)
+
+ if len(x_down) == 0:
+ x_down = [torch.zeros(1, device=x.device, dtype=x.dtype)]
+
+ outputs: list[torch.Tensor] = []
+ outputs_auto: list[torch.Tensor] = []
+ x_ = x
+ if with_point:
+ if with_label:
+ x_ = x.clone()
+ i = 0
+ for level in self.up_layers:
+ x = level["upsample"](x)
+ x = x + x_down[i]
+ x = level["blocks"](x)
+
+ if len(self.up_layers) - i <= self.dsdepth:
+ outputs.append(level["head"](x))
+ i = i + 1
+
+ outputs.reverse()
+ x = x_
+ if with_label:
+ i = 0
+ for level in self.up_layers_auto:
+ x = level["upsample"](x)
+ x = x + x_down[i]
+ x = level["blocks"](x)
+
+ if len(self.up_layers) - i <= self.dsdepth:
+ outputs_auto.append(level["head"](x))
+ i = i + 1
+
+ outputs_auto.reverse()
+
+ return outputs[0] if len(outputs) == 1 else outputs, outputs_auto[0] if len(outputs_auto) == 1 else outputs_auto
+
+ def set_auto_grad(self, auto_freeze=False, point_freeze=False):
+ """
+ Args:
+ auto_freeze: if true, freeze the image encoder and the auto-branch.
+ point_freeze: if true, freeze the image encoder and the point-branch.
+ """
+ for param in self.encoder.parameters():
+ param.requires_grad = (not auto_freeze) and (not point_freeze)
+
+ for param in self.up_layers_auto.parameters():
+ param.requires_grad = not auto_freeze
+
+ for param in self.up_layers.parameters():
+ param.requires_grad = not point_freeze
diff --git a/monai/networks/nets/spade_autoencoderkl.py b/monai/networks/nets/spade_autoencoderkl.py
new file mode 100644
index 0000000000..cc8909194a
--- /dev/null
+++ b/monai/networks/nets/spade_autoencoderkl.py
@@ -0,0 +1,502 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+from collections.abc import Sequence
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from monai.networks.blocks import Convolution, SpatialAttentionBlock, Upsample
+from monai.networks.blocks.spade_norm import SPADE
+from monai.networks.nets.autoencoderkl import Encoder
+from monai.utils import ensure_tuple_rep
+
+__all__ = ["SPADEAutoencoderKL"]
+
+
+class SPADEResBlock(nn.Module):
+ """
+ Residual block consisting of a cascade of 2 convolutions + activation + normalisation block, and a
+ residual connection between input and output.
+ Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE)
+
+ Args:
+ spatial_dims: number of spatial dimensions (1D, 2D, 3D).
+ in_channels: input channels to the layer.
+ norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of
+ channels is divisible by this number.
+ norm_eps: epsilon for the normalisation.
+ out_channels: number of output channels.
+ label_nc: number of semantic channels for SPADE normalisation
+ spade_intermediate_channels: number of intermediate channels for SPADE block layer
+ """
+
+ def __init__(
+ self,
+ spatial_dims: int,
+ in_channels: int,
+ norm_num_groups: int,
+ norm_eps: float,
+ out_channels: int,
+ label_nc: int,
+ spade_intermediate_channels: int,
+ ) -> None:
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = in_channels if out_channels is None else out_channels
+ self.norm1 = SPADE(
+ label_nc=label_nc,
+ norm_nc=in_channels,
+ norm="GROUP",
+ norm_params={"num_groups": norm_num_groups, "affine": False, "eps": norm_eps},
+ hidden_channels=spade_intermediate_channels,
+ kernel_size=3,
+ spatial_dims=spatial_dims,
+ )
+ self.conv1 = Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=self.in_channels,
+ out_channels=self.out_channels,
+ strides=1,
+ kernel_size=3,
+ padding=1,
+ conv_only=True,
+ )
+ self.norm2 = SPADE(
+ label_nc=label_nc,
+ norm_nc=out_channels,
+ norm="GROUP",
+ norm_params={"num_groups": norm_num_groups, "affine": False, "eps": norm_eps},
+ hidden_channels=spade_intermediate_channels,
+ kernel_size=3,
+ spatial_dims=spatial_dims,
+ )
+ self.conv2 = Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=self.out_channels,
+ out_channels=self.out_channels,
+ strides=1,
+ kernel_size=3,
+ padding=1,
+ conv_only=True,
+ )
+
+ self.nin_shortcut: nn.Module
+ if self.in_channels != self.out_channels:
+ self.nin_shortcut = Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=self.in_channels,
+ out_channels=self.out_channels,
+ strides=1,
+ kernel_size=1,
+ padding=0,
+ conv_only=True,
+ )
+ else:
+ self.nin_shortcut = nn.Identity()
+
+ def forward(self, x: torch.Tensor, seg: torch.Tensor) -> torch.Tensor:
+ h = x
+ h = self.norm1(h, seg)
+ h = F.silu(h)
+ h = self.conv1(h)
+ h = self.norm2(h, seg)
+ h = F.silu(h)
+ h = self.conv2(h)
+
+ x = self.nin_shortcut(x)
+
+ return x + h
+
+
+class SPADEDecoder(nn.Module):
+ """
+ Convolutional cascade upsampling from a spatial latent space into an image space.
+ Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE)
+
+ Args:
+ spatial_dims: number of spatial dimensions (1D, 2D, 3D).
+ channels: sequence of block output channels.
+ in_channels: number of channels in the bottom layer (latent space) of the autoencoder.
+ out_channels: number of output channels.
+ num_res_blocks: number of residual blocks (see ResBlock) per level.
+ norm_num_groups: number of groups for the GroupNorm layers, channels must be divisible by this number.
+ norm_eps: epsilon for the normalization.
+ attention_levels: indicate which level from channels contain an attention block.
+ label_nc: number of semantic channels for SPADE normalisation.
+ with_nonlocal_attn: if True use non-local attention block.
+ spade_intermediate_channels: number of intermediate channels for SPADE block layer.
+ include_fc: whether to include the final linear layer. Default to True.
+ use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
+ use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
+ (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
+ """
+
+ def __init__(
+ self,
+ spatial_dims: int,
+ channels: Sequence[int],
+ in_channels: int,
+ out_channels: int,
+ num_res_blocks: Sequence[int],
+ norm_num_groups: int,
+ norm_eps: float,
+ attention_levels: Sequence[bool],
+ label_nc: int,
+ with_nonlocal_attn: bool = True,
+ spade_intermediate_channels: int = 128,
+ include_fc: bool = True,
+ use_combined_linear: bool = False,
+ use_flash_attention: bool = False,
+ ) -> None:
+ super().__init__()
+ self.spatial_dims = spatial_dims
+ self.channels = channels
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.num_res_blocks = num_res_blocks
+ self.norm_num_groups = norm_num_groups
+ self.norm_eps = norm_eps
+ self.attention_levels = attention_levels
+ self.label_nc = label_nc
+
+ reversed_block_out_channels = list(reversed(channels))
+
+ blocks: list[nn.Module] = []
+
+ # Initial convolution
+ blocks.append(
+ Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=in_channels,
+ out_channels=reversed_block_out_channels[0],
+ strides=1,
+ kernel_size=3,
+ padding=1,
+ conv_only=True,
+ )
+ )
+
+ # Non-local attention block
+ if with_nonlocal_attn is True:
+ blocks.append(
+ SPADEResBlock(
+ spatial_dims=spatial_dims,
+ in_channels=reversed_block_out_channels[0],
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ out_channels=reversed_block_out_channels[0],
+ label_nc=label_nc,
+ spade_intermediate_channels=spade_intermediate_channels,
+ )
+ )
+ blocks.append(
+ SpatialAttentionBlock(
+ spatial_dims=spatial_dims,
+ num_channels=reversed_block_out_channels[0],
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ include_fc=include_fc,
+ use_combined_linear=use_combined_linear,
+ use_flash_attention=use_flash_attention,
+ )
+ )
+ blocks.append(
+ SPADEResBlock(
+ spatial_dims=spatial_dims,
+ in_channels=reversed_block_out_channels[0],
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ out_channels=reversed_block_out_channels[0],
+ label_nc=label_nc,
+ spade_intermediate_channels=spade_intermediate_channels,
+ )
+ )
+
+ reversed_attention_levels = list(reversed(attention_levels))
+ reversed_num_res_blocks = list(reversed(num_res_blocks))
+ block_out_ch = reversed_block_out_channels[0]
+ for i in range(len(reversed_block_out_channels)):
+ block_in_ch = block_out_ch
+ block_out_ch = reversed_block_out_channels[i]
+ is_final_block = i == len(channels) - 1
+
+ for _ in range(reversed_num_res_blocks[i]):
+ blocks.append(
+ SPADEResBlock(
+ spatial_dims=spatial_dims,
+ in_channels=block_in_ch,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ out_channels=block_out_ch,
+ label_nc=label_nc,
+ spade_intermediate_channels=spade_intermediate_channels,
+ )
+ )
+ block_in_ch = block_out_ch
+
+ if reversed_attention_levels[i]:
+ blocks.append(
+ SpatialAttentionBlock(
+ spatial_dims=spatial_dims,
+ num_channels=block_in_ch,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ include_fc=include_fc,
+ use_combined_linear=use_combined_linear,
+ use_flash_attention=use_flash_attention,
+ )
+ )
+
+ if not is_final_block:
+ post_conv = Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=block_in_ch,
+ out_channels=block_in_ch,
+ strides=1,
+ kernel_size=3,
+ padding=1,
+ conv_only=True,
+ )
+ blocks.append(
+ Upsample(
+ spatial_dims=spatial_dims,
+ mode="nontrainable",
+ in_channels=block_in_ch,
+ out_channels=block_in_ch,
+ interp_mode="nearest",
+ scale_factor=2.0,
+ post_conv=post_conv,
+ align_corners=None,
+ )
+ )
+
+ blocks.append(nn.GroupNorm(num_groups=norm_num_groups, num_channels=block_in_ch, eps=norm_eps, affine=True))
+ blocks.append(
+ Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=block_in_ch,
+ out_channels=out_channels,
+ strides=1,
+ kernel_size=3,
+ padding=1,
+ conv_only=True,
+ )
+ )
+
+ self.blocks = nn.ModuleList(blocks)
+
+ def forward(self, x: torch.Tensor, seg: torch.Tensor) -> torch.Tensor:
+ for block in self.blocks:
+ if isinstance(block, SPADEResBlock):
+ x = block(x, seg)
+ else:
+ x = block(x)
+ return x
+
+
+class SPADEAutoencoderKL(nn.Module):
+ """
+ Autoencoder model with KL-regularized latent space based on
+ Rombach et al. "High-Resolution Image Synthesis with Latent Diffusion Models" https://arxiv.org/abs/2112.10752
+ and Pinaya et al. "Brain Imaging Generation with Latent Diffusion Models" https://arxiv.org/abs/2209.07162
+ Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE)
+
+ Args:
+ spatial_dims: number of spatial dimensions (1D, 2D, 3D).
+ label_nc: number of semantic channels for SPADE normalisation.
+ in_channels: number of input channels.
+ out_channels: number of output channels.
+ num_res_blocks: number of residual blocks (see ResBlock) per level.
+ channels: sequence of block output channels.
+ attention_levels: sequence of levels to add attention.
+ latent_channels: latent embedding dimension.
+ norm_num_groups: number of groups for the GroupNorm layers, channels must be divisible by this number.
+ norm_eps: epsilon for the normalization.
+ with_encoder_nonlocal_attn: if True use non-local attention block in the encoder.
+ with_decoder_nonlocal_attn: if True use non-local attention block in the decoder.
+ spade_intermediate_channels: number of intermediate channels for SPADE block layer.
+ """
+
+ def __init__(
+ self,
+ spatial_dims: int,
+ label_nc: int,
+ in_channels: int = 1,
+ out_channels: int = 1,
+ num_res_blocks: Sequence[int] | int = (2, 2, 2, 2),
+ channels: Sequence[int] = (32, 64, 64, 64),
+ attention_levels: Sequence[bool] = (False, False, True, True),
+ latent_channels: int = 3,
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-6,
+ with_encoder_nonlocal_attn: bool = True,
+ with_decoder_nonlocal_attn: bool = True,
+ spade_intermediate_channels: int = 128,
+ include_fc: bool = True,
+ use_combined_linear: bool = False,
+ use_flash_attention: bool = False,
+ ) -> None:
+ super().__init__()
+
+ # All number of channels should be multiple of num_groups
+ if any((out_channel % norm_num_groups) != 0 for out_channel in channels):
+ raise ValueError("SPADEAutoencoderKL expects all channels being multiple of norm_num_groups")
+
+ if len(channels) != len(attention_levels):
+ raise ValueError("SPADEAutoencoderKL expects channels being same size of attention_levels")
+
+ if isinstance(num_res_blocks, int):
+ num_res_blocks = ensure_tuple_rep(num_res_blocks, len(channels))
+
+ if len(num_res_blocks) != len(channels):
+ raise ValueError(
+ "`num_res_blocks` should be a single integer or a tuple of integers with the same length as "
+ "`channels`."
+ )
+
+ self.encoder = Encoder(
+ spatial_dims=spatial_dims,
+ in_channels=in_channels,
+ channels=channels,
+ out_channels=latent_channels,
+ num_res_blocks=num_res_blocks,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ attention_levels=attention_levels,
+ with_nonlocal_attn=with_encoder_nonlocal_attn,
+ include_fc=include_fc,
+ use_combined_linear=use_combined_linear,
+ use_flash_attention=use_flash_attention,
+ )
+ self.decoder = SPADEDecoder(
+ spatial_dims=spatial_dims,
+ channels=channels,
+ in_channels=latent_channels,
+ out_channels=out_channels,
+ num_res_blocks=num_res_blocks,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ attention_levels=attention_levels,
+ label_nc=label_nc,
+ with_nonlocal_attn=with_decoder_nonlocal_attn,
+ spade_intermediate_channels=spade_intermediate_channels,
+ include_fc=include_fc,
+ use_combined_linear=use_combined_linear,
+ use_flash_attention=use_flash_attention,
+ )
+ self.quant_conv_mu = Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=latent_channels,
+ out_channels=latent_channels,
+ strides=1,
+ kernel_size=1,
+ padding=0,
+ conv_only=True,
+ )
+ self.quant_conv_log_sigma = Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=latent_channels,
+ out_channels=latent_channels,
+ strides=1,
+ kernel_size=1,
+ padding=0,
+ conv_only=True,
+ )
+ self.post_quant_conv = Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=latent_channels,
+ out_channels=latent_channels,
+ strides=1,
+ kernel_size=1,
+ padding=0,
+ conv_only=True,
+ )
+ self.latent_channels = latent_channels
+
+ def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Forwards an image through the spatial encoder, obtaining the latent mean and sigma representations.
+
+ Args:
+ x: BxCx[SPATIAL DIMS] tensor
+
+ """
+ h = self.encoder(x)
+ z_mu = self.quant_conv_mu(h)
+ z_log_var = self.quant_conv_log_sigma(h)
+ z_log_var = torch.clamp(z_log_var, -30.0, 20.0)
+ z_sigma = torch.exp(z_log_var / 2)
+
+ return z_mu, z_sigma
+
+ def sampling(self, z_mu: torch.Tensor, z_sigma: torch.Tensor) -> torch.Tensor:
+ """
+ From the mean and sigma representations resulting of encoding an image through the latent space,
+ obtains a noise sample resulting from sampling gaussian noise, multiplying by the variance (sigma) and
+ adding the mean.
+
+ Args:
+ z_mu: Bx[Z_CHANNELS]x[LATENT SPACE SIZE] mean vector obtained by the encoder when you encode an image
+ z_sigma: Bx[Z_CHANNELS]x[LATENT SPACE SIZE] variance vector obtained by the encoder when you encode an image
+
+ Returns:
+ sample of shape Bx[Z_CHANNELS]x[LATENT SPACE SIZE]
+ """
+ eps = torch.randn_like(z_sigma)
+ z_vae = z_mu + eps * z_sigma
+ return z_vae
+
+ def reconstruct(self, x: torch.Tensor, seg: torch.Tensor) -> torch.Tensor:
+ """
+ Encodes and decodes an input image.
+
+ Args:
+ x: BxCx[SPATIAL DIMENSIONS] tensor.
+ seg: Bx[LABEL_NC]x[SPATIAL DIMENSIONS] tensor of segmentations for SPADE norm.
+ Returns:
+ reconstructed image, of the same shape as input
+ """
+ z_mu, _ = self.encode(x)
+ reconstruction = self.decode(z_mu, seg)
+ return reconstruction
+
+ def decode(self, z: torch.Tensor, seg: torch.Tensor) -> torch.Tensor:
+ """
+ Based on a latent space sample, forwards it through the Decoder.
+
+ Args:
+ z: Bx[Z_CHANNELS]x[LATENT SPACE SHAPE]
+ seg: Bx[LABEL_NC]x[SPATIAL DIMENSIONS] tensor of segmentations for SPADE norm.
+ Returns:
+ decoded image tensor
+ """
+ z = self.post_quant_conv(z)
+ dec: torch.Tensor = self.decoder(z, seg)
+ return dec
+
+ def forward(self, x: torch.Tensor, seg: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ z_mu, z_sigma = self.encode(x)
+ z = self.sampling(z_mu, z_sigma)
+ reconstruction = self.decode(z, seg)
+ return reconstruction, z_mu, z_sigma
+
+ def encode_stage_2_inputs(self, x: torch.Tensor) -> torch.Tensor:
+ z_mu, z_sigma = self.encode(x)
+ z = self.sampling(z_mu, z_sigma)
+ return z
+
+ def decode_stage_2_outputs(self, z: torch.Tensor, seg: torch.Tensor) -> torch.Tensor:
+ image = self.decode(z, seg)
+ return image
diff --git a/monai/networks/nets/spade_diffusion_model_unet.py b/monai/networks/nets/spade_diffusion_model_unet.py
new file mode 100644
index 0000000000..a9609b1d39
--- /dev/null
+++ b/monai/networks/nets/spade_diffusion_model_unet.py
@@ -0,0 +1,971 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# =========================================================================
+# Adapted from https://github.com/huggingface/diffusers
+# which has the following license:
+# https://github.com/huggingface/diffusers/blob/main/LICENSE
+#
+# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =========================================================================
+
+from __future__ import annotations
+
+from collections.abc import Sequence
+
+import torch
+from torch import nn
+
+from monai.networks.blocks import Convolution, SpatialAttentionBlock
+from monai.networks.blocks.spade_norm import SPADE
+from monai.networks.nets.diffusion_model_unet import (
+ DiffusionUnetDownsample,
+ DiffusionUNetResnetBlock,
+ SpatialTransformer,
+ WrappedUpsample,
+ get_down_block,
+ get_mid_block,
+ get_timestep_embedding,
+ zero_module,
+)
+from monai.utils import ensure_tuple_rep
+
+__all__ = ["SPADEDiffusionModelUNet"]
+
+
+class SPADEDiffResBlock(nn.Module):
+ """
+ Residual block with timestep conditioning and SPADE norm.
+ Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE)
+
+ Args:
+ spatial_dims: The number of spatial dimensions.
+ in_channels: number of input channels.
+ temb_channels: number of timestep embedding channels.
+ label_nc: number of semantic channels for SPADE normalisation.
+ out_channels: number of output channels.
+ up: if True, performs upsampling.
+ down: if True, performs downsampling.
+ norm_num_groups: number of groups for the group normalization.
+ norm_eps: epsilon for the group normalization.
+ spade_intermediate_channels: number of intermediate channels for SPADE block layer
+ """
+
+ def __init__(
+ self,
+ spatial_dims: int,
+ in_channels: int,
+ temb_channels: int,
+ label_nc: int,
+ out_channels: int | None = None,
+ up: bool = False,
+ down: bool = False,
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-6,
+ spade_intermediate_channels: int = 128,
+ ) -> None:
+ super().__init__()
+ self.spatial_dims = spatial_dims
+ self.channels = in_channels
+ self.emb_channels = temb_channels
+ self.out_channels = out_channels or in_channels
+ self.up = up
+ self.down = down
+
+ self.norm1 = SPADE(
+ label_nc=label_nc,
+ norm_nc=in_channels,
+ norm="GROUP",
+ norm_params={"num_groups": norm_num_groups, "eps": norm_eps, "affine": True},
+ hidden_channels=spade_intermediate_channels,
+ kernel_size=3,
+ spatial_dims=spatial_dims,
+ )
+
+ self.nonlinearity = nn.SiLU()
+ self.conv1 = Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=in_channels,
+ out_channels=self.out_channels,
+ strides=1,
+ kernel_size=3,
+ padding=1,
+ conv_only=True,
+ )
+
+ self.upsample = self.downsample = None
+ if self.up:
+ self.upsample = WrappedUpsample(
+ spatial_dims=spatial_dims,
+ mode="nontrainable",
+ in_channels=in_channels,
+ out_channels=in_channels,
+ interp_mode="nearest",
+ scale_factor=2.0,
+ align_corners=None,
+ )
+ elif down:
+ self.downsample = DiffusionUnetDownsample(spatial_dims, in_channels, use_conv=False)
+
+ self.time_emb_proj = nn.Linear(temb_channels, self.out_channels)
+
+ self.norm2 = SPADE(
+ label_nc=label_nc,
+ norm_nc=self.out_channels,
+ norm="GROUP",
+ norm_params={"num_groups": norm_num_groups, "eps": norm_eps, "affine": True},
+ hidden_channels=spade_intermediate_channels,
+ kernel_size=3,
+ spatial_dims=spatial_dims,
+ )
+ self.conv2 = zero_module(
+ Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=self.out_channels,
+ out_channels=self.out_channels,
+ strides=1,
+ kernel_size=3,
+ padding=1,
+ conv_only=True,
+ )
+ )
+ self.skip_connection: nn.Module
+
+ if self.out_channels == in_channels:
+ self.skip_connection = nn.Identity()
+ else:
+ self.skip_connection = Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=in_channels,
+ out_channels=self.out_channels,
+ strides=1,
+ kernel_size=1,
+ padding=0,
+ conv_only=True,
+ )
+
+ def forward(self, x: torch.Tensor, emb: torch.Tensor, seg: torch.Tensor) -> torch.Tensor:
+ h = x
+ h = self.norm1(h, seg)
+ h = self.nonlinearity(h)
+
+ if self.upsample is not None:
+ x = self.upsample(x)
+ h = self.upsample(h)
+ elif self.downsample is not None:
+ x = self.downsample(x)
+ h = self.downsample(h)
+
+ h = self.conv1(h)
+
+ if self.spatial_dims == 2:
+ temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None]
+ else:
+ temb = self.time_emb_proj(self.nonlinearity(emb))[:, :, None, None, None]
+ h = h + temb
+
+ h = self.norm2(h, seg)
+ h = self.nonlinearity(h)
+ h = self.conv2(h)
+ output: torch.Tensor = self.skip_connection(x) + h
+ return output
+
+
+class SPADEUpBlock(nn.Module):
+ """
+ Unet's up block containing resnet and upsamplers blocks.
+ Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE)
+
+ Args:
+ spatial_dims: The number of spatial dimensions.
+ in_channels: number of input channels.
+ prev_output_channel: number of channels from residual connection.
+ out_channels: number of output channels.
+ temb_channels: number of timestep embedding channels.
+ label_nc: number of semantic channels for SPADE normalisation.
+ num_res_blocks: number of residual blocks.
+ norm_num_groups: number of groups for the group normalization.
+ norm_eps: epsilon for the group normalization.
+ add_upsample: if True add downsample block.
+ resblock_updown: if True use residual blocks for upsampling.
+ spade_intermediate_channels: number of intermediate channels for SPADE block layer.
+ """
+
+ def __init__(
+ self,
+ spatial_dims: int,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ label_nc: int,
+ num_res_blocks: int = 1,
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-6,
+ add_upsample: bool = True,
+ resblock_updown: bool = False,
+ spade_intermediate_channels: int = 128,
+ ) -> None:
+ super().__init__()
+ self.resblock_updown = resblock_updown
+ resnets = []
+
+ for i in range(num_res_blocks):
+ res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ SPADEDiffResBlock(
+ spatial_dims=spatial_dims,
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ label_nc=label_nc,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ spade_intermediate_channels=spade_intermediate_channels,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ self.upsampler: nn.Module | None
+ if add_upsample:
+ if resblock_updown:
+ self.upsampler = DiffusionUNetResnetBlock(
+ spatial_dims=spatial_dims,
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ up=True,
+ )
+ else:
+ post_conv = Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=out_channels,
+ out_channels=out_channels,
+ strides=1,
+ kernel_size=3,
+ padding=1,
+ conv_only=True,
+ )
+ self.upsampler = WrappedUpsample(
+ spatial_dims=spatial_dims,
+ mode="nontrainable",
+ in_channels=out_channels,
+ out_channels=out_channels,
+ interp_mode="nearest",
+ scale_factor=2.0,
+ post_conv=post_conv,
+ align_corners=None,
+ )
+ else:
+ self.upsampler = None
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ res_hidden_states_list: list[torch.Tensor],
+ temb: torch.Tensor,
+ seg: torch.Tensor,
+ context: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ del context
+ for resnet in self.resnets:
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_list[-1]
+ res_hidden_states_list = res_hidden_states_list[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+ hidden_states = resnet(hidden_states, temb, seg)
+
+ if self.upsampler is not None:
+ hidden_states = self.upsampler(hidden_states, temb)
+
+ return hidden_states
+
+
+class SPADEAttnUpBlock(nn.Module):
+ """
+ Unet's up block containing resnet, upsamplers, and self-attention blocks.
+ Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE)
+
+ Args:
+ spatial_dims: The number of spatial dimensions.
+ in_channels: number of input channels.
+ prev_output_channel: number of channels from residual connection.
+ out_channels: number of output channels.
+ temb_channels: number of timestep embedding channels.
+ label_nc: number of semantic channels for SPADE normalisation
+ num_res_blocks: number of residual blocks.
+ norm_num_groups: number of groups for the group normalization.
+ norm_eps: epsilon for the group normalization.
+ add_upsample: if True add downsample block.
+ resblock_updown: if True use residual blocks for upsampling.
+ num_head_channels: number of channels in each attention head.
+ spade_intermediate_channels: number of intermediate channels for SPADE block layer
+ include_fc: whether to include the final linear layer. Default to True.
+ use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
+ use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
+ (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
+ """
+
+ def __init__(
+ self,
+ spatial_dims: int,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ label_nc: int,
+ num_res_blocks: int = 1,
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-6,
+ add_upsample: bool = True,
+ resblock_updown: bool = False,
+ num_head_channels: int = 1,
+ spade_intermediate_channels: int = 128,
+ include_fc: bool = True,
+ use_combined_linear: bool = False,
+ use_flash_attention: bool = False,
+ ) -> None:
+ super().__init__()
+ self.resblock_updown = resblock_updown
+ resnets = []
+ attentions = []
+
+ for i in range(num_res_blocks):
+ res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ SPADEDiffResBlock(
+ spatial_dims=spatial_dims,
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ label_nc=label_nc,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ spade_intermediate_channels=spade_intermediate_channels,
+ )
+ )
+ attentions.append(
+ SpatialAttentionBlock(
+ spatial_dims=spatial_dims,
+ num_channels=out_channels,
+ num_head_channels=num_head_channels,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ include_fc=include_fc,
+ use_combined_linear=use_combined_linear,
+ use_flash_attention=use_flash_attention,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+ self.attentions = nn.ModuleList(attentions)
+
+ self.upsampler: nn.Module | None
+ if add_upsample:
+ if resblock_updown:
+ self.upsampler = DiffusionUNetResnetBlock(
+ spatial_dims=spatial_dims,
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ up=True,
+ )
+ else:
+ post_conv = Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=out_channels,
+ out_channels=out_channels,
+ strides=1,
+ kernel_size=3,
+ padding=1,
+ conv_only=True,
+ )
+ self.upsampler = WrappedUpsample(
+ spatial_dims=spatial_dims,
+ mode="nontrainable",
+ in_channels=out_channels,
+ out_channels=out_channels,
+ interp_mode="nearest",
+ scale_factor=2.0,
+ post_conv=post_conv,
+ align_corners=None,
+ )
+ else:
+ self.upsampler = None
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ res_hidden_states_list: list[torch.Tensor],
+ temb: torch.Tensor,
+ seg: torch.Tensor,
+ context: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ del context
+ for resnet, attn in zip(self.resnets, self.attentions):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_list[-1]
+ res_hidden_states_list = res_hidden_states_list[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+ hidden_states = resnet(hidden_states, temb, seg)
+ hidden_states = attn(hidden_states).contiguous()
+
+ if self.upsampler is not None:
+ hidden_states = self.upsampler(hidden_states, temb)
+
+ return hidden_states
+
+
+class SPADECrossAttnUpBlock(nn.Module):
+ """
+ Unet's up block containing resnet, upsamplers, and self-attention blocks.
+ Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE)
+
+ Args:
+ spatial_dims: The number of spatial dimensions.
+ in_channels: number of input channels.
+ prev_output_channel: number of channels from residual connection.
+ out_channels: number of output channels.
+ temb_channels: number of timestep embedding channels.
+ label_nc: number of semantic channels for SPADE normalisation.
+ num_res_blocks: number of residual blocks.
+ norm_num_groups: number of groups for the group normalization.
+ norm_eps: epsilon for the group normalization.
+ add_upsample: if True add downsample block.
+ resblock_updown: if True use residual blocks for upsampling.
+ num_head_channels: number of channels in each attention head.
+ transformer_num_layers: number of layers of Transformer blocks to use.
+ cross_attention_dim: number of context dimensions to use.
+ upcast_attention: if True, upcast attention operations to full precision.
+ spade_intermediate_channels: number of intermediate channels for SPADE block layer.
+ use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism.
+ (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
+ """
+
+ def __init__(
+ self,
+ spatial_dims: int,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ label_nc: int,
+ num_res_blocks: int = 1,
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-6,
+ add_upsample: bool = True,
+ resblock_updown: bool = False,
+ num_head_channels: int = 1,
+ transformer_num_layers: int = 1,
+ cross_attention_dim: int | None = None,
+ upcast_attention: bool = False,
+ spade_intermediate_channels: int = 128,
+ include_fc: bool = True,
+ use_combined_linear: bool = False,
+ use_flash_attention: bool = False,
+ ) -> None:
+ super().__init__()
+ self.resblock_updown = resblock_updown
+ resnets = []
+ attentions = []
+
+ for i in range(num_res_blocks):
+ res_skip_channels = in_channels if (i == num_res_blocks - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ SPADEDiffResBlock(
+ spatial_dims=spatial_dims,
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ label_nc=label_nc,
+ spade_intermediate_channels=spade_intermediate_channels,
+ )
+ )
+ attentions.append(
+ SpatialTransformer(
+ spatial_dims=spatial_dims,
+ in_channels=out_channels,
+ num_attention_heads=out_channels // num_head_channels,
+ num_head_channels=num_head_channels,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ num_layers=transformer_num_layers,
+ cross_attention_dim=cross_attention_dim,
+ upcast_attention=upcast_attention,
+ include_fc=include_fc,
+ use_combined_linear=use_combined_linear,
+ use_flash_attention=use_flash_attention,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ self.upsampler: nn.Module | None
+ if add_upsample:
+ if resblock_updown:
+ self.upsampler = DiffusionUNetResnetBlock(
+ spatial_dims=spatial_dims,
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ up=True,
+ )
+ else:
+ post_conv = Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=out_channels,
+ out_channels=out_channels,
+ strides=1,
+ kernel_size=3,
+ padding=1,
+ conv_only=True,
+ )
+ self.upsampler = WrappedUpsample(
+ spatial_dims=spatial_dims,
+ mode="nontrainable",
+ in_channels=out_channels,
+ out_channels=out_channels,
+ interp_mode="nearest",
+ scale_factor=2.0,
+ post_conv=post_conv,
+ align_corners=None,
+ )
+ else:
+ self.upsampler = None
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ res_hidden_states_list: list[torch.Tensor],
+ temb: torch.Tensor,
+ seg: torch.Tensor | None = None,
+ context: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ for resnet, attn in zip(self.resnets, self.attentions):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_list[-1]
+ res_hidden_states_list = res_hidden_states_list[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+ hidden_states = resnet(hidden_states, temb, seg)
+ hidden_states = attn(hidden_states, context=context).contiguous()
+
+ if self.upsampler is not None:
+ hidden_states = self.upsampler(hidden_states, temb)
+
+ return hidden_states
+
+
+def get_spade_up_block(
+ spatial_dims: int,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ num_res_blocks: int,
+ norm_num_groups: int,
+ norm_eps: float,
+ add_upsample: bool,
+ resblock_updown: bool,
+ with_attn: bool,
+ with_cross_attn: bool,
+ num_head_channels: int,
+ transformer_num_layers: int,
+ label_nc: int,
+ cross_attention_dim: int | None,
+ upcast_attention: bool = False,
+ spade_intermediate_channels: int = 128,
+ include_fc: bool = True,
+ use_combined_linear: bool = False,
+ use_flash_attention: bool = False,
+) -> nn.Module:
+ if with_attn:
+ return SPADEAttnUpBlock(
+ spatial_dims=spatial_dims,
+ in_channels=in_channels,
+ prev_output_channel=prev_output_channel,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ label_nc=label_nc,
+ num_res_blocks=num_res_blocks,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ add_upsample=add_upsample,
+ resblock_updown=resblock_updown,
+ num_head_channels=num_head_channels,
+ spade_intermediate_channels=spade_intermediate_channels,
+ include_fc=include_fc,
+ use_combined_linear=use_combined_linear,
+ use_flash_attention=use_flash_attention,
+ )
+ elif with_cross_attn:
+ return SPADECrossAttnUpBlock(
+ spatial_dims=spatial_dims,
+ in_channels=in_channels,
+ prev_output_channel=prev_output_channel,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ label_nc=label_nc,
+ num_res_blocks=num_res_blocks,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ add_upsample=add_upsample,
+ resblock_updown=resblock_updown,
+ num_head_channels=num_head_channels,
+ transformer_num_layers=transformer_num_layers,
+ cross_attention_dim=cross_attention_dim,
+ upcast_attention=upcast_attention,
+ spade_intermediate_channels=spade_intermediate_channels,
+ use_flash_attention=use_flash_attention,
+ )
+ else:
+ return SPADEUpBlock(
+ spatial_dims=spatial_dims,
+ in_channels=in_channels,
+ prev_output_channel=prev_output_channel,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ label_nc=label_nc,
+ num_res_blocks=num_res_blocks,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ add_upsample=add_upsample,
+ resblock_updown=resblock_updown,
+ spade_intermediate_channels=spade_intermediate_channels,
+ )
+
+
+class SPADEDiffusionModelUNet(nn.Module):
+ """
+ UNet network with timestep embedding and attention mechanisms for conditioning, with added SPADE normalization for
+ semantic conditioning (Park et.al (2019): https://github.com/NVlabs/SPADE). An example tutorial can be found at
+ https://github.com/Project-MONAI/GenerativeModels/tree/main/tutorials/generative/2d_spade_ldm
+
+ Args:
+ spatial_dims: number of spatial dimensions.
+ in_channels: number of input channels.
+ out_channels: number of output channels.
+ label_nc: number of semantic channels for SPADE normalisation.
+ num_res_blocks: number of residual blocks (see ResnetBlock) per level.
+ channels: tuple of block output channels.
+ attention_levels: list of levels to add attention.
+ norm_num_groups: number of groups for the normalization.
+ norm_eps: epsilon for the normalization.
+ resblock_updown: if True use residual blocks for up/downsampling.
+ num_head_channels: number of channels in each attention head.
+ with_conditioning: if True add spatial transformers to perform conditioning.
+ transformer_num_layers: number of layers of Transformer blocks to use.
+ cross_attention_dim: number of context dimensions to use.
+ num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds`
+ classes.
+ upcast_attention: if True, upcast attention operations to full precision.
+ spade_intermediate_channels: number of intermediate channels for SPADE block layer.
+ use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
+ (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
+ """
+
+ def __init__(
+ self,
+ spatial_dims: int,
+ in_channels: int,
+ out_channels: int,
+ label_nc: int,
+ num_res_blocks: Sequence[int] | int = (2, 2, 2, 2),
+ channels: Sequence[int] = (32, 64, 64, 64),
+ attention_levels: Sequence[bool] = (False, False, True, True),
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-6,
+ resblock_updown: bool = False,
+ num_head_channels: int | Sequence[int] = 8,
+ with_conditioning: bool = False,
+ transformer_num_layers: int = 1,
+ cross_attention_dim: int | None = None,
+ num_class_embeds: int | None = None,
+ upcast_attention: bool = False,
+ spade_intermediate_channels: int = 128,
+ include_fc: bool = True,
+ use_combined_linear: bool = False,
+ use_flash_attention: bool = False,
+ ) -> None:
+ super().__init__()
+ if with_conditioning is True and cross_attention_dim is None:
+ raise ValueError(
+ "SPADEDiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) "
+ "when using with_conditioning."
+ )
+ if cross_attention_dim is not None and with_conditioning is False:
+ raise ValueError(
+ "SPADEDiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim."
+ )
+
+ # All number of channels should be multiple of num_groups
+ if any((out_channel % norm_num_groups) != 0 for out_channel in channels):
+ raise ValueError("SPADEDiffusionModelUNet expects all num_channels being multiple of norm_num_groups")
+
+ if len(channels) != len(attention_levels):
+ raise ValueError("SPADEDiffusionModelUNet expects num_channels being same size of attention_levels")
+
+ if isinstance(num_head_channels, int):
+ num_head_channels = ensure_tuple_rep(num_head_channels, len(attention_levels))
+
+ if len(num_head_channels) != len(attention_levels):
+ raise ValueError(
+ "num_head_channels should have the same length as attention_levels. For the i levels without attention,"
+ " i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored."
+ )
+
+ if isinstance(num_res_blocks, int):
+ num_res_blocks = ensure_tuple_rep(num_res_blocks, len(channels))
+
+ if len(num_res_blocks) != len(channels):
+ raise ValueError(
+ "`num_res_blocks` should be a single integer or a tuple of integers with the same length as "
+ "`num_channels`."
+ )
+
+ self.in_channels = in_channels
+ self.block_out_channels = channels
+ self.out_channels = out_channels
+ self.num_res_blocks = num_res_blocks
+ self.attention_levels = attention_levels
+ self.num_head_channels = num_head_channels
+ self.with_conditioning = with_conditioning
+ self.label_nc = label_nc
+
+ # input
+ self.conv_in = Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=in_channels,
+ out_channels=channels[0],
+ strides=1,
+ kernel_size=3,
+ padding=1,
+ conv_only=True,
+ )
+
+ # time
+ time_embed_dim = channels[0] * 4
+ self.time_embed = nn.Sequential(
+ nn.Linear(channels[0], time_embed_dim), nn.SiLU(), nn.Linear(time_embed_dim, time_embed_dim)
+ )
+
+ # class embedding
+ self.num_class_embeds = num_class_embeds
+ if num_class_embeds is not None:
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
+
+ # down
+ self.down_blocks = nn.ModuleList([])
+ output_channel = channels[0]
+ for i in range(len(channels)):
+ input_channel = output_channel
+ output_channel = channels[i]
+ is_final_block = i == len(channels) - 1
+
+ down_block = get_down_block(
+ spatial_dims=spatial_dims,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=time_embed_dim,
+ num_res_blocks=num_res_blocks[i],
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ add_downsample=not is_final_block,
+ resblock_updown=resblock_updown,
+ with_attn=(attention_levels[i] and not with_conditioning),
+ with_cross_attn=(attention_levels[i] and with_conditioning),
+ num_head_channels=num_head_channels[i],
+ transformer_num_layers=transformer_num_layers,
+ cross_attention_dim=cross_attention_dim,
+ upcast_attention=upcast_attention,
+ include_fc=include_fc,
+ use_combined_linear=use_combined_linear,
+ use_flash_attention=use_flash_attention,
+ )
+
+ self.down_blocks.append(down_block)
+
+ # mid
+ self.middle_block = get_mid_block(
+ spatial_dims=spatial_dims,
+ in_channels=channels[-1],
+ temb_channels=time_embed_dim,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ with_conditioning=with_conditioning,
+ num_head_channels=num_head_channels[-1],
+ transformer_num_layers=transformer_num_layers,
+ cross_attention_dim=cross_attention_dim,
+ upcast_attention=upcast_attention,
+ include_fc=include_fc,
+ use_combined_linear=use_combined_linear,
+ use_flash_attention=use_flash_attention,
+ )
+
+ # up
+ self.up_blocks = nn.ModuleList([])
+ reversed_block_out_channels = list(reversed(channels))
+ reversed_num_res_blocks = list(reversed(num_res_blocks))
+ reversed_attention_levels = list(reversed(attention_levels))
+ reversed_num_head_channels = list(reversed(num_head_channels))
+ output_channel = reversed_block_out_channels[0]
+ for i in range(len(reversed_block_out_channels)):
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ input_channel = reversed_block_out_channels[min(i + 1, len(channels) - 1)]
+
+ is_final_block = i == len(channels) - 1
+
+ up_block = get_spade_up_block(
+ spatial_dims=spatial_dims,
+ in_channels=input_channel,
+ prev_output_channel=prev_output_channel,
+ out_channels=output_channel,
+ temb_channels=time_embed_dim,
+ num_res_blocks=reversed_num_res_blocks[i] + 1,
+ norm_num_groups=norm_num_groups,
+ norm_eps=norm_eps,
+ add_upsample=not is_final_block,
+ resblock_updown=resblock_updown,
+ with_attn=(reversed_attention_levels[i] and not with_conditioning),
+ with_cross_attn=(reversed_attention_levels[i] and with_conditioning),
+ num_head_channels=reversed_num_head_channels[i],
+ transformer_num_layers=transformer_num_layers,
+ cross_attention_dim=cross_attention_dim,
+ upcast_attention=upcast_attention,
+ label_nc=label_nc,
+ spade_intermediate_channels=spade_intermediate_channels,
+ use_flash_attention=use_flash_attention,
+ )
+
+ self.up_blocks.append(up_block)
+
+ # out
+ self.out = nn.Sequential(
+ nn.GroupNorm(num_groups=norm_num_groups, num_channels=channels[0], eps=norm_eps, affine=True),
+ nn.SiLU(),
+ zero_module(
+ Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=channels[0],
+ out_channels=out_channels,
+ strides=1,
+ kernel_size=3,
+ padding=1,
+ conv_only=True,
+ )
+ ),
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ timesteps: torch.Tensor,
+ seg: torch.Tensor,
+ context: torch.Tensor | None = None,
+ class_labels: torch.Tensor | None = None,
+ down_block_additional_residuals: tuple[torch.Tensor] | None = None,
+ mid_block_additional_residual: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ """
+ Args:
+ x: input tensor (N, C, SpatialDims).
+ timesteps: timestep tensor (N,).
+ seg: Bx[LABEL_NC]x[SPATIAL DIMENSIONS] tensor of segmentations for SPADE norm.
+ context: context tensor (N, 1, ContextDim).
+ class_labels: context tensor (N, ).
+ down_block_additional_residuals: additional residual tensors for down blocks (N, C, FeatureMapsDims).
+ mid_block_additional_residual: additional residual tensor for mid block (N, C, FeatureMapsDims).
+ """
+ # 1. time
+ t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0])
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=x.dtype)
+ emb = self.time_embed(t_emb)
+
+ # 2. class
+ if self.num_class_embeds is not None:
+ if class_labels is None:
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
+ class_emb = self.class_embedding(class_labels)
+ class_emb = class_emb.to(dtype=x.dtype)
+ emb = emb + class_emb
+
+ # 3. initial convolution
+ h = self.conv_in(x)
+
+ # 4. down
+ if context is not None and self.with_conditioning is False:
+ raise ValueError("model should have with_conditioning = True if context is provided")
+ down_block_res_samples: list[torch.Tensor] = [h]
+ for downsample_block in self.down_blocks:
+ h, res_samples = downsample_block(hidden_states=h, temb=emb, context=context)
+ for residual in res_samples:
+ down_block_res_samples.append(residual)
+
+ # Additional residual conections for Controlnets
+ if down_block_additional_residuals is not None:
+ new_down_block_res_samples: list[torch.Tensor] = [h]
+ for down_block_res_sample, down_block_additional_residual in zip(
+ down_block_res_samples, down_block_additional_residuals
+ ):
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
+ new_down_block_res_samples.append(down_block_res_sample)
+
+ down_block_res_samples = new_down_block_res_samples
+
+ # 5. mid
+ h = self.middle_block(hidden_states=h, temb=emb, context=context)
+
+ # Additional residual conections for Controlnets
+ if mid_block_additional_residual is not None:
+ h = h + mid_block_additional_residual
+
+ # 6. up
+ for upsample_block in self.up_blocks:
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+ h = upsample_block(hidden_states=h, res_hidden_states_list=res_samples, seg=seg, temb=emb, context=context)
+
+ # 7. output block
+ output: torch.Tensor = self.out(h)
+
+ return output
diff --git a/monai/networks/nets/spade_network.py b/monai/networks/nets/spade_network.py
new file mode 100644
index 0000000000..9164541f27
--- /dev/null
+++ b/monai/networks/nets/spade_network.py
@@ -0,0 +1,435 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+from typing import Sequence
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from monai.networks.blocks import Convolution
+from monai.networks.blocks.spade_norm import SPADE
+from monai.networks.layers import Act
+from monai.networks.layers.utils import get_act_layer
+from monai.utils.enums import StrEnum
+
+__all__ = ["SPADENet"]
+
+
+class UpsamplingModes(StrEnum):
+ bicubic = "bicubic"
+ nearest = "nearest"
+ bilinear = "bilinear"
+
+
+class SPADENetResBlock(nn.Module):
+ """
+ Creates a Residual Block with SPADE normalisation.
+
+ Args:
+ spatial_dims: number of spatial dimensions
+ in_channels: number of input channels
+ out_channels: number of output channels
+ label_nc: number of semantic channels that will be taken into account in SPADE normalisation blocks
+ spade_intermediate_channels: number of intermediate channels in the middle conv. layers in SPADE normalisation blocks
+ norm: base normalisation type used on top of SPADE
+ kernel_size: convolutional kernel size
+ """
+
+ def __init__(
+ self,
+ spatial_dims: int,
+ in_channels: int,
+ out_channels: int,
+ label_nc: int,
+ spade_intermediate_channels: int = 128,
+ norm: str | tuple = "INSTANCE",
+ act: str | tuple = (Act.LEAKYRELU, {"negative_slope": 0.2}),
+ kernel_size: int = 3,
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.int_channels = min(self.in_channels, self.out_channels)
+ self.learned_shortcut = self.in_channels != self.out_channels
+ self.conv_0 = Convolution(
+ spatial_dims=spatial_dims, in_channels=self.in_channels, out_channels=self.int_channels, act=None, norm=None
+ )
+ self.conv_1 = Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=self.int_channels,
+ out_channels=self.out_channels,
+ act=None,
+ norm=None,
+ )
+ self.activation = get_act_layer(act)
+ self.norm_0 = SPADE(
+ label_nc=label_nc,
+ norm_nc=self.in_channels,
+ kernel_size=kernel_size,
+ spatial_dims=spatial_dims,
+ hidden_channels=spade_intermediate_channels,
+ norm=norm,
+ )
+ self.norm_1 = SPADE(
+ label_nc=label_nc,
+ norm_nc=self.int_channels,
+ kernel_size=kernel_size,
+ spatial_dims=spatial_dims,
+ hidden_channels=spade_intermediate_channels,
+ norm=norm,
+ )
+
+ if self.learned_shortcut:
+ self.conv_s = Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=self.in_channels,
+ out_channels=self.out_channels,
+ act=None,
+ norm=None,
+ kernel_size=1,
+ )
+ self.norm_s = SPADE(
+ label_nc=label_nc,
+ norm_nc=self.in_channels,
+ kernel_size=kernel_size,
+ spatial_dims=spatial_dims,
+ hidden_channels=spade_intermediate_channels,
+ norm=norm,
+ )
+
+ def forward(self, x, seg):
+ x_s = self.shortcut(x, seg)
+ dx = self.conv_0(self.activation(self.norm_0(x, seg)))
+ dx = self.conv_1(self.activation(self.norm_1(dx, seg)))
+ out = x_s + dx
+ return out
+
+ def shortcut(self, x, seg):
+ if self.learned_shortcut:
+ x_s = self.conv_s(self.norm_s(x, seg))
+ else:
+ x_s = x
+ return x_s
+
+
+class SPADEEncoder(nn.Module):
+ """
+ Encoding branch of a VAE compatible with a SPADE-like generator
+
+ Args:
+ spatial_dims: number of spatial dimensions
+ in_channels: number of input channels
+ z_dim: latent space dimension of the VAE containing the image sytle information
+ channels: number of output after each downsampling block
+ input_shape: spatial input shape of the tensor, necessary to do the reshaping after the linear layers
+ of the autoencoder (HxWx[D])
+ kernel_size: convolutional kernel size
+ norm: normalisation layer type
+ act: activation type
+ """
+
+ def __init__(
+ self,
+ spatial_dims: int,
+ in_channels: int,
+ z_dim: int,
+ channels: Sequence[int],
+ input_shape: Sequence[int],
+ kernel_size: int = 3,
+ norm: str | tuple = "INSTANCE",
+ act: str | tuple = (Act.LEAKYRELU, {"negative_slope": 0.2}),
+ ):
+ super().__init__()
+ self.in_channels = in_channels
+ self.z_dim = z_dim
+ self.channels = channels
+ if len(input_shape) != spatial_dims:
+ raise ValueError("Length of parameter input shape must match spatial_dims; got %s" % (input_shape))
+ for s_ind, s_ in enumerate(input_shape):
+ if s_ / (2 ** len(channels)) != s_ // (2 ** len(channels)):
+ raise ValueError(
+ "Each dimension of your input must be divisible by 2 ** (autoencoder depth)."
+ "The shape in position %d, %d is not divisible by %d. " % (s_ind, s_, len(channels))
+ )
+ self.input_shape = input_shape
+ self.latent_spatial_shape = [s_ // (2 ** len(self.channels)) for s_ in self.input_shape]
+ blocks = []
+ ch_init = self.in_channels
+ for _, ch_value in enumerate(channels):
+ blocks.append(
+ Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=ch_init,
+ out_channels=ch_value,
+ strides=2,
+ kernel_size=kernel_size,
+ norm=norm,
+ act=act,
+ )
+ )
+ ch_init = ch_value
+
+ self.blocks = nn.ModuleList(blocks)
+ self.fc_mu = nn.Linear(
+ in_features=np.prod(self.latent_spatial_shape) * self.channels[-1], out_features=self.z_dim
+ )
+ self.fc_var = nn.Linear(
+ in_features=np.prod(self.latent_spatial_shape) * self.channels[-1], out_features=self.z_dim
+ )
+
+ def forward(self, x):
+ for block in self.blocks:
+ x = block(x)
+ x = x.view(x.size(0), -1)
+ mu = self.fc_mu(x)
+ logvar = self.fc_var(x)
+ return mu, logvar
+
+ def encode(self, x):
+ for block in self.blocks:
+ x = block(x)
+ x = x.view(x.size(0), -1)
+ mu = self.fc_mu(x)
+ logvar = self.fc_var(x)
+ return self.reparameterize(mu, logvar)
+
+ def reparameterize(self, mu, logvar):
+ std = torch.exp(0.5 * logvar)
+ eps = torch.randn_like(std)
+ return eps.mul(std) + mu
+
+
+class SPADEDecoder(nn.Module):
+ """
+ Decoder branch of a SPADE-like generator. It can be used independently, without an encoding branch,
+ behaving like a GAN, or coupled to a SPADE encoder.
+
+ Args:
+ label_nc: number of semantic labels
+ spatial_dims: number of spatial dimensions
+ out_channels: number of output channels
+ label_nc: number of semantic channels used for the SPADE normalisation blocks
+ input_shape: spatial input shape of the tensor, necessary to do the reshaping after the linear layers
+ channels: number of output after each downsampling block
+ z_dim: latent space dimension of the VAE containing the image sytle information (None if encoder is not used)
+ is_vae: whether the decoder is going to be coupled to an autoencoder or not (true: yes, false: no)
+ spade_intermediate_channels: number of channels in the intermediate layers of the SPADE normalisation blocks
+ norm: base normalisation type
+ act: activation layer type
+ last_act: activation layer type for the last layer of the network (can differ from previous)
+ kernel_size: convolutional kernel size
+ upsampling_mode: upsampling mode (nearest, bilinear etc.)
+ """
+
+ def __init__(
+ self,
+ spatial_dims: int,
+ out_channels: int,
+ label_nc: int,
+ input_shape: Sequence[int],
+ channels: list[int],
+ z_dim: int | None = None,
+ is_vae: bool = True,
+ spade_intermediate_channels: int = 128,
+ norm: str | tuple = "INSTANCE",
+ act: str | tuple = (Act.LEAKYRELU, {"negative_slope": 0.2}),
+ last_act: str | tuple | None = (Act.LEAKYRELU, {"negative_slope": 0.2}),
+ kernel_size: int = 3,
+ upsampling_mode: str = UpsamplingModes.nearest.value,
+ ):
+ super().__init__()
+ self.is_vae = is_vae
+ self.out_channels = out_channels
+ self.label_nc = label_nc
+ self.num_channels = channels
+ if len(input_shape) != spatial_dims:
+ raise ValueError("Length of parameter input shape must match spatial_dims; got %s" % (input_shape))
+ for s_ind, s_ in enumerate(input_shape):
+ if s_ / (2 ** len(channels)) != s_ // (2 ** len(channels)):
+ raise ValueError(
+ "Each dimension of your input must be divisible by 2 ** (autoencoder depth)."
+ "The shape in position %d, %d is not divisible by %d. " % (s_ind, s_, len(channels))
+ )
+ self.latent_spatial_shape = [s_ // (2 ** len(self.num_channels)) for s_ in input_shape]
+
+ if not self.is_vae:
+ self.conv_init = Convolution(
+ spatial_dims=spatial_dims, in_channels=label_nc, out_channels=channels[0], kernel_size=kernel_size
+ )
+ elif self.is_vae and z_dim is None:
+ raise ValueError(
+ "If the network is used in VAE-GAN mode, parameter z_dim "
+ "(number of latent channels in the VAE) must be populated."
+ )
+ else:
+ self.fc = nn.Linear(z_dim, np.prod(self.latent_spatial_shape) * channels[0])
+
+ self.z_dim = z_dim
+ blocks = []
+ channels.append(self.out_channels)
+ self.upsampling = torch.nn.Upsample(scale_factor=2, mode=upsampling_mode)
+ for ch_ind, ch_value in enumerate(channels[:-1]):
+ blocks.append(
+ SPADENetResBlock(
+ spatial_dims=spatial_dims,
+ in_channels=ch_value,
+ out_channels=channels[ch_ind + 1],
+ label_nc=label_nc,
+ spade_intermediate_channels=spade_intermediate_channels,
+ norm=norm,
+ kernel_size=kernel_size,
+ act=act,
+ )
+ )
+
+ self.blocks = torch.nn.ModuleList(blocks)
+ self.last_conv = Convolution(
+ spatial_dims=spatial_dims,
+ in_channels=channels[-1],
+ out_channels=out_channels,
+ padding=(kernel_size - 1) // 2,
+ kernel_size=kernel_size,
+ norm=None,
+ act=last_act,
+ )
+
+ def forward(self, seg, z: torch.Tensor | None = None):
+ """
+ Args:
+ seg: input BxCxHxW[xD] semantic map on which the output is conditioned on
+ z: latent vector output by the encoder if self.is_vae is True. When is_vae is
+ False, z is a random noise vector.
+
+ Returns:
+
+ """
+ if not self.is_vae:
+ x = F.interpolate(seg, size=tuple(self.latent_spatial_shape))
+ x = self.conv_init(x)
+ else:
+ if (
+ z is None and self.z_dim is not None
+ ): # Even though this network is a VAE (self.is_vae), you should be able to sample from noise as well.
+ z = torch.randn(seg.size(0), self.z_dim, dtype=torch.float32, device=seg.get_device())
+ x = self.fc(z)
+ x = x.view(*[-1, self.num_channels[0]] + self.latent_spatial_shape)
+
+ for res_block in self.blocks:
+ x = res_block(x, seg)
+ x = self.upsampling(x)
+
+ x = self.last_conv(x)
+ return x
+
+
+class SPADENet(nn.Module):
+ """
+ SPADE Network, implemented based on the code by Park, T et al. in
+ "Semantic Image Synthesis with Spatially-Adaptive Normalization"
+ (https://github.com/NVlabs/SPADE)
+
+ Args:
+ spatial_dims: number of spatial dimensions
+ in_channels: number of input channels
+ out_channels: number of output channels
+ label_nc: number of semantic channels used for the SPADE normalisation blocks
+ input_shape: spatial input shape of the tensor, necessary to do the reshaping after the linear layers
+ channels: number of output after each downsampling block
+ z_dim: latent space dimension of the VAE containing the image sytle information (None if encoder is not used)
+ is_vae: whether the decoder is going to be coupled to an autoencoder (true) or not (false)
+ spade_intermediate_channels: number of channels in the intermediate layers of the SPADE normalisation blocks
+ norm: base normalisation type
+ act: activation layer type
+ last_act: activation layer type for the last layer of the network (can differ from previous)
+ kernel_size: convolutional kernel size
+ upsampling_mode: upsampling mode (nearest, bilinear etc.)
+ """
+
+ def __init__(
+ self,
+ spatial_dims: int,
+ in_channels: int,
+ out_channels: int,
+ label_nc: int,
+ input_shape: Sequence[int],
+ channels: list[int],
+ z_dim: int | None = None,
+ is_vae: bool = True,
+ spade_intermediate_channels: int = 128,
+ norm: str | tuple = "INSTANCE",
+ act: str | tuple = (Act.LEAKYRELU, {"negative_slope": 0.2}),
+ last_act: str | tuple | None = (Act.LEAKYRELU, {"negative_slope": 0.2}),
+ kernel_size: int = 3,
+ upsampling_mode: str = UpsamplingModes.nearest.value,
+ ):
+ super().__init__()
+ self.is_vae = is_vae
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.channels = channels
+ self.label_nc = label_nc
+ self.input_shape = input_shape
+
+ if self.is_vae:
+ if z_dim is None:
+ ValueError("The latent space dimension mapped by parameter z_dim cannot be None is is_vae is True.")
+ else:
+ self.encoder = SPADEEncoder(
+ spatial_dims=spatial_dims,
+ in_channels=in_channels,
+ z_dim=z_dim,
+ channels=channels,
+ input_shape=input_shape,
+ kernel_size=kernel_size,
+ norm=norm,
+ act=act,
+ )
+
+ decoder_channels = channels
+ decoder_channels.reverse()
+
+ self.decoder = SPADEDecoder(
+ spatial_dims=spatial_dims,
+ out_channels=out_channels,
+ label_nc=label_nc,
+ input_shape=input_shape,
+ channels=decoder_channels,
+ z_dim=z_dim,
+ is_vae=is_vae,
+ spade_intermediate_channels=spade_intermediate_channels,
+ norm=norm,
+ act=act,
+ last_act=last_act,
+ kernel_size=kernel_size,
+ upsampling_mode=upsampling_mode,
+ )
+
+ def forward(self, seg: torch.Tensor, x: torch.Tensor | None = None):
+ z = None
+ if self.is_vae:
+ z_mu, z_logvar = self.encoder(x)
+ z = self.encoder.reparameterize(z_mu, z_logvar)
+ return self.decoder(seg, z), z_mu, z_logvar
+ else:
+ return (self.decoder(seg, z),)
+
+ def encode(self, x: torch.Tensor):
+ if self.is_vae:
+ return self.encoder.encode(x)
+ else:
+ return None
+
+ def decode(self, seg: torch.Tensor, z: torch.Tensor | None = None):
+ return self.decoder(seg, z)
diff --git a/monai/networks/nets/swin_unetr.py b/monai/networks/nets/swin_unetr.py
index 6f96dfd291..832135ad06 100644
--- a/monai/networks/nets/swin_unetr.py
+++ b/monai/networks/nets/swin_unetr.py
@@ -13,6 +13,7 @@
import itertools
from collections.abc import Sequence
+from typing import Final
import numpy as np
import torch
@@ -20,7 +21,6 @@
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from torch.nn import LayerNorm
-from typing_extensions import Final
from monai.networks.blocks import MLPBlock as Mlp
from monai.networks.blocks import PatchEmbed, UnetOutBlock, UnetrBasicBlock, UnetrUpBlock
@@ -320,7 +320,7 @@ def _check_input_size(self, spatial_shape):
)
def forward(self, x_in):
- if not torch.jit.is_scripting():
+ if not torch.jit.is_scripting() and not torch.jit.is_tracing():
self._check_input_size(x_in.shape[2:])
hidden_states_out = self.swinViT(x_in, self.normalize)
enc0 = self.encoder1(x_in)
@@ -347,7 +347,7 @@ def window_partition(x, window_size):
x: input tensor.
window_size: local window size.
"""
- x_shape = x.size()
+ x_shape = x.size() # length 4 or 5 only
if len(x_shape) == 5:
b, d, h, w, c = x_shape
x = x.view(
@@ -363,10 +363,11 @@ def window_partition(x, window_size):
windows = (
x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, window_size[0] * window_size[1] * window_size[2], c)
)
- elif len(x_shape) == 4:
+ else: # if len(x_shape) == 4:
b, h, w, c = x.shape
x = x.view(b, h // window_size[0], window_size[0], w // window_size[1], window_size[1], c)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0] * window_size[1], c)
+
return windows
@@ -613,7 +614,7 @@ def forward_part1(self, x, mask_matrix):
_, dp, hp, wp, _ = x.shape
dims = [b, dp, hp, wp]
- elif len(x_shape) == 4:
+ else: # elif len(x_shape) == 4
b, h, w, c = x.shape
window_size, shift_size = get_window_size((h, w), self.window_size, self.shift_size)
pad_l = pad_t = 0
@@ -1045,14 +1046,14 @@ def __init__(
def proj_out(self, x, normalize=False):
if normalize:
- x_shape = x.size()
+ x_shape = x.shape
+ # Force trace() to generate a constant by casting to int
+ ch = int(x_shape[1])
if len(x_shape) == 5:
- n, ch, d, h, w = x_shape
x = rearrange(x, "n c d h w -> n d h w c")
x = F.layer_norm(x, [ch])
x = rearrange(x, "n d h w c -> n c d h w")
elif len(x_shape) == 4:
- n, ch, h, w = x_shape
x = rearrange(x, "n c h w -> n h w c")
x = F.layer_norm(x, [ch])
x = rearrange(x, "n h w c -> n c h w")
diff --git a/monai/networks/nets/transformer.py b/monai/networks/nets/transformer.py
new file mode 100644
index 0000000000..3a278c112a
--- /dev/null
+++ b/monai/networks/nets/transformer.py
@@ -0,0 +1,157 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import torch
+import torch.nn as nn
+
+from monai.networks.blocks import TransformerBlock
+
+__all__ = ["DecoderOnlyTransformer"]
+
+
+class AbsolutePositionalEmbedding(nn.Module):
+ """Absolute positional embedding.
+
+ Args:
+ max_seq_len: Maximum sequence length.
+ embedding_dim: Dimensionality of the embedding.
+ """
+
+ def __init__(self, max_seq_len: int, embedding_dim: int) -> None:
+ super().__init__()
+ self.max_seq_len = max_seq_len
+ self.embedding_dim = embedding_dim
+ self.embedding = nn.Embedding(max_seq_len, embedding_dim)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ batch_size, seq_len = x.size()
+ positions = torch.arange(seq_len, device=x.device).repeat(batch_size, 1)
+ embedding: torch.Tensor = self.embedding(positions)
+ return embedding
+
+
+class DecoderOnlyTransformer(nn.Module):
+ """Decoder-only (Autoregressive) Transformer model.
+
+ Args:
+ num_tokens: Number of tokens in the vocabulary.
+ max_seq_len: Maximum sequence length.
+ attn_layers_dim: Dimensionality of the attention layers.
+ attn_layers_depth: Number of attention layers.
+ attn_layers_heads: Number of attention heads.
+ with_cross_attention: Whether to use cross attention for conditioning.
+ embedding_dropout_rate: Dropout rate for the embedding.
+ include_fc: whether to include the final linear layer. Default to True.
+ use_combined_linear: whether to use a single linear layer for qkv projection, default to True.
+ use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
+ (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
+ """
+
+ def __init__(
+ self,
+ num_tokens: int,
+ max_seq_len: int,
+ attn_layers_dim: int,
+ attn_layers_depth: int,
+ attn_layers_heads: int,
+ with_cross_attention: bool = False,
+ embedding_dropout_rate: float = 0.0,
+ include_fc: bool = True,
+ use_combined_linear: bool = False,
+ use_flash_attention: bool = False,
+ ) -> None:
+ super().__init__()
+ self.num_tokens = num_tokens
+ self.max_seq_len = max_seq_len
+ self.attn_layers_dim = attn_layers_dim
+ self.attn_layers_depth = attn_layers_depth
+ self.attn_layers_heads = attn_layers_heads
+ self.with_cross_attention = with_cross_attention
+
+ self.token_embeddings = nn.Embedding(num_tokens, attn_layers_dim)
+ self.position_embeddings = AbsolutePositionalEmbedding(max_seq_len=max_seq_len, embedding_dim=attn_layers_dim)
+ self.embedding_dropout = nn.Dropout(embedding_dropout_rate)
+
+ self.blocks = nn.ModuleList(
+ [
+ TransformerBlock(
+ hidden_size=attn_layers_dim,
+ mlp_dim=attn_layers_dim * 4,
+ num_heads=attn_layers_heads,
+ dropout_rate=0.0,
+ qkv_bias=False,
+ causal=True,
+ sequence_length=max_seq_len,
+ with_cross_attention=with_cross_attention,
+ include_fc=include_fc,
+ use_combined_linear=use_combined_linear,
+ use_flash_attention=use_flash_attention,
+ )
+ for _ in range(attn_layers_depth)
+ ]
+ )
+
+ self.to_logits = nn.Linear(attn_layers_dim, num_tokens)
+
+ def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch.Tensor:
+ tok_emb = self.token_embeddings(x)
+ pos_emb = self.position_embeddings(x)
+ x = self.embedding_dropout(tok_emb + pos_emb)
+
+ for block in self.blocks:
+ x = block(x, context=context)
+ logits: torch.Tensor = self.to_logits(x)
+ return logits
+
+ def load_old_state_dict(self, old_state_dict: dict, verbose=False) -> None:
+ """
+ Load a state dict from a DecoderOnlyTransformer trained with
+ [MONAI Generative](https://github.com/Project-MONAI/GenerativeModels).
+
+ Args:
+ old_state_dict: state dict from the old DecoderOnlyTransformer model.
+ """
+
+ new_state_dict = self.state_dict()
+ # if all keys match, just load the state dict
+ if all(k in new_state_dict for k in old_state_dict):
+ print("All keys match, loading state dict.")
+ self.load_state_dict(old_state_dict)
+ return
+
+ if verbose:
+ # print all new_state_dict keys that are not in old_state_dict
+ for k in new_state_dict:
+ if k not in old_state_dict:
+ print(f"key {k} not found in old state dict")
+ # and vice versa
+ print("----------------------------------------------")
+ for k in old_state_dict:
+ if k not in new_state_dict:
+ print(f"key {k} not found in new state dict")
+
+ # copy over all matching keys
+ for k in new_state_dict:
+ if k in old_state_dict:
+ new_state_dict[k] = old_state_dict.pop(k)
+
+ # fix the renamed norm blocks first norm2 -> norm_cross_attention , norm3 -> norm2
+ for k in list(old_state_dict.keys()):
+ if "norm2" in k:
+ new_state_dict[k.replace("norm2", "norm_cross_attn")] = old_state_dict.pop(k)
+ if "norm3" in k:
+ new_state_dict[k.replace("norm3", "norm2")] = old_state_dict.pop(k)
+ if verbose:
+ # print all remaining keys in old_state_dict
+ print("remaining keys in old_state_dict:", old_state_dict.keys())
+ self.load_state_dict(new_state_dict)
diff --git a/monai/networks/nets/unet.py b/monai/networks/nets/unet.py
index 7b16b6c923..eac0ddab39 100644
--- a/monai/networks/nets/unet.py
+++ b/monai/networks/nets/unet.py
@@ -20,13 +20,10 @@
from monai.networks.blocks.convolutions import Convolution, ResidualUnit
from monai.networks.layers.factories import Act, Norm
from monai.networks.layers.simplelayers import SkipConnection
-from monai.utils import alias, export
__all__ = ["UNet", "Unet"]
-@export("monai.networks.nets")
-@alias("Unet")
class UNet(nn.Module):
"""
Enhanced version of UNet which has residual units implemented with the ResidualUnit class.
diff --git a/monai/networks/nets/unetr.py b/monai/networks/nets/unetr.py
index a88e5a92fd..79ea0e23f7 100644
--- a/monai/networks/nets/unetr.py
+++ b/monai/networks/nets/unetr.py
@@ -18,7 +18,7 @@
from monai.networks.blocks.dynunet_block import UnetOutBlock
from monai.networks.blocks.unetr_block import UnetrBasicBlock, UnetrPrUpBlock, UnetrUpBlock
from monai.networks.nets.vit import ViT
-from monai.utils import deprecated_arg, ensure_tuple_rep
+from monai.utils import ensure_tuple_rep
class UNETR(nn.Module):
@@ -27,9 +27,6 @@ class UNETR(nn.Module):
UNETR: Transformers for 3D Medical Image Segmentation "
"""
- @deprecated_arg(
- name="pos_embed", since="1.2", removed="1.4", new_name="proj_type", msg_suffix="please use `proj_type` instead."
- )
def __init__(
self,
in_channels: int,
@@ -39,7 +36,6 @@ def __init__(
hidden_size: int = 768,
mlp_dim: int = 3072,
num_heads: int = 12,
- pos_embed: str = "conv",
proj_type: str = "conv",
norm_name: tuple | str = "instance",
conv_block: bool = True,
@@ -67,9 +63,6 @@ def __init__(
qkv_bias: apply the bias term for the qkv linear layer in self attention block. Defaults to False.
save_attn: to make accessible the attention in self attention block. Defaults to False.
- .. deprecated:: 1.4
- ``pos_embed`` is deprecated in favor of ``proj_type``.
-
Examples::
# for single channel input 4-channel output with image size of (96,96,96), feature size of 32 and batch norm
diff --git a/monai/networks/nets/vista3d.py b/monai/networks/nets/vista3d.py
new file mode 100644
index 0000000000..6313b7812d
--- /dev/null
+++ b/monai/networks/nets/vista3d.py
@@ -0,0 +1,948 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import math
+from typing import Any, Callable, Optional, Sequence, Tuple
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+import monai
+from monai.networks.blocks import MLPBlock, UnetrBasicBlock
+from monai.networks.nets import SegResNetDS2
+from monai.transforms.utils import convert_points_to_disc
+from monai.transforms.utils import keep_merge_components_with_points as lcc
+from monai.transforms.utils import sample_points_from_label
+from monai.utils import optional_import, unsqueeze_left, unsqueeze_right
+
+rearrange, _ = optional_import("einops", name="rearrange")
+
+__all__ = ["VISTA3D", "vista3d132"]
+
+
+def vista3d132(encoder_embed_dim: int = 48, in_channels: int = 1):
+ """
+ Exact VISTA3D network configuration used in https://arxiv.org/abs/2406.05285>`_.
+ The model treats class index larger than 132 as zero-shot.
+
+ Args:
+ encoder_embed_dim: hidden dimension for encoder.
+ in_channels: input channel number.
+ """
+ segresnet = SegResNetDS2(
+ in_channels=in_channels,
+ blocks_down=(1, 2, 2, 4, 4),
+ norm="instance",
+ out_channels=encoder_embed_dim,
+ init_filters=encoder_embed_dim,
+ dsdepth=1,
+ )
+ point_head = PointMappingSAM(feature_size=encoder_embed_dim, n_classes=512, last_supported=132)
+ class_head = ClassMappingClassify(n_classes=512, feature_size=encoder_embed_dim, use_mlp=True)
+ vista = VISTA3D(image_encoder=segresnet, class_head=class_head, point_head=point_head)
+ return vista
+
+
+class VISTA3D(nn.Module):
+ """
+ VISTA3D based on:
+ `VISTA3D: Versatile Imaging SegmenTation and Annotation model for 3D Computed Tomography
+ `_.
+
+ Args:
+ image_encoder: image encoder backbone for feature extraction.
+ class_head: class head used for class index based segmentation
+ point_head: point head used for interactive segmetnation
+ """
+
+ def __init__(self, image_encoder: nn.Module, class_head: nn.Module, point_head: nn.Module):
+ super().__init__()
+ self.image_encoder = image_encoder
+ self.class_head = class_head
+ self.point_head = point_head
+ self.image_embeddings = None
+ self.auto_freeze = False
+ self.point_freeze = False
+ self.NINF_VALUE = -9999
+ self.PINF_VALUE = 9999
+
+ def update_slidingwindow_padding(
+ self,
+ pad_size: list | None,
+ labels: torch.Tensor | None,
+ prev_mask: torch.Tensor | None,
+ point_coords: torch.Tensor | None,
+ ):
+ """
+ Image has been padded by sliding window inferer.
+ The related padding need to be performed outside of slidingwindow inferer.
+
+ Args:
+ pad_size: padding size passed from sliding window inferer.
+ labels: image label ground truth.
+ prev_mask: previous segmentation mask.
+ point_coords: point click coordinates.
+ """
+ if pad_size is None:
+ return labels, prev_mask, point_coords
+ if labels is not None:
+ labels = F.pad(labels, pad=pad_size, mode="constant", value=0)
+ if prev_mask is not None:
+ prev_mask = F.pad(prev_mask, pad=pad_size, mode="constant", value=0)
+ if point_coords is not None:
+ point_coords = point_coords + torch.tensor(
+ [pad_size[-2], pad_size[-4], pad_size[-6]], device=point_coords.device
+ )
+ return labels, prev_mask, point_coords
+
+ def get_foreground_class_count(self, class_vector: torch.Tensor | None, point_coords: torch.Tensor | None) -> int:
+ """Get number of foreground classes based on class and point prompt."""
+ if class_vector is None:
+ if point_coords is None:
+ raise ValueError("class_vector and point_coords cannot be both None.")
+ return point_coords.shape[0]
+ else:
+ return class_vector.shape[0]
+
+ def convert_point_label(
+ self,
+ point_label: torch.Tensor,
+ label_set: Sequence[int] | None = None,
+ special_index: Sequence[int] = (23, 24, 25, 26, 27, 57, 128),
+ ):
+ """
+ Convert point label based on its class prompt. For special classes defined in special index,
+ the positive/negative point label will be converted from 1/0 to 3/2. The purpose is to separate those
+ classes with ambiguous classes.
+
+ Args:
+ point_label: the point label tensor, [B, N].
+ label_set: the label index matching the indexes in labels. If labels are mapped to global index using RelabelID,
+ this label_set should be global mapped index. If labels are not mapped to global index, e.g. in zero-shot
+ evaluation, this label_set should be the original index.
+ special_index: the special class index that needs to be converted.
+ """
+ if label_set is None:
+ return point_label
+ if not point_label.shape[0] == len(label_set):
+ raise ValueError("point_label and label_set must have the same length.")
+
+ for i in range(len(label_set)):
+ if label_set[i] in special_index:
+ for j in range(len(point_label[i])):
+ point_label[i, j] = point_label[i, j] + 2 if point_label[i, j] > -1 else point_label[i, j]
+ return point_label
+
+ def sample_points_patch_val(
+ self,
+ labels: torch.Tensor,
+ patch_coords: Sequence[slice],
+ label_set: Sequence[int],
+ use_center: bool = True,
+ mapped_label_set: Sequence[int] | None = None,
+ max_ppoint: int = 1,
+ max_npoint: int = 0,
+ ):
+ """
+ Sample points for patch during sliding window validation. Only used for point only validation.
+
+ Args:
+ labels: shape [1, 1, H, W, D].
+ patch_coords: a sequence of sliding window slice objects.
+ label_set: local index, must match values in labels.
+ use_center: sample points from the center.
+ mapped_label_set: global index, it is used to identify special classes and is the global index
+ for the sampled points.
+ max_ppoint/max_npoint: positive points and negative points to sample.
+ """
+ point_coords, point_labels = sample_points_from_label(
+ labels[patch_coords],
+ label_set,
+ max_ppoint=max_ppoint,
+ max_npoint=max_npoint,
+ device=labels.device,
+ use_center=use_center,
+ )
+ point_labels = self.convert_point_label(point_labels, mapped_label_set)
+ return (point_coords, point_labels, torch.tensor(label_set).to(point_coords.device).unsqueeze(-1))
+
+ def update_point_to_patch(
+ self, patch_coords: Sequence[slice], point_coords: torch.Tensor, point_labels: torch.Tensor
+ ):
+ """
+ Update point_coords with respect to patch coords.
+ If point is outside of the patch, remove the coordinates and set label to -1.
+
+ Args:
+ patch_coords: a sequence of the python slice objects representing the patch coordinates during sliding window inference.
+ This value is passed from sliding_window_inferer.
+ point_coords: point coordinates, [B, N, 3].
+ point_labels: point labels, [B, N].
+ """
+ patch_ends = [patch_coords[-3].stop, patch_coords[-2].stop, patch_coords[-1].stop]
+ patch_starts = [patch_coords[-3].start, patch_coords[-2].start, patch_coords[-1].start]
+ # update point coords
+ patch_starts_tensor = unsqueeze_left(torch.tensor(patch_starts, device=point_coords.device), 2)
+ patch_ends_tensor = unsqueeze_left(torch.tensor(patch_ends, device=point_coords.device), 2)
+ # [1 N 1]
+ indices = torch.logical_and(
+ ((point_coords - patch_starts_tensor) > 0).all(2), ((patch_ends_tensor - point_coords) > 0).all(2)
+ )
+ # check if it's within patch coords
+ point_coords = point_coords.clone() - patch_starts_tensor
+ point_labels = point_labels.clone()
+ if indices.any():
+ point_labels[~indices] = -1
+ point_coords[~indices] = 0
+ # also remove padded points, mainly used for inference.
+ not_pad_indices = (point_labels != -1).any(0)
+ point_coords = point_coords[:, not_pad_indices]
+ point_labels = point_labels[:, not_pad_indices]
+ return point_coords, point_labels
+ return None, None
+
+ def connected_components_combine(
+ self,
+ logits: torch.Tensor,
+ point_logits: torch.Tensor,
+ point_coords: torch.Tensor,
+ point_labels: torch.Tensor,
+ mapping_index: torch.Tensor,
+ thred: float = 0.5,
+ ):
+ """
+ Combine auto results with point click response. The auto results have shape [B, 1, H, W, D] which means B foreground masks
+ from a single image patch.
+ Out of those B foreground masks, user may add points to a subset of B1 foreground masks for editing.
+ mapping_index represents the correspondence between B and B1.
+ For mapping_index with point clicks, NaN values in logits will be replaced with point_logits. Meanwhile, the added/removed
+ region in point clicks must be updated by the lcc function.
+ Notice, if a positive point is within logits/prev_mask, the components containing the positive point will be added.
+
+ Args:
+ logits: automatic branch results, [B, 1, H, W, D].
+ point_logits: point branch results, [B1, 1, H, W, D].
+ point_coords: point coordinates, [B1, N, 3].
+ point_labels: point labels, [B1, N].
+ mapping_index: [B].
+ thred: the threshold to convert logits to binary.
+ """
+ logits = logits.as_tensor() if isinstance(logits, monai.data.MetaTensor) else logits
+ _logits = logits[mapping_index]
+ inside = []
+ for i in range(_logits.shape[0]):
+ inside.append(
+ np.any(
+ [
+ _logits[i, 0, p[0], p[1], p[2]].item() > 0
+ for p in point_coords[i].cpu().numpy().round().astype(int)
+ ]
+ )
+ )
+ inside_tensor = torch.tensor(inside).to(logits.device)
+ nan_mask = torch.isnan(_logits)
+ # _logits are converted to binary [B1, 1, H, W, D]
+ _logits = torch.nan_to_num(_logits, nan=self.NINF_VALUE).sigmoid()
+ pos_region = point_logits.sigmoid() > thred
+ diff_pos = torch.logical_and(torch.logical_or(_logits <= thred, unsqueeze_right(inside_tensor, 5)), pos_region)
+ diff_neg = torch.logical_and((_logits > thred), ~pos_region)
+ cc = lcc(diff_pos, diff_neg, point_coords=point_coords, point_labels=point_labels)
+ # cc is the region that can be updated by point_logits.
+ cc = cc.to(logits.device)
+ # Need to replace NaN with point_logits. diff_neg will never lie in nan_mask,
+ # only remove unconnected positive region.
+ uc_pos_region = torch.logical_and(pos_region, ~cc)
+ fill_mask = torch.logical_and(nan_mask, uc_pos_region)
+ if fill_mask.any():
+ # fill in the mean negative value
+ point_logits[fill_mask] = -1
+ # replace logits nan value and cc with point_logits
+ cc = torch.logical_or(nan_mask, cc).to(logits.dtype)
+ logits[mapping_index] *= 1 - cc
+ logits[mapping_index] += cc * point_logits
+ return logits
+
+ def gaussian_combine(
+ self,
+ logits: torch.Tensor,
+ point_logits: torch.Tensor,
+ point_coords: torch.Tensor,
+ point_labels: torch.Tensor,
+ mapping_index: torch.Tensor,
+ radius: int | None = None,
+ ):
+ """
+ Combine point results with auto results using gaussian.
+
+ Args:
+ logits: automatic branch results, [B, 1, H, W, D].
+ point_logits: point branch results, [B1, 1, H, W, D].
+ point_coords: point coordinates, [B1, N, 3].
+ point_labels: point labels, [B1, N].
+ mapping_index: [B].
+ radius: gaussian ball radius.
+ """
+ if radius is None:
+ radius = min(point_logits.shape[-3:]) // 5 # empirical value 5
+ weight = 1 - convert_points_to_disc(point_logits.shape[-3:], point_coords, point_labels, radius=radius).sum(
+ 1, keepdims=True
+ )
+ weight[weight < 0] = 0
+ logits = logits.as_tensor() if isinstance(logits, monai.data.MetaTensor) else logits
+ logits[mapping_index] *= weight
+ logits[mapping_index] += (1 - weight) * point_logits
+ return logits
+
+ def set_auto_grad(self, auto_freeze: bool = False, point_freeze: bool = False):
+ """
+ Freeze auto-branch or point-branch.
+
+ Args:
+ auto_freeze: whether to freeze the auto branch.
+ point_freeze: whether to freeze the point branch.
+ """
+ if auto_freeze != self.auto_freeze:
+ if hasattr(self.image_encoder, "set_auto_grad"):
+ self.image_encoder.set_auto_grad(auto_freeze=auto_freeze, point_freeze=point_freeze)
+ else:
+ for param in self.image_encoder.parameters():
+ param.requires_grad = (not auto_freeze) and (not point_freeze)
+ for param in self.class_head.parameters():
+ param.requires_grad = not auto_freeze
+ self.auto_freeze = auto_freeze
+
+ if point_freeze != self.point_freeze:
+ if hasattr(self.image_encoder, "set_auto_grad"):
+ self.image_encoder.set_auto_grad(auto_freeze=auto_freeze, point_freeze=point_freeze)
+ else:
+ for param in self.image_encoder.parameters():
+ param.requires_grad = (not auto_freeze) and (not point_freeze)
+ for param in self.point_head.parameters():
+ param.requires_grad = not point_freeze
+ self.point_freeze = point_freeze
+
+ def forward(
+ self,
+ input_images: torch.Tensor,
+ patch_coords: list[Sequence[slice]] | None = None,
+ point_coords: torch.Tensor | None = None,
+ point_labels: torch.Tensor | None = None,
+ class_vector: torch.Tensor | None = None,
+ prompt_class: torch.Tensor | None = None,
+ labels: torch.Tensor | None = None,
+ label_set: Sequence[int] | None = None,
+ prev_mask: torch.Tensor | None = None,
+ radius: int | None = None,
+ val_point_sampler: Callable | None = None,
+ transpose: bool = False,
+ **kwargs,
+ ):
+ """
+ The forward function for VISTA3D. We only support single patch in training and inference.
+ One exception is allowing sliding window batch size > 1 for automatic segmentation only case.
+ B represents number of objects, N represents number of points for each objects.
+
+ Args:
+ input_images: [1, 1, H, W, D]
+ point_coords: [B, N, 3]
+ point_labels: [B, N], -1 represents padding. 0/1 means negative/positive points for regular class.
+ 2/3 means negative/postive ponits for special supported class like tumor.
+ class_vector: [B, 1], the global class index.
+ prompt_class: [B, 1], the global class index. This value is associated with point_coords to identify if
+ the points are for zero-shot or supported class. When class_vector and point_coords are both
+ provided, prompt_class is the same as class_vector. For prompt_class[b] > 512, point_coords[b]
+ will be considered novel class.
+ patch_coords: a list of sequence of the python slice objects representing the patch coordinates during sliding window
+ inference. This value is passed from sliding_window_inferer.
+ This is an indicator for training phase or validation phase.
+ Notice for sliding window batch size > 1 (only supported by automatic segmentation), patch_coords will inlcude
+ coordinates of multiple patches. If point prompts are included, the batch size can only be one and all the
+ functions using patch_coords will by default use patch_coords[0].
+ labels: [1, 1, H, W, D], the groundtruth label tensor, only used for point-only evaluation
+ label_set: the label index matching the indexes in labels. If labels are mapped to global index using RelabelID,
+ this label_set should be global mapped index. If labels are not mapped to global index, e.g. in zero-shot
+ evaluation, this label_set should be the original index.
+ prev_mask: [B, N, H_fullsize, W_fullsize, D_fullsize].
+ This is the transposed raw output from sliding_window_inferer before any postprocessing.
+ When user click points to perform auto-results correction, this can be the auto-results.
+ radius: single float value controling the gaussian blur when combining point and auto results.
+ The gaussian combine is not used in VISTA3D training but might be useful for finetuning purposes.
+ val_point_sampler: function used to sample points from labels. This is only used for point-only evaluation.
+ transpose: bool. If true, the output will be transposed to be [1, B, H, W, D]. Required to be true if calling from
+ sliding window inferer/point inferer.
+ """
+ labels, prev_mask, point_coords = self.update_slidingwindow_padding(
+ kwargs.get("pad_size", None), labels, prev_mask, point_coords
+ )
+ image_size = input_images.shape[-3:]
+ device = input_images.device
+ if point_coords is None and class_vector is None:
+ return self.NINF_VALUE + torch.zeros([1, 1, *image_size], device=device)
+
+ bs = self.get_foreground_class_count(class_vector, point_coords)
+ if patch_coords is not None:
+ # if during validation and perform enable based point-validation.
+ if labels is not None and label_set is not None:
+ # if labels is not None, sample from labels for each patch.
+ if val_point_sampler is None:
+ # TODO: think about how to refactor this part.
+ val_point_sampler = self.sample_points_patch_val
+ point_coords, point_labels, prompt_class = val_point_sampler(labels, patch_coords[0], label_set)
+ if prompt_class[0].item() == 0: # type: ignore
+ point_labels[0] = -1 # type: ignore
+ labels, prev_mask = None, None
+ elif point_coords is not None:
+ # If not performing patch-based point only validation, use user provided click points for inference.
+ # the point clicks is in original image space, convert it to current patch-coordinate space.
+ point_coords, point_labels = self.update_point_to_patch(patch_coords[0], point_coords, point_labels) # type: ignore
+
+ if point_coords is not None and point_labels is not None:
+ # remove points that used for padding purposes (point_label = -1)
+ mapping_index = ((point_labels != -1).sum(1) > 0).to(torch.bool)
+ if mapping_index.any():
+ point_coords = point_coords[mapping_index]
+ point_labels = point_labels[mapping_index]
+ if prompt_class is not None:
+ prompt_class = prompt_class[mapping_index]
+ else:
+ if self.auto_freeze or (class_vector is None and patch_coords is None):
+ # if auto_freeze, point prompt must exist to allow loss backward
+ # in training, class_vector and point cannot both be None due to loss.backward()
+ mapping_index.fill_(True)
+ else:
+ point_coords, point_labels = None, None
+
+ if point_coords is None and class_vector is None:
+ logits = self.NINF_VALUE + torch.zeros([bs, 1, *image_size], device=device)
+ if transpose:
+ logits = logits.transpose(1, 0)
+ return logits
+
+ if self.image_embeddings is not None and kwargs.get("keep_cache", False) and class_vector is None:
+ out, out_auto = self.image_embeddings, None
+ else:
+ out, out_auto = self.image_encoder(
+ input_images, with_point=point_coords is not None, with_label=class_vector is not None
+ )
+ # release memory
+ input_images = None # type: ignore
+
+ # force releasing memories that set to None
+ torch.cuda.empty_cache()
+ if class_vector is not None:
+ logits, _ = self.class_head(out_auto, class_vector)
+ if point_coords is not None:
+ point_logits = self.point_head(out, point_coords, point_labels, class_vector=prompt_class)
+ if patch_coords is None:
+ logits = self.gaussian_combine(
+ logits, point_logits, point_coords, point_labels, mapping_index, radius # type: ignore
+ )
+ else:
+ # during validation use largest component
+ logits = self.connected_components_combine(
+ logits, point_logits, point_coords, point_labels, mapping_index # type: ignore
+ )
+ else:
+ logits = self.NINF_VALUE + torch.zeros([bs, 1, *image_size], device=device, dtype=out.dtype)
+ logits[mapping_index] = self.point_head(out, point_coords, point_labels, class_vector=prompt_class)
+ if prev_mask is not None and patch_coords is not None:
+ logits = self.connected_components_combine(
+ prev_mask[patch_coords[0]].transpose(1, 0).to(logits.device),
+ logits[mapping_index],
+ point_coords, # type: ignore
+ point_labels, # type: ignore
+ mapping_index,
+ )
+ if kwargs.get("keep_cache", False) and class_vector is None:
+ self.image_embeddings = out.detach()
+ if transpose:
+ logits = logits.transpose(1, 0)
+ return logits
+
+
+class PointMappingSAM(nn.Module):
+ def __init__(self, feature_size: int, max_prompt: int = 32, n_classes: int = 512, last_supported: int = 132):
+ """Interactive point head used for VISTA3D.
+ Adapted from segment anything:
+ `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/mask_decoder.py`.
+
+ Args:
+ feature_size: feature channel from encoder.
+ max_prompt: max prompt number in each forward iteration.
+ n_classes: number of classes the model can potentially support. This is the maximum number of class embeddings.
+ last_supported: number of classes the model support, this value should match the trained model weights.
+ """
+ super().__init__()
+ transformer_dim = feature_size
+ self.max_prompt = max_prompt
+ self.feat_downsample = nn.Sequential(
+ nn.Conv3d(in_channels=feature_size, out_channels=feature_size, kernel_size=3, stride=2, padding=1),
+ nn.InstanceNorm3d(feature_size),
+ nn.GELU(),
+ nn.Conv3d(in_channels=feature_size, out_channels=transformer_dim, kernel_size=3, stride=1, padding=1),
+ nn.InstanceNorm3d(feature_size),
+ )
+
+ self.mask_downsample = nn.Conv3d(in_channels=2, out_channels=2, kernel_size=3, stride=2, padding=1)
+
+ self.transformer = TwoWayTransformer(depth=2, embedding_dim=transformer_dim, mlp_dim=512, num_heads=4)
+ self.pe_layer = PositionEmbeddingRandom(transformer_dim // 2)
+ self.point_embeddings = nn.ModuleList([nn.Embedding(1, transformer_dim), nn.Embedding(1, transformer_dim)])
+ self.not_a_point_embed = nn.Embedding(1, transformer_dim)
+ self.special_class_embed = nn.Embedding(1, transformer_dim)
+ self.mask_tokens = nn.Embedding(1, transformer_dim)
+
+ self.output_upscaling = nn.Sequential(
+ nn.ConvTranspose3d(transformer_dim, transformer_dim, kernel_size=3, stride=2, padding=1, output_padding=1),
+ nn.InstanceNorm3d(transformer_dim),
+ nn.GELU(),
+ nn.Conv3d(transformer_dim, transformer_dim, kernel_size=3, stride=1, padding=1),
+ )
+
+ self.output_hypernetworks_mlps = MLP(transformer_dim, transformer_dim, transformer_dim, 3)
+ # class embedding
+ self.n_classes = n_classes
+ self.last_supported = last_supported
+ self.class_embeddings = nn.Embedding(n_classes, feature_size)
+ self.zeroshot_embed = nn.Embedding(1, transformer_dim)
+ self.supported_embed = nn.Embedding(1, transformer_dim)
+
+ def forward(
+ self,
+ out: torch.Tensor,
+ point_coords: torch.Tensor,
+ point_labels: torch.Tensor,
+ class_vector: torch.Tensor | None = None,
+ ):
+ """Args:
+ out: feature from encoder, [1, C, H, W, C]
+ point_coords: point coordinates, [B, N, 3]
+ point_labels: point labels, [B, N]
+ class_vector: class prompts, [B]
+ """
+ # downsample out
+ out_low = self.feat_downsample(out)
+ out_shape = tuple(out.shape[-3:])
+ # release memory
+ out = None # type: ignore
+ torch.cuda.empty_cache()
+ # embed points
+ points = point_coords + 0.5 # Shift to center of pixel
+ point_embedding = self.pe_layer.forward_with_coords(points, out_shape) # type: ignore
+ point_embedding[point_labels == -1] = 0.0
+ point_embedding[point_labels == -1] += self.not_a_point_embed.weight
+ point_embedding[point_labels == 0] += self.point_embeddings[0].weight
+ point_embedding[point_labels == 1] += self.point_embeddings[1].weight
+ point_embedding[point_labels == 2] += self.point_embeddings[0].weight + self.special_class_embed.weight
+ point_embedding[point_labels == 3] += self.point_embeddings[1].weight + self.special_class_embed.weight
+ output_tokens = self.mask_tokens.weight
+
+ output_tokens = output_tokens.unsqueeze(0).expand(point_embedding.size(0), -1, -1)
+ if class_vector is None:
+ tokens_all = torch.cat(
+ (
+ output_tokens,
+ point_embedding,
+ self.supported_embed.weight.unsqueeze(0).expand(point_embedding.size(0), -1, -1),
+ ),
+ dim=1,
+ )
+ # tokens_all = torch.cat((output_tokens, point_embedding), dim=1)
+ else:
+ class_embeddings = []
+ for i in class_vector:
+ if i > self.last_supported:
+ class_embeddings.append(self.zeroshot_embed.weight)
+ else:
+ class_embeddings.append(self.supported_embed.weight)
+ tokens_all = torch.cat((output_tokens, point_embedding, torch.stack(class_embeddings)), dim=1)
+ # cross attention
+ masks = []
+ max_prompt = self.max_prompt
+ for i in range(int(np.ceil(tokens_all.shape[0] / max_prompt))):
+ # remove variables in previous for loops to save peak memory for self.transformer
+ src, upscaled_embedding, hyper_in = None, None, None
+ torch.cuda.empty_cache()
+ idx = (i * max_prompt, min((i + 1) * max_prompt, tokens_all.shape[0]))
+ tokens = tokens_all[idx[0] : idx[1]]
+ src = torch.repeat_interleave(out_low, tokens.shape[0], dim=0)
+ pos_src = torch.repeat_interleave(self.pe_layer(out_low.shape[-3:]).unsqueeze(0), tokens.shape[0], dim=0)
+ b, c, h, w, d = src.shape
+ hs, src = self.transformer(src, pos_src, tokens)
+ mask_tokens_out = hs[:, :1, :]
+ hyper_in = self.output_hypernetworks_mlps(mask_tokens_out)
+ src = src.transpose(1, 2).view(b, c, h, w, d) # type: ignore
+ upscaled_embedding = self.output_upscaling(src)
+ b, c, h, w, d = upscaled_embedding.shape
+ mask = hyper_in @ upscaled_embedding.view(b, c, h * w * d)
+ masks.append(mask.view(-1, 1, h, w, d))
+
+ return torch.vstack(masks)
+
+
+class ClassMappingClassify(nn.Module):
+ """Class head that performs automatic segmentation based on class vector."""
+
+ def __init__(self, n_classes: int, feature_size: int, use_mlp: bool = True):
+ """Args:
+ n_classes: maximum number of class embedding.
+ feature_size: class embedding size.
+ use_mlp: use mlp to further map class embedding.
+ """
+ super().__init__()
+ self.use_mlp = use_mlp
+ if use_mlp:
+ self.mlp = nn.Sequential(
+ nn.Linear(feature_size, feature_size),
+ nn.InstanceNorm1d(1),
+ nn.GELU(),
+ nn.Linear(feature_size, feature_size),
+ )
+ self.class_embeddings = nn.Embedding(n_classes, feature_size)
+ self.image_post_mapping = nn.Sequential(
+ UnetrBasicBlock(
+ spatial_dims=3,
+ in_channels=feature_size,
+ out_channels=feature_size,
+ kernel_size=3,
+ stride=1,
+ norm_name="instance",
+ res_block=True,
+ ),
+ UnetrBasicBlock(
+ spatial_dims=3,
+ in_channels=feature_size,
+ out_channels=feature_size,
+ kernel_size=3,
+ stride=1,
+ norm_name="instance",
+ res_block=True,
+ ),
+ )
+
+ def forward(self, src: torch.Tensor, class_vector: torch.Tensor):
+ b, c, h, w, d = src.shape
+ src = self.image_post_mapping(src)
+ class_embedding = self.class_embeddings(class_vector)
+ if self.use_mlp:
+ class_embedding = self.mlp(class_embedding)
+ # [b,1,feat] @ [1,feat,dim], batch dimension become class_embedding batch dimension.
+ masks_embedding = class_embedding.squeeze() @ src.view(b, c, h * w * d)
+ masks_embedding = masks_embedding.view(b, -1, h, w, d).transpose(0, 1)
+
+ return masks_embedding, class_embedding
+
+
+class TwoWayTransformer(nn.Module):
+ def __init__(
+ self,
+ depth: int,
+ embedding_dim: int,
+ num_heads: int,
+ mlp_dim: int,
+ activation: tuple | str = "relu",
+ attention_downsample_rate: int = 2,
+ ) -> None:
+ """
+ A transformer decoder that attends to an input image using
+ queries whose positional embedding is supplied.
+ Adapted from `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/transformer.py`.
+
+ Args:
+ depth: number of layers in the transformer.
+ embedding_dim: the channel dimension for the input embeddings.
+ num_heads: the number of heads for multihead attention. Must divide embedding_dim.
+ mlp_dim: the channel dimension internal to the MLP block.
+ activation: the activation to use in the MLP block.
+ attention_downsample_rate: the rate at which to downsample the image before projecting.
+ """
+ super().__init__()
+ self.depth = depth
+ self.embedding_dim = embedding_dim
+ self.num_heads = num_heads
+ self.mlp_dim = mlp_dim
+ self.layers = nn.ModuleList()
+
+ for i in range(depth):
+ self.layers.append(
+ TwoWayAttentionBlock(
+ embedding_dim=embedding_dim,
+ num_heads=num_heads,
+ mlp_dim=mlp_dim,
+ activation=activation,
+ attention_downsample_rate=attention_downsample_rate,
+ skip_first_layer_pe=(i == 0),
+ )
+ )
+
+ self.final_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
+ self.norm_final_attn = nn.LayerNorm(embedding_dim)
+
+ def forward(
+ self, image_embedding: torch.Tensor, image_pe: torch.Tensor, point_embedding: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Args:
+ image_embedding: image to attend to. Should be shape
+ B x embedding_dim x h x w for any h and w.
+ image_pe: the positional encoding to add to the image. Must
+ have the same shape as image_embedding.
+ point_embedding: the embedding to add to the query points.
+ Must have shape B x N_points x embedding_dim for any N_points.
+
+ Returns:
+ torch.Tensor: the processed point_embedding.
+ torch.Tensor: the processed image_embedding.
+ """
+ # BxCxHxW -> BxHWxC == B x N_image_tokens x C
+ image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
+ image_pe = image_pe.flatten(2).permute(0, 2, 1)
+
+ # Prepare queries
+ queries = point_embedding
+ keys = image_embedding
+
+ # Apply transformer blocks and final layernorm
+ for layer in self.layers:
+ queries, keys = layer(queries=queries, keys=keys, query_pe=point_embedding, key_pe=image_pe)
+
+ # Apply the final attention layer from the points to the image
+ q = queries + point_embedding
+ k = keys + image_pe
+ attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
+ queries = queries + attn_out
+ queries = self.norm_final_attn(queries)
+
+ return queries, keys
+
+
+class TwoWayAttentionBlock(nn.Module):
+ def __init__(
+ self,
+ embedding_dim: int,
+ num_heads: int,
+ mlp_dim: int = 2048,
+ activation: tuple | str = "relu",
+ attention_downsample_rate: int = 2,
+ skip_first_layer_pe: bool = False,
+ ) -> None:
+ """
+ A transformer block with four layers: (1) self-attention of sparse
+ inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
+ block on sparse inputs, and (4) cross attention of dense inputs to sparse
+ inputs.
+ Adapted from `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/transformer.py`.
+
+ Args:
+ embedding_dim: the channel dimension of the embeddings.
+ num_heads: the number of heads in the attention layers.
+ mlp_dim: the hidden dimension of the mlp block.
+ activation: the activation of the mlp block.
+ skip_first_layer_pe: skip the PE on the first layer.
+ """
+ super().__init__()
+ self.self_attn = Attention(embedding_dim, num_heads)
+ self.norm1 = nn.LayerNorm(embedding_dim)
+
+ self.cross_attn_token_to_image = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
+ self.norm2 = nn.LayerNorm(embedding_dim)
+
+ self.mlp = MLPBlock(hidden_size=embedding_dim, mlp_dim=mlp_dim, act=activation, dropout_mode="vista3d")
+ self.norm3 = nn.LayerNorm(embedding_dim)
+
+ self.norm4 = nn.LayerNorm(embedding_dim)
+ self.cross_attn_image_to_token = Attention(embedding_dim, num_heads, downsample_rate=attention_downsample_rate)
+
+ self.skip_first_layer_pe = skip_first_layer_pe
+
+ def forward(
+ self, queries: torch.Tensor, keys: torch.Tensor, query_pe: torch.Tensor, key_pe: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ # Self attention block
+ if self.skip_first_layer_pe:
+ queries = self.self_attn(q=queries, k=queries, v=queries)
+ else:
+ q = queries + query_pe
+ attn_out = self.self_attn(q=q, k=q, v=queries)
+ queries = queries + attn_out
+ queries = self.norm1(queries)
+
+ # Cross attention block, tokens attending to image embedding
+ q = queries + query_pe
+ k = keys + key_pe
+ attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
+ queries = queries + attn_out
+ queries = self.norm2(queries)
+
+ # MLP block
+ mlp_out = self.mlp(queries)
+ queries = queries + mlp_out
+ queries = self.norm3(queries)
+
+ # Cross attention block, image embedding attending to tokens
+ q = queries + query_pe
+ k = keys + key_pe
+ attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
+ keys = keys + attn_out
+ keys = self.norm4(keys)
+
+ return queries, keys
+
+
+class Attention(nn.Module):
+ """
+ An attention layer that allows for downscaling the size of the embedding
+ after projection to queries, keys, and values.
+ Adapted from `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/transformer.py`.
+
+ Args:
+ embedding_dim: the channel dimension of the embeddings.
+ num_heads: the number of heads in the attention layers.
+ downsample_rate: the rate at which to downsample the image before projecting.
+ """
+
+ def __init__(self, embedding_dim: int, num_heads: int, downsample_rate: int = 1) -> None:
+ super().__init__()
+ self.embedding_dim = embedding_dim
+ self.internal_dim = embedding_dim // downsample_rate
+ self.num_heads = num_heads
+ if not self.internal_dim % num_heads == 0:
+ raise ValueError("num_heads must divide embedding_dim.")
+
+ self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
+ self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
+ self.v_proj = nn.Linear(embedding_dim, self.internal_dim)
+ self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
+
+ def _separate_heads(self, x: torch.Tensor, num_heads: int) -> torch.Tensor:
+ b, n, c = x.shape
+ x = x.reshape(b, n, num_heads, c // num_heads)
+ # B x N_heads x N_tokens x C_per_head
+ return x.transpose(1, 2)
+
+ def _recombine_heads(self, x: torch.Tensor) -> torch.Tensor:
+ b, n_heads, n_tokens, c_per_head = x.shape
+ x = x.transpose(1, 2)
+ # B x N_tokens x C
+ return x.reshape(b, n_tokens, n_heads * c_per_head)
+
+ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
+ # Input projections
+ q = self.q_proj(q)
+ k = self.k_proj(k)
+ v = self.v_proj(v)
+
+ # Separate into heads
+ q = self._separate_heads(q, self.num_heads)
+ k = self._separate_heads(k, self.num_heads)
+ v = self._separate_heads(v, self.num_heads)
+
+ # Attention
+ _, _, _, c_per_head = q.shape
+ attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens
+ attn = attn / math.sqrt(c_per_head)
+ attn = torch.softmax(attn, dim=-1)
+
+ # Get output
+ out = attn @ v
+ out = self._recombine_heads(out)
+ out = self.out_proj(out)
+
+ return out
+
+
+class PositionEmbeddingRandom(nn.Module):
+ """
+ Positional encoding using random spatial frequencies.
+ Adapted from `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/prompt_encoder.py`.
+
+ Args:
+ num_pos_feats: the number of positional encoding features.
+ scale: the scale of the positional encoding.
+ """
+
+ def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
+ super().__init__()
+ if scale is None or scale <= 0.0:
+ scale = 1.0
+ self.register_buffer("positional_encoding_gaussian_matrix", scale * torch.randn((3, num_pos_feats)))
+
+ def _pe_encoding(self, coords: torch.torch.Tensor) -> torch.torch.Tensor:
+ """Positionally encode points that are normalized to [0,1]."""
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
+ coords = 2 * coords - 1
+ # [bs=1,N=2,2] @ [2,128]
+ # [bs=1, N=2, 128]
+ coords = coords @ self.positional_encoding_gaussian_matrix
+ coords = 2 * np.pi * coords
+ # outputs d_1 x ... x d_n x C shape
+ # [bs=1, N=2, 128+128=256]
+ return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
+
+ def forward(self, size: Tuple[int, int, int]) -> torch.torch.Tensor:
+ """Generate positional encoding for a grid of the specified size."""
+ h, w, d = size
+ device: Any = self.positional_encoding_gaussian_matrix.device
+ grid = torch.ones((h, w, d), device=device, dtype=torch.float32)
+ x_embed = grid.cumsum(dim=0) - 0.5
+ y_embed = grid.cumsum(dim=1) - 0.5
+ z_embed = grid.cumsum(dim=2) - 0.5
+ x_embed = x_embed / h
+ y_embed = y_embed / w
+ z_embed = z_embed / d
+ pe = self._pe_encoding(torch.stack([x_embed, y_embed, z_embed], dim=-1))
+ # C x H x W
+ return pe.permute(3, 0, 1, 2)
+
+ def forward_with_coords(
+ self, coords_input: torch.torch.Tensor, image_size: Tuple[int, int, int]
+ ) -> torch.torch.Tensor:
+ """Positionally encode points that are not normalized to [0,1]."""
+ coords = coords_input.clone()
+ coords[:, :, 0] = coords[:, :, 0] / image_size[0]
+ coords[:, :, 1] = coords[:, :, 1] / image_size[1]
+ coords[:, :, 2] = coords[:, :, 2] / image_size[2]
+ # B x N x C
+ return self._pe_encoding(coords.to(torch.float))
+
+
+class MLP(nn.Module):
+ """
+ Multi-layer perceptron. This class is only used for `PointMappingSAM`.
+ Adapted from `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/mask_decoder.py`.
+
+ Args:
+ input_dim: the input dimension.
+ hidden_dim: the hidden dimension.
+ output_dim: the output dimension.
+ num_layers: the number of layers.
+ sigmoid_output: whether to apply a sigmoid activation to the output.
+ """
+
+ def __init__(
+ self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, sigmoid_output: bool = False
+ ) -> None:
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
+ self.sigmoid_output = sigmoid_output
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ for i, layer in enumerate(self.layers):
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ if self.sigmoid_output:
+ x = F.sigmoid(x)
+ return x
diff --git a/monai/networks/nets/vit.py b/monai/networks/nets/vit.py
index 4eada6aa76..07c5147cb2 100644
--- a/monai/networks/nets/vit.py
+++ b/monai/networks/nets/vit.py
@@ -18,7 +18,6 @@
from monai.networks.blocks.patchembedding import PatchEmbeddingBlock
from monai.networks.blocks.transformerblock import TransformerBlock
-from monai.utils import deprecated_arg
__all__ = ["ViT"]
@@ -31,9 +30,6 @@ class ViT(nn.Module):
ViT supports Torchscript but only works for Pytorch after 1.8.
"""
- @deprecated_arg(
- name="pos_embed", since="1.2", removed="1.4", new_name="proj_type", msg_suffix="please use `proj_type` instead."
- )
def __init__(
self,
in_channels: int,
@@ -43,7 +39,6 @@ def __init__(
mlp_dim: int = 3072,
num_layers: int = 12,
num_heads: int = 12,
- pos_embed: str = "conv",
proj_type: str = "conv",
pos_embed_type: str = "learnable",
classification: bool = False,
@@ -75,9 +70,6 @@ def __init__(
qkv_bias (bool, optional): apply bias to the qkv linear layer in self attention block. Defaults to False.
save_attn (bool, optional): to make accessible the attention in self attention block. Defaults to False.
- .. deprecated:: 1.4
- ``pos_embed`` is deprecated in favor of ``proj_type``.
-
Examples::
# for single channel input with image size of (96,96,96), conv position embedding and segmentation backbone
diff --git a/monai/networks/nets/vitautoenc.py b/monai/networks/nets/vitautoenc.py
index d69f5df4be..3c20f9a784 100644
--- a/monai/networks/nets/vitautoenc.py
+++ b/monai/networks/nets/vitautoenc.py
@@ -20,7 +20,7 @@
from monai.networks.blocks.patchembedding import PatchEmbeddingBlock
from monai.networks.blocks.transformerblock import TransformerBlock
from monai.networks.layers import Conv
-from monai.utils import deprecated_arg, ensure_tuple_rep, is_sqrt
+from monai.utils import ensure_tuple_rep, is_sqrt
__all__ = ["ViTAutoEnc"]
@@ -33,9 +33,6 @@ class ViTAutoEnc(nn.Module):
Modified to also give same dimension outputs as the input size of the image
"""
- @deprecated_arg(
- name="pos_embed", since="1.2", removed="1.4", new_name="proj_type", msg_suffix="please use `proj_type` instead."
- )
def __init__(
self,
in_channels: int,
@@ -47,7 +44,6 @@ def __init__(
mlp_dim: int = 3072,
num_layers: int = 12,
num_heads: int = 12,
- pos_embed: str = "conv",
proj_type: str = "conv",
dropout_rate: float = 0.0,
spatial_dims: int = 3,
@@ -71,9 +67,6 @@ def __init__(
qkv_bias: apply bias to the qkv linear layer in self attention block. Defaults to False.
save_attn: to make accessible the attention in self attention block. Defaults to False. Defaults to False.
- .. deprecated:: 1.4
- ``pos_embed`` is deprecated in favor of ``proj_type``.
-
Examples::
# for single channel input with image size of (96,96,96), conv position embedding and segmentation backbone
diff --git a/monai/networks/nets/voxelmorph.py b/monai/networks/nets/voxelmorph.py
index 0496cfc8f8..4923b6ad60 100644
--- a/monai/networks/nets/voxelmorph.py
+++ b/monai/networks/nets/voxelmorph.py
@@ -21,13 +21,10 @@
from monai.networks.blocks.upsample import UpSample
from monai.networks.blocks.warp import DVF2DDF, Warp
from monai.networks.layers.simplelayers import SkipConnection
-from monai.utils import alias, export
__all__ = ["VoxelMorphUNet", "voxelmorphunet", "VoxelMorph", "voxelmorph"]
-@export("monai.networks.nets")
-@alias("voxelmorphunet")
class VoxelMorphUNet(nn.Module):
"""
The backbone network used in VoxelMorph. See :py:class:`monai.networks.nets.VoxelMorph` for more details.
@@ -340,8 +337,6 @@ def forward(self, concatenated_pairs: torch.Tensor) -> torch.Tensor:
voxelmorphunet = VoxelMorphUNet
-@export("monai.networks.nets")
-@alias("voxelmorph")
class VoxelMorph(nn.Module):
"""
A re-implementation of VoxelMorph framework for medical image registration as described in
diff --git a/monai/networks/nets/vqvae.py b/monai/networks/nets/vqvae.py
new file mode 100644
index 0000000000..f198bfbb2b
--- /dev/null
+++ b/monai/networks/nets/vqvae.py
@@ -0,0 +1,472 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+from collections.abc import Sequence
+from typing import Tuple
+
+import torch
+import torch.nn as nn
+
+from monai.networks.blocks import Convolution
+from monai.networks.layers import Act
+from monai.networks.layers.vector_quantizer import EMAQuantizer, VectorQuantizer
+from monai.utils import ensure_tuple_rep
+
+__all__ = ["VQVAE"]
+
+
+class VQVAEResidualUnit(nn.Module):
+ """
+ Implementation of the ResidualLayer used in the VQVAE network as originally used in Morphology-preserving
+ Autoregressive 3D Generative Modelling of the Brain by Tudosiu et al. (https://arxiv.org/pdf/2209.03177.pdf).
+
+ The original implementation that can be found at
+ https://github.com/AmigoLab/SynthAnatomy/blob/main/src/networks/vqvae/baseline.py#L150.
+
+ Args:
+ spatial_dims: number of spatial spatial_dims of the input data.
+ in_channels: number of input channels.
+ num_res_channels: number of channels in the residual layers.
+ act: activation type and arguments. Defaults to RELU.
+ dropout: dropout ratio. Defaults to no dropout.
+ bias: whether to have a bias term. Defaults to True.
+ """
+
+ def __init__(
+ self,
+ spatial_dims: int,
+ in_channels: int,
+ num_res_channels: int,
+ act: tuple | str | None = Act.RELU,
+ dropout: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+
+ self.spatial_dims = spatial_dims
+ self.in_channels = in_channels
+ self.num_res_channels = num_res_channels
+ self.act = act
+ self.dropout = dropout
+ self.bias = bias
+
+ self.conv1 = Convolution(
+ spatial_dims=self.spatial_dims,
+ in_channels=self.in_channels,
+ out_channels=self.num_res_channels,
+ adn_ordering="DA",
+ act=self.act,
+ dropout=self.dropout,
+ bias=self.bias,
+ )
+
+ self.conv2 = Convolution(
+ spatial_dims=self.spatial_dims,
+ in_channels=self.num_res_channels,
+ out_channels=self.in_channels,
+ bias=self.bias,
+ conv_only=True,
+ )
+
+ def forward(self, x):
+ return torch.nn.functional.relu(x + self.conv2(self.conv1(x)), True)
+
+
+class Encoder(nn.Module):
+ """
+ Encoder module for VQ-VAE.
+
+ Args:
+ spatial_dims: number of spatial spatial_dims.
+ in_channels: number of input channels.
+ out_channels: number of channels in the latent space (embedding_dim).
+ channels: sequence containing the number of channels at each level of the encoder.
+ num_res_layers: number of sequential residual layers at each level.
+ num_res_channels: number of channels in the residual layers at each level.
+ downsample_parameters: A Tuple of Tuples for defining the downsampling convolutions. Each Tuple should hold the
+ following information stride (int), kernel_size (int), dilation (int) and padding (int).
+ dropout: dropout ratio.
+ act: activation type and arguments.
+ """
+
+ def __init__(
+ self,
+ spatial_dims: int,
+ in_channels: int,
+ out_channels: int,
+ channels: Sequence[int],
+ num_res_layers: int,
+ num_res_channels: Sequence[int],
+ downsample_parameters: Sequence[Tuple[int, int, int, int]],
+ dropout: float,
+ act: tuple | str | None,
+ ) -> None:
+ super().__init__()
+ self.spatial_dims = spatial_dims
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.channels = channels
+ self.num_res_layers = num_res_layers
+ self.num_res_channels = num_res_channels
+ self.downsample_parameters = downsample_parameters
+ self.dropout = dropout
+ self.act = act
+
+ blocks: list[nn.Module] = []
+
+ for i in range(len(self.channels)):
+ blocks.append(
+ Convolution(
+ spatial_dims=self.spatial_dims,
+ in_channels=self.in_channels if i == 0 else self.channels[i - 1],
+ out_channels=self.channels[i],
+ strides=self.downsample_parameters[i][0],
+ kernel_size=self.downsample_parameters[i][1],
+ adn_ordering="DA",
+ act=self.act,
+ dropout=None if i == 0 else self.dropout,
+ dropout_dim=1,
+ dilation=self.downsample_parameters[i][2],
+ padding=self.downsample_parameters[i][3],
+ )
+ )
+
+ for _ in range(self.num_res_layers):
+ blocks.append(
+ VQVAEResidualUnit(
+ spatial_dims=self.spatial_dims,
+ in_channels=self.channels[i],
+ num_res_channels=self.num_res_channels[i],
+ act=self.act,
+ dropout=self.dropout,
+ )
+ )
+
+ blocks.append(
+ Convolution(
+ spatial_dims=self.spatial_dims,
+ in_channels=self.channels[len(self.channels) - 1],
+ out_channels=self.out_channels,
+ strides=1,
+ kernel_size=3,
+ padding=1,
+ conv_only=True,
+ )
+ )
+
+ self.blocks = nn.ModuleList(blocks)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ for block in self.blocks:
+ x = block(x)
+ return x
+
+
+class Decoder(nn.Module):
+ """
+ Decoder module for VQ-VAE.
+
+ Args:
+ spatial_dims: number of spatial spatial_dims.
+ in_channels: number of channels in the latent space (embedding_dim).
+ out_channels: number of output channels.
+ channels: sequence containing the number of channels at each level of the decoder.
+ num_res_layers: number of sequential residual layers at each level.
+ num_res_channels: number of channels in the residual layers at each level.
+ upsample_parameters: A Tuple of Tuples for defining the upsampling convolutions. Each Tuple should hold the
+ following information stride (int), kernel_size (int), dilation (int), padding (int), output_padding (int).
+ dropout: dropout ratio.
+ act: activation type and arguments.
+ output_act: activation type and arguments for the output.
+ """
+
+ def __init__(
+ self,
+ spatial_dims: int,
+ in_channels: int,
+ out_channels: int,
+ channels: Sequence[int],
+ num_res_layers: int,
+ num_res_channels: Sequence[int],
+ upsample_parameters: Sequence[Tuple[int, int, int, int, int]],
+ dropout: float,
+ act: tuple | str | None,
+ output_act: tuple | str | None,
+ ) -> None:
+ super().__init__()
+ self.spatial_dims = spatial_dims
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.channels = channels
+ self.num_res_layers = num_res_layers
+ self.num_res_channels = num_res_channels
+ self.upsample_parameters = upsample_parameters
+ self.dropout = dropout
+ self.act = act
+ self.output_act = output_act
+
+ reversed_num_channels = list(reversed(self.channels))
+
+ blocks: list[nn.Module] = []
+ blocks.append(
+ Convolution(
+ spatial_dims=self.spatial_dims,
+ in_channels=self.in_channels,
+ out_channels=reversed_num_channels[0],
+ strides=1,
+ kernel_size=3,
+ padding=1,
+ conv_only=True,
+ )
+ )
+
+ reversed_num_res_channels = list(reversed(self.num_res_channels))
+ for i in range(len(self.channels)):
+ for _ in range(self.num_res_layers):
+ blocks.append(
+ VQVAEResidualUnit(
+ spatial_dims=self.spatial_dims,
+ in_channels=reversed_num_channels[i],
+ num_res_channels=reversed_num_res_channels[i],
+ act=self.act,
+ dropout=self.dropout,
+ )
+ )
+
+ blocks.append(
+ Convolution(
+ spatial_dims=self.spatial_dims,
+ in_channels=reversed_num_channels[i],
+ out_channels=self.out_channels if i == len(self.channels) - 1 else reversed_num_channels[i + 1],
+ strides=self.upsample_parameters[i][0],
+ kernel_size=self.upsample_parameters[i][1],
+ adn_ordering="DA",
+ act=self.act,
+ dropout=self.dropout if i != len(self.channels) - 1 else None,
+ norm=None,
+ dilation=self.upsample_parameters[i][2],
+ conv_only=i == len(self.channels) - 1,
+ is_transposed=True,
+ padding=self.upsample_parameters[i][3],
+ output_padding=self.upsample_parameters[i][4],
+ )
+ )
+
+ if self.output_act:
+ blocks.append(Act[self.output_act]())
+
+ self.blocks = nn.ModuleList(blocks)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ for block in self.blocks:
+ x = block(x)
+ return x
+
+
+class VQVAE(nn.Module):
+ """
+ Vector-Quantised Variational Autoencoder (VQ-VAE) used in Morphology-preserving Autoregressive 3D Generative
+ Modelling of the Brain by Tudosiu et al. (https://arxiv.org/pdf/2209.03177.pdf)
+
+ The original implementation can be found at
+ https://github.com/AmigoLab/SynthAnatomy/blob/main/src/networks/vqvae/baseline.py#L163/
+
+ Args:
+ spatial_dims: number of spatial spatial_dims.
+ in_channels: number of input channels.
+ out_channels: number of output channels.
+ downsample_parameters: A Tuple of Tuples for defining the downsampling convolutions. Each Tuple should hold the
+ following information stride (int), kernel_size (int), dilation (int) and padding (int).
+ upsample_parameters: A Tuple of Tuples for defining the upsampling convolutions. Each Tuple should hold the
+ following information stride (int), kernel_size (int), dilation (int), padding (int), output_padding (int).
+ num_res_layers: number of sequential residual layers at each level.
+ channels: number of channels at each level.
+ num_res_channels: number of channels in the residual layers at each level.
+ num_embeddings: VectorQuantization number of atomic elements in the codebook.
+ embedding_dim: VectorQuantization number of channels of the input and atomic elements.
+ commitment_cost: VectorQuantization commitment_cost.
+ decay: VectorQuantization decay.
+ epsilon: VectorQuantization epsilon.
+ act: activation type and arguments.
+ dropout: dropout ratio.
+ output_act: activation type and arguments for the output.
+ ddp_sync: whether to synchronize the codebook across processes.
+ use_checkpointing if True, use activation checkpointing to save memory.
+ """
+
+ def __init__(
+ self,
+ spatial_dims: int,
+ in_channels: int,
+ out_channels: int,
+ channels: Sequence[int] = (96, 96, 192),
+ num_res_layers: int = 3,
+ num_res_channels: Sequence[int] | int = (96, 96, 192),
+ downsample_parameters: Sequence[Tuple[int, int, int, int]] | Tuple[int, int, int, int] = (
+ (2, 4, 1, 1),
+ (2, 4, 1, 1),
+ (2, 4, 1, 1),
+ ),
+ upsample_parameters: Sequence[Tuple[int, int, int, int, int]] | Tuple[int, int, int, int, int] = (
+ (2, 4, 1, 1, 0),
+ (2, 4, 1, 1, 0),
+ (2, 4, 1, 1, 0),
+ ),
+ num_embeddings: int = 32,
+ embedding_dim: int = 64,
+ embedding_init: str = "normal",
+ commitment_cost: float = 0.25,
+ decay: float = 0.5,
+ epsilon: float = 1e-5,
+ dropout: float = 0.0,
+ act: tuple | str | None = Act.RELU,
+ output_act: tuple | str | None = None,
+ ddp_sync: bool = True,
+ use_checkpointing: bool = False,
+ ):
+ super().__init__()
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.spatial_dims = spatial_dims
+ self.channels = channels
+ self.num_embeddings = num_embeddings
+ self.embedding_dim = embedding_dim
+ self.use_checkpointing = use_checkpointing
+
+ if isinstance(num_res_channels, int):
+ num_res_channels = ensure_tuple_rep(num_res_channels, len(channels))
+
+ if len(num_res_channels) != len(channels):
+ raise ValueError(
+ "`num_res_channels` should be a single integer or a tuple of integers with the same length as "
+ "`num_channls`."
+ )
+ if all(isinstance(values, int) for values in upsample_parameters):
+ upsample_parameters_tuple: Sequence = (upsample_parameters,) * len(channels)
+ else:
+ upsample_parameters_tuple = upsample_parameters
+
+ if all(isinstance(values, int) for values in downsample_parameters):
+ downsample_parameters_tuple: Sequence = (downsample_parameters,) * len(channels)
+ else:
+ downsample_parameters_tuple = downsample_parameters
+
+ if not all(all(isinstance(value, int) for value in sub_item) for sub_item in downsample_parameters_tuple):
+ raise ValueError("`downsample_parameters` should be a single tuple of integer or a tuple of tuples.")
+
+ # check if downsample_parameters is a tuple of ints or a tuple of tuples of ints
+ if not all(all(isinstance(value, int) for value in sub_item) for sub_item in upsample_parameters_tuple):
+ raise ValueError("`upsample_parameters` should be a single tuple of integer or a tuple of tuples.")
+
+ for parameter in downsample_parameters_tuple:
+ if len(parameter) != 4:
+ raise ValueError("`downsample_parameters` should be a tuple of tuples with 4 integers.")
+
+ for parameter in upsample_parameters_tuple:
+ if len(parameter) != 5:
+ raise ValueError("`upsample_parameters` should be a tuple of tuples with 5 integers.")
+
+ if len(downsample_parameters_tuple) != len(channels):
+ raise ValueError(
+ "`downsample_parameters` should be a tuple of tuples with the same length as `num_channels`."
+ )
+
+ if len(upsample_parameters_tuple) != len(channels):
+ raise ValueError(
+ "`upsample_parameters` should be a tuple of tuples with the same length as `num_channels`."
+ )
+
+ self.num_res_layers = num_res_layers
+ self.num_res_channels = num_res_channels
+
+ self.encoder = Encoder(
+ spatial_dims=spatial_dims,
+ in_channels=in_channels,
+ out_channels=embedding_dim,
+ channels=channels,
+ num_res_layers=num_res_layers,
+ num_res_channels=num_res_channels,
+ downsample_parameters=downsample_parameters_tuple,
+ dropout=dropout,
+ act=act,
+ )
+
+ self.decoder = Decoder(
+ spatial_dims=spatial_dims,
+ in_channels=embedding_dim,
+ out_channels=out_channels,
+ channels=channels,
+ num_res_layers=num_res_layers,
+ num_res_channels=num_res_channels,
+ upsample_parameters=upsample_parameters_tuple,
+ dropout=dropout,
+ act=act,
+ output_act=output_act,
+ )
+
+ self.quantizer = VectorQuantizer(
+ quantizer=EMAQuantizer(
+ spatial_dims=spatial_dims,
+ num_embeddings=num_embeddings,
+ embedding_dim=embedding_dim,
+ commitment_cost=commitment_cost,
+ decay=decay,
+ epsilon=epsilon,
+ embedding_init=embedding_init,
+ ddp_sync=ddp_sync,
+ )
+ )
+
+ def encode(self, images: torch.Tensor) -> torch.Tensor:
+ output: torch.Tensor
+ if self.use_checkpointing:
+ output = torch.utils.checkpoint.checkpoint(self.encoder, images, use_reentrant=False)
+ else:
+ output = self.encoder(images)
+ return output
+
+ def quantize(self, encodings: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+ x_loss, x = self.quantizer(encodings)
+ return x, x_loss
+
+ def decode(self, quantizations: torch.Tensor) -> torch.Tensor:
+ output: torch.Tensor
+
+ if self.use_checkpointing:
+ output = torch.utils.checkpoint.checkpoint(self.decoder, quantizations, use_reentrant=False)
+ else:
+ output = self.decoder(quantizations)
+ return output
+
+ def index_quantize(self, images: torch.Tensor) -> torch.Tensor:
+ return self.quantizer.quantize(self.encode(images=images))
+
+ def decode_samples(self, embedding_indices: torch.Tensor) -> torch.Tensor:
+ return self.decode(self.quantizer.embed(embedding_indices))
+
+ def forward(self, images: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
+ quantizations, quantization_losses = self.quantize(self.encode(images))
+ reconstruction = self.decode(quantizations)
+
+ return reconstruction, quantization_losses
+
+ def encode_stage_2_inputs(self, x: torch.Tensor) -> torch.Tensor:
+ z = self.encode(x)
+ e, _ = self.quantize(z)
+ return e
+
+ def decode_stage_2_outputs(self, z: torch.Tensor) -> torch.Tensor:
+ e, _ = self.quantize(z)
+ image = self.decode(e)
+ return image
diff --git a/monai/networks/schedulers/__init__.py b/monai/networks/schedulers/__init__.py
new file mode 100644
index 0000000000..29e9020d65
--- /dev/null
+++ b/monai/networks/schedulers/__init__.py
@@ -0,0 +1,17 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+from .ddim import DDIMScheduler
+from .ddpm import DDPMScheduler
+from .pndm import PNDMScheduler
+from .scheduler import NoiseSchedules, Scheduler
diff --git a/monai/networks/schedulers/ddim.py b/monai/networks/schedulers/ddim.py
new file mode 100644
index 0000000000..50a680336d
--- /dev/null
+++ b/monai/networks/schedulers/ddim.py
@@ -0,0 +1,294 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# =========================================================================
+# Adapted from https://github.com/huggingface/diffusers
+# which has the following license:
+# https://github.com/huggingface/diffusers/blob/main/LICENSE
+#
+# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =========================================================================
+
+from __future__ import annotations
+
+import numpy as np
+import torch
+
+from .ddpm import DDPMPredictionType
+from .scheduler import Scheduler
+
+DDIMPredictionType = DDPMPredictionType
+
+
+class DDIMScheduler(Scheduler):
+ """
+ Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
+ diffusion probabilistic models (DDPMs) with non-Markovian guidance. Based on: Song et al. "Denoising Diffusion
+ Implicit Models" https://arxiv.org/abs/2010.02502
+
+ Args:
+ num_train_timesteps: number of diffusion steps used to train the model.
+ schedule: member of NoiseSchedules, name of noise schedule function in component store
+ clip_sample: option to clip predicted sample between -1 and 1 for numerical stability.
+ set_alpha_to_one: each diffusion step uses the value of alphas product at that step and at the previous one.
+ For the final step there is no previous alpha. When this option is `True` the previous alpha product is
+ fixed to `1`, otherwise it uses the value of alpha at step 0.
+ steps_offset: an offset added to the inference steps. You can use a combination of `steps_offset=1` and
+ `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
+ stable diffusion.
+ prediction_type: member of DDPMPredictionType
+ clip_sample_min: minimum clipping value when clip_sample equals True
+ clip_sample_max: maximum clipping value when clip_sample equals True
+ schedule_args: arguments to pass to the schedule function
+
+ """
+
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ schedule: str = "linear_beta",
+ clip_sample: bool = True,
+ set_alpha_to_one: bool = True,
+ steps_offset: int = 0,
+ prediction_type: str = DDIMPredictionType.EPSILON,
+ clip_sample_min: float = -1.0,
+ clip_sample_max: float = 1.0,
+ **schedule_args,
+ ) -> None:
+ super().__init__(num_train_timesteps, schedule, **schedule_args)
+
+ if prediction_type not in DDIMPredictionType.__members__.values():
+ raise ValueError("Argument `prediction_type` must be a member of DDIMPredictionType")
+
+ self.prediction_type = prediction_type
+
+ # At every step in ddim, we are looking into the previous alphas_cumprod
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
+ # `set_alpha_to_one` decides whether we set this parameter simply to one or
+ # whether we use the final alpha of the "non-previous" one.
+ self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
+
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = 1.0
+
+ self.timesteps = torch.from_numpy(np.arange(0, self.num_train_timesteps)[::-1].astype(np.int64))
+
+ self.clip_sample = clip_sample
+ self.clip_sample_values = [clip_sample_min, clip_sample_max]
+ self.steps_offset = steps_offset
+
+ # default the number of inference timesteps to the number of train steps
+ self.num_inference_steps: int
+ self.set_timesteps(self.num_train_timesteps)
+
+ def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None:
+ """
+ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model.
+ device: target device to put the data.
+ """
+ if num_inference_steps > self.num_train_timesteps:
+ raise ValueError(
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.num_train_timesteps`:"
+ f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle"
+ f" maximal {self.num_train_timesteps} timesteps."
+ )
+
+ self.num_inference_steps = num_inference_steps
+ step_ratio = self.num_train_timesteps // self.num_inference_steps
+ if self.steps_offset >= step_ratio:
+ raise ValueError(
+ f"`steps_offset`: {self.steps_offset} cannot be greater than or equal to "
+ f"`num_train_timesteps // num_inference_steps : {step_ratio}` as this will cause timesteps to exceed"
+ f" the max train timestep."
+ )
+
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
+ self.timesteps = torch.from_numpy(timesteps).to(device)
+ self.timesteps += self.steps_offset
+
+ def _get_variance(self, timestep: int, prev_timestep: int) -> torch.Tensor:
+ alpha_prod_t = self.alphas_cumprod[timestep]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+
+ variance: torch.Tensor = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
+
+ return variance
+
+ def step(
+ self,
+ model_output: torch.Tensor,
+ timestep: int,
+ sample: torch.Tensor,
+ eta: float = 0.0,
+ generator: torch.Generator | None = None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output: direct output from learned diffusion model.
+ timestep: current discrete timestep in the diffusion chain.
+ sample: current instance of sample being created by diffusion process.
+ eta: weight of noise for added noise in diffusion step.
+ generator: random number generator.
+
+ Returns:
+ pred_prev_sample: Predicted previous sample
+ pred_original_sample: Predicted original sample
+ """
+ # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
+ # Ideally, read DDIM paper in-detail understanding
+
+ # Notation ( ->
+ # - model_output -> e_theta(x_t, t)
+ # - pred_original_sample -> f_theta(x_t, t) or x_0
+ # - std_dev_t -> sigma_t
+ # - eta -> η
+ # - pred_sample_direction -> "direction pointing to x_t"
+ # - pred_prev_sample -> "x_t-1"
+
+ # 1. get previous step value (=t-1)
+ prev_timestep = timestep - self.num_train_timesteps // self.num_inference_steps
+
+ # 2. compute alphas, betas
+ alpha_prod_t = self.alphas_cumprod[timestep]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
+
+ beta_prod_t = 1 - alpha_prod_t
+
+ # predefinitions satisfy pylint/mypy, these values won't be ultimately used
+ pred_original_sample = sample
+ pred_epsilon = model_output
+
+ # 3. compute predicted original sample from predicted noise also called
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ if self.prediction_type == DDIMPredictionType.EPSILON:
+ pred_original_sample = (sample - (beta_prod_t**0.5) * model_output) / (alpha_prod_t**0.5)
+ pred_epsilon = model_output
+ elif self.prediction_type == DDIMPredictionType.SAMPLE:
+ pred_original_sample = model_output
+ pred_epsilon = (sample - (alpha_prod_t**0.5) * pred_original_sample) / (beta_prod_t**0.5)
+ elif self.prediction_type == DDIMPredictionType.V_PREDICTION:
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
+ pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
+
+ # 4. Clip "predicted x_0"
+ if self.clip_sample:
+ pred_original_sample = torch.clamp(
+ pred_original_sample, self.clip_sample_values[0], self.clip_sample_values[1]
+ )
+
+ # 5. compute variance: "sigma_t(η)" -> see formula (16)
+ # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
+ variance = self._get_variance(timestep, prev_timestep)
+ std_dev_t = eta * variance**0.5
+
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** 0.5 * pred_epsilon
+
+ # 7. compute x_t-1 without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ pred_prev_sample = alpha_prod_t_prev**0.5 * pred_original_sample + pred_sample_direction
+
+ if eta > 0:
+ # randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072
+ device: torch.device = torch.device(model_output.device if torch.is_tensor(model_output) else "cpu")
+ noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator, device=device)
+ variance = self._get_variance(timestep, prev_timestep) ** 0.5 * eta * noise
+
+ pred_prev_sample = pred_prev_sample + variance
+
+ return pred_prev_sample, pred_original_sample
+
+ def reversed_step(
+ self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Predict the sample at the next timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output: direct output from learned diffusion model.
+ timestep: current discrete timestep in the diffusion chain.
+ sample: current instance of sample being created by diffusion process.
+
+ Returns:
+ pred_prev_sample: Predicted previous sample
+ pred_original_sample: Predicted original sample
+ """
+ # See Appendix F at https://arxiv.org/pdf/2105.05233.pdf, or Equation (6) in https://arxiv.org/pdf/2203.04306.pdf
+
+ # Notation ( ->
+ # - model_output -> e_theta(x_t, t)
+ # - pred_original_sample -> f_theta(x_t, t) or x_0
+ # - std_dev_t -> sigma_t
+ # - eta -> η
+ # - pred_sample_direction -> "direction pointing to x_t"
+ # - pred_post_sample -> "x_t+1"
+
+ # 1. get previous step value (=t+1)
+ prev_timestep = timestep + self.num_train_timesteps // self.num_inference_steps
+
+ # 2. compute alphas, betas at timestep t+1
+ alpha_prod_t = self.alphas_cumprod[timestep]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
+
+ beta_prod_t = 1 - alpha_prod_t
+
+ # predefinitions satisfy pylint/mypy, these values won't be ultimately used
+ pred_original_sample = sample
+ pred_epsilon = model_output
+
+ # 3. compute predicted original sample from predicted noise also called
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+
+ if self.prediction_type == DDIMPredictionType.EPSILON:
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
+ pred_epsilon = model_output
+ elif self.prediction_type == DDIMPredictionType.SAMPLE:
+ pred_original_sample = model_output
+ pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
+ elif self.prediction_type == DDIMPredictionType.V_PREDICTION:
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
+ pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
+
+ # 4. Clip "predicted x_0"
+ if self.clip_sample:
+ pred_original_sample = torch.clamp(
+ pred_original_sample, self.clip_sample_values[0], self.clip_sample_values[1]
+ )
+
+ # 5. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ pred_sample_direction = (1 - alpha_prod_t_prev) ** (0.5) * pred_epsilon
+
+ # 6. compute x_t+1 without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ pred_post_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
+
+ return pred_post_sample, pred_original_sample
diff --git a/monai/networks/schedulers/ddpm.py b/monai/networks/schedulers/ddpm.py
new file mode 100644
index 0000000000..d64e11d379
--- /dev/null
+++ b/monai/networks/schedulers/ddpm.py
@@ -0,0 +1,254 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# =========================================================================
+# Adapted from https://github.com/huggingface/diffusers
+# which has the following license:
+# https://github.com/huggingface/diffusers/blob/main/LICENSE
+#
+# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =========================================================================
+
+from __future__ import annotations
+
+import numpy as np
+import torch
+
+from monai.utils import StrEnum
+
+from .scheduler import Scheduler
+
+
+class DDPMVarianceType(StrEnum):
+ """
+ Valid names for DDPM Scheduler's `variance_type` argument. Options to clip the variance used when adding noise
+ to the denoised sample.
+ """
+
+ FIXED_SMALL = "fixed_small"
+ FIXED_LARGE = "fixed_large"
+ LEARNED = "learned"
+ LEARNED_RANGE = "learned_range"
+
+
+class DDPMPredictionType(StrEnum):
+ """
+ Set of valid prediction type names for the DDPM scheduler's `prediction_type` argument.
+
+ epsilon: predicting the noise of the diffusion process
+ sample: directly predicting the noisy sample
+ v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf
+ """
+
+ EPSILON = "epsilon"
+ SAMPLE = "sample"
+ V_PREDICTION = "v_prediction"
+
+
+class DDPMScheduler(Scheduler):
+ """
+ Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and
+ Langevin dynamics sampling. Based on: Ho et al., "Denoising Diffusion Probabilistic Models"
+ https://arxiv.org/abs/2006.11239
+
+ Args:
+ num_train_timesteps: number of diffusion steps used to train the model.
+ schedule: member of NoiseSchedules, name of noise schedule function in component store
+ variance_type: member of DDPMVarianceType
+ clip_sample: option to clip predicted sample between -1 and 1 for numerical stability.
+ prediction_type: member of DDPMPredictionType
+ clip_sample_min: minimum clipping value when clip_sample equals True
+ clip_sample_max: maximum clipping value when clip_sample equals True
+ schedule_args: arguments to pass to the schedule function
+ """
+
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ schedule: str = "linear_beta",
+ variance_type: str = DDPMVarianceType.FIXED_SMALL,
+ clip_sample: bool = True,
+ prediction_type: str = DDPMPredictionType.EPSILON,
+ clip_sample_min: float = -1.0,
+ clip_sample_max: float = 1.0,
+ **schedule_args,
+ ) -> None:
+ super().__init__(num_train_timesteps, schedule, **schedule_args)
+
+ if variance_type not in DDPMVarianceType.__members__.values():
+ raise ValueError("Argument `variance_type` must be a member of `DDPMVarianceType`")
+
+ if prediction_type not in DDPMPredictionType.__members__.values():
+ raise ValueError("Argument `prediction_type` must be a member of `DDPMPredictionType`")
+
+ self.clip_sample = clip_sample
+ self.clip_sample_values = [clip_sample_min, clip_sample_max]
+ self.variance_type = variance_type
+ self.prediction_type = prediction_type
+
+ def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None:
+ """
+ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model.
+ device: target device to put the data.
+ """
+ if num_inference_steps > self.num_train_timesteps:
+ raise ValueError(
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.num_train_timesteps`:"
+ f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle"
+ f" maximal {self.num_train_timesteps} timesteps."
+ )
+
+ self.num_inference_steps = num_inference_steps
+ step_ratio = self.num_train_timesteps // self.num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].astype(np.int64)
+ self.timesteps = torch.from_numpy(timesteps).to(device)
+
+ def _get_mean(self, timestep: int, x_0: torch.Tensor, x_t: torch.Tensor) -> torch.Tensor:
+ """
+ Compute the mean of the posterior at timestep t.
+
+ Args:
+ timestep: current timestep.
+ x0: the noise-free input.
+ x_t: the input noised to timestep t.
+
+ Returns:
+ Returns the mean
+ """
+ # these attributes are used for calculating the posterior, q(x_{t-1}|x_t,x_0),
+ # (see formula (5-7) from https://arxiv.org/pdf/2006.11239.pdf)
+ alpha_t = self.alphas[timestep]
+ alpha_prod_t = self.alphas_cumprod[timestep]
+ alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one
+
+ x_0_coefficient = alpha_prod_t_prev.sqrt() * self.betas[timestep] / (1 - alpha_prod_t)
+ x_t_coefficient = alpha_t.sqrt() * (1 - alpha_prod_t_prev) / (1 - alpha_prod_t)
+
+ mean: torch.Tensor = x_0_coefficient * x_0 + x_t_coefficient * x_t
+
+ return mean
+
+ def _get_variance(self, timestep: int, predicted_variance: torch.Tensor | None = None) -> torch.Tensor:
+ """
+ Compute the variance of the posterior at timestep t.
+
+ Args:
+ timestep: current timestep.
+ predicted_variance: variance predicted by the model.
+
+ Returns:
+ Returns the variance
+ """
+ alpha_prod_t = self.alphas_cumprod[timestep]
+ alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one
+
+ # For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
+ # and sample from it to get previous sample
+ # x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
+ variance: torch.Tensor = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * self.betas[timestep]
+ # hacks - were probably added for training stability
+ if self.variance_type == DDPMVarianceType.FIXED_SMALL:
+ variance = torch.clamp(variance, min=1e-20)
+ elif self.variance_type == DDPMVarianceType.FIXED_LARGE:
+ variance = self.betas[timestep]
+ elif self.variance_type == DDPMVarianceType.LEARNED and predicted_variance is not None:
+ return predicted_variance
+ elif self.variance_type == DDPMVarianceType.LEARNED_RANGE and predicted_variance is not None:
+ min_log = variance
+ max_log = self.betas[timestep]
+ frac = (predicted_variance + 1) / 2
+ variance = frac * max_log + (1 - frac) * min_log
+
+ return variance
+
+ def step(
+ self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor, generator: torch.Generator | None = None
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output: direct output from learned diffusion model.
+ timestep: current discrete timestep in the diffusion chain.
+ sample: current instance of sample being created by diffusion process.
+ generator: random number generator.
+
+ Returns:
+ pred_prev_sample: Predicted previous sample
+ """
+ if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
+ model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
+ else:
+ predicted_variance = None
+
+ # 1. compute alphas, betas
+ alpha_prod_t = self.alphas_cumprod[timestep]
+ alpha_prod_t_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else self.one
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+
+ # 2. compute predicted original sample from predicted noise also called
+ # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
+ if self.prediction_type == DDPMPredictionType.EPSILON:
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
+ elif self.prediction_type == DDPMPredictionType.SAMPLE:
+ pred_original_sample = model_output
+ elif self.prediction_type == DDPMPredictionType.V_PREDICTION:
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
+
+ # 3. Clip "predicted x_0"
+ if self.clip_sample:
+ pred_original_sample = torch.clamp(
+ pred_original_sample, self.clip_sample_values[0], self.clip_sample_values[1]
+ )
+
+ # 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
+ pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * self.betas[timestep]) / beta_prod_t
+ current_sample_coeff = self.alphas[timestep] ** (0.5) * beta_prod_t_prev / beta_prod_t
+
+ # 5. Compute predicted previous sample µ_t
+ # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
+ pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
+
+ # 6. Add noise
+ variance = 0
+ if timestep > 0:
+ noise = torch.randn(
+ model_output.size(),
+ dtype=model_output.dtype,
+ layout=model_output.layout,
+ generator=generator,
+ device=model_output.device,
+ )
+ variance = (self._get_variance(timestep, predicted_variance=predicted_variance) ** 0.5) * noise
+
+ pred_prev_sample = pred_prev_sample + variance
+
+ return pred_prev_sample, pred_original_sample
diff --git a/monai/networks/schedulers/pndm.py b/monai/networks/schedulers/pndm.py
new file mode 100644
index 0000000000..c0728bbdff
--- /dev/null
+++ b/monai/networks/schedulers/pndm.py
@@ -0,0 +1,316 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# =========================================================================
+# Adapted from https://github.com/huggingface/diffusers
+# which has the following license:
+# https://github.com/huggingface/diffusers/blob/main/LICENSE
+#
+# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =========================================================================
+
+from __future__ import annotations
+
+from typing import Any
+
+import numpy as np
+import torch
+
+from monai.utils import StrEnum
+
+from .scheduler import Scheduler
+
+
+class PNDMPredictionType(StrEnum):
+ """
+ Set of valid prediction type names for the PNDM scheduler's `prediction_type` argument.
+
+ epsilon: predicting the noise of the diffusion process
+ v_prediction: velocity prediction, see section 2.4 https://imagen.research.google/video/paper.pdf
+ """
+
+ EPSILON = "epsilon"
+ V_PREDICTION = "v_prediction"
+
+
+class PNDMScheduler(Scheduler):
+ """
+ Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE integration techniques,
+ namely Runge-Kutta method and a linear multi-step method. Based on: Liu et al.,
+ "Pseudo Numerical Methods for Diffusion Models on Manifolds" https://arxiv.org/abs/2202.09778
+
+ Args:
+ num_train_timesteps: number of diffusion steps used to train the model.
+ schedule: member of NoiseSchedules, name of noise schedule function in component store
+ skip_prk_steps:
+ allows the scheduler to skip the Runge-Kutta steps that are defined in the original paper as being required
+ before plms step.
+ set_alpha_to_one:
+ each diffusion step uses the value of alphas product at that step and at the previous one. For the final
+ step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
+ otherwise it uses the value of alpha at step 0.
+ prediction_type: member of DDPMPredictionType
+ steps_offset:
+ an offset added to the inference steps. You can use a combination of `offset=1` and
+ `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
+ stable diffusion.
+ schedule_args: arguments to pass to the schedule function
+ """
+
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ schedule: str = "linear_beta",
+ skip_prk_steps: bool = False,
+ set_alpha_to_one: bool = False,
+ prediction_type: str = PNDMPredictionType.EPSILON,
+ steps_offset: int = 0,
+ **schedule_args,
+ ) -> None:
+ super().__init__(num_train_timesteps, schedule, **schedule_args)
+
+ if prediction_type not in PNDMPredictionType.__members__.values():
+ raise ValueError("Argument `prediction_type` must be a member of PNDMPredictionType")
+
+ self.prediction_type = prediction_type
+
+ self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
+
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = 1.0
+
+ # For now we only support F-PNDM, i.e. the runge-kutta method
+ # For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
+ # mainly at formula (9), (12), (13) and the Algorithm 2.
+ self.pndm_order = 4
+
+ self.skip_prk_steps = skip_prk_steps
+ self.steps_offset = steps_offset
+
+ # running values
+ self.cur_model_output = torch.Tensor()
+ self.counter = 0
+ self.cur_sample = torch.Tensor()
+ self.ets: list = []
+
+ # default the number of inference timesteps to the number of train steps
+ self.set_timesteps(num_train_timesteps)
+
+ def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None:
+ """
+ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps: number of diffusion steps used when generating samples with a pre-trained model.
+ device: target device to put the data.
+ """
+ if num_inference_steps > self.num_train_timesteps:
+ raise ValueError(
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.num_train_timesteps`:"
+ f" {self.num_train_timesteps} as the unet model trained with this scheduler can only handle"
+ f" maximal {self.num_train_timesteps} timesteps."
+ )
+
+ self.num_inference_steps = num_inference_steps
+ step_ratio = self.num_train_timesteps // self.num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round().astype(np.int64)
+ self._timesteps += self.steps_offset
+
+ if self.skip_prk_steps:
+ # for some models like stable diffusion the prk steps can/should be skipped to
+ # produce better results. When using PNDM with `self.skip_prk_steps` the implementation
+ # is based on crowsonkb's PLMS sampler implementation: https://github.com/CompVis/latent-diffusion/pull/51
+ self.prk_timesteps = np.array([])
+ self.plms_timesteps = self._timesteps[::-1]
+
+ else:
+ prk_timesteps = np.array(self._timesteps[-self.pndm_order :]).repeat(2) + np.tile(
+ np.array([0, self.num_train_timesteps // num_inference_steps // 2]), self.pndm_order
+ )
+ self.prk_timesteps = (prk_timesteps[:-1].repeat(2)[1:-1])[::-1].copy()
+ self.plms_timesteps = self._timesteps[:-3][
+ ::-1
+ ].copy() # we copy to avoid having negative strides which are not supported by torch.from_numpy
+
+ timesteps = np.concatenate([self.prk_timesteps, self.plms_timesteps]).astype(np.int64)
+ self.timesteps = torch.from_numpy(timesteps).to(device)
+ # update num_inference_steps - necessary if we use prk steps
+ self.num_inference_steps = len(self.timesteps)
+
+ self.ets = []
+ self.counter = 0
+
+ def step(self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor) -> tuple[torch.Tensor, Any]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+ This function calls `step_prk()` or `step_plms()` depending on the internal variable `counter`.
+
+ Args:
+ model_output: direct output from learned diffusion model.
+ timestep: current discrete timestep in the diffusion chain.
+ sample: current instance of sample being created by diffusion process.
+ Returns:
+ pred_prev_sample: Predicted previous sample
+ """
+ # return a tuple for consistency with samplers that return (previous pred, original sample pred)
+
+ if self.counter < len(self.prk_timesteps) and not self.skip_prk_steps:
+ return self.step_prk(model_output=model_output, timestep=timestep, sample=sample), None
+ else:
+ return self.step_plms(model_output=model_output, timestep=timestep, sample=sample), None
+
+ def step_prk(self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor) -> torch.Tensor:
+ """
+ Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
+ solution to the differential equation.
+
+ Args:
+ model_output: direct output from learned diffusion model.
+ timestep: current discrete timestep in the diffusion chain.
+ sample: current instance of sample being created by diffusion process.
+
+ Returns:
+ pred_prev_sample: Predicted previous sample
+ """
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ diff_to_prev = 0 if self.counter % 2 else self.num_train_timesteps // self.num_inference_steps // 2
+ prev_timestep = timestep - diff_to_prev
+ timestep = self.prk_timesteps[self.counter // 4 * 4]
+
+ if self.counter % 4 == 0:
+ self.cur_model_output = 1 / 6 * model_output
+ self.ets.append(model_output)
+ self.cur_sample = sample
+ elif (self.counter - 1) % 4 == 0:
+ self.cur_model_output += 1 / 3 * model_output
+ elif (self.counter - 2) % 4 == 0:
+ self.cur_model_output += 1 / 3 * model_output
+ elif (self.counter - 3) % 4 == 0:
+ model_output = self.cur_model_output + 1 / 6 * model_output
+ self.cur_model_output = torch.Tensor()
+
+ # cur_sample should not be an empty torch.Tensor()
+ cur_sample = self.cur_sample if self.cur_sample.numel() != 0 else sample
+
+ prev_sample: torch.Tensor = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output)
+ self.counter += 1
+
+ return prev_sample
+
+ def step_plms(self, model_output: torch.Tensor, timestep: int, sample: torch.Tensor) -> Any:
+ """
+ Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
+ times to approximate the solution.
+
+ Args:
+ model_output: direct output from learned diffusion model.
+ timestep: current discrete timestep in the diffusion chain.
+ sample: current instance of sample being created by diffusion process.
+
+ Returns:
+ pred_prev_sample: Predicted previous sample
+ """
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ if not self.skip_prk_steps and len(self.ets) < 3:
+ raise ValueError(
+ f"{self.__class__} can only be run AFTER scheduler has been run "
+ "in 'prk' mode for at least 12 iterations "
+ )
+
+ prev_timestep = timestep - self.num_train_timesteps // self.num_inference_steps
+
+ if self.counter != 1:
+ self.ets = self.ets[-3:]
+ self.ets.append(model_output)
+ else:
+ prev_timestep = timestep
+ timestep = timestep + self.num_train_timesteps // self.num_inference_steps
+
+ if len(self.ets) == 1 and self.counter == 0:
+ model_output = model_output
+ self.cur_sample = sample
+ elif len(self.ets) == 1 and self.counter == 1:
+ model_output = (model_output + self.ets[-1]) / 2
+ sample = self.cur_sample
+ self.cur_sample = torch.Tensor()
+ elif len(self.ets) == 2:
+ model_output = (3 * self.ets[-1] - self.ets[-2]) / 2
+ elif len(self.ets) == 3:
+ model_output = (23 * self.ets[-1] - 16 * self.ets[-2] + 5 * self.ets[-3]) / 12
+ else:
+ model_output = (1 / 24) * (55 * self.ets[-1] - 59 * self.ets[-2] + 37 * self.ets[-3] - 9 * self.ets[-4])
+
+ prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output)
+ self.counter += 1
+
+ return prev_sample
+
+ def _get_prev_sample(self, sample: torch.Tensor, timestep: int, prev_timestep: int, model_output: torch.Tensor):
+ # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
+ # this function computes x_(t−δ) using the formula of (9)
+ # Note that x_t needs to be added to both sides of the equation
+
+ # Notation ( ->
+ # alpha_prod_t -> α_t
+ # alpha_prod_t_prev -> α_(t−δ)
+ # beta_prod_t -> (1 - α_t)
+ # beta_prod_t_prev -> (1 - α_(t−δ))
+ # sample -> x_t
+ # model_output -> e_θ(x_t, t)
+ # prev_sample -> x_(t−δ)
+ alpha_prod_t = self.alphas_cumprod[timestep]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+
+ if self.prediction_type == PNDMPredictionType.V_PREDICTION:
+ model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
+
+ # corresponds to (α_(t−δ) - α_t) divided by
+ # denominator of x_t in formula (9) and plus 1
+ # Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) =
+ # sqrt(α_(t−δ)) / sqrt(α_t))
+ sample_coeff = (alpha_prod_t_prev / alpha_prod_t) ** (0.5)
+
+ # corresponds to denominator of e_θ(x_t, t) in formula (9)
+ model_output_denom_coeff = alpha_prod_t * beta_prod_t_prev ** (0.5) + (
+ alpha_prod_t * beta_prod_t * alpha_prod_t_prev
+ ) ** (0.5)
+
+ # full formula (9)
+ prev_sample = (
+ sample_coeff * sample - (alpha_prod_t_prev - alpha_prod_t) * model_output / model_output_denom_coeff
+ )
+
+ return prev_sample
diff --git a/monai/networks/schedulers/scheduler.py b/monai/networks/schedulers/scheduler.py
new file mode 100644
index 0000000000..acdccc60de
--- /dev/null
+++ b/monai/networks/schedulers/scheduler.py
@@ -0,0 +1,205 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# =========================================================================
+# Adapted from https://github.com/huggingface/diffusers
+# which has the following license:
+# https://github.com/huggingface/diffusers/blob/main/LICENSE
+#
+# Copyright 2022 UC Berkeley Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =========================================================================
+
+
+from __future__ import annotations
+
+import torch
+import torch.nn as nn
+
+from monai.utils import ComponentStore, unsqueeze_right
+
+NoiseSchedules = ComponentStore("NoiseSchedules", "Functions to generate noise schedules")
+
+
+@NoiseSchedules.add_def("linear_beta", "Linear beta schedule")
+def _linear_beta(num_train_timesteps: int, beta_start: float = 1e-4, beta_end: float = 2e-2):
+ """
+ Linear beta noise schedule function.
+
+ Args:
+ num_train_timesteps: number of timesteps
+ beta_start: start of beta range, default 1e-4
+ beta_end: end of beta range, default 2e-2
+
+ Returns:
+ betas: beta schedule tensor
+ """
+ return torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
+
+
+@NoiseSchedules.add_def("scaled_linear_beta", "Scaled linear beta schedule")
+def _scaled_linear_beta(num_train_timesteps: int, beta_start: float = 1e-4, beta_end: float = 2e-2):
+ """
+ Scaled linear beta noise schedule function.
+
+ Args:
+ num_train_timesteps: number of timesteps
+ beta_start: start of beta range, default 1e-4
+ beta_end: end of beta range, default 2e-2
+
+ Returns:
+ betas: beta schedule tensor
+ """
+ return torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+
+
+@NoiseSchedules.add_def("sigmoid_beta", "Sigmoid beta schedule")
+def _sigmoid_beta(num_train_timesteps: int, beta_start: float = 1e-4, beta_end: float = 2e-2, sig_range: float = 6):
+ """
+ Sigmoid beta noise schedule function.
+
+ Args:
+ num_train_timesteps: number of timesteps
+ beta_start: start of beta range, default 1e-4
+ beta_end: end of beta range, default 2e-2
+ sig_range: pos/neg range of sigmoid input, default 6
+
+ Returns:
+ betas: beta schedule tensor
+ """
+ betas = torch.linspace(-sig_range, sig_range, num_train_timesteps)
+ return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
+
+
+@NoiseSchedules.add_def("cosine", "Cosine schedule")
+def _cosine_beta(num_train_timesteps: int, s: float = 8e-3):
+ """
+ Cosine noise schedule, see https://arxiv.org/abs/2102.09672
+
+ Args:
+ num_train_timesteps: number of timesteps
+ s: smoothing factor, default 8e-3 (see referenced paper)
+
+ Returns:
+ (betas, alphas, alpha_cumprod) values
+ """
+ x = torch.linspace(0, num_train_timesteps, num_train_timesteps + 1)
+ alphas_cumprod = torch.cos(((x / num_train_timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
+ alphas_cumprod /= alphas_cumprod[0].item()
+ alphas = torch.clip(alphas_cumprod[1:] / alphas_cumprod[:-1], 0.0001, 0.9999)
+ betas = 1.0 - alphas
+ return betas, alphas, alphas_cumprod[:-1]
+
+
+class Scheduler(nn.Module):
+ """
+ Base class for other schedulers based on a noise schedule function.
+
+ This class is meant as the base for other schedulers which implement their own way of sampling or stepping. Here
+ the class defines beta, alpha, and alpha_cumprod values from a noise schedule function named with `schedule`,
+ which is the name of a component in NoiseSchedules. These components must all be callables which return either
+ the beta schedule alone or a triple containing (betas, alphas, alphas_cumprod) values. New schedule functions
+ can be provided by using the NoiseSchedules.add_def, for example:
+
+ .. code-block:: python
+
+ from monai.networks.schedulers import NoiseSchedules, DDPMScheduler
+
+ @NoiseSchedules.add_def("my_beta_schedule", "Some description of your function")
+ def _beta_function(num_train_timesteps, beta_start=1e-4, beta_end=2e-2):
+ return torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
+
+ scheduler = DDPMScheduler(num_train_timesteps=1000, schedule="my_beta_schedule")
+
+ All such functions should have an initial positional integer argument `num_train_timesteps` stating the number of
+ timesteps the schedule is for, otherwise any other arguments can be given which will be passed by keyword through
+ the constructor's `schedule_args` value. To see what noise functions are available, print the object NoiseSchedules
+ to get a listing of stored objects with their docstring descriptions.
+
+ Note: in previous versions of the schedulers the argument `schedule_beta` was used to state the beta schedule
+ type, this now replaced with `schedule` and most names used with the previous argument now have "_beta" appended
+ to them, eg. 'schedule_beta="linear"' -> 'schedule="linear_beta"'. The `beta_start` and `beta_end` arguments are
+ still used for some schedules but these are provided as keyword arguments now.
+
+ Args:
+ num_train_timesteps: number of diffusion steps used to train the model.
+ schedule: member of NoiseSchedules,
+ a named function returning the beta tensor or (betas, alphas, alphas_cumprod) triple
+ schedule_args: arguments to pass to the schedule function
+ """
+
+ def __init__(self, num_train_timesteps: int = 1000, schedule: str = "linear_beta", **schedule_args) -> None:
+ super().__init__()
+ schedule_args["num_train_timesteps"] = num_train_timesteps
+ noise_sched = NoiseSchedules[schedule](**schedule_args)
+
+ # set betas, alphas, alphas_cumprod based off return value from noise function
+ if isinstance(noise_sched, tuple):
+ self.betas, self.alphas, self.alphas_cumprod = noise_sched
+ else:
+ self.betas = noise_sched
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
+
+ self.num_train_timesteps = num_train_timesteps
+ self.one = torch.tensor(1.0)
+
+ # settable values
+ self.num_inference_steps: int | None = None
+ self.timesteps = torch.arange(num_train_timesteps - 1, -1, -1)
+
+ def add_noise(self, original_samples: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
+ """
+ Add noise to the original samples.
+
+ Args:
+ original_samples: original samples
+ noise: noise to add to samples
+ timesteps: timesteps tensor indicating the timestep to be computed for each sample.
+
+ Returns:
+ noisy_samples: sample with added noise
+ """
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
+ self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
+ timesteps = timesteps.to(original_samples.device)
+
+ sqrt_alpha_cumprod: torch.Tensor = unsqueeze_right(self.alphas_cumprod[timesteps] ** 0.5, original_samples.ndim)
+ sqrt_one_minus_alpha_prod: torch.Tensor = unsqueeze_right(
+ (1 - self.alphas_cumprod[timesteps]) ** 0.5, original_samples.ndim
+ )
+
+ noisy_samples = sqrt_alpha_cumprod * original_samples + sqrt_one_minus_alpha_prod * noise
+ return noisy_samples
+
+ def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
+ # Make sure alphas_cumprod and timestep have same device and dtype as sample
+ self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
+ timesteps = timesteps.to(sample.device)
+
+ sqrt_alpha_prod: torch.Tensor = unsqueeze_right(self.alphas_cumprod[timesteps] ** 0.5, sample.ndim)
+ sqrt_one_minus_alpha_prod: torch.Tensor = unsqueeze_right(
+ (1 - self.alphas_cumprod[timesteps]) ** 0.5, sample.ndim
+ )
+
+ velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
+ return velocity
diff --git a/monai/networks/trt_compiler.py b/monai/networks/trt_compiler.py
new file mode 100644
index 0000000000..d2d05fae22
--- /dev/null
+++ b/monai/networks/trt_compiler.py
@@ -0,0 +1,571 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import inspect
+import os
+import tempfile
+import threading
+from collections import OrderedDict
+from pathlib import Path
+from types import MethodType
+from typing import Any, Dict, List, Union
+
+import torch
+
+from monai.apps.utils import get_logger
+from monai.networks.utils import add_casts_around_norms, convert_to_onnx, convert_to_torchscript, get_profile_shapes
+from monai.utils.module import optional_import
+
+polygraphy, polygraphy_imported = optional_import("polygraphy")
+if polygraphy_imported:
+ from polygraphy.backend.common import bytes_from_path
+ from polygraphy.backend.trt import (
+ CreateConfig,
+ Profile,
+ engine_bytes_from_network,
+ engine_from_bytes,
+ network_from_onnx_path,
+ )
+
+trt, trt_imported = optional_import("tensorrt")
+torch_tensorrt, _ = optional_import("torch_tensorrt", "1.4.0")
+cudart, _ = optional_import("cuda.cudart")
+
+
+lock_sm = threading.Lock()
+
+
+# Map of TRT dtype -> Torch dtype
+def trt_to_torch_dtype_dict():
+ return {
+ trt.int32: torch.int32,
+ trt.float32: torch.float32,
+ trt.float16: torch.float16,
+ trt.bfloat16: torch.float16,
+ trt.int64: torch.int64,
+ trt.int8: torch.int8,
+ trt.bool: torch.bool,
+ }
+
+
+def get_dynamic_axes(profiles):
+ """
+ This method calculates dynamic_axes to use in onnx.export().
+ Args:
+ profiles: [[min,opt,max],...] list of profile dimensions
+ """
+ dynamic_axes: dict[str, list[int]] = {}
+ if not profiles:
+ return dynamic_axes
+ for profile in profiles:
+ for key in profile:
+ axes = []
+ vals = profile[key]
+ for i in range(len(vals[0])):
+ if vals[0][i] != vals[2][i]:
+ axes.append(i)
+ if len(axes) > 0:
+ dynamic_axes[key] = axes
+ return dynamic_axes
+
+
+def cuassert(cuda_ret):
+ """
+ Error reporting method for CUDA calls.
+ Args:
+ cuda_ret: CUDA return code.
+ """
+ err = cuda_ret[0]
+ if err != 0:
+ raise RuntimeError(f"CUDA ERROR: {err}")
+ if len(cuda_ret) > 1:
+ return cuda_ret[1]
+ return None
+
+
+class ShapeError(Exception):
+ """
+ Exception class to report errors from setting TRT plan input shapes
+ """
+
+ pass
+
+
+class TRTEngine:
+ """
+ An auxiliary class to implement running of TRT optimized engines
+
+ """
+
+ def __init__(self, plan_path, logger=None):
+ """
+ Loads serialized engine, creates execution context and activates it
+ Args:
+ plan_path: path to serialized TRT engine.
+ logger: optional logger object
+ """
+ self.plan_path = plan_path
+ self.logger = logger or get_logger("monai.networks.trt_compiler")
+ self.logger.info(f"Loading TensorRT engine: {self.plan_path}")
+ self.engine = engine_from_bytes(bytes_from_path(self.plan_path))
+ self.tensors = OrderedDict()
+ self.cuda_graph_instance = None # cuda graph
+ self.context = self.engine.create_execution_context()
+ self.input_names = []
+ self.output_names = []
+ self.dtypes = []
+ self.cur_profile = 0
+ dtype_dict = trt_to_torch_dtype_dict()
+ for idx in range(self.engine.num_io_tensors):
+ binding = self.engine[idx]
+ if self.engine.get_tensor_mode(binding) == trt.TensorIOMode.INPUT:
+ self.input_names.append(binding)
+ elif self.engine.get_tensor_mode(binding) == trt.TensorIOMode.OUTPUT:
+ self.output_names.append(binding)
+ dtype = dtype_dict[self.engine.get_tensor_dtype(binding)]
+ self.dtypes.append(dtype)
+
+ def allocate_buffers(self, device):
+ """
+ Allocates outputs to run TRT engine
+ Args:
+ device: GPU device to allocate memory on
+ """
+ ctx = self.context
+
+ for i, binding in enumerate(self.output_names):
+ shape = list(ctx.get_tensor_shape(binding))
+ if binding not in self.tensors or list(self.tensors[binding].shape) != shape:
+ t = torch.empty(shape, dtype=self.dtypes[i], device=device).contiguous()
+ self.tensors[binding] = t
+ ctx.set_tensor_address(binding, t.data_ptr())
+
+ def set_inputs(self, feed_dict, stream):
+ """
+ Sets input bindings for TRT engine according to feed_dict
+ Args:
+ feed_dict: a dictionary [str->Tensor]
+ stream: CUDA stream to use
+ """
+ e = self.engine
+ ctx = self.context
+
+ last_profile = self.cur_profile
+
+ def try_set_inputs():
+ for binding, t in feed_dict.items():
+ if t is not None:
+ t = t.contiguous()
+ shape = t.shape
+ ctx.set_input_shape(binding, shape)
+ ctx.set_tensor_address(binding, t.data_ptr())
+
+ while True:
+ try:
+ try_set_inputs()
+ break
+ except ShapeError:
+ next_profile = (self.cur_profile + 1) % e.num_optimization_profiles
+ if next_profile == last_profile:
+ raise
+ self.cur_profile = next_profile
+ ctx.set_optimization_profile_async(self.cur_profile, stream)
+
+ left = ctx.infer_shapes()
+ assert len(left) == 0
+
+ def infer(self, stream, use_cuda_graph=False):
+ """
+ Runs TRT engine.
+ Args:
+ stream: CUDA stream to run on
+ use_cuda_graph: use CUDA graph. Note: requires all inputs to be the same GPU memory between calls.
+ """
+ if use_cuda_graph:
+ if self.cuda_graph_instance is not None:
+ cuassert(cudart.cudaGraphLaunch(self.cuda_graph_instance, stream))
+ cuassert(cudart.cudaStreamSynchronize(stream))
+ else:
+ # do inference before CUDA graph capture
+ noerror = self.context.execute_async_v3(stream)
+ if not noerror:
+ raise ValueError("ERROR: inference failed.")
+ # capture cuda graph
+ cuassert(
+ cudart.cudaStreamBeginCapture(stream, cudart.cudaStreamCaptureMode.cudaStreamCaptureModeThreadLocal)
+ )
+ self.context.execute_async_v3(stream)
+ graph = cuassert(cudart.cudaStreamEndCapture(stream))
+ self.cuda_graph_instance = cuassert(cudart.cudaGraphInstantiate(graph, 0))
+ self.logger.info("CUDA Graph captured!")
+ else:
+ noerror = self.context.execute_async_v3(stream)
+ cuassert(cudart.cudaStreamSynchronize(stream))
+ if not noerror:
+ raise ValueError("ERROR: inference failed.")
+
+ return self.tensors
+
+
+class TrtCompiler:
+ """
+ This class implements:
+ - TRT lazy persistent export
+ - Running TRT with optional fallback to Torch
+ (for TRT engines with limited profiles)
+ """
+
+ def __init__(
+ self,
+ model,
+ plan_path,
+ precision="fp16",
+ method="onnx",
+ input_names=None,
+ output_names=None,
+ export_args=None,
+ build_args=None,
+ input_profiles=None,
+ dynamic_batchsize=None,
+ use_cuda_graph=False,
+ timestamp=None,
+ fallback=False,
+ logger=None,
+ ):
+ """
+ Initialization method:
+ Tries to load persistent serialized TRT engine
+ Saves its arguments for lazy TRT build on first forward() call
+ Args:
+ model: Model to "wrap".
+ plan_path : Path where to save persistent serialized TRT engine.
+ precision: TRT builder precision o engine model. Should be 'fp32'|'tf32'|'fp16'|'bf16'.
+ method: One of 'onnx'|'torch_trt'.
+ Default is 'onnx' (torch.onnx.export()->TRT). This is the most stable and efficient option.
+ 'torch_trt' may not work for some nets. Also AMP must be turned off for it to work.
+ input_names: Optional list of input names. If None, will be read from the function signature.
+ output_names: Optional list of output names. Note: If not None, patched forward() will return a dictionary.
+ export_args: Optional args to pass to export method. See onnx.export() and Torch-TensorRT docs for details.
+ build_args: Optional args to pass to TRT builder. See polygraphy.Config for details.
+ input_profiles: Optional list of profiles for TRT builder and ONNX export.
+ Each profile is a map of the form : {"input id" : [min_shape, opt_shape, max_shape], ...}.
+ dynamic_batchsize: A sequence with three elements to define the batch size range of the input for the model to be
+ converted. Should be a sequence like [MIN_BATCH, OPT_BATCH, MAX_BATCH].
+ [note]: If neither input_profiles nor dynamic_batchsize specified, static shapes will be used to build TRT engine.
+ use_cuda_graph: Use CUDA Graph for inference. Note: all inputs have to be the same GPU memory between calls!
+ timestamp: Optional timestamp to rebuild TRT engine (e.g. if config file changes).
+ fallback: Allow to fall back to Pytorch when TRT inference fails (e.g, shapes exceed max profile).
+ """
+
+ method_vals = ["onnx", "torch_trt"]
+ if method not in method_vals:
+ raise ValueError(f"trt_compile(): 'method' should be one of {method_vals}, got: {method}.")
+ precision_vals = ["fp32", "tf32", "fp16", "bf16"]
+ if precision not in precision_vals:
+ raise ValueError(f"trt_compile(): 'precision' should be one of {precision_vals}, got: {precision}.")
+
+ self.plan_path = plan_path
+ self.precision = precision
+ self.method = method
+ self.return_dict = output_names is not None
+ self.output_names = output_names or []
+ self.profiles = input_profiles or []
+ self.dynamic_batchsize = dynamic_batchsize
+ self.export_args = export_args or {}
+ self.build_args = build_args or {}
+ self.engine: TRTEngine | None = None
+ self.use_cuda_graph = use_cuda_graph
+ self.fallback = fallback
+ self.disabled = False
+
+ self.logger = logger or get_logger("monai.networks.trt_compiler")
+
+ # Normally we read input_names from forward() but can be overridden
+ if input_names is None:
+ argspec = inspect.getfullargspec(model.forward)
+ input_names = argspec.args[1:]
+ self.input_names = input_names
+ self.old_forward = model.forward
+
+ # Force engine rebuild if older than the timestamp
+ if timestamp is not None and os.path.exists(self.plan_path) and os.path.getmtime(self.plan_path) < timestamp:
+ os.remove(self.plan_path)
+
+ def _inputs_to_dict(self, input_example):
+ trt_inputs = {}
+ for i, inp in enumerate(input_example):
+ input_name = self.input_names[i]
+ trt_inputs[input_name] = inp
+ return trt_inputs
+
+ def _load_engine(self):
+ """
+ Loads TRT plan from disk and activates its execution context.
+ """
+ try:
+ self.engine = TRTEngine(self.plan_path, self.logger)
+ self.input_names = self.engine.input_names
+ except Exception as e:
+ self.logger.debug(f"Exception while loading the engine:\n{e}")
+
+ def forward(self, model, argv, kwargs):
+ """
+ Main forward method:
+ Builds TRT engine if not available yet.
+ Tries to run TRT engine
+ If exception thrown and self.callback==True: falls back to original Pytorch
+
+ Args: Passing through whatever args wrapped module's forward() has
+ Returns: Passing through wrapped module's forward() return value(s)
+
+ """
+ if self.engine is None and not self.disabled:
+ # Restore original forward for export
+ new_forward = model.forward
+ model.forward = self.old_forward
+ try:
+ self._load_engine()
+ if self.engine is None:
+ build_args = kwargs.copy()
+ if len(argv) > 0:
+ build_args.update(self._inputs_to_dict(argv))
+ self._build_and_save(model, build_args)
+ # This will reassign input_names from the engine
+ self._load_engine()
+ assert self.engine is not None
+ except Exception as e:
+ if self.fallback:
+ self.logger.info(f"Failed to build engine: {e}")
+ self.disabled = True
+ else:
+ raise e
+ if not self.disabled and not self.fallback:
+ # Delete all parameters
+ for param in model.parameters():
+ del param
+ # Call empty_cache to release GPU memory
+ torch.cuda.empty_cache()
+ model.forward = new_forward
+ # Run the engine
+ try:
+ if len(argv) > 0:
+ kwargs.update(self._inputs_to_dict(argv))
+ argv = ()
+
+ if self.engine is not None:
+ # forward_trt is not thread safe as we do not use per-thread execution contexts
+ with lock_sm:
+ device = torch.cuda.current_device()
+ stream = torch.cuda.Stream(device=device)
+ self.engine.set_inputs(kwargs, stream.cuda_stream)
+ self.engine.allocate_buffers(device=device)
+ # Need this to synchronize with Torch stream
+ stream.wait_stream(torch.cuda.current_stream())
+ ret = self.engine.infer(stream.cuda_stream, use_cuda_graph=self.use_cuda_graph)
+ # if output_names is not None, return dictionary
+ if not self.return_dict:
+ ret = list(ret.values())
+ if len(ret) == 1:
+ ret = ret[0]
+ return ret
+ except Exception as e:
+ if model is not None:
+ self.logger.info(f"Exception: {e}\nFalling back to Pytorch ...")
+ else:
+ raise e
+ return self.old_forward(*argv, **kwargs)
+
+ def _onnx_to_trt(self, onnx_path):
+ """
+ Builds TRT engine from ONNX file at onnx_path and saves to self.plan_path
+ """
+
+ profiles = []
+ if self.profiles:
+ for input_profile in self.profiles:
+ if isinstance(input_profile, Profile):
+ profiles.append(input_profile)
+ else:
+ p = Profile()
+ for name, dims in input_profile.items():
+ assert len(dims) == 3
+ p.add(name, min=dims[0], opt=dims[1], max=dims[2])
+ profiles.append(p)
+
+ build_args = self.build_args.copy()
+ build_args["tf32"] = self.precision != "fp32"
+ if self.precision == "fp16":
+ build_args["fp16"] = True
+ elif self.precision == "bf16":
+ build_args["bf16"] = True
+
+ self.logger.info(f"Building TensorRT engine for {onnx_path}: {self.plan_path}")
+ network = network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM])
+ return engine_bytes_from_network(network, config=CreateConfig(profiles=profiles, **build_args))
+
+ def _build_and_save(self, model, input_example):
+ """
+ If TRT engine is not ready, exports model to ONNX,
+ builds TRT engine and saves serialized TRT engine to the disk.
+ Args:
+ input_example: passed to onnx.export()
+ """
+
+ if self.engine is not None:
+ return
+
+ export_args = self.export_args
+
+ add_casts_around_norms(model)
+
+ if self.method == "torch_trt":
+ enabled_precisions = [torch.float32]
+ if self.precision == "fp16":
+ enabled_precisions.append(torch.float16)
+ elif self.precision == "bf16":
+ enabled_precisions.append(torch.bfloat16)
+ inputs = list(input_example.values())
+ ir_model = convert_to_torchscript(model, inputs=inputs, use_trace=True)
+
+ def get_torch_trt_input(input_shape, dynamic_batchsize):
+ min_input_shape, opt_input_shape, max_input_shape = get_profile_shapes(input_shape, dynamic_batchsize)
+ return torch_tensorrt.Input(
+ min_shape=min_input_shape, opt_shape=opt_input_shape, max_shape=max_input_shape
+ )
+
+ tt_inputs = [get_torch_trt_input(i.shape, self.dynamic_batchsize) for i in inputs]
+ engine_bytes = torch_tensorrt.convert_method_to_trt_engine(
+ ir_model,
+ "forward",
+ inputs=tt_inputs,
+ ir="torchscript",
+ enabled_precisions=enabled_precisions,
+ **export_args,
+ )
+ else:
+ dbs = self.dynamic_batchsize
+ if dbs:
+ if len(self.profiles) > 0:
+ raise ValueError("ERROR: Both dynamic_batchsize and input_profiles set for TrtCompiler!")
+ if len(dbs) != 3:
+ raise ValueError("dynamic_batchsize has to have len ==3 ")
+ profiles = {}
+ for id, val in input_example.items():
+ sh = val.shape[1:]
+ profiles[id] = [[dbs[0], *sh], [dbs[1], *sh], [dbs[2], *sh]]
+ self.profiles = [profiles]
+
+ if len(self.profiles) > 0:
+ export_args.update({"dynamic_axes": get_dynamic_axes(self.profiles)})
+
+ # Use temporary directory for easy cleanup in case of external weights
+ with tempfile.TemporaryDirectory() as tmpdir:
+ onnx_path = Path(tmpdir) / "model.onnx"
+ self.logger.info(
+ f"Exporting to {onnx_path}:\n\toutput_names={self.output_names}\n\texport args: {export_args}"
+ )
+ convert_to_onnx(
+ model,
+ input_example,
+ filename=str(onnx_path),
+ input_names=self.input_names,
+ output_names=self.output_names,
+ **export_args,
+ )
+ self.logger.info("Export to ONNX successful.")
+ engine_bytes = self._onnx_to_trt(str(onnx_path))
+
+ open(self.plan_path, "wb").write(engine_bytes)
+
+
+def trt_forward(self, *argv, **kwargs):
+ """
+ Patch function to replace original model's forward() with.
+ Redirects to TrtCompiler.forward()
+ """
+ return self._trt_compiler.forward(self, argv, kwargs)
+
+
+def trt_compile(
+ model: torch.nn.Module,
+ base_path: str,
+ args: Dict[str, Any] | None = None,
+ submodule: Union[str, List[str]] | None = None,
+ logger: Any | None = None,
+) -> torch.nn.Module:
+ """
+ Instruments model or submodule(s) with TrtCompiler and replaces its forward() with TRT hook.
+ Note: TRT 10.3 is recommended for best performance. Some nets may even fail to work with TRT 8.x.
+ NVIDIA Volta support (GPUs with compute capability 7.0) has been removed starting with TensorRT 10.5.
+ Review the TensorRT Support Matrix for which GPUs are supported.
+ Args:
+ model: module to patch with TrtCompiler object.
+ base_path: TRT plan(s) saved to f"{base_path}[.{submodule}].plan" path.
+ dirname(base_path) must exist, base_path does not have to.
+ If base_path does point to existing file (e.g. associated checkpoint),
+ that file becomes a dependency - its mtime is added to args["timestamp"].
+ args: Optional dict : unpacked and passed to TrtCompiler() - see TrtCompiler above for details.
+ submodule: Optional hierarchical id(s) of submodule to patch, e.g. ['image_decoder.decoder']
+ If None, TrtCompiler patch is applied to the whole model.
+ Otherwise, submodule (or list of) is being patched.
+ logger: Optional logger for diagnostics.
+ Returns:
+ Always returns same model passed in as argument. This is for ease of use in configs.
+ """
+
+ default_args: Dict[str, Any] = {
+ "method": "onnx",
+ "precision": "fp16",
+ "build_args": {"builder_optimization_level": 5, "precision_constraints": "obey"},
+ }
+
+ default_args.update(args or {})
+ args = default_args
+
+ if trt_imported and polygraphy_imported and torch.cuda.is_available():
+ # if "path" filename point to existing file (e.g. checkpoint)
+ # it's also treated as dependency
+ if os.path.exists(base_path):
+ timestamp = int(os.path.getmtime(base_path))
+ if "timestamp" in args:
+ timestamp = max(int(args["timestamp"]), timestamp)
+ args["timestamp"] = timestamp
+
+ def wrap(model, path):
+ wrapper = TrtCompiler(model, path + ".plan", logger=logger, **args)
+ model._trt_compiler = wrapper
+ model.forward = MethodType(trt_forward, model)
+
+ def find_sub(parent, submodule):
+ idx = submodule.find(".")
+ # if there is "." in name, call recursively
+ if idx != -1:
+ parent_name = submodule[:idx]
+ parent = getattr(parent, parent_name)
+ submodule = submodule[idx + 1 :]
+ return find_sub(parent, submodule)
+ return parent, submodule
+
+ if submodule is not None:
+ if isinstance(submodule, str):
+ submodule = [submodule]
+ for s in submodule:
+ parent, sub = find_sub(model, s)
+ wrap(getattr(parent, sub), base_path + "." + s)
+ else:
+ wrap(model, base_path)
+ else:
+ logger = logger or get_logger("monai.networks.trt_compiler")
+ logger.warning("TensorRT and/or polygraphy packages are not available! trt_compile() has no effect.")
+
+ return model
diff --git a/monai/networks/utils.py b/monai/networks/utils.py
index 4e6699f16b..cfad0364c3 100644
--- a/monai/networks/utils.py
+++ b/monai/networks/utils.py
@@ -16,6 +16,7 @@
import io
import re
+import tempfile
import warnings
from collections import OrderedDict
from collections.abc import Callable, Mapping, Sequence
@@ -36,12 +37,15 @@
onnx, _ = optional_import("onnx")
onnxreference, _ = optional_import("onnx.reference")
onnxruntime, _ = optional_import("onnxruntime")
+polygraphy, polygraphy_imported = optional_import("polygraphy")
+torch_tensorrt, _ = optional_import("torch_tensorrt", "1.4.0")
__all__ = [
"one_hot",
"predict_segmentation",
"normalize_transform",
"to_norm_affine",
+ "CastTempType",
"normal_init",
"icnr_init",
"pixelshuffle",
@@ -60,6 +64,7 @@
"look_up_named_module",
"set_named_module",
"has_nvfuser_instance_norm",
+ "get_profile_shapes",
]
logger = get_logger(module_name=__name__)
@@ -67,6 +72,26 @@
_has_nvfuser = None
+def get_profile_shapes(input_shape: Sequence[int], dynamic_batchsize: Sequence[int] | None):
+ """
+ Given a sample input shape, calculate min/opt/max shapes according to dynamic_batchsize.
+ """
+
+ def scale_batch_size(input_shape: Sequence[int], scale_num: int):
+ scale_shape = [*input_shape]
+ scale_shape[0] = scale_num
+ return scale_shape
+
+ # Use the dynamic batchsize range to generate the min, opt and max model input shape
+ if dynamic_batchsize:
+ min_input_shape = scale_batch_size(input_shape, dynamic_batchsize[0])
+ opt_input_shape = scale_batch_size(input_shape, dynamic_batchsize[1])
+ max_input_shape = scale_batch_size(input_shape, dynamic_batchsize[2])
+ else:
+ min_input_shape = opt_input_shape = max_input_shape = input_shape
+ return min_input_shape, opt_input_shape, max_input_shape
+
+
def has_nvfuser_instance_norm():
"""whether the current environment has InstanceNorm3dNVFuser
https://github.com/NVIDIA/apex/blob/23.05-devel/apex/normalization/instance_norm.py#L15-L16
@@ -605,6 +630,9 @@ def convert_to_onnx(
rtol: float = 1e-4,
atol: float = 0.0,
use_trace: bool = True,
+ do_constant_folding: bool = True,
+ constant_size_threshold: int = 16 * 1024 * 1024 * 1024,
+ dynamo=False,
**kwargs,
):
"""
@@ -631,7 +659,10 @@ def convert_to_onnx(
rtol: the relative tolerance when comparing the outputs of PyTorch model and TorchScript model.
atol: the absolute tolerance when comparing the outputs of PyTorch model and TorchScript model.
use_trace: whether to use `torch.jit.trace` to export the torchscript model.
- kwargs: other arguments except `obj` for `torch.jit.script()` to convert model, for more details:
+ do_constant_folding: passed to onnx.export(). If True, extra polygraphy folding pass is done.
+ constant_size_threshold: passed to polygrapy conatant forling, default = 16M
+ kwargs: if use_trace=True: additional arguments to pass to torch.onnx.export()
+ else: other arguments except `obj` for `torch.jit.script()` to convert model, for more details:
https://pytorch.org/docs/master/generated/torch.jit.script.html.
"""
@@ -641,6 +672,7 @@ def convert_to_onnx(
if use_trace:
# let torch.onnx.export to trace the model.
mode_to_export = model
+ torch_versioned_kwargs = kwargs
else:
if not pytorch_after(1, 10):
if "example_outputs" not in kwargs:
@@ -653,31 +685,34 @@ def convert_to_onnx(
del kwargs["example_outputs"]
mode_to_export = torch.jit.script(model, **kwargs)
+ if torch.is_tensor(inputs) or isinstance(inputs, dict):
+ onnx_inputs = (inputs,)
+ else:
+ onnx_inputs = tuple(inputs)
+ temp_file = None
if filename is None:
- f = io.BytesIO()
- torch.onnx.export(
- mode_to_export,
- tuple(inputs),
- f=f,
- input_names=input_names,
- output_names=output_names,
- dynamic_axes=dynamic_axes,
- opset_version=opset_version,
- **torch_versioned_kwargs,
- )
- onnx_model = onnx.load_model_from_string(f.getvalue())
+ temp_file = tempfile.NamedTemporaryFile()
+ f = temp_file.name
else:
- torch.onnx.export(
- mode_to_export,
- tuple(inputs),
- f=filename,
- input_names=input_names,
- output_names=output_names,
- dynamic_axes=dynamic_axes,
- opset_version=opset_version,
- **torch_versioned_kwargs,
- )
- onnx_model = onnx.load(filename)
+ f = filename
+
+ torch.onnx.export(
+ mode_to_export,
+ onnx_inputs,
+ f=f,
+ input_names=input_names,
+ output_names=output_names,
+ dynamic_axes=dynamic_axes,
+ opset_version=opset_version,
+ do_constant_folding=do_constant_folding,
+ **torch_versioned_kwargs,
+ )
+ onnx_model = onnx.load(f)
+
+ if do_constant_folding and polygraphy_imported:
+ from polygraphy.backend.onnx.loader import fold_constants
+
+ fold_constants(onnx_model, size_threshold=constant_size_threshold)
if verify:
if device is None:
@@ -813,7 +848,6 @@ def _onnx_trt_compile(
"""
trt, _ = optional_import("tensorrt", "8.5.3")
- torch_tensorrt, _ = optional_import("torch_tensorrt", "1.4.0")
input_shapes = (min_shape, opt_shape, max_shape)
# default to an empty list to fit the `torch_tensorrt.ts.embed_engine_in_new_module` function.
@@ -821,7 +855,7 @@ def _onnx_trt_compile(
output_names = [] if not output_names else output_names
# set up the TensorRT builder
- torch_tensorrt.set_device(device)
+ torch.cuda.set_device(device)
logger = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(logger)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
@@ -840,7 +874,6 @@ def _onnx_trt_compile(
# set up the conversion configuration
config = builder.create_builder_config()
- config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 31)
config.add_optimization_profile(profile)
if precision == "fp16":
config.set_flag(trt.BuilderFlag.FP16)
@@ -851,7 +884,7 @@ def _onnx_trt_compile(
# wrap the serialized TensorRT engine back to a TorchScript module.
trt_model = torch_tensorrt.ts.embed_engine_in_new_module(
f.getvalue(),
- device=torch.device(f"cuda:{device}"),
+ device=torch_tensorrt.Device(f"cuda:{device}"),
input_binding_names=input_names,
output_binding_names=output_names,
)
@@ -916,8 +949,6 @@ def convert_to_trt(
to compile model, for more details: https://pytorch.org/TensorRT/py_api/torch_tensorrt.html#torch-tensorrt-py.
"""
- torch_tensorrt, _ = optional_import("torch_tensorrt", version="1.4.0")
-
if not torch.cuda.is_available():
raise Exception("Cannot find any GPU devices.")
@@ -931,27 +962,13 @@ def convert_to_trt(
warnings.warn(f"The dynamic batch range sequence should have 3 elements, but got {dynamic_batchsize} elements.")
device = device if device else 0
- target_device = torch.device(f"cuda:{device}") if device else torch.device("cuda:0")
+ target_device = torch.device(f"cuda:{device}")
convert_precision = torch.float32 if precision == "fp32" else torch.half
inputs = [torch.rand(ensure_tuple(input_shape)).to(target_device)]
- def scale_batch_size(input_shape: Sequence[int], scale_num: int):
- scale_shape = [*input_shape]
- scale_shape[0] *= scale_num
- return scale_shape
-
- # Use the dynamic batchsize range to generate the min, opt and max model input shape
- if dynamic_batchsize:
- min_input_shape = scale_batch_size(input_shape, dynamic_batchsize[0])
- opt_input_shape = scale_batch_size(input_shape, dynamic_batchsize[1])
- max_input_shape = scale_batch_size(input_shape, dynamic_batchsize[2])
- else:
- min_input_shape = opt_input_shape = max_input_shape = input_shape
-
# convert the torch model to a TorchScript model on target device
model = model.eval().to(target_device)
- ir_model = convert_to_torchscript(model, device=target_device, inputs=inputs, use_trace=use_trace)
- ir_model.eval()
+ min_input_shape, opt_input_shape, max_input_shape = get_profile_shapes(input_shape, dynamic_batchsize)
if use_onnx:
# set the batch dim as dynamic
@@ -960,7 +977,6 @@ def scale_batch_size(input_shape: Sequence[int], scale_num: int):
ir_model = convert_to_onnx(
model, inputs, onnx_input_names, onnx_output_names, use_trace=use_trace, dynamic_axes=dynamic_axes
)
-
# convert the model through the ONNX-TensorRT way
trt_model = _onnx_trt_compile(
ir_model,
@@ -973,6 +989,8 @@ def scale_batch_size(input_shape: Sequence[int], scale_num: int):
output_names=onnx_output_names,
)
else:
+ ir_model = convert_to_torchscript(model, device=target_device, inputs=inputs, use_trace=use_trace)
+ ir_model.eval()
# convert the model through the Torch-TensorRT way
ir_model.to(target_device)
with torch.no_grad():
@@ -986,7 +1004,8 @@ def scale_batch_size(input_shape: Sequence[int], scale_num: int):
ir_model,
inputs=input_placeholder,
enabled_precisions=convert_precision,
- device=target_device,
+ device=torch_tensorrt.Device(f"cuda:{device}"),
+ ir="torchscript",
**kwargs,
)
@@ -1167,3 +1186,189 @@ def freeze_layers(model: nn.Module, freeze_vars=None, exclude_vars=None):
warnings.warn(f"The exclude_vars includes {param}, but requires_grad is False, change it to True.")
logger.info(f"{len(frozen_keys)} of {len(src_dict)} variables frozen.")
+
+
+class CastTempType(nn.Module):
+ """
+ Cast the input tensor to a temporary type before applying the submodule, and then cast it back to the initial type.
+ """
+
+ def __init__(self, initial_type, temporary_type, submodule):
+ super().__init__()
+ self.initial_type = initial_type
+ self.temporary_type = temporary_type
+ self.submodule = submodule
+
+ def forward(self, x):
+ dtype = x.dtype
+ if dtype == self.initial_type:
+ x = x.to(self.temporary_type)
+ x = self.submodule(x)
+ if dtype == self.initial_type:
+ x = x.to(self.initial_type)
+ return x
+
+
+def cast_tensor(x, from_dtype=torch.float16, to_dtype=torch.float32):
+ """
+ Utility function to cast a single tensor from from_dtype to to_dtype
+ """
+ return x.to(dtype=to_dtype) if x.dtype == from_dtype else x
+
+
+def cast_all(x, from_dtype=torch.float16, to_dtype=torch.float32):
+ """
+ Utility function to cast all tensors in a tuple from from_dtype to to_dtype
+ """
+ if isinstance(x, torch.Tensor):
+ return cast_tensor(x, from_dtype=from_dtype, to_dtype=to_dtype)
+ else:
+ if isinstance(x, dict):
+ new_dict = {}
+ for k in x.keys():
+ new_dict[k] = cast_all(x[k], from_dtype=from_dtype, to_dtype=to_dtype)
+ return new_dict
+ elif isinstance(x, tuple):
+ return tuple(cast_all(y, from_dtype=from_dtype, to_dtype=to_dtype) for y in x)
+
+
+class CastToFloat(torch.nn.Module):
+ """
+ Class used to add autocast protection for ONNX export
+ for forward methods with single return vaue
+ """
+
+ def __init__(self, mod):
+ super().__init__()
+ self.mod = mod
+
+ def forward(self, x):
+ dtype = x.dtype
+ with torch.amp.autocast("cuda", enabled=False):
+ ret = self.mod.forward(x.to(torch.float32)).to(dtype)
+ return ret
+
+
+class CastToFloatAll(torch.nn.Module):
+ """
+ Class used to add autocast protection for ONNX export
+ for forward methods with multiple return values
+ """
+
+ def __init__(self, mod):
+ super().__init__()
+ self.mod = mod
+
+ def forward(self, *args):
+ from_dtype = args[0].dtype
+ with torch.amp.autocast("cuda", enabled=False):
+ ret = self.mod.forward(*cast_all(args, from_dtype=from_dtype, to_dtype=torch.float32))
+ return cast_all(ret, from_dtype=torch.float32, to_dtype=from_dtype)
+
+
+def wrap_module(base_t: type[nn.Module], dest_t: type[nn.Module]) -> Callable[[nn.Module], nn.Module | None]:
+ """
+ Generic function generator to replace base_t module with dest_t wrapper.
+ Args:
+ base_t : module type to replace
+ dest_t : destination module type
+ Returns:
+ swap function to replace base_t module with dest_t
+ """
+
+ def expansion_fn(mod: nn.Module) -> nn.Module | None:
+ out = dest_t(mod)
+ return out
+
+ return expansion_fn
+
+
+def simple_replace(base_t: type[nn.Module], dest_t: type[nn.Module]) -> Callable[[nn.Module], nn.Module | None]:
+ """
+ Generic function generator to replace base_t module with dest_t.
+ base_t and dest_t should have same atrributes. No weights are copied.
+ Args:
+ base_t : module type to replace
+ dest_t : destination module type
+ Returns:
+ swap function to replace base_t module with dest_t
+ """
+
+ def expansion_fn(mod: nn.Module) -> nn.Module | None:
+ if not isinstance(mod, base_t):
+ return None
+ args = [getattr(mod, name, None) for name in mod.__constants__]
+ out = dest_t(*args)
+ return out
+
+ return expansion_fn
+
+
+def _swap_modules(model: nn.Module, mapping: dict[str, nn.Module]) -> nn.Module:
+ """
+ This function swaps nested modules as specified by "dot paths" in mod with a desired replacement. This allows
+ for swapping nested modules through arbitrary levels if children
+
+ NOTE: This occurs in place, if you want to preserve model then make sure to copy it first.
+
+ """
+ for path, new_mod in mapping.items():
+ expanded_path = path.split(".")
+ parent_mod = model
+ for sub_path in expanded_path[:-1]:
+ submod = parent_mod._modules[sub_path]
+ if submod is None:
+ break
+ else:
+ parent_mod = submod
+ parent_mod._modules[expanded_path[-1]] = new_mod
+
+ return model
+
+
+def replace_modules_by_type(
+ model: nn.Module, expansions: dict[str, Callable[[nn.Module], nn.Module | None]]
+) -> nn.Module:
+ """
+ Top-level function to replace modules in model, specified by class name with a desired replacement.
+ NOTE: This occurs in place, if you want to preserve model then make sure to copy it first.
+ Args:
+ model : top level module
+ expansions : replacement dictionary: module class name -> replacement function generator
+ Returns:
+ model, possibly modified in-place
+ """
+ mapping: dict[str, nn.Module] = {}
+ for name, m in model.named_modules():
+ m_type = type(m).__name__
+ if m_type in expansions:
+ # print (f"Found {m_type} in expansions ...")
+ swapped = expansions[m_type](m)
+ if swapped:
+ mapping[name] = swapped
+
+ print(f"Swapped {len(mapping)} modules")
+ _swap_modules(model, mapping)
+ return model
+
+
+def add_casts_around_norms(model: nn.Module) -> nn.Module:
+ """
+ Top-level function to add cast wrappers around modules known to cause issues for FP16/autocast ONNX export
+ NOTE: This occurs in place, if you want to preserve model then make sure to copy it first.
+ Args:
+ model : top level module
+ Returns:
+ model, possibly modified in-place
+ """
+ print("Adding casts around norms...")
+ cast_replacements = {
+ "BatchNorm1d": wrap_module(nn.BatchNorm1d, CastToFloat),
+ "BatchNorm2d": wrap_module(nn.BatchNorm2d, CastToFloat),
+ "BatchNorm3d": wrap_module(nn.BatchNorm2d, CastToFloat),
+ "LayerNorm": wrap_module(nn.LayerNorm, CastToFloat),
+ "InstanceNorm1d": wrap_module(nn.InstanceNorm1d, CastToFloat),
+ "InstanceNorm3d": wrap_module(nn.InstanceNorm3d, CastToFloat),
+ }
+ replace_modules_by_type(model, cast_replacements)
+ return model
diff --git a/monai/optimizers/lr_finder.py b/monai/optimizers/lr_finder.py
index 045135628d..aa2e4567b3 100644
--- a/monai/optimizers/lr_finder.py
+++ b/monai/optimizers/lr_finder.py
@@ -524,7 +524,7 @@ def plot(
# Plot the LR with steepest gradient
if steepest_lr:
lr_at_steepest_grad, loss_at_steepest_grad = self.get_steepest_gradient(skip_start, skip_end)
- if lr_at_steepest_grad is not None:
+ if lr_at_steepest_grad is not None and loss_at_steepest_grad is not None:
ax.scatter(
lr_at_steepest_grad,
loss_at_steepest_grad,
diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py
index ef76862617..5065366ecf 100644
--- a/monai/transforms/__init__.py
+++ b/monai/transforms/__init__.py
@@ -92,6 +92,7 @@
from .croppad.functional import crop_func, crop_or_pad_nd, pad_func, pad_nd
from .intensity.array import (
AdjustContrast,
+ ClipIntensityPercentiles,
ComputeHoVerMaps,
DetectEnvelope,
ForegroundMask,
@@ -135,6 +136,9 @@
AdjustContrastd,
AdjustContrastD,
AdjustContrastDict,
+ ClipIntensityPercentilesd,
+ ClipIntensityPercentilesD,
+ ClipIntensityPercentilesDict,
ComputeHoVerMapsd,
ComputeHoVerMapsD,
ComputeHoVerMapsDict,
@@ -234,8 +238,18 @@
)
from .inverse import InvertibleTransform, TraceableTransform
from .inverse_batch_transform import BatchInverseTransform, Decollated, DecollateD, DecollateDict
-from .io.array import SUPPORTED_READERS, LoadImage, SaveImage
-from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict
+from .io.array import SUPPORTED_READERS, LoadImage, SaveImage, WriteFileMapping
+from .io.dictionary import (
+ LoadImaged,
+ LoadImageD,
+ LoadImageDict,
+ SaveImaged,
+ SaveImageD,
+ SaveImageDict,
+ WriteFileMappingd,
+ WriteFileMappingD,
+ WriteFileMappingDict,
+)
from .lazy.array import ApplyPending
from .lazy.dictionary import ApplyPendingd, ApplyPendingD, ApplyPendingDict
from .lazy.functional import apply_pending
@@ -382,6 +396,8 @@
from .spatial.array import (
Affine,
AffineGrid,
+ ConvertBoxToPoints,
+ ConvertPointsToBoxes,
Flip,
GridDistortion,
GridPatch,
@@ -413,6 +429,12 @@
Affined,
AffineD,
AffineDict,
+ ConvertBoxToPointsd,
+ ConvertBoxToPointsD,
+ ConvertBoxToPointsDict,
+ ConvertPointsToBoxesd,
+ ConvertPointsToBoxesD,
+ ConvertPointsToBoxesDict,
Flipd,
FlipD,
FlipDict,
@@ -489,6 +511,7 @@
from .utility.array import (
AddCoordinateChannels,
AddExtremePointsChannel,
+ ApplyTransformToPoints,
AsChannelLast,
CastToType,
ClassesToIndices,
@@ -529,6 +552,9 @@
AddExtremePointsChanneld,
AddExtremePointsChannelD,
AddExtremePointsChannelDict,
+ ApplyTransformToPointsd,
+ ApplyTransformToPointsD,
+ ApplyTransformToPointsDict,
AsChannelLastd,
AsChannelLastD,
AsChannelLastDict,
@@ -671,6 +697,7 @@
in_bounds,
is_empty,
is_positive,
+ map_and_generate_sampling_centers,
map_binary_to_indices,
map_classes_to_indices,
map_spatial_axes,
@@ -687,6 +714,7 @@
weighted_patch_samples,
zero_margins,
)
+from .utils_morphological_ops import dilate, erode
from .utils_pytorch_numpy_unification import (
allclose,
any_np_pt,
diff --git a/monai/transforms/adaptors.py b/monai/transforms/adaptors.py
index f5f1a4fc18..5a0c24c7f6 100644
--- a/monai/transforms/adaptors.py
+++ b/monai/transforms/adaptors.py
@@ -125,12 +125,9 @@ def __call__(self, img, seg):
from typing import Callable
-from monai.utils import export as _monai_export
-
__all__ = ["adaptor", "apply_alias", "to_kwargs", "FunctionSignature"]
-@_monai_export("monai.transforms")
def adaptor(function, outputs, inputs=None):
def must_be_types_or_none(variable_name, variable, types):
@@ -215,7 +212,6 @@ def _inner(ditems):
return _inner
-@_monai_export("monai.transforms")
def apply_alias(fn, name_map):
def _inner(data):
@@ -236,7 +232,6 @@ def _inner(data):
return _inner
-@_monai_export("monai.transforms")
def to_kwargs(fn):
def _inner(data):
diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py
index ce3701b263..813f8c1d44 100644
--- a/monai/transforms/croppad/array.py
+++ b/monai/transforms/croppad/array.py
@@ -362,10 +362,10 @@ def __init__(self, lazy: bool = False):
@staticmethod
def compute_slices(
- roi_center: Sequence[int] | NdarrayOrTensor | None = None,
- roi_size: Sequence[int] | NdarrayOrTensor | None = None,
- roi_start: Sequence[int] | NdarrayOrTensor | None = None,
- roi_end: Sequence[int] | NdarrayOrTensor | None = None,
+ roi_center: Sequence[int] | int | NdarrayOrTensor | None = None,
+ roi_size: Sequence[int] | int | NdarrayOrTensor | None = None,
+ roi_start: Sequence[int] | int | NdarrayOrTensor | None = None,
+ roi_end: Sequence[int] | int | NdarrayOrTensor | None = None,
roi_slices: Sequence[slice] | None = None,
) -> tuple[slice]:
"""
@@ -459,10 +459,10 @@ class SpatialCrop(Crop):
def __init__(
self,
- roi_center: Sequence[int] | NdarrayOrTensor | None = None,
- roi_size: Sequence[int] | NdarrayOrTensor | None = None,
- roi_start: Sequence[int] | NdarrayOrTensor | None = None,
- roi_end: Sequence[int] | NdarrayOrTensor | None = None,
+ roi_center: Sequence[int] | int | NdarrayOrTensor | None = None,
+ roi_size: Sequence[int] | int | NdarrayOrTensor | None = None,
+ roi_start: Sequence[int] | int | NdarrayOrTensor | None = None,
+ roi_end: Sequence[int] | int | NdarrayOrTensor | None = None,
roi_slices: Sequence[slice] | None = None,
lazy: bool = False,
) -> None:
diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py
index be9441dc4a..cea11d9676 100644
--- a/monai/transforms/croppad/dictionary.py
+++ b/monai/transforms/croppad/dictionary.py
@@ -438,10 +438,10 @@ class SpatialCropd(Cropd):
def __init__(
self,
keys: KeysCollection,
- roi_center: Sequence[int] | None = None,
- roi_size: Sequence[int] | None = None,
- roi_start: Sequence[int] | None = None,
- roi_end: Sequence[int] | None = None,
+ roi_center: Sequence[int] | int | None = None,
+ roi_size: Sequence[int] | int | None = None,
+ roi_start: Sequence[int] | int | None = None,
+ roi_end: Sequence[int] | int | None = None,
roi_slices: Sequence[slice] | None = None,
allow_missing_keys: bool = False,
lazy: bool = False,
diff --git a/monai/transforms/croppad/functional.py b/monai/transforms/croppad/functional.py
index a8286fb90c..361ec48dcd 100644
--- a/monai/transforms/croppad/functional.py
+++ b/monai/transforms/croppad/functional.py
@@ -48,7 +48,7 @@ def _np_pad(img: NdarrayTensor, pad_width: list[tuple[int, int]], mode: str, **k
warnings.warn(f"Padding: moving img {img.shape} from cuda to cpu for dtype={img.dtype} mode={mode}.")
img_np = img.detach().cpu().numpy()
else:
- img_np = img
+ img_np = np.asarray(img)
mode = convert_pad_mode(dst=img_np, mode=mode).value
if mode == "constant" and "value" in kwargs:
kwargs["constant_values"] = kwargs.pop("value")
diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py
index 0085050ee3..20000c52c4 100644
--- a/monai/transforms/intensity/array.py
+++ b/monai/transforms/intensity/array.py
@@ -30,7 +30,7 @@
from monai.data.utils import get_random_patch, get_valid_patch_size
from monai.networks.layers import GaussianFilter, HilbertTransform, MedianFilter, SavitzkyGolayFilter
from monai.transforms.transform import RandomizableTransform, Transform
-from monai.transforms.utils import Fourier, equalize_hist, is_positive, rescale_array
+from monai.transforms.utils import Fourier, equalize_hist, is_positive, rescale_array, soft_clip
from monai.transforms.utils_pytorch_numpy_unification import clip, percentile, where
from monai.utils.enums import TransformBackends
from monai.utils.misc import ensure_tuple, ensure_tuple_rep, ensure_tuple_size, fall_back_tuple
@@ -54,6 +54,7 @@
"NormalizeIntensity",
"ThresholdIntensity",
"ScaleIntensityRange",
+ "ClipIntensityPercentiles",
"AdjustContrast",
"RandAdjustContrast",
"ScaleIntensityRangePercentiles",
@@ -1007,6 +1008,151 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
return ret
+class ClipIntensityPercentiles(Transform):
+ """
+ Apply clip based on the intensity distribution of input image.
+ If `sharpness_factor` is provided, the intensity values will be soft clipped according to
+ f(x) = x + (1/sharpness_factor)*softplus(- c(x - minv)) - (1/sharpness_factor)*softplus(c(x - maxv))
+ From https://medium.com/life-at-hopper/clip-it-clip-it-good-1f1bf711b291
+
+ Soft clipping preserves the order of the values and maintains the gradient everywhere.
+ For example:
+
+ .. code-block:: python
+ :emphasize-lines: 11, 22
+
+ image = torch.Tensor(
+ [[[1, 2, 3, 4, 5],
+ [1, 2, 3, 4, 5],
+ [1, 2, 3, 4, 5],
+ [1, 2, 3, 4, 5],
+ [1, 2, 3, 4, 5],
+ [1, 2, 3, 4, 5]]])
+
+ # Hard clipping from lower and upper image intensity percentiles
+ hard_clipper = ClipIntensityPercentiles(30, 70)
+ print(hard_clipper(image))
+ metatensor([[[2., 2., 3., 4., 4.],
+ [2., 2., 3., 4., 4.],
+ [2., 2., 3., 4., 4.],
+ [2., 2., 3., 4., 4.],
+ [2., 2., 3., 4., 4.],
+ [2., 2., 3., 4., 4.]]])
+
+
+ # Soft clipping from lower and upper image intensity percentiles
+ soft_clipper = ClipIntensityPercentiles(30, 70, 10.)
+ print(soft_clipper(image))
+ metatensor([[[2.0000, 2.0693, 3.0000, 3.9307, 4.0000],
+ [2.0000, 2.0693, 3.0000, 3.9307, 4.0000],
+ [2.0000, 2.0693, 3.0000, 3.9307, 4.0000],
+ [2.0000, 2.0693, 3.0000, 3.9307, 4.0000],
+ [2.0000, 2.0693, 3.0000, 3.9307, 4.0000],
+ [2.0000, 2.0693, 3.0000, 3.9307, 4.0000]]])
+
+ See Also:
+
+ - :py:class:`monai.transforms.ScaleIntensityRangePercentiles`
+ """
+
+ backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
+
+ def __init__(
+ self,
+ lower: float | None,
+ upper: float | None,
+ sharpness_factor: float | None = None,
+ channel_wise: bool = False,
+ return_clipping_values: bool = False,
+ dtype: DtypeLike = np.float32,
+ ) -> None:
+ """
+ Args:
+ lower: lower intensity percentile. In the case of hard clipping, None will have the same effect as 0 by
+ not clipping the lowest input values. However, in the case of soft clipping, None and zero will have
+ two different effects: None will not apply clipping to low values, whereas zero will still transform
+ the lower values according to the soft clipping transformation. Please check for more details:
+ https://medium.com/life-at-hopper/clip-it-clip-it-good-1f1bf711b291.
+ upper: upper intensity percentile. The same as for lower, but this time with the highest values. If we
+ are looking to perform soft clipping, if None then there will be no effect on this side whereas if set
+ to 100, the values will be passed via the corresponding clipping equation.
+ sharpness_factor: if not None, the intensity values will be soft clipped according to
+ f(x) = x + (1/sharpness_factor)*softplus(- c(x - minv)) - (1/sharpness_factor)*softplus(c(x - maxv)).
+ defaults to None.
+ channel_wise: if True, compute intensity percentile and normalize every channel separately.
+ default to False.
+ return_clipping_values: whether to return the calculated percentiles in tensor meta information.
+ If soft clipping and requested percentile is None, return None as the corresponding clipping
+ values in meta information. Clipping values are stored in a list with each element corresponding
+ to a channel if channel_wise is set to True. defaults to False.
+ dtype: output data type, if None, same as input image. defaults to float32.
+ """
+ if lower is None and upper is None:
+ raise ValueError("lower or upper percentiles must be provided")
+ if lower is not None and (lower < 0.0 or lower > 100.0):
+ raise ValueError("Percentiles must be in the range [0, 100]")
+ if upper is not None and (upper < 0.0 or upper > 100.0):
+ raise ValueError("Percentiles must be in the range [0, 100]")
+ if upper is not None and lower is not None and upper < lower:
+ raise ValueError("upper must be greater than or equal to lower")
+ if sharpness_factor is not None and sharpness_factor <= 0:
+ raise ValueError("sharpness_factor must be greater than 0")
+
+ self.lower = lower
+ self.upper = upper
+ self.sharpness_factor = sharpness_factor
+ self.channel_wise = channel_wise
+ if return_clipping_values:
+ self.clipping_values: list[tuple[float | None, float | None]] = []
+ self.return_clipping_values = return_clipping_values
+ self.dtype = dtype
+
+ def _clip(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
+ if self.sharpness_factor is not None:
+ lower_percentile = percentile(img, self.lower) if self.lower is not None else None
+ upper_percentile = percentile(img, self.upper) if self.upper is not None else None
+ img = soft_clip(img, self.sharpness_factor, lower_percentile, upper_percentile, self.dtype)
+ else:
+ lower_percentile = percentile(img, self.lower) if self.lower is not None else percentile(img, 0)
+ upper_percentile = percentile(img, self.upper) if self.upper is not None else percentile(img, 100)
+ img = clip(img, lower_percentile, upper_percentile)
+
+ if self.return_clipping_values:
+ self.clipping_values.append(
+ (
+ (
+ lower_percentile
+ if lower_percentile is None
+ else lower_percentile.item() if hasattr(lower_percentile, "item") else lower_percentile
+ ),
+ (
+ upper_percentile
+ if upper_percentile is None
+ else upper_percentile.item() if hasattr(upper_percentile, "item") else upper_percentile
+ ),
+ )
+ )
+ img = convert_to_tensor(img, track_meta=False)
+ return img
+
+ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
+ """
+ Apply the transform to `img`.
+ """
+ img = convert_to_tensor(img, track_meta=get_track_meta())
+ img_t = convert_to_tensor(img, track_meta=False)
+ if self.channel_wise:
+ img_t = torch.stack([self._clip(img=d) for d in img_t]) # type: ignore
+ else:
+ img_t = self._clip(img=img_t)
+
+ img = convert_to_dst_type(img_t, dst=img)[0]
+ if self.return_clipping_values:
+ img.meta["clipping_values"] = self.clipping_values # type: ignore
+
+ return img
+
+
class AdjustContrast(Transform):
"""
Changes image intensity with gamma transform. Each pixel/voxel intensity is updated as::
@@ -1265,7 +1411,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
else:
img_t = self._normalize(img=img_t)
- return convert_to_dst_type(img_t, dst=img)[0]
+ return convert_to_dst_type(img_t, dst=img, dtype=self.dtype)[0]
class MaskIntensity(Transform):
@@ -2643,6 +2789,9 @@ class UltrasoundConfidenceMapTransform(Transform):
It generates a confidence map by setting source and sink points in the image and computing the probability
for random walks to reach the source for each pixel.
+ The official code is available at:
+ https://campar.in.tum.de/Main/AthanasiosKaramalisCode
+
Args:
alpha (float, optional): Alpha parameter. Defaults to 2.0.
beta (float, optional): Beta parameter. Defaults to 90.0.
@@ -2650,14 +2799,32 @@ class UltrasoundConfidenceMapTransform(Transform):
mode (str, optional): 'RF' or 'B' mode data. Defaults to 'B'.
sink_mode (str, optional): Sink mode. Defaults to 'all'. If 'mask' is selected, a mask must be when
calling the transform. Can be one of 'all', 'mid', 'min', 'mask'.
+ use_cg (bool, optional): Use Conjugate Gradient method for solving the linear system. Defaults to False.
+ cg_tol (float, optional): Tolerance for the Conjugate Gradient method. Defaults to 1e-6.
+ Will be used only if `use_cg` is True.
+ cg_maxiter (int, optional): Maximum number of iterations for the Conjugate Gradient method. Defaults to 200.
+ Will be used only if `use_cg` is True.
"""
- def __init__(self, alpha: float = 2.0, beta: float = 90.0, gamma: float = 0.05, mode="B", sink_mode="all") -> None:
+ def __init__(
+ self,
+ alpha: float = 2.0,
+ beta: float = 90.0,
+ gamma: float = 0.05,
+ mode="B",
+ sink_mode="all",
+ use_cg=False,
+ cg_tol: float = 1.0e-6,
+ cg_maxiter: int = 200,
+ ):
self.alpha = alpha
self.beta = beta
self.gamma = gamma
self.mode = mode
self.sink_mode = sink_mode
+ self.use_cg = use_cg
+ self.cg_tol = cg_tol
+ self.cg_maxiter = cg_maxiter
if self.mode not in ["B", "RF"]:
raise ValueError(f"Unknown mode: {self.mode}. Supported modes are 'B' and 'RF'.")
@@ -2667,7 +2834,9 @@ def __init__(self, alpha: float = 2.0, beta: float = 90.0, gamma: float = 0.05,
f"Unknown sink mode: {self.sink_mode}. Supported modes are 'all', 'mid', 'min' and 'mask'."
)
- self._compute_conf_map = UltrasoundConfidenceMap(self.alpha, self.beta, self.gamma, self.mode, self.sink_mode)
+ self._compute_conf_map = UltrasoundConfidenceMap(
+ self.alpha, self.beta, self.gamma, self.mode, self.sink_mode, self.use_cg, self.cg_tol, self.cg_maxiter
+ )
def __call__(self, img: NdarrayOrTensor, mask: NdarrayOrTensor | None = None) -> NdarrayOrTensor:
"""Compute confidence map from an ultrasound image.
diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py
index 5b911904b0..f2b1a2fd40 100644
--- a/monai/transforms/intensity/dictionary.py
+++ b/monai/transforms/intensity/dictionary.py
@@ -17,7 +17,8 @@
from __future__ import annotations
-from typing import Callable, Hashable, Mapping, Sequence
+from collections.abc import Hashable, Mapping, Sequence
+from typing import Callable
import numpy as np
@@ -26,6 +27,7 @@
from monai.data.meta_obj import get_track_meta
from monai.transforms.intensity.array import (
AdjustContrast,
+ ClipIntensityPercentiles,
ComputeHoVerMaps,
ForegroundMask,
GaussianSharpen,
@@ -77,6 +79,7 @@
"NormalizeIntensityd",
"ThresholdIntensityd",
"ScaleIntensityRanged",
+ "ClipIntensityPercentilesd",
"AdjustContrastd",
"RandAdjustContrastd",
"ScaleIntensityRangePercentilesd",
@@ -122,6 +125,8 @@
"ThresholdIntensityDict",
"ScaleIntensityRangeD",
"ScaleIntensityRangeDict",
+ "ClipIntensityPercentilesD",
+ "ClipIntensityPercentilesDict",
"AdjustContrastD",
"AdjustContrastDict",
"RandAdjustContrastD",
@@ -886,6 +891,36 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
return d
+class ClipIntensityPercentilesd(MapTransform):
+ """
+ Dictionary-based wrapper of :py:class:`monai.transforms.ClipIntensityPercentiles`.
+ Clip the intensity values of input image to a specific range based on the intensity distribution of the input.
+ If `sharpness_factor` is provided, the intensity values will be soft clipped according to
+ f(x) = x + (1/sharpness_factor) * softplus(- c(x - minv)) - (1/sharpness_factor)*softplus(c(x - maxv))
+ """
+
+ def __init__(
+ self,
+ keys: KeysCollection,
+ lower: float | None,
+ upper: float | None,
+ sharpness_factor: float | None = None,
+ channel_wise: bool = False,
+ dtype: DtypeLike = np.float32,
+ allow_missing_keys: bool = False,
+ ) -> None:
+ super().__init__(keys, allow_missing_keys)
+ self.scaler = ClipIntensityPercentiles(
+ lower=lower, upper=upper, sharpness_factor=sharpness_factor, channel_wise=channel_wise, dtype=dtype
+ )
+
+ def __call__(self, data: dict) -> dict:
+ d = dict(data)
+ for key in self.key_iterator(d):
+ d[key] = self.scaler(d[key])
+ return d
+
+
class AdjustContrastd(MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.AdjustContrast`.
@@ -1929,6 +1964,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
NormalizeIntensityD = NormalizeIntensityDict = NormalizeIntensityd
ThresholdIntensityD = ThresholdIntensityDict = ThresholdIntensityd
ScaleIntensityRangeD = ScaleIntensityRangeDict = ScaleIntensityRanged
+ClipIntensityPercentilesD = ClipIntensityPercentilesDict = ClipIntensityPercentilesd
AdjustContrastD = AdjustContrastDict = AdjustContrastd
RandAdjustContrastD = RandAdjustContrastDict = RandAdjustContrastd
ScaleIntensityRangePercentilesD = ScaleIntensityRangePercentilesDict = ScaleIntensityRangePercentilesd
diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py
index 7222a26fc3..4e71870fc9 100644
--- a/monai/transforms/io/array.py
+++ b/monai/transforms/io/array.py
@@ -15,6 +15,7 @@
from __future__ import annotations
import inspect
+import json
import logging
import sys
import traceback
@@ -45,11 +46,19 @@
from monai.transforms.utility.array import EnsureChannelFirst
from monai.utils import GridSamplePadMode
from monai.utils import ImageMetaKey as Key
-from monai.utils import OptionalImportError, convert_to_dst_type, ensure_tuple, look_up_option, optional_import
+from monai.utils import (
+ MetaKeys,
+ OptionalImportError,
+ convert_to_dst_type,
+ ensure_tuple,
+ look_up_option,
+ optional_import,
+)
nib, _ = optional_import("nibabel")
Image, _ = optional_import("PIL.Image")
nrrd, _ = optional_import("nrrd")
+FileLock, has_filelock = optional_import("filelock", name="FileLock")
__all__ = ["LoadImage", "SaveImage", "SUPPORTED_READERS"]
@@ -86,7 +95,7 @@ def switch_endianness(data, new="<"):
if new not in ("<", ">"):
raise NotImplementedError(f"Not implemented option new={new}.")
if current_ != new:
- data = data.byteswap().newbyteorder(new)
+ data = data.byteswap().view(data.dtype.newbyteorder(new))
elif isinstance(data, tuple):
data = tuple(switch_endianness(x, new) for x in data)
elif isinstance(data, list):
@@ -307,11 +316,11 @@ class SaveImage(Transform):
Args:
output_dir: output image directory.
- Handled by ``folder_layout`` instead, if ``folder_layout`` is not ``None``.
+ Handled by ``folder_layout`` instead, if ``folder_layout`` is not ``None``.
output_postfix: a string appended to all output file names, default to `trans`.
- Handled by ``folder_layout`` instead, if ``folder_layout`` is not ``None``.
+ Handled by ``folder_layout`` instead, if ``folder_layout`` is not ``None``.
output_ext: output file extension name.
- Handled by ``folder_layout`` instead, if ``folder_layout`` is not ``None``.
+ Handled by ``folder_layout`` instead, if ``folder_layout`` is not ``None``.
output_dtype: data type (if not None) for saving data. Defaults to ``np.float32``.
resample: whether to resample image (if needed) before saving the data array,
based on the ``"spatial_shape"`` (and ``"original_affine"``) from metadata.
@@ -505,7 +514,7 @@ def __call__(
else:
self._data_index += 1
if self.savepath_in_metadict and meta_data is not None:
- meta_data["saved_to"] = filename
+ meta_data[MetaKeys.SAVED_TO] = filename
return img
msg = "\n".join([f"{e}" for e in err])
raise RuntimeError(
@@ -514,3 +523,50 @@ def __call__(
" https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies.\n"
f" The current registered writers for {self.output_ext}: {self.writers}.\n{msg}"
)
+
+
+class WriteFileMapping(Transform):
+ """
+ Writes a JSON file that logs the mapping between input image paths and their corresponding output paths.
+ This class uses FileLock to ensure safe writing to the JSON file in a multiprocess environment.
+
+ Args:
+ mapping_file_path (Path or str): Path to the JSON file where the mappings will be saved.
+ """
+
+ def __init__(self, mapping_file_path: Path | str = "mapping.json"):
+ self.mapping_file_path = Path(mapping_file_path)
+
+ def __call__(self, img: NdarrayOrTensor):
+ """
+ Args:
+ img: The input image with metadata.
+ """
+ if isinstance(img, MetaTensor):
+ meta_data = img.meta
+
+ if MetaKeys.SAVED_TO not in meta_data:
+ raise KeyError(
+ "Missing 'saved_to' key in metadata. Check SaveImage argument 'savepath_in_metadict' is True."
+ )
+
+ input_path = meta_data[Key.FILENAME_OR_OBJ]
+ output_path = meta_data[MetaKeys.SAVED_TO]
+ log_data = {"input": input_path, "output": output_path}
+
+ if has_filelock:
+ with FileLock(str(self.mapping_file_path) + ".lock"):
+ self._write_to_file(log_data)
+ else:
+ self._write_to_file(log_data)
+ return img
+
+ def _write_to_file(self, log_data):
+ try:
+ with self.mapping_file_path.open("r") as f:
+ existing_log_data = json.load(f)
+ except (FileNotFoundError, json.JSONDecodeError):
+ existing_log_data = []
+ existing_log_data.append(log_data)
+ with self.mapping_file_path.open("w") as f:
+ json.dump(existing_log_data, f, indent=4)
diff --git a/monai/transforms/io/dictionary.py b/monai/transforms/io/dictionary.py
index 4da1d422ca..be1e78db8a 100644
--- a/monai/transforms/io/dictionary.py
+++ b/monai/transforms/io/dictionary.py
@@ -17,16 +17,17 @@
from __future__ import annotations
+from collections.abc import Hashable, Mapping
from pathlib import Path
from typing import Callable
import numpy as np
import monai
-from monai.config import DtypeLike, KeysCollection
+from monai.config import DtypeLike, KeysCollection, NdarrayOrTensor
from monai.data import image_writer
from monai.data.image_reader import ImageReader
-from monai.transforms.io.array import LoadImage, SaveImage
+from monai.transforms.io.array import LoadImage, SaveImage, WriteFileMapping
from monai.transforms.transform import MapTransform, Transform
from monai.utils import GridSamplePadMode, ensure_tuple, ensure_tuple_rep
from monai.utils.enums import PostFix
@@ -320,5 +321,31 @@ def __call__(self, data):
return d
+class WriteFileMappingd(MapTransform):
+ """
+ Dictionary-based wrapper of :py:class:`monai.transforms.WriteFileMapping`.
+
+ Args:
+ keys: keys of the corresponding items to be transformed.
+ See also: :py:class:`monai.transforms.compose.MapTransform`
+ mapping_file_path: Path to the JSON file where the mappings will be saved.
+ Defaults to "mapping.json".
+ allow_missing_keys: don't raise exception if key is missing.
+ """
+
+ def __init__(
+ self, keys: KeysCollection, mapping_file_path: Path | str = "mapping.json", allow_missing_keys: bool = False
+ ) -> None:
+ super().__init__(keys, allow_missing_keys)
+ self.mapping = WriteFileMapping(mapping_file_path)
+
+ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
+ d = dict(data)
+ for key in self.key_iterator(d):
+ d[key] = self.mapping(d[key])
+ return d
+
+
LoadImageD = LoadImageDict = LoadImaged
SaveImageD = SaveImageDict = SaveImaged
+WriteFileMappingD = WriteFileMappingDict = WriteFileMappingd
diff --git a/monai/transforms/lazy/functional.py b/monai/transforms/lazy/functional.py
index 6b95027832..a33d76807c 100644
--- a/monai/transforms/lazy/functional.py
+++ b/monai/transforms/lazy/functional.py
@@ -11,7 +11,8 @@
from __future__ import annotations
-from typing import Any, Mapping, Sequence
+from collections.abc import Mapping, Sequence
+from typing import Any
import torch
diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py
index da9b23ce57..2e733c4f6c 100644
--- a/monai/transforms/post/array.py
+++ b/monai/transforms/post/array.py
@@ -211,7 +211,8 @@ def __call__(
raise ValueError("`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead.")
img = convert_to_tensor(img, track_meta=get_track_meta())
img_t, *_ = convert_data_type(img, torch.Tensor)
- if argmax or self.argmax:
+ argmax = self.argmax if argmax is None else argmax
+ if argmax:
img_t = torch.argmax(img_t, dim=self.kwargs.get("dim", 0), keepdim=self.kwargs.get("keepdim", True))
to_onehot = self.to_onehot if to_onehot is None else to_onehot
diff --git a/monai/transforms/regularization/array.py b/monai/transforms/regularization/array.py
index 6c9022d647..66a5116c1a 100644
--- a/monai/transforms/regularization/array.py
+++ b/monai/transforms/regularization/array.py
@@ -16,12 +16,16 @@
import torch
+from monai.data.meta_obj import get_track_meta
+from monai.utils.type_conversion import convert_to_dst_type, convert_to_tensor
+
from ..transform import RandomizableTransform
__all__ = ["MixUp", "CutMix", "CutOut", "Mixer"]
class Mixer(RandomizableTransform):
+
def __init__(self, batch_size: int, alpha: float = 1.0) -> None:
"""
Mixer is a base class providing the basic logic for the mixup-class of
@@ -52,9 +56,11 @@ def randomize(self, data=None) -> None:
as needed. You need to call this method everytime you apply the transform to a new
batch.
"""
+ super().randomize(None)
self._params = (
torch.from_numpy(self.R.beta(self.alpha, self.alpha, self.batch_size)).type(torch.float32),
self.R.permutation(self.batch_size),
+ [torch.from_numpy(self.R.randint(0, d, size=(1,))) for d in data.shape[2:]] if data is not None else [],
)
@@ -68,7 +74,7 @@ class MixUp(Mixer):
"""
def apply(self, data: torch.Tensor):
- weight, perm = self._params
+ weight, perm, _ = self._params
nsamples, *dims = data.shape
if len(weight) != nsamples:
raise ValueError(f"Expected batch of size: {len(weight)}, but got {nsamples}")
@@ -79,11 +85,20 @@ def apply(self, data: torch.Tensor):
mixweight = weight[(Ellipsis,) + (None,) * len(dims)]
return mixweight * data + (1 - mixweight) * data[perm, ...]
- def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None):
- self.randomize()
+ def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None, randomize=True):
+ data_t = convert_to_tensor(data, track_meta=get_track_meta())
+ labels_t = data_t # will not stay this value, needed to satisfy pylint/mypy
+ if labels is not None:
+ labels_t = convert_to_tensor(labels, track_meta=get_track_meta())
+ if randomize:
+ self.randomize()
if labels is None:
- return self.apply(data)
- return self.apply(data), self.apply(labels)
+ return convert_to_dst_type(self.apply(data_t), dst=data)[0]
+
+ return (
+ convert_to_dst_type(self.apply(data_t), dst=data)[0],
+ convert_to_dst_type(self.apply(labels_t), dst=labels)[0],
+ )
class CutMix(Mixer):
@@ -97,6 +112,11 @@ class CutMix(Mixer):
the mixing weight but also the size of the random rectangles used during for mixing.
Please refer to the paper for details.
+ Please note that there is a change in behavior starting from version 1.4.0. In the previous
+ implementation, the transform would generate a different label each time it was called.
+ To ensure determinism, the new implementation will now generate the same label for
+ the same input image when using the same operation.
+
The most common use case is something close to:
.. code-block:: python
@@ -112,14 +132,13 @@ class CutMix(Mixer):
"""
def apply(self, data: torch.Tensor):
- weights, perm = self._params
+ weights, perm, coords = self._params
nsamples, _, *dims = data.shape
if len(weights) != nsamples:
raise ValueError(f"Expected batch of size: {len(weights)}, but got {nsamples}")
mask = torch.ones_like(data)
for s, weight in enumerate(weights):
- coords = [torch.randint(0, d, size=(1,)) for d in dims]
lengths = [d * sqrt(1 - weight) for d in dims]
idx = [slice(None)] + [slice(c, min(ceil(c + ln), d)) for c, ln, d in zip(coords, lengths, dims)]
mask[s][idx] = 0
@@ -127,7 +146,7 @@ def apply(self, data: torch.Tensor):
return mask * data + (1 - mask) * data[perm, ...]
def apply_on_labels(self, labels: torch.Tensor):
- weights, perm = self._params
+ weights, perm, _ = self._params
nsamples, *dims = labels.shape
if len(weights) != nsamples:
raise ValueError(f"Expected batch of size: {len(weights)}, but got {nsamples}")
@@ -135,10 +154,18 @@ def apply_on_labels(self, labels: torch.Tensor):
mixweight = weights[(Ellipsis,) + (None,) * len(dims)]
return mixweight * labels + (1 - mixweight) * labels[perm, ...]
- def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None):
- self.randomize()
- augmented = self.apply(data)
- return (augmented, self.apply_on_labels(labels)) if labels is not None else augmented
+ def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None, randomize=True):
+ data_t = convert_to_tensor(data, track_meta=get_track_meta())
+ augmented_label = None
+ if labels is not None:
+ labels_t = convert_to_tensor(labels, track_meta=get_track_meta())
+ if randomize:
+ self.randomize(data)
+ augmented = convert_to_dst_type(self.apply(data_t), dst=data)[0]
+
+ if labels is not None:
+ augmented_label = convert_to_dst_type(self.apply(labels_t), dst=labels)[0]
+ return (augmented, augmented_label) if labels is not None else augmented
class CutOut(Mixer):
@@ -154,20 +181,21 @@ class CutOut(Mixer):
"""
def apply(self, data: torch.Tensor):
- weights, _ = self._params
+ weights, _, coords = self._params
nsamples, _, *dims = data.shape
if len(weights) != nsamples:
raise ValueError(f"Expected batch of size: {len(weights)}, but got {nsamples}")
mask = torch.ones_like(data)
for s, weight in enumerate(weights):
- coords = [torch.randint(0, d, size=(1,)) for d in dims]
lengths = [d * sqrt(1 - weight) for d in dims]
idx = [slice(None)] + [slice(c, min(ceil(c + ln), d)) for c, ln, d in zip(coords, lengths, dims)]
mask[s][idx] = 0
return mask * data
- def __call__(self, data: torch.Tensor):
- self.randomize()
- return self.apply(data)
+ def __call__(self, data: torch.Tensor, randomize=True):
+ data_t = convert_to_tensor(data, track_meta=get_track_meta())
+ if randomize:
+ self.randomize(data)
+ return convert_to_dst_type(self.apply(data_t), dst=data)[0]
diff --git a/monai/transforms/regularization/dictionary.py b/monai/transforms/regularization/dictionary.py
index 373913da99..d8815e47b9 100644
--- a/monai/transforms/regularization/dictionary.py
+++ b/monai/transforms/regularization/dictionary.py
@@ -11,16 +11,23 @@
from __future__ import annotations
+from collections.abc import Hashable
+
+import numpy as np
+
from monai.config import KeysCollection
+from monai.config.type_definitions import NdarrayOrTensor
+from monai.data.meta_obj import get_track_meta
+from monai.utils import convert_to_tensor
from monai.utils.misc import ensure_tuple
-from ..transform import MapTransform
+from ..transform import MapTransform, RandomizableTransform
from .array import CutMix, CutOut, MixUp
__all__ = ["MixUpd", "MixUpD", "MixUpDict", "CutMixd", "CutMixD", "CutMixDict", "CutOutd", "CutOutD", "CutOutDict"]
-class MixUpd(MapTransform):
+class MixUpd(MapTransform, RandomizableTransform):
"""
Dictionary-based version :py:class:`monai.transforms.MixUp`.
@@ -31,18 +38,24 @@ class MixUpd(MapTransform):
def __init__(
self, keys: KeysCollection, batch_size: int, alpha: float = 1.0, allow_missing_keys: bool = False
) -> None:
- super().__init__(keys, allow_missing_keys)
+ MapTransform.__init__(self, keys, allow_missing_keys)
self.mixup = MixUp(batch_size, alpha)
+ def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> MixUpd:
+ super().set_random_state(seed, state)
+ self.mixup.set_random_state(seed, state)
+ return self
+
def __call__(self, data):
- self.mixup.randomize()
- result = dict(data)
- for k in self.keys:
- result[k] = self.mixup.apply(data[k])
- return result
+ d = dict(data)
+ # all the keys share the same random state
+ self.mixup.randomize(None)
+ for k in self.key_iterator(d):
+ d[k] = self.mixup(data[k], randomize=False)
+ return d
-class CutMixd(MapTransform):
+class CutMixd(MapTransform, RandomizableTransform):
"""
Dictionary-based version :py:class:`monai.transforms.CutMix`.
@@ -63,17 +76,27 @@ def __init__(
self.mixer = CutMix(batch_size, alpha)
self.label_keys = ensure_tuple(label_keys) if label_keys is not None else []
- def __call__(self, data):
- self.mixer.randomize()
- result = dict(data)
- for k in self.keys:
- result[k] = self.mixer.apply(data[k])
- for k in self.label_keys:
- result[k] = self.mixer.apply_on_labels(data[k])
- return result
-
+ def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> CutMixd:
+ super().set_random_state(seed, state)
+ self.mixer.set_random_state(seed, state)
+ return self
-class CutOutd(MapTransform):
+ def __call__(self, data):
+ d = dict(data)
+ first_key: Hashable = self.first_key(d)
+ if first_key == ():
+ out: dict[Hashable, NdarrayOrTensor] = convert_to_tensor(d, track_meta=get_track_meta())
+ return out
+ self.mixer.randomize(d[first_key])
+ for key, label_key in self.key_iterator(d, self.label_keys):
+ ret = self.mixer(data[key], data.get(label_key, None), randomize=False)
+ d[key] = ret[0]
+ if label_key in d:
+ d[label_key] = ret[1]
+ return d
+
+
+class CutOutd(MapTransform, RandomizableTransform):
"""
Dictionary-based version :py:class:`monai.transforms.CutOut`.
@@ -84,12 +107,21 @@ def __init__(self, keys: KeysCollection, batch_size: int, allow_missing_keys: bo
super().__init__(keys, allow_missing_keys)
self.cutout = CutOut(batch_size)
+ def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> CutOutd:
+ super().set_random_state(seed, state)
+ self.cutout.set_random_state(seed, state)
+ return self
+
def __call__(self, data):
- result = dict(data)
- self.cutout.randomize()
- for k in self.keys:
- result[k] = self.cutout(data[k])
- return result
+ d = dict(data)
+ first_key: Hashable = self.first_key(d)
+ if first_key == ():
+ out: dict[Hashable, NdarrayOrTensor] = convert_to_tensor(d, track_meta=get_track_meta())
+ return out
+ self.cutout.randomize(d[first_key])
+ for k in self.key_iterator(d):
+ d[k] = self.cutout(data[k], randomize=False)
+ return d
MixUpD = MixUpDict = MixUpd
diff --git a/monai/transforms/signal/array.py b/monai/transforms/signal/array.py
index 938f42192c..97df04f233 100644
--- a/monai/transforms/signal/array.py
+++ b/monai/transforms/signal/array.py
@@ -28,7 +28,7 @@
from monai.utils.enums import TransformBackends
from monai.utils.type_conversion import convert_data_type, convert_to_tensor
-shift, has_shift = optional_import("scipy.ndimage.interpolation", name="shift")
+shift, has_shift = optional_import("scipy.ndimage", name="shift")
iirnotch, has_iirnotch = optional_import("scipy.signal", name="iirnotch")
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning) # project-monai/monai#5204
diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py
index 094afdd3c4..e4ed196eff 100644
--- a/monai/transforms/spatial/array.py
+++ b/monai/transforms/spatial/array.py
@@ -15,16 +15,17 @@
from __future__ import annotations
import warnings
-from collections.abc import Callable
+from collections.abc import Callable, Sequence
from copy import deepcopy
from itertools import zip_longest
-from typing import Any, Optional, Sequence, Tuple, Union, cast
+from typing import Any, Optional, Union, cast
import numpy as np
import torch
from monai.config import USE_COMPILED, DtypeLike
from monai.config.type_definitions import NdarrayOrTensor
+from monai.data.box_utils import BoxMode, StandardMode
from monai.data.meta_obj import get_track_meta, set_track_meta
from monai.data.meta_tensor import MetaTensor
from monai.data.utils import AFFINE_TOL, affine_to_spacing, compute_shape_offset, iter_patch, to_affine_nd, zoom_affine
@@ -34,6 +35,8 @@
from monai.transforms.inverse import InvertibleTransform
from monai.transforms.spatial.functional import (
affine_func,
+ convert_box_to_points,
+ convert_points_to_box,
flip,
orientation,
resize,
@@ -113,7 +116,7 @@
"RandSimulateLowResolution",
]
-RandRange = Optional[Union[Sequence[Union[Tuple[float, float], float]], float]]
+RandRange = Optional[Union[Sequence[Union[tuple[float, float], float]], float]]
class SpatialResample(InvertibleTransform, LazyTransform):
@@ -3441,7 +3444,7 @@ def filter_count(self, image_np: NdarrayOrTensor, locations: np.ndarray) -> tupl
idx = self.R.permutation(image_np.shape[0])
idx = idx[: self.num_patches]
idx_np = convert_data_type(idx, np.ndarray)[0]
- image_np = image_np[idx]
+ image_np = image_np[idx] # type: ignore[index]
locations = locations[idx_np]
return image_np, locations
elif self.sort_fn not in (None, GridPatchSort.MIN, GridPatchSort.MAX):
@@ -3544,3 +3547,44 @@ def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor:
else:
return img
+
+
+class ConvertBoxToPoints(Transform):
+ """
+ Converts an axis-aligned bounding box to points. It can automatically convert the boxes to the points based on the box mode.
+ Bounding boxes of the shape (N, C) for N boxes. C is [x1, y1, x2, y2] for 2D or [x1, y1, z1, x2, y2, z2] for 3D for each box.
+ Return shape will be (N, 4, 2) for 2D or (N, 8, 3) for 3D.
+ """
+
+ backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
+
+ def __init__(self, mode: str | BoxMode | type[BoxMode] | None = None) -> None:
+ """
+ Args:
+ mode: the mode of the box, can be a string, a BoxMode instance or a BoxMode class. Defaults to StandardMode.
+ """
+ super().__init__()
+ self.mode = StandardMode if mode is None else mode
+
+ def __call__(self, data: Any):
+ data = convert_to_tensor(data, track_meta=get_track_meta())
+ points = convert_box_to_points(data, mode=self.mode)
+ return convert_to_dst_type(points, data)[0]
+
+
+class ConvertPointsToBoxes(Transform):
+ """
+ Converts points to an axis-aligned bounding box.
+ Points representing the corners of the bounding box. Shape (N, 8, 3) for the 8 corners of a 3D cuboid or
+ (N, 4, 2) for the 4 corners of a 2D rectangle.
+ """
+
+ backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
+
+ def __init__(self) -> None:
+ super().__init__()
+
+ def __call__(self, data: Any):
+ data = convert_to_tensor(data, track_meta=get_track_meta())
+ box = convert_points_to_box(data)
+ return convert_to_dst_type(box, data)[0]
diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py
index 01fadcfb69..2b80034a07 100644
--- a/monai/transforms/spatial/dictionary.py
+++ b/monai/transforms/spatial/dictionary.py
@@ -26,6 +26,7 @@
from monai.config import DtypeLike, KeysCollection, SequenceStr
from monai.config.type_definitions import NdarrayOrTensor
+from monai.data.box_utils import BoxMode, StandardMode
from monai.data.meta_obj import get_track_meta
from monai.data.meta_tensor import MetaTensor
from monai.networks.layers.simplelayers import GaussianFilter
@@ -33,6 +34,8 @@
from monai.transforms.inverse import InvertibleTransform
from monai.transforms.spatial.array import (
Affine,
+ ConvertBoxToPoints,
+ ConvertPointsToBoxes,
Flip,
GridDistortion,
GridPatch,
@@ -2585,6 +2588,7 @@ def set_random_state(
self, seed: int | None = None, state: np.random.RandomState | None = None
) -> RandSimulateLowResolutiond:
super().set_random_state(seed, state)
+ self.sim_lowres_tfm.set_random_state(seed, state)
return self
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
@@ -2611,6 +2615,61 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
return d
+class ConvertBoxToPointsd(MapTransform):
+ """
+ Dictionary-based wrapper of :py:class:`monai.transforms.ConvertBoxToPoints`.
+ """
+
+ backend = ConvertBoxToPoints.backend
+
+ def __init__(
+ self,
+ keys: KeysCollection,
+ point_key="points",
+ mode: str | BoxMode | type[BoxMode] | None = StandardMode,
+ allow_missing_keys: bool = False,
+ ):
+ """
+ Args:
+ keys: keys of the corresponding items to be transformed.
+ point_key: key to store the point data.
+ mode: the mode of the input boxes. Defaults to StandardMode.
+ allow_missing_keys: don't raise exception if key is missing.
+ """
+ super().__init__(keys, allow_missing_keys)
+ self.point_key = point_key
+ self.converter = ConvertBoxToPoints(mode=mode)
+
+ def __call__(self, data):
+ d = dict(data)
+ for key in self.key_iterator(d):
+ data[self.point_key] = self.converter(d[key])
+ return data
+
+
+class ConvertPointsToBoxesd(MapTransform):
+ """
+ Dictionary-based wrapper of :py:class:`monai.transforms.ConvertPointsToBoxes`.
+ """
+
+ def __init__(self, keys: KeysCollection, box_key="box", allow_missing_keys: bool = False):
+ """
+ Args:
+ keys: keys of the corresponding items to be transformed.
+ box_key: key to store the box data.
+ allow_missing_keys: don't raise exception if key is missing.
+ """
+ super().__init__(keys, allow_missing_keys)
+ self.box_key = box_key
+ self.converter = ConvertPointsToBoxes()
+
+ def __call__(self, data):
+ d = dict(data)
+ for key in self.key_iterator(d):
+ data[self.box_key] = self.converter(d[key])
+ return data
+
+
SpatialResampleD = SpatialResampleDict = SpatialResampled
ResampleToMatchD = ResampleToMatchDict = ResampleToMatchd
SpacingD = SpacingDict = Spacingd
@@ -2635,3 +2694,5 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
GridPatchD = GridPatchDict = GridPatchd
RandGridPatchD = RandGridPatchDict = RandGridPatchd
RandSimulateLowResolutionD = RandSimulateLowResolutionDict = RandSimulateLowResolutiond
+ConvertBoxToPointsD = ConvertBoxToPointsDict = ConvertBoxToPointsd
+ConvertPointsToBoxesD = ConvertPointsToBoxesDict = ConvertPointsToBoxesd
diff --git a/monai/transforms/spatial/functional.py b/monai/transforms/spatial/functional.py
index add4e7f5ea..b693e7d023 100644
--- a/monai/transforms/spatial/functional.py
+++ b/monai/transforms/spatial/functional.py
@@ -24,6 +24,7 @@
import monai
from monai.config import USE_COMPILED
from monai.config.type_definitions import NdarrayOrTensor
+from monai.data.box_utils import get_boxmode
from monai.data.meta_obj import get_track_meta
from monai.data.meta_tensor import MetaTensor
from monai.data.utils import AFFINE_TOL, compute_shape_offset, to_affine_nd
@@ -32,7 +33,7 @@
from monai.transforms.intensity.array import GaussianSmooth
from monai.transforms.inverse import TraceableTransform
from monai.transforms.utils import create_rotate, create_translate, resolves_modes, scale_affine
-from monai.transforms.utils_pytorch_numpy_unification import allclose
+from monai.transforms.utils_pytorch_numpy_unification import allclose, concatenate, stack
from monai.utils import (
LazyAttr,
TraceKeys,
@@ -373,7 +374,7 @@ def rotate(img, angle, output_shape, mode, padding_mode, align_corners, dtype, l
if output_shape is None:
corners = np.asarray(np.meshgrid(*[(0, dim) for dim in im_shape], indexing="ij")).reshape((len(im_shape), -1))
corners = transform[:-1, :-1] @ corners # type: ignore
- output_shape = np.asarray(corners.ptp(axis=1) + 0.5, dtype=int)
+ output_shape = np.asarray(np.ptp(corners, axis=1) + 0.5, dtype=int)
else:
output_shape = np.asarray(output_shape, dtype=int)
shift = create_translate(input_ndim, ((np.array(im_shape) - 1) / 2).tolist())
@@ -610,3 +611,71 @@ def affine_func(
out = _maybe_new_metatensor(img, dtype=torch.float32, device=resampler.device)
out = out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out
return out if image_only else (out, affine)
+
+
+def convert_box_to_points(bbox, mode):
+ """
+ Converts an axis-aligned bounding box to points.
+
+ Args:
+ mode: The mode specifying how to interpret the bounding box.
+ bbox: Bounding boxes of the shape (N, C) for N boxes. C is [x1, y1, x2, y2] for 2D or [x1, y1, z1, x2, y2, z2]
+ for 3D for each box. Return shape will be (N, 4, 2) for 2D or (N, 8, 3) for 3D.
+
+ Returns:
+ sequence of points representing the corners of the bounding box.
+ """
+
+ mode = get_boxmode(mode)
+
+ points_list = []
+ for _num in range(bbox.shape[0]):
+ corners = mode.boxes_to_corners(bbox[_num : _num + 1])
+ if len(corners) == 4:
+ points_list.append(
+ concatenate(
+ [
+ concatenate([corners[0], corners[1]], axis=1),
+ concatenate([corners[2], corners[1]], axis=1),
+ concatenate([corners[2], corners[3]], axis=1),
+ concatenate([corners[0], corners[3]], axis=1),
+ ],
+ axis=0,
+ )
+ )
+ else:
+ points_list.append(
+ concatenate(
+ [
+ concatenate([corners[0], corners[1], corners[2]], axis=1),
+ concatenate([corners[3], corners[1], corners[2]], axis=1),
+ concatenate([corners[3], corners[4], corners[2]], axis=1),
+ concatenate([corners[0], corners[4], corners[2]], axis=1),
+ concatenate([corners[0], corners[1], corners[5]], axis=1),
+ concatenate([corners[3], corners[1], corners[5]], axis=1),
+ concatenate([corners[3], corners[4], corners[5]], axis=1),
+ concatenate([corners[0], corners[4], corners[5]], axis=1),
+ ],
+ axis=0,
+ )
+ )
+
+ return stack(points_list, dim=0)
+
+
+def convert_points_to_box(points):
+ """
+ Converts points to an axis-aligned bounding box.
+
+ Args:
+ points: Points representing the corners of the bounding box. Shape (N, 8, 3) for the 8 corners of
+ a 3D cuboid or (N, 4, 2) for the 4 corners of a 2D rectangle.
+ """
+ from monai.transforms.utils_pytorch_numpy_unification import max, min
+
+ mins = min(points, dim=1)
+ maxs = max(points, dim=1)
+ # Concatenate the min and max values to get the bounding boxes
+ bboxes = concatenate([mins, maxs], axis=1)
+
+ return bboxes
diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py
index 3d09cea545..15c2499a73 100644
--- a/monai/transforms/transform.py
+++ b/monai/transforms/transform.py
@@ -203,8 +203,8 @@ def set_random_state(self, seed: int | None = None, state: np.random.RandomState
"""
if seed is not None:
- _seed = id(seed) if not isinstance(seed, (int, np.integer)) else seed
- _seed = _seed % MAX_SEED
+ _seed = np.int64(id(seed) if not isinstance(seed, (int, np.integer)) else seed)
+ _seed = _seed % MAX_SEED # need to account for Numpy2.0 which doesn't silently convert to int64
self.R = np.random.RandomState(_seed)
return self
diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py
index 85895f6daf..9d67e69033 100644
--- a/monai/transforms/utility/array.py
+++ b/monai/transforms/utility/array.py
@@ -31,7 +31,7 @@
from monai.config.type_definitions import NdarrayOrTensor
from monai.data.meta_obj import get_track_meta
from monai.data.meta_tensor import MetaTensor
-from monai.data.utils import is_no_channel, no_collation
+from monai.data.utils import is_no_channel, no_collation, orientation_ras_lps
from monai.networks.layers.simplelayers import (
ApplyFilter,
EllipticalFilter,
@@ -42,16 +42,17 @@
SharpenFilter,
median_filter,
)
-from monai.transforms.inverse import InvertibleTransform
+from monai.transforms.inverse import InvertibleTransform, TraceableTransform
from monai.transforms.traits import MultiSampleTrait
from monai.transforms.transform import Randomizable, RandomizableTrait, RandomizableTransform, Transform
from monai.transforms.utils import (
+ apply_affine_to_points,
extreme_points_to_image,
get_extreme_points,
map_binary_to_indices,
map_classes_to_indices,
)
-from monai.transforms.utils_pytorch_numpy_unification import concatenate, in1d, moveaxis, unravel_indices
+from monai.transforms.utils_pytorch_numpy_unification import concatenate, in1d, linalg_inv, moveaxis, unravel_indices
from monai.utils import (
MetaKeys,
TraceKeys,
@@ -66,7 +67,7 @@
)
from monai.utils.enums import TransformBackends
from monai.utils.misc import is_module_ver_at_least
-from monai.utils.type_conversion import convert_to_dst_type, get_equivalent_dtype
+from monai.utils.type_conversion import convert_to_dst_type, get_dtype_string, get_equivalent_dtype
PILImageImage, has_pil = optional_import("PIL.Image", name="Image")
pil_image_fromarray, _ = optional_import("PIL.Image", name="fromarray")
@@ -107,6 +108,7 @@
"ToCupy",
"ImageFilter",
"RandImageFilter",
+ "ApplyTransformToPoints",
]
@@ -655,6 +657,7 @@ def __init__(
data_shape: bool = True,
value_range: bool = True,
data_value: bool = False,
+ meta_info: bool = False,
additional_info: Callable | None = None,
name: str = "DataStats",
) -> None:
@@ -666,6 +669,7 @@ def __init__(
value_range: whether to show the value range of input data.
data_value: whether to show the raw value of input data.
a typical example is to print some properties of Nifti image: affine, pixdim, etc.
+ meta_info: whether to show the data of MetaTensor.
additional_info: user can define callable function to extract additional info from input data.
name: identifier of `logging.logger` to use, defaulting to "DataStats".
@@ -680,6 +684,7 @@ def __init__(
self.data_shape = data_shape
self.value_range = value_range
self.data_value = data_value
+ self.meta_info = meta_info
if additional_info is not None and not callable(additional_info):
raise TypeError(f"additional_info must be None or callable but is {type(additional_info).__name__}.")
self.additional_info = additional_info
@@ -706,6 +711,7 @@ def __call__(
data_shape: bool | None = None,
value_range: bool | None = None,
data_value: bool | None = None,
+ meta_info: bool | None = None,
additional_info: Callable | None = None,
) -> NdarrayOrTensor:
"""
@@ -726,6 +732,9 @@ def __call__(
lines.append(f"Value range: (not a PyTorch or Numpy array, type: {type(img)})")
if self.data_value if data_value is None else data_value:
lines.append(f"Value: {img}")
+ if self.meta_info if meta_info is None else meta_info:
+ metadata = getattr(img, "meta", "(input is not a MetaTensor)")
+ lines.append(f"Meta info: {repr(metadata)}")
additional_info = self.additional_info if additional_info is None else additional_info
if additional_info is not None:
lines.append(f"Additional info: {additional_info(img)}")
@@ -1640,9 +1649,9 @@ def _check_all_values_uneven(self, x: tuple) -> None:
def _check_filter_format(self, filter: str | NdarrayOrTensor | nn.Module, filter_size: int | None = None) -> None:
if isinstance(filter, str):
- if not filter_size:
+ if filter != "gauss" and not filter_size: # Gauss is the only filter that does not require `filter_size`
raise ValueError("`filter_size` must be specified when specifying filters by string.")
- if filter_size % 2 == 0:
+ if filter_size and filter_size % 2 == 0:
raise ValueError("`filter_size` should be a single uneven integer.")
if filter not in self.supported_filters:
raise NotImplementedError(f"{filter}. Supported filters are {self.supported_filters}.")
@@ -1755,3 +1764,143 @@ def __call__(self, img: NdarrayOrTensor, meta_dict: Mapping | None = None) -> Nd
if self._do_transform:
img = self.filter(img)
return img
+
+
+class ApplyTransformToPoints(InvertibleTransform, Transform):
+ """
+ Transform points between image coordinates and world coordinates.
+ The input coordinates are assumed to be in the shape (C, N, 2 or 3), where C represents the number of channels
+ and N denotes the number of points. It will return a tensor with the same shape as the input.
+
+ Args:
+ dtype: The desired data type for the output.
+ affine: A 3x3 or 4x4 affine transformation matrix applied to points. This matrix typically originates
+ from the image. For 2D points, a 3x3 matrix can be provided, avoiding the need to add an unnecessary
+ Z dimension. While a 4x4 matrix is required for 3D transformations, it's important to note that when
+ applying a 4x4 matrix to 2D points, the additional dimensions are handled accordingly.
+ The matrix is always converted to float64 for computation, which can be computationally
+ expensive when applied to a large number of points.
+ If None, will try to use the affine matrix from the input data.
+ invert_affine: Whether to invert the affine transformation matrix applied to the points. Defaults to ``True``.
+ Typically, the affine matrix is derived from an image and represents its location in world space,
+ while the points are in world coordinates. A value of ``True`` represents transforming these
+ world space coordinates to the image's coordinate space, and ``False`` the inverse of this operation.
+ affine_lps_to_ras: Defaults to ``False``. Set to `True` if your point data is in the RAS coordinate system
+ or you're using `ITKReader` with `affine_lps_to_ras=True`.
+ This ensures the correct application of the affine transformation between LPS (left-posterior-superior)
+ and RAS (right-anterior-superior) coordinate systems. This argument ensures the points and the affine
+ matrix are in the same coordinate system.
+
+ Use Cases:
+ - Transforming points between world space and image space, and vice versa.
+ - Automatically handling inverse transformations between image space and world space.
+ - If points have an existing affine transformation, the class computes and
+ applies the required delta affine transformation.
+
+ """
+
+ def __init__(
+ self,
+ dtype: DtypeLike | torch.dtype | None = None,
+ affine: torch.Tensor | None = None,
+ invert_affine: bool = True,
+ affine_lps_to_ras: bool = False,
+ ) -> None:
+ self.dtype = dtype
+ self.affine = affine
+ self.invert_affine = invert_affine
+ self.affine_lps_to_ras = affine_lps_to_ras
+
+ def _compute_final_affine(self, affine: torch.Tensor, applied_affine: torch.Tensor | None = None) -> torch.Tensor:
+ """
+ Compute the final affine transformation matrix to apply to the point data.
+
+ Args:
+ data: Input coordinates assumed to be in the shape (C, N, 2 or 3).
+ affine: 3x3 or 4x4 affine transformation matrix.
+
+ Returns:
+ Final affine transformation matrix.
+ """
+
+ affine = convert_data_type(affine, dtype=torch.float64)[0]
+
+ if self.affine_lps_to_ras:
+ affine = orientation_ras_lps(affine)
+
+ if self.invert_affine:
+ affine = linalg_inv(affine)
+ if applied_affine is not None:
+ affine = affine @ applied_affine
+
+ return affine
+
+ def transform_coordinates(
+ self, data: torch.Tensor, affine: torch.Tensor | None = None
+ ) -> tuple[torch.Tensor, dict]:
+ """
+ Transform coordinates using an affine transformation matrix.
+
+ Args:
+ data: The input coordinates are assumed to be in the shape (C, N, 2 or 3),
+ where C represents the number of channels and N denotes the number of points.
+ affine: 3x3 or 4x4 affine transformation matrix. The matrix is always converted to float64 for computation,
+ which can be computationally expensive when applied to a large number of points.
+
+ Returns:
+ Transformed coordinates.
+ """
+ data = convert_to_tensor(data, track_meta=get_track_meta())
+ if affine is None and self.invert_affine:
+ raise ValueError("affine must be provided when invert_affine is True.")
+ # applied_affine is the affine transformation matrix that has already been applied to the point data
+ applied_affine: torch.Tensor | None = getattr(data, "affine", None)
+ affine = applied_affine if affine is None else affine
+ if affine is None:
+ raise ValueError("affine must be provided if data does not have an affine matrix.")
+
+ final_affine = self._compute_final_affine(affine, applied_affine)
+ out = apply_affine_to_points(data, final_affine, dtype=self.dtype)
+
+ extra_info = {
+ "invert_affine": self.invert_affine,
+ "dtype": get_dtype_string(self.dtype),
+ "image_affine": affine,
+ "affine_lps_to_ras": self.affine_lps_to_ras,
+ }
+
+ xform = orientation_ras_lps(linalg_inv(final_affine)) if self.affine_lps_to_ras else linalg_inv(final_affine)
+ meta_info = TraceableTransform.track_transform_meta(
+ data, affine=xform, extra_info=extra_info, transform_info=self.get_transform_info()
+ )
+
+ return out, meta_info
+
+ def __call__(self, data: torch.Tensor, affine: torch.Tensor | None = None):
+ """
+ Args:
+ data: The input coordinates are assumed to be in the shape (C, N, 2 or 3),
+ where C represents the number of channels and N denotes the number of points.
+ affine: A 3x3 or 4x4 affine transformation matrix, this argument will take precedence over ``self.affine``.
+ """
+ if data.ndim != 3 or data.shape[-1] not in (2, 3):
+ raise ValueError(f"data should be in shape (C, N, 2 or 3), got {data.shape}.")
+ affine = self.affine if affine is None else affine
+ if affine is not None and affine.shape not in ((3, 3), (4, 4)):
+ raise ValueError(f"affine should be in shape (3, 3) or (4, 4), got {affine.shape}.")
+
+ out, meta_info = self.transform_coordinates(data, affine)
+
+ return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out
+
+ def inverse(self, data: torch.Tensor) -> torch.Tensor:
+ transform = self.pop_transform(data)
+ inverse_transform = ApplyTransformToPoints(
+ dtype=transform[TraceKeys.EXTRA_INFO]["dtype"],
+ invert_affine=not transform[TraceKeys.EXTRA_INFO]["invert_affine"],
+ affine_lps_to_ras=transform[TraceKeys.EXTRA_INFO]["affine_lps_to_ras"],
+ )
+ with inverse_transform.trace_transform(False):
+ data = inverse_transform(data, transform[TraceKeys.EXTRA_INFO]["image_affine"])
+
+ return data
diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py
index 92cfba3cdb..f29119d348 100644
--- a/monai/transforms/utility/dictionary.py
+++ b/monai/transforms/utility/dictionary.py
@@ -18,9 +18,9 @@
from __future__ import annotations
import re
-from collections.abc import Callable, Hashable, Mapping
+from collections.abc import Callable, Hashable, Mapping, Sequence
from copy import deepcopy
-from typing import Any, Sequence, cast
+from typing import Any, cast
import numpy as np
import torch
@@ -35,6 +35,7 @@
from monai.transforms.utility.array import (
AddCoordinateChannels,
AddExtremePointsChannel,
+ ApplyTransformToPoints,
AsChannelLast,
CastToType,
ClassesToIndices,
@@ -184,6 +185,9 @@
"ClassesToIndicesd",
"ClassesToIndicesD",
"ClassesToIndicesDict",
+ "ApplyTransformToPointsd",
+ "ApplyTransformToPointsD",
+ "ApplyTransformToPointsDict",
]
DEFAULT_POST_FIX = PostFix.meta()
@@ -793,6 +797,7 @@ def __init__(
data_shape: Sequence[bool] | bool = True,
value_range: Sequence[bool] | bool = True,
data_value: Sequence[bool] | bool = False,
+ meta_info: Sequence[bool] | bool = False,
additional_info: Sequence[Callable] | Callable | None = None,
name: str = "DataStats",
allow_missing_keys: bool = False,
@@ -812,6 +817,8 @@ def __init__(
data_value: whether to show the raw value of input data.
it also can be a sequence of bool, each element corresponds to a key in ``keys``.
a typical example is to print some properties of Nifti image: affine, pixdim, etc.
+ meta_info: whether to show the data of MetaTensor.
+ it also can be a sequence of bool, each element corresponds to a key in ``keys``.
additional_info: user can define callable function to extract
additional info from input data. it also can be a sequence of string, each element
corresponds to a key in ``keys``.
@@ -825,15 +832,34 @@ def __init__(
self.data_shape = ensure_tuple_rep(data_shape, len(self.keys))
self.value_range = ensure_tuple_rep(value_range, len(self.keys))
self.data_value = ensure_tuple_rep(data_value, len(self.keys))
+ self.meta_info = ensure_tuple_rep(meta_info, len(self.keys))
self.additional_info = ensure_tuple_rep(additional_info, len(self.keys))
self.printer = DataStats(name=name)
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
d = dict(data)
- for key, prefix, data_type, data_shape, value_range, data_value, additional_info in self.key_iterator(
- d, self.prefix, self.data_type, self.data_shape, self.value_range, self.data_value, self.additional_info
+ for (
+ key,
+ prefix,
+ data_type,
+ data_shape,
+ value_range,
+ data_value,
+ meta_info,
+ additional_info,
+ ) in self.key_iterator(
+ d,
+ self.prefix,
+ self.data_type,
+ self.data_shape,
+ self.value_range,
+ self.data_value,
+ self.meta_info,
+ self.additional_info,
):
- d[key] = self.printer(d[key], prefix, data_type, data_shape, value_range, data_value, additional_info)
+ d[key] = self.printer(
+ d[key], prefix, data_type, data_shape, value_range, data_value, meta_info, additional_info
+ )
return d
@@ -1765,6 +1791,10 @@ class RandImageFilterd(MapTransform, RandomizableTransform):
Probability the transform is applied to the data
allow_missing_keys:
Don't raise exception if key is missing.
+
+ Note:
+ - This transform does not scale output image values automatically to match the range of the input.
+ The output should be scaled by later transforms to match the input if this is desired.
"""
backend = ImageFilter.backend
@@ -1791,6 +1821,77 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
return d
+class ApplyTransformToPointsd(MapTransform, InvertibleTransform):
+ """
+ Dictionary-based wrapper of :py:class:`monai.transforms.ApplyTransformToPoints`.
+ The input coordinates are assumed to be in the shape (C, N, 2 or 3),
+ where C represents the number of channels and N denotes the number of points.
+ The output has the same shape as the input.
+
+ Args:
+ keys: keys of the corresponding items to be transformed.
+ See also: monai.transforms.MapTransform
+ refer_keys: The key of the reference item used for transformation.
+ It can directly refer to an affine or an image from which the affine can be derived. It can also be a
+ sequence of keys, in which case each refers to the affine applied to the matching points in `keys`.
+ dtype: The desired data type for the output.
+ affine: A 3x3 or 4x4 affine transformation matrix applied to points. This matrix typically originates
+ from the image. For 2D points, a 3x3 matrix can be provided, avoiding the need to add an unnecessary
+ Z dimension. While a 4x4 matrix is required for 3D transformations, it's important to note that when
+ applying a 4x4 matrix to 2D points, the additional dimensions are handled accordingly.
+ The matrix is always converted to float64 for computation, which can be computationally
+ expensive when applied to a large number of points.
+ If None, will try to use the affine matrix from the refer data.
+ invert_affine: Whether to invert the affine transformation matrix applied to the points. Defaults to ``True``.
+ Typically, the affine matrix is derived from the image, while the points are in world coordinates.
+ If you want to align the points with the image, set this to ``True``. Otherwise, set it to ``False``.
+ affine_lps_to_ras: Defaults to ``False``. Set to `True` if your point data is in the RAS coordinate system
+ or you're using `ITKReader` with `affine_lps_to_ras=True`.
+ This ensures the correct application of the affine transformation between LPS (left-posterior-superior)
+ and RAS (right-anterior-superior) coordinate systems. This argument ensures the points and the affine
+ matrix are in the same coordinate system.
+ allow_missing_keys: Don't raise exception if key is missing.
+ """
+
+ def __init__(
+ self,
+ keys: KeysCollection,
+ refer_keys: KeysCollection | None = None,
+ dtype: DtypeLike | torch.dtype = torch.float64,
+ affine: torch.Tensor | None = None,
+ invert_affine: bool = True,
+ affine_lps_to_ras: bool = False,
+ allow_missing_keys: bool = False,
+ ):
+ MapTransform.__init__(self, keys, allow_missing_keys)
+ self.refer_keys = ensure_tuple_rep(refer_keys, len(self.keys))
+ self.converter = ApplyTransformToPoints(
+ dtype=dtype, affine=affine, invert_affine=invert_affine, affine_lps_to_ras=affine_lps_to_ras
+ )
+
+ def __call__(self, data: Mapping[Hashable, torch.Tensor]):
+ d = dict(data)
+ for key, refer_key in self.key_iterator(d, self.refer_keys):
+ coords = d[key]
+ affine = None # represents using affine given in constructor
+ if refer_key is not None:
+ if refer_key in d:
+ refer_data = d[refer_key]
+ else:
+ raise KeyError(f"The refer_key '{refer_key}' is not found in the data.")
+
+ # use the "affine" member of refer_data, or refer_data itself, as the affine matrix
+ affine = getattr(refer_data, "affine", refer_data)
+ d[key] = self.converter(coords, affine)
+ return d
+
+ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:
+ d = dict(data)
+ for key in self.key_iterator(d):
+ d[key] = self.converter.inverse(d[key])
+ return d
+
+
RandImageFilterD = RandImageFilterDict = RandImageFilterd
ImageFilterD = ImageFilterDict = ImageFilterd
IdentityD = IdentityDict = Identityd
@@ -1832,3 +1933,4 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
RandCuCIMD = RandCuCIMDict = RandCuCIMd
AddCoordinateChannelsD = AddCoordinateChannelsDict = AddCoordinateChannelsd
FlattenSubKeysD = FlattenSubKeysDict = FlattenSubKeysd
+ApplyTransformToPointsD = ApplyTransformToPointsDict = ApplyTransformToPointsd
diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py
index e282ecff24..e7e1616e13 100644
--- a/monai/transforms/utils.py
+++ b/monai/transforms/utils.py
@@ -22,22 +22,27 @@
import numpy as np
import torch
+from torch import Tensor
import monai
from monai.config import DtypeLike, IndexSelection
from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor
+from monai.data.utils import to_affine_nd
from monai.networks.layers import GaussianFilter
from monai.networks.utils import meshgrid_ij
from monai.transforms.compose import Compose
from monai.transforms.transform import MapTransform, Transform, apply_transform
+from monai.transforms.utils_morphological_ops import erode
from monai.transforms.utils_pytorch_numpy_unification import (
any_np_pt,
ascontiguousarray,
+ concatenate,
cumsum,
isfinite,
nonzero,
ravel,
searchsorted,
+ softplus,
unique,
unravel_index,
where,
@@ -64,6 +69,8 @@
min_version,
optional_import,
pytorch_after,
+ unsqueeze_left,
+ unsqueeze_right,
)
from monai.utils.enums import TransformBackends
from monai.utils.type_conversion import (
@@ -102,11 +109,15 @@
"generate_spatial_bounding_box",
"get_extreme_points",
"get_largest_connected_component_mask",
+ "keep_merge_components_with_points",
+ "keep_components_with_positive_points",
+ "convert_points_to_disc",
"remove_small_objects",
"img_bounds",
"in_bounds",
"is_empty",
"is_positive",
+ "map_and_generate_sampling_centers",
"map_binary_to_indices",
"map_classes_to_indices",
"map_spatial_axes",
@@ -131,9 +142,45 @@
"resolves_modes",
"has_status_keys",
"distance_transform_edt",
+ "soft_clip",
]
+def soft_clip(
+ arr: NdarrayOrTensor,
+ sharpness_factor: float = 1.0,
+ minv: NdarrayOrTensor | float | int | None = None,
+ maxv: NdarrayOrTensor | float | int | None = None,
+ dtype: DtypeLike | torch.dtype = np.float32,
+) -> NdarrayOrTensor:
+ """
+ Apply soft clip to the input array or tensor.
+ The intensity values will be soft clipped according to
+ f(x) = x + (1/sharpness_factor)*softplus(- c(x - minv)) - (1/sharpness_factor)*softplus(c(x - maxv))
+ From https://medium.com/life-at-hopper/clip-it-clip-it-good-1f1bf711b291
+
+ To perform one-sided clipping, set either minv or maxv to None.
+ Args:
+ arr: input array to clip.
+ sharpness_factor: the sharpness of the soft clip function, default to 1.
+ minv: minimum value of target clipped array.
+ maxv: maximum value of target clipped array.
+ dtype: if not None, convert input array to dtype before computation.
+
+ """
+
+ if dtype is not None:
+ arr, *_ = convert_data_type(arr, dtype=dtype)
+
+ v = arr
+ if minv is not None:
+ v = v + softplus(-sharpness_factor * (arr - minv)) / sharpness_factor
+ if maxv is not None:
+ v = v - softplus(sharpness_factor * (arr - maxv)) / sharpness_factor
+
+ return v
+
+
def rand_choice(prob: float = 0.5) -> bool:
"""
Returns True if a randomly chosen number is less than or equal to `prob`, by default this is a 50/50 chance.
@@ -331,6 +378,70 @@ def check_non_lazy_pending_ops(
warnings.warn(msg)
+def map_and_generate_sampling_centers(
+ label: NdarrayOrTensor,
+ spatial_size: Sequence[int] | int,
+ num_samples: int,
+ label_spatial_shape: Sequence[int] | None = None,
+ num_classes: int | None = None,
+ image: NdarrayOrTensor | None = None,
+ image_threshold: float = 0.0,
+ max_samples_per_class: int | None = None,
+ ratios: list[float | int] | None = None,
+ rand_state: np.random.RandomState | None = None,
+ allow_smaller: bool = False,
+ warn: bool = True,
+) -> tuple[tuple]:
+ """
+ Combine "map_classes_to_indices" and "generate_label_classes_crop_centers" functions, return crop center coordinates.
+ This calls `map_classes_to_indices` to get indices from `label`, gets the shape from `label_spatial_shape`
+ is given otherwise from the labels, calls `generate_label_classes_crop_centers`, and returns its results.
+
+ Args:
+ label: use the label data to get the indices of every class.
+ spatial_size: spatial size of the ROIs to be sampled.
+ num_samples: total sample centers to be generated.
+ label_spatial_shape: spatial shape of the original label data to unravel selected centers.
+ indices: sequence of pre-computed foreground indices of every class in 1 dimension.
+ num_classes: number of classes for argmax label, not necessary for One-Hot label.
+ image: if image is not None, only return the indices of every class that are within the valid
+ region of the image (``image > image_threshold``).
+ image_threshold: if enabled `image`, use ``image > image_threshold`` to
+ determine the valid image content area and select class indices only in this area.
+ max_samples_per_class: maximum length of indices in each class to reduce memory consumption.
+ Default is None, no subsampling.
+ ratios: ratios of every class in the label to generate crop centers, including background class.
+ if None, every class will have the same ratio to generate crop centers.
+ rand_state: numpy randomState object to align with other modules.
+ allow_smaller: if `False`, an exception will be raised if the image is smaller than
+ the requested ROI in any dimension. If `True`, any smaller dimensions will be set to
+ match the cropped size (i.e., no cropping in that dimension).
+ warn: if `True` prints a warning if a class is not present in the label.
+ Returns:
+ Tuple of crop centres
+ """
+ if label is None:
+ raise ValueError("label must not be None.")
+ indices = map_classes_to_indices(label, num_classes, image, image_threshold, max_samples_per_class)
+
+ if label_spatial_shape is not None:
+ _shape = label_spatial_shape
+ elif isinstance(label, monai.data.MetaTensor):
+ _shape = label.peek_pending_shape()
+ else:
+ _shape = label.shape[1:]
+
+ if _shape is None:
+ raise ValueError(
+ "label_spatial_shape or label with a known shape must be provided to infer the output spatial shape."
+ )
+ centers = generate_label_classes_crop_centers(
+ spatial_size, num_samples, _shape, indices, ratios, rand_state, allow_smaller, warn
+ )
+
+ return ensure_tuple(centers)
+
+
def map_binary_to_indices(
label: NdarrayOrTensor, image: NdarrayOrTensor | None = None, image_threshold: float = 0.0
) -> tuple[NdarrayOrTensor, NdarrayOrTensor]:
@@ -471,7 +582,8 @@ def weighted_patch_samples(
if not v[-1] or not isfinite(v[-1]) or v[-1] < 0: # uniform sampling
idx = r_state.randint(0, len(v), size=n_samples)
else:
- r, *_ = convert_to_dst_type(r_state.random(n_samples), v)
+ r_samples = r_state.random(n_samples)
+ r, *_ = convert_to_dst_type(r_samples, v, dtype=r_samples.dtype)
idx = searchsorted(v, r * v[-1], right=True) # type: ignore
idx, *_ = convert_to_dst_type(idx, v, dtype=torch.int) # type: ignore
# compensate 'valid' mode
@@ -625,9 +737,12 @@ def generate_label_classes_crop_centers(
for i, array in enumerate(indices):
if len(array) == 0:
- ratios_[i] = 0
- if warn:
- warnings.warn(f"no available indices of class {i} to crop, set the crop ratio of this class to zero.")
+ if ratios_[i] != 0:
+ ratios_[i] = 0
+ if warn:
+ warnings.warn(
+ f"no available indices of class {i} to crop, setting the crop ratio of this class to zero."
+ )
centers = []
classes = rand_state.choice(len(ratios_), size=num_samples, p=np.asarray(ratios_) / np.sum(ratios_))
@@ -1067,6 +1182,227 @@ def get_largest_connected_component_mask(
return convert_to_dst_type(out, dst=img, dtype=out.dtype)[0]
+def keep_merge_components_with_points(
+ img_pos: NdarrayTensor,
+ img_neg: NdarrayTensor,
+ point_coords: NdarrayTensor,
+ point_labels: NdarrayTensor,
+ pos_val: Sequence[int] = (1, 3),
+ neg_val: Sequence[int] = (0, 2),
+ margins: int = 3,
+) -> NdarrayTensor:
+ """
+ Keep connected regions of img_pos and img_neg that include the positive points and
+ negative points separately. The function is used for merging automatic results with interactive
+ results in VISTA3D.
+
+ Args:
+ img_pos: bool type tensor, shape [B, 1, H, W, D], where B means the foreground masks from a single 3D image.
+ img_neg: same format as img_pos but corresponds to negative points.
+ pos_val: positive point label values.
+ neg_val: negative point label values.
+ point_coords: the coordinates of each point, shape [B, N, 3], where N means the number of points.
+ point_labels: the label of each point, shape [B, N].
+ margins: include points outside of the region but within the margin.
+ """
+
+ cucim_skimage, has_cucim = optional_import("cucim.skimage")
+
+ use_cp = has_cp and has_cucim and isinstance(img_pos, torch.Tensor) and img_pos.device != torch.device("cpu")
+ if use_cp:
+ img_pos_ = convert_to_cupy(img_pos.short()) # type: ignore
+ img_neg_ = convert_to_cupy(img_neg.short()) # type: ignore
+ label = cucim_skimage.measure.label
+ lib = cp
+ else:
+ if not has_measure:
+ raise RuntimeError("skimage.measure required.")
+ img_pos_, *_ = convert_data_type(img_pos, np.ndarray)
+ img_neg_, *_ = convert_data_type(img_neg, np.ndarray)
+ # for skimage.measure.label, the input must be bool type
+ if img_pos_.dtype != bool or img_neg_.dtype != bool:
+ raise ValueError("img_pos and img_neg must be bool type.")
+ label = measure.label
+ lib = np
+
+ features_pos, _ = label(img_pos_, connectivity=3, return_num=True)
+ features_neg, _ = label(img_neg_, connectivity=3, return_num=True)
+
+ outs = np.zeros_like(img_pos_)
+ for bs in range(point_coords.shape[0]):
+ for i, p in enumerate(point_coords[bs]):
+ if point_labels[bs, i] in pos_val:
+ features = features_pos
+ elif point_labels[bs, i] in neg_val:
+ features = features_neg
+ else:
+ # if -1 padding point, skip
+ continue
+ for margin in range(margins):
+ if isinstance(p, np.ndarray):
+ x, y, z = np.round(p).astype(int).tolist()
+ else:
+ x, y, z = p.float().round().int().tolist()
+ l, r = max(x - margin, 0), min(x + margin + 1, features.shape[-3])
+ t, d = max(y - margin, 0), min(y + margin + 1, features.shape[-2])
+ f, b = max(z - margin, 0), min(z + margin + 1, features.shape[-1])
+ if (features[bs, 0, l:r, t:d, f:b] > 0).any():
+ index = features[bs, 0, l:r, t:d, f:b].max()
+ outs[[bs]] += lib.isin(features[[bs]], index)
+ break
+ outs[outs > 1] = 1
+ return convert_to_dst_type(outs, dst=img_pos, dtype=outs.dtype)[0]
+
+
+def keep_components_with_positive_points(
+ img: torch.Tensor, point_coords: torch.Tensor, point_labels: torch.Tensor
+) -> torch.Tensor:
+ """
+ Keep connected regions that include the positive points. Used for point-only inference postprocessing to remove
+ regions without positive points.
+ Args:
+ img: [1, B, H, W, D]. Output prediction from VISTA3D. Value is before sigmoid and contain NaN value.
+ point_coords: [B, N, 3]. Point click coordinates
+ point_labels: [B, N]. Point click labels.
+ """
+ if not has_measure:
+ raise RuntimeError("skimage.measure required.")
+ outs = torch.zeros_like(img)
+ for c in range(len(point_coords)):
+ if not ((point_labels[c] == 3).any() or (point_labels[c] == 1).any()):
+ # skip if no positive points.
+ continue
+ coords = point_coords[c, point_labels[c] == 3].tolist() + point_coords[c, point_labels[c] == 1].tolist()
+ not_nan_mask = ~torch.isnan(img[0, c])
+ img_ = torch.nan_to_num(img[0, c] > 0, 0)
+ img_, *_ = convert_data_type(img_, np.ndarray) # type: ignore
+ label = measure.label
+ features = label(img_, connectivity=3)
+ pos_mask = torch.from_numpy(img_).to(img.device) > 0
+ # if num features less than max desired, nothing to do.
+ features = torch.from_numpy(features).to(img.device)
+ # generate a map with all pos points
+ idx = []
+ for p in coords:
+ idx.append(features[round(p[0]), round(p[1]), round(p[2])].item())
+ idx = list(set(idx))
+ for i in idx:
+ if i == 0:
+ continue
+ outs[0, c] += features == i
+ outs = outs > 0
+ # find negative mean value
+ fill_in = img[0, c][torch.logical_and(~outs[0, c], not_nan_mask)].mean()
+ img[0, c][torch.logical_and(pos_mask, ~outs[0, c])] = fill_in
+ return img
+
+
+def convert_points_to_disc(
+ image_size: Sequence[int], point: Tensor, point_label: Tensor, radius: int = 2, disc: bool = False
+):
+ """
+ Convert a 3D point coordinates into image mask. The returned mask has the same spatial
+ size as `image_size` while the batch dimension is the same as 'point' batch dimension.
+ The point is converted to a mask ball with radius defined by `radius`. The output
+ contains two channels each for negative (first channel) and positive points.
+
+ Args:
+ image_size: The output size of the converted mask. It should be a 3D tuple.
+ point: [B, N, 3], 3D point coordinates.
+ point_label: [B, N], 0 or 2 means negative points, 1 or 3 means postive points.
+ radius: disc ball radius size.
+ disc: If true, use regular disc, other use gaussian.
+ """
+ masks = torch.zeros([point.shape[0], 2, image_size[0], image_size[1], image_size[2]], device=point.device)
+ _array = [
+ torch.arange(start=0, end=image_size[i], step=1, dtype=torch.float32, device=point.device) for i in range(3)
+ ]
+ coord_rows, coord_cols, coord_z = torch.meshgrid(_array[0], _array[1], _array[2])
+ # [1, 3, h, w, d] -> [b, 2, 3, h, w, d]
+ coords = unsqueeze_left(torch.stack((coord_rows, coord_cols, coord_z), dim=0), 6)
+ coords = coords.repeat(point.shape[0], 2, 1, 1, 1, 1)
+ for b, n in np.ndindex(*point.shape[:2]):
+ point_bn = unsqueeze_right(point[b, n], 4)
+ if point_label[b, n] > -1:
+ channel = 0 if (point_label[b, n] == 0 or point_label[b, n] == 2) else 1
+ pow_diff = torch.pow(coords[b, channel] - point_bn, 2)
+ if disc:
+ masks[b, channel] += pow_diff.sum(0) < radius**2
+ else:
+ masks[b, channel] += torch.exp(-pow_diff.sum(0) / (2 * radius**2))
+ return masks
+
+
+def sample_points_from_label(
+ labels: Tensor,
+ label_set: Sequence[int],
+ max_ppoint: int = 1,
+ max_npoint: int = 0,
+ device: torch.device | str | None = "cpu",
+ use_center: bool = False,
+):
+ """Sample points from labels.
+
+ Args:
+ labels: [1, 1, H, W, D]
+ label_set: local index, must match values in labels.
+ max_ppoint: maximum positive point samples.
+ max_npoint: maximum negative point samples.
+ device: returned tensor device.
+ use_center: whether to sample points from center.
+
+ Returns:
+ point: point coordinates of [B, N, 3]. B equals to the length of label_set.
+ point_label: [B, N], always 0 for negative, 1 for positive.
+ """
+ if not labels.shape[0] == 1:
+ raise ValueError("labels must have batch size 1.")
+
+ if device is None:
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+ labels = labels[0, 0]
+ unique_labels = labels.unique().cpu().numpy().tolist()
+ _point = []
+ _point_label = []
+ for id in label_set:
+ if id in unique_labels:
+ plabels = labels == int(id)
+ nlabels = ~plabels
+ _plabels = get_largest_connected_component_mask(erode(plabels.unsqueeze(0).unsqueeze(0))[0, 0])
+ plabelpoints = torch.nonzero(_plabels).to(device)
+ if len(plabelpoints) == 0:
+ plabelpoints = torch.nonzero(plabels).to(device)
+ nlabelpoints = torch.nonzero(nlabels).to(device)
+ num_p = min(len(plabelpoints), max_ppoint)
+ num_n = min(len(nlabelpoints), max_npoint)
+ pad = max_ppoint + max_npoint - num_p - num_n
+ if use_center:
+ pmean = plabelpoints.float().mean(0)
+ pdis = ((plabelpoints - pmean) ** 2).sum(-1)
+ _, sorted_indices_tensor = torch.sort(pdis)
+ sorted_indices = sorted_indices_tensor.cpu().tolist()
+ else:
+ sorted_indices = list(range(len(plabelpoints)))
+ random.shuffle(sorted_indices)
+ _point.append(
+ torch.stack(
+ [plabelpoints[sorted_indices[i]] for i in range(num_p)]
+ + random.choices(nlabelpoints, k=num_n)
+ + [torch.tensor([0, 0, 0], device=device)] * pad
+ )
+ )
+ _point_label.append(torch.tensor([1] * num_p + [0] * num_n + [-1] * pad).to(device))
+ else:
+ # pad the background labels
+ _point.append(torch.zeros(max_ppoint + max_npoint, 3).to(device))
+ _point_label.append(torch.zeros(max_ppoint + max_npoint).to(device) - 1)
+ point = torch.stack(_point)
+ point_label = torch.stack(_point_label)
+
+ return point, point_label
+
+
def remove_small_objects(
img: NdarrayTensor,
min_size: int = 64,
@@ -1528,7 +1864,7 @@ class Fourier:
"""
@staticmethod
- def shift_fourier(x: NdarrayOrTensor, spatial_dims: int) -> NdarrayOrTensor:
+ def shift_fourier(x: NdarrayOrTensor, spatial_dims: int, as_contiguous: bool = False) -> NdarrayOrTensor:
"""
Applies fourier transform and shifts the zero-frequency component to the
center of the spectrum. Only the spatial dimensions get transformed.
@@ -1536,6 +1872,7 @@ def shift_fourier(x: NdarrayOrTensor, spatial_dims: int) -> NdarrayOrTensor:
Args:
x: Image to transform.
spatial_dims: Number of spatial dimensions.
+ as_contiguous: Whether to convert the cached NumPy array or PyTorch tensor to be contiguous.
Returns
k: K-space data.
@@ -1550,10 +1887,12 @@ def shift_fourier(x: NdarrayOrTensor, spatial_dims: int) -> NdarrayOrTensor:
k = np.fft.fftshift(np.fft.fftn(x.cpu().numpy(), axes=dims), axes=dims)
else:
k = np.fft.fftshift(np.fft.fftn(x, axes=dims), axes=dims)
- return k
+ return ascontiguousarray(k) if as_contiguous else k
@staticmethod
- def inv_shift_fourier(k: NdarrayOrTensor, spatial_dims: int, n_dims: int | None = None) -> NdarrayOrTensor:
+ def inv_shift_fourier(
+ k: NdarrayOrTensor, spatial_dims: int, n_dims: int | None = None, as_contiguous: bool = False
+ ) -> NdarrayOrTensor:
"""
Applies inverse shift and fourier transform. Only the spatial
dimensions are transformed.
@@ -1561,6 +1900,7 @@ def inv_shift_fourier(k: NdarrayOrTensor, spatial_dims: int, n_dims: int | None
Args:
k: K-space data.
spatial_dims: Number of spatial dimensions.
+ as_contiguous: Whether to convert the cached NumPy array or PyTorch tensor to be contiguous.
Returns:
x: Tensor in image space.
@@ -1575,7 +1915,7 @@ def inv_shift_fourier(k: NdarrayOrTensor, spatial_dims: int, n_dims: int | None
out = np.fft.ifftn(np.fft.ifftshift(k.cpu().numpy(), axes=dims), axes=dims).real
else:
out = np.fft.ifftn(np.fft.ifftshift(k, axes=dims), axes=dims).real
- return out
+ return ascontiguousarray(out) if as_contiguous else out
def get_number_image_type_conversions(transform: Compose, test_data: Any, key: Hashable | None = None) -> int:
@@ -2150,7 +2490,7 @@ def distance_transform_edt(
if return_distances:
dtype = torch.float64 if float64_distances else torch.float32
if distances is None:
- distances = torch.zeros_like(img, dtype=dtype) # type: ignore
+ distances = torch.zeros_like(img, memory_format=torch.contiguous_format, dtype=dtype) # type: ignore
else:
if not isinstance(distances, torch.Tensor) and distances.device != img.device:
raise TypeError("distances must be a torch.Tensor on the same device as img")
@@ -2179,6 +2519,7 @@ def distance_transform_edt(
block_params=block_params,
float64_distances=float64_distances,
)
+ torch.cuda.synchronize()
else:
if not has_ndimage:
raise RuntimeError("scipy.ndimage required if cupy is not available")
@@ -2212,7 +2553,7 @@ def distance_transform_edt(
r_vals = []
if return_distances and distances_original is None:
- r_vals.append(distances)
+ r_vals.append(distances_ if use_cp else distances)
if return_indices and indices_original is None:
r_vals.append(indices)
if not r_vals:
@@ -2221,5 +2562,26 @@ def distance_transform_edt(
return convert_data_type(r_vals[0] if len(r_vals) == 1 else r_vals, output_type=type(img), device=device)[0]
+def apply_affine_to_points(data: torch.Tensor, affine: torch.Tensor, dtype: DtypeLike | torch.dtype | None = None):
+ """
+ apply affine transformation to a set of points.
+
+ Args:
+ data: input data to apply affine transformation, should be a tensor of shape (C, N, 2 or 3),
+ where C represents the number of channels and N denotes the number of points.
+ affine: affine matrix to be applied, should be a tensor of shape (3, 3) or (4, 4).
+ dtype: output data dtype.
+ """
+ data_: torch.Tensor = convert_to_tensor(data, track_meta=False, dtype=torch.float64)
+ affine = to_affine_nd(data_.shape[-1], affine)
+
+ homogeneous: torch.Tensor = concatenate((data_, torch.ones((data_.shape[0], data_.shape[1], 1))), axis=2) # type: ignore
+ transformed_homogeneous = torch.matmul(homogeneous, affine.T)
+ transformed_coordinates = transformed_homogeneous[:, :, :-1]
+ out, *_ = convert_to_dst_type(transformed_coordinates, data, dtype=dtype)
+
+ return out
+
+
if __name__ == "__main__":
print_transform_backends()
diff --git a/monai/transforms/utils_create_transform_ims.py b/monai/transforms/utils_create_transform_ims.py
index 4b5990abd3..a29fd4dbf9 100644
--- a/monai/transforms/utils_create_transform_ims.py
+++ b/monai/transforms/utils_create_transform_ims.py
@@ -269,11 +269,9 @@ def update_docstring(code_path, transform_name):
def pre_process_data(data, ndim, is_map, is_post):
- """If transform requires 2D data, then convert to 2D"""
+ """If transform requires 2D data, then convert to 2D by selecting the middle of the last dimension."""
if ndim == 2:
- for k in keys:
- data[k] = data[k][..., data[k].shape[-1] // 2]
-
+ data = {k: v[..., v.shape[-1] // 2] for k, v in data.items()}
if is_map:
return data
return data[CommonKeys.LABEL] if is_post else data[CommonKeys.IMAGE]
diff --git a/monai/transforms/utils_morphological_ops.py b/monai/transforms/utils_morphological_ops.py
new file mode 100644
index 0000000000..61d3c5b858
--- /dev/null
+++ b/monai/transforms/utils_morphological_ops.py
@@ -0,0 +1,172 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+from collections.abc import Sequence
+
+import torch
+import torch.nn.functional as F
+from torch import Tensor
+
+from monai.config import NdarrayOrTensor
+from monai.utils import convert_data_type, convert_to_dst_type, ensure_tuple_rep
+
+__all__ = ["erode", "dilate"]
+
+
+def erode(mask: NdarrayOrTensor, filter_size: int | Sequence[int] = 3, pad_value: float = 1.0) -> NdarrayOrTensor:
+ """
+ Erode 2D/3D binary mask.
+
+ Args:
+ mask: input 2D/3D binary mask, [N,C,M,N] or [N,C,M,N,P] torch tensor or ndarray.
+ filter_size: erosion filter size, has to be odd numbers, default to be 3.
+ pad_value: the filled value for padding. We need to pad the input before filtering
+ to keep the output with the same size as input. Usually use default value
+ and not changed.
+
+ Return:
+ eroded mask, same shape and data type as input.
+
+ Example:
+
+ .. code-block:: python
+
+ # define a naive mask
+ mask = torch.zeros(3,2,3,3,3)
+ mask[:,:,1,1,1] = 1.0
+ filter_size = 3
+ erode_result = erode(mask, filter_size) # expect torch.zeros(3,2,3,3,3)
+ dilate_result = dilate(mask, filter_size) # expect torch.ones(3,2,3,3,3)
+ """
+ mask_t, *_ = convert_data_type(mask, torch.Tensor)
+ res_mask_t = erode_t(mask_t, filter_size=filter_size, pad_value=pad_value)
+ res_mask: NdarrayOrTensor
+ res_mask, *_ = convert_to_dst_type(src=res_mask_t, dst=mask)
+ return res_mask
+
+
+def dilate(mask: NdarrayOrTensor, filter_size: int | Sequence[int] = 3, pad_value: float = 0.0) -> NdarrayOrTensor:
+ """
+ Dilate 2D/3D binary mask.
+
+ Args:
+ mask: input 2D/3D binary mask, [N,C,M,N] or [N,C,M,N,P] torch tensor or ndarray.
+ filter_size: dilation filter size, has to be odd numbers, default to be 3.
+ pad_value: the filled value for padding. We need to pad the input before filtering
+ to keep the output with the same size as input. Usually use default value
+ and not changed.
+
+ Return:
+ dilated mask, same shape and data type as input.
+
+ Example:
+
+ .. code-block:: python
+
+ # define a naive mask
+ mask = torch.zeros(3,2,3,3,3)
+ mask[:,:,1,1,1] = 1.0
+ filter_size = 3
+ erode_result = erode(mask,filter_size) # expect torch.zeros(3,2,3,3,3)
+ dilate_result = dilate(mask,filter_size) # expect torch.ones(3,2,3,3,3)
+ """
+ mask_t, *_ = convert_data_type(mask, torch.Tensor)
+ res_mask_t = dilate_t(mask_t, filter_size=filter_size, pad_value=pad_value)
+ res_mask: NdarrayOrTensor
+ res_mask, *_ = convert_to_dst_type(src=res_mask_t, dst=mask)
+ return res_mask
+
+
+def get_morphological_filter_result_t(mask_t: Tensor, filter_size: int | Sequence[int], pad_value: float) -> Tensor:
+ """
+ Apply a morphological filter to a 2D/3D binary mask tensor.
+
+ Args:
+ mask_t: input 2D/3D binary mask, [N,C,M,N] or [N,C,M,N,P] torch tensor.
+ filter_size: morphological filter size, has to be odd numbers.
+ pad_value: the filled value for padding. We need to pad the input before filtering
+ to keep the output with the same size as input.
+
+ Return:
+ Tensor: Morphological filter result mask, same shape as input.
+ """
+ spatial_dims = len(mask_t.shape) - 2
+ if spatial_dims not in [2, 3]:
+ raise ValueError(
+ f"spatial_dims must be either 2 or 3, "
+ f"got spatial_dims={spatial_dims} for mask tensor with shape of {mask_t.shape}."
+ )
+
+ # Define the structuring element
+ filter_size = ensure_tuple_rep(filter_size, spatial_dims)
+ if any(size % 2 == 0 for size in filter_size):
+ raise ValueError(f"All dimensions in filter_size must be odd numbers, got {filter_size}.")
+
+ structuring_element = torch.ones((mask_t.shape[1], mask_t.shape[1]) + filter_size).to(mask_t.device)
+
+ # Pad the input tensor to handle border pixels
+ # Calculate padding size
+ pad_size = [size // 2 for size in filter_size for _ in range(2)]
+
+ input_padded = F.pad(mask_t.float(), pad_size, mode="constant", value=pad_value)
+
+ # Apply filter operation
+ conv_fn = F.conv2d if spatial_dims == 2 else F.conv3d
+ output = conv_fn(input_padded, structuring_element, padding=0) / torch.sum(structuring_element[0, ...])
+
+ return output
+
+
+def erode_t(mask_t: Tensor, filter_size: int | Sequence[int] = 3, pad_value: float = 1.0) -> Tensor:
+ """
+ Erode 2D/3D binary mask with data type as torch tensor.
+
+ Args:
+ mask_t: input 2D/3D binary mask, [N,C,M,N] or [N,C,M,N,P] torch tensor.
+ filter_size: erosion filter size, has to be odd numbers, default to be 3.
+ pad_value: the filled value for padding. We need to pad the input before filtering
+ to keep the output with the same size as input. Usually use default value
+ and not changed.
+
+ Return:
+ Tensor: eroded mask, same shape as input.
+ """
+
+ output = get_morphological_filter_result_t(mask_t, filter_size, pad_value)
+
+ # Set output values based on the minimum value within the structuring element
+ output = torch.where(torch.abs(output - 1.0) < 1e-7, 1.0, 0.0)
+
+ return output
+
+
+def dilate_t(mask_t: Tensor, filter_size: int | Sequence[int] = 3, pad_value: float = 0.0) -> Tensor:
+ """
+ Dilate 2D/3D binary mask with data type as torch tensor.
+
+ Args:
+ mask_t: input 2D/3D binary mask, [N,C,M,N] or [N,C,M,N,P] torch tensor.
+ filter_size: dilation filter size, has to be odd numbers, default to be 3.
+ pad_value: the filled value for padding. We need to pad the input before filtering
+ to keep the output with the same size as input. Usually use default value
+ and not changed.
+
+ Return:
+ Tensor: dilated mask, same shape as input.
+ """
+ output = get_morphological_filter_result_t(mask_t, filter_size, pad_value)
+
+ # Set output values based on the minimum value within the structuring element
+ output = torch.where(output > 0, 1.0, 0.0)
+
+ return output
diff --git a/monai/transforms/utils_pytorch_numpy_unification.py b/monai/transforms/utils_pytorch_numpy_unification.py
index 0774d50314..365bd1eab5 100644
--- a/monai/transforms/utils_pytorch_numpy_unification.py
+++ b/monai/transforms/utils_pytorch_numpy_unification.py
@@ -52,9 +52,24 @@
"median",
"mean",
"std",
+ "softplus",
]
+def softplus(x: NdarrayOrTensor) -> NdarrayOrTensor:
+ """stable softplus through `np.logaddexp` with equivalent implementation for torch.
+
+ Args:
+ x: array/tensor.
+
+ Returns:
+ Softplus of the input.
+ """
+ if isinstance(x, np.ndarray):
+ return np.logaddexp(np.zeros_like(x), x)
+ return torch.logaddexp(torch.zeros_like(x), x)
+
+
def allclose(a: NdarrayTensor, b: NdarrayOrTensor, rtol=1e-5, atol=1e-8, equal_nan=False) -> bool:
"""`np.allclose` with equivalent implementation for torch."""
b, *_ = convert_to_dst_type(b, a, wrap_sequence=True)
@@ -73,7 +88,7 @@ def moveaxis(x: NdarrayOrTensor, src: int | Sequence[int], dst: int | Sequence[i
def in1d(x, y):
"""`np.in1d` with equivalent implementation for torch."""
if isinstance(x, np.ndarray):
- return np.in1d(x, y)
+ return np.isin(x, y)
return (x[..., None] == torch.tensor(y, device=x.device)).any(-1).view(-1)
@@ -465,7 +480,7 @@ def max(x: NdarrayTensor, dim: int | tuple | None = None, **kwargs) -> NdarrayTe
else:
ret = torch.max(x, int(dim), **kwargs) # type: ignore
- return ret
+ return ret[0] if isinstance(ret, tuple) else ret
def mean(x: NdarrayTensor, dim: int | tuple | None = None, **kwargs) -> NdarrayTensor:
@@ -531,7 +546,7 @@ def min(x: NdarrayTensor, dim: int | tuple | None = None, **kwargs) -> NdarrayTe
else:
ret = torch.min(x, int(dim), **kwargs) # type: ignore
- return ret
+ return ret[0] if isinstance(ret, tuple) else ret
def std(x: NdarrayTensor, dim: int | tuple | None = None, unbiased: bool = False) -> NdarrayTensor:
diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py
index 2c32eb2cf4..8f2f400b5d 100644
--- a/monai/utils/__init__.py
+++ b/monai/utils/__init__.py
@@ -11,8 +11,6 @@
from __future__ import annotations
-# have to explicitly bring these in here to resolve circular import issues
-from .aliases import alias, resolve_name
from .component_store import ComponentStore
from .decorators import MethodReplacer, RestartGenerator
from .deprecate_utils import DeprecatedError, deprecated, deprecated_arg, deprecated_arg_default
@@ -40,6 +38,7 @@
GridSamplePadMode,
HoVerNetBranch,
HoVerNetMode,
+ IgniteInfo,
InterpolateMode,
JITMetadataKeys,
LazyAttr,
@@ -79,6 +78,7 @@
ensure_tuple_size,
fall_back_tuple,
first,
+ flatten_dict,
get_seed,
has_option,
is_immutable,
@@ -107,9 +107,9 @@
InvalidPyTorchVersionError,
OptionalImportError,
allow_missing_reference,
+ compute_capabilities_after,
damerau_levenshtein_distance,
exact_version,
- export,
get_full_type_name,
get_package_version,
get_torch_version_tuple,
@@ -126,6 +126,7 @@
version_leq,
)
from .nvtx import Range
+from .ordering import Ordering
from .profiling import (
PerfContext,
ProfileHandler,
@@ -147,7 +148,10 @@
dtype_numpy_to_torch,
dtype_torch_to_numpy,
get_dtype,
+ get_dtype_string,
get_equivalent_dtype,
get_numpy_dtype_from_string,
get_torch_dtype_from_string,
)
+
+# have to explicitly bring these in here to resolve circular import issues
diff --git a/monai/utils/aliases.py b/monai/utils/aliases.py
deleted file mode 100644
index 2974eec2eb..0000000000
--- a/monai/utils/aliases.py
+++ /dev/null
@@ -1,103 +0,0 @@
-# Copyright (c) MONAI Consortium
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-# http://www.apache.org/licenses/LICENSE-2.0
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-"""
-This module is written for configurable workflow, not currently in use.
-"""
-
-from __future__ import annotations
-
-import importlib
-import inspect
-import sys
-import threading
-
-alias_lock = threading.RLock()
-GlobalAliases = {}
-
-__all__ = ["alias", "resolve_name"]
-
-
-def alias(*names):
- """
- Stores the decorated function or class in the global aliases table under the given names and as the `__aliases__`
- member of the decorated object. This new member will contain all alias names declared for that object.
- """
-
- def _outer(obj):
- for n in names:
- with alias_lock:
- GlobalAliases[n] = obj
-
- # set the member list __aliases__ to contain the alias names defined by the decorator for `obj`
- obj.__aliases__ = getattr(obj, "__aliases__", ()) + tuple(names)
-
- return obj
-
- return _outer
-
-
-def resolve_name(name):
- """
- Search for the declaration (function or class) with the given name. This will first search the list of aliases to
- see if it was declared with this aliased name, then search treating `name` as a fully qualified name, then search
- the loaded modules for one having a declaration with the given name. If no declaration is found, raise ValueError.
-
- Raises:
- ValueError: When the module is not found.
- ValueError: When the module does not have the specified member.
- ValueError: When multiple modules with the declaration name are found.
- ValueError: When no module with the specified member is found.
-
- """
- # attempt to resolve an alias
- with alias_lock:
- obj = GlobalAliases.get(name)
-
- if name in GlobalAliases and obj is None:
- raise AssertionError
-
- # attempt to resolve a qualified name
- if obj is None and "." in name:
- modname, declname = name.rsplit(".", 1)
-
- try:
- mod = importlib.import_module(modname)
- obj = getattr(mod, declname, None)
- except ModuleNotFoundError as not_found_err:
- raise ValueError(f"Module {modname!r} not found.") from not_found_err
-
- if obj is None:
- raise ValueError(f"Module {modname!r} does not have member {declname!r}.")
-
- # attempt to resolve a simple name
- if obj is None:
- # Get all modules having the declaration/import, need to check here that getattr returns something which doesn't
- # equate to False since in places __getattr__ returns 0 incorrectly:
- # https://github.com/tensorflow/tensorboard/blob/a22566561d2b4fea408755a951ac9eaf3a156f8e/
- # tensorboard/compat/tensorflow_stub/pywrap_tensorflow.py#L35
- mods = [m for m in list(sys.modules.values()) if getattr(m, name, None)]
-
- if len(mods) > 0: # found modules with this declaration or import
- if len(mods) > 1: # found multiple modules, need to determine if ambiguous or just multiple imports
- foundmods = set(filter(None, {inspect.getmodule(getattr(m, name)) for m in mods})) # resolve imports
-
- if len(foundmods) > 1: # found multiple declarations with the same name
- modnames = [m.__name__ for m in foundmods]
- msg = f"Multiple modules ({modnames!r}) with declaration name {name!r} found, resolution is ambiguous."
- raise ValueError(msg)
- mods = list(foundmods)
-
- obj = getattr(mods[0], name)
-
- if obj is None:
- raise ValueError(f"No module with member {name!r} found.")
-
- return obj
diff --git a/monai/utils/component_store.py b/monai/utils/component_store.py
index d1e71eaebf..bf0d632ddd 100644
--- a/monai/utils/component_store.py
+++ b/monai/utils/component_store.py
@@ -12,9 +12,10 @@
from __future__ import annotations
from collections import namedtuple
+from collections.abc import Iterable
from keyword import iskeyword
from textwrap import dedent, indent
-from typing import Any, Callable, Iterable, TypeVar
+from typing import Any, Callable, TypeVar
T = TypeVar("T")
diff --git a/monai/utils/decorators.py b/monai/utils/decorators.py
index 1c064468e8..a784510c64 100644
--- a/monai/utils/decorators.py
+++ b/monai/utils/decorators.py
@@ -15,7 +15,8 @@
__all__ = ["RestartGenerator", "MethodReplacer"]
-from typing import Callable, Generator
+from collections.abc import Generator
+from typing import Callable
class RestartGenerator:
diff --git a/monai/utils/dist.py b/monai/utils/dist.py
index 2418b43591..47da2bee6e 100644
--- a/monai/utils/dist.py
+++ b/monai/utils/dist.py
@@ -11,20 +11,15 @@
from __future__ import annotations
-import sys
import warnings
from collections.abc import Callable
from logging import Filter
-
-if sys.version_info >= (3, 8):
- from typing import Literal
-
-from typing import overload
+from typing import Literal, overload
import torch
import torch.distributed as dist
-from monai.config import IgniteInfo
+from monai.utils.enums import IgniteInfo
from monai.utils.module import min_version, optional_import
idist, has_ignite = optional_import("ignite", IgniteInfo.OPT_IMPORT_VERSION, min_version, "distributed")
diff --git a/monai/utils/enums.py b/monai/utils/enums.py
index b786e92151..1fbf3ffa05 100644
--- a/monai/utils/enums.py
+++ b/monai/utils/enums.py
@@ -15,8 +15,6 @@
from enum import Enum
from typing import TYPE_CHECKING
-from monai.config import IgniteInfo
-from monai.utils import deprecated
from monai.utils.module import min_version, optional_import
__all__ = [
@@ -56,13 +54,13 @@
"DataStatsKeys",
"ImageStatsKeys",
"LabelStatsKeys",
- "AlgoEnsembleKeys",
"HoVerNetMode",
"HoVerNetBranch",
"LazyAttr",
"BundleProperty",
"BundlePropertyConfig",
"AlgoKeys",
+ "IgniteInfo",
]
@@ -91,14 +89,6 @@ def __repr__(self):
return self.value
-if TYPE_CHECKING:
- from ignite.engine import EventEnum
-else:
- EventEnum, _ = optional_import(
- "ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EventEnum", as_type="base"
- )
-
-
class NumpyPadMode(StrEnum):
"""
See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
@@ -543,6 +533,7 @@ class MetaKeys(StrEnum):
SPATIAL_SHAPE = "spatial_shape" # optional key for the length in each spatial dimension
SPACE = "space" # possible values of space type are defined in `SpaceKeys`
ORIGINAL_CHANNEL_DIM = "original_channel_dim" # an integer or float("nan")
+ SAVED_TO = "saved_to"
class ColorOrder(StrEnum):
@@ -614,17 +605,6 @@ class LabelStatsKeys(StrEnum):
LABEL_NCOMP = "ncomponents"
-@deprecated(since="1.2", removed="1.4", msg_suffix="please use `AlgoKeys` instead.")
-class AlgoEnsembleKeys(StrEnum):
- """
- Default keys for Mixed Ensemble
- """
-
- ID = "identifier"
- ALGO = "infer_algo"
- SCORE = "best_metric"
-
-
class HoVerNetMode(StrEnum):
"""
Modes for HoVerNet model:
@@ -729,6 +709,35 @@ class AdversarialKeys(StrEnum):
DISCRIMINATOR_LOSS = "discriminator_loss"
+class OrderingType(StrEnum):
+ RASTER_SCAN = "raster_scan"
+ S_CURVE = "s_curve"
+ RANDOM = "random"
+
+
+class OrderingTransformations(StrEnum):
+ ROTATE_90 = "rotate_90"
+ TRANSPOSE = "transpose"
+ REFLECT = "reflect"
+
+
+class IgniteInfo(StrEnum):
+ """
+ Config information of the PyTorch ignite package.
+
+ """
+
+ OPT_IMPORT_VERSION = "0.4.11"
+
+
+if TYPE_CHECKING:
+ from ignite.engine import EventEnum
+else:
+ EventEnum, _ = optional_import(
+ "ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EventEnum", as_type="base"
+ )
+
+
class AdversarialIterationEvents(EventEnum):
"""
Keys used to define events as used in the AdversarialTrainer.
@@ -745,15 +754,3 @@ class AdversarialIterationEvents(EventEnum):
DISCRIMINATOR_LOSS_COMPLETED = "discriminator_loss_completed"
DISCRIMINATOR_BACKWARD_COMPLETED = "discriminator_backward_completed"
DISCRIMINATOR_MODEL_COMPLETED = "discriminator_model_completed"
-
-
-class OrderingType(StrEnum):
- RASTER_SCAN = "raster_scan"
- S_CURVE = "s_curve"
- RANDOM = "random"
-
-
-class OrderingTransformations(StrEnum):
- ROTATE_90 = "rotate_90"
- TRANSPOSE = "transpose"
- REFLECT = "reflect"
diff --git a/monai/utils/jupyter_utils.py b/monai/utils/jupyter_utils.py
index 7dcd0e62cd..b1b43a6767 100644
--- a/monai/utils/jupyter_utils.py
+++ b/monai/utils/jupyter_utils.py
@@ -24,7 +24,7 @@
import numpy as np
import torch
-from monai.config import IgniteInfo
+from monai.utils import IgniteInfo
from monai.utils.module import min_version, optional_import
try:
diff --git a/monai/utils/misc.py b/monai/utils/misc.py
index c30eb0904d..b96a48ad7e 100644
--- a/monai/utils/misc.py
+++ b/monai/utils/misc.py
@@ -24,7 +24,6 @@
import warnings
from ast import literal_eval
from collections.abc import Callable, Iterable, Sequence
-from distutils.util import strtobool
from math import log10
from pathlib import Path
from typing import TYPE_CHECKING, Any, TypeVar, cast, overload
@@ -78,6 +77,25 @@
"run_cmd",
]
+
+def _strtobool(val: str) -> bool:
+ """
+ Replaces deprecated (pre python 3.12)
+ distutils strtobool function.
+
+ True values are y, yes, t, true, on and 1;
+ False values are n, no, f, false, off and 0.
+ Raises ValueError if val is anything else.
+ """
+ val = val.lower()
+ if val in ("y", "yes", "t", "true", "on", "1"):
+ return True
+ elif val in ("n", "no", "f", "false", "off", "0"):
+ return False
+ else:
+ raise ValueError(f"invalid truth value {val}")
+
+
_seed = None
_flag_deterministic = torch.backends.cudnn.deterministic
_flag_cudnn_benchmark = torch.backends.cudnn.benchmark
@@ -100,6 +118,7 @@ def star_zip_with(op, *vals):
T = TypeVar("T")
+NT = TypeVar("NT", np.ndarray, torch.Tensor)
@overload
@@ -400,7 +419,7 @@ def _parse_var(s):
d[key] = literal_eval(value)
except ValueError:
try:
- d[key] = bool(strtobool(str(value)))
+ d[key] = bool(_strtobool(str(value)))
except ValueError:
d[key] = value
return d
@@ -527,7 +546,7 @@ def doc_images() -> str | None:
@staticmethod
def algo_hash() -> str | None:
- return os.environ.get("MONAI_ALGO_HASH", "c51bc6a")
+ return os.environ.get("MONAI_ALGO_HASH", "e4cf5a1")
@staticmethod
def trace_transform() -> str | None:
@@ -796,7 +815,7 @@ def __init__(self, input_unit: str, target_unit: str) -> None:
"Both input and target units should be from the same quantity. "
f"Input quantity is {input_base} while target quantity is {target_base}"
)
- self._calculate_conversion_factor()
+ self.conversion_factor = self._calculate_conversion_factor()
def _get_valid_unit_and_base(self, unit):
unit = str(unit).lower()
@@ -823,7 +842,7 @@ def _calculate_conversion_factor(self):
return 1.0
input_power = self._get_unit_power(self.input_unit)
target_power = self._get_unit_power(self.target_unit)
- self.conversion_factor = 10 ** (input_power - target_power)
+ return 10 ** (input_power - target_power)
def __call__(self, value: int | float) -> Any:
return float(value) * self.conversion_factor
@@ -868,7 +887,7 @@ def run_cmd(cmd_list: list[str], **kwargs: Any) -> subprocess.CompletedProcess:
if kwargs.pop("run_cmd_verbose", False):
import monai
- monai.apps.utils.get_logger("run_cmd").info(f"{cmd_list}")
+ monai.apps.utils.get_logger("monai.utils.run_cmd").info(f"{cmd_list}") # type: ignore[attr-defined]
try:
return subprocess.run(cmd_list, **kwargs)
except subprocess.CalledProcessError as e:
@@ -889,11 +908,24 @@ def is_sqrt(num: Sequence[int] | int) -> bool:
return ensure_tuple(ret) == num
-def unsqueeze_right(arr: NdarrayOrTensor, ndim: int) -> NdarrayOrTensor:
+def unsqueeze_right(arr: NT, ndim: int) -> NT:
"""Append 1-sized dimensions to `arr` to create a result with `ndim` dimensions."""
return arr[(...,) + (None,) * (ndim - arr.ndim)]
-def unsqueeze_left(arr: NdarrayOrTensor, ndim: int) -> NdarrayOrTensor:
+def unsqueeze_left(arr: NT, ndim: int) -> NT:
"""Prepend 1-sized dimensions to `arr` to create a result with `ndim` dimensions."""
return arr[(None,) * (ndim - arr.ndim)]
+
+
+def flatten_dict(metrics: dict[str, Any]) -> dict[str, Any]:
+ """
+ Flatten the nested dictionary to a flat dictionary.
+ """
+ result = {}
+ for key, value in metrics.items():
+ if isinstance(value, dict):
+ result.update(flatten_dict(value))
+ else:
+ result[key] = value
+ return result
diff --git a/monai/utils/module.py b/monai/utils/module.py
index 6f301d8067..d3f2ff09f2 100644
--- a/monai/utils/module.py
+++ b/monai/utils/module.py
@@ -18,14 +18,14 @@
import re
import sys
import warnings
-from collections.abc import Callable, Collection, Hashable, Mapping
+from collections.abc import Callable, Collection, Hashable, Iterable, Mapping
from functools import partial, wraps
from importlib import import_module
from pkgutil import walk_packages
from pydoc import locate
from re import match
from types import FunctionType, ModuleType
-from typing import Any, Iterable, cast
+from typing import Any, cast
import torch
@@ -43,13 +43,11 @@
"InvalidPyTorchVersionError",
"OptionalImportError",
"exact_version",
- "export",
"damerau_levenshtein_distance",
"look_up_option",
"min_version",
"optional_import",
"require_pkg",
- "load_submodules",
"instantiate",
"get_full_type_name",
"get_package_version",
@@ -172,28 +170,6 @@ def damerau_levenshtein_distance(s1: str, s2: str) -> int:
return d[string_1_length - 1, string_2_length - 1]
-def export(modname):
- """
- Make the decorated object a member of the named module. This will also add the object under its aliases if it has
- a `__aliases__` member, thus this decorator should be before the `alias` decorator to pick up those names. Alias
- names which conflict with package names or existing members will be ignored.
- """
-
- def _inner(obj):
- mod = import_module(modname)
- if not hasattr(mod, obj.__name__):
- setattr(mod, obj.__name__, obj)
-
- # add the aliases for `obj` to the target module
- for alias in getattr(obj, "__aliases__", ()):
- if not hasattr(mod, alias):
- setattr(mod, alias, obj)
-
- return obj
-
- return _inner
-
-
def load_submodules(
basemod: ModuleType, load_all: bool = True, exclude_pattern: str = "(.*[tT]est.*)|(_.*)"
) -> tuple[list[ModuleType], list[str]]:
@@ -209,8 +185,11 @@ def load_submodules(
if (is_pkg or load_all) and name not in sys.modules and match(exclude_pattern, name) is None:
try:
mod = import_module(name)
- importer.find_spec(name).loader.load_module(name) # type: ignore
- submodules.append(mod)
+ mod_spec = importer.find_spec(name) # type: ignore
+ if mod_spec and mod_spec.loader:
+ loader = mod_spec.loader
+ loader.exec_module(mod)
+ submodules.append(mod)
except OptionalImportError:
pass # could not import the optional deps., they are ignored
except ImportError as e:
@@ -561,7 +540,7 @@ def version_leq(lhs: str, rhs: str) -> bool:
"""
lhs, rhs = str(lhs), str(rhs)
- pkging, has_ver = optional_import("pkg_resources", name="packaging")
+ pkging, has_ver = optional_import("packaging.Version")
if has_ver:
try:
return cast(bool, pkging.version.Version(lhs) <= pkging.version.Version(rhs))
@@ -588,7 +567,8 @@ def version_geq(lhs: str, rhs: str) -> bool:
"""
lhs, rhs = str(lhs), str(rhs)
- pkging, has_ver = optional_import("pkg_resources", name="packaging")
+ pkging, has_ver = optional_import("packaging.Version")
+
if has_ver:
try:
return cast(bool, pkging.version.Version(lhs) >= pkging.version.Version(rhs))
@@ -626,7 +606,7 @@ def pytorch_after(major: int, minor: int, patch: int = 0, current_ver_string: st
if current_ver_string is None:
_env_var = os.environ.get("PYTORCH_VER", "")
current_ver_string = _env_var if _env_var else torch.__version__
- ver, has_ver = optional_import("pkg_resources", name="parse_version")
+ ver, has_ver = optional_import("packaging.version", name="parse")
if has_ver:
return ver(".".join((f"{major}", f"{minor}", f"{patch}"))) <= ver(f"{current_ver_string}") # type: ignore
parts = f"{current_ver_string}".split("+", 1)[0].split(".", 3)
@@ -654,3 +634,44 @@ def pytorch_after(major: int, minor: int, patch: int = 0, current_ver_string: st
if is_prerelease:
return False
return True
+
+
+@functools.lru_cache(None)
+def compute_capabilities_after(major: int, minor: int = 0, current_ver_string: str | None = None) -> bool:
+ """
+ Compute whether the current system GPU CUDA compute capability is after or equal to the specified version.
+ The current system GPU CUDA compute capability is determined by the first GPU in the system.
+ The compared version is a string in the form of "major.minor".
+
+ Args:
+ major: major version number to be compared with.
+ minor: minor version number to be compared with. Defaults to 0.
+ current_ver_string: if None, the current system GPU CUDA compute capability will be used.
+
+ Returns:
+ True if the current system GPU CUDA compute capability is greater than or equal to the specified version.
+ """
+ if current_ver_string is None:
+ cuda_available = torch.cuda.is_available()
+ pynvml, has_pynvml = optional_import("pynvml")
+ if not has_pynvml: # assuming that the user has Ampere and later GPU
+ return True
+ if not cuda_available:
+ return False
+ else:
+ pynvml.nvmlInit()
+ handle = pynvml.nvmlDeviceGetHandleByIndex(0) # get the first GPU
+ major_c, minor_c = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
+ current_ver_string = f"{major_c}.{minor_c}"
+ pynvml.nvmlShutdown()
+
+ ver, has_ver = optional_import("packaging.version", name="parse")
+ if has_ver:
+ return ver(".".join((f"{major}", f"{minor}"))) <= ver(f"{current_ver_string}") # type: ignore
+ parts = f"{current_ver_string}".split("+", 1)[0].split(".", 2)
+ while len(parts) < 2:
+ parts += ["0"]
+ c_major, c_minor = parts[:2]
+ c_mn = int(c_major), int(c_minor)
+ mn = int(major), int(minor)
+ return c_mn > mn
diff --git a/monai/utils/ordering.py b/monai/utils/ordering.py
new file mode 100644
index 0000000000..1be61f98ab
--- /dev/null
+++ b/monai/utils/ordering.py
@@ -0,0 +1,207 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import numpy as np
+
+from monai.utils.enums import OrderingTransformations, OrderingType
+
+
+class Ordering:
+ """
+ Ordering class that projects a 2D or 3D image into a 1D sequence. It also allows the image to be transformed with
+ one of the following transformations:
+ Reflection (see np.flip for more details).
+ Transposition (see np.transpose for more details).
+ 90-degree rotation (see np.rot90 for more details).
+
+ The transformations are applied in the order specified by the transformation_order parameter.
+
+ Args:
+ ordering_type: The ordering type. One of the following:
+ - 'raster_scan': The image is projected into a 1D sequence by scanning the image from left to right and from
+ top to bottom. Also called a row major ordering.
+ - 's_curve': The image is projected into a 1D sequence by scanning the image in a circular snake like
+ pattern from top left towards right gowing in a spiral towards the center.
+ - random': The image is projected into a 1D sequence by randomly shuffling the image.
+ spatial_dims: The number of spatial dimensions of the image.
+ dimensions: The dimensions of the image.
+ reflected_spatial_dims: A tuple of booleans indicating whether to reflect the image along each spatial dimension.
+ transpositions_axes: A tuple of tuples indicating the axes to transpose the image along.
+ rot90_axes: A tuple of tuples indicating the axes to rotate the image along.
+ transformation_order: The order in which to apply the transformations.
+ """
+
+ def __init__(
+ self,
+ ordering_type: str,
+ spatial_dims: int,
+ dimensions: tuple[int, int, int] | tuple[int, int, int, int],
+ reflected_spatial_dims: tuple[bool, bool] | None = None,
+ transpositions_axes: tuple[tuple[int, int], ...] | tuple[tuple[int, int, int], ...] | None = None,
+ rot90_axes: tuple[tuple[int, int], ...] | None = None,
+ transformation_order: tuple[str, ...] = (
+ OrderingTransformations.TRANSPOSE.value,
+ OrderingTransformations.ROTATE_90.value,
+ OrderingTransformations.REFLECT.value,
+ ),
+ ) -> None:
+ super().__init__()
+ self.ordering_type = ordering_type
+
+ if self.ordering_type not in list(OrderingType):
+ raise ValueError(
+ f"ordering_type must be one of the following {list(OrderingType)}, but got {self.ordering_type}."
+ )
+
+ self.spatial_dims = spatial_dims
+ self.dimensions = dimensions
+
+ if len(dimensions) != self.spatial_dims + 1:
+ raise ValueError(f"dimensions must be of length {self.spatial_dims + 1}, but got {len(dimensions)}.")
+
+ self.reflected_spatial_dims = reflected_spatial_dims
+ self.transpositions_axes = transpositions_axes
+ self.rot90_axes = rot90_axes
+ if len(set(transformation_order)) != len(transformation_order):
+ raise ValueError(f"No duplicates are allowed. Received {transformation_order}.")
+
+ for transformation in transformation_order:
+ if transformation not in list(OrderingTransformations):
+ raise ValueError(
+ f"Valid transformations are {list(OrderingTransformations)} but received {transformation}."
+ )
+ self.transformation_order = transformation_order
+
+ self.template = self._create_template()
+ self._sequence_ordering = self._create_ordering()
+ self._revert_sequence_ordering = np.argsort(self._sequence_ordering)
+
+ def __call__(self, x: np.ndarray) -> np.ndarray:
+ x = x[self._sequence_ordering]
+
+ return x
+
+ def get_sequence_ordering(self) -> np.ndarray:
+ return self._sequence_ordering
+
+ def get_revert_sequence_ordering(self) -> np.ndarray:
+ return self._revert_sequence_ordering
+
+ def _create_ordering(self) -> np.ndarray:
+ self.template = self._transform_template()
+ order = self._order_template(template=self.template)
+
+ return order
+
+ def _create_template(self) -> np.ndarray:
+ spatial_dimensions = self.dimensions[1:]
+ template = np.arange(np.prod(spatial_dimensions)).reshape(*spatial_dimensions)
+
+ return template
+
+ def _transform_template(self) -> np.ndarray:
+ for transformation in self.transformation_order:
+ if transformation == OrderingTransformations.TRANSPOSE.value:
+ self.template = self._transpose_template(template=self.template)
+ elif transformation == OrderingTransformations.ROTATE_90.value:
+ self.template = self._rot90_template(template=self.template)
+ elif transformation == OrderingTransformations.REFLECT.value:
+ self.template = self._flip_template(template=self.template)
+
+ return self.template
+
+ def _transpose_template(self, template: np.ndarray) -> np.ndarray:
+ if self.transpositions_axes is not None:
+ for axes in self.transpositions_axes:
+ template = np.transpose(template, axes=axes)
+
+ return template
+
+ def _flip_template(self, template: np.ndarray) -> np.ndarray:
+ if self.reflected_spatial_dims is not None:
+ for axis, to_reflect in enumerate(self.reflected_spatial_dims):
+ template = np.flip(template, axis=axis) if to_reflect else template
+
+ return template
+
+ def _rot90_template(self, template: np.ndarray) -> np.ndarray:
+ if self.rot90_axes is not None:
+ for axes in self.rot90_axes:
+ template = np.rot90(template, axes=axes)
+
+ return template
+
+ def _order_template(self, template: np.ndarray) -> np.ndarray:
+ depths = None
+ if self.spatial_dims == 2:
+ rows, columns = template.shape[0], template.shape[1]
+ else:
+ rows, columns, depths = (template.shape[0], template.shape[1], template.shape[2])
+
+ sequence = eval(f"self.{self.ordering_type}_idx")(rows, columns, depths)
+
+ ordering = np.array([template[tuple(e)] for e in sequence])
+
+ return ordering
+
+ @staticmethod
+ def raster_scan_idx(rows: int, cols: int, depths: int | None = None) -> np.ndarray:
+ idx: list[tuple] = []
+
+ for r in range(rows):
+ for c in range(cols):
+ if depths is not None:
+ for d in range(depths):
+ idx.append((r, c, d))
+ else:
+ idx.append((r, c))
+
+ idx_np = np.array(idx)
+
+ return idx_np
+
+ @staticmethod
+ def s_curve_idx(rows: int, cols: int, depths: int | None = None) -> np.ndarray:
+ idx: list[tuple] = []
+
+ for r in range(rows):
+ col_idx = range(cols) if r % 2 == 0 else range(cols - 1, -1, -1)
+ for c in col_idx:
+ if depths:
+ depth_idx = range(depths) if c % 2 == 0 else range(depths - 1, -1, -1)
+
+ for d in depth_idx:
+ idx.append((r, c, d))
+ else:
+ idx.append((r, c))
+
+ idx_np = np.array(idx)
+
+ return idx_np
+
+ @staticmethod
+ def random_idx(rows: int, cols: int, depths: int | None = None) -> np.ndarray:
+ idx: list[tuple] = []
+
+ for r in range(rows):
+ for c in range(cols):
+ if depths:
+ for d in range(depths):
+ idx.append((r, c, d))
+ else:
+ idx.append((r, c))
+
+ idx_np = np.array(idx)
+ np.random.shuffle(idx_np)
+
+ return idx_np
diff --git a/monai/utils/state_cacher.py b/monai/utils/state_cacher.py
index d37e7abde4..60a074544b 100644
--- a/monai/utils/state_cacher.py
+++ b/monai/utils/state_cacher.py
@@ -15,8 +15,9 @@
import os
import pickle
import tempfile
+from collections.abc import Hashable
from types import ModuleType
-from typing import Any, Hashable
+from typing import Any
import torch
from torch.serialization import DEFAULT_PROTOCOL
diff --git a/monai/utils/type_conversion.py b/monai/utils/type_conversion.py
index e4f97fc4a6..420e935b33 100644
--- a/monai/utils/type_conversion.py
+++ b/monai/utils/type_conversion.py
@@ -33,6 +33,7 @@
"get_equivalent_dtype",
"convert_data_type",
"get_dtype",
+ "get_dtype_string",
"convert_to_cupy",
"convert_to_numpy",
"convert_to_tensor",
@@ -102,6 +103,13 @@ def get_dtype(data: Any) -> DtypeLike | torch.dtype:
return type(data)
+def get_dtype_string(dtype: DtypeLike | torch.dtype) -> str:
+ """Get a string representation of the dtype."""
+ if isinstance(dtype, torch.dtype):
+ return str(dtype)[6:]
+ return str(dtype)[3:]
+
+
def convert_to_tensor(
data: Any,
dtype: DtypeLike | torch.dtype = None,
diff --git a/monai/visualize/class_activation_maps.py b/monai/visualize/class_activation_maps.py
index 6d1e8dfd03..489a563818 100644
--- a/monai/visualize/class_activation_maps.py
+++ b/monai/visualize/class_activation_maps.py
@@ -290,7 +290,7 @@ def __init__(
)
self.fc_layers = fc_layers
- def compute_map(self, x, class_idx=None, layer_idx=-1, **kwargs):
+ def compute_map(self, x, class_idx=None, layer_idx=-1, **kwargs): # type: ignore[override]
logits, acti, _ = self.nn_module(x, **kwargs)
acti = acti[layer_idx]
if class_idx is None:
@@ -302,7 +302,7 @@ def compute_map(self, x, class_idx=None, layer_idx=-1, **kwargs):
output = torch.stack([output[i, b : b + 1] for i, b in enumerate(class_idx)], dim=0)
return output.reshape(b, 1, *spatial) # resume the spatial dims on the selected class
- def __call__(self, x, class_idx=None, layer_idx=-1, **kwargs):
+ def __call__(self, x, class_idx=None, layer_idx=-1, **kwargs): # type: ignore[override]
"""
Compute the activation map with upsampling and postprocessing.
@@ -361,7 +361,7 @@ class GradCAM(CAMBase):
"""
- def compute_map(self, x, class_idx=None, retain_graph=False, layer_idx=-1, **kwargs):
+ def compute_map(self, x, class_idx=None, retain_graph=False, layer_idx=-1, **kwargs): # type: ignore[override]
_, acti, grad = self.nn_module(x, class_idx=class_idx, retain_graph=retain_graph, **kwargs)
acti, grad = acti[layer_idx], grad[layer_idx]
b, c, *spatial = grad.shape
@@ -369,7 +369,7 @@ def compute_map(self, x, class_idx=None, retain_graph=False, layer_idx=-1, **kwa
acti_map = (weights * acti).sum(1, keepdim=True)
return F.relu(acti_map)
- def __call__(self, x, class_idx=None, layer_idx=-1, retain_graph=False, **kwargs):
+ def __call__(self, x, class_idx=None, layer_idx=-1, retain_graph=False, **kwargs): # type: ignore[override]
"""
Compute the activation map with upsampling and postprocessing.
@@ -401,7 +401,7 @@ class GradCAMpp(GradCAM):
"""
- def compute_map(self, x, class_idx=None, retain_graph=False, layer_idx=-1, **kwargs):
+ def compute_map(self, x, class_idx=None, retain_graph=False, layer_idx=-1, **kwargs): # type: ignore[override]
_, acti, grad = self.nn_module(x, class_idx=class_idx, retain_graph=retain_graph, **kwargs)
acti, grad = acti[layer_idx], grad[layer_idx]
b, c, *spatial = grad.shape
diff --git a/monai/visualize/img2tensorboard.py b/monai/visualize/img2tensorboard.py
index e7884e9b1f..677640bd04 100644
--- a/monai/visualize/img2tensorboard.py
+++ b/monai/visualize/img2tensorboard.py
@@ -176,7 +176,9 @@ def plot_2d_or_3d_image(
# as the `d` data has no batch dim, reduce the spatial dim index if positive
frame_dim = frame_dim - 1 if frame_dim > 0 else frame_dim
- d: np.ndarray = data_index.detach().cpu().numpy() if isinstance(data_index, torch.Tensor) else data_index
+ d: np.ndarray = (
+ data_index.detach().cpu().numpy() if isinstance(data_index, torch.Tensor) else np.asarray(data_index)
+ )
if d.ndim == 2:
d = rescale_array(d, 0, 1) # type: ignore
diff --git a/monai/visualize/utils.py b/monai/visualize/utils.py
index f6718fe7a5..88c9a0d66a 100644
--- a/monai/visualize/utils.py
+++ b/monai/visualize/utils.py
@@ -24,11 +24,9 @@
from monai.utils.type_conversion import convert_data_type, convert_to_dst_type
if TYPE_CHECKING:
- from matplotlib import cm
from matplotlib import pyplot as plt
else:
plt, _ = optional_import("matplotlib", name="pyplot")
- cm, _ = optional_import("matplotlib", name="cm")
__all__ = ["matshow3d", "blend_images"]
@@ -210,7 +208,7 @@ def blend_images(
image = repeat(image, 3, axis=0)
def get_label_rgb(cmap: str, label: NdarrayOrTensor) -> NdarrayOrTensor:
- _cmap = cm.get_cmap(cmap)
+ _cmap = plt.colormaps.get_cmap(cmap)
label_np, *_ = convert_data_type(label, np.ndarray)
label_rgb_np = _cmap(label_np[0])
label_rgb_np = np.moveaxis(label_rgb_np, -1, 0)[:3]
diff --git a/pyproject.toml b/pyproject.toml
index cd8a510b04..9dc9cf619b 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,6 +4,7 @@ requires = [
"setuptools",
"torch>=1.9",
"ninja",
+ "packaging"
]
[tool.black]
@@ -38,8 +39,18 @@ exclude = "monai/bundle/__main__.py"
[tool.ruff]
line-length = 133
-ignore-init-module-imports = true
-ignore = ["F401", "E741"]
+target-version = "py39"
+
+[tool.ruff.lint]
+select = [
+ "E", "F", "W", # flake8
+ "NPY", # NumPy specific rules
+]
+extend-ignore = [
+ "E741", # ambiguous variable name
+ "F401", # unused import
+ "NPY002", # numpy-legacy-random
+]
[tool.pytype]
# Space-separated list of files or directories to exclude.
diff --git a/requirements-dev.txt b/requirements-dev.txt
index f7f9a6db45..bffe304df4 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -2,7 +2,7 @@
-r requirements-min.txt
pytorch-ignite==0.4.11
gdown>=4.7.3
-scipy>=1.7.1
+scipy>=1.12.0; python_version >= '3.9'
itk>=5.2
nibabel
pillow!=8.3.0 # https://github.com/python-pillow/Pillow/issues/5571
@@ -11,7 +11,7 @@ scikit-image>=0.19.0
tqdm>=4.47.0
lmdb
flake8>=3.8.1
-flake8-bugbear
+flake8-bugbear<=24.2.6 # https://github.com/Project-MONAI/MONAI/issues/7690
flake8-comprehensions
mccabe
pep8-naming
@@ -21,8 +21,8 @@ black>=22.12
isort>=5.1
ruff
pytype>=2020.6.1; platform_system != "Windows"
-types-pkg_resources
-mypy>=1.5.0
+types-setuptools
+mypy>=1.5.0, <1.12.0
ninja
torchio
torchvision
@@ -34,10 +34,10 @@ tifffile; platform_system == "Linux" or platform_system == "Darwin"
pandas
requests
einops
-transformers>=4.36.0
-mlflow>=1.28.0
+transformers>=4.36.0, <4.41.0; python_version <= '3.10'
+mlflow>=2.12.2
clearml>=1.10.0rc0
-matplotlib!=3.5.0
+matplotlib>=3.6.3
tensorboardX
types-PyYAML
pyyaml
@@ -47,14 +47,18 @@ pynrrd
pre-commit
pydicom
h5py
-nni; platform_system == "Linux" and "arm" not in platform_machine and "aarch" not in platform_machine
+nni==2.10.1; platform_system == "Linux" and "arm" not in platform_machine and "aarch" not in platform_machine
optuna
git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded
onnx>=1.13.0
onnxruntime; python_version <= '3.10'
typeguard<3 # https://github.com/microsoft/nni/issues/5457
-filelock!=3.12.0 # https://github.com/microsoft/nni/issues/5523
+filelock<3.12.0 # https://github.com/microsoft/nni/issues/5523
zarr
lpips==0.1.4
nvidia-ml-py
huggingface_hub
+pyamg>=5.0.0
+git+https://github.com/facebookresearch/segment-anything.git@6fdee8f2727f4506cfbbe553e23b895e27956588
+onnx_graphsurgeon
+polygraphy
diff --git a/requirements-min.txt b/requirements-min.txt
index ad0bb1ef20..21cf9d5e5c 100644
--- a/requirements-min.txt
+++ b/requirements-min.txt
@@ -1,5 +1,7 @@
# Requirements for minimal tests
-r requirements.txt
-setuptools>=50.3.0,<66.0.0,!=60.6.0
+setuptools>=50.3.0,<66.0.0,!=60.6.0 ; python_version < "3.12"
+setuptools>=70.2.0; python_version >= "3.12"
coverage>=5.5
parameterized
+packaging
diff --git a/requirements.txt b/requirements.txt
index 1569646794..e184322c13 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,2 +1,2 @@
torch>=1.9
-numpy>=1.20
+numpy>=1.24,<2.0
diff --git a/runtests.sh b/runtests.sh
index 0b3e20ce49..65e3a2bb6b 100755
--- a/runtests.sh
+++ b/runtests.sh
@@ -167,7 +167,7 @@ function clang_format {
}
function is_pip_installed() {
- return $("${PY_EXE}" -c "import sys, pkgutil; sys.exit(0 if pkgutil.find_loader(sys.argv[1]) else 1)" $1)
+ return $("${PY_EXE}" -c "import sys, importlib.util; sys.exit(0 if importlib.util.find_spec(sys.argv[1]) else 1)" $1)
}
function clean_py {
diff --git a/setup.cfg b/setup.cfg
index 9e7a8fdada..ecfd717aff 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -21,9 +21,9 @@ classifiers =
Intended Audience :: Healthcare Industry
Programming Language :: C++
Programming Language :: Python :: 3
- Programming Language :: Python :: 3.8
Programming Language :: Python :: 3.9
Programming Language :: Python :: 3.10
+ Programming Language :: Python :: 3.11
Topic :: Scientific/Engineering
Topic :: Scientific/Engineering :: Artificial Intelligence
Topic :: Scientific/Engineering :: Medical Science Apps.
@@ -33,23 +33,24 @@ classifiers =
Typing :: Typed
[options]
-python_requires = >= 3.8
+python_requires = >= 3.9
# for compiling and develop setup only
# no need to specify the versions so that we could
# compile for multiple targeted versions.
setup_requires =
torch
ninja
+ packaging
install_requires =
torch>=1.9
- numpy>=1.20
+ numpy>=1.24,<2.0
[options.extras_require]
all =
nibabel
ninja
scikit-image>=0.14.2
- scipy>=1.7.1
+ scipy>=1.12.0; python_version >= '3.9'
pillow
tensorboard
gdown>=4.7.3
@@ -66,10 +67,10 @@ all =
imagecodecs
pandas
einops
- transformers<4.22; python_version <= '3.10'
- mlflow>=1.28.0
+ transformers>=4.36.0, <4.41.0; python_version <= '3.10'
+ mlflow>=2.12.2
clearml>=1.10.0rc0
- matplotlib
+ matplotlib>=3.6.3
tensorboardX
pyyaml
fire
@@ -85,6 +86,7 @@ all =
lpips==0.1.4
nvidia-ml-py
huggingface_hub
+ pyamg>=5.0.0
nibabel =
nibabel
ninja =
@@ -92,7 +94,7 @@ ninja =
skimage =
scikit-image>=0.14.2
scipy =
- scipy>=1.7.1
+ scipy>=1.12.0; python_version >= '3.9'
pillow =
pillow!=8.3.0
tensorboard =
@@ -126,11 +128,11 @@ pandas =
einops =
einops
transformers =
- transformers<4.22; python_version <= '3.10'
+ transformers>=4.36.0, <4.41.0; python_version <= '3.10'
mlflow =
- mlflow
+ mlflow>=2.12.2
matplotlib =
- matplotlib
+ matplotlib>=3.6.3
clearml =
clearml
tensorboardX =
@@ -139,6 +141,8 @@ pyyaml =
pyyaml
fire =
fire
+packaging =
+ packaging
jsonschema =
jsonschema
pynrrd =
@@ -160,11 +164,18 @@ lpips =
lpips==0.1.4
pynvml =
nvidia-ml-py
+polygraphy =
+ polygraphy
+
# # workaround https://github.com/Project-MONAI/MONAI/issues/5882
# MetricsReloaded =
-# MetricsReloaded @ git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded
+ # MetricsReloaded @ git+https://github.com/Project-MONAI/MetricsReloaded@monai-support#egg=MetricsReloaded
huggingface_hub =
huggingface_hub
+pyamg =
+ pyamg>=5.0.0
+# segment-anything =
+# segment-anything @ git+https://github.com/facebookresearch/segment-anything@6fdee8f2727f4506cfbbe553e23b895e27956588#egg=segment-anything
[flake8]
select = B,C,E,F,N,P,T4,W,B9
diff --git a/setup.py b/setup.py
index b90d9d0976..576743c1f7 100644
--- a/setup.py
+++ b/setup.py
@@ -17,7 +17,7 @@
import sys
import warnings
-import pkg_resources
+from packaging import version
from setuptools import find_packages, setup
import versioneer
@@ -40,7 +40,7 @@
BUILD_CUDA = FORCE_CUDA or (torch.cuda.is_available() and (CUDA_HOME is not None))
- _pt_version = pkg_resources.parse_version(torch.__version__).release
+ _pt_version = version.parse(torch.__version__).release
if _pt_version is None or len(_pt_version) < 3:
raise AssertionError("unknown torch version")
TORCH_VERSION = int(_pt_version[0]) * 10000 + int(_pt_version[1]) * 100 + int(_pt_version[2])
diff --git a/tests/hvd_evenly_divisible_all_gather.py b/tests/hvd_evenly_divisible_all_gather.py
index 78c6ca06bc..732ad13b83 100644
--- a/tests/hvd_evenly_divisible_all_gather.py
+++ b/tests/hvd_evenly_divisible_all_gather.py
@@ -30,10 +30,10 @@ def test_data(self):
self._run()
def _run(self):
- if hvd.rank() == 0:
- data1 = torch.tensor([[1, 2], [3, 4]])
- data2 = torch.tensor([[1.0, 2.0]])
- data3 = torch.tensor(7)
+ # if hvd.rank() == 0:
+ data1 = torch.tensor([[1, 2], [3, 4]])
+ data2 = torch.tensor([[1.0, 2.0]])
+ data3 = torch.tensor(7)
if hvd.rank() == 1:
data1 = torch.tensor([[5, 6]])
diff --git a/tests/min_tests.py b/tests/min_tests.py
index 8128bb7b84..f39d3f9843 100644
--- a/tests/min_tests.py
+++ b/tests/min_tests.py
@@ -154,6 +154,7 @@ def run_testsuit():
"test_plot_2d_or_3d_image",
"test_png_rw",
"test_prepare_batch_default",
+ "test_prepare_batch_diffusion",
"test_prepare_batch_extra_input",
"test_prepare_batch_hovernet",
"test_rand_grid_patch",
@@ -185,6 +186,7 @@ def run_testsuit():
"test_torchvisiond",
"test_transchex",
"test_transformerblock",
+ "test_trt_compile",
"test_unetr",
"test_unetr_block",
"test_vit",
@@ -208,6 +210,9 @@ def run_testsuit():
"test_zarr_avg_merger",
"test_perceptual_loss",
"test_ultrasound_confidence_map_transform",
+ "test_vista3d_utils",
+ "test_vista3d_transforms",
+ "test_matshow3d",
]
assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}"
diff --git a/tests/ngc_bundle_download.py b/tests/ngc_bundle_download.py
index 01dc044870..107114861c 100644
--- a/tests/ngc_bundle_download.py
+++ b/tests/ngc_bundle_download.py
@@ -127,7 +127,7 @@ def test_loading_mmar(self, item):
in_channels=1,
img_size=(96, 96, 96),
patch_size=(16, 16, 16),
- pos_embed="conv",
+ proj_type="conv",
hidden_size=768,
mlp_dim=3072,
)
diff --git a/tests/nonconfig_workflow.py b/tests/nonconfig_workflow.py
index 7b5328bf72..b2c44c12c6 100644
--- a/tests/nonconfig_workflow.py
+++ b/tests/nonconfig_workflow.py
@@ -36,8 +36,8 @@ class NonConfigWorkflow(BundleWorkflow):
"""
- def __init__(self, filename, output_dir):
- super().__init__(workflow_type="inference")
+ def __init__(self, filename, output_dir, meta_file=None, logging_file=None):
+ super().__init__(workflow_type="inference", meta_file=meta_file, logging_file=logging_file)
self.filename = filename
self.output_dir = output_dir
self._bundle_root = "will override"
diff --git a/tests/test_affine_transform.py b/tests/test_affine_transform.py
index 6ea036bce8..11464070e0 100644
--- a/tests/test_affine_transform.py
+++ b/tests/test_affine_transform.py
@@ -133,28 +133,17 @@ def test_to_norm_affine_ill(self, affine, src_size, dst_size, align_corners):
class TestAffineTransform(unittest.TestCase):
- def test_affine_shift(self):
- affine = torch.as_tensor([[1.0, 0.0, 0.0], [0.0, 1.0, -1.0]])
- image = torch.as_tensor([[[[4.0, 1.0, 3.0, 2.0], [7.0, 6.0, 8.0, 5.0], [3.0, 5.0, 3.0, 6.0]]]])
- out = AffineTransform(align_corners=False)(image, affine)
- out = out.detach().cpu().numpy()
- expected = [[[[0, 4, 1, 3], [0, 7, 6, 8], [0, 3, 5, 3]]]]
- np.testing.assert_allclose(out, expected, atol=1e-5, rtol=_rtol)
-
- def test_affine_shift_1(self):
- affine = torch.as_tensor([[1.0, 0.0, -1.0], [0.0, 1.0, -1.0]])
- image = torch.as_tensor([[[[4.0, 1.0, 3.0, 2.0], [7.0, 6.0, 8.0, 5.0], [3.0, 5.0, 3.0, 6.0]]]])
- out = AffineTransform(align_corners=False)(image, affine)
- out = out.detach().cpu().numpy()
- expected = [[[[0, 0, 0, 0], [0, 4, 1, 3], [0, 7, 6, 8]]]]
- np.testing.assert_allclose(out, expected, atol=1e-5, rtol=_rtol)
-
- def test_affine_shift_2(self):
- affine = torch.as_tensor([[1.0, 0.0, -1.0], [0.0, 1.0, 0.0]])
+ @parameterized.expand(
+ [
+ (torch.as_tensor([[1.0, 0.0, 0.0], [0.0, 1.0, -1.0]]), [[[[0, 4, 1, 3], [0, 7, 6, 8], [0, 3, 5, 3]]]]),
+ (torch.as_tensor([[1.0, 0.0, -1.0], [0.0, 1.0, -1.0]]), [[[[0, 0, 0, 0], [0, 4, 1, 3], [0, 7, 6, 8]]]]),
+ (torch.as_tensor([[1.0, 0.0, -1.0], [0.0, 1.0, 0.0]]), [[[[0, 0, 0, 0], [4, 1, 3, 2], [7, 6, 8, 5]]]]),
+ ]
+ )
+ def test_affine_transforms(self, affine, expected):
image = torch.as_tensor([[[[4.0, 1.0, 3.0, 2.0], [7.0, 6.0, 8.0, 5.0], [3.0, 5.0, 3.0, 6.0]]]])
out = AffineTransform(align_corners=False)(image, affine)
out = out.detach().cpu().numpy()
- expected = [[[[0, 0, 0, 0], [4, 1, 3, 2], [7, 6, 8, 5]]]]
np.testing.assert_allclose(out, expected, atol=1e-5, rtol=_rtol)
def test_zoom(self):
diff --git a/tests/test_apply_transform_to_points.py b/tests/test_apply_transform_to_points.py
new file mode 100644
index 0000000000..0c16603996
--- /dev/null
+++ b/tests/test_apply_transform_to_points.py
@@ -0,0 +1,81 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import unittest
+
+import torch
+from parameterized import parameterized
+
+from monai.data import MetaTensor
+from monai.transforms.utility.array import ApplyTransformToPoints
+from monai.utils import set_determinism
+
+set_determinism(seed=0)
+
+DATA_2D = torch.rand(1, 64, 64)
+DATA_3D = torch.rand(1, 64, 64, 64)
+POINT_2D_WORLD = torch.tensor([[[2, 2], [2, 4], [4, 6]]])
+POINT_2D_IMAGE = torch.tensor([[[1, 1], [1, 2], [2, 3]]])
+POINT_2D_IMAGE_RAS = torch.tensor([[[-1, -1], [-1, -2], [-2, -3]]])
+POINT_3D_WORLD = torch.tensor([[[2, 4, 6], [8, 10, 12]], [[14, 16, 18], [20, 22, 24]]])
+POINT_3D_IMAGE = torch.tensor([[[-8, 8, 6], [-2, 14, 12]], [[4, 20, 18], [10, 26, 24]]])
+POINT_3D_IMAGE_RAS = torch.tensor([[[-12, 0, 6], [-18, -6, 12]], [[-24, -12, 18], [-30, -18, 24]]])
+AFFINE_1 = torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])
+AFFINE_2 = torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]])
+
+TEST_CASES = [
+ [MetaTensor(DATA_2D, affine=AFFINE_1), POINT_2D_WORLD, None, True, False, POINT_2D_IMAGE],
+ [None, MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), None, False, False, POINT_2D_WORLD],
+ [None, MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), AFFINE_1, False, False, POINT_2D_WORLD],
+ [MetaTensor(DATA_2D, affine=AFFINE_1), POINT_2D_WORLD, None, True, True, POINT_2D_IMAGE_RAS],
+ [MetaTensor(DATA_3D, affine=AFFINE_2), POINT_3D_WORLD, None, True, False, POINT_3D_IMAGE],
+ [
+ MetaTensor(DATA_3D, affine=AFFINE_2),
+ MetaTensor(POINT_3D_IMAGE, affine=AFFINE_2),
+ None,
+ False,
+ False,
+ POINT_3D_WORLD,
+ ],
+ [MetaTensor(DATA_3D, affine=AFFINE_2), POINT_3D_WORLD, None, True, True, POINT_3D_IMAGE_RAS],
+]
+
+TEST_CASES_WRONG = [
+ [POINT_2D_WORLD, True, None],
+ [POINT_2D_WORLD.unsqueeze(0), False, None],
+ [POINT_3D_WORLD[..., 0:1], False, None],
+ [POINT_3D_WORLD, False, torch.tensor([[[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]]])],
+]
+
+
+class TestCoordinateTransform(unittest.TestCase):
+ @parameterized.expand(TEST_CASES)
+ def test_transform_coordinates(self, image, points, affine, invert_affine, affine_lps_to_ras, expected_output):
+ transform = ApplyTransformToPoints(
+ dtype=torch.int64, affine=affine, invert_affine=invert_affine, affine_lps_to_ras=affine_lps_to_ras
+ )
+ affine = image.affine if image is not None else None
+ output = transform(points, affine)
+ self.assertTrue(torch.allclose(output, expected_output))
+ invert_out = transform.inverse(output)
+ self.assertTrue(torch.allclose(invert_out, points))
+
+ @parameterized.expand(TEST_CASES_WRONG)
+ def test_wrong_input(self, input, invert_affine, affine):
+ transform = ApplyTransformToPoints(dtype=torch.int64, invert_affine=invert_affine)
+ with self.assertRaises(ValueError):
+ transform(input, affine)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_apply_transform_to_pointsd.py b/tests/test_apply_transform_to_pointsd.py
new file mode 100644
index 0000000000..978113931c
--- /dev/null
+++ b/tests/test_apply_transform_to_pointsd.py
@@ -0,0 +1,185 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import unittest
+
+import torch
+from parameterized import parameterized
+
+from monai.data import MetaTensor
+from monai.transforms.utility.dictionary import ApplyTransformToPointsd
+from monai.utils import set_determinism
+
+set_determinism(seed=0)
+
+DATA_2D = torch.rand(1, 64, 64)
+DATA_3D = torch.rand(1, 64, 64, 64)
+POINT_2D_WORLD = torch.tensor([[[2, 2], [2, 4], [4, 6]]])
+POINT_2D_IMAGE = torch.tensor([[[1, 1], [1, 2], [2, 3]]])
+POINT_2D_IMAGE_RAS = torch.tensor([[[-1, -1], [-1, -2], [-2, -3]]])
+POINT_3D_WORLD = torch.tensor([[[2, 4, 6], [8, 10, 12]], [[14, 16, 18], [20, 22, 24]]])
+POINT_3D_IMAGE = torch.tensor([[[-8, 8, 6], [-2, 14, 12]], [[4, 20, 18], [10, 26, 24]]])
+POINT_3D_IMAGE_RAS = torch.tensor([[[-12, 0, 6], [-18, -6, 12]], [[-24, -12, 18], [-30, -18, 24]]])
+AFFINE_1 = torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])
+AFFINE_2 = torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]])
+
+TEST_CASES = [
+ [MetaTensor(DATA_2D, affine=AFFINE_1), POINT_2D_WORLD, None, True, False, POINT_2D_IMAGE], # use image affine
+ [None, MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), None, False, False, POINT_2D_WORLD], # use point affine
+ [None, MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), AFFINE_1, False, False, POINT_2D_WORLD], # use input affine
+ [None, POINT_2D_WORLD, AFFINE_1, True, False, POINT_2D_IMAGE], # use input affine
+ [
+ MetaTensor(DATA_2D, affine=AFFINE_1),
+ POINT_2D_WORLD,
+ None,
+ True,
+ True,
+ POINT_2D_IMAGE_RAS,
+ ], # test affine_lps_to_ras
+ [MetaTensor(DATA_3D, affine=AFFINE_2), POINT_3D_WORLD, None, True, False, POINT_3D_IMAGE],
+ ["affine", POINT_3D_WORLD, None, True, False, POINT_3D_IMAGE], # use refer_data itself
+ [
+ MetaTensor(DATA_3D, affine=AFFINE_2),
+ MetaTensor(POINT_3D_IMAGE, affine=AFFINE_2),
+ None,
+ False,
+ False,
+ POINT_3D_WORLD,
+ ],
+ [MetaTensor(DATA_3D, affine=AFFINE_2), POINT_3D_WORLD, None, True, True, POINT_3D_IMAGE_RAS],
+ [MetaTensor(DATA_3D, affine=AFFINE_2), POINT_3D_WORLD, None, True, True, POINT_3D_IMAGE_RAS],
+]
+TEST_CASES_SEQUENCE = [
+ [
+ (MetaTensor(DATA_2D, affine=AFFINE_1), MetaTensor(DATA_3D, affine=AFFINE_2)),
+ [POINT_2D_WORLD, POINT_3D_WORLD],
+ None,
+ True,
+ False,
+ ["image_1", "image_2"],
+ [POINT_2D_IMAGE, POINT_3D_IMAGE],
+ ], # use image affine
+ [
+ (MetaTensor(DATA_2D, affine=AFFINE_1), MetaTensor(DATA_3D, affine=AFFINE_2)),
+ [POINT_2D_WORLD, POINT_3D_WORLD],
+ None,
+ True,
+ True,
+ ["image_1", "image_2"],
+ [POINT_2D_IMAGE_RAS, POINT_3D_IMAGE_RAS],
+ ], # test affine_lps_to_ras
+ [
+ (None, None),
+ [MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), MetaTensor(POINT_3D_IMAGE, affine=AFFINE_2)],
+ None,
+ False,
+ False,
+ None,
+ [POINT_2D_WORLD, POINT_3D_WORLD],
+ ], # use point affine
+ [
+ (None, None),
+ [POINT_2D_WORLD, POINT_2D_WORLD],
+ AFFINE_1,
+ True,
+ False,
+ None,
+ [POINT_2D_IMAGE, POINT_2D_IMAGE],
+ ], # use input affine
+ [
+ (MetaTensor(DATA_2D, affine=AFFINE_1), MetaTensor(DATA_3D, affine=AFFINE_2)),
+ [MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), MetaTensor(POINT_3D_IMAGE, affine=AFFINE_2)],
+ None,
+ False,
+ False,
+ ["image_1", "image_2"],
+ [POINT_2D_WORLD, POINT_3D_WORLD],
+ ],
+]
+
+TEST_CASES_WRONG = [
+ [POINT_2D_WORLD, True, None, None],
+ [POINT_2D_WORLD.unsqueeze(0), False, None, None],
+ [POINT_3D_WORLD[..., 0:1], False, None, None],
+ [POINT_3D_WORLD, False, torch.tensor([[[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]]]), None],
+ [POINT_3D_WORLD, False, None, "image"],
+ [POINT_3D_WORLD, False, None, []],
+]
+
+
+class TestCoordinateTransform(unittest.TestCase):
+ @parameterized.expand(TEST_CASES)
+ def test_transform_coordinates(self, image, points, affine, invert_affine, affine_lps_to_ras, expected_output):
+ data = {
+ "image": image,
+ "point": points,
+ "affine": torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]]),
+ }
+ refer_keys = "image" if (image is not None and image != "affine") else image
+ transform = ApplyTransformToPointsd(
+ keys="point",
+ refer_keys=refer_keys,
+ dtype=torch.int64,
+ affine=affine,
+ invert_affine=invert_affine,
+ affine_lps_to_ras=affine_lps_to_ras,
+ )
+ output = transform(data)
+
+ self.assertTrue(torch.allclose(output["point"], expected_output))
+ invert_out = transform.inverse(output)
+ self.assertTrue(torch.allclose(invert_out["point"], points))
+
+ @parameterized.expand(TEST_CASES_SEQUENCE)
+ def test_transform_coordinates_sequences(
+ self, image, points, affine, invert_affine, affine_lps_to_ras, refer_keys, expected_output
+ ):
+ data = {"image_1": image[0], "image_2": image[1], "point_1": points[0], "point_2": points[1]}
+ keys = ["point_1", "point_2"]
+ transform = ApplyTransformToPointsd(
+ keys=keys,
+ refer_keys=refer_keys,
+ dtype=torch.int64,
+ affine=affine,
+ invert_affine=invert_affine,
+ affine_lps_to_ras=affine_lps_to_ras,
+ )
+ output = transform(data)
+
+ self.assertTrue(torch.allclose(output["point_1"], expected_output[0]))
+ self.assertTrue(torch.allclose(output["point_2"], expected_output[1]))
+ invert_out = transform.inverse(output)
+ self.assertTrue(torch.allclose(invert_out["point_1"], points[0]))
+
+ @parameterized.expand(TEST_CASES_WRONG)
+ def test_wrong_input(self, input, invert_affine, affine, refer_keys):
+ if refer_keys == []:
+ with self.assertRaises(ValueError):
+ ApplyTransformToPointsd(
+ keys="point", dtype=torch.int64, invert_affine=invert_affine, affine=affine, refer_keys=refer_keys
+ )
+ else:
+ transform = ApplyTransformToPointsd(
+ keys="point", dtype=torch.int64, invert_affine=invert_affine, affine=affine, refer_keys=refer_keys
+ )
+ data = {"point": input}
+ if refer_keys == "image":
+ with self.assertRaises(KeyError):
+ transform(data)
+ else:
+ with self.assertRaises(ValueError):
+ transform(data)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_arraydataset.py b/tests/test_arraydataset.py
index efc014a267..03239a9764 100644
--- a/tests/test_arraydataset.py
+++ b/tests/test_arraydataset.py
@@ -40,8 +40,9 @@
class TestCompose(Compose):
+ __test__ = False # indicate to pytest that this class is not intended for collection
- def __call__(self, input_, lazy):
+ def __call__(self, input_, lazy=False):
img = self.transforms[0](input_)
metadata = img.meta
img = self.transforms[1](img)
diff --git a/tests/test_attentionunet.py b/tests/test_attentionunet.py
index 83f6cabc5e..6a577f763f 100644
--- a/tests/test_attentionunet.py
+++ b/tests/test_attentionunet.py
@@ -14,11 +14,17 @@
import unittest
import torch
+import torch.nn as nn
import monai.networks.nets.attentionunet as att
from tests.utils import skip_if_no_cuda, skip_if_quick
+def get_net_parameters(net: nn.Module) -> int:
+ """Returns the total number of parameters in a Module."""
+ return sum(param.numel() for param in net.parameters())
+
+
class TestAttentionUnet(unittest.TestCase):
def test_attention_block(self):
@@ -50,6 +56,20 @@ def test_attentionunet(self):
self.assertEqual(output.shape[0], input.shape[0])
self.assertEqual(output.shape[1], 2)
+ def test_attentionunet_kernel_size(self):
+ args_dict = {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 2,
+ "channels": (3, 4, 5),
+ "up_kernel_size": 5,
+ "strides": (1, 2),
+ }
+ model_a = att.AttentionUnet(**args_dict, kernel_size=5)
+ model_b = att.AttentionUnet(**args_dict, kernel_size=7)
+ self.assertEqual(get_net_parameters(model_a), 3534)
+ self.assertEqual(get_net_parameters(model_b), 5574)
+
@skip_if_no_cuda
def test_attentionunet_gpu(self):
for dims in [2, 3]:
diff --git a/tests/test_auto3dseg.py b/tests/test_auto3dseg.py
index e2097679e2..5273f0663a 100644
--- a/tests/test_auto3dseg.py
+++ b/tests/test_auto3dseg.py
@@ -123,6 +123,8 @@ class TestOperations(Operations):
Test example for user operation
"""
+ __test__ = False # indicate to pytest that this class is not intended for collection
+
def __init__(self) -> None:
self.data = {"max": np.max, "mean": np.mean, "min": np.min}
@@ -132,6 +134,8 @@ class TestAnalyzer(Analyzer):
Test example for a simple Analyzer
"""
+ __test__ = False # indicate to pytest that this class is not intended for collection
+
def __init__(self, key, report_format, stats_name="test"):
self.key = key
super().__init__(stats_name, report_format)
@@ -149,6 +153,8 @@ class TestImageAnalyzer(Analyzer):
Test example for a simple Analyzer
"""
+ __test__ = False # indicate to pytest that this class is not intended for collection
+
def __init__(self, image_key="image", stats_name="test_image"):
self.image_key = image_key
report_format = {"test_stats": None}
@@ -367,7 +373,6 @@ def test_filename_case_analyzer(self):
for batch_data in self.dataset:
d = transform(batch_data[0])
assert DataStatsKeys.BY_CASE_IMAGE_PATH in d
- assert DataStatsKeys.BY_CASE_IMAGE_PATH in d
def test_filename_case_analyzer_image_only(self):
analyzer_image = FilenameStats("image", DataStatsKeys.BY_CASE_IMAGE_PATH)
diff --git a/tests/test_autoencoderkl.py b/tests/test_autoencoderkl.py
new file mode 100644
index 0000000000..d15cb79084
--- /dev/null
+++ b/tests/test_autoencoderkl.py
@@ -0,0 +1,337 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import os
+import tempfile
+import unittest
+from unittest import skipUnless
+
+import torch
+from parameterized import parameterized
+
+from monai.apps import download_url
+from monai.networks import eval_mode
+from monai.networks.nets import AutoencoderKL
+from monai.utils import optional_import
+from tests.utils import SkipIfBeforePyTorchVersion, skip_if_downloading_fails, testing_data_config
+
+tqdm, has_tqdm = optional_import("tqdm", name="tqdm")
+_, has_einops = optional_import("einops")
+
+device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+
+
+CASES_NO_ATTENTION = [
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "channels": (4, 4, 4),
+ "latent_channels": 4,
+ "attention_levels": (False, False, False),
+ "num_res_blocks": 1,
+ "norm_num_groups": 4,
+ "with_encoder_nonlocal_attn": False,
+ "with_decoder_nonlocal_attn": False,
+ },
+ (1, 1, 16, 16),
+ (1, 1, 16, 16),
+ (1, 4, 4, 4),
+ ],
+ [
+ {
+ "spatial_dims": 3,
+ "in_channels": 1,
+ "out_channels": 1,
+ "channels": (4, 4, 4),
+ "latent_channels": 4,
+ "attention_levels": (False, False, False),
+ "num_res_blocks": 1,
+ "norm_num_groups": 4,
+ "with_encoder_nonlocal_attn": False,
+ "with_decoder_nonlocal_attn": False,
+ },
+ (1, 1, 16, 16, 16),
+ (1, 1, 16, 16, 16),
+ (1, 4, 4, 4, 4),
+ ],
+]
+
+CASES_ATTENTION = [
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "channels": (4, 4, 4),
+ "latent_channels": 4,
+ "attention_levels": (False, False, False),
+ "num_res_blocks": 1,
+ "norm_num_groups": 4,
+ },
+ (1, 1, 16, 16),
+ (1, 1, 16, 16),
+ (1, 4, 4, 4),
+ ],
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "channels": (4, 4, 4),
+ "latent_channels": 4,
+ "attention_levels": (False, False, False),
+ "num_res_blocks": (1, 1, 2),
+ "norm_num_groups": 4,
+ },
+ (1, 1, 16, 16),
+ (1, 1, 16, 16),
+ (1, 4, 4, 4),
+ ],
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "channels": (4, 4, 4),
+ "latent_channels": 4,
+ "attention_levels": (False, False, False),
+ "num_res_blocks": 1,
+ "norm_num_groups": 4,
+ },
+ (1, 1, 16, 16),
+ (1, 1, 16, 16),
+ (1, 4, 4, 4),
+ ],
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "channels": (4, 4, 4),
+ "latent_channels": 4,
+ "attention_levels": (False, False, False),
+ "num_res_blocks": 1,
+ "norm_num_groups": 4,
+ "with_encoder_nonlocal_attn": False,
+ },
+ (1, 1, 16, 16),
+ (1, 1, 16, 16),
+ (1, 4, 4, 4),
+ ],
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "channels": (4, 4, 4),
+ "latent_channels": 4,
+ "attention_levels": (False, False, True),
+ "num_res_blocks": 1,
+ "norm_num_groups": 4,
+ },
+ (1, 1, 16, 16),
+ (1, 1, 16, 16),
+ (1, 4, 4, 4),
+ ],
+ [
+ {
+ "spatial_dims": 3,
+ "in_channels": 1,
+ "out_channels": 1,
+ "channels": (4, 4, 4),
+ "latent_channels": 4,
+ "attention_levels": (False, False, True),
+ "num_res_blocks": 1,
+ "norm_num_groups": 4,
+ },
+ (1, 1, 16, 16, 16),
+ (1, 1, 16, 16, 16),
+ (1, 4, 4, 4, 4),
+ ],
+]
+
+if has_einops:
+ CASES = CASES_NO_ATTENTION + CASES_ATTENTION
+else:
+ CASES = CASES_NO_ATTENTION
+
+
+class TestAutoEncoderKL(unittest.TestCase):
+ @parameterized.expand(CASES)
+ def test_shape(self, input_param, input_shape, expected_shape, expected_latent_shape):
+ net = AutoencoderKL(**input_param).to(device)
+ with eval_mode(net):
+ result = net.forward(torch.randn(input_shape).to(device))
+ self.assertEqual(result[0].shape, expected_shape)
+ self.assertEqual(result[1].shape, expected_latent_shape)
+ self.assertEqual(result[2].shape, expected_latent_shape)
+
+ @parameterized.expand(CASES)
+ @SkipIfBeforePyTorchVersion((1, 11))
+ def test_shape_with_convtranspose_and_checkpointing(
+ self, input_param, input_shape, expected_shape, expected_latent_shape
+ ):
+ input_param = input_param.copy()
+ input_param.update({"use_checkpoint": True, "use_convtranspose": True})
+ net = AutoencoderKL(**input_param).to(device)
+ with eval_mode(net):
+ result = net.forward(torch.randn(input_shape).to(device))
+ self.assertEqual(result[0].shape, expected_shape)
+ self.assertEqual(result[1].shape, expected_latent_shape)
+ self.assertEqual(result[2].shape, expected_latent_shape)
+
+ def test_model_channels_not_multiple_of_norm_num_group(self):
+ with self.assertRaises(ValueError):
+ AutoencoderKL(
+ spatial_dims=2,
+ in_channels=1,
+ out_channels=1,
+ channels=(24, 24, 24),
+ attention_levels=(False, False, False),
+ latent_channels=8,
+ num_res_blocks=1,
+ norm_num_groups=16,
+ )
+
+ def test_model_num_channels_not_same_size_of_attention_levels(self):
+ with self.assertRaises(ValueError):
+ AutoencoderKL(
+ spatial_dims=2,
+ in_channels=1,
+ out_channels=1,
+ channels=(24, 24, 24),
+ attention_levels=(False, False),
+ latent_channels=8,
+ num_res_blocks=1,
+ norm_num_groups=16,
+ )
+
+ def test_model_num_channels_not_same_size_of_num_res_blocks(self):
+ with self.assertRaises(ValueError):
+ AutoencoderKL(
+ spatial_dims=2,
+ in_channels=1,
+ out_channels=1,
+ channels=(24, 24, 24),
+ attention_levels=(False, False, False),
+ latent_channels=8,
+ num_res_blocks=(8, 8),
+ norm_num_groups=16,
+ )
+
+ def test_shape_reconstruction(self):
+ input_param, input_shape, expected_shape, _ = CASES[0]
+ net = AutoencoderKL(**input_param).to(device)
+ with eval_mode(net):
+ result = net.reconstruct(torch.randn(input_shape).to(device))
+ self.assertEqual(result.shape, expected_shape)
+
+ @SkipIfBeforePyTorchVersion((1, 11))
+ def test_shape_reconstruction_with_convtranspose_and_checkpointing(self):
+ input_param, input_shape, expected_shape, _ = CASES[0]
+ input_param = input_param.copy()
+ input_param.update({"use_checkpoint": True, "use_convtranspose": True})
+ net = AutoencoderKL(**input_param).to(device)
+ with eval_mode(net):
+ result = net.reconstruct(torch.randn(input_shape).to(device))
+ self.assertEqual(result.shape, expected_shape)
+
+ def test_shape_encode(self):
+ input_param, input_shape, _, expected_latent_shape = CASES[0]
+ net = AutoencoderKL(**input_param).to(device)
+ with eval_mode(net):
+ result = net.encode(torch.randn(input_shape).to(device))
+ self.assertEqual(result[0].shape, expected_latent_shape)
+ self.assertEqual(result[1].shape, expected_latent_shape)
+
+ @SkipIfBeforePyTorchVersion((1, 11))
+ def test_shape_encode_with_convtranspose_and_checkpointing(self):
+ input_param, input_shape, _, expected_latent_shape = CASES[0]
+ input_param = input_param.copy()
+ input_param.update({"use_checkpoint": True, "use_convtranspose": True})
+ net = AutoencoderKL(**input_param).to(device)
+ with eval_mode(net):
+ result = net.encode(torch.randn(input_shape).to(device))
+ self.assertEqual(result[0].shape, expected_latent_shape)
+ self.assertEqual(result[1].shape, expected_latent_shape)
+
+ def test_shape_sampling(self):
+ input_param, _, _, expected_latent_shape = CASES[0]
+ net = AutoencoderKL(**input_param).to(device)
+ with eval_mode(net):
+ result = net.sampling(
+ torch.randn(expected_latent_shape).to(device), torch.randn(expected_latent_shape).to(device)
+ )
+ self.assertEqual(result.shape, expected_latent_shape)
+
+ @SkipIfBeforePyTorchVersion((1, 11))
+ def test_shape_sampling_convtranspose_and_checkpointing(self):
+ input_param, _, _, expected_latent_shape = CASES[0]
+ input_param = input_param.copy()
+ input_param.update({"use_checkpoint": True, "use_convtranspose": True})
+ net = AutoencoderKL(**input_param).to(device)
+ with eval_mode(net):
+ result = net.sampling(
+ torch.randn(expected_latent_shape).to(device), torch.randn(expected_latent_shape).to(device)
+ )
+ self.assertEqual(result.shape, expected_latent_shape)
+
+ def test_shape_decode(self):
+ input_param, expected_input_shape, _, latent_shape = CASES[0]
+ net = AutoencoderKL(**input_param).to(device)
+ with eval_mode(net):
+ result = net.decode(torch.randn(latent_shape).to(device))
+ self.assertEqual(result.shape, expected_input_shape)
+
+ @SkipIfBeforePyTorchVersion((1, 11))
+ def test_shape_decode_convtranspose_and_checkpointing(self):
+ input_param, expected_input_shape, _, latent_shape = CASES[0]
+ input_param = input_param.copy()
+ input_param.update({"use_checkpoint": True, "use_convtranspose": True})
+ net = AutoencoderKL(**input_param).to(device)
+ with eval_mode(net):
+ result = net.decode(torch.randn(latent_shape).to(device))
+ self.assertEqual(result.shape, expected_input_shape)
+
+ @skipUnless(has_einops, "Requires einops")
+ def test_compatibility_with_monai_generative(self):
+ # test loading weights from a model saved in MONAI Generative, version 0.2.3
+ with skip_if_downloading_fails():
+ net = AutoencoderKL(
+ spatial_dims=2,
+ in_channels=1,
+ out_channels=1,
+ channels=(4, 4, 4),
+ latent_channels=4,
+ attention_levels=(False, False, True),
+ num_res_blocks=1,
+ norm_num_groups=4,
+ ).to(device)
+
+ tmpdir = tempfile.mkdtemp()
+ key = "autoencoderkl_monai_generative_weights"
+ url = testing_data_config("models", key, "url")
+ hash_type = testing_data_config("models", key, "hash_type")
+ hash_val = testing_data_config("models", key, "hash_val")
+ filename = "autoencoderkl_monai_generative_weights.pt"
+
+ weight_path = os.path.join(tmpdir, filename)
+ download_url(url=url, filepath=weight_path, hash_val=hash_val, hash_type=hash_type)
+
+ net.load_old_state_dict(torch.load(weight_path), verbose=False)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_autoencoderkl_maisi.py b/tests/test_autoencoderkl_maisi.py
new file mode 100644
index 0000000000..0e9f427fb6
--- /dev/null
+++ b/tests/test_autoencoderkl_maisi.py
@@ -0,0 +1,225 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import unittest
+
+import torch
+from parameterized import parameterized
+
+from monai.apps.generation.maisi.networks.autoencoderkl_maisi import AutoencoderKlMaisi
+from monai.networks import eval_mode
+from monai.utils import optional_import
+from tests.utils import SkipIfBeforePyTorchVersion
+
+tqdm, has_tqdm = optional_import("tqdm", name="tqdm")
+_, has_einops = optional_import("einops")
+
+device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+
+CASES_NO_ATTENTION = [
+ [
+ {
+ "spatial_dims": 3,
+ "in_channels": 1,
+ "out_channels": 1,
+ "num_channels": (4, 4, 4),
+ "latent_channels": 4,
+ "attention_levels": (False, False, False),
+ "num_res_blocks": (1, 1, 1),
+ "norm_num_groups": 4,
+ "with_encoder_nonlocal_attn": False,
+ "with_decoder_nonlocal_attn": False,
+ "num_splits": 2,
+ "print_info": False,
+ },
+ (1, 1, 32, 32, 32),
+ (1, 1, 32, 32, 32),
+ (1, 4, 8, 8, 8),
+ ]
+]
+
+CASES_ATTENTION = [
+ [
+ {
+ "spatial_dims": 3,
+ "in_channels": 1,
+ "out_channels": 1,
+ "num_channels": (4, 4, 4),
+ "latent_channels": 4,
+ "attention_levels": (False, False, True),
+ "num_res_blocks": (1, 1, 1),
+ "norm_num_groups": 4,
+ "with_encoder_nonlocal_attn": True,
+ "with_decoder_nonlocal_attn": True,
+ "num_splits": 2,
+ "print_info": False,
+ },
+ (1, 1, 32, 32, 32),
+ (1, 1, 32, 32, 32),
+ (1, 4, 8, 8, 8),
+ ]
+]
+
+if has_einops:
+ CASES = CASES_NO_ATTENTION + CASES_ATTENTION
+else:
+ CASES = CASES_NO_ATTENTION
+
+
+class TestAutoencoderKlMaisi(unittest.TestCase):
+
+ @parameterized.expand(CASES)
+ def test_shape(self, input_param, input_shape, expected_shape, expected_latent_shape):
+ net = AutoencoderKlMaisi(**input_param).to(device)
+ with eval_mode(net):
+ result = net.forward(torch.randn(input_shape).to(device))
+ self.assertEqual(result[0].shape, expected_shape)
+ self.assertEqual(result[1].shape, expected_latent_shape)
+ self.assertEqual(result[2].shape, expected_latent_shape)
+
+ @parameterized.expand(CASES)
+ @SkipIfBeforePyTorchVersion((1, 11))
+ def test_shape_with_convtranspose_and_checkpointing(
+ self, input_param, input_shape, expected_shape, expected_latent_shape
+ ):
+ input_param = input_param.copy()
+ input_param.update({"use_checkpointing": True, "use_convtranspose": True})
+ net = AutoencoderKlMaisi(**input_param).to(device)
+ with eval_mode(net):
+ result = net.forward(torch.randn(input_shape).to(device))
+ self.assertEqual(result[0].shape, expected_shape)
+ self.assertEqual(result[1].shape, expected_latent_shape)
+ self.assertEqual(result[2].shape, expected_latent_shape)
+
+ def test_model_channels_not_multiple_of_norm_num_group(self):
+ with self.assertRaises(ValueError):
+ AutoencoderKlMaisi(
+ spatial_dims=3,
+ in_channels=1,
+ out_channels=1,
+ num_channels=(24, 24, 24),
+ attention_levels=(False, False, False),
+ latent_channels=8,
+ num_res_blocks=(1, 1, 1),
+ norm_num_groups=16,
+ num_splits=2,
+ print_info=False,
+ )
+
+ def test_model_num_channels_not_same_size_of_attention_levels(self):
+ with self.assertRaises(ValueError):
+ AutoencoderKlMaisi(
+ spatial_dims=3,
+ in_channels=1,
+ out_channels=1,
+ num_channels=(24, 24, 24),
+ attention_levels=(False, False),
+ latent_channels=8,
+ num_res_blocks=(1, 1, 1),
+ norm_num_groups=16,
+ num_splits=2,
+ print_info=False,
+ )
+
+ def test_model_num_channels_not_same_size_of_num_res_blocks(self):
+ with self.assertRaises(ValueError):
+ AutoencoderKlMaisi(
+ spatial_dims=3,
+ in_channels=1,
+ out_channels=1,
+ num_channels=(24, 24),
+ attention_levels=(False, False, False),
+ latent_channels=8,
+ num_res_blocks=(8, 8, 8),
+ norm_num_groups=16,
+ num_splits=2,
+ print_info=False,
+ )
+
+ def test_shape_reconstruction(self):
+ input_param, input_shape, expected_shape, _ = CASES[0]
+ net = AutoencoderKlMaisi(**input_param).to(device)
+ with eval_mode(net):
+ result = net.reconstruct(torch.randn(input_shape).to(device))
+ self.assertEqual(result.shape, expected_shape)
+
+ @SkipIfBeforePyTorchVersion((1, 11))
+ def test_shape_reconstruction_with_convtranspose_and_checkpointing(self):
+ input_param, input_shape, expected_shape, _ = CASES[0]
+ input_param = input_param.copy()
+ input_param.update({"use_checkpointing": True, "use_convtranspose": True})
+ net = AutoencoderKlMaisi(**input_param).to(device)
+ with eval_mode(net):
+ result = net.reconstruct(torch.randn(input_shape).to(device))
+ self.assertEqual(result.shape, expected_shape)
+
+ def test_shape_encode(self):
+ input_param, input_shape, _, expected_latent_shape = CASES[0]
+ net = AutoencoderKlMaisi(**input_param).to(device)
+ with eval_mode(net):
+ result = net.encode(torch.randn(input_shape).to(device))
+ self.assertEqual(result[0].shape, expected_latent_shape)
+ self.assertEqual(result[1].shape, expected_latent_shape)
+
+ @SkipIfBeforePyTorchVersion((1, 11))
+ def test_shape_encode_with_convtranspose_and_checkpointing(self):
+ input_param, input_shape, _, expected_latent_shape = CASES[0]
+ input_param = input_param.copy()
+ input_param.update({"use_checkpointing": True, "use_convtranspose": True})
+ net = AutoencoderKlMaisi(**input_param).to(device)
+ with eval_mode(net):
+ result = net.encode(torch.randn(input_shape).to(device))
+ self.assertEqual(result[0].shape, expected_latent_shape)
+ self.assertEqual(result[1].shape, expected_latent_shape)
+
+ def test_shape_sampling(self):
+ input_param, _, _, expected_latent_shape = CASES[0]
+ net = AutoencoderKlMaisi(**input_param).to(device)
+ with eval_mode(net):
+ result = net.sampling(
+ torch.randn(expected_latent_shape).to(device), torch.randn(expected_latent_shape).to(device)
+ )
+ self.assertEqual(result.shape, expected_latent_shape)
+
+ @SkipIfBeforePyTorchVersion((1, 11))
+ def test_shape_sampling_convtranspose_and_checkpointing(self):
+ input_param, _, _, expected_latent_shape = CASES[0]
+ input_param = input_param.copy()
+ input_param.update({"use_checkpointing": True, "use_convtranspose": True})
+ net = AutoencoderKlMaisi(**input_param).to(device)
+ with eval_mode(net):
+ result = net.sampling(
+ torch.randn(expected_latent_shape).to(device), torch.randn(expected_latent_shape).to(device)
+ )
+ self.assertEqual(result.shape, expected_latent_shape)
+
+ def test_shape_decode(self):
+ input_param, expected_input_shape, _, latent_shape = CASES[0]
+ net = AutoencoderKlMaisi(**input_param).to(device)
+ with eval_mode(net):
+ result = net.decode(torch.randn(latent_shape).to(device))
+ self.assertEqual(result.shape, expected_input_shape)
+
+ @SkipIfBeforePyTorchVersion((1, 11))
+ def test_shape_decode_convtranspose_and_checkpointing(self):
+ input_param, expected_input_shape, _, latent_shape = CASES[0]
+ input_param = input_param.copy()
+ input_param.update({"use_checkpointing": True, "use_convtranspose": True})
+ net = AutoencoderKlMaisi(**input_param).to(device)
+ with eval_mode(net):
+ result = net.decode(torch.randn(latent_shape).to(device))
+ self.assertEqual(result.shape, expected_input_shape)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_bundle_ckpt_export.py b/tests/test_bundle_ckpt_export.py
index 8f376a06d5..cfcadcfc4c 100644
--- a/tests/test_bundle_ckpt_export.py
+++ b/tests/test_bundle_ckpt_export.py
@@ -72,9 +72,9 @@ def test_export(self, key_in_ckpt, use_trace):
_, metadata, extra_files = load_net_with_metadata(
ts_file, more_extra_files=["inference.json", "def_args.json"]
)
- self.assertTrue("schema" in metadata)
- self.assertTrue("meta_file" in json.loads(extra_files["def_args.json"]))
- self.assertTrue("network_def" in json.loads(extra_files["inference.json"]))
+ self.assertIn("schema", metadata)
+ self.assertIn("meta_file", json.loads(extra_files["def_args.json"]))
+ self.assertIn("network_def", json.loads(extra_files["inference.json"]))
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
def test_default_value(self, key_in_ckpt, use_trace):
diff --git a/tests/test_bundle_download.py b/tests/test_bundle_download.py
index 89fbe5e8b2..399c61b117 100644
--- a/tests/test_bundle_download.py
+++ b/tests/test_bundle_download.py
@@ -16,6 +16,7 @@
import tempfile
import unittest
from unittest.case import skipUnless
+from unittest.mock import patch
import numpy as np
import torch
@@ -24,6 +25,7 @@
import monai.networks.nets as nets
from monai.apps import check_hash
from monai.bundle import ConfigParser, create_workflow, load
+from monai.bundle.scripts import _examine_monai_version, _list_latest_versions, download
from monai.utils import optional_import
from tests.utils import (
SkipIfBeforePyTorchVersion,
@@ -56,7 +58,7 @@
TEST_CASE_5 = [
["models/model.pt", "models/model.ts", "configs/train.json"],
"brats_mri_segmentation",
- "https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting/brats_mri_segmentation/versions/0.3.9/files/brats_mri_segmentation_v0.3.9.zip",
+ "https://api.ngc.nvidia.com/v2/models/nvidia/monaihosting/brats_mri_segmentation/versions/0.4.0/files/brats_mri_segmentation_v0.4.0.zip",
]
TEST_CASE_6 = [["models/model.pt", "configs/train.json"], "renalStructures_CECT_segmentation", "0.1.0"]
@@ -87,7 +89,7 @@
TEST_CASE_10 = [
["network.json", "test_output.pt", "test_input.pt", "large_files.yaml"],
"test_bundle",
- "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/test_bundle_v0.1.2.zip",
+ "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/test_bundle_v0.1.3.zip",
{"model.pt": "27952767e2e154e3b0ee65defc5aed38", "model.ts": "97746870fe591f69ac09827175b00675"},
]
@@ -173,6 +175,23 @@ def test_monaihosting_url_download_bundle(self, bundle_files, bundle_name, url):
file_path = os.path.join(tempdir, bundle_name, file)
self.assertTrue(os.path.exists(file_path))
+ @parameterized.expand([TEST_CASE_5])
+ @skip_if_quick
+ def test_ngc_private_source_download_bundle(self, bundle_files, bundle_name, _url):
+ with skip_if_downloading_fails():
+ # download a single file from url, also use `args_file`
+ with tempfile.TemporaryDirectory() as tempdir:
+ def_args = {"name": bundle_name, "bundle_dir": tempdir}
+ def_args_file = os.path.join(tempdir, "def_args.json")
+ parser = ConfigParser()
+ parser.export_config_file(config=def_args, filepath=def_args_file)
+ cmd = ["coverage", "run", "-m", "monai.bundle", "download", "--args_file", def_args_file]
+ cmd += ["--progress", "False", "--source", "ngc_private"]
+ command_line_tests(cmd)
+ for file in bundle_files:
+ file_path = os.path.join(tempdir, bundle_name, file)
+ self.assertTrue(os.path.exists(file_path))
+
@parameterized.expand([TEST_CASE_6])
@skip_if_quick
def test_monaihosting_source_download_bundle(self, bundle_files, bundle_name, version):
@@ -190,6 +209,55 @@ def test_monaihosting_source_download_bundle(self, bundle_files, bundle_name, ve
file_path = os.path.join(tempdir, bundle_name, file)
self.assertTrue(os.path.exists(file_path))
+ @patch("monai.bundle.scripts.get_versions", return_value={"version": "1.2"})
+ def test_examine_monai_version(self, mock_get_versions):
+ self.assertTrue(_examine_monai_version("1.1")[0]) # Should return True, compatible
+ self.assertTrue(_examine_monai_version("1.2rc1")[0]) # Should return True, compatible
+ self.assertFalse(_examine_monai_version("1.3")[0]) # Should return False, not compatible
+
+ @patch("monai.bundle.scripts.get_versions", return_value={"version": "1.2rc1"})
+ def test_examine_monai_version_rc(self, mock_get_versions):
+ self.assertTrue(_examine_monai_version("1.2")[0]) # Should return True, compatible
+ self.assertFalse(_examine_monai_version("1.3")[0]) # Should return False, not compatible
+
+ def test_list_latest_versions(self):
+ """Test listing of the latest versions."""
+ data = {
+ "modelVersions": [
+ {"createdDate": "2021-01-01", "versionId": "1.0"},
+ {"createdDate": "2021-01-02", "versionId": "1.1"},
+ {"createdDate": "2021-01-03", "versionId": "1.2"},
+ ]
+ }
+ self.assertEqual(_list_latest_versions(data), ["1.2", "1.1", "1.0"])
+ self.assertEqual(_list_latest_versions(data, max_versions=2), ["1.2", "1.1"])
+ data = {
+ "modelVersions": [
+ {"createdDate": "2021-01-01", "versionId": "1.0"},
+ {"createdDate": "2021-01-02", "versionId": "1.1"},
+ ]
+ }
+ self.assertEqual(_list_latest_versions(data), ["1.1", "1.0"])
+
+ @skip_if_quick
+ @patch("monai.bundle.scripts.get_versions", return_value={"version": "1.2"})
+ def test_download_monaihosting(self, mock_get_versions):
+ """Test checking MONAI version from a metadata file."""
+ with patch("monai.bundle.scripts.logger") as mock_logger:
+ with tempfile.TemporaryDirectory() as tempdir:
+ download(name="spleen_ct_segmentation", bundle_dir=tempdir, source="monaihosting")
+ # Should have a warning message because the latest version is using monai > 1.2
+ mock_logger.warning.assert_called_once()
+
+ @skip_if_quick
+ @patch("monai.bundle.scripts.get_versions", return_value={"version": "1.3"})
+ def test_download_ngc(self, mock_get_versions):
+ """Test checking MONAI version from a metadata file."""
+ with patch("monai.bundle.scripts.logger") as mock_logger:
+ with tempfile.TemporaryDirectory() as tempdir:
+ download(name="spleen_ct_segmentation", bundle_dir=tempdir, source="ngc")
+ mock_logger.warning.assert_not_called()
+
@skip_if_no_cuda
class TestLoad(unittest.TestCase):
diff --git a/tests/test_bundle_get_data.py b/tests/test_bundle_get_data.py
index 605b3945bb..f84713fbe3 100644
--- a/tests/test_bundle_get_data.py
+++ b/tests/test_bundle_get_data.py
@@ -51,8 +51,8 @@ class TestGetBundleData(unittest.TestCase):
def test_get_all_bundles_list(self, params):
with skip_if_downloading_fails():
output = get_all_bundles_list(**params)
- self.assertTrue(isinstance(output, list))
- self.assertTrue(isinstance(output[0], tuple))
+ self.assertIsInstance(output, list)
+ self.assertIsInstance(output[0], tuple)
self.assertTrue(len(output[0]) == 2)
@parameterized.expand([TEST_CASE_1, TEST_CASE_5])
@@ -60,16 +60,17 @@ def test_get_all_bundles_list(self, params):
def test_get_bundle_versions(self, params):
with skip_if_downloading_fails():
output = get_bundle_versions(**params)
- self.assertTrue(isinstance(output, dict))
- self.assertTrue("latest_version" in output and "all_versions" in output)
- self.assertTrue("0.1.0" in output["all_versions"])
+ self.assertIsInstance(output, dict)
+ self.assertIn("latest_version", output)
+ self.assertIn("all_versions", output)
+ self.assertIn("0.1.0", output["all_versions"])
@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
@skip_if_quick
def test_get_bundle_info(self, params):
with skip_if_downloading_fails():
output = get_bundle_info(**params)
- self.assertTrue(isinstance(output, dict))
+ self.assertIsInstance(output, dict)
for key in ["id", "name", "size", "download_count", "browser_download_url"]:
self.assertTrue(key in output)
@@ -78,7 +79,7 @@ def test_get_bundle_info(self, params):
def test_get_bundle_info_monaihosting(self, params):
with skip_if_downloading_fails():
output = get_bundle_info(**params)
- self.assertTrue(isinstance(output, dict))
+ self.assertIsInstance(output, dict)
for key in ["name", "browser_download_url"]:
self.assertTrue(key in output)
diff --git a/tests/test_bundle_trt_export.py b/tests/test_bundle_trt_export.py
index 47034852ef..835c8e5c1d 100644
--- a/tests/test_bundle_trt_export.py
+++ b/tests/test_bundle_trt_export.py
@@ -22,7 +22,13 @@
from monai.data import load_net_with_metadata
from monai.networks import save_state
from monai.utils import optional_import
-from tests.utils import command_line_tests, skip_if_no_cuda, skip_if_quick, skip_if_windows
+from tests.utils import (
+ SkipIfBeforeComputeCapabilityVersion,
+ command_line_tests,
+ skip_if_no_cuda,
+ skip_if_quick,
+ skip_if_windows,
+)
_, has_torchtrt = optional_import(
"torch_tensorrt",
@@ -47,6 +53,7 @@
@skip_if_windows
@skip_if_no_cuda
@skip_if_quick
+@SkipIfBeforeComputeCapabilityVersion((7, 0))
class TestTRTExport(unittest.TestCase):
def setUp(self):
@@ -91,9 +98,9 @@ def test_trt_export(self, convert_precision, input_shape, dynamic_batch):
_, metadata, extra_files = load_net_with_metadata(
ts_file, more_extra_files=["inference.json", "def_args.json"]
)
- self.assertTrue("schema" in metadata)
- self.assertTrue("meta_file" in json.loads(extra_files["def_args.json"]))
- self.assertTrue("network_def" in json.loads(extra_files["inference.json"]))
+ self.assertIn("schema", metadata)
+ self.assertIn("meta_file", json.loads(extra_files["def_args.json"]))
+ self.assertIn("network_def", json.loads(extra_files["inference.json"]))
@parameterized.expand([TEST_CASE_3, TEST_CASE_4])
@unittest.skipUnless(
@@ -129,9 +136,9 @@ def test_onnx_trt_export(self, convert_precision, input_shape, dynamic_batch):
_, metadata, extra_files = load_net_with_metadata(
ts_file, more_extra_files=["inference.json", "def_args.json"]
)
- self.assertTrue("schema" in metadata)
- self.assertTrue("meta_file" in json.loads(extra_files["def_args.json"]))
- self.assertTrue("network_def" in json.loads(extra_files["inference.json"]))
+ self.assertIn("schema", metadata)
+ self.assertIn("meta_file", json.loads(extra_files["def_args.json"]))
+ self.assertIn("network_def", json.loads(extra_files["inference.json"]))
if __name__ == "__main__":
diff --git a/tests/test_bundle_workflow.py b/tests/test_bundle_workflow.py
index f7da37acef..1727fcdf53 100644
--- a/tests/test_bundle_workflow.py
+++ b/tests/test_bundle_workflow.py
@@ -35,6 +35,8 @@
TEST_CASE_3 = [os.path.join(os.path.dirname(__file__), "testing_data", "config_fl_train.json")]
+TEST_CASE_NON_CONFIG_WRONG_LOG = [None, "logging.conf", "Cannot find the logging config file: logging.conf."]
+
class TestBundleWorkflow(unittest.TestCase):
@@ -103,6 +105,16 @@ def test_inference_config(self, config_file):
)
self._test_inferer(inferer)
+ # test property path
+ inferer = ConfigWorkflow(
+ config_file=config_file,
+ properties_path=os.path.join(os.path.dirname(__file__), "testing_data", "fl_infer_properties.json"),
+ logging_file=os.path.join(os.path.dirname(__file__), "testing_data", "logging.conf"),
+ **override,
+ )
+ self._test_inferer(inferer)
+ self.assertEqual(inferer.workflow_type, None)
+
@parameterized.expand([TEST_CASE_3])
def test_train_config(self, config_file):
# test standard MONAI model-zoo config workflow
@@ -126,11 +138,11 @@ def test_train_config(self, config_file):
self.assertListEqual(trainer.check_properties(), [])
# test read / write the properties
dataset = trainer.train_dataset
- self.assertTrue(isinstance(dataset, Dataset))
+ self.assertIsInstance(dataset, Dataset)
inferer = trainer.train_inferer
- self.assertTrue(isinstance(inferer, SimpleInferer))
+ self.assertIsInstance(inferer, SimpleInferer)
# test optional properties get
- self.assertTrue(trainer.train_key_metric is None)
+ self.assertIsNone(trainer.train_key_metric)
trainer.train_dataset = deepcopy(dataset)
trainer.train_inferer = deepcopy(inferer)
# test optional properties set
@@ -144,8 +156,14 @@ def test_train_config(self, config_file):
def test_non_config(self):
# test user defined python style workflow
inferer = NonConfigWorkflow(self.filename, self.data_dir)
+ self.assertEqual(inferer.meta_file, None)
self._test_inferer(inferer)
+ @parameterized.expand([TEST_CASE_NON_CONFIG_WRONG_LOG])
+ def test_non_config_wrong_log_cases(self, meta_file, logging_file, expected_error):
+ with self.assertRaisesRegex(FileNotFoundError, expected_error):
+ NonConfigWorkflow(self.filename, self.data_dir, meta_file, logging_file)
+
if __name__ == "__main__":
unittest.main()
diff --git a/tests/test_cell_sam_wrapper.py b/tests/test_cell_sam_wrapper.py
new file mode 100644
index 0000000000..2f1ee2b901
--- /dev/null
+++ b/tests/test_cell_sam_wrapper.py
@@ -0,0 +1,58 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import unittest
+
+import torch
+from parameterized import parameterized
+
+from monai.networks import eval_mode
+from monai.networks.nets.cell_sam_wrapper import CellSamWrapper
+from monai.utils import optional_import
+
+build_sam_vit_b, has_sam = optional_import("segment_anything.build_sam", name="build_sam_vit_b")
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+TEST_CASE_CELLSEGWRAPPER = []
+for dims in [128, 256, 512, 1024]:
+ test_case = [
+ {"auto_resize_inputs": True, "network_resize_roi": [1024, 1024], "checkpoint": None},
+ (1, 3, *([dims] * 2)),
+ (1, 3, *([dims] * 2)),
+ ]
+ TEST_CASE_CELLSEGWRAPPER.append(test_case)
+
+
+@unittest.skipUnless(has_sam, "Requires SAM installation")
+class TestResNetDS(unittest.TestCase):
+
+ @parameterized.expand(TEST_CASE_CELLSEGWRAPPER)
+ def test_shape(self, input_param, input_shape, expected_shape):
+ net = CellSamWrapper(**input_param).to(device)
+ with eval_mode(net):
+ result = net(torch.randn(input_shape).to(device))
+ self.assertEqual(result.shape, expected_shape, msg=str(input_param))
+
+ def test_ill_arg0(self):
+ with self.assertRaises(RuntimeError):
+ net = CellSamWrapper(auto_resize_inputs=False, checkpoint=None).to(device)
+ net(torch.randn([1, 3, 256, 256]).to(device))
+
+ def test_ill_arg1(self):
+ with self.assertRaises(RuntimeError):
+ net = CellSamWrapper(network_resize_roi=[256, 256], checkpoint=None).to(device)
+ net(torch.randn([1, 3, 1024, 1024]).to(device))
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_clip_intensity_percentiles.py b/tests/test_clip_intensity_percentiles.py
new file mode 100644
index 0000000000..77f811db87
--- /dev/null
+++ b/tests/test_clip_intensity_percentiles.py
@@ -0,0 +1,198 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from __future__ import annotations
+
+import unittest
+
+import torch
+from parameterized import parameterized
+
+from monai.transforms import ClipIntensityPercentiles
+from monai.transforms.utils import soft_clip
+from monai.transforms.utils_pytorch_numpy_unification import clip, percentile
+from monai.utils.type_conversion import convert_to_tensor
+from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, NumpyImageTestCase3D, assert_allclose
+
+
+def test_hard_clip_func(im, lower, upper):
+ im_t = convert_to_tensor(im)
+ if lower is None:
+ upper = percentile(im_t, upper)
+ elif upper is None:
+ lower = percentile(im_t, lower)
+ else:
+ lower, upper = percentile(im_t, (lower, upper))
+ return clip(im_t, lower, upper)
+
+
+def test_soft_clip_func(im, lower, upper):
+ im_t = convert_to_tensor(im)
+ if lower is None:
+ upper = percentile(im_t, upper)
+ elif upper is None:
+ lower = percentile(im_t, lower)
+ else:
+ lower, upper = percentile(im_t, (lower, upper))
+ return soft_clip(im_t, minv=lower, maxv=upper, sharpness_factor=1.0, dtype=torch.float32)
+
+
+class TestClipIntensityPercentiles2D(NumpyImageTestCase2D):
+
+ @parameterized.expand([[p] for p in TEST_NDARRAYS])
+ def test_hard_clipping_two_sided(self, p):
+ hard_clipper = ClipIntensityPercentiles(upper=95, lower=5)
+ im = p(self.imt)
+ result = hard_clipper(im)
+ expected = test_hard_clip_func(im, 5, 95)
+ assert_allclose(result, p(expected), type_test="tensor", rtol=1e-4, atol=0)
+
+ @parameterized.expand([[p] for p in TEST_NDARRAYS])
+ def test_hard_clipping_one_sided_high(self, p):
+ hard_clipper = ClipIntensityPercentiles(upper=95, lower=None)
+ im = p(self.imt)
+ result = hard_clipper(im)
+ expected = test_hard_clip_func(im, 0, 95)
+ assert_allclose(result, p(expected), type_test="tensor", rtol=1e-4, atol=0)
+
+ @parameterized.expand([[p] for p in TEST_NDARRAYS])
+ def test_hard_clipping_one_sided_low(self, p):
+ hard_clipper = ClipIntensityPercentiles(upper=None, lower=5)
+ im = p(self.imt)
+ result = hard_clipper(im)
+ expected = test_hard_clip_func(im, 5, 100)
+ assert_allclose(result, p(expected), type_test="tensor", rtol=1e-4, atol=0)
+
+ @parameterized.expand([[p] for p in TEST_NDARRAYS])
+ def test_soft_clipping_two_sided(self, p):
+ soft_clipper = ClipIntensityPercentiles(upper=95, lower=5, sharpness_factor=1.0)
+ im = p(self.imt)
+ result = soft_clipper(im)
+ expected = test_soft_clip_func(im, 5, 95)
+ # the rtol is set to 1e-4 because the logaddexp function used in softplus is not stable accross torch and numpy
+ assert_allclose(result, p(expected), type_test="tensor", rtol=1e-4, atol=0)
+
+ @parameterized.expand([[p] for p in TEST_NDARRAYS])
+ def test_soft_clipping_one_sided_high(self, p):
+ soft_clipper = ClipIntensityPercentiles(upper=95, lower=None, sharpness_factor=1.0)
+ im = p(self.imt)
+ result = soft_clipper(im)
+ expected = test_soft_clip_func(im, None, 95)
+ # the rtol is set to 1e-4 because the logaddexp function used in softplus is not stable accross torch and numpy
+ assert_allclose(result, p(expected), type_test="tensor", rtol=1e-4, atol=0)
+
+ @parameterized.expand([[p] for p in TEST_NDARRAYS])
+ def test_soft_clipping_one_sided_low(self, p):
+ soft_clipper = ClipIntensityPercentiles(upper=None, lower=5, sharpness_factor=1.0)
+ im = p(self.imt)
+ result = soft_clipper(im)
+ expected = test_soft_clip_func(im, 5, None)
+ # the rtol is set to 1e-4 because the logaddexp function used in softplus is not stable accross torch and numpy
+ assert_allclose(result, p(expected), type_test="tensor", rtol=1e-4, atol=0)
+
+ @parameterized.expand([[p] for p in TEST_NDARRAYS])
+ def test_channel_wise(self, p):
+ clipper = ClipIntensityPercentiles(upper=95, lower=5, channel_wise=True)
+ im = p(self.imt)
+ result = clipper(im)
+ im_t = convert_to_tensor(self.imt)
+ for i, c in enumerate(im_t):
+ lower, upper = percentile(c, (5, 95))
+ expected = clip(c, lower, upper)
+ assert_allclose(result[i], p(expected), type_test="tensor", rtol=1e-4, atol=0)
+
+ def test_ill_sharpness_factor(self):
+ with self.assertRaises(ValueError):
+ ClipIntensityPercentiles(upper=95, lower=5, sharpness_factor=0.0)
+
+ def test_ill_lower_percentile(self):
+ with self.assertRaises(ValueError):
+ ClipIntensityPercentiles(upper=None, lower=-1)
+
+ def test_ill_upper_percentile(self):
+ with self.assertRaises(ValueError):
+ ClipIntensityPercentiles(upper=101, lower=None)
+
+ def test_ill_percentiles(self):
+ with self.assertRaises(ValueError):
+ ClipIntensityPercentiles(upper=95, lower=96)
+
+ def test_ill_both_none(self):
+ with self.assertRaises(ValueError):
+ ClipIntensityPercentiles(upper=None, lower=None)
+
+
+class TestClipIntensityPercentiles3D(NumpyImageTestCase3D):
+
+ @parameterized.expand([[p] for p in TEST_NDARRAYS])
+ def test_hard_clipping_two_sided(self, p):
+ hard_clipper = ClipIntensityPercentiles(upper=95, lower=5)
+ im = p(self.imt)
+ result = hard_clipper(im)
+ expected = test_hard_clip_func(im, 5, 95)
+ assert_allclose(result, p(expected), type_test="tensor", rtol=1e-4, atol=0)
+
+ @parameterized.expand([[p] for p in TEST_NDARRAYS])
+ def test_hard_clipping_one_sided_high(self, p):
+ hard_clipper = ClipIntensityPercentiles(upper=95, lower=None)
+ im = p(self.imt)
+ result = hard_clipper(im)
+ expected = test_hard_clip_func(im, 0, 95)
+ assert_allclose(result, p(expected), type_test="tensor", rtol=1e-4, atol=0)
+
+ @parameterized.expand([[p] for p in TEST_NDARRAYS])
+ def test_hard_clipping_one_sided_low(self, p):
+ hard_clipper = ClipIntensityPercentiles(upper=None, lower=5)
+ im = p(self.imt)
+ result = hard_clipper(im)
+ expected = test_hard_clip_func(im, 5, 100)
+ assert_allclose(result, p(expected), type_test="tensor", rtol=1e-4, atol=0)
+
+ @parameterized.expand([[p] for p in TEST_NDARRAYS])
+ def test_soft_clipping_two_sided(self, p):
+ soft_clipper = ClipIntensityPercentiles(upper=95, lower=5, sharpness_factor=1.0)
+ im = p(self.imt)
+ result = soft_clipper(im)
+ expected = test_soft_clip_func(im, 5, 95)
+ # the rtol is set to 1e-4 because the logaddexp function used in softplus is not stable accross torch and numpy
+ assert_allclose(result, p(expected), type_test="tensor", rtol=1e-4, atol=0)
+
+ @parameterized.expand([[p] for p in TEST_NDARRAYS])
+ def test_soft_clipping_one_sided_high(self, p):
+ soft_clipper = ClipIntensityPercentiles(upper=95, lower=None, sharpness_factor=1.0)
+ im = p(self.imt)
+ result = soft_clipper(im)
+ expected = test_soft_clip_func(im, None, 95)
+ # the rtol is set to 1e-4 because the logaddexp function used in softplus is not stable accross torch and numpy
+ assert_allclose(result, p(expected), type_test="tensor", rtol=1e-4, atol=0)
+
+ @parameterized.expand([[p] for p in TEST_NDARRAYS])
+ def test_soft_clipping_one_sided_low(self, p):
+ soft_clipper = ClipIntensityPercentiles(upper=None, lower=5, sharpness_factor=1.0)
+ im = p(self.imt)
+ result = soft_clipper(im)
+ expected = test_soft_clip_func(im, 5, None)
+ # the rtol is set to 1e-4 because the logaddexp function used in softplus is not stable accross torch and numpy
+ assert_allclose(result, p(expected), type_test="tensor", rtol=1e-4, atol=0)
+
+ @parameterized.expand([[p] for p in TEST_NDARRAYS])
+ def test_channel_wise(self, p):
+ clipper = ClipIntensityPercentiles(upper=95, lower=5, channel_wise=True)
+ im = p(self.imt)
+ result = clipper(im)
+ im_t = convert_to_tensor(self.imt)
+ for i, c in enumerate(im_t):
+ lower, upper = percentile(c, (5, 95))
+ expected = clip(c, lower, upper)
+ assert_allclose(result[i], p(expected), type_test="tensor", rtol=1e-4, atol=0)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_clip_intensity_percentilesd.py b/tests/test_clip_intensity_percentilesd.py
new file mode 100644
index 0000000000..3e06b18418
--- /dev/null
+++ b/tests/test_clip_intensity_percentilesd.py
@@ -0,0 +1,196 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import unittest
+
+from parameterized import parameterized
+
+from monai.transforms import ClipIntensityPercentilesd
+from monai.transforms.utils_pytorch_numpy_unification import clip, percentile
+from monai.utils.type_conversion import convert_to_tensor
+from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, NumpyImageTestCase3D, assert_allclose
+
+from .test_clip_intensity_percentiles import test_hard_clip_func, test_soft_clip_func
+
+
+class TestClipIntensityPercentilesd2D(NumpyImageTestCase2D):
+
+ @parameterized.expand([[p] for p in TEST_NDARRAYS])
+ def test_hard_clipping_two_sided(self, p):
+ key = "img"
+ hard_clipper = ClipIntensityPercentilesd(keys=[key], upper=95, lower=5)
+ im = p(self.imt)
+ result = hard_clipper({key: im})
+ expected = test_hard_clip_func(im, 5, 95)
+ assert_allclose(result[key], p(expected), type_test="tensor", rtol=1e-4, atol=0)
+
+ @parameterized.expand([[p] for p in TEST_NDARRAYS])
+ def test_hard_clipping_one_sided_high(self, p):
+ key = "img"
+ hard_clipper = ClipIntensityPercentilesd(keys=[key], upper=95, lower=None)
+ im = p(self.imt)
+ result = hard_clipper({key: im})
+ expected = test_hard_clip_func(im, 0, 95)
+ assert_allclose(result[key], p(expected), type_test="tensor", rtol=1e-4, atol=0)
+
+ @parameterized.expand([[p] for p in TEST_NDARRAYS])
+ def test_hard_clipping_one_sided_low(self, p):
+ key = "img"
+ hard_clipper = ClipIntensityPercentilesd(keys=[key], upper=None, lower=5)
+ im = p(self.imt)
+ result = hard_clipper({key: im})
+ expected = test_hard_clip_func(im, 5, 100)
+ assert_allclose(result[key], p(expected), type_test="tensor", rtol=1e-4, atol=0)
+
+ @parameterized.expand([[p] for p in TEST_NDARRAYS])
+ def test_soft_clipping_two_sided(self, p):
+ key = "img"
+ soft_clipper = ClipIntensityPercentilesd(keys=[key], upper=95, lower=5, sharpness_factor=1.0)
+ im = p(self.imt)
+ result = soft_clipper({key: im})
+ expected = test_soft_clip_func(im, 5, 95)
+ # the rtol is set to 1e-4 because the logaddexp function used in softplus is not stable accross torch and numpy
+ assert_allclose(result[key], p(expected), type_test="tensor", rtol=1e-4, atol=0)
+
+ @parameterized.expand([[p] for p in TEST_NDARRAYS])
+ def test_soft_clipping_one_sided_high(self, p):
+ key = "img"
+ soft_clipper = ClipIntensityPercentilesd(keys=[key], upper=95, lower=None, sharpness_factor=1.0)
+ im = p(self.imt)
+ result = soft_clipper({key: im})
+ expected = test_soft_clip_func(im, None, 95)
+ # the rtol is set to 1e-4 because the logaddexp function used in softplus is not stable accross torch and numpy
+ assert_allclose(result[key], p(expected), type_test="tensor", rtol=1e-4, atol=0)
+
+ @parameterized.expand([[p] for p in TEST_NDARRAYS])
+ def test_soft_clipping_one_sided_low(self, p):
+ key = "img"
+ soft_clipper = ClipIntensityPercentilesd(keys=[key], upper=None, lower=5, sharpness_factor=1.0)
+ im = p(self.imt)
+ result = soft_clipper({key: im})
+ expected = test_soft_clip_func(im, 5, None)
+ # the rtol is set to 1e-4 because the logaddexp function used in softplus is not stable accross torch and numpy
+ assert_allclose(result[key], p(expected), type_test="tensor", rtol=1e-4, atol=0)
+
+ @parameterized.expand([[p] for p in TEST_NDARRAYS])
+ def test_channel_wise(self, p):
+ key = "img"
+ clipper = ClipIntensityPercentilesd(keys=[key], upper=95, lower=5, channel_wise=True)
+ im = p(self.imt)
+ result = clipper({key: im})
+ im_t = convert_to_tensor(self.imt)
+ for i, c in enumerate(im_t):
+ lower, upper = percentile(c, (5, 95))
+ expected = clip(c, lower, upper)
+ assert_allclose(result[key][i], p(expected), type_test="tensor", rtol=1e-3, atol=0)
+
+ def test_ill_sharpness_factor(self):
+ key = "img"
+ with self.assertRaises(ValueError):
+ ClipIntensityPercentilesd(keys=[key], upper=95, lower=5, sharpness_factor=0.0)
+
+ def test_ill_lower_percentile(self):
+ key = "img"
+ with self.assertRaises(ValueError):
+ ClipIntensityPercentilesd(keys=[key], upper=None, lower=-1)
+
+ def test_ill_upper_percentile(self):
+ key = "img"
+ with self.assertRaises(ValueError):
+ ClipIntensityPercentilesd(keys=[key], upper=101, lower=None)
+
+ def test_ill_percentiles(self):
+ key = "img"
+ with self.assertRaises(ValueError):
+ ClipIntensityPercentilesd(keys=[key], upper=95, lower=96)
+
+ def test_ill_both_none(self):
+ key = "img"
+ with self.assertRaises(ValueError):
+ ClipIntensityPercentilesd(keys=[key], upper=None, lower=None)
+
+
+class TestClipIntensityPercentilesd3D(NumpyImageTestCase3D):
+
+ @parameterized.expand([[p] for p in TEST_NDARRAYS])
+ def test_hard_clipping_two_sided(self, p):
+ key = "img"
+ hard_clipper = ClipIntensityPercentilesd(keys=[key], upper=95, lower=5)
+ im = p(self.imt)
+ result = hard_clipper({key: im})
+ expected = test_hard_clip_func(im, 5, 95)
+ assert_allclose(result[key], p(expected), type_test="tensor", rtol=1e-4, atol=0)
+
+ @parameterized.expand([[p] for p in TEST_NDARRAYS])
+ def test_hard_clipping_one_sided_high(self, p):
+ key = "img"
+ hard_clipper = ClipIntensityPercentilesd(keys=[key], upper=95, lower=None)
+ im = p(self.imt)
+ result = hard_clipper({key: im})
+ expected = test_hard_clip_func(im, 0, 95)
+ assert_allclose(result[key], p(expected), type_test="tensor", rtol=1e-4, atol=0)
+
+ @parameterized.expand([[p] for p in TEST_NDARRAYS])
+ def test_hard_clipping_one_sided_low(self, p):
+ key = "img"
+ hard_clipper = ClipIntensityPercentilesd(keys=[key], upper=None, lower=5)
+ im = p(self.imt)
+ result = hard_clipper({key: im})
+ expected = test_hard_clip_func(im, 5, 100)
+ assert_allclose(result[key], p(expected), type_test="tensor", rtol=1e-4, atol=0)
+
+ @parameterized.expand([[p] for p in TEST_NDARRAYS])
+ def test_soft_clipping_two_sided(self, p):
+ key = "img"
+ soft_clipper = ClipIntensityPercentilesd(keys=[key], upper=95, lower=5, sharpness_factor=1.0)
+ im = p(self.imt)
+ result = soft_clipper({key: im})
+ expected = test_soft_clip_func(im, 5, 95)
+ # the rtol is set to 1e-4 because the logaddexp function used in softplus is not stable accross torch and numpy
+ assert_allclose(result[key], p(expected), type_test="tensor", rtol=1e-4, atol=0)
+
+ @parameterized.expand([[p] for p in TEST_NDARRAYS])
+ def test_soft_clipping_one_sided_high(self, p):
+ key = "img"
+ soft_clipper = ClipIntensityPercentilesd(keys=[key], upper=95, lower=None, sharpness_factor=1.0)
+ im = p(self.imt)
+ result = soft_clipper({key: im})
+ expected = test_soft_clip_func(im, None, 95)
+ # the rtol is set to 1e-4 because the logaddexp function used in softplus is not stable accross torch and numpy
+ assert_allclose(result[key], p(expected), type_test="tensor", rtol=1e-4, atol=0)
+
+ @parameterized.expand([[p] for p in TEST_NDARRAYS])
+ def test_soft_clipping_one_sided_low(self, p):
+ key = "img"
+ soft_clipper = ClipIntensityPercentilesd(keys=[key], upper=None, lower=5, sharpness_factor=1.0)
+ im = p(self.imt)
+ result = soft_clipper({key: im})
+ expected = test_soft_clip_func(im, 5, None)
+ # the rtol is set to 1e-6 because the logaddexp function used in softplus is not stable accross torch and numpy
+ assert_allclose(result[key], p(expected), type_test="tensor", rtol=1e-4, atol=0)
+
+ @parameterized.expand([[p] for p in TEST_NDARRAYS])
+ def test_channel_wise(self, p):
+ key = "img"
+ clipper = ClipIntensityPercentilesd(keys=[key], upper=95, lower=5, channel_wise=True)
+ im = p(self.imt)
+ result = clipper({key: im})
+ im_t = convert_to_tensor(im)
+ for i, c in enumerate(im_t):
+ lower, upper = percentile(c, (5, 95))
+ expected = clip(c, lower, upper)
+ assert_allclose(result[key][i], p(expected), type_test="tensor", rtol=1e-4, atol=0)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_component_store.py b/tests/test_component_store.py
index 424eceb3d1..7e7c6dd19d 100644
--- a/tests/test_component_store.py
+++ b/tests/test_component_store.py
@@ -48,17 +48,17 @@ def test_add2(self):
self.cs.add("test_obj2", "Test object", test_obj2)
self.assertEqual(len(self.cs), 2)
- self.assertTrue("test_obj1" in self.cs)
- self.assertTrue("test_obj2" in self.cs)
+ self.assertIn("test_obj1", self.cs)
+ self.assertIn("test_obj2", self.cs)
def test_add_def(self):
- self.assertFalse("test_func" in self.cs)
+ self.assertNotIn("test_func", self.cs)
@self.cs.add_def("test_func", "Test function")
def test_func():
return 123
- self.assertTrue("test_func" in self.cs)
+ self.assertIn("test_func", self.cs)
self.assertEqual(len(self.cs), 1)
self.assertEqual(list(self.cs), [("test_func", test_func)])
diff --git a/tests/test_compose.py b/tests/test_compose.py
index 309767833b..3c53ac4a22 100644
--- a/tests/test_compose.py
+++ b/tests/test_compose.py
@@ -716,15 +716,15 @@ def test_compose_execute_equivalence_with_flags(self, flags, data, pipeline):
for k in actual.keys():
self.assertEqual(expected[k], actual[k])
else:
- self.assertTrue(expected, actual)
+ self.assertEqual(expected, actual)
p = deepcopy(pipeline)
actual = execute_compose(execute_compose(data, p, start=0, end=cutoff, **flags), p, start=cutoff, **flags)
if isinstance(actual, dict):
for k in actual.keys():
- self.assertTrue(expected[k], actual[k])
+ self.assertEqual(expected[k], actual[k])
else:
- self.assertTrue(expected, actual)
+ self.assertEqual(expected, actual)
class TestComposeCallableInput(unittest.TestCase):
diff --git a/tests/test_compute_f_beta.py b/tests/test_compute_f_beta.py
index 85997577cf..be2a7fc176 100644
--- a/tests/test_compute_f_beta.py
+++ b/tests/test_compute_f_beta.py
@@ -15,6 +15,7 @@
import numpy as np
import torch
+from parameterized import parameterized
from monai.metrics import FBetaScore
from tests.utils import assert_allclose
@@ -33,26 +34,21 @@ def test_expecting_success_and_device(self):
assert_allclose(result, torch.Tensor([0.714286]), atol=1e-6, rtol=1e-6)
np.testing.assert_equal(result.device, y_pred.device)
- def test_expecting_success2(self):
- metric = FBetaScore(beta=0.5)
- metric(
- y_pred=torch.Tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]]), y=torch.Tensor([[1, 0, 1], [0, 1, 0], [1, 0, 1]])
- )
- assert_allclose(metric.aggregate()[0], torch.Tensor([0.609756]), atol=1e-6, rtol=1e-6)
-
- def test_expecting_success3(self):
- metric = FBetaScore(beta=2)
- metric(
- y_pred=torch.Tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]]), y=torch.Tensor([[1, 0, 1], [0, 1, 0], [1, 0, 1]])
- )
- assert_allclose(metric.aggregate()[0], torch.Tensor([0.862069]), atol=1e-6, rtol=1e-6)
-
- def test_denominator_is_zero(self):
- metric = FBetaScore(beta=2)
- metric(
- y_pred=torch.Tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]]), y=torch.Tensor([[0, 0, 0], [0, 0, 0], [0, 0, 0]])
- )
- assert_allclose(metric.aggregate()[0], torch.Tensor([0.0]), atol=1e-6, rtol=1e-6)
+ @parameterized.expand(
+ [
+ (0.5, torch.Tensor([[1, 0, 1], [0, 1, 0], [1, 0, 1]]), torch.Tensor([0.609756])), # success_beta_0_5
+ (2, torch.Tensor([[1, 0, 1], [0, 1, 0], [1, 0, 1]]), torch.Tensor([0.862069])), # success_beta_2
+ (
+ 2, # success_beta_2, denominator_zero
+ torch.Tensor([[0, 0, 0], [0, 0, 0], [0, 0, 0]]),
+ torch.Tensor([0.0]),
+ ),
+ ]
+ )
+ def test_success_and_zero(self, beta, y, expected_score):
+ metric = FBetaScore(beta=beta)
+ metric(y_pred=torch.Tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]]), y=y)
+ assert_allclose(metric.aggregate()[0], expected_score, atol=1e-6, rtol=1e-6)
def test_number_of_dimensions_less_than_2_should_raise_error(self):
metric = FBetaScore()
@@ -63,7 +59,7 @@ def test_with_nan_values(self):
metric = FBetaScore(get_not_nans=True)
metric(
y_pred=torch.Tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]]),
- y=torch.Tensor([[1, 0, 1], [np.NaN, np.NaN, np.NaN], [1, 0, 1]]),
+ y=torch.Tensor([[1, 0, 1], [np.nan, np.nan, np.nan], [1, 0, 1]]),
)
assert_allclose(metric.aggregate()[0][0], torch.Tensor([0.727273]), atol=1e-6, rtol=1e-6)
diff --git a/tests/test_compute_generalized_dice.py b/tests/test_compute_generalized_dice.py
index e04444e988..985a01e993 100644
--- a/tests/test_compute_generalized_dice.py
+++ b/tests/test_compute_generalized_dice.py
@@ -22,17 +22,17 @@
_device = "cuda:0" if torch.cuda.is_available() else "cpu"
# keep background
-TEST_CASE_1 = [ # y (1, 1, 2, 2), y_pred (1, 1, 2, 2), expected out (1)
+TEST_CASE_1 = [ # y (1, 1, 2, 2), y_pred (1, 1, 2, 2), expected out (1, 1) with compute_generalized_dice
{
"y_pred": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]], device=_device),
"y": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]], device=_device),
"include_background": True,
},
- [0.8],
+ [[0.8]],
]
# remove background
-TEST_CASE_2 = [ # y (2, 1, 2, 2), y_pred (2, 3, 2, 2), expected out (2) (no background)
+TEST_CASE_2 = [ # y (2, 3, 2, 2), y_pred (2, 3, 2, 2), expected out (2) (no background) with GeneralizedDiceScore
{
"y_pred": torch.tensor(
[
@@ -47,32 +47,32 @@
]
),
"include_background": False,
+ "reduction": "mean_batch",
},
- [0.1667, 0.6667],
+ [0.583333, 0.333333],
]
-# should return 0 for both cases
-TEST_CASE_3 = [
+TEST_CASE_3 = [ # y (2, 3, 2, 2), y_pred (2, 3, 2, 2), expected out (1) with GeneralizedDiceScore
{
"y_pred": torch.tensor(
[
- [[[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]], [[1.0, 1.0], [1.0, 1.0]]],
- [[[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]], [[1.0, 1.0], [1.0, 1.0]]],
+ [[[1.0, 1.0], [1.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 1.0]]],
+ [[[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 0.0]]],
]
),
"y": torch.tensor(
[
[[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]],
- [[[0.0, 1.0], [1.0, 0.0]], [[1.0, 0.0], [0.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]]],
+ [[[0.0, 0.0], [0.0, 1.0]], [[1.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 0.0]]],
]
),
"include_background": True,
+ "reduction": "mean",
},
- [0.0, 0.0],
+ [0.5454],
]
-TEST_CASE_4 = [
- {"include_background": True, "reduction": "mean_batch"},
+TEST_CASE_4 = [ # y (2, 3, 2, 2), y_pred (2, 3, 2, 2), expected out (1) with GeneralizedDiceScore
{
"y_pred": torch.tensor(
[
@@ -83,15 +83,36 @@
"y": torch.tensor(
[
[[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]],
- [[[0.0, 0.0], [0.0, 1.0]], [[1.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 0.0]]],
+ [[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]],
]
),
+ "include_background": True,
+ "reduction": "sum",
},
- [0.5455],
+ [1.045455],
+]
+
+TEST_CASE_5 = [ # y (2, 2, 3, 3) y_pred (2, 2, 3, 3) expected out (2, 2) with compute_generalized_dice
+ {"y": torch.ones((2, 2, 3, 3)), "y_pred": torch.ones((2, 2, 3, 3))},
+ [[1.0000, 1.0000], [1.0000, 1.0000]],
]
-TEST_CASE_5 = [
- {"include_background": True, "reduction": "sum_batch"},
+TEST_CASE_6 = [ # y (2, 2, 3, 3) y_pred (2, 2, 3, 3) expected out (2, 2) with compute_generalized_dice
+ {"y": torch.zeros((2, 2, 3, 3)), "y_pred": torch.ones((2, 2, 3, 3))},
+ [[0.0000, 0.0000], [0.0000, 0.0000]],
+]
+
+TEST_CASE_7 = [ # y (2, 2, 3, 3) y_pred (2, 2, 3, 3) expected out (2, 2) with compute_generalized_dice
+ {"y": torch.ones((2, 2, 3, 3)), "y_pred": torch.zeros((2, 2, 3, 3))},
+ [[0.0000, 0.0000], [0.0000, 0.0000]],
+]
+
+TEST_CASE_8 = [ # y (2, 2, 3, 3) y_pred (2, 2, 3, 3) expected out (2, 2) with compute_generalized_dice
+ {"y": torch.zeros((2, 2, 3, 3)), "y_pred": torch.zeros((2, 2, 3, 3))},
+ [[1.0000, 1.0000], [1.0000, 1.0000]],
+]
+
+TEST_CASE_9 = [ # y (2, 3, 2, 2) y_pred (2, 3, 2, 2) expected out (2) with GeneralizedDiceScore
{
"y_pred": torch.tensor(
[
@@ -102,61 +123,118 @@
"y": torch.tensor(
[
[[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]],
- [[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]],
+ [[[0.0, 0.0], [0.0, 1.0]], [[1.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 0.0]]],
]
),
+ "include_background": True,
+ "reduction": "mean_channel",
},
- 1.0455,
+ [0.545455, 0.545455],
]
-TEST_CASE_6 = [{"y": torch.ones((2, 2, 3, 3)), "y_pred": torch.ones((2, 2, 3, 3))}, [1.0000, 1.0000]]
-TEST_CASE_7 = [{"y": torch.zeros((2, 2, 3, 3)), "y_pred": torch.ones((2, 2, 3, 3))}, [0.0000, 0.0000]]
-
-TEST_CASE_8 = [{"y": torch.ones((2, 2, 3, 3)), "y_pred": torch.zeros((2, 2, 3, 3))}, [0.0000, 0.0000]]
+TEST_CASE_10 = [ # y (2, 3, 2, 2) y_pred (2, 3, 2, 2) expected out (2, 3) with compute_generalized_dice
+ # and (3) with GeneralizedDiceScore "mean_batch"
+ {
+ "y_pred": torch.tensor(
+ [
+ [[[1.0, 1.0], [1.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 1.0]]],
+ [[[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 0.0]]],
+ ]
+ ),
+ "y": torch.tensor(
+ [
+ [[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]],
+ [[[0.0, 0.0], [0.0, 1.0]], [[1.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 0.0]]],
+ ]
+ ),
+ "include_background": True,
+ },
+ [[0.857143, 0.0, 0.0], [0.5, 0.4, 0.666667]],
+]
-TEST_CASE_9 = [{"y": torch.zeros((2, 2, 3, 3)), "y_pred": torch.zeros((2, 2, 3, 3))}, [1.0000, 1.0000]]
+TEST_CASE_11 = [ # y (2, 3, 2, 2) y_pred (2, 3, 2, 2) expected out (2, 1) with compute_generalized_dice (summed over classes)
+ # and (2) with GeneralizedDiceScore "mean_channel"
+ {
+ "y_pred": torch.tensor(
+ [
+ [[[1.0, 1.0], [1.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 1.0]]],
+ [[[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 0.0]]],
+ ]
+ ),
+ "y": torch.tensor(
+ [
+ [[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]],
+ [[[0.0, 0.0], [0.0, 1.0]], [[1.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 0.0]]],
+ ]
+ ),
+ "include_background": True,
+ "sum_over_classes": True,
+ },
+ [[0.545455], [0.545455]],
+]
class TestComputeGeneralizedDiceScore(unittest.TestCase):
-
@parameterized.expand([TEST_CASE_1])
def test_device(self, input_data, _expected_value):
+ """
+ Test if the result tensor is on the same device as the input tensor.
+ """
result = compute_generalized_dice(**input_data)
np.testing.assert_equal(result.device, input_data["y_pred"].device)
- # Functional part tests
- @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8, TEST_CASE_9])
+ @parameterized.expand([TEST_CASE_1, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8])
def test_value(self, input_data, expected_value):
+ """
+ Test if the computed generalized dice score matches the expected value.
+ """
result = compute_generalized_dice(**input_data)
np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)
- # Functional part tests
- @parameterized.expand([TEST_CASE_3])
- def test_nans(self, input_data, expected_value):
- result = compute_generalized_dice(**input_data)
- self.assertTrue(np.allclose(np.isnan(result.cpu().numpy()), expected_value))
-
- # Samplewise tests
- @parameterized.expand([TEST_CASE_1, TEST_CASE_2])
+ @parameterized.expand([TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_9])
def test_value_class(self, input_data, expected_value):
- # same test as for compute_meandice
- vals = {}
- vals["y_pred"] = input_data.pop("y_pred")
- vals["y"] = input_data.pop("y")
+ """
+ Test if the GeneralizedDiceScore class computes the correct values.
+ """
+ y_pred = input_data.pop("y_pred")
+ y = input_data.pop("y")
generalized_dice_score = GeneralizedDiceScore(**input_data)
- generalized_dice_score(**vals)
- result = generalized_dice_score.aggregate(reduction="none")
+ generalized_dice_score(y_pred=y_pred, y=y)
+ result = generalized_dice_score.aggregate()
np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)
- # Aggregation tests
- @parameterized.expand([TEST_CASE_4, TEST_CASE_5])
- def test_nans_class(self, params, input_data, expected_value):
- generalized_dice_score = GeneralizedDiceScore(**params)
- generalized_dice_score(**input_data)
- result = generalized_dice_score.aggregate()
+ @parameterized.expand([TEST_CASE_10])
+ def test_values_compare(self, input_data, expected_value):
+ """
+ Compare the results of compute_generalized_dice function and GeneralizedDiceScore class.
+ """
+ result = compute_generalized_dice(**input_data)
np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)
+ y_pred = input_data.pop("y_pred")
+ y = input_data.pop("y")
+ generalized_dice_score = GeneralizedDiceScore(**input_data, reduction="mean_batch")
+ generalized_dice_score(y_pred=y_pred, y=y)
+ result_class_mean = generalized_dice_score.aggregate()
+ np.testing.assert_allclose(result_class_mean.cpu().numpy(), np.mean(expected_value, axis=0), atol=1e-4)
+
+ @parameterized.expand([TEST_CASE_11])
+ def test_values_compare_sum_over_classes(self, input_data, expected_value):
+ """
+ Compare the results when summing over classes between compute_generalized_dice function and GeneralizedDiceScore class.
+ """
+ result = compute_generalized_dice(**input_data)
+ np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4)
+
+ y_pred = input_data.pop("y_pred")
+ y = input_data.pop("y")
+ input_data.pop("sum_over_classes")
+ generalized_dice_score = GeneralizedDiceScore(**input_data, reduction="mean_channel")
+ generalized_dice_score(y_pred=y_pred, y=y)
+ result_class_mean = generalized_dice_score.aggregate()
+ np.testing.assert_allclose(result_class_mean.cpu().numpy(), np.mean(expected_value, axis=1), atol=1e-4)
+
if __name__ == "__main__":
unittest.main()
diff --git a/tests/test_compute_ho_ver_maps.py b/tests/test_compute_ho_ver_maps.py
index bbd5230f04..6e46cf2b1e 100644
--- a/tests/test_compute_ho_ver_maps.py
+++ b/tests/test_compute_ho_ver_maps.py
@@ -67,8 +67,8 @@ class ComputeHoVerMapsTests(unittest.TestCase):
def test_horizontal_certical_maps(self, in_type, arguments, mask, hv_mask):
input_image = in_type(mask)
result = ComputeHoVerMaps(**arguments)(input_image)
- self.assertTrue(isinstance(result, torch.Tensor))
- self.assertTrue(str(result.dtype).split(".")[1] == arguments.get("dtype", "float32"))
+ self.assertIsInstance(result, torch.Tensor)
+ self.assertEqual(str(result.dtype).split(".")[1], arguments.get("dtype", "float32"))
assert_allclose(result, hv_mask, type_test="tensor")
diff --git a/tests/test_compute_ho_ver_maps_d.py b/tests/test_compute_ho_ver_maps_d.py
index 7b5ac0d9d7..0734e2e731 100644
--- a/tests/test_compute_ho_ver_maps_d.py
+++ b/tests/test_compute_ho_ver_maps_d.py
@@ -71,8 +71,8 @@ def test_horizontal_certical_maps(self, in_type, arguments, mask, hv_mask):
for k in mask.keys():
input_image[k] = in_type(mask[k])
result = ComputeHoVerMapsd(keys="mask", **arguments)(input_image)[hv_key]
- self.assertTrue(isinstance(result, torch.Tensor))
- self.assertTrue(str(result.dtype).split(".")[1] == arguments.get("dtype", "float32"))
+ self.assertIsInstance(result, torch.Tensor)
+ self.assertEqual(str(result.dtype).split(".")[1], arguments.get("dtype", "float32"))
assert_allclose(result, hv_mask[hv_key], type_test="tensor")
diff --git a/tests/test_compute_regression_metrics.py b/tests/test_compute_regression_metrics.py
index a8b7f03e47..c407ab6ba6 100644
--- a/tests/test_compute_regression_metrics.py
+++ b/tests/test_compute_regression_metrics.py
@@ -70,22 +70,24 @@ def test_shape_reduction(self):
mt = mt_fn(reduction="mean")
mt(in_tensor, in_tensor)
out_tensor = mt.aggregate()
- self.assertTrue(len(out_tensor.shape) == 1)
+ self.assertEqual(len(out_tensor.shape), 1)
mt = mt_fn(reduction="sum")
mt(in_tensor, in_tensor)
out_tensor = mt.aggregate()
- self.assertTrue(len(out_tensor.shape) == 0)
+ self.assertEqual(len(out_tensor.shape), 0)
mt = mt_fn(reduction="sum") # test reduction arg overriding
mt(in_tensor, in_tensor)
out_tensor = mt.aggregate(reduction="mean_channel")
- self.assertTrue(len(out_tensor.shape) == 1 and out_tensor.shape[0] == batch)
+ self.assertEqual(len(out_tensor.shape), 1)
+ self.assertEqual(out_tensor.shape[0], batch)
mt = mt_fn(reduction="sum_channel")
mt(in_tensor, in_tensor)
out_tensor = mt.aggregate()
- self.assertTrue(len(out_tensor.shape) == 1 and out_tensor.shape[0] == batch)
+ self.assertEqual(len(out_tensor.shape), 1)
+ self.assertEqual(out_tensor.shape[0], batch)
def test_compare_numpy(self):
set_determinism(seed=123)
diff --git a/tests/test_concat_itemsd.py b/tests/test_concat_itemsd.py
index 64c5d6e255..564ddf5c1f 100644
--- a/tests/test_concat_itemsd.py
+++ b/tests/test_concat_itemsd.py
@@ -30,7 +30,7 @@ def test_tensor_values(self):
"img2": torch.tensor([[0, 1], [1, 2]], device=device),
}
result = ConcatItemsd(keys=["img1", "img2"], name="cat_img")(input_data)
- self.assertTrue("cat_img" in result)
+ self.assertIn("cat_img", result)
result["cat_img"] += 1
assert_allclose(result["img1"], torch.tensor([[0, 1], [1, 2]], device=device))
assert_allclose(result["cat_img"], torch.tensor([[1, 2], [2, 3], [1, 2], [2, 3]], device=device))
@@ -42,8 +42,8 @@ def test_metatensor_values(self):
"img2": MetaTensor([[0, 1], [1, 2]], device=device),
}
result = ConcatItemsd(keys=["img1", "img2"], name="cat_img")(input_data)
- self.assertTrue("cat_img" in result)
- self.assertTrue(isinstance(result["cat_img"], MetaTensor))
+ self.assertIn("cat_img", result)
+ self.assertIsInstance(result["cat_img"], MetaTensor)
self.assertEqual(result["img1"].meta, result["cat_img"].meta)
result["cat_img"] += 1
assert_allclose(result["img1"], torch.tensor([[0, 1], [1, 2]], device=device))
@@ -52,7 +52,7 @@ def test_metatensor_values(self):
def test_numpy_values(self):
input_data = {"img1": np.array([[0, 1], [1, 2]]), "img2": np.array([[0, 1], [1, 2]])}
result = ConcatItemsd(keys=["img1", "img2"], name="cat_img")(input_data)
- self.assertTrue("cat_img" in result)
+ self.assertIn("cat_img", result)
result["cat_img"] += 1
np.testing.assert_allclose(result["img1"], np.array([[0, 1], [1, 2]]))
np.testing.assert_allclose(result["cat_img"], np.array([[1, 2], [2, 3], [1, 2], [2, 3]]))
diff --git a/tests/test_config_parser.py b/tests/test_config_parser.py
index cc890a0522..2b00c9f9d1 100644
--- a/tests/test_config_parser.py
+++ b/tests/test_config_parser.py
@@ -125,6 +125,22 @@ def __call__(self, a, b):
[0, 4],
]
+TEST_CASE_MERGE_JSON = ["""{"key1": [0], "key2": [0] }""", """{"key1": [1], "+key2": [4] }""", "json", [1], [0, 4]]
+
+TEST_CASE_MERGE_YAML = [
+ """
+ key1: 0
+ key2: [0]
+ """,
+ """
+ key1: 1
+ +key2: [4]
+ """,
+ "yaml",
+ 1,
+ [0, 4],
+]
+
class TestConfigParser(unittest.TestCase):
@@ -185,7 +201,7 @@ def test_function(self, config):
if id in ("compute", "cls_compute"):
parser[f"{id}#_mode_"] = "callable"
func = parser.get_parsed_content(id=id)
- self.assertTrue(id in parser.ref_resolver.resolved_content)
+ self.assertIn(id, parser.ref_resolver.resolved_content)
if id == "error_func":
with self.assertRaises(TypeError):
func(1, 2)
@@ -357,6 +373,22 @@ def test_parse_json_warn(self, config_string, extension, expected_unique_val, ex
self.assertEqual(parser.get_parsed_content("key#unique"), expected_unique_val)
self.assertIn(parser.get_parsed_content("key#duplicate"), expected_duplicate_vals)
+ @parameterized.expand([TEST_CASE_MERGE_JSON, TEST_CASE_MERGE_YAML])
+ @skipUnless(has_yaml, "Requires pyyaml")
+ def test_load_configs(
+ self, config_string, config_string2, extension, expected_overridden_val, expected_merged_vals
+ ):
+ with tempfile.TemporaryDirectory() as tempdir:
+ config_path1 = Path(tempdir) / f"config1.{extension}"
+ config_path2 = Path(tempdir) / f"config2.{extension}"
+ config_path1.write_text(config_string)
+ config_path2.write_text(config_string2)
+
+ parser = ConfigParser.load_config_files([config_path1, config_path2])
+
+ self.assertEqual(parser["key1"], expected_overridden_val)
+ self.assertEqual(parser["key2"], expected_merged_vals)
+
if __name__ == "__main__":
unittest.main()
diff --git a/tests/test_controlnet.py b/tests/test_controlnet.py
new file mode 100644
index 0000000000..4746c7ce22
--- /dev/null
+++ b/tests/test_controlnet.py
@@ -0,0 +1,215 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import os
+import tempfile
+import unittest
+from unittest import skipUnless
+
+import torch
+from parameterized import parameterized
+
+from monai.apps import download_url
+from monai.networks import eval_mode
+from monai.networks.nets.controlnet import ControlNet
+from monai.utils import optional_import
+from tests.utils import skip_if_downloading_fails, testing_data_config
+
+_, has_einops = optional_import("einops")
+UNCOND_CASES_2D = [
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "num_res_blocks": 1,
+ "channels": (8, 8, 8),
+ "attention_levels": (False, False, False),
+ "norm_num_groups": 8,
+ },
+ (1, 8, 4, 4),
+ ],
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "num_res_blocks": 1,
+ "channels": (8, 8, 8),
+ "attention_levels": (False, False, False),
+ "norm_num_groups": 8,
+ "resblock_updown": True,
+ },
+ (1, 8, 4, 4),
+ ],
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "num_res_blocks": 1,
+ "channels": (4, 4, 4),
+ "attention_levels": (False, False, True),
+ "num_head_channels": 4,
+ "norm_num_groups": 4,
+ },
+ (1, 4, 4, 4),
+ ],
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "num_res_blocks": 1,
+ "channels": (8, 8, 8),
+ "attention_levels": (False, False, True),
+ "num_head_channels": 8,
+ "norm_num_groups": 8,
+ "resblock_updown": True,
+ },
+ (1, 8, 4, 4),
+ ],
+]
+
+UNCOND_CASES_3D = [
+ [
+ {
+ "spatial_dims": 3,
+ "in_channels": 1,
+ "num_res_blocks": 1,
+ "channels": (8, 8, 8),
+ "attention_levels": (False, False, False),
+ "norm_num_groups": 8,
+ },
+ (1, 8, 4, 4, 4),
+ ],
+ [
+ {
+ "spatial_dims": 3,
+ "in_channels": 1,
+ "num_res_blocks": 1,
+ "channels": (4, 4, 4),
+ "num_head_channels": 4,
+ "attention_levels": (False, False, False),
+ "norm_num_groups": 4,
+ "resblock_updown": True,
+ },
+ (1, 4, 4, 4, 4),
+ ],
+]
+
+COND_CASES_2D = [
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "num_res_blocks": 1,
+ "channels": (8, 8, 8),
+ "attention_levels": (False, False, False),
+ "norm_num_groups": 8,
+ "with_conditioning": True,
+ "transformer_num_layers": 1,
+ "cross_attention_dim": 3,
+ },
+ (1, 8, 4, 4),
+ ],
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "num_res_blocks": 1,
+ "channels": (8, 8, 8),
+ "attention_levels": (False, False, False),
+ "norm_num_groups": 8,
+ "with_conditioning": True,
+ "transformer_num_layers": 1,
+ "cross_attention_dim": 3,
+ "resblock_updown": True,
+ },
+ (1, 8, 4, 4),
+ ],
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "num_res_blocks": 1,
+ "channels": (8, 8, 8),
+ "attention_levels": (False, False, False),
+ "norm_num_groups": 8,
+ "with_conditioning": True,
+ "transformer_num_layers": 1,
+ "cross_attention_dim": 3,
+ "upcast_attention": True,
+ },
+ (1, 8, 4, 4),
+ ],
+]
+
+
+class TestControlNet(unittest.TestCase):
+ @parameterized.expand(UNCOND_CASES_2D + UNCOND_CASES_3D)
+ @skipUnless(has_einops, "Requires einops")
+ def test_shape_unconditioned_models(self, input_param, expected_output_shape):
+ input_param["conditioning_embedding_in_channels"] = input_param["in_channels"]
+ input_param["conditioning_embedding_num_channels"] = (input_param["channels"][0],)
+ net = ControlNet(**input_param)
+ with eval_mode(net):
+ x = torch.rand((1, 1) + (16,) * input_param["spatial_dims"])
+ timesteps = torch.randint(0, 1000, (1,)).long()
+ controlnet_cond = torch.rand((1, 1) + (16,) * input_param["spatial_dims"])
+ result = net.forward(x, timesteps=timesteps, controlnet_cond=controlnet_cond)
+ self.assertEqual(len(result[0]), 2 * len(input_param["channels"]))
+ self.assertEqual(result[1].shape, expected_output_shape)
+
+ @parameterized.expand(COND_CASES_2D)
+ @skipUnless(has_einops, "Requires einops")
+ def test_shape_conditioned_models(self, input_param, expected_output_shape):
+ input_param["conditioning_embedding_in_channels"] = input_param["in_channels"]
+ input_param["conditioning_embedding_num_channels"] = (input_param["channels"][0],)
+ net = ControlNet(**input_param)
+ with eval_mode(net):
+ x = torch.rand((1, 1) + (16,) * input_param["spatial_dims"])
+ timesteps = torch.randint(0, 1000, (1,)).long()
+ controlnet_cond = torch.rand((1, 1) + (16,) * input_param["spatial_dims"])
+ result = net.forward(x, timesteps=timesteps, controlnet_cond=controlnet_cond, context=torch.rand((1, 1, 3)))
+ self.assertEqual(len(result[0]), 2 * len(input_param["channels"]))
+ self.assertEqual(result[1].shape, expected_output_shape)
+
+ @skipUnless(has_einops, "Requires einops")
+ def test_compatibility_with_monai_generative(self):
+ # test loading weights from a model saved in MONAI Generative, version 0.2.3
+ with skip_if_downloading_fails():
+ net = ControlNet(
+ spatial_dims=2,
+ in_channels=1,
+ num_res_blocks=1,
+ channels=(8, 8, 8),
+ attention_levels=(False, False, True),
+ norm_num_groups=8,
+ with_conditioning=True,
+ transformer_num_layers=1,
+ cross_attention_dim=3,
+ resblock_updown=True,
+ )
+
+ tmpdir = tempfile.mkdtemp()
+ key = "controlnet_monai_generative_weights"
+ url = testing_data_config("models", key, "url")
+ hash_type = testing_data_config("models", key, "hash_type")
+ hash_val = testing_data_config("models", key, "hash_val")
+ filename = "controlnet_monai_generative_weights.pt"
+
+ weight_path = os.path.join(tmpdir, filename)
+ download_url(url=url, filepath=weight_path, hash_val=hash_val, hash_type=hash_type)
+
+ net.load_old_state_dict(torch.load(weight_path), verbose=False)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_controlnet_inferers.py b/tests/test_controlnet_inferers.py
new file mode 100644
index 0000000000..e3b0aeb5a2
--- /dev/null
+++ b/tests/test_controlnet_inferers.py
@@ -0,0 +1,1310 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import unittest
+from unittest import skipUnless
+
+import torch
+from parameterized import parameterized
+
+from monai.inferers import ControlNetDiffusionInferer, ControlNetLatentDiffusionInferer
+from monai.networks.nets import (
+ VQVAE,
+ AutoencoderKL,
+ ControlNet,
+ DiffusionModelUNet,
+ SPADEAutoencoderKL,
+ SPADEDiffusionModelUNet,
+)
+from monai.networks.schedulers import DDIMScheduler, DDPMScheduler
+from monai.utils import optional_import
+
+_, has_scipy = optional_import("scipy")
+_, has_einops = optional_import("einops")
+
+
+CNDM_TEST_CASES = [
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "channels": [8],
+ "norm_num_groups": 8,
+ "attention_levels": [True],
+ "num_res_blocks": 1,
+ "num_head_channels": 8,
+ },
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "channels": [8],
+ "attention_levels": [True],
+ "norm_num_groups": 8,
+ "num_res_blocks": 1,
+ "num_head_channels": 8,
+ "conditioning_embedding_num_channels": [16],
+ "conditioning_embedding_in_channels": 1,
+ },
+ (2, 1, 8, 8),
+ ],
+ [
+ {
+ "spatial_dims": 3,
+ "in_channels": 1,
+ "out_channels": 1,
+ "channels": [8],
+ "norm_num_groups": 8,
+ "attention_levels": [True],
+ "num_res_blocks": 1,
+ "num_head_channels": 8,
+ },
+ {
+ "spatial_dims": 3,
+ "in_channels": 1,
+ "channels": [8],
+ "attention_levels": [True],
+ "num_res_blocks": 1,
+ "norm_num_groups": 8,
+ "num_head_channels": 8,
+ "conditioning_embedding_num_channels": [16],
+ "conditioning_embedding_in_channels": 1,
+ },
+ (2, 1, 8, 8, 8),
+ ],
+]
+LATENT_CNDM_TEST_CASES = [
+ [
+ "AutoencoderKL",
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "channels": (4, 4),
+ "latent_channels": 3,
+ "attention_levels": [False, False],
+ "num_res_blocks": 1,
+ "with_encoder_nonlocal_attn": False,
+ "with_decoder_nonlocal_attn": False,
+ "norm_num_groups": 4,
+ },
+ "DiffusionModelUNet",
+ {
+ "spatial_dims": 2,
+ "in_channels": 3,
+ "out_channels": 3,
+ "channels": [4, 4],
+ "norm_num_groups": 4,
+ "attention_levels": [False, False],
+ "num_res_blocks": 1,
+ "num_head_channels": 4,
+ },
+ {
+ "spatial_dims": 2,
+ "in_channels": 3,
+ "channels": [4, 4],
+ "attention_levels": [False, False],
+ "num_res_blocks": 1,
+ "norm_num_groups": 4,
+ "num_head_channels": 4,
+ "conditioning_embedding_num_channels": [16],
+ "conditioning_embedding_in_channels": 1,
+ },
+ (1, 1, 8, 8),
+ (1, 3, 4, 4),
+ ],
+ [
+ "VQVAE",
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "channels": [4, 4],
+ "num_res_layers": 1,
+ "num_res_channels": [4, 4],
+ "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)),
+ "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)),
+ "num_embeddings": 16,
+ "embedding_dim": 3,
+ },
+ "DiffusionModelUNet",
+ {
+ "spatial_dims": 2,
+ "in_channels": 3,
+ "out_channels": 3,
+ "channels": [8, 8],
+ "norm_num_groups": 8,
+ "attention_levels": [False, False],
+ "num_res_blocks": 1,
+ "num_head_channels": 8,
+ },
+ {
+ "spatial_dims": 2,
+ "in_channels": 3,
+ "channels": [8, 8],
+ "attention_levels": [False, False],
+ "num_res_blocks": 1,
+ "norm_num_groups": 8,
+ "num_head_channels": 8,
+ "conditioning_embedding_num_channels": [16],
+ "conditioning_embedding_in_channels": 1,
+ },
+ (1, 1, 16, 16),
+ (1, 3, 4, 4),
+ ],
+ [
+ "VQVAE",
+ {
+ "spatial_dims": 3,
+ "in_channels": 1,
+ "out_channels": 1,
+ "channels": [4, 4],
+ "num_res_layers": 1,
+ "num_res_channels": [4, 4],
+ "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)),
+ "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)),
+ "num_embeddings": 16,
+ "embedding_dim": 3,
+ },
+ "DiffusionModelUNet",
+ {
+ "spatial_dims": 3,
+ "in_channels": 3,
+ "out_channels": 3,
+ "channels": [8, 8],
+ "norm_num_groups": 8,
+ "attention_levels": [False, False],
+ "num_res_blocks": 1,
+ "num_head_channels": 8,
+ },
+ {
+ "spatial_dims": 3,
+ "in_channels": 3,
+ "channels": [8, 8],
+ "attention_levels": [False, False],
+ "num_res_blocks": 1,
+ "norm_num_groups": 8,
+ "num_head_channels": 8,
+ "conditioning_embedding_num_channels": [16],
+ "conditioning_embedding_in_channels": 1,
+ },
+ (1, 1, 16, 16, 16),
+ (1, 3, 4, 4, 4),
+ ],
+]
+LATENT_CNDM_TEST_CASES_DIFF_SHAPES = [
+ [
+ "AutoencoderKL",
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "channels": (4, 4),
+ "latent_channels": 3,
+ "attention_levels": [False, False],
+ "num_res_blocks": 1,
+ "with_encoder_nonlocal_attn": False,
+ "with_decoder_nonlocal_attn": False,
+ "norm_num_groups": 4,
+ },
+ "DiffusionModelUNet",
+ {
+ "spatial_dims": 2,
+ "in_channels": 3,
+ "out_channels": 3,
+ "channels": [4, 4],
+ "norm_num_groups": 4,
+ "attention_levels": [False, False],
+ "num_res_blocks": 1,
+ "num_head_channels": 4,
+ },
+ {
+ "spatial_dims": 2,
+ "in_channels": 3,
+ "channels": [4, 4],
+ "attention_levels": [False, False],
+ "num_res_blocks": 1,
+ "norm_num_groups": 4,
+ "num_head_channels": 4,
+ "conditioning_embedding_num_channels": [16],
+ "conditioning_embedding_in_channels": 1,
+ },
+ (1, 1, 12, 12),
+ (1, 3, 8, 8),
+ ],
+ [
+ "VQVAE",
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "channels": [4, 4],
+ "num_res_layers": 1,
+ "num_res_channels": [4, 4],
+ "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)),
+ "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)),
+ "num_embeddings": 16,
+ "embedding_dim": 3,
+ },
+ "DiffusionModelUNet",
+ {
+ "spatial_dims": 2,
+ "in_channels": 3,
+ "out_channels": 3,
+ "channels": [8, 8],
+ "norm_num_groups": 8,
+ "attention_levels": [False, False],
+ "num_res_blocks": 1,
+ "num_head_channels": 8,
+ },
+ {
+ "spatial_dims": 2,
+ "in_channels": 3,
+ "channels": [8, 8],
+ "attention_levels": [False, False],
+ "num_res_blocks": 1,
+ "norm_num_groups": 8,
+ "num_head_channels": 8,
+ "conditioning_embedding_num_channels": [16],
+ "conditioning_embedding_in_channels": 1,
+ },
+ (1, 1, 12, 12),
+ (1, 3, 8, 8),
+ ],
+ [
+ "VQVAE",
+ {
+ "spatial_dims": 3,
+ "in_channels": 1,
+ "out_channels": 1,
+ "channels": [4, 4],
+ "num_res_layers": 1,
+ "num_res_channels": [4, 4],
+ "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)),
+ "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)),
+ "num_embeddings": 16,
+ "embedding_dim": 3,
+ },
+ "DiffusionModelUNet",
+ {
+ "spatial_dims": 3,
+ "in_channels": 3,
+ "out_channels": 3,
+ "channels": [8, 8],
+ "norm_num_groups": 8,
+ "attention_levels": [False, False],
+ "num_res_blocks": 1,
+ "num_head_channels": 8,
+ },
+ {
+ "spatial_dims": 3,
+ "in_channels": 3,
+ "channels": [8, 8],
+ "attention_levels": [False, False],
+ "num_res_blocks": 1,
+ "norm_num_groups": 8,
+ "num_head_channels": 8,
+ "conditioning_embedding_num_channels": [16],
+ "conditioning_embedding_in_channels": 1,
+ },
+ (1, 1, 12, 12, 12),
+ (1, 3, 8, 8, 8),
+ ],
+ [
+ "SPADEAutoencoderKL",
+ {
+ "spatial_dims": 2,
+ "label_nc": 3,
+ "in_channels": 1,
+ "out_channels": 1,
+ "channels": (4, 4),
+ "latent_channels": 3,
+ "attention_levels": [False, False],
+ "num_res_blocks": 1,
+ "with_encoder_nonlocal_attn": False,
+ "with_decoder_nonlocal_attn": False,
+ "norm_num_groups": 4,
+ },
+ "DiffusionModelUNet",
+ {
+ "spatial_dims": 2,
+ "in_channels": 3,
+ "out_channels": 3,
+ "channels": [4, 4],
+ "norm_num_groups": 4,
+ "attention_levels": [False, False],
+ "num_res_blocks": 1,
+ "num_head_channels": 4,
+ },
+ {
+ "spatial_dims": 2,
+ "in_channels": 3,
+ "channels": [4, 4],
+ "attention_levels": [False, False],
+ "num_res_blocks": 1,
+ "norm_num_groups": 4,
+ "num_head_channels": 4,
+ "conditioning_embedding_num_channels": [16],
+ "conditioning_embedding_in_channels": 1,
+ },
+ (1, 1, 8, 8),
+ (1, 3, 4, 4),
+ ],
+ [
+ "AutoencoderKL",
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "channels": (4, 4),
+ "latent_channels": 3,
+ "attention_levels": [False, False],
+ "num_res_blocks": 1,
+ "with_encoder_nonlocal_attn": False,
+ "with_decoder_nonlocal_attn": False,
+ "norm_num_groups": 4,
+ },
+ "SPADEDiffusionModelUNet",
+ {
+ "spatial_dims": 2,
+ "label_nc": 3,
+ "in_channels": 3,
+ "out_channels": 3,
+ "channels": [4, 4],
+ "norm_num_groups": 4,
+ "attention_levels": [False, False],
+ "num_res_blocks": 1,
+ "num_head_channels": 4,
+ },
+ {
+ "spatial_dims": 2,
+ "in_channels": 3,
+ "channels": [4, 4],
+ "attention_levels": [False, False],
+ "num_res_blocks": 1,
+ "norm_num_groups": 4,
+ "num_head_channels": 4,
+ "conditioning_embedding_num_channels": [16],
+ "conditioning_embedding_in_channels": 1,
+ },
+ (1, 1, 8, 8),
+ (1, 3, 4, 4),
+ ],
+ [
+ "SPADEAutoencoderKL",
+ {
+ "spatial_dims": 2,
+ "label_nc": 3,
+ "in_channels": 1,
+ "out_channels": 1,
+ "channels": (4, 4),
+ "latent_channels": 3,
+ "attention_levels": [False, False],
+ "num_res_blocks": 1,
+ "with_encoder_nonlocal_attn": False,
+ "with_decoder_nonlocal_attn": False,
+ "norm_num_groups": 4,
+ },
+ "SPADEDiffusionModelUNet",
+ {
+ "spatial_dims": 2,
+ "label_nc": 3,
+ "in_channels": 3,
+ "out_channels": 3,
+ "channels": [4, 4],
+ "norm_num_groups": 4,
+ "attention_levels": [False, False],
+ "num_res_blocks": 1,
+ "num_head_channels": 4,
+ },
+ {
+ "spatial_dims": 2,
+ "in_channels": 3,
+ "channels": [4, 4],
+ "attention_levels": [False, False],
+ "num_res_blocks": 1,
+ "norm_num_groups": 4,
+ "num_head_channels": 4,
+ "conditioning_embedding_num_channels": [16],
+ "conditioning_embedding_in_channels": 1,
+ },
+ (1, 1, 8, 8),
+ (1, 3, 4, 4),
+ ],
+]
+
+
+class ControlNetTestDiffusionSamplingInferer(unittest.TestCase):
+ @parameterized.expand(CNDM_TEST_CASES)
+ @skipUnless(has_einops, "Requires einops")
+ def test_call(self, model_params, controlnet_params, input_shape):
+ model = DiffusionModelUNet(**model_params)
+ controlnet = ControlNet(**controlnet_params)
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
+ model.to(device)
+ model.eval()
+ controlnet.to(device)
+ controlnet.eval()
+ input = torch.randn(input_shape).to(device)
+ mask = torch.randn(input_shape).to(device)
+ noise = torch.randn(input_shape).to(device)
+ scheduler = DDPMScheduler(num_train_timesteps=10)
+ inferer = ControlNetDiffusionInferer(scheduler=scheduler)
+ scheduler.set_timesteps(num_inference_steps=10)
+ timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long()
+ sample = inferer(
+ inputs=input, noise=noise, diffusion_model=model, controlnet=controlnet, timesteps=timesteps, cn_cond=mask
+ )
+ self.assertEqual(sample.shape, input_shape)
+
+ @parameterized.expand(CNDM_TEST_CASES)
+ @skipUnless(has_einops, "Requires einops")
+ def test_sample_intermediates(self, model_params, controlnet_params, input_shape):
+ model = DiffusionModelUNet(**model_params)
+ controlnet = ControlNet(**controlnet_params)
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
+ model.to(device)
+ model.eval()
+ controlnet.to(device)
+ controlnet.eval()
+ noise = torch.randn(input_shape).to(device)
+ mask = torch.randn(input_shape).to(device)
+ scheduler = DDPMScheduler(num_train_timesteps=10)
+ inferer = ControlNetDiffusionInferer(scheduler=scheduler)
+ scheduler.set_timesteps(num_inference_steps=10)
+ sample, intermediates = inferer.sample(
+ input_noise=noise,
+ diffusion_model=model,
+ scheduler=scheduler,
+ controlnet=controlnet,
+ cn_cond=mask,
+ save_intermediates=True,
+ intermediate_steps=1,
+ )
+ self.assertEqual(len(intermediates), 10)
+
+ @parameterized.expand(CNDM_TEST_CASES)
+ @skipUnless(has_einops, "Requires einops")
+ def test_ddpm_sampler(self, model_params, controlnet_params, input_shape):
+ model = DiffusionModelUNet(**model_params)
+ controlnet = ControlNet(**controlnet_params)
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
+ model.to(device)
+ model.eval()
+ controlnet.to(device)
+ controlnet.eval()
+ mask = torch.randn(input_shape).to(device)
+ noise = torch.randn(input_shape).to(device)
+ scheduler = DDPMScheduler(num_train_timesteps=1000)
+ inferer = ControlNetDiffusionInferer(scheduler=scheduler)
+ scheduler.set_timesteps(num_inference_steps=10)
+ sample, intermediates = inferer.sample(
+ input_noise=noise,
+ diffusion_model=model,
+ scheduler=scheduler,
+ controlnet=controlnet,
+ cn_cond=mask,
+ save_intermediates=True,
+ intermediate_steps=1,
+ )
+ self.assertEqual(len(intermediates), 10)
+
+ @parameterized.expand(CNDM_TEST_CASES)
+ @skipUnless(has_einops, "Requires einops")
+ def test_ddim_sampler(self, model_params, controlnet_params, input_shape):
+ model = DiffusionModelUNet(**model_params)
+ controlnet = ControlNet(**controlnet_params)
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
+ model.to(device)
+ model.eval()
+ controlnet.to(device)
+ controlnet.eval()
+ mask = torch.randn(input_shape).to(device)
+ noise = torch.randn(input_shape).to(device)
+ scheduler = DDIMScheduler(num_train_timesteps=1000)
+ inferer = ControlNetDiffusionInferer(scheduler=scheduler)
+ scheduler.set_timesteps(num_inference_steps=10)
+ sample, intermediates = inferer.sample(
+ input_noise=noise,
+ diffusion_model=model,
+ scheduler=scheduler,
+ controlnet=controlnet,
+ cn_cond=mask,
+ save_intermediates=True,
+ intermediate_steps=1,
+ )
+ self.assertEqual(len(intermediates), 10)
+
+ @parameterized.expand(CNDM_TEST_CASES)
+ @skipUnless(has_einops, "Requires einops")
+ def test_sampler_conditioned(self, model_params, controlnet_params, input_shape):
+ model_params["with_conditioning"] = True
+ model_params["cross_attention_dim"] = 3
+ model = DiffusionModelUNet(**model_params)
+ controlnet = ControlNet(**controlnet_params)
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
+ model.to(device)
+ model.eval()
+ controlnet.to(device)
+ controlnet.eval()
+ mask = torch.randn(input_shape).to(device)
+ noise = torch.randn(input_shape).to(device)
+ scheduler = DDIMScheduler(num_train_timesteps=1000)
+ inferer = ControlNetDiffusionInferer(scheduler=scheduler)
+ scheduler.set_timesteps(num_inference_steps=10)
+ conditioning = torch.randn([input_shape[0], 1, 3]).to(device)
+ sample, intermediates = inferer.sample(
+ input_noise=noise,
+ diffusion_model=model,
+ controlnet=controlnet,
+ cn_cond=mask,
+ scheduler=scheduler,
+ save_intermediates=True,
+ intermediate_steps=1,
+ conditioning=conditioning,
+ )
+ self.assertEqual(len(intermediates), 10)
+
+ @parameterized.expand(CNDM_TEST_CASES)
+ @skipUnless(has_einops, "Requires einops")
+ def test_get_likelihood(self, model_params, controlnet_params, input_shape):
+ model = DiffusionModelUNet(**model_params)
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
+ model.to(device)
+ model.eval()
+ controlnet = ControlNet(**controlnet_params)
+ controlnet.to(device)
+ controlnet.eval()
+ input = torch.randn(input_shape).to(device)
+ mask = torch.randn(input_shape).to(device)
+ scheduler = DDPMScheduler(num_train_timesteps=10)
+ inferer = ControlNetDiffusionInferer(scheduler=scheduler)
+ scheduler.set_timesteps(num_inference_steps=10)
+ likelihood, intermediates = inferer.get_likelihood(
+ inputs=input,
+ diffusion_model=model,
+ scheduler=scheduler,
+ controlnet=controlnet,
+ cn_cond=mask,
+ save_intermediates=True,
+ )
+ self.assertEqual(intermediates[0].shape, input.shape)
+ self.assertEqual(likelihood.shape[0], input.shape[0])
+
+ @unittest.skipUnless(has_scipy, "Requires scipy library.")
+ def test_normal_cdf(self):
+ from scipy.stats import norm
+
+ scheduler = DDPMScheduler(num_train_timesteps=10)
+ inferer = ControlNetDiffusionInferer(scheduler=scheduler)
+ x = torch.linspace(-10, 10, 20)
+ cdf_approx = inferer._approx_standard_normal_cdf(x)
+ cdf_true = norm.cdf(x)
+ torch.testing.assert_allclose(cdf_approx, cdf_true, atol=1e-3, rtol=1e-5)
+
+ @parameterized.expand(CNDM_TEST_CASES)
+ @skipUnless(has_einops, "Requires einops")
+ def test_sampler_conditioned_concat(self, model_params, controlnet_params, input_shape):
+ # copy the model_params dict to prevent from modifying test cases
+ model_params = model_params.copy()
+ n_concat_channel = 2
+ model_params["in_channels"] = model_params["in_channels"] + n_concat_channel
+ model_params["cross_attention_dim"] = None
+ model_params["with_conditioning"] = False
+ model = DiffusionModelUNet(**model_params)
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
+ model.to(device)
+ model.eval()
+ controlnet = ControlNet(**controlnet_params)
+ controlnet.to(device)
+ controlnet.eval()
+ noise = torch.randn(input_shape).to(device)
+ mask = torch.randn(input_shape).to(device)
+ conditioning_shape = list(input_shape)
+ conditioning_shape[1] = n_concat_channel
+ conditioning = torch.randn(conditioning_shape).to(device)
+ scheduler = DDIMScheduler(num_train_timesteps=1000)
+ inferer = ControlNetDiffusionInferer(scheduler=scheduler)
+ scheduler.set_timesteps(num_inference_steps=10)
+ sample, intermediates = inferer.sample(
+ input_noise=noise,
+ diffusion_model=model,
+ controlnet=controlnet,
+ cn_cond=mask,
+ scheduler=scheduler,
+ save_intermediates=True,
+ intermediate_steps=1,
+ conditioning=conditioning,
+ mode="concat",
+ )
+ self.assertEqual(len(intermediates), 10)
+
+
+class LatentControlNetTestDiffusionSamplingInferer(unittest.TestCase):
+ @parameterized.expand(LATENT_CNDM_TEST_CASES)
+ @skipUnless(has_einops, "Requires einops")
+ def test_prediction_shape(
+ self,
+ ae_model_type,
+ autoencoder_params,
+ dm_model_type,
+ stage_2_params,
+ controlnet_params,
+ input_shape,
+ latent_shape,
+ ):
+ stage_1 = None
+
+ if ae_model_type == "AutoencoderKL":
+ stage_1 = AutoencoderKL(**autoencoder_params)
+ if ae_model_type == "VQVAE":
+ stage_1 = VQVAE(**autoencoder_params)
+ if dm_model_type == "SPADEDiffusionModelUNet":
+ stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
+ else:
+ stage_2 = DiffusionModelUNet(**stage_2_params)
+ controlnet = ControlNet(**controlnet_params)
+
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
+ stage_1.to(device)
+ stage_2.to(device)
+ controlnet.to(device)
+ stage_1.eval()
+ stage_2.eval()
+ controlnet.eval()
+
+ input = torch.randn(input_shape).to(device)
+ mask = torch.randn(input_shape).to(device)
+ noise = torch.randn(latent_shape).to(device)
+ scheduler = DDPMScheduler(num_train_timesteps=10)
+ inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
+ scheduler.set_timesteps(num_inference_steps=10)
+ timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long()
+
+ if dm_model_type == "SPADEDiffusionModelUNet":
+ input_shape_seg = list(input_shape)
+ if "label_nc" in stage_2_params.keys():
+ input_shape_seg[1] = stage_2_params["label_nc"]
+ else:
+ input_shape_seg[1] = autoencoder_params["label_nc"]
+ input_seg = torch.randn(input_shape_seg).to(device)
+ prediction = inferer(
+ inputs=input,
+ autoencoder_model=stage_1,
+ diffusion_model=stage_2,
+ controlnet=controlnet,
+ cn_cond=mask,
+ seg=input_seg,
+ noise=noise,
+ timesteps=timesteps,
+ )
+ else:
+ prediction = inferer(
+ inputs=input,
+ autoencoder_model=stage_1,
+ diffusion_model=stage_2,
+ noise=noise,
+ timesteps=timesteps,
+ controlnet=controlnet,
+ cn_cond=mask,
+ )
+ self.assertEqual(prediction.shape, latent_shape)
+
+ @parameterized.expand(LATENT_CNDM_TEST_CASES)
+ @skipUnless(has_einops, "Requires einops")
+ def test_sample_shape(
+ self,
+ ae_model_type,
+ autoencoder_params,
+ dm_model_type,
+ stage_2_params,
+ controlnet_params,
+ input_shape,
+ latent_shape,
+ ):
+ stage_1 = None
+
+ if ae_model_type == "AutoencoderKL":
+ stage_1 = AutoencoderKL(**autoencoder_params)
+ if ae_model_type == "VQVAE":
+ stage_1 = VQVAE(**autoencoder_params)
+ if dm_model_type == "SPADEDiffusionModelUNet":
+ stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
+ else:
+ stage_2 = DiffusionModelUNet(**stage_2_params)
+ controlnet = ControlNet(**controlnet_params)
+
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
+ stage_1.to(device)
+ stage_2.to(device)
+ controlnet.to(device)
+ stage_1.eval()
+ stage_2.eval()
+ controlnet.eval()
+
+ noise = torch.randn(latent_shape).to(device)
+ mask = torch.randn(input_shape).to(device)
+ scheduler = DDPMScheduler(num_train_timesteps=10)
+ inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
+ scheduler.set_timesteps(num_inference_steps=10)
+
+ if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet":
+ input_shape_seg = list(input_shape)
+ if "label_nc" in stage_2_params.keys():
+ input_shape_seg[1] = stage_2_params["label_nc"]
+ else:
+ input_shape_seg[1] = autoencoder_params["label_nc"]
+ input_seg = torch.randn(input_shape_seg).to(device)
+ sample = inferer.sample(
+ input_noise=noise,
+ autoencoder_model=stage_1,
+ diffusion_model=stage_2,
+ controlnet=controlnet,
+ cn_cond=mask,
+ scheduler=scheduler,
+ seg=input_seg,
+ )
+ else:
+ sample = inferer.sample(
+ input_noise=noise,
+ autoencoder_model=stage_1,
+ diffusion_model=stage_2,
+ scheduler=scheduler,
+ controlnet=controlnet,
+ cn_cond=mask,
+ )
+ self.assertEqual(sample.shape, input_shape)
+
+ @parameterized.expand(LATENT_CNDM_TEST_CASES)
+ @skipUnless(has_einops, "Requires einops")
+ def test_sample_intermediates(
+ self,
+ ae_model_type,
+ autoencoder_params,
+ dm_model_type,
+ stage_2_params,
+ controlnet_params,
+ input_shape,
+ latent_shape,
+ ):
+ stage_1 = None
+
+ if ae_model_type == "AutoencoderKL":
+ stage_1 = AutoencoderKL(**autoencoder_params)
+ if ae_model_type == "VQVAE":
+ stage_1 = VQVAE(**autoencoder_params)
+ if ae_model_type == "SPADEAutoencoderKL":
+ stage_1 = SPADEAutoencoderKL(**autoencoder_params)
+ if dm_model_type == "SPADEDiffusionModelUNet":
+ stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
+ else:
+ stage_2 = DiffusionModelUNet(**stage_2_params)
+ controlnet = ControlNet(**controlnet_params)
+
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
+ stage_1.to(device)
+ stage_2.to(device)
+ controlnet.to(device)
+ stage_1.eval()
+ stage_2.eval()
+ controlnet.eval()
+
+ noise = torch.randn(latent_shape).to(device)
+ mask = torch.randn(input_shape).to(device)
+ scheduler = DDPMScheduler(num_train_timesteps=10)
+ inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
+ scheduler.set_timesteps(num_inference_steps=10)
+
+ if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet":
+ input_shape_seg = list(input_shape)
+ if "label_nc" in stage_2_params.keys():
+ input_shape_seg[1] = stage_2_params["label_nc"]
+ else:
+ input_shape_seg[1] = autoencoder_params["label_nc"]
+ input_seg = torch.randn(input_shape_seg).to(device)
+ sample = inferer.sample(
+ input_noise=noise,
+ autoencoder_model=stage_1,
+ diffusion_model=stage_2,
+ scheduler=scheduler,
+ seg=input_seg,
+ controlnet=controlnet,
+ cn_cond=mask,
+ )
+
+ # TODO: this isn't correct, should the above produce intermediates as well?
+ # This test has always passed so is this branch not being used?
+ intermediates = None
+ else:
+ sample, intermediates = inferer.sample(
+ input_noise=noise,
+ autoencoder_model=stage_1,
+ diffusion_model=stage_2,
+ scheduler=scheduler,
+ save_intermediates=True,
+ intermediate_steps=1,
+ controlnet=controlnet,
+ cn_cond=mask,
+ )
+
+ self.assertEqual(len(intermediates), 10)
+ self.assertEqual(intermediates[0].shape, input_shape)
+
+ @parameterized.expand(LATENT_CNDM_TEST_CASES)
+ @skipUnless(has_einops, "Requires einops")
+ def test_get_likelihoods(
+ self,
+ ae_model_type,
+ autoencoder_params,
+ dm_model_type,
+ stage_2_params,
+ controlnet_params,
+ input_shape,
+ latent_shape,
+ ):
+ stage_1 = None
+
+ if ae_model_type == "AutoencoderKL":
+ stage_1 = AutoencoderKL(**autoencoder_params)
+ if ae_model_type == "VQVAE":
+ stage_1 = VQVAE(**autoencoder_params)
+ if ae_model_type == "SPADEAutoencoderKL":
+ stage_1 = SPADEAutoencoderKL(**autoencoder_params)
+ if dm_model_type == "SPADEDiffusionModelUNet":
+ stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
+ else:
+ stage_2 = DiffusionModelUNet(**stage_2_params)
+ controlnet = ControlNet(**controlnet_params)
+
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
+ stage_1.to(device)
+ stage_2.to(device)
+ controlnet.to(device)
+ stage_1.eval()
+ stage_2.eval()
+ controlnet.eval()
+
+ input = torch.randn(input_shape).to(device)
+ mask = torch.randn(input_shape).to(device)
+ scheduler = DDPMScheduler(num_train_timesteps=10)
+ inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
+ scheduler.set_timesteps(num_inference_steps=10)
+
+ if dm_model_type == "SPADEDiffusionModelUNet":
+ input_shape_seg = list(input_shape)
+ if "label_nc" in stage_2_params.keys():
+ input_shape_seg[1] = stage_2_params["label_nc"]
+ else:
+ input_shape_seg[1] = autoencoder_params["label_nc"]
+ input_seg = torch.randn(input_shape_seg).to(device)
+ sample, intermediates = inferer.get_likelihood(
+ inputs=input,
+ autoencoder_model=stage_1,
+ diffusion_model=stage_2,
+ controlnet=controlnet,
+ cn_cond=mask,
+ scheduler=scheduler,
+ save_intermediates=True,
+ seg=input_seg,
+ )
+ else:
+ sample, intermediates = inferer.get_likelihood(
+ inputs=input,
+ autoencoder_model=stage_1,
+ diffusion_model=stage_2,
+ scheduler=scheduler,
+ controlnet=controlnet,
+ cn_cond=mask,
+ save_intermediates=True,
+ )
+ self.assertEqual(len(intermediates), 10)
+ self.assertEqual(intermediates[0].shape, latent_shape)
+
+ @parameterized.expand(LATENT_CNDM_TEST_CASES)
+ @skipUnless(has_einops, "Requires einops")
+ def test_resample_likelihoods(
+ self,
+ ae_model_type,
+ autoencoder_params,
+ dm_model_type,
+ stage_2_params,
+ controlnet_params,
+ input_shape,
+ latent_shape,
+ ):
+ stage_1 = None
+
+ if ae_model_type == "AutoencoderKL":
+ stage_1 = AutoencoderKL(**autoencoder_params)
+ if ae_model_type == "VQVAE":
+ stage_1 = VQVAE(**autoencoder_params)
+ if ae_model_type == "SPADEAutoencoderKL":
+ stage_1 = SPADEAutoencoderKL(**autoencoder_params)
+ if dm_model_type == "SPADEDiffusionModelUNet":
+ stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
+ else:
+ stage_2 = DiffusionModelUNet(**stage_2_params)
+ controlnet = ControlNet(**controlnet_params)
+
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
+ stage_1.to(device)
+ stage_2.to(device)
+ controlnet.to(device)
+ stage_1.eval()
+ stage_2.eval()
+ controlnet.eval()
+
+ input = torch.randn(input_shape).to(device)
+ mask = torch.randn(input_shape).to(device)
+ scheduler = DDPMScheduler(num_train_timesteps=10)
+ inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
+ scheduler.set_timesteps(num_inference_steps=10)
+
+ if dm_model_type == "SPADEDiffusionModelUNet":
+ input_shape_seg = list(input_shape)
+ if "label_nc" in stage_2_params.keys():
+ input_shape_seg[1] = stage_2_params["label_nc"]
+ else:
+ input_shape_seg[1] = autoencoder_params["label_nc"]
+ input_seg = torch.randn(input_shape_seg).to(device)
+ sample, intermediates = inferer.get_likelihood(
+ inputs=input,
+ autoencoder_model=stage_1,
+ diffusion_model=stage_2,
+ scheduler=scheduler,
+ controlnet=controlnet,
+ cn_cond=mask,
+ save_intermediates=True,
+ resample_latent_likelihoods=True,
+ seg=input_seg,
+ )
+ else:
+ sample, intermediates = inferer.get_likelihood(
+ inputs=input,
+ autoencoder_model=stage_1,
+ diffusion_model=stage_2,
+ scheduler=scheduler,
+ controlnet=controlnet,
+ cn_cond=mask,
+ save_intermediates=True,
+ resample_latent_likelihoods=True,
+ )
+ self.assertEqual(len(intermediates), 10)
+ self.assertEqual(intermediates[0].shape[2:], input_shape[2:])
+
+ @parameterized.expand(LATENT_CNDM_TEST_CASES)
+ @skipUnless(has_einops, "Requires einops")
+ def test_prediction_shape_conditioned_concat(
+ self,
+ ae_model_type,
+ autoencoder_params,
+ dm_model_type,
+ stage_2_params,
+ controlnet_params,
+ input_shape,
+ latent_shape,
+ ):
+ stage_1 = None
+
+ if ae_model_type == "AutoencoderKL":
+ stage_1 = AutoencoderKL(**autoencoder_params)
+ if ae_model_type == "VQVAE":
+ stage_1 = VQVAE(**autoencoder_params)
+ if ae_model_type == "SPADEAutoencoderKL":
+ stage_1 = SPADEAutoencoderKL(**autoencoder_params)
+ stage_2_params = stage_2_params.copy()
+ n_concat_channel = 3
+ stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel
+ if dm_model_type == "SPADEDiffusionModelUNet":
+ stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
+ else:
+ stage_2 = DiffusionModelUNet(**stage_2_params)
+ controlnet = ControlNet(**controlnet_params)
+
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
+ stage_1.to(device)
+ stage_2.to(device)
+ controlnet.to(device)
+ stage_1.eval()
+ stage_2.eval()
+ controlnet.eval()
+
+ input = torch.randn(input_shape).to(device)
+ mask = torch.randn(input_shape).to(device)
+ noise = torch.randn(latent_shape).to(device)
+ conditioning_shape = list(latent_shape)
+ conditioning_shape[1] = n_concat_channel
+ conditioning = torch.randn(conditioning_shape).to(device)
+
+ scheduler = DDPMScheduler(num_train_timesteps=10)
+ inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
+ scheduler.set_timesteps(num_inference_steps=10)
+
+ timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long()
+
+ if dm_model_type == "SPADEDiffusionModelUNet":
+ input_shape_seg = list(input_shape)
+ if "label_nc" in stage_2_params.keys():
+ input_shape_seg[1] = stage_2_params["label_nc"]
+ else:
+ input_shape_seg[1] = autoencoder_params["label_nc"]
+ input_seg = torch.randn(input_shape_seg).to(device)
+ prediction = inferer(
+ inputs=input,
+ autoencoder_model=stage_1,
+ diffusion_model=stage_2,
+ noise=noise,
+ controlnet=controlnet,
+ cn_cond=mask,
+ timesteps=timesteps,
+ condition=conditioning,
+ mode="concat",
+ seg=input_seg,
+ )
+ else:
+ prediction = inferer(
+ inputs=input,
+ autoencoder_model=stage_1,
+ diffusion_model=stage_2,
+ noise=noise,
+ controlnet=controlnet,
+ cn_cond=mask,
+ timesteps=timesteps,
+ condition=conditioning,
+ mode="concat",
+ )
+ self.assertEqual(prediction.shape, latent_shape)
+
+ @parameterized.expand(LATENT_CNDM_TEST_CASES)
+ @skipUnless(has_einops, "Requires einops")
+ def test_sample_shape_conditioned_concat(
+ self,
+ ae_model_type,
+ autoencoder_params,
+ dm_model_type,
+ stage_2_params,
+ controlnet_params,
+ input_shape,
+ latent_shape,
+ ):
+ stage_1 = None
+
+ if ae_model_type == "AutoencoderKL":
+ stage_1 = AutoencoderKL(**autoencoder_params)
+ if ae_model_type == "VQVAE":
+ stage_1 = VQVAE(**autoencoder_params)
+ if ae_model_type == "SPADEAutoencoderKL":
+ stage_1 = SPADEAutoencoderKL(**autoencoder_params)
+ stage_2_params = stage_2_params.copy()
+ n_concat_channel = 3
+ stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel
+ if dm_model_type == "SPADEDiffusionModelUNet":
+ stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
+ else:
+ stage_2 = DiffusionModelUNet(**stage_2_params)
+ controlnet = ControlNet(**controlnet_params)
+
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
+ stage_1.to(device)
+ stage_2.to(device)
+ controlnet.to(device)
+ stage_1.eval()
+ stage_2.eval()
+ controlnet.eval()
+
+ noise = torch.randn(latent_shape).to(device)
+ mask = torch.randn(input_shape).to(device)
+ conditioning_shape = list(latent_shape)
+ conditioning_shape[1] = n_concat_channel
+ conditioning = torch.randn(conditioning_shape).to(device)
+
+ scheduler = DDPMScheduler(num_train_timesteps=10)
+ inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
+ scheduler.set_timesteps(num_inference_steps=10)
+
+ if dm_model_type == "SPADEDiffusionModelUNet":
+ input_shape_seg = list(input_shape)
+ if "label_nc" in stage_2_params.keys():
+ input_shape_seg[1] = stage_2_params["label_nc"]
+ else:
+ input_shape_seg[1] = autoencoder_params["label_nc"]
+ input_seg = torch.randn(input_shape_seg).to(device)
+ sample = inferer.sample(
+ input_noise=noise,
+ autoencoder_model=stage_1,
+ diffusion_model=stage_2,
+ controlnet=controlnet,
+ cn_cond=mask,
+ scheduler=scheduler,
+ conditioning=conditioning,
+ mode="concat",
+ seg=input_seg,
+ )
+ else:
+ sample = inferer.sample(
+ input_noise=noise,
+ autoencoder_model=stage_1,
+ diffusion_model=stage_2,
+ controlnet=controlnet,
+ cn_cond=mask,
+ scheduler=scheduler,
+ conditioning=conditioning,
+ mode="concat",
+ )
+ self.assertEqual(sample.shape, input_shape)
+
+ @parameterized.expand(LATENT_CNDM_TEST_CASES_DIFF_SHAPES)
+ @skipUnless(has_einops, "Requires einops")
+ def test_sample_shape_different_latents(
+ self,
+ ae_model_type,
+ autoencoder_params,
+ dm_model_type,
+ stage_2_params,
+ controlnet_params,
+ input_shape,
+ latent_shape,
+ ):
+ stage_1 = None
+
+ if ae_model_type == "AutoencoderKL":
+ stage_1 = AutoencoderKL(**autoencoder_params)
+ if ae_model_type == "VQVAE":
+ stage_1 = VQVAE(**autoencoder_params)
+ if ae_model_type == "SPADEAutoencoderKL":
+ stage_1 = SPADEAutoencoderKL(**autoencoder_params)
+ if dm_model_type == "SPADEDiffusionModelUNet":
+ stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
+ else:
+ stage_2 = DiffusionModelUNet(**stage_2_params)
+ controlnet = ControlNet(**controlnet_params)
+
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
+ stage_1.to(device)
+ stage_2.to(device)
+ controlnet.to(device)
+ stage_1.eval()
+ stage_2.eval()
+ controlnet.eval()
+
+ input = torch.randn(input_shape).to(device)
+ noise = torch.randn(latent_shape).to(device)
+ mask = torch.randn(input_shape).to(device)
+ scheduler = DDPMScheduler(num_train_timesteps=10)
+ # We infer the VAE shape
+ autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]) - 1)) for i in input_shape[2:]]
+ inferer = ControlNetLatentDiffusionInferer(
+ scheduler=scheduler,
+ scale_factor=1.0,
+ ldm_latent_shape=list(latent_shape[2:]),
+ autoencoder_latent_shape=autoencoder_latent_shape,
+ )
+ scheduler.set_timesteps(num_inference_steps=10)
+
+ timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long()
+
+ if dm_model_type == "SPADEDiffusionModelUNet":
+ input_shape_seg = list(input_shape)
+ if "label_nc" in stage_2_params.keys():
+ input_shape_seg[1] = stage_2_params["label_nc"]
+ else:
+ input_shape_seg[1] = autoencoder_params["label_nc"]
+ input_seg = torch.randn(input_shape_seg).to(device)
+ prediction = inferer(
+ inputs=input,
+ autoencoder_model=stage_1,
+ diffusion_model=stage_2,
+ controlnet=controlnet,
+ cn_cond=mask,
+ noise=noise,
+ timesteps=timesteps,
+ seg=input_seg,
+ )
+ else:
+ prediction = inferer(
+ inputs=input,
+ autoencoder_model=stage_1,
+ diffusion_model=stage_2,
+ noise=noise,
+ controlnet=controlnet,
+ cn_cond=mask,
+ timesteps=timesteps,
+ )
+ self.assertEqual(prediction.shape, latent_shape)
+
+ @skipUnless(has_einops, "Requires einops")
+ def test_incompatible_spade_setup(self):
+ stage_1 = SPADEAutoencoderKL(
+ spatial_dims=2,
+ label_nc=6,
+ in_channels=1,
+ out_channels=1,
+ channels=(4, 4),
+ latent_channels=3,
+ attention_levels=[False, False],
+ num_res_blocks=1,
+ with_encoder_nonlocal_attn=False,
+ with_decoder_nonlocal_attn=False,
+ norm_num_groups=4,
+ )
+ stage_2 = SPADEDiffusionModelUNet(
+ spatial_dims=2,
+ label_nc=3,
+ in_channels=3,
+ out_channels=3,
+ channels=[4, 4],
+ norm_num_groups=4,
+ attention_levels=[False, False],
+ num_res_blocks=1,
+ num_head_channels=4,
+ )
+ controlnet = ControlNet(
+ spatial_dims=2,
+ in_channels=1,
+ channels=[4, 4],
+ norm_num_groups=4,
+ attention_levels=[False, False],
+ num_res_blocks=1,
+ num_head_channels=4,
+ conditioning_embedding_num_channels=[16],
+ )
+
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
+ stage_1.to(device)
+ stage_2.to(device)
+ controlnet.to(device)
+ controlnet.to(device)
+ stage_1.eval()
+ stage_2.eval()
+ controlnet.eval()
+ noise = torch.randn((1, 3, 4, 4)).to(device)
+ mask = torch.randn((1, 1, 4, 4)).to(device)
+ input_seg = torch.randn((1, 3, 8, 8)).to(device)
+ scheduler = DDPMScheduler(num_train_timesteps=10)
+ inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
+ scheduler.set_timesteps(num_inference_steps=10)
+
+ with self.assertRaises(ValueError):
+ _ = inferer.sample(
+ input_noise=noise,
+ autoencoder_model=stage_1,
+ diffusion_model=stage_2,
+ scheduler=scheduler,
+ controlnet=controlnet,
+ cn_cond=mask,
+ seg=input_seg,
+ )
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_controlnet_maisi.py b/tests/test_controlnet_maisi.py
new file mode 100644
index 0000000000..bfdf25ec6e
--- /dev/null
+++ b/tests/test_controlnet_maisi.py
@@ -0,0 +1,171 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import unittest
+from unittest import skipUnless
+
+import torch
+from parameterized import parameterized
+
+from monai.apps.generation.maisi.networks.controlnet_maisi import ControlNetMaisi
+from monai.networks import eval_mode
+from monai.utils import optional_import
+from tests.utils import SkipIfBeforePyTorchVersion
+
+_, has_einops = optional_import("einops")
+
+TEST_CASES = [
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "num_res_blocks": 1,
+ "num_channels": (8, 8, 8),
+ "attention_levels": (False, False, True),
+ "num_head_channels": 8,
+ "norm_num_groups": 8,
+ "conditioning_embedding_in_channels": 1,
+ "conditioning_embedding_num_channels": (8, 8),
+ "use_checkpointing": False,
+ },
+ 6,
+ (1, 8, 4, 4),
+ ],
+ [
+ {
+ "spatial_dims": 3,
+ "in_channels": 1,
+ "num_res_blocks": 1,
+ "num_channels": (8, 8, 8),
+ "attention_levels": (False, False, True),
+ "num_head_channels": 8,
+ "norm_num_groups": 8,
+ "conditioning_embedding_in_channels": 1,
+ "conditioning_embedding_num_channels": (8, 8),
+ "use_checkpointing": True,
+ },
+ 6,
+ (1, 8, 4, 4, 4),
+ ],
+]
+
+TEST_CASES_CONDITIONAL = [
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "num_res_blocks": 1,
+ "num_channels": (8, 8, 8),
+ "attention_levels": (False, False, True),
+ "num_head_channels": 8,
+ "norm_num_groups": 8,
+ "conditioning_embedding_in_channels": 1,
+ "conditioning_embedding_num_channels": (8, 8),
+ "use_checkpointing": False,
+ "with_conditioning": True,
+ "cross_attention_dim": 2,
+ },
+ 6,
+ (1, 8, 4, 4),
+ ],
+ [
+ {
+ "spatial_dims": 3,
+ "in_channels": 1,
+ "num_res_blocks": 1,
+ "num_channels": (8, 8, 8),
+ "attention_levels": (False, False, True),
+ "num_head_channels": 8,
+ "norm_num_groups": 8,
+ "conditioning_embedding_in_channels": 1,
+ "conditioning_embedding_num_channels": (8, 8),
+ "use_checkpointing": True,
+ "with_conditioning": True,
+ "cross_attention_dim": 2,
+ },
+ 6,
+ (1, 8, 4, 4, 4),
+ ],
+]
+
+TEST_CASES_ERROR = [
+ [
+ {"spatial_dims": 2, "in_channels": 1, "with_conditioning": True, "cross_attention_dim": None},
+ "ControlNet expects dimension of the cross-attention conditioning (cross_attention_dim) "
+ "to be specified when with_conditioning=True.",
+ ],
+ [
+ {"spatial_dims": 2, "in_channels": 1, "with_conditioning": False, "cross_attention_dim": 2},
+ "ControlNet expects with_conditioning=True when specifying the cross_attention_dim.",
+ ],
+ [
+ {"spatial_dims": 2, "in_channels": 1, "num_channels": (8, 16), "norm_num_groups": 16},
+ f"ControlNet expects all channels to be a multiple of norm_num_groups, but got"
+ f" channels={(8, 16)} and norm_num_groups={16}",
+ ],
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "num_channels": (8, 16),
+ "attention_levels": (True,),
+ "norm_num_groups": 8,
+ },
+ f"ControlNet expects channels to have the same length as attention_levels, but got "
+ f"channels={(8, 16)} and attention_levels={(True,)}",
+ ],
+]
+
+
+@SkipIfBeforePyTorchVersion((2, 0))
+class TestControlNet(unittest.TestCase):
+
+ @parameterized.expand(TEST_CASES)
+ @skipUnless(has_einops, "Requires einops")
+ def test_shape_unconditioned_models(self, input_param, expected_num_down_blocks_residuals, expected_shape):
+ net = ControlNetMaisi(**input_param)
+ with eval_mode(net):
+ x = torch.rand((1, 1, 16, 16)) if input_param["spatial_dims"] == 2 else torch.rand((1, 1, 16, 16, 16))
+ timesteps = torch.randint(0, 1000, (1,)).long()
+ controlnet_cond = (
+ torch.rand((1, 1, 32, 32)) if input_param["spatial_dims"] == 2 else torch.rand((1, 1, 32, 32, 32))
+ )
+ result = net.forward(x, timesteps, controlnet_cond)
+ self.assertEqual(len(result[0]), expected_num_down_blocks_residuals)
+ self.assertEqual(result[1].shape, expected_shape)
+
+ @parameterized.expand(TEST_CASES_CONDITIONAL)
+ @skipUnless(has_einops, "Requires einops")
+ def test_shape_conditioned_models(self, input_param, expected_num_down_blocks_residuals, expected_shape):
+ net = ControlNetMaisi(**input_param)
+ with eval_mode(net):
+ x = torch.rand((1, 1, 16, 16)) if input_param["spatial_dims"] == 2 else torch.rand((1, 1, 16, 16, 16))
+ timesteps = torch.randint(0, 1000, (1,)).long()
+ controlnet_cond = (
+ torch.rand((1, 1, 32, 32)) if input_param["spatial_dims"] == 2 else torch.rand((1, 1, 32, 32, 32))
+ )
+ context = torch.randn((1, 1, input_param["cross_attention_dim"]))
+ result = net.forward(x, timesteps, controlnet_cond, context=context)
+ self.assertEqual(len(result[0]), expected_num_down_blocks_residuals)
+ self.assertEqual(result[1].shape, expected_shape)
+
+ @parameterized.expand(TEST_CASES_ERROR)
+ def test_error_input(self, input_param, expected_error):
+ with self.assertRaises(ValueError) as context: # output shape too small
+ _ = ControlNetMaisi(**input_param)
+ runtime_error = context.exception
+ self.assertEqual(str(runtime_error), expected_error)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_convert_box_points.py b/tests/test_convert_box_points.py
new file mode 100644
index 0000000000..5e3d7ee645
--- /dev/null
+++ b/tests/test_convert_box_points.py
@@ -0,0 +1,121 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import unittest
+
+import torch
+from parameterized import parameterized
+
+from monai.data.box_utils import convert_box_to_standard_mode
+from monai.transforms.spatial.array import ConvertBoxToPoints, ConvertPointsToBoxes
+from tests.utils import assert_allclose
+
+TEST_CASE_POINTS_2D = [
+ [
+ torch.tensor([[10, 20, 30, 40], [50, 60, 70, 80]]),
+ "xyxy",
+ torch.tensor([[[10, 20], [30, 20], [30, 40], [10, 40]], [[50, 60], [70, 60], [70, 80], [50, 80]]]),
+ ],
+ [torch.tensor([[10, 20, 20, 20]]), "ccwh", torch.tensor([[[0, 10], [20, 10], [20, 30], [0, 30]]])],
+]
+TEST_CASE_POINTS_3D = [
+ [
+ torch.tensor([[10, 20, 30, 40, 50, 60], [70, 80, 90, 100, 110, 120]]),
+ "xyzxyz",
+ torch.tensor(
+ [
+ [
+ [10, 20, 30],
+ [40, 20, 30],
+ [40, 50, 30],
+ [10, 50, 30],
+ [10, 20, 60],
+ [40, 20, 60],
+ [40, 50, 60],
+ [10, 50, 60],
+ ],
+ [
+ [70, 80, 90],
+ [100, 80, 90],
+ [100, 110, 90],
+ [70, 110, 90],
+ [70, 80, 120],
+ [100, 80, 120],
+ [100, 110, 120],
+ [70, 110, 120],
+ ],
+ ]
+ ),
+ ],
+ [
+ torch.tensor([[10, 20, 30, 10, 10, 10]]),
+ "cccwhd",
+ torch.tensor(
+ [
+ [
+ [5, 15, 25],
+ [15, 15, 25],
+ [15, 25, 25],
+ [5, 25, 25],
+ [5, 15, 35],
+ [15, 15, 35],
+ [15, 25, 35],
+ [5, 25, 35],
+ ]
+ ]
+ ),
+ ],
+ [
+ torch.tensor([[10, 20, 30, 40, 50, 60]]),
+ "xxyyzz",
+ torch.tensor(
+ [
+ [
+ [10, 30, 50],
+ [20, 30, 50],
+ [20, 40, 50],
+ [10, 40, 50],
+ [10, 30, 60],
+ [20, 30, 60],
+ [20, 40, 60],
+ [10, 40, 60],
+ ]
+ ]
+ ),
+ ],
+]
+
+TEST_CASES = TEST_CASE_POINTS_2D + TEST_CASE_POINTS_3D
+
+
+class TestConvertBoxToPoints(unittest.TestCase):
+
+ @parameterized.expand(TEST_CASES)
+ def test_convert_box_to_points(self, boxes, mode, expected_points):
+ transform = ConvertBoxToPoints(mode=mode)
+ converted_points = transform(boxes)
+ assert_allclose(converted_points, expected_points, type_test=False)
+
+
+class TestConvertPointsToBoxes(unittest.TestCase):
+
+ @parameterized.expand(TEST_CASES)
+ def test_convert_box_to_points(self, boxes, mode, points):
+ transform = ConvertPointsToBoxes()
+ converted_boxes = transform(points)
+ expected_boxes = convert_box_to_standard_mode(boxes, mode)
+ assert_allclose(converted_boxes, expected_boxes, type_test=False)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_convert_data_type.py b/tests/test_convert_data_type.py
index b95539f4b7..a27a05cf28 100644
--- a/tests/test_convert_data_type.py
+++ b/tests/test_convert_data_type.py
@@ -73,6 +73,7 @@
class TestTensor(torch.Tensor):
+ __test__ = False # indicate to pytest that this class is not intended for collection
pass
diff --git a/tests/test_convert_to_trt.py b/tests/test_convert_to_trt.py
index 5579539764..712d887c3b 100644
--- a/tests/test_convert_to_trt.py
+++ b/tests/test_convert_to_trt.py
@@ -20,7 +20,7 @@
from monai.networks import convert_to_trt
from monai.networks.nets import UNet
from monai.utils import optional_import
-from tests.utils import skip_if_no_cuda, skip_if_quick, skip_if_windows
+from tests.utils import SkipIfBeforeComputeCapabilityVersion, skip_if_no_cuda, skip_if_quick, skip_if_windows
_, has_torchtrt = optional_import(
"torch_tensorrt",
@@ -38,6 +38,7 @@
@skip_if_windows
@skip_if_no_cuda
@skip_if_quick
+@SkipIfBeforeComputeCapabilityVersion((7, 0))
class TestConvertToTRT(unittest.TestCase):
def setUp(self):
diff --git a/tests/test_crossattention.py b/tests/test_crossattention.py
new file mode 100644
index 0000000000..e034e42290
--- /dev/null
+++ b/tests/test_crossattention.py
@@ -0,0 +1,186 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import unittest
+from unittest import skipUnless
+
+import numpy as np
+import torch
+from parameterized import parameterized
+
+from monai.networks import eval_mode
+from monai.networks.blocks.crossattention import CrossAttentionBlock
+from monai.networks.layers.factories import RelPosEmbedding
+from monai.utils import optional_import
+from tests.utils import SkipIfBeforePyTorchVersion, assert_allclose
+
+einops, has_einops = optional_import("einops")
+
+TEST_CASE_CABLOCK = []
+for dropout_rate in np.linspace(0, 1, 4):
+ for hidden_size in [360, 480, 600, 768]:
+ for num_heads in [4, 6, 8, 12]:
+ for rel_pos_embedding in [None, RelPosEmbedding.DECOMPOSED]:
+ for input_size in [(16, 32), (8, 8, 8)]:
+ for flash_attn in [True, False]:
+ test_case = [
+ {
+ "hidden_size": hidden_size,
+ "num_heads": num_heads,
+ "dropout_rate": dropout_rate,
+ "rel_pos_embedding": rel_pos_embedding if not flash_attn else None,
+ "input_size": input_size,
+ "use_flash_attention": flash_attn,
+ },
+ (2, 512, hidden_size),
+ (2, 512, hidden_size),
+ ]
+ TEST_CASE_CABLOCK.append(test_case)
+
+
+class TestResBlock(unittest.TestCase):
+
+ @parameterized.expand(TEST_CASE_CABLOCK)
+ @skipUnless(has_einops, "Requires einops")
+ @SkipIfBeforePyTorchVersion((2, 0))
+ def test_shape(self, input_param, input_shape, expected_shape):
+ # Without flash attention
+ net = CrossAttentionBlock(**input_param)
+ with eval_mode(net):
+ result = net(torch.randn(input_shape), context=torch.randn(2, 512, input_param["hidden_size"]))
+ self.assertEqual(result.shape, expected_shape)
+
+ def test_ill_arg(self):
+ with self.assertRaises(ValueError):
+ CrossAttentionBlock(hidden_size=128, num_heads=12, dropout_rate=6.0)
+
+ with self.assertRaises(ValueError):
+ CrossAttentionBlock(hidden_size=620, num_heads=8, dropout_rate=0.4)
+
+ @SkipIfBeforePyTorchVersion((2, 0))
+ def test_save_attn_with_flash_attention(self):
+ with self.assertRaises(ValueError):
+ CrossAttentionBlock(
+ hidden_size=128, num_heads=3, dropout_rate=0.1, use_flash_attention=True, save_attn=True
+ )
+
+ @SkipIfBeforePyTorchVersion((2, 0))
+ def test_rel_pos_embedding_with_flash_attention(self):
+ with self.assertRaises(ValueError):
+ CrossAttentionBlock(
+ hidden_size=128,
+ num_heads=3,
+ dropout_rate=0.1,
+ use_flash_attention=True,
+ save_attn=False,
+ rel_pos_embedding=RelPosEmbedding.DECOMPOSED,
+ )
+
+ @skipUnless(has_einops, "Requires einops")
+ def test_attention_dim_not_multiple_of_heads(self):
+ with self.assertRaises(ValueError):
+ CrossAttentionBlock(hidden_size=128, num_heads=3, dropout_rate=0.1)
+
+ @skipUnless(has_einops, "Requires einops")
+ def test_inner_dim_different(self):
+ CrossAttentionBlock(hidden_size=128, num_heads=4, dropout_rate=0.1, dim_head=30)
+
+ def test_causal_no_sequence_length(self):
+ with self.assertRaises(ValueError):
+ CrossAttentionBlock(hidden_size=128, num_heads=4, dropout_rate=0.1, causal=True)
+
+ @skipUnless(has_einops, "Requires einops")
+ @SkipIfBeforePyTorchVersion((2, 0))
+ def test_causal_flash_attention(self):
+ block = CrossAttentionBlock(
+ hidden_size=128,
+ num_heads=1,
+ dropout_rate=0.1,
+ causal=True,
+ sequence_length=16,
+ save_attn=False,
+ use_flash_attention=True,
+ )
+ input_shape = (1, 16, 128)
+ # Check it runs correctly
+ block(torch.randn(input_shape))
+
+ @skipUnless(has_einops, "Requires einops")
+ def test_causal(self):
+ block = CrossAttentionBlock(
+ hidden_size=128, num_heads=1, dropout_rate=0.1, causal=True, sequence_length=16, save_attn=True
+ )
+ input_shape = (1, 16, 128)
+ block(torch.randn(input_shape))
+ # check upper triangular part of the attention matrix is zero
+ assert torch.triu(block.att_mat, diagonal=1).sum() == 0
+
+ @skipUnless(has_einops, "Requires einops")
+ def test_context_input(self):
+ block = CrossAttentionBlock(
+ hidden_size=128, num_heads=1, dropout_rate=0.1, causal=True, sequence_length=16, context_input_size=12
+ )
+ input_shape = (1, 16, 128)
+ block(torch.randn(input_shape), context=torch.randn(1, 3, 12))
+
+ @skipUnless(has_einops, "Requires einops")
+ def test_context_wrong_input_size(self):
+ block = CrossAttentionBlock(
+ hidden_size=128, num_heads=1, dropout_rate=0.1, causal=True, sequence_length=16, context_input_size=12
+ )
+ input_shape = (1, 16, 128)
+ with self.assertRaises(RuntimeError):
+ block(torch.randn(input_shape), context=torch.randn(1, 3, 24))
+
+ @skipUnless(has_einops, "Requires einops")
+ def test_access_attn_matrix(self):
+ # input format
+ hidden_size = 128
+ num_heads = 2
+ dropout_rate = 0
+ input_shape = (2, 256, hidden_size)
+
+ # be not able to access the matrix
+ no_matrix_acess_blk = CrossAttentionBlock(
+ hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate
+ )
+ no_matrix_acess_blk(torch.randn(input_shape))
+ assert isinstance(no_matrix_acess_blk.att_mat, torch.Tensor)
+ # no of elements is zero
+ assert no_matrix_acess_blk.att_mat.nelement() == 0
+
+ # be able to acess the attention matrix.
+ matrix_acess_blk = CrossAttentionBlock(
+ hidden_size=hidden_size, num_heads=num_heads, dropout_rate=dropout_rate, save_attn=True
+ )
+ matrix_acess_blk(torch.randn(input_shape))
+ assert matrix_acess_blk.att_mat.shape == (input_shape[0], input_shape[0], input_shape[1], input_shape[1])
+
+ @parameterized.expand([[True], [False]])
+ @skipUnless(has_einops, "Requires einops")
+ @SkipIfBeforePyTorchVersion((2, 0))
+ def test_flash_attention(self, causal):
+ input_param = {"hidden_size": 128, "num_heads": 1, "causal": causal, "sequence_length": 16 if causal else None}
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
+ block_w_flash_attention = CrossAttentionBlock(**input_param, use_flash_attention=True).to(device)
+ block_wo_flash_attention = CrossAttentionBlock(**input_param, use_flash_attention=False).to(device)
+ block_wo_flash_attention.load_state_dict(block_w_flash_attention.state_dict())
+ test_data = torch.randn(1, 16, 128).to(device)
+
+ out_1 = block_w_flash_attention(test_data)
+ out_2 = block_wo_flash_attention(test_data)
+ assert_allclose(out_1, out_2, atol=1e-4)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_cucim_dict_transform.py b/tests/test_cucim_dict_transform.py
index d2dcc6aa5f..3c5703a34c 100644
--- a/tests/test_cucim_dict_transform.py
+++ b/tests/test_cucim_dict_transform.py
@@ -80,8 +80,8 @@ class TestCuCIMDict(unittest.TestCase):
def test_tramsforms_numpy_single(self, params, input, expected):
input = {"image": input}
output = CuCIMd(keys="image", **params)(input)["image"]
- self.assertTrue(output.dtype == expected.dtype)
- self.assertTrue(isinstance(output, np.ndarray))
+ self.assertEqual(output.dtype, expected.dtype)
+ self.assertIsInstance(output, np.ndarray)
cp.testing.assert_allclose(output, expected)
@parameterized.expand(
@@ -98,8 +98,8 @@ def test_tramsforms_numpy_batch(self, params, input, expected):
input = {"image": input[cp.newaxis, ...]}
expected = expected[cp.newaxis, ...]
output = CuCIMd(keys="image", **params)(input)["image"]
- self.assertTrue(output.dtype == expected.dtype)
- self.assertTrue(isinstance(output, np.ndarray))
+ self.assertEqual(output.dtype, expected.dtype)
+ self.assertIsInstance(output, np.ndarray)
cp.testing.assert_allclose(output, expected)
@parameterized.expand(
@@ -116,8 +116,8 @@ def test_tramsforms_cupy_single(self, params, input, expected):
input = {"image": cp.asarray(input)}
expected = cp.asarray(expected)
output = CuCIMd(keys="image", **params)(input)["image"]
- self.assertTrue(output.dtype == expected.dtype)
- self.assertTrue(isinstance(output, cp.ndarray))
+ self.assertEqual(output.dtype, expected.dtype)
+ self.assertIsInstance(output, cp.ndarray)
cp.testing.assert_allclose(output, expected)
@parameterized.expand(
@@ -134,8 +134,8 @@ def test_tramsforms_cupy_batch(self, params, input, expected):
input = {"image": cp.asarray(input)[cp.newaxis, ...]}
expected = cp.asarray(expected)[cp.newaxis, ...]
output = CuCIMd(keys="image", **params)(input)["image"]
- self.assertTrue(output.dtype == expected.dtype)
- self.assertTrue(isinstance(output, cp.ndarray))
+ self.assertEqual(output.dtype, expected.dtype)
+ self.assertIsInstance(output, cp.ndarray)
cp.testing.assert_allclose(output, expected)
diff --git a/tests/test_cucim_transform.py b/tests/test_cucim_transform.py
index 5f16c11589..162e16b52a 100644
--- a/tests/test_cucim_transform.py
+++ b/tests/test_cucim_transform.py
@@ -79,8 +79,8 @@ class TestCuCIM(unittest.TestCase):
)
def test_tramsforms_numpy_single(self, params, input, expected):
output = CuCIM(**params)(input)
- self.assertTrue(output.dtype == expected.dtype)
- self.assertTrue(isinstance(output, np.ndarray))
+ self.assertEqual(output.dtype, expected.dtype)
+ self.assertIsInstance(output, np.ndarray)
cp.testing.assert_allclose(output, expected)
@parameterized.expand(
@@ -97,8 +97,8 @@ def test_tramsforms_numpy_batch(self, params, input, expected):
input = input[cp.newaxis, ...]
expected = expected[cp.newaxis, ...]
output = CuCIM(**params)(input)
- self.assertTrue(output.dtype == expected.dtype)
- self.assertTrue(isinstance(output, np.ndarray))
+ self.assertEqual(output.dtype, expected.dtype)
+ self.assertIsInstance(output, np.ndarray)
cp.testing.assert_allclose(output, expected)
@parameterized.expand(
@@ -115,8 +115,8 @@ def test_tramsforms_cupy_single(self, params, input, expected):
input = cp.asarray(input)
expected = cp.asarray(expected)
output = CuCIM(**params)(input)
- self.assertTrue(output.dtype == expected.dtype)
- self.assertTrue(isinstance(output, cp.ndarray))
+ self.assertEqual(output.dtype, expected.dtype)
+ self.assertIsInstance(output, cp.ndarray)
cp.testing.assert_allclose(output, expected)
@parameterized.expand(
@@ -133,8 +133,8 @@ def test_tramsforms_cupy_batch(self, params, input, expected):
input = cp.asarray(input)[cp.newaxis, ...]
expected = cp.asarray(expected)[cp.newaxis, ...]
output = CuCIM(**params)(input)
- self.assertTrue(output.dtype == expected.dtype)
- self.assertTrue(isinstance(output, cp.ndarray))
+ self.assertEqual(output.dtype, expected.dtype)
+ self.assertIsInstance(output, cp.ndarray)
cp.testing.assert_allclose(output, expected)
diff --git a/tests/test_data_stats.py b/tests/test_data_stats.py
index 05453b0694..f9b424f8e1 100644
--- a/tests/test_data_stats.py
+++ b/tests/test_data_stats.py
@@ -23,6 +23,7 @@
import torch
from parameterized import parameterized
+from monai.data.meta_tensor import MetaTensor
from monai.transforms import DataStats
TEST_CASE_1 = [
@@ -130,20 +131,55 @@
]
TEST_CASE_8 = [
+ {
+ "prefix": "test data",
+ "data_type": True,
+ "data_shape": True,
+ "value_range": True,
+ "data_value": True,
+ "additional_info": np.mean,
+ "name": "DataStats",
+ },
np.array([[0, 1], [1, 2]]),
"test data statistics:\nType: int64\nShape: (2, 2)\nValue range: (0, 2)\n"
"Value: [[0 1]\n [1 2]]\nAdditional info: 1.0\n",
]
+TEST_CASE_9 = [
+ np.array([[0, 1], [1, 2]]),
+ "test data statistics:\nType: int64\nShape: (2, 2)\nValue range: (0, 2)\n"
+ "Value: [[0 1]\n [1 2]]\n"
+ "Meta info: '(input is not a MetaTensor)'\n"
+ "Additional info: 1.0\n",
+]
+
+TEST_CASE_10 = [
+ MetaTensor(
+ torch.tensor([[0, 1], [1, 2]]),
+ affine=torch.as_tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 2, 0], [0, 0, 0, 1]], dtype=torch.float64),
+ meta={"some": "info"},
+ ),
+ "test data statistics:\nType: torch.int64\n"
+ "Shape: torch.Size([2, 2])\nValue range: (0, 2)\n"
+ "Value: tensor([[0, 1],\n [1, 2]])\n"
+ "Meta info: {'some': 'info', affine: tensor([[2., 0., 0., 0.],\n"
+ " [0., 2., 0., 0.],\n"
+ " [0., 0., 2., 0.],\n"
+ " [0., 0., 0., 1.]], dtype=torch.float64), space: RAS}\n"
+ "Additional info: 1.0\n",
+]
+
class TestDataStats(unittest.TestCase):
- @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7])
+ @parameterized.expand(
+ [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8]
+ )
def test_value(self, input_param, input_data, expected_print):
transform = DataStats(**input_param)
_ = transform(input_data)
- @parameterized.expand([TEST_CASE_8])
+ @parameterized.expand([TEST_CASE_9, TEST_CASE_10])
def test_file(self, input_data, expected_print):
with tempfile.TemporaryDirectory() as tempdir:
filename = os.path.join(tempdir, "test_data_stats.log")
@@ -158,6 +194,7 @@ def test_file(self, input_data, expected_print):
"data_shape": True,
"value_range": True,
"data_value": True,
+ "meta_info": True,
"additional_info": np.mean,
"name": name,
}
diff --git a/tests/test_data_statsd.py b/tests/test_data_statsd.py
index ef88300c10..a28a938c40 100644
--- a/tests/test_data_statsd.py
+++ b/tests/test_data_statsd.py
@@ -21,6 +21,7 @@
import torch
from parameterized import parameterized
+from monai.data.meta_tensor import MetaTensor
from monai.transforms import DataStatsd
TEST_CASE_1 = [
@@ -150,22 +151,70 @@
]
TEST_CASE_9 = [
+ {
+ "keys": "img",
+ "prefix": "test data",
+ "data_shape": True,
+ "value_range": True,
+ "data_value": True,
+ "meta_info": False,
+ "additional_info": np.mean,
+ "name": "DataStats",
+ },
{"img": np.array([[0, 1], [1, 2]])},
"test data statistics:\nType: int64\nShape: (2, 2)\nValue range: (0, 2)\n"
"Value: [[0 1]\n [1 2]]\nAdditional info: 1.0\n",
]
+TEST_CASE_10 = [
+ {"img": np.array([[0, 1], [1, 2]])},
+ "test data statistics:\nType: int64\nShape: (2, 2)\nValue range: (0, 2)\n"
+ "Value: [[0 1]\n [1 2]]\n"
+ "Meta info: '(input is not a MetaTensor)'\n"
+ "Additional info: 1.0\n",
+]
+
+TEST_CASE_11 = [
+ {
+ "img": (
+ MetaTensor(
+ torch.tensor([[0, 1], [1, 2]]),
+ affine=torch.as_tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 2, 0], [0, 0, 0, 1]], dtype=torch.float64),
+ meta={"some": "info"},
+ )
+ )
+ },
+ "test data statistics:\nType: torch.int64\n"
+ "Shape: torch.Size([2, 2])\nValue range: (0, 2)\n"
+ "Value: tensor([[0, 1],\n [1, 2]])\n"
+ "Meta info: {'some': 'info', affine: tensor([[2., 0., 0., 0.],\n"
+ " [0., 2., 0., 0.],\n"
+ " [0., 0., 2., 0.],\n"
+ " [0., 0., 0., 1.]], dtype=torch.float64), space: RAS}\n"
+ "Additional info: 1.0\n",
+]
+
class TestDataStatsd(unittest.TestCase):
@parameterized.expand(
- [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8]
+ [
+ TEST_CASE_1,
+ TEST_CASE_2,
+ TEST_CASE_3,
+ TEST_CASE_4,
+ TEST_CASE_5,
+ TEST_CASE_6,
+ TEST_CASE_7,
+ TEST_CASE_8,
+ TEST_CASE_9,
+ ]
)
def test_value(self, input_param, input_data, expected_print):
transform = DataStatsd(**input_param)
_ = transform(input_data)
- @parameterized.expand([TEST_CASE_9])
+ @parameterized.expand([TEST_CASE_10, TEST_CASE_11])
def test_file(self, input_data, expected_print):
with tempfile.TemporaryDirectory() as tempdir:
filename = os.path.join(tempdir, "test_stats.log")
@@ -180,6 +229,7 @@ def test_file(self, input_data, expected_print):
"data_shape": True,
"value_range": True,
"data_value": True,
+ "meta_info": True,
"additional_info": np.mean,
"name": name,
}
diff --git a/tests/test_dataset.py b/tests/test_dataset.py
index 1398009c63..0d37ae2efd 100644
--- a/tests/test_dataset.py
+++ b/tests/test_dataset.py
@@ -23,7 +23,7 @@
from parameterized import parameterized
from monai.data import Dataset
-from monai.transforms import Compose, LoadImaged, SimulateDelayd
+from monai.transforms import Compose, Lambda, LoadImage, LoadImaged, SimulateDelay, SimulateDelayd
from tests.test_compose import TEST_COMPOSE_LAZY_ON_CALL_LOGGING_TEST_CASES, data_from_keys
TEST_CASE_1 = [(128, 128, 128)]
@@ -99,6 +99,72 @@ def test_dataset_lazy_on_call(self):
data[0, 0:2, 0:2] = 1
+class TestTupleDataset(unittest.TestCase):
+
+ @parameterized.expand([TEST_CASE_1])
+ def test_shape(self, expected_shape):
+ test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]).astype(float), np.eye(4))
+ with tempfile.TemporaryDirectory() as tempdir:
+ nib.save(test_image, os.path.join(tempdir, "test_image1.nii.gz"))
+ nib.save(test_image, os.path.join(tempdir, "test_label1.nii.gz"))
+ nib.save(test_image, os.path.join(tempdir, "test_image2.nii.gz"))
+ nib.save(test_image, os.path.join(tempdir, "test_label2.nii.gz"))
+ test_data = [
+ (os.path.join(tempdir, "test_image1.nii.gz"), os.path.join(tempdir, "test_label1.nii.gz")),
+ (os.path.join(tempdir, "test_image2.nii.gz"), os.path.join(tempdir, "test_label2.nii.gz")),
+ ]
+
+ test_transform = Compose([LoadImage(), SimulateDelay(delay_time=1e-5)])
+
+ # Here test_transform is applied element by element for the tuple.
+ dataset = Dataset(data=test_data, transform=test_transform)
+ data1 = dataset[0]
+ data2 = dataset[1]
+
+ # Output is a list/tuple
+ self.assertTrue(isinstance(data1, (list, tuple)))
+ self.assertTrue(isinstance(data2, (list, tuple)))
+
+ # Number of elements are 2
+ self.assertEqual(len(data1), 2)
+ self.assertEqual(len(data2), 2)
+
+ # Output shapes are as expected
+ self.assertTupleEqual(data1[0].shape, expected_shape)
+ self.assertTupleEqual(data1[1].shape, expected_shape)
+ self.assertTupleEqual(data2[0].shape, expected_shape)
+ self.assertTupleEqual(data2[1].shape, expected_shape)
+
+ # Here test_transform is applied to the tuple as a whole.
+ test_transform = Compose(
+ [
+ # LoadImage creates a channel-stacked image when applied to a tuple
+ LoadImage(),
+ # Get the channel-stacked image and the label
+ Lambda(func=lambda x: (x[0].permute(2, 1, 0), x[1])),
+ ],
+ map_items=False,
+ )
+
+ dataset = Dataset(data=test_data, transform=test_transform)
+ data1 = dataset[0]
+ data2 = dataset[1]
+
+ # Output is a list/tuple
+ self.assertTrue(isinstance(data1, (list, tuple)))
+ self.assertTrue(isinstance(data2, (list, tuple)))
+
+ # Number of elements are 2
+ self.assertEqual(len(data1), 2)
+ self.assertEqual(len(data2), 2)
+
+ # Output shapes are as expected
+ self.assertTupleEqual(data1[0].shape, expected_shape)
+ self.assertTupleEqual(data1[1].shape, expected_shape)
+ self.assertTupleEqual(data2[0].shape, expected_shape)
+ self.assertTupleEqual(data2[1].shape, expected_shape)
+
+
class TestDatsesetWithLazy(unittest.TestCase):
LOGGER_NAME = "a_logger_name"
diff --git a/tests/test_decathlondataset.py b/tests/test_decathlondataset.py
index d220cd9097..70a2a6c06c 100644
--- a/tests/test_decathlondataset.py
+++ b/tests/test_decathlondataset.py
@@ -80,7 +80,7 @@ def _test_dataset(dataset):
self.assertDictEqual(properties["labels"], {"0": "background", "1": "Anterior", "2": "Posterior"})
shutil.rmtree(os.path.join(testing_dir, "Task04_Hippocampus"))
- try:
+ with self.assertRaisesRegex(RuntimeError, "^Cannot find dataset directory"):
DecathlonDataset(
root_dir=testing_dir,
task="Task04_Hippocampus",
@@ -88,9 +88,6 @@ def _test_dataset(dataset):
section="validation",
download=False,
)
- except RuntimeError as e:
- print(str(e))
- self.assertTrue(str(e).startswith("Cannot find dataset directory"))
if __name__ == "__main__":
diff --git a/tests/test_deepgrow_transforms.py b/tests/test_deepgrow_transforms.py
index a491a8004b..091d00afcd 100644
--- a/tests/test_deepgrow_transforms.py
+++ b/tests/test_deepgrow_transforms.py
@@ -141,6 +141,21 @@
DATA_12 = {"image": np.arange(27).reshape(3, 3, 3), PostFix.meta("image"): {}, "guidance": [[0, 0, 0], [0, 1, 1], 1]}
+DATA_13 = {
+ "image": np.arange(64).reshape((1, 4, 4, 4)),
+ PostFix.meta("image"): {
+ "spatial_shape": [8, 8, 4],
+ "foreground_start_coord": np.array([1, 1, 1]),
+ "foreground_end_coord": np.array([3, 3, 3]),
+ "foreground_original_shape": (1, 4, 4, 4),
+ "foreground_cropped_shape": (1, 2, 2, 2),
+ "original_affine": np.array(
+ [[[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]]
+ ),
+ },
+ "pred": np.array([[[[10, 20], [30, 40]], [[50, 60], [70, 80]]]]),
+}
+
FIND_SLICE_TEST_CASE_1 = [{"label": "label", "sids": "sids"}, DATA_1, [0]]
FIND_SLICE_TEST_CASE_2 = [{"label": "label", "sids": "sids"}, DATA_2, [0, 1]]
@@ -329,6 +344,74 @@
RESTORE_LABEL_TEST_CASE_2 = [{"keys": ["pred"], "ref_image": "image", "mode": "nearest"}, DATA_11, RESULT]
+RESTORE_LABEL_TEST_CASE_3_RESULT = np.zeros((10, 20, 20))
+RESTORE_LABEL_TEST_CASE_3_RESULT[:5, 0:10, 0:10] = 1
+RESTORE_LABEL_TEST_CASE_3_RESULT[:5, 0:10, 10:20] = 2
+RESTORE_LABEL_TEST_CASE_3_RESULT[:5, 10:20, 0:10] = 3
+RESTORE_LABEL_TEST_CASE_3_RESULT[:5, 10:20, 10:20] = 4
+RESTORE_LABEL_TEST_CASE_3_RESULT[5:10, 0:10, 0:10] = 5
+RESTORE_LABEL_TEST_CASE_3_RESULT[5:10, 0:10, 10:20] = 6
+RESTORE_LABEL_TEST_CASE_3_RESULT[5:10, 10:20, 0:10] = 7
+RESTORE_LABEL_TEST_CASE_3_RESULT[5:10, 10:20, 10:20] = 8
+
+RESTORE_LABEL_TEST_CASE_3 = [
+ {"keys": ["pred"], "ref_image": "image", "mode": "nearest", "restore_cropping": False},
+ DATA_11,
+ RESTORE_LABEL_TEST_CASE_3_RESULT,
+]
+
+RESTORE_LABEL_TEST_CASE_4_RESULT = np.zeros((4, 8, 8))
+RESTORE_LABEL_TEST_CASE_4_RESULT[1, 2:6, 2:6] = np.array(
+ [[10.0, 10.0, 20.0, 20.0], [10.0, 10.0, 20.0, 20.0], [30.0, 30.0, 40.0, 40.0], [30.0, 30.0, 40.0, 40.0]]
+)
+RESTORE_LABEL_TEST_CASE_4_RESULT[2, 2:6, 2:6] = np.array(
+ [[50.0, 50.0, 60.0, 60.0], [50.0, 50.0, 60.0, 60.0], [70.0, 70.0, 80.0, 80.0], [70.0, 70.0, 80.0, 80.0]]
+)
+
+RESTORE_LABEL_TEST_CASE_4 = [
+ {"keys": ["pred"], "ref_image": "image", "mode": "nearest", "restore_resizing": False},
+ DATA_13,
+ RESTORE_LABEL_TEST_CASE_4_RESULT,
+]
+
+RESTORE_LABEL_TEST_CASE_5_RESULT = np.zeros((4, 4, 4))
+RESTORE_LABEL_TEST_CASE_5_RESULT[1, 1:3, 1:3] = np.array([[10.0, 20.0], [30.0, 40.0]])
+RESTORE_LABEL_TEST_CASE_5_RESULT[2, 1:3, 1:3] = np.array([[50.0, 60.0], [70.0, 80.0]])
+
+RESTORE_LABEL_TEST_CASE_5 = [
+ {"keys": ["pred"], "ref_image": "image", "mode": "nearest", "restore_spacing": False},
+ DATA_13,
+ RESTORE_LABEL_TEST_CASE_5_RESULT,
+]
+
+RESTORE_LABEL_TEST_CASE_6_RESULT = np.zeros((1, 4, 8, 8))
+RESTORE_LABEL_TEST_CASE_6_RESULT[-1, 1, 2:6, 2:6] = np.array(
+ [[10.0, 10.0, 20.0, 20.0], [10.0, 10.0, 20.0, 20.0], [30.0, 30.0, 40.0, 40.0], [30.0, 30.0, 40.0, 40.0]]
+)
+RESTORE_LABEL_TEST_CASE_6_RESULT[-1, 2, 2:6, 2:6] = np.array(
+ [[50.0, 50.0, 60.0, 60.0], [50.0, 50.0, 60.0, 60.0], [70.0, 70.0, 80.0, 80.0], [70.0, 70.0, 80.0, 80.0]]
+)
+
+RESTORE_LABEL_TEST_CASE_6 = [
+ {"keys": ["pred"], "ref_image": "image", "mode": "nearest", "restore_slicing": False},
+ DATA_13,
+ RESTORE_LABEL_TEST_CASE_6_RESULT,
+]
+
+RESTORE_LABEL_TEST_CASE_7 = [
+ {
+ "keys": ["pred"],
+ "ref_image": "image",
+ "mode": "nearest",
+ "restore_resizing": False,
+ "restore_cropping": False,
+ "restore_spacing": False,
+ "restore_slicing": False,
+ },
+ DATA_11,
+ np.array([[[[1, 2], [3, 4]], [[5, 6], [7, 8]]]]),
+]
+
FETCH_2D_SLICE_TEST_CASE_1 = [
{"keys": ["image"], "guidance": "guidance"},
DATA_12,
@@ -445,7 +528,17 @@ def test_correct_results(self, arguments, input_data, expected_result):
class TestRestoreLabeld(unittest.TestCase):
- @parameterized.expand([RESTORE_LABEL_TEST_CASE_1, RESTORE_LABEL_TEST_CASE_2])
+ @parameterized.expand(
+ [
+ RESTORE_LABEL_TEST_CASE_1,
+ RESTORE_LABEL_TEST_CASE_2,
+ RESTORE_LABEL_TEST_CASE_3,
+ RESTORE_LABEL_TEST_CASE_4,
+ RESTORE_LABEL_TEST_CASE_5,
+ RESTORE_LABEL_TEST_CASE_6,
+ RESTORE_LABEL_TEST_CASE_7,
+ ]
+ )
def test_correct_results(self, arguments, input_data, expected_result):
result = RestoreLabeld(**arguments)(input_data)
np.testing.assert_allclose(result["pred"], expected_result)
diff --git a/tests/test_detect_envelope.py b/tests/test_detect_envelope.py
index e2efefeb77..f9c2b5ac53 100644
--- a/tests/test_detect_envelope.py
+++ b/tests/test_detect_envelope.py
@@ -147,7 +147,7 @@ def test_value_error(self, arguments, image, method):
elif method == "__call__":
self.assertRaises(ValueError, DetectEnvelope(**arguments), image)
else:
- raise ValueError("Expected raising method invalid. Should be __init__ or __call__.")
+ self.fail("Expected raising method invalid. Should be __init__ or __call__.")
@SkipIfModule("torch.fft")
diff --git a/tests/test_dice_ce_loss.py b/tests/test_dice_ce_loss.py
index 225618ed2c..97c7ae5050 100644
--- a/tests/test_dice_ce_loss.py
+++ b/tests/test_dice_ce_loss.py
@@ -93,10 +93,20 @@ def test_result(self, input_param, input_data, expected_val):
result = diceceloss(**input_data)
np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4)
- # def test_ill_shape(self):
- # loss = DiceCELoss()
- # with self.assertRaisesRegex(ValueError, ""):
- # loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3)))
+ def test_ill_shape(self):
+ loss = DiceCELoss()
+ with self.assertRaises(AssertionError):
+ loss.forward(torch.ones((1, 2, 3)), torch.ones((1, 2, 5)))
+
+ def test_ill_shape2(self):
+ loss = DiceCELoss()
+ with self.assertRaises(ValueError):
+ loss.forward(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3)))
+
+ def test_ill_shape3(self):
+ loss = DiceCELoss()
+ with self.assertRaises(ValueError):
+ loss.forward(torch.ones((1, 3, 4, 4)), torch.ones((1, 2, 4, 4)))
# def test_ill_reduction(self):
# with self.assertRaisesRegex(ValueError, ""):
diff --git a/tests/test_dice_focal_loss.py b/tests/test_dice_focal_loss.py
index 13899da003..f769aac69f 100644
--- a/tests/test_dice_focal_loss.py
+++ b/tests/test_dice_focal_loss.py
@@ -69,8 +69,18 @@ def test_result_no_onehot_no_bg(self, size, onehot):
def test_ill_shape(self):
loss = DiceFocalLoss()
- with self.assertRaisesRegex(ValueError, ""):
- loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3)))
+ with self.assertRaises(AssertionError):
+ loss.forward(torch.ones((1, 2, 3)), torch.ones((1, 2, 5)))
+
+ def test_ill_shape2(self):
+ loss = DiceFocalLoss()
+ with self.assertRaises(ValueError):
+ loss.forward(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3)))
+
+ def test_ill_shape3(self):
+ loss = DiceFocalLoss()
+ with self.assertRaises(ValueError):
+ loss.forward(torch.ones((1, 3, 4, 4)), torch.ones((1, 2, 4, 4)))
def test_ill_lambda(self):
with self.assertRaisesRegex(ValueError, ""):
@@ -81,6 +91,35 @@ def test_script(self):
test_input = torch.ones(2, 1, 8, 8)
test_script_save(loss, test_input, test_input)
+ @parameterized.expand(
+ [
+ ("sum_None_0.5_0.25", "sum", None, 0.5, 0.25),
+ ("sum_weight_0.5_0.25", "sum", torch.tensor([1.0, 1.0, 2.0]), 0.5, 0.25),
+ ("sum_weight_tuple_0.5_0.25", "sum", (3, 2.0, 1), 0.5, 0.25),
+ ("mean_None_0.5_0.25", "mean", None, 0.5, 0.25),
+ ("mean_weight_0.5_0.25", "mean", torch.tensor([1.0, 1.0, 2.0]), 0.5, 0.25),
+ ("mean_weight_tuple_0.5_0.25", "mean", (3, 2.0, 1), 0.5, 0.25),
+ ("none_None_0.5_0.25", "none", None, 0.5, 0.25),
+ ("none_weight_0.5_0.25", "none", torch.tensor([1.0, 1.0, 2.0]), 0.5, 0.25),
+ ("none_weight_tuple_0.5_0.25", "none", (3, 2.0, 1), 0.5, 0.25),
+ ]
+ )
+ def test_with_alpha(self, name, reduction, weight, lambda_focal, alpha):
+ size = [3, 3, 5, 5]
+ label = torch.randint(low=0, high=2, size=size)
+ pred = torch.randn(size)
+
+ common_params = {"include_background": True, "to_onehot_y": False, "reduction": reduction, "weight": weight}
+
+ dice_focal = DiceFocalLoss(gamma=1.0, lambda_focal=lambda_focal, alpha=alpha, **common_params)
+ dice = DiceLoss(**common_params)
+ focal = FocalLoss(gamma=1.0, alpha=alpha, **common_params)
+
+ result = dice_focal(pred, label)
+ expected_val = dice(pred, label) + lambda_focal * focal(pred, label)
+
+ np.testing.assert_allclose(result, expected_val, err_msg=f"Failed on case: {name}")
+
if __name__ == "__main__":
unittest.main()
diff --git a/tests/test_diffusion_inferer.py b/tests/test_diffusion_inferer.py
new file mode 100644
index 0000000000..7f37025d3c
--- /dev/null
+++ b/tests/test_diffusion_inferer.py
@@ -0,0 +1,236 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import unittest
+from unittest import skipUnless
+
+import torch
+from parameterized import parameterized
+
+from monai.inferers import DiffusionInferer
+from monai.networks.nets import DiffusionModelUNet
+from monai.networks.schedulers import DDIMScheduler, DDPMScheduler
+from monai.utils import optional_import
+
+_, has_scipy = optional_import("scipy")
+_, has_einops = optional_import("einops")
+
+TEST_CASES = [
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "channels": [8],
+ "norm_num_groups": 8,
+ "attention_levels": [True],
+ "num_res_blocks": 1,
+ "num_head_channels": 8,
+ },
+ (2, 1, 8, 8),
+ ],
+ [
+ {
+ "spatial_dims": 3,
+ "in_channels": 1,
+ "out_channels": 1,
+ "channels": [8],
+ "norm_num_groups": 8,
+ "attention_levels": [True],
+ "num_res_blocks": 1,
+ "num_head_channels": 8,
+ },
+ (2, 1, 8, 8, 8),
+ ],
+]
+
+
+class TestDiffusionSamplingInferer(unittest.TestCase):
+ @parameterized.expand(TEST_CASES)
+ @skipUnless(has_einops, "Requires einops")
+ def test_call(self, model_params, input_shape):
+ model = DiffusionModelUNet(**model_params)
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
+ model.to(device)
+ model.eval()
+ input = torch.randn(input_shape).to(device)
+ noise = torch.randn(input_shape).to(device)
+ scheduler = DDPMScheduler(num_train_timesteps=10)
+ inferer = DiffusionInferer(scheduler=scheduler)
+ scheduler.set_timesteps(num_inference_steps=10)
+ timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long()
+ sample = inferer(inputs=input, noise=noise, diffusion_model=model, timesteps=timesteps)
+ self.assertEqual(sample.shape, input_shape)
+
+ @parameterized.expand(TEST_CASES)
+ @skipUnless(has_einops, "Requires einops")
+ def test_sample_intermediates(self, model_params, input_shape):
+ model = DiffusionModelUNet(**model_params)
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
+ model.to(device)
+ model.eval()
+ noise = torch.randn(input_shape).to(device)
+ scheduler = DDPMScheduler(num_train_timesteps=10)
+ inferer = DiffusionInferer(scheduler=scheduler)
+ scheduler.set_timesteps(num_inference_steps=10)
+ sample, intermediates = inferer.sample(
+ input_noise=noise, diffusion_model=model, scheduler=scheduler, save_intermediates=True, intermediate_steps=1
+ )
+ self.assertEqual(len(intermediates), 10)
+
+ @parameterized.expand(TEST_CASES)
+ @skipUnless(has_einops, "Requires einops")
+ def test_ddpm_sampler(self, model_params, input_shape):
+ model = DiffusionModelUNet(**model_params)
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
+ model.to(device)
+ model.eval()
+ noise = torch.randn(input_shape).to(device)
+ scheduler = DDPMScheduler(num_train_timesteps=1000)
+ inferer = DiffusionInferer(scheduler=scheduler)
+ scheduler.set_timesteps(num_inference_steps=10)
+ sample, intermediates = inferer.sample(
+ input_noise=noise, diffusion_model=model, scheduler=scheduler, save_intermediates=True, intermediate_steps=1
+ )
+ self.assertEqual(len(intermediates), 10)
+
+ @parameterized.expand(TEST_CASES)
+ @skipUnless(has_einops, "Requires einops")
+ def test_ddim_sampler(self, model_params, input_shape):
+ model = DiffusionModelUNet(**model_params)
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
+ model.to(device)
+ model.eval()
+ noise = torch.randn(input_shape).to(device)
+ scheduler = DDIMScheduler(num_train_timesteps=1000)
+ inferer = DiffusionInferer(scheduler=scheduler)
+ scheduler.set_timesteps(num_inference_steps=10)
+ sample, intermediates = inferer.sample(
+ input_noise=noise, diffusion_model=model, scheduler=scheduler, save_intermediates=True, intermediate_steps=1
+ )
+ self.assertEqual(len(intermediates), 10)
+
+ @parameterized.expand(TEST_CASES)
+ @skipUnless(has_einops, "Requires einops")
+ def test_sampler_conditioned(self, model_params, input_shape):
+ model_params["with_conditioning"] = True
+ model_params["cross_attention_dim"] = 3
+ model = DiffusionModelUNet(**model_params)
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
+ model.to(device)
+ model.eval()
+ noise = torch.randn(input_shape).to(device)
+ scheduler = DDIMScheduler(num_train_timesteps=1000)
+ inferer = DiffusionInferer(scheduler=scheduler)
+ scheduler.set_timesteps(num_inference_steps=10)
+ conditioning = torch.randn([input_shape[0], 1, 3]).to(device)
+ sample, intermediates = inferer.sample(
+ input_noise=noise,
+ diffusion_model=model,
+ scheduler=scheduler,
+ save_intermediates=True,
+ intermediate_steps=1,
+ conditioning=conditioning,
+ )
+ self.assertEqual(len(intermediates), 10)
+
+ @parameterized.expand(TEST_CASES)
+ @skipUnless(has_einops, "Requires einops")
+ def test_get_likelihood(self, model_params, input_shape):
+ model = DiffusionModelUNet(**model_params)
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
+ model.to(device)
+ model.eval()
+ input = torch.randn(input_shape).to(device)
+ scheduler = DDPMScheduler(num_train_timesteps=10)
+ inferer = DiffusionInferer(scheduler=scheduler)
+ scheduler.set_timesteps(num_inference_steps=10)
+ likelihood, intermediates = inferer.get_likelihood(
+ inputs=input, diffusion_model=model, scheduler=scheduler, save_intermediates=True
+ )
+ self.assertEqual(intermediates[0].shape, input.shape)
+ self.assertEqual(likelihood.shape[0], input.shape[0])
+
+ @unittest.skipUnless(has_scipy, "Requires scipy library.")
+ def test_normal_cdf(self):
+ from scipy.stats import norm
+
+ scheduler = DDPMScheduler(num_train_timesteps=10)
+ inferer = DiffusionInferer(scheduler=scheduler)
+
+ x = torch.linspace(-10, 10, 20)
+ cdf_approx = inferer._approx_standard_normal_cdf(x)
+ cdf_true = norm.cdf(x)
+ torch.testing.assert_allclose(cdf_approx, cdf_true, atol=1e-3, rtol=1e-5)
+
+ @parameterized.expand(TEST_CASES)
+ @skipUnless(has_einops, "Requires einops")
+ def test_sampler_conditioned_concat(self, model_params, input_shape):
+ # copy the model_params dict to prevent from modifying test cases
+ model_params = model_params.copy()
+ n_concat_channel = 2
+ model_params["in_channels"] = model_params["in_channels"] + n_concat_channel
+ model_params["cross_attention_dim"] = None
+ model_params["with_conditioning"] = False
+ model = DiffusionModelUNet(**model_params)
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
+ model.to(device)
+ model.eval()
+ noise = torch.randn(input_shape).to(device)
+ conditioning_shape = list(input_shape)
+ conditioning_shape[1] = n_concat_channel
+ conditioning = torch.randn(conditioning_shape).to(device)
+ scheduler = DDIMScheduler(num_train_timesteps=1000)
+ inferer = DiffusionInferer(scheduler=scheduler)
+ scheduler.set_timesteps(num_inference_steps=10)
+ sample, intermediates = inferer.sample(
+ input_noise=noise,
+ diffusion_model=model,
+ scheduler=scheduler,
+ save_intermediates=True,
+ intermediate_steps=1,
+ conditioning=conditioning,
+ mode="concat",
+ )
+ self.assertEqual(len(intermediates), 10)
+
+ @parameterized.expand(TEST_CASES)
+ @skipUnless(has_einops, "Requires einops")
+ def test_call_conditioned_concat(self, model_params, input_shape):
+ # copy the model_params dict to prevent from modifying test cases
+ model_params = model_params.copy()
+ n_concat_channel = 2
+ model_params["in_channels"] = model_params["in_channels"] + n_concat_channel
+ model_params["cross_attention_dim"] = None
+ model_params["with_conditioning"] = False
+ model = DiffusionModelUNet(**model_params)
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
+ model.to(device)
+ model.eval()
+ input = torch.randn(input_shape).to(device)
+ noise = torch.randn(input_shape).to(device)
+ conditioning_shape = list(input_shape)
+ conditioning_shape[1] = n_concat_channel
+ conditioning = torch.randn(conditioning_shape).to(device)
+ scheduler = DDPMScheduler(num_train_timesteps=10)
+ inferer = DiffusionInferer(scheduler=scheduler)
+ scheduler.set_timesteps(num_inference_steps=10)
+ timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long()
+ sample = inferer(
+ inputs=input, noise=noise, diffusion_model=model, timesteps=timesteps, condition=conditioning, mode="concat"
+ )
+ self.assertEqual(sample.shape, input_shape)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_diffusion_model_unet.py b/tests/test_diffusion_model_unet.py
new file mode 100644
index 0000000000..7f764d85de
--- /dev/null
+++ b/tests/test_diffusion_model_unet.py
@@ -0,0 +1,585 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import os
+import tempfile
+import unittest
+from unittest import skipUnless
+
+import torch
+from parameterized import parameterized
+
+from monai.apps import download_url
+from monai.networks import eval_mode
+from monai.networks.nets import DiffusionModelUNet
+from monai.utils import optional_import
+from tests.utils import skip_if_downloading_fails, testing_data_config
+
+_, has_einops = optional_import("einops")
+
+UNCOND_CASES_2D = [
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "num_res_blocks": 1,
+ "channels": (8, 8, 8),
+ "attention_levels": (False, False, False),
+ "norm_num_groups": 8,
+ }
+ ],
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "num_res_blocks": (1, 1, 2),
+ "channels": (8, 8, 8),
+ "attention_levels": (False, False, False),
+ "norm_num_groups": 8,
+ }
+ ],
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "num_res_blocks": 1,
+ "channels": (8, 8, 8),
+ "attention_levels": (False, False, False),
+ "norm_num_groups": 8,
+ "resblock_updown": True,
+ }
+ ],
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "num_res_blocks": 1,
+ "channels": (8, 8, 8),
+ "attention_levels": (False, False, True),
+ "num_head_channels": 8,
+ "norm_num_groups": 8,
+ }
+ ],
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "num_res_blocks": 1,
+ "channels": (8, 8, 8),
+ "attention_levels": (False, False, True),
+ "num_head_channels": 8,
+ "norm_num_groups": 8,
+ "resblock_updown": True,
+ }
+ ],
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "num_res_blocks": 1,
+ "channels": (8, 8, 8),
+ "attention_levels": (False, False, True),
+ "num_head_channels": 4,
+ "norm_num_groups": 8,
+ }
+ ],
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "num_res_blocks": 1,
+ "channels": (8, 8, 8),
+ "attention_levels": (False, True, True),
+ "num_head_channels": (0, 2, 4),
+ "norm_num_groups": 8,
+ }
+ ],
+]
+
+UNCOND_CASES_3D = [
+ [
+ {
+ "spatial_dims": 3,
+ "in_channels": 1,
+ "out_channels": 1,
+ "num_res_blocks": 1,
+ "channels": (8, 8, 8),
+ "attention_levels": (False, False, False),
+ "norm_num_groups": 8,
+ }
+ ],
+ [
+ {
+ "spatial_dims": 3,
+ "in_channels": 1,
+ "out_channels": 1,
+ "num_res_blocks": 1,
+ "channels": (8, 8, 8),
+ "attention_levels": (False, False, False),
+ "norm_num_groups": 8,
+ "resblock_updown": True,
+ }
+ ],
+ [
+ {
+ "spatial_dims": 3,
+ "in_channels": 1,
+ "out_channels": 1,
+ "num_res_blocks": 1,
+ "channels": (8, 8, 8),
+ "attention_levels": (False, False, True),
+ "num_head_channels": 8,
+ "norm_num_groups": 8,
+ }
+ ],
+ [
+ {
+ "spatial_dims": 3,
+ "in_channels": 1,
+ "out_channels": 1,
+ "num_res_blocks": 1,
+ "channels": (8, 8, 8),
+ "attention_levels": (False, False, True),
+ "num_head_channels": 8,
+ "norm_num_groups": 8,
+ "resblock_updown": True,
+ }
+ ],
+ [
+ {
+ "spatial_dims": 3,
+ "in_channels": 1,
+ "out_channels": 1,
+ "num_res_blocks": 1,
+ "channels": (8, 8, 8),
+ "attention_levels": (False, False, True),
+ "num_head_channels": 4,
+ "norm_num_groups": 8,
+ }
+ ],
+ [
+ {
+ "spatial_dims": 3,
+ "in_channels": 1,
+ "out_channels": 1,
+ "num_res_blocks": 1,
+ "channels": (8, 8, 8),
+ "attention_levels": (False, False, True),
+ "num_head_channels": (0, 0, 4),
+ "norm_num_groups": 8,
+ }
+ ],
+]
+
+COND_CASES_2D = [
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "num_res_blocks": 1,
+ "channels": (8, 8, 8),
+ "attention_levels": (False, False, True),
+ "num_head_channels": 4,
+ "norm_num_groups": 8,
+ "with_conditioning": True,
+ "transformer_num_layers": 1,
+ "cross_attention_dim": 3,
+ }
+ ],
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "num_res_blocks": 1,
+ "channels": (8, 8, 8),
+ "attention_levels": (False, False, True),
+ "num_head_channels": 4,
+ "norm_num_groups": 8,
+ "with_conditioning": True,
+ "transformer_num_layers": 1,
+ "cross_attention_dim": 3,
+ "resblock_updown": True,
+ }
+ ],
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "num_res_blocks": 1,
+ "channels": (8, 8, 8),
+ "attention_levels": (False, False, True),
+ "num_head_channels": 4,
+ "norm_num_groups": 8,
+ "with_conditioning": True,
+ "transformer_num_layers": 1,
+ "cross_attention_dim": 3,
+ "upcast_attention": True,
+ }
+ ],
+]
+
+DROPOUT_OK = [
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "num_res_blocks": 1,
+ "channels": (8, 8, 8),
+ "attention_levels": (False, False, True),
+ "num_head_channels": 4,
+ "norm_num_groups": 8,
+ "with_conditioning": True,
+ "transformer_num_layers": 1,
+ "cross_attention_dim": 3,
+ "dropout_cattn": 0.25,
+ }
+ ],
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "num_res_blocks": 1,
+ "channels": (8, 8, 8),
+ "attention_levels": (False, False, True),
+ "num_head_channels": 4,
+ "norm_num_groups": 8,
+ "with_conditioning": True,
+ "transformer_num_layers": 1,
+ "cross_attention_dim": 3,
+ }
+ ],
+]
+
+DROPOUT_WRONG = [
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "num_res_blocks": 1,
+ "channels": (8, 8, 8),
+ "attention_levels": (False, False, True),
+ "num_head_channels": 4,
+ "norm_num_groups": 8,
+ "with_conditioning": True,
+ "transformer_num_layers": 1,
+ "cross_attention_dim": 3,
+ "dropout_cattn": 3.0,
+ }
+ ]
+]
+
+
+class TestDiffusionModelUNet2D(unittest.TestCase):
+ @parameterized.expand(UNCOND_CASES_2D)
+ @skipUnless(has_einops, "Requires einops")
+ def test_shape_unconditioned_models(self, input_param):
+ net = DiffusionModelUNet(**input_param)
+ with eval_mode(net):
+ result = net.forward(torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1,)).long())
+ self.assertEqual(result.shape, (1, 1, 16, 16))
+
+ @skipUnless(has_einops, "Requires einops")
+ def test_timestep_with_wrong_shape(self):
+ net = DiffusionModelUNet(
+ spatial_dims=2,
+ in_channels=1,
+ out_channels=1,
+ num_res_blocks=1,
+ channels=(8, 8, 8),
+ attention_levels=(False, False, False),
+ norm_num_groups=8,
+ )
+ with self.assertRaises(ValueError):
+ with eval_mode(net):
+ net.forward(torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1, 1)).long())
+
+ @skipUnless(has_einops, "Requires einops")
+ def test_shape_with_different_in_channel_out_channel(self):
+ in_channels = 6
+ out_channels = 3
+ net = DiffusionModelUNet(
+ spatial_dims=2,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ num_res_blocks=1,
+ channels=(8, 8, 8),
+ attention_levels=(False, False, False),
+ norm_num_groups=8,
+ )
+ with eval_mode(net):
+ result = net.forward(torch.rand((1, in_channels, 16, 16)), torch.randint(0, 1000, (1,)).long())
+ self.assertEqual(result.shape, (1, out_channels, 16, 16))
+
+ def test_model_channels_not_multiple_of_norm_num_group(self):
+ with self.assertRaises(ValueError):
+ DiffusionModelUNet(
+ spatial_dims=2,
+ in_channels=1,
+ out_channels=1,
+ num_res_blocks=1,
+ channels=(8, 8, 12),
+ attention_levels=(False, False, False),
+ norm_num_groups=8,
+ )
+
+ def test_attention_levels_with_different_length_num_head_channels(self):
+ with self.assertRaises(ValueError):
+ DiffusionModelUNet(
+ spatial_dims=2,
+ in_channels=1,
+ out_channels=1,
+ num_res_blocks=1,
+ channels=(8, 8, 8),
+ attention_levels=(False, False, False),
+ num_head_channels=(0, 2),
+ norm_num_groups=8,
+ )
+
+ def test_num_res_blocks_with_different_length_channels(self):
+ with self.assertRaises(ValueError):
+ DiffusionModelUNet(
+ spatial_dims=2,
+ in_channels=1,
+ out_channels=1,
+ num_res_blocks=(1, 1),
+ channels=(8, 8, 8),
+ attention_levels=(False, False, False),
+ norm_num_groups=8,
+ )
+
+ @skipUnless(has_einops, "Requires einops")
+ def test_shape_conditioned_models(self):
+ net = DiffusionModelUNet(
+ spatial_dims=2,
+ in_channels=1,
+ out_channels=1,
+ num_res_blocks=1,
+ channels=(8, 8, 8),
+ attention_levels=(False, False, True),
+ with_conditioning=True,
+ transformer_num_layers=1,
+ cross_attention_dim=3,
+ norm_num_groups=8,
+ num_head_channels=8,
+ )
+ with eval_mode(net):
+ result = net.forward(
+ x=torch.rand((1, 1, 16, 32)),
+ timesteps=torch.randint(0, 1000, (1,)).long(),
+ context=torch.rand((1, 1, 3)),
+ )
+ self.assertEqual(result.shape, (1, 1, 16, 32))
+
+ def test_with_conditioning_cross_attention_dim_none(self):
+ with self.assertRaises(ValueError):
+ DiffusionModelUNet(
+ spatial_dims=2,
+ in_channels=1,
+ out_channels=1,
+ num_res_blocks=1,
+ channels=(8, 8, 8),
+ attention_levels=(False, False, True),
+ with_conditioning=True,
+ transformer_num_layers=1,
+ cross_attention_dim=None,
+ norm_num_groups=8,
+ )
+
+ @skipUnless(has_einops, "Requires einops")
+ def test_context_with_conditioning_none(self):
+ net = DiffusionModelUNet(
+ spatial_dims=2,
+ in_channels=1,
+ out_channels=1,
+ num_res_blocks=1,
+ channels=(8, 8, 8),
+ attention_levels=(False, False, True),
+ with_conditioning=False,
+ transformer_num_layers=1,
+ norm_num_groups=8,
+ )
+
+ with self.assertRaises(ValueError):
+ with eval_mode(net):
+ net.forward(
+ x=torch.rand((1, 1, 16, 32)),
+ timesteps=torch.randint(0, 1000, (1,)).long(),
+ context=torch.rand((1, 1, 3)),
+ )
+
+ @skipUnless(has_einops, "Requires einops")
+ def test_shape_conditioned_models_class_conditioning(self):
+ net = DiffusionModelUNet(
+ spatial_dims=2,
+ in_channels=1,
+ out_channels=1,
+ num_res_blocks=1,
+ channels=(8, 8, 8),
+ attention_levels=(False, False, True),
+ norm_num_groups=8,
+ num_head_channels=8,
+ num_class_embeds=2,
+ )
+ with eval_mode(net):
+ result = net.forward(
+ x=torch.rand((1, 1, 16, 32)),
+ timesteps=torch.randint(0, 1000, (1,)).long(),
+ class_labels=torch.randint(0, 2, (1,)).long(),
+ )
+ self.assertEqual(result.shape, (1, 1, 16, 32))
+
+ @skipUnless(has_einops, "Requires einops")
+ def test_conditioned_models_no_class_labels(self):
+ net = DiffusionModelUNet(
+ spatial_dims=2,
+ in_channels=1,
+ out_channels=1,
+ num_res_blocks=1,
+ channels=(8, 8, 8),
+ attention_levels=(False, False, True),
+ norm_num_groups=8,
+ num_head_channels=8,
+ num_class_embeds=2,
+ )
+
+ with self.assertRaises(ValueError):
+ net.forward(x=torch.rand((1, 1, 16, 32)), timesteps=torch.randint(0, 1000, (1,)).long())
+
+ @skipUnless(has_einops, "Requires einops")
+ def test_model_channels_not_same_size_of_attention_levels(self):
+ with self.assertRaises(ValueError):
+ DiffusionModelUNet(
+ spatial_dims=2,
+ in_channels=1,
+ out_channels=1,
+ num_res_blocks=1,
+ channels=(8, 8, 8),
+ attention_levels=(False, False),
+ norm_num_groups=8,
+ num_head_channels=8,
+ num_class_embeds=2,
+ )
+
+ @parameterized.expand(COND_CASES_2D)
+ @skipUnless(has_einops, "Requires einops")
+ def test_conditioned_2d_models_shape(self, input_param):
+ net = DiffusionModelUNet(**input_param)
+ with eval_mode(net):
+ result = net.forward(torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1,)).long(), torch.rand((1, 1, 3)))
+ self.assertEqual(result.shape, (1, 1, 16, 16))
+
+
+class TestDiffusionModelUNet3D(unittest.TestCase):
+ @parameterized.expand(UNCOND_CASES_3D)
+ @skipUnless(has_einops, "Requires einops")
+ def test_shape_unconditioned_models(self, input_param):
+ net = DiffusionModelUNet(**input_param)
+ with eval_mode(net):
+ result = net.forward(torch.rand((1, 1, 16, 16, 16)), torch.randint(0, 1000, (1,)).long())
+ self.assertEqual(result.shape, (1, 1, 16, 16, 16))
+
+ @skipUnless(has_einops, "Requires einops")
+ def test_shape_with_different_in_channel_out_channel(self):
+ in_channels = 6
+ out_channels = 3
+ net = DiffusionModelUNet(
+ spatial_dims=3,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ num_res_blocks=1,
+ channels=(8, 8, 8),
+ attention_levels=(False, False, True),
+ norm_num_groups=4,
+ )
+ with eval_mode(net):
+ result = net.forward(torch.rand((1, in_channels, 16, 16, 16)), torch.randint(0, 1000, (1,)).long())
+ self.assertEqual(result.shape, (1, out_channels, 16, 16, 16))
+
+ @skipUnless(has_einops, "Requires einops")
+ def test_shape_conditioned_models(self):
+ net = DiffusionModelUNet(
+ spatial_dims=3,
+ in_channels=1,
+ out_channels=1,
+ num_res_blocks=1,
+ channels=(16, 16, 16),
+ attention_levels=(False, False, True),
+ norm_num_groups=16,
+ with_conditioning=True,
+ transformer_num_layers=1,
+ cross_attention_dim=3,
+ )
+ with eval_mode(net):
+ result = net.forward(
+ x=torch.rand((1, 1, 16, 16, 16)),
+ timesteps=torch.randint(0, 1000, (1,)).long(),
+ context=torch.rand((1, 1, 3)),
+ )
+ self.assertEqual(result.shape, (1, 1, 16, 16, 16))
+
+ # Test dropout specification for cross-attention blocks
+ @parameterized.expand(DROPOUT_WRONG)
+ def test_wrong_dropout(self, input_param):
+ with self.assertRaises(ValueError):
+ _ = DiffusionModelUNet(**input_param)
+
+ @parameterized.expand(DROPOUT_OK)
+ @skipUnless(has_einops, "Requires einops")
+ def test_right_dropout(self, input_param):
+ _ = DiffusionModelUNet(**input_param)
+
+ @skipUnless(has_einops, "Requires einops")
+ def test_compatibility_with_monai_generative(self):
+ # test loading weights from a model saved in MONAI Generative, version 0.2.3
+ with skip_if_downloading_fails():
+ net = DiffusionModelUNet(
+ spatial_dims=2,
+ in_channels=1,
+ out_channels=1,
+ num_res_blocks=1,
+ channels=(8, 8, 8),
+ attention_levels=(False, False, True),
+ with_conditioning=True,
+ cross_attention_dim=3,
+ transformer_num_layers=1,
+ norm_num_groups=8,
+ )
+
+ tmpdir = tempfile.mkdtemp()
+ key = "diffusion_model_unet_monai_generative_weights"
+ url = testing_data_config("models", key, "url")
+ hash_type = testing_data_config("models", key, "hash_type")
+ hash_val = testing_data_config("models", key, "hash_val")
+ filename = "diffusion_model_unet_monai_generative_weights.pt"
+
+ weight_path = os.path.join(tmpdir, filename)
+ download_url(url=url, filepath=weight_path, hash_val=hash_val, hash_type=hash_type)
+
+ net.load_old_state_dict(torch.load(weight_path), verbose=False)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_diffusion_model_unet_maisi.py b/tests/test_diffusion_model_unet_maisi.py
new file mode 100644
index 0000000000..f9384e6d82
--- /dev/null
+++ b/tests/test_diffusion_model_unet_maisi.py
@@ -0,0 +1,588 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import unittest
+from unittest import skipUnless
+
+import torch
+from parameterized import parameterized
+
+from monai.apps.generation.maisi.networks.diffusion_model_unet_maisi import DiffusionModelUNetMaisi
+from monai.networks import eval_mode
+from monai.utils import optional_import
+
+_, has_einops = optional_import("einops")
+
+UNCOND_CASES_2D = [
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "num_res_blocks": 1,
+ "num_channels": (8, 8, 8),
+ "attention_levels": (False, False, False),
+ "norm_num_groups": 8,
+ }
+ ],
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "num_res_blocks": (1, 1, 2),
+ "num_channels": (8, 8, 8),
+ "attention_levels": (False, False, False),
+ "norm_num_groups": 8,
+ }
+ ],
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "num_res_blocks": 1,
+ "num_channels": (8, 8, 8),
+ "attention_levels": (False, False, False),
+ "norm_num_groups": 8,
+ "resblock_updown": True,
+ }
+ ],
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "num_res_blocks": 1,
+ "num_channels": (8, 8, 8),
+ "attention_levels": (False, False, True),
+ "num_head_channels": 8,
+ "norm_num_groups": 8,
+ }
+ ],
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "num_res_blocks": 1,
+ "num_channels": (8, 8, 8),
+ "attention_levels": (False, False, True),
+ "num_head_channels": 8,
+ "norm_num_groups": 8,
+ "resblock_updown": True,
+ }
+ ],
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "num_res_blocks": 1,
+ "num_channels": (8, 8, 8),
+ "attention_levels": (False, False, True),
+ "num_head_channels": 4,
+ "norm_num_groups": 8,
+ }
+ ],
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "num_res_blocks": 1,
+ "num_channels": (8, 8, 8),
+ "attention_levels": (False, True, True),
+ "num_head_channels": (0, 2, 4),
+ "norm_num_groups": 8,
+ }
+ ],
+]
+
+UNCOND_CASES_3D = [
+ [
+ {
+ "spatial_dims": 3,
+ "in_channels": 1,
+ "out_channels": 1,
+ "num_res_blocks": 1,
+ "num_channels": (8, 8, 8),
+ "attention_levels": (False, False, False),
+ "norm_num_groups": 8,
+ }
+ ],
+ [
+ {
+ "spatial_dims": 3,
+ "in_channels": 1,
+ "out_channels": 1,
+ "num_res_blocks": 1,
+ "num_channels": (8, 8, 8),
+ "attention_levels": (False, False, False),
+ "norm_num_groups": 8,
+ "resblock_updown": True,
+ }
+ ],
+ [
+ {
+ "spatial_dims": 3,
+ "in_channels": 1,
+ "out_channels": 1,
+ "num_res_blocks": 1,
+ "num_channels": (8, 8, 8),
+ "attention_levels": (False, False, True),
+ "num_head_channels": 8,
+ "norm_num_groups": 8,
+ }
+ ],
+ [
+ {
+ "spatial_dims": 3,
+ "in_channels": 1,
+ "out_channels": 1,
+ "num_res_blocks": 1,
+ "num_channels": (8, 8, 8),
+ "attention_levels": (False, False, True),
+ "num_head_channels": 8,
+ "norm_num_groups": 8,
+ "resblock_updown": True,
+ }
+ ],
+ [
+ {
+ "spatial_dims": 3,
+ "in_channels": 1,
+ "out_channels": 1,
+ "num_res_blocks": 1,
+ "num_channels": (8, 8, 8),
+ "attention_levels": (False, False, True),
+ "num_head_channels": 4,
+ "norm_num_groups": 8,
+ }
+ ],
+ [
+ {
+ "spatial_dims": 3,
+ "in_channels": 1,
+ "out_channels": 1,
+ "num_res_blocks": 1,
+ "num_channels": (8, 8, 8),
+ "attention_levels": (False, False, True),
+ "num_head_channels": (0, 0, 4),
+ "norm_num_groups": 8,
+ }
+ ],
+]
+
+COND_CASES_2D = [
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "num_res_blocks": 1,
+ "num_channels": (8, 8, 8),
+ "attention_levels": (False, False, True),
+ "num_head_channels": 4,
+ "norm_num_groups": 8,
+ "with_conditioning": True,
+ "transformer_num_layers": 1,
+ "cross_attention_dim": 3,
+ }
+ ],
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "num_res_blocks": 1,
+ "num_channels": (8, 8, 8),
+ "attention_levels": (False, False, True),
+ "num_head_channels": 4,
+ "norm_num_groups": 8,
+ "with_conditioning": True,
+ "transformer_num_layers": 1,
+ "cross_attention_dim": 3,
+ "resblock_updown": True,
+ }
+ ],
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "num_res_blocks": 1,
+ "num_channels": (8, 8, 8),
+ "attention_levels": (False, False, True),
+ "num_head_channels": 4,
+ "norm_num_groups": 8,
+ "with_conditioning": True,
+ "transformer_num_layers": 1,
+ "cross_attention_dim": 3,
+ "upcast_attention": True,
+ }
+ ],
+]
+
+DROPOUT_OK = [
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "num_res_blocks": 1,
+ "num_channels": (8, 8, 8),
+ "attention_levels": (False, False, True),
+ "num_head_channels": 4,
+ "norm_num_groups": 8,
+ "with_conditioning": True,
+ "transformer_num_layers": 1,
+ "cross_attention_dim": 3,
+ "dropout_cattn": 0.25,
+ }
+ ],
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "num_res_blocks": 1,
+ "num_channels": (8, 8, 8),
+ "attention_levels": (False, False, True),
+ "num_head_channels": 4,
+ "norm_num_groups": 8,
+ "with_conditioning": True,
+ "transformer_num_layers": 1,
+ "cross_attention_dim": 3,
+ }
+ ],
+]
+
+DROPOUT_WRONG = [
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "num_res_blocks": 1,
+ "num_channels": (8, 8, 8),
+ "attention_levels": (False, False, True),
+ "num_head_channels": 4,
+ "norm_num_groups": 8,
+ "with_conditioning": True,
+ "transformer_num_layers": 1,
+ "cross_attention_dim": 3,
+ "dropout_cattn": 3.0,
+ }
+ ]
+]
+
+
+class TestDiffusionModelUNetMaisi2D(unittest.TestCase):
+
+ @parameterized.expand(UNCOND_CASES_2D)
+ @skipUnless(has_einops, "Requires einops")
+ def test_shape_unconditioned_models(self, input_param):
+ net = DiffusionModelUNetMaisi(**input_param)
+ with eval_mode(net):
+ result = net.forward(torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1,)).long())
+ self.assertEqual(result.shape, (1, 1, 16, 16))
+
+ @skipUnless(has_einops, "Requires einops")
+ def test_timestep_with_wrong_shape(self):
+ net = DiffusionModelUNetMaisi(
+ spatial_dims=2,
+ in_channels=1,
+ out_channels=1,
+ num_res_blocks=1,
+ num_channels=(8, 8, 8),
+ attention_levels=(False, False, False),
+ norm_num_groups=8,
+ )
+ with self.assertRaises(ValueError):
+ with eval_mode(net):
+ net.forward(torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1, 1)).long())
+
+ @skipUnless(has_einops, "Requires einops")
+ def test_shape_with_different_in_channel_out_channel(self):
+ in_channels = 6
+ out_channels = 3
+ net = DiffusionModelUNetMaisi(
+ spatial_dims=2,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ num_res_blocks=1,
+ num_channels=(8, 8, 8),
+ attention_levels=(False, False, False),
+ norm_num_groups=8,
+ )
+ with eval_mode(net):
+ result = net.forward(torch.rand((1, in_channels, 16, 16)), torch.randint(0, 1000, (1,)).long())
+ self.assertEqual(result.shape, (1, out_channels, 16, 16))
+
+ def test_model_channels_not_multiple_of_norm_num_group(self):
+ with self.assertRaises(ValueError):
+ DiffusionModelUNetMaisi(
+ spatial_dims=2,
+ in_channels=1,
+ out_channels=1,
+ num_res_blocks=1,
+ num_channels=(8, 8, 12),
+ attention_levels=(False, False, False),
+ norm_num_groups=8,
+ )
+
+ def test_attention_levels_with_different_length_num_head_channels(self):
+ with self.assertRaises(ValueError):
+ DiffusionModelUNetMaisi(
+ spatial_dims=2,
+ in_channels=1,
+ out_channels=1,
+ num_res_blocks=1,
+ num_channels=(8, 8, 8),
+ attention_levels=(False, False, False),
+ num_head_channels=(0, 2),
+ norm_num_groups=8,
+ )
+
+ def test_num_res_blocks_with_different_length_channels(self):
+ with self.assertRaises(ValueError):
+ DiffusionModelUNetMaisi(
+ spatial_dims=2,
+ in_channels=1,
+ out_channels=1,
+ num_res_blocks=(1, 1),
+ num_channels=(8, 8, 8),
+ attention_levels=(False, False, False),
+ norm_num_groups=8,
+ )
+
+ @skipUnless(has_einops, "Requires einops")
+ def test_shape_conditioned_models(self):
+ net = DiffusionModelUNetMaisi(
+ spatial_dims=2,
+ in_channels=1,
+ out_channels=1,
+ num_res_blocks=1,
+ num_channels=(8, 8, 8),
+ attention_levels=(False, False, True),
+ with_conditioning=True,
+ transformer_num_layers=1,
+ cross_attention_dim=3,
+ norm_num_groups=8,
+ num_head_channels=8,
+ )
+ with eval_mode(net):
+ result = net.forward(
+ x=torch.rand((1, 1, 16, 32)),
+ timesteps=torch.randint(0, 1000, (1,)).long(),
+ context=torch.rand((1, 1, 3)),
+ )
+ self.assertEqual(result.shape, (1, 1, 16, 32))
+
+ def test_with_conditioning_cross_attention_dim_none(self):
+ with self.assertRaises(ValueError):
+ DiffusionModelUNetMaisi(
+ spatial_dims=2,
+ in_channels=1,
+ out_channels=1,
+ num_res_blocks=1,
+ num_channels=(8, 8, 8),
+ attention_levels=(False, False, True),
+ with_conditioning=True,
+ transformer_num_layers=1,
+ cross_attention_dim=None,
+ norm_num_groups=8,
+ )
+
+ @skipUnless(has_einops, "Requires einops")
+ def test_context_with_conditioning_none(self):
+ net = DiffusionModelUNetMaisi(
+ spatial_dims=2,
+ in_channels=1,
+ out_channels=1,
+ num_res_blocks=1,
+ num_channels=(8, 8, 8),
+ attention_levels=(False, False, True),
+ with_conditioning=False,
+ transformer_num_layers=1,
+ norm_num_groups=8,
+ )
+
+ with self.assertRaises(ValueError):
+ with eval_mode(net):
+ net.forward(
+ x=torch.rand((1, 1, 16, 32)),
+ timesteps=torch.randint(0, 1000, (1,)).long(),
+ context=torch.rand((1, 1, 3)),
+ )
+
+ @skipUnless(has_einops, "Requires einops")
+ def test_shape_conditioned_models_class_conditioning(self):
+ net = DiffusionModelUNetMaisi(
+ spatial_dims=2,
+ in_channels=1,
+ out_channels=1,
+ num_res_blocks=1,
+ num_channels=(8, 8, 8),
+ attention_levels=(False, False, True),
+ norm_num_groups=8,
+ num_head_channels=8,
+ num_class_embeds=2,
+ )
+ with eval_mode(net):
+ result = net.forward(
+ x=torch.rand((1, 1, 16, 32)),
+ timesteps=torch.randint(0, 1000, (1,)).long(),
+ class_labels=torch.randint(0, 2, (1,)).long(),
+ )
+ self.assertEqual(result.shape, (1, 1, 16, 32))
+
+ @skipUnless(has_einops, "Requires einops")
+ def test_conditioned_models_no_class_labels(self):
+ net = DiffusionModelUNetMaisi(
+ spatial_dims=2,
+ in_channels=1,
+ out_channels=1,
+ num_res_blocks=1,
+ num_channels=(8, 8, 8),
+ attention_levels=(False, False, True),
+ norm_num_groups=8,
+ num_head_channels=8,
+ num_class_embeds=2,
+ )
+
+ with self.assertRaises(ValueError):
+ net.forward(x=torch.rand((1, 1, 16, 32)), timesteps=torch.randint(0, 1000, (1,)).long())
+
+ @skipUnless(has_einops, "Requires einops")
+ def test_model_channels_not_same_size_of_attention_levels(self):
+ with self.assertRaises(ValueError):
+ DiffusionModelUNetMaisi(
+ spatial_dims=2,
+ in_channels=1,
+ out_channels=1,
+ num_res_blocks=1,
+ num_channels=(8, 8, 8),
+ attention_levels=(False, False),
+ norm_num_groups=8,
+ num_head_channels=8,
+ num_class_embeds=2,
+ )
+
+ @parameterized.expand(COND_CASES_2D)
+ @skipUnless(has_einops, "Requires einops")
+ def test_conditioned_2d_models_shape(self, input_param):
+ net = DiffusionModelUNetMaisi(**input_param)
+ with eval_mode(net):
+ result = net.forward(torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1,)).long(), torch.rand((1, 1, 3)))
+ self.assertEqual(result.shape, (1, 1, 16, 16))
+
+ @parameterized.expand(UNCOND_CASES_2D)
+ @skipUnless(has_einops, "Requires einops")
+ def test_shape_with_additional_inputs(self, input_param):
+ input_param["include_top_region_index_input"] = True
+ input_param["include_bottom_region_index_input"] = True
+ input_param["include_spacing_input"] = True
+ net = DiffusionModelUNetMaisi(**input_param)
+ with eval_mode(net):
+ result = net.forward(
+ x=torch.rand((1, 1, 16, 16)),
+ timesteps=torch.randint(0, 1000, (1,)).long(),
+ top_region_index_tensor=torch.rand((1, 4)),
+ bottom_region_index_tensor=torch.rand((1, 4)),
+ spacing_tensor=torch.rand((1, 3)),
+ )
+ self.assertEqual(result.shape, (1, 1, 16, 16))
+
+
+class TestDiffusionModelUNetMaisi3D(unittest.TestCase):
+
+ @parameterized.expand(UNCOND_CASES_3D)
+ @skipUnless(has_einops, "Requires einops")
+ def test_shape_unconditioned_models(self, input_param):
+ net = DiffusionModelUNetMaisi(**input_param)
+ with eval_mode(net):
+ result = net.forward(torch.rand((1, 1, 16, 16, 16)), torch.randint(0, 1000, (1,)).long())
+ self.assertEqual(result.shape, (1, 1, 16, 16, 16))
+
+ @skipUnless(has_einops, "Requires einops")
+ def test_shape_with_different_in_channel_out_channel(self):
+ in_channels = 6
+ out_channels = 3
+ net = DiffusionModelUNetMaisi(
+ spatial_dims=3,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ num_res_blocks=1,
+ num_channels=(8, 8, 8),
+ attention_levels=(False, False, True),
+ norm_num_groups=4,
+ )
+ with eval_mode(net):
+ result = net.forward(torch.rand((1, in_channels, 16, 16, 16)), torch.randint(0, 1000, (1,)).long())
+ self.assertEqual(result.shape, (1, out_channels, 16, 16, 16))
+
+ @skipUnless(has_einops, "Requires einops")
+ def test_shape_conditioned_models(self):
+ net = DiffusionModelUNetMaisi(
+ spatial_dims=3,
+ in_channels=1,
+ out_channels=1,
+ num_res_blocks=1,
+ num_channels=(16, 16, 16),
+ attention_levels=(False, False, True),
+ norm_num_groups=16,
+ with_conditioning=True,
+ transformer_num_layers=1,
+ cross_attention_dim=3,
+ )
+ with eval_mode(net):
+ result = net.forward(
+ x=torch.rand((1, 1, 16, 16, 16)),
+ timesteps=torch.randint(0, 1000, (1,)).long(),
+ context=torch.rand((1, 1, 3)),
+ )
+ self.assertEqual(result.shape, (1, 1, 16, 16, 16))
+
+ # Test dropout specification for cross-attention blocks
+ @parameterized.expand(DROPOUT_WRONG)
+ def test_wrong_dropout(self, input_param):
+ with self.assertRaises(ValueError):
+ _ = DiffusionModelUNetMaisi(**input_param)
+
+ @parameterized.expand(DROPOUT_OK)
+ @skipUnless(has_einops, "Requires einops")
+ def test_right_dropout(self, input_param):
+ _ = DiffusionModelUNetMaisi(**input_param)
+
+ @parameterized.expand(UNCOND_CASES_3D)
+ @skipUnless(has_einops, "Requires einops")
+ def test_shape_with_additional_inputs(self, input_param):
+ input_param["include_top_region_index_input"] = True
+ input_param["include_bottom_region_index_input"] = True
+ input_param["include_spacing_input"] = True
+ net = DiffusionModelUNetMaisi(**input_param)
+ with eval_mode(net):
+ result = net.forward(
+ x=torch.rand((1, 1, 16, 16, 16)),
+ timesteps=torch.randint(0, 1000, (1,)).long(),
+ top_region_index_tensor=torch.rand((1, 4)),
+ bottom_region_index_tensor=torch.rand((1, 4)),
+ spacing_tensor=torch.rand((1, 3)),
+ )
+ self.assertEqual(result.shape, (1, 1, 16, 16, 16))
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_dynunet.py b/tests/test_dynunet.py
index f3c982056c..7c4882fcbb 100644
--- a/tests/test_dynunet.py
+++ b/tests/test_dynunet.py
@@ -13,7 +13,8 @@
import platform
import unittest
-from typing import Any, Sequence
+from collections.abc import Sequence
+from typing import Any
import torch
from parameterized import parameterized
diff --git a/tests/test_ensure_channel_first.py b/tests/test_ensure_channel_first.py
index 0c9ad5869e..fe046a4cdf 100644
--- a/tests/test_ensure_channel_first.py
+++ b/tests/test_ensure_channel_first.py
@@ -50,9 +50,10 @@ class TestEnsureChannelFirst(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6])
@unittest.skipUnless(has_itk, "itk not installed")
def test_load_nifti(self, input_param, filenames, original_channel_dim):
- if original_channel_dim is None:
- test_image = np.random.rand(8, 8, 8)
- elif original_channel_dim == -1:
+ # if original_channel_dim is None
+ test_image = np.random.rand(8, 8, 8)
+
+ if original_channel_dim == -1:
test_image = np.random.rand(8, 8, 8, 1)
with tempfile.TemporaryDirectory() as tempdir:
diff --git a/tests/test_ensure_channel_firstd.py b/tests/test_ensure_channel_firstd.py
index 63a437894b..e9effad951 100644
--- a/tests/test_ensure_channel_firstd.py
+++ b/tests/test_ensure_channel_firstd.py
@@ -35,9 +35,10 @@ class TestEnsureChannelFirstd(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
def test_load_nifti(self, input_param, filenames, original_channel_dim):
- if original_channel_dim is None:
- test_image = np.random.rand(8, 8, 8)
- elif original_channel_dim == -1:
+ # if original_channel_dim is None:
+ test_image = np.random.rand(8, 8, 8)
+
+ if original_channel_dim == -1:
test_image = np.random.rand(8, 8, 8, 1)
with tempfile.TemporaryDirectory() as tempdir:
diff --git a/tests/test_ensure_typed.py b/tests/test_ensure_typed.py
index 09aa1f04b5..fe543347de 100644
--- a/tests/test_ensure_typed.py
+++ b/tests/test_ensure_typed.py
@@ -33,8 +33,8 @@ def test_array_input(self):
keys="data", data_type=dtype, dtype=np.float32 if dtype == "NUMPY" else None, device="cpu"
)({"data": test_data})["data"]
if dtype == "NUMPY":
- self.assertTrue(result.dtype == np.float32)
- self.assertTrue(isinstance(result, torch.Tensor if dtype == "tensor" else np.ndarray))
+ self.assertEqual(result.dtype, np.float32)
+ self.assertIsInstance(result, torch.Tensor if dtype == "tensor" else np.ndarray)
assert_allclose(result, test_data, type_test=False)
self.assertTupleEqual(result.shape, (2, 2))
@@ -45,7 +45,7 @@ def test_single_input(self):
for test_data in test_datas:
for dtype in ("tensor", "numpy"):
result = EnsureTyped(keys="data", data_type=dtype)({"data": test_data})["data"]
- self.assertTrue(isinstance(result, torch.Tensor if dtype == "tensor" else np.ndarray))
+ self.assertIsInstance(result, torch.Tensor if dtype == "tensor" else np.ndarray)
if isinstance(test_data, bool):
self.assertFalse(result)
else:
@@ -56,11 +56,11 @@ def test_string(self):
for dtype in ("tensor", "numpy"):
# string input
result = EnsureTyped(keys="data", data_type=dtype)({"data": "test_string"})["data"]
- self.assertTrue(isinstance(result, str))
+ self.assertIsInstance(result, str)
self.assertEqual(result, "test_string")
# numpy array of string
result = EnsureTyped(keys="data", data_type=dtype)({"data": np.array(["test_string"])})["data"]
- self.assertTrue(isinstance(result, np.ndarray))
+ self.assertIsInstance(result, np.ndarray)
self.assertEqual(result[0], "test_string")
def test_list_tuple(self):
@@ -68,15 +68,15 @@ def test_list_tuple(self):
result = EnsureTyped(keys="data", data_type=dtype, wrap_sequence=False, track_meta=True)(
{"data": [[1, 2], [3, 4]]}
)["data"]
- self.assertTrue(isinstance(result, list))
- self.assertTrue(isinstance(result[0][1], MetaTensor if dtype == "tensor" else np.ndarray))
+ self.assertIsInstance(result, list)
+ self.assertIsInstance(result[0][1], MetaTensor if dtype == "tensor" else np.ndarray)
assert_allclose(result[1][0], torch.as_tensor(3), type_test=False)
# tuple of numpy arrays
result = EnsureTyped(keys="data", data_type=dtype, wrap_sequence=False)(
{"data": (np.array([1, 2]), np.array([3, 4]))}
)["data"]
- self.assertTrue(isinstance(result, tuple))
- self.assertTrue(isinstance(result[0], torch.Tensor if dtype == "tensor" else np.ndarray))
+ self.assertIsInstance(result, tuple)
+ self.assertIsInstance(result[0], torch.Tensor if dtype == "tensor" else np.ndarray)
assert_allclose(result[1], torch.as_tensor([3, 4]), type_test=False)
def test_dict(self):
@@ -92,19 +92,19 @@ def test_dict(self):
)
for key in ("data", "label"):
result = trans[key]
- self.assertTrue(isinstance(result, dict))
- self.assertTrue(isinstance(result["img"], torch.Tensor if dtype == "tensor" else np.ndarray))
- self.assertTrue(isinstance(result["meta"]["size"], torch.Tensor if dtype == "tensor" else np.ndarray))
+ self.assertIsInstance(result, dict)
+ self.assertIsInstance(result["img"], torch.Tensor if dtype == "tensor" else np.ndarray)
+ self.assertIsInstance(result["meta"]["size"], torch.Tensor if dtype == "tensor" else np.ndarray)
self.assertEqual(result["meta"]["path"], "temp/test")
self.assertEqual(result["extra"], None)
assert_allclose(result["img"], torch.as_tensor([1.0, 2.0]), type_test=False)
assert_allclose(result["meta"]["size"], torch.as_tensor([1, 2, 3]), type_test=False)
if dtype == "numpy":
- self.assertTrue(trans["data"]["img"].dtype == np.float32)
- self.assertTrue(trans["label"]["img"].dtype == np.int8)
+ self.assertEqual(trans["data"]["img"].dtype, np.float32)
+ self.assertEqual(trans["label"]["img"].dtype, np.int8)
else:
- self.assertTrue(trans["data"]["img"].dtype == torch.float32)
- self.assertTrue(trans["label"]["img"].dtype == torch.int8)
+ self.assertEqual(trans["data"]["img"].dtype, torch.float32)
+ self.assertEqual(trans["label"]["img"].dtype, torch.int8)
if __name__ == "__main__":
diff --git a/tests/test_evenly_divisible_all_gather_dist.py b/tests/test_evenly_divisible_all_gather_dist.py
index d6d26c7e23..f1d45ba48f 100644
--- a/tests/test_evenly_divisible_all_gather_dist.py
+++ b/tests/test_evenly_divisible_all_gather_dist.py
@@ -27,10 +27,10 @@ def test_data(self):
self._run()
def _run(self):
- if dist.get_rank() == 0:
- data1 = torch.tensor([[1, 2], [3, 4]])
- data2 = torch.tensor([[1.0, 2.0]])
- data3 = torch.tensor(7)
+ # if dist.get_rank() == 0
+ data1 = torch.tensor([[1, 2], [3, 4]])
+ data2 = torch.tensor([[1.0, 2.0]])
+ data3 = torch.tensor(7)
if dist.get_rank() == 1:
data1 = torch.tensor([[5, 6]])
diff --git a/tests/test_fastmri_reader.py b/tests/test_fastmri_reader.py
index af2eed7db5..06c3954eae 100644
--- a/tests/test_fastmri_reader.py
+++ b/tests/test_fastmri_reader.py
@@ -17,7 +17,7 @@
from parameterized import parameterized
from monai.apps.reconstruction.fastmri_reader import FastMRIReader
-from tests.utils import assert_allclose
+from tests.utils import SkipIfNoModule, assert_allclose
TEST_CASE1 = [
{
@@ -64,6 +64,7 @@
]
+@SkipIfNoModule("h5py")
class TestMRIUtils(unittest.TestCase):
@parameterized.expand([TEST_CASE1, TEST_CASE2])
diff --git a/tests/test_fl_monai_algo.py b/tests/test_fl_monai_algo.py
index 54bec24b98..c8cb3451fc 100644
--- a/tests/test_fl_monai_algo.py
+++ b/tests/test_fl_monai_algo.py
@@ -75,7 +75,7 @@
tracking={
"handlers_id": DEFAULT_HANDLERS_ID,
"configs": {
- "execute_config": f"{_data_dir}/config_executed.json",
+ "save_execute_config": f"{_data_dir}/config_executed.json",
"trainer": {
"_target_": "MLFlowHandler",
"tracking_uri": path_to_uri(_data_dir) + "/mlflow_override",
@@ -201,7 +201,7 @@ def test_train(self, input_params):
algo.finalize()
# test experiment management
- if "execute_config" in algo.train_workflow.parser:
+ if "save_execute_config" in algo.train_workflow.parser:
self.assertTrue(os.path.exists(f"{_data_dir}/mlflow_override"))
shutil.rmtree(f"{_data_dir}/mlflow_override")
self.assertTrue(os.path.exists(f"{_data_dir}/config_executed.json"))
@@ -224,7 +224,7 @@ def test_evaluate(self, input_params):
algo.evaluate(data=data, extra={})
# test experiment management
- if "execute_config" in algo.eval_workflow.parser:
+ if "save_execute_config" in algo.eval_workflow.parser:
self.assertGreater(len(list(glob.glob(f"{_data_dir}/mlflow_*"))), 0)
for f in list(glob.glob(f"{_data_dir}/mlflow_*")):
shutil.rmtree(f)
diff --git a/tests/test_flexible_unet.py b/tests/test_flexible_unet.py
index 404855c9a8..42baa28b71 100644
--- a/tests/test_flexible_unet.py
+++ b/tests/test_flexible_unet.py
@@ -23,12 +23,11 @@
EfficientNetBNFeatures,
FlexibleUNet,
FlexUNetEncoderRegister,
- ResNet,
- ResNetBlock,
- ResNetBottleneck,
+ ResNetEncoder,
+ ResNetFeatures,
)
from monai.utils import optional_import
-from tests.utils import skip_if_downloading_fails, skip_if_quick
+from tests.utils import SkipIfNoModule, skip_if_downloading_fails, skip_if_quick
torchvision, has_torchvision = optional_import("torchvision")
PIL, has_pil = optional_import("PIL")
@@ -59,101 +58,6 @@ def get_encoder_names(cls):
return ["encoder_wrong_channels", "encoder_no_param1", "encoder_no_param2", "encoder_no_param3"]
-class ResNetEncoder(ResNet, BaseEncoder):
- backbone_names = ["resnet10", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnet200"]
- output_feature_channels = [(64, 128, 256, 512)] * 3 + [(256, 512, 1024, 2048)] * 4
- parameter_layers = [
- [1, 1, 1, 1],
- [2, 2, 2, 2],
- [3, 4, 6, 3],
- [3, 4, 6, 3],
- [3, 4, 23, 3],
- [3, 8, 36, 3],
- [3, 24, 36, 3],
- ]
-
- def __init__(self, in_channels, pretrained, **kargs):
- super().__init__(**kargs, n_input_channels=in_channels)
- if pretrained:
- # Author of paper zipped the state_dict on googledrive,
- # so would need to download, unzip and read (2.8gb file for a ~150mb state dict).
- # Would like to load dict from url but need somewhere to save the state dicts.
- raise NotImplementedError(
- "Currently not implemented. You need to manually download weights provided by the paper's author"
- " and load then to the model with `state_dict`. See https://github.com/Tencent/MedicalNet"
- )
-
- @staticmethod
- def get_inplanes():
- return [64, 128, 256, 512]
-
- @classmethod
- def get_encoder_parameters(cls) -> list[dict]:
- """
- Get parameter list to initialize encoder networks.
- Each parameter dict must have `spatial_dims`, `in_channels`
- and `pretrained` parameters.
- """
- parameter_list = []
- res_type: type[ResNetBlock] | type[ResNetBottleneck]
- for backbone in range(len(cls.backbone_names)):
- if backbone < 3:
- res_type = ResNetBlock
- else:
- res_type = ResNetBottleneck
- parameter_list.append(
- {
- "block": res_type,
- "layers": cls.parameter_layers[backbone],
- "block_inplanes": ResNetEncoder.get_inplanes(),
- "spatial_dims": 2,
- "in_channels": 3,
- "pretrained": False,
- }
- )
- return parameter_list
-
- @classmethod
- def num_channels_per_output(cls):
- """
- Get number of output features' channel.
- """
- return cls.output_feature_channels
-
- @classmethod
- def num_outputs(cls):
- """
- Get number of output feature.
- """
- return [4] * 7
-
- @classmethod
- def get_encoder_names(cls):
- """
- Get the name string of backbones which will be used to initialize flexible unet.
- """
- return cls.backbone_names
-
- def forward(self, x: torch.Tensor):
- feature_list = []
- x = self.conv1(x)
- x = self.bn1(x)
- x = self.relu(x)
- if not self.no_max_pool:
- x = self.maxpool(x)
- x = self.layer1(x)
- feature_list.append(x)
- x = self.layer2(x)
- feature_list.append(x)
- x = self.layer3(x)
- feature_list.append(x)
- x = self.layer4(x)
- feature_list.append(x)
-
- return feature_list
-
-
-FLEXUNET_BACKBONE.register_class(ResNetEncoder)
FLEXUNET_BACKBONE.register_class(DummyEncoder)
@@ -204,9 +108,7 @@ def make_shape_cases(
def make_error_case():
- error_dummy_backbones = DummyEncoder.get_encoder_names()
- error_resnet_backbones = ResNetEncoder.get_encoder_names()
- error_backbones = error_dummy_backbones + error_resnet_backbones
+ error_backbones = DummyEncoder.get_encoder_names()
error_param_list = []
for backbone in error_backbones:
error_param_list.append(
@@ -232,7 +134,7 @@ def make_error_case():
norm="instance",
)
CASES_3D = make_shape_cases(
- models=[SEL_MODELS[0]],
+ models=[SEL_MODELS[0], SEL_MODELS[2]],
spatial_dims=[3],
batches=[1],
pretrained=[False],
@@ -345,6 +247,7 @@ def make_error_case():
"spatial_dims": 2,
"norm": ("batch", {"eps": 1e-3, "momentum": 0.01}),
},
+ EfficientNetBNFeatures,
{
"in_channels": 3,
"num_classes": 10,
@@ -354,7 +257,20 @@ def make_error_case():
"norm": ("batch", {"eps": 1e-3, "momentum": 0.01}),
},
["_conv_stem.weight"],
- )
+ ),
+ (
+ {
+ "in_channels": 1,
+ "out_channels": 10,
+ "backbone": SEL_MODELS[2],
+ "pretrained": True,
+ "spatial_dims": 3,
+ "norm": ("batch", {"eps": 1e-3, "momentum": 0.01}),
+ },
+ ResNetFeatures,
+ {"model_name": SEL_MODELS[2], "pretrained": True, "spatial_dims": 3, "in_channels": 1},
+ ["conv1.weight"],
+ ),
]
CASE_ERRORS = make_error_case()
@@ -363,6 +279,7 @@ def make_error_case():
CASE_REGISTER_ENCODER = ["EfficientNetEncoder", "monai.networks.nets.EfficientNetEncoder"]
+@SkipIfNoModule("hf_hub_download")
@skip_if_quick
class TestFLEXIBLEUNET(unittest.TestCase):
@@ -381,19 +298,19 @@ def test_shape(self, input_param, input_shape, expected_shape):
self.assertEqual(result.shape, expected_shape)
@parameterized.expand(CASES_PRETRAIN)
- def test_pretrain(self, input_param, efficient_input_param, weight_list):
+ def test_pretrain(self, flexunet_input_param, feature_extractor_class, feature_extractor_input_param, weight_list):
device = "cuda" if torch.cuda.is_available() else "cpu"
with skip_if_downloading_fails():
- net = FlexibleUNet(**input_param).to(device)
+ net = FlexibleUNet(**flexunet_input_param).to(device)
with skip_if_downloading_fails():
- eff_net = EfficientNetBNFeatures(**efficient_input_param).to(device)
+ feature_extractor_net = feature_extractor_class(**feature_extractor_input_param).to(device)
for weight_name in weight_list:
- if weight_name in net.encoder.state_dict() and weight_name in eff_net.state_dict():
+ if weight_name in net.encoder.state_dict() and weight_name in feature_extractor_net.state_dict():
net_weight = net.encoder.state_dict()[weight_name]
- download_weight = eff_net.state_dict()[weight_name]
+ download_weight = feature_extractor_net.state_dict()[weight_name]
weight_diff = torch.abs(net_weight - download_weight)
diff_sum = torch.sum(weight_diff)
# check if a weight in weight_list equals to the downloaded weight.
diff --git a/tests/test_flipd.py b/tests/test_flipd.py
index 277f387051..1df6d34056 100644
--- a/tests/test_flipd.py
+++ b/tests/test_flipd.py
@@ -78,7 +78,7 @@ def test_torch(self, spatial_axis, img: torch.Tensor, track_meta: bool, device):
def test_meta_dict(self):
xform = Flipd("image", [0, 1])
res = xform({"image": torch.zeros(1, 3, 4)})
- self.assertTrue(res["image"].applied_operations == res["image_transforms"])
+ self.assertEqual(res["image"].applied_operations, res["image_transforms"])
if __name__ == "__main__":
diff --git a/tests/test_freeze_layers.py b/tests/test_freeze_layers.py
index 1bea4ed1b5..7be8e576bf 100644
--- a/tests/test_freeze_layers.py
+++ b/tests/test_freeze_layers.py
@@ -40,9 +40,9 @@ def test_freeze_vars(self, device):
for name, param in model.named_parameters():
if "class_layer" in name:
- self.assertEqual(param.requires_grad, False)
+ self.assertFalse(param.requires_grad)
else:
- self.assertEqual(param.requires_grad, True)
+ self.assertTrue(param.requires_grad)
@parameterized.expand(TEST_CASES)
def test_exclude_vars(self, device):
@@ -53,9 +53,9 @@ def test_exclude_vars(self, device):
for name, param in model.named_parameters():
if "class_layer" in name:
- self.assertEqual(param.requires_grad, True)
+ self.assertTrue(param.requires_grad)
else:
- self.assertEqual(param.requires_grad, False)
+ self.assertFalse(param.requires_grad)
if __name__ == "__main__":
diff --git a/tests/test_gdsdataset.py b/tests/test_gdsdataset.py
index f0a419dcf5..5d2e2aa013 100644
--- a/tests/test_gdsdataset.py
+++ b/tests/test_gdsdataset.py
@@ -23,7 +23,7 @@
from monai.data import GDSDataset, json_hashing
from monai.transforms import Compose, Flip, Identity, LoadImaged, SimulateDelayd, Transform
from monai.utils import optional_import
-from tests.utils import TEST_NDARRAYS, assert_allclose
+from tests.utils import TEST_NDARRAYS, assert_allclose, skip_if_no_cuda
_, has_cp = optional_import("cupy")
nib, has_nib = optional_import("nibabel")
@@ -70,9 +70,9 @@ def __call__(self, data):
return data
+@skip_if_no_cuda
@unittest.skipUnless(has_cp, "Requires CuPy library.")
-@unittest.skipUnless(has_nib, "Requires nibabel package.")
-@unittest.skipUnless(has_kvikio_numpy, "Requires scikit-image library.")
+@unittest.skipUnless(has_cp and has_kvikio_numpy, "Requires CuPy and kvikio library.")
class TestDataset(unittest.TestCase):
def test_cache(self):
@@ -131,6 +131,7 @@ def test_dtype(self):
self.assertEqual(ds[0].dtype, DTYPES[_dtype])
self.assertEqual(ds1[0].dtype, DTYPES[_dtype])
+ @unittest.skipUnless(has_nib, "Requires nibabel package.")
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
def test_shape(self, transform, expected_shape):
test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]).astype(float), np.eye(4))
diff --git a/tests/test_generalized_dice_focal_loss.py b/tests/test_generalized_dice_focal_loss.py
index 8a0a80865e..65252611ca 100644
--- a/tests/test_generalized_dice_focal_loss.py
+++ b/tests/test_generalized_dice_focal_loss.py
@@ -59,8 +59,18 @@ def test_result_no_onehot_no_bg(self):
def test_ill_shape(self):
loss = GeneralizedDiceFocalLoss()
- with self.assertRaisesRegex(ValueError, ""):
- loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3)))
+ with self.assertRaises(AssertionError):
+ loss.forward(torch.ones((1, 2, 3)), torch.ones((1, 2, 5)))
+
+ def test_ill_shape2(self):
+ loss = GeneralizedDiceFocalLoss()
+ with self.assertRaises(ValueError):
+ loss.forward(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3)))
+
+ def test_ill_shape3(self):
+ loss = GeneralizedDiceFocalLoss()
+ with self.assertRaises(ValueError):
+ loss.forward(torch.ones((1, 3, 4, 4)), torch.ones((1, 2, 4, 4)))
def test_ill_lambda(self):
with self.assertRaisesRegex(ValueError, ""):
diff --git a/tests/test_generalized_dice_loss.py b/tests/test_generalized_dice_loss.py
index 7499507129..5738f4a089 100644
--- a/tests/test_generalized_dice_loss.py
+++ b/tests/test_generalized_dice_loss.py
@@ -184,7 +184,7 @@ def test_differentiability(self):
generalized_dice_loss = GeneralizedDiceLoss()
loss = generalized_dice_loss(prediction, target)
- self.assertNotEqual(loss.grad_fn, None)
+ self.assertIsNotNone(loss.grad_fn)
def test_batch(self):
prediction = torch.zeros(2, 3, 3, 3)
@@ -194,7 +194,7 @@ def test_batch(self):
generalized_dice_loss = GeneralizedDiceLoss(batch=True)
loss = generalized_dice_loss(prediction, target)
- self.assertNotEqual(loss.grad_fn, None)
+ self.assertIsNotNone(loss.grad_fn)
def test_script(self):
loss = GeneralizedDiceLoss()
diff --git a/tests/test_get_package_version.py b/tests/test_get_package_version.py
index ab9e69cd31..e9e1d8eca6 100644
--- a/tests/test_get_package_version.py
+++ b/tests/test_get_package_version.py
@@ -20,14 +20,14 @@ class TestGetVersion(unittest.TestCase):
def test_default(self):
output = get_package_version("42foobarnoexist")
- self.assertTrue("UNKNOWN" in output)
+ self.assertIn("UNKNOWN", output)
output = get_package_version("numpy")
- self.assertFalse("UNKNOWN" in output)
+ self.assertNotIn("UNKNOWN", output)
def test_msg(self):
output = get_package_version("42foobarnoexist", "test")
- self.assertTrue("test" in output)
+ self.assertIn("test", output)
if __name__ == "__main__":
diff --git a/tests/test_global_mutual_information_loss.py b/tests/test_global_mutual_information_loss.py
index 36a1978c93..22f5e88431 100644
--- a/tests/test_global_mutual_information_loss.py
+++ b/tests/test_global_mutual_information_loss.py
@@ -15,6 +15,7 @@
import numpy as np
import torch
+from parameterized import parameterized
from monai import transforms
from monai.losses.image_dissimilarity import GlobalMutualInformationLoss
@@ -116,24 +117,33 @@ def transformation(translate_params=(0.0, 0.0, 0.0), rotate_params=(0.0, 0.0, 0.
class TestGlobalMutualInformationLossIll(unittest.TestCase):
- def test_ill_shape(self):
+ @parameterized.expand(
+ [
+ (torch.ones((1, 2), dtype=torch.float), torch.ones((1, 3), dtype=torch.float)), # mismatched_simple_dims
+ (
+ torch.ones((1, 3, 3), dtype=torch.float),
+ torch.ones((1, 3), dtype=torch.float),
+ ), # mismatched_advanced_dims
+ ]
+ )
+ def test_ill_shape(self, input1, input2):
loss = GlobalMutualInformationLoss()
- with self.assertRaisesRegex(ValueError, ""):
- loss.forward(torch.ones((1, 2), dtype=torch.float), torch.ones((1, 3), dtype=torch.float, device=device))
- with self.assertRaisesRegex(ValueError, ""):
- loss.forward(torch.ones((1, 3, 3), dtype=torch.float), torch.ones((1, 3), dtype=torch.float, device=device))
-
- def test_ill_opts(self):
+ with self.assertRaises(ValueError):
+ loss.forward(input1, input2)
+
+ @parameterized.expand(
+ [
+ (0, "mean", ValueError, ""), # num_bins_zero
+ (-1, "mean", ValueError, ""), # num_bins_negative
+ (64, "unknown", ValueError, ""), # reduction_unknown
+ (64, None, ValueError, ""), # reduction_none
+ ]
+ )
+ def test_ill_opts(self, num_bins, reduction, expected_exception, expected_message):
pred = torch.ones((1, 3, 3, 3, 3), dtype=torch.float, device=device)
target = torch.ones((1, 3, 3, 3, 3), dtype=torch.float, device=device)
- with self.assertRaisesRegex(ValueError, ""):
- GlobalMutualInformationLoss(num_bins=0)(pred, target)
- with self.assertRaisesRegex(ValueError, ""):
- GlobalMutualInformationLoss(num_bins=-1)(pred, target)
- with self.assertRaisesRegex(ValueError, ""):
- GlobalMutualInformationLoss(reduction="unknown")(pred, target)
- with self.assertRaisesRegex(ValueError, ""):
- GlobalMutualInformationLoss(reduction=None)(pred, target)
+ with self.assertRaisesRegex(expected_exception, expected_message):
+ GlobalMutualInformationLoss(num_bins=num_bins, reduction=reduction)(pred, target)
if __name__ == "__main__":
diff --git a/tests/test_grid_patch.py b/tests/test_grid_patch.py
index 4b324eda1a..56af123548 100644
--- a/tests/test_grid_patch.py
+++ b/tests/test_grid_patch.py
@@ -124,11 +124,11 @@ def test_grid_patch_meta(self, input_parameters, image, expected, expected_meta)
self.assertTrue(output.meta["path"] == expected_meta[0]["path"])
for output_patch, expected_patch, expected_patch_meta in zip(output, expected, expected_meta):
assert_allclose(output_patch, expected_patch, type_test=False)
- self.assertTrue(isinstance(output_patch, MetaTensor))
- self.assertTrue(output_patch.meta["location"] == expected_patch_meta["location"])
+ self.assertIsInstance(output_patch, MetaTensor)
+ self.assertEqual(output_patch.meta["location"], expected_patch_meta["location"])
self.assertTrue(output_patch.meta["spatial_shape"], list(output_patch.shape[1:]))
if "path" in expected_meta[0]:
- self.assertTrue(output_patch.meta["path"] == expected_patch_meta["path"])
+ self.assertEqual(output_patch.meta["path"], expected_patch_meta["path"])
if __name__ == "__main__":
diff --git a/tests/test_handler_garbage_collector.py b/tests/test_handler_garbage_collector.py
index 317eba1b11..4254a73a6b 100644
--- a/tests/test_handler_garbage_collector.py
+++ b/tests/test_handler_garbage_collector.py
@@ -19,10 +19,9 @@
from ignite.engine import Engine
from parameterized import parameterized
-from monai.config import IgniteInfo
from monai.data import Dataset
from monai.handlers import GarbageCollector
-from monai.utils import min_version, optional_import
+from monai.utils import IgniteInfo, min_version, optional_import
Events, has_ignite = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
diff --git a/tests/test_handler_ignite_metric.py b/tests/test_handler_ignite_metric.py
index 28e0b69621..3e42bda35d 100644
--- a/tests/test_handler_ignite_metric.py
+++ b/tests/test_handler_ignite_metric.py
@@ -16,7 +16,7 @@
import torch
from parameterized import parameterized
-from monai.handlers import IgniteMetric, IgniteMetricHandler, from_engine
+from monai.handlers import IgniteMetricHandler, from_engine
from monai.losses import DiceLoss
from monai.metrics import LossMetric
from tests.utils import SkipIfNoModule, assert_allclose, optional_import
@@ -172,7 +172,7 @@ def _val_func(engine, batch):
@parameterized.expand(TEST_CASES[0:2])
def test_old_ignite_metric(self, input_param, input_data, expected_val):
loss_fn = DiceLoss(**input_param)
- ignite_metric = IgniteMetric(loss_fn=loss_fn, output_transform=from_engine(["pred", "label"]))
+ ignite_metric = IgniteMetricHandler(loss_fn=loss_fn, output_transform=from_engine(["pred", "label"]))
def _val_func(engine, batch):
pass
diff --git a/tests/test_handler_metrics_saver_dist.py b/tests/test_handler_metrics_saver_dist.py
index 46c9ad27d7..2e12b08aa9 100644
--- a/tests/test_handler_metrics_saver_dist.py
+++ b/tests/test_handler_metrics_saver_dist.py
@@ -51,8 +51,10 @@ def _val_func(engine, batch):
engine = Engine(_val_func)
+ # define here to ensure symbol always exists regardless of the following if conditions
+ data = [{PostFix.meta("image"): {"filename_or_obj": [fnames[0]]}}]
+
if my_rank == 0:
- data = [{PostFix.meta("image"): {"filename_or_obj": [fnames[0]]}}]
@engine.on(Events.EPOCH_COMPLETED)
def _save_metrics0(engine):
diff --git a/tests/test_handler_mlflow.py b/tests/test_handler_mlflow.py
index 44adc49fc2..36d59ff1bf 100644
--- a/tests/test_handler_mlflow.py
+++ b/tests/test_handler_mlflow.py
@@ -122,6 +122,11 @@ def _train_func(engine, batch):
def _update_metric(engine):
current_metric = engine.state.metrics.get("acc", 0.1)
engine.state.metrics["acc"] = current_metric + 0.1
+ # log nested metrics
+ engine.state.metrics["acc_per_label"] = {
+ "label_0": current_metric + 0.1,
+ "label_1": current_metric + 0.2,
+ }
engine.state.test = current_metric
# set up testing handler
@@ -138,10 +143,12 @@ def _update_metric(engine):
state_attributes=["test"],
experiment_param=experiment_param,
artifacts=[artifact_path],
- close_on_complete=True,
+ close_on_complete=False,
)
handler.attach(engine)
engine.run(range(3), max_epochs=2)
+ cur_run = handler.client.get_run(handler.cur_run.info.run_id)
+ self.assertTrue("label_0" in cur_run.data.metrics.keys())
handler.close()
# check logging output
self.assertTrue(len(glob.glob(test_path)) > 0)
diff --git a/tests/test_handler_prob_map_producer.py b/tests/test_handler_prob_map_producer.py
index 347f8cb92c..406fe77c8f 100644
--- a/tests/test_handler_prob_map_producer.py
+++ b/tests/test_handler_prob_map_producer.py
@@ -30,6 +30,7 @@
class TestDataset(Dataset):
+ __test__ = False # indicate to pytest that this class is not intended for collection
def __init__(self, name, size):
super().__init__(
@@ -64,6 +65,7 @@ def __getitem__(self, index):
class TestEvaluator(Evaluator):
+ __test__ = False # indicate to pytest that this class is not intended for collection
def _iteration(self, engine, batchdata):
return batchdata
diff --git a/tests/test_handler_stats.py b/tests/test_handler_stats.py
index f876cff2a3..52da5c179b 100644
--- a/tests/test_handler_stats.py
+++ b/tests/test_handler_stats.py
@@ -76,9 +76,9 @@ def _update_metric(engine):
if has_key_word.match(line):
content_count += 1
if epoch_log is True:
- self.assertTrue(content_count == max_epochs)
+ self.assertEqual(content_count, max_epochs)
else:
- self.assertTrue(content_count == 2) # 2 = len([1, 2]) from event_filter
+ self.assertEqual(content_count, 2) # 2 = len([1, 2]) from event_filter
@parameterized.expand([[True], [get_event_filter([1, 3])]])
def test_loss_print(self, iteration_log):
@@ -116,9 +116,9 @@ def _train_func(engine, batch):
if has_key_word.match(line):
content_count += 1
if iteration_log is True:
- self.assertTrue(content_count == num_iters * max_epochs)
+ self.assertEqual(content_count, num_iters * max_epochs)
else:
- self.assertTrue(content_count == 2) # 2 = len([1, 3]) from event_filter
+ self.assertEqual(content_count, 2) # 2 = len([1, 3]) from event_filter
def test_loss_dict(self):
log_stream = StringIO()
@@ -150,7 +150,7 @@ def _train_func(engine, batch):
for line in output_str.split("\n"):
if has_key_word.match(line):
content_count += 1
- self.assertTrue(content_count > 0)
+ self.assertGreater(content_count, 0)
def test_loss_file(self):
key_to_handler = "test_logging"
@@ -184,7 +184,7 @@ def _train_func(engine, batch):
for line in output_str.split("\n"):
if has_key_word.match(line):
content_count += 1
- self.assertTrue(content_count > 0)
+ self.assertGreater(content_count, 0)
def test_exception(self):
# set up engine
@@ -239,7 +239,7 @@ def _update_metric(engine):
for line in output_str.split("\n"):
if has_key_word.match(line):
content_count += 1
- self.assertTrue(content_count > 0)
+ self.assertGreater(content_count, 0)
def test_default_logger(self):
log_stream = StringIO()
@@ -274,7 +274,7 @@ def _train_func(engine, batch):
for line in output_str.split("\n"):
if has_key_word.match(line):
content_count += 1
- self.assertTrue(content_count > 0)
+ self.assertGreater(content_count, 0)
if __name__ == "__main__":
diff --git a/tests/test_handler_validation.py b/tests/test_handler_validation.py
index 752b1d3df7..92f8578f11 100644
--- a/tests/test_handler_validation.py
+++ b/tests/test_handler_validation.py
@@ -22,6 +22,7 @@
class TestEvaluator(Evaluator):
+ __test__ = False # indicate to pytest that this class is not intended for collection
def _iteration(self, engine, batchdata):
engine.state.output = "called"
diff --git a/tests/test_hausdorff_loss.py b/tests/test_hausdorff_loss.py
index f279d45b14..f2211008c2 100644
--- a/tests/test_hausdorff_loss.py
+++ b/tests/test_hausdorff_loss.py
@@ -219,17 +219,12 @@ def test_ill_opts(self):
with self.assertRaisesRegex(ValueError, ""):
HausdorffDTLoss(reduction=None)(chn_input, chn_target)
- def test_input_warnings(self):
+ @parameterized.expand([(False, False, False), (False, True, False), (False, False, True)])
+ def test_input_warnings(self, include_background, softmax, to_onehot_y):
chn_input = torch.ones((1, 1, 1, 3))
chn_target = torch.ones((1, 1, 1, 3))
with self.assertWarns(Warning):
- loss = HausdorffDTLoss(include_background=False)
- loss.forward(chn_input, chn_target)
- with self.assertWarns(Warning):
- loss = HausdorffDTLoss(softmax=True)
- loss.forward(chn_input, chn_target)
- with self.assertWarns(Warning):
- loss = HausdorffDTLoss(to_onehot_y=True)
+ loss = HausdorffDTLoss(include_background=include_background, softmax=softmax, to_onehot_y=to_onehot_y)
loss.forward(chn_input, chn_target)
@@ -256,17 +251,12 @@ def test_ill_opts(self):
with self.assertRaisesRegex(ValueError, ""):
LogHausdorffDTLoss(reduction=None)(chn_input, chn_target)
- def test_input_warnings(self):
+ @parameterized.expand([(False, False, False), (False, True, False), (False, False, True)])
+ def test_input_warnings(self, include_background, softmax, to_onehot_y):
chn_input = torch.ones((1, 1, 1, 3))
chn_target = torch.ones((1, 1, 1, 3))
with self.assertWarns(Warning):
- loss = LogHausdorffDTLoss(include_background=False)
- loss.forward(chn_input, chn_target)
- with self.assertWarns(Warning):
- loss = LogHausdorffDTLoss(softmax=True)
- loss.forward(chn_input, chn_target)
- with self.assertWarns(Warning):
- loss = LogHausdorffDTLoss(to_onehot_y=True)
+ loss = LogHausdorffDTLoss(include_background=include_background, softmax=softmax, to_onehot_y=to_onehot_y)
loss.forward(chn_input, chn_target)
diff --git a/tests/test_hilbert_transform.py b/tests/test_hilbert_transform.py
index 879a74969d..b91ba3f6b7 100644
--- a/tests/test_hilbert_transform.py
+++ b/tests/test_hilbert_transform.py
@@ -19,11 +19,11 @@
from monai.networks.layers import HilbertTransform
from monai.utils import OptionalImportError
-from tests.utils import SkipIfModule, SkipIfNoModule, skip_if_no_cuda
+from tests.utils import SkipIfModule, SkipIfNoModule
def create_expected_numpy_output(input_datum, **kwargs):
- x = np.fft.fft(input_datum.cpu().numpy() if input_datum.device.type == "cuda" else input_datum.numpy(), **kwargs)
+ x = np.fft.fft(input_datum.cpu().numpy(), **kwargs)
f = np.fft.fftfreq(x.shape[kwargs["axis"]])
u = np.heaviside(f, 0.5)
new_dims_before = kwargs["axis"]
@@ -44,19 +44,15 @@ def create_expected_numpy_output(input_datum, **kwargs):
# CPU TEST DATA
cpu_input_data = {}
-cpu_input_data["1D"] = torch.as_tensor(hann_windowed_sine, device=cpu).unsqueeze(0).unsqueeze(0)
-cpu_input_data["2D"] = (
- torch.as_tensor(np.stack([hann_windowed_sine] * 10, axis=1), device=cpu).unsqueeze(0).unsqueeze(0)
-)
-cpu_input_data["3D"] = (
- torch.as_tensor(np.stack([np.stack([hann_windowed_sine] * 10, axis=1)] * 10, axis=2), device=cpu)
- .unsqueeze(0)
- .unsqueeze(0)
-)
-cpu_input_data["1D 2CH"] = torch.as_tensor(np.stack([hann_windowed_sine] * 10, axis=1), device=cpu).unsqueeze(0)
+cpu_input_data["1D"] = torch.as_tensor(hann_windowed_sine, device=cpu)[None, None]
+cpu_input_data["2D"] = torch.as_tensor(np.stack([hann_windowed_sine] * 10, axis=1), device=cpu)[None, None]
+cpu_input_data["3D"] = torch.as_tensor(
+ np.stack([np.stack([hann_windowed_sine] * 10, axis=1)] * 10, axis=2), device=cpu
+)[None, None]
+cpu_input_data["1D 2CH"] = torch.as_tensor(np.stack([hann_windowed_sine] * 10, axis=1), device=cpu)[None]
cpu_input_data["2D 2CH"] = torch.as_tensor(
np.stack([np.stack([hann_windowed_sine] * 10, axis=1)] * 10, axis=2), device=cpu
-).unsqueeze(0)
+)[None]
# SINGLE-CHANNEL CPU VALUE TESTS
@@ -97,64 +93,21 @@ def create_expected_numpy_output(input_datum, **kwargs):
1e-5, # absolute tolerance
]
+TEST_CASES_CPU = [
+ TEST_CASE_1D_SINE_CPU,
+ TEST_CASE_2D_SINE_CPU,
+ TEST_CASE_3D_SINE_CPU,
+ TEST_CASE_1D_2CH_SINE_CPU,
+ TEST_CASE_2D_2CH_SINE_CPU,
+]
+
# GPU TEST DATA
if torch.cuda.is_available():
gpu = torch.device("cuda")
-
- gpu_input_data = {}
- gpu_input_data["1D"] = torch.as_tensor(hann_windowed_sine, device=gpu).unsqueeze(0).unsqueeze(0)
- gpu_input_data["2D"] = (
- torch.as_tensor(np.stack([hann_windowed_sine] * 10, axis=1), device=gpu).unsqueeze(0).unsqueeze(0)
- )
- gpu_input_data["3D"] = (
- torch.as_tensor(np.stack([np.stack([hann_windowed_sine] * 10, axis=1)] * 10, axis=2), device=gpu)
- .unsqueeze(0)
- .unsqueeze(0)
- )
- gpu_input_data["1D 2CH"] = torch.as_tensor(np.stack([hann_windowed_sine] * 10, axis=1), device=gpu).unsqueeze(0)
- gpu_input_data["2D 2CH"] = torch.as_tensor(
- np.stack([np.stack([hann_windowed_sine] * 10, axis=1)] * 10, axis=2), device=gpu
- ).unsqueeze(0)
-
- # SINGLE CHANNEL GPU VALUE TESTS
-
- TEST_CASE_1D_SINE_GPU = [
- {}, # args (empty, so use default)
- gpu_input_data["1D"], # Input data: Random 1D signal
- create_expected_numpy_output(gpu_input_data["1D"], axis=2), # Expected output: FFT of signal
- 1e-5, # absolute tolerance
- ]
-
- TEST_CASE_2D_SINE_GPU = [
- {}, # args (empty, so use default)
- gpu_input_data["2D"], # Input data: Random 1D signal
- create_expected_numpy_output(gpu_input_data["2D"], axis=2), # Expected output: FFT of signal
- 1e-5, # absolute tolerance
- ]
-
- TEST_CASE_3D_SINE_GPU = [
- {}, # args (empty, so use default)
- gpu_input_data["3D"], # Input data: Random 1D signal
- create_expected_numpy_output(gpu_input_data["3D"], axis=2), # Expected output: FFT of signal
- 1e-5, # absolute tolerance
- ]
-
- # MULTICHANNEL GPU VALUE TESTS, PROCESS ALONG FIRST SPATIAL AXIS
-
- TEST_CASE_1D_2CH_SINE_GPU = [
- {}, # args (empty, so use default)
- gpu_input_data["1D 2CH"], # Input data: Random 1D signal
- create_expected_numpy_output(gpu_input_data["1D 2CH"], axis=2),
- 1e-5, # absolute tolerance
- ]
-
- TEST_CASE_2D_2CH_SINE_GPU = [
- {}, # args (empty, so use default)
- gpu_input_data["2D 2CH"], # Input data: Random 1D signal
- create_expected_numpy_output(gpu_input_data["2D 2CH"], axis=2),
- 1e-5, # absolute tolerance
- ]
+ TEST_CASES_GPU = [[args, image.to(gpu), exp_data, atol] for args, image, exp_data, atol in TEST_CASES_CPU]
+else:
+ TEST_CASES_GPU = []
# TESTS CHECKING PADDING, AXIS SELECTION ETC ARE COVERED BY test_detect_envelope.py
@@ -162,42 +115,10 @@ def create_expected_numpy_output(input_datum, **kwargs):
@SkipIfNoModule("torch.fft")
class TestHilbertTransformCPU(unittest.TestCase):
- @parameterized.expand(
- [
- TEST_CASE_1D_SINE_CPU,
- TEST_CASE_2D_SINE_CPU,
- TEST_CASE_3D_SINE_CPU,
- TEST_CASE_1D_2CH_SINE_CPU,
- TEST_CASE_2D_2CH_SINE_CPU,
- ]
- )
- def test_value(self, arguments, image, expected_data, atol):
- result = HilbertTransform(**arguments)(image)
- result = result.squeeze(0).squeeze(0).numpy()
- np.testing.assert_allclose(result, expected_data.squeeze(), atol=atol)
-
-
-@skip_if_no_cuda
-@SkipIfNoModule("torch.fft")
-class TestHilbertTransformGPU(unittest.TestCase):
-
- @parameterized.expand(
- (
- []
- if not torch.cuda.is_available()
- else [
- TEST_CASE_1D_SINE_GPU,
- TEST_CASE_2D_SINE_GPU,
- TEST_CASE_3D_SINE_GPU,
- TEST_CASE_1D_2CH_SINE_GPU,
- TEST_CASE_2D_2CH_SINE_GPU,
- ]
- ),
- skip_on_empty=True,
- )
+ @parameterized.expand(TEST_CASES_CPU + TEST_CASES_GPU)
def test_value(self, arguments, image, expected_data, atol):
result = HilbertTransform(**arguments)(image)
- result = result.squeeze(0).squeeze(0).cpu().numpy()
+ result = np.squeeze(result.cpu().numpy())
np.testing.assert_allclose(result, expected_data.squeeze(), atol=atol)
diff --git a/tests/test_image_filter.py b/tests/test_image_filter.py
index adc9dade9c..fb08b2295d 100644
--- a/tests/test_image_filter.py
+++ b/tests/test_image_filter.py
@@ -38,6 +38,7 @@
class TestModule(torch.nn.Module):
+ __test__ = False # indicate to pytest that this class is not intended for collection
def __init__(self):
super().__init__()
@@ -133,6 +134,12 @@ def test_pass_empty_metadata_dict(self):
out_tensor = filter(image)
self.assertTrue(isinstance(out_tensor, MetaTensor))
+ def test_gaussian_filter_without_filter_size(self):
+ "Test Gaussian filter without specifying filter_size"
+ filter = ImageFilter("gauss", sigma=2)
+ out_tensor = filter(SAMPLE_IMAGE_2D)
+ self.assertEqual(out_tensor.shape[1:], SAMPLE_IMAGE_2D.shape[1:])
+
class TestImageFilterDict(unittest.TestCase):
diff --git a/tests/test_integration_bundle_run.py b/tests/test_integration_bundle_run.py
index c2e0fb55b7..60aaef05bf 100644
--- a/tests/test_integration_bundle_run.py
+++ b/tests/test_integration_bundle_run.py
@@ -135,9 +135,8 @@ def test_scripts_fold(self):
command_run = cmd + ["run", "training", "--config_file", config_file, "--meta_file", meta_file]
completed_process = subprocess.run(command_run, check=True, capture_output=True, text=True)
output = repr(completed_process.stdout).replace("\\n", "\n").replace("\\t", "\t") # Get the captured output
- print(output)
- self.assertTrue(expected_condition in output)
+ self.assertIn(expected_condition, output)
command_run_workflow = cmd + [
"run_workflow",
"--run_id",
@@ -149,8 +148,7 @@ def test_scripts_fold(self):
]
completed_process = subprocess.run(command_run_workflow, check=True, capture_output=True, text=True)
output = repr(completed_process.stdout).replace("\\n", "\n").replace("\\t", "\t") # Get the captured output
- print(output)
- self.assertTrue(expected_condition in output)
+ self.assertIn(expected_condition, output)
# test missing meta file
self.assertIn("ERROR", command_line_tests(cmd + ["run", "training", "--config_file", config_file]))
diff --git a/tests/test_integration_unet_2d.py b/tests/test_integration_unet_2d.py
index 918190775c..3b40682de0 100644
--- a/tests/test_integration_unet_2d.py
+++ b/tests/test_integration_unet_2d.py
@@ -35,6 +35,7 @@ def __getitem__(self, _unused_id):
def __len__(self):
return train_steps
+ net = None
if net_name == "basicunet":
net = BasicUNet(spatial_dims=2, in_channels=1, out_channels=1, features=(4, 8, 8, 16, 16, 32))
elif net_name == "unet":
diff --git a/tests/test_integration_workflows_adversarial.py b/tests/test_integration_workflows_adversarial.py
new file mode 100644
index 0000000000..f323fc9917
--- /dev/null
+++ b/tests/test_integration_workflows_adversarial.py
@@ -0,0 +1,173 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import os
+import shutil
+import tempfile
+import unittest
+from glob import glob
+
+import numpy as np
+import torch
+
+import monai
+from monai.data import create_test_image_2d
+from monai.engines import AdversarialTrainer
+from monai.handlers import CheckpointSaver, StatsHandler, TensorBoardStatsHandler
+from monai.networks.nets import AutoEncoder, Discriminator
+from monai.transforms import Compose, EnsureChannelFirstd, LoadImaged, RandFlipd, ScaleIntensityd
+from monai.utils import AdversarialKeys as Keys
+from monai.utils import CommonKeys, optional_import, set_determinism
+from tests.utils import DistTestCase, TimedCall, skip_if_quick
+
+nib, has_nibabel = optional_import("nibabel")
+
+
+def run_training_test(root_dir, device="cuda:0"):
+ learning_rate = 2e-4
+ real_label = 1
+ fake_label = 0
+
+ real_images = sorted(glob(os.path.join(root_dir, "img*.nii.gz")))
+ train_files = [{CommonKeys.IMAGE: img, CommonKeys.LABEL: img} for img in zip(real_images)]
+
+ # prepare real data
+ train_transforms = Compose(
+ [
+ LoadImaged(keys=[CommonKeys.IMAGE, CommonKeys.LABEL]),
+ EnsureChannelFirstd(keys=[CommonKeys.IMAGE, CommonKeys.LABEL], channel_dim=2),
+ ScaleIntensityd(keys=[CommonKeys.IMAGE]),
+ RandFlipd(keys=[CommonKeys.IMAGE, CommonKeys.LABEL], prob=0.5),
+ ]
+ )
+ train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms, cache_rate=0.5)
+ train_loader = monai.data.DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4)
+
+ # Create Discriminator
+ discriminator_net = Discriminator(
+ in_shape=(1, 64, 64), channels=(8, 16, 32, 64, 1), strides=(2, 2, 2, 2, 1), num_res_units=1, kernel_size=5
+ ).to(device)
+ discriminator_opt = torch.optim.Adam(discriminator_net.parameters(), learning_rate)
+ discriminator_loss_criterion = torch.nn.BCELoss()
+
+ def discriminator_loss(real_logits, fake_logits):
+ real_target = real_logits.new_full((real_logits.shape[0], 1), real_label)
+ fake_target = fake_logits.new_full((fake_logits.shape[0], 1), fake_label)
+ real_loss = discriminator_loss_criterion(real_logits, real_target)
+ fake_loss = discriminator_loss_criterion(fake_logits.detach(), fake_target)
+ return torch.div(torch.add(real_loss, fake_loss), 2)
+
+ # Create Generator
+ generator_network = AutoEncoder(
+ spatial_dims=2,
+ in_channels=1,
+ out_channels=1,
+ channels=(8, 16, 32, 64),
+ strides=(2, 2, 2, 2),
+ num_res_units=1,
+ num_inter_units=1,
+ )
+ generator_network = generator_network.to(device)
+ generator_optimiser = torch.optim.Adam(generator_network.parameters(), learning_rate)
+ generator_loss_criterion = torch.nn.MSELoss()
+
+ def reconstruction_loss(recon_images, real_images):
+ return generator_loss_criterion(recon_images, real_images)
+
+ def generator_loss(fake_logits):
+ fake_target = fake_logits.new_full((fake_logits.shape[0], 1), real_label)
+ recon_loss = discriminator_loss_criterion(fake_logits.detach(), fake_target)
+ return recon_loss
+
+ key_train_metric = None
+
+ train_handlers = [
+ StatsHandler(
+ name="training_loss",
+ output_transform=lambda x: {
+ Keys.RECONSTRUCTION_LOSS: x[Keys.RECONSTRUCTION_LOSS],
+ Keys.DISCRIMINATOR_LOSS: x[Keys.DISCRIMINATOR_LOSS],
+ Keys.GENERATOR_LOSS: x[Keys.GENERATOR_LOSS],
+ },
+ ),
+ TensorBoardStatsHandler(
+ log_dir=root_dir,
+ tag_name="training_loss",
+ output_transform=lambda x: {
+ Keys.RECONSTRUCTION_LOSS: x[Keys.RECONSTRUCTION_LOSS],
+ Keys.DISCRIMINATOR_LOSS: x[Keys.DISCRIMINATOR_LOSS],
+ Keys.GENERATOR_LOSS: x[Keys.GENERATOR_LOSS],
+ },
+ ),
+ CheckpointSaver(
+ save_dir=root_dir,
+ save_dict={"g_net": generator_network, "d_net": discriminator_net},
+ save_interval=2,
+ epoch_level=True,
+ ),
+ ]
+
+ num_epochs = 5
+
+ trainer = AdversarialTrainer(
+ device=device,
+ max_epochs=num_epochs,
+ train_data_loader=train_loader,
+ g_network=generator_network,
+ g_optimizer=generator_optimiser,
+ g_loss_function=generator_loss,
+ recon_loss_function=reconstruction_loss,
+ d_network=discriminator_net,
+ d_optimizer=discriminator_opt,
+ d_loss_function=discriminator_loss,
+ non_blocking=True,
+ key_train_metric=key_train_metric,
+ train_handlers=train_handlers,
+ )
+ trainer.run()
+
+ return trainer.state
+
+
+@skip_if_quick
+@unittest.skipUnless(has_nibabel, "Requires nibabel library.")
+class IntegrationWorkflowsAdversarialTrainer(DistTestCase):
+ def setUp(self):
+ set_determinism(seed=0)
+
+ self.data_dir = tempfile.mkdtemp()
+ for i in range(40):
+ im, _ = create_test_image_2d(64, 64, num_objs=3, rad_max=14, num_seg_classes=1, channel_dim=-1)
+ n = nib.Nifti1Image(im, np.eye(4))
+ nib.save(n, os.path.join(self.data_dir, f"img{i:d}.nii.gz"))
+
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu:0")
+ monai.config.print_config()
+
+ def tearDown(self):
+ set_determinism(seed=None)
+ shutil.rmtree(self.data_dir)
+
+ @TimedCall(seconds=200, daemon=False)
+ def test_training(self):
+ torch.manual_seed(0)
+
+ finish_state = run_training_test(self.data_dir, device=self.device)
+
+ # Assert AdversarialTrainer training finished
+ self.assertEqual(finish_state.iteration, 100)
+ self.assertEqual(finish_state.epoch, 5)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_inverse_collation.py b/tests/test_inverse_collation.py
index f33b5c67eb..bf3972e6bd 100644
--- a/tests/test_inverse_collation.py
+++ b/tests/test_inverse_collation.py
@@ -133,7 +133,7 @@ def test_collation(self, _, transform, collate_fn, ndim):
d = decollate_batch(item)
self.assertTrue(len(d) <= self.batch_size)
for b in d:
- self.assertTrue(isinstance(b["image"], MetaTensor))
+ self.assertIsInstance(b["image"], MetaTensor)
np.testing.assert_array_equal(
b["image"].applied_operations[-1]["orig_size"], b["label"].applied_operations[-1]["orig_size"]
)
diff --git a/tests/test_invertd.py b/tests/test_invertd.py
index c32a3af643..f6e8fc40e7 100644
--- a/tests/test_invertd.py
+++ b/tests/test_invertd.py
@@ -134,7 +134,7 @@ def test_invert(self):
# 25300: 2 workers (cpu, non-macos)
# 1812: 0 workers (gpu or macos)
# 1821: windows torch 1.10.0
- self.assertTrue((reverted.size - n_good) < 40000, f"diff. {reverted.size - n_good}")
+ self.assertLess((reverted.size - n_good), 40000, f"diff. {reverted.size - n_good}")
set_determinism(seed=None)
diff --git a/tests/test_latent_diffusion_inferer.py b/tests/test_latent_diffusion_inferer.py
new file mode 100644
index 0000000000..2e04ad6c5c
--- /dev/null
+++ b/tests/test_latent_diffusion_inferer.py
@@ -0,0 +1,824 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import unittest
+from unittest import skipUnless
+
+import torch
+from parameterized import parameterized
+
+from monai.inferers import LatentDiffusionInferer
+from monai.networks.nets import VQVAE, AutoencoderKL, DiffusionModelUNet, SPADEAutoencoderKL, SPADEDiffusionModelUNet
+from monai.networks.schedulers import DDPMScheduler
+from monai.utils import optional_import
+
+_, has_einops = optional_import("einops")
+TEST_CASES = [
+ [
+ "AutoencoderKL",
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "channels": (4, 4),
+ "latent_channels": 3,
+ "attention_levels": [False, False],
+ "num_res_blocks": 1,
+ "with_encoder_nonlocal_attn": False,
+ "with_decoder_nonlocal_attn": False,
+ "norm_num_groups": 4,
+ },
+ "DiffusionModelUNet",
+ {
+ "spatial_dims": 2,
+ "in_channels": 3,
+ "out_channels": 3,
+ "channels": [4, 4],
+ "norm_num_groups": 4,
+ "attention_levels": [False, False],
+ "num_res_blocks": 1,
+ "num_head_channels": 4,
+ },
+ (1, 1, 8, 8),
+ (1, 3, 4, 4),
+ ],
+ [
+ "VQVAE",
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "channels": [4, 4],
+ "num_res_layers": 1,
+ "num_res_channels": [4, 4],
+ "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)),
+ "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)),
+ "num_embeddings": 16,
+ "embedding_dim": 3,
+ },
+ "DiffusionModelUNet",
+ {
+ "spatial_dims": 2,
+ "in_channels": 3,
+ "out_channels": 3,
+ "channels": [8, 8],
+ "norm_num_groups": 8,
+ "attention_levels": [False, False],
+ "num_res_blocks": 1,
+ "num_head_channels": 8,
+ },
+ (1, 1, 16, 16),
+ (1, 3, 4, 4),
+ ],
+ [
+ "VQVAE",
+ {
+ "spatial_dims": 3,
+ "in_channels": 1,
+ "out_channels": 1,
+ "channels": [4, 4],
+ "num_res_layers": 1,
+ "num_res_channels": [4, 4],
+ "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)),
+ "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)),
+ "num_embeddings": 16,
+ "embedding_dim": 3,
+ },
+ "DiffusionModelUNet",
+ {
+ "spatial_dims": 3,
+ "in_channels": 3,
+ "out_channels": 3,
+ "channels": [8, 8],
+ "norm_num_groups": 8,
+ "attention_levels": [False, False],
+ "num_res_blocks": 1,
+ "num_head_channels": 8,
+ },
+ (1, 1, 16, 16, 16),
+ (1, 3, 4, 4, 4),
+ ],
+ [
+ "AutoencoderKL",
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "channels": (4, 4),
+ "latent_channels": 3,
+ "attention_levels": [False, False],
+ "num_res_blocks": 1,
+ "with_encoder_nonlocal_attn": False,
+ "with_decoder_nonlocal_attn": False,
+ "norm_num_groups": 4,
+ },
+ "SPADEDiffusionModelUNet",
+ {
+ "spatial_dims": 2,
+ "label_nc": 3,
+ "in_channels": 3,
+ "out_channels": 3,
+ "channels": [4, 4],
+ "norm_num_groups": 4,
+ "attention_levels": [False, False],
+ "num_res_blocks": 1,
+ "num_head_channels": 4,
+ },
+ (1, 1, 8, 8),
+ (1, 3, 4, 4),
+ ],
+]
+TEST_CASES_DIFF_SHAPES = [
+ [
+ "AutoencoderKL",
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "channels": (4, 4),
+ "latent_channels": 3,
+ "attention_levels": [False, False],
+ "num_res_blocks": 1,
+ "with_encoder_nonlocal_attn": False,
+ "with_decoder_nonlocal_attn": False,
+ "norm_num_groups": 4,
+ },
+ "DiffusionModelUNet",
+ {
+ "spatial_dims": 2,
+ "in_channels": 3,
+ "out_channels": 3,
+ "channels": [4, 4],
+ "norm_num_groups": 4,
+ "attention_levels": [False, False],
+ "num_res_blocks": 1,
+ "num_head_channels": 4,
+ },
+ (1, 1, 12, 12),
+ (1, 3, 8, 8),
+ ],
+ [
+ "VQVAE",
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "channels": [4, 4],
+ "num_res_layers": 1,
+ "num_res_channels": [4, 4],
+ "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)),
+ "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)),
+ "num_embeddings": 16,
+ "embedding_dim": 3,
+ },
+ "DiffusionModelUNet",
+ {
+ "spatial_dims": 2,
+ "in_channels": 3,
+ "out_channels": 3,
+ "channels": [8, 8],
+ "norm_num_groups": 8,
+ "attention_levels": [False, False],
+ "num_res_blocks": 1,
+ "num_head_channels": 8,
+ },
+ (1, 1, 12, 12),
+ (1, 3, 8, 8),
+ ],
+ [
+ "VQVAE",
+ {
+ "spatial_dims": 3,
+ "in_channels": 1,
+ "out_channels": 1,
+ "channels": [4, 4],
+ "num_res_layers": 1,
+ "num_res_channels": [4, 4],
+ "downsample_parameters": ((2, 4, 1, 1), (2, 4, 1, 1)),
+ "upsample_parameters": ((2, 4, 1, 1, 0), (2, 4, 1, 1, 0)),
+ "num_embeddings": 16,
+ "embedding_dim": 3,
+ },
+ "DiffusionModelUNet",
+ {
+ "spatial_dims": 3,
+ "in_channels": 3,
+ "out_channels": 3,
+ "channels": [8, 8],
+ "norm_num_groups": 8,
+ "attention_levels": [False, False],
+ "num_res_blocks": 1,
+ "num_head_channels": 8,
+ },
+ (1, 1, 12, 12, 12),
+ (1, 3, 8, 8, 8),
+ ],
+ [
+ "SPADEAutoencoderKL",
+ {
+ "spatial_dims": 2,
+ "label_nc": 3,
+ "in_channels": 1,
+ "out_channels": 1,
+ "channels": (4, 4),
+ "latent_channels": 3,
+ "attention_levels": [False, False],
+ "num_res_blocks": 1,
+ "with_encoder_nonlocal_attn": False,
+ "with_decoder_nonlocal_attn": False,
+ "norm_num_groups": 4,
+ },
+ "DiffusionModelUNet",
+ {
+ "spatial_dims": 2,
+ "in_channels": 3,
+ "out_channels": 3,
+ "channels": [4, 4],
+ "norm_num_groups": 4,
+ "attention_levels": [False, False],
+ "num_res_blocks": 1,
+ "num_head_channels": 4,
+ },
+ (1, 1, 8, 8),
+ (1, 3, 4, 4),
+ ],
+ [
+ "AutoencoderKL",
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "channels": (4, 4),
+ "latent_channels": 3,
+ "attention_levels": [False, False],
+ "num_res_blocks": 1,
+ "with_encoder_nonlocal_attn": False,
+ "with_decoder_nonlocal_attn": False,
+ "norm_num_groups": 4,
+ },
+ "SPADEDiffusionModelUNet",
+ {
+ "spatial_dims": 2,
+ "label_nc": 3,
+ "in_channels": 3,
+ "out_channels": 3,
+ "channels": [4, 4],
+ "norm_num_groups": 4,
+ "attention_levels": [False, False],
+ "num_res_blocks": 1,
+ "num_head_channels": 4,
+ },
+ (1, 1, 8, 8),
+ (1, 3, 4, 4),
+ ],
+ [
+ "SPADEAutoencoderKL",
+ {
+ "spatial_dims": 2,
+ "label_nc": 3,
+ "in_channels": 1,
+ "out_channels": 1,
+ "channels": (4, 4),
+ "latent_channels": 3,
+ "attention_levels": [False, False],
+ "num_res_blocks": 1,
+ "with_encoder_nonlocal_attn": False,
+ "with_decoder_nonlocal_attn": False,
+ "norm_num_groups": 4,
+ },
+ "SPADEDiffusionModelUNet",
+ {
+ "spatial_dims": 2,
+ "label_nc": 3,
+ "in_channels": 3,
+ "out_channels": 3,
+ "channels": [4, 4],
+ "norm_num_groups": 4,
+ "attention_levels": [False, False],
+ "num_res_blocks": 1,
+ "num_head_channels": 4,
+ },
+ (1, 1, 8, 8),
+ (1, 3, 4, 4),
+ ],
+]
+
+
+class TestDiffusionSamplingInferer(unittest.TestCase):
+ @parameterized.expand(TEST_CASES)
+ @skipUnless(has_einops, "Requires einops")
+ def test_prediction_shape(
+ self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape
+ ):
+ stage_1 = None
+
+ if ae_model_type == "AutoencoderKL":
+ stage_1 = AutoencoderKL(**autoencoder_params)
+ if ae_model_type == "VQVAE":
+ stage_1 = VQVAE(**autoencoder_params)
+ if dm_model_type == "SPADEDiffusionModelUNet":
+ stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
+ else:
+ stage_2 = DiffusionModelUNet(**stage_2_params)
+
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
+ stage_1.to(device)
+ stage_2.to(device)
+ stage_1.eval()
+ stage_2.eval()
+
+ input = torch.randn(input_shape).to(device)
+ noise = torch.randn(latent_shape).to(device)
+ scheduler = DDPMScheduler(num_train_timesteps=10)
+ inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
+ scheduler.set_timesteps(num_inference_steps=10)
+ timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long()
+
+ if dm_model_type == "SPADEDiffusionModelUNet":
+ input_shape_seg = list(input_shape)
+ if "label_nc" in stage_2_params.keys():
+ input_shape_seg[1] = stage_2_params["label_nc"]
+ else:
+ input_shape_seg[1] = autoencoder_params["label_nc"]
+ input_seg = torch.randn(input_shape_seg).to(device)
+ prediction = inferer(
+ inputs=input,
+ autoencoder_model=stage_1,
+ diffusion_model=stage_2,
+ seg=input_seg,
+ noise=noise,
+ timesteps=timesteps,
+ )
+ else:
+ prediction = inferer(
+ inputs=input, autoencoder_model=stage_1, diffusion_model=stage_2, noise=noise, timesteps=timesteps
+ )
+ self.assertEqual(prediction.shape, latent_shape)
+
+ @parameterized.expand(TEST_CASES)
+ @skipUnless(has_einops, "Requires einops")
+ def test_sample_shape(
+ self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape
+ ):
+ stage_1 = None
+
+ if ae_model_type == "AutoencoderKL":
+ stage_1 = AutoencoderKL(**autoencoder_params)
+ if ae_model_type == "VQVAE":
+ stage_1 = VQVAE(**autoencoder_params)
+ if dm_model_type == "SPADEDiffusionModelUNet":
+ stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
+ else:
+ stage_2 = DiffusionModelUNet(**stage_2_params)
+
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
+ stage_1.to(device)
+ stage_2.to(device)
+ stage_1.eval()
+ stage_2.eval()
+
+ noise = torch.randn(latent_shape).to(device)
+ scheduler = DDPMScheduler(num_train_timesteps=10)
+ inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
+ scheduler.set_timesteps(num_inference_steps=10)
+
+ if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet":
+ input_shape_seg = list(input_shape)
+ if "label_nc" in stage_2_params.keys():
+ input_shape_seg[1] = stage_2_params["label_nc"]
+ else:
+ input_shape_seg[1] = autoencoder_params["label_nc"]
+ input_seg = torch.randn(input_shape_seg).to(device)
+ sample = inferer.sample(
+ input_noise=noise,
+ autoencoder_model=stage_1,
+ diffusion_model=stage_2,
+ scheduler=scheduler,
+ seg=input_seg,
+ )
+ else:
+ sample = inferer.sample(
+ input_noise=noise, autoencoder_model=stage_1, diffusion_model=stage_2, scheduler=scheduler
+ )
+ self.assertEqual(sample.shape, input_shape)
+
+ @parameterized.expand(TEST_CASES)
+ @skipUnless(has_einops, "Requires einops")
+ def test_sample_intermediates(
+ self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape
+ ):
+ stage_1 = None
+
+ if ae_model_type == "AutoencoderKL":
+ stage_1 = AutoencoderKL(**autoencoder_params)
+ if ae_model_type == "VQVAE":
+ stage_1 = VQVAE(**autoencoder_params)
+ if ae_model_type == "SPADEAutoencoderKL":
+ stage_1 = SPADEAutoencoderKL(**autoencoder_params)
+ if dm_model_type == "SPADEDiffusionModelUNet":
+ stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
+ else:
+ stage_2 = DiffusionModelUNet(**stage_2_params)
+
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
+ stage_1.to(device)
+ stage_2.to(device)
+ stage_1.eval()
+ stage_2.eval()
+
+ noise = torch.randn(latent_shape).to(device)
+ scheduler = DDPMScheduler(num_train_timesteps=10)
+ inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
+ scheduler.set_timesteps(num_inference_steps=10)
+
+ if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet":
+ input_shape_seg = list(input_shape)
+ if "label_nc" in stage_2_params.keys():
+ input_shape_seg[1] = stage_2_params["label_nc"]
+ else:
+ input_shape_seg[1] = autoencoder_params["label_nc"]
+ input_seg = torch.randn(input_shape_seg).to(device)
+ sample, intermediates = inferer.sample(
+ input_noise=noise,
+ autoencoder_model=stage_1,
+ diffusion_model=stage_2,
+ scheduler=scheduler,
+ seg=input_seg,
+ save_intermediates=True,
+ intermediate_steps=1,
+ )
+ else:
+ sample, intermediates = inferer.sample(
+ input_noise=noise,
+ autoencoder_model=stage_1,
+ diffusion_model=stage_2,
+ scheduler=scheduler,
+ save_intermediates=True,
+ intermediate_steps=1,
+ )
+ self.assertEqual(len(intermediates), 10)
+ self.assertEqual(intermediates[0].shape, input_shape)
+
+ @parameterized.expand(TEST_CASES)
+ @skipUnless(has_einops, "Requires einops")
+ def test_get_likelihoods(
+ self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape
+ ):
+ stage_1 = None
+
+ if ae_model_type == "AutoencoderKL":
+ stage_1 = AutoencoderKL(**autoencoder_params)
+ if ae_model_type == "VQVAE":
+ stage_1 = VQVAE(**autoencoder_params)
+ if ae_model_type == "SPADEAutoencoderKL":
+ stage_1 = SPADEAutoencoderKL(**autoencoder_params)
+ if dm_model_type == "SPADEDiffusionModelUNet":
+ stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
+ else:
+ stage_2 = DiffusionModelUNet(**stage_2_params)
+
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
+ stage_1.to(device)
+ stage_2.to(device)
+ stage_1.eval()
+ stage_2.eval()
+
+ input = torch.randn(input_shape).to(device)
+ scheduler = DDPMScheduler(num_train_timesteps=10)
+ inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
+ scheduler.set_timesteps(num_inference_steps=10)
+
+ if dm_model_type == "SPADEDiffusionModelUNet":
+ input_shape_seg = list(input_shape)
+ if "label_nc" in stage_2_params.keys():
+ input_shape_seg[1] = stage_2_params["label_nc"]
+ else:
+ input_shape_seg[1] = autoencoder_params["label_nc"]
+ input_seg = torch.randn(input_shape_seg).to(device)
+ sample, intermediates = inferer.get_likelihood(
+ inputs=input,
+ autoencoder_model=stage_1,
+ diffusion_model=stage_2,
+ scheduler=scheduler,
+ save_intermediates=True,
+ seg=input_seg,
+ )
+ else:
+ sample, intermediates = inferer.get_likelihood(
+ inputs=input,
+ autoencoder_model=stage_1,
+ diffusion_model=stage_2,
+ scheduler=scheduler,
+ save_intermediates=True,
+ )
+ self.assertEqual(len(intermediates), 10)
+ self.assertEqual(intermediates[0].shape, latent_shape)
+
+ @parameterized.expand(TEST_CASES)
+ @skipUnless(has_einops, "Requires einops")
+ def test_resample_likelihoods(
+ self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape
+ ):
+ stage_1 = None
+
+ if ae_model_type == "AutoencoderKL":
+ stage_1 = AutoencoderKL(**autoencoder_params)
+ if ae_model_type == "VQVAE":
+ stage_1 = VQVAE(**autoencoder_params)
+ if ae_model_type == "SPADEAutoencoderKL":
+ stage_1 = SPADEAutoencoderKL(**autoencoder_params)
+ if dm_model_type == "SPADEDiffusionModelUNet":
+ stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
+ else:
+ stage_2 = DiffusionModelUNet(**stage_2_params)
+
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
+ stage_1.to(device)
+ stage_2.to(device)
+ stage_1.eval()
+ stage_2.eval()
+
+ input = torch.randn(input_shape).to(device)
+ scheduler = DDPMScheduler(num_train_timesteps=10)
+ inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
+ scheduler.set_timesteps(num_inference_steps=10)
+
+ if dm_model_type == "SPADEDiffusionModelUNet":
+ input_shape_seg = list(input_shape)
+ if "label_nc" in stage_2_params.keys():
+ input_shape_seg[1] = stage_2_params["label_nc"]
+ else:
+ input_shape_seg[1] = autoencoder_params["label_nc"]
+ input_seg = torch.randn(input_shape_seg).to(device)
+ sample, intermediates = inferer.get_likelihood(
+ inputs=input,
+ autoencoder_model=stage_1,
+ diffusion_model=stage_2,
+ scheduler=scheduler,
+ save_intermediates=True,
+ resample_latent_likelihoods=True,
+ seg=input_seg,
+ )
+ else:
+ sample, intermediates = inferer.get_likelihood(
+ inputs=input,
+ autoencoder_model=stage_1,
+ diffusion_model=stage_2,
+ scheduler=scheduler,
+ save_intermediates=True,
+ resample_latent_likelihoods=True,
+ )
+ self.assertEqual(len(intermediates), 10)
+ self.assertEqual(intermediates[0].shape[2:], input_shape[2:])
+
+ @parameterized.expand(TEST_CASES)
+ @skipUnless(has_einops, "Requires einops")
+ def test_prediction_shape_conditioned_concat(
+ self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape
+ ):
+ stage_1 = None
+
+ if ae_model_type == "AutoencoderKL":
+ stage_1 = AutoencoderKL(**autoencoder_params)
+ if ae_model_type == "VQVAE":
+ stage_1 = VQVAE(**autoencoder_params)
+ if ae_model_type == "SPADEAutoencoderKL":
+ stage_1 = SPADEAutoencoderKL(**autoencoder_params)
+ stage_2_params = stage_2_params.copy()
+ n_concat_channel = 3
+ stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel
+ if dm_model_type == "SPADEDiffusionModelUNet":
+ stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
+ else:
+ stage_2 = DiffusionModelUNet(**stage_2_params)
+
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
+ stage_1.to(device)
+ stage_2.to(device)
+ stage_1.eval()
+ stage_2.eval()
+
+ input = torch.randn(input_shape).to(device)
+ noise = torch.randn(latent_shape).to(device)
+ conditioning_shape = list(latent_shape)
+ conditioning_shape[1] = n_concat_channel
+ conditioning = torch.randn(conditioning_shape).to(device)
+
+ scheduler = DDPMScheduler(num_train_timesteps=10)
+ inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
+ scheduler.set_timesteps(num_inference_steps=10)
+
+ timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long()
+
+ if dm_model_type == "SPADEDiffusionModelUNet":
+ input_shape_seg = list(input_shape)
+ if "label_nc" in stage_2_params.keys():
+ input_shape_seg[1] = stage_2_params["label_nc"]
+ else:
+ input_shape_seg[1] = autoencoder_params["label_nc"]
+ input_seg = torch.randn(input_shape_seg).to(device)
+ prediction = inferer(
+ inputs=input,
+ autoencoder_model=stage_1,
+ diffusion_model=stage_2,
+ noise=noise,
+ timesteps=timesteps,
+ condition=conditioning,
+ mode="concat",
+ seg=input_seg,
+ )
+ else:
+ prediction = inferer(
+ inputs=input,
+ autoencoder_model=stage_1,
+ diffusion_model=stage_2,
+ noise=noise,
+ timesteps=timesteps,
+ condition=conditioning,
+ mode="concat",
+ )
+ self.assertEqual(prediction.shape, latent_shape)
+
+ @parameterized.expand(TEST_CASES)
+ @skipUnless(has_einops, "Requires einops")
+ def test_sample_shape_conditioned_concat(
+ self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape
+ ):
+ stage_1 = None
+
+ if ae_model_type == "AutoencoderKL":
+ stage_1 = AutoencoderKL(**autoencoder_params)
+ if ae_model_type == "VQVAE":
+ stage_1 = VQVAE(**autoencoder_params)
+ if ae_model_type == "SPADEAutoencoderKL":
+ stage_1 = SPADEAutoencoderKL(**autoencoder_params)
+ stage_2_params = stage_2_params.copy()
+ n_concat_channel = 3
+ stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel
+ if dm_model_type == "SPADEDiffusionModelUNet":
+ stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
+ else:
+ stage_2 = DiffusionModelUNet(**stage_2_params)
+
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
+ stage_1.to(device)
+ stage_2.to(device)
+ stage_1.eval()
+ stage_2.eval()
+
+ noise = torch.randn(latent_shape).to(device)
+ conditioning_shape = list(latent_shape)
+ conditioning_shape[1] = n_concat_channel
+ conditioning = torch.randn(conditioning_shape).to(device)
+
+ scheduler = DDPMScheduler(num_train_timesteps=10)
+ inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
+ scheduler.set_timesteps(num_inference_steps=10)
+
+ if dm_model_type == "SPADEDiffusionModelUNet":
+ input_shape_seg = list(input_shape)
+ if "label_nc" in stage_2_params.keys():
+ input_shape_seg[1] = stage_2_params["label_nc"]
+ else:
+ input_shape_seg[1] = autoencoder_params["label_nc"]
+ input_seg = torch.randn(input_shape_seg).to(device)
+ sample = inferer.sample(
+ input_noise=noise,
+ autoencoder_model=stage_1,
+ diffusion_model=stage_2,
+ scheduler=scheduler,
+ conditioning=conditioning,
+ mode="concat",
+ seg=input_seg,
+ )
+ else:
+ sample = inferer.sample(
+ input_noise=noise,
+ autoencoder_model=stage_1,
+ diffusion_model=stage_2,
+ scheduler=scheduler,
+ conditioning=conditioning,
+ mode="concat",
+ )
+ self.assertEqual(sample.shape, input_shape)
+
+ @parameterized.expand(TEST_CASES_DIFF_SHAPES)
+ @skipUnless(has_einops, "Requires einops")
+ def test_sample_shape_different_latents(
+ self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape
+ ):
+ stage_1 = None
+
+ if ae_model_type == "AutoencoderKL":
+ stage_1 = AutoencoderKL(**autoencoder_params)
+ if ae_model_type == "VQVAE":
+ stage_1 = VQVAE(**autoencoder_params)
+ if ae_model_type == "SPADEAutoencoderKL":
+ stage_1 = SPADEAutoencoderKL(**autoencoder_params)
+ if dm_model_type == "SPADEDiffusionModelUNet":
+ stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
+ else:
+ stage_2 = DiffusionModelUNet(**stage_2_params)
+
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
+ stage_1.to(device)
+ stage_2.to(device)
+ stage_1.eval()
+ stage_2.eval()
+
+ input = torch.randn(input_shape).to(device)
+ noise = torch.randn(latent_shape).to(device)
+ scheduler = DDPMScheduler(num_train_timesteps=10)
+ # We infer the VAE shape
+ autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]) - 1)) for i in input_shape[2:]]
+ inferer = LatentDiffusionInferer(
+ scheduler=scheduler,
+ scale_factor=1.0,
+ ldm_latent_shape=list(latent_shape[2:]),
+ autoencoder_latent_shape=autoencoder_latent_shape,
+ )
+ scheduler.set_timesteps(num_inference_steps=10)
+
+ timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long()
+
+ if dm_model_type == "SPADEDiffusionModelUNet":
+ input_shape_seg = list(input_shape)
+ if "label_nc" in stage_2_params.keys():
+ input_shape_seg[1] = stage_2_params["label_nc"]
+ else:
+ input_shape_seg[1] = autoencoder_params["label_nc"]
+ input_seg = torch.randn(input_shape_seg).to(device)
+ prediction = inferer(
+ inputs=input,
+ autoencoder_model=stage_1,
+ diffusion_model=stage_2,
+ noise=noise,
+ timesteps=timesteps,
+ seg=input_seg,
+ )
+ else:
+ prediction = inferer(
+ inputs=input, autoencoder_model=stage_1, diffusion_model=stage_2, noise=noise, timesteps=timesteps
+ )
+ self.assertEqual(prediction.shape, latent_shape)
+
+ @skipUnless(has_einops, "Requires einops")
+ def test_incompatible_spade_setup(self):
+ stage_1 = SPADEAutoencoderKL(
+ spatial_dims=2,
+ label_nc=6,
+ in_channels=1,
+ out_channels=1,
+ channels=(4, 4),
+ latent_channels=3,
+ attention_levels=[False, False],
+ num_res_blocks=1,
+ with_encoder_nonlocal_attn=False,
+ with_decoder_nonlocal_attn=False,
+ norm_num_groups=4,
+ )
+ stage_2 = SPADEDiffusionModelUNet(
+ spatial_dims=2,
+ label_nc=3,
+ in_channels=3,
+ out_channels=3,
+ channels=[4, 4],
+ norm_num_groups=4,
+ attention_levels=[False, False],
+ num_res_blocks=1,
+ num_head_channels=4,
+ )
+
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
+ stage_1.to(device)
+ stage_2.to(device)
+ stage_1.eval()
+ stage_2.eval()
+ noise = torch.randn((1, 3, 4, 4)).to(device)
+ input_seg = torch.randn((1, 3, 8, 8)).to(device)
+ scheduler = DDPMScheduler(num_train_timesteps=10)
+ inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
+ scheduler.set_timesteps(num_inference_steps=10)
+
+ with self.assertRaises(ValueError):
+ _ = inferer.sample(
+ input_noise=noise,
+ autoencoder_model=stage_1,
+ diffusion_model=stage_2,
+ scheduler=scheduler,
+ seg=input_seg,
+ )
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_load_imaged.py b/tests/test_load_imaged.py
index 699ed70059..914240c705 100644
--- a/tests/test_load_imaged.py
+++ b/tests/test_load_imaged.py
@@ -190,7 +190,7 @@ def test_correct(self, input_p, expected_shape, track_meta):
self.assertTrue(hasattr(r, "affine"))
self.assertIsInstance(r.affine, torch.Tensor)
self.assertEqual(r.meta["space"], "RAS")
- self.assertTrue("qform_code" not in r.meta)
+ self.assertNotIn("qform_code", r.meta)
else:
self.assertIsInstance(r, torch.Tensor)
self.assertNotIsInstance(r, MetaTensor)
diff --git a/tests/test_load_spacing_orientation.py b/tests/test_load_spacing_orientation.py
index 63422761ca..cbc730e1bb 100644
--- a/tests/test_load_spacing_orientation.py
+++ b/tests/test_load_spacing_orientation.py
@@ -48,7 +48,7 @@ def test_load_spacingd(self, filename):
ref = resample_to_output(anat, (1, 0.2, 1), order=1)
t2 = time.time()
print(f"time scipy: {t2 - t1}")
- self.assertTrue(t2 >= t1)
+ self.assertGreaterEqual(t2, t1)
np.testing.assert_allclose(res_dict["image"].affine, ref.affine)
np.testing.assert_allclose(res_dict["image"].shape[1:], ref.shape)
np.testing.assert_allclose(ref.get_fdata(), res_dict["image"][0], atol=0.05)
@@ -68,7 +68,7 @@ def test_load_spacingd_rotate(self, filename):
ref = resample_to_output(anat, (1, 2, 3), order=1)
t2 = time.time()
print(f"time scipy: {t2 - t1}")
- self.assertTrue(t2 >= t1)
+ self.assertGreaterEqual(t2, t1)
np.testing.assert_allclose(res_dict["image"].affine, ref.affine)
if "anatomical" not in filename:
np.testing.assert_allclose(res_dict["image"].shape[1:], ref.shape)
diff --git a/tests/test_look_up_option.py b/tests/test_look_up_option.py
index d40b7eaa8c..75560b4ac4 100644
--- a/tests/test_look_up_option.py
+++ b/tests/test_look_up_option.py
@@ -56,7 +56,7 @@ def test_default(self):
def test_str_enum(self):
output = look_up_option("C", {"A", "B"}, default=None)
- self.assertEqual(output, None)
+ self.assertIsNone(output)
self.assertEqual(list(_CaseStrEnum), ["A", "B"])
self.assertEqual(_CaseStrEnum.MODE_A, "A")
self.assertEqual(str(_CaseStrEnum.MODE_A), "A")
diff --git a/tests/test_map_and_generate_sampling_centers.py b/tests/test_map_and_generate_sampling_centers.py
new file mode 100644
index 0000000000..ff74f974b9
--- /dev/null
+++ b/tests/test_map_and_generate_sampling_centers.py
@@ -0,0 +1,87 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import unittest
+from copy import deepcopy
+
+import numpy as np
+from parameterized import parameterized
+
+from monai.transforms import map_and_generate_sampling_centers
+from monai.utils.misc import set_determinism
+from tests.utils import TEST_NDARRAYS, assert_allclose
+
+TEST_CASE_1 = [
+ # test Argmax data
+ {
+ "label": (np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]])),
+ "spatial_size": [2, 2, 2],
+ "num_samples": 2,
+ "label_spatial_shape": [3, 3, 3],
+ "num_classes": 3,
+ "image": None,
+ "ratios": [0, 1, 2],
+ "image_threshold": 0.0,
+ },
+ tuple,
+ 2,
+ 3,
+]
+
+TEST_CASE_2 = [
+ {
+ "label": (
+ np.array(
+ [
+ [[1, 0, 0], [0, 1, 0], [0, 0, 1]],
+ [[0, 1, 0], [0, 0, 1], [1, 0, 0]],
+ [[0, 0, 1], [1, 0, 0], [0, 1, 0]],
+ ]
+ )
+ ),
+ "spatial_size": [2, 2, 2],
+ "num_samples": 1,
+ "ratios": None,
+ "label_spatial_shape": [3, 3, 3],
+ "image": None,
+ "image_threshold": 0.0,
+ },
+ tuple,
+ 1,
+ 3,
+]
+
+
+class TestMapAndGenerateSamplingCenters(unittest.TestCase):
+
+ @parameterized.expand([TEST_CASE_1, TEST_CASE_2])
+ def test_map_and_generate_sampling_centers(self, input_data, expected_type, expected_count, expected_shape):
+ results = []
+ for p in TEST_NDARRAYS + (None,):
+ input_data = deepcopy(input_data)
+ if p is not None:
+ input_data["label"] = p(input_data["label"])
+ set_determinism(0)
+ result = map_and_generate_sampling_centers(**input_data)
+ self.assertIsInstance(result, expected_type)
+ self.assertEqual(len(result), expected_count)
+ self.assertEqual(len(result[0]), expected_shape)
+ # check for consistency between numpy, torch and torch.cuda
+ results.append(result)
+ if len(results) > 1:
+ for x, y in zip(result[0], result[-1]):
+ assert_allclose(x, y, type_test=False)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_mapping_file.py b/tests/test_mapping_file.py
new file mode 100644
index 0000000000..97fa4312ed
--- /dev/null
+++ b/tests/test_mapping_file.py
@@ -0,0 +1,117 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import json
+import os
+import shutil
+import tempfile
+import unittest
+
+import numpy as np
+from parameterized import parameterized
+
+from monai.data import DataLoader, Dataset
+from monai.transforms import Compose, LoadImage, SaveImage, WriteFileMapping
+from monai.utils import optional_import
+
+nib, has_nib = optional_import("nibabel")
+
+
+def create_input_file(temp_dir, name):
+ test_image = np.random.rand(128, 128, 128)
+ output_ext = ".nii.gz"
+ input_file = os.path.join(temp_dir, name + output_ext)
+ nib.save(nib.Nifti1Image(test_image, np.eye(4)), input_file)
+ return input_file
+
+
+def create_transform(temp_dir, mapping_file_path, savepath_in_metadict=True):
+ return Compose(
+ [
+ LoadImage(image_only=True),
+ SaveImage(output_dir=temp_dir, output_ext=".nii.gz", savepath_in_metadict=savepath_in_metadict),
+ WriteFileMapping(mapping_file_path=mapping_file_path),
+ ]
+ )
+
+
+@unittest.skipUnless(has_nib, "nibabel required")
+class TestWriteFileMapping(unittest.TestCase):
+ def setUp(self):
+ self.temp_dir = tempfile.mkdtemp()
+
+ def tearDown(self):
+ shutil.rmtree(self.temp_dir)
+
+ @parameterized.expand([(True,), (False,)])
+ def test_mapping_file(self, savepath_in_metadict):
+ mapping_file_path = os.path.join(self.temp_dir, "mapping.json")
+ name = "test_image"
+ input_file = create_input_file(self.temp_dir, name)
+ output_file = os.path.join(self.temp_dir, name, name + "_trans.nii.gz")
+
+ transform = create_transform(self.temp_dir, mapping_file_path, savepath_in_metadict)
+
+ if savepath_in_metadict:
+ transform(input_file)
+ self.assertTrue(os.path.exists(mapping_file_path))
+ with open(mapping_file_path) as f:
+ mapping_data = json.load(f)
+ self.assertEqual(len(mapping_data), 1)
+ self.assertEqual(mapping_data[0]["input"], input_file)
+ self.assertEqual(mapping_data[0]["output"], output_file)
+ else:
+ with self.assertRaises(RuntimeError) as cm:
+ transform(input_file)
+ cause_exception = cm.exception.__cause__
+ self.assertIsInstance(cause_exception, KeyError)
+ self.assertIn(
+ "Missing 'saved_to' key in metadata. Check SaveImage argument 'savepath_in_metadict' is True.",
+ str(cause_exception),
+ )
+
+ def test_multiprocess_mapping_file(self):
+ num_images = 50
+
+ single_mapping_file = os.path.join(self.temp_dir, "single_mapping.json")
+ multi_mapping_file = os.path.join(self.temp_dir, "multi_mapping.json")
+
+ data = [create_input_file(self.temp_dir, f"test_image_{i}") for i in range(num_images)]
+
+ # single process
+ single_transform = create_transform(self.temp_dir, single_mapping_file)
+ single_dataset = Dataset(data=data, transform=single_transform)
+ single_loader = DataLoader(single_dataset, batch_size=1, num_workers=0, shuffle=True)
+ for _ in single_loader:
+ pass
+
+ # multiple processes
+ multi_transform = create_transform(self.temp_dir, multi_mapping_file)
+ multi_dataset = Dataset(data=data, transform=multi_transform)
+ multi_loader = DataLoader(multi_dataset, batch_size=4, num_workers=3, shuffle=True)
+ for _ in multi_loader:
+ pass
+
+ with open(single_mapping_file) as f:
+ single_mapping_data = json.load(f)
+ with open(multi_mapping_file) as f:
+ multi_mapping_data = json.load(f)
+
+ single_set = {(entry["input"], entry["output"]) for entry in single_mapping_data}
+ multi_set = {(entry["input"], entry["output"]) for entry in multi_mapping_data}
+
+ self.assertEqual(single_set, multi_set)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_mapping_filed.py b/tests/test_mapping_filed.py
new file mode 100644
index 0000000000..d0f8bcf938
--- /dev/null
+++ b/tests/test_mapping_filed.py
@@ -0,0 +1,122 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import json
+import os
+import shutil
+import tempfile
+import unittest
+
+import numpy as np
+import torch
+from parameterized import parameterized
+
+from monai.data import DataLoader, Dataset, decollate_batch
+from monai.inferers import sliding_window_inference
+from monai.networks.nets import UNet
+from monai.transforms import Compose, EnsureChannelFirstd, LoadImaged, SaveImaged, WriteFileMappingd
+from monai.utils import optional_import
+
+nib, has_nib = optional_import("nibabel")
+
+
+def create_input_file(temp_dir, name):
+ test_image = np.random.rand(128, 128, 128)
+ input_file = os.path.join(temp_dir, name + ".nii.gz")
+ nib.save(nib.Nifti1Image(test_image, np.eye(4)), input_file)
+ return input_file
+
+
+# Test cases that should succeed
+SUCCESS_CASES = [(["seg"], ["seg"]), (["image", "seg"], ["seg"])]
+
+# Test cases that should fail
+FAILURE_CASES = [(["seg"], ["image"]), (["image"], ["seg"]), (["seg"], ["image", "seg"])]
+
+
+@unittest.skipUnless(has_nib, "nibabel required")
+class TestWriteFileMappingd(unittest.TestCase):
+ def setUp(self):
+ self.temp_dir = tempfile.mkdtemp()
+ self.output_dir = os.path.join(self.temp_dir, "output")
+ os.makedirs(self.output_dir)
+ self.mapping_file_path = os.path.join(self.temp_dir, "mapping.json")
+
+ def tearDown(self):
+ shutil.rmtree(self.temp_dir)
+ if os.path.exists(self.mapping_file_path):
+ os.remove(self.mapping_file_path)
+
+ def run_test(self, save_keys, write_keys):
+ name = "test_image"
+ input_file = create_input_file(self.temp_dir, name)
+ output_file = os.path.join(self.output_dir, name, name + "_seg.nii.gz")
+ data = [{"image": input_file}]
+
+ test_transforms = Compose([LoadImaged(keys=["image"]), EnsureChannelFirstd(keys=["image"])])
+
+ post_transforms = Compose(
+ [
+ SaveImaged(
+ keys=save_keys,
+ meta_keys="image_meta_dict",
+ output_dir=self.output_dir,
+ output_postfix="seg",
+ savepath_in_metadict=True,
+ ),
+ WriteFileMappingd(keys=write_keys, mapping_file_path=self.mapping_file_path),
+ ]
+ )
+
+ dataset = Dataset(data=data, transform=test_transforms)
+ dataloader = DataLoader(dataset, batch_size=1)
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ model = UNet(spatial_dims=3, in_channels=1, out_channels=2, channels=(16, 32), strides=(2,)).to(device)
+ model.eval()
+
+ with torch.no_grad():
+ for batch_data in dataloader:
+ test_inputs = batch_data["image"].to(device)
+ roi_size = (64, 64, 64)
+ sw_batch_size = 2
+ batch_data["seg"] = sliding_window_inference(test_inputs, roi_size, sw_batch_size, model)
+ batch_data = [post_transforms(i) for i in decollate_batch(batch_data)]
+
+ return input_file, output_file
+
+ @parameterized.expand(SUCCESS_CASES)
+ def test_successful_mapping_filed(self, save_keys, write_keys):
+ input_file, output_file = self.run_test(save_keys, write_keys)
+ self.assertTrue(os.path.exists(self.mapping_file_path))
+ with open(self.mapping_file_path) as f:
+ mapping_data = json.load(f)
+ self.assertEqual(len(mapping_data), len(write_keys))
+ for entry in mapping_data:
+ self.assertEqual(entry["input"], input_file)
+ self.assertEqual(entry["output"], output_file)
+
+ @parameterized.expand(FAILURE_CASES)
+ def test_failure_mapping_filed(self, save_keys, write_keys):
+ with self.assertRaises(RuntimeError) as cm:
+ self.run_test(save_keys, write_keys)
+
+ cause_exception = cm.exception.__cause__
+ self.assertIsInstance(cause_exception, KeyError)
+ self.assertIn(
+ "Missing 'saved_to' key in metadata. Check SaveImage argument 'savepath_in_metadict' is True.",
+ str(cause_exception),
+ )
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_matshow3d.py b/tests/test_matshow3d.py
index e513025e69..2eba310f4e 100644
--- a/tests/test_matshow3d.py
+++ b/tests/test_matshow3d.py
@@ -78,7 +78,7 @@ def test_samples(self):
fig, mat = matshow3d(
[im[keys] for im in ims], title=f"testing {keys}", figsize=(2, 2), frames_per_row=5, every_n=2, show=False
)
- self.assertTrue(mat.dtype == np.float32)
+ self.assertEqual(mat.dtype, np.float32)
with tempfile.TemporaryDirectory() as tempdir:
tempimg = f"{tempdir}/matshow3d_patch_test.png"
@@ -114,6 +114,7 @@ def test_3d_rgb(self):
every_n=2,
frame_dim=-1,
channel_dim=0,
+ fill_value=0,
show=False,
)
diff --git a/tests/test_median_filter.py b/tests/test_median_filter.py
index 1f5e623260..02fa812380 100644
--- a/tests/test_median_filter.py
+++ b/tests/test_median_filter.py
@@ -15,27 +15,21 @@
import numpy as np
import torch
+from parameterized import parameterized
from monai.networks.layers import MedianFilter
class MedianFilterTestCase(unittest.TestCase):
- def test_3d_big(self):
- a = torch.ones(1, 1, 2, 3, 5)
- g = MedianFilter([1, 2, 4]).to(torch.device("cpu:0"))
+ @parameterized.expand([(torch.ones(1, 1, 2, 3, 5), [1, 2, 4]), (torch.ones(1, 1, 4, 3, 4), 1)]) # 3d_big # 3d
+ def test_3d(self, input_tensor, radius):
+ filter = MedianFilter(radius).to(torch.device("cpu:0"))
- expected = a.numpy()
- out = g(a).cpu().numpy()
- np.testing.assert_allclose(out, expected, rtol=1e-5)
-
- def test_3d(self):
- a = torch.ones(1, 1, 4, 3, 4)
- g = MedianFilter(1).to(torch.device("cpu:0"))
+ expected = input_tensor.numpy()
+ output = filter(input_tensor).cpu().numpy()
- expected = a.numpy()
- out = g(a).cpu().numpy()
- np.testing.assert_allclose(out, expected, rtol=1e-5)
+ np.testing.assert_allclose(output, expected, rtol=1e-5)
def test_3d_radii(self):
a = torch.ones(1, 1, 4, 3, 2)
diff --git a/tests/test_mednext.py b/tests/test_mednext.py
new file mode 100644
index 0000000000..b4ba4f9939
--- /dev/null
+++ b/tests/test_mednext.py
@@ -0,0 +1,122 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import unittest
+
+import torch
+from parameterized import parameterized
+
+from monai.networks import eval_mode
+from monai.networks.nets import MedNeXt, MedNeXtL, MedNeXtM, MedNeXtS
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+
+TEST_CASE_MEDNEXT = []
+for spatial_dims in range(2, 4):
+ for init_filters in [8, 16]:
+ for deep_supervision in [False, True]:
+ for do_res in [False, True]:
+ test_case = [
+ {
+ "spatial_dims": spatial_dims,
+ "init_filters": init_filters,
+ "deep_supervision": deep_supervision,
+ "use_residual_connection": do_res,
+ },
+ (2, 1, *([16] * spatial_dims)),
+ (2, 2, *([16] * spatial_dims)),
+ ]
+ TEST_CASE_MEDNEXT.append(test_case)
+
+TEST_CASE_MEDNEXT_2 = []
+for spatial_dims in range(2, 4):
+ for out_channels in [1, 2]:
+ for deep_supervision in [False, True]:
+ test_case = [
+ {
+ "spatial_dims": spatial_dims,
+ "init_filters": 8,
+ "out_channels": out_channels,
+ "deep_supervision": deep_supervision,
+ },
+ (2, 1, *([16] * spatial_dims)),
+ (2, out_channels, *([16] * spatial_dims)),
+ ]
+ TEST_CASE_MEDNEXT_2.append(test_case)
+
+TEST_CASE_MEDNEXT_VARIANTS = []
+for model in [MedNeXtS, MedNeXtM, MedNeXtL]:
+ for spatial_dims in range(2, 4):
+ for out_channels in [1, 2]:
+ test_case = [
+ model, # type: ignore
+ {"spatial_dims": spatial_dims, "in_channels": 1, "out_channels": out_channels},
+ (2, 1, *([16] * spatial_dims)),
+ (2, out_channels, *([16] * spatial_dims)),
+ ]
+ TEST_CASE_MEDNEXT_VARIANTS.append(test_case)
+
+
+class TestMedNeXt(unittest.TestCase):
+
+ @parameterized.expand(TEST_CASE_MEDNEXT)
+ def test_shape(self, input_param, input_shape, expected_shape):
+ net = MedNeXt(**input_param).to(device)
+ with eval_mode(net):
+ result = net(torch.randn(input_shape).to(device))
+ if input_param["deep_supervision"] and net.training:
+ assert isinstance(result, tuple)
+ self.assertEqual(result[0].shape, expected_shape, msg=str(input_param))
+ else:
+ self.assertEqual(result.shape, expected_shape, msg=str(input_param))
+
+ @parameterized.expand(TEST_CASE_MEDNEXT_2)
+ def test_shape2(self, input_param, input_shape, expected_shape):
+ net = MedNeXt(**input_param).to(device)
+
+ net.train()
+ result = net(torch.randn(input_shape).to(device))
+ if input_param["deep_supervision"]:
+ assert isinstance(result, tuple)
+ self.assertEqual(result[0].shape, expected_shape, msg=str(input_param))
+ else:
+ assert isinstance(result, torch.Tensor)
+ self.assertEqual(result.shape, expected_shape, msg=str(input_param))
+
+ net.eval()
+ result = net(torch.randn(input_shape).to(device))
+ assert isinstance(result, torch.Tensor)
+ self.assertEqual(result.shape, expected_shape, msg=str(input_param))
+
+ def test_ill_arg(self):
+ with self.assertRaises(AssertionError):
+ MedNeXt(spatial_dims=4)
+
+ @parameterized.expand(TEST_CASE_MEDNEXT_VARIANTS)
+ def test_mednext_variants(self, model, input_param, input_shape, expected_shape):
+ net = model(**input_param).to(device)
+
+ net.train()
+ result = net(torch.randn(input_shape).to(device))
+ assert isinstance(result, torch.Tensor)
+ self.assertEqual(result.shape, expected_shape, msg=str(input_param))
+
+ net.eval()
+ with torch.no_grad():
+ result = net(torch.randn(input_shape).to(device))
+ assert isinstance(result, torch.Tensor)
+ self.assertEqual(result.shape, expected_shape, msg=str(input_param))
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_mednistdataset.py b/tests/test_mednistdataset.py
index baf3bf4f2d..c1b21e9373 100644
--- a/tests/test_mednistdataset.py
+++ b/tests/test_mednistdataset.py
@@ -41,7 +41,7 @@ def _test_dataset(dataset):
self.assertEqual(len(dataset), int(MEDNIST_FULL_DATASET_LENGTH * dataset.test_frac))
self.assertTrue("image" in dataset[0])
self.assertTrue("label" in dataset[0])
- self.assertTrue(isinstance(dataset[0]["image"], MetaTensor))
+ self.assertIsInstance(dataset[0]["image"], MetaTensor)
self.assertTupleEqual(dataset[0]["image"].shape, (1, 64, 64))
with skip_if_downloading_fails():
@@ -65,11 +65,8 @@ def _test_dataset(dataset):
self.assertEqual(data[0]["class_name"], "AbdomenCT")
self.assertEqual(data[0]["label"], 0)
shutil.rmtree(os.path.join(testing_dir, "MedNIST"))
- try:
+ with self.assertRaisesRegex(RuntimeError, "^Cannot find dataset directory"):
MedNISTDataset(root_dir=testing_dir, transform=transform, section="test", download=False)
- except RuntimeError as e:
- print(str(e))
- self.assertTrue(str(e).startswith("Cannot find dataset directory"))
if __name__ == "__main__":
diff --git a/tests/test_meta_affine.py b/tests/test_meta_affine.py
index 95764a0c89..890734391f 100644
--- a/tests/test_meta_affine.py
+++ b/tests/test_meta_affine.py
@@ -160,7 +160,7 @@ def test_linear_consistent(self, xform_cls, input_dict, atol):
diff = np.abs(itk.GetArrayFromImage(ref_2) - itk.GetArrayFromImage(expected))
avg_diff = np.mean(diff)
- self.assertTrue(avg_diff < atol, f"{xform_cls} avg_diff: {avg_diff}, tol: {atol}")
+ self.assertLess(avg_diff, atol, f"{xform_cls} avg_diff: {avg_diff}, tol: {atol}")
@parameterized.expand(TEST_CASES_DICT)
def test_linear_consistent_dict(self, xform_cls, input_dict, atol):
@@ -175,7 +175,7 @@ def test_linear_consistent_dict(self, xform_cls, input_dict, atol):
diff = {k: np.abs(itk.GetArrayFromImage(ref_2[k]) - itk.GetArrayFromImage(expected[k])) for k in keys}
avg_diff = {k: np.mean(diff[k]) for k in keys}
for k in keys:
- self.assertTrue(avg_diff[k] < atol, f"{xform_cls} avg_diff: {avg_diff}, tol: {atol}")
+ self.assertLess(avg_diff[k], atol, f"{xform_cls} avg_diff: {avg_diff}, tol: {atol}")
if __name__ == "__main__":
diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py
index 1e0f188b63..60b6019703 100644
--- a/tests/test_meta_tensor.py
+++ b/tests/test_meta_tensor.py
@@ -222,9 +222,9 @@ def test_stack(self, device, dtype):
def test_get_set_meta_fns(self):
set_track_meta(False)
- self.assertEqual(get_track_meta(), False)
+ self.assertFalse(get_track_meta())
set_track_meta(True)
- self.assertEqual(get_track_meta(), True)
+ self.assertTrue(get_track_meta())
@parameterized.expand(TEST_DEVICES)
def test_torchscript(self, device):
@@ -448,7 +448,7 @@ def test_shape(self):
def test_astype(self):
t = MetaTensor([1.0], affine=torch.tensor(1), meta={"fname": "filename"})
- for np_types in ("float32", "np.float32", "numpy.float32", np.float32, float, "int", np.compat.long, np.uint16):
+ for np_types in ("float32", "np.float32", "numpy.float32", np.float32, float, "int", np.uint16):
self.assertIsInstance(t.astype(np_types), np.ndarray)
for pt_types in ("torch.float", torch.float, "torch.float64"):
self.assertIsInstance(t.astype(pt_types), torch.Tensor)
diff --git a/tests/test_mlp.py b/tests/test_mlp.py
index 54f70d3318..2598d8877d 100644
--- a/tests/test_mlp.py
+++ b/tests/test_mlp.py
@@ -15,10 +15,12 @@
import numpy as np
import torch
+import torch.nn as nn
from parameterized import parameterized
from monai.networks import eval_mode
from monai.networks.blocks.mlp import MLPBlock
+from monai.networks.layers.factories import split_args
TEST_CASE_MLP = []
for dropout_rate in np.linspace(0, 1, 4):
@@ -31,6 +33,14 @@
]
TEST_CASE_MLP.append(test_case)
+# test different activation layers
+TEST_CASE_ACT = []
+for act in ["GELU", "GEGLU", ("GEGLU", {})]: # type: ignore
+ TEST_CASE_ACT.append([{"hidden_size": 128, "mlp_dim": 0, "act": act}, (2, 512, 128), (2, 512, 128)])
+
+# test different dropout modes
+TEST_CASE_DROP = [["vit", nn.Dropout], ["swin", nn.Dropout], ["vista3d", nn.Identity]]
+
class TestMLPBlock(unittest.TestCase):
@@ -45,6 +55,24 @@ def test_ill_arg(self):
with self.assertRaises(ValueError):
MLPBlock(hidden_size=128, mlp_dim=512, dropout_rate=5.0)
+ @parameterized.expand(TEST_CASE_ACT)
+ def test_act(self, input_param, input_shape, expected_shape):
+ net = MLPBlock(**input_param)
+ with eval_mode(net):
+ result = net(torch.randn(input_shape))
+ self.assertEqual(result.shape, expected_shape)
+ act_name, _ = split_args(input_param["act"])
+ if act_name == "GEGLU":
+ self.assertEqual(net.linear1.in_features, net.linear1.out_features // 2)
+ else:
+ self.assertEqual(net.linear1.in_features, net.linear1.out_features)
+
+ @parameterized.expand(TEST_CASE_DROP)
+ def test_dropout_mode(self, dropout_mode, dropout_layer):
+ net = MLPBlock(hidden_size=128, mlp_dim=512, dropout_rate=0.1, dropout_mode=dropout_mode)
+ self.assertTrue(isinstance(net.drop1, dropout_layer))
+ self.assertTrue(isinstance(net.drop2, dropout_layer))
+
if __name__ == "__main__":
unittest.main()
diff --git a/tests/test_mmar_download.py b/tests/test_mmar_download.py
index 6af3d09fb2..2ac73a8149 100644
--- a/tests/test_mmar_download.py
+++ b/tests/test_mmar_download.py
@@ -142,7 +142,7 @@ def test_load_ckpt(self, input_args, expected_name, expected_val):
def test_unique(self):
# model ids are unique
keys = sorted(m["id"] for m in MODEL_DESC)
- self.assertTrue(keys == sorted(set(keys)))
+ self.assertEqual(keys, sorted(set(keys)))
def test_search(self):
self.assertEqual(_get_val({"a": 1, "b": 2}, key="b"), 2)
diff --git a/tests/test_monai_utils_misc.py b/tests/test_monai_utils_misc.py
index a2a4ed62f7..f4eb5d3956 100644
--- a/tests/test_monai_utils_misc.py
+++ b/tests/test_monai_utils_misc.py
@@ -92,12 +92,11 @@ def test_run_cmd(self):
cmd2 = "-c"
cmd3 = 'import sys; print("\\tThis is on stderr\\n", file=sys.stderr); sys.exit(1)'
os.environ["MONAI_DEBUG"] = str(True)
- try:
+ with self.assertRaises(RuntimeError) as cm:
run_cmd([cmd1, cmd2, cmd3], check=True)
- except RuntimeError as err:
- self.assertIn("This is on stderr", str(err))
- self.assertNotIn("\\n", str(err))
- self.assertNotIn("\\t", str(err))
+ self.assertIn("This is on stderr", str(cm.exception))
+ self.assertNotIn("\\n", str(cm.exception))
+ self.assertNotIn("\\t", str(cm.exception))
if __name__ == "__main__":
diff --git a/tests/test_morphological_ops.py b/tests/test_morphological_ops.py
new file mode 100644
index 0000000000..422e8c4b9d
--- /dev/null
+++ b/tests/test_morphological_ops.py
@@ -0,0 +1,102 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import unittest
+
+import torch
+from parameterized import parameterized
+
+from monai.transforms.utils_morphological_ops import dilate, erode, get_morphological_filter_result_t
+from tests.utils import TEST_NDARRAYS, assert_allclose
+
+TESTS_SHAPE = []
+for p in TEST_NDARRAYS:
+ mask = torch.zeros(1, 1, 5, 5, 5)
+ filter_size = 3
+ TESTS_SHAPE.append([{"mask": p(mask), "filter_size": filter_size}, [1, 1, 5, 5, 5]])
+ mask = torch.zeros(3, 2, 5, 5, 5)
+ filter_size = 5
+ TESTS_SHAPE.append([{"mask": p(mask), "filter_size": filter_size}, [3, 2, 5, 5, 5]])
+ mask = torch.zeros(1, 1, 1, 1, 1)
+ filter_size = 5
+ TESTS_SHAPE.append([{"mask": p(mask), "filter_size": filter_size}, [1, 1, 1, 1, 1]])
+ mask = torch.zeros(1, 1, 1, 1)
+ filter_size = 5
+ TESTS_SHAPE.append([{"mask": p(mask), "filter_size": filter_size}, [1, 1, 1, 1]])
+
+TESTS_VALUE_T = []
+filter_size = 3
+mask = torch.ones(3, 2, 3, 3, 3)
+TESTS_VALUE_T.append([{"mask": mask, "filter_size": filter_size, "pad_value": 1.0}, torch.ones(3, 2, 3, 3, 3)])
+mask = torch.zeros(3, 2, 3, 3, 3)
+TESTS_VALUE_T.append([{"mask": mask, "filter_size": filter_size, "pad_value": 0.0}, torch.zeros(3, 2, 3, 3, 3)])
+mask = torch.ones(3, 2, 3, 3)
+TESTS_VALUE_T.append([{"mask": mask, "filter_size": filter_size, "pad_value": 1.0}, torch.ones(3, 2, 3, 3)])
+mask = torch.zeros(3, 2, 3, 3)
+TESTS_VALUE_T.append([{"mask": mask, "filter_size": filter_size, "pad_value": 0.0}, torch.zeros(3, 2, 3, 3)])
+
+TESTS_VALUE = []
+for p in TEST_NDARRAYS:
+ mask = torch.zeros(3, 2, 5, 5, 5)
+ filter_size = 3
+ TESTS_VALUE.append(
+ [{"mask": p(mask), "filter_size": filter_size}, p(torch.zeros(3, 2, 5, 5, 5)), p(torch.zeros(3, 2, 5, 5, 5))]
+ )
+ mask = torch.ones(1, 1, 3, 3, 3)
+ filter_size = 3
+ TESTS_VALUE.append(
+ [{"mask": p(mask), "filter_size": filter_size}, p(torch.ones(1, 1, 3, 3, 3)), p(torch.ones(1, 1, 3, 3, 3))]
+ )
+ mask = torch.ones(1, 2, 3, 3, 3)
+ filter_size = 3
+ TESTS_VALUE.append(
+ [{"mask": p(mask), "filter_size": filter_size}, p(torch.ones(1, 2, 3, 3, 3)), p(torch.ones(1, 2, 3, 3, 3))]
+ )
+ mask = torch.zeros(3, 2, 3, 3, 3)
+ mask[:, :, 1, 1, 1] = 1.0
+ filter_size = 3
+ TESTS_VALUE.append(
+ [{"mask": p(mask), "filter_size": filter_size}, p(torch.zeros(3, 2, 3, 3, 3)), p(torch.ones(3, 2, 3, 3, 3))]
+ )
+ mask = torch.zeros(3, 2, 3, 3)
+ mask[:, :, 1, 1] = 1.0
+ filter_size = 3
+ TESTS_VALUE.append(
+ [{"mask": p(mask), "filter_size": filter_size}, p(torch.zeros(3, 2, 3, 3)), p(torch.ones(3, 2, 3, 3))]
+ )
+
+
+class TestMorph(unittest.TestCase):
+
+ @parameterized.expand(TESTS_SHAPE)
+ def test_shape(self, input_data, expected_result):
+ result1 = erode(input_data["mask"], input_data["filter_size"])
+ assert_allclose(result1.shape, expected_result, type_test=False, device_test=False, atol=0.0)
+
+ @parameterized.expand(TESTS_VALUE_T)
+ def test_value_t(self, input_data, expected_result):
+ result1 = get_morphological_filter_result_t(
+ input_data["mask"], input_data["filter_size"], input_data["pad_value"]
+ )
+ assert_allclose(result1, expected_result, type_test=False, device_test=False, atol=0.0)
+
+ @parameterized.expand(TESTS_VALUE)
+ def test_value(self, input_data, expected_erode_result, expected_dilate_result):
+ result1 = erode(input_data["mask"], input_data["filter_size"])
+ assert_allclose(result1, expected_erode_result, type_test=True, device_test=True, atol=0.0)
+ result2 = dilate(input_data["mask"], input_data["filter_size"])
+ assert_allclose(result2, expected_dilate_result, type_test=True, device_test=True, atol=0.0)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_multi_scale.py b/tests/test_multi_scale.py
index 6681f266a8..0b49087216 100644
--- a/tests/test_multi_scale.py
+++ b/tests/test_multi_scale.py
@@ -58,17 +58,24 @@ def test_shape(self, input_param, input_data, expected_val):
result = MultiScaleLoss(**input_param).forward(**input_data)
np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-5)
- def test_ill_opts(self):
- with self.assertRaisesRegex(ValueError, ""):
- MultiScaleLoss(loss=dice_loss, kernel="none")
- with self.assertRaisesRegex(ValueError, ""):
- MultiScaleLoss(loss=dice_loss, scales=[-1])(
- torch.ones((1, 1, 3), device=device), torch.ones((1, 1, 3), device=device)
- )
- with self.assertRaisesRegex(ValueError, ""):
- MultiScaleLoss(loss=dice_loss, scales=[-1], reduction="none")(
- torch.ones((1, 1, 3), device=device), torch.ones((1, 1, 3), device=device)
- )
+ @parameterized.expand(
+ [
+ ({"loss": dice_loss, "kernel": "none"}, None, None), # kernel_none
+ ({"loss": dice_loss, "scales": [-1]}, torch.ones((1, 1, 3)), torch.ones((1, 1, 3))), # scales_negative
+ (
+ {"loss": dice_loss, "scales": [-1], "reduction": "none"},
+ torch.ones((1, 1, 3)),
+ torch.ones((1, 1, 3)),
+ ), # scales_negative_reduction_none
+ ]
+ )
+ def test_ill_opts(self, kwargs, input, target):
+ if input is None and target is None:
+ with self.assertRaisesRegex(ValueError, ""):
+ MultiScaleLoss(**kwargs)
+ else:
+ with self.assertRaisesRegex(ValueError, ""):
+ MultiScaleLoss(**kwargs)(input, target)
def test_script(self):
input_param, input_data, expected_val = TEST_CASES[0]
diff --git a/tests/test_nacl_loss.py b/tests/test_nacl_loss.py
new file mode 100644
index 0000000000..704bbdb9b1
--- /dev/null
+++ b/tests/test_nacl_loss.py
@@ -0,0 +1,167 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import unittest
+
+import numpy as np
+import torch
+from parameterized import parameterized
+
+from monai.losses import NACLLoss
+
+inputs = torch.tensor(
+ [
+ [
+ [
+ [0.1498, 0.1158, 0.3996, 0.3730],
+ [0.2155, 0.1585, 0.8541, 0.8579],
+ [0.6640, 0.2424, 0.0774, 0.0324],
+ [0.0580, 0.2180, 0.3447, 0.8722],
+ ],
+ [
+ [0.3908, 0.9366, 0.1779, 0.1003],
+ [0.9630, 0.6118, 0.4405, 0.7916],
+ [0.5782, 0.9515, 0.4088, 0.3946],
+ [0.7860, 0.3910, 0.0324, 0.9568],
+ ],
+ [
+ [0.0759, 0.0238, 0.5570, 0.1691],
+ [0.2703, 0.7722, 0.1611, 0.6431],
+ [0.8051, 0.6596, 0.4121, 0.1125],
+ [0.5283, 0.6746, 0.5528, 0.7913],
+ ],
+ ]
+ ]
+)
+targets = torch.tensor([[[1, 1, 1, 1], [1, 1, 1, 0], [0, 0, 1, 0], [0, 1, 0, 0]]])
+
+TEST_CASES = [
+ [{"classes": 3, "dim": 2}, {"inputs": inputs, "targets": targets}, 1.1442],
+ [{"classes": 3, "dim": 2}, {"inputs": inputs.repeat(4, 1, 1, 1), "targets": targets.repeat(4, 1, 1)}, 1.1442],
+ [{"classes": 3, "dim": 2, "kernel_ops": "gaussian"}, {"inputs": inputs, "targets": targets}, 1.1433],
+ [{"classes": 3, "dim": 2, "kernel_ops": "gaussian", "sigma": 0.5}, {"inputs": inputs, "targets": targets}, 1.1469],
+ [{"classes": 3, "dim": 2, "distance_type": "l2"}, {"inputs": inputs, "targets": targets}, 1.1269],
+ [{"classes": 3, "dim": 2, "alpha": 0.2}, {"inputs": inputs, "targets": targets}, 1.1790],
+ [
+ {"classes": 3, "dim": 3, "kernel_ops": "gaussian"},
+ {
+ "inputs": torch.tensor(
+ [
+ [
+ [
+ [
+ [0.5977, 0.2767, 0.0591, 0.1675],
+ [0.4835, 0.3778, 0.8406, 0.3065],
+ [0.6047, 0.2860, 0.9742, 0.2013],
+ [0.9128, 0.8368, 0.6711, 0.4384],
+ ],
+ [
+ [0.9797, 0.1863, 0.5584, 0.6652],
+ [0.2272, 0.2004, 0.7914, 0.4224],
+ [0.5097, 0.8818, 0.2581, 0.3495],
+ [0.1054, 0.5483, 0.3732, 0.3587],
+ ],
+ [
+ [0.3060, 0.7066, 0.7922, 0.4689],
+ [0.1733, 0.8902, 0.6704, 0.2037],
+ [0.8656, 0.5561, 0.2701, 0.0092],
+ [0.1866, 0.7714, 0.6424, 0.9791],
+ ],
+ [
+ [0.5067, 0.3829, 0.6156, 0.8985],
+ [0.5192, 0.8347, 0.2098, 0.2260],
+ [0.8887, 0.3944, 0.6400, 0.5345],
+ [0.1207, 0.3763, 0.5282, 0.7741],
+ ],
+ ],
+ [
+ [
+ [0.8499, 0.4759, 0.1964, 0.5701],
+ [0.3190, 0.1238, 0.2368, 0.9517],
+ [0.0797, 0.6185, 0.0135, 0.8672],
+ [0.4116, 0.1683, 0.1355, 0.0545],
+ ],
+ [
+ [0.7533, 0.2658, 0.5955, 0.4498],
+ [0.9500, 0.2317, 0.2825, 0.9763],
+ [0.1493, 0.1558, 0.3743, 0.8723],
+ [0.1723, 0.7980, 0.8816, 0.0133],
+ ],
+ [
+ [0.8426, 0.2666, 0.2077, 0.3161],
+ [0.1725, 0.8414, 0.1515, 0.2825],
+ [0.4882, 0.5159, 0.4120, 0.1585],
+ [0.2551, 0.9073, 0.7691, 0.9898],
+ ],
+ [
+ [0.4633, 0.8717, 0.8537, 0.2899],
+ [0.3693, 0.7953, 0.1183, 0.4596],
+ [0.0087, 0.7925, 0.0989, 0.8385],
+ [0.8261, 0.6920, 0.7069, 0.4464],
+ ],
+ ],
+ [
+ [
+ [0.0110, 0.1608, 0.4814, 0.6317],
+ [0.0194, 0.9669, 0.3259, 0.0028],
+ [0.5674, 0.8286, 0.0306, 0.5309],
+ [0.3973, 0.8183, 0.0238, 0.1934],
+ ],
+ [
+ [0.8947, 0.6629, 0.9439, 0.8905],
+ [0.0072, 0.1697, 0.4634, 0.0201],
+ [0.7184, 0.2424, 0.0820, 0.7504],
+ [0.3937, 0.1424, 0.4463, 0.5779],
+ ],
+ [
+ [0.4123, 0.6227, 0.0523, 0.8826],
+ [0.0051, 0.0353, 0.3662, 0.7697],
+ [0.4867, 0.8986, 0.2510, 0.5316],
+ [0.1856, 0.2634, 0.9140, 0.9725],
+ ],
+ [
+ [0.2041, 0.4248, 0.2371, 0.7256],
+ [0.2168, 0.5380, 0.4538, 0.7007],
+ [0.9013, 0.2623, 0.0739, 0.2998],
+ [0.1366, 0.5590, 0.2952, 0.4592],
+ ],
+ ],
+ ]
+ ]
+ ),
+ "targets": torch.tensor(
+ [
+ [
+ [[0, 1, 0, 1], [1, 2, 1, 0], [2, 1, 1, 1], [1, 1, 0, 1]],
+ [[2, 1, 0, 2], [1, 2, 0, 2], [1, 0, 1, 1], [1, 1, 0, 0]],
+ [[1, 0, 2, 1], [0, 2, 2, 1], [1, 0, 1, 1], [0, 0, 2, 1]],
+ [[2, 1, 1, 0], [1, 0, 0, 2], [1, 0, 2, 1], [2, 1, 0, 1]],
+ ]
+ ]
+ ),
+ },
+ 1.15035,
+ ],
+]
+
+
+class TestNACLLoss(unittest.TestCase):
+ @parameterized.expand(TEST_CASES)
+ def test_result(self, input_param, input_data, expected_val):
+ loss = NACLLoss(**input_param)
+ result = loss(**input_data)
+ np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_network_consistency.py b/tests/test_network_consistency.py
index 4182501808..bcfd448144 100644
--- a/tests/test_network_consistency.py
+++ b/tests/test_network_consistency.py
@@ -14,8 +14,8 @@
import json
import os
import unittest
+from collections.abc import Sequence
from glob import glob
-from typing import Sequence
from unittest.case import skipIf
import torch
diff --git a/tests/test_nifti_endianness.py b/tests/test_nifti_endianness.py
index 4475d8aaab..f8531dc08f 100644
--- a/tests/test_nifti_endianness.py
+++ b/tests/test_nifti_endianness.py
@@ -82,7 +82,7 @@ def test_switch(self): # verify data types
after = switch_endianness(before)
np.testing.assert_allclose(after.astype(float), expected_float)
- before = np.array(["1.12", "-9.2", "42"], dtype=np.string_)
+ before = np.array(["1.12", "-9.2", "42"], dtype=np.bytes_)
after = switch_endianness(before)
np.testing.assert_array_equal(before, after)
diff --git a/tests/test_nrrd_reader.py b/tests/test_nrrd_reader.py
index 649b9fa94d..5bf958e970 100644
--- a/tests/test_nrrd_reader.py
+++ b/tests/test_nrrd_reader.py
@@ -40,8 +40,8 @@
"dimension": 4,
"space": "left-posterior-superior",
"sizes": [3, 4, 4, 1],
- "space directions": [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
- "space origin": [0.0, 0.0, 0.0],
+ "space directions": [[0.7, 0.0, 0.0], [0.0, 0.0, -0.8], [0.0, 0.9, 0.0]],
+ "space origin": [1.0, 5.0, 20.0],
},
]
@@ -110,6 +110,10 @@ def test_read_with_header(self, data_shape, filename, expected_shape, dtype, ref
np.testing.assert_allclose(image_array, test_image)
self.assertIsInstance(image_header, dict)
self.assertTupleEqual(tuple(image_header["spatial_shape"]), expected_shape)
+ np.testing.assert_allclose(
+ image_header["affine"],
+ np.array([[-0.7, 0.0, 0.0, -1.0], [0.0, 0.0, -0.9, -5.0], [0.0, -0.8, 0.0, 20.0], [0.0, 0.0, 0.0, 1.0]]),
+ )
@parameterized.expand([TEST_CASE_8])
def test_read_with_header_index_order_c(self, data_shape, filename, expected_shape, dtype, reference_header):
diff --git a/tests/test_optional_import.py b/tests/test_optional_import.py
index e7e1c03fd0..2f640f88d0 100644
--- a/tests/test_optional_import.py
+++ b/tests/test_optional_import.py
@@ -13,22 +13,20 @@
import unittest
+from parameterized import parameterized
+
from monai.utils import OptionalImportError, exact_version, optional_import
class TestOptionalImport(unittest.TestCase):
- def test_default(self):
- my_module, flag = optional_import("not_a_module")
+ @parameterized.expand(["not_a_module", "torch.randint"])
+ def test_default(self, import_module):
+ my_module, flag = optional_import(import_module)
self.assertFalse(flag)
with self.assertRaises(OptionalImportError):
my_module.test
- my_module, flag = optional_import("torch.randint")
- with self.assertRaises(OptionalImportError):
- self.assertFalse(flag)
- print(my_module.test)
-
def test_import_valid(self):
my_module, flag = optional_import("torch")
self.assertTrue(flag)
@@ -47,18 +45,9 @@ def test_import_wrong_number(self):
self.assertTrue(flag)
print(my_module.randint(1, 2, (1, 2)))
- def test_import_good_number(self):
- my_module, flag = optional_import("torch", "0")
- my_module.nn
- self.assertTrue(flag)
- print(my_module.randint(1, 2, (1, 2)))
-
- my_module, flag = optional_import("torch", "0.0.0.1")
- my_module.nn
- self.assertTrue(flag)
- print(my_module.randint(1, 2, (1, 2)))
-
- my_module, flag = optional_import("torch", "1.1.0")
+ @parameterized.expand(["0", "0.0.0.1", "1.1.0"])
+ def test_import_good_number(self, version_number):
+ my_module, flag = optional_import("torch", version_number)
my_module.nn
self.assertTrue(flag)
print(my_module.randint(1, 2, (1, 2)))
diff --git a/tests/test_ordering.py b/tests/test_ordering.py
new file mode 100644
index 0000000000..e6b235e179
--- /dev/null
+++ b/tests/test_ordering.py
@@ -0,0 +1,289 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import unittest
+
+import numpy as np
+from parameterized import parameterized
+
+from monai.utils.enums import OrderingTransformations, OrderingType
+from monai.utils.ordering import Ordering
+
+TEST_2D_NON_RANDOM = [
+ [
+ {
+ "ordering_type": OrderingType.RASTER_SCAN,
+ "spatial_dims": 2,
+ "dimensions": (1, 2, 2),
+ "reflected_spatial_dims": (),
+ "transpositions_axes": (),
+ "rot90_axes": (),
+ "transformation_order": (
+ OrderingTransformations.TRANSPOSE.value,
+ OrderingTransformations.ROTATE_90.value,
+ OrderingTransformations.REFLECT.value,
+ ),
+ },
+ [0, 1, 2, 3],
+ ],
+ [
+ {
+ "ordering_type": OrderingType.S_CURVE,
+ "spatial_dims": 2,
+ "dimensions": (1, 2, 2),
+ "reflected_spatial_dims": (),
+ "transpositions_axes": (),
+ "rot90_axes": (),
+ "transformation_order": (
+ OrderingTransformations.TRANSPOSE.value,
+ OrderingTransformations.ROTATE_90.value,
+ OrderingTransformations.REFLECT.value,
+ ),
+ },
+ [0, 1, 3, 2],
+ ],
+ [
+ {
+ "ordering_type": OrderingType.RASTER_SCAN,
+ "spatial_dims": 2,
+ "dimensions": (1, 2, 2),
+ "reflected_spatial_dims": (True, False),
+ "transpositions_axes": (),
+ "rot90_axes": (),
+ "transformation_order": (
+ OrderingTransformations.TRANSPOSE.value,
+ OrderingTransformations.ROTATE_90.value,
+ OrderingTransformations.REFLECT.value,
+ ),
+ },
+ [2, 3, 0, 1],
+ ],
+ [
+ {
+ "ordering_type": OrderingType.S_CURVE,
+ "spatial_dims": 2,
+ "dimensions": (1, 2, 2),
+ "reflected_spatial_dims": (True, False),
+ "transpositions_axes": (),
+ "rot90_axes": (),
+ "transformation_order": (
+ OrderingTransformations.TRANSPOSE.value,
+ OrderingTransformations.ROTATE_90.value,
+ OrderingTransformations.REFLECT.value,
+ ),
+ },
+ [2, 3, 1, 0],
+ ],
+ [
+ {
+ "ordering_type": OrderingType.RASTER_SCAN,
+ "spatial_dims": 2,
+ "dimensions": (1, 2, 2),
+ "reflected_spatial_dims": (),
+ "transpositions_axes": ((1, 0),),
+ "rot90_axes": (),
+ "transformation_order": (
+ OrderingTransformations.TRANSPOSE.value,
+ OrderingTransformations.ROTATE_90.value,
+ OrderingTransformations.REFLECT.value,
+ ),
+ },
+ [0, 2, 1, 3],
+ ],
+ [
+ {
+ "ordering_type": OrderingType.S_CURVE,
+ "spatial_dims": 2,
+ "dimensions": (1, 2, 2),
+ "reflected_spatial_dims": (),
+ "transpositions_axes": ((1, 0),),
+ "rot90_axes": (),
+ "transformation_order": (
+ OrderingTransformations.TRANSPOSE.value,
+ OrderingTransformations.ROTATE_90.value,
+ OrderingTransformations.REFLECT.value,
+ ),
+ },
+ [0, 2, 3, 1],
+ ],
+ [
+ {
+ "ordering_type": OrderingType.RASTER_SCAN,
+ "spatial_dims": 2,
+ "dimensions": (1, 2, 2),
+ "reflected_spatial_dims": (),
+ "transpositions_axes": (),
+ "rot90_axes": ((0, 1),),
+ "transformation_order": (
+ OrderingTransformations.TRANSPOSE.value,
+ OrderingTransformations.ROTATE_90.value,
+ OrderingTransformations.REFLECT.value,
+ ),
+ },
+ [1, 3, 0, 2],
+ ],
+ [
+ {
+ "ordering_type": OrderingType.S_CURVE,
+ "spatial_dims": 2,
+ "dimensions": (1, 2, 2),
+ "reflected_spatial_dims": (),
+ "transpositions_axes": (),
+ "rot90_axes": ((0, 1),),
+ "transformation_order": (
+ OrderingTransformations.TRANSPOSE.value,
+ OrderingTransformations.ROTATE_90.value,
+ OrderingTransformations.REFLECT.value,
+ ),
+ },
+ [1, 3, 2, 0],
+ ],
+ [
+ {
+ "ordering_type": OrderingType.RASTER_SCAN,
+ "spatial_dims": 2,
+ "dimensions": (1, 2, 2),
+ "reflected_spatial_dims": (True, False),
+ "transpositions_axes": ((1, 0),),
+ "rot90_axes": ((0, 1),),
+ "transformation_order": (
+ OrderingTransformations.TRANSPOSE.value,
+ OrderingTransformations.ROTATE_90.value,
+ OrderingTransformations.REFLECT.value,
+ ),
+ },
+ [0, 1, 2, 3],
+ ],
+ [
+ {
+ "ordering_type": OrderingType.S_CURVE,
+ "spatial_dims": 2,
+ "dimensions": (1, 2, 2),
+ "reflected_spatial_dims": (True, False),
+ "transpositions_axes": ((1, 0),),
+ "rot90_axes": ((0, 1),),
+ "transformation_order": (
+ OrderingTransformations.TRANSPOSE.value,
+ OrderingTransformations.ROTATE_90.value,
+ OrderingTransformations.REFLECT.value,
+ ),
+ },
+ [0, 1, 3, 2],
+ ],
+]
+
+
+TEST_3D = [
+ [
+ {
+ "ordering_type": OrderingType.RASTER_SCAN,
+ "spatial_dims": 3,
+ "dimensions": (1, 2, 2, 2),
+ "reflected_spatial_dims": (),
+ "transpositions_axes": (),
+ "rot90_axes": (),
+ "transformation_order": (
+ OrderingTransformations.TRANSPOSE.value,
+ OrderingTransformations.ROTATE_90.value,
+ OrderingTransformations.REFLECT.value,
+ ),
+ },
+ [0, 1, 2, 3, 4, 5, 6, 7],
+ ]
+]
+
+TEST_ORDERING_TYPE_FAILURE = [
+ [
+ {
+ "ordering_type": "hilbert",
+ "spatial_dims": 2,
+ "dimensions": (1, 2, 2),
+ "reflected_spatial_dims": (True, False),
+ "transpositions_axes": ((1, 0),),
+ "rot90_axes": ((0, 1),),
+ "transformation_order": (
+ OrderingTransformations.TRANSPOSE.value,
+ OrderingTransformations.ROTATE_90.value,
+ OrderingTransformations.REFLECT.value,
+ ),
+ }
+ ]
+]
+
+TEST_ORDERING_TRANSFORMATION_FAILURE = [
+ [
+ {
+ "ordering_type": OrderingType.S_CURVE,
+ "spatial_dims": 2,
+ "dimensions": (1, 2, 2),
+ "reflected_spatial_dims": (True, False),
+ "transpositions_axes": ((1, 0),),
+ "rot90_axes": ((0, 1),),
+ "transformation_order": (
+ OrderingTransformations.TRANSPOSE.value,
+ OrderingTransformations.ROTATE_90.value,
+ "flip",
+ ),
+ }
+ ]
+]
+
+TEST_REVERT = [
+ [
+ {
+ "ordering_type": OrderingType.S_CURVE,
+ "spatial_dims": 2,
+ "dimensions": (1, 2, 2),
+ "reflected_spatial_dims": (True, False),
+ "transpositions_axes": (),
+ "rot90_axes": (),
+ "transformation_order": (
+ OrderingTransformations.TRANSPOSE.value,
+ OrderingTransformations.ROTATE_90.value,
+ OrderingTransformations.REFLECT.value,
+ ),
+ }
+ ]
+]
+
+
+class TestOrdering(unittest.TestCase):
+ @parameterized.expand(TEST_2D_NON_RANDOM + TEST_3D)
+ def test_ordering(self, input_param, expected_sequence_ordering):
+ ordering = Ordering(**input_param)
+ self.assertTrue(np.array_equal(ordering.get_sequence_ordering(), expected_sequence_ordering, equal_nan=True))
+
+ @parameterized.expand(TEST_ORDERING_TYPE_FAILURE)
+ def test_ordering_type_failure(self, input_param):
+ with self.assertRaises(ValueError):
+ Ordering(**input_param)
+
+ @parameterized.expand(TEST_ORDERING_TRANSFORMATION_FAILURE)
+ def test_ordering_transformation_failure(self, input_param):
+ with self.assertRaises(ValueError):
+ Ordering(**input_param)
+
+ @parameterized.expand(TEST_REVERT)
+ def test_revert(self, input_param):
+ sequence = np.random.randint(0, 100, size=input_param["dimensions"]).flatten()
+
+ ordering = Ordering(**input_param)
+
+ reverted_sequence = sequence[ordering.get_sequence_ordering()]
+ reverted_sequence = reverted_sequence[ordering.get_revert_sequence_ordering()]
+
+ self.assertTrue(np.array_equal(sequence, reverted_sequence, equal_nan=True))
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_pad_collation.py b/tests/test_pad_collation.py
index ee6e001438..9d5012c9a3 100644
--- a/tests/test_pad_collation.py
+++ b/tests/test_pad_collation.py
@@ -89,7 +89,7 @@ def tearDown(self) -> None:
@parameterized.expand(TESTS)
def test_pad_collation(self, t_type, collate_method, transform):
- if t_type == dict:
+ if t_type is dict:
dataset = CacheDataset(self.dict_data, transform, progress=False)
else:
dataset = _Dataset(self.list_data, self.list_labels, transform)
@@ -104,7 +104,7 @@ def test_pad_collation(self, t_type, collate_method, transform):
loader = DataLoader(dataset, batch_size=10, collate_fn=collate_method)
# check collation in forward direction
for data in loader:
- if t_type == dict:
+ if t_type is dict:
shapes = []
decollated_data = decollate_batch(data)
for d in decollated_data:
@@ -113,11 +113,11 @@ def test_pad_collation(self, t_type, collate_method, transform):
self.assertTrue(len(output["image"].applied_operations), len(dataset.transform.transforms))
self.assertTrue(len(set(shapes)) > 1) # inverted shapes must be different because of random xforms
- if t_type == dict:
+ if t_type is dict:
batch_inverse = BatchInverseTransform(dataset.transform, loader)
for data in loader:
output = batch_inverse(data)
- self.assertTrue(output[0]["image"].shape, (1, 10, 9))
+ self.assertEqual(output[0]["image"].shape, (1, 10, 9))
if __name__ == "__main__":
diff --git a/tests/test_patch_gan_dicriminator.py b/tests/test_patch_gan_dicriminator.py
new file mode 100644
index 0000000000..c19898e70d
--- /dev/null
+++ b/tests/test_patch_gan_dicriminator.py
@@ -0,0 +1,179 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import unittest
+
+import torch
+from parameterized import parameterized
+
+from monai.networks import eval_mode
+from monai.networks.nets import MultiScalePatchDiscriminator, PatchDiscriminator
+from tests.utils import test_script_save
+
+TEST_PATCHGAN = [
+ [
+ {
+ "num_layers_d": 3,
+ "spatial_dims": 2,
+ "channels": 8,
+ "in_channels": 3,
+ "out_channels": 1,
+ "kernel_size": 3,
+ "activation": "LEAKYRELU",
+ "norm": "instance",
+ "bias": False,
+ "dropout": 0.1,
+ },
+ torch.rand([1, 3, 256, 512]),
+ (1, 8, 128, 256),
+ (1, 1, 32, 64),
+ ],
+ [
+ {
+ "num_layers_d": 3,
+ "spatial_dims": 3,
+ "channels": 8,
+ "in_channels": 3,
+ "out_channels": 1,
+ "kernel_size": 3,
+ "activation": "LEAKYRELU",
+ "norm": "instance",
+ "bias": False,
+ "dropout": 0.1,
+ },
+ torch.rand([1, 3, 256, 512, 256]),
+ (1, 8, 128, 256, 128),
+ (1, 1, 32, 64, 32),
+ ],
+]
+
+TEST_MULTISCALE_PATCHGAN = [
+ [
+ {
+ "num_d": 2,
+ "num_layers_d": 3,
+ "spatial_dims": 2,
+ "channels": 8,
+ "in_channels": 3,
+ "out_channels": 1,
+ "kernel_size": 3,
+ "activation": "LEAKYRELU",
+ "norm": "instance",
+ "bias": False,
+ "dropout": 0.1,
+ "minimum_size_im": 256,
+ },
+ torch.rand([1, 3, 256, 512]),
+ [(1, 1, 32, 64), (1, 1, 4, 8)],
+ [4, 7],
+ ],
+ [
+ {
+ "num_d": 2,
+ "num_layers_d": 3,
+ "spatial_dims": 3,
+ "channels": 8,
+ "in_channels": 3,
+ "out_channels": 1,
+ "kernel_size": 3,
+ "activation": "LEAKYRELU",
+ "norm": "instance",
+ "bias": False,
+ "dropout": 0.1,
+ "minimum_size_im": 256,
+ },
+ torch.rand([1, 3, 256, 512, 256]),
+ [(1, 1, 32, 64, 32), (1, 1, 4, 8, 4)],
+ [4, 7],
+ ],
+]
+TEST_TOO_SMALL_SIZE = [
+ {
+ "num_d": 2,
+ "num_layers_d": 6,
+ "spatial_dims": 2,
+ "channels": 8,
+ "in_channels": 3,
+ "out_channels": 1,
+ "kernel_size": 3,
+ "activation": "LEAKYRELU",
+ "norm": "instance",
+ "bias": False,
+ "dropout": 0.1,
+ "minimum_size_im": 256,
+ }
+]
+
+
+class TestPatchGAN(unittest.TestCase):
+ @parameterized.expand(TEST_PATCHGAN)
+ def test_shape(self, input_param, input_data, expected_shape_feature, expected_shape_output):
+ net = PatchDiscriminator(**input_param)
+ with eval_mode(net):
+ result = net.forward(input_data)
+ self.assertEqual(tuple(result[0].shape), expected_shape_feature)
+ self.assertEqual(tuple(result[-1].shape), expected_shape_output)
+
+ def test_script(self):
+ net = PatchDiscriminator(
+ num_layers_d=3,
+ spatial_dims=2,
+ channels=8,
+ in_channels=3,
+ out_channels=1,
+ kernel_size=3,
+ activation="LEAKYRELU",
+ norm="instance",
+ bias=False,
+ dropout=0.1,
+ )
+ i = torch.rand([1, 3, 256, 512])
+ test_script_save(net, i)
+
+
+class TestMultiscalePatchGAN(unittest.TestCase):
+ @parameterized.expand(TEST_MULTISCALE_PATCHGAN)
+ def test_shape(self, input_param, input_data, expected_shape, features_lengths=None):
+ net = MultiScalePatchDiscriminator(**input_param)
+ with eval_mode(net):
+ result, features = net.forward(input_data)
+ for r_ind, r in enumerate(result):
+ self.assertEqual(tuple(r.shape), expected_shape[r_ind])
+ for o_d_ind, o_d in enumerate(features):
+ self.assertEqual(len(o_d), features_lengths[o_d_ind])
+
+ def test_too_small_shape(self):
+ with self.assertRaises(AssertionError):
+ MultiScalePatchDiscriminator(**TEST_TOO_SMALL_SIZE[0])
+
+ def test_script(self):
+ net = MultiScalePatchDiscriminator(
+ num_d=2,
+ num_layers_d=3,
+ spatial_dims=2,
+ channels=8,
+ in_channels=3,
+ out_channels=1,
+ kernel_size=3,
+ activation="LEAKYRELU",
+ norm="instance",
+ bias=False,
+ dropout=0.1,
+ minimum_size_im=256,
+ )
+ i = torch.rand([1, 3, 256, 512])
+ test_script_save(net, i)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_patchembedding.py b/tests/test_patchembedding.py
index d059145033..71ac767966 100644
--- a/tests/test_patchembedding.py
+++ b/tests/test_patchembedding.py
@@ -43,7 +43,7 @@
"patch_size": (patch_size,) * nd,
"hidden_size": hidden_size,
"num_heads": num_heads,
- "pos_embed": proj_type,
+ "proj_type": proj_type,
"pos_embed_type": pos_embed_type,
"dropout_rate": dropout_rate,
},
@@ -127,7 +127,7 @@ def test_ill_arg(self):
patch_size=(16, 16, 16),
hidden_size=128,
num_heads=12,
- pos_embed="conv",
+ proj_type="conv",
pos_embed_type="sincos",
dropout_rate=5.0,
)
@@ -139,7 +139,7 @@ def test_ill_arg(self):
patch_size=(64, 64, 64),
hidden_size=512,
num_heads=8,
- pos_embed="perceptron",
+ proj_type="perceptron",
pos_embed_type="sincos",
dropout_rate=0.3,
)
@@ -151,7 +151,7 @@ def test_ill_arg(self):
patch_size=(8, 8, 8),
hidden_size=512,
num_heads=14,
- pos_embed="conv",
+ proj_type="conv",
dropout_rate=0.3,
)
@@ -162,7 +162,7 @@ def test_ill_arg(self):
patch_size=(4, 4, 4),
hidden_size=768,
num_heads=8,
- pos_embed="perceptron",
+ proj_type="perceptron",
dropout_rate=0.3,
)
with self.assertRaises(ValueError):
@@ -183,7 +183,7 @@ def test_ill_arg(self):
patch_size=(16, 16, 16),
hidden_size=768,
num_heads=12,
- pos_embed="perc",
+ proj_type="perc",
dropout_rate=0.3,
)
diff --git a/tests/test_perceptual_loss.py b/tests/test_perceptual_loss.py
index 02232e6f8d..b8aa2e5982 100644
--- a/tests/test_perceptual_loss.py
+++ b/tests/test_perceptual_loss.py
@@ -18,7 +18,7 @@
from monai.losses import PerceptualLoss
from monai.utils import optional_import
-from tests.utils import SkipIfBeforePyTorchVersion, skip_if_downloading_fails, skip_if_quick
+from tests.utils import SkipIfBeforePyTorchVersion, assert_allclose, skip_if_downloading_fails, skip_if_quick
_, has_torchvision = optional_import("torchvision")
TEST_CASES = [
@@ -40,11 +40,31 @@
(2, 1, 64, 64, 64),
(2, 1, 64, 64, 64),
],
+ [
+ {"spatial_dims": 3, "network_type": "medicalnet_resnet10_23datasets", "is_fake_3d": False},
+ (2, 6, 64, 64, 64),
+ (2, 6, 64, 64, 64),
+ ],
+ [
+ {
+ "spatial_dims": 3,
+ "network_type": "medicalnet_resnet10_23datasets",
+ "is_fake_3d": False,
+ "channel_wise": True,
+ },
+ (2, 6, 64, 64, 64),
+ (2, 6, 64, 64, 64),
+ ],
[
{"spatial_dims": 3, "network_type": "medicalnet_resnet50_23datasets", "is_fake_3d": False},
(2, 1, 64, 64, 64),
(2, 1, 64, 64, 64),
],
+ [
+ {"spatial_dims": 3, "network_type": "medicalnet_resnet50_23datasets", "is_fake_3d": False},
+ (2, 6, 64, 64, 64),
+ (2, 6, 64, 64, 64),
+ ],
[
{"spatial_dims": 3, "network_type": "resnet50", "is_fake_3d": True, "pretrained": True, "fake_3d_ratio": 0.2},
(2, 1, 64, 64, 64),
@@ -63,7 +83,11 @@ def test_shape(self, input_param, input_shape, target_shape):
with skip_if_downloading_fails():
loss = PerceptualLoss(**input_param)
result = loss(torch.randn(input_shape), torch.randn(target_shape))
- self.assertEqual(result.shape, torch.Size([]))
+
+ if "channel_wise" in input_param.keys() and input_param["channel_wise"]:
+ self.assertEqual(result.shape, torch.Size([input_shape[1]]))
+ else:
+ self.assertEqual(result.shape, torch.Size([]))
@parameterized.expand(TEST_CASES)
def test_identical_input(self, input_param, input_shape, target_shape):
@@ -71,7 +95,11 @@ def test_identical_input(self, input_param, input_shape, target_shape):
loss = PerceptualLoss(**input_param)
tensor = torch.randn(input_shape)
result = loss(tensor, tensor)
- self.assertEqual(result, torch.Tensor([0.0]))
+
+ if "channel_wise" in input_param.keys() and input_param["channel_wise"]:
+ assert_allclose(result, torch.Tensor([0.0] * input_shape[1]))
+ else:
+ self.assertEqual(result, torch.Tensor([0.0]))
def test_different_shape(self):
with skip_if_downloading_fails():
@@ -85,12 +113,10 @@ def test_1d(self):
with self.assertRaises(NotImplementedError):
PerceptualLoss(spatial_dims=1)
- def test_medicalnet_on_2d_data(self):
- with self.assertRaises(ValueError):
- PerceptualLoss(spatial_dims=2, network_type="medicalnet_resnet10_23datasets")
-
+ @parameterized.expand(["medicalnet_resnet10_23datasets", "medicalnet_resnet50_23datasets"])
+ def test_medicalnet_on_2d_data(self, network_type):
with self.assertRaises(ValueError):
- PerceptualLoss(spatial_dims=2, network_type="medicalnet_resnet50_23datasets")
+ PerceptualLoss(spatial_dims=2, network_type=network_type)
if __name__ == "__main__":
diff --git a/tests/test_persistentdataset.py b/tests/test_persistentdataset.py
index b7bf2fbb11..7c4969e283 100644
--- a/tests/test_persistentdataset.py
+++ b/tests/test_persistentdataset.py
@@ -165,7 +165,7 @@ def test_different_transforms(self):
im1 = PersistentDataset([im], Identity(), cache_dir=path, hash_transform=json_hashing)[0]
im2 = PersistentDataset([im], Flip(1), cache_dir=path, hash_transform=json_hashing)[0]
l2 = ((im1 - im2) ** 2).sum() ** 0.5
- self.assertTrue(l2 > 1)
+ self.assertGreater(l2, 1)
if __name__ == "__main__":
diff --git a/tests/test_point_based_window_inferer.py b/tests/test_point_based_window_inferer.py
new file mode 100644
index 0000000000..1b293288c4
--- /dev/null
+++ b/tests/test_point_based_window_inferer.py
@@ -0,0 +1,77 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import unittest
+
+import torch
+from parameterized import parameterized
+
+from monai.apps.vista3d.inferer import point_based_window_inferer
+from monai.networks import eval_mode
+from monai.networks.nets.vista3d import vista3d132
+from monai.utils import optional_import
+from tests.utils import SkipIfBeforePyTorchVersion, skip_if_quick
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+
+_, has_tqdm = optional_import("tqdm")
+
+TEST_CASES = [
+ [
+ {"encoder_embed_dim": 48, "in_channels": 1},
+ (1, 1, 64, 64, 64),
+ {
+ "roi_size": [32, 32, 32],
+ "point_coords": torch.tensor([[[1, 2, 3], [1, 2, 3]]], device=device),
+ "point_labels": torch.tensor([[1, 0]], device=device),
+ },
+ ],
+ [
+ {"encoder_embed_dim": 48, "in_channels": 1},
+ (1, 1, 64, 64, 64),
+ {
+ "roi_size": [32, 32, 32],
+ "point_coords": torch.tensor([[[1, 2, 3], [1, 2, 3]]], device=device),
+ "point_labels": torch.tensor([[1, 0]], device=device),
+ "class_vector": torch.tensor([1], device=device),
+ },
+ ],
+ [
+ {"encoder_embed_dim": 48, "in_channels": 1},
+ (1, 1, 64, 64, 64),
+ {
+ "roi_size": [32, 32, 32],
+ "point_coords": torch.tensor([[[1, 2, 3], [1, 2, 3]]], device=device),
+ "point_labels": torch.tensor([[1, 0]], device=device),
+ "class_vector": torch.tensor([1], device=device),
+ "point_start": 1,
+ },
+ ],
+]
+
+
+@SkipIfBeforePyTorchVersion((1, 11))
+@skip_if_quick
+class TestPointBasedWindowInferer(unittest.TestCase):
+ @parameterized.expand(TEST_CASES)
+ def test_vista3d(self, vista3d_params, inputs_shape, inferer_params):
+ vista3d = vista3d132(**vista3d_params).to(device)
+ with eval_mode(vista3d):
+ inferer_params["predictor"] = vista3d
+ inferer_params["inputs"] = torch.randn(*inputs_shape).to(device)
+ stitched_output = point_based_window_inferer(**inferer_params)
+ self.assertEqual(stitched_output.shape, inputs_shape)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_prepare_batch_default.py b/tests/test_prepare_batch_default.py
index d5a5fbf57e..093468ce27 100644
--- a/tests/test_prepare_batch_default.py
+++ b/tests/test_prepare_batch_default.py
@@ -14,12 +14,14 @@
import unittest
import torch
+from parameterized import parameterized
from monai.engines import PrepareBatchDefault, SupervisedEvaluator
from tests.utils import assert_allclose
class TestNet(torch.nn.Module):
+ __test__ = False # indicate to pytest that this class is not intended for collection
def forward(self, x: torch.Tensor):
return x
@@ -27,85 +29,48 @@ def forward(self, x: torch.Tensor):
class TestPrepareBatchDefault(unittest.TestCase):
- def test_dict_content(self):
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- dataloader = [
- {
- "image": torch.tensor([1, 2]),
- "label": torch.tensor([3, 4]),
- "extra1": torch.tensor([5, 6]),
- "extra2": 16,
- "extra3": "test",
- }
+ @parameterized.expand(
+ [
+ (
+ [
+ {
+ "image": torch.tensor([1, 2]),
+ "label": torch.tensor([3, 4]),
+ "extra1": torch.tensor([5, 6]),
+ "extra2": 16,
+ "extra3": "test",
+ }
+ ],
+ TestNet(),
+ True,
+ ), # dict_content
+ ([torch.tensor([1, 2])], torch.nn.Identity(), True), # tensor_content
+ ([(torch.tensor([1, 2]), torch.tensor([3, 4]))], torch.nn.Identity(), True), # pair_content
+ ([], TestNet(), False), # empty_data
]
- # set up engine
- evaluator = SupervisedEvaluator(
- device=device,
- val_data_loader=dataloader,
- epoch_length=1,
- network=TestNet(),
- non_blocking=False,
- prepare_batch=PrepareBatchDefault(),
- decollate=False,
- mode="eval",
- )
- evaluator.run()
- output = evaluator.state.output
- assert_allclose(output["image"], torch.tensor([1, 2], device=device))
- assert_allclose(output["label"], torch.tensor([3, 4], device=device))
-
- def test_tensor_content(self):
+ )
+ def test_prepare_batch(self, dataloader, network, should_run):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- dataloader = [torch.tensor([1, 2])]
-
- # set up engine
evaluator = SupervisedEvaluator(
device=device,
val_data_loader=dataloader,
- epoch_length=1,
- network=torch.nn.Identity(),
+ epoch_length=len(dataloader) if should_run else 0,
+ network=network,
non_blocking=False,
prepare_batch=PrepareBatchDefault(),
decollate=False,
- mode="eval",
+ mode="eval" if should_run else "train",
)
evaluator.run()
- output = evaluator.state.output
- assert_allclose(output["image"], torch.tensor([1, 2], device=device))
- self.assertTrue(output["label"] is None)
- def test_pair_content(self):
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- dataloader = [(torch.tensor([1, 2]), torch.tensor([3, 4]))]
-
- # set up engine
- evaluator = SupervisedEvaluator(
- device=device,
- val_data_loader=dataloader,
- epoch_length=1,
- network=torch.nn.Identity(),
- non_blocking=False,
- prepare_batch=PrepareBatchDefault(),
- decollate=False,
- mode="eval",
- )
- evaluator.run()
- output = evaluator.state.output
- assert_allclose(output["image"], torch.tensor([1, 2], device=device))
- assert_allclose(output["label"], torch.tensor([3, 4], device=device))
-
- def test_empty_data(self):
- dataloader = []
- evaluator = SupervisedEvaluator(
- val_data_loader=dataloader,
- device=torch.device("cpu"),
- epoch_length=0,
- network=TestNet(),
- non_blocking=False,
- prepare_batch=PrepareBatchDefault(),
- decollate=False,
- )
- evaluator.run()
+ if should_run:
+ output = evaluator.state.output
+ if isinstance(dataloader[0], dict) or isinstance(dataloader[0], tuple):
+ assert_allclose(output["image"], torch.tensor([1, 2], device=device))
+ assert_allclose(output["label"], torch.tensor([3, 4], device=device))
+ else:
+ assert_allclose(output["image"], torch.tensor([1, 2], device=device))
+ self.assertTrue(output["label"] is None)
if __name__ == "__main__":
diff --git a/tests/test_prepare_batch_default_dist.py b/tests/test_prepare_batch_default_dist.py
index 0c53a74834..53a79575e6 100644
--- a/tests/test_prepare_batch_default_dist.py
+++ b/tests/test_prepare_batch_default_dist.py
@@ -43,6 +43,7 @@
class TestNet(torch.nn.Module):
+ __test__ = False # indicate to pytest that this class is not intended for collection
def forward(self, x: torch.Tensor):
return x
diff --git a/tests/test_prepare_batch_diffusion.py b/tests/test_prepare_batch_diffusion.py
new file mode 100644
index 0000000000..d969c06368
--- /dev/null
+++ b/tests/test_prepare_batch_diffusion.py
@@ -0,0 +1,104 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import unittest
+
+import torch
+from parameterized import parameterized
+
+from monai.engines import SupervisedEvaluator
+from monai.engines.utils import DiffusionPrepareBatch
+from monai.inferers import DiffusionInferer
+from monai.networks.nets import DiffusionModelUNet
+from monai.networks.schedulers import DDPMScheduler
+
+TEST_CASES = [
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "channels": [8],
+ "norm_num_groups": 8,
+ "attention_levels": [True],
+ "num_res_blocks": 1,
+ "num_head_channels": 8,
+ },
+ (2, 1, 8, 8),
+ ],
+ [
+ {
+ "spatial_dims": 3,
+ "in_channels": 1,
+ "out_channels": 1,
+ "channels": [8],
+ "norm_num_groups": 8,
+ "attention_levels": [True],
+ "num_res_blocks": 1,
+ "num_head_channels": 8,
+ },
+ (2, 1, 8, 8, 8),
+ ],
+]
+
+
+class TestPrepareBatchDiffusion(unittest.TestCase):
+ @parameterized.expand(TEST_CASES)
+ def test_output_sizes(self, input_args, image_size):
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ dataloader = [{"image": torch.randn(image_size).to(device)}]
+ scheduler = DDPMScheduler(num_train_timesteps=20)
+ inferer = DiffusionInferer(scheduler=scheduler)
+ network = DiffusionModelUNet(**input_args).to(device)
+ evaluator = SupervisedEvaluator(
+ device=device,
+ val_data_loader=dataloader,
+ epoch_length=1,
+ network=network,
+ inferer=inferer,
+ non_blocking=True,
+ prepare_batch=DiffusionPrepareBatch(num_train_timesteps=20),
+ decollate=False,
+ )
+ evaluator.run()
+ output = evaluator.state.output
+ # check shapes are the same
+ self.assertEqual(output["pred"].shape, image_size)
+ self.assertEqual(output["label"].shape, output["image"].shape)
+
+ @parameterized.expand(TEST_CASES)
+ def test_conditioning(self, input_args, image_size):
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ dataloader = [{"image": torch.randn(image_size).to(device), "context": torch.randn((2, 4, 3)).to(device)}]
+ scheduler = DDPMScheduler(num_train_timesteps=20)
+ inferer = DiffusionInferer(scheduler=scheduler)
+ network = DiffusionModelUNet(**input_args, with_conditioning=True, cross_attention_dim=3).to(device)
+ evaluator = SupervisedEvaluator(
+ device=device,
+ val_data_loader=dataloader,
+ epoch_length=1,
+ network=network,
+ inferer=inferer,
+ non_blocking=True,
+ prepare_batch=DiffusionPrepareBatch(num_train_timesteps=20, condition_name="context"),
+ decollate=False,
+ )
+ evaluator.run()
+ output = evaluator.state.output
+ # check shapes are the same
+ self.assertEqual(output["pred"].shape, image_size)
+ self.assertEqual(output["label"].shape, output["image"].shape)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_prepare_batch_extra_input.py b/tests/test_prepare_batch_extra_input.py
index f20c6e7352..3c53cc6481 100644
--- a/tests/test_prepare_batch_extra_input.py
+++ b/tests/test_prepare_batch_extra_input.py
@@ -36,6 +36,7 @@
class TestNet(torch.nn.Module):
+ __test__ = False # indicate to pytest that this class is not intended for collection
def forward(self, x: torch.Tensor, t1=None, t2=None, t3=None):
return {"x": x, "t1": t1, "t2": t2, "t3": t3}
diff --git a/tests/test_prepare_batch_hovernet.py b/tests/test_prepare_batch_hovernet.py
index 773fcb53bf..ae9554a3e8 100644
--- a/tests/test_prepare_batch_hovernet.py
+++ b/tests/test_prepare_batch_hovernet.py
@@ -28,6 +28,7 @@
class TestNet(torch.nn.Module):
+ __test__ = False # indicate to pytest that this class is not intended for collection
def forward(self, x: torch.Tensor):
return {HoVerNetBranch.NP: torch.tensor([1, 2]), HoVerNetBranch.NC: torch.tensor([4, 4]), HoVerNetBranch.HV: 16}
diff --git a/tests/test_profiling.py b/tests/test_profiling.py
index 6bee7ba262..649d980ebf 100644
--- a/tests/test_profiling.py
+++ b/tests/test_profiling.py
@@ -35,6 +35,7 @@ def setUp(self):
self.scale = mt.ScaleIntensity()
self.scale_call_name = "ScaleIntensity.__call__"
+ self.compose_call_name = "Compose.__call__"
self.test_comp = mt.Compose([mt.ScaleIntensity(), mt.RandAxisFlip(0.5)])
self.test_image = torch.rand(1, 16, 16, 16)
self.pid = os.getpid()
@@ -82,7 +83,7 @@ def test_profile_multithread(self):
self.assertSequenceEqual(batch.shape, (4, 1, 16, 16, 16))
results = wp.get_results()
- self.assertSequenceEqual(list(results), [self.scale_call_name])
+ self.assertSequenceEqual(list(results), [self.scale_call_name, self.compose_call_name])
prs = results[self.scale_call_name]
@@ -98,6 +99,7 @@ def test_profile_context(self):
self.scale(self.test_image)
results = wp.get_results()
+
self.assertSequenceEqual(set(results), {"ScaleIntensity.__call__", "context"})
prs = results["context"]
diff --git a/tests/test_rand_affine.py b/tests/test_rand_affine.py
index 23e3fd148c..2c827b7426 100644
--- a/tests/test_rand_affine.py
+++ b/tests/test_rand_affine.py
@@ -152,11 +152,10 @@ def test_rand_affine(self, input_param, input_data, expected_val):
self.assertTrue(g._cached_grid is not None)
assert_allclose(result, expected_val, rtol=_rtol, atol=1e-4, type_test="tensor")
- def test_ill_cache(self):
+ @parameterized.expand([(None,), ((1, 1, -1),)])
+ def test_ill_cache(self, spatial_size):
with self.assertWarns(UserWarning):
- RandAffine(cache_grid=True)
- with self.assertWarns(UserWarning):
- RandAffine(cache_grid=True, spatial_size=(1, 1, -1))
+ RandAffine(cache_grid=True, spatial_size=spatial_size)
@parameterized.expand(TEST_CASES_SKIPPED_CONSISTENCY)
def test_skipped_transform_consistency(self, im, in_dtype):
diff --git a/tests/test_rand_affined.py b/tests/test_rand_affined.py
index 32fde8dc0f..eb8ebd06c5 100644
--- a/tests/test_rand_affined.py
+++ b/tests/test_rand_affined.py
@@ -240,7 +240,7 @@ def test_rand_affined(self, input_param, input_data, expected_val, track_meta):
resampler.lazy = False
if input_param.get("cache_grid", False):
- self.assertTrue(g.rand_affine._cached_grid is not None)
+ self.assertIsNotNone(g.rand_affine._cached_grid)
for key in res:
if isinstance(key, str) and key.endswith("_transforms"):
continue
@@ -272,13 +272,10 @@ def test_rand_affined(self, input_param, input_data, expected_val, track_meta):
self.assertEqual(len(v.applied_operations), 0)
self.assertTupleEqual(v.shape, input_data[k].shape)
- def test_ill_cache(self):
+ @parameterized.expand([(None,), ((2, -1),)]) # spatial size is None # spatial size is dynamic
+ def test_ill_cache(self, spatial_size):
with self.assertWarns(UserWarning):
- # spatial size is None
- RandAffined(device=device, spatial_size=None, prob=1.0, cache_grid=True, keys=("img", "seg"))
- with self.assertWarns(UserWarning):
- # spatial size is dynamic
- RandAffined(device=device, spatial_size=(2, -1), prob=1.0, cache_grid=True, keys=("img", "seg"))
+ RandAffined(device=device, spatial_size=spatial_size, prob=1.0, cache_grid=True, keys=("img", "seg"))
if __name__ == "__main__":
diff --git a/tests/test_rand_bias_field.py b/tests/test_rand_bias_field.py
index 333a9ecba5..328f46b7ee 100644
--- a/tests/test_rand_bias_field.py
+++ b/tests/test_rand_bias_field.py
@@ -39,7 +39,7 @@ def test_output_shape(self, class_args, img_shape):
img = p(np.random.rand(*img_shape))
output = bias_field(img)
np.testing.assert_equal(output.shape, img_shape)
- self.assertTrue(output.dtype in (np.float32, torch.float32))
+ self.assertIn(output.dtype, (np.float32, torch.float32))
img_zero = np.zeros([*img_shape])
output_zero = bias_field(img_zero)
diff --git a/tests/test_rand_weighted_crop.py b/tests/test_rand_weighted_crop.py
index 47a8f3bfa2..f509065a56 100644
--- a/tests/test_rand_weighted_crop.py
+++ b/tests/test_rand_weighted_crop.py
@@ -90,6 +90,21 @@ def get_data(ndim):
[[63, 37], [31, 43], [66, 20]],
]
)
+ im = SEG1_2D
+ weight_map = np.zeros_like(im, dtype=np.int32)
+ weight_map[0, 30, 20] = 3
+ weight_map[0, 45, 44] = 1
+ weight_map[0, 60, 50] = 2
+ TESTS.append(
+ [
+ "int w 2d",
+ dict(spatial_size=(10, 12), num_samples=3),
+ p(im),
+ q(weight_map),
+ (1, 10, 12),
+ [[60, 50], [30, 20], [45, 44]],
+ ]
+ )
im = SEG1_3D
weight = np.zeros_like(im)
weight[0, 5, 30, 17] = 1.1
@@ -149,6 +164,21 @@ def get_data(ndim):
[[32, 24, 40], [32, 24, 40], [32, 24, 40]],
]
)
+ im = SEG1_3D
+ weight_map = np.zeros_like(im, dtype=np.int32)
+ weight_map[0, 6, 22, 19] = 4
+ weight_map[0, 8, 40, 31] = 2
+ weight_map[0, 13, 20, 24] = 3
+ TESTS.append(
+ [
+ "int w 3d",
+ dict(spatial_size=(8, 10, 12), num_samples=3),
+ p(im),
+ q(weight_map),
+ (1, 8, 10, 12),
+ [[13, 20, 24], [6, 22, 19], [8, 40, 31]],
+ ]
+ )
class TestRandWeightedCrop(CropTest):
diff --git a/tests/test_rand_weighted_cropd.py b/tests/test_rand_weighted_cropd.py
index 1524442f61..a1414df0ac 100644
--- a/tests/test_rand_weighted_cropd.py
+++ b/tests/test_rand_weighted_cropd.py
@@ -154,7 +154,7 @@ def test_rand_weighted_cropd(self, _, init_params, input_data, expected_shape, e
crop = RandWeightedCropd(**init_params)
crop.set_random_state(10)
result = crop(input_data)
- self.assertTrue(len(result) == init_params["num_samples"])
+ self.assertEqual(len(result), init_params["num_samples"])
_len = len(tuple(input_data.keys()))
self.assertTupleEqual(tuple(result[0].keys())[:_len], tuple(input_data.keys()))
diff --git a/tests/test_recon_net_utils.py b/tests/test_recon_net_utils.py
index 1815000777..48d3b59a17 100644
--- a/tests/test_recon_net_utils.py
+++ b/tests/test_recon_net_utils.py
@@ -64,7 +64,7 @@ def test_reshape_channel_complex(self, test_data):
def test_complex_normalize(self, test_data):
result, mean, std = complex_normalize(test_data)
result = result * std + mean
- self.assertTrue((((result - test_data) ** 2).mean() ** 0.5).item() < 1e-5)
+ self.assertLess((((result - test_data) ** 2).mean() ** 0.5).item(), 1e-5)
@parameterized.expand(TEST_PAD)
def test_pad(self, test_data):
diff --git a/tests/test_reg_loss_integration.py b/tests/test_reg_loss_integration.py
index e8f82eb0c2..8afc2da6ad 100644
--- a/tests/test_reg_loss_integration.py
+++ b/tests/test_reg_loss_integration.py
@@ -83,6 +83,9 @@ def forward(self, x):
# initialize a SGD optimizer
optimizer = optim.Adam(net.parameters(), lr=learning_rate)
+ # declare first for pylint
+ init_loss = None
+
# train the network
for it in range(max_iter):
# set the gradient to zero
@@ -99,7 +102,7 @@ def forward(self, x):
# backward pass
loss_val.backward()
optimizer.step()
- self.assertTrue(init_loss > loss_val, "loss did not decrease")
+ self.assertGreater(init_loss, loss_val, "loss did not decrease")
if __name__ == "__main__":
diff --git a/tests/test_regularization.py b/tests/test_regularization.py
index d381ea72ca..12d64637d5 100644
--- a/tests/test_regularization.py
+++ b/tests/test_regularization.py
@@ -13,20 +13,31 @@
import unittest
+import numpy as np
import torch
-from monai.transforms import CutMix, CutMixd, CutOut, MixUp, MixUpd
+from monai.transforms import CutMix, CutMixd, CutOut, CutOutd, MixUp, MixUpd
+from tests.utils import assert_allclose
class TestMixup(unittest.TestCase):
+
def test_mixup(self):
for dims in [2, 3]:
shape = (6, 3) + (32,) * dims
sample = torch.rand(*shape, dtype=torch.float32)
mixup = MixUp(6, 1.0)
+ mixup.set_random_state(seed=0)
output = mixup(sample)
+ np.random.seed(0)
+ # simulate the randomize() of transform
+ np.random.random()
+ weight = torch.from_numpy(np.random.beta(1.0, 1.0, 6)).type(torch.float32)
+ perm = np.random.permutation(6)
self.assertEqual(output.shape, sample.shape)
- self.assertTrue(any(not torch.allclose(sample, mixup(sample)) for _ in range(10)))
+ mixweight = weight[(Ellipsis,) + (None,) * (dims + 1)]
+ expected = mixweight * sample + (1 - mixweight) * sample[perm, ...]
+ assert_allclose(output, expected, type_test=False, atol=1e-7)
with self.assertRaises(ValueError):
MixUp(6, -0.5)
@@ -44,19 +55,32 @@ def test_mixupd(self):
t = torch.rand(*shape, dtype=torch.float32)
sample = {"a": t, "b": t}
mixup = MixUpd(["a", "b"], 6)
+ mixup.set_random_state(seed=0)
output = mixup(sample)
- self.assertTrue(torch.allclose(output["a"], output["b"]))
+ np.random.seed(0)
+ # simulate the randomize() of transform
+ np.random.random()
+ weight = torch.from_numpy(np.random.beta(1.0, 1.0, 6)).type(torch.float32)
+ perm = np.random.permutation(6)
+ self.assertEqual(output["a"].shape, sample["a"].shape)
+ mixweight = weight[(Ellipsis,) + (None,) * (dims + 1)]
+ expected = mixweight * sample["a"] + (1 - mixweight) * sample["a"][perm, ...]
+ assert_allclose(output["a"], expected, type_test=False, atol=1e-7)
+ assert_allclose(output["a"], output["b"], type_test=False, atol=1e-7)
+ # self.assertTrue(torch.allclose(output["a"], output["b"]))
with self.assertRaises(ValueError):
MixUpd(["k1", "k2"], 6, -0.5)
class TestCutMix(unittest.TestCase):
+
def test_cutmix(self):
for dims in [2, 3]:
shape = (6, 3) + (32,) * dims
sample = torch.rand(*shape, dtype=torch.float32)
cutmix = CutMix(6, 1.0)
+ cutmix.set_random_state(seed=0)
output = cutmix(sample)
self.assertEqual(output.shape, sample.shape)
self.assertTrue(any(not torch.allclose(sample, cutmix(sample)) for _ in range(10)))
@@ -68,22 +92,50 @@ def test_cutmixd(self):
label = torch.randint(0, 1, shape)
sample = {"a": t, "b": t, "lbl1": label, "lbl2": label}
cutmix = CutMixd(["a", "b"], 6, label_keys=("lbl1", "lbl2"))
+ cutmix.set_random_state(seed=123)
output = cutmix(sample)
- # croppings are different on each application
- self.assertTrue(not torch.allclose(output["a"], output["b"]))
# but mixing of labels is not affected by it
self.assertTrue(torch.allclose(output["lbl1"], output["lbl2"]))
class TestCutOut(unittest.TestCase):
+
def test_cutout(self):
for dims in [2, 3]:
shape = (6, 3) + (32,) * dims
sample = torch.rand(*shape, dtype=torch.float32)
cutout = CutOut(6, 1.0)
+ cutout.set_random_state(seed=123)
output = cutout(sample)
+ np.random.seed(123)
+ # simulate the randomize() of transform
+ np.random.random()
+ weight = torch.from_numpy(np.random.beta(1.0, 1.0, 6)).type(torch.float32)
+ perm = np.random.permutation(6)
+ coords = [torch.from_numpy(np.random.randint(0, d, size=(1,))) for d in sample.shape[2:]]
+ assert_allclose(weight, cutout._params[0])
+ assert_allclose(perm, cutout._params[1])
+ self.assertSequenceEqual(coords, cutout._params[2])
self.assertEqual(output.shape, sample.shape)
- self.assertTrue(any(not torch.allclose(sample, cutout(sample)) for _ in range(10)))
+
+ def test_cutoutd(self):
+ for dims in [2, 3]:
+ shape = (6, 3) + (32,) * dims
+ t = torch.rand(*shape, dtype=torch.float32)
+ sample = {"a": t, "b": t}
+ cutout = CutOutd(["a", "b"], 6, 1.0)
+ cutout.set_random_state(seed=123)
+ output = cutout(sample)
+ np.random.seed(123)
+ # simulate the randomize() of transform
+ np.random.random()
+ weight = torch.from_numpy(np.random.beta(1.0, 1.0, 6)).type(torch.float32)
+ perm = np.random.permutation(6)
+ coords = [torch.from_numpy(np.random.randint(0, d, size=(1,))) for d in t.shape[2:]]
+ assert_allclose(weight, cutout.cutout._params[0])
+ assert_allclose(perm, cutout.cutout._params[1])
+ self.assertSequenceEqual(coords, cutout.cutout._params[2])
+ self.assertEqual(output["a"].shape, sample["a"].shape)
if __name__ == "__main__":
diff --git a/tests/test_resnet.py b/tests/test_resnet.py
index ad1aad8fc6..a55d18f5de 100644
--- a/tests/test_resnet.py
+++ b/tests/test_resnet.py
@@ -24,6 +24,7 @@
from monai.networks import eval_mode
from monai.networks.nets import (
ResNet,
+ ResNetFeatures,
get_medicalnet_pretrained_resnet_args,
get_pretrained_resnet_medicalnet,
resnet10,
@@ -36,7 +37,14 @@
)
from monai.networks.nets.resnet import ResNetBlock
from monai.utils import optional_import
-from tests.utils import equal_state_dict, skip_if_downloading_fails, skip_if_no_cuda, skip_if_quick, test_script_save
+from tests.utils import (
+ SkipIfNoModule,
+ equal_state_dict,
+ skip_if_downloading_fails,
+ skip_if_no_cuda,
+ skip_if_quick,
+ test_script_save,
+)
if TYPE_CHECKING:
import torchvision
@@ -99,6 +107,7 @@
"num_classes": 3,
"conv1_t_size": [3],
"conv1_t_stride": 1,
+ "act": ("relu", {"inplace": False}),
},
(1, 2, 32),
(1, 3),
@@ -177,19 +186,60 @@
(1, 3),
]
+TEST_CASE_8 = [
+ {
+ "block": "bottleneck",
+ "layers": [3, 4, 6, 3],
+ "block_inplanes": [64, 128, 256, 512],
+ "spatial_dims": 1,
+ "n_input_channels": 2,
+ "num_classes": 3,
+ "conv1_t_size": [3],
+ "conv1_t_stride": 1,
+ "act": ("relu", {"inplace": False}),
+ },
+ (1, 2, 32),
+ (1, 3),
+]
+
+TEST_CASE_9 = [ # Layer norm
+ {
+ "block": ResNetBlock,
+ "layers": [3, 4, 6, 3],
+ "block_inplanes": [64, 128, 256, 512],
+ "spatial_dims": 1,
+ "n_input_channels": 2,
+ "num_classes": 3,
+ "conv1_t_size": [3],
+ "conv1_t_stride": 1,
+ "act": ("relu", {"inplace": False}),
+ "norm": ("layer", {"normalized_shape": (64, 32)}),
+ },
+ (1, 2, 32),
+ (1, 3),
+]
+
TEST_CASES = []
PRETRAINED_TEST_CASES = []
for case in [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_2_A, TEST_CASE_3_A]:
for model in [resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200]:
TEST_CASES.append([model, *case])
PRETRAINED_TEST_CASES.append([model, *case])
-for case in [TEST_CASE_5, TEST_CASE_5_A, TEST_CASE_6, TEST_CASE_7]:
+for case in [TEST_CASE_5, TEST_CASE_5_A, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8, TEST_CASE_9]:
TEST_CASES.append([ResNet, *case])
TEST_SCRIPT_CASES = [
[model, *TEST_CASE_1] for model in [resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200]
]
+CASE_EXTRACT_FEATURES = [
+ (
+ {"model_name": "resnet10", "pretrained": True, "spatial_dims": 3, "in_channels": 1},
+ [1, 1, 64, 64, 64],
+ ([1, 64, 32, 32, 32], [1, 64, 16, 16, 16], [1, 128, 8, 8, 8], [1, 256, 4, 4, 4], [1, 512, 2, 2, 2]),
+ )
+]
+
class TestResNet(unittest.TestCase):
@@ -211,12 +261,12 @@ def test_resnet_shape(self, model, input_param, input_shape, expected_shape):
if input_param.get("feed_forward", True):
self.assertEqual(result.shape, expected_shape)
else:
- self.assertTrue(result.shape in expected_shape)
+ self.assertIn(result.shape, expected_shape)
@parameterized.expand(PRETRAINED_TEST_CASES)
@skip_if_quick
@skip_if_no_cuda
- def test_resnet_pretrained(self, model, input_param, input_shape, expected_shape):
+ def test_resnet_pretrained(self, model, input_param, _input_shape, _expected_shape):
net = model(**input_param).to(device)
# Save ckpt
torch.save(net.state_dict(), self.tmp_ckpt_filename)
@@ -240,9 +290,7 @@ def test_resnet_pretrained(self, model, input_param, input_shape, expected_shape
and input_param.get("n_input_channels", 3) == 1
and input_param.get("feed_forward", True) is False
and input_param.get("shortcut_type", "B") == shortcut_type
- and (
- input_param.get("bias_downsample", True) == bool(bias_downsample) if bias_downsample != -1 else True
- )
+ and (input_param.get("bias_downsample", True) == bias_downsample)
):
model(**cp_input_param)
else:
@@ -253,7 +301,7 @@ def test_resnet_pretrained(self, model, input_param, input_shape, expected_shape
cp_input_param["n_input_channels"] = 1
cp_input_param["feed_forward"] = False
cp_input_param["shortcut_type"] = shortcut_type
- cp_input_param["bias_downsample"] = bool(bias_downsample) if bias_downsample != -1 else True
+ cp_input_param["bias_downsample"] = bias_downsample
if cp_input_param.get("spatial_dims", 3) == 3:
with skip_if_downloading_fails():
pretrained_net = model(**cp_input_param).to(device)
@@ -270,5 +318,25 @@ def test_script(self, model, input_param, input_shape, expected_shape):
test_script_save(net, test_data)
+@SkipIfNoModule("hf_hub_download")
+class TestExtractFeatures(unittest.TestCase):
+
+ @parameterized.expand(CASE_EXTRACT_FEATURES)
+ def test_shape(self, input_param, input_shape, expected_shapes):
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ with skip_if_downloading_fails():
+ net = ResNetFeatures(**input_param).to(device)
+
+ # run inference with random tensor
+ with eval_mode(net):
+ features = net(torch.randn(input_shape).to(device))
+
+ # check output shape
+ self.assertEqual(len(features), len(expected_shapes))
+ for feature, expected_shape in zip(features, expected_shapes):
+ self.assertEqual(feature.shape, torch.Size(expected_shape))
+
+
if __name__ == "__main__":
unittest.main()
diff --git a/tests/test_scale_intensity_range_percentiles.py b/tests/test_scale_intensity_range_percentiles.py
index 7c3a684a00..a7390efe72 100644
--- a/tests/test_scale_intensity_range_percentiles.py
+++ b/tests/test_scale_intensity_range_percentiles.py
@@ -14,6 +14,7 @@
import unittest
import numpy as np
+import torch
from monai.transforms.intensity.array import ScaleIntensityRangePercentiles
from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose
@@ -34,6 +35,7 @@ def test_scaling(self):
scaler = ScaleIntensityRangePercentiles(lower=lower, upper=upper, b_min=b_min, b_max=b_max, dtype=np.uint8)
for p in TEST_NDARRAYS:
result = scaler(p(img))
+ self.assertEqual(result.dtype, torch.uint8)
assert_allclose(result, p(expected), type_test="tensor", rtol=1e-4)
def test_relative_scaling(self):
diff --git a/tests/test_scheduler_ddim.py b/tests/test_scheduler_ddim.py
new file mode 100644
index 0000000000..1a8f8cab67
--- /dev/null
+++ b/tests/test_scheduler_ddim.py
@@ -0,0 +1,83 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import unittest
+
+import torch
+from parameterized import parameterized
+
+from monai.networks.schedulers import DDIMScheduler
+from tests.utils import assert_allclose
+
+TEST_2D_CASE = []
+for beta_schedule in ["linear_beta", "scaled_linear_beta"]:
+ TEST_2D_CASE.append([{"schedule": beta_schedule}, (2, 6, 16, 16), (2, 6, 16, 16)])
+
+TEST_3D_CASE = []
+for beta_schedule in ["linear_beta", "scaled_linear_beta"]:
+ TEST_3D_CASE.append([{"schedule": beta_schedule}, (2, 6, 16, 16, 16), (2, 6, 16, 16, 16)])
+
+TEST_CASES = TEST_2D_CASE + TEST_3D_CASE
+
+TEST_FULl_LOOP = [
+ [{"schedule": "linear_beta"}, (1, 1, 2, 2), torch.Tensor([[[[-0.9579, -0.6457], [0.4684, -0.9694]]]])]
+]
+
+
+class TestDDPMScheduler(unittest.TestCase):
+ @parameterized.expand(TEST_CASES)
+ def test_add_noise(self, input_param, input_shape, expected_shape):
+ scheduler = DDIMScheduler(**input_param)
+ scheduler.set_timesteps(num_inference_steps=100)
+ original_sample = torch.zeros(input_shape)
+ noise = torch.randn_like(original_sample)
+ timesteps = torch.randint(0, scheduler.num_train_timesteps, (original_sample.shape[0],)).long()
+
+ noisy = scheduler.add_noise(original_samples=original_sample, noise=noise, timesteps=timesteps)
+ self.assertEqual(noisy.shape, expected_shape)
+
+ @parameterized.expand(TEST_CASES)
+ def test_step_shape(self, input_param, input_shape, expected_shape):
+ scheduler = DDIMScheduler(**input_param)
+ scheduler.set_timesteps(num_inference_steps=100)
+ model_output = torch.randn(input_shape)
+ sample = torch.randn(input_shape)
+ output_step = scheduler.step(model_output=model_output, timestep=500, sample=sample)
+ self.assertEqual(output_step[0].shape, expected_shape)
+ self.assertEqual(output_step[1].shape, expected_shape)
+
+ @parameterized.expand(TEST_FULl_LOOP)
+ def test_full_timestep_loop(self, input_param, input_shape, expected_output):
+ scheduler = DDIMScheduler(**input_param)
+ scheduler.set_timesteps(50)
+ torch.manual_seed(42)
+ model_output = torch.randn(input_shape)
+ sample = torch.randn(input_shape)
+ for t in range(50):
+ sample, _ = scheduler.step(model_output=model_output, timestep=t, sample=sample)
+ assert_allclose(sample, expected_output, rtol=1e-3, atol=1e-3)
+
+ def test_set_timesteps(self):
+ scheduler = DDIMScheduler(num_train_timesteps=1000)
+ scheduler.set_timesteps(num_inference_steps=100)
+ self.assertEqual(scheduler.num_inference_steps, 100)
+ self.assertEqual(len(scheduler.timesteps), 100)
+
+ def test_set_timesteps_with_num_inference_steps_bigger_than_num_train_timesteps(self):
+ scheduler = DDIMScheduler(num_train_timesteps=1000)
+ with self.assertRaises(ValueError):
+ scheduler.set_timesteps(num_inference_steps=2000)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_scheduler_ddpm.py b/tests/test_scheduler_ddpm.py
new file mode 100644
index 0000000000..f0447aded2
--- /dev/null
+++ b/tests/test_scheduler_ddpm.py
@@ -0,0 +1,104 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import unittest
+
+import torch
+from parameterized import parameterized
+
+from monai.networks.schedulers import DDPMScheduler
+from tests.utils import assert_allclose
+
+TEST_2D_CASE = []
+for beta_schedule in ["linear_beta", "scaled_linear_beta"]:
+ for variance_type in ["fixed_small", "fixed_large"]:
+ TEST_2D_CASE.append(
+ [{"schedule": beta_schedule, "variance_type": variance_type}, (2, 6, 16, 16), (2, 6, 16, 16)]
+ )
+
+TEST_3D_CASE = []
+for beta_schedule in ["linear_beta", "scaled_linear_beta"]:
+ for variance_type in ["fixed_small", "fixed_large"]:
+ TEST_3D_CASE.append(
+ [{"schedule": beta_schedule, "variance_type": variance_type}, (2, 6, 16, 16, 16), (2, 6, 16, 16, 16)]
+ )
+
+TEST_CASES = TEST_2D_CASE + TEST_3D_CASE
+
+TEST_FULl_LOOP = [
+ [{"schedule": "linear_beta"}, (1, 1, 2, 2), torch.Tensor([[[[-1.0153, -0.3218], [0.8454, -0.7870]]]])]
+]
+
+
+class TestDDPMScheduler(unittest.TestCase):
+ @parameterized.expand(TEST_CASES)
+ def test_add_noise(self, input_param, input_shape, expected_shape):
+ scheduler = DDPMScheduler(**input_param)
+ original_sample = torch.zeros(input_shape)
+ noise = torch.randn_like(original_sample)
+ timesteps = torch.randint(0, scheduler.num_train_timesteps, (original_sample.shape[0],)).long()
+
+ noisy = scheduler.add_noise(original_samples=original_sample, noise=noise, timesteps=timesteps)
+ self.assertEqual(noisy.shape, expected_shape)
+
+ @parameterized.expand(TEST_CASES)
+ def test_step_shape(self, input_param, input_shape, expected_shape):
+ scheduler = DDPMScheduler(**input_param)
+ model_output = torch.randn(input_shape)
+ sample = torch.randn(input_shape)
+ output_step = scheduler.step(model_output=model_output, timestep=500, sample=sample)
+ self.assertEqual(output_step[0].shape, expected_shape)
+ self.assertEqual(output_step[1].shape, expected_shape)
+
+ @parameterized.expand(TEST_FULl_LOOP)
+ def test_full_timestep_loop(self, input_param, input_shape, expected_output):
+ scheduler = DDPMScheduler(**input_param)
+ scheduler.set_timesteps(50)
+ torch.manual_seed(42)
+ model_output = torch.randn(input_shape)
+ sample = torch.randn(input_shape)
+ for t in range(50):
+ sample, _ = scheduler.step(model_output=model_output, timestep=t, sample=sample)
+ assert_allclose(sample, expected_output, rtol=1e-3, atol=1e-3)
+
+ @parameterized.expand(TEST_CASES)
+ def test_get_velocity_shape(self, input_param, input_shape, expected_shape):
+ scheduler = DDPMScheduler(**input_param)
+ sample = torch.randn(input_shape)
+ timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],)).long()
+ velocity = scheduler.get_velocity(sample=sample, noise=sample, timesteps=timesteps)
+ self.assertEqual(velocity.shape, expected_shape)
+
+ def test_step_learned(self):
+ for variance_type in ["learned", "learned_range"]:
+ scheduler = DDPMScheduler(variance_type=variance_type)
+ model_output = torch.randn(2, 6, 16, 16)
+ sample = torch.randn(2, 3, 16, 16)
+ output_step = scheduler.step(model_output=model_output, timestep=500, sample=sample)
+ self.assertEqual(output_step[0].shape, sample.shape)
+ self.assertEqual(output_step[1].shape, sample.shape)
+
+ def test_set_timesteps(self):
+ scheduler = DDPMScheduler(num_train_timesteps=1000)
+ scheduler.set_timesteps(num_inference_steps=100)
+ self.assertEqual(scheduler.num_inference_steps, 100)
+ self.assertEqual(len(scheduler.timesteps), 100)
+
+ def test_set_timesteps_with_num_inference_steps_bigger_than_num_train_timesteps(self):
+ scheduler = DDPMScheduler(num_train_timesteps=1000)
+ with self.assertRaises(ValueError):
+ scheduler.set_timesteps(num_inference_steps=2000)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_scheduler_pndm.py b/tests/test_scheduler_pndm.py
new file mode 100644
index 0000000000..69e5e403f5
--- /dev/null
+++ b/tests/test_scheduler_pndm.py
@@ -0,0 +1,108 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import unittest
+
+import torch
+from parameterized import parameterized
+
+from monai.networks.schedulers import PNDMScheduler
+from tests.utils import assert_allclose
+
+TEST_2D_CASE = []
+for beta_schedule in ["linear_beta", "scaled_linear_beta"]:
+ TEST_2D_CASE.append([{"schedule": beta_schedule}, (2, 6, 16, 16), (2, 6, 16, 16)])
+
+TEST_3D_CASE = []
+for beta_schedule in ["linear_beta", "scaled_linear_beta"]:
+ TEST_3D_CASE.append([{"schedule": beta_schedule}, (2, 6, 16, 16, 16), (2, 6, 16, 16, 16)])
+
+TEST_CASES = TEST_2D_CASE + TEST_3D_CASE
+
+TEST_FULl_LOOP = [
+ [
+ {"schedule": "linear_beta"},
+ (1, 1, 2, 2),
+ torch.Tensor([[[[-2123055.2500, -459014.2812], [2863438.0000, -1263401.7500]]]]),
+ ]
+]
+
+
+class TestDDPMScheduler(unittest.TestCase):
+ @parameterized.expand(TEST_CASES)
+ def test_add_noise(self, input_param, input_shape, expected_shape):
+ scheduler = PNDMScheduler(**input_param)
+ original_sample = torch.zeros(input_shape)
+ noise = torch.randn_like(original_sample)
+ timesteps = torch.randint(0, scheduler.num_train_timesteps, (original_sample.shape[0],)).long()
+ noisy = scheduler.add_noise(original_samples=original_sample, noise=noise, timesteps=timesteps)
+ self.assertEqual(noisy.shape, expected_shape)
+
+ @parameterized.expand(TEST_CASES)
+ def test_step_shape(self, input_param, input_shape, expected_shape):
+ scheduler = PNDMScheduler(**input_param)
+ scheduler.set_timesteps(600)
+ model_output = torch.randn(input_shape)
+ sample = torch.randn(input_shape)
+ output_step = scheduler.step(model_output=model_output, timestep=500, sample=sample)
+ self.assertEqual(output_step[0].shape, expected_shape)
+ self.assertEqual(output_step[1], None)
+
+ @parameterized.expand(TEST_FULl_LOOP)
+ def test_full_timestep_loop(self, input_param, input_shape, expected_output):
+ scheduler = PNDMScheduler(**input_param)
+ scheduler.set_timesteps(50)
+ torch.manual_seed(42)
+ model_output = torch.randn(input_shape)
+ sample = torch.randn(input_shape)
+ for t in range(50):
+ sample, _ = scheduler.step(model_output=model_output, timestep=t, sample=sample)
+ assert_allclose(sample, expected_output, rtol=1e-3, atol=1e-3)
+
+ @parameterized.expand(TEST_FULl_LOOP)
+ def test_timestep_two_loops(self, input_param, input_shape, expected_output):
+ scheduler = PNDMScheduler(**input_param)
+ scheduler.set_timesteps(50)
+ torch.manual_seed(42)
+ model_output = torch.randn(input_shape)
+ sample = torch.randn(input_shape)
+ for t in range(50):
+ sample, _ = scheduler.step(model_output=model_output, timestep=t, sample=sample)
+ torch.manual_seed(42)
+ model_output2 = torch.randn(input_shape)
+ sample2 = torch.randn(input_shape)
+ scheduler.set_timesteps(50)
+ for t in range(50):
+ sample2, _ = scheduler.step(model_output=model_output2, timestep=t, sample=sample2)
+ assert_allclose(sample, sample2, rtol=1e-3, atol=1e-3)
+
+ def test_set_timesteps(self):
+ scheduler = PNDMScheduler(num_train_timesteps=1000, skip_prk_steps=True)
+ scheduler.set_timesteps(num_inference_steps=100)
+ self.assertEqual(scheduler.num_inference_steps, 100)
+ self.assertEqual(len(scheduler.timesteps), 100)
+
+ def test_set_timesteps_prk(self):
+ scheduler = PNDMScheduler(num_train_timesteps=1000, skip_prk_steps=False)
+ scheduler.set_timesteps(num_inference_steps=100)
+ self.assertEqual(scheduler.num_inference_steps, 109)
+ self.assertEqual(len(scheduler.timesteps), 109)
+
+ def test_set_timesteps_with_num_inference_steps_bigger_than_num_train_timesteps(self):
+ scheduler = PNDMScheduler(num_train_timesteps=1000)
+ with self.assertRaises(ValueError):
+ scheduler.set_timesteps(num_inference_steps=2000)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_segresnet_ds.py b/tests/test_segresnet_ds.py
index 5372fcc8ae..eab7bac9a0 100644
--- a/tests/test_segresnet_ds.py
+++ b/tests/test_segresnet_ds.py
@@ -17,7 +17,7 @@
from parameterized import parameterized
from monai.networks import eval_mode
-from monai.networks.nets import SegResNetDS
+from monai.networks.nets import SegResNetDS, SegResNetDS2
from tests.utils import SkipIfBeforePyTorchVersion, test_script_save
device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -71,7 +71,7 @@
]
-class TestResNetDS(unittest.TestCase):
+class TestSegResNetDS(unittest.TestCase):
@parameterized.expand(TEST_CASE_SEGRESNET_DS)
def test_shape(self, input_param, input_shape, expected_shape):
@@ -80,47 +80,71 @@ def test_shape(self, input_param, input_shape, expected_shape):
result = net(torch.randn(input_shape).to(device))
self.assertEqual(result.shape, expected_shape, msg=str(input_param))
+ @parameterized.expand(TEST_CASE_SEGRESNET_DS)
+ def test_shape_ds2(self, input_param, input_shape, expected_shape):
+ net = SegResNetDS2(**input_param).to(device)
+ with eval_mode(net):
+ result = net(torch.randn(input_shape).to(device), with_label=False)
+ self.assertEqual(result[0].shape, expected_shape, msg=str(input_param))
+ self.assertTrue(result[1] == [])
+
+ result = net(torch.randn(input_shape).to(device), with_point=False)
+ self.assertEqual(result[1].shape, expected_shape, msg=str(input_param))
+ self.assertTrue(result[0] == [])
+
@parameterized.expand(TEST_CASE_SEGRESNET_DS2)
def test_shape2(self, input_param, input_shape, expected_shape):
dsdepth = input_param.get("dsdepth", 1)
- net = SegResNetDS(**input_param).to(device)
-
- net.train()
- result = net(torch.randn(input_shape).to(device))
- if dsdepth > 1:
- assert isinstance(result, list)
- self.assertEqual(dsdepth, len(result))
- for i in range(dsdepth):
- self.assertEqual(
- result[i].shape,
- expected_shape[:2] + tuple(e // (2**i) for e in expected_shape[2:]),
- msg=str(input_param),
- )
- else:
- assert isinstance(result, torch.Tensor)
- self.assertEqual(result.shape, expected_shape, msg=str(input_param))
-
- net.eval()
- result = net(torch.randn(input_shape).to(device))
- assert isinstance(result, torch.Tensor)
- self.assertEqual(result.shape, expected_shape, msg=str(input_param))
+ for net in [SegResNetDS, SegResNetDS2]:
+ net = net(**input_param).to(device)
+ net.train()
+ if isinstance(net, SegResNetDS2):
+ result = net(torch.randn(input_shape).to(device), with_label=False)[0]
+ else:
+ result = net(torch.randn(input_shape).to(device))
+ if dsdepth > 1:
+ assert isinstance(result, list)
+ self.assertEqual(dsdepth, len(result))
+ for i in range(dsdepth):
+ self.assertEqual(
+ result[i].shape,
+ expected_shape[:2] + tuple(e // (2**i) for e in expected_shape[2:]),
+ msg=str(input_param),
+ )
+ else:
+ assert isinstance(result, torch.Tensor)
+ self.assertEqual(result.shape, expected_shape, msg=str(input_param))
+
+ if not isinstance(net, SegResNetDS2):
+ # eval mode of SegResNetDS2 has same output as training mode
+ # so only test eval mode for SegResNetDS
+ net.eval()
+ result = net(torch.randn(input_shape).to(device))
+ assert isinstance(result, torch.Tensor)
+ self.assertEqual(result.shape, expected_shape, msg=str(input_param))
@parameterized.expand(TEST_CASE_SEGRESNET_DS3)
def test_shape3(self, input_param, input_shape, expected_shapes):
dsdepth = input_param.get("dsdepth", 1)
- net = SegResNetDS(**input_param).to(device)
-
- net.train()
- result = net(torch.randn(input_shape).to(device))
- assert isinstance(result, list)
- self.assertEqual(dsdepth, len(result))
- for i in range(dsdepth):
- self.assertEqual(result[i].shape, expected_shapes[i], msg=str(input_param))
+ for net in [SegResNetDS, SegResNetDS2]:
+ net = net(**input_param).to(device)
+ net.train()
+ if isinstance(net, SegResNetDS2):
+ result = net(torch.randn(input_shape).to(device), with_point=False)[1]
+ else:
+ result = net(torch.randn(input_shape).to(device))
+ assert isinstance(result, list)
+ self.assertEqual(dsdepth, len(result))
+ for i in range(dsdepth):
+ self.assertEqual(result[i].shape, expected_shapes[i], msg=str(input_param))
def test_ill_arg(self):
with self.assertRaises(ValueError):
SegResNetDS(spatial_dims=4)
+ with self.assertRaises(ValueError):
+ SegResNetDS2(spatial_dims=4)
+
@SkipIfBeforePyTorchVersion((1, 10))
def test_script(self):
input_param, input_shape, _ = TEST_CASE_SEGRESNET_DS[0]
diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py
index b8be4fd1b6..88919fd8b1 100644
--- a/tests/test_selfattention.py
+++ b/tests/test_selfattention.py
@@ -20,7 +20,9 @@
from monai.networks import eval_mode
from monai.networks.blocks.selfattention import SABlock
+from monai.networks.layers.factories import RelPosEmbedding
from monai.utils import optional_import
+from tests.utils import SkipIfBeforePyTorchVersion, assert_allclose, test_script_save
einops, has_einops = optional_import("einops")
@@ -28,18 +30,32 @@
for dropout_rate in np.linspace(0, 1, 4):
for hidden_size in [360, 480, 600, 768]:
for num_heads in [4, 6, 8, 12]:
- test_case = [
- {"hidden_size": hidden_size, "num_heads": num_heads, "dropout_rate": dropout_rate},
- (2, 512, hidden_size),
- (2, 512, hidden_size),
- ]
- TEST_CASE_SABLOCK.append(test_case)
+ for rel_pos_embedding in [None, RelPosEmbedding.DECOMPOSED]:
+ for input_size in [(16, 32), (8, 8, 8)]:
+ for include_fc in [True, False]:
+ for use_combined_linear in [True, False]:
+ test_case = [
+ {
+ "hidden_size": hidden_size,
+ "num_heads": num_heads,
+ "dropout_rate": dropout_rate,
+ "rel_pos_embedding": rel_pos_embedding,
+ "input_size": input_size,
+ "include_fc": include_fc,
+ "use_combined_linear": use_combined_linear,
+ "use_flash_attention": True if rel_pos_embedding is None else False,
+ },
+ (2, 512, hidden_size),
+ (2, 512, hidden_size),
+ ]
+ TEST_CASE_SABLOCK.append(test_case)
class TestResBlock(unittest.TestCase):
@parameterized.expand(TEST_CASE_SABLOCK)
@skipUnless(has_einops, "Requires einops")
+ @SkipIfBeforePyTorchVersion((2, 0))
def test_shape(self, input_param, input_shape, expected_shape):
net = SABlock(**input_param)
with eval_mode(net):
@@ -53,6 +69,60 @@ def test_ill_arg(self):
with self.assertRaises(ValueError):
SABlock(hidden_size=620, num_heads=8, dropout_rate=0.4)
+ @SkipIfBeforePyTorchVersion((2, 0))
+ def test_rel_pos_embedding_with_flash_attention(self):
+ with self.assertRaises(ValueError):
+ SABlock(
+ hidden_size=128,
+ num_heads=3,
+ dropout_rate=0.1,
+ use_flash_attention=True,
+ save_attn=False,
+ rel_pos_embedding=RelPosEmbedding.DECOMPOSED,
+ )
+
+ @SkipIfBeforePyTorchVersion((1, 13))
+ def test_save_attn_with_flash_attention(self):
+ with self.assertRaises(ValueError):
+ SABlock(hidden_size=128, num_heads=3, dropout_rate=0.1, use_flash_attention=True, save_attn=True)
+
+ def test_attention_dim_not_multiple_of_heads(self):
+ with self.assertRaises(ValueError):
+ SABlock(hidden_size=128, num_heads=3, dropout_rate=0.1)
+
+ @skipUnless(has_einops, "Requires einops")
+ def test_inner_dim_different(self):
+ SABlock(hidden_size=128, num_heads=4, dropout_rate=0.1, dim_head=30)
+
+ def test_causal_no_sequence_length(self):
+ with self.assertRaises(ValueError):
+ SABlock(hidden_size=128, num_heads=4, dropout_rate=0.1, causal=True)
+
+ @skipUnless(has_einops, "Requires einops")
+ @SkipIfBeforePyTorchVersion((2, 0))
+ def test_causal_flash_attention(self):
+ block = SABlock(
+ hidden_size=128,
+ num_heads=1,
+ dropout_rate=0.1,
+ causal=True,
+ sequence_length=16,
+ save_attn=False,
+ use_flash_attention=True,
+ )
+ input_shape = (1, 16, 128)
+ # Check it runs correctly
+ block(torch.randn(input_shape))
+
+ @skipUnless(has_einops, "Requires einops")
+ def test_causal(self):
+ block = SABlock(hidden_size=128, num_heads=1, dropout_rate=0.1, causal=True, sequence_length=16, save_attn=True)
+ input_shape = (1, 16, 128)
+ block(torch.randn(input_shape))
+ # check upper triangular part of the attention matrix is zero
+ assert torch.triu(block.att_mat, diagonal=1).sum() == 0
+
+ @skipUnless(has_einops, "Requires einops")
def test_access_attn_matrix(self):
# input format
hidden_size = 128
@@ -74,6 +144,73 @@ def test_access_attn_matrix(self):
matrix_acess_blk(torch.randn(input_shape))
assert matrix_acess_blk.att_mat.shape == (input_shape[0], input_shape[0], input_shape[1], input_shape[1])
+ def test_number_of_parameters(self):
+
+ def count_sablock_params(*args, **kwargs):
+ """Count the number of parameters in a SABlock."""
+ sablock = SABlock(*args, **kwargs)
+ return sum([x.numel() for x in sablock.parameters() if x.requires_grad])
+
+ hidden_size = 128
+ num_heads = 8
+ default_dim_head = hidden_size // num_heads
+
+ # Default dim_head is hidden_size // num_heads
+ nparams_default = count_sablock_params(hidden_size=hidden_size, num_heads=num_heads)
+ nparams_like_default = count_sablock_params(
+ hidden_size=hidden_size, num_heads=num_heads, dim_head=default_dim_head
+ )
+ self.assertEqual(nparams_default, nparams_like_default)
+
+ # Increasing dim_head should increase the number of parameters
+ nparams_custom_large = count_sablock_params(
+ hidden_size=hidden_size, num_heads=num_heads, dim_head=default_dim_head * 2
+ )
+ self.assertGreater(nparams_custom_large, nparams_default)
+
+ # Decreasing dim_head should decrease the number of parameters
+ nparams_custom_small = count_sablock_params(
+ hidden_size=hidden_size, num_heads=num_heads, dim_head=default_dim_head // 2
+ )
+ self.assertGreater(nparams_default, nparams_custom_small)
+
+ # Increasing the number of heads with the default behaviour should not change the number of params.
+ nparams_default_more_heads = count_sablock_params(hidden_size=hidden_size, num_heads=num_heads * 2)
+ self.assertEqual(nparams_default, nparams_default_more_heads)
+
+ @parameterized.expand([[True, False], [True, True], [False, True], [False, False]])
+ @skipUnless(has_einops, "Requires einops")
+ @SkipIfBeforePyTorchVersion((2, 0))
+ def test_script(self, include_fc, use_combined_linear):
+ input_param = {
+ "hidden_size": 360,
+ "num_heads": 4,
+ "dropout_rate": 0.0,
+ "rel_pos_embedding": None,
+ "input_size": (16, 32),
+ "include_fc": include_fc,
+ "use_combined_linear": use_combined_linear,
+ }
+ net = SABlock(**input_param)
+ input_shape = (2, 512, 360)
+ test_data = torch.randn(input_shape)
+ test_script_save(net, test_data)
+
+ @skipUnless(has_einops, "Requires einops")
+ @SkipIfBeforePyTorchVersion((2, 0))
+ def test_flash_attention(self):
+ for causal in [True, False]:
+ input_param = {"hidden_size": 360, "num_heads": 4, "input_size": (16, 32), "causal": causal}
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
+ block_w_flash_attention = SABlock(**input_param, use_flash_attention=True).to(device)
+ block_wo_flash_attention = SABlock(**input_param, use_flash_attention=False).to(device)
+ block_wo_flash_attention.load_state_dict(block_w_flash_attention.state_dict())
+ test_data = torch.randn(2, 512, 360).to(device)
+
+ out_1 = block_w_flash_attention(test_data)
+ out_2 = block_wo_flash_attention(test_data)
+ assert_allclose(out_1, out_2, atol=1e-4)
+
if __name__ == "__main__":
unittest.main()
diff --git a/tests/test_set_visible_devices.py b/tests/test_set_visible_devices.py
index 7860656b3d..b4f44957a2 100644
--- a/tests/test_set_visible_devices.py
+++ b/tests/test_set_visible_devices.py
@@ -14,7 +14,7 @@
import os
import unittest
-from tests.utils import skip_if_no_cuda
+from tests.utils import SkipIfAtLeastPyTorchVersion, skip_if_no_cuda
class TestVisibleDevices(unittest.TestCase):
@@ -25,6 +25,7 @@ def run_process_and_get_exit_code(code_to_execute):
return int(bin(value).replace("0b", "").rjust(16, "0")[:8], 2)
@skip_if_no_cuda
+ @SkipIfAtLeastPyTorchVersion((2, 2, 1))
def test_visible_devices(self):
num_gpus_before = self.run_process_and_get_exit_code(
'python -c "import os; import torch; '
diff --git a/tests/test_signal_fillempty.py b/tests/test_signal_fillempty.py
index a3ee623cc5..2be4bd8600 100644
--- a/tests/test_signal_fillempty.py
+++ b/tests/test_signal_fillempty.py
@@ -30,7 +30,7 @@ class TestSignalFillEmptyNumpy(unittest.TestCase):
def test_correct_parameters_multi_channels(self):
self.assertIsInstance(SignalFillEmpty(replacement=0.0), SignalFillEmpty)
sig = np.load(TEST_SIGNAL)
- sig[:, 123] = np.NAN
+ sig[:, 123] = np.nan
fillempty = SignalFillEmpty(replacement=0.0)
fillemptysignal = fillempty(sig)
self.assertTrue(not np.isnan(fillemptysignal).any())
@@ -42,7 +42,7 @@ class TestSignalFillEmptyTorch(unittest.TestCase):
def test_correct_parameters_multi_channels(self):
self.assertIsInstance(SignalFillEmpty(replacement=0.0), SignalFillEmpty)
sig = convert_to_tensor(np.load(TEST_SIGNAL))
- sig[:, 123] = convert_to_tensor(np.NAN)
+ sig[:, 123] = convert_to_tensor(np.nan)
fillempty = SignalFillEmpty(replacement=0.0)
fillemptysignal = fillempty(sig)
self.assertTrue(not torch.isnan(fillemptysignal).any())
diff --git a/tests/test_signal_fillemptyd.py b/tests/test_signal_fillemptyd.py
index ee8c571ef8..7710279495 100644
--- a/tests/test_signal_fillemptyd.py
+++ b/tests/test_signal_fillemptyd.py
@@ -30,7 +30,7 @@ class TestSignalFillEmptyNumpy(unittest.TestCase):
def test_correct_parameters_multi_channels(self):
self.assertIsInstance(SignalFillEmptyd(replacement=0.0), SignalFillEmptyd)
sig = np.load(TEST_SIGNAL)
- sig[:, 123] = np.NAN
+ sig[:, 123] = np.nan
data = {}
data["signal"] = sig
fillempty = SignalFillEmptyd(keys=("signal",), replacement=0.0)
@@ -46,7 +46,7 @@ class TestSignalFillEmptyTorch(unittest.TestCase):
def test_correct_parameters_multi_channels(self):
self.assertIsInstance(SignalFillEmptyd(replacement=0.0), SignalFillEmptyd)
sig = convert_to_tensor(np.load(TEST_SIGNAL))
- sig[:, 123] = convert_to_tensor(np.NAN)
+ sig[:, 123] = convert_to_tensor(np.nan)
data = {}
data["signal"] = sig
fillempty = SignalFillEmptyd(keys=("signal",), replacement=0.0)
diff --git a/tests/test_sobel_gradient.py b/tests/test_sobel_gradient.py
index 3d995a60c9..a0d7cf5a8b 100644
--- a/tests/test_sobel_gradient.py
+++ b/tests/test_sobel_gradient.py
@@ -164,8 +164,8 @@ def test_sobel_gradients(self, image, arguments, expected_grad):
)
def test_sobel_kernels(self, arguments, expected_kernels):
sobel = SobelGradients(**arguments)
- self.assertTrue(sobel.kernel_diff.dtype == expected_kernels[0].dtype)
- self.assertTrue(sobel.kernel_smooth.dtype == expected_kernels[0].dtype)
+ self.assertEqual(sobel.kernel_diff.dtype, expected_kernels[0].dtype)
+ self.assertEqual(sobel.kernel_smooth.dtype, expected_kernels[0].dtype)
assert_allclose(sobel.kernel_diff, expected_kernels[0])
assert_allclose(sobel.kernel_smooth, expected_kernels[1])
diff --git a/tests/test_sobel_gradientd.py b/tests/test_sobel_gradientd.py
index 7499a0410b..03524823a5 100644
--- a/tests/test_sobel_gradientd.py
+++ b/tests/test_sobel_gradientd.py
@@ -187,8 +187,8 @@ def test_sobel_gradients(self, image_dict, arguments, expected_grad):
)
def test_sobel_kernels(self, arguments, expected_kernels):
sobel = SobelGradientsd(**arguments)
- self.assertTrue(sobel.kernel_diff.dtype == expected_kernels[0].dtype)
- self.assertTrue(sobel.kernel_smooth.dtype == expected_kernels[0].dtype)
+ self.assertEqual(sobel.kernel_diff.dtype, expected_kernels[0].dtype)
+ self.assertEqual(sobel.kernel_smooth.dtype, expected_kernels[0].dtype)
assert_allclose(sobel.kernel_diff, expected_kernels[0])
assert_allclose(sobel.kernel_smooth, expected_kernels[1])
diff --git a/tests/test_soft_clip.py b/tests/test_soft_clip.py
new file mode 100644
index 0000000000..de5122e982
--- /dev/null
+++ b/tests/test_soft_clip.py
@@ -0,0 +1,125 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import unittest
+
+import numpy as np
+import torch
+from parameterized import parameterized
+
+from monai.transforms.utils import soft_clip
+
+TEST_CASES = [
+ [
+ {"minv": 2, "maxv": 8, "sharpness_factor": 10},
+ {
+ "input": torch.arange(10).float(),
+ "clipped": torch.tensor([2.0000, 2.0000, 2.0693, 3.0000, 4.0000, 5.0000, 6.0000, 7.0000, 7.9307, 8.0000]),
+ },
+ ],
+ [
+ {"minv": 2, "maxv": None, "sharpness_factor": 10},
+ {
+ "input": torch.arange(10).float(),
+ "clipped": torch.tensor([2.0000, 2.0000, 2.0693, 3.0000, 4.0000, 5.0000, 6.0000, 7.0000, 8.0000, 9.0000]),
+ },
+ ],
+ [
+ {"minv": None, "maxv": 7, "sharpness_factor": 10},
+ {
+ "input": torch.arange(10).float(),
+ "clipped": torch.tensor([0.0000, 1.0000, 2.0000, 3.0000, 4.0000, 5.0000, 6.0000, 6.9307, 7.0000, 7.0000]),
+ },
+ ],
+ [
+ {"minv": 2, "maxv": 8, "sharpness_factor": 1.0},
+ {
+ "input": torch.arange(10).float(),
+ "clipped": torch.tensor([2.1266, 2.3124, 2.6907, 3.3065, 4.1088, 5.0000, 5.8912, 6.6935, 7.3093, 7.6877]),
+ },
+ ],
+ [
+ {"minv": 2, "maxv": 8, "sharpness_factor": 3.0},
+ {
+ "input": torch.arange(10).float(),
+ "clipped": torch.tensor([2.0008, 2.0162, 2.2310, 3.0162, 4.0008, 5.0000, 5.9992, 6.9838, 7.7690, 7.9838]),
+ },
+ ],
+ [
+ {"minv": 2, "maxv": 8, "sharpness_factor": 5.0},
+ {
+ "input": torch.arange(10).float(),
+ "clipped": torch.tensor([2.0000, 2.0013, 2.1386, 3.0013, 4.0000, 5.0000, 6.0000, 6.9987, 7.8614, 7.9987]),
+ },
+ ],
+ [
+ {"minv": 2, "maxv": 8, "sharpness_factor": 10},
+ {
+ "input": np.arange(10).astype(np.float32),
+ "clipped": np.array([2.0000, 2.0000, 2.0693, 3.0000, 4.0000, 5.0000, 6.0000, 7.0000, 7.9307, 8.0000]),
+ },
+ ],
+ [
+ {"minv": 2, "maxv": None, "sharpness_factor": 10},
+ {
+ "input": np.arange(10).astype(float),
+ "clipped": np.array([2.0000, 2.0000, 2.0693, 3.0000, 4.0000, 5.0000, 6.0000, 7.0000, 8.0000, 9.0000]),
+ },
+ ],
+ [
+ {"minv": None, "maxv": 7, "sharpness_factor": 10},
+ {
+ "input": np.arange(10).astype(float),
+ "clipped": np.array([0.0000, 1.0000, 2.0000, 3.0000, 4.0000, 5.0000, 6.0000, 6.9307, 7.0000, 7.0000]),
+ },
+ ],
+ [
+ {"minv": 2, "maxv": 8, "sharpness_factor": 1.0},
+ {
+ "input": np.arange(10).astype(float),
+ "clipped": np.array([2.1266, 2.3124, 2.6907, 3.3065, 4.1088, 5.0000, 5.8912, 6.6935, 7.3093, 7.6877]),
+ },
+ ],
+ [
+ {"minv": 2, "maxv": 8, "sharpness_factor": 3.0},
+ {
+ "input": np.arange(10).astype(float),
+ "clipped": np.array([2.0008, 2.0162, 2.2310, 3.0162, 4.0008, 5.0000, 5.9992, 6.9838, 7.7690, 7.9838]),
+ },
+ ],
+ [
+ {"minv": 2, "maxv": 8, "sharpness_factor": 5.0},
+ {
+ "input": np.arange(10).astype(float),
+ "clipped": np.array([2.0000, 2.0013, 2.1386, 3.0013, 4.0000, 5.0000, 6.0000, 6.9987, 7.8614, 7.9987]),
+ },
+ ],
+]
+
+
+class TestSoftClip(unittest.TestCase):
+
+ @parameterized.expand(TEST_CASES)
+ def test_result(self, input_param, input_data):
+ outputs = soft_clip(input_data["input"], **input_param)
+ expected_val = input_data["clipped"]
+ if isinstance(outputs, torch.Tensor):
+ np.testing.assert_allclose(
+ outputs.detach().cpu().numpy(), expected_val.detach().cpu().numpy(), atol=1e-4, rtol=1e-4
+ )
+ else:
+ np.testing.assert_allclose(outputs, expected_val, atol=1e-4, rtol=1e-4)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_spade_autoencoderkl.py b/tests/test_spade_autoencoderkl.py
new file mode 100644
index 0000000000..9353ceedc2
--- /dev/null
+++ b/tests/test_spade_autoencoderkl.py
@@ -0,0 +1,295 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import unittest
+from unittest import skipUnless
+
+import torch
+from parameterized import parameterized
+
+from monai.networks import eval_mode
+from monai.networks.nets import SPADEAutoencoderKL
+from monai.utils import optional_import
+
+einops, has_einops = optional_import("einops")
+
+CASES_NO_ATTENTION = [
+ [
+ {
+ "spatial_dims": 2,
+ "label_nc": 3,
+ "in_channels": 1,
+ "out_channels": 1,
+ "channels": (4, 4, 4),
+ "latent_channels": 4,
+ "attention_levels": (False, False, False),
+ "num_res_blocks": 1,
+ "norm_num_groups": 4,
+ "with_encoder_nonlocal_attn": False,
+ "with_decoder_nonlocal_attn": False,
+ },
+ (1, 1, 16, 16),
+ (1, 3, 16, 16),
+ (1, 1, 16, 16),
+ (1, 4, 4, 4),
+ ],
+ [
+ {
+ "spatial_dims": 3,
+ "label_nc": 3,
+ "in_channels": 1,
+ "out_channels": 1,
+ "channels": (4, 4, 4),
+ "latent_channels": 4,
+ "attention_levels": (False, False, False),
+ "num_res_blocks": 1,
+ "norm_num_groups": 4,
+ "with_encoder_nonlocal_attn": False,
+ "with_decoder_nonlocal_attn": False,
+ },
+ (1, 1, 16, 16, 16),
+ (1, 3, 16, 16, 16),
+ (1, 1, 16, 16, 16),
+ (1, 4, 4, 4, 4),
+ ],
+]
+
+CASES_ATTENTION = [
+ [
+ {
+ "spatial_dims": 2,
+ "label_nc": 3,
+ "in_channels": 1,
+ "out_channels": 1,
+ "channels": (4, 4, 4),
+ "latent_channels": 4,
+ "attention_levels": (False, False, False),
+ "num_res_blocks": 1,
+ "norm_num_groups": 4,
+ },
+ (1, 1, 16, 16),
+ (1, 3, 16, 16),
+ (1, 1, 16, 16),
+ (1, 4, 4, 4),
+ ],
+ [
+ {
+ "spatial_dims": 2,
+ "label_nc": 3,
+ "in_channels": 1,
+ "out_channels": 1,
+ "channels": (4, 4, 4),
+ "latent_channels": 4,
+ "attention_levels": (False, False, False),
+ "num_res_blocks": (1, 1, 2),
+ "norm_num_groups": 4,
+ },
+ (1, 1, 16, 16),
+ (1, 3, 16, 16),
+ (1, 1, 16, 16),
+ (1, 4, 4, 4),
+ ],
+ [
+ {
+ "spatial_dims": 2,
+ "label_nc": 3,
+ "in_channels": 1,
+ "out_channels": 1,
+ "channels": (4, 4, 4),
+ "latent_channels": 4,
+ "attention_levels": (False, False, False),
+ "num_res_blocks": 1,
+ "norm_num_groups": 4,
+ },
+ (1, 1, 16, 16),
+ (1, 3, 16, 16),
+ (1, 1, 16, 16),
+ (1, 4, 4, 4),
+ ],
+ [
+ {
+ "spatial_dims": 2,
+ "label_nc": 3,
+ "in_channels": 1,
+ "out_channels": 1,
+ "channels": (4, 4, 4),
+ "latent_channels": 4,
+ "attention_levels": (False, False, True),
+ "num_res_blocks": 1,
+ "norm_num_groups": 4,
+ },
+ (1, 1, 16, 16),
+ (1, 3, 16, 16),
+ (1, 1, 16, 16),
+ (1, 4, 4, 4),
+ ],
+ [
+ {
+ "spatial_dims": 2,
+ "label_nc": 3,
+ "in_channels": 1,
+ "out_channels": 1,
+ "channels": (4, 4, 4),
+ "latent_channels": 4,
+ "attention_levels": (False, False, False),
+ "num_res_blocks": 1,
+ "norm_num_groups": 4,
+ "with_encoder_nonlocal_attn": False,
+ },
+ (1, 1, 16, 16),
+ (1, 3, 16, 16),
+ (1, 1, 16, 16),
+ (1, 4, 4, 4),
+ ],
+ [
+ {
+ "spatial_dims": 3,
+ "label_nc": 3,
+ "in_channels": 1,
+ "out_channels": 1,
+ "channels": (4, 4, 4),
+ "latent_channels": 4,
+ "attention_levels": (False, False, True),
+ "num_res_blocks": 1,
+ "norm_num_groups": 4,
+ },
+ (1, 1, 16, 16, 16),
+ (1, 3, 16, 16, 16),
+ (1, 1, 16, 16, 16),
+ (1, 4, 4, 4, 4),
+ ],
+ [
+ {
+ "spatial_dims": 2,
+ "label_nc": 3,
+ "in_channels": 1,
+ "out_channels": 1,
+ "channels": (4, 4, 4),
+ "latent_channels": 4,
+ "attention_levels": (False, False, True),
+ "num_res_blocks": 1,
+ "norm_num_groups": 4,
+ "spade_intermediate_channels": 32,
+ },
+ (1, 1, 16, 16),
+ (1, 3, 16, 16),
+ (1, 1, 16, 16),
+ (1, 4, 4, 4),
+ ],
+]
+
+device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+
+if has_einops:
+ CASES = CASES_ATTENTION + CASES_NO_ATTENTION
+else:
+ CASES = CASES_NO_ATTENTION
+
+
+class TestSPADEAutoEncoderKL(unittest.TestCase):
+ @parameterized.expand(CASES)
+ def test_shape(self, input_param, input_shape, input_seg, expected_shape, expected_latent_shape):
+ net = SPADEAutoencoderKL(**input_param).to(device)
+ with eval_mode(net):
+ result = net.forward(torch.randn(input_shape).to(device), torch.randn(input_seg).to(device))
+ self.assertEqual(result[0].shape, expected_shape)
+ self.assertEqual(result[1].shape, expected_latent_shape)
+
+ @skipUnless(has_einops, "Requires einops")
+ def test_model_channels_not_multiple_of_norm_num_group(self):
+ with self.assertRaises(ValueError):
+ SPADEAutoencoderKL(
+ spatial_dims=2,
+ label_nc=3,
+ in_channels=1,
+ out_channels=1,
+ channels=(24, 24, 24),
+ attention_levels=(False, False, False),
+ latent_channels=8,
+ num_res_blocks=1,
+ norm_num_groups=16,
+ )
+
+ @skipUnless(has_einops, "Requires einops")
+ def test_model_channels_not_same_size_of_attention_levels(self):
+ with self.assertRaises(ValueError):
+ SPADEAutoencoderKL(
+ spatial_dims=2,
+ label_nc=3,
+ in_channels=1,
+ out_channels=1,
+ channels=(24, 24, 24),
+ attention_levels=(False, False),
+ latent_channels=8,
+ num_res_blocks=1,
+ norm_num_groups=16,
+ )
+
+ @skipUnless(has_einops, "Requires einops")
+ def test_model_channels_not_same_size_of_num_res_blocks(self):
+ with self.assertRaises(ValueError):
+ SPADEAutoencoderKL(
+ spatial_dims=2,
+ label_nc=3,
+ in_channels=1,
+ out_channels=1,
+ channels=(24, 24, 24),
+ attention_levels=(False, False, False),
+ latent_channels=8,
+ num_res_blocks=(8, 8),
+ norm_num_groups=16,
+ )
+
+ def test_shape_encode(self):
+ input_param, input_shape, _, _, expected_latent_shape = CASES[0]
+ net = SPADEAutoencoderKL(**input_param).to(device)
+ with eval_mode(net):
+ result = net.encode(torch.randn(input_shape).to(device))
+ self.assertEqual(result[0].shape, expected_latent_shape)
+ self.assertEqual(result[1].shape, expected_latent_shape)
+
+ def test_shape_sampling(self):
+ input_param, _, _, _, expected_latent_shape = CASES[0]
+ net = SPADEAutoencoderKL(**input_param).to(device)
+ with eval_mode(net):
+ result = net.sampling(
+ torch.randn(expected_latent_shape).to(device), torch.randn(expected_latent_shape).to(device)
+ )
+ self.assertEqual(result.shape, expected_latent_shape)
+
+ def test_shape_decode(self):
+ input_param, _, input_seg_shape, expected_input_shape, latent_shape = CASES[0]
+ net = SPADEAutoencoderKL(**input_param).to(device)
+ with eval_mode(net):
+ result = net.decode(torch.randn(latent_shape).to(device), torch.randn(input_seg_shape).to(device))
+ self.assertEqual(result.shape, expected_input_shape)
+
+ @skipUnless(has_einops, "Requires einops")
+ def test_wrong_shape_decode(self):
+ net = SPADEAutoencoderKL(
+ spatial_dims=2,
+ label_nc=3,
+ in_channels=1,
+ out_channels=1,
+ channels=(4, 4, 4),
+ latent_channels=4,
+ attention_levels=(False, False, False),
+ num_res_blocks=1,
+ norm_num_groups=4,
+ )
+ with self.assertRaises(RuntimeError):
+ _ = net.decode(torch.randn((1, 1, 16, 16)).to(device), torch.randn((1, 6, 16, 16)).to(device))
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_spade_diffusion_model_unet.py b/tests/test_spade_diffusion_model_unet.py
new file mode 100644
index 0000000000..481705f56f
--- /dev/null
+++ b/tests/test_spade_diffusion_model_unet.py
@@ -0,0 +1,574 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import unittest
+from unittest import skipUnless
+
+import torch
+from parameterized import parameterized
+
+from monai.networks import eval_mode
+from monai.networks.nets import SPADEDiffusionModelUNet
+from monai.utils import optional_import
+
+einops, has_einops = optional_import("einops")
+UNCOND_CASES_2D = [
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "num_res_blocks": 1,
+ "channels": (8, 8, 8),
+ "attention_levels": (False, False, False),
+ "norm_num_groups": 8,
+ "label_nc": 3,
+ }
+ ],
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "num_res_blocks": (1, 1, 2),
+ "channels": (8, 8, 8),
+ "attention_levels": (False, False, False),
+ "norm_num_groups": 8,
+ "label_nc": 3,
+ }
+ ],
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "num_res_blocks": 1,
+ "channels": (8, 8, 8),
+ "attention_levels": (False, False, False),
+ "norm_num_groups": 8,
+ "resblock_updown": True,
+ "label_nc": 3,
+ }
+ ],
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "num_res_blocks": 1,
+ "channels": (8, 8, 8),
+ "attention_levels": (False, False, True),
+ "num_head_channels": 8,
+ "norm_num_groups": 8,
+ "label_nc": 3,
+ }
+ ],
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "num_res_blocks": 1,
+ "channels": (8, 8, 8),
+ "attention_levels": (False, False, True),
+ "num_head_channels": 8,
+ "norm_num_groups": 8,
+ "resblock_updown": True,
+ "label_nc": 3,
+ }
+ ],
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "num_res_blocks": 1,
+ "channels": (8, 8, 8),
+ "attention_levels": (False, False, True),
+ "num_head_channels": 4,
+ "norm_num_groups": 8,
+ "label_nc": 3,
+ }
+ ],
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "num_res_blocks": 1,
+ "channels": (8, 8, 8),
+ "attention_levels": (False, True, True),
+ "num_head_channels": (0, 2, 4),
+ "norm_num_groups": 8,
+ "label_nc": 3,
+ }
+ ],
+]
+
+UNCOND_CASES_3D = [
+ [
+ {
+ "spatial_dims": 3,
+ "in_channels": 1,
+ "out_channels": 1,
+ "num_res_blocks": 1,
+ "channels": (8, 8, 8),
+ "attention_levels": (False, False, False),
+ "norm_num_groups": 8,
+ "label_nc": 3,
+ "spade_intermediate_channels": 256,
+ }
+ ],
+ [
+ {
+ "spatial_dims": 3,
+ "in_channels": 1,
+ "out_channels": 1,
+ "num_res_blocks": 1,
+ "channels": (8, 8, 8),
+ "attention_levels": (False, False, False),
+ "norm_num_groups": 8,
+ "label_nc": 3,
+ }
+ ],
+ [
+ {
+ "spatial_dims": 3,
+ "in_channels": 1,
+ "out_channels": 1,
+ "num_res_blocks": 1,
+ "channels": (8, 8, 8),
+ "attention_levels": (False, False, False),
+ "norm_num_groups": 8,
+ "resblock_updown": True,
+ "label_nc": 3,
+ }
+ ],
+ [
+ {
+ "spatial_dims": 3,
+ "in_channels": 1,
+ "out_channels": 1,
+ "num_res_blocks": 1,
+ "channels": (8, 8, 8),
+ "attention_levels": (False, False, True),
+ "num_head_channels": 8,
+ "norm_num_groups": 8,
+ "label_nc": 3,
+ }
+ ],
+ [
+ {
+ "spatial_dims": 3,
+ "in_channels": 1,
+ "out_channels": 1,
+ "num_res_blocks": 1,
+ "channels": (8, 8, 8),
+ "attention_levels": (False, False, True),
+ "num_head_channels": 8,
+ "norm_num_groups": 8,
+ "resblock_updown": True,
+ "label_nc": 3,
+ }
+ ],
+ [
+ {
+ "spatial_dims": 3,
+ "in_channels": 1,
+ "out_channels": 1,
+ "num_res_blocks": 1,
+ "channels": (8, 8, 8),
+ "attention_levels": (False, False, True),
+ "num_head_channels": 4,
+ "norm_num_groups": 8,
+ "label_nc": 3,
+ }
+ ],
+ [
+ {
+ "spatial_dims": 3,
+ "in_channels": 1,
+ "out_channels": 1,
+ "num_res_blocks": 1,
+ "channels": (8, 8, 8),
+ "attention_levels": (False, False, True),
+ "num_head_channels": (0, 0, 4),
+ "norm_num_groups": 8,
+ "label_nc": 3,
+ }
+ ],
+]
+
+COND_CASES_2D = [
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "num_res_blocks": 1,
+ "channels": (8, 8, 8),
+ "attention_levels": (False, False, True),
+ "num_head_channels": 4,
+ "norm_num_groups": 8,
+ "with_conditioning": True,
+ "transformer_num_layers": 1,
+ "cross_attention_dim": 3,
+ "label_nc": 3,
+ }
+ ],
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "num_res_blocks": 1,
+ "channels": (8, 8, 8),
+ "attention_levels": (False, False, True),
+ "num_head_channels": 4,
+ "norm_num_groups": 8,
+ "with_conditioning": True,
+ "transformer_num_layers": 1,
+ "cross_attention_dim": 3,
+ "resblock_updown": True,
+ "label_nc": 3,
+ }
+ ],
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "num_res_blocks": 1,
+ "channels": (8, 8, 8),
+ "attention_levels": (False, False, True),
+ "num_head_channels": 4,
+ "norm_num_groups": 8,
+ "with_conditioning": True,
+ "transformer_num_layers": 1,
+ "cross_attention_dim": 3,
+ "upcast_attention": True,
+ "label_nc": 3,
+ }
+ ],
+]
+
+
+class TestSPADEDiffusionModelUNet2D(unittest.TestCase):
+ @parameterized.expand(UNCOND_CASES_2D)
+ @skipUnless(has_einops, "Requires einops")
+ def test_shape_unconditioned_models(self, input_param):
+ net = SPADEDiffusionModelUNet(**input_param)
+ with eval_mode(net):
+ result = net.forward(
+ torch.rand((1, 1, 16, 16)),
+ torch.randint(0, 1000, (1,)).long(),
+ torch.rand((1, input_param["label_nc"], 16, 16)),
+ )
+ self.assertEqual(result.shape, (1, 1, 16, 16))
+
+ @skipUnless(has_einops, "Requires einops")
+ def test_timestep_with_wrong_shape(self):
+ net = SPADEDiffusionModelUNet(
+ spatial_dims=2,
+ label_nc=3,
+ in_channels=1,
+ out_channels=1,
+ num_res_blocks=1,
+ channels=(8, 8, 8),
+ attention_levels=(False, False, False),
+ norm_num_groups=8,
+ )
+ with self.assertRaises(ValueError):
+ with eval_mode(net):
+ net.forward(
+ torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1, 1)).long(), torch.rand((1, 3, 16, 16))
+ )
+
+ @skipUnless(has_einops, "Requires einops")
+ def test_label_with_wrong_shape(self):
+ net = SPADEDiffusionModelUNet(
+ spatial_dims=2,
+ label_nc=3,
+ in_channels=1,
+ out_channels=1,
+ num_res_blocks=1,
+ channels=(8, 8, 8),
+ attention_levels=(False, False, False),
+ norm_num_groups=8,
+ )
+ with self.assertRaises(RuntimeError):
+ with eval_mode(net):
+ net.forward(torch.rand((1, 1, 16, 16)), torch.randint(0, 1000, (1,)).long(), torch.rand((1, 6, 16, 16)))
+
+ @skipUnless(has_einops, "Requires einops")
+ def test_shape_with_different_in_channel_out_channel(self):
+ in_channels = 6
+ out_channels = 3
+ net = SPADEDiffusionModelUNet(
+ spatial_dims=2,
+ label_nc=3,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ num_res_blocks=1,
+ channels=(8, 8, 8),
+ attention_levels=(False, False, False),
+ norm_num_groups=8,
+ )
+ with eval_mode(net):
+ result = net.forward(
+ torch.rand((1, in_channels, 16, 16)), torch.randint(0, 1000, (1,)).long(), torch.rand((1, 3, 16, 16))
+ )
+ self.assertEqual(result.shape, (1, out_channels, 16, 16))
+
+ def test_model_channels_not_multiple_of_norm_num_group(self):
+ with self.assertRaises(ValueError):
+ SPADEDiffusionModelUNet(
+ spatial_dims=2,
+ label_nc=3,
+ in_channels=1,
+ out_channels=1,
+ num_res_blocks=1,
+ channels=(8, 8, 12),
+ attention_levels=(False, False, False),
+ norm_num_groups=8,
+ )
+
+ def test_attention_levels_with_different_length_num_head_channels(self):
+ with self.assertRaises(ValueError):
+ SPADEDiffusionModelUNet(
+ spatial_dims=2,
+ label_nc=3,
+ in_channels=1,
+ out_channels=1,
+ num_res_blocks=1,
+ channels=(8, 8, 8),
+ attention_levels=(False, False, False),
+ num_head_channels=(0, 2),
+ norm_num_groups=8,
+ )
+
+ def test_num_res_blocks_with_different_length_channels(self):
+ with self.assertRaises(ValueError):
+ SPADEDiffusionModelUNet(
+ spatial_dims=2,
+ label_nc=3,
+ in_channels=1,
+ out_channels=1,
+ num_res_blocks=(1, 1),
+ channels=(8, 8, 8),
+ attention_levels=(False, False, False),
+ norm_num_groups=8,
+ )
+
+ @skipUnless(has_einops, "Requires einops")
+ def test_shape_conditioned_models(self):
+ net = SPADEDiffusionModelUNet(
+ spatial_dims=2,
+ label_nc=3,
+ in_channels=1,
+ out_channels=1,
+ num_res_blocks=1,
+ channels=(8, 8, 8),
+ attention_levels=(False, False, True),
+ with_conditioning=True,
+ transformer_num_layers=1,
+ cross_attention_dim=3,
+ norm_num_groups=8,
+ num_head_channels=8,
+ )
+ with eval_mode(net):
+ result = net.forward(
+ x=torch.rand((1, 1, 16, 32)),
+ timesteps=torch.randint(0, 1000, (1,)).long(),
+ seg=torch.rand((1, 3, 16, 32)),
+ context=torch.rand((1, 1, 3)),
+ )
+ self.assertEqual(result.shape, (1, 1, 16, 32))
+
+ @skipUnless(has_einops, "Requires einops")
+ def test_with_conditioning_cross_attention_dim_none(self):
+ with self.assertRaises(ValueError):
+ SPADEDiffusionModelUNet(
+ spatial_dims=2,
+ label_nc=3,
+ in_channels=1,
+ out_channels=1,
+ num_res_blocks=1,
+ channels=(8, 8, 8),
+ attention_levels=(False, False, True),
+ with_conditioning=True,
+ transformer_num_layers=1,
+ cross_attention_dim=None,
+ norm_num_groups=8,
+ )
+
+ @skipUnless(has_einops, "Requires einops")
+ def test_context_with_conditioning_none(self):
+ net = SPADEDiffusionModelUNet(
+ spatial_dims=2,
+ label_nc=3,
+ in_channels=1,
+ out_channels=1,
+ num_res_blocks=1,
+ channels=(8, 8, 8),
+ attention_levels=(False, False, True),
+ with_conditioning=False,
+ transformer_num_layers=1,
+ norm_num_groups=8,
+ )
+
+ with self.assertRaises(ValueError):
+ with eval_mode(net):
+ net.forward(
+ x=torch.rand((1, 1, 16, 32)),
+ timesteps=torch.randint(0, 1000, (1,)).long(),
+ seg=torch.rand((1, 3, 16, 32)),
+ context=torch.rand((1, 1, 3)),
+ )
+
+ @skipUnless(has_einops, "Requires einops")
+ def test_shape_conditioned_models_class_conditioning(self):
+ net = SPADEDiffusionModelUNet(
+ spatial_dims=2,
+ label_nc=3,
+ in_channels=1,
+ out_channels=1,
+ num_res_blocks=1,
+ channels=(8, 8, 8),
+ attention_levels=(False, False, True),
+ norm_num_groups=8,
+ num_head_channels=8,
+ num_class_embeds=2,
+ )
+ with eval_mode(net):
+ result = net.forward(
+ x=torch.rand((1, 1, 16, 32)),
+ timesteps=torch.randint(0, 1000, (1,)).long(),
+ seg=torch.rand((1, 3, 16, 32)),
+ class_labels=torch.randint(0, 2, (1,)).long(),
+ )
+ self.assertEqual(result.shape, (1, 1, 16, 32))
+
+ @skipUnless(has_einops, "Requires einops")
+ def test_conditioned_models_no_class_labels(self):
+ net = SPADEDiffusionModelUNet(
+ spatial_dims=2,
+ label_nc=3,
+ in_channels=1,
+ out_channels=1,
+ num_res_blocks=1,
+ channels=(8, 8, 8),
+ attention_levels=(False, False, True),
+ norm_num_groups=8,
+ num_head_channels=8,
+ num_class_embeds=2,
+ )
+
+ with self.assertRaises(ValueError):
+ net.forward(
+ x=torch.rand((1, 1, 16, 32)),
+ timesteps=torch.randint(0, 1000, (1,)).long(),
+ seg=torch.rand((1, 3, 16, 32)),
+ )
+
+ def test_model_channels_not_same_size_of_attention_levels(self):
+ with self.assertRaises(ValueError):
+ SPADEDiffusionModelUNet(
+ spatial_dims=2,
+ label_nc=3,
+ in_channels=1,
+ out_channels=1,
+ num_res_blocks=1,
+ channels=(8, 8, 8),
+ attention_levels=(False, False),
+ norm_num_groups=8,
+ num_head_channels=8,
+ num_class_embeds=2,
+ )
+
+ @parameterized.expand(COND_CASES_2D)
+ @skipUnless(has_einops, "Requires einops")
+ def test_conditioned_2d_models_shape(self, input_param):
+ net = SPADEDiffusionModelUNet(**input_param)
+ with eval_mode(net):
+ result = net.forward(
+ torch.rand((1, 1, 16, 16)),
+ torch.randint(0, 1000, (1,)).long(),
+ torch.rand((1, input_param["label_nc"], 16, 16)),
+ torch.rand((1, 1, 3)),
+ )
+ self.assertEqual(result.shape, (1, 1, 16, 16))
+
+
+class TestDiffusionModelUNet3D(unittest.TestCase):
+ @parameterized.expand(UNCOND_CASES_3D)
+ @skipUnless(has_einops, "Requires einops")
+ def test_shape_unconditioned_models(self, input_param):
+ net = SPADEDiffusionModelUNet(**input_param)
+ with eval_mode(net):
+ result = net.forward(
+ torch.rand((1, 1, 16, 16, 16)),
+ torch.randint(0, 1000, (1,)).long(),
+ torch.rand((1, input_param["label_nc"], 16, 16, 16)),
+ )
+ self.assertEqual(result.shape, (1, 1, 16, 16, 16))
+
+ @skipUnless(has_einops, "Requires einops")
+ def test_shape_with_different_in_channel_out_channel(self):
+ in_channels = 6
+ out_channels = 3
+ net = SPADEDiffusionModelUNet(
+ spatial_dims=3,
+ label_nc=3,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ num_res_blocks=1,
+ channels=(8, 8, 8),
+ attention_levels=(False, False, True),
+ norm_num_groups=4,
+ )
+ with eval_mode(net):
+ result = net.forward(
+ torch.rand((1, in_channels, 16, 16, 16)),
+ torch.randint(0, 1000, (1,)).long(),
+ torch.rand((1, 3, 16, 16, 16)),
+ )
+ self.assertEqual(result.shape, (1, out_channels, 16, 16, 16))
+
+ @skipUnless(has_einops, "Requires einops")
+ def test_shape_conditioned_models(self):
+ net = SPADEDiffusionModelUNet(
+ spatial_dims=3,
+ label_nc=3,
+ in_channels=1,
+ out_channels=1,
+ num_res_blocks=1,
+ channels=(16, 16, 16),
+ attention_levels=(False, False, True),
+ norm_num_groups=16,
+ with_conditioning=True,
+ transformer_num_layers=1,
+ cross_attention_dim=3,
+ )
+ with eval_mode(net):
+ result = net.forward(
+ x=torch.rand((1, 1, 16, 16, 16)),
+ timesteps=torch.randint(0, 1000, (1,)).long(),
+ seg=torch.rand((1, 3, 16, 16, 16)),
+ context=torch.rand((1, 1, 3)),
+ )
+ self.assertEqual(result.shape, (1, 1, 16, 16, 16))
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_spade_vaegan.py b/tests/test_spade_vaegan.py
new file mode 100644
index 0000000000..3fdb9b74cb
--- /dev/null
+++ b/tests/test_spade_vaegan.py
@@ -0,0 +1,140 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import unittest
+
+import numpy as np
+import torch
+from parameterized import parameterized
+
+from monai.networks import eval_mode
+from monai.networks.nets import SPADENet
+
+CASE_2D = [
+ [[2, 1, 1, 3, [64, 64], [16, 32, 64, 128], 16, True]],
+ [[2, 1, 1, 3, [64, 64], [16, 32, 64, 128], None, False]],
+]
+CASE_3D = [
+ [[3, 1, 1, 3, [64, 64, 64], [16, 32, 64, 128], 16, True]],
+ [[3, 1, 1, 3, [64, 64, 64], [16, 32, 64, 128], None, False]],
+]
+
+
+def create_semantic_data(shape: list, semantic_regions: int):
+ """
+ To create semantic and image mock inputs for the network.
+ Args:
+ shape: input shape
+ semantic_regions: number of semantic region
+ Returns:
+ """
+ out_label = torch.zeros(shape)
+ out_image = torch.zeros(shape) + torch.randn(shape) * 0.01
+ for i in range(1, semantic_regions):
+ shape_square = [i // np.random.choice(list(range(2, i // 2))) for i in shape]
+ start_point = [np.random.choice(list(range(shape[ind] - shape_square[ind]))) for ind, i in enumerate(shape)]
+ if len(shape) == 2:
+ out_label[
+ start_point[0] : (start_point[0] + shape_square[0]), start_point[1] : (start_point[1] + shape_square[1])
+ ] = i
+ base_intensity = torch.ones(shape_square) * np.random.randn()
+ out_image[
+ start_point[0] : (start_point[0] + shape_square[0]), start_point[1] : (start_point[1] + shape_square[1])
+ ] = (base_intensity + torch.randn(shape_square) * 0.1)
+ elif len(shape) == 3:
+ out_label[
+ start_point[0] : (start_point[0] + shape_square[0]),
+ start_point[1] : (start_point[1] + shape_square[1]),
+ start_point[2] : (start_point[2] + shape_square[2]),
+ ] = i
+ base_intensity = torch.ones(shape_square) * np.random.randn()
+ out_image[
+ start_point[0] : (start_point[0] + shape_square[0]),
+ start_point[1] : (start_point[1] + shape_square[1]),
+ start_point[2] : (start_point[2] + shape_square[2]),
+ ] = (base_intensity + torch.randn(shape_square) * 0.1)
+ else:
+ ValueError("Supports only 2D and 3D tensors")
+
+ # One hot encode label
+ out_label_ = torch.zeros([semantic_regions] + list(out_label.shape))
+ for ch in range(semantic_regions):
+ out_label_[ch, ...] = out_label == ch
+
+ return out_label_.unsqueeze(0), out_image.unsqueeze(0).unsqueeze(0)
+
+
+class TestSpadeNet(unittest.TestCase):
+ @parameterized.expand(CASE_2D)
+ def test_forward_2d(self, input_param):
+ """
+ Check that forward method is called correctly and output shape matches.
+ """
+ net = SPADENet(*input_param)
+ in_label, in_image = create_semantic_data(input_param[4], input_param[3])
+ with eval_mode(net):
+ if not net.is_vae:
+ out = net(in_label, in_image)
+ out = out[0]
+ else:
+ out, z_mu, z_logvar = net(in_label, in_image)
+ self.assertTrue(torch.all(torch.isfinite(z_mu)))
+ self.assertTrue(torch.all(torch.isfinite(z_logvar)))
+
+ self.assertTrue(torch.all(torch.isfinite(out)))
+ self.assertEqual(list(out.shape), [1, 1, 64, 64])
+
+ @parameterized.expand(CASE_2D)
+ def test_encoder_decoder(self, input_param):
+ """
+ Check that forward method is called correctly and output shape matches.
+ """
+ net = SPADENet(*input_param)
+ in_label, in_image = create_semantic_data(input_param[4], input_param[3])
+ with eval_mode(net):
+ out_z = net.encode(in_image)
+ if net.is_vae:
+ self.assertEqual(list(out_z.shape), [1, 16])
+ else:
+ self.assertEqual(out_z, None)
+ out_i = net.decode(in_label, out_z)
+ self.assertEqual(list(out_i.shape), [1, 1, 64, 64])
+
+ @parameterized.expand(CASE_3D)
+ def test_forward_3d(self, input_param):
+ """
+ Check that forward method is called correctly and output shape matches.
+ """
+ net = SPADENet(*input_param)
+ in_label, in_image = create_semantic_data(input_param[4], input_param[3])
+ with eval_mode(net):
+ if net.is_vae:
+ out, z_mu, z_logvar = net(in_label, in_image)
+ self.assertTrue(torch.all(torch.isfinite(z_mu)))
+ self.assertTrue(torch.all(torch.isfinite(z_logvar)))
+ else:
+ out = net(in_label, in_image)
+ out = out[0]
+ self.assertTrue(torch.all(torch.isfinite(out)))
+ self.assertEqual(list(out.shape), [1, 1, 64, 64, 64])
+
+ def test_shape_wrong(self):
+ """
+ We input an input shape that isn't divisible by 2**(n downstream steps)
+ """
+ with self.assertRaises(ValueError):
+ _ = SPADENet(1, 1, 8, [16, 16], [16, 32, 64, 128], 16, True)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_spatialattention.py b/tests/test_spatialattention.py
new file mode 100644
index 0000000000..70b78263c5
--- /dev/null
+++ b/tests/test_spatialattention.py
@@ -0,0 +1,55 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import unittest
+from unittest import skipUnless
+
+import torch
+from parameterized import parameterized
+
+from monai.networks import eval_mode
+from monai.networks.blocks.spatialattention import SpatialAttentionBlock
+from monai.utils import optional_import
+
+einops, has_einops = optional_import("einops")
+
+TEST_CASES = [
+ [
+ {"spatial_dims": 2, "num_channels": 128, "num_head_channels": 32, "norm_num_groups": 32, "norm_eps": 1e-6},
+ (1, 128, 32, 32),
+ (1, 128, 32, 32),
+ ],
+ [
+ {"spatial_dims": 3, "num_channels": 16, "num_head_channels": 8, "norm_num_groups": 8, "norm_eps": 1e-6},
+ (1, 16, 8, 8, 8),
+ (1, 16, 8, 8, 8),
+ ],
+]
+
+
+class TestBlock(unittest.TestCase):
+ @parameterized.expand(TEST_CASES)
+ @skipUnless(has_einops, "Requires einops")
+ def test_shape(self, input_param, input_shape, expected_shape):
+ net = SpatialAttentionBlock(**input_param)
+ with eval_mode(net):
+ result = net(torch.randn(input_shape))
+ self.assertEqual(result.shape, expected_shape)
+
+ def test_attention_dim_not_multiple_of_heads(self):
+ with self.assertRaises(ValueError):
+ SpatialAttentionBlock(spatial_dims=2, num_channels=128, num_head_channels=33)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_subpixel_upsample.py b/tests/test_subpixel_upsample.py
index 5abbe57e11..fe9fb1c328 100644
--- a/tests/test_subpixel_upsample.py
+++ b/tests/test_subpixel_upsample.py
@@ -55,9 +55,9 @@
(2, 1, 32, 16, 8),
]
-TEST_CASE_SUBPIXEL.append(TEST_CASE_SUBPIXEL_2D_EXTRA)
-TEST_CASE_SUBPIXEL.append(TEST_CASE_SUBPIXEL_3D_EXTRA)
-TEST_CASE_SUBPIXEL.append(TEST_CASE_SUBPIXEL_CONV_BLOCK_EXTRA)
+TEST_CASE_SUBPIXEL.append(TEST_CASE_SUBPIXEL_2D_EXTRA) # type: ignore
+TEST_CASE_SUBPIXEL.append(TEST_CASE_SUBPIXEL_3D_EXTRA) # type: ignore
+TEST_CASE_SUBPIXEL.append(TEST_CASE_SUBPIXEL_CONV_BLOCK_EXTRA) # type: ignore
# add every test back with the pad/pool sequential component omitted
for tests in list(TEST_CASE_SUBPIXEL):
diff --git a/tests/test_sure_loss.py b/tests/test_sure_loss.py
index 903f9bd2ca..fb8f5dda72 100644
--- a/tests/test_sure_loss.py
+++ b/tests/test_sure_loss.py
@@ -65,7 +65,7 @@ def operator(x):
loss_real = sure_loss_real(operator, x_real, y_pseudo_gt_real, complex_input=False)
loss_complex = sure_loss_complex(operator, x_complex, y_pseudo_gt_complex, complex_input=True)
- self.assertAlmostEqual(loss_real.item(), loss_complex.abs().item(), places=6)
+ self.assertAlmostEqual(loss_real.item(), loss_complex.abs().item(), places=5)
if __name__ == "__main__":
diff --git a/tests/test_synthetic.py b/tests/test_synthetic.py
index 7db3c3e77a..4ab2144568 100644
--- a/tests/test_synthetic.py
+++ b/tests/test_synthetic.py
@@ -47,7 +47,7 @@ def test_create_test_image(self, dim, input_param, expected_img, expected_seg, e
set_determinism(seed=0)
if dim == 2:
img, seg = create_test_image_2d(**input_param)
- elif dim == 3:
+ else: # dim == 3
img, seg = create_test_image_3d(**input_param)
self.assertEqual(img.shape, expected_shape)
self.assertEqual(seg.max(), expected_max_cls)
diff --git a/tests/test_tciadataset.py b/tests/test_tciadataset.py
index d996922e20..5a16bb4816 100644
--- a/tests/test_tciadataset.py
+++ b/tests/test_tciadataset.py
@@ -108,7 +108,7 @@ def _test_dataset(dataset):
)[0]
shutil.rmtree(os.path.join(testing_dir, collection))
- try:
+ with self.assertRaisesRegex(RuntimeError, "^Cannot find dataset directory"):
TciaDataset(
root_dir=testing_dir,
collection=collection,
@@ -117,8 +117,6 @@ def _test_dataset(dataset):
download=False,
val_frac=val_frac,
)
- except RuntimeError as e:
- self.assertTrue(str(e).startswith("Cannot find dataset directory"))
if __name__ == "__main__":
diff --git a/tests/test_threadcontainer.py b/tests/test_threadcontainer.py
index 9551dec703..568461748b 100644
--- a/tests/test_threadcontainer.py
+++ b/tests/test_threadcontainer.py
@@ -62,7 +62,7 @@ def test_container(self):
self.assertTrue(con.is_alive)
self.assertIsNotNone(con.status())
- self.assertTrue(len(con.status_dict) > 0)
+ self.assertGreater(len(con.status_dict), 0)
con.join()
diff --git a/tests/test_to_cupy.py b/tests/test_to_cupy.py
index 5a1754e7c5..38400f0d3f 100644
--- a/tests/test_to_cupy.py
+++ b/tests/test_to_cupy.py
@@ -62,8 +62,8 @@ def test_numpy_input_dtype(self):
test_data = np.rot90(test_data)
self.assertFalse(test_data.flags["C_CONTIGUOUS"])
result = ToCupy(np.uint8)(test_data)
- self.assertTrue(result.dtype == cp.uint8)
- self.assertTrue(isinstance(result, cp.ndarray))
+ self.assertEqual(result.dtype, cp.uint8)
+ self.assertIsInstance(result, cp.ndarray)
self.assertTrue(result.flags["C_CONTIGUOUS"])
cp.testing.assert_allclose(result, test_data)
@@ -72,8 +72,8 @@ def test_tensor_input(self):
test_data = test_data.rot90()
self.assertFalse(test_data.is_contiguous())
result = ToCupy()(test_data)
- self.assertTrue(result.dtype == cp.float32)
- self.assertTrue(isinstance(result, cp.ndarray))
+ self.assertEqual(result.dtype, cp.float32)
+ self.assertIsInstance(result, cp.ndarray)
self.assertTrue(result.flags["C_CONTIGUOUS"])
cp.testing.assert_allclose(result, test_data)
@@ -83,8 +83,8 @@ def test_tensor_cuda_input(self):
test_data = test_data.rot90()
self.assertFalse(test_data.is_contiguous())
result = ToCupy()(test_data)
- self.assertTrue(result.dtype == cp.float32)
- self.assertTrue(isinstance(result, cp.ndarray))
+ self.assertEqual(result.dtype, cp.float32)
+ self.assertIsInstance(result, cp.ndarray)
self.assertTrue(result.flags["C_CONTIGUOUS"])
cp.testing.assert_allclose(result, test_data)
@@ -95,8 +95,8 @@ def test_tensor_cuda_input_dtype(self):
self.assertFalse(test_data.is_contiguous())
result = ToCupy(dtype="float32")(test_data)
- self.assertTrue(result.dtype == cp.float32)
- self.assertTrue(isinstance(result, cp.ndarray))
+ self.assertEqual(result.dtype, cp.float32)
+ self.assertIsInstance(result, cp.ndarray)
self.assertTrue(result.flags["C_CONTIGUOUS"])
cp.testing.assert_allclose(result, test_data)
diff --git a/tests/test_to_numpy.py b/tests/test_to_numpy.py
index 8f7cf34865..f4e5f80a29 100644
--- a/tests/test_to_numpy.py
+++ b/tests/test_to_numpy.py
@@ -32,7 +32,7 @@ def test_cupy_input(self):
test_data = cp.rot90(test_data)
self.assertFalse(test_data.flags["C_CONTIGUOUS"])
result = ToNumpy()(test_data)
- self.assertTrue(isinstance(result, np.ndarray))
+ self.assertIsInstance(result, np.ndarray)
self.assertTrue(result.flags["C_CONTIGUOUS"])
assert_allclose(result, test_data.get(), type_test=False)
@@ -41,8 +41,8 @@ def test_numpy_input(self):
test_data = np.rot90(test_data)
self.assertFalse(test_data.flags["C_CONTIGUOUS"])
result = ToNumpy(dtype="float32")(test_data)
- self.assertTrue(isinstance(result, np.ndarray))
- self.assertTrue(result.dtype == np.float32)
+ self.assertIsInstance(result, np.ndarray)
+ self.assertEqual(result.dtype, np.float32)
self.assertTrue(result.flags["C_CONTIGUOUS"])
assert_allclose(result, test_data, type_test=False)
@@ -51,7 +51,7 @@ def test_tensor_input(self):
test_data = test_data.rot90()
self.assertFalse(test_data.is_contiguous())
result = ToNumpy(dtype=torch.uint8)(test_data)
- self.assertTrue(isinstance(result, np.ndarray))
+ self.assertIsInstance(result, np.ndarray)
self.assertTrue(result.flags["C_CONTIGUOUS"])
assert_allclose(result, test_data, type_test=False)
@@ -61,7 +61,7 @@ def test_tensor_cuda_input(self):
test_data = test_data.rot90()
self.assertFalse(test_data.is_contiguous())
result = ToNumpy()(test_data)
- self.assertTrue(isinstance(result, np.ndarray))
+ self.assertIsInstance(result, np.ndarray)
self.assertTrue(result.flags["C_CONTIGUOUS"])
assert_allclose(result, test_data, type_test=False)
@@ -71,13 +71,13 @@ def test_list_tuple(self):
assert_allclose(result, np.asarray(test_data), type_test=False)
test_data = ((1, 2), (3, 4))
result = ToNumpy(wrap_sequence=False)(test_data)
- self.assertTrue(type(result), tuple)
+ self.assertIsInstance(result, tuple)
assert_allclose(result, ((np.asarray(1), np.asarray(2)), (np.asarray(3), np.asarray(4))))
def test_single_value(self):
for test_data in [5, np.array(5), torch.tensor(5)]:
result = ToNumpy(dtype=np.uint8)(test_data)
- self.assertTrue(isinstance(result, np.ndarray))
+ self.assertIsInstance(result, np.ndarray)
assert_allclose(result, np.asarray(test_data), type_test=False)
self.assertEqual(result.ndim, 0)
diff --git a/tests/test_torchscript_utils.py b/tests/test_torchscript_utils.py
index 6f8f231829..5a5fb47864 100644
--- a/tests/test_torchscript_utils.py
+++ b/tests/test_torchscript_utils.py
@@ -23,6 +23,7 @@
class TestModule(torch.nn.Module):
+ __test__ = False # indicate to pytest that this class is not intended for collection
def forward(self, x):
return x + 10
diff --git a/tests/test_torchvision_fc_model.py b/tests/test_torchvision_fc_model.py
index 322cce1161..9cc19db62c 100644
--- a/tests/test_torchvision_fc_model.py
+++ b/tests/test_torchvision_fc_model.py
@@ -195,8 +195,8 @@ def test_get_module(self):
mod = look_up_named_module("model.1.submodule.1.submodule.1.submodule.0.conv", net)
self.assertTrue(str(mod).startswith("Conv2d"))
self.assertIsInstance(set_named_module(net, "model", torch.nn.Identity()).model, torch.nn.Identity)
- self.assertEqual(look_up_named_module("model.1.submodule.1.submodule.1.submodule.conv", net), None)
- self.assertEqual(look_up_named_module("test attribute", net), None)
+ self.assertIsNone(look_up_named_module("model.1.submodule.1.submodule.1.submodule.conv", net))
+ self.assertIsNone(look_up_named_module("test attribute", net))
if __name__ == "__main__":
diff --git a/tests/test_traceable_transform.py b/tests/test_traceable_transform.py
index dd139053e3..6a499b2dd9 100644
--- a/tests/test_traceable_transform.py
+++ b/tests/test_traceable_transform.py
@@ -33,12 +33,12 @@ def test_default(self):
expected_key = "_transforms"
a = _TraceTest()
for x in a.transform_info_keys():
- self.assertTrue(x in a.get_transform_info())
+ self.assertIn(x, a.get_transform_info())
self.assertEqual(a.trace_key(), expected_key)
data = {"image": "test"}
data = a(data) # adds to the stack
- self.assertTrue(isinstance(data[expected_key], list))
+ self.assertIsInstance(data[expected_key], list)
self.assertEqual(data[expected_key][0]["class"], "_TraceTest")
data = a(data) # adds to the stack
diff --git a/tests/test_transformer.py b/tests/test_transformer.py
new file mode 100644
index 0000000000..b371809d47
--- /dev/null
+++ b/tests/test_transformer.py
@@ -0,0 +1,109 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import os
+import tempfile
+import unittest
+from unittest import skipUnless
+
+import numpy as np
+import torch
+from parameterized import parameterized
+
+from monai.apps import download_url
+from monai.networks import eval_mode
+from monai.networks.nets import DecoderOnlyTransformer
+from monai.utils import optional_import
+from tests.utils import skip_if_downloading_fails, testing_data_config
+
+_, has_einops = optional_import("einops")
+TEST_CASES = []
+for dropout_rate in np.linspace(0, 1, 2):
+ for attention_layer_dim in [360, 480, 600, 768]:
+ for num_heads in [4, 6, 8, 12]:
+ TEST_CASES.append(
+ [
+ {
+ "num_tokens": 10,
+ "max_seq_len": 16,
+ "attn_layers_dim": attention_layer_dim,
+ "attn_layers_depth": 2,
+ "attn_layers_heads": num_heads,
+ "embedding_dropout_rate": dropout_rate,
+ }
+ ]
+ )
+
+
+class TestDecoderOnlyTransformer(unittest.TestCase):
+ @parameterized.expand(TEST_CASES)
+ @skipUnless(has_einops, "Requires einops")
+ def test_unconditioned_models(self, input_param):
+ net = DecoderOnlyTransformer(**input_param)
+ with eval_mode(net):
+ net.forward(torch.randint(0, 10, (1, 16)))
+
+ @parameterized.expand(TEST_CASES)
+ @skipUnless(has_einops, "Requires einops")
+ def test_conditioned_models(self, input_param):
+ net = DecoderOnlyTransformer(**input_param, with_cross_attention=True)
+ with eval_mode(net):
+ net.forward(torch.randint(0, 10, (1, 16)), context=torch.randn(1, 3, input_param["attn_layers_dim"]))
+
+ def test_attention_dim_not_multiple_of_heads(self):
+ with self.assertRaises(ValueError):
+ DecoderOnlyTransformer(
+ num_tokens=10, max_seq_len=16, attn_layers_dim=8, attn_layers_depth=2, attn_layers_heads=3
+ )
+
+ @skipUnless(has_einops, "Requires einops")
+ def test_dropout_rate_negative(self):
+
+ with self.assertRaises(ValueError):
+ DecoderOnlyTransformer(
+ num_tokens=10,
+ max_seq_len=16,
+ attn_layers_dim=8,
+ attn_layers_depth=2,
+ attn_layers_heads=2,
+ embedding_dropout_rate=-1,
+ )
+
+ @skipUnless(has_einops, "Requires einops")
+ def test_compatibility_with_monai_generative(self):
+ # test loading weights from a model saved in MONAI Generative, version 0.2.3
+ with skip_if_downloading_fails():
+ net = DecoderOnlyTransformer(
+ num_tokens=10,
+ max_seq_len=16,
+ attn_layers_dim=8,
+ attn_layers_depth=2,
+ attn_layers_heads=2,
+ with_cross_attention=True,
+ embedding_dropout_rate=0,
+ )
+
+ tmpdir = tempfile.mkdtemp()
+ key = "decoder_only_transformer_monai_generative_weights"
+ url = testing_data_config("models", key, "url")
+ hash_type = testing_data_config("models", key, "hash_type")
+ hash_val = testing_data_config("models", key, "hash_val")
+ filename = "decoder_only_transformer_monai_generative_weights.pt"
+ weight_path = os.path.join(tmpdir, filename)
+ download_url(url=url, filepath=weight_path, hash_val=hash_val, hash_type=hash_type)
+
+ net.load_old_state_dict(torch.load(weight_path), verbose=False)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_transformerblock.py b/tests/test_transformerblock.py
index 5a8dbba83c..a850cc6f74 100644
--- a/tests/test_transformerblock.py
+++ b/tests/test_transformerblock.py
@@ -12,6 +12,7 @@
from __future__ import annotations
import unittest
+from unittest import skipUnless
import numpy as np
import torch
@@ -19,28 +20,33 @@
from monai.networks import eval_mode
from monai.networks.blocks.transformerblock import TransformerBlock
+from monai.utils import optional_import
+einops, has_einops = optional_import("einops")
TEST_CASE_TRANSFORMERBLOCK = []
for dropout_rate in np.linspace(0, 1, 4):
for hidden_size in [360, 480, 600, 768]:
for num_heads in [4, 8, 12]:
for mlp_dim in [1024, 3072]:
- test_case = [
- {
- "hidden_size": hidden_size,
- "num_heads": num_heads,
- "mlp_dim": mlp_dim,
- "dropout_rate": dropout_rate,
- },
- (2, 512, hidden_size),
- (2, 512, hidden_size),
- ]
- TEST_CASE_TRANSFORMERBLOCK.append(test_case)
+ for cross_attention in [False, True]:
+ test_case = [
+ {
+ "hidden_size": hidden_size,
+ "num_heads": num_heads,
+ "mlp_dim": mlp_dim,
+ "dropout_rate": dropout_rate,
+ "with_cross_attention": cross_attention,
+ },
+ (2, 512, hidden_size),
+ (2, 512, hidden_size),
+ ]
+ TEST_CASE_TRANSFORMERBLOCK.append(test_case)
class TestTransformerBlock(unittest.TestCase):
@parameterized.expand(TEST_CASE_TRANSFORMERBLOCK)
+ @skipUnless(has_einops, "Requires einops")
def test_shape(self, input_param, input_shape, expected_shape):
net = TransformerBlock(**input_param)
with eval_mode(net):
@@ -54,6 +60,7 @@ def test_ill_arg(self):
with self.assertRaises(ValueError):
TransformerBlock(hidden_size=622, num_heads=8, mlp_dim=3072, dropout_rate=0.4)
+ @skipUnless(has_einops, "Requires einops")
def test_access_attn_matrix(self):
# input format
hidden_size = 128
diff --git a/tests/test_trt_compile.py b/tests/test_trt_compile.py
new file mode 100644
index 0000000000..49404fdbbe
--- /dev/null
+++ b/tests/test_trt_compile.py
@@ -0,0 +1,148 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import tempfile
+import unittest
+
+import torch
+from parameterized import parameterized
+
+from monai.handlers import TrtHandler
+from monai.networks import trt_compile
+from monai.networks.nets import UNet, cell_sam_wrapper, vista3d132
+from monai.utils import min_version, optional_import
+from tests.utils import (
+ SkipIfAtLeastPyTorchVersion,
+ SkipIfBeforeComputeCapabilityVersion,
+ skip_if_no_cuda,
+ skip_if_quick,
+ skip_if_windows,
+)
+
+trt, trt_imported = optional_import("tensorrt", "10.1.0", min_version)
+polygraphy, polygraphy_imported = optional_import("polygraphy")
+build_sam_vit_b, has_sam = optional_import("segment_anything.build_sam", name="build_sam_vit_b")
+
+TEST_CASE_1 = ["fp32"]
+TEST_CASE_2 = ["fp16"]
+
+
+@skip_if_windows
+@skip_if_no_cuda
+@skip_if_quick
+@unittest.skipUnless(trt_imported, "tensorrt is required")
+@unittest.skipUnless(polygraphy_imported, "polygraphy is required")
+@SkipIfBeforeComputeCapabilityVersion((7, 0))
+class TestTRTCompile(unittest.TestCase):
+
+ def setUp(self):
+ self.gpu_device = torch.cuda.current_device()
+
+ def tearDown(self):
+ current_device = torch.cuda.current_device()
+ if current_device != self.gpu_device:
+ torch.cuda.set_device(self.gpu_device)
+
+ @SkipIfAtLeastPyTorchVersion((2, 4, 1))
+ def test_handler(self):
+ from ignite.engine import Engine
+
+ net1 = torch.nn.Sequential(*[torch.nn.PReLU(), torch.nn.PReLU()])
+ data1 = net1.state_dict()
+ data1["0.weight"] = torch.tensor([0.1])
+ data1["1.weight"] = torch.tensor([0.2])
+ net1.load_state_dict(data1)
+ net1.cuda()
+
+ with tempfile.TemporaryDirectory() as tempdir:
+ engine = Engine(lambda e, b: None)
+ args = {"method": "torch_trt"}
+ TrtHandler(net1, tempdir + "/trt_handler", args=args).attach(engine)
+ engine.run([0] * 8, max_epochs=1)
+ self.assertIsNotNone(net1._trt_compiler)
+ self.assertIsNone(net1._trt_compiler.engine)
+ net1.forward(torch.tensor([[0.0, 1.0], [1.0, 2.0]], device="cuda"))
+ self.assertIsNotNone(net1._trt_compiler.engine)
+
+ @parameterized.expand([TEST_CASE_1, TEST_CASE_2])
+ def test_unet_value(self, precision):
+ model = UNet(
+ spatial_dims=3,
+ in_channels=1,
+ out_channels=2,
+ channels=(2, 2, 4, 8, 4),
+ strides=(2, 2, 2, 2),
+ num_res_units=2,
+ norm="batch",
+ ).cuda()
+ with torch.no_grad(), tempfile.TemporaryDirectory() as tmpdir:
+ model.eval()
+ input_example = torch.randn(2, 1, 96, 96, 96).cuda()
+ output_example = model(input_example)
+ args: dict = {"builder_optimization_level": 1}
+ trt_compile(
+ model,
+ f"{tmpdir}/test_unet_trt_compile",
+ args={"precision": precision, "build_args": args, "dynamic_batchsize": [1, 4, 8]},
+ )
+ self.assertIsNone(model._trt_compiler.engine)
+ trt_output = model(input_example)
+ # Check that lazy TRT build succeeded
+ self.assertIsNotNone(model._trt_compiler.engine)
+ torch.testing.assert_close(trt_output, output_example, rtol=0.01, atol=0.01)
+
+ @parameterized.expand([TEST_CASE_1, TEST_CASE_2])
+ @unittest.skipUnless(has_sam, "Requires SAM installation")
+ def test_cell_sam_wrapper_value(self, precision):
+ model = cell_sam_wrapper.CellSamWrapper(checkpoint=None).to("cuda")
+ with torch.no_grad(), tempfile.TemporaryDirectory() as tmpdir:
+ model.eval()
+ input_example = torch.randn(1, 3, 128, 128).to("cuda")
+ output_example = model(input_example)
+ trt_compile(
+ model,
+ f"{tmpdir}/test_cell_sam_wrapper_trt_compile",
+ args={"precision": precision, "dynamic_batchsize": [1, 1, 1]},
+ )
+ self.assertIsNone(model._trt_compiler.engine)
+ trt_output = model(input_example)
+ # Check that lazy TRT build succeeded
+ self.assertIsNotNone(model._trt_compiler.engine)
+ torch.testing.assert_close(trt_output, output_example, rtol=0.01, atol=0.01)
+
+ @parameterized.expand([TEST_CASE_1, TEST_CASE_2])
+ def test_vista3d(self, precision):
+ model = vista3d132(in_channels=1).to("cuda")
+ with torch.no_grad(), tempfile.TemporaryDirectory() as tmpdir:
+ model.eval()
+ input_example = torch.randn(1, 1, 64, 64, 64).to("cuda")
+ output_example = model(input_example)
+ model = trt_compile(
+ model,
+ f"{tmpdir}/test_vista3d_trt_compile",
+ args={"precision": precision, "dynamic_batchsize": [1, 1, 1]},
+ submodule=["image_encoder.encoder", "class_head"],
+ )
+ self.assertIsNotNone(model.image_encoder.encoder._trt_compiler)
+ self.assertIsNotNone(model.class_head._trt_compiler)
+ trt_output = model.forward(input_example)
+ # Check that lazy TRT build succeeded
+ # TODO: set up input_example in such a way that image_encoder.encoder and class_head are called
+ # and uncomment the asserts below
+ # self.assertIsNotNone(model.image_encoder.encoder._trt_compiler.engine)
+ # self.assertIsNotNone(model.class_head._trt_compiler.engine)
+ torch.testing.assert_close(trt_output, output_example, rtol=0.01, atol=0.01)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_tversky_loss.py b/tests/test_tversky_loss.py
index efe1f2cdf3..0365503ea2 100644
--- a/tests/test_tversky_loss.py
+++ b/tests/test_tversky_loss.py
@@ -165,17 +165,12 @@ def test_ill_shape(self):
with self.assertRaisesRegex(ValueError, ""):
TverskyLoss(reduction=None)(chn_input, chn_target)
- def test_input_warnings(self):
+ @parameterized.expand([(False, False, False), (False, True, False), (False, False, True)])
+ def test_input_warnings(self, include_background, softmax, to_onehot_y):
chn_input = torch.ones((1, 1, 3))
chn_target = torch.ones((1, 1, 3))
with self.assertWarns(Warning):
- loss = TverskyLoss(include_background=False)
- loss.forward(chn_input, chn_target)
- with self.assertWarns(Warning):
- loss = TverskyLoss(softmax=True)
- loss.forward(chn_input, chn_target)
- with self.assertWarns(Warning):
- loss = TverskyLoss(to_onehot_y=True)
+ loss = TverskyLoss(include_background=include_background, softmax=softmax, to_onehot_y=to_onehot_y)
loss.forward(chn_input, chn_target)
def test_script(self):
diff --git a/tests/test_ultrasound_confidence_map_transform.py b/tests/test_ultrasound_confidence_map_transform.py
index f672961700..1c6b8f7635 100644
--- a/tests/test_ultrasound_confidence_map_transform.py
+++ b/tests/test_ultrasound_confidence_map_transform.py
@@ -11,14 +11,20 @@
from __future__ import annotations
+import os
import unittest
import numpy as np
import torch
+from parameterized import parameterized
+from PIL import Image
from monai.transforms import UltrasoundConfidenceMapTransform
+from monai.utils import optional_import
from tests.utils import assert_allclose
+_, has_scipy = optional_import("scipy")
+
TEST_INPUT = np.array(
[
[1, 2, 3, 23, 13, 22, 5, 1, 2, 3],
@@ -31,7 +37,8 @@
[1, 2, 3, 32, 33, 34, 35, 1, 2, 3],
[1, 2, 3, 36, 37, 38, 39, 1, 2, 3],
[1, 2, 3, 40, 41, 42, 43, 1, 2, 3],
- ]
+ ],
+ dtype=np.float32,
)
TEST_MASK = np.array(
@@ -46,477 +53,439 @@
[1, 1, 1, 0, 0, 0, 1, 1, 1, 0],
[1, 1, 1, 0, 0, 0, 1, 1, 1, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
- ]
+ ],
+ dtype=np.float32,
)
SINK_ALL_OUTPUT = np.array(
[
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
[
- 0.97514489,
- 0.96762971,
- 0.96164186,
- 0.95463443,
- 0.9941512,
- 0.99023054,
- 0.98559401,
- 0.98230057,
- 0.96601224,
- 0.95119599,
- ],
- [
- 0.92960533,
- 0.92638451,
- 0.9056675,
- 0.9487176,
- 0.9546961,
- 0.96165853,
- 0.96172303,
- 0.92686401,
- 0.92122613,
- 0.89957239,
- ],
- [
- 0.86490963,
- 0.85723665,
- 0.83798141,
- 0.90816201,
- 0.90816097,
- 0.90815301,
- 0.9081427,
- 0.85933627,
- 0.85146935,
- 0.82948586,
- ],
- [
- 0.77430346,
- 0.76731372,
- 0.74372311,
- 0.89128774,
- 0.89126885,
- 0.89125066,
- 0.89123521,
- 0.76858589,
- 0.76106647,
- 0.73807776,
- ],
- [
- 0.66098109,
- 0.65327697,
- 0.63090644,
- 0.33086588,
- 0.3308383,
- 0.33081937,
- 0.33080718,
- 0.6557468,
- 0.64825099,
- 0.62593375,
- ],
- [
- 0.52526945,
- 0.51832586,
- 0.49709412,
- 0.25985059,
- 0.25981009,
- 0.25977729,
- 0.25975222,
- 0.52118958,
- 0.51426328,
- 0.49323164,
- ],
- [
- 0.3697845,
- 0.36318971,
- 0.34424661,
- 0.17386804,
- 0.17382046,
- 0.17377993,
- 0.17374668,
- 0.36689317,
- 0.36036096,
- 0.3415582,
- ],
- [
- 0.19546374,
- 0.1909659,
- 0.17319999,
- 0.08423318,
- 0.08417993,
- 0.08413242,
- 0.08409104,
- 0.19393909,
- 0.18947485,
- 0.17185031,
+ 0.8884930952884654,
+ 0.8626656901726876,
+ 0.8301161870669913,
+ 0.9757179300830185,
+ 0.9989819637626414,
+ 0.9994717624885747,
+ 0.9954377526794013,
+ 0.8898638133944221,
+ 0.862604343021387,
+ 0.8277862494812598,
+ ],
+ [
+ 0.7765718877433174,
+ 0.7363731552518268,
+ 0.6871875923653379,
+ 0.9753673327387775,
+ 0.9893175316399789,
+ 0.9944181334242039,
+ 0.9936979128319371,
+ 0.7778001700035326,
+ 0.7362622619974832,
+ 0.6848377775329241,
+ ],
+ [
+ 0.6648416226360719,
+ 0.6178079903692397,
+ 0.5630152545966568,
+ 0.8278402502498404,
+ 0.82790391019578,
+ 0.8289702087149963,
+ 0.8286730258710652,
+ 0.6658773633169731,
+ 0.6176836507071695,
+ 0.5609165245633834,
+ ],
+ [
+ 0.5534420483956817,
+ 0.5055401989946189,
+ 0.451865872383879,
+ 0.7541423053657541,
+ 0.7544115886347456,
+ 0.7536884376055174,
+ 0.7524927915364896,
+ 0.5542943466824017,
+ 0.505422678400297,
+ 0.4502051549732117,
+ ],
+ [
+ 0.4423657561928356,
+ 0.398221575954319,
+ 0.35030055029978124,
+ 0.4793202144786371,
+ 0.48057175662074125,
+ 0.4812057229564038,
+ 0.48111949176149327,
+ 0.44304092606050766,
+ 0.39812149713417405,
+ 0.34902458531143377,
+ ],
+ [
+ 0.3315561576450342,
+ 0.29476346732036784,
+ 0.2558303772864961,
+ 0.35090405668257535,
+ 0.3515225984307705,
+ 0.35176548159366317,
+ 0.3516979775419521,
+ 0.33205839061494885,
+ 0.2946859567272435,
+ 0.2549042599220772,
+ ],
+ [
+ 0.22094175240967673,
+ 0.19431840633358133,
+ 0.16672448058324435,
+ 0.22716195845848167,
+ 0.22761996456848282,
+ 0.22782525614780919,
+ 0.22781876632199002,
+ 0.22127471252104777,
+ 0.19426593309729956,
+ 0.16612306610996525,
+ ],
+ [
+ 0.11044782531624744,
+ 0.09623229814933323,
+ 0.08174664901235043,
+ 0.11081911718888311,
+ 0.11102310514207447,
+ 0.1111041051969924,
+ 0.11108329076967229,
+ 0.11061376973431204,
+ 0.09620592927336903,
+ 0.08145227209865454,
],
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
- ]
+ ],
+ dtype=np.float32,
)
SINK_MID_OUTPUT = np.array(
[
+ [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
[
- 1.00000000e00,
- 1.00000000e00,
- 1.00000000e00,
- 1.00000000e00,
- 1.00000000e00,
- 1.00000000e00,
- 1.00000000e00,
- 1.00000000e00,
- 1.00000000e00,
- 1.00000000e00,
- ],
- [
- 9.99996103e-01,
- 9.99994823e-01,
- 9.99993550e-01,
- 9.99930863e-01,
- 9.99990782e-01,
- 9.99984683e-01,
- 9.99979000e-01,
- 9.99997804e-01,
- 9.99995985e-01,
- 9.99994325e-01,
- ],
- [
- 9.99989344e-01,
- 9.99988600e-01,
- 9.99984099e-01,
- 9.99930123e-01,
- 9.99926598e-01,
- 9.99824297e-01,
- 9.99815032e-01,
- 9.99991228e-01,
- 9.99990881e-01,
- 9.99988462e-01,
- ],
- [
- 9.99980787e-01,
- 9.99979264e-01,
- 9.99975828e-01,
- 9.59669286e-01,
- 9.59664779e-01,
- 9.59656566e-01,
- 9.59648332e-01,
- 9.99983882e-01,
- 9.99983038e-01,
- 9.99980732e-01,
- ],
- [
- 9.99970181e-01,
- 9.99969032e-01,
- 9.99965730e-01,
- 9.45197806e-01,
- 9.45179593e-01,
- 9.45163629e-01,
- 9.45151458e-01,
- 9.99973352e-01,
- 9.99973254e-01,
- 9.99971098e-01,
- ],
- [
- 9.99958608e-01,
- 9.99957307e-01,
- 9.99953444e-01,
- 4.24743523e-01,
- 4.24713305e-01,
- 4.24694646e-01,
- 4.24685271e-01,
- 9.99960948e-01,
- 9.99961829e-01,
- 9.99960347e-01,
- ],
- [
- 9.99946675e-01,
- 9.99945139e-01,
- 9.99940312e-01,
- 3.51353224e-01,
- 3.51304003e-01,
- 3.51268260e-01,
- 3.51245366e-01,
- 9.99947688e-01,
- 9.99950165e-01,
- 9.99949512e-01,
- ],
- [
- 9.99935877e-01,
- 9.99934088e-01,
- 9.99928982e-01,
- 2.51197134e-01,
- 2.51130273e-01,
- 2.51080014e-01,
- 2.51045852e-01,
- 9.99936187e-01,
- 9.99939716e-01,
- 9.99940022e-01,
- ],
- [
- 9.99927846e-01,
- 9.99925911e-01,
- 9.99920188e-01,
- 1.31550973e-01,
- 1.31462736e-01,
- 1.31394558e-01,
- 1.31346069e-01,
- 9.99927275e-01,
- 9.99932142e-01,
- 9.99933313e-01,
- ],
- [
- 9.99924204e-01,
- 9.99922004e-01,
- 9.99915767e-01,
- 3.04861147e-04,
- 1.95998056e-04,
- 0.00000000e00,
- 2.05182682e-05,
- 9.99923115e-01,
- 9.99928835e-01,
- 9.99930535e-01,
- ],
- ]
+ 0.9999957448889315,
+ 0.9999781044114231,
+ 0.9999142422442185,
+ 0.999853253199584,
+ 0.9999918403054282,
+ 0.9999874855193227,
+ 0.9999513619364747,
+ 0.9999589247003497,
+ 0.9999861765528631,
+ 0.9999939213967494,
+ ],
+ [
+ 0.9999918011366045,
+ 0.9999588498417253,
+ 0.9998388659316617,
+ 0.9998496524281603,
+ 0.9999154673258592,
+ 0.9997827845182361,
+ 0.9998160234579786,
+ 0.9999163964511287,
+ 0.9999743435786168,
+ 0.9999894752861168,
+ ],
+ [
+ 0.9999883847481621,
+ 0.9999427334014465,
+ 0.9997703972600652,
+ 0.9853967608835997,
+ 0.9852517829915376,
+ 0.9853308520519438,
+ 0.9854102394414211,
+ 0.9998728503298413,
+ 0.9999642585978225,
+ 0.999986204909933,
+ ],
+ [
+ 0.999985544721449,
+ 0.9999296195017368,
+ 0.9997066149628903,
+ 0.9753803016111353,
+ 0.9750688049429371,
+ 0.9749211929217173,
+ 0.9750052047129354,
+ 0.9998284130289159,
+ 0.9999558481338295,
+ 0.9999837966320273,
+ ],
+ [
+ 0.9999832723447848,
+ 0.9999192263814408,
+ 0.9996472692076177,
+ 0.90541293509353,
+ 0.9049945536526819,
+ 0.9051142437853055,
+ 0.9057005861296792,
+ 0.9997839348839027,
+ 0.9999490318922627,
+ 0.9999820419085812,
+ ],
+ [
+ 0.9999815409510937,
+ 0.9999113168889934,
+ 0.9995930143319085,
+ 0.8370025145062345,
+ 0.8358345435164332,
+ 0.8358231468627223,
+ 0.8369430449157075,
+ 0.9997408260265034,
+ 0.9999437526409107,
+ 0.9999808010740554,
+ ],
+ [
+ 0.9999803198262347,
+ 0.9999057164296593,
+ 0.9995461103528891,
+ 0.7047260555380003,
+ 0.7023346743490383,
+ 0.7022946969603594,
+ 0.7045662738042475,
+ 0.9997017258131392,
+ 0.9999399744001316,
+ 0.9999799785302944,
+ ],
+ [
+ 0.9999795785255197,
+ 0.9999022923125928,
+ 0.999510772973329,
+ 0.46283993237260707,
+ 0.4577365087549323,
+ 0.4571888733219068,
+ 0.4614967878524538,
+ 0.9996710272733927,
+ 0.9999376682163403,
+ 0.9999795067125865,
+ ],
+ [
+ 0.9999792877553907,
+ 0.9999009179811408,
+ 0.9994950057121632,
+ 0.05049460567213739,
+ 0.030946131978013824,
+ 0.0,
+ 0.019224121648385283,
+ 0.9996568912408903,
+ 0.9999367861122628,
+ 0.9999793358521326,
+ ],
+ ],
+ dtype=np.float32,
)
SINK_MIN_OUTPUT = np.array(
[
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
[
- 0.99997545,
- 0.99996582,
- 0.99995245,
- 0.99856594,
- 0.99898314,
- 0.99777223,
- 0.99394423,
- 0.98588113,
- 0.97283215,
- 0.96096504,
- ],
- [
- 0.99993872,
- 0.99993034,
- 0.9998832,
- 0.9986147,
- 0.99848741,
- 0.9972981,
- 0.99723719,
- 0.94157173,
- 0.9369832,
- 0.91964243,
- ],
- [
- 0.99990802,
- 0.99989475,
- 0.99986873,
- 0.98610197,
- 0.98610047,
- 0.98609749,
- 0.98609423,
- 0.88741275,
- 0.88112911,
- 0.86349156,
- ],
- [
- 0.99988924,
- 0.99988509,
- 0.99988698,
- 0.98234089,
- 0.98233591,
- 0.98233065,
- 0.98232562,
- 0.81475172,
- 0.80865978,
- 0.79033138,
- ],
- [
- 0.99988418,
- 0.99988484,
- 0.99988323,
- 0.86796555,
- 0.86795874,
- 0.86795283,
- 0.86794756,
- 0.72418193,
- 0.71847704,
- 0.70022037,
- ],
- [
- 0.99988241,
- 0.99988184,
- 0.99988103,
- 0.85528225,
- 0.85527303,
- 0.85526389,
- 0.85525499,
- 0.61716519,
- 0.61026209,
- 0.59503671,
- ],
- [
- 0.99988015,
- 0.99987985,
- 0.99987875,
- 0.84258114,
- 0.84257121,
- 0.84256042,
- 0.84254897,
- 0.48997924,
- 0.49083978,
- 0.46891561,
- ],
- [
- 0.99987865,
- 0.99987827,
- 0.9998772,
- 0.83279589,
- 0.83278624,
- 0.83277384,
- 0.83275897,
- 0.36345545,
- 0.33690244,
- 0.35696828,
- ],
- [
- 0.99987796,
- 0.99987756,
- 0.99987643,
- 0.82873223,
- 0.82872648,
- 0.82871803,
- 0.82870711,
- 0.0,
- 0.26106012,
- 0.29978657,
- ],
- ]
+ 0.9999961997987318,
+ 0.9999801752476248,
+ 0.9999185667341594,
+ 0.9993115972922259,
+ 0.9999536433504382,
+ 0.9997590064584757,
+ 0.9963282396026231,
+ 0.9020645423682648,
+ 0.965641014946897,
+ 0.9847003633599846,
+ ],
+ [
+ 0.9999926824858815,
+ 0.9999628275604145,
+ 0.9998472915971415,
+ 0.9992953054409239,
+ 0.9995550237000549,
+ 0.9972853256638443,
+ 0.9958871482234863,
+ 0.8006505271617617,
+ 0.9360757301263053,
+ 0.9734843475613124,
+ ],
+ [
+ 0.9999896427490426,
+ 0.9999484707116104,
+ 0.9997841142091455,
+ 0.9321779021295554,
+ 0.9308591506422442,
+ 0.9299937642438358,
+ 0.9286536283468563,
+ 0.6964658886602826,
+ 0.9106656689679997,
+ 0.9652109119709528,
+ ],
+ [
+ 0.9999871227708508,
+ 0.9999369646510842,
+ 0.9997276125796202,
+ 0.9006206490361908,
+ 0.8987968702587018,
+ 0.8965696900664386,
+ 0.8941507574801211,
+ 0.5892568658180841,
+ 0.8892240419729905,
+ 0.9590996257620853,
+ ],
+ [
+ 0.9999851119906539,
+ 0.9999280075234918,
+ 0.9996788394671484,
+ 0.778755271203017,
+ 0.7763917808258874,
+ 0.7737517385551721,
+ 0.7707980517990098,
+ 0.4788014936236403,
+ 0.8715671104783401,
+ 0.954632732759503,
+ ],
+ [
+ 0.9999835837292402,
+ 0.999921323618806,
+ 0.9996389455307461,
+ 0.7222961578407286,
+ 0.7186158832946955,
+ 0.7146983167265393,
+ 0.7105768254632475,
+ 0.3648911004360315,
+ 0.8575943501305144,
+ 0.9514642802768379,
+ ],
+ [
+ 0.9999825081019064,
+ 0.999916683268467,
+ 0.9996093996776352,
+ 0.6713490686473397,
+ 0.6664914636518112,
+ 0.6613110504728309,
+ 0.6558325489984669,
+ 0.247299682539502,
+ 0.8473037957967624,
+ 0.9493580587294981,
+ ],
+ [
+ 0.999981856118739,
+ 0.9999138938063622,
+ 0.9995907248497593,
+ 0.6331535096751639,
+ 0.6271637176135582,
+ 0.6206687804556549,
+ 0.6136262027168252,
+ 0.12576864809108962,
+ 0.8407892431959736,
+ 0.9481472656653798,
+ ],
+ [
+ 0.9999816006081851,
+ 0.9999127861527936,
+ 0.9995832399159849,
+ 0.6133274396648696,
+ 0.6086364734302403,
+ 0.6034602717119345,
+ 0.5978473214165134,
+ 0.0,
+ 0.8382338778894218,
+ 0.9477082231321966,
+ ],
+ ],
+ dtype=np.float32,
)
SINK_MASK_OUTPUT = np.array(
[
+ [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
+ [0.0, 0.0, 0.0, 0.9047934405899283, 0.9936046284605553, 0.9448690902377527, 0.0, 0.0, 0.0, 0.8363773255131761],
+ [0.0, 0.0, 0.0, 0.90375200446097, 0.9434594475474036, 0.4716831449516178, 0.0, 0.0, 0.0, 0.7364197333910302],
+ [
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.09080438801405301,
+ 0.06774182873204163,
+ 0.038207095016625024,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.6745641479264269,
+ ],
+ [
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.01731082802870267,
+ 0.013540929458217351,
+ 0.007321202161532623,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.6341231654271253,
+ ],
+ [
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0006444251665178544,
+ 0.0005397129128756325,
+ 0.0003048384803626333,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.6070178708536365,
+ ],
+ [
+ 0.0,
+ 0.0,
+ 0.0,
+ 5.406078586212675e-05,
+ 4.416783924970537e-05,
+ 2.4597362039020103e-05,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.5889413683184284,
+ ],
[
- 1.00000000e00,
- 1.00000000e00,
- 1.00000000e00,
- 1.00000000e00,
- 1.00000000e00,
- 1.00000000e00,
- 1.00000000e00,
- 1.00000000e00,
- 1.00000000e00,
- 1.00000000e00,
- ],
- [
- 0.00000000e00,
- 0.00000000e00,
- 0.00000000e00,
- 2.86416400e-01,
- 7.93271181e-01,
- 5.81341234e-01,
- 0.00000000e00,
- 0.00000000e00,
- 0.00000000e00,
- 1.98395623e-01,
- ],
- [
- 0.00000000e00,
- 0.00000000e00,
- 0.00000000e00,
- 2.66733297e-01,
- 2.80741490e-01,
- 4.14078784e-02,
- 0.00000000e00,
- 0.00000000e00,
- 0.00000000e00,
- 7.91676486e-04,
- ],
- [
- 0.00000000e00,
- 0.00000000e00,
- 0.00000000e00,
- 1.86244537e-04,
- 1.53413401e-04,
- 7.85806495e-05,
- 0.00000000e00,
- 0.00000000e00,
- 0.00000000e00,
- 5.09797387e-06,
- ],
- [
- 0.00000000e00,
- 0.00000000e00,
- 0.00000000e00,
- 9.62904581e-07,
- 7.23946225e-07,
- 3.68824440e-07,
- 0.00000000e00,
- 0.00000000e00,
- 0.00000000e00,
- 4.79525316e-08,
- ],
- [
- 0.00000000e00,
- 0.00000000e00,
- 0.00000000e00,
- 1.50939343e-10,
- 1.17724874e-10,
- 6.21760843e-11,
- 0.00000000e00,
- 0.00000000e00,
- 0.00000000e00,
- 6.08922784e-10,
- ],
- [
- 0.00000000e00,
- 0.00000000e00,
- 0.00000000e00,
- 2.57593754e-13,
- 1.94066716e-13,
- 9.83784370e-14,
- 0.00000000e00,
- 0.00000000e00,
- 0.00000000e00,
- 9.80828665e-12,
- ],
- [
- 0.00000000e00,
- 0.00000000e00,
- 0.00000000e00,
- 4.22323494e-16,
- 3.17556633e-16,
- 1.60789400e-16,
- 0.00000000e00,
- 0.00000000e00,
- 0.00000000e00,
- 1.90789819e-13,
- ],
- [
- 0.00000000e00,
- 0.00000000e00,
- 0.00000000e00,
- 7.72677888e-19,
- 5.83029424e-19,
- 2.95946659e-19,
- 0.00000000e00,
- 0.00000000e00,
- 0.00000000e00,
- 4.97038275e-15,
- ],
- [
- 2.71345908e-24,
- 5.92006757e-24,
- 2.25580089e-23,
- 3.82601970e-18,
- 3.82835349e-18,
- 3.83302158e-18,
- 3.84002606e-18,
- 8.40760586e-16,
- 1.83433696e-15,
- 1.11629633e-15,
- ],
- ]
+ 0.0,
+ 0.0,
+ 0.0,
+ 4.39259327223233e-06,
+ 3.6050656774754658e-06,
+ 2.0127120155893425e-06,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.5774279920364456,
+ ],
+ [
+ 0.0,
+ 0.0,
+ 0.0,
+ 4.0740501726718113e-07,
+ 3.374875487404489e-07,
+ 1.9113630985667455e-07,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.5709897726747111,
+ ],
+ [
+ 3.2266922388030425e-17,
+ 1.801110982679718e-14,
+ 9.325899448306927e-12,
+ 3.913608442133728e-07,
+ 3.9581822403393465e-07,
+ 4.02383505118481e-07,
+ 4.14820241328287e-07,
+ 4.281640797396309e-06,
+ 0.0023900192231620593,
+ 0.5686882523793125,
+ ],
+ ],
+ dtype=np.float32,
)
+@unittest.skipUnless(has_scipy, "Requires scipy")
class TestUltrasoundConfidenceMapTransform(unittest.TestCase):
def setUp(self):
@@ -526,6 +495,21 @@ def setUp(self):
self.input_img_torch = torch.from_numpy(TEST_INPUT).unsqueeze(0) # mock image (torch tensor)
self.input_mask_torch = torch.from_numpy(TEST_MASK).unsqueeze(0) # mock mask (torch tensor)
+ self.real_input_img_paths = [
+ os.path.join(os.path.dirname(__file__), "testing_data", "ultrasound_confidence_map", "neck_input.png"),
+ os.path.join(os.path.dirname(__file__), "testing_data", "ultrasound_confidence_map", "femur_input.png"),
+ ]
+
+ self.real_result_npy_paths = [
+ os.path.join(os.path.dirname(__file__), "testing_data", "ultrasound_confidence_map", "neck_result.npy"),
+ os.path.join(os.path.dirname(__file__), "testing_data", "ultrasound_confidence_map", "femur_result.npy"),
+ ]
+
+ self.real_input_paramaters = [
+ {"alpha": 2.0, "beta": 90, "gamma": 0.03},
+ {"alpha": 2.0, "beta": 90, "gamma": 0.06},
+ ]
+
def test_parameters(self):
# Unknown mode
with self.assertRaises(ValueError):
@@ -535,162 +519,92 @@ def test_parameters(self):
with self.assertRaises(ValueError):
UltrasoundConfidenceMapTransform(sink_mode="unknown")
- def test_rgb(self):
+ @parameterized.expand(
+ [("all", SINK_ALL_OUTPUT), ("mid", SINK_MID_OUTPUT), ("min", SINK_MIN_OUTPUT), ("mask", SINK_MASK_OUTPUT, True)]
+ )
+ def test_ultrasound_confidence_map_transform(self, sink_mode, expected_output, use_mask=False):
# RGB image
input_img_rgb = np.expand_dims(np.repeat(self.input_img_np, 3, axis=0), axis=0)
input_img_rgb_torch = torch.from_numpy(input_img_rgb)
- transform = UltrasoundConfidenceMapTransform(sink_mode="all")
- result_torch = transform(input_img_rgb_torch)
- self.assertIsInstance(result_torch, torch.Tensor)
- assert_allclose(result_torch, torch.tensor(SINK_ALL_OUTPUT), rtol=1e-4, atol=1e-4)
- result_np = transform(input_img_rgb)
- self.assertIsInstance(result_np, np.ndarray)
- assert_allclose(result_np, SINK_ALL_OUTPUT, rtol=1e-4, atol=1e-4)
-
- transform = UltrasoundConfidenceMapTransform(sink_mode="mid")
- result_torch = transform(input_img_rgb_torch)
- self.assertIsInstance(result_torch, torch.Tensor)
- assert_allclose(result_torch, torch.tensor(SINK_MID_OUTPUT), rtol=1e-4, atol=1e-4)
- result_np = transform(input_img_rgb)
- self.assertIsInstance(result_np, np.ndarray)
- assert_allclose(result_np, SINK_MID_OUTPUT, rtol=1e-4, atol=1e-4)
+ transform = UltrasoundConfidenceMapTransform(sink_mode=sink_mode)
- transform = UltrasoundConfidenceMapTransform(sink_mode="min")
- result_torch = transform(input_img_rgb_torch)
- self.assertIsInstance(result_torch, torch.Tensor)
- assert_allclose(result_torch, torch.tensor(SINK_MIN_OUTPUT), rtol=1e-4, atol=1e-4)
- result_np = transform(input_img_rgb)
- self.assertIsInstance(result_np, np.ndarray)
- assert_allclose(result_np, SINK_MIN_OUTPUT, rtol=1e-4, atol=1e-4)
+ if use_mask:
+ result_torch = transform(input_img_rgb_torch, self.input_mask_torch)
+ result_np = transform(input_img_rgb, self.input_mask_np)
+ else:
+ result_torch = transform(input_img_rgb_torch)
+ result_np = transform(input_img_rgb)
- transform = UltrasoundConfidenceMapTransform(sink_mode="mask")
- result_torch = transform(input_img_rgb_torch, self.input_mask_torch)
self.assertIsInstance(result_torch, torch.Tensor)
- assert_allclose(result_torch, torch.tensor(SINK_MASK_OUTPUT), rtol=1e-4, atol=1e-4)
- result_np = transform(input_img_rgb, self.input_mask_np)
+ assert_allclose(result_torch, torch.tensor(expected_output), rtol=1e-4, atol=1e-4)
self.assertIsInstance(result_np, np.ndarray)
- assert_allclose(result_np, SINK_MASK_OUTPUT, rtol=1e-4, atol=1e-4)
+ assert_allclose(result_np, expected_output, rtol=1e-4, atol=1e-4)
- def test_multi_channel_2d(self):
- # 2D multi-channel image
+ @parameterized.expand(
+ [
+ ("all", SINK_ALL_OUTPUT),
+ ("mid", SINK_MID_OUTPUT),
+ ("min", SINK_MIN_OUTPUT),
+ ("mask", SINK_MASK_OUTPUT, True), # Adding a flag for mask cases
+ ]
+ )
+ def test_multi_channel_2d(self, sink_mode, expected_output, use_mask=False):
input_img_rgb = np.expand_dims(np.repeat(self.input_img_np, 17, axis=0), axis=0)
input_img_rgb_torch = torch.from_numpy(input_img_rgb)
- transform = UltrasoundConfidenceMapTransform(sink_mode="all")
- result_torch = transform(input_img_rgb_torch)
- self.assertIsInstance(result_torch, torch.Tensor)
- assert_allclose(result_torch, torch.tensor(SINK_ALL_OUTPUT), rtol=1e-4, atol=1e-4)
- result_np = transform(input_img_rgb)
- self.assertIsInstance(result_np, np.ndarray)
- assert_allclose(result_np, SINK_ALL_OUTPUT, rtol=1e-4, atol=1e-4)
-
- transform = UltrasoundConfidenceMapTransform(sink_mode="mid")
- result_torch = transform(input_img_rgb_torch)
- self.assertIsInstance(result_torch, torch.Tensor)
- assert_allclose(result_torch, torch.tensor(SINK_MID_OUTPUT), rtol=1e-4, atol=1e-4)
- result_np = transform(input_img_rgb)
- self.assertIsInstance(result_np, np.ndarray)
- assert_allclose(result_np, SINK_MID_OUTPUT, rtol=1e-4, atol=1e-4)
+ transform = UltrasoundConfidenceMapTransform(sink_mode=sink_mode)
- transform = UltrasoundConfidenceMapTransform(sink_mode="min")
- result_torch = transform(input_img_rgb_torch)
- self.assertIsInstance(result_torch, torch.Tensor)
- assert_allclose(result_torch, torch.tensor(SINK_MIN_OUTPUT), rtol=1e-4, atol=1e-4)
- result_np = transform(input_img_rgb)
- self.assertIsInstance(result_np, np.ndarray)
- assert_allclose(result_np, SINK_MIN_OUTPUT, rtol=1e-4, atol=1e-4)
+ if use_mask:
+ result_torch = transform(input_img_rgb_torch, self.input_mask_torch)
+ result_np = transform(input_img_rgb, self.input_mask_np)
+ else:
+ result_torch = transform(input_img_rgb_torch)
+ result_np = transform(input_img_rgb)
- transform = UltrasoundConfidenceMapTransform(sink_mode="mask")
- result_torch = transform(input_img_rgb_torch, self.input_mask_torch)
self.assertIsInstance(result_torch, torch.Tensor)
- assert_allclose(result_torch, torch.tensor(SINK_MASK_OUTPUT), rtol=1e-4, atol=1e-4)
- result_np = transform(input_img_rgb, self.input_mask_np)
+ assert_allclose(result_torch, torch.tensor(expected_output), rtol=1e-4, atol=1e-4)
self.assertIsInstance(result_np, np.ndarray)
- assert_allclose(result_np, SINK_MASK_OUTPUT, rtol=1e-4, atol=1e-4)
+ assert_allclose(result_np, expected_output, rtol=1e-4, atol=1e-4)
- def test_non_one_first_dim(self):
- # Image without first dimension as 1
+ @parameterized.expand([("all",), ("mid",), ("min",), ("mask",)])
+ def test_non_one_first_dim(self, sink_mode):
+ transform = UltrasoundConfidenceMapTransform(sink_mode=sink_mode)
input_img_rgb = np.repeat(self.input_img_np, 3, axis=0)
input_img_rgb_torch = torch.from_numpy(input_img_rgb)
- transform = UltrasoundConfidenceMapTransform(sink_mode="all")
- with self.assertRaises(ValueError):
- transform(input_img_rgb_torch)
- with self.assertRaises(ValueError):
- transform(input_img_rgb)
-
- transform = UltrasoundConfidenceMapTransform(sink_mode="mid")
- with self.assertRaises(ValueError):
- transform(input_img_rgb_torch)
- with self.assertRaises(ValueError):
- transform(input_img_rgb)
-
- transform = UltrasoundConfidenceMapTransform(sink_mode="min")
- with self.assertRaises(ValueError):
- transform(input_img_rgb_torch)
- with self.assertRaises(ValueError):
- transform(input_img_rgb)
-
- transform = UltrasoundConfidenceMapTransform(sink_mode="mask")
- with self.assertRaises(ValueError):
- transform(input_img_rgb_torch, self.input_mask_torch)
- with self.assertRaises(ValueError):
- transform(input_img_rgb, self.input_mask_np)
-
- def test_no_first_dim(self):
- # Image without first dimension
+ if sink_mode == "mask":
+ with self.assertRaises(ValueError):
+ transform(input_img_rgb_torch, self.input_mask_torch)
+ with self.assertRaises(ValueError):
+ transform(input_img_rgb, self.input_mask_np)
+ else:
+ with self.assertRaises(ValueError):
+ transform(input_img_rgb_torch)
+ with self.assertRaises(ValueError):
+ transform(input_img_rgb)
+
+ @parameterized.expand([("all",), ("mid",), ("min",), ("mask",)])
+ def test_no_first_dim(self, sink_mode):
input_img_rgb = self.input_img_np[0]
input_img_rgb_torch = torch.from_numpy(input_img_rgb)
- transform = UltrasoundConfidenceMapTransform(sink_mode="all")
- with self.assertRaises(ValueError):
- transform(input_img_rgb_torch)
- with self.assertRaises(ValueError):
- transform(input_img_rgb)
-
- transform = UltrasoundConfidenceMapTransform(sink_mode="mid")
- with self.assertRaises(ValueError):
- transform(input_img_rgb_torch)
- with self.assertRaises(ValueError):
- transform(input_img_rgb)
+ transform = UltrasoundConfidenceMapTransform(sink_mode=sink_mode)
- transform = UltrasoundConfidenceMapTransform(sink_mode="min")
with self.assertRaises(ValueError):
transform(input_img_rgb_torch)
with self.assertRaises(ValueError):
transform(input_img_rgb)
- transform = UltrasoundConfidenceMapTransform(sink_mode="mask")
- with self.assertRaises(ValueError):
- transform(input_img_rgb_torch, self.input_mask_torch)
- with self.assertRaises(ValueError):
- transform(input_img_rgb, self.input_mask_np)
+ if sink_mode == "mask":
+ with self.assertRaises(ValueError):
+ transform(input_img_rgb_torch, self.input_mask_torch)
+ with self.assertRaises(ValueError):
+ transform(input_img_rgb, self.input_mask_np)
- def test_sink_all(self):
- transform = UltrasoundConfidenceMapTransform(sink_mode="all")
-
- # This should not raise an exception for torch tensor
- result_torch = transform(self.input_img_torch)
- self.assertIsInstance(result_torch, torch.Tensor)
-
- # This should not raise an exception for numpy array
- result_np = transform(self.input_img_np)
- self.assertIsInstance(result_np, np.ndarray)
-
- def test_sink_mid(self):
- transform = UltrasoundConfidenceMapTransform(sink_mode="mid")
-
- # This should not raise an exception for torch tensor
- result_torch = transform(self.input_img_torch)
- self.assertIsInstance(result_torch, torch.Tensor)
-
- # This should not raise an exception for numpy array
- result_np = transform(self.input_img_np)
- self.assertIsInstance(result_np, np.ndarray)
-
- def test_sink_min(self):
- transform = UltrasoundConfidenceMapTransform(sink_mode="min")
+ @parameterized.expand([("all",), ("mid",), ("min",)])
+ def test_sink_mode(self, mode):
+ transform = UltrasoundConfidenceMapTransform(sink_mode=mode)
# This should not raise an exception for torch tensor
result_torch = transform(self.input_img_torch)
@@ -752,6 +666,44 @@ def test_func(self):
output = transform(self.input_img_torch, self.input_mask_torch)
assert_allclose(output, torch.tensor(SINK_MASK_OUTPUT), rtol=1e-4, atol=1e-4)
+ def test_against_official_code(self):
+ # This test is to compare the output of the transform with the official code
+ # The official code is available at:
+ # https://campar.in.tum.de/Main/AthanasiosKaramalisCode
+
+ for input_img_path, result_npy_path, params in zip(
+ self.real_input_img_paths, self.real_result_npy_paths, self.real_input_paramaters
+ ):
+ input_img = np.array(Image.open(input_img_path))
+ input_img = np.expand_dims(input_img, axis=0)
+
+ result_img = np.load(result_npy_path)
+
+ transform = UltrasoundConfidenceMapTransform(sink_mode="all", **params)
+ output = transform(input_img)
+
+ assert_allclose(output, result_img, rtol=1e-4, atol=1e-4)
+
+ def test_against_official_code_using_cg(self):
+ # This test is to compare the output of the transform with the official code
+ # The official code is available at:
+ # https://campar.in.tum.de/Main/AthanasiosKaramalisCode
+
+ for input_img_path, result_npy_path, params in zip(
+ self.real_input_img_paths, self.real_result_npy_paths, self.real_input_paramaters
+ ):
+ input_img = np.array(Image.open(input_img_path))
+ input_img = np.expand_dims(input_img, axis=0)
+
+ result_img = np.load(result_npy_path)
+
+ transform = UltrasoundConfidenceMapTransform(
+ sink_mode="all", use_cg=True, cg_tol=1.0e-6, cg_maxiter=300, **params
+ )
+ output = transform(input_img)
+
+ assert_allclose(output, result_img, rtol=1e-2, atol=1e-2)
+
if __name__ == "__main__":
unittest.main()
diff --git a/tests/test_unetr.py b/tests/test_unetr.py
index 46018d2bc0..8c5ecb32e1 100644
--- a/tests/test_unetr.py
+++ b/tests/test_unetr.py
@@ -30,7 +30,7 @@
for num_heads in [8]:
for mlp_dim in [3072]:
for norm_name in ["instance"]:
- for pos_embed in ["perceptron"]:
+ for proj_type in ["perceptron"]:
for nd in (2, 3):
test_case = [
{
@@ -42,7 +42,7 @@
"norm_name": norm_name,
"mlp_dim": mlp_dim,
"num_heads": num_heads,
- "pos_embed": pos_embed,
+ "proj_type": proj_type,
"dropout_rate": dropout_rate,
"conv_block": True,
"res_block": False,
@@ -75,7 +75,7 @@ def test_ill_arg(self):
hidden_size=128,
mlp_dim=3072,
num_heads=12,
- pos_embed="conv",
+ proj_type="conv",
norm_name="instance",
dropout_rate=5.0,
)
@@ -89,7 +89,7 @@ def test_ill_arg(self):
hidden_size=512,
mlp_dim=3072,
num_heads=12,
- pos_embed="conv",
+ proj_type="conv",
norm_name="instance",
dropout_rate=0.5,
)
@@ -103,7 +103,7 @@ def test_ill_arg(self):
hidden_size=512,
mlp_dim=3072,
num_heads=14,
- pos_embed="conv",
+ proj_type="conv",
norm_name="batch",
dropout_rate=0.4,
)
@@ -117,13 +117,13 @@ def test_ill_arg(self):
hidden_size=768,
mlp_dim=3072,
num_heads=12,
- pos_embed="perc",
+ proj_type="perc",
norm_name="instance",
dropout_rate=0.2,
)
@parameterized.expand(TEST_CASE_UNETR)
- @SkipIfBeforePyTorchVersion((1, 9))
+ @SkipIfBeforePyTorchVersion((2, 0))
def test_script(self, input_param, input_shape, _):
net = UNETR(**(input_param))
net.eval()
diff --git a/tests/test_utils_pytorch_numpy_unification.py b/tests/test_utils_pytorch_numpy_unification.py
index 6e655289e4..90c0401e46 100644
--- a/tests/test_utils_pytorch_numpy_unification.py
+++ b/tests/test_utils_pytorch_numpy_unification.py
@@ -17,7 +17,7 @@
import torch
from parameterized import parameterized
-from monai.transforms.utils_pytorch_numpy_unification import mode, percentile
+from monai.transforms.utils_pytorch_numpy_unification import max, min, mode, percentile
from monai.utils import set_determinism
from tests.utils import TEST_NDARRAYS, assert_allclose, skip_if_quick
@@ -27,6 +27,13 @@
TEST_MODE.append([p(np.array([3.1, 4.1, 4.1, 5.1])), p(4.1), False])
TEST_MODE.append([p(np.array([3.1, 4.1, 4.1, 5.1])), p(4), True])
+TEST_MIN_MAX = []
+for p in TEST_NDARRAYS:
+ TEST_MIN_MAX.append([p(np.array([1, 2, 3, 4, 4, 5])), {}, min, p(1)])
+ TEST_MIN_MAX.append([p(np.array([[3.1, 4.1, 4.1, 5.1], [3, 5, 4.1, 5]])), {"dim": 1}, min, p([3.1, 3])])
+ TEST_MIN_MAX.append([p(np.array([1, 2, 3, 4, 4, 5])), {}, max, p(5)])
+ TEST_MIN_MAX.append([p(np.array([[3.1, 4.1, 4.1, 5.1], [3, 5, 4.1, 5]])), {"dim": 1}, max, p([5.1, 5])])
+
class TestPytorchNumpyUnification(unittest.TestCase):
@@ -74,6 +81,11 @@ def test_mode(self, array, expected, to_long):
res = mode(array, to_long=to_long)
assert_allclose(res, expected)
+ @parameterized.expand(TEST_MIN_MAX)
+ def test_min_max(self, array, input_params, func, expected):
+ res = func(array, **input_params)
+ assert_allclose(res, expected, type_test=False)
+
if __name__ == "__main__":
unittest.main()
diff --git a/tests/test_vector_quantizer.py b/tests/test_vector_quantizer.py
new file mode 100644
index 0000000000..43533d0377
--- /dev/null
+++ b/tests/test_vector_quantizer.py
@@ -0,0 +1,89 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import unittest
+from math import prod
+
+import torch
+from parameterized import parameterized
+
+from monai.networks.layers import EMAQuantizer, VectorQuantizer
+
+TEST_CASES = [
+ [{"spatial_dims": 2, "num_embeddings": 16, "embedding_dim": 8}, (1, 8, 4, 4), (1, 4, 4)],
+ [{"spatial_dims": 3, "num_embeddings": 16, "embedding_dim": 8}, (1, 8, 4, 4, 4), (1, 4, 4, 4)],
+]
+
+
+class TestEMA(unittest.TestCase):
+ @parameterized.expand(TEST_CASES)
+ def test_ema_shape(self, input_param, input_shape, output_shape):
+ layer = EMAQuantizer(**input_param)
+ x = torch.randn(input_shape)
+ layer = layer.train()
+ outputs = layer(x)
+ self.assertEqual(outputs[0].shape, input_shape)
+ self.assertEqual(outputs[2].shape, output_shape)
+
+ layer = layer.eval()
+ outputs = layer(x)
+ self.assertEqual(outputs[0].shape, input_shape)
+ self.assertEqual(outputs[2].shape, output_shape)
+
+ @parameterized.expand(TEST_CASES)
+ def test_ema_quantize(self, input_param, input_shape, output_shape):
+ layer = EMAQuantizer(**input_param)
+ x = torch.randn(input_shape)
+ outputs = layer.quantize(x)
+ self.assertEqual(outputs[0].shape, (prod(input_shape[2:]), input_shape[1])) # (HxW[xD], C)
+ self.assertEqual(outputs[1].shape, (prod(input_shape[2:]), input_param["num_embeddings"])) # (HxW[xD], E)
+ self.assertEqual(outputs[2].shape, (input_shape[0],) + input_shape[2:]) # (1, H, W, [D])
+
+ def test_ema(self):
+ layer = EMAQuantizer(spatial_dims=2, num_embeddings=2, embedding_dim=2, epsilon=0, decay=0)
+ original_weight_0 = layer.embedding.weight[0].clone()
+ original_weight_1 = layer.embedding.weight[1].clone()
+ x_0 = original_weight_0
+ x_0 = x_0.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
+ x_0 = x_0.repeat(1, 1, 1, 2) + 0.001
+
+ x_1 = original_weight_1
+ x_1 = x_1.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
+ x_1 = x_1.repeat(1, 1, 1, 2)
+
+ x = torch.cat([x_0, x_1], dim=0)
+ layer = layer.train()
+ _ = layer(x)
+
+ self.assertTrue(all(layer.embedding.weight[0] != original_weight_0))
+ self.assertTrue(all(layer.embedding.weight[1] == original_weight_1))
+
+
+class TestVectorQuantizer(unittest.TestCase):
+ @parameterized.expand(TEST_CASES)
+ def test_vector_quantizer_shape(self, input_param, input_shape, output_shape):
+ layer = VectorQuantizer(EMAQuantizer(**input_param))
+ x = torch.randn(input_shape)
+ outputs = layer(x)
+ self.assertEqual(outputs[1].shape, input_shape)
+
+ @parameterized.expand(TEST_CASES)
+ def test_vector_quantizer_quantize(self, input_param, input_shape, output_shape):
+ layer = VectorQuantizer(EMAQuantizer(**input_param))
+ x = torch.randn(input_shape)
+ outputs = layer.quantize(x)
+ self.assertEqual(outputs.shape, output_shape)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_pytorch_version_after.py b/tests/test_version_after.py
similarity index 72%
rename from tests/test_pytorch_version_after.py
rename to tests/test_version_after.py
index 147707d2c0..b6cb741382 100644
--- a/tests/test_pytorch_version_after.py
+++ b/tests/test_version_after.py
@@ -15,9 +15,9 @@
from parameterized import parameterized
-from monai.utils import pytorch_after
+from monai.utils import compute_capabilities_after, pytorch_after
-TEST_CASES = (
+TEST_CASES_PT = (
(1, 5, 9, "1.6.0"),
(1, 6, 0, "1.6.0"),
(1, 6, 1, "1.6.0", False),
@@ -36,14 +36,30 @@
(1, 6, 1, "1.6.0+cpu", False),
)
+TEST_CASES_SM = [
+ # (major, minor, sm, expected)
+ (6, 1, "6.1", True),
+ (6, 1, "6.0", False),
+ (6, 0, "8.6", True),
+ (7, 0, "8", True),
+ (8, 6, "8", False),
+]
+
class TestPytorchVersionCompare(unittest.TestCase):
- @parameterized.expand(TEST_CASES)
+ @parameterized.expand(TEST_CASES_PT)
def test_compare(self, a, b, p, current, expected=True):
"""Test pytorch_after with a and b"""
self.assertEqual(pytorch_after(a, b, p, current), expected)
+class TestComputeCapabilitiesAfter(unittest.TestCase):
+
+ @parameterized.expand(TEST_CASES_SM)
+ def test_compute_capabilities_after(self, major, minor, sm, expected):
+ self.assertEqual(compute_capabilities_after(major, minor, sm), expected)
+
+
if __name__ == "__main__":
unittest.main()
diff --git a/tests/test_vis_cam.py b/tests/test_vis_cam.py
index b641599af2..68b12de2f8 100644
--- a/tests/test_vis_cam.py
+++ b/tests/test_vis_cam.py
@@ -70,6 +70,8 @@ class TestClassActivationMap(unittest.TestCase):
@parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
def test_shape(self, input_data, expected_shape):
+ model = None
+
if input_data["model"] == "densenet2d":
model = DenseNet121(spatial_dims=2, in_channels=1, out_channels=3)
if input_data["model"] == "densenet3d":
@@ -80,6 +82,7 @@ def test_shape(self, input_data, expected_shape):
model = SEResNet50(spatial_dims=2, in_channels=3, num_classes=4)
if input_data["model"] == "senet3d":
model = SEResNet50(spatial_dims=3, in_channels=3, num_classes=4)
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
diff --git a/tests/test_vis_gradcam.py b/tests/test_vis_gradcam.py
index 325b74b3ce..f77d916a5b 100644
--- a/tests/test_vis_gradcam.py
+++ b/tests/test_vis_gradcam.py
@@ -153,6 +153,8 @@ class TestGradientClassActivationMap(unittest.TestCase):
@parameterized.expand(TESTS)
def test_shape(self, cam_class, input_data, expected_shape):
+ model = None
+
if input_data["model"] == "densenet2d":
model = DenseNet121(spatial_dims=2, in_channels=1, out_channels=3)
elif input_data["model"] == "densenet2d_bin":
diff --git a/tests/test_vista3d.py b/tests/test_vista3d.py
new file mode 100644
index 0000000000..d3b4e0c10e
--- /dev/null
+++ b/tests/test_vista3d.py
@@ -0,0 +1,85 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import unittest
+
+import torch
+from parameterized import parameterized
+
+from monai.networks import eval_mode
+from monai.networks.nets import VISTA3D, SegResNetDS2
+from monai.networks.nets.vista3d import ClassMappingClassify, PointMappingSAM
+from tests.utils import SkipIfBeforePyTorchVersion, skip_if_quick
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+
+TEST_CASES = [
+ [{"encoder_embed_dim": 48, "in_channels": 1}, {}, (1, 1, 64, 64, 64), (1, 1, 64, 64, 64)],
+ [{"encoder_embed_dim": 48, "in_channels": 2}, {}, (1, 2, 64, 64, 64), (1, 1, 64, 64, 64)],
+ [
+ {"encoder_embed_dim": 48, "in_channels": 1},
+ {"class_vector": torch.tensor([1, 2, 3], device=device)},
+ (1, 1, 64, 64, 64),
+ (3, 1, 64, 64, 64),
+ ],
+ [
+ {"encoder_embed_dim": 48, "in_channels": 1},
+ {
+ "point_coords": torch.tensor([[[1, 2, 3], [1, 2, 3]]], device=device),
+ "point_labels": torch.tensor([[1, 0]], device=device),
+ },
+ (1, 1, 64, 64, 64),
+ (1, 1, 64, 64, 64),
+ ],
+ [
+ {"encoder_embed_dim": 48, "in_channels": 1},
+ {
+ "class_vector": torch.tensor([1, 2], device=device),
+ "point_coords": torch.tensor([[[1, 2, 3], [1, 2, 3]], [[1, 2, 3], [1, 2, 3]]], device=device),
+ "point_labels": torch.tensor([[1, 0], [1, 0]], device=device),
+ },
+ (1, 1, 64, 64, 64),
+ (2, 1, 64, 64, 64),
+ ],
+]
+
+
+@SkipIfBeforePyTorchVersion((1, 11))
+@skip_if_quick
+class TestVista3d(unittest.TestCase):
+
+ @parameterized.expand(TEST_CASES)
+ def test_vista3d_shape(self, args, input_params, input_shape, expected_shape):
+ segresnet = SegResNetDS2(
+ in_channels=args["in_channels"],
+ blocks_down=(1, 2, 2, 4, 4),
+ norm="instance",
+ out_channels=args["encoder_embed_dim"],
+ init_filters=args["encoder_embed_dim"],
+ dsdepth=1,
+ )
+ point_head = PointMappingSAM(feature_size=args["encoder_embed_dim"], n_classes=512, last_supported=132)
+ class_head = ClassMappingClassify(n_classes=512, feature_size=args["encoder_embed_dim"], use_mlp=True)
+ net = VISTA3D(image_encoder=segresnet, class_head=class_head, point_head=point_head).to(device)
+ with eval_mode(net):
+ result = net.forward(
+ torch.randn(input_shape).to(device),
+ point_coords=input_params.get("point_coords", None),
+ point_labels=input_params.get("point_labels", None),
+ class_vector=input_params.get("class_vector", None),
+ )
+ self.assertEqual(result.shape, expected_shape)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_vista3d_sampler.py b/tests/test_vista3d_sampler.py
new file mode 100644
index 0000000000..6945d250d2
--- /dev/null
+++ b/tests/test_vista3d_sampler.py
@@ -0,0 +1,100 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import unittest
+
+import torch
+from parameterized import parameterized
+
+from monai.apps.vista3d.sampler import sample_prompt_pairs
+
+label = torch.zeros([1, 1, 64, 64, 64])
+label[:, :, :10, :10, :10] = 1
+label[:, :, 20:30, 20:30, 20:30] = 2
+label[:, :, 30:40, 30:40, 30:40] = 3
+label1 = torch.zeros([1, 1, 64, 64, 64])
+
+TEST_VISTA_SAMPLE_PROMPT = [
+ [
+ {
+ "labels": label,
+ "label_set": [0, 1, 2, 3, 4],
+ "max_prompt": 5,
+ "max_foreprompt": 4,
+ "max_backprompt": 1,
+ "drop_label_prob": 0,
+ "drop_point_prob": 0,
+ },
+ [4, 4, 4, 4],
+ ],
+ [
+ {
+ "labels": label,
+ "label_set": [0, 1],
+ "max_prompt": 5,
+ "max_foreprompt": 4,
+ "max_backprompt": 1,
+ "drop_label_prob": 0,
+ "drop_point_prob": 1,
+ },
+ [2, None, None, 2],
+ ],
+ [
+ {
+ "labels": label,
+ "label_set": [0, 1, 2, 3, 4],
+ "max_prompt": 5,
+ "max_foreprompt": 4,
+ "max_backprompt": 1,
+ "drop_label_prob": 1,
+ "drop_point_prob": 0,
+ },
+ [None, 3, 3, 3],
+ ],
+ [
+ {
+ "labels": label1,
+ "label_set": [0, 1],
+ "max_prompt": 5,
+ "max_foreprompt": 4,
+ "max_backprompt": 1,
+ "drop_label_prob": 0,
+ "drop_point_prob": 1,
+ },
+ [1, None, None, 1],
+ ],
+ [
+ {
+ "labels": label1,
+ "label_set": [0, 1],
+ "max_prompt": 5,
+ "max_foreprompt": 4,
+ "max_backprompt": 0,
+ "drop_label_prob": 0,
+ "drop_point_prob": 1,
+ },
+ [None, None, None, None],
+ ],
+]
+
+
+class TestGeneratePrompt(unittest.TestCase):
+ @parameterized.expand(TEST_VISTA_SAMPLE_PROMPT)
+ def test_result(self, input_data, expected):
+ output = sample_prompt_pairs(**input_data)
+ result = [i.shape[0] if i is not None else None for i in output]
+ self.assertEqual(result, expected)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_vista3d_transforms.py b/tests/test_vista3d_transforms.py
new file mode 100644
index 0000000000..9d61fe2fc2
--- /dev/null
+++ b/tests/test_vista3d_transforms.py
@@ -0,0 +1,94 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import unittest
+from unittest.case import skipUnless
+
+import torch
+from parameterized import parameterized
+
+from monai.apps.vista3d.transforms import VistaPostTransformd, VistaPreTransformd
+from monai.utils import min_version
+from monai.utils.module import optional_import
+
+measure, has_measure = optional_import("skimage.measure", "0.14.2", min_version)
+
+device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+
+
+TEST_VISTA_PRETRANSFORM = [
+ [
+ {"label_prompt": [1], "points": [[0, 0, 0]], "point_labels": [1]},
+ {"label_prompt": [1], "points": [[0, 0, 0]], "point_labels": [3]},
+ ],
+ [
+ {"label_prompt": [2], "points": [[0, 0, 0]], "point_labels": [0]},
+ {"label_prompt": [2], "points": [[0, 0, 0]], "point_labels": [2]},
+ ],
+ [
+ {"label_prompt": [3], "points": [[0, 0, 0]], "point_labels": [0]},
+ {"label_prompt": [4, 5], "points": [[0, 0, 0]], "point_labels": [0]},
+ ],
+ [
+ {"label_prompt": [6], "points": [[0, 0, 0]], "point_labels": [0]},
+ {"label_prompt": [7, 8], "points": [[0, 0, 0]], "point_labels": [0]},
+ ],
+]
+
+
+pred1 = torch.zeros([2, 64, 64, 64])
+pred1[0, :10, :10, :10] = 1
+pred1[1, 20:30, 20:30, 20:30] = 1
+output1 = torch.zeros([1, 64, 64, 64])
+output1[:, :10, :10, :10] = 2
+output1[:, 20:30, 20:30, 20:30] = 3
+
+# -1 is needed since pred should be before sigmoid.
+pred2 = torch.zeros([1, 64, 64, 64]) - 1
+pred2[:, :10, :10, :10] = 1
+pred2[:, 20:30, 20:30, 20:30] = 1
+output2 = torch.zeros([1, 64, 64, 64])
+output2[:, 20:30, 20:30, 20:30] = 1
+
+TEST_VISTA_POSTTRANSFORM = [
+ [{"pred": pred1.to(device), "label_prompt": torch.tensor([2, 3]).to(device)}, output1.to(device)],
+ [
+ {
+ "pred": pred2.to(device),
+ "points": torch.tensor([[25, 25, 25]]).to(device),
+ "point_labels": torch.tensor([1]).to(device),
+ },
+ output2.to(device),
+ ],
+]
+
+
+class TestVistaPreTransformd(unittest.TestCase):
+ @parameterized.expand(TEST_VISTA_PRETRANSFORM)
+ def test_result(self, input_data, expected):
+ transform = VistaPreTransformd(keys="image", subclass={"3": [4, 5], "6": [7, 8]}, special_index=[1, 2])
+ result = transform(input_data)
+ self.assertEqual(result, expected)
+
+
+@skipUnless(has_measure, "skimage.measure required")
+class TestVistaPostTransformd(unittest.TestCase):
+ @parameterized.expand(TEST_VISTA_POSTTRANSFORM)
+ def test_result(self, input_data, expected):
+ transform = VistaPostTransformd(keys="pred")
+ result = transform(input_data)
+ self.assertEqual((result["pred"] == expected).all(), True)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_vista3d_utils.py b/tests/test_vista3d_utils.py
new file mode 100644
index 0000000000..5a0caedd61
--- /dev/null
+++ b/tests/test_vista3d_utils.py
@@ -0,0 +1,162 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import unittest
+from unittest.case import skipUnless
+
+import numpy as np
+import torch
+from parameterized import parameterized
+
+from monai.transforms.utils import convert_points_to_disc, keep_merge_components_with_points, sample_points_from_label
+from monai.utils import min_version
+from monai.utils.module import optional_import
+from tests.utils import skip_if_no_cuda, skip_if_quick
+
+cp, has_cp = optional_import("cupy")
+cucim_skimage, has_cucim = optional_import("cucim.skimage")
+measure, has_measure = optional_import("skimage.measure", "0.14.2", min_version)
+
+device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+
+
+TESTS_SAMPLE_POINTS_FROM_LABEL = []
+for use_center in [True, False]:
+ labels = torch.zeros(1, 1, 32, 32, 32)
+ labels[0, 0, 5:10, 5:10, 5:10] = 1
+ labels[0, 0, 10:15, 10:15, 10:15] = 3
+ labels[0, 0, 20:25, 20:25, 20:25] = 5
+ TESTS_SAMPLE_POINTS_FROM_LABEL.append(
+ [{"labels": labels, "label_set": (1, 3, 5), "use_center": use_center}, (3, 1, 3), (3, 1)]
+ )
+
+TEST_CONVERT_POINTS_TO_DISC = []
+for radius in [1, 2]:
+ for disc in [True, False]:
+ image_size = (32, 32, 32)
+ point = torch.randn(3, 1, 3)
+ point_label = torch.randint(0, 4, (3, 1))
+ expected_shape = (point.shape[0], 2, *image_size)
+ TEST_CONVERT_POINTS_TO_DISC.append(
+ [
+ {"image_size": image_size, "point": point, "point_label": point_label, "radius": radius, "disc": disc},
+ expected_shape,
+ ]
+ )
+ image_size = (16, 32, 64)
+ point = torch.tensor([[[8, 16, 42], [2, 8, 21]]])
+ point_label = torch.tensor([[1, 0]])
+ expected_shape = (point.shape[0], 2, *image_size)
+ TEST_CONVERT_POINTS_TO_DISC.append(
+ [
+ {"image_size": image_size, "point": point, "point_label": point_label, "radius": radius, "disc": disc},
+ expected_shape,
+ ]
+ )
+
+TEST_CONVERT_POINTS_TO_DISC_VALUE = []
+image_size = (16, 32, 64)
+point = torch.tensor([[[8, 16, 42], [2, 8, 21]]])
+point_label = torch.tensor([[1, 0]])
+expected_shape = (point.shape[0], 2, *image_size)
+for radius in [5, 10]:
+ for disc in [True, False]:
+ TEST_CONVERT_POINTS_TO_DISC_VALUE.append(
+ [
+ {"image_size": image_size, "point": point, "point_label": point_label, "radius": radius, "disc": disc},
+ [point, point_label],
+ ]
+ )
+
+
+TEST_LCC_MASK_POINT_TORCH = []
+for bs in [1, 2]:
+ for num_points in [1, 3]:
+ shape = (bs, 1, 128, 32, 32)
+ TEST_LCC_MASK_POINT_TORCH.append(
+ [
+ {
+ "img_pos": torch.randint(0, 2, shape, dtype=torch.bool),
+ "img_neg": torch.randint(0, 2, shape, dtype=torch.bool),
+ "point_coords": torch.randint(0, 10, (bs, num_points, 3)),
+ "point_labels": torch.randint(0, 4, (bs, num_points)),
+ },
+ shape,
+ ]
+ )
+
+TEST_LCC_MASK_POINT_NP = []
+for bs in [1, 2]:
+ for num_points in [1, 3]:
+ shape = (bs, 1, 32, 32, 64)
+ TEST_LCC_MASK_POINT_NP.append(
+ [
+ {
+ "img_pos": np.random.randint(0, 2, shape, dtype=bool),
+ "img_neg": np.random.randint(0, 2, shape, dtype=bool),
+ "point_coords": np.random.randint(0, 5, (bs, num_points, 3)),
+ "point_labels": np.random.randint(0, 4, (bs, num_points)),
+ },
+ shape,
+ ]
+ )
+
+
+@skipUnless(has_measure or cucim_skimage, "skimage or cucim.skimage required")
+class TestSamplePointsFromLabel(unittest.TestCase):
+
+ @parameterized.expand(TESTS_SAMPLE_POINTS_FROM_LABEL)
+ def test_shape(self, input_data, expected_point_shape, expected_point_label_shape):
+ point, point_label = sample_points_from_label(**input_data)
+ self.assertEqual(point.shape, expected_point_shape)
+ self.assertEqual(point_label.shape, expected_point_label_shape)
+
+
+class TestConvertPointsToDisc(unittest.TestCase):
+
+ @parameterized.expand(TEST_CONVERT_POINTS_TO_DISC)
+ def test_shape(self, input_data, expected_shape):
+ result = convert_points_to_disc(**input_data)
+ self.assertEqual(result.shape, expected_shape)
+
+ @parameterized.expand(TEST_CONVERT_POINTS_TO_DISC_VALUE)
+ def test_value(self, input_data, points):
+ result = convert_points_to_disc(**input_data)
+ point, point_label = points
+ for i in range(point.shape[0]):
+ for j in range(point.shape[1]):
+ self.assertEqual(result[i, point_label[i, j], point[i, j][0], point[i, j][1], point[i, j][2]], True)
+
+
+@skipUnless(has_measure or cucim_skimage, "skimage or cucim.skimage required")
+class TestKeepMergeComponentsWithPoints(unittest.TestCase):
+
+ @skip_if_quick
+ @skip_if_no_cuda
+ @skipUnless(has_cp and cucim_skimage, "cupy and cucim.skimage required")
+ @parameterized.expand(TEST_LCC_MASK_POINT_TORCH)
+ def test_cp_shape(self, input_data, shape):
+ for key in input_data:
+ input_data[key] = input_data[key].to(device)
+ mask = keep_merge_components_with_points(**input_data)
+ self.assertEqual(mask.shape, shape)
+
+ @skipUnless(has_measure, "skimage required")
+ @parameterized.expand(TEST_LCC_MASK_POINT_NP)
+ def test_np_shape(self, input_data, shape):
+ mask = keep_merge_components_with_points(**input_data)
+ self.assertEqual(mask.shape, shape)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_vit.py b/tests/test_vit.py
index a84883cba0..a3ffd0b2ef 100644
--- a/tests/test_vit.py
+++ b/tests/test_vit.py
@@ -30,7 +30,7 @@
for mlp_dim in [3072]:
for num_layers in [4]:
for num_classes in [8]:
- for pos_embed in ["conv", "perceptron"]:
+ for proj_type in ["conv", "perceptron"]:
for classification in [False, True]:
for nd in (2, 3):
test_case = [
@@ -42,7 +42,7 @@
"mlp_dim": mlp_dim,
"num_layers": num_layers,
"num_heads": num_heads,
- "pos_embed": pos_embed,
+ "proj_type": proj_type,
"classification": classification,
"num_classes": num_classes,
"dropout_rate": dropout_rate,
@@ -69,79 +69,44 @@ def test_shape(self, input_param, input_shape, expected_shape):
result, _ = net(torch.randn(input_shape))
self.assertEqual(result.shape, expected_shape)
- def test_ill_arg(self):
+ @parameterized.expand(
+ [
+ (1, (128, 128, 128), (16, 16, 16), 128, 3072, 12, 12, "conv", False, 5.0),
+ (1, (32, 32, 32), (64, 64, 64), 512, 3072, 12, 8, "perceptron", False, 0.3),
+ (1, (96, 96, 96), (8, 8, 8), 512, 3072, 12, 14, "conv", False, 0.3),
+ (1, (97, 97, 97), (4, 4, 4), 768, 3072, 12, 8, "perceptron", True, 0.3),
+ (4, (96, 96, 96), (16, 16, 16), 768, 3072, 12, 12, "perc", False, 0.3),
+ ]
+ )
+ def test_ill_arg(
+ self,
+ in_channels,
+ img_size,
+ patch_size,
+ hidden_size,
+ mlp_dim,
+ num_layers,
+ num_heads,
+ proj_type,
+ classification,
+ dropout_rate,
+ ):
with self.assertRaises(ValueError):
ViT(
- in_channels=1,
- img_size=(128, 128, 128),
- patch_size=(16, 16, 16),
- hidden_size=128,
- mlp_dim=3072,
- num_layers=12,
- num_heads=12,
- pos_embed="conv",
- classification=False,
- dropout_rate=5.0,
+ in_channels=in_channels,
+ img_size=img_size,
+ patch_size=patch_size,
+ hidden_size=hidden_size,
+ mlp_dim=mlp_dim,
+ num_layers=num_layers,
+ num_heads=num_heads,
+ proj_type=proj_type,
+ classification=classification,
+ dropout_rate=dropout_rate,
)
- with self.assertRaises(ValueError):
- ViT(
- in_channels=1,
- img_size=(32, 32, 32),
- patch_size=(64, 64, 64),
- hidden_size=512,
- mlp_dim=3072,
- num_layers=12,
- num_heads=8,
- pos_embed="perceptron",
- classification=False,
- dropout_rate=0.3,
- )
-
- with self.assertRaises(ValueError):
- ViT(
- in_channels=1,
- img_size=(96, 96, 96),
- patch_size=(8, 8, 8),
- hidden_size=512,
- mlp_dim=3072,
- num_layers=12,
- num_heads=14,
- pos_embed="conv",
- classification=False,
- dropout_rate=0.3,
- )
-
- with self.assertRaises(ValueError):
- ViT(
- in_channels=1,
- img_size=(97, 97, 97),
- patch_size=(4, 4, 4),
- hidden_size=768,
- mlp_dim=3072,
- num_layers=12,
- num_heads=8,
- pos_embed="perceptron",
- classification=True,
- dropout_rate=0.3,
- )
-
- with self.assertRaises(ValueError):
- ViT(
- in_channels=4,
- img_size=(96, 96, 96),
- patch_size=(16, 16, 16),
- hidden_size=768,
- mlp_dim=3072,
- num_layers=12,
- num_heads=12,
- pos_embed="perc",
- classification=False,
- dropout_rate=0.3,
- )
-
- @parameterized.expand(TEST_CASE_Vit)
- @SkipIfBeforePyTorchVersion((1, 9))
+ @parameterized.expand(TEST_CASE_Vit[:1])
+ @SkipIfBeforePyTorchVersion((2, 0))
def test_script(self, input_param, input_shape, _):
net = ViT(**(input_param))
net.eval()
diff --git a/tests/test_vitautoenc.py b/tests/test_vitautoenc.py
index cc3d493bb3..9a503948d0 100644
--- a/tests/test_vitautoenc.py
+++ b/tests/test_vitautoenc.py
@@ -23,7 +23,7 @@
for in_channels in [1, 4]:
for img_size in [64, 96, 128]:
for patch_size in [16]:
- for pos_embed in ["conv", "perceptron"]:
+ for proj_type in ["conv", "perceptron"]:
for nd in [2, 3]:
test_case = [
{
@@ -34,7 +34,7 @@
"mlp_dim": 3072,
"num_layers": 4,
"num_heads": 12,
- "pos_embed": pos_embed,
+ "proj_type": proj_type,
"dropout_rate": 0.6,
"spatial_dims": nd,
},
@@ -54,7 +54,7 @@
"mlp_dim": 3072,
"num_layers": 4,
"num_heads": 12,
- "pos_embed": "conv",
+ "proj_type": "conv",
"dropout_rate": 0.6,
"spatial_dims": 3,
},
@@ -82,83 +82,30 @@ def test_shape(self, input_param, input_shape, expected_shape):
result, _ = net(torch.randn(input_shape))
self.assertEqual(result.shape, expected_shape)
- def test_ill_arg(self):
+ @parameterized.expand(
+ [
+ (1, (32, 32, 32), (64, 64, 64), 512, 3072, 12, 8, "perceptron", 0.3), # img_size_too_large_for_patch_size
+ (1, (96, 96, 96), (8, 8, 8), 512, 3072, 12, 14, "conv", 0.3), # num_heads_out_of_bound
+ (1, (97, 97, 97), (4, 4, 4), 768, 3072, 12, 8, "perceptron", 0.3), # img_size_not_divisible_by_patch_size
+ (4, (96, 96, 96), (16, 16, 16), 768, 3072, 12, 12, "perc", 0.3), # invalid_pos_embed
+ (4, (96, 96, 96), (9, 9, 9), 768, 3072, 12, 12, "perc", 0.3), # patch_size_not_divisible
+ # Add more test cases as needed
+ ]
+ )
+ def test_ill_arg(
+ self, in_channels, img_size, patch_size, hidden_size, mlp_dim, num_layers, num_heads, proj_type, dropout_rate
+ ):
with self.assertRaises(ValueError):
ViTAutoEnc(
- in_channels=1,
- img_size=(128, 128, 128),
- patch_size=(16, 16, 16),
- hidden_size=128,
- mlp_dim=3072,
- num_layers=12,
- num_heads=12,
- pos_embed="conv",
- dropout_rate=5.0,
- )
-
- with self.assertRaises(ValueError):
- ViTAutoEnc(
- in_channels=1,
- img_size=(32, 32, 32),
- patch_size=(64, 64, 64),
- hidden_size=512,
- mlp_dim=3072,
- num_layers=12,
- num_heads=8,
- pos_embed="perceptron",
- dropout_rate=0.3,
- )
-
- with self.assertRaises(ValueError):
- ViTAutoEnc(
- in_channels=1,
- img_size=(96, 96, 96),
- patch_size=(8, 8, 8),
- hidden_size=512,
- mlp_dim=3072,
- num_layers=12,
- num_heads=14,
- pos_embed="conv",
- dropout_rate=0.3,
- )
-
- with self.assertRaises(ValueError):
- ViTAutoEnc(
- in_channels=1,
- img_size=(97, 97, 97),
- patch_size=(4, 4, 4),
- hidden_size=768,
- mlp_dim=3072,
- num_layers=12,
- num_heads=8,
- pos_embed="perceptron",
- dropout_rate=0.3,
- )
-
- with self.assertRaises(ValueError):
- ViTAutoEnc(
- in_channels=4,
- img_size=(96, 96, 96),
- patch_size=(16, 16, 16),
- hidden_size=768,
- mlp_dim=3072,
- num_layers=12,
- num_heads=12,
- pos_embed="perc",
- dropout_rate=0.3,
- )
-
- with self.assertRaises(ValueError):
- ViTAutoEnc(
- in_channels=4,
- img_size=(96, 96, 96),
- patch_size=(9, 9, 9),
- hidden_size=768,
- mlp_dim=3072,
- num_layers=12,
- num_heads=12,
- pos_embed="perc",
- dropout_rate=0.3,
+ in_channels=in_channels,
+ img_size=img_size,
+ patch_size=patch_size,
+ hidden_size=hidden_size,
+ mlp_dim=mlp_dim,
+ num_layers=num_layers,
+ num_heads=num_heads,
+ proj_type=proj_type,
+ dropout_rate=dropout_rate,
)
diff --git a/tests/test_vqvae.py b/tests/test_vqvae.py
new file mode 100644
index 0000000000..4916dc2faa
--- /dev/null
+++ b/tests/test_vqvae.py
@@ -0,0 +1,274 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import unittest
+
+import torch
+from parameterized import parameterized
+
+from monai.networks import eval_mode
+from monai.networks.nets.vqvae import VQVAE
+from tests.utils import SkipIfBeforePyTorchVersion
+
+TEST_CASES = [
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "channels": (4, 4),
+ "num_res_layers": 1,
+ "num_res_channels": (4, 4),
+ "downsample_parameters": ((2, 4, 1, 1),) * 2,
+ "upsample_parameters": ((2, 4, 1, 1, 0),) * 2,
+ "num_embeddings": 8,
+ "embedding_dim": 8,
+ },
+ (1, 1, 8, 8),
+ (1, 1, 8, 8),
+ ],
+ [
+ {
+ "spatial_dims": 3,
+ "in_channels": 1,
+ "out_channels": 1,
+ "channels": (4, 4),
+ "num_res_layers": 1,
+ "num_res_channels": 4,
+ "downsample_parameters": ((2, 4, 1, 1),) * 2,
+ "upsample_parameters": ((2, 4, 1, 1, 0),) * 2,
+ "num_embeddings": 8,
+ "embedding_dim": 8,
+ },
+ (1, 1, 8, 8, 8),
+ (1, 1, 8, 8, 8),
+ ],
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "channels": (4, 4),
+ "num_res_layers": 1,
+ "num_res_channels": (4, 4),
+ "downsample_parameters": (2, 4, 1, 1),
+ "upsample_parameters": ((2, 4, 1, 1, 0),) * 2,
+ "num_embeddings": 8,
+ "embedding_dim": 8,
+ },
+ (1, 1, 8, 8),
+ (1, 1, 8, 8),
+ ],
+ [
+ {
+ "spatial_dims": 3,
+ "in_channels": 1,
+ "out_channels": 1,
+ "channels": (4, 4),
+ "num_res_layers": 1,
+ "num_res_channels": (4, 4),
+ "downsample_parameters": ((2, 4, 1, 1),) * 2,
+ "upsample_parameters": (2, 4, 1, 1, 0),
+ "num_embeddings": 8,
+ "embedding_dim": 8,
+ },
+ (1, 1, 8, 8, 8),
+ (1, 1, 8, 8, 8),
+ ],
+]
+
+TEST_LATENT_SHAPE = {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "downsample_parameters": ((2, 4, 1, 1),) * 2,
+ "upsample_parameters": ((2, 4, 1, 1, 0),) * 2,
+ "num_res_layers": 1,
+ "channels": (8, 8),
+ "num_res_channels": (8, 8),
+ "num_embeddings": 16,
+ "embedding_dim": 8,
+}
+
+
+class TestVQVAE(unittest.TestCase):
+ @parameterized.expand(TEST_CASES)
+ def test_shape(self, input_param, input_shape, expected_shape):
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ net = VQVAE(**input_param).to(device)
+
+ with eval_mode(net):
+ result, _ = net(torch.randn(input_shape).to(device))
+
+ self.assertEqual(result.shape, expected_shape)
+
+ @parameterized.expand(TEST_CASES)
+ @SkipIfBeforePyTorchVersion((1, 11))
+ def test_shape_with_checkpoint(self, input_param, input_shape, expected_shape):
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ input_param = input_param.copy()
+ input_param.update({"use_checkpointing": True})
+
+ net = VQVAE(**input_param).to(device)
+
+ with eval_mode(net):
+ result, _ = net(torch.randn(input_shape).to(device))
+
+ self.assertEqual(result.shape, expected_shape)
+
+ # Removed this test case since TorchScript currently does not support activation checkpoint.
+ # def test_script(self):
+ # net = VQVAE(
+ # spatial_dims=2,
+ # in_channels=1,
+ # out_channels=1,
+ # downsample_parameters=((2, 4, 1, 1),) * 2,
+ # upsample_parameters=((2, 4, 1, 1, 0),) * 2,
+ # num_res_layers=1,
+ # channels=(8, 8),
+ # num_res_channels=(8, 8),
+ # num_embeddings=16,
+ # embedding_dim=8,
+ # ddp_sync=False,
+ # )
+ # test_data = torch.randn(1, 1, 16, 16)
+ # test_script_save(net, test_data)
+
+ def test_channels_not_same_size_of_num_res_channels(self):
+ with self.assertRaises(ValueError):
+ VQVAE(
+ spatial_dims=2,
+ in_channels=1,
+ out_channels=1,
+ channels=(16, 16),
+ num_res_channels=(16, 16, 16),
+ downsample_parameters=((2, 4, 1, 1),) * 2,
+ upsample_parameters=((2, 4, 1, 1, 0),) * 2,
+ )
+
+ def test_channels_not_same_size_of_downsample_parameters(self):
+ with self.assertRaises(ValueError):
+ VQVAE(
+ spatial_dims=2,
+ in_channels=1,
+ out_channels=1,
+ channels=(16, 16),
+ num_res_channels=(16, 16),
+ downsample_parameters=((2, 4, 1, 1),) * 3,
+ upsample_parameters=((2, 4, 1, 1, 0),) * 2,
+ )
+
+ def test_channels_not_same_size_of_upsample_parameters(self):
+ with self.assertRaises(ValueError):
+ VQVAE(
+ spatial_dims=2,
+ in_channels=1,
+ out_channels=1,
+ channels=(16, 16),
+ num_res_channels=(16, 16),
+ downsample_parameters=((2, 4, 1, 1),) * 2,
+ upsample_parameters=((2, 4, 1, 1, 0),) * 3,
+ )
+
+ def test_downsample_parameters_not_sequence_or_int(self):
+ with self.assertRaises(ValueError):
+ VQVAE(
+ spatial_dims=2,
+ in_channels=1,
+ out_channels=1,
+ channels=(16, 16),
+ num_res_channels=(16, 16),
+ downsample_parameters=(("test", 4, 1, 1),) * 2,
+ upsample_parameters=((2, 4, 1, 1, 0),) * 2,
+ )
+
+ def test_upsample_parameters_not_sequence_or_int(self):
+ with self.assertRaises(ValueError):
+ VQVAE(
+ spatial_dims=2,
+ in_channels=1,
+ out_channels=1,
+ channels=(16, 16),
+ num_res_channels=(16, 16),
+ downsample_parameters=((2, 4, 1, 1),) * 2,
+ upsample_parameters=(("test", 4, 1, 1, 0),) * 2,
+ )
+
+ def test_downsample_parameter_length_different_4(self):
+ with self.assertRaises(ValueError):
+ VQVAE(
+ spatial_dims=2,
+ in_channels=1,
+ out_channels=1,
+ channels=(16, 16),
+ num_res_channels=(16, 16),
+ downsample_parameters=((2, 4, 1),) * 3,
+ upsample_parameters=((2, 4, 1, 1, 0),) * 2,
+ )
+
+ def test_upsample_parameter_length_different_5(self):
+ with self.assertRaises(ValueError):
+ VQVAE(
+ spatial_dims=2,
+ in_channels=1,
+ out_channels=1,
+ channels=(16, 16),
+ num_res_channels=(16, 16, 16),
+ downsample_parameters=((2, 4, 1, 1),) * 2,
+ upsample_parameters=((2, 4, 1, 1, 0, 1),) * 3,
+ )
+
+ def test_encode_shape(self):
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ net = VQVAE(**TEST_LATENT_SHAPE).to(device)
+
+ with eval_mode(net):
+ latent = net.encode(torch.randn(1, 1, 32, 32).to(device))
+
+ self.assertEqual(latent.shape, (1, 8, 8, 8))
+
+ def test_index_quantize_shape(self):
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ net = VQVAE(**TEST_LATENT_SHAPE).to(device)
+
+ with eval_mode(net):
+ latent = net.index_quantize(torch.randn(1, 1, 32, 32).to(device))
+
+ self.assertEqual(latent.shape, (1, 8, 8))
+
+ def test_decode_shape(self):
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ net = VQVAE(**TEST_LATENT_SHAPE).to(device)
+
+ with eval_mode(net):
+ latent = net.decode(torch.randn(1, 8, 8, 8).to(device))
+
+ self.assertEqual(latent.shape, (1, 1, 32, 32))
+
+ def test_decode_samples_shape(self):
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ net = VQVAE(**TEST_LATENT_SHAPE).to(device)
+
+ with eval_mode(net):
+ latent = net.decode_samples(torch.randint(low=0, high=16, size=(1, 8, 8)).to(device))
+
+ self.assertEqual(latent.shape, (1, 1, 32, 32))
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_vqvaetransformer_inferer.py b/tests/test_vqvaetransformer_inferer.py
new file mode 100644
index 0000000000..36b715f588
--- /dev/null
+++ b/tests/test_vqvaetransformer_inferer.py
@@ -0,0 +1,295 @@
+# Copyright (c) MONAI Consortium
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import unittest
+from unittest import skipUnless
+
+import torch
+from parameterized import parameterized
+
+from monai.inferers import VQVAETransformerInferer
+from monai.networks.nets import VQVAE, DecoderOnlyTransformer
+from monai.utils import optional_import
+from monai.utils.ordering import Ordering, OrderingType
+
+einops, has_einops = optional_import("einops")
+TEST_CASES = [
+ [
+ {
+ "spatial_dims": 2,
+ "in_channels": 1,
+ "out_channels": 1,
+ "channels": (8, 8),
+ "num_res_channels": (8, 8),
+ "downsample_parameters": ((2, 4, 1, 1),) * 2,
+ "upsample_parameters": ((2, 4, 1, 1, 0),) * 2,
+ "num_res_layers": 1,
+ "num_embeddings": 16,
+ "embedding_dim": 8,
+ },
+ {
+ "num_tokens": 16 + 1,
+ "max_seq_len": 4,
+ "attn_layers_dim": 4,
+ "attn_layers_depth": 2,
+ "attn_layers_heads": 1,
+ "with_cross_attention": False,
+ },
+ {"ordering_type": OrderingType.RASTER_SCAN.value, "spatial_dims": 2, "dimensions": (2, 2, 2)},
+ (2, 1, 8, 8),
+ (2, 4, 17),
+ (2, 2, 2),
+ ],
+ [
+ {
+ "spatial_dims": 3,
+ "in_channels": 1,
+ "out_channels": 1,
+ "channels": (8, 8),
+ "num_res_channels": (8, 8),
+ "downsample_parameters": ((2, 4, 1, 1),) * 2,
+ "upsample_parameters": ((2, 4, 1, 1, 0),) * 2,
+ "num_res_layers": 1,
+ "num_embeddings": 16,
+ "embedding_dim": 8,
+ },
+ {
+ "num_tokens": 16 + 1,
+ "max_seq_len": 8,
+ "attn_layers_dim": 4,
+ "attn_layers_depth": 2,
+ "attn_layers_heads": 1,
+ "with_cross_attention": False,
+ },
+ {"ordering_type": OrderingType.RASTER_SCAN.value, "spatial_dims": 3, "dimensions": (2, 2, 2, 2)},
+ (2, 1, 8, 8, 8),
+ (2, 8, 17),
+ (2, 2, 2, 2),
+ ],
+]
+
+
+class TestVQVAETransformerInferer(unittest.TestCase):
+ @parameterized.expand(TEST_CASES)
+ @skipUnless(has_einops, "Requires einops")
+ def test_prediction_shape(
+ self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape
+ ):
+ stage_1 = VQVAE(**stage_1_params)
+ stage_2 = DecoderOnlyTransformer(**stage_2_params)
+ ordering = Ordering(**ordering_params)
+
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
+ stage_1.to(device)
+ stage_2.to(device)
+ stage_1.eval()
+ stage_2.eval()
+
+ input = torch.randn(input_shape).to(device)
+
+ inferer = VQVAETransformerInferer()
+ prediction = inferer(inputs=input, vqvae_model=stage_1, transformer_model=stage_2, ordering=ordering)
+ self.assertEqual(prediction.shape, logits_shape)
+
+ @parameterized.expand(TEST_CASES)
+ @skipUnless(has_einops, "Requires einops")
+ def test_prediction_shape_shorter_sequence(
+ self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape
+ ):
+ stage_1 = VQVAE(**stage_1_params)
+ max_seq_len = 3
+ stage_2_params_shorter = dict(stage_2_params)
+ stage_2_params_shorter["max_seq_len"] = max_seq_len
+ stage_2 = DecoderOnlyTransformer(**stage_2_params_shorter)
+ ordering = Ordering(**ordering_params)
+
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
+ stage_1.to(device)
+ stage_2.to(device)
+ stage_1.eval()
+ stage_2.eval()
+
+ input = torch.randn(input_shape).to(device)
+
+ inferer = VQVAETransformerInferer()
+ prediction = inferer(inputs=input, vqvae_model=stage_1, transformer_model=stage_2, ordering=ordering)
+ cropped_logits_shape = (logits_shape[0], max_seq_len, logits_shape[2])
+ self.assertEqual(prediction.shape, cropped_logits_shape)
+
+ @skipUnless(has_einops, "Requires einops")
+ def test_sample(self):
+
+ stage_1 = VQVAE(
+ spatial_dims=2,
+ in_channels=1,
+ out_channels=1,
+ channels=(8, 8),
+ num_res_channels=(8, 8),
+ downsample_parameters=((2, 4, 1, 1),) * 2,
+ upsample_parameters=((2, 4, 1, 1, 0),) * 2,
+ num_res_layers=1,
+ num_embeddings=16,
+ embedding_dim=8,
+ )
+ stage_2 = DecoderOnlyTransformer(
+ num_tokens=16 + 1,
+ max_seq_len=4,
+ attn_layers_dim=4,
+ attn_layers_depth=2,
+ attn_layers_heads=1,
+ with_cross_attention=False,
+ )
+ ordering = Ordering(ordering_type=OrderingType.RASTER_SCAN.value, spatial_dims=2, dimensions=(2, 2, 2))
+
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
+ stage_1.to(device)
+ stage_2.to(device)
+ stage_1.eval()
+ stage_2.eval()
+
+ inferer = VQVAETransformerInferer()
+
+ starting_token = 16 # from stage_1 num_embeddings
+
+ sample = inferer.sample(
+ latent_spatial_dim=(2, 2),
+ starting_tokens=starting_token * torch.ones((2, 1), device=device),
+ vqvae_model=stage_1,
+ transformer_model=stage_2,
+ ordering=ordering,
+ )
+ self.assertEqual(sample.shape, (2, 1, 8, 8))
+
+ @skipUnless(has_einops, "Requires einops")
+ def test_sample_shorter_sequence(self):
+ stage_1 = VQVAE(
+ spatial_dims=2,
+ in_channels=1,
+ out_channels=1,
+ channels=(8, 8),
+ num_res_channels=(8, 8),
+ downsample_parameters=((2, 4, 1, 1),) * 2,
+ upsample_parameters=((2, 4, 1, 1, 0),) * 2,
+ num_res_layers=1,
+ num_embeddings=16,
+ embedding_dim=8,
+ )
+ stage_2 = DecoderOnlyTransformer(
+ num_tokens=16 + 1,
+ max_seq_len=2,
+ attn_layers_dim=4,
+ attn_layers_depth=2,
+ attn_layers_heads=1,
+ with_cross_attention=False,
+ )
+ ordering = Ordering(ordering_type=OrderingType.RASTER_SCAN.value, spatial_dims=2, dimensions=(2, 2, 2))
+
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
+ stage_1.to(device)
+ stage_2.to(device)
+ stage_1.eval()
+ stage_2.eval()
+
+ inferer = VQVAETransformerInferer()
+
+ starting_token = 16 # from stage_1 num_embeddings
+
+ sample = inferer.sample(
+ latent_spatial_dim=(2, 2),
+ starting_tokens=starting_token * torch.ones((2, 1), device=device),
+ vqvae_model=stage_1,
+ transformer_model=stage_2,
+ ordering=ordering,
+ )
+ self.assertEqual(sample.shape, (2, 1, 8, 8))
+
+ @parameterized.expand(TEST_CASES)
+ @skipUnless(has_einops, "Requires einops")
+ def test_get_likelihood(
+ self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape
+ ):
+ stage_1 = VQVAE(**stage_1_params)
+ stage_2 = DecoderOnlyTransformer(**stage_2_params)
+ ordering = Ordering(**ordering_params)
+
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
+ stage_1.to(device)
+ stage_2.to(device)
+ stage_1.eval()
+ stage_2.eval()
+
+ input = torch.randn(input_shape).to(device)
+
+ inferer = VQVAETransformerInferer()
+ likelihood = inferer.get_likelihood(
+ inputs=input, vqvae_model=stage_1, transformer_model=stage_2, ordering=ordering
+ )
+ self.assertEqual(likelihood.shape, latent_shape)
+
+ @parameterized.expand(TEST_CASES)
+ @skipUnless(has_einops, "Requires einops")
+ def test_get_likelihood_shorter_sequence(
+ self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape
+ ):
+ stage_1 = VQVAE(**stage_1_params)
+ max_seq_len = 3
+ stage_2_params_shorter = dict(stage_2_params)
+ stage_2_params_shorter["max_seq_len"] = max_seq_len
+ stage_2 = DecoderOnlyTransformer(**stage_2_params_shorter)
+ ordering = Ordering(**ordering_params)
+
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
+ stage_1.to(device)
+ stage_2.to(device)
+ stage_1.eval()
+ stage_2.eval()
+
+ input = torch.randn(input_shape).to(device)
+
+ inferer = VQVAETransformerInferer()
+ likelihood = inferer.get_likelihood(
+ inputs=input, vqvae_model=stage_1, transformer_model=stage_2, ordering=ordering
+ )
+ self.assertEqual(likelihood.shape, latent_shape)
+
+ @parameterized.expand(TEST_CASES)
+ @skipUnless(has_einops, "Requires einops")
+ def test_get_likelihood_resampling(
+ self, stage_1_params, stage_2_params, ordering_params, input_shape, logits_shape, latent_shape
+ ):
+ stage_1 = VQVAE(**stage_1_params)
+ stage_2 = DecoderOnlyTransformer(**stage_2_params)
+ ordering = Ordering(**ordering_params)
+
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
+ stage_1.to(device)
+ stage_2.to(device)
+ stage_1.eval()
+ stage_2.eval()
+
+ input = torch.randn(input_shape).to(device)
+
+ inferer = VQVAETransformerInferer()
+ likelihood = inferer.get_likelihood(
+ inputs=input,
+ vqvae_model=stage_1,
+ transformer_model=stage_2,
+ ordering=ordering,
+ resample_latent_likelihoods=True,
+ resample_interpolation_mode="nearest",
+ )
+ self.assertEqual(likelihood.shape, input_shape)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_warp.py b/tests/test_warp.py
index bac595224f..0e5f2466db 100644
--- a/tests/test_warp.py
+++ b/tests/test_warp.py
@@ -124,7 +124,7 @@ def test_itk_benchmark(self):
relative_diff = np.mean(
np.divide(monai_result - itk_result, itk_result, out=np.zeros_like(itk_result), where=(itk_result != 0))
)
- self.assertTrue(relative_diff < 0.01)
+ self.assertLess(relative_diff, 0.01)
@parameterized.expand(TEST_CASES, skip_on_empty=True)
def test_resample(self, input_param, input_data, expected_val):
@@ -217,6 +217,7 @@ def itk_warp(img, ddf):
# warp
warp_filter.SetDisplacementField(displacement_field)
warp_filter.SetInput(itk_img)
+ warp_filter.Update()
warped_img = warp_filter.GetOutput()
warped_img = np.asarray(warped_img)
diff --git a/tests/testing_data/data_config.json b/tests/testing_data/data_config.json
index a570c787ba..79033dd0d6 100644
--- a/tests/testing_data/data_config.json
+++ b/tests/testing_data/data_config.json
@@ -138,13 +138,33 @@
"url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/ssl_pretrained_weights.pth",
"hash_type": "sha256",
"hash_val": "c3564f40a6a051d3753a6d8fae5cc8eaf21ce8d82a9a3baf80748d15664055e8"
+ },
+ "decoder_only_transformer_monai_generative_weights": {
+ "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/decoder_only_transformer.pth",
+ "hash_type": "sha256",
+ "hash_val": "f93de37d64d77cf91f3bde95cdf93d161aee800074c89a92aff9d5699120ec0d"
+ },
+ "diffusion_model_unet_monai_generative_weights": {
+ "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/diffusion_model_unet.pth",
+ "hash_type": "sha256",
+ "hash_val": "0d2171b386902f5b4fd3e967b4024f63e353694ca45091b114970019d045beee"
+ },
+ "autoencoderkl_monai_generative_weights": {
+ "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/autoencoderkl.pth",
+ "hash_type": "sha256",
+ "hash_val": "6e02c9540c51b16b9ba98b5c0c75d6b84b430afe9a3237df1d67a520f8d34184"
+ },
+ "controlnet_monai_generative_weights": {
+ "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/controlnet.pth",
+ "hash_type": "sha256",
+ "hash_val": "cd100d0c69f47569ae5b4b7df653a1cb19f5e02eff1630db3210e2646fb1ab2e"
}
},
"configs": {
"test_meta_file": {
- "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20220324.json",
+ "url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20240725.json",
"hash_type": "md5",
- "hash_val": "662135097106b71067cd1fc657f8720f"
+ "hash_val": "06954cad2cc5d3784e72077ac76f0fc8"
}
}
}
diff --git a/tests/testing_data/fl_infer_properties.json b/tests/testing_data/fl_infer_properties.json
new file mode 100644
index 0000000000..72e97cd2c6
--- /dev/null
+++ b/tests/testing_data/fl_infer_properties.json
@@ -0,0 +1,67 @@
+{
+ "bundle_root": {
+ "description": "root path of the bundle.",
+ "required": true,
+ "id": "bundle_root"
+ },
+ "device": {
+ "description": "target device to execute the bundle workflow.",
+ "required": true,
+ "id": "device"
+ },
+ "dataset_dir": {
+ "description": "directory path of the dataset.",
+ "required": true,
+ "id": "dataset_dir"
+ },
+ "dataset": {
+ "description": "PyTorch dataset object for the inference / evaluation logic.",
+ "required": true,
+ "id": "dataset"
+ },
+ "evaluator": {
+ "description": "inference / evaluation workflow engine.",
+ "required": true,
+ "id": "evaluator"
+ },
+ "network_def": {
+ "description": "network module for the inference.",
+ "required": true,
+ "id": "network_def"
+ },
+ "inferer": {
+ "description": "MONAI Inferer object to execute the model computation in inference.",
+ "required": true,
+ "id": "inferer"
+ },
+ "dataset_data": {
+ "description": "data source for the inference / evaluation dataset.",
+ "required": false,
+ "id": "dataset::data",
+ "refer_id": null
+ },
+ "handlers": {
+ "description": "event-handlers for the inference / evaluation logic.",
+ "required": false,
+ "id": "handlers",
+ "refer_id": "evaluator::val_handlers"
+ },
+ "preprocessing": {
+ "description": "preprocessing for the input data.",
+ "required": false,
+ "id": "preprocessing",
+ "refer_id": "dataset::transform"
+ },
+ "postprocessing": {
+ "description": "postprocessing for the model output data.",
+ "required": false,
+ "id": "postprocessing",
+ "refer_id": "evaluator::postprocessing"
+ },
+ "key_metric": {
+ "description": "the key metric during evaluation.",
+ "required": false,
+ "id": "key_metric",
+ "refer_id": "evaluator::key_val_metric"
+ }
+}
diff --git a/tests/testing_data/metadata.json b/tests/testing_data/metadata.json
index 98a17b73c5..29737e3a9d 100644
--- a/tests/testing_data/metadata.json
+++ b/tests/testing_data/metadata.json
@@ -1,5 +1,5 @@
{
- "schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20220324.json",
+ "schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20240725.json",
"version": "0.1.0",
"changelog": {
"0.1.0": "complete the model package",
@@ -8,7 +8,7 @@
"monai_version": "0.9.0",
"pytorch_version": "1.10.0",
"numpy_version": "1.21.2",
- "optional_packages_version": {
+ "required_packages_version": {
"nibabel": "3.2.1"
},
"task": "Decathlon spleen segmentation",
diff --git a/tests/testing_data/ultrasound_confidence_map/femur_input.png b/tests/testing_data/ultrasound_confidence_map/femur_input.png
new file mode 100644
index 0000000000..0343e58720
Binary files /dev/null and b/tests/testing_data/ultrasound_confidence_map/femur_input.png differ
diff --git a/tests/testing_data/ultrasound_confidence_map/femur_result.npy b/tests/testing_data/ultrasound_confidence_map/femur_result.npy
new file mode 100644
index 0000000000..a3f322b113
Binary files /dev/null and b/tests/testing_data/ultrasound_confidence_map/femur_result.npy differ
diff --git a/tests/testing_data/ultrasound_confidence_map/neck_input.png b/tests/testing_data/ultrasound_confidence_map/neck_input.png
new file mode 100644
index 0000000000..74a64a9d90
Binary files /dev/null and b/tests/testing_data/ultrasound_confidence_map/neck_input.png differ
diff --git a/tests/testing_data/ultrasound_confidence_map/neck_result.npy b/tests/testing_data/ultrasound_confidence_map/neck_result.npy
new file mode 100644
index 0000000000..8bf760182c
Binary files /dev/null and b/tests/testing_data/ultrasound_confidence_map/neck_result.npy differ
diff --git a/tests/utils.py b/tests/utils.py
index ea73a3ed81..2a00af50e9 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -47,7 +47,7 @@
from monai.networks import convert_to_onnx, convert_to_torchscript
from monai.utils import optional_import
from monai.utils.misc import MONAIEnvVars
-from monai.utils.module import pytorch_after
+from monai.utils.module import compute_capabilities_after, pytorch_after
from monai.utils.tf32 import detect_default_tf32
from monai.utils.type_conversion import convert_data_type
@@ -156,6 +156,7 @@ def skip_if_downloading_fails():
"limit", # HTTP Error 503: Egress is over the account limit
"authenticate",
"timed out", # urlopen error [Errno 110] Connection timed out
+ "HTTPError", # HTTPError: 429 Client Error: Too Many Requests for huggingface hub
)
):
raise unittest.SkipTest(f"error while downloading: {rt_e}") from rt_e # incomplete download
@@ -285,6 +286,20 @@ def __call__(self, obj):
)(obj)
+class SkipIfBeforeComputeCapabilityVersion:
+ """Decorator to be used if test should be skipped
+ with Compute Capability older than that given."""
+
+ def __init__(self, compute_capability_tuple):
+ self.min_version = compute_capability_tuple
+ self.version_too_old = not compute_capabilities_after(*compute_capability_tuple)
+
+ def __call__(self, obj):
+ return unittest.skipIf(
+ self.version_too_old, f"Skipping tests that fail on Compute Capability versions before: {self.min_version}"
+ )(obj)
+
+
def is_main_test_process():
ps = torch.multiprocessing.current_process()
if not ps or not hasattr(ps, "name"):
@@ -474,7 +489,7 @@ def run_process(self, func, local_rank, args, kwargs, results):
if self.verbose:
os.environ["NCCL_DEBUG"] = "INFO"
os.environ["NCCL_DEBUG_SUBSYS"] = "ALL"
- os.environ["NCCL_BLOCKING_WAIT"] = str(1)
+ os.environ["TORCH_NCCL_BLOCKING_WAIT"] = str(1)
os.environ["OMP_NUM_THREADS"] = str(1)
os.environ["WORLD_SIZE"] = str(self.nproc_per_node * self.nnodes)
os.environ["RANK"] = str(self.nproc_per_node * self.node_rank + local_rank)