Skip to content

Commit

Permalink
AbstractVars are now overriden by annotations in subclasses.
Browse files Browse the repository at this point in the history
When subclassing, it used to be the case that `cls.__abstractvars__`
basically never got smaller (only if an element was overriden by a
class-level attribute or method):
```python
class Foo(eqx.Module):
  x: AbstractVar[bool]

class Bar(eqx.Module):
  x: bool

Bar.__abstractvars__ == frozenset({"x"})
```
This was intended -- the idea is that the all abstractvars would get
checked during initialisation, i.e. validity wrt this condition being
a property of the instance, rather than being a property of just the
class object.

With this change, the above example will remove `x` from
`__abstractvars__`.

This is because it's useful and typical to reason about whether a class
is abstract or not -- it's much more annoying to have to reason about
whether each individual instance is abstract. Indeed the recent changes
to `eqx.Module`, in strict mode, are a use case in which want to be
able to reason about things in this way.

Thus, any element of either `subcls.__dict__` or
`subcls.__annotations__` can be used to concretise any abstract
variable, rather than just doing `hasattr(self, var)` during
initialisation.
  • Loading branch information
patrick-kidger committed Sep 28, 2023
1 parent fc13135 commit d15b320
Showing 1 changed file with 27 additions and 36 deletions.
63 changes: 27 additions & 36 deletions equinox/_better_abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ def __new__(mcs, name, bases, namespace, /, **kwargs):
cls = super().__new__(mcs, name, bases, namespace, **kwargs)
abstract_vars = dict()
abstract_class_vars = dict()
cls_annotations = cls.__dict__.get("__annotations__", {})
for attr, group in [
("__abstractvars__", abstract_vars),
("__abstractclassvars__", abstract_class_vars),
Expand All @@ -233,38 +234,34 @@ def __new__(mcs, name, bases, namespace, /, **kwargs):
"Base classes have mismatched type annotations for "
f"{name}"
)
if "__annotations__" in cls.__dict__:
try:
new_annotation = cls.__annotations__[name]
except KeyError:
pass
else:
if not _is_concretisation(new_annotation, annotation):
raise TypeError(
"Base class and derived class have mismatched type "
f"annotations for {name}"
)
try:
new_annotation = cls_annotations[name]
except KeyError:
pass
else:
if not _is_concretisation(new_annotation, annotation):
raise TypeError(
"Base class and derived class have mismatched type "
f"annotations for {name}"
)
# Not just `if name not in namespace`, as `cls.__dict__` may be
# slightly bigger from `__init_subclass__`.
if name not in cls.__dict__:
if name not in cls.__dict__ and name not in cls_annotations:
group[name] = annotation
if "__annotations__" in cls.__dict__:
for name, annotation in cls.__annotations__.items():
is_abstract, is_class = _process_annotation(annotation)
if is_abstract:
if name in namespace:
if is_class:
raise TypeError(
f"Abstract class attribute {name} cannot have value"
)
else:
raise TypeError(
f"Abstract attribute {name} cannot have value"
)
for name, annotation in cls_annotations.items():
is_abstract, is_class = _process_annotation(annotation)
if is_abstract:
if name in namespace:
if is_class:
abstract_class_vars[name] = annotation
raise TypeError(
f"Abstract class attribute {name} cannot have value"
)
else:
abstract_vars[name] = annotation
raise TypeError(f"Abstract attribute {name} cannot have value")
if is_class:
abstract_class_vars[name] = annotation
else:
abstract_vars[name] = annotation
cls.__abstractvars__ = abstract_vars # pyright: ignore
cls.__abstractclassvars__ = abstract_class_vars # pyright: ignore
return cls
Expand All @@ -277,17 +274,11 @@ def __call__(cls, *args, **kwargs):
f"attributes {abstract_class_vars}"
)
self = super().__call__(*args, **kwargs)
abstract_vars = set()
for name in cls.__abstractvars__: # pyright: ignore
# Deliberately not doing `if name in self.__dict__` to allow for use of
# properties (which are actually class attributes) to override abstract
# instance variables.
if getattr(self, name, _sentinel) is _sentinel:
abstract_vars.add(name)
if len(abstract_vars) > 0:
if len(cls.__abstractvars__) > 0: # pyright: ignore
abstract_class_vars = set(cls.__abstractvars__) # pyright: ignore
raise TypeError(
f"Can't instantiate abstract class {cls.__name__} with abstract "
f"attributes {abstract_vars}"
f"attributes {abstract_class_vars}"
)
return self

Expand Down

0 comments on commit d15b320

Please sign in to comment.