From 9b3bf7cc5bf691254b6638297c28845df2194052 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francisco=20Manr=C3=ADquez?= Date: Wed, 13 Nov 2024 09:35:12 -0300 Subject: [PATCH 1/3] Add typings to tex_mobject.py and numbers.py --- manim/mobject/mobject.py | 33 +-- manim/mobject/svg/svg_mobject.py | 10 +- manim/mobject/text/numbers.py | 119 ++++++---- manim/mobject/text/tex_mobject.py | 264 ++++++++++++---------- manim/mobject/types/vectorized_mobject.py | 7 +- manim/mobject/value_tracker.py | 7 +- manim/utils/tex_file_writing.py | 2 +- 7 files changed, 255 insertions(+), 187 deletions(-) diff --git a/manim/mobject/mobject.py b/manim/mobject/mobject.py index 0359f66045..5c4e562324 100644 --- a/manim/mobject/mobject.py +++ b/manim/mobject/mobject.py @@ -14,10 +14,9 @@ import sys import types import warnings -from collections.abc import Iterable from functools import partialmethod, reduce from pathlib import Path -from typing import TYPE_CHECKING, Callable, Literal +from typing import TYPE_CHECKING import numpy as np @@ -40,13 +39,15 @@ from ..utils.space_ops import angle_between_vectors, normalize, rotation_matrix if TYPE_CHECKING: + from collections.abc import Iterable, Sequence + from typing import Callable, Literal + from typing_extensions import Self, TypeAlias from manim.typing import ( FunctionOverride, InternalPoint3D, ManimFloat, - ManimInt, MappingFunction, PathFuncType, PixelArray, @@ -100,18 +101,18 @@ def __init__( color: ParsableManimColor | list[ParsableManimColor] = WHITE, name: str | None = None, dim: int = 3, - target=None, + target: Mobject | None = None, z_index: float = 0, ) -> None: - self.name = self.__class__.__name__ if name is None else name - self.dim = dim - self.target = target - self.z_index = z_index + self.name: str = self.__class__.__name__ if name is None else name + self.dim: int = dim + self.target: Mobject | None = target + self.z_index: float = z_index self.point_hash = None - self.submobjects = [] + self.submobjects: Sequence[Mobject] = [] self.updaters: list[Updater] = [] self.updating_suspended = False - self.color = ManimColor.parse(color) + self.color: ManimColor | list[ManimColor] = ManimColor.parse(color) self.reset_points() self.generate_points() @@ -2291,16 +2292,16 @@ def get_mobject_type_class() -> type[Mobject]: """Return the base class of this mobject type.""" return Mobject - def split(self) -> list[Self]: + def split(self) -> Sequence[Self]: result = [self] if len(self.points) > 0 else [] return result + self.submobjects - def get_family(self, recurse: bool = True) -> list[Self]: + def get_family(self, recurse: bool = True) -> Sequence[Self]: sub_families = [x.get_family() for x in self.submobjects] all_mobjects = [self] + list(it.chain(*sub_families)) return remove_list_redundancies(all_mobjects) - def family_members_with_points(self) -> list[Self]: + def family_members_with_points(self) -> Sequence[Self]: return [m for m in self.get_family() if m.get_num_points() > 0] def arrange( @@ -2576,13 +2577,13 @@ def init_sizes(sizes, num, measures, name): def sort( self, - point_to_num_func: Callable[[Point3D], ManimInt] = lambda p: p[0], - submob_func: Callable[[Mobject], ManimInt] | None = None, + point_to_num_func: Callable[[Point3D], float] = lambda p: p[0], + submob_func: Callable[[Mobject], float] | None = None, ) -> Self: """Sorts the list of :attr:`submobjects` by a function defined by ``submob_func``.""" if submob_func is None: - def submob_func(m: Mobject): + def submob_func(m: Mobject) -> float: return point_to_num_func(m.get_center()) self.submobjects.sort(key=submob_func) diff --git a/manim/mobject/svg/svg_mobject.py b/manim/mobject/svg/svg_mobject.py index 82c121fce7..914c310c47 100644 --- a/manim/mobject/svg/svg_mobject.py +++ b/manim/mobject/svg/svg_mobject.py @@ -4,6 +4,7 @@ import os from pathlib import Path +from typing import TYPE_CHECKING from xml.etree import ElementTree as ET import numpy as np @@ -21,6 +22,9 @@ from ..opengl.opengl_compatibility import ConvertToOpenGL from ..types.vectorized_mobject import VMobject +if TYPE_CHECKING: + from manim.utils.color import ParsableManimColor + __all__ = ["SVGMobject", "VMobjectFromSVGPath"] @@ -98,11 +102,11 @@ def __init__( should_center: bool = True, height: float | None = 2, width: float | None = None, - color: str | None = None, + color: ParsableManimColor | None = None, opacity: float | None = None, - fill_color: str | None = None, + fill_color: ParsableManimColor | None = None, fill_opacity: float | None = None, - stroke_color: str | None = None, + stroke_color: ParsableManimColor | None = None, stroke_opacity: float | None = None, stroke_width: float | None = None, svg_default: dict | None = None, diff --git a/manim/mobject/text/numbers.py b/manim/mobject/text/numbers.py index 5283c24a20..0c3bb0a02f 100644 --- a/manim/mobject/text/numbers.py +++ b/manim/mobject/text/numbers.py @@ -4,19 +4,30 @@ __all__ = ["DecimalNumber", "Integer", "Variable"] -from collections.abc import Sequence +from typing import TYPE_CHECKING import numpy as np -from manim import config +from manim._config import config from manim.constants import * from manim.mobject.opengl.opengl_compatibility import ConvertToOpenGL -from manim.mobject.text.tex_mobject import MathTex, SingleStringMathTex, Tex -from manim.mobject.text.text_mobject import Text +from manim.mobject.text.tex_mobject import MathTex, SingleStringMathTex from manim.mobject.types.vectorized_mobject import VMobject from manim.mobject.value_tracker import ValueTracker -string_to_mob_map = {} +if TYPE_CHECKING: + from typing import Any, Union + + from typing_extensions import Self, TypeAlias + + from manim.mobject.text.tex_mobject import Tex + from manim.mobject.text.text_mobject import MarkupText, Text + from manim.typing import Vector3D + + TextLike: TypeAlias = Union[SingleStringMathTex, MathTex, Tex, Text, MarkupText] + + +string_to_mob_map: dict[str, TextLike] = {} __all__ = ["DecimalNumber", "Integer", "Variable"] @@ -83,9 +94,9 @@ def construct(self): def __init__( self, - number: float = 0, + number: float | complex = 0, num_decimal_places: int = 2, - mob_class: VMobject = MathTex, + mob_class: type[TextLike] = MathTex, include_sign: bool = False, group_with_commas: bool = True, digit_buff_per_font_unit: float = 0.001, @@ -93,28 +104,28 @@ def __init__( unit: str | None = None, # Aligned to bottom unless it starts with "^" unit_buff_per_font_unit: float = 0, include_background_rectangle: bool = False, - edge_to_fix: Sequence[float] = LEFT, + edge_to_fix: Vector3D = LEFT, font_size: float = DEFAULT_FONT_SIZE, stroke_width: float = 0, fill_opacity: float = 1.0, - **kwargs, - ): + **kwargs: Any, + ) -> None: super().__init__(**kwargs, stroke_width=stroke_width) - self.number = number - self.num_decimal_places = num_decimal_places - self.include_sign = include_sign - self.mob_class = mob_class - self.group_with_commas = group_with_commas - self.digit_buff_per_font_unit = digit_buff_per_font_unit - self.show_ellipsis = show_ellipsis - self.unit = unit - self.unit_buff_per_font_unit = unit_buff_per_font_unit - self.include_background_rectangle = include_background_rectangle - self.edge_to_fix = edge_to_fix - self._font_size = font_size - self.fill_opacity = fill_opacity - - self.initial_config = kwargs.copy() + self.number: float | complex = number + self.num_decimal_places: int = num_decimal_places + self.include_sign: bool = include_sign + self.mob_class: type[TextLike] = mob_class + self.group_with_commas: bool = group_with_commas + self.digit_buff_per_font_unit: float = digit_buff_per_font_unit + self.show_ellipsis: bool = show_ellipsis + self.unit: str | None = unit + self.unit_buff_per_font_unit: float = unit_buff_per_font_unit + self.include_background_rectangle: bool = include_background_rectangle + self.edge_to_fix: Vector3D = edge_to_fix + self._font_size: float = font_size + self.fill_opacity: float = fill_opacity + + self.initial_config: dict[str, Any] = kwargs.copy() self.initial_config.update( { "num_decimal_places": num_decimal_places, @@ -136,12 +147,12 @@ def __init__( self.init_colors() @property - def font_size(self): + def font_size(self) -> float: """The font size of the tex mobject.""" return self.height / self.initial_height * self._font_size @font_size.setter - def font_size(self, font_val): + def font_size(self, font_val: float) -> None: if font_val <= 0: raise ValueError("font_size must be greater than 0.") elif self.height > 0: @@ -152,7 +163,7 @@ def font_size(self, font_val): # font_size does not depend on current size. self.scale(font_val / self.font_size) - def _set_submobjects_from_number(self, number): + def _set_submobjects_from_number(self, number: float | complex) -> None: self.number = number self.submobjects = [] @@ -161,8 +172,9 @@ def _set_submobjects_from_number(self, number): # Add non-numerical bits if self.show_ellipsis: + # TODO: Why MyPy 'cannot determine type of "color"'? self.add( - self._string_to_mob("\\dots", SingleStringMathTex, color=self.color), + self._string_to_mob(r"\dots", SingleStringMathTex, color=self.color), # type: ignore [has-type] ) self.arrange( @@ -196,12 +208,12 @@ def _set_submobjects_from_number(self, number): self.unit_sign.align_to(self, UP) # track the initial height to enable scaling via font_size - self.initial_height = self.height + self.initial_height: float = self.height if self.include_background_rectangle: self.add_background_rectangle() - def _get_num_string(self, number): + def _get_num_string(self, number: float | complex) -> str: if isinstance(number, complex): formatter = self._get_complex_formatter() else: @@ -214,17 +226,21 @@ def _get_num_string(self, number): return num_string - def _string_to_mob(self, string: str, mob_class: VMobject | None = None, **kwargs): + def _string_to_mob( + self, string: str, mob_class: type[TextLike] | None = None, **kwargs: Any + ) -> TextLike: if mob_class is None: mob_class = self.mob_class + _mob_class = self.mob_class if mob_class is None else mob_class + if string not in string_to_mob_map: - string_to_mob_map[string] = mob_class(string, **kwargs) + string_to_mob_map[string] = _mob_class(string, **kwargs) mob = string_to_mob_map[string].copy() mob.font_size = self._font_size return mob - def _get_formatter(self, **kwargs): + def _get_formatter(self, **kwargs: Any) -> str: """ Configuration is based first off instance attributes, but overwritten by any kew word argument. Relevant @@ -257,16 +273,16 @@ def _get_formatter(self, **kwargs): ], ) - def _get_complex_formatter(self): + def _get_complex_formatter(self, **kwargs: Any) -> str: return "".join( [ - self._get_formatter(field_name="0.real"), - self._get_formatter(field_name="0.imag", include_sign=True), + self._get_formatter(field_name="0.real", **kwargs), + self._get_formatter(field_name="0.imag", **kwargs, include_sign=True), "i", ], ) - def set_value(self, number: float): + def set_value(self, number: float | complex) -> Self: """Set the value of the :class:`~.DecimalNumber` to a new number. Parameters @@ -303,11 +319,12 @@ def set_value(self, number: float): self.init_colors() return self - def get_value(self): + def get_value(self) -> float | complex: return self.number - def increment_value(self, delta_t=1): + def increment_value(self, delta_t: float | complex = 1.0) -> Self: self.set_value(self.get_value() + delta_t) + return self class Integer(DecimalNumber): @@ -327,10 +344,15 @@ def construct(self): self.add(Integer(number=6.28).set_x(-1.5).set_y(-2).set_color(YELLOW).scale(1.4)) """ - def __init__(self, number=0, num_decimal_places=0, **kwargs): + def __init__( + self, + number: float | complex = 0, + num_decimal_places: int = 0, + **kwargs: Any, + ) -> None: super().__init__(number=number, num_decimal_places=num_decimal_places, **kwargs) - def get_value(self): + def get_value(self) -> int: return int(np.round(super().get_value())) @@ -441,16 +463,19 @@ def __init__( self, var: float, label: str | Tex | MathTex | Text | SingleStringMathTex, - var_type: DecimalNumber | Integer = DecimalNumber, + var_type: type[DecimalNumber | Integer] = DecimalNumber, num_decimal_places: int = 2, - **kwargs, - ): - self.label = MathTex(label) if isinstance(label, str) else label + **kwargs: Any, + ) -> None: + self.label: Tex | MathTex | Text | SingleStringMathTex = ( + MathTex(label) if isinstance(label, str) else label + ) equals = MathTex("=").next_to(self.label, RIGHT) self.label.add(equals) - self.tracker = ValueTracker(var) + self.tracker: ValueTracker = ValueTracker(var) + self.value: DecimalNumber | Integer if var_type == DecimalNumber: self.value = DecimalNumber( self.tracker.get_value(), diff --git a/manim/mobject/text/tex_mobject.py b/manim/mobject/text/tex_mobject.py index 26334a60d9..1c4fa7add2 100644 --- a/manim/mobject/text/tex_mobject.py +++ b/manim/mobject/text/tex_mobject.py @@ -12,8 +12,6 @@ from __future__ import annotations -from manim.utils.color import BLACK, ManimColor, ParsableManimColor - __all__ = [ "SingleStringMathTex", "MathTex", @@ -23,22 +21,27 @@ ] -import itertools as it import operator as op import re -from collections.abc import Iterable from functools import reduce from textwrap import dedent +from typing import TYPE_CHECKING from manim import config, logger from manim.constants import * from manim.mobject.geometry.line import Line from manim.mobject.svg.svg_mobject import SVGMobject from manim.mobject.types.vectorized_mobject import VGroup, VMobject -from manim.utils.tex import TexTemplate +from manim.utils.color import BLACK from manim.utils.tex_file_writing import tex_to_svg_file -tex_string_to_mob_map = {} +if TYPE_CHECKING: + from collections.abc import Iterable, Sequence + + from typing_extensions import Any, Self + + from manim.utils.color import ParsableManimColor + from manim.utils.tex import TexTemplate class SingleStringMathTex(SVGMobject): @@ -63,27 +66,28 @@ def __init__( tex_template: TexTemplate | None = None, font_size: float = DEFAULT_FONT_SIZE, color: ParsableManimColor | None = None, - **kwargs, - ): + **kwargs: Any, + ) -> None: if color is None: - color = VMobject().color + # TODO: Why MyPy 'cannot determine type of "color"'? + color = VMobject().color # type: ignore [has-type] - self._font_size = font_size - self.organize_left_to_right = organize_left_to_right - self.tex_environment = tex_environment + self._font_size: float = font_size + self.organize_left_to_right: bool = organize_left_to_right + self.tex_environment: str = tex_environment if tex_template is None: tex_template = config["tex_template"] - self.tex_template = tex_template + self.tex_template: TexTemplate = tex_template assert isinstance(tex_string, str) - self.tex_string = tex_string - file_name = tex_to_svg_file( + self.tex_string: str = tex_string + file_path = tex_to_svg_file( self._get_modified_expression(tex_string), environment=self.tex_environment, tex_template=self.tex_template, ) super().__init__( - file_name=file_name, + file_name=file_path, should_center=should_center, stroke_width=stroke_width, height=height, @@ -97,24 +101,24 @@ def __init__( self.init_colors() # used for scaling via font_size.setter - self.initial_height = self.height + self.initial_height: float = self.height if height is None: - self.font_size = self._font_size + self.font_size: float = self._font_size if self.organize_left_to_right: self._organize_submobjects_left_to_right() - def __repr__(self): + def __repr__(self) -> str: return f"{type(self).__name__}({repr(self.tex_string)})" @property - def font_size(self): + def font_size(self) -> float: """The font size of the tex mobject.""" return self.height / self.initial_height / SCALE_FACTOR_PER_FONT_POINT @font_size.setter - def font_size(self, font_val): + def font_size(self, font_val: float) -> None: if font_val <= 0: raise ValueError("font_size must be greater than 0.") elif self.height > 0: @@ -125,14 +129,14 @@ def font_size(self, font_val): # font_size does not depend on current size. self.scale(font_val / self.font_size) - def _get_modified_expression(self, tex_string): - result = tex_string - result = result.strip() - result = self._modify_special_strings(result) - return result - - def _modify_special_strings(self, tex): + def _get_modified_expression(self, tex_string: str) -> str: + tex = tex_string tex = tex.strip() + tex = self._modify_special_strings(tex) + return tex + + def _modify_special_strings(self, tex_string: str) -> str: + tex = tex_string.strip() should_add_filler = reduce( op.or_, [ @@ -184,13 +188,13 @@ def _modify_special_strings(self, tex): tex = "" return tex - def _remove_stray_braces(self, tex): - r""" - Makes :class:`~.MathTex` resilient to unmatched braces. + def _remove_stray_braces(self, tex_string: str) -> str: + r"""Makes :class:`~.MathTex` resilient to unmatched braces. This is important when the braces in the TeX code are spread over multiple arguments as in, e.g., ``MathTex(r"e^{i", r"\tau} = 1")``. """ + tex = tex_string # "\{" does not count (it's a brace literal), but "\\{" counts (it's a new line and then brace) num_lefts = tex.count("{") - tex.count("\\{") + tex.count("\\\\{") num_rights = tex.count("}") - tex.count("\\}") + tex.count("\\\\}") @@ -202,24 +206,26 @@ def _remove_stray_braces(self, tex): num_rights += 1 return tex - def _organize_submobjects_left_to_right(self): + def _organize_submobjects_left_to_right(self) -> Self: self.sort(lambda p: p[0]) return self - def get_tex_string(self): + def get_tex_string(self) -> str: return self.tex_string - def init_colors(self, propagate_colors=True): + def init_colors(self, propagate_colors: bool = True) -> Self: for submobject in self.submobjects: # needed to preserve original (non-black) # TeX colors of individual submobjects - if submobject.color != BLACK: + # TODO: Why MyPy 'cannot determine type of "color"'? + if submobject.color != BLACK: # type: ignore [has-type] continue - submobject.color = self.color + submobject.color = self.color # type: ignore [has-type] if config.renderer == RendererType.OPENGL: submobject.init_colors() elif config.renderer == RendererType.CAIRO: submobject.init_colors(propagate_colors=propagate_colors) + return self class MathTex(SingleStringMathTex): @@ -255,24 +261,28 @@ def construct(self): def __init__( self, - *tex_strings, + *tex_strings: str, arg_separator: str = " ", substrings_to_isolate: Iterable[str] | None = None, - tex_to_color_map: dict[str, ManimColor] = None, + tex_to_color_map: dict[str | Iterable[str], ParsableManimColor] | None = None, tex_environment: str = "align*", - **kwargs, - ): - self.tex_template = kwargs.pop("tex_template", config["tex_template"]) - self.arg_separator = arg_separator - self.substrings_to_isolate = ( - [] if substrings_to_isolate is None else substrings_to_isolate + **kwargs: Any, + ) -> None: + self.tex_template: TexTemplate = kwargs.pop( + "tex_template", config["tex_template"] + ) + self.arg_separator: str = arg_separator + self.substrings_to_isolate: list[str] = ( + [] if substrings_to_isolate is None else list(substrings_to_isolate) ) - self.tex_to_color_map = tex_to_color_map - if self.tex_to_color_map is None: - self.tex_to_color_map = {} - self.tex_environment = tex_environment - self.brace_notation_split_occurred = False - self.tex_strings = self._break_up_tex_strings(tex_strings) + if tex_to_color_map is None: + tex_to_color_map = {} + self.tex_to_color_map: dict[str | Iterable[str], ParsableManimColor] = ( + tex_to_color_map + ) + self.tex_environment: str = tex_environment + self.brace_notation_split_occurred: bool = False + self.tex_strings: list[str] = self._break_up_tex_strings(tex_strings) try: super().__init__( self.arg_separator.join(self.tex_strings), @@ -301,42 +311,45 @@ def __init__( if self.organize_left_to_right: self._organize_submobjects_left_to_right() - def _break_up_tex_strings(self, tex_strings): + def _break_up_tex_strings(self, tex_strings: Sequence[str]) -> list[str]: # Separate out anything surrounded in double braces pre_split_length = len(tex_strings) - tex_strings = [re.split("{{(.*?)}}", str(t)) for t in tex_strings] - tex_strings = sum(tex_strings, []) - if len(tex_strings) > pre_split_length: + # TODO: do we need that str(t)? + pre_pieces_arr = [re.split("{{(.*?)}}", str(t)) for t in tex_strings] + pre_pieces = sum(pre_pieces_arr, []) + if len(pre_pieces) > pre_split_length: self.brace_notation_split_occurred = True # Separate out any strings specified in the isolate # or tex_to_color_map lists. - patterns = [] - patterns.extend( - [ - f"({re.escape(ss)})" - for ss in it.chain( - self.substrings_to_isolate, - self.tex_to_color_map.keys(), - ) - ], - ) + patterns: list[str] = self.substrings_to_isolate.copy() + for key in self.tex_to_color_map: + try: + # If the given key behaves like tex_strings + key + "" # type: ignore [operator] + patterns.append(key) # type: ignore [arg-type] + except TypeError: + # If the given key is a tuple + patterns.extend(key) + + patterns = [f"({re.escape(pattern)})" for pattern in patterns] pattern = "|".join(patterns) + + pieces: list[str] if pattern: pieces = [] - for s in tex_strings: - pieces.extend(re.split(pattern, s)) + for p in pre_pieces: + pieces.extend(re.split(pattern, p)) else: - pieces = tex_strings + pieces = pre_pieces return [p for p in pieces if p] - def _break_up_by_substrings(self): - """ - Reorganize existing submobjects one layer + def _break_up_by_substrings(self) -> Self: + """Reorganize existing submobjects one layer deeper based on the structure of tex_strings (as a list of tex_strings) """ - new_submobjects = [] + new_submobjects: list[VMobject] = [] curr_index = 0 for tex_string in self.tex_strings: sub_tex_mob = SingleStringMathTex( @@ -358,8 +371,10 @@ def _break_up_by_substrings(self): self.submobjects = new_submobjects return self - def get_parts_by_tex(self, tex, substring=True, case_sensitive=True): - def test(tex1, tex2): + def get_parts_by_tex( + self, tex_string: str, substring: bool = True, case_sensitive: bool = True + ) -> VGroup: + def test(tex1: str, tex2: str) -> bool: if not case_sensitive: tex1 = tex1.lower() tex2 = tex2.lower() @@ -368,21 +383,31 @@ def test(tex1, tex2): else: return tex1 == tex2 - return VGroup(*(m for m in self.submobjects if test(tex, m.get_tex_string()))) + return VGroup( + *(m for m in self.submobjects if test(tex_string, m.get_tex_string())) + ) - def get_part_by_tex(self, tex, **kwargs): - all_parts = self.get_parts_by_tex(tex, **kwargs) + def get_part_by_tex( + self, tex_string: str, **kwargs: Any + ) -> SingleStringMathTex | None: + all_parts = self.get_parts_by_tex(tex_string, **kwargs) return all_parts[0] if all_parts else None - def set_color_by_tex(self, tex, color, **kwargs): - parts_to_color = self.get_parts_by_tex(tex, **kwargs) + def set_color_by_tex( + self, tex_string: str, color: ParsableManimColor, **kwargs: Any + ) -> Self: + parts_to_color = self.get_parts_by_tex(tex_string, **kwargs) for part in parts_to_color: part.set_color(color) return self def set_opacity_by_tex( - self, tex: str, opacity: float = 0.5, remaining_opacity: float = None, **kwargs - ): + self, + tex_string: str, + opacity: float = 0.5, + remaining_opacity: float | None = None, + **kwargs: Any, + ) -> Self: """ Sets the opacity of the tex specified. If 'remaining_opacity' is specified, then the remaining tex will be set to that opacity. @@ -399,34 +424,39 @@ def set_opacity_by_tex( """ if remaining_opacity is not None: self.set_opacity(opacity=remaining_opacity) - for part in self.get_parts_by_tex(tex): + for part in self.get_parts_by_tex(tex_string): part.set_opacity(opacity) return self - def set_color_by_tex_to_color_map(self, texs_to_color_map, **kwargs): - for texs, color in list(texs_to_color_map.items()): + def set_color_by_tex_to_color_map( + self, + texes_to_color_map: dict[str | Iterable[str], ParsableManimColor], + **kwargs: Any, + ) -> Self: + for texes, color in list(texes_to_color_map.items()): try: # If the given key behaves like tex_strings - texs + "" - self.set_color_by_tex(texs, color, **kwargs) + texes + "" # type: ignore [operator] + self.set_color_by_tex(texes, color, **kwargs) # type: ignore [arg-type] except TypeError: # If the given key is a tuple - for tex in texs: + for tex in texes: self.set_color_by_tex(tex, color, **kwargs) return self - def index_of_part(self, part): + def index_of_part(self, part: MathTex | SingleStringMathTex | None) -> int: split_self = self.split() if part not in split_self: raise ValueError("Trying to get index of part not in MathTex") return split_self.index(part) - def index_of_part_by_tex(self, tex, **kwargs): - part = self.get_part_by_tex(tex, **kwargs) + def index_of_part_by_tex(self, tex_string: str, **kwargs: Any) -> int: + part = self.get_part_by_tex(tex_string, **kwargs) return self.index_of_part(part) - def sort_alphabetically(self): + def sort_alphabetically(self) -> Self: self.submobjects.sort(key=lambda m: m.get_tex_string()) + return self class Tex(MathTex): @@ -447,8 +477,12 @@ class Tex(MathTex): """ def __init__( - self, *tex_strings, arg_separator="", tex_environment="center", **kwargs - ): + self, + *tex_strings: Any, + arg_separator: str = "", + tex_environment: str = "center", + **kwargs: Any, + ) -> None: super().__init__( *tex_strings, arg_separator=arg_separator, @@ -477,27 +511,24 @@ def construct(self): def __init__( self, - *items, - buff=MED_LARGE_BUFF, - dot_scale_factor=2, - tex_environment=None, - **kwargs, - ): - self.buff = buff - self.dot_scale_factor = dot_scale_factor - self.tex_environment = tex_environment - line_separated_items = [s + "\\\\" for s in items] - super().__init__( - *line_separated_items, tex_environment=tex_environment, **kwargs - ) + *items: str, + buff: float = MED_LARGE_BUFF, + dot_scale_factor: float = 2, + **kwargs: Any, + ) -> None: + self.buff: float = buff + self.dot_scale_factor: float = dot_scale_factor + line_separated_items: list[str] = [s + r"\\" for s in items] + super().__init__(*line_separated_items, **kwargs) for part in self: - dot = MathTex("\\cdot").scale(self.dot_scale_factor) + dot = MathTex(r"\cdot").scale(self.dot_scale_factor) dot.next_to(part[0], LEFT, SMALL_BUFF) part.add_to_back(dot) self.arrange(DOWN, aligned_edge=LEFT, buff=self.buff) - def fade_all_but(self, index_or_string, opacity=0.5): + def fade_all_but(self, index_or_string: int | str, opacity: float = 0.5) -> Self: arg = index_or_string + part: VMobject | None if isinstance(arg, str): part = self.get_part_by_tex(arg) elif isinstance(arg, int): @@ -509,6 +540,7 @@ def fade_all_but(self, index_or_string, opacity=0.5): other_part.set_fill(opacity=1) else: other_part.set_fill(opacity=opacity) + return self class Title(Tex): @@ -531,15 +563,15 @@ def construct(self): def __init__( self, - *text_parts, - include_underline=True, - match_underline_width_to_text=False, - underline_buff=MED_SMALL_BUFF, - **kwargs, - ): - self.include_underline = include_underline - self.match_underline_width_to_text = match_underline_width_to_text - self.underline_buff = underline_buff + *text_parts: str, + include_underline: bool = True, + match_underline_width_to_text: bool = False, + underline_buff: float = MED_SMALL_BUFF, + **kwargs: Any, + ) -> None: + self.include_underline: bool = include_underline + self.match_underline_width_to_text: bool = match_underline_width_to_text + self.underline_buff: float = underline_buff super().__init__(*text_parts, **kwargs) self.to_edge(UP) if self.include_underline: diff --git a/manim/mobject/types/vectorized_mobject.py b/manim/mobject/types/vectorized_mobject.py index a8d32682fd..128bb286cd 100644 --- a/manim/mobject/types/vectorized_mobject.py +++ b/manim/mobject/types/vectorized_mobject.py @@ -14,8 +14,8 @@ import itertools as it import sys -from collections.abc import Generator, Hashable, Iterable, Mapping, Sequence -from typing import TYPE_CHECKING, Callable, Literal +from collections.abc import Iterable +from typing import TYPE_CHECKING import numpy as np from PIL.Image import Image @@ -48,7 +48,8 @@ from manim.utils.space_ops import rotate_vector, shoelace_direction if TYPE_CHECKING: - from typing import Any + from collections.abc import Generator, Hashable, Mapping, Sequence + from typing import Any, Callable, Literal import numpy.typing as npt from typing_extensions import Self diff --git a/manim/mobject/value_tracker.py b/manim/mobject/value_tracker.py index 9d81035e89..aa8e595ec1 100644 --- a/manim/mobject/value_tracker.py +++ b/manim/mobject/value_tracker.py @@ -5,12 +5,17 @@ __all__ = ["ValueTracker", "ComplexValueTracker"] +from typing import TYPE_CHECKING + import numpy as np from manim.mobject.mobject import Mobject from manim.mobject.opengl.opengl_compatibility import ConvertToOpenGL from manim.utils.paths import straight_path +if TYPE_CHECKING: + from typing_extensions import Any + class ValueTracker(Mobject, metaclass=ConvertToOpenGL): """A mobject that can be used for tracking (real-valued) parameters. @@ -69,7 +74,7 @@ def construct(self): """ - def __init__(self, value=0, **kwargs): + def __init__(self, value: float = 0, **kwargs: Any) -> None: super().__init__(**kwargs) self.set(points=np.zeros((1, 3))) self.set_value(value) diff --git a/manim/utils/tex_file_writing.py b/manim/utils/tex_file_writing.py index 45e84d4907..8f0202192e 100644 --- a/manim/utils/tex_file_writing.py +++ b/manim/utils/tex_file_writing.py @@ -34,7 +34,7 @@ def tex_to_svg_file( expression: str, environment: str | None = None, tex_template: TexTemplate | None = None, -): +) -> Path: r"""Takes a tex expression and returns the svg version of the compiled tex Parameters From e96beca5e47b058c8c8738077f1c036d5a6ce270 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francisco=20Manr=C3=ADquez?= Date: Wed, 13 Nov 2024 09:47:08 -0300 Subject: [PATCH 2/3] Added missing changes to mypy.ini --- mypy.ini | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/mypy.ini b/mypy.ini index 65a77f1d00..59344fed1b 100644 --- a/mypy.ini +++ b/mypy.ini @@ -73,6 +73,12 @@ ignore_errors = True [mypy-manim.mobject.geometry.*] ignore_errors = False +[mypy-manim.mobject.text.numbers] +ignore_errors = False + +[mypy-manim.mobject.text.tex_mobject] +ignore_errors = False + [mypy-manim.plugins.*] ignore_errors = True From 139a3253a897a9df3c4861bef328051e5a699f4b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francisco=20Manr=C3=ADquez?= Date: Tue, 26 Nov 2024 23:44:15 -0300 Subject: [PATCH 3/3] Address requested changes --- manim/mobject/text/numbers.py | 6 ++--- manim/mobject/text/tex_mobject.py | 40 +++++++++++++++---------------- 2 files changed, 21 insertions(+), 25 deletions(-) diff --git a/manim/mobject/text/numbers.py b/manim/mobject/text/numbers.py index 0c3bb0a02f..aec81a3f6b 100644 --- a/manim/mobject/text/numbers.py +++ b/manim/mobject/text/numbers.py @@ -232,10 +232,8 @@ def _string_to_mob( if mob_class is None: mob_class = self.mob_class - _mob_class = self.mob_class if mob_class is None else mob_class - if string not in string_to_mob_map: - string_to_mob_map[string] = _mob_class(string, **kwargs) + string_to_mob_map[string] = mob_class(string, **kwargs) mob = string_to_mob_map[string].copy() mob.font_size = self._font_size return mob @@ -346,7 +344,7 @@ def construct(self): def __init__( self, - number: float | complex = 0, + number: float = 0, num_decimal_places: int = 0, **kwargs: Any, ) -> None: diff --git a/manim/mobject/text/tex_mobject.py b/manim/mobject/text/tex_mobject.py index 1c4fa7add2..94701b2953 100644 --- a/manim/mobject/text/tex_mobject.py +++ b/manim/mobject/text/tex_mobject.py @@ -104,7 +104,7 @@ def __init__( self.initial_height: float = self.height if height is None: - self.font_size: float = self._font_size + self.font_size = self._font_size if self.organize_left_to_right: self._organize_submobjects_left_to_right() @@ -323,14 +323,11 @@ def _break_up_tex_strings(self, tex_strings: Sequence[str]) -> list[str]: # Separate out any strings specified in the isolate # or tex_to_color_map lists. patterns: list[str] = self.substrings_to_isolate.copy() - for key in self.tex_to_color_map: - try: - # If the given key behaves like tex_strings - key + "" # type: ignore [operator] - patterns.append(key) # type: ignore [arg-type] - except TypeError: - # If the given key is a tuple - patterns.extend(key) + for tex_or_texes in self.tex_to_color_map: + if isinstance(tex_or_texes, str): + patterns.append(tex_or_texes) + else: + patterns += tex_or_texes patterns = [f"({re.escape(pattern)})" for pattern in patterns] pattern = "|".join(patterns) @@ -374,7 +371,7 @@ def _break_up_by_substrings(self) -> Self: def get_parts_by_tex( self, tex_string: str, substring: bool = True, case_sensitive: bool = True ) -> VGroup: - def test(tex1: str, tex2: str) -> bool: + def compare_tex(tex1: str, tex2: str) -> bool: if not case_sensitive: tex1 = tex1.lower() tex2 = tex2.lower() @@ -384,7 +381,11 @@ def test(tex1: str, tex2: str) -> bool: return tex1 == tex2 return VGroup( - *(m for m in self.submobjects if test(tex_string, m.get_tex_string())) + *( + m + for m in self.submobjects + if compare_tex(tex_string, m.get_tex_string()) + ) ) def get_part_by_tex( @@ -430,17 +431,14 @@ def set_opacity_by_tex( def set_color_by_tex_to_color_map( self, - texes_to_color_map: dict[str | Iterable[str], ParsableManimColor], + tex_to_color_map: dict[str | Iterable[str], ParsableManimColor], **kwargs: Any, ) -> Self: - for texes, color in list(texes_to_color_map.items()): - try: - # If the given key behaves like tex_strings - texes + "" # type: ignore [operator] - self.set_color_by_tex(texes, color, **kwargs) # type: ignore [arg-type] - except TypeError: - # If the given key is a tuple - for tex in texes: + for tex_or_texes, color in tex_to_color_map.items(): + if isinstance(tex_or_texes, str): + self.set_color_by_tex(tex_or_texes, color, **kwargs) + else: + for tex in tex_or_texes: self.set_color_by_tex(tex, color, **kwargs) return self @@ -478,7 +476,7 @@ class Tex(MathTex): def __init__( self, - *tex_strings: Any, + *tex_strings: str, arg_separator: str = "", tex_environment: str = "center", **kwargs: Any,