diff --git a/docs/source/contributing/docs/typings.rst b/docs/source/contributing/docs/typings.rst index 7cc14068c8..b262cda749 100644 --- a/docs/source/contributing/docs/typings.rst +++ b/docs/source/contributing/docs/typings.rst @@ -141,6 +141,14 @@ Typing guidelines from manim.typing import Vector3D # type stuff with Vector3D +* When typing something like ``VGroup``, type it as if it were a list, not as if it was a tuple. + +.. code:: py + # not VGroup[Line, Line] + def get_two_lines() -> VGroup[Line]: + return VGroup(Line(), Line().shift(LEFT)) + + Missing Sections for typehints are: ----------------------------------- diff --git a/manim/animation/changing.py b/manim/animation/changing.py index bb11cfc0a4..17e2f5ad49 100644 --- a/manim/animation/changing.py +++ b/manim/animation/changing.py @@ -4,10 +4,11 @@ __all__ = ["AnimatedBoundary", "TracedPath"] +from collections.abc import Sequence from typing import Callable from manim.mobject.opengl.opengl_compatibility import ConvertToOpenGL -from manim.mobject.types.vectorized_mobject import VGroup, VMobject +from manim.mobject.types.vectorized_mobject import VGroup, VMobject, VMobjectT from manim.utils.color import ( BLUE_B, BLUE_D, @@ -19,7 +20,7 @@ from manim.utils.rate_functions import smooth -class AnimatedBoundary(VGroup): +class AnimatedBoundary(VGroup[VMobjectT]): """Boundary of a :class:`.VMobject` with animated color change. Examples @@ -38,11 +39,11 @@ def construct(self): def __init__( self, - vmobject, - colors=[BLUE_D, BLUE_B, BLUE_E, GREY_BROWN], - max_stroke_width=3, - cycle_rate=0.5, - back_and_forth=True, + vmobject: VMobjectT, + colors: Sequence[ParsableManimColor] = [BLUE_D, BLUE_B, BLUE_E, GREY_BROWN], + max_stroke_width: float = 3, + cycle_rate: float = 0.5, + back_and_forth: bool = True, draw_rate_func=smooth, fade_rate_func=smooth, **kwargs, @@ -60,7 +61,7 @@ def __init__( ] self.add(*self.boundary_copies) self.total_time = 0 - self.add_updater(lambda m, dt: self.update_boundary_copies(dt)) + self.add_updater(lambda _, dt: self.update_boundary_copies(dt)) def update_boundary_copies(self, dt): # Not actual time, but something which passes at diff --git a/manim/mobject/geometry/shape_matchers.py b/manim/mobject/geometry/shape_matchers.py index 86afb58db5..a585fd6fcf 100644 --- a/manim/mobject/geometry/shape_matchers.py +++ b/manim/mobject/geometry/shape_matchers.py @@ -4,9 +4,7 @@ __all__ = ["SurroundingRectangle", "BackgroundRectangle", "Cross", "Underline"] -from typing import Any - -from typing_extensions import Self +from typing import TYPE_CHECKING, Any from manim import logger from manim._config import config @@ -15,7 +13,6 @@ LEFT, RIGHT, SMALL_BUFF, - UP, ) from manim.mobject.geometry.line import Line from manim.mobject.geometry.polygram import RoundedRectangle @@ -23,6 +20,9 @@ from manim.mobject.types.vectorized_mobject import VGroup from manim.utils.color import BLACK, RED, YELLOW, ManimColor, ParsableManimColor +if TYPE_CHECKING: + from typing_extensions import Self + class SurroundingRectangle(RoundedRectangle): r"""A rectangle surrounding a :class:`~.Mobject` @@ -151,7 +151,7 @@ def get_fill_color(self) -> ManimColor: return temp_color -class Cross(VGroup): +class Cross(VGroup[Line]): """Creates a cross. Parameters @@ -184,9 +184,7 @@ def __init__( scale_factor: float = 1.0, **kwargs: Any, ) -> None: - super().__init__( - Line(UP + LEFT, DOWN + RIGHT), Line(UP + RIGHT, DOWN + LEFT), **kwargs - ) + super().__init__(Line(UL, DR), Line(UR, DL), **kwargs) if mobject is not None: self.replace(mobject, stretch=True) self.scale(scale_factor) diff --git a/manim/mobject/logo.py b/manim/mobject/logo.py index 6242a4c645..afc6dc0279 100644 --- a/manim/mobject/logo.py +++ b/manim/mobject/logo.py @@ -16,7 +16,7 @@ from ..animation.creation import Create, SpiralIn from ..animation.fading import FadeIn from ..mobject.svg.svg_mobject import VMobjectFromSVGPath -from ..mobject.types.vectorized_mobject import VGroup +from ..mobject.types.vectorized_mobject import VGroup, VMobjectT from ..utils.rate_functions import ease_in_out_cubic, smooth MANIM_SVG_PATHS: list[se.Path] = [ @@ -100,7 +100,7 @@ ] -class ManimBanner(VGroup): +class ManimBanner(VGroup[VMobjectT]): r"""Convenience class representing Manim's banner. Can be animated using custom methods. diff --git a/manim/mobject/table.py b/manim/mobject/table.py index 0810e6d59b..aaed561996 100644 --- a/manim/mobject/table.py +++ b/manim/mobject/table.py @@ -79,12 +79,12 @@ def construct(self): from ..animation.composition import AnimationGroup from ..animation.creation import Create, Write from ..animation.fading import FadeIn -from ..mobject.types.vectorized_mobject import VGroup, VMobject +from ..mobject.types.vectorized_mobject import VGroup, VMobject, VMobjectT from ..utils.color import BLACK, YELLOW, ManimColor, ParsableManimColor from .utils import get_vectorized_mobject_class -class Table(VGroup): +class Table(VGroup[VMobjectT]): r"""A mobject that displays a table on the screen. Parameters diff --git a/manim/mobject/text/code_mobject.py b/manim/mobject/text/code_mobject.py index 999ab3c90e..758b03a91b 100644 --- a/manim/mobject/text/code_mobject.py +++ b/manim/mobject/text/code_mobject.py @@ -22,11 +22,11 @@ from manim.mobject.geometry.polygram import RoundedRectangle from manim.mobject.geometry.shape_matchers import SurroundingRectangle from manim.mobject.text.text_mobject import Paragraph -from manim.mobject.types.vectorized_mobject import VGroup +from manim.mobject.types.vectorized_mobject import VGroup, VMobjectT from manim.utils.color import WHITE -class Code(VGroup): +class Code(VGroup[VMobjectT]): """A highlighted source code listing. An object ``listing`` of :class:`.Code` is a :class:`.VGroup` consisting diff --git a/manim/mobject/text/text_mobject.py b/manim/mobject/text/text_mobject.py index ef14267891..e8841a68dd 100644 --- a/manim/mobject/text/text_mobject.py +++ b/manim/mobject/text/text_mobject.py @@ -70,7 +70,7 @@ def construct(self): from manim.constants import * from manim.mobject.geometry.arc import Dot from manim.mobject.svg.svg_mobject import SVGMobject -from manim.mobject.types.vectorized_mobject import VGroup, VMobject +from manim.mobject.types.vectorized_mobject import VGroup, VMobject, VMobjectT from manim.utils.color import ManimColor, ParsableManimColor, color_gradient from manim.utils.deprecation import deprecated @@ -115,7 +115,7 @@ def remove_invisible_chars(mobject: SVGMobject) -> SVGMobject: return mobject_without_dots -class Paragraph(VGroup): +class Paragraph(VGroup[VMobjectT]): r"""Display a paragraph of text. For a given :class:`.Paragraph` ``par``, the attribute ``par.chars`` is a diff --git a/manim/mobject/three_d/polyhedra.py b/manim/mobject/three_d/polyhedra.py index 8046f6066c..d1d4718818 100644 --- a/manim/mobject/three_d/polyhedra.py +++ b/manim/mobject/three_d/polyhedra.py @@ -9,7 +9,7 @@ from manim.mobject.geometry.polygram import Polygon from manim.mobject.graph import Graph from manim.mobject.three_d.three_dimensions import Dot3D -from manim.mobject.types.vectorized_mobject import VGroup +from manim.mobject.types.vectorized_mobject import VGroup, VMobjectT from manim.utils.qhull import QuickHull if TYPE_CHECKING: @@ -26,7 +26,7 @@ ] -class Polyhedron(VGroup): +class Polyhedron(VGroup[VMobjectT]): """An abstract polyhedra class. In this implementation, polyhedra are defined with a list of vertex coordinates in space, and a list diff --git a/manim/mobject/three_d/three_dimensions.py b/manim/mobject/three_d/three_dimensions.py index 7b30f9a7ad..62945496cb 100644 --- a/manim/mobject/three_d/three_dimensions.py +++ b/manim/mobject/three_d/three_dimensions.py @@ -32,7 +32,12 @@ from manim.mobject.mobject import * from manim.mobject.opengl.opengl_compatibility import ConvertToOpenGL from manim.mobject.opengl.opengl_mobject import OpenGLMobject -from manim.mobject.types.vectorized_mobject import VectorizedPoint, VGroup, VMobject +from manim.mobject.types.vectorized_mobject import ( + VectorizedPoint, + VGroup, + VMobject, + VMobjectT, +) from manim.utils.color import ( ManimColor, ParsableManimColor, @@ -458,7 +463,7 @@ def __init__( self.set_color(color) -class Cube(VGroup): +class Cube(VGroup[VMobjectT]): """A three-dimensional cube. Parameters diff --git a/manim/mobject/types/vectorized_mobject.py b/manim/mobject/types/vectorized_mobject.py index a8d32682fd..ac36a9672f 100644 --- a/manim/mobject/types/vectorized_mobject.py +++ b/manim/mobject/types/vectorized_mobject.py @@ -14,11 +14,11 @@ import itertools as it import sys -from collections.abc import Generator, Hashable, Iterable, Mapping, Sequence -from typing import TYPE_CHECKING, Callable, Literal +from typing import TYPE_CHECKING, Any, Callable, Generic, Literal import numpy as np from PIL.Image import Image +from typing_extensions import TypeVar from manim import config from manim.constants import * @@ -48,7 +48,7 @@ from manim.utils.space_ops import rotate_vector, shoelace_direction if TYPE_CHECKING: - from typing import Any + from collections.abc import Generator, Hashable, Iterable, Mapping, Sequence import numpy.typing as npt from typing_extensions import Self @@ -583,7 +583,7 @@ def get_fill_colors(self) -> list[ManimColor | None]: def get_fill_opacities(self) -> npt.NDArray[ManimFloat]: return self.get_fill_rgbas()[:, 3] - def get_stroke_rgbas(self, background: bool = False) -> RGBA_Array_float | Zeros: + def get_stroke_rgbas(self, background: bool = False) -> RGBA_Array_Float | Zeros: try: if background: self.background_stroke_rgbas: RGBA_Array_Float @@ -2059,7 +2059,10 @@ def force_direction(self, target_direction: Literal["CW", "CCW"]) -> Self: return self -class VGroup(VMobject, metaclass=ConvertToOpenGL): +VMobjectT = TypeVar("VMobjectT", bound=VMobject, default=VMobject) + + +class VGroup(VMobject, Generic[VMobjectT], metaclass=ConvertToOpenGL): """A group of vectorized mobjects. This can be used to group multiple :class:`~.VMobject` instances together @@ -2119,7 +2122,7 @@ def construct(self): """ def __init__( - self, *vmobjects: VMobject | Iterable[VMobject], **kwargs: Any + self, *vmobjects: VMobjectT | Iterable[VMobjectT], **kwargs: Any ) -> None: super().__init__(**kwargs) self.add(*vmobjects) @@ -2660,7 +2663,7 @@ def set_location(self, new_loc: Point3D): self.set_points(np.array([new_loc])) -class CurvesAsSubmobjects(VGroup): +class CurvesAsSubmobjects(VGroup[VMobject]): """Convert a curve's elements to submobjects. Examples diff --git a/manim/mobject/vector_field.py b/manim/mobject/vector_field.py index 28f5c6d26f..e9cc74bcc7 100644 --- a/manim/mobject/vector_field.py +++ b/manim/mobject/vector_field.py @@ -12,7 +12,7 @@ import random from collections.abc import Iterable, Sequence from math import ceil, floor -from typing import Callable +from typing import TYPE_CHECKING, Callable import numpy as np from PIL import Image @@ -27,7 +27,7 @@ from ..animation.indication import ShowPassingFlash from ..constants import OUT, RIGHT, UP, RendererType from ..mobject.mobject import Mobject -from ..mobject.types.vectorized_mobject import VGroup +from ..mobject.types.vectorized_mobject import VGroup, VMobjectT from ..mobject.utils import get_vectorized_mobject_class from ..utils.bezier import interpolate, inverse_interpolate from ..utils.color import ( @@ -43,10 +43,16 @@ from ..utils.rate_functions import ease_out_sine, linear from ..utils.simple_functions import sigmoid +if TYPE_CHECKING: + import numpy.typing as npt + from typing_extensions import Self + + from manim.typing import MappingFunction, Point3D, Vector3D + DEFAULT_SCALAR_FIELD_COLORS: list = [BLUE_E, GREEN, YELLOW, RED] -class VectorField(VGroup): +class VectorField(VGroup[VMobjectT]): """A vector field. Vector fields are based on a function defining a vector at every position. @@ -74,14 +80,14 @@ class VectorField(VGroup): def __init__( self, - func: Callable[[np.ndarray], np.ndarray], + func: MappingFunction, color: ParsableManimColor | None = None, - color_scheme: Callable[[np.ndarray], float] | None = None, + color_scheme: Callable[[Vector3D], float] | None = None, min_color_scheme_value: float = 0, max_color_scheme_value: float = 2, colors: Sequence[ParsableManimColor] = DEFAULT_SCALAR_FIELD_COLORS, **kwargs, - ): + ) -> None: super().__init__(**kwargs) self.func = func if color is None: @@ -121,9 +127,9 @@ def pos_to_rgb(pos: np.ndarray) -> tuple[float, float, float, float]: @staticmethod def shift_func( - func: Callable[[np.ndarray], np.ndarray], - shift_vector: np.ndarray, - ) -> Callable[[np.ndarray], np.ndarray]: + func: MappingFunction, + shift_vector: Vector3D, + ) -> MappingFunction: """Shift a vector field function. Parameters @@ -143,9 +149,9 @@ def shift_func( @staticmethod def scale_func( - func: Callable[[np.ndarray], np.ndarray], + func: MappingFunction, scalar: float, - ) -> Callable[[np.ndarray], np.ndarray]: + ) -> MappingFunction: """Scale a vector field function. Parameters @@ -178,7 +184,7 @@ def construct(self): """ return lambda p: func(p * scalar) - def fit_to_coordinate_system(self, coordinate_system: CoordinateSystem): + def fit_to_coordinate_system(self, coordinate_system: CoordinateSystem) -> None: """Scale the vector field to fit a coordinate system. This method is useful when the vector field is defined in a coordinate system @@ -200,7 +206,7 @@ def nudge( dt: float = 1, substeps: int = 1, pointwise: bool = False, - ) -> VectorField: + ) -> Self: """Nudge a :class:`~.Mobject` along the vector field. Parameters @@ -284,7 +290,7 @@ def nudge_submobjects( dt: float = 1, substeps: int = 1, pointwise: bool = False, - ) -> VectorField: + ) -> Self: """Apply a nudge along the vector field to all submobjects. Parameters @@ -335,7 +341,7 @@ def start_submobject_movement( self, speed: float = 1, pointwise: bool = False, - ) -> VectorField: + ) -> Self: """Start continuously moving all submobjects along the vector field. Calling this method multiple times will result in removing the previous updater created by this method. @@ -361,7 +367,7 @@ def start_submobject_movement( self.add_updater(self.submob_movement_updater) return self - def stop_submobject_movement(self) -> VectorField: + def stop_submobject_movement(self) -> Self: """Stops the continuous movement started using :meth:`start_submobject_movement`. Returns @@ -455,7 +461,7 @@ def func(values, opacity=1): return func -class ArrowVectorField(VectorField): +class ArrowVectorField(VectorField[Vector]): """A :class:`VectorField` represented by a set of change vectors. Vector fields are always based on a function defining the :class:`~.Vector` at every position. @@ -540,9 +546,9 @@ def construct(self): def __init__( self, - func: Callable[[np.ndarray], np.ndarray], + func: MappingFunction, color: ParsableManimColor | None = None, - color_scheme: Callable[[np.ndarray], float] | None = None, + color_scheme: Callable[[npt.NDArray], float] | None = None, min_color_scheme_value: float = 0, max_color_scheme_value: float = 2, colors: Sequence[ParsableManimColor] = DEFAULT_SCALAR_FIELD_COLORS, @@ -608,7 +614,7 @@ def __init__( ) self.set_opacity(self.opacity) - def get_vector(self, point: np.ndarray): + def get_vector(self, point: Point3D) -> Vector: """Creates a vector in the vector field. The created vector is based on the function of the vector field and is @@ -634,7 +640,7 @@ def get_vector(self, point: np.ndarray): return vect -class StreamLines(VectorField): +class StreamLines(VectorField[VMobjectT]): """StreamLines represent the flow of a :class:`VectorField` using the trace of moving agents. Vector fields are always based on a function defining the vector at every position. @@ -714,9 +720,9 @@ def construct(self): def __init__( self, - func: Callable[[np.ndarray], np.ndarray], + func: MappingFunction, color: ParsableManimColor | None = None, - color_scheme: Callable[[np.ndarray], float] | None = None, + color_scheme: Callable[[npt.NDArray], float] | None = None, min_color_scheme_value: float = 0, max_color_scheme_value: float = 2, colors: Sequence[ParsableManimColor] = DEFAULT_SCALAR_FIELD_COLORS, diff --git a/manim/scene/vector_space_scene.py b/manim/scene/vector_space_scene.py index be75151471..27361c42a3 100644 --- a/manim/scene/vector_space_scene.py +++ b/manim/scene/vector_space_scene.py @@ -743,7 +743,7 @@ def add_moving_mobject( mobject.target = target_mobject self.add_special_mobjects(self.moving_mobjects, mobject) - def get_ghost_vectors(self) -> VGroup: + def get_ghost_vectors(self) -> VGroup[VGroup[Vector]]: """ Returns all ghost vectors ever added to ``self``. Each element is a ``VGroup`` of two ghost vectors.