Skip to content

Commit

Permalink
Type some of pymbolic.mapper.optimize
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Oct 6, 2024
1 parent feecb04 commit 5e05471
Showing 1 changed file with 30 additions and 12 deletions.
42 changes: 30 additions & 12 deletions pymbolic/mapper/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
"""

import ast
from collections.abc import Callable, Iterable, MutableMapping
from functools import cached_property, lru_cache
from typing import TextIO, TypeVar, cast


# This machinery applies AST rewriting to the mapper in a mildly brutal
Expand All @@ -39,7 +41,14 @@

# {{{ ast retrieval

def _get_def_from_ast_container(container, name, node_type):
AstDefNodeT = TypeVar("AstDefNodeT", ast.FunctionDef, ast.ClassDef)


def _get_def_from_ast_container(
container: Iterable[ast.AST],
name: str,
node_type: type[AstDefNodeT]
) -> AstDefNodeT:
for entry in container:
if isinstance(entry, node_type) and entry.name == name:
return entry
Expand All @@ -48,31 +57,31 @@ def _get_def_from_ast_container(container, name, node_type):


@lru_cache
def _get_ast_for_file(filename):
def _get_ast_for_file(filename: str) -> ast.Module:
with open(filename) as inf:
return ast.parse(inf.read(), filename)


def _get_file_name_for_module_name(module_name):
def _get_file_name_for_module_name(module_name: str) -> str | None:
from importlib import import_module
return import_module(module_name).__file__


def _get_ast_for_module_name(module_name):
def _get_ast_for_module_name(module_name: str) -> ast.Module:
return _get_ast_for_file(_get_file_name_for_module_name(module_name))


def _get_module_ast_for_object(obj):
return _get_ast_for_module_name(obj.__module__)


def _get_ast_for_class(cls):
def _get_ast_for_class(cls: type) -> ast.ClassDef:
mod_ast = _get_module_ast_for_object(cls)
return _get_def_from_ast_container(
mod_ast.body, cls.__name__, ast.ClassDef)


def _get_ast_for_method(f):
def _get_ast_for_method(f: Callable) -> ast.FunctionDef:
dot_components = f.__qualname__.split(".")
assert dot_components[-1] == f.__name__
cls_name, = dot_components[:-1]
Expand Down Expand Up @@ -238,22 +247,31 @@ def visit_Call(self, node): # noqa: N802
return result_expr


def _set_and_return(mapping, key, value):
KeyT = TypeVar("KeyT")
ValueT = TypeVar("ValueT")


def _set_and_return(
mapping: MutableMapping[KeyT, ValueT],
key: KeyT,
value: ValueT
) -> ValueT:
mapping[key] = value
return value


def optimize_mapper(
*, drop_args=False, drop_kwargs=False,
inline_rec=False, inline_cache=False, inline_get_cache_key=False,
print_modified_code_file=None):
*, drop_args: bool = False, drop_kwargs: bool = False,
inline_rec: bool = False, inline_cache: bool = False,
inline_get_cache_key: bool = False,
print_modified_code_file: TextIO | None = None) -> Callable[[type], type]:
"""
:param print_modified_code_file: a file-like object to which the modified
code will be printed, or ``None``.
"""
# This is a crime, an abomination. But a somewhat effective one.

def wrapper(cls):
def wrapper(cls: type) -> type:
try:
# Introduced in Py3.9
ast.unparse # noqa: B018
Expand Down Expand Up @@ -378,7 +396,7 @@ def wrapper(cls):
"exec"),
compile_dict)

return compile_dict[cls.__name__]
return cast(type, compile_dict[cls.__name__])

return wrapper

Expand Down

0 comments on commit 5e05471

Please sign in to comment.