Skip to content

Commit

Permalink
Made struct members a query. made enriched members more direct. (#6271)
Browse files Browse the repository at this point in the history
  • Loading branch information
orizi committed Aug 25, 2024
1 parent b9be69c commit 401d763
Show file tree
Hide file tree
Showing 12 changed files with 84 additions and 92 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ use cairo_lang_semantic::db::SemanticGroup;
use cairo_lang_semantic::diagnostic::{NotFoundItemType, SemanticDiagnostics};
use cairo_lang_semantic::expr::inference::InferenceId;
use cairo_lang_semantic::items::function_with_body::SemanticExprLookup;
use cairo_lang_semantic::items::structure::SemanticStructEx;
use cairo_lang_semantic::items::us::SemanticUseEx;
use cairo_lang_semantic::lookup_item::{HasResolverData, LookupItemEx};
use cairo_lang_semantic::resolve::{ResolvedConcreteItem, ResolvedGenericItem, Resolver};
Expand Down Expand Up @@ -241,17 +240,15 @@ pub fn dot_completions(
// Find members of the type.
let (_, long_ty) = peel_snapshots(db, ty);
if let TypeLongId::Concrete(ConcreteTypeId::Struct(concrete_struct_id)) = long_ty {
db.concrete_struct_members(concrete_struct_id).ok()?.into_iter().for_each(
|(name, member)| {
let completion = CompletionItem {
label: name.to_string(),
detail: Some(member.ty.format(db.upcast())),
kind: Some(CompletionItemKind::FIELD),
..CompletionItem::default()
};
completions.push(completion);
},
);
db.concrete_struct_members(concrete_struct_id).ok()?.iter().for_each(|(name, member)| {
let completion = CompletionItem {
label: name.to_string(),
detail: Some(member.ty.format(db.upcast())),
kind: Some(CompletionItemKind::FIELD),
..CompletionItem::default()
};
completions.push(completion);
});
}
Some(completions)
}
Expand Down
3 changes: 1 addition & 2 deletions crates/cairo-lang-lowering/src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use cairo_lang_diagnostics::{Diagnostics, DiagnosticsBuilder, Maybe};
use cairo_lang_filesystem::ids::FileId;
use cairo_lang_semantic::db::SemanticGroup;
use cairo_lang_semantic::items::enm::SemanticEnumEx;
use cairo_lang_semantic::items::structure::SemanticStructEx;
use cairo_lang_semantic::{self as semantic, corelib, ConcreteTypeId, TypeId, TypeLongId};
use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
Expand Down Expand Up @@ -766,7 +765,7 @@ fn type_size(db: &dyn LoweringGroup, ty: TypeId) -> usize {
ConcreteTypeId::Struct(struct_id) => db
.concrete_struct_members(struct_id)
.unwrap()
.into_iter()
.iter()
.map(|(_, member)| db.type_size(member.ty))
.sum::<usize>(),
ConcreteTypeId::Enum(enum_id) => {
Expand Down
3 changes: 1 addition & 2 deletions crates/cairo-lang-lowering/src/lower/block_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
use cairo_lang_utils::ordered_hash_set::OrderedHashSet;
use cairo_lang_utils::{require, Intern, LookupIntern};
use itertools::{chain, zip_eq, Itertools};
use semantic::items::structure::SemanticStructEx;
use semantic::{ConcreteTypeId, ExprVarMemberPath, TypeLongId};

use super::context::{LoweredExpr, LoweringContext, LoweringFlowError, LoweringResult, VarRequest};
Expand Down Expand Up @@ -157,7 +156,7 @@ impl BlockBuilder {
generators::StructMemberAccess {
input: parent_var,
member_tys: members
.into_iter()
.iter()
.map(|(_, member)| {
wrap_in_snapshots(ctx.db.upcast(), member.ty, parent_number_of_snapshots)
})
Expand Down
8 changes: 4 additions & 4 deletions crates/cairo-lang-lowering/src/lower/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ use semantic::corelib::{
never_ty, unit_ty,
};
use semantic::items::constant::{value_as_const_value, ConstValue};
use semantic::items::structure::SemanticStructEx;
use semantic::literals::try_extract_minus_literal;
use semantic::types::{peel_snapshots, wrap_in_snapshots};
use semantic::{
Expand Down Expand Up @@ -736,7 +735,8 @@ fn lower_single_pattern(
})
.collect(),
};
for (var_id, (_, member)) in izip!(generator.add(ctx, &mut builder.statements), members)
for (var_id, (_, member)) in
izip!(generator.add(ctx, &mut builder.statements), members.iter())
{
if let Some(member_pattern) = required_members.remove(&member.id) {
let member_pattern = ctx.function_body.arenas.patterns[*member_pattern].clone();
Expand Down Expand Up @@ -1598,7 +1598,7 @@ fn lower_expr_member_access(
generators::StructMemberAccess {
input: lower_expr_to_var_usage(ctx, builder, expr.expr)?,
member_tys: members
.into_iter()
.iter()
.map(|(_, member)| wrap_in_snapshots(ctx.db.upcast(), member.ty, expr.n_snapshots))
.collect(),
member_idx,
Expand Down Expand Up @@ -1668,7 +1668,7 @@ fn lower_expr_struct_ctor(
Ok(LoweredExpr::AtVariable(
generators::StructConstruct {
inputs: members
.into_iter()
.iter()
.map(|(_, member)| member_expr_usages.remove(&member.id).unwrap())
.collect::<Result<Vec<_>, _>>()?,
ty: expr.ty,
Expand Down
8 changes: 7 additions & 1 deletion crates/cairo-lang-semantic/src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,10 +260,16 @@ pub trait SemanticGroup:
fn struct_members(
&self,
struct_id: StructId,
) -> Maybe<OrderedHashMap<SmolStr, semantic::Member>>;
) -> Maybe<Arc<OrderedHashMap<SmolStr, semantic::Member>>>;
/// Returns the resolution resolved_items of a struct definition.
#[salsa::invoke(items::structure::struct_definition_resolver_data)]
fn struct_definition_resolver_data(&self, structure_id: StructId) -> Maybe<Arc<ResolverData>>;
/// Returns the concrete members of a struct.
#[salsa::invoke(items::structure::concrete_struct_members)]
fn concrete_struct_members(
&self,
concrete_struct_id: types::ConcreteStructId,
) -> Maybe<Arc<OrderedHashMap<SmolStr, semantic::Member>>>;

// Enum.
// =======
Expand Down
71 changes: 34 additions & 37 deletions crates/cairo-lang-semantic/src/expr/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ use crate::items::feature_kind::extract_item_feature_config;
use crate::items::functions::function_signature_params;
use crate::items::imp::{filter_candidate_traits, infer_impl_by_self};
use crate::items::modifiers::compute_mutability;
use crate::items::structure::SemanticStructEx;
use crate::items::visibility;
use crate::literals::try_extract_minus_literal;
use crate::resolve::{ResolvedConcreteItem, ResolvedGenericItem, Resolver};
Expand Down Expand Up @@ -2058,7 +2057,7 @@ fn maybe_compute_pattern_semantic(
})?;
let pattern_param_asts = pattern_struct.params(syntax_db).elements(syntax_db);
let struct_id = concrete_struct_id.struct_id(ctx.db);
let mut members = ctx.db.concrete_struct_members(concrete_struct_id)?;
let mut members = ctx.db.concrete_struct_members(concrete_struct_id)?.as_ref().clone();
let mut used_members = UnorderedHashSet::<_>::default();
let mut get_member = |ctx: &mut ComputationContext<'_>,
member_name: SmolStr,
Expand Down Expand Up @@ -2121,8 +2120,8 @@ fn maybe_compute_pattern_semantic(
}
}
if !has_tail {
for (member_name, _) in members {
ctx.diagnostics.report(pattern_struct, MissingMember(member_name));
for (member_name, _) in members.iter() {
ctx.diagnostics.report(pattern_struct, MissingMember(member_name.clone()));
}
}
Pattern::Struct(PatternStruct {
Expand Down Expand Up @@ -2851,12 +2850,16 @@ fn enriched_members(

if let TypeLongId::Concrete(ConcreteTypeId::Struct(concrete_struct_id)) = long_ty {
let members = ctx.db.concrete_struct_members(concrete_struct_id)?;
for (member_name, member) in members.iter() {
res.insert(member_name.clone(), (member.clone(), 0));
if let Some(ref accessed_member_name) = accessed_member_name {
if *member_name == *accessed_member_name {
return Ok(EnrichedMembers { members: res, deref_functions });
}
if let Some(accessed_member_name) = &accessed_member_name {
if let Some(member) = members.get(accessed_member_name) {
return Ok(EnrichedMembers {
members: [(accessed_member_name.clone(), (member.clone(), 0))].into(),
deref_functions,
});
}
} else {
for (member_name, member) in members.iter() {
res.insert(member_name.clone(), (member.clone(), 0));
}
}
}
Expand All @@ -2881,28 +2884,18 @@ fn enriched_members(

// Add members of derefed types.
let mut n_deref = 0;
// If the variable is mutable, and implements DerefMut, we use DerefMut in the first iteration.
let mut use_deref_mut = match expr.clone().expr {
Expr::Var(expr_var) => {
let var_id = expr_var.var;
match ctx.semantic_defs.get(&var_id) {
Some(variable) if variable.is_mut() => {
compute_deref_method_function_call_data(ctx, expr.clone(), true).is_ok()
}
_ => false,
}
}
let base_var = match &expr.expr {
Expr::Var(expr_var) => Some(expr_var.var),
Expr::MemberAccess(ExprMemberAccess { member_path: Some(member_path), .. }) => {
let var_id = member_path.base_var();
match ctx.semantic_defs.get(&var_id) {
Some(variable) if variable.is_mut() => {
compute_deref_method_function_call_data(ctx, expr.clone(), true).is_ok()
}
_ => false,
}
Some(member_path.base_var())
}
_ => false,
_ => None,
};
// If the variable is mutable, and implements DerefMut, we use DerefMut in the first iteration.
let mut use_deref_mut = base_var
.filter(|var_id| matches!(ctx.semantic_defs.get(var_id), Some(var) if var.is_mut()))
.is_some()
&& compute_deref_method_function_call_data(ctx, expr.clone(), true).is_ok();

while let Ok((function_id, _, cur_expr, mutability)) =
compute_deref_method_function_call_data(ctx, expr, use_deref_mut)
Expand All @@ -2928,14 +2921,18 @@ fn enriched_members(
expr = ExprAndId { expr: derefed_expr.clone(), id: ctx.arenas.exprs.alloc(derefed_expr) };
if let TypeLongId::Concrete(ConcreteTypeId::Struct(concrete_struct_id)) = long_ty {
let members = ctx.db.concrete_struct_members(concrete_struct_id)?;
for (member_name, member) in members.iter() {
// Insert member if there is not already a member with the same name.
if res.get(&member_name.clone()).is_none() {
res.insert(member_name.clone(), (member.clone(), n_deref));
if let Some(ref accessed_member_name) = accessed_member_name {
if *member_name == *accessed_member_name {
return Ok(EnrichedMembers { members: res, deref_functions });
}
if let Some(accessed_member_name) = &accessed_member_name {
if let Some(member) = members.get(accessed_member_name) {
return Ok(EnrichedMembers {
members: [(accessed_member_name.clone(), (member.clone(), n_deref))].into(),
deref_functions,
});
}
} else {
for (member_name, member) in members.iter() {
// Insert member if there is not already a member with the same name.
if !res.contains_key(member_name) {
res.insert(member_name.clone(), (member.clone(), n_deref));
}
}
}
Expand Down
1 change: 0 additions & 1 deletion crates/cairo-lang-semantic/src/items/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ use smol_str::SmolStr;

use super::functions::{GenericFunctionId, GenericFunctionWithBodyId};
use super::imp::ImplId;
use super::structure::SemanticStructEx;
use crate::corelib::{
core_box_ty, core_felt252_ty, core_nonzero_ty, get_core_trait, get_core_ty_by_name,
try_extract_nz_wrapped_type, validate_literal, CoreTraitContext, LiteralError,
Expand Down
1 change: 0 additions & 1 deletion crates/cairo-lang-semantic/src/items/imp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ use super::impl_alias::{
impl_alias_generic_params_data_helper, impl_alias_semantic_data_cycle_helper,
impl_alias_semantic_data_helper, ImplAliasData,
};
use super::structure::SemanticStructEx;
use super::trt::{
ConcreteTraitConstantId, ConcreteTraitGenericFunctionId, ConcreteTraitGenericFunctionLongId,
};
Expand Down
48 changes: 24 additions & 24 deletions crates/cairo-lang-semantic/src/items/structure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use cairo_lang_proc_macros::{DebugWithDb, SemanticObject};
use cairo_lang_syntax::attribute::structured::{Attribute, AttributeListStructurize};
use cairo_lang_syntax::node::{Terminal, TypedStablePtr, TypedSyntaxNode};
use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
use cairo_lang_utils::{Intern, LookupIntern, Upcast};
use cairo_lang_utils::{Intern, LookupIntern};
use smol_str::SmolStr;

use super::attribute::SemanticQueryAttrs;
Expand Down Expand Up @@ -147,7 +147,7 @@ pub fn struct_declaration_resolver_data(
#[debug_db(dyn SemanticGroup + 'static)]
pub struct StructDefinitionData {
diagnostics: Diagnostics<SemanticDiagnostic>,
members: OrderedHashMap<SmolStr, Member>,
members: Arc<OrderedHashMap<SmolStr, Member>>,
resolver_data: Arc<ResolverData>,
}
#[derive(Clone, Debug, PartialEq, Eq, DebugWithDb, SemanticObject)]
Expand Down Expand Up @@ -213,7 +213,11 @@ pub fn priv_struct_definition_data(
}

let resolver_data = Arc::new(resolver.data);
Ok(StructDefinitionData { diagnostics: diagnostics.build(), members, resolver_data })
Ok(StructDefinitionData {
diagnostics: diagnostics.build(),
members: members.into(),
resolver_data,
})
}

/// Query implementation of [crate::db::SemanticGroup::struct_definition_diagnostics].
Expand Down Expand Up @@ -248,7 +252,7 @@ pub fn struct_definition_diagnostics(
pub fn struct_members(
db: &dyn SemanticGroup,
struct_id: StructId,
) -> Maybe<OrderedHashMap<SmolStr, Member>> {
) -> Maybe<Arc<OrderedHashMap<SmolStr, Member>>> {
Ok(db.priv_struct_definition_data(struct_id)?.members)
}

Expand All @@ -260,30 +264,26 @@ pub fn struct_definition_resolver_data(
Ok(db.priv_struct_definition_data(struct_id)?.resolver_data)
}

pub trait SemanticStructEx<'a>: Upcast<dyn SemanticGroup + 'a> {
fn concrete_struct_members(
&self,
concrete_struct_id: ConcreteStructId,
) -> Maybe<OrderedHashMap<SmolStr, semantic::Member>> {
// TODO(spapini): Uphold the invariant that constructed ConcreteEnumId instances
// always have the correct number of generic arguments.
let db = self.upcast();
let generic_params = db.struct_generic_params(concrete_struct_id.struct_id(db))?;
let generic_args = concrete_struct_id.lookup_intern(db).generic_args;
let substitution = GenericSubstitution::new(&generic_params, &generic_args);
/// Query implementation of [crate::db::SemanticGroup::concrete_struct_members].
pub fn concrete_struct_members(
db: &dyn SemanticGroup,
concrete_struct_id: ConcreteStructId,
) -> Maybe<Arc<OrderedHashMap<SmolStr, semantic::Member>>> {
// TODO(spapini): Uphold the invariant that constructed ConcreteEnumId instances
// always have the correct number of generic arguments.
let generic_params = db.struct_generic_params(concrete_struct_id.struct_id(db))?;
let generic_args = concrete_struct_id.lookup_intern(db).generic_args;
let substitution = GenericSubstitution::new(&generic_params, &generic_args);

let generic_members =
self.upcast().struct_members(concrete_struct_id.struct_id(self.upcast()))?;
let generic_members = db.struct_members(concrete_struct_id.struct_id(db))?;
Ok(Arc::new(
generic_members
.into_iter()
.iter()
.map(|(name, member)| {
let ty =
SubstitutionRewriter { db, substitution: &substitution }.rewrite(member.ty)?;
let member = semantic::Member { ty, ..member };
Ok((name, member))
Ok((name.clone(), semantic::Member { ty, ..member.clone() }))
})
.collect::<Maybe<_>>()
}
.collect::<Maybe<_>>()?,
))
}

impl<'a, T: Upcast<dyn SemanticGroup + 'a> + ?Sized> SemanticStructEx<'a> for T {}
2 changes: 1 addition & 1 deletion crates/cairo-lang-semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,7 @@ pub fn type_size_info(db: &dyn SemanticGroup, ty: TypeId) -> Maybe<TypeSizeInfor
TypeLongId::Concrete(concrete_type_id) => match concrete_type_id {
ConcreteTypeId::Struct(id) => {
let mut zero_sized = true;
for (_, member) in db.struct_members(id.struct_id(db))? {
for (_, member) in db.struct_members(id.struct_id(db))?.iter() {
if db.type_size_info(member.ty)? != TypeSizeInformation::ZeroSized {
zero_sized = false;
}
Expand Down
9 changes: 3 additions & 6 deletions crates/cairo-lang-sierra-generator/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use cairo_lang_diagnostics::Maybe;
use cairo_lang_lowering::ids::SemanticFunctionIdEx;
use cairo_lang_semantic as semantic;
use cairo_lang_semantic::items::enm::SemanticEnumEx;
use cairo_lang_semantic::items::structure::SemanticStructEx;
use cairo_lang_sierra::extensions::snapshot::snapshot_ty;
use cairo_lang_sierra::ids::UserTypeId;
use cairo_lang_sierra::program::{ConcreteTypeLongId, GenericArg as SierraGenericArg};
Expand Down Expand Up @@ -186,11 +185,9 @@ pub fn type_dependencies(
) -> Maybe<Arc<[semantic::TypeId]>> {
Ok(match type_id.lookup_intern(db) {
semantic::TypeLongId::Concrete(ty) => match ty {
semantic::ConcreteTypeId::Struct(structure) => db
.concrete_struct_members(structure)?
.into_iter()
.map(|(_, member)| member.ty)
.collect(),
semantic::ConcreteTypeId::Struct(structure) => {
db.concrete_struct_members(structure)?.iter().map(|(_, member)| member.ty).collect()
}
semantic::ConcreteTypeId::Enum(enm) => {
db.concrete_enum_variants(enm)?.into_iter().map(|variant| variant.ty).collect()
}
Expand Down
1 change: 0 additions & 1 deletion crates/cairo-lang-starknet/src/abi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ use cairo_lang_semantic::db::SemanticGroup;
use cairo_lang_semantic::items::attribute::SemanticQueryAttrs;
use cairo_lang_semantic::items::enm::SemanticEnumEx;
use cairo_lang_semantic::items::imp::{ImplLongId, ImplLookupContext};
use cairo_lang_semantic::items::structure::SemanticStructEx;
use cairo_lang_semantic::types::{get_impl_at_context, ConcreteEnumLongId, ConcreteStructLongId};
use cairo_lang_semantic::{
ConcreteTraitLongId, ConcreteTypeId, GenericArgumentId, GenericParam, Mutability, Signature,
Expand Down

0 comments on commit 401d763

Please sign in to comment.