From a97e4eb7e2cc35e4f5cd9c3827744b9852b30465 Mon Sep 17 00:00:00 2001 From: Alexander Senier Date: Tue, 13 Aug 2024 15:12:14 +0000 Subject: [PATCH] Use PEP604 type annotations Ref. eng/recordflux/RecordFlux#1752 --- Makefile | 2 +- poetry.lock | 3 +- rflx/ada.py | 121 ++++++----- rflx/cli.py | 3 +- rflx/converter/iana.py | 13 +- rflx/error.py | 12 +- rflx/expr.py | 194 +++++++++--------- rflx/expr_proof.py | 40 ++-- rflx/fatal_error.py | 8 +- rflx/generator/allocator.py | 11 +- rflx/generator/common.py | 15 +- rflx/generator/generator.py | 11 +- rflx/generator/message.py | 3 +- rflx/generator/serializer.py | 7 +- rflx/generator/session.py | 58 +++--- rflx/graph.py | 3 +- rflx/identifier.py | 1 + rflx/integration.py | 19 +- rflx/ir.py | 164 +++++++-------- rflx/ls/lexer.py | 20 +- rflx/ls/model.py | 5 +- rflx/ls/server.py | 10 +- rflx/model/cache.py | 4 +- rflx/model/declaration.py | 16 +- rflx/model/message.py | 85 ++++---- rflx/model/session.py | 26 +-- rflx/model/statement.py | 17 +- rflx/model/top_level_declaration.py | 4 +- rflx/model/type_decl.py | 22 +- rflx/pyrflx/bitstring.py | 3 +- rflx/pyrflx/package.py | 3 +- rflx/pyrflx/pyrflx.py | 7 +- rflx/pyrflx/typevalue.py | 77 +++---- rflx/rapidflux/__init__.pyi | 17 +- rflx/specification/parser.py | 40 ++-- rflx/specification/style.py | 3 +- rflx/typing_.py | 16 +- rflx/validator.py | 36 ++-- rflx/version.py | 3 +- stubs/pydotplus.pyi | 16 +- stubs/z3.pyi | 13 +- tests/feature/__init__.py | 27 ++- tests/property/strategies.py | 28 ++- tests/tools/check_grammar_test.py | 3 +- .../check_unit_test_file_coverage_test.py | 13 +- tests/unit/cli_test.py | 6 +- tests/unit/generator/session_test.py | 6 +- tests/unit/identifier_test.py | 3 +- tests/unit/ls/server_test.py | 4 +- tests/unit/typing__test.py | 5 +- tests/unit/version_test.py | 4 +- tests/utils.py | 37 ++-- tools/check_doc.py | 19 +- tools/check_requirements.py | 5 +- tools/check_unit_test_file_coverage.py | 5 +- tools/extract_packets.py | 3 +- tools/rflxlexer.py | 4 +- 57 files changed, 646 insertions(+), 657 deletions(-) diff --git a/Makefile b/Makefile index e0821653a..0e7f7edcc 100644 --- a/Makefile +++ b/Makefile @@ -40,7 +40,7 @@ GNATCOLL_DIR = contrib/gnatcoll-bindings LANGKIT_DIR = contrib/langkit ADASAT_DIR = contrib/adasat -DEVUTILS_HEAD = 7d5514cdb8d103dd30d2e23453010ac97cec8c18 +DEVUTILS_HEAD = 11948234a10771fff43eaa6f04e312c8a0777535 GNATCOLL_HEAD = f988b2052d01310b830d63e86e19c8dc77d382a2 LANGKIT_HEAD = 07218ed24a932e747b98416ceae74e1620276ecc ADASAT_HEAD = f20d814a4b26508d0dd6ee48a92bc560b122dad5 diff --git a/poetry.lock b/poetry.lock index 837f5a542..461bb91bb 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1742,6 +1742,7 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -1778,7 +1779,7 @@ files = [ [[package]] name = "RecordFlux-devutils" -version = "0.1.dev72+g7d5514c" +version = "0.1.dev73+g1194823" description = "Linter configs and custom checkers" optional = false python-versions = ">=3.8" diff --git a/rflx/ada.py b/rflx/ada.py index 417a42676..db0c5331e 100644 --- a/rflx/ada.py +++ b/rflx/ada.py @@ -8,7 +8,6 @@ from dataclasses import dataclass, field as dataclass_field from enum import Enum from sys import intern -from typing import Optional, Union from typing_extensions import Self @@ -407,7 +406,7 @@ def rflx_expr(self) -> expr.Variable: class Attribute(Name): - def __init__(self, prefix: Union[StrID, Expr]) -> None: + def __init__(self, prefix: StrID | Expr) -> None: if isinstance(prefix, ID): prefix = Variable(prefix) if isinstance(prefix, str): @@ -496,7 +495,7 @@ def _representation(self) -> str: class AttributeExpr(Attribute): def __init__( self, - prefix: Union[StrID, Expr], + prefix: StrID | Expr, expression: Expr, ) -> None: self.expression = expression @@ -520,7 +519,7 @@ class Succ(AttributeExpr): class BinAttributeExpr(Attribute): - def __init__(self, prefix: Union[StrID, Expr], left: Expr, right: Expr) -> None: + def __init__(self, prefix: StrID | Expr, left: Expr, right: Expr) -> None: self.left = left self.right = right super().__init__(prefix) @@ -539,7 +538,7 @@ class Max(BinAttributeExpr): class NamedAttributeExpr(Attribute): - def __init__(self, prefix: Union[StrID, Expr], *associations: tuple[StrID, Expr]) -> None: + def __init__(self, prefix: StrID | Expr, *associations: tuple[StrID, Expr]) -> None: self.associations = [(ID(n) if isinstance(n, str) else n, e) for n, e in associations] super().__init__(prefix) @@ -602,8 +601,8 @@ class Call(Name): def __init__( self, identifier: StrID, - arguments: Optional[Sequence[Expr]] = None, - named_arguments: Optional[Mapping[ID, Expr]] = None, + arguments: Sequence[Expr] | None = None, + named_arguments: Mapping[ID, Expr] | None = None, ) -> None: self.identifier = ID(identifier) self.arguments = arguments or [] @@ -680,7 +679,7 @@ def rflx_expr(self) -> expr.String: class NamedAggregate(Expr): - def __init__(self, *elements: tuple[Union[StrID, Expr], Expr]) -> None: + def __init__(self, *elements: tuple[StrID | Expr, Expr]) -> None: super().__init__() self.elements = [(ID(n) if isinstance(n, str) else n, e) for n, e in elements] @@ -700,7 +699,7 @@ def precedence(self) -> Precedence: return Precedence.LITERAL def rflx_expr(self) -> expr.NamedAggregate: - elements: list[tuple[Union[ID, expr.Expr], expr.Expr]] = [ + elements: list[tuple[ID | expr.Expr, expr.Expr]] = [ ( n if isinstance(n, ID) else n.rflx_expr(), e.rflx_expr(), @@ -774,7 +773,7 @@ def symbol(self) -> str: def If( # noqa: N802 condition_expressions: Sequence[tuple[Expr, Expr]], - else_expression: Optional[Expr] = None, + else_expression: Expr | None = None, ) -> Expr: if len(condition_expressions) == 0 and else_expression is not None: return else_expression @@ -786,7 +785,7 @@ def If( # noqa: N802 def IfThenElse( # noqa: N802 condition: Expr, then_expr: Expr, - else_expr: Optional[Expr] = None, + else_expr: Expr | None = None, ) -> Expr: return If([(condition, then_expr)], else_expr) @@ -795,7 +794,7 @@ class IfExpr(Expr): def __init__( self, condition_expressions: Sequence[tuple[Expr, Expr]], - else_expression: Optional[Expr] = None, + else_expression: Expr | None = None, ) -> None: super().__init__() self.condition_expressions = condition_expressions @@ -950,7 +949,7 @@ def keyword(self) -> str: class ValueRange(Expr): - def __init__(self, lower: Expr, upper: Expr, type_identifier: Optional[StrID] = None): + def __init__(self, lower: Expr, upper: Expr, type_identifier: StrID | None = None): super().__init__() self.lower = lower self.upper = upper @@ -1010,7 +1009,7 @@ def rflx_expr(self) -> expr.Expr: class Raise(Expr): - def __init__(self, identifier: StrID, string: Optional[Expr] = None) -> None: + def __init__(self, identifier: StrID, string: Expr | None = None) -> None: super().__init__() self.identifier = ID(identifier) self.string = string @@ -1184,7 +1183,7 @@ def input_values(values: Sequence[StrID]) -> str: class AlwaysTerminates(Aspect): - def __init__(self, expression: Optional[Expr] = None) -> None: + def __init__(self, expression: Expr | None = None) -> None: self.expression = expression @property @@ -1430,7 +1429,7 @@ class FormalSubprogramDeclaration(FormalDeclaration): def __init__( self, specification: SubprogramSpecification, - default: Optional[StrID] = None, + default: StrID | None = None, ) -> None: self.specification = specification self.default = ID(default) if default else None @@ -1448,7 +1447,7 @@ def __init__( self, identifier: StrID, generic_identifier: StrID, - associations: Optional[Sequence[StrID]] = None, + associations: Sequence[StrID] | None = None, ) -> None: self.identifier = ID(identifier) self.generic_identifier = ID(generic_identifier) @@ -1468,10 +1467,10 @@ class PackageDeclaration(Declaration): def __init__( self, identifier: StrID, - declarations: Optional[Sequence[Declaration]] = None, - private_declarations: Optional[Sequence[Declaration]] = None, - formal_parameters: Optional[Sequence[FormalDeclaration]] = None, - aspects: Optional[Sequence[Aspect]] = None, + declarations: Sequence[Declaration] | None = None, + private_declarations: Sequence[Declaration] | None = None, + formal_parameters: Sequence[FormalDeclaration] | None = None, + aspects: Sequence[Aspect] | None = None, ) -> None: self.identifier = ID(identifier) self.declarations = declarations or [] @@ -1493,9 +1492,9 @@ class PackageBody(Declaration): def __init__( self, identifier: StrID, - declarations: Optional[Sequence[Declaration]] = None, - statements: Optional[Sequence[Statement]] = None, - aspects: Optional[Sequence[Aspect]] = None, + declarations: Sequence[Declaration] | None = None, + statements: Sequence[Statement] | None = None, + aspects: Sequence[Aspect] | None = None, ) -> None: self.identifier = ID(identifier) self.declarations = declarations or [] @@ -1523,7 +1522,7 @@ def __init__( self, identifier: StrID, generic_package: StrID, - associations: Optional[Sequence[StrID]] = None, + associations: Sequence[StrID] | None = None, ) -> None: self.identifier = ID(identifier) self.generic_package = ID(generic_package) @@ -1555,11 +1554,11 @@ class ObjectDeclaration(Declaration): def __init__( # noqa: PLR0913 self, identifiers: Sequence[StrID], - type_identifier: Union[StrID, Expr], - expression: Optional[Expr] = None, + type_identifier: StrID | Expr, + expression: Expr | None = None, constant: bool = False, aliased: bool = False, - aspects: Optional[Sequence[Aspect]] = None, + aspects: Sequence[Aspect] | None = None, ) -> None: self.identifiers = list(map(ID, identifiers)) self.type_identifier = ( @@ -1589,7 +1588,7 @@ def __init__( self, identifiers: Sequence[StrID], type_identifier: StrID, - default: Optional[Expr] = None, + default: Expr | None = None, ) -> None: self.identifiers = list(map(ID, identifiers)) self.type_identifier = ID(type_identifier) @@ -1605,8 +1604,8 @@ class TypeDeclaration(Declaration, FormalDeclaration): def __init__( self, identifier: StrID, - discriminants: Optional[Sequence[Discriminant]] = None, - aspects: Optional[Sequence[Aspect]] = None, + discriminants: Sequence[Discriminant] | None = None, + aspects: Sequence[Aspect] | None = None, ) -> None: self.identifier = ID(identifier) self.discriminants = discriminants @@ -1640,7 +1639,7 @@ def __init__( self, identifier: StrID, modulus: Expr, - aspects: Optional[Sequence[Aspect]] = None, + aspects: Sequence[Aspect] | None = None, ) -> None: super().__init__(identifier, aspects=aspects or []) self.modulus = modulus @@ -1656,7 +1655,7 @@ def __init__( identifier: StrID, first: Expr, last: Expr, - aspects: Optional[Sequence[Aspect]] = None, + aspects: Sequence[Aspect] | None = None, ) -> None: super().__init__(identifier, aspects=aspects or []) self.first = first @@ -1671,8 +1670,8 @@ class EnumerationType(TypeDeclaration): def __init__( self, identifier: StrID, - literals: Mapping[ID, Optional[Number]], - size: Optional[Expr] = None, + literals: Mapping[ID, Number | None], + size: Expr | None = None, ) -> None: super().__init__(identifier, aspects=([SizeAspect(size)] if size else [])) self.literals = ( @@ -1709,7 +1708,7 @@ def __init__( self, identifier: StrID, base_identifier: StrID, - aspects: Optional[Sequence[Aspect]] = None, + aspects: Sequence[Aspect] | None = None, ) -> None: super().__init__(identifier, aspects=aspects) self.base_identifier = ID(base_identifier) @@ -1738,7 +1737,7 @@ def __init__( self, identifier: StrID, type_identifier: StrID, - record_extension: Optional[Sequence[Component]] = None, + record_extension: Sequence[Component] | None = None, ) -> None: super().__init__(identifier) self.type_identifier = ID(type_identifier) @@ -1805,8 +1804,8 @@ class Component(Base): def __init__( self, identifier: StrID, - type_identifier: Union[StrID, Expr], - default: Optional[Expr] = None, + type_identifier: StrID | Expr, + default: Expr | None = None, aliased: bool = False, ) -> None: self.identifier = ID(identifier) @@ -1856,9 +1855,9 @@ def __init__( # noqa: PLR0913 self, identifier: StrID, components: Sequence[Component], - discriminants: Optional[Sequence[Discriminant]] = None, - variant_part: Optional[VariantPart] = None, - aspects: Optional[Sequence[Aspect]] = None, + discriminants: Sequence[Discriminant] | None = None, + variant_part: VariantPart | None = None, + aspects: Sequence[Aspect] | None = None, abstract: bool = False, tagged: bool = False, limited: bool = False, @@ -1905,7 +1904,7 @@ def __str__(self) -> str: class Assignment(Statement): - def __init__(self, name: Union[StrID, Expr], expression: Expr) -> None: + def __init__(self, name: StrID | Expr, expression: Expr) -> None: self.name = name if isinstance(name, Expr) else Variable(name) self.expression = expression @@ -1917,8 +1916,8 @@ class CallStatement(Statement): def __init__( self, identifier: StrID, - arguments: Optional[Sequence[Expr]] = None, - named_arguments: Optional[Mapping[ID, Expr]] = None, + arguments: Sequence[Expr] | None = None, + named_arguments: Mapping[ID, Expr] | None = None, ) -> None: self.identifier = ID(identifier) self.arguments = arguments or [] @@ -1951,7 +1950,7 @@ def __str__(self) -> str: class ReturnStatement(Statement): - def __init__(self, expression: Optional[Expr] = None) -> None: + def __init__(self, expression: Expr | None = None) -> None: self.expression = expression def __str__(self) -> str: @@ -1963,7 +1962,7 @@ def __str__(self) -> str: class ExitStatement(Statement): - def __init__(self, expression: Optional[Expr] = None) -> None: + def __init__(self, expression: Expr | None = None) -> None: self.expression = expression def __str__(self) -> str: @@ -2003,7 +2002,7 @@ class IfStatement(Statement): def __init__( self, condition_statements: Sequence[tuple[Expr, Sequence[Statement]]], - else_statements: Optional[Sequence[Statement]] = None, + else_statements: Sequence[Statement] | None = None, ) -> None: assert condition_statements or else_statements self.condition_statements = condition_statements @@ -2126,7 +2125,7 @@ def iterator_spec(self) -> str: class RaiseStatement(Statement): - def __init__(self, identifier: StrID, string: Optional[Expr] = None) -> None: + def __init__(self, identifier: StrID, string: Expr | None = None) -> None: super().__init__() self.identifier = ID(identifier) self.string = string @@ -2156,7 +2155,7 @@ def __init__( self, identifiers: Sequence[StrID], type_identifier: StrID, - default: Optional[Expr] = None, + default: Expr | None = None, ) -> None: self.identifiers = list(map(ID, identifiers)) self.type_identifier = ID(type_identifier) @@ -2189,7 +2188,7 @@ def __init__( self, identifiers: Sequence[StrID], type_identifier: StrID, - default: Optional[Expr] = None, + default: Expr | None = None, constant: bool = False, ) -> None: super().__init__(identifiers, type_identifier, default) @@ -2201,7 +2200,7 @@ def mode(self) -> str: class SubprogramSpecification(Base): - def __init__(self, identifier: StrID, parameters: Optional[Sequence[Parameter]] = None) -> None: + def __init__(self, identifier: StrID, parameters: Sequence[Parameter] | None = None) -> None: self.identifier = ID(identifier) self.parameters = parameters or [] @@ -2226,7 +2225,7 @@ def __init__( self, identifier: StrID, return_type: StrID, - parameters: Optional[Sequence[Parameter]] = None, + parameters: Sequence[Parameter] | None = None, ) -> None: super().__init__(identifier, parameters) self.return_type = ID(return_type) @@ -2254,8 +2253,8 @@ class SubprogramDeclaration(Subprogram): def __init__( self, specification: SubprogramSpecification, - aspects: Optional[Sequence[Aspect]] = None, - formal_parameters: Optional[Sequence[FormalDeclaration]] = None, + aspects: Sequence[Aspect] | None = None, + formal_parameters: Sequence[FormalDeclaration] | None = None, abstract: bool = False, ) -> None: super().__init__(specification) @@ -2277,7 +2276,7 @@ def __init__( specification: SubprogramSpecification, declarations: Sequence[Declaration], statements: Sequence[Statement], - aspects: Optional[Sequence[Aspect]] = None, + aspects: Sequence[Aspect] | None = None, ) -> None: super().__init__(specification) self.declarations = declarations or [] @@ -2306,7 +2305,7 @@ def __init__( self, specification: FunctionSpecification, expression: Expr, - aspects: Optional[Sequence[Aspect]] = None, + aspects: Sequence[Aspect] | None = None, ) -> None: super().__init__(specification) self.expression = expression @@ -2322,7 +2321,7 @@ def __init__( self, identifier: StrID, specification: ProcedureSpecification, - associations: Optional[Sequence[StrID]] = None, + associations: Sequence[StrID] | None = None, ) -> None: super().__init__(specification) self.identifier = ID(identifier) @@ -2344,7 +2343,7 @@ def __init__( self, identifier: StrID, specification: FunctionSpecification, - associations: Optional[Sequence[StrID]] = None, + associations: Sequence[StrID] | None = None, ) -> None: super().__init__(specification) self.identifier = ID(identifier) @@ -2375,7 +2374,7 @@ def __str__(self) -> str: class Pragma(Declaration, ContextItem): - def __init__(self, identifier: StrID, parameters: Optional[Sequence[Expr]] = None) -> None: + def __init__(self, identifier: StrID, parameters: Sequence[Expr] | None = None) -> None: super().__init__(identifier) self.pragma_parameters = parameters or [] @@ -2515,7 +2514,7 @@ class SubprogramUnitPart: statements: list[Statement] = dataclass_field(default_factory=list) -def generic_formal_part(parameters: Optional[Sequence[FormalDeclaration]] = None) -> str: +def generic_formal_part(parameters: Sequence[FormalDeclaration] | None = None) -> str: if parameters is None: return "" return ( diff --git a/rflx/cli.py b/rflx/cli.py index 4b68d51e6..2e4e8c298 100644 --- a/rflx/cli.py +++ b/rflx/cli.py @@ -11,7 +11,6 @@ from enum import Enum from multiprocessing import cpu_count from pathlib import Path -from typing import Optional import importlib_resources from importlib_resources.abc import Traversable @@ -538,7 +537,7 @@ def parse( no_caching: bool, no_verification: bool, workers: int = 1, - integration_files_dir: Optional[Path] = None, + integration_files_dir: Path | None = None, ) -> tuple[Model, Integration]: parser = Parser( cache(no_caching, no_verification), diff --git a/rflx/converter/iana.py b/rflx/converter/iana.py index 9fec94211..2a1e2bcb1 100644 --- a/rflx/converter/iana.py +++ b/rflx/converter/iana.py @@ -7,7 +7,6 @@ from collections.abc import Sequence from datetime import datetime from pathlib import Path -from typing import Optional, Union from xml.etree.ElementTree import Element, ParseError from defusedxml import ElementTree @@ -90,8 +89,8 @@ def write_rflx_specification( file: Path, package_entries: list[SpecificationElement], package_name: str, - registry_title: Optional[Element], - registry_last_updated: Optional[Element], + registry_title: Element | None, + registry_last_updated: Element | None, reproducible: bool, ) -> None: package_header: list[CommentBlock] = [] @@ -133,7 +132,7 @@ def resolve_duplicate_literals(enum_types: Sequence[EnumType]) -> None: def _convert_registry_to_enum_type( registry: Element, always_valid: bool, -) -> Optional[EnumType]: +) -> EnumType | None: records = registry.findall("iana:record", NAMESPACE) title = registry.find("iana:title", NAMESPACE) if len(records) == 0 or title is None or title.text is None: @@ -201,7 +200,7 @@ def _convert_registry_to_enum_type( return None -def _get_name_tag(record: Element) -> Optional[str]: +def _get_name_tag(record: Element) -> str | None: sub_elements = record.findall("*", NAMESPACE) child_names = {c.tag[c.tag.index("}") + 1 :] for c in sub_elements} possible_name_tags = ["name", "code", "type", "description"] @@ -281,7 +280,7 @@ def __init__( rflx_name: str, rflx_value: str, bit_length: int, - comments: Optional[list[Element]] = None, + comments: list[Element] | None = None, ): self.name = rflx_name self.value = rflx_value @@ -368,7 +367,7 @@ def wrap_line(line: str) -> str: def _normalize_name(description_text: str) -> str: - t: dict[str, Union[int, str, None]] = {c: " " for c in string.punctuation + "\n"} + t: dict[str, int | str | None] = {c: " " for c in string.punctuation + "\n"} name = description_text.translate(str.maketrans(t)) return "_".join([s[0].upper() + s[1:] for s in name.split()]) diff --git a/rflx/error.py b/rflx/error.py index 93242fd08..3f16f0533 100644 --- a/rflx/error.py +++ b/rflx/error.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import NoReturn, Optional, Sequence +from typing import NoReturn, Sequence from typing_extensions import TypeGuard @@ -16,7 +16,7 @@ def fail( message: str, severity: Severity = Severity.ERROR, - location: Optional[Location] = None, + location: Location | None = None, ) -> NoReturn: raise RecordFluxError([ErrorEntry(message, severity, location)]) @@ -24,26 +24,26 @@ def fail( def fatal_fail( message: str, severity: Severity = Severity.ERROR, - location: Optional[Location] = None, + location: Location | None = None, ) -> NoReturn: raise FatalError(str(RecordFluxError([ErrorEntry(message, severity, location)]))) def warn( message: str, - location: Optional[Location] = None, + location: Location | None = None, ) -> None: RecordFluxError([ErrorEntry(message, Severity.WARNING, location)]).print_messages() def info( message: str, - location: Optional[Location] = None, + location: Location | None = None, ) -> None: RecordFluxError([ErrorEntry(message, Severity.INFO, location)]).print_messages() def are_all_locations_present( - locations: Sequence[Optional[Location]], + locations: Sequence[Location | None], ) -> TypeGuard[Sequence[Location]]: return all(l is not None for l in locations) diff --git a/rflx/expr.py b/rflx/expr.py index 3eb8e7d97..22e2c2c33 100644 --- a/rflx/expr.py +++ b/rflx/expr.py @@ -11,7 +11,7 @@ from operator import itemgetter from pathlib import Path from sys import intern -from typing import TYPE_CHECKING, Final, Optional, Union +from typing import TYPE_CHECKING, Final from rflx import const, typing_ as rty from rflx.common import Base, indent, indent_next, unique @@ -42,7 +42,7 @@ class Expr(Base): def __init__( self, type_: rty.Type = rty.UNDEFINED, - location: Optional[Location] = None, + location: Location | None = None, ): self.type_ = type_ self.location = location @@ -98,7 +98,7 @@ def _check_type_subexpr(self) -> RecordFluxError: """Initialize and check the types of sub-expressions.""" raise NotImplementedError - def check_type(self, expected: Union[rty.Type, tuple[rty.Type, ...]]) -> RecordFluxError: + def check_type(self, expected: rty.Type | tuple[rty.Type, ...]) -> RecordFluxError: """Initialize and check the types of the expression and all sub-expressions.""" error = self._check_type_subexpr() error.extend( @@ -113,7 +113,7 @@ def check_type(self, expected: Union[rty.Type, tuple[rty.Type, ...]]) -> RecordF def check_type_instance( self, - expected: Union[type[rty.Type], tuple[type[rty.Type], ...]], + expected: type[rty.Type] | tuple[type[rty.Type], ...], ) -> RecordFluxError: """Initialize and check the types of the expression and all sub-expressions.""" error = self._check_type_subexpr() @@ -140,8 +140,8 @@ def findall(self, match: Callable[[Expr], bool]) -> Sequence[Expr]: def substituted( self, - func: Optional[Callable[[Expr], Expr]] = None, - mapping: Optional[Mapping[Name, Expr]] = None, + func: Callable[[Expr], Expr] | None = None, + mapping: Mapping[Name, Expr] | None = None, ) -> Expr: func = substitution(mapping or {}, func) return func(self) @@ -157,7 +157,7 @@ def parenthesized(self, expr: Expr) -> str: class Not(Expr): - def __init__(self, expr: Expr, location: Optional[Location] = None) -> None: + def __init__(self, expr: Expr, location: Location | None = None) -> None: super().__init__(rty.BOOLEAN, location) self.expr = expr @@ -185,8 +185,8 @@ def findall(self, match: Callable[[Expr], bool]) -> Sequence[Expr]: def substituted( self, - func: Optional[Callable[[Expr], Expr]] = None, - mapping: Optional[Mapping[Name, Expr]] = None, + func: Callable[[Expr], Expr] | None = None, + mapping: Mapping[Name, Expr] | None = None, ) -> Expr: func = substitution(mapping or {}, func) expr = func(self) @@ -234,7 +234,7 @@ def __init__( left: Expr, right: Expr, type_: rty.Type = rty.UNDEFINED, - location: Optional[Location] = None, + location: Location | None = None, ) -> None: super().__init__(type_, location) self.left = left @@ -275,8 +275,8 @@ def findall(self, match: Callable[[Expr], bool]) -> Sequence[Expr]: def substituted( self, - func: Optional[Callable[[Expr], Expr]] = None, - mapping: Optional[Mapping[Name, Expr]] = None, + func: Callable[[Expr], Expr] | None = None, + mapping: Mapping[Name, Expr] | None = None, ) -> Expr: func = substitution(mapping or {}, func) expr = func(self) @@ -302,7 +302,7 @@ def symbol(self) -> str: class AssExpr(Expr): - def __init__(self, *terms: Expr, location: Optional[Location] = None) -> None: + def __init__(self, *terms: Expr, location: Location | None = None) -> None: super().__init__(rty.UNDEFINED, location=location) self.terms = list(terms) @@ -345,8 +345,8 @@ def findall(self, match: Callable[[Expr], bool]) -> Sequence[Expr]: def substituted( self, - func: Optional[Callable[[Expr], Expr]] = None, - mapping: Optional[Mapping[Name, Expr]] = None, + func: Callable[[Expr], Expr] | None = None, + mapping: Mapping[Name, Expr] | None = None, ) -> Expr: func = substitution(mapping or {}, func) expr = func(self) @@ -466,7 +466,7 @@ def symbol(self) -> str: class BoolAssExpr(AssExpr): - def __init__(self, *terms: Expr, location: Optional[Location] = None) -> None: + def __init__(self, *terms: Expr, location: Location | None = None) -> None: super().__init__(*terms, location=location) self.type_ = rty.BOOLEAN @@ -572,7 +572,7 @@ def symbol(self) -> str: class Number(Expr): type_: rty.UniversalInteger - def __init__(self, value: int, base: int = 0, location: Optional[Location] = None) -> None: + def __init__(self, value: int, base: int = 0, location: Location | None = None) -> None: super().__init__(rty.UniversalInteger(ty.Bounds(value, value)), location) self.value = value self.base = base @@ -684,7 +684,7 @@ def simplified(self) -> Expr: class Neg(Expr): - def __init__(self, expr: Expr, location: Optional[Location] = None) -> None: + def __init__(self, expr: Expr, location: Location | None = None) -> None: super().__init__(expr.type_, location) self.expr = expr @@ -712,8 +712,8 @@ def findall(self, match: Callable[[Expr], bool]) -> Sequence[Expr]: def substituted( self, - func: Optional[Callable[[Expr], Expr]] = None, - mapping: Optional[Mapping[Name, Expr]] = None, + func: Callable[[Expr], Expr] | None = None, + mapping: Mapping[Name, Expr] | None = None, ) -> Expr: func = substitution(mapping or {}, func) expr = func(self) @@ -733,7 +733,7 @@ def simplified(self) -> Expr: class MathAssExpr(AssExpr): - def __init__(self, *terms: Expr, location: Optional[Location] = None) -> None: + def __init__(self, *terms: Expr, location: Location | None = None) -> None: super().__init__(*terms, location=location) common_type = rty.common_type([t.type_ for t in terms]) self.type_ = common_type if common_type != rty.UNDEFINED else rty.BASE_INTEGER @@ -815,7 +815,7 @@ def symbol(self) -> str: class MathBinExpr(BinExpr): - def __init__(self, left: Expr, right: Expr, location: Optional[Location] = None) -> None: + def __init__(self, left: Expr, right: Expr, location: Location | None = None) -> None: super().__init__(left, right, rty.common_type([left.type_, right.type_]), location) def _check_type_subexpr(self) -> RecordFluxError: @@ -937,7 +937,7 @@ def __init__( self, immutable: bool = False, type_: rty.Type = rty.UNDEFINED, - location: Optional[Location] = None, + location: Location | None = None, ) -> None: super().__init__(type_, location) self.immutable = immutable @@ -957,8 +957,8 @@ def representation(self) -> str: def substituted( self, - func: Optional[Callable[[Expr], Expr]] = None, - mapping: Optional[Mapping[Name, Expr]] = None, + func: Callable[[Expr], Expr] | None = None, + mapping: Mapping[Name, Expr] | None = None, ) -> Expr: if self.immutable: return self @@ -974,7 +974,7 @@ def __init__( self, identifier: StrID, type_: rty.Type = rty.UNDEFINED, - location: Optional[Location] = None, + location: Location | None = None, ) -> None: self.identifier = ID(identifier) super().__init__(immutable=False, type_=type_, location=location) @@ -1002,7 +1002,7 @@ def __init__( self, identifier: StrID, type_: rty.Type = rty.UNDEFINED, - location: Optional[Location] = None, + location: Location | None = None, ) -> None: self.identifier = ID(identifier) super().__init__(immutable=False, type_=type_, location=location) @@ -1034,9 +1034,9 @@ def variables(self) -> list[Variable]: def copy( self, - identifier: Optional[StrID] = None, - type_: Optional[rty.Type] = None, - location: Optional[Location] = None, + identifier: StrID | None = None, + type_: rty.Type | None = None, + location: Location | None = None, ) -> Literal: return self.__class__( ID(identifier) if identifier is not None else self.identifier, @@ -1051,7 +1051,7 @@ def __init__( identifier: StrID, immutable: bool = False, type_: rty.Type = rty.UNDEFINED, - location: Optional[Location] = None, + location: Location | None = None, ) -> None: self.identifier = ID(identifier) super().__init__(immutable, type_, location) @@ -1084,10 +1084,10 @@ def variables(self) -> list[Variable]: def copy( self, - identifier: Optional[StrID] = None, - immutable: Optional[bool] = None, - type_: Optional[rty.Type] = None, - location: Optional[Location] = None, + identifier: StrID | None = None, + immutable: bool | None = None, + type_: rty.Type | None = None, + location: Location | None = None, ) -> Variable: return self.__class__( ID(identifier) if identifier is not None else self.identifier, @@ -1110,7 +1110,7 @@ def copy( class Attribute(Name): - def __init__(self, prefix: Union[StrID, Expr]) -> None: + def __init__(self, prefix: StrID | Expr) -> None: if isinstance(prefix, ID): prefix = Variable(prefix, location=prefix.location) if isinstance(prefix, str): @@ -1139,8 +1139,8 @@ def findall(self, match: Callable[[Expr], bool]) -> Sequence[Expr]: def substituted( self, - func: Optional[Callable[[Expr], Expr]] = None, - mapping: Optional[Mapping[Name, Expr]] = None, + func: Callable[[Expr], Expr] | None = None, + mapping: Mapping[Name, Expr] | None = None, ) -> Expr: func = substitution(mapping or {}, func) expr = func(self) @@ -1158,7 +1158,7 @@ def variables(self) -> list[Variable]: class Size(Attribute): - def __init__(self, prefix: Union[StrID, Expr]) -> None: + def __init__(self, prefix: StrID | Expr) -> None: super().__init__(prefix) self.type_ = rty.UNIVERSAL_INTEGER @@ -1167,7 +1167,7 @@ def _check_type_subexpr(self) -> RecordFluxError: class Length(Attribute): - def __init__(self, prefix: Union[StrID, Expr]) -> None: + def __init__(self, prefix: StrID | Expr) -> None: super().__init__(prefix) self.type_ = rty.UNIVERSAL_INTEGER @@ -1176,7 +1176,7 @@ def _check_type_subexpr(self) -> RecordFluxError: class First(Attribute): - def __init__(self, prefix: Union[StrID, Expr]) -> None: + def __init__(self, prefix: StrID | Expr) -> None: super().__init__(prefix) self.type_ = rty.UNIVERSAL_INTEGER @@ -1185,7 +1185,7 @@ def _check_type_subexpr(self) -> RecordFluxError: class Last(Attribute): - def __init__(self, prefix: Union[StrID, Expr]) -> None: + def __init__(self, prefix: StrID | Expr) -> None: super().__init__(prefix) self.type_ = rty.UNIVERSAL_INTEGER @@ -1194,7 +1194,7 @@ def _check_type_subexpr(self) -> RecordFluxError: class ValidChecksum(Attribute): - def __init__(self, prefix: Union[StrID, Expr]) -> None: + def __init__(self, prefix: StrID | Expr) -> None: super().__init__(prefix) self.type_ = rty.BOOLEAN @@ -1207,7 +1207,7 @@ def representation(self) -> str: class Valid(Attribute): - def __init__(self, prefix: Union[StrID, Expr]) -> None: + def __init__(self, prefix: StrID | Expr) -> None: super().__init__(prefix) self.type_ = rty.BOOLEAN @@ -1218,7 +1218,7 @@ def _check_type_subexpr(self) -> RecordFluxError: class Present(Attribute): - def __init__(self, prefix: Union[StrID, Expr]) -> None: + def __init__(self, prefix: StrID | Expr) -> None: super().__init__(prefix) self.type_ = rty.BOOLEAN @@ -1239,7 +1239,7 @@ def _check_type_subexpr(self) -> RecordFluxError: class HasData(Attribute): - def __init__(self, prefix: Union[StrID, Expr]) -> None: + def __init__(self, prefix: StrID | Expr) -> None: super().__init__(prefix) self.type_ = rty.BOOLEAN @@ -1254,7 +1254,7 @@ def _check_type_subexpr(self) -> RecordFluxError: class Head(Attribute): def __init__( self, - prefix: Union[StrID, Expr], + prefix: StrID | Expr, type_: rty.Type = rty.UNDEFINED, ): super().__init__(prefix) @@ -1277,7 +1277,7 @@ def _check_type_subexpr(self) -> RecordFluxError: class Opaque(Attribute): - def __init__(self, prefix: Union[StrID, Expr]) -> None: + def __init__(self, prefix: StrID | Expr) -> None: super().__init__(prefix) self.type_ = rty.OPAQUE @@ -1297,7 +1297,7 @@ class Val(Attribute): def __init__( self, - prefix: Union[StrID, Expr], + prefix: StrID | Expr, expression: Expr, ) -> None: self.expression = expression @@ -1317,8 +1317,8 @@ def findall(self, match: Callable[[Expr], bool]) -> Sequence[Expr]: def substituted( self, - func: Optional[Callable[[Expr], Expr]] = None, # noqa: ARG002 - mapping: Optional[Mapping[Name, Expr]] = None, # noqa: ARG002 + func: Callable[[Expr], Expr] | None = None, # noqa: ARG002 + mapping: Mapping[Name, Expr] | None = None, # noqa: ARG002 ) -> Expr: return self @@ -1357,7 +1357,7 @@ def __init__( selector: StrID, immutable: bool = False, type_: rty.Type = rty.UNDEFINED, - location: Optional[Location] = None, + location: Location | None = None, ) -> None: self.prefix = prefix self.selector = ID(selector) @@ -1407,8 +1407,8 @@ def variables(self) -> list[Variable]: def substituted( self, - func: Optional[Callable[[Expr], Expr]] = None, - mapping: Optional[Mapping[Name, Expr]] = None, + func: Callable[[Expr], Expr] | None = None, + mapping: Mapping[Name, Expr] | None = None, ) -> Expr: func = substitution(mapping or {}, func) expr = func(self) @@ -1424,11 +1424,11 @@ def substituted( def copy( self, - prefix: Optional[Expr] = None, - selector: Optional[StrID] = None, - immutable: Optional[bool] = None, - type_: Optional[rty.Type] = None, - location: Optional[Location] = None, + prefix: Expr | None = None, + selector: StrID | None = None, + immutable: bool | None = None, + type_: rty.Type | None = None, + location: Location | None = None, ) -> Selected: return self.__class__( prefix if prefix is not None else self.prefix, @@ -1444,10 +1444,10 @@ def __init__( # noqa: PLR0913 self, identifier: StrID, type_: rty.Type, - args: Optional[Sequence[Expr]] = None, + args: Sequence[Expr] | None = None, immutable: bool = False, - argument_types: Optional[Sequence[rty.Type]] = None, - location: Optional[Location] = None, + argument_types: Sequence[rty.Type] | None = None, + location: Location | None = None, ) -> None: self.identifier = ID(identifier) self.args = args or [] @@ -1505,8 +1505,8 @@ def findall(self, match: Callable[[Expr], bool]) -> Sequence[Expr]: def substituted( self, - func: Optional[Callable[[Expr], Expr]] = None, - mapping: Optional[Mapping[Name, Expr]] = None, + func: Callable[[Expr], Expr] | None = None, + mapping: Mapping[Name, Expr] | None = None, ) -> Expr: func = substitution(mapping or {}, func) expr = func(self) @@ -1557,7 +1557,7 @@ def _check_type_subexpr(self) -> RecordFluxError: class Aggregate(Expr): - def __init__(self, *elements: Expr, location: Optional[Location] = None) -> None: + def __init__(self, *elements: Expr, location: Location | None = None) -> None: super().__init__(rty.Aggregate(rty.common_type([e.type_ for e in elements])), location) self.elements = list(elements) @@ -1587,8 +1587,8 @@ def precedence(self) -> Precedence: def substituted( self, - func: Optional[Callable[[Expr], Expr]] = None, - mapping: Optional[Mapping[Name, Expr]] = None, + func: Callable[[Expr], Expr] | None = None, + mapping: Mapping[Name, Expr] | None = None, ) -> Expr: func = substitution(mapping or {}, func) expr = func(self) @@ -1608,7 +1608,7 @@ def length(self) -> Expr: class String(Aggregate): - def __init__(self, data: str, location: Optional[Location] = None) -> None: + def __init__(self, data: str, location: Location | None = None) -> None: super().__init__(*[Number(ord(d)) for d in data], location=location) self.data = data @@ -1624,8 +1624,8 @@ def precedence(self) -> Precedence: def substituted( self, - func: Optional[Callable[[Expr], Expr]] = None, - mapping: Optional[Mapping[Name, Expr]] = None, + func: Callable[[Expr], Expr] | None = None, + mapping: Mapping[Name, Expr] | None = None, ) -> Expr: func = substitution(mapping or {}, func) return func(self) @@ -1637,7 +1637,7 @@ def simplified(self) -> Expr: class NamedAggregate(Expr): """Only used by code generator and therefore provides minimum functionality.""" - def __init__(self, *elements: tuple[Union[StrID, Expr], Expr]) -> None: + def __init__(self, *elements: tuple[StrID | Expr, Expr]) -> None: super().__init__() self.elements = [(ID(n) if isinstance(n, str) else n, e) for n, e in elements] @@ -1662,7 +1662,7 @@ def simplified(self) -> Expr: class Relation(BinExpr): - def __init__(self, left: Expr, right: Expr, location: Optional[Location] = None) -> None: + def __init__(self, left: Expr, right: Expr, location: Location | None = None) -> None: super().__init__(left, right, rty.BOOLEAN, location) @abstractmethod @@ -1874,7 +1874,7 @@ class IfExpr(Expr): def __init__( self, condition_expressions: Sequence[tuple[Expr, Expr]], - else_expression: Optional[Expr] = None, + else_expression: Expr | None = None, ) -> None: super().__init__( rty.common_type( @@ -1926,8 +1926,8 @@ def precedence(self) -> Precedence: def substituted( self, - func: Optional[Callable[[Expr], Expr]] = None, - mapping: Optional[Mapping[Name, Expr]] = None, + func: Callable[[Expr], Expr] | None = None, + mapping: Mapping[Name, Expr] | None = None, ) -> Expr: func = substitution(mapping or {}, func) expr = func(self) @@ -1964,7 +1964,7 @@ def __init__( parameter_identifier: StrID, iterable: Expr, predicate: Expr, - location: Optional[Location] = None, + location: Location | None = None, ) -> None: super().__init__(rty.BOOLEAN, location) self.parameter_identifier = ID(parameter_identifier) @@ -2018,8 +2018,8 @@ def variables(self) -> list[Variable]: def substituted( self, - func: Optional[Callable[[Expr], Expr]] = None, - mapping: Optional[Mapping[Name, Expr]] = None, + func: Callable[[Expr], Expr] | None = None, + mapping: Mapping[Name, Expr] | None = None, ) -> Expr: func = substitution(mapping or {}, func) expr = func(self) @@ -2080,7 +2080,7 @@ def keyword(self) -> str: class ValueRange(Expr): - def __init__(self, lower: Expr, upper: Expr, location: Optional[Location] = None): + def __init__(self, lower: Expr, upper: Expr, location: Location | None = None): super().__init__(rty.Any(), location) self.lower = lower self.upper = upper @@ -2103,8 +2103,8 @@ def precedence(self) -> Precedence: def substituted( self, - func: Optional[Callable[[Expr], Expr]] = None, - mapping: Optional[Mapping[Name, Expr]] = None, + func: Callable[[Expr], Expr] | None = None, + mapping: Mapping[Name, Expr] | None = None, ) -> Expr: func = substitution(mapping or {}, func) expr = func(self) @@ -2126,8 +2126,8 @@ def __init__( identifier: StrID, argument: Expr, type_: rty.Type = rty.UNDEFINED, - argument_types: Optional[Sequence[rty.Type]] = None, - location: Optional[Location] = None, + argument_types: Sequence[rty.Type] | None = None, + location: Location | None = None, ) -> None: super().__init__(type_, location) self.identifier = ID(identifier) @@ -2185,8 +2185,8 @@ def precedence(self) -> Precedence: def substituted( self, - func: Optional[Callable[[Expr], Expr]] = None, - mapping: Optional[Mapping[Name, Expr]] = None, + func: Callable[[Expr], Expr] | None = None, + mapping: Mapping[Name, Expr] | None = None, ) -> Expr: func = substitution(mapping or {}, func) expr = func(self) @@ -2252,7 +2252,7 @@ def __init__( sequence: Expr, selector: Expr, condition: Expr, - location: Optional[Location] = None, + location: Location | None = None, ) -> None: super().__init__(rty.Aggregate(selector.type_), location) self.iterator = ID(iterator) @@ -2300,8 +2300,8 @@ def simplified(self) -> Expr: def substituted( self, - func: Optional[Callable[[Expr], Expr]] = None, - mapping: Optional[Mapping[Name, Expr]] = None, + func: Callable[[Expr], Expr] | None = None, + mapping: Mapping[Name, Expr] | None = None, ) -> Expr: func = substitution(mapping or {}, func) expr = func(self) @@ -2335,7 +2335,7 @@ def __init__( identifier: StrID, field_values: Mapping[StrID, Expr], type_: rty.Type = rty.UNDEFINED, - location: Optional[Location] = None, + location: Location | None = None, ) -> None: super().__init__(type_, location) self.identifier = ID(identifier) @@ -2468,8 +2468,8 @@ def simplified(self) -> Expr: def substituted( self, - func: Optional[Callable[[Expr], Expr]] = None, - mapping: Optional[Mapping[Name, Expr]] = None, + func: Callable[[Expr], Expr] | None = None, + mapping: Mapping[Name, Expr] | None = None, ) -> Expr: func = substitution(mapping or {}, func) expr = func(self) @@ -2518,7 +2518,7 @@ def _check_for_missing_fields(self) -> RecordFluxError: def substitution( mapping: Mapping[Name, Expr], - func: Optional[Callable[[Expr], Expr]] = None, + func: Callable[[Expr], Expr] | None = None, ) -> Callable[[Expr], Expr]: assert not (mapping and func) if func: @@ -2564,8 +2564,8 @@ class CaseExpr(Expr): def __init__( self, expr: Expr, - choices: Sequence[tuple[Sequence[Union[ID, Number]], Expr]], - location: Optional[Location] = None, + choices: Sequence[tuple[Sequence[ID | Number], Expr]], + location: Location | None = None, ) -> None: super().__init__(rty.common_type([e.type_ for _, e in choices]), location) self.expr = expr @@ -2795,8 +2795,8 @@ def simplified(self) -> Expr: def substituted( self, - func: Optional[Callable[[Expr], Expr]] = None, - mapping: Optional[Mapping[Name, Expr]] = None, + func: Callable[[Expr], Expr] | None = None, + mapping: Mapping[Name, Expr] | None = None, ) -> Expr: func = substitution(mapping or {}, func) expr = func(self) @@ -2840,7 +2840,7 @@ def similar_fields( def _similar_field_names( field: ID, fields: Iterable[ID], - location: Optional[Location], + location: Location | None, ) -> list[ErrorEntry]: similar_flds = similar_fields(field, fields) if similar_flds: diff --git a/rflx/expr_proof.py b/rflx/expr_proof.py index d2b673033..e42c63d05 100644 --- a/rflx/expr_proof.py +++ b/rflx/expr_proof.py @@ -6,7 +6,7 @@ from dataclasses import dataclass from enum import Enum from functools import singledispatch -from typing import Final, Optional, Union +from typing import Final, Union import z3 @@ -19,6 +19,10 @@ PROVER_TIMEOUT: Final = 1800000 +# TODO(eng/recordflux/RecordFlux#1424): Replace with PEP604 union +ArithBoolRef = Union[z3.ArithRef, z3.BoolRef] + + class ProofResult(Enum): SAT = z3.sat UNSAT = z3.unsat @@ -29,14 +33,14 @@ class Proof: def __init__( self, expr: expr.Expr, - facts: Optional[Sequence[expr.Expr]] = None, + facts: Sequence[expr.Expr] | None = None, logic: str = "QF_NIA", ): self._expr = expr self._facts = facts or [] self._result = ProofResult.UNSAT self._logic = logic - self._unknown_reason: Optional[str] = None + self._unknown_reason: str | None = None solver = z3.SolverFor(self._logic) solver.set("timeout", PROVER_TIMEOUT) @@ -53,7 +57,7 @@ def result(self) -> ProofResult: return self._result @property - def error(self) -> list[tuple[str, Optional[Location]]]: + def error(self) -> list[tuple[str, Location | None]]: assert self._result != ProofResult.SAT if self._result == ProofResult.UNKNOWN: @@ -343,8 +347,8 @@ def _(expression: expr.Relation) -> z3.BoolRef: @singledispatch def _relation_operator( _: expr.Relation, - left: Union[z3.ArithRef, z3.BoolRef], # noqa: ARG001 - right: Union[z3.ArithRef, z3.BoolRef], # noqa: ARG001 + left: ArithBoolRef, # noqa: ARG001 + right: ArithBoolRef, # noqa: ARG001 ) -> object: raise NotImplementedError @@ -352,8 +356,8 @@ def _relation_operator( @_relation_operator.register def _( _: expr.Less, - left: Union[z3.ArithRef, z3.BoolRef], - right: Union[z3.ArithRef, z3.BoolRef], + left: ArithBoolRef, + right: ArithBoolRef, ) -> object: return operator.lt(left, right) @@ -361,8 +365,8 @@ def _( @_relation_operator.register def _( _: expr.LessEqual, - left: Union[z3.ArithRef, z3.BoolRef], - right: Union[z3.ArithRef, z3.BoolRef], + left: ArithBoolRef, + right: ArithBoolRef, ) -> object: return operator.le(left, right) @@ -370,8 +374,8 @@ def _( @_relation_operator.register def _( _: expr.Equal, - left: Union[z3.ArithRef, z3.BoolRef], - right: Union[z3.ArithRef, z3.BoolRef], + left: ArithBoolRef, + right: ArithBoolRef, ) -> object: return operator.eq(left, right) @@ -379,8 +383,8 @@ def _( @_relation_operator.register def _( _: expr.GreaterEqual, - left: Union[z3.ArithRef, z3.BoolRef], - right: Union[z3.ArithRef, z3.BoolRef], + left: ArithBoolRef, + right: ArithBoolRef, ) -> object: return operator.ge(left, right) @@ -388,8 +392,8 @@ def _( @_relation_operator.register def _( _: expr.Greater, - left: Union[z3.ArithRef, z3.BoolRef], - right: Union[z3.ArithRef, z3.BoolRef], + left: ArithBoolRef, + right: ArithBoolRef, ) -> object: return operator.gt(left, right) @@ -397,8 +401,8 @@ def _( @_relation_operator.register def _( _: expr.NotEqual, - left: Union[z3.ArithRef, z3.BoolRef], - right: Union[z3.ArithRef, z3.BoolRef], + left: ArithBoolRef, + right: ArithBoolRef, ) -> object: return operator.ne(left, right) diff --git a/rflx/fatal_error.py b/rflx/fatal_error.py index 0f8a99c9b..809f2308d 100644 --- a/rflx/fatal_error.py +++ b/rflx/fatal_error.py @@ -3,7 +3,7 @@ import sys import traceback import types -from typing import Callable, Final, Optional +from typing import Callable, Final from rflx.version import is_gnat_tracker_release, version @@ -32,9 +32,9 @@ def __enter__(self) -> None: def __exit__( self, - exc_type: Optional[type[BaseException]], - exc_value: Optional[BaseException], - tb: Optional[types.TracebackType], + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + tb: types.TracebackType | None, ) -> None: if exc_type is not None: self._output_func(fatal_error_message(self._unsafe)) diff --git a/rflx/generator/allocator.py b/rflx/generator/allocator.py index 160843f7d..6f76df1cc 100644 --- a/rflx/generator/allocator.py +++ b/rflx/generator/allocator.py @@ -3,7 +3,6 @@ from collections.abc import Sequence from dataclasses import dataclass from itertools import zip_longest -from typing import Optional from rflx import ir, typing_ as rty from rflx.ada import ( @@ -124,15 +123,15 @@ def get_global_slot_ptrs(self) -> list[ID]: def get_local_slot_ptrs(self) -> list[ID]: return [self._slot_name(s.slot_id) for s in self._numbered_slots if not s.global_] - def get_slot_ptr(self, location: Optional[Location]) -> ID: + def get_slot_ptr(self, location: Location | None) -> ID: assert location is not None slot_id: int = self._allocation_slots[location] return self._slot_name(slot_id) - def get_size(self, variable: Optional[ID] = None, state: Optional[ID] = None) -> int: + def get_size(self, variable: ID | None = None, state: ID | None = None) -> int: return self._integration.get_size(self._session.identifier, variable, state) - def is_externally_managed(self, location: Optional[Location]) -> bool: + def is_externally_managed(self, location: Location | None) -> bool: assert location is not None return location in self._externally_managed_buffers @@ -370,7 +369,7 @@ def _allocate_global_slots( return (slots, externally_managed_buffers) @staticmethod - def _scope(state: ir.State, var_id: ID) -> Optional[ID]: + def _scope(state: ir.State, var_id: ID) -> ID | None: """ Return the scope of the variable var_id. @@ -397,7 +396,7 @@ def _allocate_local_slots( @dataclass class AllocationRequirement: - location: Optional[Location] + location: Location | None size: int def determine_allocation_requirements( diff --git a/rflx/generator/common.py b/rflx/generator/common.py index 64e66ae9c..b9af75f66 100644 --- a/rflx/generator/common.py +++ b/rflx/generator/common.py @@ -5,7 +5,6 @@ import textwrap from collections.abc import Callable from dataclasses import dataclass -from typing import Optional from rflx import expr, expr_conv, ir, model, typing_ as rty from rflx.ada import ( @@ -674,7 +673,7 @@ def context_cursor_unchanged( def sufficient_space_for_field_condition( message_id: ID, field_name: Name, - size: Optional[Expr] = None, + size: Expr | None = None, ) -> Expr: if size is None: size = Call(message_id * "Field_Size", [Variable("Ctx"), field_name]) @@ -757,9 +756,9 @@ def field_condition_call( prefix: str, message: model.Message, field: model.Field, - value: Optional[Expr] = None, - aggregate: Optional[Expr] = None, - size: Optional[Expr] = None, + value: Expr | None = None, + aggregate: Expr | None = None, + size: Expr | None = None, ) -> Expr: package = prefix * message.identifier if value is None: @@ -823,7 +822,7 @@ def contains_function_name(refinement_package: ID, pdu: ID, sdu: ID, field: ID) def has_value_dependent_condition( message: model.Message, - field: Optional[model.Field] = None, + field: model.Field | None = None, ) -> bool: links = message.outgoing(field) if field else message.structure fields = [field] if field else message.fields @@ -842,7 +841,7 @@ def has_value_dependent_condition( def has_aggregate_dependent_condition( message: model.Message, - field: Optional[model.Field] = None, + field: model.Field | None = None, ) -> bool: links = message.outgoing(field) if field else message.structure fields = [field] if field else message.fields @@ -861,7 +860,7 @@ def has_aggregate_dependent_condition( def has_size_dependent_condition( message: model.Message, - field: Optional[model.Field] = None, + field: model.Field | None = None, ) -> bool: field_sizes = {expr.Size(f.name) for f in message.fields} links = message.outgoing(field) if field else message.structure diff --git a/rflx/generator/generator.py b/rflx/generator/generator.py index 6c8d10316..a4eab6697 100644 --- a/rflx/generator/generator.py +++ b/rflx/generator/generator.py @@ -6,7 +6,6 @@ from datetime import date from functools import cached_property from pathlib import Path -from typing import Optional from rflx import __version__, expr, expr_conv, typing_ as rty from rflx.ada import ( @@ -367,11 +366,11 @@ def _create_session(self, session: Session, integration: Integration) -> dict[ID def _create_unit( # noqa: PLR0913 prefix: str, identifier: ID, - declaration_context: Optional[abc.Sequence[ContextItem]] = None, - body_context: Optional[abc.Sequence[ContextItem]] = None, - formal_parameters: Optional[list[FormalDeclaration]] = None, - configuration_pragmas: Optional[abc.Sequence[Pragma]] = None, - aspects: Optional[abc.Sequence[Aspect]] = None, + declaration_context: abc.Sequence[ContextItem] | None = None, + body_context: abc.Sequence[ContextItem] | None = None, + formal_parameters: list[FormalDeclaration] | None = None, + configuration_pragmas: abc.Sequence[Pragma] | None = None, + aspects: abc.Sequence[Aspect] | None = None, terminating: bool = True, ) -> PackageUnit: declaration_context = declaration_context if declaration_context else [] diff --git a/rflx/generator/message.py b/rflx/generator/message.py index 955f48a99..da239cb59 100644 --- a/rflx/generator/message.py +++ b/rflx/generator/message.py @@ -1,7 +1,6 @@ from __future__ import annotations from collections import abc -from typing import Union from rflx import expr, expr_conv, typing_ as rty from rflx.ada import ( @@ -3417,7 +3416,7 @@ def _create_structure_type(prefix: str, message: Message) -> UnitPart: type_ = message.field_types[link.target] - component_type: Union[ID, Expr] + component_type: ID | Expr if isinstance(type_, Scalar): component_type = common.prefixed_type_identifier(type_.identifier, prefix) diff --git a/rflx/generator/serializer.py b/rflx/generator/serializer.py index 161ebe1a0..f9ee1e2a0 100644 --- a/rflx/generator/serializer.py +++ b/rflx/generator/serializer.py @@ -2,7 +2,6 @@ from collections.abc import Mapping from enum import Enum -from typing import Optional from rflx import expr, expr_conv, typing_ as rty from rflx.ada import ( @@ -1720,7 +1719,7 @@ def composite_setter_preconditions( self, message: Message, field: Field, - size: Optional[Expr] = None, + size: Expr | None = None, ) -> list[Expr]: return [ common.sufficient_space_for_field_condition( @@ -1767,8 +1766,8 @@ def scalar_setter_and_getter_relation( @staticmethod def _update_last( - message: Optional[Message] = None, - field: Optional[Field] = None, + message: Message | None = None, + field: Field | None = None, ) -> list[Statement]: assert (message and field) or not (message or field) last = ( diff --git a/rflx/generator/session.py b/rflx/generator/session.py index e77026f80..fe4a296dc 100644 --- a/rflx/generator/session.py +++ b/rflx/generator/session.py @@ -4,7 +4,7 @@ from collections.abc import Callable, Iterable, Mapping, Sequence from dataclasses import dataclass, field as dataclass_field from functools import partial, singledispatchmethod -from typing import NoReturn, Optional, Union +from typing import NoReturn from typing_extensions import Self @@ -782,7 +782,7 @@ def _create_external_buffer_type( def _create_context_type( self, initial_state: ID, - global_variables: Mapping[ID, tuple[ID, Optional[Expr]]], + global_variables: Mapping[ID, tuple[ID, Expr | None]], has_functions: bool, ) -> UnitPart: return UnitPart( @@ -2890,7 +2890,7 @@ def _create_write_procedure( def _evaluate_declarations( self, declarations: Iterable[ir.VarDecl], - is_global: Optional[Callable[[ID], bool]] = None, + is_global: Callable[[ID], bool] | None = None, session_global: bool = False, ) -> EvaluatedDeclaration: if session_global: @@ -3002,8 +3002,8 @@ def _declare( # noqa: PLR0912, PLR0913 identifier: ID, type_: rty.Type, is_global: Callable[[ID], bool], - alloc_id: Optional[Location], - expression: Optional[ir.ComplexExpr] = None, + alloc_id: Location | None, + expression: ir.ComplexExpr | None = None, constant: bool = False, session_global: bool = False, ) -> EvaluatedDeclaration: @@ -3167,7 +3167,7 @@ def _assign( # noqa: PLR0913 exception_handler: ExceptionHandler, is_global: Callable[[ID], bool], state: ID, - alloc_id: Optional[Location], + alloc_id: Location | None, ) -> Sequence[Statement]: if isinstance(expression, ir.DeltaMsgAgg): return self._assign_to_delta_message_aggregate( @@ -3430,7 +3430,7 @@ def _assign_to_head( # noqa: PLR0913 exception_handler: ExceptionHandler, is_global: Callable[[ID], bool], state: ID, - alloc_id: Optional[Location], + alloc_id: Location | None, ) -> Sequence[Statement]: if not isinstance(head.type_, (rty.Integer, rty.Enumeration, rty.Message)): fatal_fail( @@ -3455,7 +3455,7 @@ def _assign_to_find( # noqa: PLR0913 exception_handler: ExceptionHandler, is_global: Callable[[ID], bool], state: ID, - alloc_id: Optional[Location], + alloc_id: Location | None, ) -> Sequence[Statement]: assert isinstance(find.sequence.type_, rty.Sequence) sequence_type_id = find.sequence.type_.identifier @@ -3571,7 +3571,7 @@ def _assign_to_head_sequence( # noqa: PLR0913 exception_handler: ExceptionHandler, is_global: Callable[[ID], bool], state: ID, - alloc_id: Optional[Location], + alloc_id: Location | None, ) -> Sequence[Statement]: assert isinstance(head.prefix_type, rty.Sequence) assert isinstance(head.type_, (rty.Integer, rty.Enumeration, rty.Message)) @@ -3765,7 +3765,7 @@ def _assign_to_comprehension( # noqa: PLR0913 exception_handler: ExceptionHandler, is_global: Callable[[ID], bool], state: ID, - alloc_id: Optional[Location], + alloc_id: Location | None, ) -> Sequence[Statement]: assert isinstance(comprehension.type_, (rty.Sequence, rty.Aggregate)) assert isinstance(comprehension.sequence.type_, rty.Sequence) @@ -4332,7 +4332,7 @@ def _append( def check( sequence_type: ID, required_space: Expr, - precondition: Optional[Expr] = None, + precondition: Expr | None = None, ) -> list[Statement]: return [ *( @@ -4478,7 +4478,7 @@ def _write( def _check( self, expression: ir.BoolExpr, - origin: Optional[ir.Origin], + origin: ir.Origin | None, exception_handler: ExceptionHandler, is_global: Callable[[ID], bool], ) -> Sequence[Statement]: @@ -4852,7 +4852,7 @@ def _(self, expression: ir.Agg, is_global: ty.Callable[[ID], bool]) -> Expr: @_to_ada_expr.register def _(self, expression: ir.NamedAgg, is_global: ty.Callable[[ID], bool]) -> Expr: - elements: list[tuple[Union[ID, ada.Expr], ada.Expr]] = [ + elements: list[tuple[ID | ada.Expr, ada.Expr]] = [ ( n if isinstance(n, ID) else self._to_ada_expr(n, is_global), self._to_ada_expr(e, is_global), @@ -5225,8 +5225,8 @@ def _set_opaque_field( # noqa: PLR0913 get_statements: Sequence[Statement], length: Expr, exception_handler: ExceptionHandler, - pre_declarations: Optional[Sequence[Declaration]] = None, - post_statements: Optional[Sequence[Statement]] = None, + pre_declarations: Sequence[Declaration] | None = None, + post_statements: Sequence[Statement] | None = None, ) -> Declare: pre_declarations = pre_declarations if pre_declarations else [] post_statements = post_statements if post_statements else [] @@ -5566,7 +5566,7 @@ def _declare_sequence_copy( # noqa: PLR0913 statements: Callable[[ExceptionHandler], Sequence[Statement]], exception_handler: ExceptionHandler, is_global: Callable[[ID], bool], - alloc_id: Optional[Location], + alloc_id: Location | None, ) -> list[Statement]: # Eng/RecordFlux/RecordFlux#577 sequence_context = context_id(sequence_identifier, is_global) @@ -5621,7 +5621,7 @@ def _declare_message_field_sequence_copy( # noqa: PLR0913 target_buffer_is_smaller: bool, exception_handler: ExceptionHandler, is_global: Callable[[ID], bool], - alloc_id: Optional[Location], + alloc_id: Location | None, ) -> Declare: # Eng/RecordFlux/RecordFlux#577 take_buffer = self._take_buffer(sequence_identifier, sequence_type, is_global) @@ -5678,7 +5678,7 @@ def _comprehension( # noqa: PLR0913 sequence_identifier: ID, sequence_type: ID, target_identifier: ID, - target_type: Union[rty.Sequence, rty.Integer, rty.Enumeration, rty.Message], + target_type: rty.Sequence | rty.Integer | rty.Enumeration | rty.Message, iterator_identifier: ID, iterator_type: ID, selector_stmts: list[ir.Stmt], @@ -5688,7 +5688,7 @@ def _comprehension( # noqa: PLR0913 exception_handler: ExceptionHandler, is_global: Callable[[ID], bool], state: ID, - alloc_id: Optional[Location], + alloc_id: Location | None, ) -> While: assert not isinstance(selector, ir.MsgAgg) @@ -5897,7 +5897,7 @@ def _comprehension( # noqa: PLR0913 def _comprehension_assign_element( # noqa: PLR0913 self, target_identifier: ID, - target_type: Union[rty.Integer, rty.Enumeration, rty.Message], + target_type: rty.Integer | rty.Enumeration | rty.Message, selector_stmts: list[ir.Stmt], selector: ir.Expr, update_context: Sequence[Statement], @@ -6086,7 +6086,7 @@ def _free_context_buffer( identifier: ID, type_: ID, is_global: Callable[[ID], bool], - alloc_id: Optional[Location], + alloc_id: Location | None, ) -> Sequence[Statement]: if self._allocator.is_externally_managed(alloc_id): return self._take_buffer(identifier, type_, is_global) @@ -6096,7 +6096,7 @@ def _free_context_buffer( *self._free_buffer(identifier, alloc_id), ] - def _free_buffer(self, identifier: ID, alloc_id: Optional[Location]) -> Sequence[Statement]: + def _free_buffer(self, identifier: ID, alloc_id: Location | None) -> Sequence[Statement]: slot = Variable("Ctx.P.Slots" * self._allocator.get_slot_ptr(alloc_id)) return [ PragmaStatement("Assert", [Equal(slot, Variable("null"))]), @@ -6116,7 +6116,7 @@ def _take_buffer( identifier: ID, type_: ID, is_global: Callable[[ID], bool], - buf: Optional[ID] = None, + buf: ID | None = None, ) -> Sequence[Statement]: context = context_id(identifier, is_global) buf = buf or buffer_id(identifier) @@ -6186,7 +6186,7 @@ def _update_context( ), ] - def _allocate_buffer(self, identifier: ID, alloc_id: Optional[Location]) -> Sequence[Statement]: + def _allocate_buffer(self, identifier: ID, alloc_id: Location | None) -> Sequence[Statement]: if self._allocator.is_externally_managed(alloc_id): return [] @@ -6204,11 +6204,11 @@ def _initialize_context( # noqa: PLR0913 identifier: ID, type_: ID, is_global: Callable[[ID], bool], - first: Optional[Expr] = None, - last: Optional[Expr] = None, - parameters: Optional[Mapping[ID, Expr]] = None, - written_last: Optional[Expr] = None, - buffer: Optional[Expr] = None, + first: Expr | None = None, + last: Expr | None = None, + parameters: Mapping[ID, Expr] | None = None, + written_last: Expr | None = None, + buffer: Expr | None = None, ) -> CallStatement: return CallStatement( type_ * "Initialize", diff --git a/rflx/graph.py b/rflx/graph.py index 21f54d736..0bb74f5bc 100644 --- a/rflx/graph.py +++ b/rflx/graph.py @@ -3,7 +3,6 @@ import re from collections.abc import Sequence from pathlib import Path -from typing import Optional from pydotplus import Dot, Edge, InvocationException, Node # type: ignore[attr-defined] @@ -100,7 +99,7 @@ def _edge_label(link: Link) -> str: return result -def create_session_graph(session: Session, ignore: Optional[Sequence[str]] = None) -> Dot: +def create_session_graph(session: Session, ignore: Sequence[str] | None = None) -> Dot: """ Return pydot graph representation of session. diff --git a/rflx/identifier.py b/rflx/identifier.py index 99b6d8827..84286ed5f 100644 --- a/rflx/identifier.py +++ b/rflx/identifier.py @@ -5,6 +5,7 @@ from rflx.rapidflux import ID as ID +# TODO(eng/recordflux/RecordFlux#1424): Replace with PEP604 union StrID = Union[str, ID] ID_PREFIX: Final = "T_" diff --git a/rflx/integration.py b/rflx/integration.py index 4d84cc9fc..787a3987a 100644 --- a/rflx/integration.py +++ b/rflx/integration.py @@ -24,19 +24,26 @@ # This is only relevant for Python 3.8. +# TODO(eng/recordflux/RecordFlux#1424): Replace remaining use of Optional +# and Union. Pydantic has issues with PEP604 type annotations in Python +# 3.8 and 3.9. + IntSize = Annotated[int, Gt(0)] class SessionSize(BaseModel): # type: ignore[misc] - default: Optional[IntSize] = Field(None, alias="Default") - global_: Optional[ty.Mapping[str, IntSize]] = Field(None, alias="Global") - local_: Optional[ty.Mapping[str, ty.Mapping[str, IntSize]]] = Field(None, alias="Local") + default: Optional[IntSize] = Field(None, alias="Default") # noqa: UP007 + global_: Optional[ty.Mapping[str, IntSize]] = Field(None, alias="Global") # noqa: UP007 + local_: Optional[ty.Mapping[str, ty.Mapping[str, IntSize]]] = Field( # noqa: UP007 + None, + alias="Local", + ) model_config = ConfigDict(extra="forbid") class SessionIntegration(BaseModel): # type: ignore[misc] - buffer_size: Optional[SessionSize] = Field(alias="Buffer_Size", default=None) + buffer_size: Optional[SessionSize] = Field(alias="Buffer_Size", default=None) # noqa: UP007 external_io_buffers: bool = Field(alias="External_IO_Buffers", default=False) model_config = ConfigDict(extra="forbid") @@ -53,7 +60,7 @@ class Integration: def defaultsize(self) -> int: return 4096 - def __init__(self, integration_files_dir: Optional[Path] = None) -> None: + def __init__(self, integration_files_dir: Path | None = None) -> None: self._packages: dict[str, IntegrationFile] = {} self._integration_files_dir = integration_files_dir @@ -102,7 +109,7 @@ def validate(self, model: Model, error: RecordFluxError) -> None: self._validate_globals(package, integration, session, error) self._validate_states(package, integration, session, error) - def get_size(self, session: ID, variable: Optional[ID], state: Optional[ID]) -> int: + def get_size(self, session: ID, variable: ID | None, state: ID | None) -> int: """ Return the requested buffer size for a variable of a given session and state. diff --git a/rflx/ir.py b/rflx/ir.py index 5bade88f4..4b944c2b1 100644 --- a/rflx/ir.py +++ b/rflx/ir.py @@ -8,7 +8,7 @@ from concurrent.futures import ProcessPoolExecutor from enum import Enum from sys import intern -from typing import TYPE_CHECKING, Optional, Protocol, TypeVar, Union +from typing import TYPE_CHECKING, Protocol, TypeVar import z3 from attr import define, field, frozen @@ -31,11 +31,11 @@ class Origin(Protocol): def __str__(self) -> str: ... # pragma: no cover @property - def location(self) -> Optional[Location]: ... # pragma: no cover + def location(self) -> Location | None: ... # pragma: no cover class ConstructedOrigin(Origin): - def __init__(self, string_representation: str, location: Optional[Location]) -> None: + def __init__(self, string_representation: str, location: Location | None) -> None: self._string_representation = string_representation self._location = location @@ -43,7 +43,7 @@ def __str__(self) -> str: return self._string_representation @property - def location(self) -> Optional[Location]: + def location(self) -> Location | None: return self._location @@ -105,7 +105,7 @@ def _check(job: ProofJob) -> ProofJob: class Cond(Base): - def __init__(self, goal: BoolExpr, facts: Optional[Sequence[Stmt]] = None) -> None: + def __init__(self, goal: BoolExpr, facts: Sequence[Stmt] | None = None) -> None: self._goal = goal self._facts = facts or [] @@ -122,8 +122,8 @@ def facts(self) -> Sequence[Stmt]: class Stmt(Base): """Statement in three-address code (TAC) format.""" - origin: Optional[Origin] - _str: Optional[str] = field(init=False, default=None) + origin: Origin | None + _str: str | None = field(init=False, default=None) def __eq__(self, other: object) -> bool: if isinstance(other, self.__class__): @@ -139,7 +139,7 @@ def __str__(self) -> str: return self._str # type: ignore[unreachable] @property - def location(self) -> Optional[Location]: + def location(self) -> Location | None: return self.origin.location if self.origin else None @property @@ -170,8 +170,8 @@ def _update_str(self) -> None: class VarDecl(Stmt): identifier: ID = field(converter=ID) type_: rty.NamedType - expression: Optional[ComplexExpr] = None - origin: Optional[Origin] = None + expression: ComplexExpr | None = None + origin: Origin | None = None @property def accessed_vars(self) -> list[ID]: @@ -202,7 +202,7 @@ class Assign(Stmt): target: ID = field(converter=ID) expression: Expr type_: rty.NamedType - origin: Optional[Origin] = None + origin: Origin | None = None @property def accessed_vars(self) -> list[ID]: @@ -252,7 +252,7 @@ class FieldAssign(Stmt): field: ID = field(converter=ID) expression: Expr type_: rty.Message - origin: Optional[Origin] = None + origin: Origin | None = None @property def accessed_vars(self) -> list[ID]: @@ -293,7 +293,7 @@ class Append(Stmt): sequence: ID = field(converter=ID) expression: Expr type_: rty.Sequence - origin: Optional[Origin] = None + origin: Origin | None = None @property def accessed_vars(self) -> list[ID]: @@ -326,7 +326,7 @@ class Extend(Stmt): sequence: ID = field(converter=ID) expression: Expr type_: rty.Sequence - origin: Optional[Origin] = None + origin: Origin | None = None @property def accessed_vars(self) -> list[ID]: @@ -355,7 +355,7 @@ class Reset(Stmt): identifier: ID = field(converter=ID) parameter_values: Mapping[ID, Expr] type_: rty.Any - origin: Optional[Origin] = None + origin: Origin | None = None @property def accessed_vars(self) -> list[ID]: @@ -389,7 +389,7 @@ def _update_str(self) -> None: class ChannelStmt(Stmt): channel: ID = field(converter=ID) expression: BasicExpr - origin: Optional[Origin] = None + origin: Origin | None = None @property def accessed_vars(self) -> list[ID]: @@ -425,10 +425,10 @@ class Write(ChannelStmt): @define(eq=False) class Check(Stmt): expression: BoolExpr - origin: Optional[Origin] = None + origin: Origin | None = None @property - def location(self) -> Optional[Location]: + def location(self) -> Location | None: return self.expression.location @property @@ -458,8 +458,8 @@ def _update_str(self) -> None: class Expr(Base): """Expression in three-address code (TAC) format.""" - origin: Optional[Origin] - _str: Optional[str] = field(init=False, default=None) + origin: Origin | None + _str: str | None = field(init=False, default=None) def __eq__(self, other: object) -> bool: if isinstance(other, self.__class__): @@ -486,7 +486,7 @@ def origin_str(self) -> str: return str(self) @property - def location(self) -> Optional[Location]: + def location(self) -> Location | None: return self.origin.location if self.origin else None @property @@ -548,7 +548,7 @@ class BasicBoolExpr(BasicExpr, BoolExpr): @define(eq=False) class Var(BasicExpr): identifier: ID = field(converter=ID) - origin: Optional[Origin] = None + origin: Origin | None = None @property def accessed_vars(self) -> list[ID]: @@ -562,7 +562,7 @@ def _update_str(self) -> None: class IntVar(Var, BasicIntExpr): identifier: ID = field(converter=ID) var_type: rty.AnyInteger - origin: Optional[Origin] = None + origin: Origin | None = None @property def type_(self) -> rty.AnyInteger: @@ -580,7 +580,7 @@ def to_z3_expr(self) -> z3.ArithRef: @define(eq=False) class BoolVar(Var, BasicBoolExpr): identifier: ID = field(converter=ID) - origin: Optional[Origin] = None + origin: Origin | None = None @property def type_(self) -> rty.Enumeration: @@ -599,7 +599,7 @@ def to_z3_expr(self) -> z3.BoolRef: class ObjVar(Var): identifier: ID = field(converter=ID) var_type: rty.Any - origin: Optional[Origin] = None + origin: Origin | None = None @property def type_(self) -> rty.Any: @@ -618,7 +618,7 @@ def to_z3_expr(self) -> z3.ExprRef: class EnumLit(BasicExpr): identifier: ID = field(converter=ID) enum_type: rty.Enumeration - origin: Optional[Origin] = None + origin: Origin | None = None @property def type_(self) -> rty.Enumeration: @@ -641,7 +641,7 @@ def _update_str(self) -> None: @define(eq=False) class IntVal(BasicIntExpr): value: int - origin: Optional[Origin] = None + origin: Origin | None = None @property def type_(self) -> rty.UniversalInteger: @@ -664,7 +664,7 @@ def _update_str(self) -> None: @define(eq=False) class BoolVal(BasicBoolExpr): value: bool - origin: Optional[Origin] = None + origin: Origin | None = None @property def accessed_vars(self) -> list[ID]: @@ -684,7 +684,7 @@ def _update_str(self) -> None: class Attr(Expr): prefix: ID = field(converter=ID) prefix_type: rty.Any - origin: Optional[Origin] = None + origin: Origin | None = None @property def accessed_vars(self) -> list[ID]: @@ -706,7 +706,7 @@ def _update_str(self) -> None: class IntAttr(Attr, IntExpr): prefix: ID = field(converter=ID) prefix_type: rty.Any - origin: Optional[Origin] = None + origin: Origin | None = None def substituted(self, mapping: Mapping[ID, ID]) -> IntAttr: return self.__class__( @@ -783,7 +783,7 @@ def to_z3_expr(self) -> z3.BoolRef: class Head(Attr): prefix: ID = field(converter=ID) prefix_type: rty.Composite - origin: Optional[Origin] = None + origin: Origin | None = None @property def type_(self) -> rty.Any: @@ -797,8 +797,8 @@ def to_z3_expr(self) -> z3.ExprRef: @define(eq=False) class Opaque(Attr): prefix: ID = field(converter=ID) - prefix_type: Union[rty.Message, rty.Sequence] - origin: Optional[Origin] = None + prefix_type: rty.Message | rty.Sequence + origin: Origin | None = None @property def type_(self) -> rty.Sequence: @@ -818,7 +818,7 @@ class FieldAccessAttr(Expr): message: ID = field(converter=ID) field: ID = field(converter=ID) message_type: rty.Compound - origin: Optional[Origin] = None + origin: Origin | None = None @property def accessed_vars(self) -> list[ID]: @@ -902,7 +902,7 @@ def _symbol(self) -> str: @define(eq=False) class UnaryExpr(Expr): expression: BasicExpr - origin: Optional[Origin] = None + origin: Origin | None = None @property def accessed_vars(self) -> list[ID]: @@ -915,7 +915,7 @@ def substituted(self, mapping: Mapping[ID, ID]) -> UnaryExpr: @define(eq=False) class UnaryIntExpr(UnaryExpr, IntExpr): expression: BasicIntExpr - origin: Optional[Origin] = None + origin: Origin | None = None @property def type_(self) -> rty.AnyInteger: @@ -925,14 +925,14 @@ def type_(self) -> rty.AnyInteger: @define(eq=False) class UnaryBoolExpr(UnaryExpr, BoolExpr): expression: BasicBoolExpr - origin: Optional[Origin] = None + origin: Origin | None = None @define(eq=False) class BinaryExpr(Expr): left: BasicExpr right: BasicExpr - origin: Optional[Origin] = None + origin: Origin | None = None @property def accessed_vars(self) -> list[ID]: @@ -952,7 +952,7 @@ def origin_str(self) -> str: return f"{self.left.origin_str}{self._symbol}{self.right.origin_str}" @property - def location(self) -> Optional[Location]: + def location(self) -> Location | None: if self.origin is not None: return self.origin.location if self.left.origin is not None: @@ -971,7 +971,7 @@ def _symbol(self) -> str: class BinaryIntExpr(BinaryExpr, IntExpr): left: BasicIntExpr right: BasicIntExpr - origin: Optional[Origin] = None + origin: Origin | None = None @abstractmethod def preconditions( @@ -992,7 +992,7 @@ def type_(self) -> rty.AnyInteger: class BinaryBoolExpr(BinaryExpr, BoolExpr): left: BasicBoolExpr right: BasicBoolExpr - origin: Optional[Origin] = None + origin: Origin | None = None @define(eq=False) @@ -1258,14 +1258,14 @@ def _symbol(self) -> str: class Relation(BoolExpr, BinaryExpr): left: BasicExpr right: BasicExpr - origin: Optional[Origin] = None + origin: Origin | None = None @define(eq=False) class Less(Relation): left: BasicIntExpr right: BasicIntExpr - origin: Optional[Origin] = None + origin: Origin | None = None def to_z3_expr(self) -> z3.BoolRef: return self.left.to_z3_expr() < self.right.to_z3_expr() @@ -1279,7 +1279,7 @@ def _symbol(self) -> str: class LessEqual(Relation): left: BasicIntExpr right: BasicIntExpr - origin: Optional[Origin] = None + origin: Origin | None = None def to_z3_expr(self) -> z3.BoolRef: return self.left.to_z3_expr() <= self.right.to_z3_expr() @@ -1303,7 +1303,7 @@ def _symbol(self) -> str: class GreaterEqual(Relation): left: BasicIntExpr right: BasicIntExpr - origin: Optional[Origin] = None + origin: Origin | None = None def to_z3_expr(self) -> z3.BoolRef: return self.left.to_z3_expr() >= self.right.to_z3_expr() @@ -1317,7 +1317,7 @@ def _symbol(self) -> str: class Greater(Relation): left: BasicIntExpr right: BasicIntExpr - origin: Optional[Origin] = None + origin: Origin | None = None def to_z3_expr(self) -> z3.BoolRef: return self.left.to_z3_expr() > self.right.to_z3_expr() @@ -1342,7 +1342,7 @@ class Call(Expr): identifier: ID = field(converter=ID) arguments: Sequence[Expr] argument_types: Sequence[rty.Any] - origin: Optional[Origin] = None + origin: Origin | None = None _preconditions: list[Cond] = field(init=False, factory=list) @property @@ -1371,7 +1371,7 @@ class IntCall(Call, IntExpr): arguments: Sequence[Expr] argument_types: Sequence[rty.Any] type_: rty.AnyInteger - origin: Optional[Origin] = None + origin: Origin | None = None def substituted(self, mapping: Mapping[ID, ID]) -> IntCall: return self.__class__( @@ -1416,7 +1416,7 @@ class ObjCall(Call): arguments: Sequence[Expr] argument_types: Sequence[rty.Any] type_: rty.Any - origin: Optional[Origin] = None + origin: Origin | None = None def substituted(self, mapping: Mapping[ID, ID]) -> ObjCall: return self.__class__( @@ -1436,7 +1436,7 @@ class FieldAccess(Expr): message: ID = field(converter=ID) field: ID = field(converter=ID) message_type: rty.Compound - origin: Optional[Origin] = None + origin: Origin | None = None @property def accessed_vars(self) -> list[ID]: @@ -1467,7 +1467,7 @@ class IntFieldAccess(FieldAccess, IntExpr): message: ID = field(converter=ID) field: ID = field(converter=ID) message_type: rty.Compound - origin: Optional[Origin] = None + origin: Origin | None = None @property def type_(self) -> rty.AnyInteger: @@ -1501,7 +1501,7 @@ class ObjFieldAccess(FieldAccess): message: ID = field(converter=ID) field: ID = field(converter=ID) message_type: rty.Compound - origin: Optional[Origin] = None + origin: Origin | None = None @property def type_(self) -> rty.Any: @@ -1526,7 +1526,7 @@ class IfExpr(Expr): condition: BasicBoolExpr then_expr: ComplexExpr else_expr: ComplexExpr - origin: Optional[Origin] = None + origin: Origin | None = None @property def accessed_vars(self) -> list[ID]: @@ -1562,7 +1562,7 @@ class IntIfExpr(IfExpr, IntExpr): then_expr: ComplexIntExpr else_expr: ComplexIntExpr return_type: rty.AnyInteger - origin: Optional[Origin] = None + origin: Origin | None = None @property def type_(self) -> rty.AnyInteger: @@ -1579,7 +1579,7 @@ class BoolIfExpr(IfExpr, BoolExpr): condition: BasicBoolExpr then_expr: ComplexBoolExpr else_expr: ComplexBoolExpr - origin: Optional[Origin] = None + origin: Origin | None = None @property def type_(self) -> rty.Enumeration: @@ -1595,7 +1595,7 @@ def to_z3_expr(self) -> z3.BoolRef: class Conversion(Expr): target_type: rty.NamedType argument: Expr - origin: Optional[Origin] = None + origin: Origin | None = None @property def type_(self) -> rty.Any: @@ -1626,7 +1626,7 @@ def _update_str(self) -> None: class IntConversion(Conversion, BasicIntExpr): target_type: rty.Integer argument: IntExpr - origin: Optional[Origin] = None + origin: Origin | None = None @property def type_(self) -> rty.Integer: @@ -1670,10 +1670,10 @@ def to_z3_expr(self) -> z3.ArithRef: @define(eq=False) class Comprehension(Expr): iterator: ID = field(converter=ID) - sequence: Union[Var, FieldAccess] + sequence: Var | FieldAccess selector: ComplexExpr condition: ComplexBoolExpr - origin: Optional[Origin] = None + origin: Origin | None = None @property def type_(self) -> rty.Aggregate: @@ -1717,10 +1717,10 @@ def _update_str(self) -> None: @define(eq=False) class Find(Expr): iterator: ID = field(converter=ID) - sequence: Union[Var, FieldAccess] + sequence: Var | FieldAccess selector: ComplexExpr condition: ComplexBoolExpr - origin: Optional[Origin] = None + origin: Origin | None = None @property def type_(self) -> rty.Any: @@ -1764,7 +1764,7 @@ def _update_str(self) -> None: @define(eq=False) class Agg(Expr): elements: Sequence[BasicExpr] - origin: Optional[Origin] = None + origin: Origin | None = None @property def type_(self) -> rty.Aggregate: @@ -1791,8 +1791,8 @@ def _update_str(self) -> None: def _named_agg_elements_converter( - elements: Sequence[tuple[Union[StrID, BasicExpr], BasicExpr]], -) -> Sequence[tuple[Union[ID, BasicExpr], BasicExpr]]: + elements: Sequence[tuple[StrID | BasicExpr, BasicExpr]], +) -> Sequence[tuple[ID | BasicExpr, BasicExpr]]: return [(ID(n) if isinstance(n, str) else n, e) for n, e in elements] @@ -1800,10 +1800,10 @@ def _named_agg_elements_converter( class NamedAgg(Expr): """Only used by code generator and therefore provides minimum functionality.""" - elements: Sequence[tuple[Union[ID, BasicExpr], BasicExpr]] = field( + elements: Sequence[tuple[ID | BasicExpr, BasicExpr]] = field( converter=_named_agg_elements_converter, ) - origin: Optional[Origin] = None + origin: Origin | None = None @property def type_(self) -> rty.Any: @@ -1829,7 +1829,7 @@ def _update_str(self) -> None: @define(eq=False) class Str(Expr): string: str - origin: Optional[Origin] = None + origin: Origin | None = None @property def type_(self) -> rty.Sequence: @@ -1857,7 +1857,7 @@ class MsgAgg(Expr): identifier: ID = field(converter=ID) field_values: Mapping[ID, Expr] type_: rty.Message - origin: Optional[Origin] = None + origin: Origin | None = None @property def accessed_vars(self) -> list[ID]: @@ -1891,7 +1891,7 @@ class DeltaMsgAgg(Expr): identifier: ID = field(converter=ID) field_values: Mapping[ID, Expr] type_: rty.Message - origin: Optional[Origin] = None + origin: Origin | None = None @property def accessed_vars(self) -> list[ID]: @@ -1925,7 +1925,7 @@ class CaseExpr(Expr): expression: BasicExpr choices: Sequence[tuple[Sequence[BasicExpr], BasicExpr]] type_: rty.Any - origin: Optional[Origin] = None + origin: Origin | None = None @property def accessed_vars(self) -> list[ID]: @@ -1964,7 +1964,7 @@ class SufficientSpace(FieldAccessAttr, BoolExpr): message: ID = field(converter=ID) field: ID = field(converter=ID) message_type: rty.Message - origin: Optional[Origin] = None + origin: Origin | None = None def to_z3_expr(self) -> z3.BoolRef: return z3.Bool(str(self)) @@ -1974,7 +1974,7 @@ def to_z3_expr(self) -> z3.BoolRef: class HasElement(Attr, BoolExpr): prefix: ID = field(converter=ID) prefix_type: rty.Sequence - origin: Optional[Origin] = None + origin: Origin | None = None def to_z3_expr(self) -> z3.BoolRef: return z3.Bool(str(self)) @@ -1983,7 +1983,7 @@ def to_z3_expr(self) -> z3.BoolRef: @frozen class Decl: identifier: ID = field(converter=ID) - location: Optional[Location] + location: Location | None @frozen @@ -2004,7 +2004,7 @@ class FuncDecl(FormalDecl): arguments: Sequence[Argument] return_type: ID = field(converter=ID) type_: rty.Type - location: Optional[Location] + location: Location | None @frozen @@ -2012,7 +2012,7 @@ class ChannelDecl(FormalDecl): identifier: ID = field(converter=ID) readable: bool writable: bool - location: Optional[Location] + location: Location | None @frozen @@ -2060,18 +2060,18 @@ class ComplexBoolExpr(ComplexExpr): class Transition: target: ID = field(converter=ID) condition: ComplexExpr - description: Optional[str] - location: Optional[Location] + description: str | None + location: Location | None @frozen class State: identifier: ID = field(converter=ID) transitions: Sequence[Transition] - exception_transition: Optional[Transition] + exception_transition: Transition | None actions: Sequence[Stmt] - description: Optional[str] - location: Optional[Location] + description: str | None + location: Location | None @property def declarations(self) -> list[VarDecl]: @@ -2088,7 +2088,7 @@ class Session: declarations: Sequence[VarDecl] parameters: Sequence[FormalDecl] types: Mapping[ID, type_decl.TypeDecl] - location: Optional[Location] + location: Location | None def __init__( # noqa: PLR0913 self, @@ -2097,7 +2097,7 @@ def __init__( # noqa: PLR0913 declarations: Sequence[VarDecl], parameters: Sequence[FormalDecl], types: Mapping[ID, type_decl.TypeDecl], - location: Optional[Location], + location: Location | None, variable_id: Generator[ID, None, None], workers: int = 1, ) -> None: diff --git a/rflx/ls/lexer.py b/rflx/ls/lexer.py index 9a3c6276f..50414b31d 100644 --- a/rflx/ls/lexer.py +++ b/rflx/ls/lexer.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from functools import singledispatchmethod -from typing import Optional, Union, cast +from typing import Optional, cast from rflx import const, lang from rflx.identifier import ID @@ -26,7 +26,7 @@ class Token: """ - symbol: Optional[Symbol] + symbol: Symbol | None lexeme: str line_number: int character_offset: int @@ -35,10 +35,10 @@ class Token: @dataclass class State: imports: set[ID] - current_package: Optional[ID] + current_package: ID | None declarations: list[ID] - foreign_package: Optional[ID] - current_session: Optional[ID] + foreign_package: ID | None + current_session: ID | None top_level: bool @@ -52,7 +52,7 @@ def tokens(self) -> list[Token]: """List of tokens resulting from the previous LSLexer.tokenize calls.""" return self._tokens - def search_token(self, line_number: int, character_offset: int) -> Optional[Token]: + def search_token(self, line_number: int, character_offset: int) -> Token | None: """Return the token at the given location if it exists, otherwise return None.""" if len(self._tokens) == 0: @@ -110,8 +110,10 @@ def tokenize(self, source: str, path: str = "") -> None: state = State(set(), None, [], None, None, top_level=False) self._process_ast_node(unit.root, state) + # TODO(eng/recordflux/RecordFlux#1424): Replace remaining use of Optional + # singledispatch has issues with PEP604 type annotations in Python 3.8 and 3.9. @singledispatchmethod - def _process_ast_node(self, node: Optional[lang.RFLXNode], state: State) -> None: + def _process_ast_node(self, node: Optional[lang.RFLXNode], state: State) -> None: # noqa: UP007 if node is None: return @@ -121,7 +123,7 @@ def _process_children(self, node: lang.RFLXNode, state: State) -> None: for child in node.children: self._process_ast_node(child, state) - def _identify_symbol(self, lexeme: str, state: State) -> Optional[Symbol]: + def _identify_symbol(self, lexeme: str, state: State) -> Symbol | None: symbols: list[Symbol] = self._model.get_symbols(lexeme) if state.foreign_package is not None and state.foreign_package != ID(lexeme): @@ -203,7 +205,7 @@ def _(self, node: lang.PackageNode, state: State) -> None: # Python 3.11 directly supports single dispatch with typing.Union @_process_ast_node.register(lang.TypeDecl) @_process_ast_node.register(lang.SessionDecl) - def _(self, node: Union[lang.TypeDecl, lang.SessionDecl], state: State) -> None: + def _(self, node: lang.TypeDecl | lang.SessionDecl, state: State) -> None: partial_identifier = ID(node.f_identifier.text) identifier = ( state.current_package * partial_identifier diff --git a/rflx/ls/model.py b/rflx/ls/model.py index 9cf948806..61d13da4a 100644 --- a/rflx/ls/model.py +++ b/rflx/ls/model.py @@ -3,7 +3,6 @@ from collections import defaultdict from dataclasses import dataclass from enum import Enum, auto -from typing import Optional from rflx.common import assert_never from rflx.error import Location @@ -84,8 +83,8 @@ class Symbol: identifier: ID category: SymbolCategory - definition_location: Optional[Location] - parent: Optional[ID] + definition_location: Location | None + parent: ID | None class LSModel: diff --git a/rflx/ls/server.py b/rflx/ls/server.py index 495abf6be..a1b1d9399 100644 --- a/rflx/ls/server.py +++ b/rflx/ls/server.py @@ -6,7 +6,7 @@ import uuid from collections import defaultdict from pathlib import Path -from typing import Callable, Final, Iterable, Mapping, Optional +from typing import Callable, Final, Iterable, Mapping from urllib.parse import unquote, urlparse from lsprotocol.types import ( @@ -67,7 +67,7 @@ } -def to_lsp_location(location: Optional[error.Location]) -> Optional[Location]: +def to_lsp_location(location: error.Location | None) -> Location | None: if ( location is None or location.source is None @@ -85,7 +85,7 @@ def to_lsp_location(location: Optional[error.Location]) -> Optional[Location]: ) -def to_lsp_severity(severity: error.Severity) -> Optional[DiagnosticSeverity]: +def to_lsp_severity(severity: error.Severity) -> DiagnosticSeverity | None: if severity is error.Severity.ERROR: return DiagnosticSeverity.Error if severity is error.Severity.INFO or severity is error.Severity.NOTE: @@ -223,7 +223,7 @@ def _workspace_files(self) -> list[Path]: # TODO(eng/recordflux/RecordFlux#1424): Use typing.ParamSpec instead of ... and object def debounce( # type: ignore[misc] interval_s: int, - keyed_by: Optional[str] = None, + keyed_by: str | None = None, ) -> Callable[[Callable[..., None]], Callable[..., None]]: """ Debounce calls to this function until interval_s seconds have passed. @@ -302,7 +302,7 @@ async def did_change(ls: RecordFluxLanguageServer, params: DidChangeTextDocument async def go_to_definition( ls: RecordFluxLanguageServer, params: DefinitionParams, -) -> Optional[list[Location]]: +) -> list[Location] | None: with LSFatalErrorHandler(ls): lexer = initialize_lexer(ls, params.text_document.uri) token = lexer.search_token(params.position.line, params.position.character) diff --git a/rflx/model/cache.py b/rflx/model/cache.py index 90d10c9e9..d4f8e187c 100644 --- a/rflx/model/cache.py +++ b/rflx/model/cache.py @@ -7,7 +7,7 @@ import typing as ty from functools import lru_cache, singledispatch from pathlib import Path -from typing import Literal, Optional, TextIO +from typing import Literal, TextIO import importlib_resources @@ -187,7 +187,7 @@ def key(self) -> str: return self._full_name @property - def value(self) -> Optional[str]: + def value(self) -> str | None: return self._value diff --git a/rflx/model/declaration.py b/rflx/model/declaration.py index 0b5b0a4a2..b728dc173 100644 --- a/rflx/model/declaration.py +++ b/rflx/model/declaration.py @@ -2,7 +2,7 @@ from abc import abstractmethod from collections.abc import Callable, Generator, Sequence -from typing import ClassVar, Optional +from typing import ClassVar import rflx.typing_ as rty from rflx import expr_conv, ir @@ -18,7 +18,7 @@ class Declaration(Base): DESCRIPTIVE_NAME: ClassVar[str] - def __init__(self, identifier: StrID, location: Optional[Location] = None): + def __init__(self, identifier: StrID, location: Location | None = None): self.identifier = ID(identifier) self.location = location self._refcount = 0 @@ -56,7 +56,7 @@ def __init__( identifier: StrID, type_identifier: StrID, type_: rty.Type = rty.UNDEFINED, - location: Optional[Location] = None, + location: Location | None = None, ): super().__init__(identifier, location) self._type_identifier = ID(type_identifier) @@ -95,9 +95,9 @@ def __init__( self, identifier: StrID, type_identifier: StrID, - expression: Optional[Expr] = None, + expression: Expr | None = None, type_: rty.Type = rty.UNDEFINED, - location: Optional[Location] = None, + location: Location | None = None, ): super().__init__(identifier, type_identifier, type_, location) self.expression = expression @@ -144,7 +144,7 @@ def __init__( type_identifier: StrID, expression: Selected, type_: rty.Type = rty.UNDEFINED, - location: Optional[Location] = None, + location: Location | None = None, ): super().__init__(identifier, type_identifier, type_, location) self.expression = expression @@ -246,7 +246,7 @@ def __init__( arguments: Sequence[Argument], return_type: StrID, type_: rty.Type = rty.UNDEFINED, - location: Optional[Location] = None, + location: Location | None = None, ): super().__init__(identifier, return_type, type_, location) self._arguments = arguments @@ -292,7 +292,7 @@ def __init__( identifier: StrID, readable: bool = False, writable: bool = False, - location: Optional[Location] = None, + location: Location | None = None, ): assert readable or writable super().__init__(identifier, location) diff --git a/rflx/model/message.py b/rflx/model/message.py index 73139def4..64e5a35c4 100644 --- a/rflx/model/message.py +++ b/rflx/model/message.py @@ -8,7 +8,7 @@ from dataclasses import dataclass, field as dataclass_field from enum import Enum from functools import cached_property, partial -from typing import Callable, Optional, Union +from typing import Callable import rflx.typing_ as rty from rflx import expr, expr_proof @@ -63,7 +63,7 @@ class Link(Base): condition: expr.Expr = expr.TRUE size: expr.Expr = expr.UNDEFINED first: expr.Expr = expr.UNDEFINED - location: Optional[Location] = dataclass_field(default=None, repr=False) + location: Location | None = dataclass_field(default=None, repr=False) def __str__(self) -> str: condition = indent_next( @@ -104,9 +104,9 @@ def __init__( # noqa: PLR0913 identifier: StrID, structure: Sequence[Link], types: Mapping[Field, type_decl.TypeDecl], - checksums: Optional[Mapping[ID, Sequence[expr.Expr]]] = None, - byte_order: Optional[Union[ByteOrder, Mapping[Field, ByteOrder]]] = None, - location: Optional[Location] = None, + checksums: Mapping[ID, Sequence[expr.Expr]] | None = None, + byte_order: ByteOrder | Mapping[Field, ByteOrder] | None = None, + location: Location | None = None, skip_verification: bool = False, workers: int = 1, ) -> None: @@ -269,13 +269,13 @@ def byte_order(self) -> Mapping[Field, ByteOrder]: def copy( # noqa: PLR0913 self, - identifier: Optional[StrID] = None, - structure: Optional[Sequence[Link]] = None, - types: Optional[Mapping[Field, type_decl.TypeDecl]] = None, - checksums: Optional[Mapping[ID, Sequence[expr.Expr]]] = None, - byte_order: Optional[Union[ByteOrder, Mapping[Field, ByteOrder]]] = None, - location: Optional[Location] = None, - skip_verification: Optional[bool] = None, + identifier: StrID | None = None, + structure: Sequence[Link] | None = None, + types: Mapping[Field, type_decl.TypeDecl] | None = None, + checksums: Mapping[ID, Sequence[expr.Expr]] | None = None, + byte_order: ByteOrder | Mapping[Field, ByteOrder] | None = None, + location: Location | None = None, + skip_verification: bool | None = None, ) -> Message: return Message( identifier if identifier else self.identifier, @@ -388,7 +388,7 @@ def path_condition(self, field: Field) -> expr.Expr: return result - def field_size_opt(self, field: Field) -> Optional[expr.Number]: + def field_size_opt(self, field: Field) -> expr.Number | None: """Return field size if field size is fixed and None otherwise.""" if field == FINAL: return expr.Number(0) @@ -591,8 +591,8 @@ def is_definite(self) -> bool: def size( self, - field_values: Optional[Mapping[Field, expr.Expr]] = None, - message_instance: Optional[ID] = None, + field_values: Mapping[Field, expr.Expr] | None = None, + message_instance: ID | None = None, subpath: bool = False, ) -> expr.Expr: """ @@ -1027,7 +1027,7 @@ def _validate_structure(self, structure_fields: set[Field]) -> bool: ), ) - def has_final(field: Field, seen: Optional[set[Field]] = None) -> bool: + def has_final(field: Field, seen: set[Field] | None = None) -> bool: """Return True if the field has a path to the final field or a cycle was found.""" if seen is None: @@ -1238,7 +1238,7 @@ def substitute(expression: expr.Expr) -> expr.Expr: return (structure, checksums) - def _compute_topological_sorting(self, has_unreachable: bool) -> Optional[tuple[Field, ...]]: + def _compute_topological_sorting(self, has_unreachable: bool) -> tuple[Field, ...] | None: """Return fields topologically sorted (Kahn's algorithm).""" result: tuple[Field, ...] = () fields = [INITIAL] @@ -2284,7 +2284,7 @@ def _target_size(self, link: Link) -> expr.Expr: return link.size return self.field_size(link.target) - def _target_size_opt(self, link: Link) -> Optional[expr.Expr]: + def _target_size_opt(self, link: Link) -> expr.Expr | None: if link.size != expr.UNDEFINED: return link.size return self.field_size_opt(link.target) @@ -2296,7 +2296,7 @@ def _target_last(self, link: Link) -> expr.Expr: link.target.identifier.location, ) - def _target_last_opt(self, link: Link) -> Optional[expr.Expr]: + def _target_last_opt(self, link: Link) -> expr.Expr | None: size = self._target_size_opt(link) if not size: return None @@ -2401,11 +2401,11 @@ def __init__( # noqa: PLR0913 self, identifier: StrID, base: Message, - structure: Optional[Sequence[Link]] = None, - types: Optional[Mapping[Field, type_decl.TypeDecl]] = None, - checksums: Optional[Mapping[ID, Sequence[expr.Expr]]] = None, - byte_order: Optional[Union[ByteOrder, Mapping[Field, ByteOrder]]] = None, - location: Optional[Location] = None, + structure: Sequence[Link] | None = None, + types: Mapping[Field, type_decl.TypeDecl] | None = None, + checksums: Mapping[ID, Sequence[expr.Expr]] | None = None, + byte_order: ByteOrder | Mapping[Field, ByteOrder] | None = None, + location: Location | None = None, skip_verification: bool = False, workers: int = 1, ) -> None: @@ -2444,13 +2444,13 @@ def __init__( # noqa: PLR0913 def copy( # noqa: PLR0913 self, - identifier: Optional[StrID] = None, - structure: Optional[Sequence[Link]] = None, - types: Optional[Mapping[Field, type_decl.TypeDecl]] = None, - checksums: Optional[Mapping[ID, Sequence[expr.Expr]]] = None, - byte_order: Optional[Union[ByteOrder, Mapping[Field, ByteOrder]]] = None, - location: Optional[Location] = None, - skip_verification: Optional[bool] = None, + identifier: StrID | None = None, + structure: Sequence[Link] | None = None, + types: Mapping[Field, type_decl.TypeDecl] | None = None, + checksums: Mapping[ID, Sequence[expr.Expr]] | None = None, + byte_order: ByteOrder | Mapping[Field, ByteOrder] | None = None, + location: Location | None = None, + skip_verification: bool | None = None, ) -> DerivedMessage: return DerivedMessage( identifier if identifier else self.identifier, @@ -2472,7 +2472,7 @@ def __init__( # noqa: PLR0913 field: Field, sdu: Message, condition: expr.Expr = expr.TRUE, - location: Optional[Location] = None, + location: Location | None = None, skip_verification: bool = False, ) -> None: super().__init__( @@ -2679,16 +2679,13 @@ class UncheckedMessage(type_decl.UncheckedTypeDecl): parameter_types: Sequence[tuple[Field, ID, Sequence[tuple[ID, expr.Expr]]]] field_types: Sequence[tuple[Field, ID, Sequence[tuple[ID, expr.Expr]]]] checksums: Mapping[ID, Sequence[expr.Expr]] = dataclass_field(default_factory=dict) - byte_order: Union[ - Mapping[Field, ByteOrder], - ByteOrder, - ] = dataclass_field( + byte_order: Mapping[Field, ByteOrder] | ByteOrder = dataclass_field( # TODO(eng/recordflux/RecordFlux#1359): Fix type annotation # The type should be `dict[Field, ByteOrder]`, but the subscription of `dict` is not # supported by Python 3.8. default_factory=dict, ) - location: Optional[Location] = dataclass_field(default=None) + location: Location | None = dataclass_field(default=None) @property def fields(self) -> list[Field]: @@ -2836,7 +2833,7 @@ def checked( def merged( self, declarations: Sequence[TopLevelDeclaration], - message_arguments: Optional[Mapping[ID, Mapping[ID, expr.Expr]]] = None, + message_arguments: Mapping[ID, Mapping[ID, expr.Expr]] | None = None, ) -> UncheckedMessage: assert all_types_declared(self, declarations) @@ -2866,7 +2863,7 @@ def _check_message_arguments( argument_errors: RecordFluxError, message: Message, type_arguments: Sequence[tuple[ID, expr.Expr]], - field_type_location: Optional[Location], + field_type_location: Location | None, ) -> None: for param, arg in itertools.zip_longest(message.parameter_types, type_arguments): if arg: @@ -3240,7 +3237,7 @@ def _prune_dangling_fields( class UncheckedDerivedMessage(type_decl.UncheckedTypeDecl): identifier: ID base_identifier: ID - location: Optional[Location] = dataclass_field(default=None) + location: Location | None = dataclass_field(default=None) def checked( self, @@ -3298,7 +3295,7 @@ class UncheckedRefinement(type_decl.UncheckedTypeDecl): field: Field sdu: ID condition: expr.Expr - location: Optional[Location] = dataclass_field(default=None) + location: Location | None = dataclass_field(default=None) def __init__( # noqa: PLR0913 self, @@ -3307,7 +3304,7 @@ def __init__( # noqa: PLR0913 field: Field, sdu: ID, condition: expr.Expr = expr.TRUE, - location: Optional[Location] = None, + location: Location | None = None, ) -> None: super().__init__( ID(package) * f"__REFINEMENT__{sdu.flat}__{pdu.flat}__{field.name}__", @@ -3423,7 +3420,7 @@ def aggregate_constraints( def get_constraints( aggregate: expr.Aggregate, field: expr.Variable, - location: Optional[Location], + location: Location | None, ) -> Sequence[expr.Expr]: comp = types[Field(field.name)] @@ -3591,7 +3588,7 @@ def prove( def annotate_path( path: Sequence[Link], message_location: Location, - link_filter: Optional[Callable[[Link], bool]] = None, + link_filter: Callable[[Link], bool] | None = None, ) -> Sequence[Annotation]: link_filter = link_filter or (lambda _: True) assert message_location.end is not None diff --git a/rflx/model/session.py b/rflx/model/session.py index 3fd1999f5..4d3904a63 100644 --- a/rflx/model/session.py +++ b/rflx/model/session.py @@ -7,7 +7,7 @@ from copy import deepcopy from dataclasses import dataclass from functools import lru_cache -from typing import Final, Optional +from typing import Final from rflx import expr, expr_conv, ir, typing_ as rty from rflx.common import Base, indent, indent_next, verbose_repr @@ -28,8 +28,8 @@ def __init__( self, target: StrID, condition: expr.Expr = expr.TRUE, - description: Optional[str] = None, - location: Optional[Location] = None, + description: str | None = None, + location: Location | None = None, ): self.target = ID(target) self.condition = condition @@ -64,12 +64,12 @@ class State(Base): def __init__( # noqa: PLR0913 self, identifier: StrID, - transitions: Optional[Sequence[Transition]] = None, - exception_transition: Optional[Transition] = None, - actions: Optional[Sequence[stmt.Statement]] = None, - declarations: Optional[Sequence[decl.BasicDeclaration]] = None, - description: Optional[str] = None, - location: Optional[Location] = None, + transitions: Sequence[Transition] | None = None, + exception_transition: Transition | None = None, + actions: Sequence[stmt.Statement] | None = None, + declarations: Sequence[decl.BasicDeclaration] | None = None, + description: str | None = None, + location: Location | None = None, ): if transitions: assert transitions[-1].condition == expr.TRUE, "missing default transition" @@ -118,7 +118,7 @@ def transitions(self) -> Sequence[Transition]: return self._transitions or [] @property - def exception_transition(self) -> Optional[Transition]: + def exception_transition(self) -> Transition | None: return self._exception_transition @property @@ -351,7 +351,7 @@ def __init__( # noqa: PLR0913 declarations: Sequence[decl.BasicDeclaration], parameters: Sequence[decl.FormalDeclaration], types: Sequence[type_decl.TypeDecl], - location: Optional[Location] = None, + location: Location | None = None, workers: int = 1, ): super().__init__(identifier, location) @@ -679,7 +679,7 @@ def _validate_declarations( ) -> None: visible_declarations = dict(visible_declarations) - def undefined_type(type_identifier: StrID, location: Optional[Location]) -> None: + def undefined_type(type_identifier: StrID, location: Location | None) -> None: self.error.push( ErrorEntry( f'undefined type "{type_identifier}"', @@ -1034,7 +1034,7 @@ class UncheckedSession(UncheckedTopLevelDeclaration): states: Sequence[State] declarations: Sequence[decl.BasicDeclaration] parameters: Sequence[decl.FormalDeclaration] - location: Optional[Location] + location: Location | None def checked( self, diff --git a/rflx/model/statement.py b/rflx/model/statement.py index d16162223..a0da57f08 100644 --- a/rflx/model/statement.py +++ b/rflx/model/statement.py @@ -2,7 +2,6 @@ from abc import abstractmethod from collections.abc import Callable, Generator, Mapping, Sequence -from typing import Optional from rflx import expr_conv, ir, typing_ as rty from rflx.common import Base @@ -16,7 +15,7 @@ def __init__( self, identifier: StrID, type_: rty.Type = rty.UNDEFINED, - location: Optional[Location] = None, + location: Location | None = None, ): self.identifier = ID(identifier) self.type_ = type_ @@ -47,7 +46,7 @@ def __init__( identifier: StrID, expression: Expr, type_: rty.Type = rty.UNDEFINED, - location: Optional[Location] = None, + location: Location | None = None, ) -> None: super().__init__(identifier, type_, location) self.expression = expression @@ -89,7 +88,7 @@ def __init__( field: StrID, expression: Expr, type_: rty.Type = rty.UNDEFINED, - location: Optional[Location] = None, + location: Location | None = None, ) -> None: super().__init__(message, expression, type_, location) self.message = ID(message) @@ -164,7 +163,7 @@ def __init__( attribute: str, parameters: list[Expr], type_: rty.Type = rty.UNDEFINED, - location: Optional[Location] = None, + location: Location | None = None, ) -> None: super().__init__(identifier, type_, location) self.attribute = attribute @@ -194,7 +193,7 @@ def __init__( identifier: StrID, parameter: Expr, type_: rty.Type = rty.UNDEFINED, - location: Optional[Location] = None, + location: Location | None = None, ) -> None: super().__init__(identifier, self.__class__.__name__, [parameter], type_, location) @@ -295,9 +294,9 @@ class Reset(AttributeStatement): def __init__( self, identifier: StrID, - associations: Optional[Mapping[ID, Expr]] = None, + associations: Mapping[ID, Expr] | None = None, type_: rty.Type = rty.UNDEFINED, - location: Optional[Location] = None, + location: Location | None = None, ) -> None: super().__init__(identifier, self.__class__.__name__, [], type_, location) self.associations = associations or {} @@ -384,7 +383,7 @@ def __init__( identifier: StrID, parameter: Expr, type_: rty.Type = rty.UNDEFINED, - location: Optional[Location] = None, + location: Location | None = None, ) -> None: super().__init__(identifier, self.__class__.__name__, [parameter], type_, location) diff --git a/rflx/model/top_level_declaration.py b/rflx/model/top_level_declaration.py index 2e5c41dfd..f1ba2bd23 100644 --- a/rflx/model/top_level_declaration.py +++ b/rflx/model/top_level_declaration.py @@ -2,7 +2,7 @@ from abc import abstractmethod from dataclasses import dataclass -from typing import Optional, Sequence +from typing import Sequence from rflx.common import Base from rflx.identifier import ID, StrID @@ -10,7 +10,7 @@ class TopLevelDeclaration(Base): - def __init__(self, identifier: StrID, location: Optional[Location] = None) -> None: + def __init__(self, identifier: StrID, location: Location | None = None) -> None: self.identifier = ID(identifier) self.location = location self.error = RecordFluxError() diff --git a/rflx/model/type_decl.py b/rflx/model/type_decl.py index 89c68b622..f82b3ceba 100644 --- a/rflx/model/type_decl.py +++ b/rflx/model/type_decl.py @@ -5,7 +5,7 @@ from collections import abc from dataclasses import dataclass from pathlib import Path -from typing import Literal, Optional +from typing import Literal import rflx.typing_ as rty from rflx import const, expr @@ -64,7 +64,7 @@ def __init__( self, identifier: StrID, size: expr.Expr, - location: Optional[Location] = None, + location: Location | None = None, ) -> None: super().__init__(identifier, location) @@ -104,7 +104,7 @@ def __init__( first: expr.Expr, last: expr.Expr, size: expr.Expr, - location: Optional[Location] = None, + location: Location | None = None, ) -> None: super().__init__(identifier, size, location) @@ -305,7 +305,7 @@ def __init__( # noqa: PLR0912 literals: abc.Sequence[tuple[StrID, expr.Number]], size: expr.Expr, always_valid: bool, - location: Optional[Location] = None, + location: Location | None = None, ) -> None: super().__init__(identifier, size, location) @@ -507,7 +507,7 @@ def __init__( self, identifier: StrID, element_type: TypeDecl, - location: Optional[Location] = None, + location: Location | None = None, ) -> None: super().__init__(identifier, location) self.element_type = element_type @@ -627,7 +627,7 @@ def dependencies(self) -> list[TypeDecl]: class Opaque(Composite): - def __init__(self, location: Optional[Location] = None) -> None: + def __init__(self, location: Location | None = None) -> None: super().__init__(const.INTERNAL_PACKAGE * "Opaque", location) def __repr__(self) -> str: @@ -665,7 +665,7 @@ class UncheckedInteger(UncheckedTypeDecl): first: expr.Expr last: expr.Expr size: expr.Expr - location: Optional[Location] + location: Location | None def checked( self, @@ -682,7 +682,7 @@ class UncheckedEnumeration(UncheckedTypeDecl): literals: abc.Sequence[tuple[ID, expr.Number]] size: expr.Expr always_valid: bool - location: Optional[Location] + location: Location | None def checked( self, @@ -703,7 +703,7 @@ def checked( class UncheckedSequence(UncheckedTypeDecl): identifier: ID element_identifier: ID - location: Optional[Location] + location: Location | None def checked( self, @@ -731,7 +731,7 @@ def checked( @dataclass class UncheckedOpaque(UncheckedTypeDecl): identifier: ID - location: Optional[Location] + location: Location | None def checked( self, @@ -783,7 +783,7 @@ def is_builtin_type(identifier: StrID) -> bool: ) -def internal_type_identifier(identifier: ID, package: Optional[ID] = None) -> ID: +def internal_type_identifier(identifier: ID, package: ID | None = None) -> ID: """ Return the internal identifier of a type. diff --git a/rflx/pyrflx/bitstring.py b/rflx/pyrflx/bitstring.py index e6b7007cf..dfdf9308e 100644 --- a/rflx/pyrflx/bitstring.py +++ b/rflx/pyrflx/bitstring.py @@ -1,7 +1,6 @@ from __future__ import annotations from collections.abc import Sequence -from typing import Union from typing_extensions import Self @@ -23,7 +22,7 @@ def __iadd__(self, other: Bitstring) -> Self: self._bits += other._bits return self - def __getitem__(self, key: Union[int, slice]) -> Bitstring: + def __getitem__(self, key: int | slice) -> Bitstring: if isinstance(key, slice) and isinstance(key.stop, int) and len(self._bits) < key.stop: raise IndexError return Bitstring(self._bits[key]) diff --git a/rflx/pyrflx/package.py b/rflx/pyrflx/package.py index f19bd6559..7ca157252 100644 --- a/rflx/pyrflx/package.py +++ b/rflx/pyrflx/package.py @@ -1,7 +1,6 @@ from __future__ import annotations from collections.abc import Iterator, Mapping -from typing import Optional, Union from rflx.common import Base from rflx.identifier import StrID @@ -22,7 +21,7 @@ def name(self) -> str: def new_message( self, key: StrID, - parameters: Optional[Mapping[str, Union[bool, int, str]]] = None, + parameters: Mapping[str, bool | int | str] | None = None, ) -> MessageValue: message = self._messages[str(key)].clone() if parameters: diff --git a/rflx/pyrflx/pyrflx.py b/rflx/pyrflx/pyrflx.py index b4f007f3d..1faa88d77 100644 --- a/rflx/pyrflx/pyrflx.py +++ b/rflx/pyrflx/pyrflx.py @@ -2,7 +2,6 @@ from collections.abc import Iterable, Iterator, Mapping from pathlib import Path -from typing import Optional, Union from rflx.error import FatalError from rflx.identifier import ID, StrID @@ -18,7 +17,7 @@ class PyRFLX: def __init__( self, model: Model, - checksum_functions: Optional[Mapping[StrID, Mapping[str, ChecksumFunction]]] = None, + checksum_functions: Mapping[StrID, Mapping[str, ChecksumFunction]] | None = None, skip_message_verification: bool = False, ) -> None: """ @@ -54,8 +53,8 @@ def __init__( @classmethod def from_specs( cls, - files: Iterable[Union[str, Path]], - cache: Optional[Cache] = None, + files: Iterable[str | Path], + cache: Cache | None = None, skip_message_verification: bool = False, ) -> PyRFLX: paths = list(map(Path, files)) diff --git a/rflx/pyrflx/typevalue.py b/rflx/pyrflx/typevalue.py index d76e65325..d9970f3ae 100644 --- a/rflx/pyrflx/typevalue.py +++ b/rflx/pyrflx/typevalue.py @@ -4,7 +4,7 @@ from abc import abstractmethod from collections import abc from dataclasses import dataclass -from typing import Any, Optional, Protocol, Union +from typing import Any, Protocol, Union from rflx.common import Base from rflx.const import BUILTINS_PACKAGE @@ -56,6 +56,7 @@ class ChecksumFunction(Protocol): def __call__(self, message: bytes, **kwargs: object) -> int: ... # pragma: no cover +# TODO(eng/recordflux/RecordFlux#1424): Replace with PEP604 union ValueType = Union[ "MessageValue", ty.Sequence["TypeValue"], @@ -67,7 +68,7 @@ def __call__(self, message: bytes, **kwargs: object) -> int: ... # pragma: no c class TypeValue(Base): - _value: Optional[ValueType] = None + _value: ValueType | None = None def __init__(self, vtype: TypeDecl) -> None: self._type = vtype @@ -110,7 +111,7 @@ def assign(self, value: Any, check: bool = True) -> None: # type: ignore[misc] raise NotImplementedError @abstractmethod - def parse(self, value: Union[Bitstring, bytes], check: bool = True) -> None: + def parse(self, value: Bitstring | bytes, check: bool = True) -> None: raise NotImplementedError @property @@ -145,7 +146,7 @@ def construct( cls, vtype: TypeDecl, imported: bool = False, - refinements: Optional[abc.Sequence[RefinementValue]] = None, + refinements: abc.Sequence[RefinementValue] | None = None, ) -> TypeValue: if isinstance(vtype, Integer): return IntegerValue(vtype) @@ -207,7 +208,7 @@ def assign(self, value: int, check: bool = True) -> None: raise e self._value = value - def parse(self, value: Union[Bitstring, bytes], check: bool = True) -> None: + def parse(self, value: Bitstring | bytes, check: bool = True) -> None: if isinstance(value, bytes): value = Bitstring.from_bytes(value) self.assign(int(value), check) @@ -279,7 +280,7 @@ def assign(self, value: str, check: bool = True) -> None: assert r == TRUE self._value = (str(prefixed_value), self._type.literals[prefixed_value.name]) - def parse(self, value: Union[Bitstring, bytes], _check: bool = True) -> None: + def parse(self, value: Bitstring | bytes, _check: bool = True) -> None: if isinstance(value, bytes): value = Bitstring.from_bytes(value) value_as_number = Number(int(value)) @@ -334,7 +335,7 @@ def as_json(self) -> object: class CompositeValue(TypeValue): def __init__(self, vtype: Composite) -> None: - self._expected_size: Optional[Expr] = None + self._expected_size: Expr | None = None super().__init__(vtype) def set_expected_size(self, expected_size: Expr) -> None: @@ -342,7 +343,7 @@ def set_expected_size(self, expected_size: Expr) -> None: def _check_size_of_assigned_value( self, - value: Union[bytes, Bitstring, abc.Sequence[TypeValue]], + value: bytes | Bitstring | abc.Sequence[TypeValue], ) -> None: if isinstance(value, bytes): size_of_value = len(value) * 8 @@ -371,17 +372,17 @@ def value(self) -> ValueType: class OpaqueValue(CompositeValue): - _value: Optional[bytes] - _nested_message: Optional[MessageValue] = None + _value: bytes | None + _nested_message: MessageValue | None = None def __init__(self, vtype: Opaque) -> None: super().__init__(vtype) - self._refinement_message: Optional[MessageValue] = None + self._refinement_message: MessageValue | None = None def assign(self, value: bytes, check: bool = True) -> None: self.parse(value, check) - def parse(self, value: Union[Bitstring, bytes], check: bool = True) -> None: + def parse(self, value: Bitstring | bytes, check: bool = True) -> None: if check: self._check_size_of_assigned_value(value) if self._refinement_message is not None: @@ -411,7 +412,7 @@ def size(self) -> Expr: return Number(len(self._value) * 8) @property - def nested_message(self) -> Optional[MessageValue]: + def nested_message(self) -> MessageValue | None: return self._nested_message @property @@ -434,7 +435,7 @@ def bitstring(self) -> Bitstring: def accepted_type(self) -> type: return bytes - def as_json(self) -> Optional[bytes]: + def as_json(self) -> bytes | None: return self._value @@ -486,7 +487,7 @@ def assign(self, value: list[TypeValue], check: bool = True) -> None: self._value = value - def parse(self, value: Union[Bitstring, bytes], check: bool = True) -> None: + def parse(self, value: Bitstring | bytes, check: bool = True) -> None: self._check_size_of_assigned_value(value) if isinstance(value, bytes): value = Bitstring.from_bytes(value) @@ -563,10 +564,10 @@ class MessageValue(TypeValue): def __init__( self, model: Message, - refinements: Optional[abc.Sequence[RefinementValue]] = None, + refinements: abc.Sequence[RefinementValue] | None = None, skip_verification: bool = False, - parameters: Optional[abc.Mapping[Name, Expr]] = None, - state: Optional[MessageValue.State] = None, + parameters: abc.Mapping[Name, Expr] | None = None, + state: MessageValue.State | None = None, ) -> None: super().__init__(model) self._skip_verification = skip_verification @@ -618,7 +619,7 @@ def __init__( def add_refinement(self, refinement: RefinementValue) -> None: self._refinements = [*(self._refinements or []), refinement] - def add_parameters(self, parameters: abc.Mapping[str, Union[bool, int, str]]) -> None: + def add_parameters(self, parameters: abc.Mapping[str, bool | int | str]) -> None: expected = {p.name for p in self._type.parameter_types} added = set(parameters) @@ -763,7 +764,7 @@ def as_json(self) -> dict[str, dict[str, object]]: def _valid_refinement_condition(self, refinement: RefinementValue) -> bool: return self._simplified(refinement.condition) == TRUE - def _next_link(self, source_field_name: str) -> Optional[Link]: + def _next_link(self, source_field_name: str) -> Link | None: field = Field(source_field_name) if field == FINAL: return None @@ -800,7 +801,7 @@ def _prev_field(self, fld: str) -> str: return field return "" - def _get_size(self, fld: str) -> Optional[Number]: + def _get_size(self, fld: str) -> Number | None: typeval = self._fields[fld].typeval if isinstance(typeval, ScalarValue): return typeval.size @@ -815,7 +816,7 @@ def _get_size(self, fld: str) -> Optional[Number]: return size if isinstance(size, Number) else None return None - def _get_first(self, fld: str) -> Optional[Number]: + def _get_first(self, fld: str) -> Number | None: for l in self._type.incoming(Field(fld)): if l.first != UNDEFINED and ( self._skip_verification or self._simplified(l.condition) == TRUE @@ -845,7 +846,7 @@ def size(self) -> Number: def assign(self, value: bytes, check: bool = True) -> None: raise NotImplementedError - def parse(self, value: Union[Bitstring, bytes], _check: bool = True) -> None: + def parse(self, value: Bitstring | bytes, _check: bool = True) -> None: assert not self._skip_verification self._path.clear() if isinstance(value, bytes): @@ -923,7 +924,7 @@ def set_field_with_size(field_name: str, field_size: int) -> tuple[int, int]: def _set_unchecked( self, field_name: str, - value: Union[bytes, int, str, abc.Sequence[TypeValue]], + value: bytes | int | str | abc.Sequence[TypeValue], ) -> None: field = self._fields[field_name] field.prev = self._last_field @@ -942,8 +943,8 @@ def _set_unchecked( def _set_checked( self, field_name: str, - value: Union[bytes, int, str, abc.Sequence[TypeValue], Bitstring], - message_size: Optional[int] = None, + value: bytes | int | str | abc.Sequence[TypeValue] | Bitstring, + message_size: int | None = None, ) -> None: def set_refinement(fld: MessageValue.Field, fld_name: str) -> None: if isinstance(fld.typeval, OpaqueValue): @@ -1029,7 +1030,7 @@ def check_outgoing_condition_satisfied() -> None: def set( self, field_name: str, - value: Union[bytes, int, str, abc.Sequence[TypeValue]], + value: bytes | int | str | abc.Sequence[TypeValue], ) -> None: if self._skip_verification: self._set_unchecked(field_name, value) @@ -1048,7 +1049,7 @@ def set( def _set_parsed_value( self, field_name: str, - value: Union[bytes, int, str, abc.Sequence[TypeValue], Bitstring], + value: bytes | int | str | abc.Sequence[TypeValue] | Bitstring, message_size: int, ) -> None: self._set_checked(field_name, value, message_size) @@ -1191,7 +1192,7 @@ def _calculate_checksum(self, checksum: MessageValue.Checksum) -> int: ) raise e - arguments: dict[str, Union[str, int, bytes, tuple[int, int], list[int]]] = {} + arguments: dict[str, str | int | bytes | tuple[int, int] | list[int]] = {} for expr_tuple in checksum.parameters: if isinstance(expr_tuple.evaluated_expression, ValueRange): assert isinstance(expr_tuple.evaluated_expression.lower, Number) @@ -1337,8 +1338,8 @@ def valid_message(self) -> bool: def _update_simplified_mapping( self, - message_size: Optional[int] = None, - field: Optional[Field] = None, + message_size: int | None = None, + field: Field | None = None, ) -> None: if field: if isinstance(field.typeval, ScalarValue): @@ -1442,7 +1443,7 @@ def subst(expression: Expr) -> Expr: class Checksum: def __init__(self, field_name: str, parameters: abc.Sequence[Expr]): self.field_name = field_name - self.function: Optional[ChecksumFunction] = None + self.function: ChecksumFunction | None = None self.calculated = False @dataclass @@ -1460,10 +1461,10 @@ def __init__( # noqa: PLR0913 self, type_value: TypeValue, name: str = "", - name_variable: Optional[Variable] = None, - name_first: Optional[First] = None, - name_last: Optional[Last] = None, - name_size: Optional[Size] = None, + name_variable: Variable | None = None, + name_first: First | None = None, + name_last: Last | None = None, + name_size: Size | None = None, ): assert name or (name_variable and name_first and name_last and name_size) self.typeval = type_value @@ -1517,8 +1518,8 @@ def last(self) -> Expr: @dataclass class State: - fields: Optional[abc.Mapping[str, MessageValue.Field]] = None - checksums: Optional[abc.Mapping[str, MessageValue.Checksum]] = None + fields: abc.Mapping[str, MessageValue.Field] | None = None + checksums: abc.Mapping[str, MessageValue.Checksum] | None = None class RefinementValue: diff --git a/rflx/rapidflux/__init__.pyi b/rflx/rapidflux/__init__.pyi index 435cd9888..8e40c5965 100644 --- a/rflx/rapidflux/__init__.pyi +++ b/rflx/rapidflux/__init__.pyi @@ -1,15 +1,14 @@ from collections.abc import Sequence from enum import Enum from pathlib import Path -from typing import Optional, Union from typing_extensions import Self class ID: def __init__( self, - identifier: Union[str, Sequence[str], Self], - location: Optional[Location] = None, + identifier: str | Sequence[str] | Self, + location: Location | None = None, ) -> None: ... def __eq__(self, other: object) -> bool: ... def __lt__(self, other: object) -> bool: ... @@ -18,7 +17,7 @@ class ID: def __mul__(self: Self, other: object) -> Self: ... def __rmul__(self: Self, other: object) -> Self: ... @property - def location(self) -> Optional[Location]: ... + def location(self) -> Location | None: ... @property def parts(self) -> Sequence[str]: ... @property @@ -34,20 +33,20 @@ class Location: def __init__( self, start: tuple[int, int], - source: Optional[Path] = None, - end: Optional[tuple[int, int]] = None, + source: Path | None = None, + end: tuple[int, int] | None = None, ): ... @property - def source(self) -> Optional[Path]: ... + def source(self) -> Path | None: ... @property def start(self) -> tuple[int, int]: ... @property - def end(self) -> Optional[tuple[int, int]]: ... + def end(self) -> tuple[int, int] | None: ... @property def short(self) -> Location: ... def __lt__(self, other: object) -> bool: ... @staticmethod - def merge(locations: Sequence[Optional[Location]]) -> Location: ... + def merge(locations: Sequence[Location | None]) -> Location: ... class Severity(Enum): ERROR: Severity diff --git a/rflx/specification/parser.py b/rflx/specification/parser.py index 1dc6a1c20..b96288fc9 100644 --- a/rflx/specification/parser.py +++ b/rflx/specification/parser.py @@ -5,7 +5,7 @@ from collections.abc import Callable, Mapping, Sequence from dataclasses import dataclass from pathlib import Path -from typing import Iterable, Optional, Union +from typing import Iterable import rflx.typing_ as rty from rflx import expr, lang, model @@ -118,7 +118,7 @@ def validate_handler( error.propagate() -def create_description(description: Optional[lang.Description] = None) -> Optional[str]: +def create_description(description: lang.Description | None = None) -> str | None: if description: assert isinstance(description.text, str) return description.text.split('"')[1] @@ -363,7 +363,7 @@ def create_sequence( _parameters: lang.Parameters, sequence: lang.TypeDef, filename: Path, -) -> Optional[model.UncheckedSequence]: +) -> model.UncheckedSequence | None: assert isinstance(sequence, lang.SequenceTypeDef) element_identifier = model.internal_type_identifier( create_id(error, sequence.f_element_type, filename), @@ -435,7 +435,7 @@ def create_binop(error: RecordFluxError, expression: lang.Expr, filename: Path) raise NotImplementedError(f"Invalid BinOp {expression.f_op.kind_name} => {expression.text}") -MATH_OPERATIONS: Mapping[str, type[Union[expr.BinExpr, expr.AssExpr]]] = { +MATH_OPERATIONS: Mapping[str, type[expr.BinExpr | expr.AssExpr]] = { "OpPow": expr.Pow, "OpAdd": expr.Add, "OpSub": expr.Sub, @@ -841,9 +841,9 @@ def create_case(error: RecordFluxError, expression: lang.Expr, filename: Path) - assert isinstance(expression, lang.CaseExpression) def create_choice( - value: Union[lang.AbstractID, lang.Expr], + value: lang.AbstractID | lang.Expr, filename: Path, - ) -> Union[ID, expr.Number]: + ) -> ID | expr.Number: if isinstance(value, lang.AbstractID): return create_id(error, value, filename) assert isinstance(value, lang.Expr) @@ -851,7 +851,7 @@ def create_choice( assert isinstance(result, expr.Number) return result - choices: Sequence[tuple[Sequence[Union[ID, expr.Number]], expr.Expr]] = [ + choices: Sequence[tuple[Sequence[ID | expr.Number], expr.Expr]] = [ ( [ create_choice(s, filename) @@ -1047,7 +1047,7 @@ def create_range( _parameters: lang.Parameters, rangetype: lang.TypeDef, filename: Path, -) -> Optional[model.UncheckedInteger]: +) -> model.UncheckedInteger | None: assert isinstance(rangetype, lang.RangeTypeDef) if rangetype.f_size.f_identifier.text != "Size": error.extend( @@ -1087,7 +1087,7 @@ def create_null_message( _parameters: lang.Parameters, message: lang.TypeDef, _filename: Path, -) -> Optional[model.UncheckedMessage]: +) -> model.UncheckedMessage | None: assert isinstance(message, lang.NullMessageTypeDef) return model.UncheckedMessage( identifier, @@ -1106,11 +1106,11 @@ def create_message( parameters: lang.Parameters, message: lang.TypeDef, filename: Path, -) -> Optional[model.UncheckedMessage]: +) -> model.UncheckedMessage | None: assert isinstance(message, lang.MessageTypeDef) fields = message.f_message_fields - def get_parameters(param: lang.Parameters) -> Optional[lang.ParameterList]: + def get_parameters(param: lang.Parameters) -> lang.ParameterList | None: if not param: return None assert isinstance(param.f_parameters, lang.ParameterList) @@ -1422,10 +1422,10 @@ def parse_aspects( # noqa: PLR0912 filename: Path, ) -> tuple[ Mapping[ID, Sequence[expr.Expr]], - Union[model.ByteOrder, dict[model.Field, model.ByteOrder]], + model.ByteOrder | dict[model.Field, model.ByteOrder], ]: checksum_result = {} - byte_order_result: Union[model.ByteOrder, dict[model.Field, model.ByteOrder]] = {} + byte_order_result: model.ByteOrder | dict[model.Field, model.ByteOrder] = {} grouped = defaultdict(list) for aspect in aspects: @@ -1472,7 +1472,7 @@ def create_derived_message( _parameters: lang.Parameters, derivation: lang.TypeDef, filename: Path, -) -> Optional[model.UncheckedDerivedMessage]: +) -> model.UncheckedDerivedMessage | None: assert isinstance(derivation, lang.TypeDerivationDef) base_id = create_id(error, derivation.f_base, filename) base_name = model.internal_type_identifier(base_id, identifier.parent) @@ -1489,11 +1489,11 @@ def create_enumeration( _parameters: lang.Parameters, enumeration: lang.TypeDef, filename: Path, -) -> Optional[model.UncheckedEnumeration]: +) -> model.UncheckedEnumeration | None: assert isinstance(enumeration, lang.EnumerationTypeDef) literals: list[tuple[ID, expr.Number]] = [] - def create_aspects(aspects: lang.AspectList) -> Optional[tuple[expr.Expr, bool]]: + def create_aspects(aspects: lang.AspectList) -> tuple[expr.Expr, bool] | None: always_valid = False size = None @@ -1726,9 +1726,9 @@ def package(self) -> ID: class Parser: def __init__( self, - cache: Optional[Cache] = None, + cache: Cache | None = None, workers: int = 1, - integration_files_dir: Optional[Path] = None, + integration_files_dir: Path | None = None, ) -> None: self._cache = AlwaysVerify() if cache is None else cache self._workers = workers @@ -1827,7 +1827,7 @@ def specifications(self) -> dict[str, lang.Specification]: for spec_node in self._specifications.values() } - def _parse_file(self, error: RecordFluxError, filename: Path) -> Optional[SpecificationFile]: + def _parse_file(self, error: RecordFluxError, filename: Path) -> SpecificationFile | None: logging.info("Parsing {filename}", filename=filename) source_code_str = filename.read_text() @@ -1907,7 +1907,7 @@ def _evaluate_specification( lang.TypeDef, Path, ], - Optional[model.UncheckedTypeDecl], + model.UncheckedTypeDecl | None, ], ] = { "SequenceTypeDef": create_sequence, diff --git a/rflx/specification/style.py b/rflx/specification/style.py index 602af3bfe..7d0873cb2 100644 --- a/rflx/specification/style.py +++ b/rflx/specification/style.py @@ -3,7 +3,6 @@ import re from enum import Enum from pathlib import Path -from typing import Optional from rflx.rapidflux import ErrorEntry, Location, RecordFluxError, Severity @@ -80,7 +79,7 @@ def _append( row: int, col: int, spec_file: Path, - check_type: Optional[Check] = None, + check_type: Check | None = None, ) -> None: error.push( ErrorEntry( diff --git a/rflx/typing_.py b/rflx/typing_.py index 05a174691..4a9f85185 100644 --- a/rflx/typing_.py +++ b/rflx/typing_.py @@ -3,7 +3,7 @@ from abc import abstractmethod from collections import abc from pathlib import Path -from typing import ClassVar, Final, Optional, Union +from typing import ClassVar, Final import attr @@ -80,7 +80,7 @@ class Enumeration(NamedType): DESCRIPTIVE_NAME: ClassVar[str] = "enumeration type" literals: abc.Sequence[ID] = attr.ib() always_valid: bool = attr.ib(default=False) - location: Optional[Location] = attr.ib(default=None, cmp=False) + location: Location | None = attr.ib(default=None, cmp=False) def __str__(self) -> str: return f'{self.DESCRIPTIVE_NAME} "{self.identifier}"' @@ -125,7 +125,7 @@ class Integer(AnyInteger, NamedType): DESCRIPTIVE_NAME: ClassVar[str] = "integer type" identifier: ID = attr.ib(converter=ID) bounds: Bounds = attr.ib() - location: Optional[Location] = attr.ib(default=None, cmp=False) + location: Location | None = attr.ib(default=None, cmp=False) def __str__(self) -> str: bounds = f" ({self.bounds})" if self.bounds else "" @@ -303,8 +303,8 @@ def common_type(types: abc.Sequence[Type]) -> Type: def check_type( actual: Type, - expected: Union[Type, tuple[Type, ...]], - location: Optional[Location], + expected: Type | tuple[Type, ...], + location: Location | None, description: str, ) -> RecordFluxError: assert expected, "empty expected types" @@ -347,8 +347,8 @@ def check_type( def check_type_instance( actual: Type, - expected: Union[type[Type], tuple[type[Type], ...]], - location: Optional[Location], + expected: type[Type] | tuple[type[Type], ...], + location: Location | None, description: str = "", additionnal_annotations: abc.Sequence[Annotation] | None = None, ) -> RecordFluxError: @@ -383,7 +383,7 @@ def check_type_instance( return error -def _undefined_type(location: Optional[Location], description: str = "") -> RecordFluxError: +def _undefined_type(location: Location | None, description: str = "") -> RecordFluxError: return RecordFluxError( [ ErrorEntry( diff --git a/rflx/validator.py b/rflx/validator.py index 6f635c590..b7e3dbe56 100644 --- a/rflx/validator.py +++ b/rflx/validator.py @@ -8,7 +8,7 @@ from itertools import product from pathlib import Path from types import TracebackType -from typing import Optional, TextIO, Union +from typing import TextIO from ruamel.yaml.main import YAML from typing_extensions import Self @@ -33,9 +33,9 @@ class Validator: def __init__( self, - files: Iterable[Union[str, Path]], - checksum_module: Optional[str] = None, - cache: Optional[Cache] = None, + files: Iterable[str | Path], + checksum_module: str | None = None, + cache: Cache | None = None, split_disjunctions: bool = False, ): model = self._create_model( @@ -73,9 +73,9 @@ def __init__( def validate( # noqa: PLR0913 self, message_identifier: ID, - paths_invalid: Optional[list[Path]] = None, - paths_valid: Optional[list[Path]] = None, - json_output: Optional[Path] = None, + paths_invalid: list[Path] | None = None, + paths_valid: list[Path] | None = None, + json_output: Path | None = None, abort_on_error: bool = False, coverage: bool = False, target_coverage: float = 0.00, @@ -155,9 +155,9 @@ def validate( # noqa: PLR0913 def _check_arguments( self, _message_identifier: ID, - paths_invalid: Optional[list[Path]] = None, - paths_valid: Optional[list[Path]] = None, - json_output: Optional[Path] = None, + paths_invalid: list[Path] | None = None, + paths_valid: list[Path] | None = None, + json_output: Path | None = None, _abort_on_error: bool = False, _coverage: bool = False, target_coverage: float = 0.00, @@ -290,7 +290,7 @@ def _expand_expression(expression: expr.Expr) -> list[expr.Expr]: return result @staticmethod - def _parse_checksum_module(name: Optional[str]) -> dict[StrID, dict[str, ChecksumFunction]]: + def _parse_checksum_module(name: str | None) -> dict[StrID, dict[str, ChecksumFunction]]: if name is None: return {} @@ -336,7 +336,7 @@ def _validate_message( raise ValidationError(f"{message_path} is not a regular file") parameters_path = message_path.with_suffix(".yaml") - message_parameters: dict[str, Union[bool, int, str]] = {} + message_parameters: dict[str, bool | int | str] = {} if parameters_path.is_file(): yaml = YAML() @@ -480,7 +480,7 @@ def _print_link_coverage(self) -> None: class ValidationResult: validation_success: bool parsed_message: MessageValue - parser_error: Optional[str] + parser_error: str | None message_path: Path original_message: bytes valid_original_message: bool @@ -515,9 +515,9 @@ def print_console_output(self) -> None: class OutputWriter: - file: Optional[TextIO] + file: TextIO | None - def __init__(self, file: Optional[Path]) -> None: + def __init__(self, file: Path | None) -> None: if file is not None: try: self.file = file.open("w", encoding="utf-8") @@ -534,9 +534,9 @@ def __enter__(self) -> Self: def __exit__( self, - exception_type: Optional[type[BaseException]], - exception_value: Optional[BaseException], - traceback: Optional[TracebackType], + exception_type: type[BaseException] | None, + exception_value: BaseException | None, + traceback: TracebackType | None, ) -> None: if self.file is not None: self.file.write("\n]\n") diff --git a/rflx/version.py b/rflx/version.py index 62ce1c8be..d0b0fddf6 100644 --- a/rflx/version.py +++ b/rflx/version.py @@ -2,7 +2,6 @@ import re from importlib import metadata -from typing import Optional import rflx from rflx import __version__ @@ -52,7 +51,7 @@ def is_gnat_tracker_release() -> bool: class Requirement: def __init__(self, string: str) -> None: self.name: str - self.extra: Optional[str] + self.extra: str | None match = re.match(r'([^<=> (]{1,})[^;]*(?: *; extra == [\'"](.*)[\'"])?', string) diff --git a/stubs/pydotplus.pyi b/stubs/pydotplus.pyi index fc489849c..c1816d48a 100644 --- a/stubs/pydotplus.pyi +++ b/stubs/pydotplus.pyi @@ -1,13 +1,13 @@ from collections.abc import Iterable -from typing import BinaryIO, Optional +from typing import BinaryIO class Edge: - def __init__(self, dst: str, src: str, **attr: Optional[str]): ... + def __init__(self, dst: str, src: str, **attr: str | None): ... def get_source(self) -> str: ... def get_destination(self) -> str: ... class Node: - def __init__(self, name: str, **attrs: Optional[str]): ... + def __init__(self, name: str, **attrs: str | None): ... def get_name(self) -> str: ... class Dot: @@ -17,11 +17,11 @@ class Dot: def write( self, handle: BinaryIO, - prog: Optional[str] = None, - format: Optional[str] = "raw", # noqa: A002 + prog: str | None = None, + format: str | None = "raw", # noqa: A002 ) -> None: ... def get_nodes(self) -> Iterable[Node]: ... def get_edges(self) -> Iterable[Edge]: ... - def set_graph_defaults(self, **attrs: Optional[str]) -> None: ... - def set_edge_defaults(self, **attrs: Optional[str]) -> None: ... - def set_node_defaults(self, **attrs: Optional[str]) -> None: ... + def set_graph_defaults(self, **attrs: str | None) -> None: ... + def set_edge_defaults(self, **attrs: str | None) -> None: ... + def set_node_defaults(self, **attrs: str | None) -> None: ... diff --git a/stubs/z3.pyi b/stubs/z3.pyi index 838119d11..1e85fd93a 100644 --- a/stubs/z3.pyi +++ b/stubs/z3.pyi @@ -1,7 +1,6 @@ # ruff: noqa: N802, N818 from collections.abc import Iterable -from typing import Optional class Context: ... @@ -40,16 +39,16 @@ class SortRef(AstRef): ... def DeclareSort(name: str) -> SortRef: ... def Const(name: str, sort: SortRef) -> ExprRef: ... -def Int(name: str, ctx: Optional[Context] = None) -> ArithRef: ... -def IntVal(val: int, ctx: Optional[Context] = None) -> ArithRef: ... +def Int(name: str, ctx: Context | None = None) -> ArithRef: ... +def IntVal(val: int, ctx: Context | None = None) -> ArithRef: ... def Sum(*args: ArithRef) -> ArithRef: ... def Product(*args: ArithRef) -> ArithRef: ... -def Bool(name: str, ctx: Optional[Context] = None) -> BoolRef: ... -def BoolVal(val: bool, ctx: Optional[Context] = None) -> BoolRef: ... -def Not(val: BoolRef, ctx: Optional[Context] = None) -> BoolRef: ... +def Bool(name: str, ctx: Context | None = None) -> BoolRef: ... +def BoolVal(val: bool, ctx: Context | None = None) -> BoolRef: ... +def Not(val: BoolRef, ctx: Context | None = None) -> BoolRef: ... def And(*args: BoolRef) -> BoolRef: ... def Or(*args: BoolRef) -> BoolRef: ... -def If(c: BoolRef, t: ExprRef, e: ExprRef, ctx: Optional[Context] = None) -> ExprRef: ... +def If(c: BoolRef, t: ExprRef, e: ExprRef, ctx: Context | None = None) -> ExprRef: ... def ForAll(v: Iterable[ExprRef], cond: ExprRef) -> ExprRef: ... def Exists(v: Iterable[ExprRef], cond: ExprRef) -> ExprRef: ... def simplify(e: ExprRef) -> ExprRef: ... diff --git a/tests/feature/__init__.py b/tests/feature/__init__.py index a19137693..273bb645a 100644 --- a/tests/feature/__init__.py +++ b/tests/feature/__init__.py @@ -27,38 +27,45 @@ # This is only relevant for Python 3.8. +# TODO(eng/recordflux/RecordFlux#1424): Replace remaining use of Optional +# and Union. Pydantic has issues with PEP604 type annotations in Python +# 3.8 and 3.9. + + class ConfigFile(BaseModel): # type: ignore[misc] - input: Optional[ty.Mapping[str, Optional[ty.Sequence[ty.Union[int, str]]]]] = None - output: Optional[ty.Sequence[str]] = None - sequence: Optional[str] = None - prove: Optional[ty.Sequence[str]] = None - external_io_buffers: Optional[int] = None + input: Optional[ty.Mapping[str, Optional[ty.Sequence[ty.Union[int, str]]]]] = ( # noqa: UP007 + None + ) + output: Optional[ty.Sequence[str]] = None # noqa: UP007 + sequence: Optional[str] = None # noqa: UP007 + prove: Optional[ty.Sequence[str]] = None # noqa: UP007 + external_io_buffers: Optional[int] = None # noqa: UP007 @field_validator("input") def initialize_input_if_present( cls, # noqa: N805 - value: Optional[ty.Mapping[str, ty.Sequence[str]]], + value: Optional[ty.Mapping[str, ty.Sequence[str]]], # noqa: UP007 ) -> ty.Mapping[str, ty.Sequence[str]]: return value if value is not None else {} @field_validator("output") def initialize_output_if_present( cls, # noqa: N805 - value: Optional[ty.Sequence[str]], + value: Optional[ty.Sequence[str]], # noqa: UP007 ) -> ty.Sequence[str]: return value if value is not None else [] @field_validator("prove") def initialize_prove_if_present( cls, # noqa: N805 - value: Optional[ty.Sequence[str]], + value: Optional[ty.Sequence[str]], # noqa: UP007 ) -> ty.Sequence[str]: return value if value is not None else [] @field_validator("external_io_buffers") def initialize_external_io_buffers_if_present( cls, # noqa: N805 - value: Optional[int], + value: Optional[int], # noqa: UP007 ) -> int: return value if value is not None else 0 @@ -68,7 +75,7 @@ class Config: inp: dict[str, Sequence[tuple[int, ...]]] = dataclass_field(default_factory=dict) out: Sequence[str] = dataclass_field(default_factory=list) sequence: str = dataclass_field(default="") - prove: Optional[Sequence[str]] = dataclass_field(default=None) + prove: Sequence[str] | None = dataclass_field(default=None) external_io_buffers: int = dataclass_field(default=0) diff --git a/tests/property/strategies.py b/tests/property/strategies.py index cce3f75ae..7f21f25d9 100644 --- a/tests/property/strategies.py +++ b/tests/property/strategies.py @@ -3,7 +3,7 @@ import string from collections import abc from dataclasses import dataclass -from typing import Optional, Protocol, TypeVar, Union +from typing import Protocol, TypeVar from hypothesis import assume, strategies as st @@ -184,8 +184,8 @@ def messages( # noqa: PLR0915 class FieldPair: source: Field target: Field - source_type: Optional[TypeDecl] - target_type: Optional[TypeDecl] + source_type: TypeDecl | None + target_type: TypeDecl | None def size(pair: FieldPair) -> expr.Expr: max_size = 2**29 - 1 @@ -379,7 +379,7 @@ def append_types(message: Message) -> None: @st.composite -def numbers(draw: Draw, min_value: int = 0, max_value: Optional[int] = None) -> expr.Number: +def numbers(draw: Draw, min_value: int = 0, max_value: int | None = None) -> expr.Number: return expr.Number(draw(st.integers(min_value=min_value, max_value=max_value))) @@ -390,7 +390,7 @@ def variables(draw: Draw, elements: st.SearchStrategy[str]) -> expr.Variable: @st.composite def attributes(draw: Draw, elements: st.SearchStrategy[expr.Expr]) -> expr.Expr: - sample: st.SearchStrategy[type[Union[expr.Size, expr.First, expr.Last]]] = st.sampled_from( + sample: st.SearchStrategy[type[expr.Size | expr.First | expr.Last]] = st.sampled_from( [expr.Size, expr.First, expr.Last], ) attribute = draw(sample) @@ -432,16 +432,14 @@ def mathematical_expressions(draw: Draw, elements: st.SearchStrategy[expr.Expr]) def relations(draw: Draw, elements: st.SearchStrategy[expr.Expr]) -> expr.Relation: sample: st.SearchStrategy[ type[ - Union[ - expr.Less, - expr.LessEqual, - expr.Equal, - expr.GreaterEqual, - expr.Greater, - expr.NotEqual, - expr.In, - expr.NotIn, - ] + expr.Less + | expr.LessEqual + | expr.Equal + | expr.GreaterEqual + | expr.Greater + | expr.NotEqual + | expr.In + | expr.NotIn ] ] = st.sampled_from( [ diff --git a/tests/tools/check_grammar_test.py b/tests/tools/check_grammar_test.py index bf40f7b63..4a52a3e37 100644 --- a/tests/tools/check_grammar_test.py +++ b/tests/tools/check_grammar_test.py @@ -2,7 +2,6 @@ import sys from pathlib import Path -from typing import Optional import pytest @@ -13,7 +12,7 @@ DATA_DIR = BASE_DATA_DIR / "lrm_grammar" -def check_rst(filename: Path, invalid: bool = False, examples: Optional[list[Path]] = None) -> None: +def check_rst(filename: Path, invalid: bool = False, examples: list[Path] | None = None) -> None: errors = RecordFluxError() check_spec( filename=filename, diff --git a/tests/tools/check_unit_test_file_coverage_test.py b/tests/tools/check_unit_test_file_coverage_test.py index 447935d1c..e057afb69 100644 --- a/tests/tools/check_unit_test_file_coverage_test.py +++ b/tests/tools/check_unit_test_file_coverage_test.py @@ -2,7 +2,6 @@ import sys from pathlib import Path -from typing import Optional import pytest @@ -180,8 +179,8 @@ def test_no_files( tmp_path: Path, source_files: list[str], test_files: list[str], - expected: Optional[str], - ignore: Optional[list[str]], + expected: str | None, + ignore: list[str] | None, ) -> None: source_dir = tmp_path / "source" test_dir = tmp_path / "test" @@ -230,7 +229,7 @@ def test_no_files( ("a", None, False), ], ) -def test_ignored(file: str, ignore_list: Optional[list[str]], expected: bool) -> None: +def test_ignored(file: str, ignore_list: list[str] | None, expected: bool) -> None: assert ( check_unit_test_file_coverage._ignored( # noqa: SLF001 Path(file), @@ -241,13 +240,13 @@ def test_ignored(file: str, ignore_list: Optional[list[str]], expected: bool) -> def test_cli(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: - stored_source_dir: Optional[Path] = None - stored_test_dir: Optional[Path] = None + stored_source_dir: Path | None = None + stored_test_dir: Path | None = None def dummy_check_file_coverage( source_dir: Path, test_dir: Path, - ignore: Optional[list[Path]] = None, # noqa: ARG001 + ignore: list[Path] | None = None, # noqa: ARG001 ) -> None: nonlocal stored_source_dir, stored_test_dir stored_source_dir = source_dir diff --git a/tests/unit/cli_test.py b/tests/unit/cli_test.py index 5384803ab..9ece5d256 100644 --- a/tests/unit/cli_test.py +++ b/tests/unit/cli_test.py @@ -8,7 +8,7 @@ from collections.abc import Callable from io import TextIOWrapper from pathlib import Path -from typing import ClassVar, NoReturn, Optional +from typing import ClassVar, NoReturn import pytest @@ -30,8 +30,8 @@ def validator_mock( self: object, # noqa: ARG001 files: object, # noqa: ARG001 - checksum_module: Optional[str] = None, # noqa: ARG001 - cache: Optional[object] = None, # noqa: ARG001 + checksum_module: str | None = None, # noqa: ARG001 + cache: object | None = None, # noqa: ARG001 split_disjunctions: bool = False, # noqa: ARG001 ) -> None: return None diff --git a/tests/unit/generator/session_test.py b/tests/unit/generator/session_test.py index 3539d74ec..5d5566428 100644 --- a/tests/unit/generator/session_test.py +++ b/tests/unit/generator/session_test.py @@ -3,7 +3,7 @@ import typing from dataclasses import dataclass from functools import lru_cache -from typing import Callable, Optional, Sequence +from typing import Callable, Sequence import pytest import z3 @@ -729,7 +729,7 @@ class EvaluatedDeclarationStr: ) def test_session_declare( type_: rty.Type, - expression: Optional[ir.ComplexExpr], + expression: ir.ComplexExpr | None, constant: bool, session_global: bool, expected: EvaluatedDeclarationStr, @@ -830,7 +830,7 @@ def test_session_declare( ) def test_session_declare_error( type_: rty.Type, - expression: Optional[ir.ComplexExpr], + expression: ir.ComplexExpr | None, error_type: type[RecordFluxError], error_msg: str, ) -> None: diff --git a/tests/unit/identifier_test.py b/tests/unit/identifier_test.py index 42d6bfd53..a80bc40bc 100644 --- a/tests/unit/identifier_test.py +++ b/tests/unit/identifier_test.py @@ -3,7 +3,6 @@ import pickle from collections.abc import Sequence from pathlib import Path -from typing import Union import pytest @@ -38,7 +37,7 @@ def test_id_invalid_type() -> None: "A::B:C::D", ], ) -def test_id_invalid(identifier: Union[str, Sequence[str]]) -> None: +def test_id_invalid(identifier: str | Sequence[str]) -> None: with pytest.raises(FatalError, match=r"^invalid identifier$"): ID(identifier) diff --git a/tests/unit/ls/server_test.py b/tests/unit/ls/server_test.py index 18acb0e97..058c9ef80 100644 --- a/tests/unit/ls/server_test.py +++ b/tests/unit/ls/server_test.py @@ -3,7 +3,7 @@ import asyncio import re from pathlib import Path -from typing import Final, Optional +from typing import Final import pytest from lsprotocol.types import ( @@ -136,7 +136,7 @@ def test_to_lsp_location() -> None: (error.Severity.HELP, DiagnosticSeverity.Hint), ], ) -def test_to_lsp_severity(severity: error.Severity, expected: Optional[DiagnosticSeverity]) -> None: +def test_to_lsp_severity(severity: error.Severity, expected: DiagnosticSeverity | None) -> None: assert server.to_lsp_severity(severity) == expected diff --git a/tests/unit/typing__test.py b/tests/unit/typing__test.py index 677343798..490ea1a20 100644 --- a/tests/unit/typing__test.py +++ b/tests/unit/typing__test.py @@ -1,7 +1,6 @@ from __future__ import annotations from collections import abc -from typing import Union import pytest @@ -662,7 +661,7 @@ def test_check_type_error(actual: Type, expected: Type, match: str) -> None: ) def test_check_type_instance( actual: Type, - expected: Union[type[Type], tuple[type[Type], ...]], + expected: type[Type] | tuple[type[Type], ...], ) -> None: check_type_instance(actual, expected, Location((10, 20)), '"A"').propagate() @@ -694,7 +693,7 @@ def test_check_type_instance( ) def test_check_type_instance_error( actual: Type, - expected: Union[type[Type], tuple[type[Type], ...]], + expected: type[Type] | tuple[type[Type], ...], match: str, ) -> None: with pytest.raises(RecordFluxError, match=match): diff --git a/tests/unit/version_test.py b/tests/unit/version_test.py index d37af2c67..7405bca2f 100644 --- a/tests/unit/version_test.py +++ b/tests/unit/version_test.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import Optional - import pytest import rflx @@ -47,7 +45,7 @@ def test_is_gnat_tracker_release( ('setuptools_scm<8,>=6.2; extra == "devel"', "setuptools_scm", "devel"), ], ) -def test_requirement(requirement: str, name: str, extra: Optional[str]) -> None: +def test_requirement(requirement: str, name: str, extra: str | None) -> None: r = version.Requirement(requirement) assert r.name == name assert r.extra == extra diff --git a/tests/utils.py b/tests/utils.py index a5d5bc3d5..502583774 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -7,7 +7,6 @@ import subprocess import textwrap from collections.abc import Iterable, Mapping, Sequence -from typing import Optional, Union import pytest @@ -45,8 +44,8 @@ def assert_message_model_error( structure: Sequence[Link], types: Mapping[Field, TypeDecl], regex: str, - checksums: Optional[Mapping[ID, Sequence[Expr]]] = None, - location: Optional[Location] = None, + checksums: Mapping[ID, Sequence[Expr]] | None = None, + location: Location | None = None, ) -> None: location = location or Location((1, 1), end=(1, 2)) check_regex(regex) @@ -80,7 +79,7 @@ def assert_session_model_error( def assert_equal_code_specs( - spec_files: Iterable[Union[str, pathlib.Path]], + spec_files: Iterable[str | pathlib.Path], expected_dir: pathlib.Path, tmp_path: pathlib.Path, accept_extra_files: bool = False, @@ -133,9 +132,9 @@ def assert_equal_code( def assert_compilable_code_specs( - spec_files: Iterable[Union[str, pathlib.Path]], + spec_files: Iterable[str | pathlib.Path], tmp_path: pathlib.Path, - prefix: Optional[str] = None, + prefix: str | None = None, ) -> None: parser = Parser() @@ -148,7 +147,7 @@ def assert_compilable_code_specs( def assert_compilable_code_string( specification: str, tmp_path: pathlib.Path, - prefix: Optional[str] = None, + prefix: str | None = None, ) -> None: parser = Parser() parser.parse_string(specification) @@ -160,8 +159,8 @@ def assert_compilable_code( # noqa: PLR0913 model: Model, integration: Integration, tmp_path: pathlib.Path, - main: Optional[str] = None, - prefix: Optional[str] = None, + main: str | None = None, + prefix: str | None = None, debug: Debug = Debug.BUILTIN, mode: str = "strict", ) -> None: @@ -184,7 +183,7 @@ def assert_executable_code( integration: Integration, tmp_path: pathlib.Path, main: str = MAIN, - prefix: Optional[str] = None, + prefix: str | None = None, debug: Debug = Debug.BUILTIN, ) -> str: assert_compilable_code( @@ -215,8 +214,8 @@ def assert_executable_code( def assert_provable_code_string( specification: str, tmp_path: pathlib.Path, - prefix: Optional[str] = None, - units: Optional[Sequence[str]] = None, + prefix: str | None = None, + units: Sequence[str] | None = None, ) -> None: parser = Parser() parser.parse_string(specification) @@ -228,9 +227,9 @@ def assert_provable_code( model: Model, integration: Integration, tmp_path: pathlib.Path, - main: Optional[str] = None, - prefix: Optional[str] = None, - units: Optional[Sequence[str]] = None, + main: str | None = None, + prefix: str | None = None, + units: Sequence[str] | None = None, ) -> None: _create_files(tmp_path, model, integration, main, prefix) @@ -259,8 +258,8 @@ def _create_files( tmp_path: pathlib.Path, model: Model, integration: Integration, - main: Optional[str] = None, - prefix: Optional[str] = None, + main: str | None = None, + prefix: str | None = None, debug: Debug = Debug.BUILTIN, ) -> None: shutil.copy("defaults.gpr", tmp_path) @@ -318,8 +317,8 @@ def _create_files( def session_main( - input_channels: Optional[dict[str, Sequence[tuple[int, ...]]]] = None, - output_channels: Optional[Sequence[str]] = None, + input_channels: dict[str, Sequence[tuple[int, ...]]] | None = None, + output_channels: Sequence[str] | None = None, external_io_buffers: int = 0, ) -> Mapping[str, str]: input_channels = input_channels or {} diff --git a/tools/check_doc.py b/tools/check_doc.py index bcde3f2b9..d79ba8119 100755 --- a/tools/check_doc.py +++ b/tools/check_doc.py @@ -19,7 +19,6 @@ import tempfile import textwrap from pathlib import Path -from typing import Optional from ruamel.yaml import YAML from ruamel.yaml.parser import ParserError @@ -58,7 +57,7 @@ class State(enum.Enum): class StyleChecker: def __init__(self, filename: Path): self._filename = filename - self._previous: Optional[tuple[int, str]] = None + self._previous: tuple[int, str] | None = None self._headings_re = re.compile(r"^(=+|-+|~+|\^+|\*+|\"+)$") def check(self, lineno: int, line: str) -> None: @@ -134,11 +133,11 @@ def __init__(self, filename: Path): def check( self, - lineno: Optional[int], + lineno: int | None, block: str, - code_type: Optional[CodeBlockType], + code_type: CodeBlockType | None, indent: int, - subtype: Optional[str] = None, + subtype: str | None = None, ) -> None: assert lineno # Remove trailing empty line as this is an error for RecordFlux style checks. It could be @@ -165,7 +164,7 @@ def check( f"{self._filename}:{lineno}: error in code block\n{error}", ) from error - def _check_rflx(self, block: str, subtype: Optional[str] = None) -> None: + def _check_rflx(self, block: str, subtype: str | None = None) -> None: try: if subtype is None: parser = Parser() @@ -178,7 +177,7 @@ def _check_rflx(self, block: str, subtype: Optional[str] = None) -> None: except RecordFluxError as rflx_error: raise CheckDocError(str(rflx_error)) from rflx_error - def _check_ada(self, block: str, subtype: Optional[str] = None) -> None: + def _check_ada(self, block: str, subtype: str | None = None) -> None: args = [] unit = "main" @@ -263,10 +262,10 @@ def check_file(filename: Path, content: str) -> bool: # noqa: PLR0912, PLR0915 found = False state = State.OUTSIDE block = "" - block_start: Optional[int] = None - doc_check_type: Optional[CodeBlockType] = None + block_start: int | None = None + doc_check_type: CodeBlockType | None = None indent: int = 0 - subtype: Optional[str] = None + subtype: str | None = None style_checker = StyleChecker(filename) code_checker = CodeChecker(filename) diff --git a/tools/check_requirements.py b/tools/check_requirements.py index cde32330c..ef30abfda 100755 --- a/tools/check_requirements.py +++ b/tools/check_requirements.py @@ -10,7 +10,6 @@ import sys from collections.abc import Sequence from pathlib import Path -from typing import Optional, Union ID_SEPARATOR = "-" ID_REGEX = r"ยง(?:[A-Z0-9]+" + ID_SEPARATOR + r")*[A-Z0-9]+" @@ -42,7 +41,7 @@ def __init__( identifier: str, description: str, referenced: bool = False, - requirements: Optional[list[Requirement]] = None, + requirements: list[Requirement] | None = None, ): self._identifier = identifier self._description = description.capitalize() @@ -78,7 +77,7 @@ def referenced(self, value: bool) -> None: self._referenced = value -def main(argv: Sequence[str]) -> Union[bool, str]: +def main(argv: Sequence[str]) -> bool | str: arg_parser = argparse.ArgumentParser() arg_parser.add_argument("requirements", metavar="REQUIREMENTS_FILE", type=Path) arg_parser.add_argument("directories", metavar="DIRECTORY", type=Path, nargs="+") diff --git a/tools/check_unit_test_file_coverage.py b/tools/check_unit_test_file_coverage.py index ec339c8e4..8baa125fa 100755 --- a/tools/check_unit_test_file_coverage.py +++ b/tools/check_unit_test_file_coverage.py @@ -12,14 +12,13 @@ import argparse import sys from pathlib import Path -from typing import Optional class CheckUnitTestFileError(Exception): pass -def _ignored(file: Path, ignore_list: Optional[list[Path]] = None) -> bool: +def _ignored(file: Path, ignore_list: list[Path] | None = None) -> bool: return ignore_list is not None and any( i.parts == file.parts[: len(i.parts)] for i in ignore_list ) @@ -28,7 +27,7 @@ def _ignored(file: Path, ignore_list: Optional[list[Path]] = None) -> bool: def check_file_coverage( source_dir: Path, test_dir: Path, - ignore: Optional[list[Path]] = None, + ignore: list[Path] | None = None, ) -> None: if not source_dir.exists(): diff --git a/tools/extract_packets.py b/tools/extract_packets.py index 3b1e3c678..e2a6e67c4 100755 --- a/tools/extract_packets.py +++ b/tools/extract_packets.py @@ -17,13 +17,12 @@ from math import ceil, log from pathlib import Path from pydoc import locate -from typing import Union import scapy.layers # type: ignore[import-untyped] from scapy.utils import hexdump, rdpcap # type: ignore[import-untyped] -def main(argv: Sequence[str]) -> Union[bool, str]: +def main(argv: Sequence[str]) -> bool | str: available_layers = [ f"{p.name}.{c}" for p in pkgutil.iter_modules(scapy.layers.__path__) diff --git a/tools/rflxlexer.py b/tools/rflxlexer.py index 8bc4f1377..e3fbd7cfc 100644 --- a/tools/rflxlexer.py +++ b/tools/rflxlexer.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Iterator, Optional +from typing import Iterator import pygments.lexer import pygments.token @@ -174,7 +174,7 @@ def get_tokens_unprocessed( for t in unit.iter_tokens() ] - prev: Optional[tuple[int, lang.SlocRange]] = None + prev: tuple[int, lang.SlocRange] | None = None result = [] for start, t in tokens: