Skip to content

Commit

Permalink
Refactor 2024 (#33)
Browse files Browse the repository at this point in the history
* update convertor

* add allocator

* add test

* execute v3

* update gather embedding

* to device
  • Loading branch information
grimoire authored Feb 18, 2024
1 parent b25188d commit c429009
Show file tree
Hide file tree
Showing 41 changed files with 888 additions and 2,508 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ on: [push, pull_request]

jobs:
lint:
runs-on: ubuntu-18.04
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.7
- name: Set up Python 3.9
uses: actions/setup-python@v2
with:
python-version: 3.7
python-version: 3.9
- name: Install pre-commit hook
run: |
pip install pre-commit
Expand Down
2 changes: 1 addition & 1 deletion .isort.cfg
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
[settings]
known_third_party = graphviz,numpy,packaging,setuptools,tensorrt,termcolor,torch,torchvision
known_third_party = distutils,graphviz,numpy,packaging,pytest,setuptools,tensorrt,torch,torchvision
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://gitlab.com/pycqa/flake8.git
rev: 3.8.3
- repo: https://github.com/PyCQA/flake8
rev: 4.0.1
hooks:
- id: flake8
- repo: https://github.com/asottile/seed-isort-config
Expand All @@ -12,11 +12,11 @@ repos:
hooks:
- id: isort
- repo: https://github.com/pre-commit/mirrors-yapf
rev: v0.30.0
rev: v0.32.0
hooks:
- id: yapf
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.1.0
rev: v4.2.0
hooks:
- id: trailing-whitespace
- id: check-yaml
Expand Down
46 changes: 23 additions & 23 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# torch2trt dynamic

This is a branch of [torch2trt](https://github.com/NVIDIA-AI-IOT/torch2trt) with dynamic input support

Note that not all layers support dynamic input such as `torch.split()` etc...
This is a branch of [torch2trt](https://github.com/NVIDIA-AI-IOT/torch2trt) with dynamic input support.

## Usage

Expand All @@ -11,58 +9,60 @@ Here are some examples
### Convert

```python
from torch2trt_dynamic import torch2trt_dynamic
from torch2trt_dynamic import module2trt, BuildEngineConfig
import torch
from torch import nn
from torchvision.models.resnet import resnet50
from torchvision.models import resnet18

# create some regular pytorch model...
model = resnet50().cuda().eval()
model = resnet18().cuda().eval()

# create example data
x = torch.ones((1, 3, 224, 224)).cuda()

# convert to TensorRT feeding sample data as input
opt_shape_param = [
[
[1, 3, 128, 128], # min
[1, 3, 256, 256], # opt
[1, 3, 512, 512] # max
]
]
model_trt = torch2trt_dynamic(model, [x], fp16_mode=False, opt_shape_param=opt_shape_param)
config = BuildEngineConfig(
shape_ranges=dict(
x=dict(
min=(1, 3, 224, 224),
opt=(2, 3, 224, 224),
max=(4, 3, 224, 224),
)
))
trt_model = module2trt(
model,
args=[x],
config=config)
```

### Execute

We can execute the returned `TRTModule` just like the original PyTorch model

```python
x = torch.rand(1,3,256,256).cuda()
x = torch.rand(1, 3, 224, 224).cuda()
with torch.no_grad():
y = model(x)
y_trt = model_trt(x)
y_trt = trt_model(x)

# check the output against PyTorch
print(torch.max(torch.abs(y - y_trt)))
torch.testing.assert_close(y, y_trt)
```

### Save and load

We can save the model as a ``state_dict``.

```python
torch.save(model_trt.state_dict(), 'alexnet_trt.pth')
torch.save(trt_model.state_dict(), 'my_engine.pth')
```

We can load the saved model into a ``TRTModule``

```python
from torch2trt_dynamic import TRTModule

model_trt = TRTModule()

model_trt.load_state_dict(torch.load('alexnet_trt.pth'))
trt_model = TRTModule()
trt_model.load_state_dict(torch.load('my_engine.pth'))
```

## Setup
Expand All @@ -72,7 +72,7 @@ To install without compiling plugins, call the following
```bash
git clone https://github.com/grimoire/torch2trt_dynamic.git torch2trt_dynamic
cd torch2trt_dynamic
python setup.py develop
pip install .
```

### Set plugins(optional)
Expand Down
20 changes: 0 additions & 20 deletions benchmarks/JETSON_NANO.md

This file was deleted.

31 changes: 0 additions & 31 deletions benchmarks/JETSON_XAVIER.md

This file was deleted.

171 changes: 0 additions & 171 deletions notebooks/image_classification/conversion.ipynb

This file was deleted.

Loading

0 comments on commit c429009

Please sign in to comment.