-
Notifications
You must be signed in to change notification settings - Fork 98
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
added a jax example #11
Open
tanayvarshney
wants to merge
3
commits into
main
Choose a base branch
from
tvarshney_jax
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
<!-- | ||
# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# | ||
# Redistribution and use in source and binary forms, with or without | ||
# modification, are permitted provided that the following conditions | ||
# are met: | ||
# * Redistributions of source code must retain the above copyright | ||
# notice, this list of conditions and the following disclaimer. | ||
# * Redistributions in binary form must reproduce the above copyright | ||
# notice, this list of conditions and the following disclaimer in the | ||
# documentation and/or other materials provided with the distribution. | ||
# * Neither the name of NVIDIA CORPORATION nor the names of its | ||
# contributors may be used to endorse or promote products derived | ||
# from this software without specific prior written permission. | ||
# | ||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY | ||
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR | ||
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR | ||
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, | ||
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, | ||
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR | ||
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | ||
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | ||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | ||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | ||
--> | ||
|
||
# Deploying a JAX Model | ||
|
||
This README showcases how to deploy a simple ResNet model on Triton Inference Server. While Triton doesn't yet have a dedicated JAX backend, JAX/Flax models can be deployed using [Python Backend](https://github.com/triton-inference-server/python_backend). If you are new to Triton, it is recommended to watch this [getting started video](https://www.youtube.com/watch?v=NQDtfSi5QF4) and review [Part 1](https://github.com/triton-inference-server/tutorials/tree/main/Conceptual_Guide/Part_1-model_deployment) of the conceptual guide before proceeding. For the purposes of demonstration, we are using a pre-trained model provided by [flaxmodels](https://github.com/matthias-wright/flaxmodels). | ||
|
||
Before diving into the specifics execution, an understanding of the underlying structure is needed. To use a JAX or a Flax model, the recommended path for this is using a ["Python Model"](https://github.com/triton-inference-server/python_backend#python-backend). Python models in Triton are classes with three Triton-specific functions: `initialize`, `execute` and `finalize`. Users can customize this class to serve any python function they write or any model they want as long as it can be loaded in python runtime. The `initialize` function runs when the python model is loaded into memory, and the `finalize` function runs when the model is unloaded from memory. Both of these functions are optional to define. For the purposes of this example, we will use the `initialize` and the `execute` functions to load and run(respectively) a `resnet18` model. | ||
|
||
We use the initialize method to load in the model weights and create our Flax model object. Here, we load a pretrained model from the flaxmodels library. You could also load weights from another pretrained model library, or from a file located in the model directory. Note that with JAX, our model parameters are automatically loaded onto any available accelerator, like a GPU. | ||
|
||
In the execute function, we perform the actual model inference. Note that the input to the `execute` method is an arbitrary length _list_ of request objects that may have been dynamically batched together. In this example, we loop through and execute each request individually and append each response into the `responses` list. If your model supports batched inputs, you may find it more efficient to execute all of the requests in one function call. | ||
|
||
```python | ||
import triton_python_backend_utils as pb_utils | ||
import jax | ||
import flaxmodels as fm | ||
|
||
import numpy as np | ||
|
||
class TritonPythonModel: | ||
|
||
def initialize(self, args): | ||
|
||
self.key = jax.random.PRNGKey(0) | ||
self.resnet18 = fm.ResNet18(output='logits', pretrained='imagenet') | ||
|
||
|
||
def execute(self, requests): | ||
responses = [] | ||
for request in requests: | ||
inp = pb_utils.get_input_tensor_by_name(request, "image") | ||
input_image = inp.as_numpy() | ||
|
||
params = self.resnet18.init(self.key, input_image) | ||
out = self.resnet18.apply(params, input_image, train=False) | ||
|
||
inference_response = pb_utils.InferenceResponse(output_tensors=[ | ||
pb_utils.Tensor( | ||
"fc_out", | ||
np.array(out), | ||
) | ||
]) | ||
responses.append(inference_response) | ||
return responses | ||
``` | ||
|
||
## Step 1: Set Up Triton Inference Server | ||
|
||
To use Triton, we need to build a model repository. The structure of the repository is as follows: | ||
|
||
```text | ||
model_repository/ | ||
└── resnet50 | ||
├── 1 | ||
│ └── model.py | ||
└── config.pbtxt | ||
``` | ||
|
||
For this example, we have pre-built the model repository. Next, we install the required dependencies and launch the Triton Inference Server. | ||
|
||
```bash | ||
# Replace the yy.mm in the image name with the release year and month | ||
# of the Triton version needed, eg. 22.12 | ||
docker run --gpus=all -it --shm-size=256m --rm -p8000:8000 -p8001:8001 -p8002:8002 -v$(pwd):/workspace/ -v/$(pwd)/model_repository:/models nvcr.io/nvidia/tritonserver:<yy.mm>-py3 bash | ||
|
||
# Note: See JAX install guide for more details on installing JAX: https://github.com/google/jax#installation | ||
pip install --upgrade pip | ||
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html | ||
|
||
pip install --upgrade git+https://github.com/matthias-wright/flaxmodels.git | ||
|
||
tritonserver --model-repository=/models | ||
``` | ||
|
||
## Step 2: Using a Triton Client to Query the Server | ||
|
||
Let's breakdown the client application. First, we setup a connection with the Triton Inference Server. | ||
|
||
```python | ||
client = httpclient.InferenceServerClient(url="localhost:8000") | ||
``` | ||
|
||
Then we set the input and output arrays. | ||
|
||
```python | ||
# Set Inputs | ||
input_tensors = [ | ||
httpclient.InferInput("image", image.shape, datatype="FP32") | ||
] | ||
input_tensors[0].set_data_from_numpy(image) | ||
|
||
# Set outputs | ||
outputs = [ | ||
httpclient.InferRequestedOutput("fc_out") | ||
] | ||
``` | ||
|
||
Lastly, we query send a request to the Triton Inference Server. | ||
|
||
```python | ||
# Query | ||
query_response = client.infer(model_name="resnet50", | ||
inputs=input_tensors, | ||
outputs=outputs) | ||
|
||
# Output | ||
out = query_response.as_numpy("fc_out") | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# | ||
# Redistribution and use in source and binary forms, with or without | ||
# modification, are permitted provided that the following conditions | ||
# are met: | ||
# * Redistributions of source code must retain the above copyright | ||
# notice, this list of conditions and the following disclaimer. | ||
# * Redistributions in binary form must reproduce the above copyright | ||
# notice, this list of conditions and the following disclaimer in the | ||
# documentation and/or other materials provided with the distribution. | ||
# * Neither the name of NVIDIA CORPORATION nor the names of its | ||
# contributors may be used to endorse or promote products derived | ||
# from this software without specific prior written permission. | ||
# | ||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY | ||
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR | ||
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR | ||
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, | ||
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, | ||
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR | ||
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | ||
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | ||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | ||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | ||
|
||
import numpy as np | ||
from tritonclient.utils import * | ||
from PIL import Image | ||
import tritonclient.http as httpclient | ||
import requests | ||
|
||
|
||
def main(): | ||
client = httpclient.InferenceServerClient(url="localhost:8000") | ||
|
||
# Inputs | ||
url = "http://images.cocodataset.org/val2017/000000161642.jpg" | ||
image = np.asarray(Image.open(requests.get(url, stream=True).raw)).astype(np.float32) | ||
image = np.expand_dims(image, axis=0) | ||
|
||
# Set Inputs | ||
input_tensors = [ | ||
httpclient.InferInput("image", image.shape, datatype="FP32") | ||
] | ||
input_tensors[0].set_data_from_numpy(image) | ||
|
||
# Set outputs | ||
outputs = [ | ||
httpclient.InferRequestedOutput("fc_out") | ||
] | ||
|
||
# Query | ||
query_response = client.infer(model_name="resnet50", | ||
inputs=input_tensors, | ||
outputs=outputs) | ||
|
||
# Output | ||
out = query_response.as_numpy("fc_out") | ||
print(out.shape) | ||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# | ||
# Redistribution and use in source and binary forms, with or without | ||
# modification, are permitted provided that the following conditions | ||
# are met: | ||
# * Redistributions of source code must retain the above copyright | ||
# notice, this list of conditions and the following disclaimer. | ||
# * Redistributions in binary form must reproduce the above copyright | ||
# notice, this list of conditions and the following disclaimer in the | ||
# documentation and/or other materials provided with the distribution. | ||
# * Neither the name of NVIDIA CORPORATION nor the names of its | ||
# contributors may be used to endorse or promote products derived | ||
# from this software without specific prior written permission. | ||
# | ||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY | ||
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR | ||
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR | ||
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, | ||
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, | ||
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR | ||
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | ||
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | ||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | ||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | ||
|
||
import triton_python_backend_utils as pb_utils | ||
import jax | ||
import flaxmodels as fm | ||
|
||
import numpy as np | ||
|
||
class TritonPythonModel: | ||
|
||
def initialize(self, args): | ||
self.key = jax.random.PRNGKey(0) | ||
self.resnet18 = fm.ResNet18(output='logits', pretrained='imagenet') | ||
|
||
|
||
def execute(self, requests): | ||
responses = [] | ||
for request in requests: | ||
inp = pb_utils.get_input_tensor_by_name(request, "image") | ||
input_image = inp.as_numpy() | ||
|
||
params = self.resnet18.init(self.key, input_image) | ||
out = self.resnet18.apply(params, input_image, train=False) | ||
|
||
inference_response = pb_utils.InferenceResponse(output_tensors=[ | ||
pb_utils.Tensor( | ||
"fc_out", | ||
np.array(out), | ||
) | ||
]) | ||
responses.append(inference_response) | ||
return responses |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# | ||
# Redistribution and use in source and binary forms, with or without | ||
# modification, are permitted provided that the following conditions | ||
# are met: | ||
# * Redistributions of source code must retain the above copyright | ||
# notice, this list of conditions and the following disclaimer. | ||
# * Redistributions in binary form must reproduce the above copyright | ||
# notice, this list of conditions and the following disclaimer in the | ||
# documentation and/or other materials provided with the distribution. | ||
# * Neither the name of NVIDIA CORPORATION nor the names of its | ||
# contributors may be used to endorse or promote products derived | ||
# from this software without specific prior written permission. | ||
# | ||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY | ||
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR | ||
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR | ||
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, | ||
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, | ||
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR | ||
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | ||
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | ||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | ||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | ||
|
||
name: "resnet50" | ||
backend: "python" | ||
max_batch_size: 8 | ||
|
||
input [ | ||
{ | ||
name: "image" | ||
data_type: TYPE_FP32 | ||
dims: [-1, -1, -1] | ||
tanayvarshney marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
] | ||
output [ | ||
{ | ||
name: "fc_out" | ||
data_type: TYPE_FP32 | ||
dims: [-1, -1] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This doesn't seem right. Shouldn't it be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It [-1,1000] if I remember it correctly. I can re-run and check |
||
} | ||
] | ||
|
||
instance_group [ | ||
{ | ||
kind: KIND_GPU | ||
} | ||
] | ||
tanayvarshney marked this conversation as resolved.
Show resolved
Hide resolved
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm surprised by this
init
call here. I'm not overly familiar with how Flax or the Flaxmodels library structure things, but why do we call init here instead of in theinitialize
method. What does it actually do? Is it about needing to know the shape of the input image?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, good catch, I missed this.