diff --git a/fortls/helper_functions.py b/fortls/helper_functions.py index 0d589377..b5e42167 100644 --- a/fortls/helper_functions.py +++ b/fortls/helper_functions.py @@ -100,10 +100,9 @@ def strip_line_label(line: str) -> tuple[str, str | None]: match = FRegex.LINE_LABEL.match(line) if match is None: return line, None - else: - line_label = match.group(1) - out_str = line[: match.start(1)] + " " * len(line_label) + line[match.end(1) :] - return out_str, line_label + line_label = match.group(1) + out_str = line[: match.start(1)] + " " * len(line_label) + line[match.end(1) :] + return out_str, line_label def strip_strings(in_line: str, maintain_len: bool = False) -> str: @@ -172,7 +171,7 @@ def separate_def_list(test_str: str) -> list[str] | None: if curr_str != "": def_list.append(curr_str) curr_str = "" - elif (curr_str == "") and (len(def_list) == 0): + elif not def_list: return None continue curr_str += char @@ -198,17 +197,20 @@ def find_word_in_line(line: str, word: str) -> Range: start and end positions (indices) of the word if not found it returns -1, len(word) -1 """ - i = -1 - for poss_name in FRegex.WORD.finditer(line): - if poss_name.group() == word: - i = poss_name.start() - break + i = next( + ( + poss_name.start() + for poss_name in FRegex.WORD.finditer(line) + if poss_name.group() == word + ), + -1, + ) # TODO: if i == -1: return None makes more sense return Range(i, i + len(word)) def find_paren_match(string: str) -> int: - """Find matching closing parenthesis **from an already open parenthesis scope** + """Find matching closing parenthesis from an already open parenthesis scope by forward search of the string, returns -1 if no match is found Parameters @@ -237,7 +239,6 @@ def find_paren_match(string: str) -> int: -1 """ paren_count = 1 - ind = -1 for i, char in enumerate(string): if char == "(": paren_count += 1 @@ -245,7 +246,7 @@ def find_paren_match(string: str) -> int: paren_count -= 1 if paren_count == 0: return i - return ind + return -1 def get_line_prefix( @@ -282,17 +283,16 @@ def get_line_prefix( col += len(prepend_string) line_prefix = curr_line[:col].lower() # Ignore string literals - if qs: - if (line_prefix.find("'") > -1) or (line_prefix.find('"') > -1): - sq_count = 0 - dq_count = 0 - for char in line_prefix: - if (char == "'") and (dq_count % 2 == 0): - sq_count += 1 - elif (char == '"') and (sq_count % 2 == 0): - dq_count += 1 - if (dq_count % 2 == 1) or (sq_count % 2 == 1): - return None + if qs and ((line_prefix.find("'") > -1) or (line_prefix.find('"') > -1)): + sq_count = 0 + dq_count = 0 + for char in line_prefix: + if (char == "'") and (dq_count % 2 == 0): + sq_count += 1 + elif (char == '"') and (sq_count % 2 == 0): + dq_count += 1 + if (dq_count % 2 == 1) or (sq_count % 2 == 1): + return None return line_prefix @@ -329,14 +329,12 @@ def resolve_globs(glob_path: str, root_path: str = None) -> list[str]: >>> resolve_globs('test') == [str(pathlib.Path(os.getcwd()) / 'test')] True """ - # Resolve absolute paths i.e. not in our root_path - if os.path.isabs(glob_path) or not root_path: - p = Path(glob_path).resolve() - root = p.anchor # drive letter + root path - rel = str(p.relative_to(root)) # contains glob pattern - return [str(p.resolve()) for p in Path(root).glob(rel)] - else: + if not os.path.isabs(glob_path) and root_path: return [str(p.resolve()) for p in Path(root_path).resolve().glob(glob_path)] + p = Path(glob_path).resolve() + root = p.anchor # drive letter + root path + rel = str(p.relative_to(root)) # contains glob pattern + return [str(p.resolve()) for p in Path(root).glob(rel)] def only_dirs(paths: list[str]) -> list[str]: @@ -406,7 +404,9 @@ def map_keywords(keywords: list[str]): return mapped_keywords, keyword_info -def get_keywords(keywords: list, keyword_info: dict = {}): +def get_keywords(keywords: list, keyword_info: dict = None): + if keyword_info is None: + keyword_info = {} keyword_strings = [] for keyword_id in keywords: string_rep = KEYWORD_LIST[keyword_id] @@ -461,10 +461,7 @@ def get_paren_substring(string: str) -> str | None: """ i1 = string.find("(") i2 = string.rfind(")") - if -1 < i1 < i2: - return string[i1 + 1 : i2] - else: - return None + return string[i1 + 1 : i2] if -1 < i1 < i2 else None def get_paren_level(line: str) -> tuple[str, list[Range]]: @@ -496,7 +493,7 @@ def get_paren_level(line: str) -> tuple[str, list[Range]]: ('', [Range(start=0, end=0)]) """ - if line == "": + if not line: return "", [Range(0, 0)] level = 0 in_string = False @@ -526,9 +523,7 @@ def get_paren_level(line: str) -> tuple[str, list[Range]]: if level == 0: sections.append(Range(i, i1)) sections.reverse() - out_string = "" - for section in sections: - out_string += line[section.start : section.end] + out_string = "".join(line[section.start : section.end] for section in sections) return out_string, sections @@ -564,7 +559,7 @@ def get_var_stack(line: str) -> list[str]: >>> get_var_stack('') [''] """ - if len(line) == 0: + if not line: return [""] final_var, sections = get_paren_level(line) if final_var == "": @@ -574,10 +569,9 @@ def get_var_stack(line: str) -> list[str]: for i, section in enumerate(sections): if not line[section.start : section.end].strip().startswith("%"): iLast = i - final_var = "" - for section in sections[iLast:]: - final_var += line[section.start : section.end] - + final_var = "".join( + line[section.start : section.end] for section in sections[iLast:] + ) if final_var is not None: final_var = "%".join([i.strip() for i in final_var.split("%")]) final_op_split: list[str] = FRegex.OBJBREAK.split(final_var) @@ -586,6 +580,43 @@ def get_var_stack(line: str) -> list[str]: return None +def get_placeholders(arg_list: list[str]) -> tuple[str, str]: + """ + Function used to generate placeholders for snippets + + Parameters + ---------- + arg_list : list[str] + Method arguments list + + Returns + ------- + Tuple[str, str] + Tuple of arguments as a string and snippet string + + Examples + -------- + >>> get_placeholders(['x', 'y']) + ('(x, y)', '(${1:x}, ${2:y})') + + >>> get_placeholders(['x=1', 'y=2']) + ('(x=1, y=2)', '(x=${1:1}, y=${2:2})') + + >>> get_placeholders(['x', 'y=2', 'z']) + ('(x, y=2, z)', '(${1:x}, y=${2:2}, ${3:z})') + """ + place_holders = [] + for i, arg in enumerate(arg_list): + opt_split = arg.split("=") + if len(opt_split) > 1: + place_holders.append(f"{opt_split[0]}=${{{i+1}:{opt_split[1]}}}") + else: + place_holders.append(f"${{{i+1}:{arg}}}") + arg_str = f"({', '.join(arg_list)})" + arg_snip = f"({', '.join(place_holders)})" + return arg_str, arg_snip + + def fortran_md(code: str, docs: str | None): """Convert Fortran code to markdown diff --git a/fortls/intrinsics.py b/fortls/intrinsics.py index 3f128dbd..701e3724 100644 --- a/fortls/intrinsics.py +++ b/fortls/intrinsics.py @@ -3,8 +3,9 @@ import glob import json import os +import pathlib -from fortls.helper_functions import fortran_md, map_keywords +from fortls.helper_functions import fortran_md, get_placeholders, map_keywords from fortls.objects import ( FortranAST, FortranObj, @@ -26,9 +27,7 @@ def set_lowercase_intrinsics(): def intrinsics_case(name: str, args: str): - if lowercase_intrinsics: - return name.lower(), args.lower() - return name, args + return (name.lower(), args.lower()) if lowercase_intrinsics else (name, args) class Intrinsic(FortranObj): @@ -68,19 +67,13 @@ def get_snippet(self, name_replace=None, drop_arg=-1): arg_snip = None else: arg_list = self.args.split(",") - arg_str, arg_snip = self.get_placeholders(arg_list) - name = self.name - if name_replace is not None: - name = name_replace - snippet = None - if arg_snip is not None: - snippet = name + arg_snip + arg_str, arg_snip = get_placeholders(arg_list) + name = name_replace if name_replace is not None else self.name + snippet = name + arg_snip if arg_snip is not None else None return name + arg_str, snippet def get_signature(self): - arg_sigs = [] - for arg in self.args.split(","): - arg_sigs.append({"label": arg}) + arg_sigs = [{"label": arg} for arg in self.args.split(",")] call_sig, _ = self.get_snippet() return call_sig, self.doc_str, arg_sigs @@ -89,13 +82,11 @@ def get_hover(self, long=False): def get_hover_md(self, long=False): msg, docs = self.get_hover(long) - msg = msg if msg else "" + msg = msg or "" return fortran_md(msg, docs) def is_callable(self): - if self.type == 2: - return True - return False + return self.type == 2 def load_intrinsics(): @@ -281,8 +272,7 @@ def update_m_intrinsics(): for f in sorted(files): key = f.replace("M_intrinsics/md/", "") key = key.replace(".md", "").upper() # remove md extension - with open(f) as md_f: - val = md_f.read() + val = pathlib.Path(f).read_text() # remove manpage tag val = val.replace(f"**{key.lower()}**(3)", f"**{key.lower()}**") val = val.replace(f"**{key.upper()}**(3)", f"**{key.upper()}**") diff --git a/fortls/json_templates.py b/fortls/json_templates.py index bd9f15ea..34e51a0a 100644 --- a/fortls/json_templates.py +++ b/fortls/json_templates.py @@ -10,7 +10,7 @@ def range_json(sln: int, sch: int, eln: int = None, ech: int = None): } -def diagnostic_json(sln: int, sch: int, eln: int, ech: int, msg: str, sev: str): +def diagnostic_json(sln: int, sch: int, eln: int, ech: int, msg: str, sev: int): return {**range_json(sln, sch, eln, ech), "message": msg, "severity": sev} diff --git a/fortls/objects.py b/fortls/objects.py index 72186305..955e69c4 100644 --- a/fortls/objects.py +++ b/fortls/objects.py @@ -1,10 +1,12 @@ from __future__ import annotations +import contextlib import copy import os import re from dataclasses import dataclass from typing import Pattern +from typing import Type as T from fortls.constants import ( ASSOC_TYPE_ID, @@ -31,6 +33,7 @@ fortran_md, get_keywords, get_paren_substring, + get_placeholders, get_var_stack, ) from fortls.json_templates import diagnostic_json, location_json, range_json @@ -41,10 +44,17 @@ def get_use_tree( scope: Scope, use_dict: dict[str, Use | Import], obj_tree: dict, - only_list: set[str] = set(), - rename_map: dict[str, str] = {}, - curr_path: list[str] = [], + only_list: list[str] = None, + rename_map: dict[str, str] = None, + curr_path: list[str] = None, ): + if only_list is None: + only_list = set() + if rename_map is None: + rename_map = {} + if curr_path is None: + curr_path = [] + def intersect_only(use_stmnt: Use | Import): tmp_list = [] tmp_map = rename_map.copy() @@ -72,7 +82,7 @@ def intersect_only(use_stmnt: Use | Import): if type(use_stmnt) is Import and use_stmnt.import_type is ImportTypes.NONE: continue # Intersect parent and current ONLY list and renaming - if len(only_list) == 0: + if not only_list: merged_use_list = use_stmnt.only_list.copy() merged_rename = use_stmnt.rename_map.copy() elif len(use_stmnt.only_list) == 0: @@ -123,10 +133,8 @@ def intersect_only(use_stmnt: Use | Import): only_list=set(merged_use_list), rename_map=merged_rename, ) - try: + with contextlib.suppress(AttributeError): use_dict[use_stmnt.mod_name].scope = scope.parent.parent - except AttributeError: - pass # Do not descent the IMPORT tree, because it does not exist if type(use_stmnt) is Import: continue @@ -161,9 +169,9 @@ def check_scope( tmp_var = check_scope(child, var_name_lower, filter_public) if tmp_var is not None: return tmp_var - if filter_public: - if (child.vis < 0) or ((local_scope.def_vis < 0) and (child.vis <= 0)): - continue + is_private = child.vis < 0 or (local_scope.def_vis < 0 and child.vis <= 0) + if filter_public and is_private: + continue if child.name.lower() == var_name_lower: # For functions with an implicit result() variable the name # of the function is used. If we are hovering over the function @@ -183,13 +191,13 @@ def check_scope( def check_import_scope(scope: Scope, var_name_lower: str): for use_stmnt in scope.use: - if not type(use_stmnt) is Import: + if type(use_stmnt) is not Import: continue if use_stmnt.import_type == ImportTypes.ONLY: # Check if name is in only list if var_name_lower in use_stmnt.only_list: return ImportTypes.ONLY - # Get Get the parent scope + # Get the parent scope elif use_stmnt.import_type == ImportTypes.ALL: return ImportTypes.ALL # Skip looking for parent scope @@ -227,9 +235,8 @@ def check_import_scope(scope: Scope, var_name_lower: str): if use_mod.lower() == var_name_lower: return use_scope # Filter children by only_list - if len(use_info.only_list) > 0: - if var_name_lower not in use_info.only_list: - continue + if len(use_info.only_list) > 0 and var_name_lower not in use_info.only_list: + continue mod_name = use_info.rename_map.get(var_name_lower, var_name_lower) tmp_var = check_scope(use_scope, mod_name, filter_public=True) if tmp_var is not None: @@ -320,10 +327,14 @@ class Use: def __init__( self, mod_name: str, - only_list: set[str] = set(), - rename_map: dict[str, str] = {}, - line_number: int | None = 0, + only_list: set[str] = None, + rename_map: dict[str, str] = None, + line_number: int = 0, ): + if only_list is None: + only_list = set() + if rename_map is None: + rename_map = {} self.mod_name: str = mod_name.lower() self._line_no: int = line_number self.only_list: set[str] = only_list @@ -341,14 +352,13 @@ def line_number(self): def line_number(self, line_number: int): self._line_no = line_number - def rename(self, only_list: list[str] = []): + def rename(self, only_list: list[str] = None): """Rename ONLY:, statements""" + if only_list is None: + only_list = [] if not only_list: only_list = self.only_list - renamed_only_list = [] - for only_name in only_list: - renamed_only_list.append(self.rename_map.get(only_name, only_name)) - return renamed_only_list + return [self.rename_map.get(only_name, only_name) for only_name in only_list] class ImportTypes: @@ -365,10 +375,14 @@ def __init__( self, name: str, import_type: ImportTypes = ImportTypes.DEFAULT, - only_list: set[str] = set(), - rename_map: dict[str, str] = {}, + only_list: set[str] = None, + rename_map: dict[str, str] = None, line_number: int = 0, ): + if only_list is None: + only_list = set() + if rename_map is None: + rename_map = {} super().__init__(name, only_list, rename_map, line_number) self.import_type = import_type self._scope: Scope | Module | None = None @@ -485,19 +499,6 @@ def get_desc(self): def get_snippet(self, name_replace=None, drop_arg=-1): return None, None - @staticmethod - def get_placeholders(arg_list: list[str]): - place_holders = [] - for i, arg in enumerate(arg_list): - opt_split = arg.split("=") - if len(opt_split) > 1: - place_holders.append(f"{opt_split[0]}=${{{i+1}:{opt_split[1]}}}") - else: - place_holders.append(f"${{{i+1}:{arg}}}") - arg_str = f"({', '.join(arg_list)})" - arg_snip = f"({', '.join(place_holders)})" - return arg_str, arg_snip - def get_documentation(self): return self.doc_str @@ -526,11 +527,10 @@ def get_diagnostics(self): def get_implicit(self): if self.parent is None: return self.implicit_vars - else: - parent_implicit = self.parent.get_implicit() - if (self.implicit_vars is not None) or (parent_implicit is None): - return self.implicit_vars - return parent_implicit + parent_implicit = self.parent.get_implicit() + if (self.implicit_vars is not None) or (parent_implicit is None): + return self.implicit_vars + return parent_implicit def get_actions(self, sline, eline): return None @@ -571,7 +571,7 @@ def __init__(self, file_ast, line_number: int, name: str, keywords: list = None) self.sline: int = line_number self.eline: int = line_number self.name: str = name - self.children: list = [] + self.children: list[T[Scope]] = [] self.members: list = [] self.use: list[Use | Import] = [] self.keywords: list = keywords @@ -581,7 +581,7 @@ def __init__(self, file_ast, line_number: int, name: str, keywords: list = None) self.implicit_line = None self.FQSN: str = self.name.lower() if file_ast.enc_scope_name is not None: - self.FQSN = file_ast.enc_scope_name.lower() + "::" + self.name.lower() + self.FQSN = f"{file_ast.enc_scope_name.lower()}::{self.name.lower()}" def copy_from(self, copy_source: Scope): # Pass the reference, we don't want shallow copy since that would still @@ -613,7 +613,7 @@ def add_child(self, child): def update_fqsn(self, enc_scope=None): if enc_scope is not None: - self.FQSN = enc_scope.lower() + "::" + self.name.lower() + self.FQSN = f"{enc_scope.lower()}::{self.name.lower()}" else: self.FQSN = self.name.lower() for child in self.children: @@ -622,56 +622,53 @@ def update_fqsn(self, enc_scope=None): def add_member(self, member): self.members.append(member) - def get_children(self, public_only=False): - if public_only: - pub_children = [] - for child in self.children: - if (child.vis < 0) or ((self.def_vis < 0) and (child.vis <= 0)): - continue - if child.name.startswith("#GEN_INT"): - pub_children.append(child) - continue - pub_children.append(child) - return pub_children - else: + def get_children(self, public_only=False) -> list[T[FortranObj]]: + if not public_only: return copy.copy(self.children) + pub_children = [] + for child in self.children: + if (child.vis < 0) or ((self.def_vis < 0) and (child.vis <= 0)): + continue + if child.name.startswith("#GEN_INT"): + pub_children.append(child) + continue + pub_children.append(child) + return pub_children - def check_definitions(self, obj_tree): + def check_definitions(self, obj_tree) -> list[Diagnostic]: """Check for definition errors in scope""" - FQSN_dict = {} + fqsn_dict: dict[str, int] = {} + errors: list[Diagnostic] = [] + known_types: dict[str, FortranObj] = {} + for child in self.children: # Skip masking/double checks for interfaces if child.get_type() == INTERFACE_TYPE_ID: continue # Check other variables in current scope - if child.FQSN in FQSN_dict: - if child.sline < FQSN_dict[child.FQSN]: - FQSN_dict[child.FQSN] = child.sline - 1 + if child.FQSN in fqsn_dict: + if child.sline < fqsn_dict[child.FQSN]: + fqsn_dict[child.FQSN] = child.sline - 1 else: - FQSN_dict[child.FQSN] = child.sline - 1 - # + fqsn_dict[child.FQSN] = child.sline - 1 + contains_line = -1 - after_contains_list = (SUBROUTINE_TYPE_ID, FUNCTION_TYPE_ID) if self.get_type() in ( MODULE_TYPE_ID, SUBMODULE_TYPE_ID, SUBROUTINE_TYPE_ID, FUNCTION_TYPE_ID, ): - if self.contains_start is None: - contains_line = self.eline - else: - contains_line = self.contains_start + contains_line = ( + self.contains_start if self.contains_start is not None else self.eline + ) # Detect interface definitions - is_interface = False - if ( - (self.parent is not None) - and (self.parent.get_type() == INTERFACE_TYPE_ID) - and (not self.is_mod_scope()) - ): - is_interface = True - errors = [] - known_types = {} + is_interface = ( + self.parent is not None + and self.parent.get_type() == INTERFACE_TYPE_ID + and not self.is_mod_scope() + ) + for child in self.children: if child.name.startswith("#"): continue @@ -683,8 +680,9 @@ def check_definitions(self, obj_tree): if def_error is not None: errors.append(def_error) # Detect contains errors - if (contains_line >= child.sline) and ( - child.get_type(no_link=True) in after_contains_list + if contains_line >= child.sline and child.get_type(no_link=True) in ( + SUBROUTINE_TYPE_ID, + FUNCTION_TYPE_ID, ): new_diag = Diagnostic( line_number, @@ -693,37 +691,41 @@ def check_definitions(self, obj_tree): ) errors.append(new_diag) # Skip masking/double checks for interfaces and members - if (self.get_type() == INTERFACE_TYPE_ID) or ( - child.get_type() == INTERFACE_TYPE_ID + if ( + self.get_type() == INTERFACE_TYPE_ID + or child.get_type() == INTERFACE_TYPE_ID ): continue # Check other variables in current scope - if child.FQSN in FQSN_dict: - if line_number > FQSN_dict[child.FQSN]: - new_diag = Diagnostic( - line_number, - message=f'Variable "{child.name}" declared twice in scope', - severity=1, - find_word=child.name, - ) - new_diag.add_related( - path=self.file_ast.path, - line=FQSN_dict[child.FQSN], - message="First declaration", - ) - errors.append(new_diag) - continue + if child.FQSN in fqsn_dict and line_number > fqsn_dict[child.FQSN]: + new_diag = Diagnostic( + line_number, + message=f'Variable "{child.name}" declared twice in scope', + severity=1, + find_word=child.name, + ) + new_diag.add_related( + path=self.file_ast.path, + line=fqsn_dict[child.FQSN], + message="First declaration", + ) + errors.append(new_diag) + continue # Check for masking from parent scope in subroutines, functions, and blocks - if (self.parent is not None) and ( - self.get_type() in (SUBROUTINE_TYPE_ID, FUNCTION_TYPE_ID, BLOCK_TYPE_ID) + if self.parent is not None and self.get_type() in ( + SUBROUTINE_TYPE_ID, + FUNCTION_TYPE_ID, + BLOCK_TYPE_ID, ): parent_var = find_in_scope(self.parent, child.name, obj_tree) if parent_var is not None: # Ignore if function return variable - if (self.get_type() == FUNCTION_TYPE_ID) and ( - parent_var.FQSN == self.FQSN + if ( + self.get_type() == FUNCTION_TYPE_ID + and parent_var.FQSN == self.FQSN ): continue + new_diag = Diagnostic( line_number, message=( @@ -738,6 +740,7 @@ def check_definitions(self, obj_tree): message="First declaration", ) errors.append(new_diag) + return errors def check_use(self, obj_tree): @@ -809,10 +812,8 @@ def get_hover(self, long=False, drop_arg=-1) -> tuple[str, str]: doc_str = self.get_documentation() return hover, doc_str - def check_valid_parent(self): - if self.parent is not None: - return False - return True + def check_valid_parent(self) -> bool: + return self.parent is None class Include(Scope): @@ -861,47 +862,59 @@ def require_inherit(self): return True def resolve_link(self, obj_tree): + def get_ancestor_interfaces( + ancestor_children: list[Scope], + ) -> list[T[Interface]]: + interfaces = [] + for child in ancestor_children: + if child.get_type() != INTERFACE_TYPE_ID: + continue + for interface in child.children: + interface_type = interface.get_type() + if ( + interface_type + in (SUBROUTINE_TYPE_ID, FUNCTION_TYPE_ID, BASE_TYPE_ID) + ) and interface.is_mod_scope(): + interfaces.append(interface) + return interfaces + + def create_child_from_prototype(child: Scope, interface: Interface): + if interface.get_type() == SUBROUTINE_TYPE_ID: + return Subroutine(child.file_ast, child.sline, child.name) + elif interface.get_type() == FUNCTION_TYPE_ID: + return Function(child.file_ast, child.sline, child.name) + else: + raise ValueError(f"Unsupported interface type: {interface.get_type()}") + + def replace_child_in_scope_list(child: Scope, child_old: Scope): + for i, file_scope in enumerate(child.file_ast.scope_list): + if file_scope is child_old: + child.file_ast.scope_list[i] = child + return child + # Link subroutine/function implementations to prototypes if self.ancestor_obj is None: return - # Grab ancestor interface definitions (function/subroutine only) - ancestor_interfaces = [] - for child in self.ancestor_obj.children: - if child.get_type() == INTERFACE_TYPE_ID: - for prototype in child.children: - prototype_type = prototype.get_type() - if ( - prototype_type - in (SUBROUTINE_TYPE_ID, FUNCTION_TYPE_ID, BASE_TYPE_ID) - ) and prototype.is_mod_scope(): - ancestor_interfaces.append(prototype) + + ancestor_interfaces = get_ancestor_interfaces(self.ancestor_obj.children) # Match interface definitions to implementations - for prototype in ancestor_interfaces: + for interface in ancestor_interfaces: for i, child in enumerate(self.children): - if child.name.lower() == prototype.name.lower(): - # Create correct object for interface - if child.get_type() == BASE_TYPE_ID: - child_old = child - if prototype.get_type() == SUBROUTINE_TYPE_ID: - child = Subroutine( - child_old.file_ast, child_old.sline, child_old.name - ) - elif prototype.get_type() == FUNCTION_TYPE_ID: - child = Function( - child_old.file_ast, child_old.sline, child_old.name - ) - child.copy_from(child_old) - # Replace in child and scope lists - self.children[i] = child - for j, file_scope in enumerate(child.file_ast.scope_list): - if file_scope is child_old: - child.file_ast.scope_list[j] = child - if child.get_type() == prototype.get_type(): - # Link the interface with the implementation - prototype.link_obj = child - prototype.resolve_link(obj_tree) - child.copy_interface(prototype) - break + if child.name.lower() != interface.name.lower(): + continue + + if child.get_type() == BASE_TYPE_ID: + child_old = child + child = create_child_from_prototype(child_old, interface) + child.copy_from(child_old) + self.children[i] = child + child = replace_child_in_scope_list(child, child_old) + + if child.get_type() == interface.get_type(): + interface.link_obj = child + interface.resolve_link(obj_tree) + child.copy_interface(interface) + break def require_link(self): return True @@ -938,9 +951,7 @@ def copy_interface(self, copy_source: Subroutine) -> list[str]: self.args_snip = copy_source.args_snip self.arg_objs = copy_source.arg_objs # Get current fields - child_names = [] - for child in self.children: - child_names.append(child.name.lower()) + child_names = [child.name.lower() for child in self.children] # Import arg_objs from copy object self.in_children = [] for child in copy_source.arg_objs: @@ -1004,15 +1015,11 @@ def get_snippet(self, name_replace=None, drop_arg=-1): del arg_list[drop_arg] arg_snip = None if len(arg_list) > 0: - arg_str, arg_snip = self.get_placeholders(arg_list) + arg_str, arg_snip = get_placeholders(arg_list) else: arg_str = "()" - name = self.name - if name_replace is not None: - name = name_replace - snippet = None - if arg_snip is not None: - snippet = name + arg_snip + name = name_replace if name_replace is not None else self.name + snippet = name + arg_snip if arg_snip is not None else None return name + arg_str, snippet def get_desc(self): @@ -1091,14 +1098,14 @@ def get_signature(self, drop_arg=-1): # TODO: fix this def get_interface_array( - self, keywords: list[str], signature: str, change_arg=-1, change_strings=None + self, keywords: list[str], signature: str, drop_arg=-1, change_strings=None ): interface_array = [" ".join(keywords) + signature] for i, arg_obj in enumerate(self.arg_objs): if arg_obj is None: return None arg_doc, docs = arg_obj.get_hover() - if i == change_arg: + if i == drop_arg: i0 = arg_doc.lower().find(change_strings[0].lower()) if i0 >= 0: i1 = i0 + len(change_strings[0]) @@ -1106,16 +1113,14 @@ def get_interface_array( interface_array.append(f"{arg_doc} :: {arg_obj.name}") return interface_array - def get_interface(self, name_replace=None, change_arg=-1, change_strings=None): + def get_interface(self, name_replace=None, drop_arg=-1, change_strings=None): sub_sig, _ = self.get_snippet(name_replace=name_replace) keyword_list = get_keywords(self.keywords) keyword_list.append("SUBROUTINE ") interface_array = self.get_interface_array( - keyword_list, sub_sig, change_arg, change_strings + keyword_list, sub_sig, drop_arg, change_strings ) - name = self.name - if name_replace is not None: - name = name_replace + name = name_replace if name_replace is not None else self.name interface_array.append(f"END SUBROUTINE {name}") return "\n".join(interface_array) @@ -1192,9 +1197,11 @@ def copy_interface(self, copy_source: Function): self.result_name = copy_source.result_name self.result_type = copy_source.result_type self.result_obj = copy_source.result_obj - if copy_source.result_obj is not None: - if copy_source.result_obj.name.lower() not in child_names: - self.in_children.append(copy_source.result_obj) + if ( + copy_source.result_obj is not None + and copy_source.result_obj.name.lower() not in child_names + ): + self.in_children.append(copy_source.result_obj) def resolve_link(self, obj_tree): self.resolve_arg_link(obj_tree) @@ -1210,9 +1217,8 @@ def get_type(self, no_link=False): return FUNCTION_TYPE_ID def get_desc(self): - if self.result_type: - return self.result_type + " FUNCTION" - return "FUNCTION" + token = "FUNCTION" + return f"{self.result_type} {token}" if self.result_type else token def is_callable(self): return False @@ -1266,7 +1272,7 @@ def get_hover(self, long: bool = False, drop_arg: int = -1) -> tuple[str, str]: return "\n ".join(hover_array), " \n".join(docs) # TODO: fix this - def get_interface(self, name_replace=None, change_arg=-1, change_strings=None): + def get_interface(self, name_replace=None, drop_arg=-1, change_strings=None): fun_sig, _ = self.get_snippet(name_replace=name_replace) fun_sig += f" RESULT({self.result_name})" # XXX: @@ -1277,14 +1283,12 @@ def get_interface(self, name_replace=None, change_arg=-1, change_strings=None): keyword_list.append("FUNCTION ") interface_array = self.get_interface_array( - keyword_list, fun_sig, change_arg, change_strings + keyword_list, fun_sig, drop_arg, change_strings ) if self.result_obj is not None: arg_doc, docs = self.result_obj.get_hover() interface_array.append(f"{arg_doc} :: {self.result_obj.name}") - name = self.name - if name_replace is not None: - name = name_replace + name = name_replace if name_replace is not None else self.name interface_array.append(f"END FUNCTION {name}") return "\n".join(interface_array) @@ -1294,16 +1298,12 @@ def __init__( self, file_ast: FortranAST, line_number: int, name: str, keywords: list ): super().__init__(file_ast, line_number, name, keywords) - # self.in_children: list = [] self.inherit = None self.inherit_var = None self.inherit_tmp = None self.inherit_version = -1 - if self.keywords.count(KEYWORD_ID_DICT["abstract"]) > 0: - self.abstract = True - else: - self.abstract = False + self.abstract = self.keywords.count(KEYWORD_ID_DICT["abstract"]) > 0 if self.keywords.count(KEYWORD_ID_DICT["public"]) > 0: self.vis = 1 if self.keywords.count(KEYWORD_ID_DICT["private"]) > 0: @@ -1326,21 +1326,22 @@ def resolve_inherit(self, obj_tree, inherit_version): self.inherit_version = inherit_version self.inherit_var = find_in_scope(self.parent, self.inherit, obj_tree) if self.inherit_var is not None: - # Resolve parent inheritance while avoiding circular recursion - self.inherit_tmp = self.inherit - self.inherit = None - self.inherit_var.resolve_inherit(obj_tree, inherit_version) - self.inherit = self.inherit_tmp - self.inherit_tmp = None - # Get current fields - child_names = [] - for child in self.children: - child_names.append(child.name.lower()) - # Import for parent objects - self.in_children = [] - for child in self.inherit_var.get_children(): - if child.name.lower() not in child_names: - self.in_children.append(child) + self._resolve_inherit_parent(obj_tree, inherit_version) + + def _resolve_inherit_parent(self, obj_tree, inherit_version): + # Resolve parent inheritance while avoiding circular recursion + self.inherit_tmp = self.inherit + self.inherit = None + self.inherit_var.resolve_inherit(obj_tree, inherit_version) + self.inherit = self.inherit_tmp + self.inherit_tmp = None + # Get current fields + child_names = [child.name.lower() for child in self.children] + # Import for parent objects + self.in_children = [] + for child in self.inherit_var.get_children(): + if child.name.lower() not in child_names: + self.in_children.append(child) def require_inherit(self): return True @@ -1359,11 +1360,8 @@ def get_overridden(self, field_name): def check_valid_parent(self): if self.parent is None: return False - else: - parent_type = self.parent.get_type() - if (parent_type == CLASS_TYPE_ID) or (parent_type >= BLOCK_TYPE_ID): - return False - return True + parent_type = self.parent.get_type() + return parent_type != CLASS_TYPE_ID and parent_type < BLOCK_TYPE_ID def get_diagnostics(self): errors = [] @@ -1567,12 +1565,10 @@ def resolve_link(self, obj_tree): if type_scope is None: continue var_obj = find_in_scope(type_scope, var_stack[-1], obj_tree) - if var_obj is not None: - assoc.var.link_obj = var_obj else: var_obj = find_in_scope(self, assoc.link_name, obj_tree) - if var_obj is not None: - assoc.var.link_obj = var_obj + if var_obj is not None: + assoc.var.link_obj = var_obj def require_link(self): return True @@ -1623,7 +1619,7 @@ def is_type_binding(self): return self.select_type == 2 def is_type_region(self): - return (self.select_type == 3) or (self.select_type == 4) + return self.select_type in [3, 4] def create_binding_variable(self, file_ast, line_number, var_desc, case_type): if self.parent.get_type() != SELECT_TYPE_ID: @@ -1641,7 +1637,7 @@ def create_binding_variable(self, file_ast, line_number, var_desc, case_type): return Variable( file_ast, line_number, binding_name, var_desc, [], link_obj=bound_var ) - elif (binding_name is None) and (bound_var is not None): + elif bound_var is not None: return Variable(file_ast, line_number, bound_var, var_desc, []) return None @@ -1723,7 +1719,7 @@ def __init__( if link_obj is not None: self.link_name = link_obj.lower() if file_ast.enc_scope_name is not None: - self.FQSN = file_ast.enc_scope_name.lower() + "::" + self.name.lower() + self.FQSN = f"{file_ast.enc_scope_name.lower()}::{self.name.lower()}" if self.keywords.count(KEYWORD_ID_DICT["public"]) > 0: self.vis = 1 if self.keywords.count(KEYWORD_ID_DICT["private"]) > 0: @@ -1738,7 +1734,7 @@ def __init__( def update_fqsn(self, enc_scope=None): if enc_scope is not None: - self.FQSN = enc_scope.lower() + "::" + self.name.lower() + self.FQSN = f"{enc_scope.lower()}::{self.name.lower()}" else: self.FQSN = self.name.lower() for child in self.children: @@ -1766,9 +1762,7 @@ def get_desc(self, no_link=False): if not no_link and self.link_obj is not None: return self.link_obj.get_desc() # Normal variable - if self.kind: - return self.desc + self.kind - return self.desc + return self.desc + self.kind if self.kind else self.desc def get_type_obj(self, obj_tree): if self.link_obj is not None: @@ -1794,9 +1788,7 @@ def set_dim(self, dim_str): self.keywords.sort() def get_snippet(self, name_replace=None, drop_arg=-1): - name = self.name - if name_replace is not None: - name = name_replace + name = name_replace if name_replace is not None else self.name if self.link_obj is not None: return self.link_obj.get_snippet(name, drop_arg) # Normal variable @@ -1826,10 +1818,7 @@ def get_keywords(self): return get_keywords(self.keywords, self.keyword_info) def is_optional(self): - if self.keywords.count(KEYWORD_ID_DICT["optional"]) > 0: - return True - else: - return False + return self.keywords.count(KEYWORD_ID_DICT["optional"]) > 0 def is_callable(self): return self.callable @@ -1844,7 +1833,9 @@ def set_external_attr(self): self.keywords.append(KEYWORD_ID_DICT["external"]) self.is_external = True - def check_definition(self, obj_tree, known_types={}, interface=False): + def check_definition(self, obj_tree, known_types=None, interface=False): + if known_types is None: + known_types = {} # Check for type definition in scope type_match = FRegex.DEF_KIND.match(self.get_desc(no_link=True)) if type_match is not None: @@ -1860,51 +1851,55 @@ def check_definition(self, obj_tree, known_types={}, interface=False): interface=interface, ) if type_def is None: - type_defs = find_in_workspace( - obj_tree, - desc_obj_name, - filter_public=True, - exact_match=True, + self._check_definition_type_def( + obj_tree, desc_obj_name, known_types, type_match ) - known_types[desc_obj_name] = None - var_type = type_match.group(1).strip().lower() - filter_id = VAR_TYPE_ID - if (var_type == "class") or (var_type == "type"): - filter_id = CLASS_TYPE_ID - for type_def in type_defs: - if type_def.get_type() == filter_id: - known_types[desc_obj_name] = (1, type_def) - break else: known_types[desc_obj_name] = (0, type_def) type_info = known_types[desc_obj_name] - if type_info is not None: - if type_info[0] == 1: - if interface: - out_diag = Diagnostic( - self.sline - 1, - message=( - f'Object "{desc_obj_name}" not imported in interface' - ), - severity=1, - find_word=desc_obj_name, - ) - else: - out_diag = Diagnostic( - self.sline - 1, - message=f'Object "{desc_obj_name}" not found in scope', - severity=1, - find_word=desc_obj_name, - ) - type_def = type_info[1] - out_diag.add_related( - path=type_def.file_ast.path, - line=type_def.sline - 1, - message="Possible object", - ) - return out_diag, known_types + if type_info is not None and type_info[0] == 1: + if interface: + out_diag = Diagnostic( + self.sline - 1, + message=f'Object "{desc_obj_name}" not imported in interface', + severity=1, + find_word=desc_obj_name, + ) + else: + out_diag = Diagnostic( + self.sline - 1, + message=f'Object "{desc_obj_name}" not found in scope', + severity=1, + find_word=desc_obj_name, + ) + type_def = type_info[1] + out_diag.add_related( + path=type_def.file_ast.path, + line=type_def.sline - 1, + message="Possible object", + ) + return out_diag, known_types return None, known_types + def _check_definition_type_def( + self, obj_tree, desc_obj_name, known_types, type_match + ): + type_defs = find_in_workspace( + obj_tree, + desc_obj_name, + filter_public=True, + exact_match=True, + ) + known_types[desc_obj_name] = None + var_type = type_match.group(1).strip().lower() + filter_id = VAR_TYPE_ID + if var_type in ["class", "type"]: + filter_id = CLASS_TYPE_ID + for type_def in type_defs: + if type_def.get_type() == filter_id: + known_types[desc_obj_name] = (1, type_def) + break + class Method(Variable): # i.e. TypeBound procedure def __init__( @@ -1947,10 +1942,7 @@ def set_parent(self, parent_obj): def get_snippet(self, name_replace=None, drop_arg=-1): if self.link_obj is not None: - if name_replace is None: - name = self.name - else: - name = name_replace + name = self.name if name_replace is None else name_replace return self.link_obj.get_snippet(name, self.drop_arg) return None, None @@ -2000,7 +1992,7 @@ def get_signature(self, drop_arg=-1): return call_sig, self.get_documentation(), arg_sigs return None, None, None - def get_interface(self, name_replace=None, change_arg=-1, change_strings=None): + def get_interface(self, name_replace=None, drop_arg=-1, change_strings=None): if self.link_obj is not None: return self.link_obj.get_interface( name_replace, self.drop_arg, change_strings @@ -2027,7 +2019,9 @@ def resolve_link(self, obj_tree): def is_callable(self): return True - def check_definition(self, obj_tree, known_types={}, interface=False): + def check_definition(self, obj_tree, known_types=None, interface=False): + if known_types is None: + known_types = {} return None, known_types @@ -2070,9 +2064,7 @@ def create_none_scope(self): def get_enc_scope_name(self): """Get current enclosing scope name""" - if self.current_scope is None: - return None - return self.current_scope.FQSN + return None if self.current_scope is None else self.current_scope.FQSN def add_scope( self, @@ -2089,12 +2081,11 @@ def add_scope( if self.current_scope is None: if req_container: self.create_none_scope() - new_scope.FQSN = self.none_scope.FQSN + "::" + new_scope.name.lower() + new_scope.FQSN = f"{self.none_scope.FQSN}::{new_scope.name.lower()}" self.current_scope.add_child(new_scope) self.scope_stack.append(self.current_scope) - else: - if exportable: - self.global_dict[new_scope.FQSN] = new_scope + elif exportable: + self.global_dict[new_scope.FQSN] = new_scope else: self.current_scope.add_child(new_scope) self.scope_stack.append(self.current_scope) @@ -2128,7 +2119,7 @@ def end_scope(self, line_number: int, check: bool = True): def add_variable(self, new_var: Variable): if self.current_scope is None: self.create_none_scope() - new_var.FQSN = self.none_scope.FQSN + "::" + new_var.name.lower() + new_var.FQSN = f"{self.none_scope.FQSN}::{new_var.name.lower()}" self.current_scope.add_child(new_var) self.variable_list.append(new_var) if new_var.is_external: @@ -2144,10 +2135,10 @@ def add_int_member(self, key): self.current_scope.add_member(key) def add_private(self, name: str): - self.private_list.append(self.enc_scope_name + "::" + name) + self.private_list.append(f"{self.enc_scope_name}::{name}") def add_public(self, name: str): - self.public_list.append(self.enc_scope_name + "::" + name) + self.public_list.append(f"{self.enc_scope_name}::{name}") def add_use(self, use_mod: Use | Import): if self.current_scope is None: @@ -2158,13 +2149,12 @@ def add_include(self, path: str, line_number: int): self.include_statements.append(IncludeInfo(line_number, path, None, [])) def add_doc(self, doc_string: str, forward: bool = False): - if doc_string == "": + if not doc_string: return if forward: self.pending_doc = doc_string - else: - if self.last_obj is not None: - self.last_obj.add_doc(doc_string) + elif self.last_obj is not None: + self.last_obj.add_doc(doc_string) def add_error(self, msg: str, sev: int, ln: int, sch: int, ech: int = None): """Add a Diagnostic error, encountered during parsing, for a range @@ -2213,26 +2203,27 @@ def get_scopes(self, line_number: int = None): if (line_number >= scope.sline) and (line_number <= scope.eline): if type(scope.parent) == Interface: for use_stmnt in scope.use: - if not type(use_stmnt) == Import: + if type(use_stmnt) != Import: continue # Exclude the parent and all other scopes if use_stmnt.import_type == ImportTypes.NONE: return [scope] scope_list.append(scope) - for ancestor in scope.get_ancestors(): - scope_list.append(ancestor) - if (len(scope_list) == 0) and (self.none_scope is not None): + scope_list.extend(iter(scope.get_ancestors())) + if scope_list or self.none_scope is None: + return scope_list + else: return [self.none_scope] - return scope_list def get_inner_scope(self, line_number: int): scope_sline = -1 curr_scope = None for scope in self.scope_list: - if scope.sline > scope_sline: - if (line_number >= scope.sline) and (line_number <= scope.eline): - curr_scope = scope - scope_sline = scope.sline + if scope.sline > scope_sline and ( + (line_number >= scope.sline) and (line_number <= scope.eline) + ): + curr_scope = scope + scope_sline = scope.sline if (curr_scope is None) and (self.none_scope is not None): return self.none_scope return curr_scope @@ -2271,7 +2262,7 @@ def resolve_includes(self, workspace, path: str = None): file_dir = os.path.dirname(self.path) for inc in self.include_statements: file_path = os.path.normpath(os.path.join(file_dir, inc.path)) - if path and not (path == file_path): + if path and path != file_path: continue parent_scope = self.get_inner_scope(inc.line_number) added_entities = inc.scope_objs