Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plugin inference and loading from onnx #4266

Closed
idantene opened this issue Dec 2, 2024 · 7 comments
Closed

Plugin inference and loading from onnx #4266

idantene opened this issue Dec 2, 2024 · 7 comments

Comments

@idantene
Copy link

idantene commented Dec 2, 2024

Hey!

I'm trying to follow through the logic of automatic parsing of plugin usage from onnx files, as per Loading an ONNX model with the custom operator.
To experiment with this, I've played around with the plugins provided by TensorRT-LLM (v0.16) and TRT 10.6.

My understanding is that when an unknown operation type is detected, the OnnxParser looks it up in the plugin registry (if the matching plugin_namespace attribute exists? Is the version not required?). If everything works well - it will use that plugin.

Now, I wanted to compare the default implementation of e.g. Gemm operation and the one provided by TensorRT-LLM. I was just curious at this point to see whether they are different implementation at all, or is the one in TRT-LLM offered as a sort of backwards compatibility.
However, since the op_type Gemm exists natively, it seems the OnnxParser never even looks at the plugin registry.

Are there any specific flags missing? Is this a design choice or an oversight? Can one even reuse native ONNX operation names/types as plugin names?

Thanks!

@asfiyab-nvidia
Copy link
Collaborator

cc @venkywonka

@venkywonka
Copy link
Collaborator

Thank you for taking the time to create this @idantene ,

when an unknown operation type is detected, the OnnxParser looks it up in the plugin registry (if the matching plugin_namespace attribute exists? Is the version not required?).

You are right that the OnnxParser looks up the unknown operation (that aren't natively supported in TensorRT) inside the plugin registry, with the unique key for the lookup being: (plugin_name, plugin_namespace, plugin_version).

The OnnxParser extracts (plugin_name, plugin_namespace, plugin_version) from the particular ONNX node's op property, plugin_namespace string attribute, and plugin_version string attribute respectively.

If there exists no plugin_version attribute in the ONNX node, then version 1 is assumed by default by the parser. So in the provided example, the parser tries to get the plugin creator by using the key: (circ_pad_plugin, example, 1) from the plugin registry.

If multiple versions of plugin (and its corresponding creator) are present, the ONNX node is expected to have the right plugin_version attribute as part of its node proto. This is a rarer situation, though.
I agree there seems to be a documentation gap here, we shall add information to make this explicit. Thank you for pointing it out.

--

Are there any specific flags missing? Is this a design choice or an oversight? Can one even reuse native ONNX operation names/types as plugin names?

It was an intended design choice to prefer the native implementation if the plugin names match. One would have to use a different name that doesn't match an ONNX Op, as the plugin registry is only checked as a fallback unfortunately.

@idantene
Copy link
Author

idantene commented Dec 5, 2024

Are there any specific flags missing? Is this a design choice or an oversight? Can one even reuse native ONNX operation names/types as plugin names?

It was an intended design choice to prefer the native implementation if the plugin names match. One would have to use a different name that doesn't match an ONNX Op, as the plugin registry is only checked as a fallback unfortunately.

Thanks for clarifying! As mentioned, I'm trying to use TRT-LLM's Gemm plugin, but because of this constraint, I cannot, as Gemm matches an ONNX op.
Are there plans to change this behavior? For example, providing a flag to OnnxParser so that it prefers plugins over ONNX ops?

EDIT: Alternatively (and/or additionally), I think that the documentation would also benefit from this clarification, as would the TRT-LLM people, so that they would not name their plugins to conflate with this design decision...

@venkywonka
Copy link
Collaborator

venkywonka commented Dec 5, 2024

Thank you @idantene , yes we are working it! Unfortunately for now, there aren't any flags in the OnxxParser that could arbitrate this.
If the user wants to use a plugin that shares the same name as a native layer (the TRT-LLM Gemm plugin just so happens to do this) with ONNX, then there are two options:

  1. Edit the parser sources to manually choose the plugin over the native layer for Gemm ops
  2. Edit the plugin name (in TRTLLM plugin library code) and ONNX graph nodes (using tools like ONNX-Graphsurgeon or DL Designer) to a non-official name, i.e. Gemm_Plugin. Others should be identical to your existing workflow.

Thanks!

@idantene
Copy link
Author

idantene commented Dec 9, 2024

  1. Edit the parser sources to manually choose the plugin over the native layer for Gemm ops
  2. Edit the plugin name (in TRTLLM plugin library code) and ONNX graph nodes (using tools like ONNX-Graphsurgeon or DL Designer) to a non-official name, i.e. Gemm_Plugin. Others should be identical to your existing workflow.

Thank you for these suggestions!

If I understand correctly, (2) would then require recompiling the TRTLLM plugin, so I'm a bit more naturally drawn to (1).
How would I go about changing the parser sources? Any guidance to a starting point would be greatly appreciated 👍🏻

@venkywonka
Copy link
Collaborator

Digging through, this does seem doable inside the parser's source.
If you looks at these lines in onnx-tensorrt (which is the parsers submodule of TensorRT), you might just wanna tweak the precedence.
Maybe try adding

if (opImporters.count(nodeType))
{
    // inject your override here:
    // start
    if (isNodeInPluginRegistry(ctx, node))
    {
        LOG_INFO("Found registered plugin: " << nodeType << ". Importing Native Op as a plugin.");
        importFunc = &opImporters.at("FallbackPluginImporter");
    }
   else
    // end
        importFunc = &opImporters.at(nodeType);
}

We're internally tracking this, and although seems simple, we shall improve based on evaluating broader use-cases/implications.

Thank you for exposing this gap, and helping us improve! 😄

Do let me know if I can go ahead and close this.

@idantene
Copy link
Author

That's great, above and beyond what I expected, thanks @venkywonka! Much appreciated.

I assume once these changes are in place, I could simply follow the build process as described in the README -- so feel free to close this now. Thanks again!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants