From 6ece63ed35bcf0852699f484af84f2d9c3d939a4 Mon Sep 17 00:00:00 2001 From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> Date: Fri, 16 Jun 2023 12:52:07 +0800 Subject: [PATCH] [Feature] Support Pure Python style Configuration File (#1071) --- docs/en/advanced_tutorials/config.md | 463 ++++++++++- docs/en/conf.py | 1 + docs/zh_cn/advanced_tutorials/config.md | 465 +++++++++++ docs/zh_cn/conf.py | 1 + mmengine/config/config.py | 786 +++++++++++++++--- mmengine/config/lazy.py | 216 +++++ mmengine/config/utils.py | 304 ++++++- mmengine/model/averaged_model.py | 3 +- mmengine/model/base_module.py | 5 +- mmengine/registry/build_functions.py | 2 +- mmengine/registry/registry.py | 28 +- mmengine/runner/loops.py | 2 +- mmengine/runner/runner.py | 24 +- mmengine/utils/__init__.py | 5 +- mmengine/utils/misc.py | 42 +- mmengine/utils/package_utils.py | 36 +- requirements/docs.txt | 2 + .../config/lazy_module_config/__init__.py | 1 + .../lazy_module_config/_base_/__init__.py | 1 + .../lazy_module_config/_base_/base_model.py | 4 + .../_base_/default_runtime.py | 36 + .../lazy_module_config/_base_/scheduler.py | 20 + .../lazy_module_config/error_mix_using1.py | 2 + .../lazy_module_config/error_mix_using2.py | 3 + .../lazy_module_config/load_mmdet_config.py | 7 + .../lazy_module_config/test_ast_transform.py | 13 + .../test_ast_transform_error_catching1.py | 2 + .../config/lazy_module_config/toy_model.py | 45 + tests/test_config/test_config.py | 180 ++++ tests/test_config/test_lazy.py | 182 ++++ tests/test_registry/test_registry.py | 12 + tests/test_utils/test_misc.py | 23 +- tests/test_utils/test_package_utils.py | 41 + 33 files changed, 2809 insertions(+), 148 deletions(-) create mode 100644 mmengine/config/lazy.py create mode 100644 tests/data/config/lazy_module_config/__init__.py create mode 100644 tests/data/config/lazy_module_config/_base_/__init__.py create mode 100644 tests/data/config/lazy_module_config/_base_/base_model.py create mode 100644 tests/data/config/lazy_module_config/_base_/default_runtime.py create mode 100644 tests/data/config/lazy_module_config/_base_/scheduler.py create mode 100644 tests/data/config/lazy_module_config/error_mix_using1.py create mode 100644 tests/data/config/lazy_module_config/error_mix_using2.py create mode 100644 tests/data/config/lazy_module_config/load_mmdet_config.py create mode 100644 tests/data/config/lazy_module_config/test_ast_transform.py create mode 100644 tests/data/config/lazy_module_config/test_ast_transform_error_catching1.py create mode 100644 tests/data/config/lazy_module_config/toy_model.py create mode 100644 tests/test_config/test_lazy.py create mode 100644 tests/test_utils/test_package_utils.py diff --git a/docs/en/advanced_tutorials/config.md b/docs/en/advanced_tutorials/config.md index 57c227bdd0..f8815385bd 100644 --- a/docs/en/advanced_tutorials/config.md +++ b/docs/en/advanced_tutorials/config.md @@ -1,6 +1,31 @@ # Config -MMEngine implements an abstract configuration class (`Config`) to provide a unified configuration access interface for users. `Config` supports different type of configuration file, including `python`, `json` and `yaml`, and you can choose the type according to your preference. `Config` overrides some magic method, which could help you access the data stored in `Config` just like getting values from `dict`, or getting attributes from instances. Besides, `Config` also provides an inheritance mechanism, which could help you better organize and manage the configuration files. +- [Config](#config) + - [Read the configuration file](#read-the-configuration-file) + - [How to use `Config`](#how-to-use-config) + - [Inheritance between configuration files](#inheritance-between-configuration-files) + - [Overview of inheritance mechanism](#overview-of-inheritance-mechanism) + - [Modify the inherited fields](#modify-the-inherited-fields) + - [Delete key in `dict`](#delete-key-in-dict) + - [Reference of the inherited file](#reference-of-the-inherited-file) + - [Dump the configuration file](#dump-the-configuration-file) + - [Advanced usage](#advanced-usage) + - [Predefined fields](#predefined-fields) + - [Modify the fields in command line](#modify-the-fields-in-command-line) + - [Replace fields with environment variables](#replace-fields-with-environment-variables) + - [import the custom module](#import-the-custom-module) + - [Inherit configuration files across repository](#inherit-configuration-files-across-repository) + - [Get configuration files across repository](#get-configuration-files-across-repository) + - [A Pure Python style Configuration File (Beta)](#a-pure-python-style-configuration-file-beta) + - [Basic Syntax](#basic-syntax) + - [Module Construction](#module-construction) + - [Inheritance](#inheritance) + - [Dump the Configuration File](#dump-the-configuration-file-1) + - [What is Lazy Import](#what-is-lazy-import) + - [Limitations](#limitations) + - [Migration Guide](#migration-guide) + +MMEngine implements an abstract configuration class (`Config`) to provide a unified configuration access interface for users. `Config` supports different types of configuration file, including `python`, `json` and `yaml`, and you can choose the type according to your preference. `Config` overrides some magic method, which could help you access the data stored in `Config` just like getting values from `dict`, or getting attributes from instances. Besides, `Config` also provides an inheritance mechanism, which could help you better organize and manage the configuration files. Before starting the tutorial, let's download the configuration files needed in the tutorial (it is recommended to execute them in a temporary directory to facilitate deleting these files latter.): @@ -25,6 +50,10 @@ wget https://raw.githubusercontent.com/open-mmlab/mmengine/main/docs/resources/c wget https://raw.githubusercontent.com/open-mmlab/mmengine/main/docs/resources/config/modify_base_var.py ``` +```{note} +The `Config` supports two styles of configuration files: text style and pure Python style (introduced in v0.8.0). Each has its own characteristics while maintaining a unified interface for calling. For users who are not familiar with the basic usage of the `Config`, it is recommended to start reading from the section on [Read the configuration file](#read-the-configuration-file) to understand the functionality of the `Config` and the syntax of text style configuration files. In some cases, the syntax of text style configuration files is more concise and compatible with different formats such as `json` and `yaml`. If you prefer a more flexible syntax for configuration files, it is recommended to use the [Pure Python Style Configuration Files (beta)](#a-pure-python-style-configuration-file-beta). +``` + ## Read the configuration file `Config` provides a uniform interface `Config.fromfile()` to read and parse configuration files. @@ -374,6 +403,10 @@ b=2 In this section, we'll introduce some advanced usage of the `Config`, and some tips that could make it easier for users to develop and use downstream repositories. +```{note} +If you use pure Python style configuration file. Advanced usage should not be used except for the function described in "Modify the fields in command line" +``` + ### Predefined fields Sometimes we need some fields in the configuration file, which are related to the path to the workspace. For example, we define a working directory in the configuration file that holds the models and logs for this set of experimental configurations. We expect to have different working directories for different configuration files. A common choice is to use the configuration file name directly as part of the working directory name. @@ -691,3 +724,431 @@ print(type(model)) http loads checkpoint from path: https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth ``` + +## A Pure Python style Configuration File (Beta) + +In the previous tutorial, we introduced how to use configuration files to build modules with registry and how to use `_base_` to inherit configuration files. These pure text style configuration files can satisfy most of our development needs and some module aliases can greatly simplify the configuration files (e.g. `ResNet` can refer to `mmcls.models.ResNet`). However, there are also some disadvantages: + +1. In the configuration file, the `type` field is specified by a string, and IDE cannot directly jump to the corresponding class definition, which is not conducive to code reading and jumping. +2. The inheritance of configuration files is also specified by a string, and IDE cannot directly jump to the inherited file. When the inheritance structure of the configuration file is complex, it is not conducive to reading and jumping of the configuration file. +3. The inheritance rules are relatively implicit, and beginners find it difficult to understand how the configuration file merges variables with the same fields and derives special syntax such as `_delete_`, resulting in a higher learning cost. +4. It is easy for users to forget to register the module and cause `module not found` errors. +5. In the yet-to-be-mentioned cross-codebase inheritance, the introduction of the scope makes the inheritance rules of the configuration file more complicated, and beginners find it difficult to understand. + +In summary, although pure text style configuration files can provide the same syntax rules for `python`, `json`, and `yaml` format configurations, when the configuration files become complex, pure text style configuration files will appear inadequate. Therefore, we provide a pure Python style configuration file, i.e., the `lazy import` mode, which can fully utilize Python's syntax rules to solve the above problems. At the same time, the pure Python style configuration file also supports exporting to `json` and `yaml` formats. + +### Basic Syntax + +In the previous tutorial, we introduced module construction, inheritance, and export based on pure text style configuration files. This section will introduce pure Python style configuration files based on these three aspects. + +#### Module Construction + +We use a simple example to compare pure Python style and pure text style configuration files: + +```{eval-rst} +.. tabs:: + .. tabs:: + + .. code-tab:: python Pure Python style + + # No need for registration + + .. code-tab:: python Pure text style + + # Registration process + from torch.optim import SGD + from mmengine.registry import OPTIMIZERS + + OPTIMIZERS.register_module(module=SGD, name='SGD') + + .. tabs:: + + .. code-tab:: python Pure Python style + + # Configuration file writing + from torch.optim import SGD + + + optimizer = dict(type=SGD, lr=0.1) + + .. code-tab:: python Pure text style + + # Configuration file writing + optimizer = dict(type='SGD', lr=0.1) + + .. tabs:: + + .. code-tab:: python Pure Python style + + # The construction process is exactly the same + import torch.nn as nn + from mmengine.registry import OPTIMIZERS + + + cfg = Config.fromfile('optimizer.py') + model = nn.Conv2d(1, 1, 1) + cfg.optimizer.params = model.parameters() + optimizer = OPTIMIZERS.build(cfg.optimizer) + + .. code-tab:: python Pure text style + + # The construction process is exactly the same + import torch.nn as nn + from mmengine.registry import OPTIMIZERS + + + cfg = Config.fromfile('optimizer.py') + model = nn.Conv2d(1, 1, 1) + cfg.optimizer.params = model.parameters() + optimizer = OPTIMIZERS.build(cfg.optimizer) +``` + +From the above example, we can see that the difference between pure Python style and pure text style configuration files is: + +1. Pure Python style configuration files do not require module registration. +2. In pure Python style configuration files, the `type` field is no longer a string but directly refers to the module. Correspondingly, import syntax needs to be added in the configuration file. + +It should be noted that the OpenMMLab series algorithm library still retains the registration process when adding modules. When users build their own projects based on MMEngine, if they use pure Python style configuration files, registration is not required. You may wonder that if you are not in an environment with torch installed, you cannot parse the sample configuration file. Can this configuration file still be called a configuration file? Don't worry, we will explain this part later. + +#### Inheritance + +The inheritance syntax of pure Python style configuration files is slightly different: + +```{eval-rst} +.. tabs:: + + .. code-tab:: python Pure Python style Inheritance + + if '_base_': + from .optimizer import * + + .. code-tab:: python Pure text style Inheritance + + _base_ = [./optimizer.py] + +``` + +Pure Python style configuration files use import syntax to achieve inheritance. The advantage of doing this is that we can directly jump to the inherited configuration file for easy reading and jumping. The variable inheritance rule (add, delete, change, and search) is completely aligned with Python syntax. For example, if I want to modify the learning rate of the optimizer in the base configuration file: + +```python +if '_base_': + from .optimizer import * + +# optimizer is a variable defined in the base configuration file +optimizer.update( + lr=0.01, +) +``` + +Of course, if you are already accustomed to the inheritance rules of pure text style configuration files and the variable is of the `dict` type in the `_base_` configuration file, you can also use merge syntax to achieve the same inheritance rule as pure text style configuration files: + +```python +if '_base_': + from .optimizer import * + +# optimizer is a variable defined in the base configuration file +optimizer.merge( + _delete_=True, + lr=0.01, + type='SGD' +) + +# The equivalent Python style writing is as follows, completely consistent with Python's import rules +# optimizer = dict( +# lr=0.01, +# type='SGD' +# ) +``` + +````{note} +It should be noted that the `update` method of the dictionary in pure Python style configuration files is slightly different from `dict.update`. Pure Python style update will recursively update the content in the dictionary, for example: + +```python +x = dict(a=1, b=dict(c=2, d=3)) + +x.update(dict(b=dict(d=4))) +# Update rules in the configuration file: +# {a: 1, b: {c: 2, d: 4}} +# Update rules in the normal dict: +# {a: 1, b: {d: 4}} +``` + +It can be seen that using the update method in the configuration file will recursively update the fields, rather than simply covering them. +```` + +Compared with pure text style configuration files, the inheritance rule of pure Python style configuration files is completely aligned with the import syntax of Python, which is easier to understand and supports jumping between configuration files. You may wonder since both inheritance and module imports use import syntax, why do we need an `if '_base_'` statement for inheriting configuration files? On the one hand, this can improve the readability of configuration files, making inherited configuration files more prominent. On the other hand, it is also restricted by the rules of lazy_import, which will be explained later. + +#### Dump the Configuration File + +The pure Python style configuration files can also be exported via the `dump` interface, and there is no difference in usage. However, the exported contents will be different: + +```{eval-rst} +.. tabs:: + + .. tabs:: + + .. code-tab:: python Export in pure Python style + + optimizer = dict(type='torch.optim.SGD', lr=0.1) + + .. code-tab:: python Export in pure text style + + optimizer = dict(type='SGD', lr=0.1) + + .. tabs:: + + .. code-tab:: yaml Export in pure Python style + + optimizer: + type: torch.optim.SGD + lr: 0.1 + + .. code-tab:: yaml Export in pure text style + + optimizer: + type: SGD + lr: 0.1 + + .. tabs:: + + .. code-tab:: json Export in pure Python style + + {"optimizer": "torch.optim.SGD", "lr": 0.1} + + .. code-tab:: json Export in pure text style + + {"optimizer": "SGD", "lr": 0.1} +``` + +As can be seen, the type field exported in pure Python style contains the full module information. The exported configuration file can also be directly loaded to construct an instance through the registry. + +### What is Lazy Import + +You may find that pure Python style configuration files seem to organize configuration files using pure Python syntax. Then, I do not need configuration classes, and I could just import configuration files using Python syntax. If you have such a feeling, then it is worth celebrating because this is exactly the effect we want. + +As mentioned earlier, parsing configuration files requires dependencies on third-party libraries referenced in the configuration files. This is actually a very unreasonable thing. For example, if I trained a model based on MMagic and wanted to deploy it with the onnxruntime backend of MMDeploy. Due to the lack of torch in the deployment environment, and torch is needed in the configuration file parsing process, this makes it inconvenient for me to directly use the configuration file of MMagic as the deployment configuration. To solve this problem, we introduced the concept of lazy_import. + +It is a complex task to discuss the specific implementation of lazy_import, so here we only briefly introduce its function. The core idea of lazy_import is to delay the execution of the import statement in the configuration file until the configuration file is parsed, so that the dependency problem caused by the import statement in the configuration file can be avoided. During the configuration file parsing process, the equivalent code executed by the Python interpreter is as follows: + +```{eval-rst} +.. tabs:: + .. code-tab:: python Original configuration file + + from torch.optim import SGD + + + optimizer = dict(type=SGD) + + .. code-tab:: python Code actually executed by the python interpreter through the configuration class + + lazy_obj = LazyObject('torch.optim', 'SGD') + + optimizer = dict(type=lazy_obj) +``` + +As an internal type of the `Config` module, the `LazyObject` cannot be accessed directly by users. When accessing the type field, it will undergo a series of conversions to convert `LazyObject` into the actual `torch.optim.SGD` type. In this way, parsing the configuration file will not trigger the import of third-party libraries, while users can still access the types of third-party libraries normally when using the configuration file. + +To access the internal type of `LazyObject`, you can use the `Config.to_dict` interface: + +```python +cfg = Config.fromfile('optimizer.py').to_dict() +print(type(cfg['optimizer']['type'])) +# mmengine.config.lazy.LazyObject +``` + +At this point, the type accessed is the `LazyObject` type. + +However, we cannot adopt the lazy import strategy for the inheritance (import) of base files since we need the configuration file parsed to include the fields defined in the base configuration file, and we need to trigger the import really. Therefore, we have added a restriction on importing base files, which must be imported in the `if '_base_'` code block. + +### Limitations + +1. Functions and classes cannot be defined in the configuration file. +2. The configuration file name must comply with the naming convention of Python modules, which can only contain letters, numbers, and underscores, and cannot start with a number. +3. When importing variables from the base configuration file, such as `from ._base_.alpha import beta`, the `alpha` here must be the module (module) name, i.e., a Python file, rather than the package (package) name containing `__init__.py`. +4. Importing multiple variables simultaneously in an absolute import statement, such as `import torch, numpy, os`, is not supported. Multiple import statements need to be used instead, such as `import torch; import numpy; import os`. + +### Migration Guide + +To migrate from a pure text style configuration file to a pure Python style configuration file, the following rules must be followed: + +1. Replace the string type with the specific class: + + - If the code does not depend on the type field being a string, and no special processing is done on the type field, the string type of the type field can be replaced with the specific class, and the class should be imported at the beginning of the configuration file. + - If the code depends on the type field being a string, the code needs to be modified, or the original string format of the type should be retained. + +2. Rename the configuration file. The configuration file name must comply with the naming convention of Python modules, which can only contain letters, numbers, and underscores, and cannot start with a number. + +3. Remove scope-related configurations. Pure Python style configuration files no longer need to use scope to get modules across libraries, and modules can be directly imported. For compatibility reasons, we still set the `default_scope` parameter of the Runner to `mmengine`, and users need to manually set it to `None`. + +4. For modules that have aliases in the registry, replace their aliases with their corresponding real modules. The following is a table of commonly used alias replacements: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
ModuleAliasNotes
nearesttorch.nn.modules.upsampling.UpsampleWhen replacing 'type' with 'Upsample', the 'mode' parameter needs to be specified as 'nearest'.
bilineartorch.nn.modules.upsampling.UpsampleWhen replacing 'type' with 'Upsample', the 'mode' parameter needs to be specified as 'bilinear'.
Clipmmcv.cnn.bricks.activation.ClampNone
Convmmcv.cnn.bricks.wrappers.Conv2dNone
BNtorch.nn.modules.batchnorm.BatchNorm2dNone
BN1dtorch.nn.modules.batchnorm.BatchNorm1dNone
BN2dtorch.nn.modules.batchnorm.BatchNorm2dNone
BN3dtorch.nn.modules.batchnorm.BatchNorm3dNone
SyncBNtorch.nn.SyncBatchNormNone
GNtorch.nn.modules.normalization.GroupNormNone
LNtorch.nn.modules.normalization.LayerNormNone
INtorch.nn.modules.instancenorm.InstanceNorm2dNone
IN1dtorch.nn.modules.instancenorm.InstanceNorm1dNone
IN2dtorch.nn.modules.instancenorm.InstanceNorm2dNone
IN3dtorch.nn.modules.instancenorm.InstanceNorm3dNone
zerotorch.nn.modules.padding.ZeroPad2dNone
reflecttorch.nn.modules.padding.ReflectionPad2dNone
replicatetorch.nn.modules.padding.ReplicationPad2dNone
ConvWSmmcv.cnn.bricks.conv_ws.ConvWS2dNone
ConvAWSmmcv.cnn.bricks.conv_ws.ConvAWS2dNone
HSwishtorch.nn.modules.activation.HardswishNone
pixel_shufflemmcv.cnn.bricks.upsample.PixelShufflePackNone
deconvmmcv.cnn.bricks.wrappers.ConvTranspose2dNone
deconv3dmmcv.cnn.bricks.wrappers.ConvTranspose3dNone
Constantmmengine.model.weight_init.ConstantInitNone
Xaviermmengine.model.weight_init.XavierInitNone
Normalmmengine.model.weight_init.NormalInitNone
TruncNormalmmengine.model.weight_init.TruncNormalInitNone
Uniformmmengine.model.weight_init.UniformInitNone
Kaimingmmengine.model.weight_init.KaimingInitNone
Caffe2Xaviermmengine.model.weight_init.Caffe2XavierInitNone
Pretrainedmmengine.model.weight_init.PretrainedInitNone
diff --git a/docs/en/conf.py b/docs/en/conf.py index de24cfb121..c2b4961477 100644 --- a/docs/en/conf.py +++ b/docs/en/conf.py @@ -47,6 +47,7 @@ 'myst_parser', 'sphinx_copybutton', 'sphinx.ext.autodoc.typehints', + 'sphinx_tabs.tabs', ] # yapf: disable autodoc_typehints = 'description' myst_heading_anchors = 4 diff --git a/docs/zh_cn/advanced_tutorials/config.md b/docs/zh_cn/advanced_tutorials/config.md index f1f66580ba..f7b1440446 100644 --- a/docs/zh_cn/advanced_tutorials/config.md +++ b/docs/zh_cn/advanced_tutorials/config.md @@ -1,5 +1,30 @@ # 配置(Config) +- [配置(Config)](#配置config) + - [配置文件读取](#配置文件读取) + - [配置文件的使用](#配置文件的使用) + - [配置文件的继承](#配置文件的继承) + - [继承机制概述](#继承机制概述) + - [修改继承字段](#修改继承字段) + - [删除字典中的 key](#删除字典中的-key) + - [引用被继承文件中的变量](#引用被继承文件中的变量) + - [配置文件的导出](#配置文件的导出) + - [其他进阶用法](#其他进阶用法) + - [预定义字段](#预定义字段) + - [命令行修改配置](#命令行修改配置) + - [使用环境变量替换配置](#使用环境变量替换配置) + - [导入自定义 Python 模块](#导入自定义-python-模块) + - [跨项目继承配置文件](#跨项目继承配置文件) + - [跨项目获取配置文件](#跨项目获取配置文件) + - [纯 Python 风格的配置文件(Beta)](#纯-python-风格的配置文件beta) + - [基本语法](#基本语法) + - [模块构建](#模块构建) + - [继承](#继承) + - [配置文件的导出](#配置文件的导出-1) + - [什么是 lazy import](#什么是-lazy-import) + - [功能限制](#功能限制) + - [迁移指南](#迁移指南) + MMEngine 实现了抽象的配置类(Config),为用户提供统一的配置访问接口。配置类能够支持不同格式的配置文件,包括 `python`,`json`,`yaml`,用户可以根据需求选择自己偏好的格式。配置类提供了类似字典或者 Python 对象属性的访问接口,用户可以十分自然地进行配置字段的读取和修改。为了方便算法框架管理配置文件,配置类也实现了一些特性,例如配置文件的字段继承等。 在开始教程之前,我们先将教程中需要用到的配置文件下载到本地(建议在临时目录下执行,方便后续删除示例配置文件): @@ -25,6 +50,10 @@ wget https://raw.githubusercontent.com/open-mmlab/mmengine/main/docs/resources/c wget https://raw.githubusercontent.com/open-mmlab/mmengine/main/docs/resources/config/modify_base_var.py ``` +```{note} +配置类支持两种风格的配置文件,即纯文本风格的配置文件和纯 Python 风格的配置文件(v0.8.0 的新特性),二者在调用接口统一的前提下各有特色。对于尚且不了解配置类基本用法用户,建议从[配置文件读取](#配置文件读取) 一节开始阅读,以了解配置类的功能和纯文本配置文件的语法。在一些情况下,纯文本风格的配置文件写法更加简洁,语法兼容性更好(`json`、`yaml` 通用)。如果你希望配置文件的写法可以更加灵活,建议阅读并使用[纯 Python 风格的配置文件](#纯-python-风格的配置文件beta)(beta) +``` + ## 配置文件读取 配置类提供了统一的接口 `Config.fromfile()`,来读取和解析配置文件。 @@ -368,8 +397,16 @@ b=2 这里介绍一下配置类的进阶用法,这些小技巧可能使用户开发和使用算法库更简单方便。 +```{note} +需要注意的是,如果你用的是纯 Python 风格的配置文件,只有“命令行修改配置”一节中提到功能是有效的。 +``` + ### 预定义字段 +```{note} +该用法仅适用于非 `lazy_import` 模式,具体见纯 Python 风格的配置文件一节 +``` + 有时候我们希望配置文件中的一些字段和当前路径或者文件名等相关,这里举一个典型使用场景的例子。在训练模型时,我们会在配置文件中定义一个工作目录,存放这组实验配置的模型和日志,那么对于不同的配置文件,我们期望定义不同的工作目录。用户的一种常见选择是,直接使用配置文件名作为工作目录名的一部分,例如对于配置文件 `predefined_var.py`,工作目录就是 `./work_dir/predefined_var`。 使用预定义字段可以方便地实现这种需求,在配置文件 `predefined_var.py` 中可以这样写: @@ -698,3 +735,431 @@ print(cfg.model_path) ``` https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth ``` + +## 纯 Python 风格的配置文件(Beta) + +在之前的教程里,我们介绍了如何使用配置文件,搭配注册器来构建模块;如何使用 `_base_` 来继承配置文件。这些纯文本风格的配置文件固然能够满足我们平时开发的大部分需求,并且一些模块的 alias 也大大简化了配置文件(例如 `ResNet` 就能指代 `mmcls.models.ResNet`)。但是也存在一些弊端: + +1. 配置文件中,type 字段是通过字符串来指定的,在 IDE 中无法直接跳转到对应的类定义处,不利于代码阅读和跳转 +2. 配置文件的继承,也是通过字符串来指定的,IDE 无法直接跳转到被继承的文件中,当配置文件继承结构复杂时,不利于配置文件的阅读和跳转 +3. 继承规则较为隐式,初学者很难理解配置文件是如何对相同字段的变量进行融合,且衍生出 `_delete_` 这类特殊语法,学习成本较高 +4. 用户忘记注册模块时,容易发生 module not found 的 error +5. 在尚且没有提到的跨库继承中,scope 的引入导致配置文件的继承规则更加复杂,初学者很难理解 + +综上所述,尽管纯文本风格的配置文件能够为 `python`、`json`、`yaml` 格式的配置提供相同的语法规则,但是当配置文件变得复杂时,纯文本风格的配置文件会显得力不从心。为此,我们提供了纯 Python 风格的配置文件,即 `lazy import` 模式,它能够充分利用 Python 的语法规则,解决上述问题。与此同时,纯 Python 风格的配置文件也支持导出成 `json` 和 `yaml` 格式。 + +### 基本语法 + +之前的教程分别介绍了基于纯文本风格配置文件的模块构建、继承和导出,本节将基于这三个方面来介绍纯 Python 风格的配置文件。 + +#### 模块构建 + +我们通过一个简单的例子来对比纯 Python 风格和纯文本风格的配置文件: + +```{eval-rst} +.. tabs:: + .. tabs:: + + .. code-tab:: python 纯 Python 风格 + + # 无需注册 + + .. code-tab:: python 纯文本风格 + + # 注册流程 + from torch.optim import SGD + from mmengine.registry import OPTIMIZERS + + OPTIMIZERS.register_module(module=SGD, name='SGD') + + .. tabs:: + + .. code-tab:: python 纯 Python 风格 + + # 配置文件写法 + from torch.optim import SGD + + + optimizer = dict(type=SGD, lr=0.1) + + .. code-tab:: python 纯文本风格 + + # 配置文件写法 + optimizer = dict(type='SGD', lr=0.1) + + .. tabs:: + + .. code-tab:: python 纯 Python 风格 + + # 构建流程完全一致 + import torch.nn as nn + from mmengine.registry import OPTIMIZERS + + + cfg = Config.fromfile('optimizer.py') + model = nn.Conv2d(1, 1, 1) + cfg.optimizer.params = model.parameters() + optimizer = OPTIMIZERS.build(cfg.optimizer) + + .. code-tab:: python 纯文本风格 + + # 构建流程完全一致 + import torch.nn as nn + from mmengine.registry import OPTIMIZERS + + + cfg = Config.fromfile('optimizer.py') + model = nn.Conv2d(1, 1, 1) + cfg.optimizer.params = model.parameters() + optimizer = OPTIMIZERS.build(cfg.optimizer) +``` + +从上面的例子可以看出,纯 Python 风格的配置文件和纯文本风格的配置文件的区别在于: + +1. 纯 Python 风格的配置文件无需注册模块 +2. 纯 Python 风格的配置文件中,type 字段不再是字符串,而是直接指代模块。相应的配置文件需要多出 import 语法 + +需要注意的是,OpenMMLab 系列算法库在新增模块时仍会保留注册过程,用户基于 MMEngine 构建自己的项目时,如果使用纯 Python 风格的配置文件,则无需注册。看到这你会或许会好奇,这样没有安装 PyTorch 的环境不就没法解析样例配置文件了么,这样的配置文件还叫配置文件么?不要着急,这部分的内容我们会在后面介绍。 + +#### 继承 + +纯 Python 风格的配置文件继承语法有所不同: + +```{eval-rst} +.. tabs:: + + .. code-tab:: python 纯 Python 风格继承 + + if '_base_': + from .optimizer import * + + .. code-tab:: python 纯文本风格继承 + + _base_ = [./optimizer.py] + +``` + +纯 Python 风格的配置文件通过 import 语法来实现继承,这样做的好处是,我们可以直接跳转到被继承的配置文件中,方便阅读和跳转。变量的继承规则(增删改查)完全对齐 Python 语法,例如我想修改 base 配置文件中 optimizer 的学习率: + +```python +if '_base_': + from .optimizer import * + +# optimizer 为 base 配置文件定义的变量 +optimizer.update( + lr=0.01, +) +``` + +当然了,如果你已经习惯了纯文本风格的继承规则,且该变量在 _base_ 配置文件中为 `dict` 类型,也可以通过 merge 语法来实现和纯文本风格配置文件一致的继承规则: + +```python +if '_base_': + from .optimizer import * + +# optimizer 为 base 配置文件定义的变量 +optimizer.merge( + _delete_=True, + lr=0.01, + type='SGD' +) + +# 等价的 python 风格写法如下,与 Python 的 import 规则完全一致 +# optimizer = dict( +# lr=0.01, +# type='SGD' +# ) +``` + +````{note} +需要注意的是,纯 Python 风格的配置文件中,字典的 `update` 方法与 `dict.update` 稍有不同。纯 Python 风格的 update 会递归地去更新字典中的内容,例如: + +```python +x = dict(a=1, b=dict(c=2, d=3)) + +x.update(dict(b=dict(d=4))) +# 配置文件中的 update 规则: +# {a: 1, b: {c: 2, d: 4}} +# 普通 dict 的 update 规则: +# {a: 1, b: {d: 4}} +``` + +可见在配置文件中使用 update 方法会递归地去更新字段,而不是简单的覆盖。 +```` + +与纯文本风格的配置文件相比,纯 Python 风格的配置文件的继承规则完全对齐 import 语法,更容易理解,且支持配置文件之间的跳转。你或许会好奇既然继承和模块的导入都使用了 import 语法,为什么继承配置文件还需要额外的 `if '_base_'` 语句呢?一方面这样可以提升配置文件的可读性,可以让继承的配置文件更加突出,另一方面也是受限于 lazy_import 的规则,这个会在后面讲到。 + +#### 配置文件的导出 + +纯 python 风格配置文件也通过 dump 接口导出,使用上没有任何区别,但是导出的内容会有所不同: + +```{eval-rst} +.. tabs:: + + .. tabs:: + + .. code-tab:: python 纯 Python 风格导出 + + optimizer = dict(type='torch.optim.SGD', lr=0.1) + + .. code-tab:: python 纯文本风格导出 + + optimizer = dict(type='SGD', lr=0.1) + + .. tabs:: + + .. code-tab:: yaml 纯 Python 风格导出 + + optimizer: + type: torch.optim.SGD + lr: 0.1 + + .. code-tab:: yaml 纯文本风格导出 + + optimizer: + type: SGD + lr: 0.1 + + .. tabs:: + + .. code-tab:: json 纯 Python 风格导出 + + {"optimizer": "torch.optim.SGD", "lr": 0.1} + + .. code-tab:: json 纯文本风格导出 + + {"optimizer": "SGD", "lr": 0.1} +``` + +可以看到,纯 Python 风格导出的 type 字段会包含模块的全量信息。导出的配置文件也可以被直接加载,通过注册器来构建实例。 + +### 什么是 lazy import + +看到这你可能会吐槽,这纯 Python 风格的配置文件感觉就像是用纯 Python 语法来组织配置文件嘛。这样我哪还需要配置类,直接用 Python 语法来导入配置文件不就好了。如果你有这样的感受,那真是一件值得庆祝的事,因为这正是我们想要的效果。 + +正如前面所提到的,解析配置文件需要依赖配置文件中引用的三方库,这其实是一件非常不合理的事。例如我基于 MMagic 训练了一个模型,想使用 MMDeploy 的 onnxruntime 后端部署。由于部署环境中没有 torch,而配置文件解析过程中需要 torch,这就导致了我无法直接使用 MMagic 的配置文件作为部署的配置,这是非常不方便的。为了解决这个问题,我们引入了 lazy_import 的概念。 + +要聊 lazy_import 的具体实现是一件比较复杂的事,在此我们仅对其功能做简要介绍。lazy_import 的核心思想是,将配置文件中的 import 语句延迟到配置文件被解析时才执行,这样就可以避免配置文件中的 import 语句导致的三方库依赖问题。配置文件解析过程时,Python 解释器实际执行的等效代码如下 + +```{eval-rst} +.. tabs:: + .. code-tab:: python 原始配置文件 + + from torch.optim import SGD + + + optimizer = dict(type=SGD) + + .. code-tab:: python 通过配置类,Python 解释器实际执行的代码 + + lazy_obj = LazyObject('torch.optim', 'SGD') + + optimizer = dict(type=lazy_obj) +``` + +LazyObject 作为 `Config` 模块的內部类型,无法被用户直接访问。用户在访问 type 字段时,会经过一系列的转换,将 `LazyObject` 转化成真正的 `torch.optim.SGD` 类型。这样一来,配置文件的解析不会触发三方库的导入,而用户使用配置文件时,又可以正常访问三方库的类型。 + +要想访问 `LazyObject` 的内部类型,可以通过 `Config.to_dict` 接口: + +```python +cfg = Config.fromfile('optimizer.py').to_dict() +print(type(cfg['optimizer']['type'])) +# mmengine.config.lazy.LazyObject +``` + +此时得到的 type 就是 `LazyObject` 类型。 + +然而对于 base 文件的继承(导入,import),我们不能够采取 lazy import 的策略,这是因为我们希望解析后的配置文件能够包含 base 配置文件定义的字段,需要真正的触发 import。因此我们对 base 文件的导入加了一层限制,即必须在 `if '_base_'` 的代码块中导入。 + +### 功能限制 + +1. 不能在配置文件中定义函数、类等 +2. 配置文件名必须符合 Python 模块名的命名规范,即只能包含字母、数字、下划线,且不能以数字开头 +3. 导入 base 配置文件中的变量时,例如 `from ._base_.alpha import beta`,此处的 `alpha` 必须是模块(module)名,即 Python 文件,而不能是含有 `__init__.py` 的包(package)名 +4. 不支持在 absolute import 语句中同时导入多个变量,例如 `import torch, numpy, os`。需要通过多个 import 语句来实现,例如 `import torch; import numpy; import os` + +### 迁移指南 + +从纯文本风格的配置文件迁移到纯 Python 风格的配置文件,需要遵守以下规则: + +1. type 从字符串替换成具体的类: + + - 代码不依赖 type 字段是字符串,且没有对 type 字段做特殊处理,则可以将字符串类型的 type 替换成具体的类,并在配置文件的开头导入该类 + - 代码依赖 type 字段是字符串,则需要修改代码,或保持原有的字符串格式的 type + +2. 重命名配置文件,配置文件命名需要符合 Python 模块名的命名规范,即只能包含字母、数字、下划线,且不能以数字开头 + +3. 删除 scope 相关配置。纯 Python 风格的配置文件不再需要通过 scope 来跨库调用模块,直接通过 import 导入即可。出于兼容性方面的考虑,我们仍然让 Runner 的 default_scope 参数为 `mmengine`,用户需要将其手动设置为 `None` + +4. 对于注册器中存在别名的(alias)的模块,将其别名替换成其对应的真实模块即可,以下是常用的别名替换表: + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
模块别名注意事项
nearesttorch.nn.modules.upsampling.Upsample将 type 替换成 Upsample 后,需要额外将 mode 参数指定为 'nearest'
bilineartorch.nn.modules.upsampling.Upsample将 type 替换成 Upsample 后,需要额外将 mode 参数指定为 'bilinear'
Clipmmcv.cnn.bricks.activation.Clamp
Convmmcv.cnn.bricks.wrappers.Conv2d
BNtorch.nn.modules.batchnorm.BatchNorm2d
BN1dtorch.nn.modules.batchnorm.BatchNorm1d
BN2dtorch.nn.modules.batchnorm.BatchNorm2d
BN3dtorch.nn.modules.batchnorm.BatchNorm3d
SyncBNtorch.nn.SyncBatchNorm
GNtorch.nn.modules.normalization.GroupNorm
LNtorch.nn.modules.normalization.LayerNorm
INtorch.nn.modules.instancenorm.InstanceNorm2d
IN1dtorch.nn.modules.instancenorm.InstanceNorm1d
IN2dtorch.nn.modules.instancenorm.InstanceNorm2d
IN3dtorch.nn.modules.instancenorm.InstanceNorm3d
zerotorch.nn.modules.padding.ZeroPad2d
reflecttorch.nn.modules.padding.ReflectionPad2d
replicatetorch.nn.modules.padding.ReplicationPad2d
ConvWSmmcv.cnn.bricks.conv_ws.ConvWS2d
ConvAWSmmcv.cnn.bricks.conv_ws.ConvAWS2d
HSwishtorch.nn.modules.activation.Hardswish
pixel_shufflemmcv.cnn.bricks.upsample.PixelShufflePack
deconvmmcv.cnn.bricks.wrappers.ConvTranspose2d
deconv3dmmcv.cnn.bricks.wrappers.ConvTranspose3d
Constantmmengine.model.weight_init.ConstantInit
Xaviermmengine.model.weight_init.XavierInit
Normalmmengine.model.weight_init.NormalInit
TruncNormalmmengine.model.weight_init.TruncNormalInit
Uniformmmengine.model.weight_init.UniformInit
Kaimingmmengine.model.weight_init.KaimingInit
Caffe2Xaviermmengine.model.weight_init.Caffe2XavierInit
Pretrainedmmengine.model.weight_init.PretrainedInit
diff --git a/docs/zh_cn/conf.py b/docs/zh_cn/conf.py index dbd2a8223e..ad611187f9 100644 --- a/docs/zh_cn/conf.py +++ b/docs/zh_cn/conf.py @@ -52,6 +52,7 @@ 'myst_parser', 'sphinx_copybutton', 'sphinx.ext.autodoc.typehints', + 'sphinx_tabs.tabs', ] # yapf: disable autodoc_typehints = 'description' myst_heading_anchors = 4 diff --git a/mmengine/config/config.py b/mmengine/config/config.py index 71d1bf68f6..614f9a8e12 100644 --- a/mmengine/config/config.py +++ b/mmengine/config/config.py @@ -22,8 +22,11 @@ from mmengine.logging import print_log from mmengine.utils import (check_file_exist, get_installed_path, import_modules_from_strings, is_installed) -from .utils import (RemoveAssignFromAST, _get_external_cfg_base_path, - _get_external_cfg_path, _get_package_and_cfg_path) +from .lazy import LazyAttr, LazyObject +from .utils import (ImportTransformer, RemoveAssignFromAST, + _gather_abs_import_lazyobj, _get_external_cfg_base_path, + _get_external_cfg_path, _get_package_and_cfg_path, + _is_builtin_module) BASE_KEY = '_base_' DELETE_KEY = '_delete_' @@ -42,7 +45,41 @@ class ConfigDict(Dict): The Config class would transform the nested fields (dictionary-like fields) in config file into ``ConfigDict``. + + If the class attribute ``lazy`` is ``False``, users will get the + object built by ``LazyObject`` or ``LazyAttr``, otherwise users will get + the ``LazyObject`` or ``LazyAttr`` itself. + + The ``lazy`` should be set to ``True`` to avoid building the imported + object during configuration parsing, and it should be set to False outside + the Config to ensure that users do not experience the ``LazyObject``. """ + lazy = False + + def __init__(__self, *args, **kwargs): + object.__setattr__(__self, '__parent', kwargs.pop('__parent', None)) + object.__setattr__(__self, '__key', kwargs.pop('__key', None)) + object.__setattr__(__self, '__frozen', False) + for arg in args: + if not arg: + continue + # Since ConfigDict.items will convert LazyObject to real object + # automatically, we need to call super().items() to make sure + # the LazyObject will not be converted. + if isinstance(arg, ConfigDict): + for key, val in dict.items(arg): + __self[key] = __self._hook(val) + elif isinstance(arg, dict): + for key, val in arg.items(): + __self[key] = __self._hook(val) + elif isinstance(arg, tuple) and (not isinstance(arg[0], tuple)): + __self[arg[0]] = __self._hook(arg[1]) + else: + for key, val in iter(arg): + __self[key] = __self._hook(val) + + for key, val in dict.items(kwargs): + __self[key] = __self._hook(val) def __missing__(self, name): raise KeyError(name) @@ -50,6 +87,8 @@ def __missing__(self, name): def __getattr__(self, name): try: value = super().__getattr__(name) + if isinstance(value, (LazyAttr, LazyObject)) and not self.lazy: + value = value.build() except KeyError: raise AttributeError(f"'{self.__class__.__name__}' object has no " f"attribute '{name}'") @@ -58,6 +97,150 @@ def __getattr__(self, name): else: return value + def __setattr__(self, name, value): + value = self._hook(value) + return super().__setattr__(name, value) + + def __setitem__(self, name, value): + value = self._hook(value) + return super().__setitem__(name, value) + + def __getitem__(self, key): + return self.build_lazy(super().__getitem__(key)) + + def __deepcopy__(self, memo): + other = self.__class__() + memo[id(self)] = other + for key, value in super().items(): + other[copy.deepcopy(key, memo)] = copy.deepcopy(value, memo) + return other + + def get(self, key: str, default: Optional[Any] = None) -> Any: + """Get the value of the key. If class attribute ``lazy`` is True, the + LazyObject will be built and returned. + + Args: + key (str): The key. + default (any, optional): The default value. Defaults to None. + + Returns: + Any: The value of the key. + """ + return self.build_lazy(super().get(key, default)) + + def pop(self, key, default=None): + """Pop the value of the key. If class attribute ``lazy`` is True, the + LazyObject will be built and returned. + + Args: + key (str): The key. + default (any, optional): The default value. Defaults to None. + + Returns: + Any: The value of the key. + """ + return self.build_lazy(super().pop(key, default)) + + def update(self, *args, **kwargs) -> None: + """Override this method to make sure the LazyObject will not be built + during updating.""" + other = {} + if args: + if len(args) > 1: + raise TypeError('update only accept one positional argument') + # Avoid to used self.items to build LazyObject + for key, value in dict.items(args[0]): + other[key] = value + + for key, value in dict(kwargs).items(): + other[key] = value + for k, v in other.items(): + if ((k not in self) or (not isinstance(self[k], dict)) + or (not isinstance(v, dict))): + self[k] = self._hook(v) + else: + self[k].update(v) + + def build_lazy(self, value: Any) -> Any: + """If class attribute ``lazy`` is False, the LazyObject will be built + and returned. + + Args: + value (Any): The value to be built. + + Returns: + Any: The built value. + """ + if isinstance(value, (LazyAttr, LazyObject)) and not self.lazy: + value = value.build() + return value + + def values(self): + """Yield the values of the dictionary. + + If class attribute ``lazy`` is False, the value of ``LazyObject`` or + ``LazyAttr`` will be built and returned. + """ + for value in super().values(): + yield self.build_lazy(value) + + def items(self): + """Yield the keys and values of the dictionary. + + If class attribute ``lazy`` is False, the value of ``LazyObject`` or + ``LazyAttr`` will be built and returned. + """ + for key, value in super().items(): + yield key, self.build_lazy(value) + + def merge(self, other: dict): + """Merge another dictionary into current dictionary. + + Args: + other (dict): Another dictionary. + """ + default = object() + + def _merge_a_into_b(a, b): + if isinstance(a, dict): + if not isinstance(b, dict): + a.pop(DELETE_KEY, None) + return a + if a.pop(DELETE_KEY, False): + b.clear() + all_keys = list(b.keys()) + list(a.keys()) + return { + key: + _merge_a_into_b(a.get(key, default), b.get(key, default)) + for key in all_keys if key != DELETE_KEY + } + else: + return a if a is not default else b + + merged = _merge_a_into_b(copy.deepcopy(other), copy.deepcopy(self)) + self.clear() + for key, value in merged.items(): + self[key] = value + + def to_dict(self): + """Convert the ConfigDict to a normal dictionary recursively, and keep + the ``LazyObject`` or ``LazyAttr`` object not built.""" + + def _to_dict(data): + if isinstance(data, ConfigDict): + return { + key: _to_dict(value) + for key, value in Dict.items(data) + } + elif isinstance(data, dict): + return {key: _to_dict(value) for key, value in data.items()} + elif isinstance(data, (list, tuple)): + return type(data)(_to_dict(item) for item in data) + else: + return data + + return _to_dict(self) + def add_args(parser: ArgumentParser, cfg: dict, @@ -109,6 +292,8 @@ class Config: filename (str or Path, optional): Name of config file. Defaults to None. + Here is a simple example: + Examples: >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1]))) >>> cfg.a @@ -126,7 +311,11 @@ class Config: "Config [path: /home/username/projects/mmengine/tests/data/config/a.py] :" "{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}" - """ + + You can find more advance usage in the `config tutorial`_. + + .. _config tutorial: https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html + """ # noqa: E501 def __init__(self, cfg_dict: dict = None, @@ -143,7 +332,9 @@ def __init__(self, if key in RESERVED_KEYS: raise KeyError(f'{key} is reserved for config file') - super().__setattr__('_cfg_dict', ConfigDict(cfg_dict)) + if not isinstance(cfg_dict, ConfigDict): + cfg_dict = ConfigDict(cfg_dict) + super().__setattr__('_cfg_dict', cfg_dict) super().__setattr__('_filename', filename) if cfg_text: text = cfg_text @@ -161,7 +352,8 @@ def __init__(self, def fromfile(filename: Union[str, Path], use_predefined_variables: bool = True, import_custom_modules: bool = True, - use_environment_variables: bool = True) -> 'Config': + use_environment_variables: bool = True, + lazy_import: Optional[bool] = None) -> 'Config': """Build a Config instance from config file. Args: @@ -169,32 +361,61 @@ def fromfile(filename: Union[str, Path], use_predefined_variables (bool, optional): Whether to use predefined variables. Defaults to True. import_custom_modules (bool, optional): Whether to support - importing custom modules in config. Defaults to True. + importing custom modules in config. Defaults to None. + lazy_import (bool): Whether to load config in `lazy_import` mode. + If it is `None`, it will be deduced by the content of the + config file. Defaults to None. Returns: Config: Config instance built from config file. """ filename = str(filename) if isinstance(filename, Path) else filename - cfg_dict, cfg_text, env_variables = Config._file2dict( - filename, use_predefined_variables, use_environment_variables) - if import_custom_modules and cfg_dict.get('custom_imports', None): + if lazy_import is None: + lazy_import = Config._is_lazy_import(filename) + if not lazy_import: + cfg_dict, cfg_text, env_variables = Config._file2dict( + filename, use_predefined_variables, use_environment_variables) + if import_custom_modules and cfg_dict.get('custom_imports', None): + try: + import_modules_from_strings(**cfg_dict['custom_imports']) + except ImportError as e: + err_msg = ( + 'Failed to import custom modules from ' + f"{cfg_dict['custom_imports']}, the current sys.path " + 'is: ') + for p in sys.path: + err_msg += f'\n {p}' + err_msg += ( + '\nYou should set `PYTHONPATH` to make `sys.path` ' + 'include the directory which contains your custom ' + 'module') + raise ImportError(err_msg) from e + return Config( + cfg_dict, + cfg_text=cfg_text, + filename=filename, + env_variables=env_variables) + else: + # Enable lazy import when parsing the config. + # Using try-except to make sure ``ConfigDict.lazy`` will be reset + # to False. See more details about lazy in the docstring of + # ConfigDict + ConfigDict.lazy = True try: - import_modules_from_strings(**cfg_dict['custom_imports']) - except ImportError as e: - err_msg = ( - 'Failed to import custom modules from ' - f"{cfg_dict['custom_imports']}, the current sys.path is: ") - for p in sys.path: - err_msg += f'\n {p}' - err_msg += ( - '\nYou should set `PYTHONPATH` to make `sys.path` include ' - 'the directory which contains your custom module') - raise ImportError(err_msg) from e - return Config( - cfg_dict, - cfg_text=cfg_text, - filename=filename, - env_variables=env_variables) + cfg_dict, imported_names = Config._parse_lazy_import(filename) + except Exception as e: + raise e + finally: + ConfigDict.lazy = False + for key, value in list(cfg_dict.to_dict().items()): + if isinstance(value, (types.FunctionType, types.ModuleType)): + cfg_dict.pop(key) + + # disable lazy import to get the real type. See more details about + # lazy in the docstring of ConfigDict + cfg = Config(cfg_dict, filename=filename) + object.__setattr__(cfg, '_imported_names', imported_names) + return cfg @staticmethod def fromstring(cfg_str: str, file_format: str) -> 'Config': @@ -232,6 +453,77 @@ def fromstring(cfg_str: str, file_format: str) -> 'Config': os.remove(temp_file.name) # manually delete the temporary file return cfg + @staticmethod + def _get_base_modules(nodes: list) -> list: + """Get base module name from parsed code. + + Args: + nodes (list): Parsed code of the config file. + + Returns: + list: Name of base modules. + """ + + def _get_base_module_from_if(if_nodes: list) -> list: + """Get base module name from if statement in python file. + + Args: + if_nodes (list): List of if statement. + + Returns: + list: Name of base modules. + """ + base_modules = [] + for node in if_nodes: + assert isinstance(node, ast.ImportFrom), ( + 'Illegal syntax in config file! Only ' + '`from ... import ...` could be implemented` in ' + '`if _base_`') + assert node.module is not None, ( + 'Illegal syntax in config file! Syntax like ' + '`from . import xxx` is not allowed in `if _base_` ') + base_modules.append(node.level * '.' + node.module) + return base_modules + + for idx, node in enumerate(nodes): + if (isinstance(node, ast.Assign) + and isinstance(node.targets[0], ast.Name) + and node.targets[0].id == BASE_KEY): + raise RuntimeError( + 'The configuration file type in the inheritance chain ' + 'must match the current configuration file type, either ' + '"lazy_import" or non-"lazy_import". You got this error ' + f'since you use the syntax like `_base_ = "{node.targets[0].id}"` ' # noqa: E501 + 'in your config. You should use `if "_base_": ... to` ' + 'mark the inherited config file. See more information ' + 'in https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html' # noqa: E501 + ) + + if not isinstance(node, ast.If): + continue + value = node.test + if isinstance(value, ast.Constant) and not value.value == BASE_KEY: + continue + if isinstance(value, ast.Str) and not value.s == BASE_KEY: + continue + + # The original code: + # ``` + # if _base_: + # from .._base_.default_runtime import * + # ``` + # The processed code: + # ``` + # from .._base_.default_runtime import * + # ``` + # As you can see, the if statement is removed and the + # from ... import statement will be unindent + for nested_idx, nested_node in enumerate(node.body): + nodes.insert(idx + nested_idx + 1, nested_node) + nodes.pop(idx) + return _get_base_module_from_if(node.body) + return [] + @staticmethod def _validate_py_syntax(filename: str): """Validate syntax of python config. @@ -454,89 +746,107 @@ def _file2dict( Returns: Tuple[dict, str]: Variables dictionary and text of Config. """ + if Config._is_lazy_import(filename): + raise RuntimeError( + 'The configuration file type in the inheritance chain ' + 'must match the current configuration file type, either ' + '"lazy_import" or non-"lazy_import". You got this error ' + 'since you use the syntax like `if "_base_": ...` ' + f'or import non-builtin module in {filename}. See more ' + 'information in https://mmengine.readthedocs.io/en/latest/advanced_tutorials/config.html' # noqa: E501 + ) + filename = osp.abspath(osp.expanduser(filename)) check_file_exist(filename) fileExtname = osp.splitext(filename)[1] if fileExtname not in ['.py', '.json', '.yaml', '.yml']: raise OSError('Only py/yml/yaml/json type are supported now!') + try: + with tempfile.TemporaryDirectory() as temp_config_dir: + temp_config_file = tempfile.NamedTemporaryFile( + dir=temp_config_dir, suffix=fileExtname, delete=False) + if platform.system() == 'Windows': + temp_config_file.close() + + # Substitute predefined variables + if use_predefined_variables: + Config._substitute_predefined_vars(filename, + temp_config_file.name) + else: + shutil.copyfile(filename, temp_config_file.name) + # Substitute environment variables + env_variables = dict() + if use_environment_variables: + env_variables = Config._substitute_env_variables( + temp_config_file.name, temp_config_file.name) + # Substitute base variables from placeholders to strings + base_var_dict = Config._pre_substitute_base_vars( + temp_config_file.name, temp_config_file.name) - with tempfile.TemporaryDirectory() as temp_config_dir: - temp_config_file = tempfile.NamedTemporaryFile( - dir=temp_config_dir, suffix=fileExtname) - if platform.system() == 'Windows': + # Handle base files + base_cfg_dict = ConfigDict() + cfg_text_list = list() + for base_cfg_path in Config._get_base_files( + temp_config_file.name): + base_cfg_path, scope = Config._get_cfg_path( + base_cfg_path, filename) + _cfg_dict, _cfg_text, _env_variables = Config._file2dict( + filename=base_cfg_path, + use_predefined_variables=use_predefined_variables, + use_environment_variables=use_environment_variables) + cfg_text_list.append(_cfg_text) + env_variables.update(_env_variables) + duplicate_keys = base_cfg_dict.keys() & _cfg_dict.keys() + if len(duplicate_keys) > 0: + raise KeyError( + 'Duplicate key is not allowed among bases. ' + f'Duplicate keys: {duplicate_keys}') + + # _dict_to_config_dict will do the following things: + # 1. Recursively converts ``dict`` to :obj:`ConfigDict`. + # 2. Set `_scope_` for the outer dict variable for the base + # config. + # 3. Set `scope` attribute for each base variable. + # Different from `_scope_`, `scope` is not a key of base + # dict, `scope` attribute will be parsed to key `_scope_` + # by function `_parse_scope` only if the base variable is + # accessed by the current config. + _cfg_dict = Config._dict_to_config_dict(_cfg_dict, scope) + base_cfg_dict.update(_cfg_dict) + + if filename.endswith('.py'): + with open(temp_config_file.name, encoding='utf-8') as f: + parsed_codes = ast.parse(f.read()) + parsed_codes = RemoveAssignFromAST(BASE_KEY).visit( + parsed_codes) + codeobj = compile(parsed_codes, '', mode='exec') + # Support load global variable in nested function of the + # config. + global_locals_var = {BASE_KEY: base_cfg_dict} + ori_keys = set(global_locals_var.keys()) + eval(codeobj, global_locals_var, global_locals_var) + cfg_dict = { + key: value + for key, value in global_locals_var.items() + if (key not in ori_keys and not key.startswith('__')) + } + elif filename.endswith(('.yml', '.yaml', '.json')): + cfg_dict = load(temp_config_file.name) + # close temp file + for key, value in list(cfg_dict.items()): + if isinstance(value, + (types.FunctionType, types.ModuleType)): + cfg_dict.pop(key) temp_config_file.close() - # Substitute predefined variables - if use_predefined_variables: - Config._substitute_predefined_vars(filename, - temp_config_file.name) - else: - shutil.copyfile(filename, temp_config_file.name) - # Substitute environment variables - env_variables = dict() - if use_environment_variables: - env_variables = Config._substitute_env_variables( - temp_config_file.name, temp_config_file.name) - # Substitute base variables from placeholders to strings - base_var_dict = Config._pre_substitute_base_vars( - temp_config_file.name, temp_config_file.name) - - # Handle base files - base_cfg_dict = ConfigDict() - cfg_text_list = list() - for base_cfg_path in Config._get_base_files(temp_config_file.name): - base_cfg_path, scope = Config._get_cfg_path( - base_cfg_path, filename) - _cfg_dict, _cfg_text, _env_variables = Config._file2dict( - filename=base_cfg_path, - use_predefined_variables=use_predefined_variables, - use_environment_variables=use_environment_variables) - cfg_text_list.append(_cfg_text) - env_variables.update(_env_variables) - duplicate_keys = base_cfg_dict.keys() & _cfg_dict.keys() - if len(duplicate_keys) > 0: - raise KeyError('Duplicate key is not allowed among bases. ' - f'Duplicate keys: {duplicate_keys}') - - # _dict_to_config_dict will do the following things: - # 1. Recursively converts ``dict`` to :obj:`ConfigDict`. - # 2. Set `_scope_` for the outer dict variable for the base - # config. - # 3. Set `scope` attribute for each base variable. Different - # from `_scope_`, `scope` is not a key of base dict, - # `scope` attribute will be parsed to key `_scope_` by - # function `_parse_scope` only if the base variable is - # accessed by the current config. - _cfg_dict = Config._dict_to_config_dict(_cfg_dict, scope) - base_cfg_dict.update(_cfg_dict) - - if filename.endswith('.py'): - with open(temp_config_file.name, encoding='utf-8') as f: - codes = ast.parse(f.read()) - codes = RemoveAssignFromAST(BASE_KEY).visit(codes) - codeobj = compile(codes, '', mode='exec') - # Support load global variable in nested function of the - # config. - global_locals_var = {'_base_': base_cfg_dict} - ori_keys = set(global_locals_var.keys()) - eval(codeobj, global_locals_var, global_locals_var) - cfg_dict = { - key: value - for key, value in global_locals_var.items() - if (key not in ori_keys and not key.startswith('__')) - } - elif filename.endswith(('.yml', '.yaml', '.json')): - cfg_dict = load(temp_config_file.name) - # close temp file - for key, value in list(cfg_dict.items()): - if isinstance(value, (types.FunctionType, types.ModuleType)): - cfg_dict.pop(key) - temp_config_file.close() - - # If the current config accesses a base variable of base - # configs, The ``scope`` attribute of corresponding variable - # will be converted to the `_scope_`. - Config._parse_scope(cfg_dict) + # If the current config accesses a base variable of base + # configs, The ``scope`` attribute of corresponding variable + # will be converted to the `_scope_`. + Config._parse_scope(cfg_dict) + except Exception as e: + if osp.exists(temp_config_dir): + shutil.rmtree(temp_config_dir) + raise e # check deprecation information if DEPRECATION_KEY in cfg_dict: @@ -573,6 +883,172 @@ def _file2dict( return cfg_dict, cfg_text, env_variables + @staticmethod + def _parse_lazy_import(filename: str) -> Tuple[ConfigDict, set]: + """Transform file to variables dictionary. + + Args: + filename (str): Name of config file. + + Returns: + Tuple[dict, dict]: ``cfg_dict`` and ``imported_names``. + + - cfg_dict (dict): Variables dictionary of parsed config. + - imported_names (set): Used to mark the names of + imported object. + """ + # In lazy import mode, users can use the Python syntax `import` to + # implement inheritance between configuration files, which is easier + # for users to understand the hierarchical relationships between + # different configuration files. + + # Besides, users can also using `import` syntax to import corresponding + # module which will be filled in the `type` field. It means users + # can directly navigate to the source of the module in the + # configuration file by clicking the `type` field. + + # To avoid really importing the third party package like `torch` + # during import `type` object, we use `_parse_lazy_import` to parse the + # configuration file, which will not actually trigger the import + # process, but simply parse the imported `type`s as LazyObject objects. + + # The overall pipeline of _parse_lazy_import is: + # 1. Parse the base module from the config file. + # || + # \/ + # base_module = ['mmdet.configs.default_runtime'] + # || + # \/ + # 2. recursively parse the base module and gather imported objects to + # a dict. + # || + # \/ + # The base_dict will be: + # { + # 'mmdet.configs.default_runtime': {...} + # 'mmdet.configs.retinanet_r50_fpn_1x_coco': {...} + # ... + # }, each item in base_dict is a dict of `LazyObject` + # 3. parse the current config file filling the imported variable + # with the base_dict. + # + # 4. During the parsing process, all imported variable will be + # recorded in the `imported_names` set. These variables can be + # accessed, but will not be dumped by default. + + with open(filename, encoding='utf-8') as f: + global_dict = {'LazyObject': LazyObject} + base_dict = {} + + parsed_codes = ast.parse(f.read()) + # get the names of base modules, and remove the + # `if '_base_:'` statement + base_modules = Config._get_base_modules(parsed_codes.body) + base_imported_names = set() + for base_module in base_modules: + # If base_module means a relative import, assuming the level is + # 2, which means the module is imported like + # "from ..a.b import c". we must ensure that c is an + # object `defined` in module b, and module b should not be a + # package including `__init__` file but a single python file. + level = len(re.match(r'\.*', base_module).group()) + if level > 0: + # Relative import + base_dir = osp.dirname(filename) + module_path = osp.join( + base_dir, *(['..'] * (level - 1)), + f'{base_module[level:].replace(".", "/")}.py') + else: + # Absolute import + module_list = base_module.split('.') + if len(module_list) == 1: + raise RuntimeError( + 'The imported configuration file should not be ' + f'an independent package {module_list[0]}. Here ' + 'is an example: ' + '"_base_ = mmdet.configs.retinanet_r50_fpn_1x_coco"' # noqa: E501 + ) + else: + package = module_list[0] + root_path = get_installed_path(package) + module_path = f'{osp.join(root_path, *module_list[1:])}.py' # noqa: E501 + if not osp.isfile(module_path): + raise FileNotFoundError( + f'{module_path} not found! It means that incorrect ' + 'module is defined in ' + f"_base_ = ['{base_module}', ...], please " + 'make sure the base config module is valid ' + 'and is consistent with the prior import ' + 'logic') + _base_cfg_dict, _base_imported_names = Config._parse_lazy_import( # noqa: E501 + module_path) + base_imported_names |= _base_imported_names + # The base_dict will be: + # { + # 'mmdet.configs.default_runtime': {...} + # 'mmdet.configs.retinanet_r50_fpn_1x_coco': {...} + # ... + # } + base_dict[base_module] = _base_cfg_dict + + # `base_dict` contains all the imported modules from `base_cfg`. + # In order to collect the specific imported module from `base_cfg` + # before parse the current file, we using AST Transform to + # transverse the imported module from base_cfg and merge then into + # the global dict. After the ast transformation, most of import + # syntax will be removed (except for the builtin import) and + # replaced with the `LazyObject` + transform = ImportTransformer( + global_dict=global_dict, + base_dict=base_dict, + filename=filename) + modified_code = transform.visit(parsed_codes) + modified_code, abs_imported = _gather_abs_import_lazyobj( + modified_code, filename=filename) + imported_names = transform.imported_obj | abs_imported + imported_names |= base_imported_names + modified_code = ast.fix_missing_locations(modified_code) + exec( + compile(modified_code, filename, mode='exec'), global_dict, + global_dict) + + ret: dict = {} + for key, value in global_dict.items(): + if key.startswith('__') or key in ['LazyObject']: + continue + ret[key] = value + # convert dict to ConfigDict + cfg_dict = Config._dict_to_config_dict_lazy(ret) + + return cfg_dict, imported_names + + @staticmethod + def _dict_to_config_dict_lazy(cfg: dict): + """Recursively converts ``dict`` to :obj:`ConfigDict`. The only + difference between ``_dict_to_config_dict_lazy`` and + ``_dict_to_config_dict_lazy`` is that the former one does not consider + the scope, and will not trigger the building of ``LazyObject``. + + Args: + cfg (dict): Config dict. + + Returns: + ConfigDict: Converted dict. + """ + # Only the outer dict with key `type` should have the key `_scope_`. + if isinstance(cfg, dict): + if isinstance(cfg, ConfigDict): + # Use to_dict to avoid build lazy object. + cfg = cfg.to_dict() + cfg_dict = ConfigDict() + for key, value in cfg.items(): + cfg_dict[key] = Config._dict_to_config_dict_lazy(value) + return cfg_dict + if isinstance(cfg, (tuple, list)): + return type(cfg)( + Config._dict_to_config_dict_lazy(_cfg) for _cfg in cfg) + return cfg + @staticmethod def _dict_to_config_dict(cfg: dict, scope: Optional[str] = None, @@ -644,14 +1120,15 @@ def _get_base_files(filename: str) -> list: if file_format == '.py': Config._validate_py_syntax(filename) with open(filename, encoding='utf-8') as f: - codes = ast.parse(f.read()).body + parsed_codes = ast.parse(f.read()).body def is_base_line(c): return (isinstance(c, ast.Assign) and isinstance(c.targets[0], ast.Name) and c.targets[0].id == BASE_KEY) - base_code = next((c for c in codes if is_base_line(c)), None) + base_code = next((c for c in parsed_codes if is_base_line(c)), + None) if base_code is not None: base_code = ast.Expression( # type: ignore body=base_code.value) # type: ignore @@ -817,6 +1294,8 @@ def _indent(s_, num_spaces): def _format_basic_types(k, v, use_mapping=False): if isinstance(v, str): v_str = repr(v) + elif isinstance(v, (LazyObject, LazyAttr)): + v_str = f"'{v.module}.{str(v)}'" else: v_str = str(v) @@ -829,21 +1308,37 @@ def _format_basic_types(k, v, use_mapping=False): return attr_str - def _format_list(k, v, use_mapping=False): + def _format_list_tuple(k, v, use_mapping=False): + if isinstance(v, list): + left = '[' + right = ']' + else: + left = '(' + right = ')' + + v_str = f'{left}\n' # check if all items in the list are dict - if all(isinstance(_, dict) for _ in v): - v_str = '[\n' - v_str += '\n'.join( - f'dict({_indent(_format_dict(v_), indent)}),' - for v_ in v).rstrip(',') - if use_mapping: - k_str = f"'{k}'" if isinstance(k, str) else str(k) - attr_str = f'{k_str}: {v_str}' + for item in v: + if isinstance(item, dict): + v_str += f'dict({_indent(_format_dict(item), indent)}),\n' + elif isinstance(item, tuple): + v_str += f'{_indent(_format_list_tuple(None, item), indent)},\n' # noqa: 501 + elif isinstance(item, list): + v_str += f'{_indent(_format_list_tuple(None, item), indent)},\n' # noqa: 501 + elif isinstance(item, str): + v_str += f'{_indent(repr(item), indent)},\n' + elif isinstance(item, (LazyObject, LazyAttr)): + v_str += f"'{str(item)}',\n" else: - attr_str = f'{str(k)}={v_str}' - attr_str = _indent(attr_str, indent) + ']' + v_str += str(item) + ',\n' + if k is None: + return _indent(v_str, indent) + right + if use_mapping: + k_str = f"'{k}'" if isinstance(k, str) else str(k) + attr_str = f'{k_str}: {v_str}' else: - attr_str = _format_basic_types(k, v, use_mapping) + attr_str = f'{str(k)}={v_str}' + attr_str = _indent(attr_str, indent) + right return attr_str def _contain_invalid_identifier(dict_str): @@ -863,7 +1358,9 @@ def _format_dict(input_dict, outest_level=False): for idx, (k, v) in enumerate(input_dict.items()): is_last = idx >= len(input_dict) - 1 end = '' if outest_level or is_last else ',' - if isinstance(v, dict): + if isinstance(v, (LazyObject, LazyAttr)): + attr_str = _format_basic_types(k, v, use_mapping) + end + elif isinstance(v, dict): v_str = '\n' + _format_dict(v) if use_mapping: k_str = f"'{k}'" if isinstance(k, str) else str(k) @@ -871,8 +1368,8 @@ def _format_dict(input_dict, outest_level=False): else: attr_str = f'{str(k)}=dict({v_str}' attr_str = _indent(attr_str, indent) + ')' + end - elif isinstance(v, list): - attr_str = _format_list(k, v, use_mapping) + end + elif isinstance(v, (list, tuple)): + attr_str = _format_list_tuple(k, v, use_mapping) + end else: attr_str = _format_basic_types(k, v, use_mapping) + end @@ -882,14 +1379,18 @@ def _format_dict(input_dict, outest_level=False): r += '}' return r - cfg_dict = self._cfg_dict.to_dict() + cfg_dict = self.to_dict(keep_imported=False) text = _format_dict(cfg_dict, outest_level=True) # copied from setup.cfg yapf_style = dict( based_on_style='pep8', blank_line_before_nested_class_or_def=True, split_before_expression_after_opening_paren=True) - text, _ = FormatCode(text, style_config=yapf_style, verify=True) + try: + text, _ = FormatCode(text, style_config=yapf_style, verify=True) + except: # noqa: E722 + raise SyntaxError('Failed to format the config file, please ' + f'check the syntax of: \n{text}') return text @@ -1020,6 +1521,53 @@ def merge_from_dict(self, Config._merge_a_into_b( option_cfg_dict, cfg_dict, allow_list_keys=allow_list_keys)) + @staticmethod + def _is_lazy_import(filename: str) -> bool: + if not filename.endswith('.py'): + return False + with open(filename, encoding='utf-8') as f: + codes_str = f.read() + parsed_codes = ast.parse(codes_str) + for node in ast.walk(parsed_codes): + if (isinstance(node, ast.Assign) + and isinstance(node.targets[0], ast.Name) + and node.targets[0].id == BASE_KEY): + return False + if (isinstance(node, ast.If) + and isinstance(node.test, ast.Constant) + and node.test.value == BASE_KEY): + return True + if isinstance(node, ast.ImportFrom): + # relative import -> lazy_import + if node.level != 0: + return True + # Skip checking when using `mmengine.config` in cfg file + if (node.module == 'mmengine' and len(node.names) == 1 + and node.names[0].name == 'Config'): + continue + if not isinstance(node.module, str): + continue + # non-builtin module -> lazy_import + if not _is_builtin_module(node.module): + return True + if isinstance(node, ast.Import): + for alias_node in node.names: + if not _is_builtin_module(alias_node.name): + return True + return False + + def to_dict(self, keep_imported: bool = True) -> dict: + """Convert config object to dictionary and filter the imported + object.""" + res = self._cfg_dict.to_dict() + if hasattr(self, '_imported_names') and not keep_imported: + res = { + key: value + for key, value in res.items() + if key not in self._imported_names + } + return res + class DictAction(Action): """ diff --git a/mmengine/config/lazy.py b/mmengine/config/lazy.py new file mode 100644 index 0000000000..ab5ce35d6b --- /dev/null +++ b/mmengine/config/lazy.py @@ -0,0 +1,216 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import importlib +from typing import Any, Optional, Union + +from mmengine.utils import is_seq_of + + +class LazyObject: + """LazyObject is used to lazily initialize the imported module during + parsing the configuration file. + + During parsing process, the syntax like: + + Examples: + >>> import torch.nn as nn + >>> from mmdet.models import RetinaNet + >>> import mmcls.models + >>> import mmcls.datasets + >>> import mmcls + + Will be parsed as: + + Examples: + >>> # import torch.nn as nn + >>> nn = lazyObject('torch.nn') + >>> # from mmdet.models import RetinaNet + >>> RetinaNet = lazyObject('mmdet.models', 'RetinaNet') + >>> # import mmcls.models; import mmcls.datasets; import mmcls + >>> mmcls = lazyObject(['mmcls', 'mmcls.datasets', 'mmcls.models']) + + ``LazyObject`` records all module information and will be further + referenced by the configuration file. + + Args: + module (str or list or tuple): The module name to be imported. + imported (str, optional): The imported module name. Defaults to None. + location (str, optional): The filename and line number of the imported + module statement happened. + """ + + def __init__(self, + module: Union[str, list, tuple], + imported: Optional[str] = None, + location: Optional[str] = None): + if not isinstance(module, str) and not is_seq_of(module, str): + raise TypeError('module should be `str`, `list`, or `tuple`' + f'but got {type(module)}, this might be ' + 'a bug of MMEngine, please report it to ' + 'https://github.com/open-mmlab/mmengine/issues') + self._module: Union[str, list, tuple] = module + + if not isinstance(imported, str) and imported is not None: + raise TypeError('imported should be `str` or None, but got ' + f'{type(imported)}, this might be ' + 'a bug of MMEngine, please report it to ' + 'https://github.com/open-mmlab/mmengine/issues') + self._imported = imported + self.location = location + + def build(self) -> Any: + """Return imported object. + + Returns: + Any: Imported object + """ + if isinstance(self._module, str): + try: + module = importlib.import_module(self._module) + except Exception as e: + raise type(e)(f'Failed to import {self._module} ' + f'in {self.location} for {e}') + + if self._imported is not None: + if hasattr(module, self._imported): + module = getattr(module, self._imported) + else: + raise ImportError( + f'Failed to import {self._imported} ' + f'from {self._module} in {self.location}') + + return module + else: + # import xxx.xxx + # import xxx.yyy + # import xxx.zzz + # return imported xxx + try: + for module in self._module: + importlib.import_module(module) # type: ignore + module_name = self._module[0].split('.')[0] + return importlib.import_module(module_name) + except Exception as e: + raise type(e)(f'Failed to import {self.module} ' + f'in {self.location} for {e}') + + @property + def module(self): + if isinstance(self._module, str): + return self._module + return self._module[0].split('.')[0] + + def __call__(self, *args, **kwargs): + raise RuntimeError() + + def __deepcopy__(self, memo): + return LazyObject(self._module, self._imported, self.location) + + def __getattr__(self, name): + # Cannot locate the line number of the getting attribute. + # Therefore only record the filename. + if self.location is not None: + location = self.location.split(', line')[0] + else: + location = self.location + return LazyAttr(name, self, location) + + def __str__(self) -> str: + if self._imported is not None: + return self._imported + return self.module + + __repr__ = __str__ + + +class LazyAttr: + """The attribute of the LazyObject. + + When parsing the configuration file, the imported syntax will be + parsed as the assignment ``LazyObject``. During the subsequent parsing + process, users may reference the attributes of the LazyObject. + To ensure that these attributes also contain information needed to + reconstruct the attribute itself, LazyAttr was introduced. + + Examples: + >>> models = LazyObject(['mmdet.models']) + >>> model = dict(type=models.RetinaNet) + >>> print(type(model['type'])) # + >>> print(model['type'].build()) # + """ # noqa: E501 + + def __init__(self, + name: str, + source: Union['LazyObject', 'LazyAttr'], + location=None): + self.name = name + self.source: Union[LazyAttr, LazyObject] = source + + if isinstance(self.source, LazyObject): + if isinstance(self.source._module, str): + # In this case, the source code of LazyObject could be one of + # the following: + # 1. import xxx.yyy as zzz + # 2. from xxx.yyy import zzz + + # The equivalent code of LazyObject is: + # 1. zzz = LazyObject('xxx.yyy') + # 2. zzz = LazyObject('xxx.yyy', 'zzz') + + # The source code of LazyAttr will be: + # eee = zzz.eee + # Then, eee._module = xxx.yyy + self._module = self.source._module + else: + # The source code of LazyObject should be + # 1. import xxx.yyy + # 2. import xxx.zzz + # Equivalent to + # xxx = LazyObject(['xxx.yyy', 'xxx.zzz']) + + # The source code of LazyAttr should be + # eee = xxx.eee + # Then, eee._module = xxx + self._module = str(self.source) + elif isinstance(self.source, LazyAttr): + # 1. import xxx + # 2. zzz = xxx.yyy.zzz + + # Equivalent to: + # xxx = LazyObject('xxx') + # zzz = xxx.yyy.zzz + # zzz._module = xxx.yyy._module + zzz.name + self._module = f'{self.source._module}.{self.source.name}' + self.location = location + + @property + def module(self): + return self._module + + def __call__(self, *args, **kwargs: Any) -> Any: + raise RuntimeError() + + def __getattr__(self, name: str) -> 'LazyAttr': + return LazyAttr(name, self) + + def __deepcopy__(self, memo): + return LazyAttr(self.name, self.source) + + def build(self) -> Any: + """Return the attribute of the imported object. + + Returns: + Any: attribute of the imported object. + """ + obj = self.source.build() + try: + return getattr(obj, self.name) + except AttributeError: + raise ImportError(f'Failed to import {self.module}.{self.name} in ' + f'{self.location}') + except ImportError as e: + raise e + + def __str__(self) -> str: + return self.name + + __repr__ = __str__ diff --git a/mmengine/config/utils.py b/mmengine/config/utils.py index a967bb3691..f6703c8010 100644 --- a/mmengine/config/utils.py +++ b/mmengine/config/utils.py @@ -2,12 +2,17 @@ import ast import os.path as osp import re +import sys import warnings -from typing import Tuple +from collections import defaultdict +from importlib.util import find_spec +from typing import List, Optional, Tuple, Union from mmengine.fileio import load from mmengine.utils import check_file_exist +PYTHON_ROOT_DIR = osp.dirname(osp.dirname(sys.executable)) + MODULE2PACKAGE = { 'mmcls': 'mmcls', 'mmdet': 'mmdet', @@ -144,3 +149,300 @@ def visit_Assign(self, node): return None else: return node + + +def _is_builtin_module(module_name: str) -> bool: + """Check if a module is a built-in module. + + Arg: + module_name: name of module. + """ + if module_name.startswith('.'): + return False + if module_name.startswith('mmengine.config'): + return True + spec = find_spec(module_name.split('.')[0]) + # Module not found + if spec is None: + return False + origin_path = getattr(spec, 'origin', None) + if origin_path is None: + return False + origin_path = osp.abspath(origin_path) + if ('site-package' in origin_path + or not origin_path.startswith(PYTHON_ROOT_DIR)): + return False + else: + return True + + +class ImportTransformer(ast.NodeTransformer): + """Convert the import syntax to the assignment of + :class:`mmengine.config.LazyObject` and preload the base variable before + parsing the configuration file. + + Since you are already looking at this part of the code, I believe you must + be interested in the mechanism of the ``lazy_import`` feature of + :class:`Config`. In this docstring, we will dive deeper into its + principles. + + Most of OpenMMLab users maybe bothered with that: + + * In most of popular IDEs, they cannot navigate to the source code in + configuration file + * In most of popular IDEs, they cannot jump to the base file in current + configuration file, which is much painful when the inheritance + relationship is complex. + + In order to solve this problem, we introduce the ``lazy_import`` mode. + + A very intuitive idea for solving this problem is to import the module + corresponding to the "type" field using the ``import`` syntax. Similarly, + we can also ``import`` base file. + + However, this approach has a significant drawback. It requires triggering + the import logic to parse the configuration file, which can be + time-consuming. Additionally, it implies downloading numerous dependencies + solely for the purpose of parsing the configuration file. + However, it's possible that only a portion of the config will actually be + used. For instance, the package used in the ``train_pipeline`` may not + be necessary for an evaluation task. Forcing users to download these + unused packages is not a desirable solution. + + To avoid this problem, we introduce :class:`mmengine.config.LazyObject` and + :class:`mmengine.config.LazyAttr`. Before we proceed with further + explanations, you may refer to the documentation of these two modules to + gain an understanding of their functionalities. + + Actually, one of the functions of ``ImportTransformer`` is to hack the + ``import`` syntax. It will replace the import syntax + (exclude import the base files) with the assignment of ``LazyObject``. + + As for the import syntax of the base file, we cannot lazy import it since + we're eager to merge the fields of current file and base files. Therefore, + another function of the ``ImportTransformer`` is to collaborate with + ``Config._parse_lazy_import`` to parse the base files. + + Args: + global_dict (dict): The global dict of the current configuration file. + If we divide ordinary Python syntax into two parts, namely the + import section and the non-import section (assuming a simple case + with imports at the beginning and the rest of the code following), + the variables generated by the import statements are stored in + global variables for subsequent code use. In this context, + the ``global_dict`` represents the global variables required when + executing the non-import code. ``global_dict`` will be filled + during visiting the parsed code. + base_dict (dict): All variables defined in base files. + + Examples: + >>> if '_base_': + >>> from .._base_.default_runtime import * + >>> from .._base_.datasets.coco_detection import dataset + + In this case, the base_dict will be: + + Examples: + >>> base_dict = { + >>> '.._base_.default_runtime': ... + >>> '.._base_.datasets.coco_detection': dataset} + + and `global_dict` will be updated like this: + + Examples: + >>> global_dict.update(base_dict['.._base_.default_runtime']) # `import *` means update all data + >>> global_dict.update(dataset=base_dict['.._base_.datasets.coco_detection']['dataset']) # only update `dataset` + """ # noqa: E501 + + def __init__(self, + global_dict: dict, + base_dict: Optional[dict] = None, + filename: Optional[str] = None): + self.base_dict = base_dict if base_dict is not None else {} + self.global_dict = global_dict + # In Windows, the filename could be like this: + # "C:\\Users\\runneradmin\\AppData\\Local\\" + # Although it has been an raw string, ast.parse will firstly escape + # it as the executed code: + # "C:\Users\runneradmin\AppData\Local\\\" + # As you see, the `\U` will be treated as a part of + # the escape sequence during code parsing, leading to an + # parsing error + # Here we use `encode('unicode_escape').decode()` for double escaping + if isinstance(filename, str): + filename = filename.encode('unicode_escape').decode() + self.filename = filename + self.imported_obj: set = set() + super().__init__() + + def visit_ImportFrom( + self, node: ast.ImportFrom + ) -> Optional[Union[List[ast.Assign], ast.ImportFrom]]: + """Hack the ``from ... import ...`` syntax and update the global_dict. + + Examples: + >>> from mmdet.models import RetinaNet + + Will be parsed as: + + Examples: + >>> RetinaNet = lazyObject('mmdet.models', 'RetinaNet') + + ``global_dict`` will also be updated by ``base_dict`` as the + class docstring says. + + Args: + node (ast.AST): The node of the current import statement. + + Returns: + Optional[List[ast.Assign]]: There three cases: + + * If the node is a statement of importing base files. + None will be returned. + * If the node is a statement of importing a builtin module, + node will be directly returned + * Otherwise, it will return the assignment statements of + ``LazyObject``. + """ + # Built-in modules will not be parsed as LazyObject + module = f'{node.level*"."}{node.module}' + if _is_builtin_module(module): + return node + + if module in self.base_dict: + for alias_node in node.names: + if alias_node.name == '*': + self.global_dict.update(self.base_dict[module]) + return None + if alias_node.asname is not None: + base_key = alias_node.asname + else: + base_key = alias_node.name + self.global_dict[base_key] = self.base_dict[module][ + alias_node.name] + return None + + nodes: List[ast.Assign] = [] + for alias_node in node.names: + # `ast.alias` has lineno attr after Python 3.10, + if hasattr(alias_node, 'lineno'): + lineno = alias_node.lineno + else: + lineno = node.lineno + if alias_node.name == '*': + # TODO: If users import * from a non-config module, it should + # fallback to import the real module and raise a warning to + # remind users the real module will be imported which will slow + # down the parsing speed. + raise RuntimeError( + 'Illegal syntax in config! `from xxx import *` is not ' + 'allowed to appear outside the `if base:` statement') + elif alias_node.asname is not None: + # case1: + # from mmengine.dataset import BaseDataset as Dataset -> + # Dataset = LazyObject('mmengine.dataset', 'BaseDataset') + code = f'{alias_node.asname} = LazyObject("{module}", "{alias_node.name}", "{self.filename}, line {lineno}")' # noqa: E501 + self.imported_obj.add(alias_node.asname) + else: + # case2: + # from mmengine.model import BaseModel + # BaseModel = LazyObject('mmengine.model', 'BaseModel') + code = f'{alias_node.name} = LazyObject("{module}", "{alias_node.name}", "{self.filename}, line {lineno}")' # noqa: E501 + self.imported_obj.add(alias_node.name) + try: + nodes.append(ast.parse(code).body[0]) # type: ignore + except Exception as e: + raise ImportError( + f'Cannot import {alias_node} from {module}', + '1. Cannot import * from 3rd party lib in the config ' + 'file\n' + '2. Please check if the module is a base config which ' + 'should be added to `_base_`\n', + ) from e + return nodes + + def visit_Import(self, node) -> Union[ast.Assign, ast.Import]: + """Work with ``_gather_abs_import_lazyobj`` to hack the ``import ...`` + syntax. + + Examples: + >>> import mmcls.models + >>> import mmcls.datasets + >>> import mmcls + + Will be parsed as: + + Examples: + >>> # import mmcls.models; import mmcls.datasets; import mmcls + >>> mmcls = lazyObject(['mmcls', 'mmcls.datasets', 'mmcls.models']) + + Args: + node (ast.AST): The node of the current import statement. + + Returns: + ast.Assign: If the import statement is ``import ... as ...``, + ast.Assign will be returned, otherwise node will be directly + returned. + """ + # For absolute import like: `import mmdet.configs as configs`. + # It will be parsed as: + # configs = LazyObject('mmdet.configs') + # For absolute import like: + # `import mmdet.configs` + # `import mmdet.configs.default_runtime` + # This will be parsed as + # mmdet = LazyObject(['mmdet.configs.default_runtime', 'mmdet.configs]) + # However, visit_Import cannot gather other import information, so + # `_gather_abs_import_LazyObject` will gather all import information + # from the same module and construct the LazyObject. + alias_list = node.names + assert len(alias_list) == 1, ( + 'Illegal syntax in config! import multiple modules in one line is ' + 'not supported') + # TODO Support multiline import + alias = alias_list[0] + if alias.asname is not None: + self.imported_obj.add(alias.asname) + return ast.parse( # type: ignore + f'{alias.asname} = LazyObject(' + f'"{alias.name}",' + f'location="{self.filename}, line {node.lineno}")').body[0] + return node + + +def _gather_abs_import_lazyobj(tree: ast.Module, + filename: Optional[str] = None): + """Experimental implementation of gathering absolute import information.""" + if isinstance(filename, str): + filename = filename.encode('unicode_escape').decode() + imported = defaultdict(list) + abs_imported = set() + new_body: List[ast.stmt] = [] + # module2node is used to get lineno when Python < 3.10 + module2node: dict = dict() + for node in tree.body: + if isinstance(node, ast.Import): + for alias in node.names: + # Skip converting built-in module to LazyObject + if _is_builtin_module(alias.name): + new_body.append(node) + continue + module = alias.name.split('.')[0] + module2node.setdefault(module, node) + imported[module].append(alias) + continue + new_body.append(node) + + for key, value in imported.items(): + names = [_value.name for _value in value] + if hasattr(value[0], 'lineno'): + lineno = value[0].lineno + else: + lineno = module2node[key].lineno + lazy_module_assign = ast.parse( + f'{key} = LazyObject({names}, location="{filename}, line {lineno}")' # noqa: E501 + ) # noqa: E501 + abs_imported.add(key) + new_body.insert(0, lazy_module_assign.body[0]) + tree.body = new_body + return tree, abs_imported diff --git a/mmengine/model/averaged_model.py b/mmengine/model/averaged_model.py index a8876b81c2..58457c2a6e 100644 --- a/mmengine/model/averaged_model.py +++ b/mmengine/model/averaged_model.py @@ -258,5 +258,6 @@ def avg_func(self, averaged_param: Tensor, source_param: Tensor, steps (int): The number of times the parameters have been updated. """ - momentum = max(self.momentum, self.gamma / (self.gamma + self.steps)) + momentum = max(self.momentum, + self.gamma / (self.gamma + self.steps.item())) averaged_param.lerp_(source_param, momentum) diff --git a/mmengine/model/base_module.py b/mmengine/model/base_module.py index 6912a35357..6e9f8ee6d8 100644 --- a/mmengine/model/base_module.py +++ b/mmengine/model/base_module.py @@ -10,7 +10,7 @@ from mmengine.dist import master_only from mmengine.logging import MMLogger, print_log -from .weight_init import initialize, update_init_info +from .weight_init import PretrainedInit, initialize, update_init_info from .wrappers.utils import is_model_wrapper @@ -116,7 +116,8 @@ def init_weights(self): pretrained_cfg = [] for init_cfg in init_cfgs: assert isinstance(init_cfg, dict) - if init_cfg['type'] == 'Pretrained': + if (init_cfg['type'] == 'Pretrained' + or init_cfg['type'] is PretrainedInit): pretrained_cfg.append(init_cfg) else: other_cfgs.append(init_cfg) diff --git a/mmengine/registry/build_functions.py b/mmengine/registry/build_functions.py index 57b2c85d0e..ba22680e85 100644 --- a/mmengine/registry/build_functions.py +++ b/mmengine/registry/build_functions.py @@ -186,7 +186,7 @@ def build_runner_from_cfg(cfg: Union[dict, ConfigDict, Config], # temporarily. scope = args.pop('_scope_', None) with registry.switch_scope_and_registry(scope) as registry: - obj_type = args.get('runner_type', 'mmengine.Runner') + obj_type = args.get('runner_type', 'Runner') if isinstance(obj_type, str): runner_cls = registry.get(obj_type) if runner_cls is None: diff --git a/mmengine/registry/registry.py b/mmengine/registry/registry.py index 18545e97a3..f2c21b063c 100644 --- a/mmengine/registry/registry.py +++ b/mmengine/registry/registry.py @@ -11,7 +11,7 @@ from rich.table import Table from mmengine.config.utils import MODULE2PACKAGE -from mmengine.utils import is_seq_of +from mmengine.utils import get_object_from_string, is_seq_of from .default_scope import DefaultScope @@ -384,8 +384,12 @@ def import_from_location(self) -> None: def get(self, key: str) -> Optional[Type]: """Get the registry record. - The method will first parse :attr:`key` and check whether it contains - a scope name. The logic to search for :attr:`key`: + If `key`` represents the whole object name with its module + information, for example, `mmengine.model.BaseModel`, ``get`` + will directly return the class object :class:`BaseModel`. + + Otherwise, it will first parse ``key`` and check whether it + contains a scope name. The logic to search for ``key``: - ``key`` does not contain a scope name, i.e., it is purely a module name like "ResNet": :meth:`get` will search for ``ResNet`` from the @@ -433,6 +437,24 @@ def get(self, key: str) -> Optional[Type]: # Avoid circular import from ..logging import print_log + if not isinstance(key, str): + raise TypeError( + 'The key argument of `Registry.get` must be a str, ' + f'got {type(key)}') + + # Actually, it's strange to implement this `try ... except` to get the + # object by its name in `Registry.get`. However, If we want to build + # the model using a configuration like + # `dict(type='mmengine.model.BaseModel')`, which can + # be dumped by lazy import config, we need this code snippet + # for `Registry.get` to work. + try: + obj_cls = get_object_from_string(key) + except Exception: + raise RuntimeError(f'Failed to get {key}') + if obj_cls is not None: + return obj_cls + scope, real_key = self.split_scope_key(key) obj_cls = None registry_name = self.name diff --git a/mmengine/runner/loops.py b/mmengine/runner/loops.py index cd9310c751..6a874a6ad6 100644 --- a/mmengine/runner/loops.py +++ b/mmengine/runner/loops.py @@ -334,7 +334,7 @@ def __init__(self, fp16: bool = False) -> None: super().__init__(runner, dataloader) - if isinstance(evaluator, dict) or isinstance(evaluator, list): + if isinstance(evaluator, (dict, list)): self.evaluator = runner.build_evaluator(evaluator) # type: ignore else: assert isinstance(evaluator, Evaluator), ( diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index b0b1468969..2431011c3f 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -367,8 +367,12 @@ def __init__( mmengine.mkdir_or_exist(self._log_dir) # Used to reset registries location. See :meth:`Registry.build` for # more details. - self.default_scope = DefaultScope.get_instance( - self._experiment_name, scope_name=default_scope) + if default_scope is not None: + default_scope = DefaultScope.get_instance( # type: ignore + self._experiment_name, + scope_name=default_scope) + self.default_scope = default_scope + # Build log processor to format message. log_processor = dict() if log_processor is None else log_processor self.log_processor = self.build_log_processor(log_processor) @@ -878,6 +882,7 @@ def wrap_model( broadcast_buffers=False, find_unused_parameters=find_unused_parameters) else: + model_wrapper_cfg.setdefault('type', 'MMDistributedDataParallel') model_wrapper_type = MODEL_WRAPPERS.get( model_wrapper_cfg.get('type')) # type: ignore default_args: dict = dict() @@ -1384,7 +1389,14 @@ def build_dataloader(dataloader: Union[DataLoader, Dict], if 'worker_init_fn' in dataloader_cfg: worker_init_fn_cfg = dataloader_cfg.pop('worker_init_fn') worker_init_fn_type = worker_init_fn_cfg.pop('type') - worker_init_fn = FUNCTIONS.get(worker_init_fn_type) + if isinstance(worker_init_fn_type, str): + worker_init_fn = FUNCTIONS.get(worker_init_fn_type) + elif callable(worker_init_fn_type): + worker_init_fn = worker_init_fn_type + else: + raise TypeError( + 'type of worker_init_fn should be string or callable ' + f'object, but got {type(worker_init_fn)}') assert callable(worker_init_fn) init_fn = partial(worker_init_fn, **worker_init_fn_cfg) # type: ignore @@ -1423,7 +1435,10 @@ def build_dataloader(dataloader: Union[DataLoader, Dict], dict(type='pseudo_collate')) if isinstance(collate_fn_cfg, dict): collate_fn_type = collate_fn_cfg.pop('type') - collate_fn = FUNCTIONS.get(collate_fn_type) + if isinstance(collate_fn_type, str): + collate_fn = FUNCTIONS.get(collate_fn_type) + else: + collate_fn = collate_fn_type collate_fn = partial(collate_fn, **collate_fn_cfg) # type: ignore elif callable(collate_fn_cfg): collate_fn = collate_fn_cfg @@ -1431,7 +1446,6 @@ def build_dataloader(dataloader: Union[DataLoader, Dict], raise TypeError( 'collate_fn should be a dict or callable object, but got ' f'{collate_fn_cfg}') - data_loader = DataLoader( dataset=dataset, sampler=sampler if batch_sampler is None else None, diff --git a/mmengine/utils/__init__.py b/mmengine/utils/__init__.py index 2800d935c6..ba89c4ffe8 100644 --- a/mmengine/utils/__init__.py +++ b/mmengine/utils/__init__.py @@ -1,7 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. from .manager import ManagerMeta, ManagerMixin from .misc import (apply_to, check_prerequisites, concat_list, - deprecated_api_warning, deprecated_function, has_method, + deprecated_api_warning, deprecated_function, + get_object_from_string, has_method, import_modules_from_strings, is_list_of, is_method_overridden, is_seq_of, is_str, is_tuple_of, iter_cast, list_cast, requires_executable, requires_package, @@ -28,5 +29,5 @@ 'get_git_hash', 'ManagerMeta', 'ManagerMixin', 'Timer', 'check_time', 'TimerError', 'ProgressBar', 'track_iter_progress', 'track_parallel_progress', 'track_progress', 'deprecated_function', - 'apply_to' + 'apply_to', 'get_object_from_string' ] diff --git a/mmengine/utils/misc.py b/mmengine/utils/misc.py index aaea15c4a3..948329f603 100644 --- a/mmengine/utils/misc.py +++ b/mmengine/utils/misc.py @@ -9,7 +9,7 @@ import warnings from collections import abc from importlib import import_module -from inspect import getfullargspec +from inspect import getfullargspec, ismodule from itertools import repeat from typing import Any, Callable, Optional, Type, Union @@ -500,3 +500,43 @@ def wrapper(*args, **kwargs): return wrapper return decorator + + +def get_object_from_string(obj_name: str): + """Get object from name. + + Args: + obj_name (str): The name of the object. + + Examples: + >>> get_object_from_string('torch.optim.sgd.SGD') + >>> torch.optim.sgd.SGD + """ + parts = iter(obj_name.split('.')) + module_name = next(parts) + # import module + while True: + try: + module = import_module(module_name) + part = next(parts) + # mmcv.ops has nms.py has nms function at the same time. So the + # function will have a higher priority + obj = getattr(module, part, None) + if obj is not None and not ismodule(obj): + break + module_name = f'{module_name}.{part}' + except StopIteration: + # if obj is a module + return module + except ImportError: + return None + + # get class or attribute from module + while True: + try: + obj_cls = getattr(module, part) + part = next(parts) + except StopIteration: + return obj_cls + except AttributeError: + return None diff --git a/mmengine/utils/package_utils.py b/mmengine/utils/package_utils.py index 47a301721d..8abf4965c6 100644 --- a/mmengine/utils/package_utils.py +++ b/mmengine/utils/package_utils.py @@ -1,5 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -import importlib import os.path as osp import subprocess @@ -13,6 +12,8 @@ def is_installed(package: str) -> bool: # When executing `import mmengine.runner`, # pkg_resources will be imported and it takes too much time. # Therefore, import it in function scope to save time. + import importlib.util + import pkg_resources from pkg_resources import get_distribution @@ -23,7 +24,7 @@ def is_installed(package: str) -> bool: get_distribution(package) return True except pkg_resources.DistributionNotFound: - return False + return importlib.util.find_spec(package) is not None def get_installed_path(package: str) -> str: @@ -36,13 +37,40 @@ def get_installed_path(package: str) -> str: >>> get_installed_path('mmcls') >>> '.../lib/python3.7/site-packages/mmcls' """ - from pkg_resources import get_distribution + import importlib.util + + from pkg_resources import DistributionNotFound, get_distribution # if the package name is not the same as module name, module name should be # inferred. For example, mmcv-full is the package name, but mmcv is module # name. If we want to get the installed path of mmcv-full, we should concat # the pkg.location and module name - pkg = get_distribution(package) + try: + pkg = get_distribution(package) + except DistributionNotFound as e: + # if the package is not installed, package path set in PYTHONPATH + # can be detected by `find_spec` + spec = importlib.util.find_spec(package) + if spec is not None: + if spec.origin is not None: + return osp.dirname(spec.origin) + # For namespace packages, the origin is None, and the first path + # in submodule_search_locations will be returned. + # namespace packages: https://packaging.python.org/en/latest/guides/packaging-namespace-packages/ # noqa: E501 + elif spec.submodule_search_locations is not None: + locations = spec.submodule_search_locations + if isinstance(locations, list): + return locations[0] + else: + # `submodule_search_locations` is not subscriptable in + # python3.7. There for we use `_path` to get the first + # path. + return locations._path[0] # type: ignore + else: + raise e + else: + raise e + possible_path = osp.join(pkg.location, package) if osp.exists(possible_path): return possible_path diff --git a/requirements/docs.txt b/requirements/docs.txt index a63366dc85..a0d0e05000 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -4,7 +4,9 @@ opencv-python -e git+https://github.com/open-mmlab/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme sphinx==4.5.0 sphinx-copybutton +sphinx-tabs sphinx_markdown_tables +tabulate torch torchvision urllib3<2.0.0 diff --git a/tests/data/config/lazy_module_config/__init__.py b/tests/data/config/lazy_module_config/__init__.py new file mode 100644 index 0000000000..ef101fec61 --- /dev/null +++ b/tests/data/config/lazy_module_config/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/tests/data/config/lazy_module_config/_base_/__init__.py b/tests/data/config/lazy_module_config/_base_/__init__.py new file mode 100644 index 0000000000..ef101fec61 --- /dev/null +++ b/tests/data/config/lazy_module_config/_base_/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/tests/data/config/lazy_module_config/_base_/base_model.py b/tests/data/config/lazy_module_config/_base_/base_model.py new file mode 100644 index 0000000000..8e3a9dab7a --- /dev/null +++ b/tests/data/config/lazy_module_config/_base_/base_model.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.testing.runner_test_case import ToyModel + +model = dict(type=ToyModel) diff --git a/tests/data/config/lazy_module_config/_base_/default_runtime.py b/tests/data/config/lazy_module_config/_base_/default_runtime.py new file mode 100644 index 0000000000..d8ab215548 --- /dev/null +++ b/tests/data/config/lazy_module_config/_base_/default_runtime.py @@ -0,0 +1,36 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, + LoggerHook, ParamSchedulerHook) + +default_scope = 'test_config' + +# configure default hooks +default_hooks = dict( + timer=dict(type=IterTimerHook), + logger=dict(type=LoggerHook, interval=100), + param_scheduler=dict(type=ParamSchedulerHook), + checkpoint=dict(type=CheckpointHook, interval=1), + sampler_seed=dict(type=DistSamplerSeedHook), +) + +# configure environment +env_cfg = dict( + # whether to enable cudnn benchmark + cudnn_benchmark=False, + # set multi process parameters + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters + dist_cfg=dict(backend='nccl'), +) + +# set log level +log_level = 'INFO' + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = False + +# Defaults to use random seed and disable `deterministic` +randomness = dict(seed=None, deterministic=False) diff --git a/tests/data/config/lazy_module_config/_base_/scheduler.py b/tests/data/config/lazy_module_config/_base_/scheduler.py new file mode 100644 index 0000000000..a9a4c15af8 --- /dev/null +++ b/tests/data/config/lazy_module_config/_base_/scheduler.py @@ -0,0 +1,20 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from torch.optim import SGD + +from mmengine.optim.scheduler import MultiStepLR + +# optimizer +optim_wrapper = dict( + optimizer=dict(type=SGD, lr=0.1, momentum=0.9, weight_decay=0.0001)) +# learning policy +param_scheduler = dict( + type=MultiStepLR, by_epoch=True, milestones=[1, 2], gamma=0.1) + +# train, val, test setting +train_cfg = dict(by_epoch=True, max_epochs=5, val_interval=1) +val_cfg = dict() +test_cfg = dict() + +# NOTE: `auto_scale_lr` is for automatically scaling LR +# based on the actual training batch size. +auto_scale_lr = dict(base_batch_size=128) diff --git a/tests/data/config/lazy_module_config/error_mix_using1.py b/tests/data/config/lazy_module_config/error_mix_using1.py new file mode 100644 index 0000000000..7328294a4d --- /dev/null +++ b/tests/data/config/lazy_module_config/error_mix_using1.py @@ -0,0 +1,2 @@ +# Copyright (c) OpenMMLab. All rights reserved. +_base_ = './toy_model.py' diff --git a/tests/data/config/lazy_module_config/error_mix_using2.py b/tests/data/config/lazy_module_config/error_mix_using2.py new file mode 100644 index 0000000000..e02c9c6fbb --- /dev/null +++ b/tests/data/config/lazy_module_config/error_mix_using2.py @@ -0,0 +1,3 @@ +# Copyright (c) OpenMMLab. All rights reserved. +if '_base_': + from ...config.py_config.test_base_variables import * diff --git a/tests/data/config/lazy_module_config/load_mmdet_config.py b/tests/data/config/lazy_module_config/load_mmdet_config.py new file mode 100644 index 0000000000..9102d7486f --- /dev/null +++ b/tests/data/config/lazy_module_config/load_mmdet_config.py @@ -0,0 +1,7 @@ +# Copyright (c) OpenMMLab. All rights reserved. +if '_base_': + from mmdet.configs.retinanet.retinanet_r50_caffe_fpn_1x_coco import * + from mmdet.configs.retinanet.retinanet_r101_caffe_fpn_1x_coco import \ + model as r101 + +model = r101 diff --git a/tests/data/config/lazy_module_config/test_ast_transform.py b/tests/data/config/lazy_module_config/test_ast_transform.py new file mode 100644 index 0000000000..8e8b02f445 --- /dev/null +++ b/tests/data/config/lazy_module_config/test_ast_transform.py @@ -0,0 +1,13 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +from importlib.util import find_spec as find_module + +import numpy +import numpy.compat +import numpy.linalg as linalg + +from mmengine.config import Config +from mmengine.fileio import LocalBackend as local +from mmengine.fileio import PetrelBackend +from ._base_.default_runtime import default_scope as scope +from ._base_.scheduler import val_cfg diff --git a/tests/data/config/lazy_module_config/test_ast_transform_error_catching1.py b/tests/data/config/lazy_module_config/test_ast_transform_error_catching1.py new file mode 100644 index 0000000000..e226796591 --- /dev/null +++ b/tests/data/config/lazy_module_config/test_ast_transform_error_catching1.py @@ -0,0 +1,2 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from torch.optim import * diff --git a/tests/data/config/lazy_module_config/toy_model.py b/tests/data/config/lazy_module_config/toy_model.py new file mode 100644 index 0000000000..55d71b1959 --- /dev/null +++ b/tests/data/config/lazy_module_config/toy_model.py @@ -0,0 +1,45 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.dataset import DefaultSampler +from mmengine.hooks import EMAHook +from mmengine.model import MomentumAnnealingEMA +from mmengine.testing.runner_test_case import ToyDataset, ToyMetric + +if '_base_': + from ._base_.base_model import * + from ._base_.default_runtime import * + from ._base_.scheduler import * + +param_scheduler.milestones = [2, 4] + + +train_dataloader = dict( + dataset=dict(type=ToyDataset), + sampler=dict(type=DefaultSampler, shuffle=True), + batch_size=3, + num_workers=0) + +val_dataloader = dict( + dataset=dict(type=ToyDataset), + sampler=dict(type=DefaultSampler, shuffle=False), + batch_size=3, + num_workers=0) + +val_evaluator = [dict(type=ToyMetric)] + +test_dataloader = dict( + dataset=dict(type=ToyDataset), + sampler=dict(type=DefaultSampler, shuffle=False), + batch_size=3, + num_workers=0) + +test_evaluator = [dict(type=ToyMetric)] + +custom_hooks = [ + dict( + type=EMAHook, + ema_type=MomentumAnnealingEMA, + momentum=0.0002, + update_buffers=True, + strict_load=False, + priority=49) +] diff --git a/tests/test_config/test_config.py b/tests/test_config/test_config.py index 36a47bc6dd..58087690da 100644 --- a/tests/test_config/test_config.py +++ b/tests/test_config/test_config.py @@ -8,11 +8,14 @@ import tempfile from importlib import import_module from pathlib import Path +from unittest import TestCase from unittest.mock import patch import pytest +import mmengine from mmengine import Config, ConfigDict, DictAction +from mmengine.config.lazy import LazyObject from mmengine.fileio import dump, load from mmengine.registry import MODELS, DefaultScope, Registry from mmengine.utils import is_installed @@ -942,3 +945,180 @@ class ToyModel: assert model.backbone.style == 'pytorch' assert isinstance(model.roi_head.bbox_head.loss_cls, ToyLoss) DefaultScope._instance_dict.pop('test1') + + def test_lazy_import(self, tmp_path): + lazy_import_cfg_path = osp.join( + self.data_path, 'config/lazy_module_config/toy_model.py') + cfg = Config.fromfile(lazy_import_cfg_path) + # Dumpe config + dumped_cfg_path = tmp_path / 'test_dump_lazy.py' + cfg.dump(dumped_cfg_path) + dumped_cfg = Config.fromfile(dumped_cfg_path) + + def _compare_dict(a, b): + if isinstance(a, dict): + assert len(a) == len(b) + for k, v in a.items(): + _compare_dict(v, b[k]) + elif isinstance(a, list): + assert len(a) == len(b) + for item_a, item_b in zip(a, b): + _compare_dict(item_a, item_b) + else: + if isinstance(a, str) and a != '_module_': + assert a == b + elif isinstance(a, LazyObject): + assert str(a) == str(b) + + _compare_dict(cfg, dumped_cfg) + + # TODO reimplement this part of unit test when mmdetection adds the + # new config. + # if find_spec('mmdet') is not None: + # cfg = Config.fromfile( + # osp.join(self.data_path, + # 'config/lazy_module_config/load_mmdet_config.py')) + # assert cfg.model.backbone.depth == 101 + # cfg.work_dir = str(tmp_path) + # else: + # pytest.skip('skip testing loading config from mmdet since mmdet ' + # 'is not installed or mmdet version is too low') + + # catch import error correctly + error_obj = tmp_path / 'error_obj.py' + error_obj.write_text("""from mmengine.fileio import error_obj""") + # match pattern should be double escaped + match = str(error_obj).encode('unicode_escape').decode() + with pytest.raises(ImportError, match=match): + cfg = Config.fromfile(str(error_obj)) + cfg.error_obj + + error_attr = tmp_path / 'error_attr.py' + error_attr.write_text(""" +import mmengine +error_attr = mmengine.error_attr +""") # noqa: E122 + match = str(error_attr).encode('unicode_escape').decode() + with pytest.raises(ImportError, match=match): + cfg = Config.fromfile(str(error_attr)) + cfg.error_attr + + error_module = tmp_path / 'error_module.py' + error_module.write_text("""import error_module""") + match = str(error_module).encode('unicode_escape').decode() + with pytest.raises(ImportError, match=match): + cfg = Config.fromfile(str(error_module)) + cfg.error_module + + # lazy-import and non-lazy-import should not be used mixed. + # current text config, base lazy-import config + with pytest.raises(RuntimeError, match='if "_base_"'): + Config.fromfile( + osp.join(self.data_path, + 'config/lazy_module_config/error_mix_using1.py')) + + # current lazy-import config, base text config + with pytest.raises(RuntimeError, match='_base_ ='): + Config.fromfile( + osp.join(self.data_path, + 'config/lazy_module_config/error_mix_using2.py')) + + +class TestConfigDict(TestCase): + + def test_build_lazy(self): + # This unit test are divide into two parts: + # I. ConfigDict will never return a `LazyObject` instance. Only the + # built will be returned. The `LazyObject` can be accessed after + # `to_dict` is called. + + # II. LazyObject will always be kept in the ConfigDict no matter what + # operation is performed, such as ``update``, ``setitem``, or + # building another ConfigDict from the current one. The updated + # ConfigDict also follow the rule of Part I + + # Part I + # Keep key-value the same + raw = dict(a=1, b=dict(c=2, e=[dict(f=(2, ))])) + cfg_dict = ConfigDict(raw) + self.assertDictEqual(cfg_dict, raw) + + # Check `items` and `values` will only return the build object + raw = dict( + a=LazyObject('mmengine'), + b=dict( + c=2, + e=[ + dict( + f=dict(h=LazyObject('mmengine')), + g=LazyObject('mmengine')) + ])) + cfg_dict = ConfigDict(raw) + # check `items` and values + self.assertDictEqual(cfg_dict.to_dict(), raw) + self._check(cfg_dict) + + # check getattr + self.assertIs(cfg_dict.a, mmengine) + self.assertIs(cfg_dict.b.e[0].f.h, mmengine) + self.assertIs(cfg_dict.b.e[0].g, mmengine) + + # check get + self.assertIs(cfg_dict.get('a'), mmengine) + self.assertIs( + cfg_dict.get('b').get('e')[0].get('f').get('h'), mmengine) + self.assertIs(cfg_dict.get('b').get('e')[0].get('g'), mmengine) + + # check pop + a = cfg_dict.pop('a') + b = cfg_dict.pop('b') + e = b.pop('e') + h = e[0].pop('f')['h'] + g = e[0].pop('g') + self.assertIs(a, mmengine) + self.assertIs(h, mmengine) + self.assertIs(g, mmengine) + self.assertEqual(cfg_dict, {}) + self.assertEqual(b, {'c': 2}) + + # Part II + # check update with dict and ConfigDict + for dict_type in (dict, ConfigDict): + cfg_dict = ConfigDict(x=LazyObject('mmengine')) + cfg_dict.update(dict_type(raw)) + self._check(cfg_dict) + + # Create a new ConfigDict + new_dict = ConfigDict(cfg_dict) + self._check(new_dict) + + # Update the ConfigDict by __setitem__ and __setattr__ + new_dict['b']['h'] = LazyObject('mmengine') + new_dict['b']['k'] = dict(l=dict(n=LazyObject('mmengine'))) + new_dict.b.e[0].i = LazyObject('mmengine') + new_dict.b.e[0].j = dict(l=dict(n=LazyObject('mmengine'))) + self._check(new_dict) + + def _check(self, cfg_dict): + self._recursive_check_lazy(cfg_dict, + lambda x: not isinstance(x, LazyObject)) + self._recursive_check_lazy(cfg_dict.to_dict(), + lambda x: x is not mmengine) + self._recursive_check_lazy( + cfg_dict.to_dict(), lambda x: not isinstance(x, ConfigDict) + if isinstance(x, dict) else True) + self._recursive_check_lazy( + cfg_dict, lambda x: isinstance(x, ConfigDict) + if isinstance(x, dict) else True) + + def _recursive_check_lazy(self, cfg, expr): + if isinstance(cfg, dict): + { + key: self._recursive_check_lazy(value, expr) + for key, value in cfg.items() + } + [self._recursive_check_lazy(value, expr) for value in cfg.values()] + elif isinstance(cfg, (tuple, list)): + [self._recursive_check_lazy(value, expr) for value in cfg] + else: + self.assertTrue(expr(cfg)) diff --git a/tests/test_config/test_lazy.py b/tests/test_config/test_lazy.py new file mode 100644 index 0000000000..265f22de9b --- /dev/null +++ b/tests/test_config/test_lazy.py @@ -0,0 +1,182 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import ast +import copy +import os +import os.path as osp +from importlib import import_module +from importlib.util import find_spec +from unittest import TestCase + +import numpy +import numpy.compat +import numpy.linalg as linalg + +import mmengine +from mmengine.config import Config +from mmengine.config.lazy import LazyAttr, LazyObject +from mmengine.config.utils import ImportTransformer, _gather_abs_import_lazyobj +from mmengine.fileio import LocalBackend, PetrelBackend + + +class TestImportTransformer(TestCase): + + @classmethod + def setUpClass(cls) -> None: + cls.data_dir = osp.join( # type: ignore + osp.dirname(__file__), '..', 'data', 'config', + 'lazy_module_config') + super().setUpClass() + + def test_lazy_module(self): + cfg_path = osp.join(self.data_dir, 'test_ast_transform.py') + with open(cfg_path) as f: + codestr = f.read() + codeobj = ast.parse(codestr) + global_dict = { + 'LazyObject': LazyObject, + } + base_dict = { + '._base_.default_runtime': { + 'default_scope': 'test_config' + }, + '._base_.scheduler': { + 'val_cfg': {} + }, + } + codeobj = ImportTransformer(global_dict, base_dict).visit(codeobj) + codeobj, _ = _gather_abs_import_lazyobj(codeobj) + codeobj = ast.fix_missing_locations(codeobj) + + exec(compile(codeobj, cfg_path, mode='exec'), global_dict, global_dict) + # 1. absolute import + # 1.1 import module as LazyObject + lazy_numpy = global_dict['numpy'] + self.assertIsInstance(lazy_numpy, LazyObject) + + # 1.2 getattr as LazyAttr + self.assertIsInstance(lazy_numpy.linalg, LazyAttr) + self.assertIsInstance(lazy_numpy.compat, LazyAttr) + + # 1.3 Build module from LazyObject. amp and functional can be accessed + imported_numpy = lazy_numpy.build() + self.assertIs(imported_numpy.linalg, linalg) + self.assertIs(imported_numpy.compat, numpy.compat) + + # 1.4 Build module from LazyAttr + imported_linalg = lazy_numpy.linalg.build() + imported_compat = lazy_numpy.compat.build() + self.assertIs(imported_compat, numpy.compat) + self.assertIs(imported_linalg, linalg) + + # 1.5 import ... as, and build module from LazyObject + lazy_linalg = global_dict['linalg'] + self.assertIsInstance(lazy_linalg, LazyObject) + self.assertIs(lazy_linalg.build(), linalg) + self.assertIsInstance(lazy_linalg.norm, LazyAttr) + self.assertIs(lazy_linalg.norm.build(), linalg.norm) + + # 1.6 import built in module + imported_os = global_dict['os'] + self.assertIs(imported_os, os) + + # 2. Relative import + # 2.1 from ... import ... + lazy_local_backend = global_dict['local'] + self.assertIsInstance(lazy_local_backend, LazyObject) + self.assertIs(lazy_local_backend.build(), LocalBackend) + + # 2.2 from ... import ... as ... + lazy_petrel_backend = global_dict['PetrelBackend'] + self.assertIsInstance(lazy_petrel_backend, LazyObject) + self.assertIs(lazy_petrel_backend.build(), PetrelBackend) + + # 2.3 from ... import builtin module or obj from `mmengine.Config` + self.assertIs(global_dict['find_module'], find_spec) + self.assertIs(global_dict['Config'], Config) + + # 3 test import base config + # 3.1 simple from ... import and from ... import ... as + self.assertEqual(global_dict['scope'], 'test_config') + self.assertDictEqual(global_dict['val_cfg'], {}) + + # 4. Error catching + cfg_path = osp.join(self.data_dir, + 'test_ast_transform_error_catching1.py') + with open(cfg_path) as f: + codestr = f.read() + codeobj = ast.parse(codestr) + global_dict = {'LazyObject': LazyObject} + with self.assertRaisesRegex( + RuntimeError, + r'Illegal syntax in config! `from xxx import \*`'): + codeobj = ImportTransformer(global_dict).visit(codeobj) + + +class TestLazyObject(TestCase): + + def test_init(self): + LazyObject('mmengine') + LazyObject('mmengine.fileio') + LazyObject('mmengine.fileio', 'LocalBackend') + + # module must be str + with self.assertRaises(TypeError): + LazyObject(1) + + # imported must be a sequence of string or None + with self.assertRaises(TypeError): + LazyObject('mmengine', ['error_type']) + + def test_build(self): + lazy_mmengine = LazyObject('mmengine') + self.assertIs(lazy_mmengine.build(), mmengine) + + lazy_mmengine_fileio = LazyObject('mmengine.fileio') + self.assertIs(lazy_mmengine_fileio.build(), + import_module('mmengine.fileio')) + + lazy_local_backend = LazyObject('mmengine.fileio', 'LocalBackend') + self.assertIs(lazy_local_backend.build(), LocalBackend) + + # TODO: The commented test is required, we need to test the built + # LazyObject can access the `mmengine.dataset`. We need to clean the + # environment to make sure the `dataset` is not imported before, and + # it is triggered by lazy_mmengine.build(). However, if we simply + # pop the `mmengine.dataset` will lead to other tests failed, of which + # reason is still unknown. We need to figure out the reason and fix it + # in the latter + + # sys.modules.pop('mmengine.config') + # sys.modules.pop('mmengine.fileio') + # sys.modules.pop('mmengine') + # lazy_mmengine = LazyObject(['mmengine', 'mmengine.dataset']) + # self.assertIs(lazy_mmengine.build().dataset, + # import_module('mmengine.config')) + copied = copy.deepcopy(lazy_local_backend) + self.assertDictEqual(copied.__dict__, lazy_local_backend.__dict__) + + with self.assertRaises(RuntimeError): + lazy_mmengine() + + with self.assertRaises(ImportError): + LazyObject('unknown').build() + + +class TestLazyAttr(TestCase): + # Since LazyAttr should only be built from LazyObect, we only test + # the build method here. + def test_build(self): + lazy_mmengine = LazyObject('mmengine') + local_backend = lazy_mmengine.fileio.LocalBackend + self.assertIs(local_backend.build(), LocalBackend) + + copied = copy.deepcopy(local_backend) + self.assertDictEqual(copied.__dict__, local_backend.__dict__) + + with self.assertRaises(RuntimeError): + local_backend() + + with self.assertRaisesRegex( + ImportError, + 'Failed to import mmengine.fileio.LocalBackend.unknown'): + local_backend.unknown.build() diff --git a/tests/test_registry/test_registry.py b/tests/test_registry/test_registry.py index 0ddac65f51..ec1f0d2486 100644 --- a/tests/test_registry/test_registry.py +++ b/tests/test_registry/test_registry.py @@ -228,6 +228,10 @@ def test_get(self): DOGS, HOUNDS, LITTLE_HOUNDS = registries[:3] MID_HOUNDS, SAMOYEDS, LITTLE_SAMOYEDS = registries[3:] + # error type of key + with pytest.raises(TypeError): + MID_HOUNDS.get(None) + @DOGS.register_module() def bark(word, times): return [word] * times @@ -318,6 +322,14 @@ class LittlePedigreeSamoyed: assert DOGS.get('samoyed.LittlePedigreeSamoyed') is None assert LITTLE_HOUNDS.get('mid_hound.PedigreeSamoyedddddd') is None + # Get mmengine.utils by string + utils = LITTLE_HOUNDS.get('mmengine.utils') + import mmengine.utils + assert utils is mmengine.utils + + unknown = LITTLE_HOUNDS.get('mmengine.unknown') + assert unknown is None + def test__search_child(self): # Hierarchical Registry # DOGS diff --git a/tests/test_utils/test_misc.py b/tests/test_utils/test_misc.py index 3798759e69..700cde8759 100644 --- a/tests/test_utils/test_misc.py +++ b/tests/test_utils/test_misc.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from collections import namedtuple +from importlib import import_module import numpy as np import pytest @@ -8,13 +9,13 @@ from mmengine import MMLogger # yapf: disable from mmengine.utils.misc import (apply_to, concat_list, deprecated_api_warning, - deprecated_function, has_method, - import_modules_from_strings, is_list_of, - is_method_overridden, is_seq_of, is_tuple_of, - iter_cast, list_cast, requires_executable, - requires_package, slice_list, to_1tuple, - to_2tuple, to_3tuple, to_4tuple, to_ntuple, - tuple_cast) + deprecated_function, get_object_from_string, + has_method, import_modules_from_strings, + is_list_of, is_method_overridden, is_seq_of, + is_tuple_of, iter_cast, list_cast, + requires_executable, requires_package, + slice_list, to_1tuple, to_2tuple, to_3tuple, + to_4tuple, to_ntuple, tuple_cast) # yapf: enable @@ -327,3 +328,11 @@ def test_apply_to(): assert result[0] == 'train' assert isinstance(result.b['a'][0]['c'], torch.Tensor) assert isinstance(result.b['b'], float) + + +def test_locate(): + assert get_object_from_string('a.b.c') is None + model_module = import_module('mmengine.model') + assert get_object_from_string('mmengine.model') is model_module + assert get_object_from_string( + 'mmengine.model.BaseModel') is model_module.BaseModel diff --git a/tests/test_utils/test_package_utils.py b/tests/test_utils/test_package_utils.py new file mode 100644 index 0000000000..c23d1c31c4 --- /dev/null +++ b/tests/test_utils/test_package_utils.py @@ -0,0 +1,41 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import sys +from pathlib import Path + +from mmengine.utils import get_installed_path, is_installed + + +def test_is_installed(): + # TODO: Windows CI may failed in unknown reason. Skip check the value + is_installed('mmengine') + + # package set by PYTHONPATH + assert not is_installed('py_config') + sys.path.append(osp.abspath(osp.join(osp.dirname(__file__), '..'))) + assert is_installed('test_utils') + sys.path.pop() + + +def test_get_install_path(tmp_path: Path): + # TODO: Windows CI may failed in unknown reason. Skip check the value + get_installed_path('mmengine') + + # get path for package "installed" by setting PYTHONPATH + PYTHONPATH = osp.abspath(osp.join( + osp.dirname(__file__), + '..', + )) + sys.path.append(PYTHONPATH) + res_path = get_installed_path('test_utils') + assert osp.join(PYTHONPATH, 'test_utils') == res_path + + # return the first path for namespace package + # See more information about namespace package in: + # https://packaging.python.org/en/latest/guides/packaging-namespace-packages/ # noqa:E501 + (tmp_path / 'test_utils').mkdir() + sys.path.insert(-1, str(tmp_path)) + res_path = get_installed_path('test_utils') + assert osp.abspath(osp.join(tmp_path, 'test_utils')) == res_path + sys.path.pop() + sys.path.pop()