From ef41f69553df5dd6921c76affc87b542f54f5c86 Mon Sep 17 00:00:00 2001 From: lzhangzz Date: Mon, 28 Jun 2021 20:35:15 +0800 Subject: [PATCH] add model splitting support (#1) * add function marker and model extractor * add fsaf split & partial mask rcnn split, import extract.py * 1. add value renaming 2. add apply_marks in config to turn on/off marks * rewind changes on pytorch2onnx.py Co-authored-by: q.yao --- configs/mmdet/split.py | 4 + mmdeploy/apis/pytorch2onnx.py | 2 +- mmdeploy/mmdet/models/dense_heads/__init__.py | 4 +- .../mmdet/models/dense_heads/fsaf_head.py | 7 + mmdeploy/mmdet/models/dense_heads/rpn_head.py | 7 + mmdeploy/mmdet/models/detectors/__init__.py | 3 +- mmdeploy/mmdet/models/detectors/two_stage.py | 19 ++ mmdeploy/utils/__init__.py | 3 +- mmdeploy/utils/function_marker.py | 61 +++++ tools/extract.py | 234 ++++++++++++++++++ 10 files changed, 340 insertions(+), 4 deletions(-) create mode 100644 configs/mmdet/split.py create mode 100644 mmdeploy/mmdet/models/dense_heads/fsaf_head.py create mode 100644 mmdeploy/mmdet/models/dense_heads/rpn_head.py create mode 100644 mmdeploy/mmdet/models/detectors/two_stage.py create mode 100644 mmdeploy/utils/function_marker.py create mode 100644 tools/extract.py diff --git a/configs/mmdet/split.py b/configs/mmdet/split.py new file mode 100644 index 000000000..b503f30fb --- /dev/null +++ b/configs/mmdet/split.py @@ -0,0 +1,4 @@ +_base_ = ['./base.py', '../_base_/backends/tensorrt.py'] + +backend = 'default' +apply_marks = True \ No newline at end of file diff --git a/mmdeploy/apis/pytorch2onnx.py b/mmdeploy/apis/pytorch2onnx.py index 5ab3c8fc9..e5170e99a 100644 --- a/mmdeploy/apis/pytorch2onnx.py +++ b/mmdeploy/apis/pytorch2onnx.py @@ -62,4 +62,4 @@ def torch2onnx(img: Any, keep_initializers_as_inputs=pytorch2onnx_cfg[ 'keep_initializers_as_inputs']) - ret_value.value = 0 + ret_value.value = 0 \ No newline at end of file diff --git a/mmdeploy/mmdet/models/dense_heads/__init__.py b/mmdeploy/mmdet/models/dense_heads/__init__.py index 37a26044f..582f4ec7f 100644 --- a/mmdeploy/mmdet/models/dense_heads/__init__.py +++ b/mmdeploy/mmdet/models/dense_heads/__init__.py @@ -1,3 +1,5 @@ from .anchor_head import AnchorHead +from .rpn_head import rpn_head_forward +from .fsaf_head import fsaf_head_forward -__all__ = ['AnchorHead'] +__all__ = ['AnchorHead', 'rpn_head_forward', 'fsaf_head_forward'] diff --git a/mmdeploy/mmdet/models/dense_heads/fsaf_head.py b/mmdeploy/mmdet/models/dense_heads/fsaf_head.py new file mode 100644 index 000000000..669a86bd0 --- /dev/null +++ b/mmdeploy/mmdet/models/dense_heads/fsaf_head.py @@ -0,0 +1,7 @@ +from mmdeploy.utils import FUNCTION_REWRITERS, mark + + +@FUNCTION_REWRITERS.register_rewriter('mmdet.models.FSAFHead.forward') +@mark('rpn_forward') +def fsaf_head_forward(rewriter, *args): + return rewriter.origin_func(*args) diff --git a/mmdeploy/mmdet/models/dense_heads/rpn_head.py b/mmdeploy/mmdet/models/dense_heads/rpn_head.py new file mode 100644 index 000000000..484db25a3 --- /dev/null +++ b/mmdeploy/mmdet/models/dense_heads/rpn_head.py @@ -0,0 +1,7 @@ +from mmdeploy.utils import FUNCTION_REWRITERS, mark + + +@FUNCTION_REWRITERS.register_rewriter('mmdet.models.RPNHead.forward') +@mark('rpn_forward') +def rpn_head_forward(rewriter, self, feats): + return rewriter.origin_func(self, feats) diff --git a/mmdeploy/mmdet/models/detectors/__init__.py b/mmdeploy/mmdet/models/detectors/__init__.py index a4e3a40ad..fcc9c701d 100644 --- a/mmdeploy/mmdet/models/detectors/__init__.py +++ b/mmdeploy/mmdet/models/detectors/__init__.py @@ -1,3 +1,4 @@ from .single_stage import SingleStageDetector +from .two_stage import extract_feat -__all__ = ['SingleStageDetector'] +__all__ = ['SingleStageDetector', 'extract_feat'] diff --git a/mmdeploy/mmdet/models/detectors/two_stage.py b/mmdeploy/mmdet/models/detectors/two_stage.py new file mode 100644 index 000000000..c3799a3e9 --- /dev/null +++ b/mmdeploy/mmdet/models/detectors/two_stage.py @@ -0,0 +1,19 @@ +from mmdeploy.utils import FUNCTION_REWRITERS, mark +from mmdeploy.utils import SYMBOLICS_REGISTER +from mmcv.onnx.symbolic import grid_sampler + + +@FUNCTION_REWRITERS.register_rewriter('mmdet.models.TwoStageDetector.extract_feat') +@mark('extract_feat') +def extract_feat(rewriter, self, img): + return rewriter.origin_func(self, img) + + +@FUNCTION_REWRITERS.register_rewriter('mmdet.models.TwoStageDetector.forward') +def two_stage_forward(rewriter, self, img, *args): + return rewriter.origin_func(self, [img], img_metas=[[{}]], return_loss=False, *args) + + +@SYMBOLICS_REGISTER.register_symbolic('grid_sampler', is_pytorch=True) +def symbolic_grid_sample(symbolic_wrapper, *args): + return grid_sampler(*args) diff --git a/mmdeploy/utils/__init__.py b/mmdeploy/utils/__init__.py index 7135100a0..d96600d95 100644 --- a/mmdeploy/utils/__init__.py +++ b/mmdeploy/utils/__init__.py @@ -1,8 +1,9 @@ from .function_rewriter import FUNCTION_REWRITERS, RewriterContext from .module_rewriter import MODULE_REWRITERS, patch_model from .symbolic_register import SYMBOLICS_REGISTER, register_extra_symbolics +from .function_marker import mark __all__ = [ 'RewriterContext', 'FUNCTION_REWRITERS', 'MODULE_REWRITERS', 'patch_model', - 'SYMBOLICS_REGISTER', 'register_extra_symbolics' + 'SYMBOLICS_REGISTER', 'register_extra_symbolics', 'mark' ] diff --git a/mmdeploy/utils/function_marker.py b/mmdeploy/utils/function_marker.py new file mode 100644 index 000000000..2fedcd8f3 --- /dev/null +++ b/mmdeploy/utils/function_marker.py @@ -0,0 +1,61 @@ +import inspect +import torch +from .function_rewriter import FUNCTION_REWRITERS + + +class Mark(torch.autograd.Function): + @staticmethod + def symbolic(g, x, type, name, id, attrs): + n = g.op("mmcv::Mark", x, type_s=type, name_s=name, id_i=id, **attrs) + return n + + @staticmethod + def forward(ctx, x, *args): + return x + + +@FUNCTION_REWRITERS.register_rewriter("mmdeploy.utils.function_marker.Mark.symbolic") +def mark_symbolic(rewriter, g, x, *args): + if rewriter.cfg.get("apply_marks", False): + return rewriter.origin_func(g, x, *args) + return x + + +def mark_tensors(xs, type, name, attrs): + index = 0 + visit = set() + + def impl(ys, prefix): + nonlocal index + if isinstance(ys, torch.Tensor): + if ys not in visit: + visit.add(ys) + index += 1 + return Mark.apply(ys, type, prefix, index - 1, attrs) + return ys + elif isinstance(ys, list): + return [impl(y, f'{prefix}/{i}') for i, y in enumerate(ys)] + elif isinstance(ys, tuple): + return tuple(impl(y, f'{prefix}/{i}') for i, y in enumerate(ys)) + elif isinstance(ys, dict): + return {k: impl(v, f'{prefix}/{k}') for k, v in ys.items()} + return ys + return impl(xs, name) + + +def mark(func, **attrs): + attrs['func_s'] = func + + def decorator(f): + params = inspect.signature(f).parameters.keys() + def g(*args, **kwargs): + if torch.onnx.is_in_onnx_export(): + args = [mark_tensors(arg, 'input', name, attrs) + for name, arg in zip(params, args)] + rets = f(*args, **kwargs) + # TODO: maybe we can traverse the AST to get the retval names? + return mark_tensors(rets, 'output', func, attrs) + else: + return f(*args, **kwargs) + return g + return decorator diff --git a/tools/extract.py b/tools/extract.py new file mode 100644 index 000000000..169187e87 --- /dev/null +++ b/tools/extract.py @@ -0,0 +1,234 @@ +import argparse +import os.path as osp +import onnx +import onnx.utils +import onnx.helper +from onnx import AttributeProto + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Extract model based on markers.') + parser.add_argument('input_model', help='Input ONNX model') + parser.add_argument('output_model', help='Output ONNX model') + parser.add_argument( + '--start', help='Start markers, format: func:type, e.g. backbone:input') + parser.add_argument('--end', help='End markers') + + args = parser.parse_args() + + args.start = args.start.split(',') if args.start else [] + args.end = args.end.split(',') if args.end else [] + + return args + + +def remove_markers(model): + shortcut = [] + success = True + while success: + success = False + for i, node in enumerate(model.graph.node): + if node.op_type == 'Mark': + for input in node.input: + shortcut.append((input, node.output)) + del model.graph.node[i] + success = True + break + for src, dsts in shortcut: + for curr in model.graph.node: + for k, input in enumerate(curr.input): + if input in dsts: + curr.input[k] = src + # TODO: handle duplicated case? + for k, output in enumerate(model.graph.output): + print(output.name, dsts) + if output.name in dsts: + output.name = src + return model + + +def attribute_to_dict(attribute): + ret = {} + for a in attribute: + name = a.name + if a.type == AttributeProto.AttributeType.STRING: + ret[name] = str(a.s, 'utf-8') + elif a.type == AttributeProto.AttributeType.INT: + ret[name] = a.i + return ret + + +def _dfs_search_reacable_nodes_fast(self, node_output_name, graph_input_nodes, reachable_nodes): + outputs = {} + for index, node in enumerate(self.graph.node): + for name in node.output: + if name not in outputs: + outputs[name] = set() + outputs[name].add(index) + + def impl(node_output_name, graph_input_nodes, reachable_nodes): + if node_output_name in graph_input_nodes: + return + if node_output_name not in outputs: + return + for index in outputs[node_output_name]: + node = self.graph.node[index] + if node in reachable_nodes: + continue + reachable_nodes.append(node) + for name in node.input: + impl(name, graph_input_nodes, reachable_nodes) + impl(node_output_name, graph_input_nodes, reachable_nodes) + + +def get_new_name(attrs): + if 'name' in attrs: + return attrs['name'] + return '_'.join((attrs['func'], attrs['type'], str(attrs['id']))) + + +def rename_value(model, old_name, new_name): + for n in model.graph.node: + for i, output in enumerate(n.output): + if output == old_name: + n.output[i] = new_name + for i, input in enumerate(n.input): + if input == old_name: + n.input[i] = new_name + for v in model.graph.value_info: + if v.name == old_name: + v.name = new_name + for i, name in enumerate(model.graph.input): + if name == old_name: + model.graph.input[i] = new_name + for i, name in enumerate(model.graph.output): + if name == old_name: + model.graph.output[i] = new_name + + +def extract_model(model, start, end): + inputs = [] + outputs = [] + if not isinstance(start, (list, tuple)): + start = [start] + for s in start: + start_name, start_type = s.split(':') + assert start_type in ['input', 'output'] + for node in model.graph.node: + if node.op_type == 'Mark': + attr = attribute_to_dict(node.attribute) + if attr['func'] == start_name and attr['type'] == start_type: + name = node.output[0] if start_type == 'input' else node.input[0] + if name not in inputs: + new_name = get_new_name(attr) + rename_value(model, name, new_name) + inputs.append(new_name) + + print(f'inputs: {inputs}') + + # collect outputs + # outputs = [] + if not isinstance(end, (list, tuple)): + end = [end] + for e in end: + end_name, end_type = e.split(':') + assert end_type in ['input', 'output'] + for node in model.graph.node: + if node.op_type == 'Mark': + attr = attribute_to_dict(node.attribute) + if attr['func'] == end_name and attr['type'] == end_type: + name = node.input[0] if end_type == 'input' else node.output[0] + if name not in outputs: + new_name = get_new_name(attr) + rename_value(model, name, new_name) + outputs.append(new_name) + + print(f'outputs: {outputs}') + + # replace Mark with Identity + for node in model.graph.node: + if node.op_type == 'Mark': + del node.attribute[:] + node.domain = '' + node.op_type = 'Identity' + + # patch extractor + onnx.utils.Extractor._dfs_search_reachable_nodes = _dfs_search_reacable_nodes_fast + + extractor = onnx.utils.Extractor(model) + extracted_model = extractor.extract_model(inputs, outputs) + + # collect all used inputs + used = set() + for node in extracted_model.graph.node: + for input in node.input: + used.add(input) + + for output in extracted_model.graph.output: + used.add(output.name) + + # delete unused inputs + success = True + while success: + success = False + for i, input in enumerate(extracted_model.graph.input): + if input.name not in used: + del extracted_model.graph.input[i] + success = True + break + + # eliminate output without shape + for xs in [extracted_model.graph.output]: + for x in xs: + if not x.type.tensor_type.shape.dim: + print(f'fixing output shape: {x.name}') + x.CopyFrom(onnx.helper.make_tensor_value_info( + x.name, x.type.tensor_type.elem_type, [])) + + # eliminate 0-batch dimension, dirty workaround for two-stage detectors + for input in extracted_model.graph.input: + if input.name in inputs: + if input.type.tensor_type.shape.dim[0].dim_value == 0: + input.type.tensor_type.shape.dim[0].dim_value = 1 + + # eliminate duplicated value_info for inputs + success = True + while success: + success = False + for i, x in enumerate(extracted_model.graph.value_info): + if x.name in inputs: + del extracted_model.graph.value_info[i] + success = True + break + + return extracted_model + + +def collect_avaiable_marks(model): + marks = [] + for node in model.graph.node: + if node.op_type == 'Mark': + attr = attribute_to_dict(node.attribute) + func = attr['func'] + if func not in marks: + marks.append(func) + return marks + + +def main(): + args = parse_args() + + model = onnx.load(args.input_model) + marks = collect_avaiable_marks(model) + print("Available marks:\n {}".format('\n '.join(marks))) + + extracted_model = extract_model(model, args.start, args.end) + + if osp.splitext(args.output_model)[-1] != '.onnx': + args.output_model += '.onnx' + onnx.save(extracted_model, args.output_model) + + +if __name__ == '__main__': + main()