Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow VGroup type subscripting #3606

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
c698f0b
Add __class_getitem__ to Mobject
JasonGrace2282 Feb 1, 2024
d31f6d0
feat(VGroup): Make VGroup Generic in VMobjectT
JasonGrace2282 Apr 30, 2024
88ea3d5
Fix undefined import
JasonGrace2282 Apr 30, 2024
401f71d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 30, 2024
1d895b2
Remove default= in TypeVar definition
JasonGrace2282 May 2, 2024
e807493
Merge branch 'main' into VGroup_type_subscripting
JasonGrace2282 May 23, 2024
03c830f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 23, 2024
0cee738
Add back import+default
JasonGrace2282 May 23, 2024
2b00112
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 23, 2024
c07ccf6
Move import to TYPE_CHECKING
JasonGrace2282 May 23, 2024
b8bf21c
Add Generic import back after merge conflict
JasonGrace2282 May 23, 2024
2d6c4df
Add missing import from merge conflict
JasonGrace2282 May 23, 2024
bc59d1e
Merge branch 'main' into VGroup_type_subscripting
JasonGrace2282 Jun 19, 2024
8777b9b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 19, 2024
ab2a7a7
Merge branch 'main' into VGroup_type_subscripting
JasonGrace2282 Jun 26, 2024
7d1834d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 26, 2024
0e7e26e
Change list->Sequence
JasonGrace2282 Jun 26, 2024
7c5e64a
Merge branch 'main' into VGroup_type_subscripting
JasonGrace2282 Dec 7, 2024
8cbe628
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/source/contributing/docs/typings.rst
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,14 @@ Typing guidelines
from manim.typing import Vector3D
# type stuff with Vector3D

* When typing something like ``VGroup``, type it as if it were a list, not as if it was a tuple.

.. code:: py
# not VGroup[Line, Line]
def get_two_lines() -> VGroup[Line]:
return VGroup(Line(), Line().shift(LEFT))


Missing Sections for typehints are:
-----------------------------------

Expand Down
17 changes: 9 additions & 8 deletions manim/animation/changing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@

__all__ = ["AnimatedBoundary", "TracedPath"]

from collections.abc import Sequence
from typing import Callable

from manim.mobject.opengl.opengl_compatibility import ConvertToOpenGL
from manim.mobject.types.vectorized_mobject import VGroup, VMobject
from manim.mobject.types.vectorized_mobject import VGroup, VMobject, VMobjectT
from manim.utils.color import (
BLUE_B,
BLUE_D,
Expand All @@ -19,7 +20,7 @@
from manim.utils.rate_functions import smooth


class AnimatedBoundary(VGroup):
class AnimatedBoundary(VGroup[VMobjectT]):
"""Boundary of a :class:`.VMobject` with animated color change.

Examples
Expand All @@ -38,11 +39,11 @@ def construct(self):

def __init__(
self,
vmobject,
colors=[BLUE_D, BLUE_B, BLUE_E, GREY_BROWN],
max_stroke_width=3,
cycle_rate=0.5,
back_and_forth=True,
vmobject: VMobjectT,
colors: Sequence[ParsableManimColor] = [BLUE_D, BLUE_B, BLUE_E, GREY_BROWN],
max_stroke_width: float = 3,
cycle_rate: float = 0.5,
back_and_forth: bool = True,
draw_rate_func=smooth,
fade_rate_func=smooth,
**kwargs,
Expand All @@ -60,7 +61,7 @@ def __init__(
]
self.add(*self.boundary_copies)
self.total_time = 0
self.add_updater(lambda m, dt: self.update_boundary_copies(dt))
self.add_updater(lambda _, dt: self.update_boundary_copies(dt))

def update_boundary_copies(self, dt):
# Not actual time, but something which passes at
Expand Down
14 changes: 6 additions & 8 deletions manim/mobject/geometry/shape_matchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@

__all__ = ["SurroundingRectangle", "BackgroundRectangle", "Cross", "Underline"]

from typing import Any

from typing_extensions import Self
from typing import TYPE_CHECKING, Any

from manim import logger
from manim._config import config
Expand All @@ -15,14 +13,16 @@
LEFT,
RIGHT,
SMALL_BUFF,
UP,
)
from manim.mobject.geometry.line import Line
from manim.mobject.geometry.polygram import RoundedRectangle
from manim.mobject.mobject import Mobject
from manim.mobject.types.vectorized_mobject import VGroup
from manim.utils.color import BLACK, RED, YELLOW, ManimColor, ParsableManimColor

if TYPE_CHECKING:

Check failure

Code scanning / CodeQL

Module-level cyclic import Error

'VGroup' may not be defined if module
manim.mobject.types.vectorized_mobject
is imported before module
manim.mobject.geometry.shape_matchers
, as the
definition
of VGroup occurs after the cyclic
import
of manim.mobject.geometry.shape_matchers.
'VGroup' may not be defined if module
manim.mobject.types.vectorized_mobject
is imported before module
manim.mobject.geometry.shape_matchers
, as the
definition
of VGroup occurs after the cyclic
import
of manim.mobject.geometry.shape_matchers.
from typing_extensions import Self


class SurroundingRectangle(RoundedRectangle):
r"""A rectangle surrounding a :class:`~.Mobject`
Expand Down Expand Up @@ -151,7 +151,7 @@
return temp_color


class Cross(VGroup):
class Cross(VGroup[Line]):
"""Creates a cross.

Parameters
Expand Down Expand Up @@ -184,9 +184,7 @@
scale_factor: float = 1.0,
**kwargs: Any,
) -> None:
super().__init__(
Line(UP + LEFT, DOWN + RIGHT), Line(UP + RIGHT, DOWN + LEFT), **kwargs
)
super().__init__(Line(UL, DR), Line(UR, DL), **kwargs)
if mobject is not None:
self.replace(mobject, stretch=True)
self.scale(scale_factor)
Expand Down
4 changes: 2 additions & 2 deletions manim/mobject/logo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ..animation.creation import Create, SpiralIn
from ..animation.fading import FadeIn
from ..mobject.svg.svg_mobject import VMobjectFromSVGPath
from ..mobject.types.vectorized_mobject import VGroup
from ..mobject.types.vectorized_mobject import VGroup, VMobjectT
from ..utils.rate_functions import ease_in_out_cubic, smooth

MANIM_SVG_PATHS: list[se.Path] = [
Expand Down Expand Up @@ -100,7 +100,7 @@
]


class ManimBanner(VGroup):
class ManimBanner(VGroup[VMobjectT]):
r"""Convenience class representing Manim's banner.

Can be animated using custom methods.
Expand Down
4 changes: 2 additions & 2 deletions manim/mobject/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,12 @@ def construct(self):
from ..animation.composition import AnimationGroup
from ..animation.creation import Create, Write
from ..animation.fading import FadeIn
from ..mobject.types.vectorized_mobject import VGroup, VMobject
from ..mobject.types.vectorized_mobject import VGroup, VMobject, VMobjectT
from ..utils.color import BLACK, YELLOW, ManimColor, ParsableManimColor
from .utils import get_vectorized_mobject_class


class Table(VGroup):
class Table(VGroup[VMobjectT]):
r"""A mobject that displays a table on the screen.

Parameters
Expand Down
4 changes: 2 additions & 2 deletions manim/mobject/text/code_mobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@
from manim.mobject.geometry.polygram import RoundedRectangle
from manim.mobject.geometry.shape_matchers import SurroundingRectangle
from manim.mobject.text.text_mobject import Paragraph
from manim.mobject.types.vectorized_mobject import VGroup
from manim.mobject.types.vectorized_mobject import VGroup, VMobjectT
from manim.utils.color import WHITE


class Code(VGroup):
class Code(VGroup[VMobjectT]):
"""A highlighted source code listing.

An object ``listing`` of :class:`.Code` is a :class:`.VGroup` consisting
Expand Down
4 changes: 2 additions & 2 deletions manim/mobject/text/text_mobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
from manim.constants import *
from manim.mobject.geometry.arc import Dot
from manim.mobject.svg.svg_mobject import SVGMobject
from manim.mobject.types.vectorized_mobject import VGroup, VMobject
from manim.mobject.types.vectorized_mobject import VGroup, VMobject, VMobjectT
Fixed Show fixed Hide fixed

Check failure

Code scanning / CodeQL

Module-level cyclic import Error

'VGroup' may not be defined if module
manim.mobject.types.vectorized_mobject
is imported before module
manim.mobject.text.text_mobject
, as the
definition
of VGroup occurs after the cyclic
import
of manim.mobject.text.text_mobject.
'VGroup' may not be defined if module
manim.mobject.types.vectorized_mobject
is imported before module
manim.mobject.text.text_mobject
, as the
definition
of VGroup occurs after the cyclic
import
of manim.mobject.text.text_mobject.

Check failure

Code scanning / CodeQL

Module-level cyclic import Error

'VMobject' may not be defined if module
manim.mobject.types.vectorized_mobject
is imported before module
manim.mobject.text.text_mobject
, as the
definition
of VMobject occurs after the cyclic
import
of manim.mobject.text.text_mobject.
'VMobject' may not be defined if module
manim.mobject.types.vectorized_mobject
is imported before module
manim.mobject.text.text_mobject
, as the
definition
of VMobject occurs after the cyclic
import
of manim.mobject.text.text_mobject.

Check failure

Code scanning / CodeQL

Module-level cyclic import Error

'VMobjectT' may not be defined if module
manim.mobject.types.vectorized_mobject
is imported before module
manim.mobject.text.text_mobject
, as the
definition
of VMobjectT occurs after the cyclic
import
of manim.mobject.text.text_mobject.
'VMobjectT' may not be defined if module
manim.mobject.types.vectorized_mobject
is imported before module
manim.mobject.text.text_mobject
, as the
definition
of VMobjectT occurs after the cyclic
import
of manim.mobject.text.text_mobject.
from manim.utils.color import ManimColor, ParsableManimColor, color_gradient
from manim.utils.deprecation import deprecated

Expand Down Expand Up @@ -115,7 +115,7 @@
return mobject_without_dots


class Paragraph(VGroup):
class Paragraph(VGroup[VMobjectT]):
r"""Display a paragraph of text.

For a given :class:`.Paragraph` ``par``, the attribute ``par.chars`` is a
Expand Down
4 changes: 2 additions & 2 deletions manim/mobject/three_d/polyhedra.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from manim.mobject.geometry.polygram import Polygon
from manim.mobject.graph import Graph
from manim.mobject.three_d.three_dimensions import Dot3D
from manim.mobject.types.vectorized_mobject import VGroup
from manim.mobject.types.vectorized_mobject import VGroup, VMobjectT
from manim.utils.qhull import QuickHull

if TYPE_CHECKING:
Expand All @@ -26,7 +26,7 @@
]


class Polyhedron(VGroup):
class Polyhedron(VGroup[VMobjectT]):
"""An abstract polyhedra class.

In this implementation, polyhedra are defined with a list of vertex coordinates in space, and a list
Expand Down
9 changes: 7 additions & 2 deletions manim/mobject/three_d/three_dimensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,12 @@
from manim.mobject.mobject import *
from manim.mobject.opengl.opengl_compatibility import ConvertToOpenGL
from manim.mobject.opengl.opengl_mobject import OpenGLMobject
from manim.mobject.types.vectorized_mobject import VectorizedPoint, VGroup, VMobject
from manim.mobject.types.vectorized_mobject import (
VectorizedPoint,
VGroup,
VMobject,
VMobjectT,
)
from manim.utils.color import (
ManimColor,
ParsableManimColor,
Expand Down Expand Up @@ -458,7 +463,7 @@ def __init__(
self.set_color(color)


class Cube(VGroup):
class Cube(VGroup[VMobjectT]):
"""A three-dimensional cube.

Parameters
Expand Down
17 changes: 10 additions & 7 deletions manim/mobject/types/vectorized_mobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@

import itertools as it
import sys
from collections.abc import Generator, Hashable, Iterable, Mapping, Sequence
from typing import TYPE_CHECKING, Callable, Literal
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal

import numpy as np
from PIL.Image import Image
from typing_extensions import TypeVar

from manim import config
from manim.constants import *
Expand Down Expand Up @@ -48,7 +48,7 @@
from manim.utils.space_ops import rotate_vector, shoelace_direction

if TYPE_CHECKING:
from typing import Any
from collections.abc import Generator, Hashable, Iterable, Mapping, Sequence

import numpy.typing as npt
from typing_extensions import Self
Expand Down Expand Up @@ -583,7 +583,7 @@ def get_fill_colors(self) -> list[ManimColor | None]:
def get_fill_opacities(self) -> npt.NDArray[ManimFloat]:
return self.get_fill_rgbas()[:, 3]

def get_stroke_rgbas(self, background: bool = False) -> RGBA_Array_float | Zeros:
def get_stroke_rgbas(self, background: bool = False) -> RGBA_Array_Float | Zeros:
try:
if background:
self.background_stroke_rgbas: RGBA_Array_Float
Expand Down Expand Up @@ -2059,7 +2059,10 @@ def force_direction(self, target_direction: Literal["CW", "CCW"]) -> Self:
return self


class VGroup(VMobject, metaclass=ConvertToOpenGL):
VMobjectT = TypeVar("VMobjectT", bound=VMobject, default=VMobject)
Fixed Show fixed Hide fixed


class VGroup(VMobject, Generic[VMobjectT], metaclass=ConvertToOpenGL):
"""A group of vectorized mobjects.

This can be used to group multiple :class:`~.VMobject` instances together
Expand Down Expand Up @@ -2119,7 +2122,7 @@ def construct(self):
"""

def __init__(
self, *vmobjects: VMobject | Iterable[VMobject], **kwargs: Any
self, *vmobjects: VMobjectT | Iterable[VMobjectT], **kwargs: Any
) -> None:
super().__init__(**kwargs)
self.add(*vmobjects)
Expand Down Expand Up @@ -2660,7 +2663,7 @@ def set_location(self, new_loc: Point3D):
self.set_points(np.array([new_loc]))


class CurvesAsSubmobjects(VGroup):
class CurvesAsSubmobjects(VGroup[VMobject]):
"""Convert a curve's elements to submobjects.

Examples
Expand Down
Loading
Loading