diff --git a/pymbolic/primitives.py b/pymbolic/primitives.py index b87eaa05..0df83531 100644 --- a/pymbolic/primitives.py +++ b/pymbolic/primitives.py @@ -926,9 +926,11 @@ def _augment_expression_dataclass( from pytools.codegen import remove_common_indentation augment_code = remove_common_indentation( - f""" + """ from warnings import warn from dataclasses import is_dataclass + """ + + (f""" def {cls.__name__}_eq(self, other): @@ -951,6 +953,8 @@ def {cls.__name__}_eq(self, other): return self.__class__ == other.__class__ and {comparison} cls.__eq__ = {cls.__name__}_eq + """ if generate_hash else "") + + (f""" def {cls.__name__}_hash(self): @@ -973,8 +977,9 @@ def {cls.__name__}_hash(self): object.__setattr__(self, "_hash_value", hash_val) return hash_val - if {generate_hash}: - cls.__hash__ = {cls.__name__}_hash + cls.__hash__ = {cls.__name__}_hash + """ if generate_hash else "") + + f""" def {cls.__name__}_init_arg_names(self):