diff --git a/.github/workflows/fenicsx-tests.yml b/.github/workflows/fenicsx-tests.yml index f9c7d811f..bf5c987f6 100644 --- a/.github/workflows/fenicsx-tests.yml +++ b/.github/workflows/fenicsx-tests.yml @@ -41,7 +41,7 @@ jobs: with: path: ./ffcx repository: FEniCS/ffcx - ref: main + ref: mscroggs/ufl-cell - name: Install FFCx run: | cd ffcx @@ -78,7 +78,7 @@ jobs: - name: Install Basix and FFCx run: | python3 -m pip install git+https://github.com/FEniCS/basix.git - python3 -m pip install git+https://github.com/FEniCS/ffcx.git + python3 -m pip install git+https://github.com/FEniCS/ffcx.git@mscroggs/ufl-cell - name: Clone DOLFINx uses: actions/checkout@v3 diff --git a/setup.cfg b/setup.cfg index 256c55d64..1a0e9d878 100644 --- a/setup.cfg +++ b/setup.cfg @@ -43,6 +43,7 @@ setup_requires = wheel install_requires = numpy + typing_extensions; python_version < "3.11" [options.extras_require] docs = sphinx; sphinx_rtd_theme diff --git a/ufl/cell.py b/ufl/cell.py index 4fd0cb8e7..d9f3cf801 100644 --- a/ufl/cell.py +++ b/ufl/cell.py @@ -1,4 +1,4 @@ -"Types for representing a cell." +"""Types for representing a cell.""" # Copyright (C) 2008-2016 Martin Sandve Alnæs # @@ -6,293 +6,438 @@ # # SPDX-License-Identifier: LGPL-3.0-or-later +from __future__ import annotations import functools import numbers +import typing +import weakref -import ufl.cell -from ufl.core.ufl_type import attach_operators_from_hash_data +from ufl.core.ufl_type import UFLObject +from abc import abstractmethod + +try: + from typing import Self +except ImportError: + # This alternative is needed pre Python 3.11 + # Delete this after 04 Oct 2026 (Python 3.10 end of life) + from typing_extensions import Self -# Export list for ufl.classes __all_classes__ = ["AbstractCell", "Cell", "TensorProductCell"] -# --- The most abstract cell class, base class for other cell types +class AbstractCell(UFLObject): + """A base class for all cells.""" + @abstractmethod + def topological_dimension(self) -> int: + """Return the dimension of the topology of this cell.""" + + @abstractmethod + def geometric_dimension(self) -> int: + """Return the dimension of the geometry of this cell.""" + + @abstractmethod + def is_simplex(self) -> bool: + """Return True if this is a simplex cell.""" + + @abstractmethod + def has_simplex_facets(self) -> bool: + """Return True if all the facets of this cell are simplex cells.""" + + @abstractmethod + def _lt(self, other: Self) -> bool: + """Define an arbitrarily chosen but fixed sort order for all instances of this type with the same dimensions.""" + + @abstractmethod + def num_sub_entities(self, dim: int) -> int: + """Get the number of sub-entities of the given dimension.""" + + @abstractmethod + def sub_entities(self, dim: int) -> typing.Tuple[AbstractCell, ...]: + """Get the sub-entities of the given dimension.""" + + @abstractmethod + def sub_entity_types(self, dim: int) -> typing.Tuple[AbstractCell, ...]: + """Get the unique sub-entity types of the given dimension.""" + + @abstractmethod + def cellname(self) -> str: + """Return the cellname of the cell.""" + + @abstractmethod + def reconstruct(self, **kwargs: typing.Any) -> Cell: + """Reconstruct this cell, overwriting properties by those in kwargs.""" + + def __lt__(self, other: AbstractCell) -> bool: + """Define an arbitrarily chosen but fixed sort order for all cells.""" + if type(self) == type(other): + s = (self.geometric_dimension(), self.topological_dimension()) + o = (other.geometric_dimension(), other.topological_dimension()) + if s != o: + return s < o + return self._lt(other) + else: + if type(self).__name__ == type(other).__name__: + raise ValueError("Cannot order cell types with the same name") + return type(self).__name__ < type(other).__name__ -class AbstractCell(object): - """Representation of an abstract finite element cell with only the - dimensions known. + def num_vertices(self) -> int: + """Get the number of vertices.""" + return self.num_sub_entities(0) - """ - __slots__ = ("_topological_dimension", - "_geometric_dimension") + def num_edges(self) -> int: + """Get the number of edges.""" + return self.num_sub_entities(1) - def __init__(self, topological_dimension, geometric_dimension): - # Validate dimensions - if not isinstance(geometric_dimension, numbers.Integral): - raise ValueError("Expecting integer geometric_dimension.") - if not isinstance(topological_dimension, numbers.Integral): - raise ValueError("Expecting integer topological_dimension.") - if topological_dimension > geometric_dimension: - raise ValueError("Topological dimension cannot be larger than geometric dimension.") + def num_faces(self) -> int: + """Get the number of faces.""" + return self.num_sub_entities(2) - # Store validated dimensions - self._topological_dimension = topological_dimension - self._geometric_dimension = geometric_dimension - - def topological_dimension(self): - "Return the dimension of the topology of this cell." - return self._topological_dimension - - def geometric_dimension(self): - "Return the dimension of the space this cell is embedded in." - return self._geometric_dimension - - def is_simplex(self): - "Return True if this is a simplex cell." - raise NotImplementedError("Implement this to allow important checks and optimizations.") - - def has_simplex_facets(self): - "Return True if all the facets of this cell are simplex cells." - raise NotImplementedError("Implement this to allow important checks and optimizations.") - - def __lt__(self, other): - "Define an arbitrarily chosen but fixed sort order for all cells." - if not isinstance(other, AbstractCell): - return NotImplemented - # Sort by gdim first, tdim next, then whatever's left - # depending on the subclass - s = (self.geometric_dimension(), self.topological_dimension()) - o = (other.geometric_dimension(), other.topological_dimension()) - if s != o: - return s < o - return self._ufl_hash_data_() < other._ufl_hash_data_() + def num_facets(self) -> int: + """Get the number of facets. + Facets are entities of dimension tdim-1. + """ + tdim = self.topological_dimension() + return self.num_sub_entities(tdim - 1) -# --- Basic topological properties of known basic cells + def num_ridges(self) -> int: + """Get the number of ridges. -# Mapping from cell name to number of cell entities of each -# topological dimension -num_cell_entities = {"vertex": (1,), - "interval": (2, 1), - "triangle": (3, 3, 1), - "quadrilateral": (4, 4, 1), - "tetrahedron": (4, 6, 4, 1), - "prism": (6, 9, 5, 1), - "pyramid": (5, 8, 5, 1), - "hexahedron": (8, 12, 6, 1)} + Ridges are entities of dimension tdim-2. + """ + tdim = self.topological_dimension() + return self.num_sub_entities(tdim - 2) -# Mapping from cell name to topological dimension -cellname2dim = dict((k, len(v) - 1) for k, v in num_cell_entities.items()) + def num_peaks(self) -> int: + """Get the number of peaks. + Peaks are entities of dimension tdim-3. + """ + tdim = self.topological_dimension() + return self.num_sub_entities(tdim - 3) -# --- Basic cell representation classes + def vertices(self) -> typing.Tuple[AbstractCell, ...]: + """Get the vertices.""" + return self.sub_entities(0) -@attach_operators_from_hash_data -class Cell(AbstractCell): - "Representation of a named finite element cell with known structure." - __slots__ = ("_cellname",) + def edges(self) -> typing.Tuple[AbstractCell, ...]: + """Get the edges.""" + return self.sub_entities(1) - def __init__(self, cellname, geometric_dimension=None): - "Initialize basic cell description." + def faces(self) -> typing.Tuple[AbstractCell, ...]: + """Get the faces.""" + return self.sub_entities(2) - self._cellname = cellname + def facets(self) -> typing.Tuple[AbstractCell, ...]: + """Get the facets. - # The topological dimension is defined by the cell type, so - # the cellname must be among the known ones, so we can find - # the known dimension, unless we have a product cell, in which - # the given dimension is used - topological_dimension = len(num_cell_entities[cellname]) - 1 + Facets are entities of dimension tdim-1. + """ + tdim = self.topological_dimension() + return self.sub_entities(tdim - 1) - # The geometric dimension defaults to equal the topological - # dimension unless overridden for embedded cells - if geometric_dimension is None: - geometric_dimension = topological_dimension + def ridges(self) -> typing.Tuple[AbstractCell, ...]: + """Get the ridges. - # Initialize and validate dimensions - AbstractCell.__init__(self, topological_dimension, geometric_dimension) + Ridges are entities of dimension tdim-2. + """ + tdim = self.topological_dimension() + return self.sub_entities(tdim - 2) - # --- Overrides of AbstractCell methods --- + def peaks(self) -> typing.Tuple[AbstractCell, ...]: + """Get the peaks. - def reconstruct(self, geometric_dimension=None): - if geometric_dimension is None: - geometric_dimension = self._geometric_dimension - return Cell(self._cellname, geometric_dimension=geometric_dimension) + Peaks are entities of dimension tdim-3. + """ + tdim = self.topological_dimension() + return self.sub_entities(tdim - 3) - def is_simplex(self): - " Return True if this is a simplex cell." - return self.num_vertices() == self.topological_dimension() + 1 + def vertex_types(self) -> typing.Tuple[AbstractCell, ...]: + """Get the unique vertices types.""" + return self.sub_entity_types(0) - def has_simplex_facets(self): - "Return True if all the facets of this cell are simplex cells." - return self.is_simplex() or self.cellname() == "quadrilateral" + def edge_types(self) -> typing.Tuple[AbstractCell, ...]: + """Get the unique edge types.""" + return self.sub_entity_types(1) - # --- Specific cell properties --- + def face_types(self) -> typing.Tuple[AbstractCell, ...]: + """Get the unique face types.""" + return self.sub_entity_types(2) - def cellname(self): - "Return the cellname of the cell." - return self._cellname + def facet_types(self) -> typing.Tuple[AbstractCell, ...]: + """Get the unique facet types. - def num_vertices(self): - "The number of cell vertices." - return num_cell_entities[self.cellname()][0] + Facets are entities of dimension tdim-1. + """ + tdim = self.topological_dimension() + return self.sub_entity_types(tdim - 1) - def num_edges(self): - "The number of cell edges." - return num_cell_entities[self.cellname()][1] + def ridge_types(self) -> typing.Tuple[AbstractCell, ...]: + """Get the unique ridge types. - def num_facets(self): - "The number of cell facets." + Ridges are entities of dimension tdim-2. + """ tdim = self.topological_dimension() - return num_cell_entities[self.cellname()][tdim - 1] - - # --- Facet properties --- - - def facet_types(self): - "A tuple of ufl.Cell representing the facets of self." - # TODO Move outside method? - facet_type_names = {"interval": ("vertex",), - "triangle": ("interval",), - "quadrilateral": ("interval",), - "tetrahedron": ("triangle",), - "hexahedron": ("quadrilateral",), - "prism": ("triangle", "quadrilateral")} - return tuple(ufl.Cell(facet_name, self.geometric_dimension()) - for facet_name in facet_type_names[self.cellname()]) - - # --- Special functions for proper object behaviour --- - - def __str__(self): - gdim = self.geometric_dimension() - tdim = self.topological_dimension() - s = self.cellname() - if gdim > tdim: - s += "%dD" % gdim - return s + return self.sub_entity_types(tdim - 2) - def __repr__(self): - # For standard cells, return name of builtin cell object if - # possible. This reduces the size of the repr strings for - # domains, elements, etc. as well - gdim = self.geometric_dimension() + def peak_types(self) -> typing.Tuple[AbstractCell, ...]: + """Get the unique peak types. + + Peaks are entities of dimension tdim-3. + """ tdim = self.topological_dimension() - name = self.cellname() - if gdim == tdim and name in cellname2dim: - r = name - else: - r = "Cell(%s, %s)" % (repr(name), repr(gdim)) - return r + return self.sub_entity_types(tdim - 3) - def _ufl_hash_data_(self): - return (self._geometric_dimension, self._topological_dimension, - self._cellname) +_sub_entity_celltypes = { + "vertex": [("vertex", )], + "interval": [tuple("vertex" for i in range(2)), ("interval", )], + "triangle": [tuple("vertex" for i in range(3)), tuple("interval" for i in range(3)), ("triangle", )], + "quadrilateral": [tuple("vertex" for i in range(4)), tuple("interval" for i in range(4)), ("quadrilateral", )], + "tetrahedron": [tuple("vertex" for i in range(4)), tuple("interval" for i in range(4)), + tuple("triangle" for i in range(4)), ("tetrahedron", )], + "hexahedron": [tuple("vertex" for i in range(8)), tuple("interval" for i in range(12)), + tuple("quadrilateral" for i in range(6)), ("hexahedron", )], + "prism": [tuple("vertex" for i in range(6)), tuple("interval" for i in range(9)), + ("triangle", "quadrilateral", "quadrilateral", "quadrilateral", "triangle"), ("prism", )], + "pyramid": [tuple("vertex" for i in range(5)), tuple("interval" for i in range(8)), + ("quadrilateral", "triangle", "triangle", "triangle", "triangle"), ("pyramid", )], +} -@attach_operators_from_hash_data -class TensorProductCell(AbstractCell): - __slots__ = ("_cells",) - def __init__(self, *cells, **kwargs): - keywords = list(kwargs.keys()) - if keywords and keywords != ["geometric_dimension"]: - raise ValueError( - "TensorProductCell got an unexpected keyword argument '%s'" % - keywords[0]) +class Cell(AbstractCell): + """Representation of a named finite element cell with known structure.""" + __slots__ = ("_cellname", "_tdim", "_gdim", "_num_cell_entities", "_sub_entity_types", + "_sub_entities", "_sub_entity_types") - self._cells = tuple(as_cell(cell) for cell in cells) + def __init__(self, cellname: str, geometric_dimension: typing.Optional[int] = None): + if cellname not in _sub_entity_celltypes: + raise ValueError(f"Unsupported cell type: {cellname}") + + self._sub_entity_celltypes = _sub_entity_celltypes[cellname] + + self._cellname = cellname + self._tdim = len(self._sub_entity_celltypes) - 1 + self._gdim = self._tdim if geometric_dimension is None else geometric_dimension + + self._num_cell_entities = [len(i) for i in self._sub_entity_celltypes] + self._sub_entities = [tuple(Cell(t, self._gdim) for t in se_types) for se_types in self._sub_entity_celltypes[:-1]] + self._sub_entity_types = [tuple(Cell(t, self._gdim) for t in set(se_types)) for se_types in self._sub_entity_celltypes[:-1]] + self._sub_entities.append((weakref.proxy(self), )) + self._sub_entity_types.append((weakref.proxy(self), )) + + if not isinstance(self._gdim, numbers.Integral): + raise ValueError("Expecting integer geometric_dimension.") + if not isinstance(self._tdim, numbers.Integral): + raise ValueError("Expecting integer topological_dimension.") + if self._tdim > self._gdim: + raise ValueError("Topological dimension cannot be larger than geometric dimension.") + + def topological_dimension(self) -> int: + """Return the dimension of the topology of this cell.""" + return self._tdim + + def geometric_dimension(self) -> int: + """Return the dimension of the geometry of this cell.""" + return self._gdim + + def is_simplex(self) -> bool: + """Return True if this is a simplex cell.""" + return self._cellname in ["vertex", "interval", "triangle", "tetrahedron"] + + def has_simplex_facets(self) -> bool: + """Return True if all the facets of this cell are simplex cells.""" + return self._cellname in ["interval", "triangle", "quadrilateral", "tetrahedron"] + + def num_sub_entities(self, dim: int) -> int: + """Get the number of sub-entities of the given dimension.""" + try: + return self._num_cell_entities[dim] + except IndexError: + return 0 + + def sub_entities(self, dim: int) -> typing.Tuple[AbstractCell, ...]: + """Get the sub-entities of the given dimension.""" + try: + return self._sub_entities[dim] + except IndexError: + return () + + def sub_entity_types(self, dim: int) -> typing.Tuple[AbstractCell, ...]: + """Get the unique sub-entity types of the given dimension.""" + try: + return self._sub_entity_types[dim] + except IndexError: + return () + + def _lt(self, other: Self) -> bool: + return self._cellname < other._cellname + + def cellname(self) -> str: + """Return the cellname of the cell.""" + return self._cellname - tdim = sum([cell.topological_dimension() for cell in self._cells]) - if kwargs: - gdim = kwargs["geometric_dimension"] + def __str__(self) -> str: + if self._gdim == self._tdim: + return self._cellname else: - gdim = sum([cell.geometric_dimension() for cell in self._cells]) + return f"{self._cellname}{self._gdim}D" - AbstractCell.__init__(self, tdim, gdim) + def __repr__(self) -> str: + if self._gdim == self._tdim: + return self._cellname + else: + return f"Cell({self._cellname}, {self._gdim})" - def cellname(self): - "Return the cellname of the cell." - return " * ".join([cell._cellname for cell in self._cells]) + def _ufl_hash_data_(self) -> typing.Hashable: + return (self._cellname, self._gdim) - def reconstruct(self, geometric_dimension=None): - if geometric_dimension is None: - geometric_dimension = self._geometric_dimension - return TensorProductCell(*(self._cells), geometric_dimension=geometric_dimension) + def reconstruct(self, **kwargs: typing.Any) -> Cell: + """Reconstruct this cell, overwriting properties by those in kwargs.""" + gdim = self._gdim + for key, value in kwargs.items(): + if key == "geometric_dimension": + gdim = value + else: + raise TypeError(f"reconstruct() got unexpected keyword argument '{key}'") + return Cell(self._cellname, geometric_dimension=gdim) - def is_simplex(self): - "Return True if this is a simplex cell." - if len(self._cells) == 1: - return self._cells[0].is_simplex() - return False - def has_simplex_facets(self): - "Return True if all the facets of this cell are simplex cells." - if len(self._cells) == 1: - return self._cells[0].has_simplex_facets() - return False +class TensorProductCell(AbstractCell): + __slots__ = ("_cells", "_tdim", "_gdim") - def num_vertices(self): - "The number of cell vertices." - return functools.reduce(lambda x, y: x * y, [c.num_vertices() for c in self._cells]) + def __init__(self, *cells: Cell, geometric_dimension: typing.Optional[int] = None): + self._cells = tuple(as_cell(cell) for cell in cells) - def num_edges(self): - "The number of cell edges." - raise ValueError("Not defined for TensorProductCell.") + self._tdim = sum([cell.topological_dimension() for cell in self._cells]) + self._gdim = self._tdim if geometric_dimension is None else geometric_dimension - def num_facets(self): - "The number of cell facets." - return sum(c.num_facets() for c in self._cells if c.topological_dimension() > 0) + if not isinstance(self._gdim, numbers.Integral): + raise ValueError("Expecting integer geometric_dimension.") + if not isinstance(self._tdim, numbers.Integral): + raise ValueError("Expecting integer topological_dimension.") + if self._tdim > self._gdim: + raise ValueError("Topological dimension cannot be larger than geometric dimension.") - def sub_cells(self): - "Return list of cell factors." + def sub_cells(self) -> typing.List[AbstractCell]: + """Return list of cell factors.""" return self._cells - def __str__(self): - gdim = self.geometric_dimension() - tdim = self.topological_dimension() - reprs = ", ".join(repr(c) for c in self._cells) - if gdim == tdim: - gdimstr = "" - else: - gdimstr = ", geometric_dimension=%d" % gdim - r = "TensorProductCell(%s%s)" % (reprs, gdimstr) - return r + def topological_dimension(self) -> int: + """Return the dimension of the topology of this cell.""" + return self._tdim - def __repr__(self): - return str(self) + def geometric_dimension(self) -> int: + """Return the dimension of the geometry of this cell.""" + return self._gdim - def _ufl_hash_data_(self): - return tuple(c._ufl_hash_data_() for c in self._cells) + (self._geometric_dimension,) + def is_simplex(self) -> bool: + """Return True if this is a simplex cell.""" + if len(self._cells) == 1: + return self._cells[0].is_simplex() + return False + def has_simplex_facets(self) -> bool: + """Return True if all the facets of this cell are simplex cells.""" + if len(self._cells) == 1: + return self._cells[0].has_simplex_facets() + if self._tdim == 1: + return True + return False -# --- Utility conversion functions + def num_sub_entities(self, dim: int) -> int: + """Get the number of sub-entities of the given dimension.""" + if dim < 0 or dim > self._tdim: + return 0 + if dim == 0: + return functools.reduce(lambda x, y: x * y, [c.num_vertices() for c in self._cells]) + if dim == self._tdim - 1: + # Note: This is not the number of facets that the cell has, but I'm leaving it here for now + # to not change past behaviour + return sum(c.num_facets() for c in self._cells if c.topological_dimension() > 0) + if dim == self._tdim: + return 1 + raise NotImplementedError(f"TensorProductCell.num_sub_entities({dim}) is not implemented.") + + def sub_entities(self, dim: int) -> typing.Tuple[AbstractCell, ...]: + """Get the sub-entities of the given dimension.""" + if dim < 0 or dim > self._tdim: + return [] + if dim == 0: + return [Cell("vertex", self._gdim) for i in range(self.num_sub_entities(0))] + if dim == self._tdim: + return [self] + raise NotImplementedError(f"TensorProductCell.sub_entities({dim}) is not implemented.") + + def sub_entity_types(self, dim: int) -> typing.Tuple[AbstractCell, ...]: + """Get the unique sub-entity types of the given dimension.""" + if dim < 0 or dim > self._tdim: + return [] + if dim == 0: + return [Cell("vertex", self._gdim)] + if dim == self._tdim: + return [self] + raise NotImplementedError(f"TensorProductCell.sub_entities({dim}) is not implemented.") + + def _lt(self, other: Self) -> bool: + return self._ufl_hash_data_() < other._ufl_hash_data_() -# Mapping from topological dimension to reference cell name for -# simplices -_simplex_dim2cellname = {0: "vertex", - 1: "interval", - 2: "triangle", - 3: "tetrahedron"} + def cellname(self) -> str: + """Return the cellname of the cell.""" + return " * ".join([cell.cellname() for cell in self._cells]) -# Mapping from topological dimension to reference cell name for -# hypercubes -_hypercube_dim2cellname = {0: "vertex", - 1: "interval", - 2: "quadrilateral", - 3: "hexahedron"} + def __str__(self) -> str: + s = "TensorProductCell(" + s += ", ".join(f"{c!r}" for c in self._cells) + if self._tdim != self._gdim: + s += f", geometric_dimension={self._gdim}" + s += ")" + return s + def __repr__(self) -> str: + return str(self) -def simplex(topological_dimension, geometric_dimension=None): - "Return a simplex cell of given dimension." - return Cell(_simplex_dim2cellname[topological_dimension], - geometric_dimension) + def _ufl_hash_data_(self) -> typing.Hashable: + return tuple(c._ufl_hash_data_() for c in self._cells) + (self._gdim,) + + def reconstruct(self, **kwargs: typing.Any) -> Cell: + """Reconstruct this cell, overwriting properties by those in kwargs.""" + gdim = self._gdim + for key, value in kwargs.items(): + if key == "geometric_dimension": + gdim = value + else: + raise TypeError(f"reconstruct() got unexpected keyword argument '{key}'") + return TensorProductCell(self._cellname, geometric_dimension=gdim) + + +def simplex(topological_dimension: int, geometric_dimension: typing.Optional[int] = None): + """Return a simplex cell of the given dimension.""" + if topological_dimension == 0: + return Cell("vertex", geometric_dimension) + if topological_dimension == 1: + return Cell("interval", geometric_dimension) + if topological_dimension == 2: + return Cell("triangle", geometric_dimension) + if topological_dimension == 3: + return Cell("tetrahedron", geometric_dimension) + raise ValueError(f"Unsupported topological dimension for simplex: {topological_dimension}") def hypercube(topological_dimension, geometric_dimension=None): - "Return a hypercube cell of given dimension." - return Cell(_hypercube_dim2cellname[topological_dimension], - geometric_dimension) - - -def as_cell(cell): + """Return a hypercube cell of the given dimension.""" + if topological_dimension == 0: + return Cell("vertex", geometric_dimension) + if topological_dimension == 1: + return Cell("interval", geometric_dimension) + if topological_dimension == 2: + return Cell("quadrilateral", geometric_dimension) + if topological_dimension == 3: + return Cell("hexahedron", geometric_dimension) + raise ValueError(f"Unsupported topological dimension for hypercube: {topological_dimension}") + + +def as_cell(cell: typing.Union[AbstractCell, str, typing.Tuple[AbstractCell, ...]]) -> AbstractCell: """Convert any valid object to a Cell or return cell if it is already a Cell. Allows an already valid cell, a known cellname string, or a tuple of cells for a product cell. diff --git a/ufl/core/ufl_type.py b/ufl/core/ufl_type.py index e6435684c..df6c991e5 100644 --- a/ufl/core/ufl_type.py +++ b/ufl/core/ufl_type.py @@ -7,15 +7,43 @@ # SPDX-License-Identifier: LGPL-3.0-or-later # # Modified by Massimiliano Leoni, 2016 +# Modified by Matthew Scroggs, 2023 + +from __future__ import annotations +import typing +import warnings from ufl.core.compute_expr_hash import compute_expr_hash from ufl.utils.formatting import camel2underscore +from abc import ABC, abstractmethod # Avoid circular import import ufl.core as core -# Make UFL type coercion available under the as_ufl name -# as_ufl = Expr._ufl_coerce_ +class UFLObject(ABC): + """A UFL Object.""" + + @abstractmethod + def _ufl_hash_data_(self) -> typing.Hashable: + """Return hashable data that uniquely defines this object.""" + + @abstractmethod + def __str__(self) -> str: + """Return a human-readable string representation of the object.""" + + @abstractmethod + def __repr__(self) -> str: + """Return a string representation of the object.""" + + def __hash__(self) -> int: + return hash(self._ufl_hash_data_()) + + def __eq__(self, other): + return type(self) == type(other) and self._ufl_hash_data_() == other._ufl_hash_data_() + + def __ne__(self, other): + return not self.__eq__(other) + def attach_operators_from_hash_data(cls): """Class decorator to attach ``__hash__``, ``__eq__`` and ``__ne__`` implementations. @@ -23,6 +51,7 @@ def attach_operators_from_hash_data(cls): These are implemented in terms of a ``._ufl_hash_data()`` method on the class, which should return a tuple or hashable and comparable data. """ + warnings.warn("attach_operators_from_hash_data deprecated, please use UFLObject instead.", DeprecationWarning) assert hasattr(cls, "_ufl_hash_data_") def __hash__(self): @@ -37,7 +66,7 @@ def __eq__(self, other): def __ne__(self, other): "__ne__ implementation attached in attach_operators_from_hash_data" - return type(self) != type(other) or self._ufl_hash_data_() != other._ufl_hash_data_() + return not self.__eq__(other) cls.__ne__ = __ne__ return cls