diff --git a/pymbolic/primitives.py b/pymbolic/primitives.py index 3d7d11bd..89448baa 100644 --- a/pymbolic/primitives.py +++ b/pymbolic/primitives.py @@ -33,7 +33,9 @@ ClassVar, Mapping, NoReturn, + Protocol, TypeVar, + cast, ) from warnings import warn @@ -819,9 +821,13 @@ def __iter__(self): ) +class _HasMapperMethod(Protocol): + mapper_method: ClassVar[str] + + def _augment_expression_dataclass( cls: type[DataclassInstance], - hash: bool, + generate_hash: bool, ) -> None: attr_tuple = ", ".join(f"self.{fld.name}" for fld in fields(cls)) if attr_tuple: @@ -854,8 +860,9 @@ def {cls.__name__}_eq(self, other): return True if self.__class__ is not other.__class__: return False - if hash(self) != hash(other): - return False + if {generate_hash}: + if hash(self) != hash(other): + return False if self.__class__ is not cls and self.init_arg_names != {fld_name_tuple}: warn(f"{{self.__class__}} is derived from {cls}, which is now " f"a dataclass. {{self.__class__}} should be converted to being " @@ -890,7 +897,7 @@ def {cls.__name__}_hash(self): object.__setattr__(self, "_hash_value", hash_val) return hash_val - if {hash}: + if {generate_hash}: cls.__hash__ = {cls.__name__}_hash @@ -956,23 +963,23 @@ def {cls.__name__}_setstate(self, state): # {{{ assign mapper_method - assert issubclass(cls, Expression) + mm_cls = cast(type[_HasMapperMethod], cls) - snake_clsname = _CAMEL_TO_SNAKE_RE.sub("_", cls.__name__).lower() + snake_clsname = _CAMEL_TO_SNAKE_RE.sub("_", mm_cls.__name__).lower() default_mapper_method_name = f"map_{snake_clsname}" # This covers two cases: the class does not have the attribute in the first # place, or it inherits a value but does not set it itself. - sets_mapper_method = "mapper_method" in cls.__dict__ + sets_mapper_method = "mapper_method" in mm_cls.__dict__ if sets_mapper_method: - if default_mapper_method_name == cls.mapper_method: - warn(f"Explicit mapper_method on {cls} not needed, default matches " + if default_mapper_method_name == mm_cls.mapper_method: + warn(f"Explicit mapper_method on {mm_cls} not needed, default matches " "explicit assignment. Just delete the explicit assignment.", stacklevel=3) if not sets_mapper_method: - cls.mapper_method = intern(default_mapper_method_name) + mm_cls.mapper_method = intern(default_mapper_method_name) # }}} @@ -983,18 +990,21 @@ def {cls.__name__}_setstate(self, state): @dataclass_transform(frozen_default=True) def expr_dataclass( init: bool = True, - hash: bool = True + hash: bool = True, ) -> Callable[[type[_T]], type[_T]]: - """A class decorator that makes the class a :func:`~dataclasses.dataclass` + r"""A class decorator that makes the class a :func:`~dataclasses.dataclass` while also adding functionality needed for :class:`Expression` nodes. Specifically, it adds cached hashing, equality comparisons with ``self is other`` shortcuts as well as some methods/attributes - for backward compatibility (e.g. ``__getinitargs__``, ``init_arg_names``) + for backward compatibility (e.g. ``__getinitargs__``, ``init_arg_names``). It also adds a :attr:`Expression.mapper_method` based on the class name if not already present. If :attr:`~Expression.mapper_method` is inherited, it will be viewed as unset and replaced. + Note that the class to which this decorator is applied need not be + a subclass of :class:`~pymbolic.Expression`. + .. versionadded:: 2024.1 """ def map_cls(cls: type[_T]) -> type[_T]: @@ -1008,7 +1018,7 @@ def map_cls(cls: type[_T]) -> type[_T]: # It should just understand that? _augment_expression_dataclass( dc_cls, # type: ignore[arg-type] - hash=hash + generate_hash=hash, ) return dc_cls