Skip to content

Commit

Permalink
Fix fatal error when comparing opaque fields
Browse files Browse the repository at this point in the history
Ref. eng/recordflux/RecordFlux#1497
  • Loading branch information
treiher committed Dec 12, 2023
1 parent a7f4d1e commit c91993c
Show file tree
Hide file tree
Showing 7 changed files with 236 additions and 60 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Fixed

- Fatal error when comparing opaque fields (AdaCore/RecordFlux#1294, eng/recordflux/RecordFlux#1497)

### Removed

- Verification of message bit coverage (eng/recordflux/RecordFlux#1495)
Expand Down
26 changes: 19 additions & 7 deletions rflx/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -3861,14 +3861,26 @@ def _to_ir_basic_expr(
assert isinstance(expression.type_, rty.Any)
result_expr = ir.ObjVar(result_id, expression.type_, origin=expression)

result_type = result_expr.type_
if isinstance(result_expr.type_, rty.Aggregate):
# TODO(eng/recordflux/RecordFlux#1497): Support comparisons of opaque fields
result_stmts = [ # pragma: no cover
*result.stmts,
ir.VarDecl(
result_id,
rty.OPAQUE,
ir.ComplexExpr([], result.expr),
origin=expression,
),
]
else:
result_type = result_expr.type_

assert isinstance(result_type, rty.NamedType)
assert isinstance(result_type, rty.NamedType)

result_stmts = [
*result.stmts,
ir.VarDecl(result_id, result_type, None, origin=expression),
ir.Assign(result_id, result.expr, result_type, origin=expression),
]
result_stmts = [
*result.stmts,
ir.VarDecl(result_id, result_type, None, origin=expression),
ir.Assign(result_id, result.expr, result_type, origin=expression),
]

return (result_stmts, result_expr)
94 changes: 59 additions & 35 deletions rflx/generator/allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def _scope(state: ir.State, var_id: ID) -> Optional[ID]:
return state.identifier.name
return None

def _allocate_local_slots( # noqa: PLR0912
def _allocate_local_slots(
self,
) -> list[SlotInfo]:
"""
Expand All @@ -378,58 +378,82 @@ class AllocationRequirement:
location: Optional[Location]
size: int

alloc_requirements_per_state: list[list[AllocationRequirement]] = []
for s in self._session.states:
state_requirements = []
for a in s.actions:
if isinstance(a, ir.VarDecl) and self._needs_allocation(a.type_):
state_requirements.append(
def determine_allocation_requirements(
statements: Sequence[ir.Stmt],
state: ir.State,
) -> list[AllocationRequirement]:
alloc_requirements = []

for statement in statements:
if isinstance(statement, ir.VarDecl) and self._needs_allocation(statement.type_):
alloc_requirements.append(
AllocationRequirement(
a.location,
self.get_size(a.identifier, s.identifier.name),
statement.location,
self.get_size(statement.identifier, state.identifier.name),
),
)
if (
isinstance(a, ir.Assign)
and isinstance(a.expression, ir.Comprehension)
and isinstance(a.expression.sequence.type_, rty.Sequence)
and isinstance(a.expression.sequence.type_.element, rty.Message)
and isinstance(a.expression.sequence, (ir.Var, ir.FieldAccess))
isinstance(statement, ir.Assign)
and isinstance(statement.expression, ir.Comprehension)
and isinstance(statement.expression.sequence.type_, rty.Sequence)
and isinstance(statement.expression.sequence.type_.element, rty.Message)
and isinstance(statement.expression.sequence, (ir.Var, ir.FieldAccess))
):
if isinstance(a.expression.sequence, ir.FieldAccess):
identifier = a.expression.sequence.message
if isinstance(statement.expression.sequence, ir.FieldAccess):
identifier = statement.expression.sequence.message
else:
assert isinstance(a.expression.sequence, ir.Var)
identifier = a.expression.sequence.identifier
state_requirements.append(
assert isinstance(statement.expression.sequence, ir.Var)
identifier = statement.expression.sequence.identifier
alloc_requirements.append(
AllocationRequirement(
a.location,
self.get_size(identifier, self._scope(s, identifier)),
statement.location,
self.get_size(identifier, self._scope(state, identifier)),
),
)
if isinstance(a, ir.Assign) and isinstance(a.expression, ir.Head):
identifier = a.expression.prefix
state_requirements.append(
if isinstance(statement, ir.Assign) and isinstance(statement.expression, ir.Head):
identifier = statement.expression.prefix
alloc_requirements.append(
AllocationRequirement(
a.location,
self.get_size(identifier, self._scope(s, identifier)),
statement.location,
self.get_size(identifier, self._scope(state, identifier)),
),
)
if isinstance(a, ir.Assign) and isinstance(a.expression, ir.Find):
if isinstance(a.expression.sequence, ir.Var):
identifier = a.expression.sequence.identifier
elif isinstance(a.expression.sequence, ir.FieldAccess):
identifier = a.expression.sequence.message
if isinstance(statement, ir.Assign) and isinstance(statement.expression, ir.Find):
if isinstance(statement.expression.sequence, ir.Var):
identifier = statement.expression.sequence.identifier
elif isinstance(statement.expression.sequence, ir.FieldAccess):
identifier = statement.expression.sequence.message
else:
assert False
state_requirements.append(
alloc_requirements.append(
AllocationRequirement(
a.location,
self.get_size(identifier, self._scope(s, identifier)),
statement.location,
self.get_size(identifier, self._scope(state, identifier)),
),
)
if isinstance(statement, ir.Assign) and isinstance(
statement.expression,
(ir.Comprehension, ir.Find),
):
alloc_requirements.extend(
determine_allocation_requirements(
statement.expression.selector.stmts,
state,
),
)
alloc_requirements.extend(
determine_allocation_requirements(
statement.expression.condition.stmts,
state,
),
)

return alloc_requirements

alloc_requirements_per_state.append(state_requirements)
alloc_requirements_per_state = [
determine_allocation_requirements(state.actions, state)
for state in self._session.states
]

for state_requirements in alloc_requirements_per_state:
state_requirements.sort(key=lambda x: x.size, reverse=True)
Expand Down
5 changes: 3 additions & 2 deletions rflx/generator/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,7 @@ def _create_uninitialized_function(
)
for declaration in composite_globals
if isinstance(declaration.type_, (rty.Message, rty.Sequence))
and declaration.type_ != rty.OPAQUE
],
*(
[
Expand Down Expand Up @@ -812,7 +813,7 @@ def _create_states(
]

for d in declarations:
if isinstance(d.type_, (rty.Message, rty.Sequence)):
if isinstance(d.type_, (rty.Message, rty.Sequence)) and d.type_ != rty.OPAQUE:
identifier = context_id(d.identifier, is_global)
type_identifier = self._ada_type(d.type_.identifier)
invariant.extend(
Expand Down Expand Up @@ -2329,7 +2330,7 @@ def _assign( # noqa: PLR0913
ir.CaseExpr,
),
) and (
isinstance(expression.type_, (rty.AnyInteger, rty.Enumeration))
isinstance(expression.type_, (rty.AnyInteger, rty.Enumeration, rty.Aggregate))
or expression.type_ == rty.OPAQUE
):
assert isinstance(
Expand Down
66 changes: 50 additions & 16 deletions rflx/model/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from rflx import expression as expr, ir, typing_ as rty
from rflx.common import Base, indent, indent_next, verbose_repr
from rflx.error import Location, Severity, Subsystem
from rflx.error import Location, RecordFluxError, Severity, Subsystem
from rflx.identifier import ID, StrID, id_generator

from . import (
Expand Down Expand Up @@ -674,6 +674,35 @@ def _validate_actions(
actions: Sequence[stmt.Statement],
declarations: Mapping[ID, decl.Declaration],
local_declarations: Mapping[ID, decl.Declaration],
) -> None:
self._validate_io_states(actions, local_declarations)

for a in actions:
try:
type_ = declarations[a.identifier].type_
except KeyError:
type_ = rty.Undefined()

self.error.extend(
a.check_type(
type_,
lambda x: self._typify_variable(x, declarations),
),
)

self._reference_variable_declaration(a.variables(), declarations)

if isinstance(a, stmt.Assignment):
a.expression.substituted(lambda e: error_on_unsupported_expression(e, self.error))
else:
assert isinstance(a, stmt.AttributeStatement)
for e in a.parameters:
e.substituted(lambda e: error_on_unsupported_expression(e, self.error))

def _validate_io_states(
self,
actions: Sequence[stmt.Statement],
local_declarations: Mapping[ID, decl.Declaration],
) -> None:
io_statements = [a for a in actions if isinstance(a, stmt.ChannelAttributeStatement)]

Expand Down Expand Up @@ -769,21 +798,6 @@ def _validate_actions(
],
)

for a in actions:
try:
type_ = declarations[a.identifier].type_
except KeyError:
type_ = rty.Undefined()

self.error.extend(
a.check_type(
type_,
lambda x: self._typify_variable(x, declarations),
),
)

self._reference_variable_declaration(a.variables(), declarations)

def _validate_transitions(
self,
state: State,
Expand All @@ -794,6 +808,8 @@ def _validate_transitions(
self.error.extend(t.condition.check_type(rty.BOOLEAN))
self._reference_variable_declaration(t.condition.variables(), declarations)

t.condition.substituted(lambda e: error_on_unsupported_expression(e, self.error))

if not state.exception_transition and state.has_exceptions:
self.error.extend(
[
Expand Down Expand Up @@ -932,3 +948,21 @@ def checked(


FINAL_STATE: Final[State] = State("Final")


def error_on_unsupported_expression(expression: expr.Expr, error: RecordFluxError) -> expr.Expr:
# TODO(eng/recordflux/RecordFlux#1497): Support comparisons of opaque fields
if isinstance(expression, (expr.Equal, expr.NotEqual)):
for e in [expression.left, expression.right]:
if isinstance(e, expr.Selected) and e.type_ == rty.OPAQUE:
error.extend(
[
(
"comparisons of opaque fields not yet supported",
Subsystem.MODEL,
Severity.ERROR,
expression.left.location,
),
],
)
return expression
57 changes: 57 additions & 0 deletions tests/compilation/integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytest

from rflx.error import RecordFluxError
from rflx.generator import Debug
from rflx.integration import Integration
from rflx.model import Model, Session, State, Transition
Expand Down Expand Up @@ -923,3 +924,59 @@ def test_session_message_field_access_in_transition(tmp_path: Path) -> None:
""",
tmp_path,
)


@pytest.mark.parametrize(
("global_decl", "local_decl", "value"),
[
("Key : Opaque := [0, 1, 0];", "", "Key"),
("", "Key : Opaque := [0, 1, 0];", "Key"),
("", "", "[0, 1, 0]"),
],
)
def test_session_comparing_opaque_values_in_comprehension(
global_decl: str,
local_decl: str,
value: str,
tmp_path: Path,
) -> None:
spec = f"""\
package Test is
type Message is
message
Key : Opaque
with Size => 3 * 8;
end message;
type Messages is sequence of Message;
generic
session Session is
{global_decl}
begin
state S is
Ms_1 : Messages;
Ms_2 : Messages;
{local_decl}
begin
Ms_2 :=
[for M in Ms_1
if M.Key = {value} =>
M];
transition
goto null
exception
goto null
end S;
end Session;
end Test;
"""
# TODO(eng/recordflux/RecordFlux#1497): Support comparisons of opaque fields
with pytest.raises(
RecordFluxError,
match=r"^<stdin>:22:17: model: error: comparisons of opaque fields not yet supported$",
):
utils.assert_compilable_code_string(spec, tmp_path)
Loading

0 comments on commit c91993c

Please sign in to comment.