Skip to content

Commit

Permalink
[red-knot] Add control flow for for loops (#13318)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexWaygood committed Sep 10, 2024
1 parent e6b927a commit b93d0ab
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 10 deletions.
39 changes: 34 additions & 5 deletions crates/red_knot_python_semantic/src/semantic_index/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -575,14 +575,23 @@ where
self.flow_merge(pre_if);
}
}
ast::Stmt::While(node) => {
self.visit_expr(&node.test);
ast::Stmt::While(ast::StmtWhile {
test,
body,
orelse,
range: _,
}) => {
self.visit_expr(test);

let pre_loop = self.flow_snapshot();

// Save aside any break states from an outer loop
let saved_break_states = std::mem::take(&mut self.loop_break_states);
self.visit_body(&node.body);

// TODO: definitions created inside the body should be fully visible
// to other statements/expressions inside the body --Alex/Carl
self.visit_body(body);

// Get the break states from the body of this loop, and restore the saved outer
// ones.
let break_states =
Expand All @@ -591,7 +600,7 @@ where
// We may execute the `else` clause without ever executing the body, so merge in
// the pre-loop state before visiting `else`.
self.flow_merge(pre_loop);
self.visit_body(&node.orelse);
self.visit_body(orelse);

// Breaking out of a while loop bypasses the `else` clause, so merge in the break
// states after visiting `else`.
Expand Down Expand Up @@ -625,15 +634,35 @@ where
orelse,
},
) => {
// TODO add control flow similar to `ast::Stmt::While` above
self.add_standalone_expression(iter);
self.visit_expr(iter);

let pre_loop = self.flow_snapshot();
let saved_break_states = std::mem::take(&mut self.loop_break_states);

debug_assert!(self.current_assignment.is_none());
self.current_assignment = Some(for_stmt.into());
self.visit_expr(target);
self.current_assignment = None;

// TODO: Definitions created by loop variables
// (and definitions created inside the body)
// are fully visible to other statements/expressions inside the body --Alex/Carl
self.visit_body(body);

let break_states =
std::mem::replace(&mut self.loop_break_states, saved_break_states);

// We may execute the `else` clause without ever executing the body, so merge in
// the pre-loop state before visiting `else`.
self.flow_merge(pre_loop);
self.visit_body(orelse);

// Breaking out of a `for` loop bypasses the `else` clause, so merge in the break
// states after visiting `else`.
for break_state in break_states {
self.flow_merge(break_state);
}
}
ast::Stmt::Match(ast::StmtMatch {
subject,
Expand Down
95 changes: 90 additions & 5 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4271,7 +4271,92 @@ mod tests {
",
)?;

assert_public_ty(&db, "src/a.py", "x", "int");
assert_public_ty(&db, "src/a.py", "x", "Unbound | int");

Ok(())
}

#[test]
fn for_loop_with_previous_definition() -> anyhow::Result<()> {
let mut db = setup_db();

db.write_dedented(
"src/a.py",
"
class IntIterator:
def __next__(self) -> int:
return 42
class IntIterable:
def __iter__(self) -> IntIterator:
return IntIterator()
x = 'foo'
for x in IntIterable():
pass
",
)?;

assert_public_ty(&db, "src/a.py", "x", r#"Literal["foo"] | int"#);

Ok(())
}

#[test]
fn for_loop_no_break() -> anyhow::Result<()> {
let mut db = setup_db();

db.write_dedented(
"src/a.py",
"
class IntIterator:
def __next__(self) -> int:
return 42
class IntIterable:
def __iter__(self) -> IntIterator:
return IntIterator()
for x in IntIterable():
pass
else:
x = 'foo'
",
)?;

// The `for` loop can never break, so the `else` clause will always be executed,
// meaning that the visible definition by the end of the scope is solely determined
// by the `else` clause
assert_public_ty(&db, "src/a.py", "x", r#"Literal["foo"]"#);

Ok(())
}

#[test]
fn for_loop_may_break() -> anyhow::Result<()> {
let mut db = setup_db();

db.write_dedented(
"src/a.py",
"
class IntIterator:
def __next__(self) -> int:
return 42
class IntIterable:
def __iter__(self) -> IntIterator:
return IntIterator()
for x in IntIterable():
if x > 5:
break
else:
x = 'foo'
",
)?;

assert_public_ty(&db, "src/a.py", "x", r#"int | Literal["foo"]"#);

Ok(())
}
Expand All @@ -4292,7 +4377,7 @@ mod tests {
",
)?;

assert_public_ty(&db, "src/a.py", "x", "int");
assert_public_ty(&db, "src/a.py", "x", "Unbound | int");

Ok(())
}
Expand Down Expand Up @@ -4320,7 +4405,7 @@ mod tests {
",
)?;

assert_scope_ty(&db, "src/a.py", &["foo"], "x", "Unknown");
assert_scope_ty(&db, "src/a.py", &["foo"], "x", "Unbound | Unknown");

Ok(())
}
Expand All @@ -4347,7 +4432,7 @@ mod tests {
)?;

// TODO(Alex) async iterables/iterators!
assert_scope_ty(&db, "src/a.py", &["foo"], "x", "Unknown");
assert_scope_ty(&db, "src/a.py", &["foo"], "x", "Unbound | Unknown");

Ok(())
}
Expand All @@ -4368,7 +4453,7 @@ mod tests {
&db,
"src/a.py",
"x",
r#"Literal[1] | Literal["a"] | Literal[b"foo"]"#,
r#"Unbound | Literal[1] | Literal["a"] | Literal[b"foo"]"#,
);

Ok(())
Expand Down

0 comments on commit b93d0ab

Please sign in to comment.