Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add decomposition for Grid_Sample and Floor op #54

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

meenakshiramanathan1
Copy link
Contributor

@meenakshiramanathan1 meenakshiramanathan1 commented Dec 24, 2024

Decomposition for torch.grid_sample and torch.floor op has been added.
Grid Sample op decomposition has been adapted from bilinear grid sample from here.
Currently this decomposition supports bilinear interpolation mode with 4d input shapes only, nearest interpolation mode and 5d inputs support will be added later.

The change in DecomposeMultiDimSqueeze callback has been added to handle this failure when squeeze_() was used.
TypeError: 'NoneType' object is not iterable

So in that cases, all dimensions with value 1 will be squeezed.

@meenakshiramanathan1 meenakshiramanathan1 marked this pull request as draft December 24, 2024 07:09
@meenakshiramanathan1 meenakshiramanathan1 force-pushed the mramanathan/grid_sample_tvm branch 6 times, most recently from ffbf66b to ea03475 Compare December 27, 2024 06:44
@meenakshiramanathan1 meenakshiramanathan1 marked this pull request as ready for review December 27, 2024 13:22
python/tvm/relay/op/contrib/forge/forge_passes.py Outdated Show resolved Hide resolved

# Compute integer pixel indices for bilinear interpolation
if mode == 'bilinear':
x0 = tvm.relay.floor(x).astype("int32")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to do a type cast?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, if it's not typecasted we are ending up with dtype mismatch later.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Let's print out a warning as well.

python/tvm/relay/op/contrib/forge/forge_passes.py Outdated Show resolved Hide resolved
python/tvm/relay/op/contrib/forge/forge_passes.py Outdated Show resolved Hide resolved
@meenakshiramanathan1 meenakshiramanathan1 force-pushed the mramanathan/grid_sample_tvm branch 3 times, most recently from 863f508 to 7cec18c Compare December 31, 2024 15:47
@meenakshiramanathan1 meenakshiramanathan1 force-pushed the mramanathan/grid_sample_tvm branch from 7cec18c to 8aed232 Compare January 2, 2025 05:38

# Compute integer pixel indices for bilinear interpolation
if mode == 'bilinear':
x0 = tvm.relay.floor(x).astype("int32")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Let's print out a warning as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants