From 5e054718ede67c6430df222abe2147e3eaa8f9d9 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Sun, 6 Oct 2024 12:50:23 -0500 Subject: [PATCH] Type some of pymbolic.mapper.optimize --- pymbolic/mapper/optimize.py | 42 ++++++++++++++++++++++++++----------- 1 file changed, 30 insertions(+), 12 deletions(-) diff --git a/pymbolic/mapper/optimize.py b/pymbolic/mapper/optimize.py index 4295e097..b7c848b1 100644 --- a/pymbolic/mapper/optimize.py +++ b/pymbolic/mapper/optimize.py @@ -24,7 +24,9 @@ """ import ast +from collections.abc import Callable, Iterable, MutableMapping from functools import cached_property, lru_cache +from typing import TextIO, TypeVar, cast # This machinery applies AST rewriting to the mapper in a mildly brutal @@ -39,7 +41,14 @@ # {{{ ast retrieval -def _get_def_from_ast_container(container, name, node_type): +AstDefNodeT = TypeVar("AstDefNodeT", ast.FunctionDef, ast.ClassDef) + + +def _get_def_from_ast_container( + container: Iterable[ast.AST], + name: str, + node_type: type[AstDefNodeT] + ) -> AstDefNodeT: for entry in container: if isinstance(entry, node_type) and entry.name == name: return entry @@ -48,17 +57,17 @@ def _get_def_from_ast_container(container, name, node_type): @lru_cache -def _get_ast_for_file(filename): +def _get_ast_for_file(filename: str) -> ast.Module: with open(filename) as inf: return ast.parse(inf.read(), filename) -def _get_file_name_for_module_name(module_name): +def _get_file_name_for_module_name(module_name: str) -> str | None: from importlib import import_module return import_module(module_name).__file__ -def _get_ast_for_module_name(module_name): +def _get_ast_for_module_name(module_name: str) -> ast.Module: return _get_ast_for_file(_get_file_name_for_module_name(module_name)) @@ -66,13 +75,13 @@ def _get_module_ast_for_object(obj): return _get_ast_for_module_name(obj.__module__) -def _get_ast_for_class(cls): +def _get_ast_for_class(cls: type) -> ast.ClassDef: mod_ast = _get_module_ast_for_object(cls) return _get_def_from_ast_container( mod_ast.body, cls.__name__, ast.ClassDef) -def _get_ast_for_method(f): +def _get_ast_for_method(f: Callable) -> ast.FunctionDef: dot_components = f.__qualname__.split(".") assert dot_components[-1] == f.__name__ cls_name, = dot_components[:-1] @@ -238,22 +247,31 @@ def visit_Call(self, node): # noqa: N802 return result_expr -def _set_and_return(mapping, key, value): +KeyT = TypeVar("KeyT") +ValueT = TypeVar("ValueT") + + +def _set_and_return( + mapping: MutableMapping[KeyT, ValueT], + key: KeyT, + value: ValueT + ) -> ValueT: mapping[key] = value return value def optimize_mapper( - *, drop_args=False, drop_kwargs=False, - inline_rec=False, inline_cache=False, inline_get_cache_key=False, - print_modified_code_file=None): + *, drop_args: bool = False, drop_kwargs: bool = False, + inline_rec: bool = False, inline_cache: bool = False, + inline_get_cache_key: bool = False, + print_modified_code_file: TextIO | None = None) -> Callable[[type], type]: """ :param print_modified_code_file: a file-like object to which the modified code will be printed, or ``None``. """ # This is a crime, an abomination. But a somewhat effective one. - def wrapper(cls): + def wrapper(cls: type) -> type: try: # Introduced in Py3.9 ast.unparse # noqa: B018 @@ -378,7 +396,7 @@ def wrapper(cls): "exec"), compile_dict) - return compile_dict[cls.__name__] + return cast(type, compile_dict[cls.__name__]) return wrapper