Skip to content

Commit

Permalink
Trampoline for generator functions
Browse files Browse the repository at this point in the history
  • Loading branch information
boxed committed Oct 22, 2024
1 parent b4b553d commit 056e72c
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 5 deletions.
32 changes: 29 additions & 3 deletions mutmut/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def _mutmut_trampoline(orig, mutants, *args, **kwargs):
return mutants[mutant_name](*args, **kwargs)
"""
yield_from_trampoline_impl = trampoline_impl.replace('return ', 'yield from ').replace('_mutmut_trampoline', '_mutmut_yield_from_trampoline')


def create_mutants():
Expand Down Expand Up @@ -297,7 +298,7 @@ def write_all_mutants_to_file(*, out, source, filename):
return mutant_names, hash_by_function_name


def build_trampoline(orig_name, mutants, class_name=None):
def build_trampoline(*, orig_name, mutants, class_name, is_generator):
assert orig_name not in NEVER_MUTATE_FUNCTION_NAMES

mangled_name = mangle_function_name(name=orig_name, class_name=class_name)
Expand All @@ -309,11 +310,18 @@ def build_trampoline(orig_name, mutants, class_name=None):
access_prefix = f'object.__getattribute__(self, "'
access_suffix = '")'

if is_generator:
return_or_yield_statement = 'yield from'
trampoline_name = '_mutmut_yield_from_trampoline'
else:
return_or_yield_statement = 'return'
trampoline_name = '_mutmut_trampoline'

return f"""
{mutants_dict}
def {orig_name}({'self, ' if class_name is not None else ''}*args, **kwargs):
return _mutmut_trampoline({access_prefix}{mangled_name}__mutmut_orig{access_suffix}, {access_prefix}{mangled_name}__mutmut_mutants{access_suffix}, *args, **kwargs)
{return_or_yield_statement} {trampoline_name}({access_prefix}{mangled_name}__mutmut_orig{access_suffix}, {access_prefix}{mangled_name}__mutmut_mutants{access_suffix}, *args, **kwargs)
{orig_name}.__signature__ = _mutmut_signature({mangled_name}__mutmut_orig)
{mangled_name}__mutmut_orig.__name__ = '{mangled_name}'
Expand Down Expand Up @@ -447,6 +455,23 @@ def is_inside_dict_synonym_call(self):
return False


def is_generator(node):
assert node.type == 'funcdef'

def _is_generator(n):
if n is not node and n.type in ('funcdef', 'classdef'):
return False

if n.type == 'keyword' and n.value == 'yield':
return True

for c in getattr(n, 'children', []):
if _is_generator(c):
return True
return False
return _is_generator(node)


def yield_mutants_for_function(node, *, class_name=None, no_mutate_lines):
assert node.type == 'funcdef'

Expand Down Expand Up @@ -481,7 +506,7 @@ def yield_mutants_for_function(node, *, class_name=None, no_mutate_lines):
finally:
context.stack.pop()

trampoline = build_trampoline(node.name.value, context.mutants, class_name=class_name)
trampoline = build_trampoline(orig_name=node.name.value, mutants=context.mutants, class_name=class_name, is_generator=is_generator(node))
if class_name is not None:
trampoline = indent(trampoline, ' ')
yield 'trampoline', trampoline, None, None
Expand Down Expand Up @@ -530,6 +555,7 @@ def yield_mutants_for_module(node, no_mutate_lines):
yield from yield_future_imports(node)

yield 'trampoline_impl', trampoline_impl, None, None
yield 'trampoline_impl', yield_from_trampoline_impl, None, None
yield 'filler', '\n', None, None
for child_node in node.children:
if child_node.type == 'funcdef':
Expand Down
38 changes: 38 additions & 0 deletions tests/test_mutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
CLASS_NAME_SEPARATOR,
FuncContext,
get_diff_for_mutant,
is_generator,
mangle_function_name,
orig_function_and_class_names_from_key,
pragma_no_mutate_lines,
Expand Down Expand Up @@ -358,3 +359,40 @@ def foo():
mutated_source = full_mutated_source(source)
assert mutated_source.split('\n')[0] == 'from __future__ import annotations'
assert mutated_source.count('from __future__') == 1


def test_preserve_generators():
source = '''
def foo():
yield 1
'''.strip()
mutated_source = full_mutated_source(source)
assert 'yield from _mutmut_yield_from_trampoline' in mutated_source


def test_is_generator():
source = '''
def foo():
yield 1
'''.strip()
assert is_generator(parse(source).children[0])

source = '''
def foo():
yield from bar()
'''.strip()
assert is_generator(parse(source).children[0])

source = '''
def foo():
return 1
'''.strip()
assert not is_generator(parse(source).children[0])

source = '''
def foo():
def bar():
yield 2
return 1
'''.strip()
assert not is_generator(parse(source).children[0])
5 changes: 3 additions & 2 deletions tests/test_mutmut3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from mutmut.__main__ import (
trampoline_impl,
yield_from_trampoline_impl,
yield_mutants_for_module,
)

Expand All @@ -14,7 +15,7 @@ def foo(a, b, c):
return a + b * c
"""

expected = trampoline_impl + """
expected = trampoline_impl + yield_from_trampoline_impl + """
a + 1
Expand Down Expand Up @@ -53,7 +54,7 @@ def foo(a: List[int]) -> int:
return 1
"""

expected = trampoline_impl + """
expected = trampoline_impl + yield_from_trampoline_impl + """
def x_foo__mutmut_orig(a: List[int]) -> int:
return 1
Expand Down

0 comments on commit 056e72c

Please sign in to comment.