Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: properly capture lvalues in closure environments (#2120) #2257

Merged
merged 1 commit into from
Aug 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions crates/noirc_frontend/src/hir/resolution/resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -919,7 +919,10 @@
fn resolve_lvalue(&mut self, lvalue: LValue) -> HirLValue {
match lvalue {
LValue::Ident(ident) => {
HirLValue::Ident(self.find_variable_or_default(&ident).0, Type::Error)
let ident = self.find_variable_or_default(&ident);
self.resolve_local_variable(ident.0, ident.1);

HirLValue::Ident(ident.0, Type::Error)
}
LValue::MemberAccess { object, field_name } => {
let object = Box::new(self.resolve_lvalue(*object));
Expand Down Expand Up @@ -1018,8 +1021,8 @@
self.interner.push_definition_type(hir_ident.id, typ);
}
}
// We ignore the above definition kinds because only local variables can be captured by closures.
DefinitionKind::Local(_) => {
// only local variables can be captured by closures.
self.resolve_local_variable(hir_ident, var_scope_index);
}
}
Expand Down Expand Up @@ -1925,7 +1928,7 @@
println(f"I want to print {0}");

let new_val = 10;
println(f"randomstring{new_val}{new_val}");

Check warning on line 1931 in crates/noirc_frontend/src/hir/resolution/resolver.rs

View workflow job for this annotation

GitHub Actions / Spellcheck / Spellcheck

Unknown word (randomstring)
}
fn println<T>(x : T) -> T {
x
Expand Down
48 changes: 31 additions & 17 deletions crates/noirc_frontend/src/monomorphization/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
pub mod printer;

struct LambdaContext {
env_ident: Box<ast::Expression>,
env_ident: ast::Ident,
captures: Vec<HirCapturedVar>,
}

Expand Down Expand Up @@ -552,13 +552,26 @@
ast::Expression::Block(definitions)
}

/// Find a captured variable in the innermost closure
fn lookup_captured(&mut self, id: node_interner::DefinitionId) -> Option<ast::Expression> {
/// Find a captured variable in the innermost closure, and construct an expression
fn lookup_captured_expr(&mut self, id: node_interner::DefinitionId) -> Option<ast::Expression> {
let ctx = self.lambda_envs_stack.last()?;
ctx.captures
.iter()
.position(|capture| capture.ident.id == id)
.map(|index| ast::Expression::ExtractTupleField(ctx.env_ident.clone(), index))
ctx.captures.iter().position(|capture| capture.ident.id == id).map(|index| {
ast::Expression::ExtractTupleField(
Box::new(ast::Expression::Ident(ctx.env_ident.clone())),
index,
)
})
}

/// Find a captured variable in the innermost closure construct a LValue
fn lookup_captured_lvalue(&mut self, id: node_interner::DefinitionId) -> Option<ast::LValue> {
let ctx = self.lambda_envs_stack.last()?;
ctx.captures.iter().position(|capture| capture.ident.id == id).map(|index| {
ast::LValue::MemberAccess {
object: Box::new(ast::LValue::Ident(ctx.env_ident.clone())),
field_index: index,
}
})
}

/// A local (ie non-global) ident only
Expand Down Expand Up @@ -599,7 +612,7 @@
}
}
DefinitionKind::Global(expr_id) => self.expr(*expr_id),
DefinitionKind::Local(_) => self.lookup_captured(ident.id).unwrap_or_else(|| {
DefinitionKind::Local(_) => self.lookup_captured_expr(ident.id).unwrap_or_else(|| {
let ident = self.local_ident(&ident).unwrap();
ast::Expression::Ident(ident)
}),
Expand Down Expand Up @@ -961,7 +974,9 @@

fn lvalue(&mut self, lvalue: HirLValue) -> ast::LValue {
match lvalue {
HirLValue::Ident(ident, _) => ast::LValue::Ident(self.local_ident(&ident).unwrap()),
HirLValue::Ident(ident, _) => self
.lookup_captured_lvalue(ident.id)
.unwrap_or_else(|| ast::LValue::Ident(self.local_ident(&ident).unwrap())),
HirLValue::MemberAccess { object, field_index, .. } => {
let field_index = field_index.unwrap();
let object = Box::new(self.lvalue(*object));
Expand Down Expand Up @@ -1031,7 +1046,7 @@
expr: node_interner::ExprId,
) -> (ast::Expression, ast::Expression) {
// returns (<closure setup>, <closure variable>)
// which can be used directly in callsites or transformed

Check warning on line 1049 in crates/noirc_frontend/src/monomorphization/mod.rs

View workflow job for this annotation

GitHub Actions / Spellcheck / Spellcheck

Unknown word (callsites)
// directly to a single `Expression`
// for other cases by `lambda` which is called by `expr`
//
Expand Down Expand Up @@ -1065,7 +1080,7 @@
match capture.transitive_capture_index {
Some(field_index) => match self.lambda_envs_stack.last() {
Some(lambda_ctx) => ast::Expression::ExtractTupleField(
lambda_ctx.env_ident.clone(),
Box::new(ast::Expression::Ident(lambda_ctx.env_ident.clone())),
field_index,
),
None => unreachable!(
Expand Down Expand Up @@ -1096,18 +1111,16 @@
let mutable = false;
let definition = Definition::Local(env_local_id);

let env_ident = ast::Expression::Ident(ast::Ident {
let env_ident = ast::Ident {
location,
mutable,
definition,
name: env_name.to_string(),
typ: env_typ.clone(),
});
};

self.lambda_envs_stack.push(LambdaContext {
env_ident: Box::new(env_ident.clone()),
captures: lambda.captures,
});
self.lambda_envs_stack
.push(LambdaContext { env_ident: env_ident.clone(), captures: lambda.captures });
let body = self.expr(lambda.body);
self.lambda_envs_stack.pop();

Expand All @@ -1129,7 +1142,8 @@
let function = ast::Function { id, name, parameters, body, return_type, unconstrained };
self.push_function(id, function);

let lambda_value = ast::Expression::Tuple(vec![env_ident, lambda_fn]);
let lambda_value =
ast::Expression::Tuple(vec![ast::Expression::Ident(env_ident), lambda_fn]);
let block_local_id = self.next_local_id();
let block_ident_name = "closure_variable";
let block_let_stmt = ast::Expression::Let(ast::Let {
Expand Down