Skip to content

Commit

Permalink
[red-knot] Add support for unpacking for target (#15058)
Browse files Browse the repository at this point in the history
## Summary

Related to #13773 

This PR adds support for unpacking `for` statement targets.

This involves updating the `value` field in the `Unpack` target to use
an enum which specifies the "where did the value expression came from?".
This is because for an iterable expression, we need to unpack the
iterator type while for assignment statement we need to unpack the value
type itself. And, this needs to be done in the unpack query.

### Question

One of the ways unpacking works in `for` statement is by looking at the
union of the types because if the iterable expression is a tuple then
the iterator type will be union of all the types in the tuple. This
means that the test cases that will test the unpacking in `for`
statement will also implicitly test the unpacking union logic. I was
wondering if it makes sense to merge these cases and only add the ones
that are specific to the union unpacking or for statement unpacking
logic.

## Test Plan

Add test cases involving iterating over a tuple type. I've intentionally
left out certain cases for now and I'm curious to know any thoughts on
the above query.
  • Loading branch information
dhruvmanila authored Dec 23, 2024
1 parent b6c8f5d commit 113c804
Show file tree
Hide file tree
Showing 7 changed files with 294 additions and 83 deletions.
101 changes: 101 additions & 0 deletions crates/red_knot_python_semantic/resources/mdtest/unpacking.md
Original file line number Diff line number Diff line change
Expand Up @@ -472,3 +472,104 @@ def _(arg: tuple[int, str] | Iterable):
reveal_type(a) # revealed: int | bytes
reveal_type(b) # revealed: str | bytes
```

## For statement

Unpacking in a `for` statement.

### Same types

```py
def _(arg: tuple[tuple[int, int], tuple[int, int]]):
for a, b in arg:
reveal_type(a) # revealed: int
reveal_type(b) # revealed: int
```

### Mixed types (1)

```py
def _(arg: tuple[tuple[int, int], tuple[int, str]]):
for a, b in arg:
reveal_type(a) # revealed: int
reveal_type(b) # revealed: int | str
```

### Mixed types (2)

```py
def _(arg: tuple[tuple[int, str], tuple[str, int]]):
for a, b in arg:
reveal_type(a) # revealed: int | str
reveal_type(b) # revealed: str | int
```

### Mixed types (3)

```py
def _(arg: tuple[tuple[int, int, int], tuple[int, str, bytes], tuple[int, int, str]]):
for a, b, c in arg:
reveal_type(a) # revealed: int
reveal_type(b) # revealed: int | str
reveal_type(c) # revealed: int | bytes | str
```

### Same literal values

```py
for a, b in ((1, 2), (3, 4)):
reveal_type(a) # revealed: Literal[1, 3]
reveal_type(b) # revealed: Literal[2, 4]
```

### Mixed literal values (1)

```py
for a, b in ((1, 2), ("a", "b")):
reveal_type(a) # revealed: Literal[1] | Literal["a"]
reveal_type(b) # revealed: Literal[2] | Literal["b"]
```

### Mixed literals values (2)

```py
# error: "Object of type `Literal[1]` is not iterable"
# error: "Object of type `Literal[2]` is not iterable"
# error: "Object of type `Literal[4]` is not iterable"
for a, b in (1, 2, (3, "a"), 4, (5, "b"), "c"):
reveal_type(a) # revealed: Unknown | Literal[3, 5] | LiteralString

This comment has been minimized.

Copy link
@T-256

T-256 Dec 23, 2024

Contributor

Why inferred LiteralString here?

Traceback (most recent call last):
  File "<pyshell#29>", line 1, in <module>
    for a,b in ("c",):
ValueError: not enough values to unpack (expected 2, got 1)

This comment has been minimized.

Copy link
@dhruvmanila

dhruvmanila Dec 24, 2024

Author Member

The LiteralString comes from unpacking the "c", it's not a single element tuple but just a string. So,

In [2]: a, b = "ab"

In [3]: a
Out[3]: 'a'

In [4]: b
Out[4]: 'b'

Notice that only a has the type LiteralString in its union but not b because the "c" string has only one character. So, it can only be unpacked to one variable.

We'll emit a diagnostic for the string: Not enough values to unpack (expected 2, got 1) in a follow-up PR.

reveal_type(b) # revealed: Unknown | Literal["a", "b"]
```

### Custom iterator (1)

```py
class Iterator:
def __next__(self) -> tuple[int, int]:
return (1, 2)

class Iterable:
def __iter__(self) -> Iterator:
return Iterator()

for a, b in Iterable():
reveal_type(a) # revealed: int
reveal_type(b) # revealed: int
```

### Custom iterator (2)

```py
class Iterator:
def __next__(self) -> bytes:
return b""

class Iterable:
def __iter__(self) -> Iterator:
return Iterator()

def _(arg: tuple[tuple[int, str], Iterable]):
for a, b in arg:
reveal_type(a) # revealed: int | bytes
reveal_type(b) # revealed: str | bytes
```
67 changes: 51 additions & 16 deletions crates/red_knot_python_semantic/src/semantic_index/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use crate::semantic_index::use_def::{
FlowSnapshot, ScopedConstraintId, ScopedVisibilityConstraintId, UseDefMapBuilder,
};
use crate::semantic_index::SemanticIndex;
use crate::unpack::Unpack;
use crate::unpack::{Unpack, UnpackValue};
use crate::visibility_constraints::VisibilityConstraint;
use crate::Db;

Expand Down Expand Up @@ -810,7 +810,7 @@ where
unsafe {
AstNodeRef::new(self.module.clone(), target)
},
value,
UnpackValue::Assign(value),
countme::Count::default(),
)),
})
Expand Down Expand Up @@ -1021,18 +1021,47 @@ where
orelse,
},
) => {
self.add_standalone_expression(iter);
debug_assert_eq!(&self.current_assignments, &[]);

let iter_expr = self.add_standalone_expression(iter);
self.visit_expr(iter);

self.record_ambiguous_visibility();

let pre_loop = self.flow_snapshot();
let saved_break_states = std::mem::take(&mut self.loop_break_states);

debug_assert_eq!(&self.current_assignments, &[]);
self.push_assignment(for_stmt.into());
let current_assignment = match &**target {
ast::Expr::List(_) | ast::Expr::Tuple(_) => Some(CurrentAssignment::For {
node: for_stmt,
first: true,
unpack: Some(Unpack::new(
self.db,
self.file,
self.current_scope(),
#[allow(unsafe_code)]
unsafe {
AstNodeRef::new(self.module.clone(), target)
},
UnpackValue::Iterable(iter_expr),
countme::Count::default(),
)),
}),
ast::Expr::Name(_) => Some(CurrentAssignment::For {
node: for_stmt,
unpack: None,
first: false,
}),
_ => None,
};

if let Some(current_assignment) = current_assignment {
self.push_assignment(current_assignment);
}
self.visit_expr(target);
self.pop_assignment();
if current_assignment.is_some() {
self.pop_assignment();
}

// TODO: Definitions created by loop variables
// (and definitions created inside the body)
Expand Down Expand Up @@ -1283,12 +1312,18 @@ where
Some(CurrentAssignment::AugAssign(aug_assign)) => {
self.add_definition(symbol, aug_assign);
}
Some(CurrentAssignment::For(node)) => {
Some(CurrentAssignment::For {
node,
first,
unpack,
}) => {
self.add_definition(
symbol,
ForStmtDefinitionNodeRef {
unpack,
first,
iterable: &node.iter,
target: name_node,
name: name_node,
is_async: node.is_async,
},
);
Expand Down Expand Up @@ -1324,7 +1359,9 @@ where
}
}

if let Some(CurrentAssignment::Assign { first, .. }) = self.current_assignment_mut()
if let Some(
CurrentAssignment::Assign { first, .. } | CurrentAssignment::For { first, .. },
) = self.current_assignment_mut()
{
*first = false;
}
Expand Down Expand Up @@ -1566,7 +1603,11 @@ enum CurrentAssignment<'a> {
},
AnnAssign(&'a ast::StmtAnnAssign),
AugAssign(&'a ast::StmtAugAssign),
For(&'a ast::StmtFor),
For {
node: &'a ast::StmtFor,
first: bool,
unpack: Option<Unpack<'a>>,
},
Named(&'a ast::ExprNamed),
Comprehension {
node: &'a ast::Comprehension,
Expand All @@ -1590,12 +1631,6 @@ impl<'a> From<&'a ast::StmtAugAssign> for CurrentAssignment<'a> {
}
}

impl<'a> From<&'a ast::StmtFor> for CurrentAssignment<'a> {
fn from(value: &'a ast::StmtFor) -> Self {
Self::For(value)
}
}

impl<'a> From<&'a ast::ExprNamed> for CurrentAssignment<'a> {
fn from(value: &'a ast::ExprNamed) -> Self {
Self::Named(value)
Expand Down
50 changes: 31 additions & 19 deletions crates/red_knot_python_semantic/src/semantic_index/definition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,10 @@ pub(crate) struct WithItemDefinitionNodeRef<'a> {

#[derive(Copy, Clone, Debug)]
pub(crate) struct ForStmtDefinitionNodeRef<'a> {
pub(crate) unpack: Option<Unpack<'a>>,
pub(crate) iterable: &'a ast::Expr,
pub(crate) target: &'a ast::ExprName,
pub(crate) name: &'a ast::ExprName,
pub(crate) first: bool,
pub(crate) is_async: bool,
}

Expand Down Expand Up @@ -298,12 +300,16 @@ impl<'db> DefinitionNodeRef<'db> {
DefinitionKind::AugmentedAssignment(AstNodeRef::new(parsed, augmented_assignment))
}
DefinitionNodeRef::For(ForStmtDefinitionNodeRef {
unpack,
iterable,
target,
name,
first,
is_async,
}) => DefinitionKind::For(ForStmtDefinitionKind {
target: TargetKind::from(unpack),
iterable: AstNodeRef::new(parsed.clone(), iterable),
target: AstNodeRef::new(parsed, target),
name: AstNodeRef::new(parsed, name),
first,
is_async,
}),
DefinitionNodeRef::Comprehension(ComprehensionDefinitionNodeRef {
Expand Down Expand Up @@ -382,10 +388,12 @@ impl<'db> DefinitionNodeRef<'db> {
Self::AnnotatedAssignment(node) => node.into(),
Self::AugmentedAssignment(node) => node.into(),
Self::For(ForStmtDefinitionNodeRef {
unpack: _,
iterable: _,
target,
name,
first: _,
is_async: _,
}) => target.into(),
}) => name.into(),
Self::Comprehension(ComprehensionDefinitionNodeRef { target, .. }) => target.into(),
Self::VariadicPositionalParameter(node) => node.into(),
Self::VariadicKeywordParameter(node) => node.into(),
Expand Down Expand Up @@ -452,7 +460,7 @@ pub enum DefinitionKind<'db> {
Assignment(AssignmentDefinitionKind<'db>),
AnnotatedAssignment(AstNodeRef<ast::StmtAnnAssign>),
AugmentedAssignment(AstNodeRef<ast::StmtAugAssign>),
For(ForStmtDefinitionKind),
For(ForStmtDefinitionKind<'db>),
Comprehension(ComprehensionDefinitionKind),
VariadicPositionalParameter(AstNodeRef<ast::Parameter>),
VariadicKeywordParameter(AstNodeRef<ast::Parameter>),
Expand All @@ -477,7 +485,7 @@ impl Ranged for DefinitionKind<'_> {
DefinitionKind::Assignment(assignment) => assignment.name().range(),
DefinitionKind::AnnotatedAssignment(assign) => assign.target.range(),
DefinitionKind::AugmentedAssignment(aug_assign) => aug_assign.target.range(),
DefinitionKind::For(for_stmt) => for_stmt.target().range(),
DefinitionKind::For(for_stmt) => for_stmt.name().range(),
DefinitionKind::Comprehension(comp) => comp.target().range(),
DefinitionKind::VariadicPositionalParameter(parameter) => parameter.name.range(),
DefinitionKind::VariadicKeywordParameter(parameter) => parameter.name.range(),
Expand Down Expand Up @@ -665,22 +673,32 @@ impl WithItemDefinitionKind {
}

#[derive(Clone, Debug)]
pub struct ForStmtDefinitionKind {
pub struct ForStmtDefinitionKind<'db> {
target: TargetKind<'db>,
iterable: AstNodeRef<ast::Expr>,
target: AstNodeRef<ast::ExprName>,
name: AstNodeRef<ast::ExprName>,
first: bool,
is_async: bool,
}

impl ForStmtDefinitionKind {
impl<'db> ForStmtDefinitionKind<'db> {
pub(crate) fn iterable(&self) -> &ast::Expr {
self.iterable.node()
}

pub(crate) fn target(&self) -> &ast::ExprName {
self.target.node()
pub(crate) fn target(&self) -> TargetKind<'db> {
self.target
}

pub(crate) fn is_async(&self) -> bool {
pub(crate) fn name(&self) -> &ast::ExprName {
self.name.node()
}

pub(crate) const fn is_first(&self) -> bool {
self.first
}

pub(crate) const fn is_async(&self) -> bool {
self.is_async
}
}
Expand Down Expand Up @@ -756,12 +774,6 @@ impl From<&ast::StmtAugAssign> for DefinitionNodeKey {
}
}

impl From<&ast::StmtFor> for DefinitionNodeKey {
fn from(value: &ast::StmtFor) -> Self {
Self(NodeKey::from_node(value))
}
}

impl From<&ast::Parameter> for DefinitionNodeKey {
fn from(node: &ast::Parameter) -> Self {
Self(NodeKey::from_node(node))
Expand Down
Loading

0 comments on commit 113c804

Please sign in to comment.