Skip to content

Commit

Permalink
Fix __eq__ protocol (#2197)
Browse files Browse the repository at this point in the history
If it is not known how to compare objects, `__eq__` should return
NotImplemented instead of False. That way the right side object of the
comparison might possibly take over.
  • Loading branch information
ndparker authored Oct 25, 2023
1 parent e1d79a3 commit 24adf62
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 4 deletions.
24 changes: 24 additions & 0 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import unittest

from troposphere import (
AWSHelperFn,
AWSObject,
AWSProperty,
Cidr,
Expand All @@ -25,6 +26,19 @@
from troposphere.validators import positive_integer


class TypeComparator:
""" Helper to test the __eq__ protocol """

def __init__(self, valid_types):
self.valid_types = valid_types

def __eq__(self, other):
return isinstance(other, self.valid_types)

def __ne__(self, other):
return not self == other


def double(x):
return positive_integer(x) * 2

Expand Down Expand Up @@ -77,8 +91,13 @@ def test___eq__(self):
"title": "foobar",
"Properties": {"callcorrect": True},
}
assert FakeAWSObject("foobar", callcorrect=True) == TypeComparator(AWSObject)
assert TypeComparator(AWSObject) == FakeAWSObject("foobar", callcorrect=True)

assert GenericHelperFn("foobar") == GenericHelperFn("foobar")
assert GenericHelperFn({"foo": "bar"}) == {"foo": "bar"}
assert GenericHelperFn("foobar") == TypeComparator(AWSHelperFn)
assert TypeComparator(AWSHelperFn) == GenericHelperFn("foobar")

def test___ne__(self):
"""Test __ne__."""
Expand All @@ -89,9 +108,14 @@ def test___ne__(self):
"foobar", callcorrect=False
)
assert FakeAWSObject("foobar", callcorrect=True) != FakeAWSProperty("foobar")
assert FakeAWSObject("foobar", callcorrect=True) != TypeComparator(AWSHelperFn)
assert TypeComparator(AWSHelperFn) != FakeAWSObject("foobar", callcorrect=True)

assert GenericHelperFn("foobar") != GenericHelperFn("bar")
assert GenericHelperFn("foobar") != "foobar"
assert GenericHelperFn("foobar") != FakeAWSProperty("foobar")
assert GenericHelperFn("foobar") != TypeComparator(AWSObject)
assert TypeComparator(AWSObject) != GenericHelperFn("foobar")

def test_badproperty(self):
with self.assertRaises(AttributeError):
Expand Down
19 changes: 19 additions & 0 deletions tests/test_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,19 @@ def test_max_mappings(self):
template.add_mapping("mapping", {"n": "v"})


class TypeComparator:
""" Helper to test the __eq__ protocol """

def __init__(self, valid_types):
self.valid_types = valid_types

def __eq__(self, other):
return isinstance(other, self.valid_types)

def __ne__(self, other):
return not self == other


class TestEquality(unittest.TestCase):
def test_eq(self):
metadata = "foo"
Expand All @@ -98,6 +111,9 @@ def test_eq(self):

self.assertEqual(t1, t2)

self.assertEqual(t1, TypeComparator(Template))
self.assertEqual(TypeComparator(Template), t1)

def test_ne(self):
t1 = Template(Description="foo1", Metadata="bar1")
t1.add_resource(Bucket("Baz1"))
Expand All @@ -109,6 +125,9 @@ def test_ne(self):

self.assertNotEqual(t1, t2)

self.assertNotEqual(t1, TypeComparator(Output))
self.assertNotEqual(TypeComparator(Output), t1)

def test_hash(self):
metadata = "foo"
description = "bar"
Expand Down
8 changes: 4 additions & 4 deletions troposphere/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ def __eq__(self, other: object) -> bool:
return self.title == other.title and self.to_json() == other.to_json()
if isinstance(other, dict):
return {"title": self.title, **self.to_dict()} == other
return False
return NotImplemented

def __ne__(self, other: object) -> bool:
return not self == other
Expand Down Expand Up @@ -514,7 +514,7 @@ def __eq__(self, other: object) -> bool:
return self.to_json() == other.to_json()
if isinstance(other, (dict, list)):
return self.to_dict() == other
return False
return NotImplemented

def __hash__(self) -> int:
return hash(self.to_json(indent=0))
Expand Down Expand Up @@ -986,10 +986,10 @@ def __eq__(self, other: object) -> bool:
if isinstance(other, Template):
return self.to_json() == other.to_json()
else:
return False
return NotImplemented

def __ne__(self, other: object) -> bool:
return not self.__eq__(other)
return not self == other

def __hash__(self) -> int:
return hash(self.to_json())
Expand Down

0 comments on commit 24adf62

Please sign in to comment.