Skip to content

Commit

Permalink
Fix tf.where inconsistent dtype bug (#126)
Browse files Browse the repository at this point in the history
* Make sure `tf.where` having same dtype for x and y

* Upgrade to `0.5.8`
  • Loading branch information
james77777778 authored Nov 13, 2023
1 parent 362d47c commit 9f6c60d
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 18 deletions.
2 changes: 1 addition & 1 deletion keras_aug/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@
from keras_aug.core import SignedNormalFactorSampler
from keras_aug.core import UniformFactorSampler

__version__ = "0.5.7"
__version__ = "0.5.8"
4 changes: 3 additions & 1 deletion keras_aug/datapoints/bounding_box/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ def _relative_area(boxes, bounding_box_format):
heights = boxes[..., 3]
# handle corner case where shear performs a full inversion.
return tf.where(
tf.math.logical_and(widths > 0, heights > 0), widths * heights, 0.0
tf.math.logical_and(widths > 0, heights > 0),
widths * heights,
tf.constant(0.0, dtype=widths.dtype),
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,22 +154,22 @@ def get_random_transformation_batch(
tops = tf.where(
new_heights < self.crop_height,
tf.cast((self.crop_height - new_heights) / 2, tf.int32),
0,
tf.constant(0, dtype=tf.int32),
)
bottoms = tf.where(
new_heights < self.crop_height,
self.crop_height - new_heights - tops,
0,
tf.constant(0, dtype=tf.int32),
)
lefts = tf.where(
new_widths < self.crop_width,
tf.cast((self.crop_width - new_widths) / 2, tf.int32),
0,
tf.constant(0, dtype=tf.int32),
)
rights = tf.where(
new_widths < self.crop_width,
self.crop_width - new_widths - lefts,
0,
tf.constant(0, dtype=tf.int32),
)
(tops, bottoms, lefts, rights) = augmentation_utils.get_position_params(
tops, bottoms, lefts, rights, self.position, self._random_generator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def augment_bounding_boxes(
# bounding_box.to_ragged() after self.augment_bounding_boxes()
bounding_boxes["classes"] = tf.where(
intersection_ratios >= self.bbox_removal_threshold,
-1.0,
tf.constant(-1.0, dtype=bounding_boxes["classes"].dtype),
bounding_boxes["classes"],
)
return bounding_boxes
Expand Down
24 changes: 16 additions & 8 deletions keras_aug/layers/preprocessing/geometry/resize.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,36 +111,44 @@ def get_random_transformation_batch(
tops = tf.where(
new_heights > self.height,
tf.cast((new_heights - self.height) / 2, tf.int32),
0,
tf.constant(0, tf.int32),
)
bottoms = tf.where(
new_heights > self.height, new_heights - self.height - tops, 0
new_heights > self.height,
new_heights - self.height - tops,
tf.constant(0, tf.int32),
)
lefts = tf.where(
new_widths > self.width,
tf.cast((new_widths - self.width) / 2, tf.int32),
0,
tf.constant(0, tf.int32),
)
rights = tf.where(
new_widths > self.width, new_widths - self.width - lefts, 0
new_widths > self.width,
new_widths - self.width - lefts,
tf.constant(0, tf.int32),
)
else:
assert self.pad_to_aspect_ratio
tops = tf.where(
new_heights < self.height,
tf.cast((self.height - new_heights) / 2, tf.int32),
0,
tf.constant(0, tf.int32),
)
bottoms = tf.where(
new_heights < self.height, self.height - new_heights - tops, 0
new_heights < self.height,
self.height - new_heights - tops,
tf.constant(0, tf.int32),
)
lefts = tf.where(
new_widths < self.width,
tf.cast((self.width - new_widths) / 2, tf.int32),
0,
tf.constant(0, tf.int32),
)
rights = tf.where(
new_widths < self.width, self.width - new_widths - lefts, 0
new_widths < self.width,
self.width - new_widths - lefts,
tf.constant(0, tf.int32),
)
(tops, bottoms, lefts, rights) = augmentation_utils.get_position_params(
tops, bottoms, lefts, rights, self.position, self._random_generator
Expand Down
4 changes: 2 additions & 2 deletions keras_aug/layers/preprocessing/intensity/auto_contrast.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def augment_images(self, images, transformations, **kwargs):
highs = tf.reduce_max(images, axis=(1, 2), keepdims=True)
scales = 255.0 / (highs - lows)
eq_idxs = tf.math.is_inf(scales)
lows = tf.where(eq_idxs, 0.0, lows)
scales = tf.where(eq_idxs, 1.0, scales)
lows = tf.where(eq_idxs, tf.constant(0.0, dtype=lows.dtype), lows)
scales = tf.where(eq_idxs, tf.constant(1.0, dtype=scales.dtype), scales)
images = tf.clip_by_value((images - lows) * scales, 0, 255)
images = image_utils.transform_value_range(
images,
Expand Down
2 changes: 1 addition & 1 deletion keras_aug/layers/preprocessing/intensity/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def equalize_single_channel(self, image, channel_index):
big_number = 1410065408
histogram_without_zeroes = tf.where(
tf.equal(histogram, 0),
big_number,
tf.constant(big_number, dtype=histogram.dtype),
histogram,
)
step = (
Expand Down

0 comments on commit 9f6c60d

Please sign in to comment.