diff --git a/crates/red_knot_python_semantic/resources/mdtest/unpacking.md b/crates/red_knot_python_semantic/resources/mdtest/unpacking.md index 1ee1c5edf0094..49815a52fc529 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/unpacking.md +++ b/crates/red_knot_python_semantic/resources/mdtest/unpacking.md @@ -472,3 +472,104 @@ def _(arg: tuple[int, str] | Iterable): reveal_type(a) # revealed: int | bytes reveal_type(b) # revealed: str | bytes ``` + +## For statement + +Unpacking in a `for` statement. + +### Same types + +```py +def _(arg: tuple[tuple[int, int], tuple[int, int]]): + for a, b in arg: + reveal_type(a) # revealed: int + reveal_type(b) # revealed: int +``` + +### Mixed types (1) + +```py +def _(arg: tuple[tuple[int, int], tuple[int, str]]): + for a, b in arg: + reveal_type(a) # revealed: int + reveal_type(b) # revealed: int | str +``` + +### Mixed types (2) + +```py +def _(arg: tuple[tuple[int, str], tuple[str, int]]): + for a, b in arg: + reveal_type(a) # revealed: int | str + reveal_type(b) # revealed: str | int +``` + +### Mixed types (3) + +```py +def _(arg: tuple[tuple[int, int, int], tuple[int, str, bytes], tuple[int, int, str]]): + for a, b, c in arg: + reveal_type(a) # revealed: int + reveal_type(b) # revealed: int | str + reveal_type(c) # revealed: int | bytes | str +``` + +### Same literal values + +```py +for a, b in ((1, 2), (3, 4)): + reveal_type(a) # revealed: Literal[1, 3] + reveal_type(b) # revealed: Literal[2, 4] +``` + +### Mixed literal values (1) + +```py +for a, b in ((1, 2), ("a", "b")): + reveal_type(a) # revealed: Literal[1] | Literal["a"] + reveal_type(b) # revealed: Literal[2] | Literal["b"] +``` + +### Mixed literals values (2) + +```py +# error: "Object of type `Literal[1]` is not iterable" +# error: "Object of type `Literal[2]` is not iterable" +# error: "Object of type `Literal[4]` is not iterable" +for a, b in (1, 2, (3, "a"), 4, (5, "b"), "c"): + reveal_type(a) # revealed: Unknown | Literal[3, 5] | LiteralString + reveal_type(b) # revealed: Unknown | Literal["a", "b"] +``` + +### Custom iterator (1) + +```py +class Iterator: + def __next__(self) -> tuple[int, int]: + return (1, 2) + +class Iterable: + def __iter__(self) -> Iterator: + return Iterator() + +for a, b in Iterable(): + reveal_type(a) # revealed: int + reveal_type(b) # revealed: int +``` + +### Custom iterator (2) + +```py +class Iterator: + def __next__(self) -> bytes: + return b"" + +class Iterable: + def __iter__(self) -> Iterator: + return Iterator() + +def _(arg: tuple[tuple[int, str], Iterable]): + for a, b in arg: + reveal_type(a) # revealed: int | bytes + reveal_type(b) # revealed: str | bytes +``` diff --git a/crates/red_knot_python_semantic/src/semantic_index/builder.rs b/crates/red_knot_python_semantic/src/semantic_index/builder.rs index 379ffaebab8c0..8be253cb27a78 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/builder.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/builder.rs @@ -29,7 +29,7 @@ use crate::semantic_index::use_def::{ FlowSnapshot, ScopedConstraintId, ScopedVisibilityConstraintId, UseDefMapBuilder, }; use crate::semantic_index::SemanticIndex; -use crate::unpack::Unpack; +use crate::unpack::{Unpack, UnpackValue}; use crate::visibility_constraints::VisibilityConstraint; use crate::Db; @@ -810,7 +810,7 @@ where unsafe { AstNodeRef::new(self.module.clone(), target) }, - value, + UnpackValue::Assign(value), countme::Count::default(), )), }) @@ -1021,7 +1021,9 @@ where orelse, }, ) => { - self.add_standalone_expression(iter); + debug_assert_eq!(&self.current_assignments, &[]); + + let iter_expr = self.add_standalone_expression(iter); self.visit_expr(iter); self.record_ambiguous_visibility(); @@ -1029,10 +1031,37 @@ where let pre_loop = self.flow_snapshot(); let saved_break_states = std::mem::take(&mut self.loop_break_states); - debug_assert_eq!(&self.current_assignments, &[]); - self.push_assignment(for_stmt.into()); + let current_assignment = match &**target { + ast::Expr::List(_) | ast::Expr::Tuple(_) => Some(CurrentAssignment::For { + node: for_stmt, + first: true, + unpack: Some(Unpack::new( + self.db, + self.file, + self.current_scope(), + #[allow(unsafe_code)] + unsafe { + AstNodeRef::new(self.module.clone(), target) + }, + UnpackValue::Iterable(iter_expr), + countme::Count::default(), + )), + }), + ast::Expr::Name(_) => Some(CurrentAssignment::For { + node: for_stmt, + unpack: None, + first: false, + }), + _ => None, + }; + + if let Some(current_assignment) = current_assignment { + self.push_assignment(current_assignment); + } self.visit_expr(target); - self.pop_assignment(); + if current_assignment.is_some() { + self.pop_assignment(); + } // TODO: Definitions created by loop variables // (and definitions created inside the body) @@ -1283,12 +1312,18 @@ where Some(CurrentAssignment::AugAssign(aug_assign)) => { self.add_definition(symbol, aug_assign); } - Some(CurrentAssignment::For(node)) => { + Some(CurrentAssignment::For { + node, + first, + unpack, + }) => { self.add_definition( symbol, ForStmtDefinitionNodeRef { + unpack, + first, iterable: &node.iter, - target: name_node, + name: name_node, is_async: node.is_async, }, ); @@ -1324,7 +1359,9 @@ where } } - if let Some(CurrentAssignment::Assign { first, .. }) = self.current_assignment_mut() + if let Some( + CurrentAssignment::Assign { first, .. } | CurrentAssignment::For { first, .. }, + ) = self.current_assignment_mut() { *first = false; } @@ -1566,7 +1603,11 @@ enum CurrentAssignment<'a> { }, AnnAssign(&'a ast::StmtAnnAssign), AugAssign(&'a ast::StmtAugAssign), - For(&'a ast::StmtFor), + For { + node: &'a ast::StmtFor, + first: bool, + unpack: Option>, + }, Named(&'a ast::ExprNamed), Comprehension { node: &'a ast::Comprehension, @@ -1590,12 +1631,6 @@ impl<'a> From<&'a ast::StmtAugAssign> for CurrentAssignment<'a> { } } -impl<'a> From<&'a ast::StmtFor> for CurrentAssignment<'a> { - fn from(value: &'a ast::StmtFor) -> Self { - Self::For(value) - } -} - impl<'a> From<&'a ast::ExprNamed> for CurrentAssignment<'a> { fn from(value: &'a ast::ExprNamed) -> Self { Self::Named(value) diff --git a/crates/red_knot_python_semantic/src/semantic_index/definition.rs b/crates/red_knot_python_semantic/src/semantic_index/definition.rs index 6b9d57ca071cc..d39c918a82046 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/definition.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/definition.rs @@ -225,8 +225,10 @@ pub(crate) struct WithItemDefinitionNodeRef<'a> { #[derive(Copy, Clone, Debug)] pub(crate) struct ForStmtDefinitionNodeRef<'a> { + pub(crate) unpack: Option>, pub(crate) iterable: &'a ast::Expr, - pub(crate) target: &'a ast::ExprName, + pub(crate) name: &'a ast::ExprName, + pub(crate) first: bool, pub(crate) is_async: bool, } @@ -298,12 +300,16 @@ impl<'db> DefinitionNodeRef<'db> { DefinitionKind::AugmentedAssignment(AstNodeRef::new(parsed, augmented_assignment)) } DefinitionNodeRef::For(ForStmtDefinitionNodeRef { + unpack, iterable, - target, + name, + first, is_async, }) => DefinitionKind::For(ForStmtDefinitionKind { + target: TargetKind::from(unpack), iterable: AstNodeRef::new(parsed.clone(), iterable), - target: AstNodeRef::new(parsed, target), + name: AstNodeRef::new(parsed, name), + first, is_async, }), DefinitionNodeRef::Comprehension(ComprehensionDefinitionNodeRef { @@ -382,10 +388,12 @@ impl<'db> DefinitionNodeRef<'db> { Self::AnnotatedAssignment(node) => node.into(), Self::AugmentedAssignment(node) => node.into(), Self::For(ForStmtDefinitionNodeRef { + unpack: _, iterable: _, - target, + name, + first: _, is_async: _, - }) => target.into(), + }) => name.into(), Self::Comprehension(ComprehensionDefinitionNodeRef { target, .. }) => target.into(), Self::VariadicPositionalParameter(node) => node.into(), Self::VariadicKeywordParameter(node) => node.into(), @@ -452,7 +460,7 @@ pub enum DefinitionKind<'db> { Assignment(AssignmentDefinitionKind<'db>), AnnotatedAssignment(AstNodeRef), AugmentedAssignment(AstNodeRef), - For(ForStmtDefinitionKind), + For(ForStmtDefinitionKind<'db>), Comprehension(ComprehensionDefinitionKind), VariadicPositionalParameter(AstNodeRef), VariadicKeywordParameter(AstNodeRef), @@ -477,7 +485,7 @@ impl Ranged for DefinitionKind<'_> { DefinitionKind::Assignment(assignment) => assignment.name().range(), DefinitionKind::AnnotatedAssignment(assign) => assign.target.range(), DefinitionKind::AugmentedAssignment(aug_assign) => aug_assign.target.range(), - DefinitionKind::For(for_stmt) => for_stmt.target().range(), + DefinitionKind::For(for_stmt) => for_stmt.name().range(), DefinitionKind::Comprehension(comp) => comp.target().range(), DefinitionKind::VariadicPositionalParameter(parameter) => parameter.name.range(), DefinitionKind::VariadicKeywordParameter(parameter) => parameter.name.range(), @@ -665,22 +673,32 @@ impl WithItemDefinitionKind { } #[derive(Clone, Debug)] -pub struct ForStmtDefinitionKind { +pub struct ForStmtDefinitionKind<'db> { + target: TargetKind<'db>, iterable: AstNodeRef, - target: AstNodeRef, + name: AstNodeRef, + first: bool, is_async: bool, } -impl ForStmtDefinitionKind { +impl<'db> ForStmtDefinitionKind<'db> { pub(crate) fn iterable(&self) -> &ast::Expr { self.iterable.node() } - pub(crate) fn target(&self) -> &ast::ExprName { - self.target.node() + pub(crate) fn target(&self) -> TargetKind<'db> { + self.target } - pub(crate) fn is_async(&self) -> bool { + pub(crate) fn name(&self) -> &ast::ExprName { + self.name.node() + } + + pub(crate) const fn is_first(&self) -> bool { + self.first + } + + pub(crate) const fn is_async(&self) -> bool { self.is_async } } @@ -756,12 +774,6 @@ impl From<&ast::StmtAugAssign> for DefinitionNodeKey { } } -impl From<&ast::StmtFor> for DefinitionNodeKey { - fn from(value: &ast::StmtFor) -> Self { - Self(NodeKey::from_node(value)) - } -} - impl From<&ast::Parameter> for DefinitionNodeKey { fn from(node: &ast::Parameter) -> Self { Self(NodeKey::from_node(node)) diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index d689e30338339..6111eaf72f6cc 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -42,7 +42,7 @@ use crate::module_resolver::{file_to_module, resolve_module}; use crate::semantic_index::ast_ids::{HasScopedExpressionId, HasScopedUseId, ScopedExpressionId}; use crate::semantic_index::definition::{ AssignmentDefinitionKind, Definition, DefinitionKind, DefinitionNodeKey, - ExceptHandlerDefinitionKind, TargetKind, + ExceptHandlerDefinitionKind, ForStmtDefinitionKind, TargetKind, }; use crate::semantic_index::expression::Expression; use crate::semantic_index::semantic_index; @@ -198,17 +198,11 @@ pub(crate) fn infer_expression_types<'db>( fn infer_unpack_types<'db>(db: &'db dyn Db, unpack: Unpack<'db>) -> UnpackResult<'db> { let file = unpack.file(db); let _span = - tracing::trace_span!("infer_unpack_types", unpack=?unpack.as_id(), file=%file.path(db)) + tracing::trace_span!("infer_unpack_types", range=?unpack.range(db), file=%file.path(db)) .entered(); - let value = unpack.value(db); - let scope = unpack.scope(db); - - let result = infer_expression_types(db, value); - let value_ty = result.expression_ty(value.node_ref(db).scoped_expression_id(db, scope)); - - let mut unpacker = Unpacker::new(db, scope); - unpacker.unpack(unpack.target(db), value_ty); + let mut unpacker = Unpacker::new(db, unpack.scope(db)); + unpacker.unpack(unpack.target(db), unpack.value(db)); unpacker.finish() } @@ -710,12 +704,7 @@ impl<'db> TypeInferenceBuilder<'db> { self.infer_augment_assignment_definition(augmented_assignment.node(), definition); } DefinitionKind::For(for_statement_definition) => { - self.infer_for_statement_definition( - for_statement_definition.target(), - for_statement_definition.iterable(), - for_statement_definition.is_async(), - definition, - ); + self.infer_for_statement_definition(for_statement_definition, definition); } DefinitionKind::NamedExpression(named_expression) => { self.infer_named_expression_definition(named_expression.node(), definition); @@ -1833,18 +1822,22 @@ impl<'db> TypeInferenceBuilder<'db> { } = assignment; for target in targets { - self.infer_assignment_target(target, value); + self.infer_target(target, value); } } + /// Infer the definition type involved in a `target` expression. + /// + /// This is used for assignment statements, for statements, etc. with a single or multiple + /// targets (unpacking). // TODO: Remove the `value` argument once we handle all possible assignment targets. - fn infer_assignment_target(&mut self, target: &ast::Expr, value: &ast::Expr) { + fn infer_target(&mut self, target: &ast::Expr, value: &ast::Expr) { match target { ast::Expr::Name(name) => self.infer_definition(name), ast::Expr::List(ast::ExprList { elts, .. }) | ast::Expr::Tuple(ast::ExprTuple { elts, .. }) => { for element in elts { - self.infer_assignment_target(element, value); + self.infer_target(element, value); } } _ => { @@ -1863,10 +1856,7 @@ impl<'db> TypeInferenceBuilder<'db> { let value = assignment.value(); let name = assignment.name(); - self.infer_standalone_expression(value); - - let value_ty = self.expression_ty(value); - let name_ast_id = name.scoped_expression_id(self.db(), self.scope()); + let value_ty = self.infer_standalone_expression(value); let mut target_ty = match assignment.target() { TargetKind::Sequence(unpack) => { @@ -1877,6 +1867,7 @@ impl<'db> TypeInferenceBuilder<'db> { self.context.extend(unpacked); } + let name_ast_id = name.scoped_expression_id(self.db(), self.scope()); unpacked.get(name_ast_id).unwrap_or(Type::Unknown) } TargetKind::Name => value_ty, @@ -2104,36 +2095,41 @@ impl<'db> TypeInferenceBuilder<'db> { is_async: _, } = for_statement; - // TODO more complex assignment targets - if let ast::Expr::Name(name) = &**target { - self.infer_definition(name); - } else { - self.infer_standalone_expression(iter); - self.infer_expression(target); - } + self.infer_target(target, iter); self.infer_body(body); self.infer_body(orelse); } fn infer_for_statement_definition( &mut self, - target: &ast::ExprName, - iterable: &ast::Expr, - is_async: bool, + for_stmt: &ForStmtDefinitionKind<'db>, definition: Definition<'db>, ) { + let iterable = for_stmt.iterable(); + let name = for_stmt.name(); + let iterable_ty = self.infer_standalone_expression(iterable); - let loop_var_value_ty = if is_async { + let loop_var_value_ty = if for_stmt.is_async() { todo_type!("async iterables/iterators") } else { - iterable_ty - .iterate(self.db()) - .unwrap_with_diagnostic(&self.context, iterable.into()) + match for_stmt.target() { + TargetKind::Sequence(unpack) => { + let unpacked = infer_unpack_types(self.db(), unpack); + if for_stmt.is_first() { + self.context.extend(unpacked); + } + let name_ast_id = name.scoped_expression_id(self.db(), self.scope()); + unpacked.get(name_ast_id).unwrap_or(Type::Unknown) + } + TargetKind::Name => iterable_ty + .iterate(self.db()) + .unwrap_with_diagnostic(&self.context, iterable.into()), + } }; - self.store_expression_type(target, loop_var_value_ty); - self.add_binding(target.into(), definition, loop_var_value_ty); + self.store_expression_type(name, loop_var_value_ty); + self.add_binding(name.into(), definition, loop_var_value_ty); } fn infer_while_statement(&mut self, while_statement: &ast::StmtWhile) { diff --git a/crates/red_knot_python_semantic/src/types/unpacker.rs b/crates/red_knot_python_semantic/src/types/unpacker.rs index c22fdda4378c1..3786069a85ed4 100644 --- a/crates/red_knot_python_semantic/src/types/unpacker.rs +++ b/crates/red_knot_python_semantic/src/types/unpacker.rs @@ -6,7 +6,8 @@ use ruff_python_ast::{self as ast, AnyNodeRef}; use crate::semantic_index::ast_ids::{HasScopedExpressionId, ScopedExpressionId}; use crate::semantic_index::symbol::ScopeId; -use crate::types::{todo_type, Type, TypeCheckDiagnostics}; +use crate::types::{infer_expression_types, todo_type, Type, TypeCheckDiagnostics}; +use crate::unpack::UnpackValue; use crate::Db; use super::context::{InferContext, WithDiagnostics}; @@ -32,13 +33,24 @@ impl<'db> Unpacker<'db> { self.context.db() } - /// Unpack the value type to the target expression. - pub(crate) fn unpack(&mut self, target: &ast::Expr, value_ty: Type<'db>) { + /// Unpack the value to the target expression. + pub(crate) fn unpack(&mut self, target: &ast::Expr, value: UnpackValue<'db>) { debug_assert!( matches!(target, ast::Expr::List(_) | ast::Expr::Tuple(_)), "Unpacking target must be a list or tuple expression" ); + let mut value_ty = infer_expression_types(self.db(), value.expression()) + .expression_ty(value.scoped_expression_id(self.db(), self.scope)); + + if value.is_iterable() { + // If the value is an iterable, then the type that needs to be unpacked is the iterator + // type. + value_ty = value_ty + .iterate(self.db()) + .unwrap_with_diagnostic(&self.context, value.as_any_node_ref(self.db())); + } + self.unpack_inner(target, value_ty); } diff --git a/crates/red_knot_python_semantic/src/unpack.rs b/crates/red_knot_python_semantic/src/unpack.rs index 25ad40231c5c0..e0b3be92c4b45 100644 --- a/crates/red_knot_python_semantic/src/unpack.rs +++ b/crates/red_knot_python_semantic/src/unpack.rs @@ -1,7 +1,9 @@ use ruff_db::files::File; -use ruff_python_ast::{self as ast}; +use ruff_python_ast::{self as ast, AnyNodeRef}; +use ruff_text_size::{Ranged, TextRange}; use crate::ast_node_ref::AstNodeRef; +use crate::semantic_index::ast_ids::{HasScopedExpressionId, ScopedExpressionId}; use crate::semantic_index::expression::Expression; use crate::semantic_index::symbol::{FileScopeId, ScopeId}; use crate::Db; @@ -41,7 +43,7 @@ pub(crate) struct Unpack<'db> { /// The ingredient representing the value expression of the unpacking. For example, in /// `(a, b) = (1, 2)`, the value expression is `(1, 2)`. #[no_eq] - pub(crate) value: Expression<'db>, + pub(crate) value: UnpackValue<'db>, #[no_eq] count: countme::Count>, @@ -52,4 +54,48 @@ impl<'db> Unpack<'db> { pub(crate) fn scope(self, db: &'db dyn Db) -> ScopeId<'db> { self.file_scope(db).to_scope_id(db, self.file(db)) } + + /// Returns the range of the unpack target expression. + pub(crate) fn range(self, db: &'db dyn Db) -> TextRange { + self.target(db).range() + } +} + +/// The expression that is being unpacked. +#[derive(Clone, Copy, Debug)] +pub(crate) enum UnpackValue<'db> { + /// An iterable expression like the one in a `for` loop or a comprehension. + Iterable(Expression<'db>), + /// An expression that is being assigned to a target. + Assign(Expression<'db>), +} + +impl<'db> UnpackValue<'db> { + /// Returns `true` if the value is an iterable expression. + pub(crate) const fn is_iterable(self) -> bool { + matches!(self, UnpackValue::Iterable(_)) + } + + /// Returns the underlying [`Expression`] that is being unpacked. + pub(crate) const fn expression(self) -> Expression<'db> { + match self { + UnpackValue::Assign(expr) | UnpackValue::Iterable(expr) => expr, + } + } + + /// Returns the [`ScopedExpressionId`] of the underlying expression. + pub(crate) fn scoped_expression_id( + self, + db: &'db dyn Db, + scope: ScopeId<'db>, + ) -> ScopedExpressionId { + self.expression() + .node_ref(db) + .scoped_expression_id(db, scope) + } + + /// Returns the expression as an [`AnyNodeRef`]. + pub(crate) fn as_any_node_ref(self, db: &'db dyn Db) -> AnyNodeRef<'db> { + self.expression().node_ref(db).node().into() + } } diff --git a/crates/red_knot_workspace/tests/check.rs b/crates/red_knot_workspace/tests/check.rs index 733bf2efdd0d3..d7d0a14b11d39 100644 --- a/crates/red_knot_workspace/tests/check.rs +++ b/crates/red_knot_workspace/tests/check.rs @@ -182,11 +182,11 @@ impl<'db> PullTypesVisitor<'db> { } } - fn visit_assign_target(&mut self, target: &Expr) { + fn visit_target(&mut self, target: &Expr) { match target { Expr::List(ast::ExprList { elts, .. }) | Expr::Tuple(ast::ExprTuple { elts, .. }) => { for element in elts { - self.visit_assign_target(element); + self.visit_target(element); } } _ => self.visit_expr(target), @@ -205,8 +205,18 @@ impl SourceOrderVisitor<'_> for PullTypesVisitor<'_> { } Stmt::Assign(assign) => { for target in &assign.targets { - self.visit_assign_target(target); + self.visit_target(target); } + // TODO + //self.visit_expr(&assign.value); + return; + } + Stmt::For(for_stmt) => { + self.visit_target(&for_stmt.target); + // TODO + //self.visit_expr(&for_stmt.iter); + self.visit_body(&for_stmt.body); + self.visit_body(&for_stmt.orelse); return; } Stmt::AnnAssign(_) @@ -214,7 +224,6 @@ impl SourceOrderVisitor<'_> for PullTypesVisitor<'_> { | Stmt::Delete(_) | Stmt::AugAssign(_) | Stmt::TypeAlias(_) - | Stmt::For(_) | Stmt::While(_) | Stmt::If(_) | Stmt::With(_)