From 24adf62d35da064f4f0d137139ed86ff74e6cded Mon Sep 17 00:00:00 2001 From: "n.d. parker" Date: Wed, 25 Oct 2023 23:42:48 +0200 Subject: [PATCH] Fix __eq__ protocol (#2197) If it is not known how to compare objects, `__eq__` should return NotImplemented instead of False. That way the right side object of the comparison might possibly take over. --- tests/test_basic.py | 24 ++++++++++++++++++++++++ tests/test_template.py | 19 +++++++++++++++++++ troposphere/__init__.py | 8 ++++---- 3 files changed, 47 insertions(+), 4 deletions(-) diff --git a/tests/test_basic.py b/tests/test_basic.py index 09ad41e1d..ef2ad19ce 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -2,6 +2,7 @@ import unittest from troposphere import ( + AWSHelperFn, AWSObject, AWSProperty, Cidr, @@ -25,6 +26,19 @@ from troposphere.validators import positive_integer +class TypeComparator: + """ Helper to test the __eq__ protocol """ + + def __init__(self, valid_types): + self.valid_types = valid_types + + def __eq__(self, other): + return isinstance(other, self.valid_types) + + def __ne__(self, other): + return not self == other + + def double(x): return positive_integer(x) * 2 @@ -77,8 +91,13 @@ def test___eq__(self): "title": "foobar", "Properties": {"callcorrect": True}, } + assert FakeAWSObject("foobar", callcorrect=True) == TypeComparator(AWSObject) + assert TypeComparator(AWSObject) == FakeAWSObject("foobar", callcorrect=True) + assert GenericHelperFn("foobar") == GenericHelperFn("foobar") assert GenericHelperFn({"foo": "bar"}) == {"foo": "bar"} + assert GenericHelperFn("foobar") == TypeComparator(AWSHelperFn) + assert TypeComparator(AWSHelperFn) == GenericHelperFn("foobar") def test___ne__(self): """Test __ne__.""" @@ -89,9 +108,14 @@ def test___ne__(self): "foobar", callcorrect=False ) assert FakeAWSObject("foobar", callcorrect=True) != FakeAWSProperty("foobar") + assert FakeAWSObject("foobar", callcorrect=True) != TypeComparator(AWSHelperFn) + assert TypeComparator(AWSHelperFn) != FakeAWSObject("foobar", callcorrect=True) + assert GenericHelperFn("foobar") != GenericHelperFn("bar") assert GenericHelperFn("foobar") != "foobar" assert GenericHelperFn("foobar") != FakeAWSProperty("foobar") + assert GenericHelperFn("foobar") != TypeComparator(AWSObject) + assert TypeComparator(AWSObject) != GenericHelperFn("foobar") def test_badproperty(self): with self.assertRaises(AttributeError): diff --git a/tests/test_template.py b/tests/test_template.py index 34a533c1a..38bc8d71f 100644 --- a/tests/test_template.py +++ b/tests/test_template.py @@ -81,6 +81,19 @@ def test_max_mappings(self): template.add_mapping("mapping", {"n": "v"}) +class TypeComparator: + """ Helper to test the __eq__ protocol """ + + def __init__(self, valid_types): + self.valid_types = valid_types + + def __eq__(self, other): + return isinstance(other, self.valid_types) + + def __ne__(self, other): + return not self == other + + class TestEquality(unittest.TestCase): def test_eq(self): metadata = "foo" @@ -98,6 +111,9 @@ def test_eq(self): self.assertEqual(t1, t2) + self.assertEqual(t1, TypeComparator(Template)) + self.assertEqual(TypeComparator(Template), t1) + def test_ne(self): t1 = Template(Description="foo1", Metadata="bar1") t1.add_resource(Bucket("Baz1")) @@ -109,6 +125,9 @@ def test_ne(self): self.assertNotEqual(t1, t2) + self.assertNotEqual(t1, TypeComparator(Output)) + self.assertNotEqual(TypeComparator(Output), t1) + def test_hash(self): metadata = "foo" description = "bar" diff --git a/troposphere/__init__.py b/troposphere/__init__.py index a713f774f..f9090bd69 100644 --- a/troposphere/__init__.py +++ b/troposphere/__init__.py @@ -420,7 +420,7 @@ def __eq__(self, other: object) -> bool: return self.title == other.title and self.to_json() == other.to_json() if isinstance(other, dict): return {"title": self.title, **self.to_dict()} == other - return False + return NotImplemented def __ne__(self, other: object) -> bool: return not self == other @@ -514,7 +514,7 @@ def __eq__(self, other: object) -> bool: return self.to_json() == other.to_json() if isinstance(other, (dict, list)): return self.to_dict() == other - return False + return NotImplemented def __hash__(self) -> int: return hash(self.to_json(indent=0)) @@ -986,10 +986,10 @@ def __eq__(self, other: object) -> bool: if isinstance(other, Template): return self.to_json() == other.to_json() else: - return False + return NotImplemented def __ne__(self, other: object) -> bool: - return not self.__eq__(other) + return not self == other def __hash__(self) -> int: return hash(self.to_json())