-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
220 changed files
with
9,805 additions
and
452 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,186 @@ | ||
import argparse | ||
import math | ||
from pathlib import Path | ||
|
||
import torch | ||
from rich.console import Console | ||
|
||
console = Console() | ||
|
||
prog_description = """\ | ||
Draw the state dict tree. | ||
""" | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser(description=prog_description) | ||
parser.add_argument( | ||
'path', | ||
type=Path, | ||
help='The path of the checkpoint or model config to draw.') | ||
parser.add_argument('--depth', type=int, help='The max depth to draw.') | ||
parser.add_argument( | ||
'--full-name', | ||
action='store_true', | ||
help='Whether to print the full name of the key.') | ||
parser.add_argument( | ||
'--shape', | ||
action='store_true', | ||
help='Whether to print the shape of the parameter.') | ||
parser.add_argument( | ||
'--state-key', | ||
type=str, | ||
help='The key of the state dict in the checkpoint.') | ||
parser.add_argument( | ||
'--number', | ||
action='store_true', | ||
help='Mark all parameters and their index number.') | ||
parser.add_argument( | ||
'--node', | ||
type=str, | ||
help='Show the sub-tree of a node, like "backbone.layers".') | ||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def ckpt_to_state_dict(checkpoint, key=None): | ||
if key is not None: | ||
state_dict = checkpoint[key] | ||
elif 'state_dict' in checkpoint: | ||
# try mmcls style | ||
state_dict = checkpoint['state_dict'] | ||
elif 'model' in checkpoint: | ||
state_dict = checkpoint['model'] | ||
elif isinstance(next(iter(checkpoint.values())), torch.Tensor): | ||
# try native style | ||
state_dict = checkpoint | ||
else: | ||
raise KeyError('Please specify the key of state ' | ||
f'dict from {list(checkpoint.keys())}.') | ||
return state_dict | ||
|
||
|
||
class StateDictTree: | ||
|
||
def __init__(self, key='', value=None): | ||
self.children = {} | ||
self.key: str = key | ||
self.value = value | ||
|
||
def add_parameter(self, key, value): | ||
keys = key.split('.', 1) | ||
if len(keys) == 1: | ||
self.children[key] = StateDictTree(key, value) | ||
elif keys[0] in self.children: | ||
self.children[keys[0]].add_parameter(keys[1], value) | ||
else: | ||
node = StateDictTree(keys[0]) | ||
node.add_parameter(keys[1], value) | ||
self.children[keys[0]] = node | ||
|
||
def __getitem__(self, key: str): | ||
return self.children[key] | ||
|
||
def __repr__(self) -> str: | ||
with console.capture() as capture: | ||
for line in self.iter_tree(): | ||
console.print(line) | ||
return capture.get() | ||
|
||
def __len__(self): | ||
return len(self.children) | ||
|
||
def draw_tree(self, | ||
max_depth=None, | ||
full_name=False, | ||
with_shape=False, | ||
with_value=False): | ||
for line in self.iter_tree( | ||
max_depth=max_depth, | ||
full_name=full_name, | ||
with_shape=with_shape, | ||
with_value=with_value, | ||
): | ||
console.print(line, highlight=False) | ||
|
||
def iter_tree( | ||
self, | ||
lead='', | ||
prefix='', | ||
max_depth=None, | ||
full_name=False, | ||
with_shape=False, | ||
with_value=False, | ||
): | ||
if self.value is None: | ||
key_str = f'[blue]{self.key}[/]' | ||
elif with_shape: | ||
key_str = f'[green]{self.key}[/] {tuple(self.value.shape)}' | ||
elif with_value: | ||
key_str = f'[green]{self.key}[/] {self.value}' | ||
else: | ||
key_str = f'[green]{self.key}[/]' | ||
|
||
yield lead + prefix + key_str | ||
|
||
lead = lead.replace('├─', '│ ') | ||
lead = lead.replace('└─', ' ') | ||
if self.key and full_name: | ||
prefix = f'{prefix}{self.key}.' | ||
|
||
if max_depth == 0: | ||
return | ||
elif max_depth is not None: | ||
max_depth -= 1 | ||
|
||
for i, child in enumerate(self.children.values()): | ||
level_lead = '├─' if i < len(self.children) - 1 else '└─' | ||
yield from child.iter_tree( | ||
lead=f'{lead}{level_lead} ', | ||
prefix=prefix, | ||
max_depth=max_depth, | ||
full_name=full_name, | ||
with_shape=with_shape, | ||
with_value=with_value) | ||
|
||
|
||
def main(): | ||
args = parse_args() | ||
if args.path.suffix in ['.json', '.py', '.yml']: | ||
from mmengine.runner import get_state_dict | ||
|
||
from mmcls.apis import init_model | ||
model = init_model(args.path, device='cpu') | ||
state_dict = get_state_dict(model) | ||
else: | ||
ckpt = torch.load(args.path, map_location='cpu') | ||
state_dict = ckpt_to_state_dict(ckpt, args.state_key) | ||
|
||
root = StateDictTree() | ||
for k, v in state_dict.items(): | ||
root.add_parameter(k, v) | ||
|
||
para_index = 0 | ||
mark_width = math.floor(math.log(len(state_dict), 10) + 1) | ||
if args.node is not None: | ||
for key in args.node.split('.'): | ||
root = root[key] | ||
|
||
for line in root.iter_tree( | ||
max_depth=args.depth, | ||
full_name=args.full_name, | ||
with_shape=args.shape, | ||
): | ||
if not args.number: | ||
mark = '' | ||
# A hack method to determine whether a line is parameter. | ||
elif '[green]' in line: | ||
mark = f'[red]({str(para_index).ljust(mark_width)})[/]' | ||
para_index += 1 | ||
else: | ||
mark = ' ' * (mark_width + 2) | ||
console.print(mark + line, highlight=False) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
Oops, something went wrong.