Skip to content

Commit

Permalink
Add default dim_order asserts
Browse files Browse the repository at this point in the history
Differential Revision: D61311560

Pull Request resolved: #4725
  • Loading branch information
digantdesai authored Aug 16, 2024
1 parent 7b795d7 commit b75e7d7
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 0 deletions.
5 changes: 5 additions & 0 deletions backends/xnnpack/runtime/XNNExecutor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ __ET_NODISCARD Error XNNExecutor::prepare_args(EValue** args) {
// Reshape runtime inputs
if (i < input_ids_.size()) {
size_t num_dims = tensor->dim();
ET_CHECK_OR_RETURN_ERROR(
is_contiguous_dim_order(tensor->dim_order().data(), tensor->dim()),
Internal,
"Expecting default dim_order but got a non default dim_order tensor for external input %u",
i);
size_t dims[XNN_MAX_TENSOR_DIMS];
ET_CHECK_OR_RETURN_ERROR(
num_dims <= XNN_MAX_TENSOR_DIMS,
Expand Down
19 changes: 19 additions & 0 deletions backends/xnnpack/xnnpack_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,22 @@ def generate_node_to_external_map(
return node_to_external_map


def assert_default_dim_order(edge_graph_module: torch.fx.GraphModule) -> None:
for node in edge_graph_module.graph.nodes:
if node.op != "placeholder":
continue

# We expect the default dim order for all tensor-like inputs i.e. inputs, buffers, and params
t = node.meta.get("val", None)
if t is not None and getattr(t, "dim_order", None) is not None:
default_dim_order = tuple(range(t.dim()))
if t.dim_order() != default_dim_order:
raise RuntimeError(
f"XNNPACK backend only supports contiguous memory format for inputs."
f"Expecting dim_order: {default_dim_order}, but got {node.meta['val'].dim_order()} for a placeholder node {node}."
)


@final
class XnnpackBackend(BackendDetails):
@staticmethod
Expand Down Expand Up @@ -126,6 +142,9 @@ def preprocess(

node_to_external_map = generate_node_to_external_map(ep, graph_module)

# Make sure all inputs are contiguous_format or NCHW or default dim order
assert_default_dim_order(graph_module)

# TODO retrace the graph module to lift the new params may have
# been added to the graph in passes

Expand Down

0 comments on commit b75e7d7

Please sign in to comment.