You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[Bug] When I tried using RMosaic data augmentation and set the workflow to [('train', 1), ('val', 1)], I encountered an error that showed 'KeyError: img'. However, when I changed the workflow back to [('train', 1)], the error disappeared. Why is this happening?
#1063
Prerequisite
Task
I have modified the scripts/configs, or I'm working on my own tasks/models/datasets.
Branch
master branch https://github.com/open-mmlab/mmrotate
Environment
fatal: not a git repository (or any of the parent directories): .git
sys.platform: win32
Python: 3.8.19 (default, Mar 20 2024, 19:55:45) [MSC v.1916 64 bit (AMD64)]
CUDA available: True
GPU 0: NVIDIA GeForce RTX 3060 Ti
CUDA_HOME: C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8
NVCC: Cuda compilation tools, release 11.8, V11.8.89
MSVC: 用于 x64 的 Microsoft (R) C/C++ 优化编译器 19.40.33811 版
GCC: n/a
PyTorch: 1.8.0+cu111
PyTorch compiling details: PyTorch built with:
imental -DNDEBUG -DUSE_FBGEMM -DUSE_XNNPACK, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.8.0, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=OFF, USE_NNPACK=OFF, USE_OPENMP=ON,
TorchVision: 0.9.0+cu111
OpenCV: 4.10.0
MMCV: 1.6.0
MMCV Compiler: MSVC 192930137
MMCV CUDA Compiler: 11.1
MMRotate: 0.3.4+
Reproduces the problem - code sample
新配置继承基础模型的设置
from mmdet.core.evaluation import class_names
base = '../../configs/lsknet/lsk_s_fpn_1x_dota_le90.py'
1.数据集设置
dataset_type = 'DOTADataset'
data_root = 'data/test/'
angle_version = 'le90'
classes = ('dam', 'groyne', 'lock', 'sluice', 'weir')
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(
type='RRandomFlip',
flip_ratio=[0.25, 0.25, 0.25],
direction=['horizontal', 'vertical', 'diagonal'],
version=angle_version),
dict(
type='RMosaic', # 马赛克的数据增强
img_scale=(512, 512),
center_ratio_range=(0.8, 1.2),
min_bbox_size=6,
bbox_clip_border=True,
skip_filter=True,
version=angle_version),
dict(
type='PolyRandomRotate', # 对图像和边界框(bbox)进行旋转的数据增强
rotate_ratio=0.5,
angles_range=180,
auto_bound=False,
version=angle_version),
dict(
type='RandomPhotoMetricDistortion', # 光度变换的数据增强
prob=0.5,
brightness_delta=0, # 亮度变化范围
contrast_range=(0.9, 1.1), # 对比度变化范围
saturation_range=(0.9, 1.1), # 饱和度变化范围
hue_delta=18), # 色调变化范围
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
]
data = dict(
samples_per_gpu=2,
workers_per_gpu=4,
train=dict(
delete=True,
type='MultiImageMixDataset',
dataset=dict(
type=dataset_type,
classes=classes,
ann_file=data_root + 'train/annfiles/',
img_prefix=data_root + 'train/images/',
pipeline=[
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True)
],
version=angle_version
),
pipeline=train_pipeline,
max_refetch=500),
val=dict(
type=dataset_type,
classes=classes,
ann_file=data_root + 'val/annfiles/',
img_prefix=data_root + 'val/images/'),
test=dict(
type=dataset_type,
classes=classes,
ann_file=data_root + 'test/annfiles/',
img_prefix=data_root + 'test/images/'))
2. 优化器设置
optimizer = dict(
delete=True,
type='AdamW',
lr=0.0001, #/8*gpu_number,
betas=(0.9, 0.999),
weight_decay=0.05)
lr_config = dict(
delete=True,
policy='CosineAnnealing',
by_epoch=False,
min_lr=0,
warmup='linear',
warmup_iters=500,
warmup_ratio=1.0 / 3)
evaluation = dict(interval=1, metric='mAP', save_best='mAP')
runner = dict(type='EpochBasedRunner', max_epochs=12)
3. 结果输出设置
log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
dict(type='TensorboardLoggerHook')
])
workflow = [('train',1),('val',1)]
4. 模型设置
gpu_number = 1 # 修改为自己的GPU数量
model = dict(
backbone=dict(
init_cfg=dict(type='Pretrained', checkpoint='checkpoints/lsk_s_backbone.pth.tar'),
norm_cfg=dict(type='BN', requires_grad=True)), # 单卡训练时将SyncBN改为BN
roi_head=dict(
bbox_head=dict(
num_classes=5)))
Reproduces the problem - command or script
python tools/train.py my_demo/configs/my_lsk_s_orcnn_fpn_1x_dota_le90.py --work-dir my_demo/output/debug/lsk_orcnn
Reproduces the problem - error message
2024-09-04 11:10:24,555 - mmrotate - INFO - workflow: [('train', 1), ('val', 1)], max: 12 epochs
2024-09-04 11:10:24,555 - mmrotate - INFO - Checkpoints will be saved to G:\mmrotate\my_demo\output\debug\lsk_orcnn by HardDiskBackend.
D:\Program\Anaconda3\envs\mmrotate\lib\site-packages\mmdet\models\dense_heads\anchor_head.py:123: UserWarning: DeprecationWarning: anchor_generator is deprecated, please use "prior_generator" instead
warnings.warn('DeprecationWarning: anchor_generator is deprecated, '
2024-09-04 11:11:20,375 - mmrotate - INFO - Epoch [1][50/95] lr: 3.969e-05, eta: 0:20:14, time: 1.114, data_time: 0.271, memory: 7126, loss_rpn_cls: 0.5340, loss_rpn_bbox: 0.1305, loss_cls: 0.3892, acc: 92.8320, loss_bbox: 0.0319, loss: 1.0856, grad_norm: 12.6486
2024-09-04 11:11:51,915 - mmrotate - INFO - Saving checkpoint at 1 epochs
[>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>] 7/7, 1.0 task/s, elapsed: 7s, ETA: 0s2024-09-04 11:12:11,302 - mmrotate - INFO -
+--------+-----+------+--------+-------+
| class | gts | dets | recall | ap |
+--------+-----+------+--------+-------+
| dam | 2 | 129 | 0.000 | 0.000 |
| groyne | 0 | 6 | 0.000 | 0.000 |
| lock | 0 | 3 | 0.000 | 0.000 |
| sluice | 5 | 8 | 0.000 | 0.000 |
| weir | 0 | 15 | 0.000 | 0.000 |
+--------+-----+------+--------+-------+
| mAP | | | | 0.000 |
+--------+-----+------+--------+-------+
2024-09-04 11:12:11,340 - mmrotate - INFO - Exp name: my_lsk_s_orcnn_fpn_1x_dota_le90.py
2024-09-04 11:12:11,340 - mmrotate - INFO - Epoch(val) [1][7] mAP: 0.0000
Traceback (most recent call last):
File "tools/train.py", line 192, in
main()
File "tools/train.py", line 181, in main
train_detector(
File "g:\mmrotate\mmrotate\apis\train.py", line 141, in train_detector
runner.run(data_loaders, cfg.workflow)
File "D:\Program\Anaconda3\envs\mmrotate\lib\site-packages\mmcv\runner\epoch_based_runner.py", line 136, in run
epoch_runner(data_loaders[i], **kwargs)
File "D:\Program\Anaconda3\envs\mmrotate\lib\site-packages\torch\autograd\grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "D:\Program\Anaconda3\envs\mmrotate\lib\site-packages\mmcv\runner\epoch_based_runner.py", line 68, in val
for i, data_batch in enumerate(self.data_loader):
File "D:\Program\Anaconda3\envs\mmrotate\lib\site-packages\torch\utils\data\dataloader.py", line 517, in next
data = self._next_data()
File "D:\Program\Anaconda3\envs\mmrotate\lib\site-packages\torch\utils\data\dataloader.py", line 1199, in _next_data
return self._process_data(data)
File "D:\Program\Anaconda3\envs\mmrotate\lib\site-packages\torch\utils\data\dataloader.py", line 1225, in _process_data
data.reraise()
File "D:\Program\Anaconda3\envs\mmrotate\lib\site-packages\torch_utils.py", line 429, in reraise
raise self.exc_type(msg)
KeyError: Caught KeyError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "D:\Program\Anaconda3\envs\mmrotate\lib\site-packages\torch\utils\data_utils\worker.py", line 202, in _worker_loop
data = fetcher.fetch(index)
File "D:\Program\Anaconda3\envs\mmrotate\lib\site-packages\torch\utils\data_utils\fetch.py", line 44, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "D:\Program\Anaconda3\envs\mmrotate\lib\site-packages\torch\utils\data_utils\fetch.py", line 44, in
data = [self.dataset[idx] for idx in possibly_batched_index]
File "D:\Program\Anaconda3\envs\mmrotate\lib\site-packages\mmdet\datasets\custom.py", line 218, in getitem
data = self.prepare_train_img(idx)
File "D:\Program\Anaconda3\envs\mmrotate\lib\site-packages\mmdet\datasets\custom.py", line 241, in prepare_train_img
return self.pipeline(results)
File "D:\Program\Anaconda3\envs\mmrotate\lib\site-packages\mmdet\datasets\pipelines\compose.py", line 41, in call
data = t(data)
File "D:\Program\Anaconda3\envs\mmrotate\lib\site-packages\mmdet\datasets\pipelines\transforms.py", line 469, in call
results[key], direction=results['flip_direction'])
KeyError: 'img'
Additional information
This is just a test dataset, so the issue of mAP=0 can be ignored.
The text was updated successfully, but these errors were encountered: