-
Notifications
You must be signed in to change notification settings - Fork 10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[IDEA] Figuring out how to use this library with TensorFlow multi-class and binary classification #20
Comments
Hi Jordan! As of v0.5.7: The current logits/argmax functionality is not flexible and can definitely be improved. Moreover, your example highlights a limitation that I overlooked completely. I also use the library to tile images, feed tiles to semantic segmentation network and merge back to full result, but in my case those images don't have channel dimension as it's always just one value per pixel, so I never used Merger's logits/argmax functionality and Tiler's channel_dimension at the same time... Merger's If you specify I will try to find the time to implement this soon and sorry for not supporting your usecase yet! |
Hi, Thanks for your response! While this might not be ideal, I was able to work around the channel_dimension constraints by having a second Tiler object with the N channels that are supposed to be the output of the network. The following is an example of the implementation. If you are okay with this, I can create a pull request with a similar example so other users I point to this library can leverage it. Binary segmentation problem where output is N x 256 x 256 x 1 mode = 'constant'
batch_size = 512
tiler_image = Tiler(
data_shape=image.shape,
tile_shape=(256, 256, 4),
channel_dimension=2,
overlap=0.50,
mode=mode,
)
tiler_mask = Tiler(
data_shape=image.shape,
tile_shape=(256, 256, 1),
channel_dimension=2,
overlap=0.50,
mode=mode,
)
new_shape, padding = tiler_image.calculate_padding()
tiler_image.recalculate(data_shape=new_shape)
tiler_mask.recalculate(data_shape=new_shape)
padded_image = np.pad(image, padding, mode=mm, constant_values=1200)
merger = Merger(tiler=tiler_mask, window="overlap-tile")
for batch_id, batch in tiler_image(padded_image, batch_size=batch_size):
batch = model.predict(batch)
merger.add_batch(batch_id, batch_size, batch)
prediction = merger.merge(extra_padding=padding, dtype=image.dtype)
prediction = np.squeeze(np.where(prediction > 0.5, 1, 0).astype(np.int16))
print(prediction.shape, prediction.min(), prediction.max()) The only challenge I am trying to work around now is the presence of artifact effects at the boundary level of non-uniform images (e.g. an image of size 90538x9148x4 where the tile size is 256x256 with a batch size of 512). Is this something you have worked around with this library? I can open a new issue with this topic as well. An example is illustrated below, where those vertical lines are not expected at the left border of the image. |
Nice workaround! Hopefully in the near future it will not be needed anymore! Not sure how helpful these suggestions are, but:
I'm curious to hear if you manage to fix this :) |
Is your feature request related to a problem? Please describe.
I have been trying to use this library for the inference of TensorFlow binary and multiclass segmentation models. I am able to use the tiler object to perform the predictions. I have not been able to figure out how to leverage the merger for the following cases.
data_shape = 5000 x 3000 x 4
tile_shape = (256, 256, 4)
channel_dimension = 0
The output of the model can be either a batch of (N x 256 x 256 x 1) or (N x 256 x 256 x 6); where 6 is the number of classes.
ValueError: Passed data shape ([256 256 1]) does not fit expected tile shape ((256, 256, 4)).
Describe the solution you'd like
Would be great to have additional examples regarding similar use cases performing TensorFlow or PyTorch predictions.
Here is an example of what I have been trying:
I am probably missing something, but would be nice to have it documented. Also, argmax option seems to be hardcoded for channel first images, which adds additional computational requirements when using channels last images. Any help would be appreciated.
The text was updated successfully, but these errors were encountered: