diff --git a/rope/refactor/similarfinder.py b/rope/refactor/similarfinder.py index dd5bec36..3c197e42 100644 --- a/rope/refactor/similarfinder.py +++ b/rope/refactor/similarfinder.py @@ -269,10 +269,26 @@ def get_region(self): class CodeTemplate: + _dollar_name_pattern = r"(?P\$\{[^\s\$\}]*\})" + def __init__(self, template): self.template = template self._find_names() + @classmethod + def _get_pattern(cls): + if cls._match_pattern is None: + pattern = "|".join( + ( + codeanalyze.get_comment_pattern(), + codeanalyze.get_string_pattern(), + f"(?P{codeanalyze.get_formatted_string_pattern()})", + cls._dollar_name_pattern, + ) + ) + cls._match_pattern = re.compile(pattern) + return cls._match_pattern + def _find_names(self): self.names = {} for match in CodeTemplate._get_pattern().finditer(self.template): @@ -283,6 +299,29 @@ def _find_names(self): self.names[name] = [] self.names[name].append((start, end)) + elif "fstring" in match.groupdict() and match.group("fstring") is not None: + self._fstring_case(match) + + def _fstring_case(self, fstring_match: re.Match): + """Needed because CodeTemplate._match_pattern short circuits + as soon as it sees a '#', even if that '#' is inside a f-string + that has a ${variable}.""" + + string_start, string_end = fstring_match.span("fstring") + + for match in re.finditer(self._dollar_name_pattern, self.template): + if match.start("name") < string_start: + continue + + if match.end("name") > string_end: + break + + start, end = match.span("name") + name = self.template[start + 2 : end - 1] + if name not in self.names: + self.names[name] = [] + self.names[name].append((start, end)) + def get_names(self): return self.names.keys() @@ -298,19 +337,6 @@ def substitute(self, mapping): _match_pattern = None - @classmethod - def _get_pattern(cls): - if cls._match_pattern is None: - pattern = ( - codeanalyze.get_comment_pattern() - + "|" - + codeanalyze.get_string_pattern() - + "|" - + r"(?P\$\{[^\s\$\}]*\})" - ) - cls._match_pattern = re.compile(pattern) - return cls._match_pattern - class _RopeVariable: """Transform and identify rope inserted wildcards""" diff --git a/ropetest/refactor/extracttest.py b/ropetest/refactor/extracttest.py index 4b34d58d..54fc83bd 100644 --- a/ropetest/refactor/extracttest.py +++ b/ropetest/refactor/extracttest.py @@ -108,6 +108,28 @@ def extracted(): """) self.assertEqual(expected, refactored) + def test_extract_function_with_fstring(self): + code = dedent("""\ + def main(): + h = 1 + g = f"#{h}" + print(g) + """) + start, end = self._convert_line_range_to_offset(code, 3, 3) + refactored = self.do_extract_method(code, start, end, "extracted") + + expected = dedent("""\ + def main(): + h = 1 + g = extracted(h) + print(g) + + def extracted(h): + g = f"#{h}" + return g + """) + self.assertEqual(expected, refactored) + def test_extract_function_containing_dict_generalized_unpacking(self): code = dedent("""\ def a_func(dict1):