diff --git a/Cargo.lock b/Cargo.lock index e1159c7120..f78fbfede2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -45,6 +45,8 @@ dependencies = [ "acir", "acvm_blackbox_solver", "ark-bls12-381", + "ark-bn254", + "bn254_blackbox_solver", "brillig_vm", "indexmap 1.9.3", "num-bigint", @@ -52,6 +54,7 @@ dependencies = [ "serde", "thiserror", "tracing", + "zkhash", ] [[package]] @@ -539,6 +542,18 @@ dependencies = [ "typenum", ] +[[package]] +name = "bitvec" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c" +dependencies = [ + "funty", + "radium", + "tap", + "wyz", +] + [[package]] name = "blake2" version = "0.10.6" @@ -548,6 +563,17 @@ dependencies = [ "digest", ] +[[package]] +name = "blake2b_simd" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23285ad32269793932e830392f2fe2f83e26488fd3ec778883a93c8323735780" +dependencies = [ + "arrayref", + "arrayvec", + "constant_time_eq", +] + [[package]] name = "blake3" version = "1.5.0" @@ -570,6 +596,19 @@ dependencies = [ "generic-array", ] +[[package]] +name = "bls12_381" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3c196a77437e7cc2fb515ce413a6401291578b5afc8ecb29a3c7ab957f05941" +dependencies = [ + "ff 0.12.1", + "group 0.12.1", + "pairing", + "rand_core 0.6.4", + "subtle", +] + [[package]] name = "bn254_blackbox_solver" version = "0.49.0" @@ -1360,9 +1399,9 @@ dependencies = [ "crypto-bigint", "der", "digest", - "ff", + "ff 0.12.1", "generic-array", - "group", + "group 0.12.1", "pkcs8", "rand_core 0.6.4", "sec1", @@ -1465,6 +1504,18 @@ version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d013fc25338cc558c5c2cfbad646908fb23591e2404481826742b651c9af7160" dependencies = [ + "bitvec", + "rand_core 0.6.4", + "subtle", +] + +[[package]] +name = "ff" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ded41244b729663b1e574f1b4fb731469f69f79c17667b5d776b16cda0479449" +dependencies = [ + "bitvec", "rand_core 0.6.4", "subtle", ] @@ -1569,6 +1620,12 @@ dependencies = [ "libc", ] +[[package]] +name = "funty" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" + [[package]] name = "futures" version = "0.1.31" @@ -1764,7 +1821,19 @@ version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5dfbfb3a6cfbd390d5c9564ab283a0349b9b9fcd46a706c1eb10e0db70bfbac7" dependencies = [ - "ff", + "ff 0.12.1", + "memuse", + "rand_core 0.6.4", + "subtle", +] + +[[package]] +name = "group" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0f9ef7462f7c099f518d754361858f86d8a07af53ba9af0fe635bbccb151a63" +dependencies = [ + "ff 0.13.0", "rand_core 0.6.4", "subtle", ] @@ -1775,6 +1844,29 @@ version = "1.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eabb4a44450da02c90444cf74558da904edde8fb4e9035a9a6a4e15445af0bd7" +[[package]] +name = "halo2" +version = "0.1.0-beta.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a23c779b38253fe1538102da44ad5bd5378495a61d2c4ee18d64eaa61ae5995" +dependencies = [ + "halo2_proofs", +] + +[[package]] +name = "halo2_proofs" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e925780549adee8364c7f2b685c753f6f3df23bde520c67416e93bf615933760" +dependencies = [ + "blake2b_simd", + "ff 0.12.1", + "group 0.12.1", + "pasta_curves 0.4.1", + "rand_core 0.6.4", + "rayon", +] + [[package]] name = "hashbrown" version = "0.11.2" @@ -2252,6 +2344,20 @@ dependencies = [ "unicase", ] +[[package]] +name = "jubjub" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a575df5f985fe1cd5b2b05664ff6accfc46559032b954529fd225a2168d27b0f" +dependencies = [ + "bitvec", + "bls12_381", + "ff 0.12.1", + "group 0.12.1", + "rand_core 0.6.4", + "subtle", +] + [[package]] name = "k256" version = "0.11.6" @@ -2329,6 +2435,9 @@ name = "lazy_static" version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" +dependencies = [ + "spin", +] [[package]] name = "libaes" @@ -2464,6 +2573,12 @@ dependencies = [ "autocfg", ] +[[package]] +name = "memuse" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2145869435ace5ea6ea3d35f59be559317ec9a0d04e1812d5f185a87b6d36f1a" + [[package]] name = "miniz_oxide" version = "0.7.1" @@ -3100,6 +3215,15 @@ dependencies = [ "sha2", ] +[[package]] +name = "pairing" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "135590d8bdba2b31346f9cd1fb2a912329f5135e832a4f422942eb6ead8b6b3b" +dependencies = [ + "group 0.12.1", +] + [[package]] name = "parking_lot" version = "0.11.2" @@ -3148,6 +3272,36 @@ dependencies = [ "windows-targets 0.48.1", ] +[[package]] +name = "pasta_curves" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5cc65faf8e7313b4b1fbaa9f7ca917a0eed499a9663be71477f87993604341d8" +dependencies = [ + "blake2b_simd", + "ff 0.12.1", + "group 0.12.1", + "lazy_static", + "rand 0.8.5", + "static_assertions", + "subtle", +] + +[[package]] +name = "pasta_curves" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3e57598f73cc7e1b2ac63c79c517b31a0877cd7c402cdcaa311b5208de7a095" +dependencies = [ + "blake2b_simd", + "ff 0.13.0", + "group 0.13.0", + "lazy_static", + "rand 0.8.5", + "static_assertions", + "subtle", +] + [[package]] name = "paste" version = "1.0.14" @@ -3473,6 +3627,12 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "radium" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09" + [[package]] name = "radix_trie" version = "0.2.1" @@ -4193,6 +4353,12 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "spin" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" + [[package]] name = "spki" version = "0.6.0" @@ -4209,6 +4375,12 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" +[[package]] +name = "static_assertions" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" + [[package]] name = "str-buf" version = "1.0.6" @@ -4310,6 +4482,12 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "tap" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" + [[package]] name = "tempfile" version = "3.8.0" @@ -5178,6 +5356,15 @@ dependencies = [ "memchr", ] +[[package]] +name = "wyz" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f360fc0b24296329c78fda852a1e9ae82de9cf7b27dae4b7f62f118f77b9ed" +dependencies = [ + "tap", +] + [[package]] name = "zerocopy" version = "0.7.32" @@ -5217,3 +5404,30 @@ dependencies = [ "quote", "syn 2.0.64", ] + +[[package]] +name = "zkhash" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4352d1081da6922701401cdd4cbf29a2723feb4cfabb5771f6fee8e9276da1c7" +dependencies = [ + "ark-ff", + "ark-std", + "bitvec", + "blake2", + "bls12_381", + "byteorder", + "cfg-if 1.0.0", + "group 0.12.1", + "group 0.13.0", + "halo2", + "hex", + "jubjub", + "lazy_static", + "pasta_curves 0.5.1", + "rand 0.8.5", + "serde", + "sha2", + "sha3", + "subtle", +] diff --git a/Cargo.toml b/Cargo.toml index 7d9b3254c5..52cb1012b7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -166,6 +166,14 @@ rust-embed = "6.6.0" # See https://ritik-mishra.medium.com/resolving-the-wasm-pack-error-locals-exceed-maximum-ec3a9d96685b opt-level = 1 +# release mode with extra checks, e.g. overflow checks +[profile.release-pedantic] +inherits = "release" +overflow-checks = true + +[profile.test] +inherits = "dev" +overflow-checks = true [profile.size] inherits = "release" diff --git a/acvm-repo/acir/src/circuit/mod.rs b/acvm-repo/acir/src/circuit/mod.rs index 43984e4a92..f700fefe0c 100644 --- a/acvm-repo/acir/src/circuit/mod.rs +++ b/acvm-repo/acir/src/circuit/mod.rs @@ -153,9 +153,28 @@ pub struct ResolvedOpcodeLocation { /// map opcodes to debug information related to their context. pub enum OpcodeLocation { Acir(usize), + // TODO(https://github.com/noir-lang/noir/issues/5792): We can not get rid of this enum field entirely just yet as this format is still + // used for resolving assert messages which is a breaking serialization change. Brillig { acir_index: usize, brillig_index: usize }, } +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)] +pub struct BrilligOpcodeLocation(pub usize); + +impl OpcodeLocation { + // Utility method to allow easily comparing a resolved Brillig location and a debug Brillig location. + // This method is useful when fetching Brillig debug locations as this does not need an ACIR index, + // and just need the Brillig index. + pub fn to_brillig_location(self) -> Option { + match self { + OpcodeLocation::Brillig { brillig_index, .. } => { + Some(BrilligOpcodeLocation(brillig_index)) + } + _ => None, + } + } +} + impl std::fmt::Display for OpcodeLocation { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { @@ -204,6 +223,13 @@ impl FromStr for OpcodeLocation { } } +impl std::fmt::Display for BrilligOpcodeLocation { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let index = self.0; + write!(f, "{index}") + } +} + impl Circuit { pub fn num_vars(&self) -> u32 { self.current_witness_index + 1 diff --git a/acvm-repo/acvm/Cargo.toml b/acvm-repo/acvm/Cargo.toml index bf1170ce07..ea80dbeedb 100644 --- a/acvm-repo/acvm/Cargo.toml +++ b/acvm-repo/acvm/Cargo.toml @@ -13,8 +13,6 @@ repository.workspace = true [lints] workspace = true -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - [dependencies] num-bigint.workspace = true thiserror.workspace = true @@ -41,4 +39,7 @@ bls12_381 = [ [dev-dependencies] ark-bls12-381 = { version = "^0.4.0", default-features = false, features = ["curve"] } +ark-bn254.workspace = true +bn254_blackbox_solver.workspace = true proptest.workspace = true +zkhash = { version = "^0.2.0", default-features = false } diff --git a/acvm-repo/acvm/tests/solver.proptest-regressions b/acvm-repo/acvm/tests/solver.proptest-regressions index 35627c1fba..d5b09c8c00 100644 --- a/acvm-repo/acvm/tests/solver.proptest-regressions +++ b/acvm-repo/acvm/tests/solver.proptest-regressions @@ -4,6 +4,7 @@ # # It is recommended to check this file in to source control so that # everyone who runs the test benefits from these saved cases. +cc 960460afabe9cd5f293b310d0aa3cc55f79163d4445fb4fd24f148c0c70ef421 # shrinks to inputs = [(7006800039331019243688393824145158009389575060470190384886350246083750305789, false), (2¹⁶×69288181877743077422831840787272084000814053518824660036579689151244970, false), (300437348917483825612039528102296693334006129135490594023152856084953868810, false), (4823492773360499854802554302585328525608528551020562050189864554765904117770, false), (4824670368033820130439612125278697378018346703067216363823090518963541115050, false), (4541974568008122685711264049174891850838557765495784399449231366824048721930, false), (300437350040540685984720241501165014133014058360545849645894867666537650617, true), (163192779580560125232076252958207681919, true), (-2542062050439421159413867956702051849160671674920746658341296409298182239337, true), (175503121977504223703263610086399928344, true), (260495679943662138441041500775727067764, true), (256267345976523334436046522042864460684, false), (-3293060738351338103745720847983311414454700585429643550776950744744641792278, true), (-2738558284890938529700946417088255944936801663233427893951742481234449937863, true), (-9408903287898353981993159772236368632269873213738312580426228276387594806467, false), (-4094435362802425278873069427818373025789998862201415928058861770747528163349, false), (112103815807797910215405155129290739992, true), (652959048122537987900699228406811595117461985442100769568930484681117184927, true), (25648069884835098934129943086492583018, true), (182381658164976581567218989771782914043, true), (276027589667111318145641746681175535935, true), (228979537701935045218138086489994361738, true), (-7189629140037494234317011123662291175105203248014412116387794444668331938972, true), (2⁴×21045079095550098765213744333690206045, true), (2⁴×272957279228827760019939917744285068022450150396463247147001692315953383164, true), (278169782777797975865517348961398261204, true), (-9350891260062047521505428797610570271715889226434172923277698254137606657637, true), (31753891497170711449574490616043022555, true), (5826530521679896062199745042089932193960338071930275815357730199930663038971, false), (6333148601890872882199727702821458893692619379008420418245130578951125712813, true), (11591994241860040248819868426028245994970593822060591090353134114607771350840, false), (285019902865519749239467179980780356895, true), (-1841690738801312317622577110475383143336981923160464720298620703157843518801, false), (5677035692440022901187907185898133134673347474144834237621131124787028412481, false), (3817657503746267247293557834223189311, false), (3524986425011502884950940801392664786470720156074841155252661063475974853061, true), (3804027911019145431046523949797604749769206523757137367319384533460898442710, false), (66232570440565238307407257297826182007, true), (10950346695884721979708694328592923581621839140812284359518482850256345903994, true), (-2555270735135738372913055831827559435131548066765611995534606037449334265987, false), (254175244691305364094581273598120793718, true), (290269702902160220362408941750532374247, false), (123377132611149223348638391192338552487, true), (11181151842165891780257230406582701746739259352010643125205313477915192331122, true), (-2⁴×409367997994623731443501907254201822154034077116501532446678462812476345374, true), (2⁴×544614809509366338085123708986493598, false), (2⁴×17078170458432298388079487507615195387, true), (149630116761755329336202584260592884865, false), (2⁴×10091764197139943452626559964081430454, true), (2534364375519095349027150248721158492583764392404691624469146116049080450061, true), (603871911530068460120598744369974193816624318623555065050680487659787773114, true), (160084452435578937990759565211574469929, false), (218093312507148512669194680815491642329, true), (300736043265743439077982281106014100410, true), (18523147759780982956728003842757122081, true), (306248530673006182475517947158145534649, true), (102222390890150910453739963495986974375, false), (257196139219119037476292875587496141911, true), (331081228043791906616479284936784004527, false), (3016805042897548992579828613322067728372560923112623438800394966350526781639, false), (-572714658027596419754534645583807426372197275429130874987048126395899381912, true), (-729547819819366432764794677355722188538208598242889613088699865869727118200, true), (821714036325498522863927403152236109, false), (4787751092911790208417000875909331970360784454334971842120679656099485179967, false), (1782142348917496268507120173995265725105472252018261170186756542207304517779, true), (12386003977677513084491398210232899807, false), (-2⁴×618559586542445675091456136080909891843100214761559602644497715652929490680, true), (4825546553395317448131256122977630559401382302668933235792589235215581675013, true), (140081105962524159719688476469460050188, true), (10859915619506763272622767997975661267764177865434405229649588030297621763336, false), (5224965182422450783905844900991128046, false), (71305303954419714098999884166539099370, false), (1774498925999897022672533369155675274519638992774354242677460742588824167893, true), (7303027640230419310139844673790234369436236422573534223908300465886393464159, false), (-6886271862152701865078616271612468719595307788386883226924940208040368361967, false), (10731912210227615778772646231985571152124249989191718989152918741848181957943, true), (198890466402904980166209748345696486119, true), (-5302367762468092462990244835112194335091450476355385352485920702032880189295, true), (2⁸×10099407372054304360516381277909507981224443414132090997901220204171639546, false), (168440361528147182399972323599534901404, true), (-6814422589767867668898643051011344807889685848964308955561956885501315175751, false), (316318615743355850862007524903255359480, true), (2712091589809687549346821393688518190583490946219058843320554352096717761470, true), (271066251172692433277334139422005324431, true), (5518750645785029524081697870015943313782873802072211951410725469084236770902, true), (1598180654131654153665855630440993899751870982081743673174495880400279439835, true), (11087487398111971618312866830968963281800717504531211030365913622073456199882, true), (111117995977996342335283103148096717294, true), (-1295976581376625868027817481511298593289394037483836095148092178778026996567, true), (83596219152712990240747906837305397145, true), (1228889622953345470420372628777807948517612731014485410422479520987606375826, true), (-8913468232804310348258666527777442833035058410912856005485334678693975514181, false), (-2⁴×278169562875273203973590188765589611858625386111898684152446666758264013985, true), (119076628552542975308949066183764445758, false), (2⁴×14050391655614718300501271078314774634, false), (194148774408232082873511641818616649590, false), (8194443279687593339348520601136879720103397639149864695196588804167374140678, false), (-354687261929126904957780450034290767854037900960107070732201978247851494917, false), (-4179720645035540348033514199762031592916784883648049359344050592582034531797, false), (252905126534820884873153821542654918918, false), (1623505226178950044083922559187169870608127387658796034641797863725138722258, false), (4666866662074690176801634318736786643725641691879698870430777699162065406876, true), (2⁸×435424843840757189423437488726199338, true), (-4965389845430595378911449436549917702549453083792442397843742095714939156996, true), (14382174601370788118680121014383071020, true), (18435597849872688842745199082065115774, true), (62402934393754095161223999760955177676, true), (3342319610924388381676805660071709048562626908839360397545889815909430346935, false), (320934966917535193944554370897413442767, false), (-7042080042894669842698869952298956796788859350146750227466425877912934318731, false), (730331371245847611804891620969657153593889954959738225897355098167792980949, false), (4302619576720736253790944194743764744133983784456893791762467136654139077499, true), (8873437849856151332967264853182096823478282147593198413978830303785491549165, true), (90112803136689598890112508644364307477, false), (5515586277356922699974022986439856116085865722104320017169902845027519269630, true), (193123430947542232681959839788943833622, false), (313518414676988815146392681741405591399, false), (1744503198465993468536485478138286257957976156358242481202409192902804688819, true), (232270086930796904432134012473980400372, false), (54252678944965710450270323121364517361, false), (227966767610810610640465794393680628548, true), (195275454883093736141594026466958950870, false), (41489184570808717515029072032722768645, false), (-1376299006702054572367600313792943879784167154440024183554470146184948865317, false), (-4342520612104393712790888025887841267334261124780864718056663459023881749401, false), (-2195492958636161600210780697531407029867553813415393642723791298864497552233, true), (332196823653718041971704314790182081950, true), (6692397022125437472229833939879677974689961856437724474341917567424805187801, true), (-2383455130457463032029147766306601662937360882870225696738568438241981024563, true), (6717167117084866258340101046663161160224757566359450082951177418440696737772, true), (-1642389908034348896239060205015333283837806723968213066158348166269363356454, true), (6906326300271352814414395025112363945375522026620021146077167231059213695969, false), (62420300707619024606657497955183365269, false), (310840635777201053115164797040466639857, true)] cc e4dd0e141df173f5dfdfb186bba4154247ec284b71d8f294fa3282da953a0e92 # shrinks to x = 0, y = 1 cc 419ed6fdf1bf1f2513889c42ec86c665c9d0500ceb075cbbd07f72444dbd78c6 # shrinks to x = 266672725 cc 0810fc9e126b56cf0a0ddb25e0dc498fa3b2f1980951550403479fc01c209833 # shrinks to modulus = [71, 253, 124, 216, 22, 140, 32, 60, 141, 202, 113, 104, 145, 106, 129, 151, 93, 88, 129, 129, 182, 69, 80, 184, 41, 160, 49, 225, 114, 78, 100, 48], zero_or_ones_constant = false, use_constant = false @@ -11,3 +12,8 @@ cc 735ee9beb1a1dbb82ded6f30e544d7dfde149957e5d45a8c96fc65a690b6b71c # shrinks to cc ca81bc11114a2a2b34021f44ecc1e10cb018e35021ef4d728e07a6791dad38d6 # shrinks to (xs, modulus) = ([(0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (49, false)], [71, 253, 124, 216, 22, 140, 32, 60, 141, 202, 113, 104, 145, 106, 129, 151, 93, 88, 129, 129, 182, 69, 80, 184, 41, 160, 49, 225, 114, 78, 100, 48]) cc 6c1d571a0111e6b4c244dc16da122ebab361e77b71db7770d638076ab21a717b # shrinks to (xs, modulus) = ([(0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (49, false)], [71, 253, 124, 216, 22, 140, 32, 60, 141, 202, 113, 104, 145, 106, 129, 151, 93, 88, 129, 129, 182, 69, 80, 184, 41, 160, 49, 225, 114, 78, 100, 48]) cc ccb7061ab6b85e2554d00bf03d74204977ed7a4109d7e2d5c6b5aaa2179cfaf9 # shrinks to (xs, modulus) = ([(0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (0, false), (49, false)], [71, 253, 124, 216, 22, 140, 32, 60, 141, 202, 113, 104, 145, 106, 129, 151, 93, 88, 129, 129, 182, 69, 80, 184, 41, 160, 49, 225, 114, 78, 100, 48]) +cc 853d774f6d69809a63e3121ccdbbb780db42acb861cb6cada63247446932a321 # shrinks to inputs_distinct_inputs = ([(0, false)], [(9374390252263900826, false)]) +cc 4fc0bd347d9f4967801e2e30c746d2f6c012882911f72e7e816d350a742ced28 # shrinks to inputs_distinct_inputs = ([(2⁸×61916068613087029720904767285796661, false)], [(2⁸×220343640628484768581538005104492351, false)]) +cc 04d8571793600c2023d7aba2d1dd8f0e2c82b6010130d95a193df02b07977712 # shrinks to inputs_distinct_inputs = ([], [(0, true)]) +cc dbc57772b9450371db70f8aa06d10502bb1aef030448c6df467465937bc8916a # shrinks to inputs_distinct_inputs = ([(295, false), (0, false), (0, false), (0, false), (0, false), (0, false)], [(295, false), (0, false), (328, false), (237, true), (484, true), (69, false)]) +cc ef68d2dc6f0d366dd69edf8eec02a7b9cd7d6888983cea45496516b6effca813 # shrinks to inputs_distinct_inputs = ([(40, false), (471, false), (56, false), (35, false), (104, false), (232, false), (252, false), (131, false), (437, true), (354, false), (235, false), (316, true), (364, true), (242, false), (436, true), (298, true), (360, true), (174, true), (295, false), (250, true), (178, true), (426, false), (78, false), (217, true), (296, true), (371, false), (349, true), (445, false), (221, false), (409, false), (59, false), (511, true), (482, false)], [(136, true), (228, true), (193, true), (190, true), (15, false), (399, false), (54, false), (195, true), (258, true), (99, false), (83, false), (383, true), (456, true), (409, true), (347, false), (183, false), (371, true), (410, true), (439, true), (175, true), (445, false), (165, false), (70, false), (2⁴×22, true), (339, true), (161, true), (313, false), (2⁴×23, true), (275, true), (278, true), (294, true), (284, true), (262, false)]) diff --git a/acvm-repo/acvm/tests/solver.rs b/acvm-repo/acvm/tests/solver.rs index a1b8b62f8b..2a06e07f09 100644 --- a/acvm-repo/acvm/tests/solver.rs +++ b/acvm-repo/acvm/tests/solver.rs @@ -1,4 +1,5 @@ -use std::collections::BTreeMap; +use std::collections::{BTreeMap, HashSet}; +use std::sync::Arc; use acir::{ acir_field::GenericFieldElement, @@ -14,13 +15,14 @@ use acir::{ use acvm::pwg::{ACVMStatus, ErrorLocation, ForeignCallWaitInfo, OpcodeResolutionError, ACVM}; use acvm_blackbox_solver::StubbedBlackBoxSolver; +use bn254_blackbox_solver::{field_from_hex, Bn254BlackBoxSolver, POSEIDON2_CONFIG}; use brillig_vm::brillig::HeapValueType; use proptest::arbitrary::any; use proptest::prelude::*; use proptest::result::maybe_ok; - -// Reenable these test cases once we move the brillig implementation of inversion down into the acvm stdlib. +use proptest::sample::select; +use zkhash::poseidon2::poseidon2_params::Poseidon2Params; #[test] fn bls12_381_circuit() { @@ -728,6 +730,248 @@ fn memory_operations() { assert_eq!(witness_map[&Witness(8)], FieldElement::from(6u128)); } +fn allowed_bigint_moduli() -> Vec> { + let bn254_fq: Vec = vec![ + 0x47, 0xFD, 0x7C, 0xD8, 0x16, 0x8C, 0x20, 0x3C, 0x8d, 0xca, 0x71, 0x68, 0x91, 0x6a, 0x81, + 0x97, 0x5d, 0x58, 0x81, 0x81, 0xb6, 0x45, 0x50, 0xb8, 0x29, 0xa0, 0x31, 0xe1, 0x72, 0x4e, + 0x64, 0x30, + ]; + let bn254_fr: Vec = vec![ + 1, 0, 0, 240, 147, 245, 225, 67, 145, 112, 185, 121, 72, 232, 51, 40, 93, 88, 129, 129, + 182, 69, 80, 184, 41, 160, 49, 225, 114, 78, 100, 48, + ]; + let secpk1_fr: Vec = vec![ + 0x41, 0x41, 0x36, 0xD0, 0x8C, 0x5E, 0xD2, 0xBF, 0x3B, 0xA0, 0x48, 0xAF, 0xE6, 0xDC, 0xAE, + 0xBA, 0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, + ]; + let secpk1_fq: Vec = vec![ + 0x2F, 0xFC, 0xFF, 0xFF, 0xFE, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, + ]; + let secpr1_fq: Vec = vec![ + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0xFF, 0xFF, + 0xFF, 0xFF, + ]; + let secpr1_fr: Vec = vec![ + 81, 37, 99, 252, 194, 202, 185, 243, 132, 158, 23, 167, 173, 250, 230, 188, 255, 255, 255, + 255, 255, 255, 255, 255, 0, 0, 0, 0, 255, 255, 255, 255, + ]; + + vec![bn254_fq, bn254_fr, secpk1_fr, secpk1_fq, secpr1_fq, secpr1_fr] +} + +/// Whether to use a FunctionInput::constant or FunctionInput::witness: +/// +/// (value, use_constant) +type ConstantOrWitness = (FieldElement, bool); + +// For each ConstantOrWitness, +// - If use_constant, then convert to a FunctionInput::constant +// - Otherwise, convert to FunctionInput::witness +// + With the Witness index as (input_index + offset) +fn constant_or_witness_to_function_inputs( + xs: Vec, + offset: usize, + num_bits: Option, +) -> Vec> { + let num_bits = num_bits.unwrap_or(FieldElement::max_num_bits()); + xs.into_iter() + .enumerate() + .map(|(i, (x, use_constant))| { + if use_constant { + FunctionInput::constant(x, num_bits) + } else { + FunctionInput::witness(Witness((i + offset) as u32), num_bits) + } + }) + .collect() +} + +// Convert ConstantOrWitness's back to FieldElement's by dropping the bool's +fn drop_use_constant(input: &[ConstantOrWitness]) -> Vec { + input.iter().map(|x| x.0).collect() +} + +// equivalent values (ignoring use_constant) +fn drop_use_constant_eq(x: &[ConstantOrWitness], y: &[ConstantOrWitness]) -> bool { + drop_use_constant(x) == drop_use_constant(y) +} + +// Convert FieldElement's to ConstantOrWitness's by making all of them witnesses +fn use_witnesses(inputs: Vec) -> Vec { + inputs.into_iter().map(|input| (input, false)).collect() +} + +fn solve_array_input_blackbox_call( + inputs: Vec, + num_outputs: usize, + num_bits: Option, + f: F, +) -> Vec +where + F: FnOnce((Vec>, Vec)) -> BlackBoxFuncCall, +{ + let initial_witness_vec: Vec<_> = + inputs.iter().enumerate().map(|(i, (x, _))| (Witness(i as u32), *x)).collect(); + let outputs: Vec<_> = (0..num_outputs) + .map(|i| Witness((i + inputs.len()) as u32)) // offset past the indices of inputs + .collect(); + let initial_witness = WitnessMap::from(BTreeMap::from_iter(initial_witness_vec)); + + let inputs = constant_or_witness_to_function_inputs(inputs, 0, num_bits); + let op = Opcode::BlackBoxFuncCall(f((inputs.clone(), outputs.clone()))); + let opcodes = vec![op]; + let unconstrained_functions = vec![]; + let mut acvm = + ACVM::new(&Bn254BlackBoxSolver, &opcodes, initial_witness, &unconstrained_functions, &[]); + let solver_status = acvm.solve(); + assert_eq!(solver_status, ACVMStatus::Solved); + let witness_map = acvm.finalize(); + + outputs + .iter() + .map(|witness| *witness_map.get(witness).expect("all witnesses to be set")) + .collect() +} + +prop_compose! { + fn bigint_with_modulus()(modulus in select(allowed_bigint_moduli())) + (inputs in proptest::collection::vec(any::<(u8, bool)>(), modulus.len()), modulus in Just(modulus)) + -> (Vec, Vec) { + let inputs = inputs.into_iter().zip(modulus.iter()).map(|((input, use_constant), modulus_byte)| { + (FieldElement::from(input.clamp(0, *modulus_byte) as u128), use_constant) + }).collect(); + (inputs, modulus) + } +} + +prop_compose! { + fn bigint_pair_with_modulus()(inputs_modulus in bigint_with_modulus()) + (second_inputs in proptest::collection::vec(any::<(u8, bool)>(), inputs_modulus.1.len()), inputs_modulus in Just(inputs_modulus)) + -> (Vec, Vec, Vec) { + let (inputs, modulus) = inputs_modulus; + let second_inputs = second_inputs.into_iter().zip(modulus.iter()).map(|((input, use_constant), modulus_byte)| { + (FieldElement::from(input.clamp(0, *modulus_byte) as u128), use_constant) + }).collect(); + (inputs, second_inputs, modulus) + } +} + +prop_compose! { + fn bigint_triple_with_modulus()(inputs_pair_modulus in bigint_pair_with_modulus()) + (third_inputs in proptest::collection::vec(any::<(u8, bool)>(), inputs_pair_modulus.2.len()), inputs_pair_modulus in Just(inputs_pair_modulus)) + -> (Vec, Vec, Vec, Vec) { + let (inputs, second_inputs, modulus) = inputs_pair_modulus; + let third_inputs = third_inputs.into_iter().zip(modulus.iter()).map(|((input, use_constant), modulus_byte)| { + (FieldElement::from(input.clamp(0, *modulus_byte) as u128), use_constant) + }).collect(); + (inputs, second_inputs, third_inputs, modulus) + } +} + +fn bigint_add_op() -> BlackBoxFuncCall { + BlackBoxFuncCall::BigIntAdd { lhs: 0, rhs: 1, output: 2 } +} + +fn bigint_mul_op() -> BlackBoxFuncCall { + BlackBoxFuncCall::BigIntMul { lhs: 0, rhs: 1, output: 2 } +} + +fn bigint_sub_op() -> BlackBoxFuncCall { + BlackBoxFuncCall::BigIntSub { lhs: 0, rhs: 1, output: 2 } +} + +fn bigint_div_op() -> BlackBoxFuncCall { + BlackBoxFuncCall::BigIntDiv { lhs: 0, rhs: 1, output: 2 } +} + +// Input is a BigInt, represented as a LE Vec of u8-range FieldElement's along +// with their use_constant values. +// +// Output is a zeroed BigInt that matches the input BigInt's +// - Byte length +// - use_constant values +fn bigint_zeroed(inputs: &[ConstantOrWitness]) -> Vec { + inputs.iter().map(|(_, use_constant)| (FieldElement::zero(), *use_constant)).collect() +} + +// bigint_zeroed, but returns one +fn bigint_to_one(inputs: &[ConstantOrWitness]) -> Vec { + let mut one = bigint_zeroed(inputs); + // little-endian + one[0] = (FieldElement::one(), one[0].1); + one +} + +// Using the given BigInt modulus, solve the following circuit: +// - Convert xs, ys to BigInt's with ID's 0, 1, resp. +// - If the middle_op is present, run it +// + Input BigInt ID's: 0, 1 +// + Output BigInt ID: 2 +// - If the middle_op is missing, the output BigInt ID is 0 +// - Run BigIntToLeBytes on the output BigInt ID +// - Output the resulting Vec of LE bytes +fn bigint_solve_binary_op_opt( + middle_op: Option>, + modulus: Vec, + lhs: Vec, + rhs: Vec, +) -> Vec { + let initial_witness_vec: Vec<_> = lhs + .iter() + .chain(rhs.iter()) + .enumerate() + .map(|(i, (x, _))| (Witness(i as u32), *x)) + .collect(); + let output_witnesses: Vec<_> = initial_witness_vec + .iter() + .take(lhs.len()) + .enumerate() + .map(|(index, _)| Witness((index + 2 * lhs.len()) as u32)) // offset past the indices of lhs, rhs + .collect(); + let initial_witness = WitnessMap::from(BTreeMap::from_iter(initial_witness_vec)); + + let lhs = constant_or_witness_to_function_inputs(lhs, 0, None); + let rhs = constant_or_witness_to_function_inputs(rhs, lhs.len(), None); + + let to_op_input = if middle_op.is_some() { 2 } else { 0 }; + + let bigint_from_lhs_op = Opcode::BlackBoxFuncCall(BlackBoxFuncCall::BigIntFromLeBytes { + inputs: lhs, + modulus: modulus.clone(), + output: 0, + }); + let bigint_from_rhs_op = Opcode::BlackBoxFuncCall(BlackBoxFuncCall::BigIntFromLeBytes { + inputs: rhs, + modulus: modulus.clone(), + output: 1, + }); + let bigint_to_op = Opcode::BlackBoxFuncCall(BlackBoxFuncCall::BigIntToLeBytes { + input: to_op_input, + outputs: output_witnesses.clone(), + }); + + let mut opcodes = vec![bigint_from_lhs_op, bigint_from_rhs_op]; + if let Some(middle_op) = middle_op { + opcodes.push(Opcode::BlackBoxFuncCall(middle_op)); + } + opcodes.push(bigint_to_op); + + let unconstrained_functions = vec![]; + let mut acvm = + ACVM::new(&StubbedBlackBoxSolver, &opcodes, initial_witness, &unconstrained_functions, &[]); + let solver_status = acvm.solve(); + assert_eq!(solver_status, ACVMStatus::Solved); + let witness_map = acvm.finalize(); + output_witnesses + .iter() + .map(|witness| *witness_map.get(witness).expect("all witnesses to be set")) + .collect() +} + // Solve the given BlackBoxFuncCall with witnesses: 1, 2 as x, y, resp. #[cfg(test)] fn solve_blackbox_func_call( @@ -735,25 +979,26 @@ fn solve_blackbox_func_call( Option, Option, ) -> BlackBoxFuncCall, - x: (FieldElement, bool), // if false, use a Witness - y: (FieldElement, bool), // if false, use a Witness + lhs: (FieldElement, bool), // if false, use a Witness + rhs: (FieldElement, bool), // if false, use a Witness ) -> FieldElement { - let (x, x_constant) = x; - let (y, y_constant) = y; + let (lhs, lhs_constant) = lhs; + let (rhs, rhs_constant) = rhs; - let initial_witness = WitnessMap::from(BTreeMap::from_iter([(Witness(1), x), (Witness(2), y)])); + let initial_witness = + WitnessMap::from(BTreeMap::from_iter([(Witness(1), lhs), (Witness(2), rhs)])); - let mut lhs = None; - if x_constant { - lhs = Some(x); + let mut lhs_opt = None; + if lhs_constant { + lhs_opt = Some(lhs); } - let mut rhs = None; - if y_constant { - rhs = Some(y); + let mut rhs_opt = None; + if rhs_constant { + rhs_opt = Some(rhs); } - let op = Opcode::BlackBoxFuncCall(blackbox_func_call(lhs, rhs)); + let op = Opcode::BlackBoxFuncCall(blackbox_func_call(lhs_opt, rhs_opt)); let opcodes = vec![op]; let unconstrained_functions = vec![]; let mut acvm = @@ -765,6 +1010,205 @@ fn solve_blackbox_func_call( witness_map[&Witness(3)] } +// N inputs +// 32 outputs +fn sha256_op( + function_inputs_and_outputs: (Vec>, Vec), +) -> BlackBoxFuncCall { + let (function_inputs, outputs) = function_inputs_and_outputs; + BlackBoxFuncCall::SHA256 { + inputs: function_inputs, + outputs: outputs.try_into().expect("SHA256 returns 32 outputs"), + } +} + +// N inputs +// 32 outputs +fn blake2s_op( + function_inputs_and_outputs: (Vec>, Vec), +) -> BlackBoxFuncCall { + let (function_inputs, outputs) = function_inputs_and_outputs; + BlackBoxFuncCall::Blake2s { + inputs: function_inputs, + outputs: outputs.try_into().expect("Blake2s returns 32 outputs"), + } +} + +// N inputs +// 32 outputs +fn blake3_op( + function_inputs_and_outputs: (Vec>, Vec), +) -> BlackBoxFuncCall { + let (function_inputs, outputs) = function_inputs_and_outputs; + BlackBoxFuncCall::Blake3 { + inputs: function_inputs, + outputs: outputs.try_into().expect("Blake3 returns 32 outputs"), + } +} + +// variable inputs +// 32 outputs +fn keccak256_op( + function_inputs_and_outputs: (Vec>, Vec), +) -> BlackBoxFuncCall { + let (function_inputs, outputs) = function_inputs_and_outputs; + let function_inputs_len = function_inputs.len(); + BlackBoxFuncCall::Keccak256 { + inputs: function_inputs, + var_message_size: FunctionInput::constant( + function_inputs_len.into(), + FieldElement::max_num_bits(), + ), + outputs: outputs.try_into().expect("Keccak256 returns 32 outputs"), + } +} + +// var_message_size is the number of bytes to take +// from the input. Note: if `var_message_size` +// is more than the number of bytes in the input, +// then an error is returned. +// +// variable inputs +// 32 outputs +fn keccak256_invalid_message_size_op( + function_inputs_and_outputs: (Vec>, Vec), +) -> BlackBoxFuncCall { + let (function_inputs, outputs) = function_inputs_and_outputs; + let function_inputs_len = function_inputs.len(); + BlackBoxFuncCall::Keccak256 { + inputs: function_inputs, + var_message_size: FunctionInput::constant( + (function_inputs_len - 1).into(), + FieldElement::max_num_bits(), + ), + outputs: outputs.try_into().expect("Keccak256 returns 32 outputs"), + } +} + +// 25 inputs +// 25 outputs +fn keccakf1600_op( + function_inputs_and_outputs: (Vec>, Vec), +) -> BlackBoxFuncCall { + let (function_inputs, outputs) = function_inputs_and_outputs; + BlackBoxFuncCall::Keccakf1600 { + inputs: function_inputs.try_into().expect("Keccakf1600 expects 25 inputs"), + outputs: outputs.try_into().expect("Keccakf1600 returns 25 outputs"), + } +} + +// N inputs +// N outputs +fn poseidon2_permutation_op( + function_inputs_and_outputs: (Vec>, Vec), +) -> BlackBoxFuncCall { + let (inputs, outputs) = function_inputs_and_outputs; + let len = inputs.len() as u32; + BlackBoxFuncCall::Poseidon2Permutation { inputs, outputs, len } +} + +// N inputs +// N outputs +fn poseidon2_permutation_invalid_len_op( + function_inputs_and_outputs: (Vec>, Vec), +) -> BlackBoxFuncCall { + let (inputs, outputs) = function_inputs_and_outputs; + let len = (inputs.len() as u32) + 1; + BlackBoxFuncCall::Poseidon2Permutation { inputs, outputs, len } +} + +// 24 inputs (16 + 8) +// 8 outputs +fn sha256_compression_op( + function_inputs_and_outputs: (Vec>, Vec), +) -> BlackBoxFuncCall { + let (function_inputs, outputs) = function_inputs_and_outputs; + let mut function_inputs = function_inputs.into_iter(); + let inputs = core::array::from_fn(|_| function_inputs.next().unwrap()); + let hash_values = core::array::from_fn(|_| function_inputs.next().unwrap()); + BlackBoxFuncCall::Sha256Compression { + inputs: Box::new(inputs), + hash_values: Box::new(hash_values), + outputs: outputs.try_into().unwrap(), + } +} + +fn into_repr_vec(fields: T) -> Vec +where + T: IntoIterator, +{ + fields.into_iter().map(|field| field.into_repr()).collect() +} + +fn into_repr_mat(fields: T) -> Vec> +where + T: IntoIterator, + U: IntoIterator, +{ + fields.into_iter().map(|field| into_repr_vec(field)).collect() +} + +fn run_both_poseidon2_permutations( + inputs: Vec, +) -> (Vec, Vec) { + let result = solve_array_input_blackbox_call( + inputs.clone(), + inputs.len(), + None, + poseidon2_permutation_op, + ); + + let poseidon2_t = POSEIDON2_CONFIG.t as usize; + let poseidon2_d = 5; + let rounds_f = POSEIDON2_CONFIG.rounds_f as usize; + let rounds_p = POSEIDON2_CONFIG.rounds_p as usize; + let mat_internal_diag_m_1 = into_repr_vec(POSEIDON2_CONFIG.internal_matrix_diagonal); + let mat_internal = vec![]; + let round_constants = into_repr_mat(POSEIDON2_CONFIG.round_constant); + + let external_poseidon2 = + zkhash::poseidon2::poseidon2::Poseidon2::new(&Arc::new(Poseidon2Params::new( + poseidon2_t, + poseidon2_d, + rounds_f, + rounds_p, + &mat_internal_diag_m_1, + &mat_internal, + &round_constants, + ))); + + let expected_result = + external_poseidon2.permutation(&into_repr_vec(drop_use_constant(&inputs))); + (into_repr_vec(result), expected_result) +} + +// Using the given BigInt modulus, solve the following circuit: +// - Convert xs, ys to BigInt's with ID's 0, 1, resp. +// - Run the middle_op: +// + Input BigInt ID's: 0, 1 +// + Output BigInt ID: 2 +// - Run BigIntToLeBytes on the output BigInt ID +// - Output the resulting Vec of LE bytes +fn bigint_solve_binary_op( + middle_op: BlackBoxFuncCall, + modulus: Vec, + lhs: Vec, + rhs: Vec, +) -> Vec { + bigint_solve_binary_op_opt(Some(middle_op), modulus, lhs, rhs) +} + +// Using the given BigInt modulus, solve the following circuit: +// - Convert the input to a BigInt with ID 0 +// - Run BigIntToLeBytes on BigInt ID 0 +// - Output the resulting Vec of LE bytes +fn bigint_solve_from_to_le_bytes( + modulus: Vec, + inputs: Vec, +) -> Vec { + bigint_solve_binary_op_opt(None, modulus, inputs, vec![]) +} + fn function_input_from_option( witness: Witness, opt_constant: Option, @@ -827,6 +1271,33 @@ fn prop_assert_zero_l( (solve_blackbox_func_call(op, op_zero, x), FieldElement::zero()) } +// Test that varying one of the inputs produces a different result +// +// (is the op injective for the given inputs?, failure string) +fn prop_assert_injective( + inputs: Vec, + distinct_inputs: Vec, + num_outputs: usize, + num_bits: Option, + op: F, +) -> (bool, String) +where + F: FnOnce((Vec>, Vec)) -> BlackBoxFuncCall + + Clone, +{ + let equal_inputs = drop_use_constant_eq(&inputs, &distinct_inputs); + let message = format!("not injective:\n{:?}\n{:?}", &inputs, &distinct_inputs); + let outputs_not_equal = + solve_array_input_blackbox_call(inputs, num_outputs, num_bits, op.clone()) + != solve_array_input_blackbox_call(distinct_inputs, num_outputs, num_bits, op); + (equal_inputs || outputs_not_equal, message) +} + +fn field_element_ones() -> FieldElement { + let exponent: FieldElement = (253_u128).into(); + FieldElement::from(2u128).pow(&exponent) - FieldElement::one() +} + prop_compose! { // Use both `u128` and hex proptest strategies fn field_element() @@ -841,11 +1312,174 @@ prop_compose! { } } -fn field_element_ones() -> FieldElement { - let exponent: FieldElement = (253_u128).into(); - FieldElement::from(2u128).pow(&exponent) - FieldElement::one() +prop_compose! { + fn any_distinct_inputs(max_input_bits: Option, min_size: usize, max_size: usize) + (size_and_patch in any::<(usize, usize, usize)>()) // NOTE: macro ambiguity when using (x: T) + (inputs_distinct_inputs in + (proptest::collection::vec(any::<(u128, bool)>(), std::cmp::max(min_size, size_and_patch.0 % max_size)), + proptest::collection::vec(any::<(u128, bool)>(), std::cmp::max(min_size, size_and_patch.0 % max_size))), + size_and_patch in Just(size_and_patch)) + -> (Vec, Vec) { + let (_size, patch_location, patch_value) = size_and_patch; + let (inputs, distinct_inputs) = inputs_distinct_inputs; + let to_input = |(x, use_constant)| { + let modulus = if let Some(max_input_bits) = max_input_bits { + 2u128 << max_input_bits + } else { + 1 + }; + (FieldElement::from(x % modulus), use_constant) + }; + let inputs: Vec<_> = inputs.into_iter().map(to_input).collect(); + let mut distinct_inputs: Vec<_> = distinct_inputs.into_iter().map(to_input).collect(); + + // if equivalent w/o use_constant, patch with the patch_value + if drop_use_constant_eq(&inputs, &distinct_inputs) { + let distinct_inputs_len = distinct_inputs.len(); + let positive_patch_value = std::cmp::max(patch_value, 1); + if distinct_inputs_len != 0 { + distinct_inputs[patch_location % distinct_inputs_len].0 += FieldElement::from(positive_patch_value) + } else { + distinct_inputs.push((FieldElement::zero(), true)) + } + } + + (inputs, distinct_inputs) + } +} + +#[test] +fn poseidon2_permutation_zeroes() { + let use_constants: [bool; 4] = [false; 4]; + let inputs: Vec<_> = [FieldElement::zero(); 4].into_iter().zip(use_constants).collect(); + let (result, expected_result) = run_both_poseidon2_permutations(inputs); + + let internal_expected_result = vec![ + field_from_hex("18DFB8DC9B82229CFF974EFEFC8DF78B1CE96D9D844236B496785C698BC6732E"), + field_from_hex("095C230D1D37A246E8D2D5A63B165FE0FADE040D442F61E25F0590E5FB76F839"), + field_from_hex("0BB9545846E1AFA4FA3C97414A60A20FC4949F537A68CCECA34C5CE71E28AA59"), + field_from_hex("18A4F34C9C6F99335FF7638B82AEED9018026618358873C982BBDDE265B2ED6D"), + ]; + + assert_eq!(expected_result, into_repr_vec(internal_expected_result)); + assert_eq!(result, expected_result); +} + +#[test] +fn sha256_zeros() { + let results = solve_array_input_blackbox_call(vec![], 32, None, sha256_op); + let expected_results: Vec<_> = vec![ + 227, 176, 196, 66, 152, 252, 28, 20, 154, 251, 244, 200, 153, 111, 185, 36, 39, 174, 65, + 228, 100, 155, 147, 76, 164, 149, 153, 27, 120, 82, 184, 85, + ] + .into_iter() + .map(|x: u128| FieldElement::from(x)) + .collect(); + assert_eq!(results, expected_results); +} + +#[test] +fn sha256_compression_zeros() { + let results = solve_array_input_blackbox_call( + [(FieldElement::zero(), false); 24].try_into().unwrap(), + 8, + None, + sha256_compression_op, + ); + let expected_results: Vec<_> = vec![ + 2091193876, 1113340840, 3461668143, 3254913767, 3068490961, 2551409935, 2927503052, + 3205228454, + ] + .into_iter() + .map(|x: u128| FieldElement::from(x)) + .collect(); + assert_eq!(results, expected_results); +} + +#[test] +fn blake2s_zeros() { + let results = solve_array_input_blackbox_call(vec![], 32, None, blake2s_op); + let expected_results: Vec<_> = vec![ + 105, 33, 122, 48, 121, 144, 128, 148, 225, 17, 33, 208, 66, 53, 74, 124, 31, 85, 182, 72, + 44, 161, 165, 30, 27, 37, 13, 253, 30, 208, 238, 249, + ] + .into_iter() + .map(|x: u128| FieldElement::from(x)) + .collect(); + assert_eq!(results, expected_results); +} + +#[test] +fn blake3_zeros() { + let results = solve_array_input_blackbox_call(vec![], 32, None, blake3_op); + let expected_results: Vec<_> = vec![ + 175, 19, 73, 185, 245, 249, 161, 166, 160, 64, 77, 234, 54, 220, 201, 73, 155, 203, 37, + 201, 173, 193, 18, 183, 204, 154, 147, 202, 228, 31, 50, 98, + ] + .into_iter() + .map(|x: u128| FieldElement::from(x)) + .collect(); + assert_eq!(results, expected_results); +} + +#[test] +fn keccak256_zeros() { + let results = solve_array_input_blackbox_call(vec![], 32, None, keccak256_op); + let expected_results: Vec<_> = vec![ + 197, 210, 70, 1, 134, 247, 35, 60, 146, 126, 125, 178, 220, 199, 3, 192, 229, 0, 182, 83, + 202, 130, 39, 59, 123, 250, 216, 4, 93, 133, 164, 112, + ] + .into_iter() + .map(|x: u128| FieldElement::from(x)) + .collect(); + assert_eq!(results, expected_results); +} + +#[test] +fn keccakf1600_zeros() { + let results = solve_array_input_blackbox_call( + [(FieldElement::zero(), false); 25].into(), + 25, + Some(64), + keccakf1600_op, + ); + let expected_results: Vec<_> = vec![ + 17376452488221285863, + 9571781953733019530, + 15391093639620504046, + 13624874521033984333, + 10027350355371872343, + 18417369716475457492, + 10448040663659726788, + 10113917136857017974, + 12479658147685402012, + 3500241080921619556, + 16959053435453822517, + 12224711289652453635, + 9342009439668884831, + 4879704952849025062, + 140226327413610143, + 424854978622500449, + 7259519967065370866, + 7004910057750291985, + 13293599522548616907, + 10105770293752443592, + 10668034807192757780, + 1747952066141424100, + 1654286879329379778, + 8500057116360352059, + 16929593379567477321, + ] + .into_iter() + .map(|x: u128| FieldElement::from(x)) + .collect(); + + assert_eq!(results, expected_results); } +// NOTE: an "average" bigint is large, so consider increasing the number of proptest shrinking +// iterations (from the default 1024) to reach a simplified case, e.g. +// PROPTEST_MAX_SHRINK_ITERS=1024000 proptest! { #[test] @@ -910,4 +1544,329 @@ proptest! { let (lhs, rhs) = prop_assert_zero_l(and_op, zero, x); prop_assert_eq!(lhs, rhs); } + + #[test] + fn poseidon2_permutation_matches_external_impl(inputs in proptest::collection::vec(field_element(), 4)) { + let (result, expected_result) = run_both_poseidon2_permutations(inputs); + prop_assert_eq!(result, expected_result) + } + + #[test] + fn sha256_injective(inputs_distinct_inputs in any_distinct_inputs(None, 0, 32)) { + let (inputs, distinct_inputs) = inputs_distinct_inputs; + let (result, message) = prop_assert_injective(inputs, distinct_inputs, 32, None, sha256_op); + prop_assert!(result, "{}", message); + } + + #[test] + fn sha256_compression_injective(inputs_distinct_inputs in any_distinct_inputs(None, 24, 24)) { + let (inputs, distinct_inputs) = inputs_distinct_inputs; + if inputs.len() == 24 && distinct_inputs.len() == 24 { + let (result, message) = prop_assert_injective(inputs, distinct_inputs, 8, None, sha256_compression_op); + prop_assert!(result, "{}", message); + } + } + + #[test] + fn blake2s_injective(inputs_distinct_inputs in any_distinct_inputs(None, 0, 32)) { + let (inputs, distinct_inputs) = inputs_distinct_inputs; + let (result, message) = prop_assert_injective(inputs, distinct_inputs, 32, None, blake2s_op); + prop_assert!(result, "{}", message); + } + + #[test] + fn blake3_injective(inputs_distinct_inputs in any_distinct_inputs(None, 0, 32)) { + let (inputs, distinct_inputs) = inputs_distinct_inputs; + let (result, message) = prop_assert_injective(inputs, distinct_inputs, 32, None, blake3_op); + prop_assert!(result, "{}", message); + } + + #[test] + fn keccak256_injective(inputs_distinct_inputs in any_distinct_inputs(Some(8), 0, 32)) { + let (inputs, distinct_inputs) = inputs_distinct_inputs; + let (result, message) = prop_assert_injective(inputs, distinct_inputs, 32, Some(32), keccak256_op); + prop_assert!(result, "{}", message); + } + + // TODO(https://github.com/noir-lang/noir/issues/5689): doesn't fail with a user error + // The test failing with "not injective" demonstrates that it returns constant output instead + // of failing with a user error. + #[test] + #[should_panic(expected = "Test failed: not injective")] + fn keccak256_invalid_message_size_fails(inputs_distinct_inputs in any_distinct_inputs(Some(8), 0, 32)) { + let (inputs, distinct_inputs) = inputs_distinct_inputs; + let (result, message) = prop_assert_injective(inputs, distinct_inputs, 32, Some(8), keccak256_invalid_message_size_op); + prop_assert!(result, "{}", message); + } + + #[test] + fn keccakf1600_injective(inputs_distinct_inputs in any_distinct_inputs(Some(8), 25, 25)) { + let (inputs, distinct_inputs) = inputs_distinct_inputs; + assert_eq!(inputs.len(), 25); + assert_eq!(distinct_inputs.len(), 25); + let (result, message) = prop_assert_injective(inputs, distinct_inputs, 25, Some(64), keccakf1600_op); + prop_assert!(result, "{}", message); + } + + // TODO(https://github.com/noir-lang/noir/issues/5699): wrong failure message + #[test] + #[should_panic(expected = "Failure(BlackBoxFunctionFailed(Poseidon2Permutation, \"the number of inputs does not match specified length. 6 != 7\"))")] + fn poseidon2_permutation_invalid_size_fails(inputs_distinct_inputs in any_distinct_inputs(None, 6, 6)) { + let (inputs, distinct_inputs) = inputs_distinct_inputs; + let (result, message) = prop_assert_injective(inputs, distinct_inputs, 1, None, poseidon2_permutation_invalid_len_op); + prop_assert!(result, "{}", message); + } + + #[test] + fn bigint_from_to_le_bytes_zero_one(modulus in select(allowed_bigint_moduli()), zero_or_ones_constant: bool, use_constant: bool) { + let zero_function_input = if zero_or_ones_constant { + FieldElement::one() + } else { + FieldElement::zero() + }; + let zero_or_ones: Vec<_> = modulus.iter().map(|_| (zero_function_input, use_constant)).collect(); + let expected_results = drop_use_constant(&zero_or_ones); + let results = bigint_solve_from_to_le_bytes(modulus.clone(), zero_or_ones); + prop_assert_eq!(results, expected_results) + } + + #[test] + fn bigint_from_to_le_bytes((input, modulus) in bigint_with_modulus()) { + let expected_results: Vec<_> = drop_use_constant(&input); + let results = bigint_solve_from_to_le_bytes(modulus.clone(), input); + prop_assert_eq!(results, expected_results) + } + + #[test] + // TODO(https://github.com/noir-lang/noir/issues/5580): desired behavior? + fn bigint_from_to_le_bytes_extra_input_bytes((input, modulus) in bigint_with_modulus(), extra_bytes_len: u8, extra_bytes in proptest::collection::vec(any::<(u8, bool)>(), u8::MAX as usize)) { + let mut input = input; + let mut extra_bytes: Vec<_> = extra_bytes.into_iter().take(extra_bytes_len as usize).map(|(x, use_constant)| (FieldElement::from(x as u128), use_constant)).collect(); + input.append(&mut extra_bytes); + let expected_results: Vec<_> = drop_use_constant(&input); + let results = bigint_solve_from_to_le_bytes(modulus.clone(), input); + prop_assert_eq!(results, expected_results) + } + + #[test] + // TODO(https://github.com/noir-lang/noir/issues/5580): desired behavior? + #[should_panic(expected = "Test failed: assertion failed: `(left == right)`")] + fn bigint_from_to_le_bytes_bigger_than_u8((input, modulus) in bigint_with_modulus(), patch_location: usize, larger_value: u16, use_constant: bool) { + let mut input = input; + let patch_location = patch_location % input.len(); + let larger_value = FieldElement::from(std::cmp::max((u8::MAX as u16) + 1, larger_value) as u128); + input[patch_location] = (larger_value, use_constant); + let expected_results: Vec<_> = drop_use_constant(&input); + let results = bigint_solve_from_to_le_bytes(modulus.clone(), input); + prop_assert_eq!(results, expected_results) + } + + #[test] + // TODO(https://github.com/noir-lang/noir/issues/5578): this test attempts to use a guaranteed-invalid BigInt modulus + // #[should_panic(expected = "attempt to add with overflow")] + fn bigint_from_to_le_bytes_disallowed_modulus(mut modulus in select(allowed_bigint_moduli()), patch_location: usize, patch_amount: u8, zero_or_ones_constant: bool, use_constant: bool) { + let allowed_moduli: HashSet> = allowed_bigint_moduli().into_iter().collect(); + let mut patch_location = patch_location % modulus.len(); + let patch_amount = patch_amount.clamp(1, u8::MAX); + while allowed_moduli.contains(&modulus) { + modulus[patch_location] = patch_amount.wrapping_add(modulus[patch_location]); + patch_location += 1; + patch_location %= modulus.len(); + } + + let zero_function_input = if zero_or_ones_constant { + FieldElement::zero() + } else { + FieldElement::one() + }; + let zero: Vec<_> = modulus.iter().map(|_| (zero_function_input, use_constant)).collect(); + let expected_results: Vec<_> = drop_use_constant(&zero); + let results = bigint_solve_from_to_le_bytes(modulus.clone(), zero); + + prop_assert_eq!(results, expected_results) + } + + #[test] + fn bigint_add_commutative((xs, ys, modulus) in bigint_pair_with_modulus()) { + let lhs_results = bigint_solve_binary_op(bigint_add_op(), modulus.clone(), xs.clone(), ys.clone()); + let rhs_results = bigint_solve_binary_op(bigint_add_op(), modulus, ys, xs); + + prop_assert_eq!(lhs_results, rhs_results) + } + + #[test] + fn bigint_mul_commutative((xs, ys, modulus) in bigint_pair_with_modulus()) { + let lhs_results = bigint_solve_binary_op(bigint_mul_op(), modulus.clone(), xs.clone(), ys.clone()); + let rhs_results = bigint_solve_binary_op(bigint_mul_op(), modulus, ys, xs); + + prop_assert_eq!(lhs_results, rhs_results) + } + + #[test] + fn bigint_add_associative((xs, ys, zs, modulus) in bigint_triple_with_modulus()) { + // f(f(xs, ys), zs) == + let op_xs_ys = bigint_solve_binary_op(bigint_add_op(), modulus.clone(), xs.clone(), ys.clone()); + let xs_ys = use_witnesses(op_xs_ys); + let op_xs_ys_op_zs = bigint_solve_binary_op(bigint_add_op(), modulus.clone(), xs_ys, zs.clone()); + + // f(xs, f(ys, zs)) + let op_ys_zs = bigint_solve_binary_op(bigint_add_op(), modulus.clone(), ys.clone(), zs.clone()); + let ys_zs = use_witnesses(op_ys_zs); + let op_xs_op_ys_zs = bigint_solve_binary_op(bigint_add_op(), modulus, xs, ys_zs); + + prop_assert_eq!(op_xs_ys_op_zs, op_xs_op_ys_zs) + } + + #[test] + fn bigint_mul_associative((xs, ys, zs, modulus) in bigint_triple_with_modulus()) { + // f(f(xs, ys), zs) == + let op_xs_ys = bigint_solve_binary_op(bigint_mul_op(), modulus.clone(), xs.clone(), ys.clone()); + let xs_ys = use_witnesses(op_xs_ys); + let op_xs_ys_op_zs = bigint_solve_binary_op(bigint_mul_op(), modulus.clone(), xs_ys, zs.clone()); + + // f(xs, f(ys, zs)) + let op_ys_zs = bigint_solve_binary_op(bigint_mul_op(), modulus.clone(), ys.clone(), zs.clone()); + let ys_zs = use_witnesses(op_ys_zs); + let op_xs_op_ys_zs = bigint_solve_binary_op(bigint_mul_op(), modulus, xs, ys_zs); + + prop_assert_eq!(op_xs_ys_op_zs, op_xs_op_ys_zs) + } + + #[test] + fn bigint_mul_add_distributive((xs, ys, zs, modulus) in bigint_triple_with_modulus()) { + // xs * (ys + zs) == + let add_ys_zs = bigint_solve_binary_op(bigint_add_op(), modulus.clone(), ys.clone(), zs.clone()); + let add_ys_zs = use_witnesses(add_ys_zs); + let mul_xs_add_ys_zs = bigint_solve_binary_op(bigint_mul_op(), modulus.clone(), xs.clone(), add_ys_zs); + + // xs * ys + xs * zs + let mul_xs_ys = bigint_solve_binary_op(bigint_mul_op(), modulus.clone(), xs.clone(), ys); + let mul_xs_ys = use_witnesses(mul_xs_ys); + let mul_xs_zs = bigint_solve_binary_op(bigint_mul_op(), modulus.clone(), xs, zs); + let mul_xs_zs = use_witnesses(mul_xs_zs); + let add_mul_xs_ys_mul_xs_zs = bigint_solve_binary_op(bigint_add_op(), modulus, mul_xs_ys, mul_xs_zs); + + prop_assert_eq!(mul_xs_add_ys_zs, add_mul_xs_ys_mul_xs_zs) + } + + + #[test] + fn bigint_add_zero_l((xs, modulus) in bigint_with_modulus()) { + let zero = bigint_zeroed(&xs); + let expected_results = drop_use_constant(&xs); + let results = bigint_solve_binary_op(bigint_add_op(), modulus, zero, xs); + + prop_assert_eq!(results, expected_results) + } + + #[test] + fn bigint_mul_zero_l((xs, modulus) in bigint_with_modulus()) { + let zero = bigint_zeroed(&xs); + let expected_results = drop_use_constant(&zero); + let results = bigint_solve_binary_op(bigint_mul_op(), modulus, zero, xs); + prop_assert_eq!(results, expected_results) + } + + #[test] + fn bigint_mul_one_l((xs, modulus) in bigint_with_modulus()) { + let one = bigint_to_one(&xs); + let expected_results: Vec<_> = drop_use_constant(&xs); + let results = bigint_solve_binary_op(bigint_mul_op(), modulus, one, xs); + prop_assert_eq!(results, expected_results) + } + + #[test] + fn bigint_sub_self((xs, modulus) in bigint_with_modulus()) { + let expected_results = drop_use_constant(&bigint_zeroed(&xs)); + let results = bigint_solve_binary_op(bigint_sub_op(), modulus, xs.clone(), xs); + prop_assert_eq!(results, expected_results) + } + + #[test] + fn bigint_sub_zero((xs, modulus) in bigint_with_modulus()) { + let zero = bigint_zeroed(&xs); + let expected_results: Vec<_> = drop_use_constant(&xs); + let results = bigint_solve_binary_op(bigint_sub_op(), modulus, xs, zero); + prop_assert_eq!(results, expected_results) + } + + #[test] + fn bigint_sub_one((xs, modulus) in bigint_with_modulus()) { + let one = bigint_to_one(&xs); + let expected_results: Vec<_> = drop_use_constant(&xs); + let results = bigint_solve_binary_op(bigint_sub_op(), modulus, xs, one); + prop_assert!(results != expected_results, "{:?} == {:?}", results, expected_results) + } + + #[test] + fn bigint_div_self((xs, modulus) in bigint_with_modulus()) { + let one = drop_use_constant(&bigint_to_one(&xs)); + let results = bigint_solve_binary_op(bigint_div_op(), modulus, xs.clone(), xs); + prop_assert_eq!(results, one) + } + + #[test] + // TODO(https://github.com/noir-lang/noir/issues/5645) + fn bigint_div_by_zero((xs, modulus) in bigint_with_modulus()) { + let zero = bigint_zeroed(&xs); + let expected_results = drop_use_constant(&zero); + let results = bigint_solve_binary_op(bigint_div_op(), modulus, xs, zero); + prop_assert_eq!(results, expected_results) + } + + #[test] + fn bigint_div_one((xs, modulus) in bigint_with_modulus()) { + let one = bigint_to_one(&xs); + let expected_results = drop_use_constant(&xs); + let results = bigint_solve_binary_op(bigint_div_op(), modulus, xs, one); + prop_assert_eq!(results, expected_results) + } + + #[test] + fn bigint_div_zero((xs, modulus) in bigint_with_modulus()) { + let zero = bigint_zeroed(&xs); + let expected_results = drop_use_constant(&zero); + let results = bigint_solve_binary_op(bigint_div_op(), modulus, zero, xs); + prop_assert_eq!(results, expected_results) + } + + #[test] + fn bigint_add_sub((xs, ys, modulus) in bigint_pair_with_modulus()) { + let expected_results = drop_use_constant(&xs); + let add_results = bigint_solve_binary_op(bigint_add_op(), modulus.clone(), xs, ys.clone()); + let add_bigint = use_witnesses(add_results); + let results = bigint_solve_binary_op(bigint_sub_op(), modulus, add_bigint, ys); + + prop_assert_eq!(results, expected_results) + } + + #[test] + fn bigint_sub_add((xs, ys, modulus) in bigint_pair_with_modulus()) { + let expected_results = drop_use_constant(&xs); + let sub_results = bigint_solve_binary_op(bigint_sub_op(), modulus.clone(), xs, ys.clone()); + let add_bigint = use_witnesses(sub_results); + let results = bigint_solve_binary_op(bigint_add_op(), modulus, add_bigint, ys); + + prop_assert_eq!(results, expected_results) + } + + #[test] + fn bigint_div_mul((xs, ys, modulus) in bigint_pair_with_modulus()) { + let expected_results = drop_use_constant(&xs); + let div_results = bigint_solve_binary_op(bigint_div_op(), modulus.clone(), xs, ys.clone()); + let div_bigint = use_witnesses(div_results); + let results = bigint_solve_binary_op(bigint_mul_op(), modulus, div_bigint, ys); + + prop_assert_eq!(results, expected_results) + } + + #[test] + fn bigint_mul_div((xs, ys, modulus) in bigint_pair_with_modulus()) { + let expected_results = drop_use_constant(&xs); + let mul_results = bigint_solve_binary_op(bigint_mul_op(), modulus.clone(), xs, ys.clone()); + let mul_bigint = use_witnesses(mul_results); + let results = bigint_solve_binary_op(bigint_div_op(), modulus, mul_bigint, ys); + + prop_assert_eq!(results, expected_results) + } } diff --git a/acvm-repo/bn254_blackbox_solver/src/lib.rs b/acvm-repo/bn254_blackbox_solver/src/lib.rs index 6897116e90..43ee6a9ddd 100644 --- a/acvm-repo/bn254_blackbox_solver/src/lib.rs +++ b/acvm-repo/bn254_blackbox_solver/src/lib.rs @@ -13,7 +13,7 @@ mod schnorr; use ark_ec::AffineRepr; pub use embedded_curve_ops::{embedded_curve_add, multi_scalar_mul}; pub use generator::generators::derive_generators; -pub use poseidon2::poseidon2_permutation; +pub use poseidon2::{field_from_hex, poseidon2_permutation, Poseidon2Config, POSEIDON2_CONFIG}; // Temporary hack, this ensure that we always use a bn254 field here // without polluting the feature flags of the `acir_field` crate. diff --git a/acvm-repo/bn254_blackbox_solver/src/poseidon2.rs b/acvm-repo/bn254_blackbox_solver/src/poseidon2.rs index 18ed0b1d8a..dd3e8b725c 100644 --- a/acvm-repo/bn254_blackbox_solver/src/poseidon2.rs +++ b/acvm-repo/bn254_blackbox_solver/src/poseidon2.rs @@ -16,26 +16,26 @@ pub(crate) struct Poseidon2<'a> { config: &'a Poseidon2Config, } -struct Poseidon2Config { - t: u32, - rounds_f: u32, - rounds_p: u32, - internal_matrix_diagonal: [FieldElement; 4], - round_constant: [[FieldElement; 4]; 64], +pub struct Poseidon2Config { + pub t: u32, + pub rounds_f: u32, + pub rounds_p: u32, + pub internal_matrix_diagonal: [FieldElement; 4], + pub round_constant: [[FieldElement; 4]; 64], } -fn field_from_hex(hex: &str) -> FieldElement { +pub fn field_from_hex(hex: &str) -> FieldElement { FieldElement::from_be_bytes_reduce(&hex::decode(hex).expect("Should be passed only valid hex")) } lazy_static! { - static ref INTERNAL_MATRIX_DIAGONAL: [FieldElement; 4] = [ + pub static ref INTERNAL_MATRIX_DIAGONAL: [FieldElement; 4] = [ field_from_hex("10dc6e9c006ea38b04b1e03b4bd9490c0d03f98929ca1d7fb56821fd19d3b6e7"), field_from_hex("0c28145b6a44df3e0149b3d0a30b3bb599df9756d4dd9b84a86b38cfb45a740b"), field_from_hex("00544b8338791518b2c7645a50392798b21f75bb60e3596170067d00141cac15"), field_from_hex("222c01175718386f2e2e82eb122789e352e105a3b8fa852613bc534433ee428b"), ]; - static ref ROUND_CONSTANT: [[FieldElement; 4]; 64] = [ + pub static ref ROUND_CONSTANT: [[FieldElement; 4]; 64] = [ [ field_from_hex("19b849f69450b06848da1d39bd5e4a4302bb86744edc26238b0878e269ed23e5"), field_from_hex("265ddfe127dd51bd7239347b758f0a1320eb2cc7450acc1dad47f80c8dcf34d6"), @@ -421,7 +421,7 @@ lazy_static! { field_from_hex("176563472456aaa746b694c60e1823611ef39039b2edc7ff391e6f2293d2c404"), ], ]; - static ref POSEIDON2_CONFIG: Poseidon2Config = Poseidon2Config { + pub static ref POSEIDON2_CONFIG: Poseidon2Config = Poseidon2Config { t: 4, rounds_f: 8, rounds_p: 56, diff --git a/aztec_macros/src/transforms/note_interface.rs b/aztec_macros/src/transforms/note_interface.rs index 46ed75620a..8df1d128c6 100644 --- a/aztec_macros/src/transforms/note_interface.rs +++ b/aztec_macros/src/transforms/note_interface.rs @@ -88,6 +88,7 @@ pub fn generate_note_interface_impl( let mut note_fields = vec![]; let note_interface_generics = trait_impl .trait_generics + .ordered_args .iter() .map(|gen| match gen.typ.clone() { UnresolvedTypeData::Named(path, _, _) => Ok(path.last_name().to_string()), @@ -120,7 +121,7 @@ pub fn generate_note_interface_impl( ident("header"), make_type(UnresolvedTypeData::Named( chained_dep!("aztec", "note", "note_header", "NoteHeader"), - vec![], + Default::default(), false, )), ); diff --git a/aztec_macros/src/transforms/storage.rs b/aztec_macros/src/transforms/storage.rs index ce82b4d4b6..7dd21f1a8a 100644 --- a/aztec_macros/src/transforms/storage.rs +++ b/aztec_macros/src/transforms/storage.rs @@ -1,8 +1,9 @@ use acvm::acir::AcirField; use noirc_errors::Span; use noirc_frontend::ast::{ - BlockExpression, Expression, ExpressionKind, FunctionDefinition, Ident, Literal, NoirFunction, - NoirStruct, Pattern, StatementKind, TypeImpl, UnresolvedType, UnresolvedTypeData, + BlockExpression, Expression, ExpressionKind, FunctionDefinition, GenericTypeArgs, Ident, + Literal, NoirFunction, NoirStruct, Pattern, StatementKind, TypeImpl, UnresolvedType, + UnresolvedTypeData, }; use noirc_frontend::{ graph::CrateId, @@ -54,13 +55,13 @@ pub fn check_for_storage_definition( fn inject_context_in_storage_field(field: &mut UnresolvedType) -> Result<(), AztecMacroError> { match &mut field.typ { UnresolvedTypeData::Named(path, generics, _) => { - generics.push(make_type(UnresolvedTypeData::Named( + generics.ordered_args.push(make_type(UnresolvedTypeData::Named( ident_path("Context"), - vec![], + GenericTypeArgs::default(), false, ))); match path.last_name() { - "Map" => inject_context_in_storage_field(&mut generics[1]), + "Map" => inject_context_in_storage_field(&mut generics.ordered_args[1]), _ => Ok(()), } } @@ -144,7 +145,10 @@ pub fn generate_storage_field_constructor( generate_storage_field_constructor( // Map is expected to have three generic parameters: key, value and context (i.e. // Map. Here `get(1)` fetches the value type. - &(type_ident.clone(), generics.get(1).unwrap().clone()), + &( + type_ident.clone(), + generics.ordered_args.get(1).unwrap().clone(), + ), variable("slot"), )?, ), @@ -219,8 +223,11 @@ pub fn generate_storage_implementation( // This is the type over which the impl is generic. let generic_context_ident = ident("Context"); - let generic_context_type = - make_type(UnresolvedTypeData::Named(ident_path("Context"), vec![], true)); + let generic_context_type = make_type(UnresolvedTypeData::Named( + ident_path("Context"), + GenericTypeArgs::default(), + true, + )); let init = NoirFunction::normal(FunctionDefinition::normal( &ident("init"), @@ -231,13 +238,12 @@ pub fn generate_storage_implementation( &return_type(chained_path!("Self")), )); + let ordered_args = vec![generic_context_type.clone()]; + let generics = GenericTypeArgs { ordered_args, named_args: Vec::new() }; + let storage_impl = TypeImpl { object_type: UnresolvedType { - typ: UnresolvedTypeData::Named( - chained_path!(storage_struct_name), - vec![generic_context_type.clone()], - true, - ), + typ: UnresolvedTypeData::Named(chained_path!(storage_struct_name), generics, true), span: Span::default(), }, type_span: Span::default(), diff --git a/aztec_macros/src/utils/parse_utils.rs b/aztec_macros/src/utils/parse_utils.rs index 6b5db103c0..4c6cbb10d9 100644 --- a/aztec_macros/src/utils/parse_utils.rs +++ b/aztec_macros/src/utils/parse_utils.rs @@ -2,12 +2,13 @@ use noirc_frontend::{ ast::{ ArrayLiteral, AssignStatement, BlockExpression, CallExpression, CastExpression, ConstrainStatement, ConstructorExpression, Expression, ExpressionKind, ForLoopStatement, - ForRange, FunctionReturnType, Ident, IfExpression, IndexExpression, InfixExpression, - LValue, Lambda, LetStatement, Literal, MemberAccessExpression, MethodCallExpression, - ModuleDeclaration, NoirFunction, NoirStruct, NoirTrait, NoirTraitImpl, NoirTypeAlias, Path, - PathSegment, Pattern, PrefixExpression, Statement, StatementKind, TraitImplItem, TraitItem, - TypeImpl, UnresolvedGeneric, UnresolvedGenerics, UnresolvedTraitConstraint, UnresolvedType, - UnresolvedTypeData, UnresolvedTypeExpression, UseTree, UseTreeKind, + ForRange, FunctionReturnType, GenericTypeArgs, Ident, IfExpression, IndexExpression, + InfixExpression, LValue, Lambda, LetStatement, Literal, MemberAccessExpression, + MethodCallExpression, ModuleDeclaration, NoirFunction, NoirStruct, NoirTrait, + NoirTraitImpl, NoirTypeAlias, Path, PathSegment, Pattern, PrefixExpression, Statement, + StatementKind, TraitImplItem, TraitItem, TypeImpl, UnresolvedGeneric, UnresolvedGenerics, + UnresolvedTraitConstraint, UnresolvedType, UnresolvedTypeData, UnresolvedTypeExpression, + UseTree, UseTreeKind, }, parser::{Item, ItemKind, ParsedSubModule, ParserError}, ParsedModule, @@ -297,6 +298,14 @@ fn empty_unresolved_types(unresolved_types: &mut [UnresolvedType]) { } } +fn empty_type_args(generics: &mut GenericTypeArgs) { + empty_unresolved_types(&mut generics.ordered_args); + for (name, typ) in &mut generics.named_args { + empty_ident(name); + empty_unresolved_type(typ); + } +} + fn empty_unresolved_type(unresolved_type: &mut UnresolvedType) { unresolved_type.span = Default::default(); @@ -318,11 +327,11 @@ fn empty_unresolved_type(unresolved_type: &mut UnresolvedType) { } UnresolvedTypeData::Named(path, unresolved_types, _) => { empty_path(path); - empty_unresolved_types(unresolved_types); + empty_type_args(unresolved_types); } UnresolvedTypeData::TraitAsType(path, unresolved_types) => { empty_path(path); - empty_unresolved_types(unresolved_types); + empty_type_args(unresolved_types); } UnresolvedTypeData::MutableReference(unresolved_type) => { empty_unresolved_type(unresolved_type) @@ -543,5 +552,10 @@ fn empty_unresolved_type_expression(unresolved_type_expression: &mut UnresolvedT empty_unresolved_type_expression(rhs); } UnresolvedTypeExpression::Constant(_, _) => (), + UnresolvedTypeExpression::AsTraitPath(path) => { + empty_unresolved_type(&mut path.typ); + empty_path(&mut path.trait_path); + empty_ident(&mut path.impl_item); + } } } diff --git a/compiler/noirc_driver/src/lib.rs b/compiler/noirc_driver/src/lib.rs index 467bda2ca8..cb3a4d25c9 100644 --- a/compiler/noirc_driver/src/lib.rs +++ b/compiler/noirc_driver/src/lib.rs @@ -123,6 +123,12 @@ pub struct CompileOptions { /// Temporary flag to enable the experimental arithmetic generics feature #[arg(long, hide = true)] pub arithmetic_generics: bool, + + /// Flag to turn off the compiler check for under constrained values. + /// Warning: This can improve compilation speed but can also lead to correctness errors. + /// This check should always be run on production code. + #[arg(long)] + pub skip_underconstrained_check: bool, } pub fn parse_expression_width(input: &str) -> Result { @@ -574,6 +580,7 @@ pub fn compile_no_check( ExpressionWidth::default() }, emit_ssa: if options.emit_ssa { Some(context.package_build_path.clone()) } else { None }, + skip_underconstrained_check: options.skip_underconstrained_check, }; let SsaProgramArtifact { program, debug, warnings, names, brillig_names, error_types, .. } = diff --git a/compiler/noirc_errors/src/debug_info.rs b/compiler/noirc_errors/src/debug_info.rs index 1a254175c0..b480d20fde 100644 --- a/compiler/noirc_errors/src/debug_info.rs +++ b/compiler/noirc_errors/src/debug_info.rs @@ -1,4 +1,5 @@ use acvm::acir::circuit::brillig::BrilligFunctionId; +use acvm::acir::circuit::BrilligOpcodeLocation; use acvm::acir::circuit::OpcodeLocation; use acvm::compiler::AcirTransformationMap; @@ -98,8 +99,8 @@ pub struct DebugInfo { /// that they should be serialized to/from strings. #[serde_as(as = "BTreeMap")] pub locations: BTreeMap>, - #[serde_as(as = "BTreeMap<_, BTreeMap>")] - pub brillig_locations: BTreeMap>>, + pub brillig_locations: + BTreeMap>>, pub variables: DebugVariables, pub functions: DebugFunctions, pub types: DebugTypes, @@ -116,7 +117,10 @@ pub struct OpCodesCount { impl DebugInfo { pub fn new( locations: BTreeMap>, - brillig_locations: BTreeMap>>, + brillig_locations: BTreeMap< + BrilligFunctionId, + BTreeMap>, + >, variables: DebugVariables, functions: DebugFunctions, types: DebugTypes, diff --git a/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs b/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs index b4e55a52a3..54dbd4716a 100644 --- a/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs +++ b/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs @@ -735,7 +735,7 @@ impl<'block> BrilligBlock<'block> { 1, ); } - Instruction::EnableSideEffects { .. } => { + Instruction::EnableSideEffectsIf { .. } => { todo!("enable_side_effects not supported by brillig") } Instruction::IfElse { .. } => { diff --git a/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_directive.rs b/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_directive.rs index fca1f60544..c17088a5d8 100644 --- a/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_directive.rs +++ b/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_directive.rs @@ -63,7 +63,7 @@ pub(crate) fn directive_invert() -> GeneratedBrillig { /// /// This is equivalent to the Noir (pseudo)code /// -/// ```ignore +/// ```text /// fn quotient(a: T, b: T) -> (T,T) { /// (a/b, a-a/b*b) /// } diff --git a/compiler/noirc_evaluator/src/ssa.rs b/compiler/noirc_evaluator/src/ssa.rs index 9daf98e606..2d138c13f7 100644 --- a/compiler/noirc_evaluator/src/ssa.rs +++ b/compiler/noirc_evaluator/src/ssa.rs @@ -64,6 +64,9 @@ pub struct SsaEvaluatorOptions { /// Dump the unoptimized SSA to the supplied path if it exists pub emit_ssa: Option, + + /// Skip the check for under constrained values + pub skip_underconstrained_check: bool, } pub(crate) struct ArtifactsAndWarnings(Artifacts, Vec); @@ -111,13 +114,19 @@ pub(crate) fn optimize_into_acir( .run_pass(Ssa::inline_functions_with_no_predicates, "After Inlining:") .run_pass(Ssa::remove_if_else, "After Remove IfElse:") .run_pass(Ssa::fold_constants, "After Constant Folding:") - .run_pass(Ssa::remove_enable_side_effects, "After EnableSideEffects removal:") + .run_pass(Ssa::remove_enable_side_effects, "After EnableSideEffectsIf removal:") .run_pass(Ssa::fold_constants_using_constraints, "After Constraint Folding:") .run_pass(Ssa::dead_instruction_elimination, "After Dead Instruction Elimination:") .run_pass(Ssa::array_set_optimization, "After Array Set Optimizations:") .finish(); - let ssa_level_warnings = ssa.check_for_underconstrained_values(); + let ssa_level_warnings = if options.skip_underconstrained_check { + vec![] + } else { + time("After Check for Underconstrained Values", options.print_codegen_timings, || { + ssa.check_for_underconstrained_values() + }) + }; let brillig = time("SSA to Brillig", options.print_codegen_timings, || { ssa.to_brillig(options.enable_brillig_logging) }); diff --git a/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs b/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs index 70e6c923ce..a27354d7cb 100644 --- a/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs +++ b/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs @@ -492,7 +492,7 @@ impl AcirContext { self.sub_var(sum, mul) } else { // Implement OR in terms of AND - // (NOT a) NAND (NOT b) => a OR b + // (NOT a) AND (NOT b) => NOT (a OR b) let a = self.not_var(lhs, typ.clone())?; let b = self.not_var(rhs, typ.clone())?; let a_and_b = self.and_var(a, b, typ.clone())?; @@ -1957,6 +1957,9 @@ impl AcirContext { } Some(optional_value) => { let mut values = Vec::new(); + if let AcirValue::DynamicArray(_) = optional_value { + unreachable!("Dynamic array should already be initialized"); + } self.initialize_array_inner(&mut values, optional_value)?; values } @@ -1986,8 +1989,16 @@ impl AcirContext { self.initialize_array_inner(witnesses, value)?; } } - AcirValue::DynamicArray(_) => { - unreachable!("Dynamic array should already be initialized"); + AcirValue::DynamicArray(AcirDynamicArray { block_id, len, .. }) => { + let dynamic_array_values = try_vecmap(0..len, |i| { + let index_var = self.add_constant(i); + + let read = self.read_from_memory(block_id, &index_var)?; + Ok::(AcirValue::Var(read, AcirType::field())) + })?; + for value in dynamic_array_values { + self.initialize_array_inner(witnesses, value)?; + } } } Ok(()) diff --git a/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/generated_acir.rs b/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/generated_acir.rs index 2e61a82d5b..0cad7b9c97 100644 --- a/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/generated_acir.rs +++ b/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/generated_acir.rs @@ -1,6 +1,6 @@ //! `GeneratedAcir` is constructed as part of the `acir_gen` pass to accumulate all of the ACIR //! program as it is being converted from SSA form. -use std::collections::BTreeMap; +use std::{collections::BTreeMap, u32}; use crate::{ brillig::{brillig_gen::brillig_directive, brillig_ir::artifact::GeneratedBrillig}, @@ -11,7 +11,7 @@ use acvm::acir::{ circuit::{ brillig::{BrilligFunctionId, BrilligInputs, BrilligOutputs}, opcodes::{BlackBoxFuncCall, FunctionInput, Opcode as AcirOpcode}, - AssertionPayload, OpcodeLocation, + AssertionPayload, BrilligOpcodeLocation, OpcodeLocation, }, native_types::Witness, BlackBoxFunc, @@ -53,7 +53,7 @@ pub(crate) struct GeneratedAcir { /// Brillig function id -> Opcodes locations map /// This map is used to prevent redundant locations being stored for the same Brillig entry point. - pub(crate) brillig_locations: BTreeMap, + pub(crate) brillig_locations: BTreeMap, /// Source code location of the current instruction being processed /// None if we do not know the location @@ -77,6 +77,8 @@ pub(crate) struct GeneratedAcir { /// Correspondence between an opcode index (in opcodes) and the source code call stack which generated it pub(crate) type OpcodeToLocationsMap = BTreeMap; +pub(crate) type BrilligOpcodeToLocationsMap = BTreeMap; + #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] pub(crate) enum BrilligStdlibFunc { Inverse, @@ -591,6 +593,7 @@ impl GeneratedAcir { return; } + // TODO(https://github.com/noir-lang/noir/issues/5792) for (brillig_index, message) in generated_brillig.assert_messages.iter() { self.assertion_payloads.insert( OpcodeLocation::Brillig { @@ -606,13 +609,10 @@ impl GeneratedAcir { } for (brillig_index, call_stack) in generated_brillig.locations.iter() { - self.brillig_locations.entry(brillig_function_index).or_default().insert( - OpcodeLocation::Brillig { - acir_index: self.opcodes.len() - 1, - brillig_index: *brillig_index, - }, - call_stack.clone(), - ); + self.brillig_locations + .entry(brillig_function_index) + .or_default() + .insert(BrilligOpcodeLocation(*brillig_index), call_stack.clone()); } } @@ -626,6 +626,7 @@ impl GeneratedAcir { OpcodeLocation::Acir(index) => index, _ => panic!("should not have brillig index"), }; + match &mut self.opcodes[acir_index] { AcirOpcode::BrilligCall { id, .. } => *id = brillig_function_index, _ => panic!("expected brillig call opcode"), diff --git a/compiler/noirc_evaluator/src/ssa/acir_gen/mod.rs b/compiler/noirc_evaluator/src/ssa/acir_gen/mod.rs index 346d6fcd92..37ec43fb13 100644 --- a/compiler/noirc_evaluator/src/ssa/acir_gen/mod.rs +++ b/compiler/noirc_evaluator/src/ssa/acir_gen/mod.rs @@ -706,7 +706,7 @@ impl<'a> Context<'a> { self.convert_ssa_truncate(*value, *bit_size, *max_bit_size, dfg)?; self.define_result_var(dfg, instruction_id, result_acir_var); } - Instruction::EnableSideEffects { condition } => { + Instruction::EnableSideEffectsIf { condition } => { let acir_var = self.convert_numeric_value(*condition, dfg)?; self.current_side_effects_enabled_var = acir_var; } @@ -1165,11 +1165,15 @@ impl<'a> Context<'a> { let index_var = self.convert_numeric_value(index, dfg)?; let index_var = self.get_flattened_index(&array_typ, array_id, index_var, dfg)?; - // predicate_index = index*predicate + (1-predicate)*offset - let offset = self.acir_context.add_constant(offset); - let sub = self.acir_context.sub_var(index_var, offset)?; - let pred = self.acir_context.mul_var(sub, self.current_side_effects_enabled_var)?; - let predicate_index = self.acir_context.add_var(pred, offset)?; + let predicate_index = if dfg.is_safe_index(index, array_id) { + index_var + } else { + // index*predicate + (1-predicate)*offset + let offset = self.acir_context.add_constant(offset); + let sub = self.acir_context.sub_var(index_var, offset)?; + let pred = self.acir_context.mul_var(sub, self.current_side_effects_enabled_var)?; + self.acir_context.add_var(pred, offset)? + }; let new_value = if let Some(store) = store_value { let store_value = self.convert_value(store, dfg); diff --git a/compiler/noirc_evaluator/src/ssa/checks/check_for_underconstrained_values.rs b/compiler/noirc_evaluator/src/ssa/checks/check_for_underconstrained_values.rs index 24fcb8f61d..79db4e645e 100644 --- a/compiler/noirc_evaluator/src/ssa/checks/check_for_underconstrained_values.rs +++ b/compiler/noirc_evaluator/src/ssa/checks/check_for_underconstrained_values.rs @@ -255,7 +255,7 @@ impl Context { } Instruction::Allocate { .. } | Instruction::DecrementRc { .. } - | Instruction::EnableSideEffects { .. } + | Instruction::EnableSideEffectsIf { .. } | Instruction::IncrementRc { .. } | Instruction::RangeCheck { .. } => {} } diff --git a/compiler/noirc_evaluator/src/ssa/function_builder/mod.rs b/compiler/noirc_evaluator/src/ssa/function_builder/mod.rs index 49184bf4c6..8cc42241d9 100644 --- a/compiler/noirc_evaluator/src/ssa/function_builder/mod.rs +++ b/compiler/noirc_evaluator/src/ssa/function_builder/mod.rs @@ -337,7 +337,7 @@ impl FunctionBuilder { /// Insert an enable_side_effects_if instruction. These are normally only automatically /// inserted during the flattening pass when branching is removed. pub(crate) fn insert_enable_side_effects_if(&mut self, condition: ValueId) { - self.insert_instruction(Instruction::EnableSideEffects { condition }, None); + self.insert_instruction(Instruction::EnableSideEffectsIf { condition }, None); } /// Terminates the current block with the given terminator instruction diff --git a/compiler/noirc_evaluator/src/ssa/ir/dfg.rs b/compiler/noirc_evaluator/src/ssa/ir/dfg.rs index f06f46d7af..d79916a9e1 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/dfg.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/dfg.rs @@ -472,6 +472,14 @@ impl DataFlowGraph { } } + /// A constant index less than the array length is safe + pub(crate) fn is_safe_index(&self, index: ValueId, array: ValueId) -> bool { + #[allow(clippy::match_like_matches_macro)] + match (self.type_of_value(array), self.get_numeric_constant(index)) { + (Type::Array(_, len), Some(index)) if index.to_u128() < (len as u128) => true, + _ => false, + } + } /// Sets the terminator instruction for the given basic block pub(crate) fn set_block_terminator( &mut self, diff --git a/compiler/noirc_evaluator/src/ssa/ir/instruction.rs b/compiler/noirc_evaluator/src/ssa/ir/instruction.rs index 1b3466c76f..c3cd27bf17 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/instruction.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/instruction.rs @@ -218,13 +218,23 @@ pub(crate) enum Instruction { Store { address: ValueId, value: ValueId }, /// Provides a context for all instructions that follow up until the next - /// `EnableSideEffects` is encountered, for stating a condition that determines whether + /// `EnableSideEffectsIf` is encountered, for stating a condition that determines whether /// such instructions are allowed to have side-effects. /// + /// For example, + /// ```text + /// EnableSideEffectsIf condition0; + /// code0; + /// EnableSideEffectsIf condition1; + /// code1; + /// ``` + /// - `code0` will have side effects iff `condition0` evaluates to `true` + /// - `code1` will have side effects iff `condition1` evaluates to `true` + /// /// This instruction is only emitted after the cfg flattening pass, and is used to annotate /// instruction regions with an condition that corresponds to their position in the CFG's /// if-branching structure. - EnableSideEffects { condition: ValueId }, + EnableSideEffectsIf { condition: ValueId }, /// Retrieve a value from an array at the given index ArrayGet { array: ValueId, index: ValueId }, @@ -249,6 +259,17 @@ pub(crate) enum Instruction { DecrementRc { value: ValueId }, /// Merge two values returned from opposite branches of a conditional into one. + /// + /// ```text + /// if then_condition { + /// then_value + /// } else { // else_condition = !then_condition + /// else_value + /// } + /// ``` + /// + /// Where we save the result of !then_condition so that we have the same + /// ValueId for it each time. IfElse { then_condition: ValueId, then_value: ValueId, @@ -279,7 +300,7 @@ impl Instruction { | Instruction::IncrementRc { .. } | Instruction::DecrementRc { .. } | Instruction::RangeCheck { .. } - | Instruction::EnableSideEffects { .. } => InstructionResultType::None, + | Instruction::EnableSideEffectsIf { .. } => InstructionResultType::None, Instruction::Allocate { .. } | Instruction::Load { .. } | Instruction::ArrayGet { .. } @@ -306,7 +327,7 @@ impl Instruction { match self { // These either have side-effects or interact with memory - EnableSideEffects { .. } + EnableSideEffectsIf { .. } | Allocate | Load { .. } | Store { .. } @@ -362,7 +383,7 @@ impl Instruction { Constrain(..) | Store { .. } - | EnableSideEffects { .. } + | EnableSideEffectsIf { .. } | IncrementRc { .. } | DecrementRc { .. } | RangeCheck { .. } => false, @@ -396,16 +417,12 @@ impl Instruction { true } - // `ArrayGet`s which read from "known good" indices from an array don't need a predicate. Instruction::ArrayGet { array, index } => { - #[allow(clippy::match_like_matches_macro)] - match (dfg.type_of_value(*array), dfg.get_numeric_constant(*index)) { - (Type::Array(_, len), Some(index)) if index.to_u128() < (len as u128) => false, - _ => true, - } + // `ArrayGet`s which read from "known good" indices from an array should not need a predicate. + !dfg.is_safe_index(*index, *array) } - Instruction::EnableSideEffects { .. } | Instruction::ArraySet { .. } => true, + Instruction::EnableSideEffectsIf { .. } | Instruction::ArraySet { .. } => true, Instruction::Call { func, .. } => match dfg[*func] { Value::Function(_) => true, @@ -470,8 +487,8 @@ impl Instruction { Instruction::Store { address, value } => { Instruction::Store { address: f(*address), value: f(*value) } } - Instruction::EnableSideEffects { condition } => { - Instruction::EnableSideEffects { condition: f(*condition) } + Instruction::EnableSideEffectsIf { condition } => { + Instruction::EnableSideEffectsIf { condition: f(*condition) } } Instruction::ArrayGet { array, index } => { Instruction::ArrayGet { array: f(*array), index: f(*index) } @@ -545,7 +562,7 @@ impl Instruction { f(*index); f(*value); } - Instruction::EnableSideEffects { condition } => { + Instruction::EnableSideEffectsIf { condition } => { f(*condition); } Instruction::IncrementRc { value } @@ -682,11 +699,11 @@ impl Instruction { Instruction::Call { func, arguments } => { simplify_call(*func, arguments, dfg, block, ctrl_typevars, call_stack) } - Instruction::EnableSideEffects { condition } => { + Instruction::EnableSideEffectsIf { condition } => { if let Some(last) = dfg[block].instructions().last().copied() { let last = &mut dfg[last]; - if matches!(last, Instruction::EnableSideEffects { .. }) { - *last = Instruction::EnableSideEffects { condition: *condition }; + if matches!(last, Instruction::EnableSideEffectsIf { .. }) { + *last = Instruction::EnableSideEffectsIf { condition: *condition }; return Remove; } } diff --git a/compiler/noirc_evaluator/src/ssa/ir/printer.rs b/compiler/noirc_evaluator/src/ssa/ir/printer.rs index 656bd26620..e8c9d01988 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/printer.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/printer.rs @@ -176,7 +176,7 @@ fn display_instruction_inner( Instruction::Store { address, value } => { writeln!(f, "store {} at {}", show(*value), show(*address)) } - Instruction::EnableSideEffects { condition } => { + Instruction::EnableSideEffectsIf { condition } => { writeln!(f, "enable_side_effects {}", show(*condition)) } Instruction::ArrayGet { array, index } => { diff --git a/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs b/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs index 160105d27e..c8f6d201d8 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs @@ -159,9 +159,9 @@ impl Context { *side_effects_enabled_var, ); - // If we just inserted an `Instruction::EnableSideEffects`, we need to update `side_effects_enabled_var` + // If we just inserted an `Instruction::EnableSideEffectsIf`, we need to update `side_effects_enabled_var` // so that we use the correct set of constrained values in future. - if let Instruction::EnableSideEffects { condition } = instruction { + if let Instruction::EnableSideEffectsIf { condition } = instruction { *side_effects_enabled_var = condition; }; } diff --git a/compiler/noirc_evaluator/src/ssa/opt/die.rs b/compiler/noirc_evaluator/src/ssa/opt/die.rs index 1aa0c2efbd..b980406211 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/die.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/die.rs @@ -245,7 +245,7 @@ impl Context { let instruction_id = *instruction_id; let instruction = &function.dfg[instruction_id]; - if let Instruction::EnableSideEffects { condition } = instruction { + if let Instruction::EnableSideEffectsIf { condition } = instruction { side_effects_condition = Some(*condition); // We still need to keep the EnableSideEffects instruction diff --git a/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs b/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs index 288e41cb99..72ed02b00a 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg.rs @@ -11,7 +11,7 @@ //! elimination (DIE) pass. //! //! Though CFG information is lost during this pass, some key information is retained in the form -//! of `EnableSideEffect` instructions. Each time the flattening pass enters and exits a branch of +//! of `EnableSideEffectsIf` instructions. Each time the flattening pass enters and exits a branch of //! a jmpif, an instruction is inserted to capture a condition that is analogous to the activeness //! of the program point. For example: //! @@ -573,7 +573,7 @@ impl<'f> Context<'f> { } /// Checks the branch condition on the top of the stack and uses it to build and insert an - /// `EnableSideEffects` instruction into the entry block. + /// `EnableSideEffectsIf` instruction into the entry block. /// /// If the stack is empty, a "true" u1 constant is taken to be the active condition. This is /// necessary for re-enabling side-effects when re-emerging to a branch depth of 0. @@ -584,7 +584,7 @@ impl<'f> Context<'f> { self.inserter.function.dfg.make_constant(FieldElement::one(), Type::unsigned(1)) } }; - let enable_side_effects = Instruction::EnableSideEffects { condition }; + let enable_side_effects = Instruction::EnableSideEffectsIf { condition }; let call_stack = self.inserter.function.dfg.get_value_call_stack(condition); self.insert_instruction_with_typevars(enable_side_effects, None, call_stack); } diff --git a/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg/value_merger.rs b/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg/value_merger.rs index 90e24a1d5e..7c2db62b0e 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg/value_merger.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/flatten_cfg/value_merger.rs @@ -374,7 +374,7 @@ impl<'a> ValueMerger<'a> { for (index, element_type, condition) in changed_indices { let typevars = Some(vec![element_type.clone()]); - let instruction = Instruction::EnableSideEffects { condition }; + let instruction = Instruction::EnableSideEffectsIf { condition }; self.insert_instruction(instruction); let mut get_element = |array, typevars| { @@ -398,7 +398,7 @@ impl<'a> ValueMerger<'a> { array = self.insert_array_set(array, index, value, Some(condition)).first(); } - let instruction = Instruction::EnableSideEffects { condition: current_condition }; + let instruction = Instruction::EnableSideEffectsIf { condition: current_condition }; self.insert_instruction(instruction); Some(array) } diff --git a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs index d78399a3e6..1ff593a153 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs @@ -502,7 +502,7 @@ impl<'function> PerFunctionContext<'function> { } None => self.push_instruction(*id), }, - Instruction::EnableSideEffects { condition } => { + Instruction::EnableSideEffectsIf { condition } => { side_effects_enabled = Some(self.translate_value(*condition)); self.push_instruction(*id); } diff --git a/compiler/noirc_evaluator/src/ssa/opt/remove_enable_side_effects.rs b/compiler/noirc_evaluator/src/ssa/opt/remove_enable_side_effects.rs index 224060e131..a56786b260 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/remove_enable_side_effects.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/remove_enable_side_effects.rs @@ -1,13 +1,13 @@ -//! The goal of the "remove enable side effects" optimization pass is to delay any [Instruction::EnableSideEffects] +//! The goal of the "remove enable side effects" optimization pass is to delay any [Instruction::EnableSideEffectsIf] //! instructions such that they cover the minimum number of instructions possible. //! //! The pass works as follows: -//! - Insert instructions until an [Instruction::EnableSideEffects] is encountered, save this [InstructionId]. +//! - Insert instructions until an [Instruction::EnableSideEffectsIf] is encountered, save this [InstructionId]. //! - Continue inserting instructions until either -//! - Another [Instruction::EnableSideEffects] is encountered, if so then drop the previous [InstructionId] in favour +//! - Another [Instruction::EnableSideEffectsIf] is encountered, if so then drop the previous [InstructionId] in favour //! of this one. -//! - An [Instruction] with side-effects is encountered, if so then insert the currently saved [Instruction::EnableSideEffects] -//! before the [Instruction]. Continue inserting instructions until the next [Instruction::EnableSideEffects] is encountered. +//! - An [Instruction] with side-effects is encountered, if so then insert the currently saved [Instruction::EnableSideEffectsIf] +//! before the [Instruction]. Continue inserting instructions until the next [Instruction::EnableSideEffectsIf] is encountered. use std::collections::HashSet; use acvm::{acir::AcirField, FieldElement}; @@ -70,10 +70,10 @@ impl Context { for instruction_id in instructions { let instruction = &function.dfg[instruction_id]; - // If we run into another `Instruction::EnableSideEffects` before encountering any + // If we run into another `Instruction::EnableSideEffectsIf` before encountering any // instructions with side effects then we can drop the instruction we're holding and - // continue with the new `Instruction::EnableSideEffects`. - if let Instruction::EnableSideEffects { condition } = instruction { + // continue with the new `Instruction::EnableSideEffectsIf`. + if let Instruction::EnableSideEffectsIf { condition } = instruction { // If this instruction isn't changing the currently active condition then we can ignore it. if active_condition == *condition { continue; @@ -98,7 +98,7 @@ impl Context { } // If we hit an instruction which is affected by the side effects var then we must insert the - // `Instruction::EnableSideEffects` before we insert this new instruction. + // `Instruction::EnableSideEffectsIf` before we insert this new instruction. if Self::responds_to_side_effects_var(&function.dfg, instruction) { if let Some(enable_side_effects_instruction_id) = last_side_effects_enabled_instruction.take() @@ -140,7 +140,7 @@ impl Context { | IncrementRc { .. } | DecrementRc { .. } => false, - EnableSideEffects { .. } + EnableSideEffectsIf { .. } | ArrayGet { .. } | ArraySet { .. } | Allocate diff --git a/compiler/noirc_evaluator/src/ssa/opt/remove_if_else.rs b/compiler/noirc_evaluator/src/ssa/opt/remove_if_else.rs index b1ca5fa25a..cc02faeb3d 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/remove_if_else.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/remove_if_else.rs @@ -128,7 +128,7 @@ impl Context { self.slice_sizes.insert(result, old_capacity); function.dfg[block].instructions_mut().push(instruction); } - Instruction::EnableSideEffects { condition } => { + Instruction::EnableSideEffectsIf { condition } => { current_conditional = *condition; function.dfg[block].instructions_mut().push(instruction); } diff --git a/compiler/noirc_frontend/src/ast/mod.rs b/compiler/noirc_frontend/src/ast/mod.rs index 781630571e..6f6d5cbccd 100644 --- a/compiler/noirc_frontend/src/ast/mod.rs +++ b/compiler/noirc_frontend/src/ast/mod.rs @@ -112,10 +112,10 @@ pub enum UnresolvedTypeData { Parenthesized(Box), /// A Named UnresolvedType can be a struct type or a type variable - Named(Path, Vec, /*is_synthesized*/ bool), + Named(Path, GenericTypeArgs, /*is_synthesized*/ bool), /// A Trait as return type or parameter of function, including its generics - TraitAsType(Path, Vec), + TraitAsType(Path, GenericTypeArgs), /// &mut T MutableReference(Box), @@ -151,6 +151,46 @@ pub struct UnresolvedType { pub span: Span, } +/// An argument to a generic type or trait. +#[derive(Debug, PartialEq, Eq, Clone, Hash)] +pub enum GenericTypeArg { + /// An ordered argument, e.g. `` + Ordered(UnresolvedType), + + /// A named argument, e.g. ``. + /// Used for associated types. + Named(Ident, UnresolvedType), +} + +#[derive(Debug, Default, PartialEq, Eq, Clone, Hash)] +pub struct GenericTypeArgs { + /// Each ordered argument, e.g. `` + pub ordered_args: Vec, + + /// All named arguments, e.g. ``. + /// Used for associated types. + pub named_args: Vec<(Ident, UnresolvedType)>, +} + +impl GenericTypeArgs { + pub fn is_empty(&self) -> bool { + self.ordered_args.is_empty() && self.named_args.is_empty() + } +} + +impl From> for GenericTypeArgs { + fn from(args: Vec) -> Self { + let mut this = GenericTypeArgs::default(); + for arg in args { + match arg { + GenericTypeArg::Ordered(typ) => this.ordered_args.push(typ), + GenericTypeArg::Named(name, typ) => this.named_args.push((name, typ)), + } + } + this + } +} + /// Type wrapper for a member access pub struct UnaryRhsMemberAccess { pub method_or_field: Ident, @@ -176,6 +216,7 @@ pub enum UnresolvedTypeExpression { Box, Span, ), + AsTraitPath(Box), } impl Recoverable for UnresolvedType { @@ -184,6 +225,32 @@ impl Recoverable for UnresolvedType { } } +impl std::fmt::Display for GenericTypeArg { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + GenericTypeArg::Ordered(typ) => typ.fmt(f), + GenericTypeArg::Named(name, typ) => write!(f, "{name} = {typ}"), + } + } +} + +impl std::fmt::Display for GenericTypeArgs { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if self.is_empty() { + Ok(()) + } else { + let mut args = vecmap(&self.ordered_args, ToString::to_string).join(", "); + + if !self.ordered_args.is_empty() && !self.named_args.is_empty() { + args += ", "; + } + + args += &vecmap(&self.named_args, |(name, typ)| format!("{name} = {typ}")).join(", "); + write!(f, "<{args}>") + } + } +} + impl std::fmt::Display for UnresolvedTypeData { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { use UnresolvedTypeData::*; @@ -195,22 +262,8 @@ impl std::fmt::Display for UnresolvedTypeData { Signedness::Signed => write!(f, "i{num_bits}"), Signedness::Unsigned => write!(f, "u{num_bits}"), }, - Named(s, args, _) => { - let args = vecmap(args, |arg| ToString::to_string(&arg.typ)); - if args.is_empty() { - write!(f, "{s}") - } else { - write!(f, "{}<{}>", s, args.join(", ")) - } - } - TraitAsType(s, args) => { - let args = vecmap(args, |arg| ToString::to_string(&arg.typ)); - if args.is_empty() { - write!(f, "impl {s}") - } else { - write!(f, "impl {}<{}>", s, args.join(", ")) - } - } + Named(s, args, _) => write!(f, "{s}{args}"), + TraitAsType(s, args) => write!(f, "impl {s}{args}"), Tuple(elements) => { let elements = vecmap(elements, ToString::to_string); write!(f, "({})", elements.join(", ")) @@ -263,6 +316,7 @@ impl std::fmt::Display for UnresolvedTypeExpression { UnresolvedTypeExpression::BinaryOperation(lhs, op, rhs, _) => { write!(f, "({lhs} {op} {rhs})") } + UnresolvedTypeExpression::AsTraitPath(path) => write!(f, "{path}"), } } } @@ -334,6 +388,9 @@ impl UnresolvedTypeExpression { UnresolvedTypeExpression::Variable(path) => path.span(), UnresolvedTypeExpression::Constant(_, span) => *span, UnresolvedTypeExpression::BinaryOperation(_, _, _, span) => *span, + UnresolvedTypeExpression::AsTraitPath(path) => { + path.trait_path.span.merge(path.impl_item.span()) + } } } @@ -376,6 +433,9 @@ impl UnresolvedTypeExpression { }; Ok(UnresolvedTypeExpression::BinaryOperation(lhs, op, rhs, expr.span)) } + ExpressionKind::AsTraitPath(path) => { + Ok(UnresolvedTypeExpression::AsTraitPath(Box::new(path))) + } _ => Err(expr), } } diff --git a/compiler/noirc_frontend/src/ast/statement.rs b/compiler/noirc_frontend/src/ast/statement.rs index 1803319707..edccf545a0 100644 --- a/compiler/noirc_frontend/src/ast/statement.rs +++ b/compiler/noirc_frontend/src/ast/statement.rs @@ -7,8 +7,8 @@ use iter_extended::vecmap; use noirc_errors::{Span, Spanned}; use super::{ - BlockExpression, Expression, ExpressionKind, IndexExpression, MemberAccessExpression, - MethodCallExpression, UnresolvedType, + BlockExpression, Expression, ExpressionKind, GenericTypeArgs, IndexExpression, + MemberAccessExpression, MethodCallExpression, UnresolvedType, }; use crate::elaborator::types::SELF_TYPE_NAME; use crate::lexer::token::SpannedToken; @@ -371,6 +371,7 @@ impl UseTree { pub struct AsTraitPath { pub typ: UnresolvedType, pub trait_path: Path, + pub trait_generics: GenericTypeArgs, pub impl_item: Ident, } diff --git a/compiler/noirc_frontend/src/ast/traits.rs b/compiler/noirc_frontend/src/ast/traits.rs index f8f8ef667b..e3221f287d 100644 --- a/compiler/noirc_frontend/src/ast/traits.rs +++ b/compiler/noirc_frontend/src/ast/traits.rs @@ -10,6 +10,8 @@ use crate::ast::{ use crate::macros_api::SecondaryAttribute; use crate::node_interner::TraitId; +use super::GenericTypeArgs; + /// AST node for trait definitions: /// `trait name { ... items ... }` #[derive(Clone, Debug)] @@ -62,7 +64,8 @@ pub struct NoirTraitImpl { pub impl_generics: UnresolvedGenerics, pub trait_name: Path, - pub trait_generics: Vec, + + pub trait_generics: GenericTypeArgs, pub object_type: UnresolvedType, @@ -88,7 +91,7 @@ pub struct UnresolvedTraitConstraint { pub struct TraitBound { pub trait_path: Path, pub trait_id: Option, // initially None, gets assigned during DC - pub trait_generics: Vec, + pub trait_generics: GenericTypeArgs, } #[derive(Clone, Debug)] @@ -179,21 +182,13 @@ impl Display for UnresolvedTraitConstraint { impl Display for TraitBound { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let generics = vecmap(&self.trait_generics, |generic| generic.to_string()); - if !generics.is_empty() { - write!(f, "{}<{}>", self.trait_path, generics.join(", ")) - } else { - write!(f, "{}", self.trait_path) - } + write!(f, "{}{}", self.trait_path, self.trait_generics) } } impl Display for NoirTraitImpl { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let generics = vecmap(&self.trait_generics, |generic| generic.to_string()); - let generics = generics.join(", "); - - writeln!(f, "impl {}<{}> for {} {{", self.trait_name, generics, self.object_type)?; + writeln!(f, "impl {}{} for {} {{", self.trait_name, self.trait_generics, self.object_type)?; for item in self.items.iter() { let item = item.to_string(); diff --git a/compiler/noirc_frontend/src/elaborator/comptime.rs b/compiler/noirc_frontend/src/elaborator/comptime.rs index 2b78c02e53..01b4585640 100644 --- a/compiler/noirc_frontend/src/elaborator/comptime.rs +++ b/compiler/noirc_frontend/src/elaborator/comptime.rs @@ -284,13 +284,14 @@ impl<'context> Elaborator<'context> { }); } TopLevelStatement::TraitImpl(mut trait_impl) => { - let methods = dc_mod::collect_trait_impl_functions( - self.interner, - &mut trait_impl, - self.crate_id, - self.file, - self.local_module, - ); + let (methods, associated_types, associated_constants) = + dc_mod::collect_trait_impl_items( + self.interner, + &mut trait_impl, + self.crate_id, + self.file, + self.local_module, + ); generated_items.trait_impls.push(UnresolvedTraitImpl { file_id: self.file, @@ -301,6 +302,8 @@ impl<'context> Elaborator<'context> { methods, generics: trait_impl.impl_generics, where_clause: trait_impl.where_clause, + associated_types, + associated_constants, // These last fields are filled in later trait_id: None, diff --git a/compiler/noirc_frontend/src/elaborator/expressions.rs b/compiler/noirc_frontend/src/elaborator/expressions.rs index 65e94c4fcf..cf0b4f4071 100644 --- a/compiler/noirc_frontend/src/elaborator/expressions.rs +++ b/compiler/noirc_frontend/src/elaborator/expressions.rs @@ -11,7 +11,7 @@ use crate::{ hir::{ comptime::{self, InterpreterError}, resolution::errors::ResolverError, - type_check::TypeCheckError, + type_check::{generics::TraitGenerics, TypeCheckError}, }, hir_def::{ expr::{ @@ -397,7 +397,7 @@ impl<'context> Elaborator<'context> { // so that the backend doesn't need to worry about methods // TODO: update object_type here? let ((function_id, function_name), function_call) = method_call.into_function_call( - &method_ref, + method_ref, object_type, is_macro_call, location, @@ -620,7 +620,7 @@ impl<'context> Elaborator<'context> { let constraint = TraitConstraint { typ: operand_type.clone(), trait_id: trait_id.trait_id, - trait_generics: Vec::new(), + trait_generics: TraitGenerics::default(), span, }; self.push_trait_constraint(constraint, expr_id); diff --git a/compiler/noirc_frontend/src/elaborator/lints.rs b/compiler/noirc_frontend/src/elaborator/lints.rs index a4140043ac..78df10fa94 100644 --- a/compiler/noirc_frontend/src/elaborator/lints.rs +++ b/compiler/noirc_frontend/src/elaborator/lints.rs @@ -236,9 +236,9 @@ pub(crate) fn overflowing_int( }, HirExpression::Prefix(expr) => { overflowing_int(interner, &expr.rhs, annotated_type); - if expr.operator == UnaryOp::Minus { + if expr.operator == UnaryOp::Minus && annotated_type.is_unsigned() { errors.push(TypeCheckError::InvalidUnaryOp { - kind: "annotated_type".to_string(), + kind: annotated_type.to_string(), span, }); } diff --git a/compiler/noirc_frontend/src/elaborator/mod.rs b/compiler/noirc_frontend/src/elaborator/mod.rs index bba87f9a93..53b4653607 100644 --- a/compiler/noirc_frontend/src/elaborator/mod.rs +++ b/compiler/noirc_frontend/src/elaborator/mod.rs @@ -4,7 +4,7 @@ use std::{ }; use crate::{ - ast::{FunctionKind, UnresolvedTraitConstraint}, + ast::{FunctionKind, GenericTypeArgs, UnresolvedTraitConstraint}, hir::{ def_collector::dc_crate::{ filter_literal_globals, CompilationError, ImplMap, UnresolvedGlobal, UnresolvedStruct, @@ -13,7 +13,7 @@ use crate::{ def_map::DefMaps, resolution::{errors::ResolverError, path_resolver::PathResolver}, scope::ScopeForest as GenericScopeForest, - type_check::TypeCheckError, + type_check::{generics::TraitGenerics, TypeCheckError}, }, hir_def::{ expr::{HirCapturedVar, HirIdent}, @@ -40,7 +40,7 @@ use crate::{ Context, }, hir_def::function::{FuncMeta, HirFunction}, - macros_api::{Param, Path, UnresolvedType, UnresolvedTypeData}, + macros_api::{Param, Path, UnresolvedTypeData}, node_interner::TraitImplId, }; use crate::{ @@ -122,6 +122,9 @@ pub struct Elaborator<'context> { /// to the corresponding trait impl ID. current_trait_impl: Option, + /// The trait we're currently resolving, if we are resolving one. + current_trait: Option, + /// In-resolution names /// /// This needs to be a set because we can have multiple in-resolution @@ -210,6 +213,7 @@ impl<'context> Elaborator<'context> { debug_comptime_in_file, unresolved_globals: BTreeMap::new(), enable_arithmetic_generics, + current_trait: None, } } @@ -486,7 +490,8 @@ impl<'context> Elaborator<'context> { self.verify_trait_constraint( &constraint.typ, constraint.trait_id, - &constraint.trait_generics, + &constraint.trait_generics.ordered, + &constraint.trait_generics.named, expr_id, span, ); @@ -502,7 +507,7 @@ impl<'context> Elaborator<'context> { fn desugar_impl_trait_arg( &mut self, trait_path: Path, - trait_generics: Vec, + trait_generics: GenericTypeArgs, generics: &mut Vec, trait_constraints: &mut Vec, ) -> Type { @@ -682,33 +687,13 @@ impl<'context> Elaborator<'context> { bound: &TraitBound, typ: Type, ) -> Option { - let the_trait = self.lookup_trait_or_error(bound.trait_path.clone())?; - - let resolved_generics = &the_trait.generics.clone(); - assert_eq!(resolved_generics.len(), bound.trait_generics.len()); - let generics_with_types = resolved_generics.iter().zip(&bound.trait_generics); - let trait_generics = vecmap(generics_with_types, |(generic, typ)| { - self.resolve_type_inner(typ.clone(), &generic.kind) - }); - let the_trait = self.lookup_trait_or_error(bound.trait_path.clone())?; let trait_id = the_trait.id; + let span = bound.trait_path.span; - let span = bound.trait_path.span(); - - let expected_generics = the_trait.generics.len(); - let actual_generics = trait_generics.len(); - - if actual_generics != expected_generics { - let item_name = the_trait.name.to_string(); - self.push_err(ResolverError::IncorrectGenericCount { - span, - item_name, - actual: actual_generics, - expected: expected_generics, - }); - } + let (ordered, named) = self.resolve_type_args(bound.trait_generics.clone(), trait_id, span); + let trait_generics = TraitGenerics { ordered, named }; Some(TraitConstraint { typ, trait_id, trait_generics, span }) } @@ -1028,11 +1013,7 @@ impl<'context> Elaborator<'context> { if let Some(trait_id) = trait_impl.trait_id { self.generics = trait_impl.resolved_generics.clone(); - let where_clause = trait_impl - .where_clause - .iter() - .flat_map(|item| self.resolve_trait_constraint(item)) - .collect::>(); + let where_clause = self.resolve_trait_constraints(&trait_impl.where_clause); self.collect_trait_impl_methods(trait_id, trait_impl, &where_clause); @@ -1050,9 +1031,9 @@ impl<'context> Elaborator<'context> { ident: trait_impl.trait_path.last_ident(), typ: self_type.clone(), trait_id, - trait_generics: trait_generics.clone(), + trait_generics, file: trait_impl.file_id, - where_clause, + where_clause: where_clause.clone(), methods, }); @@ -1061,7 +1042,6 @@ impl<'context> Elaborator<'context> { if let Err((prev_span, prev_file)) = self.interner.add_trait_implementation( self_type.clone(), trait_id, - trait_generics, trait_impl.impl_id.expect("impl_id should be set in define_function_metas"), generics, resolved_trait_impl, @@ -1381,7 +1361,7 @@ impl<'context> Elaborator<'context> { let trait_id = self.resolve_trait_by_path(trait_impl.trait_path.clone()); trait_impl.trait_id = trait_id; - let unresolved_type = &trait_impl.object_type; + let unresolved_type = trait_impl.object_type.clone(); self.add_generics(&trait_impl.generics); trait_impl.resolved_generics = self.generics.clone(); @@ -1391,24 +1371,28 @@ impl<'context> Elaborator<'context> { method.def.where_clause.append(&mut trait_impl.where_clause.clone()); } + // Add each associated type to the list of named type arguments + let mut trait_generics = trait_impl.trait_generics.clone(); + trait_generics.named_args.extend(self.take_unresolved_associated_types(trait_impl)); + + let impl_id = self.interner.next_trait_impl_id(); + self.current_trait_impl = Some(impl_id); + // Fetch trait constraints here - let trait_generics = trait_impl + let (ordered_generics, named_generics) = trait_impl .trait_id - .and_then(|trait_id| self.resolve_trait_impl_generics(trait_impl, trait_id)) - .unwrap_or_else(|| { - // We still resolve as to continue type checking - vecmap(&trait_impl.trait_generics, |generic| self.resolve_type(generic.clone())) - }); + .map(|trait_id| { + self.resolve_type_args(trait_generics, trait_id, trait_impl.trait_path.span) + }) + .unwrap_or_default(); - trait_impl.resolved_trait_generics = trait_generics; + trait_impl.resolved_trait_generics = ordered_generics; + self.interner.set_associated_types_for_impl(impl_id, named_generics); - let self_type = self.resolve_type(unresolved_type.clone()); + let self_type = self.resolve_type(unresolved_type); self.self_type = Some(self_type.clone()); trait_impl.methods.self_type = Some(self_type); - let impl_id = self.interner.next_trait_impl_id(); - self.current_trait_impl = Some(impl_id); - self.define_function_metas_for_functions(&mut trait_impl.methods); trait_impl.resolved_object_type = self.self_type.take(); diff --git a/compiler/noirc_frontend/src/elaborator/patterns.rs b/compiler/noirc_frontend/src/elaborator/patterns.rs index bd44e087e7..06c153d4c1 100644 --- a/compiler/noirc_frontend/src/elaborator/patterns.rs +++ b/compiler/noirc_frontend/src/elaborator/patterns.rs @@ -585,25 +585,7 @@ impl<'context> Elaborator<'context> { // will replace each trait generic with a fresh type variable, rather than // the type used in the trait constraint (if it exists). See #4088. if let ImplKind::TraitMethod(_, constraint, assumed) = &ident.impl_kind { - let the_trait = self.interner.get_trait(constraint.trait_id); - assert_eq!(the_trait.generics.len(), constraint.trait_generics.len()); - - for (param, arg) in the_trait.generics.iter().zip(&constraint.trait_generics) { - // Avoid binding t = t - if !arg.occurs(param.type_var.id()) { - bindings.insert(param.type_var.id(), (param.type_var.clone(), arg.clone())); - } - } - - // If the trait impl is already assumed to exist we should add any type bindings for `Self`. - // Otherwise `self` will be replaced with a fresh type variable, which will require the user - // to specify a redundant type annotation. - if *assumed { - bindings.insert( - the_trait.self_type_typevar_id, - (the_trait.self_type_typevar.clone(), constraint.typ.clone()), - ); - } + self.bind_generics_from_trait_constraint(constraint, *assumed, &mut bindings); } // An identifiers type may be forall-quantified in the case of generic functions. @@ -622,6 +604,7 @@ impl<'context> Elaborator<'context> { let span = self.interner.expr_span(&expr_id); let location = self.interner.expr_location(&expr_id); + // This instantiates a trait's generics as well which need to be set // when the constraint below is later solved for when the function is // finished. How to link the two? @@ -643,10 +626,9 @@ impl<'context> Elaborator<'context> { if let ImplKind::TraitMethod(_, mut constraint, assumed) = ident.impl_kind { constraint.apply_bindings(&bindings); if assumed { - let trait_impl = TraitImplKind::Assumed { - object_type: constraint.typ, - trait_generics: constraint.trait_generics, - }; + let trait_generics = constraint.trait_generics.clone(); + let object_type = constraint.typ; + let trait_impl = TraitImplKind::Assumed { object_type, trait_generics }; self.interner.select_impl_for_expression(expr_id, trait_impl); } else { // Currently only one impl can be selected per expr_id, so this diff --git a/compiler/noirc_frontend/src/elaborator/statements.rs b/compiler/noirc_frontend/src/elaborator/statements.rs index da4492eb21..0bb8641b6b 100644 --- a/compiler/noirc_frontend/src/elaborator/statements.rs +++ b/compiler/noirc_frontend/src/elaborator/statements.rs @@ -66,36 +66,31 @@ impl<'context> Elaborator<'context> { ) -> (HirStatement, Type) { let expr_span = let_stmt.expression.span; let (expression, expr_type) = self.elaborate_expression(let_stmt.expression); - let annotated_type = self.resolve_type(let_stmt.r#type); + let annotated_type = self.resolve_inferred_type(let_stmt.r#type); let definition = match global_id { None => DefinitionKind::Local(Some(expression)), Some(id) => DefinitionKind::Global(id), }; - // First check if the LHS is unspecified - // If so, then we give it the same type as the expression - let r#type = if annotated_type != Type::Error { - // Now check if LHS is the same type as the RHS - // Importantly, we do not coerce any types implicitly - self.unify_with_coercions(&expr_type, &annotated_type, expression, expr_span, || { - TypeCheckError::TypeMismatch { - expected_typ: annotated_type.to_string(), - expr_typ: expr_type.to_string(), - expr_span, - } - }); - if annotated_type.is_integer() { - let errors = lints::overflowing_int(self.interner, &expression, &annotated_type); - for error in errors { - self.push_err(error); - } + // Now check if LHS is the same type as the RHS + // Importantly, we do not coerce any types implicitly + self.unify_with_coercions(&expr_type, &annotated_type, expression, expr_span, || { + TypeCheckError::TypeMismatch { + expected_typ: annotated_type.to_string(), + expr_typ: expr_type.to_string(), + expr_span, } - annotated_type - } else { - expr_type - }; + }); + + if annotated_type.is_integer() { + let errors = lints::overflowing_int(self.interner, &expression, &annotated_type); + for error in errors { + self.push_err(error); + } + } + let r#type = annotated_type; let pattern = self.elaborate_pattern_and_store_ids( let_stmt.pattern, r#type.clone(), @@ -424,7 +419,7 @@ impl<'context> Elaborator<'context> { // If we get here the type has no field named 'access.rhs'. // Now we specialize the error message based on whether we know the object type in question yet. if let Type::TypeVariable(..) = &lhs_type { - self.push_err(TypeCheckError::TypeAnnotationsNeeded { span }); + self.push_err(TypeCheckError::TypeAnnotationsNeededForFieldAccess { span }); } else if lhs_type != Type::Error { self.push_err(TypeCheckError::AccessUnknownMember { lhs_type, diff --git a/compiler/noirc_frontend/src/elaborator/trait_impls.rs b/compiler/noirc_frontend/src/elaborator/trait_impls.rs index 20719b9f09..aa7e1cb89c 100644 --- a/compiler/noirc_frontend/src/elaborator/trait_impls.rs +++ b/compiler/noirc_frontend/src/elaborator/trait_impls.rs @@ -1,16 +1,16 @@ use crate::{ + ast::UnresolvedTypeExpression, graph::CrateId, hir::def_collector::{dc_crate::UnresolvedTraitImpl, errors::DefCollectorErrorKind}, + macros_api::{Ident, UnresolvedType, UnresolvedTypeData}, + node_interner::TraitImplId, ResolvedGeneric, }; use crate::{ hir::def_collector::errors::DuplicateType, - hir_def::{ - traits::{TraitConstraint, TraitFunction}, - types::Generics, - }, + hir_def::traits::{TraitConstraint, TraitFunction}, node_interner::{FuncId, TraitId}, - Type, TypeBindings, + Type, }; use noirc_errors::Location; @@ -28,6 +28,8 @@ impl<'context> Elaborator<'context> { self.local_module = trait_impl.module_id; self.file = trait_impl.file_id; + let impl_id = trait_impl.impl_id.expect("impl_id should be set in define_function_metas"); + // In this Vec methods[i] corresponds to trait.methods[i]. If the impl has no implementation // for a particular method, the default implementation will be added at that slot. let mut ordered_methods = Vec::new(); @@ -38,7 +40,6 @@ impl<'context> Elaborator<'context> { // set of function ids that have a corresponding method in the trait let mut func_ids_in_trait = HashSet::default(); - let trait_generics = &self.interner.get_trait(trait_id).generics.clone(); // Temporarily take ownership of the trait's methods so we can iterate over them // while also mutating the interner let the_trait = self.interner.get_trait_mut(trait_id); @@ -82,7 +83,8 @@ impl<'context> Elaborator<'context> { method, trait_impl_where_clause, &trait_impl.resolved_trait_generics, - trait_generics, + trait_id, + impl_id, ); func_ids_in_trait.insert(*func_id); @@ -138,16 +140,15 @@ impl<'context> Elaborator<'context> { func_id: &FuncId, method: &TraitFunction, trait_impl_where_clause: &[TraitConstraint], - impl_trait_generics: &[Type], - trait_generics: &Generics, + trait_impl_generics: &[Type], + trait_id: TraitId, + impl_id: TraitImplId, ) { - let mut bindings = TypeBindings::new(); - for (trait_generic, impl_trait_generic) in trait_generics.iter().zip(impl_trait_generics) { - bindings.insert( - trait_generic.type_var.id(), - (trait_generic.type_var.clone(), impl_trait_generic.clone()), - ); - } + // First get the general trait to impl bindings. + // Then we'll need to add the bindings for this specific method. + let self_type = self.self_type.as_ref().unwrap().clone(); + let mut bindings = + self.interner.trait_to_impl_bindings(trait_id, impl_id, trait_impl_generics, self_type); let override_meta = self.interner.function_meta(func_id); // Substitute each generic on the trait function with the corresponding generic on the impl function @@ -163,11 +164,9 @@ impl<'context> Elaborator<'context> { let mut substituted_method_ids = HashSet::default(); for method_constraint in method.trait_constraints.iter() { let substituted_constraint_type = method_constraint.typ.substitute(&bindings); - let substituted_trait_generics = method_constraint - .trait_generics - .iter() - .map(|generic| generic.substitute(&bindings)) - .collect::>(); + let substituted_trait_generics = + method_constraint.trait_generics.map(|generic| generic.substitute(&bindings)); + substituted_method_ids.insert(( substituted_constraint_type, method_constraint.trait_id, @@ -222,4 +221,26 @@ impl<'context> Elaborator<'context> { }); } } + + pub(super) fn take_unresolved_associated_types( + &mut self, + trait_impl: &mut UnresolvedTraitImpl, + ) -> Vec<(Ident, UnresolvedType)> { + let mut associated_types = Vec::new(); + for (name, _, expr) in trait_impl.associated_constants.drain(..) { + let span = expr.span; + let typ = match UnresolvedTypeExpression::from_expr(expr, span) { + Ok(expr) => UnresolvedTypeData::Expression(expr).with_span(span), + Err(error) => { + self.push_err(error); + UnresolvedTypeData::Error.with_span(span) + } + }; + associated_types.push((name, typ)); + } + for (name, typ) in trait_impl.associated_types.drain(..) { + associated_types.push((name, typ)); + } + associated_types + } } diff --git a/compiler/noirc_frontend/src/elaborator/traits.rs b/compiler/noirc_frontend/src/elaborator/traits.rs index 5240774625..f651630baa 100644 --- a/compiler/noirc_frontend/src/elaborator/traits.rs +++ b/compiler/noirc_frontend/src/elaborator/traits.rs @@ -7,14 +7,8 @@ use crate::{ ast::{ FunctionKind, TraitItem, UnresolvedGeneric, UnresolvedGenerics, UnresolvedTraitConstraint, }, - hir::{ - def_collector::dc_crate::{CompilationError, UnresolvedTrait, UnresolvedTraitImpl}, - type_check::TypeCheckError, - }, - hir_def::{ - function::Parameters, - traits::{TraitConstant, TraitFunction, TraitType}, - }, + hir::{def_collector::dc_crate::UnresolvedTrait, type_check::TypeCheckError}, + hir_def::{function::Parameters, traits::TraitFunction}, macros_api::{ BlockExpression, FunctionDefinition, FunctionReturnType, Ident, ItemVisibility, NodeInterner, NoirFunction, Param, Pattern, UnresolvedType, Visibility, @@ -30,18 +24,19 @@ impl<'context> Elaborator<'context> { pub fn collect_traits(&mut self, traits: &BTreeMap) { for (trait_id, unresolved_trait) in traits { self.recover_generics(|this| { + this.current_trait = Some(*trait_id); + let resolved_generics = this.interner.get_trait(*trait_id).generics.clone(); this.add_existing_generics( &unresolved_trait.trait_def.generics, &resolved_generics, ); - // Resolve order - // 1. Trait Types ( Trait constants can have a trait type, therefore types before constants) - let _ = this.resolve_trait_types(unresolved_trait); - // 2. Trait Constants ( Trait's methods can use trait types & constants, therefore they should be after) - let _ = this.resolve_trait_constants(unresolved_trait); - // 3. Trait Methods + // Each associated type in this trait is also an implicit generic + for associated_type in &this.interner.get_trait(*trait_id).associated_types { + this.generics.push(associated_type.clone()); + } + let methods = this.resolve_trait_methods(*trait_id, unresolved_trait); this.interner.update_trait(*trait_id, |trait_def| { @@ -57,19 +52,8 @@ impl<'context> Elaborator<'context> { self.interner.try_add_prefix_operator_trait(*trait_id); } } - } - fn resolve_trait_types(&mut self, _unresolved_trait: &UnresolvedTrait) -> Vec { - // TODO - vec![] - } - - fn resolve_trait_constants( - &mut self, - _unresolved_trait: &UnresolvedTrait, - ) -> Vec { - // TODO - vec![] + self.current_trait = None; } fn resolve_trait_methods( @@ -207,31 +191,6 @@ impl<'context> Elaborator<'context> { // Don't check the scope tree for unused variables, they can't be used in a declaration anyway. self.generics.truncate(old_generic_count); } - - pub fn resolve_trait_impl_generics( - &mut self, - trait_impl: &UnresolvedTraitImpl, - trait_id: TraitId, - ) -> Option> { - let trait_def = self.interner.get_trait(trait_id); - let resolved_generics = trait_def.generics.clone(); - if resolved_generics.len() != trait_impl.trait_generics.len() { - self.push_err(CompilationError::TypeError(TypeCheckError::GenericCountMismatch { - item: trait_def.name.to_string(), - expected: resolved_generics.len(), - found: trait_impl.trait_generics.len(), - span: trait_impl.trait_path.span(), - })); - - return None; - } - - let generics = trait_impl.trait_generics.iter().zip(resolved_generics.iter()); - let mapped = generics.map(|(generic, resolved_generic)| { - self.resolve_type_inner(generic.clone(), &resolved_generic.kind) - }); - Some(mapped.collect()) - } } /// Checks that the type of a function in a trait impl matches the type @@ -264,24 +223,18 @@ pub(crate) fn check_trait_impl_method_matches_declaration( let definition_type = meta.typ.as_monotype(); - let impl_ = + let impl_id = meta.trait_impl.expect("Trait impl function should have a corresponding trait impl"); // If the trait implementation is not defined in the interner then there was a previous // error in resolving the trait path and there is likely no trait for this impl. - let Some(impl_) = interner.try_get_trait_implementation(impl_) else { + let Some(impl_) = interner.try_get_trait_implementation(impl_id) else { return errors; }; let impl_ = impl_.borrow(); let trait_info = interner.get_trait(impl_.trait_id); - let mut bindings = TypeBindings::new(); - bindings.insert( - trait_info.self_type_typevar_id, - (trait_info.self_type_typevar.clone(), impl_.typ.clone()), - ); - if trait_info.generics.len() != impl_.trait_generics.len() { let expected = trait_info.generics.len(); let found = impl_.trait_generics.len(); @@ -291,9 +244,12 @@ pub(crate) fn check_trait_impl_method_matches_declaration( } // Substitute each generic on the trait with the corresponding generic on the impl - for (generic, arg) in trait_info.generics.iter().zip(&impl_.trait_generics) { - bindings.insert(generic.type_var.id(), (generic.type_var.clone(), arg.clone())); - } + let mut bindings = interner.trait_to_impl_bindings( + impl_.trait_id, + impl_id, + &impl_.trait_generics, + impl_.typ.clone(), + ); // If this is None, the trait does not have the corresponding function. // This error should have been caught in name resolution already so we don't diff --git a/compiler/noirc_frontend/src/elaborator/types.rs b/compiler/noirc_frontend/src/elaborator/types.rs index f74d7449cd..44bded6b92 100644 --- a/compiler/noirc_frontend/src/elaborator/types.rs +++ b/compiler/noirc_frontend/src/elaborator/types.rs @@ -1,19 +1,23 @@ -use std::{collections::BTreeMap, rc::Rc}; +use std::{borrow::Cow, collections::BTreeMap, rc::Rc}; use acvm::acir::AcirField; use iter_extended::vecmap; use noirc_errors::{Location, Span}; +use rustc_hash::FxHashMap as HashMap; use crate::{ ast::{ - BinaryOpKind, IntegerBitSize, UnresolvedGeneric, UnresolvedGenerics, - UnresolvedTypeExpression, + AsTraitPath, BinaryOpKind, GenericTypeArgs, IntegerBitSize, UnresolvedGeneric, + UnresolvedGenerics, UnresolvedTypeExpression, }, hir::{ comptime::{Interpreter, Value}, def_map::ModuleDefId, resolution::errors::ResolverError, - type_check::{NoMatchingImplFoundError, Source, TypeCheckError}, + type_check::{ + generics::{Generic, TraitGenerics}, + NoMatchingImplFoundError, Source, TypeCheckError, + }, }, hir_def::{ expr::{ @@ -21,17 +25,18 @@ use crate::{ HirPrefixExpression, }, function::{FuncMeta, Parameters}, - traits::TraitConstraint, + traits::{NamedType, TraitConstraint}, }, macros_api::{ - HirExpression, HirLiteral, HirStatement, NodeInterner, Path, PathKind, SecondaryAttribute, - Signedness, UnaryOp, UnresolvedType, UnresolvedTypeData, + HirExpression, HirLiteral, HirStatement, Ident, NodeInterner, Path, PathKind, + SecondaryAttribute, Signedness, UnaryOp, UnresolvedType, UnresolvedTypeData, }, node_interner::{ - DefinitionKind, DependencyId, ExprId, FuncId, GlobalId, TraitId, TraitImplKind, - TraitMethodId, + DefinitionKind, DependencyId, ExprId, FuncId, GlobalId, ImplSearchErrorKind, TraitId, + TraitImplKind, TraitMethodId, }, - Generics, Kind, ResolvedGeneric, Type, TypeBinding, TypeVariable, TypeVariableKind, + Generics, Kind, ResolvedGeneric, Type, TypeBinding, TypeBindings, TypeVariable, + TypeVariableKind, }; use super::{lints, Elaborator}; @@ -115,7 +120,11 @@ impl<'context> Elaborator<'context> { } Quoted(quoted) => Type::Quoted(quoted), Unit => Type::Unit, - Unspecified => Type::Error, + Unspecified => { + let span = typ.span; + self.push_err(TypeCheckError::UnspecifiedType { span }); + Type::Error + } Error => Type::Error, Named(path, args, _) => self.resolve_named_type(path, args), TraitAsType(path, args) => self.resolve_trait_as_type(path, args), @@ -148,17 +157,14 @@ impl<'context> Elaborator<'context> { } Parenthesized(typ) => self.resolve_type_inner(*typ, kind), Resolved(id) => self.interner.get_quoted_type(id).clone(), - AsTraitPath(_) => todo!("Resolve AsTraitPath"), + AsTraitPath(path) => self.resolve_as_trait_path(*path), }; - let unresolved_span = typ.span; - let location = Location::new(named_path_span.unwrap_or(unresolved_span), self.file); - + let location = Location::new(named_path_span.unwrap_or(typ.span), self.file); match resolved_type { Type::Struct(ref struct_type, _) => { // Record the location of the type reference self.interner.push_type_ref_location(resolved_type.clone(), location); - if !is_synthetic { self.interner.add_struct_reference( struct_type.borrow().id, @@ -202,7 +208,27 @@ impl<'context> Elaborator<'context> { self.generics.iter().find(|generic| generic.name.as_ref() == target_name) } - fn resolve_named_type(&mut self, path: Path, args: Vec) -> Type { + // Resolve Self::Foo to an associated type on the current trait or trait impl + fn lookup_associated_type_on_self(&self, path: &Path) -> Option { + if path.segments.len() == 2 && path.first_name() == SELF_TYPE_NAME { + if let Some(trait_id) = self.current_trait { + let the_trait = self.interner.get_trait(trait_id); + if let Some(typ) = the_trait.get_associated_type(path.last_name()) { + return Some(typ.clone().as_named_generic()); + } + } + + if let Some(impl_id) = self.current_trait_impl { + let name = path.last_name(); + if let Some(typ) = self.interner.find_associated_type_for_impl(impl_id, name) { + return Some(typ.clone()); + } + } + } + None + } + + fn resolve_named_type(&mut self, path: Path, args: GenericTypeArgs) -> Type { if args.is_empty() { if let Some(typ) = self.lookup_generic_or_global_type(&path) { return typ; @@ -224,28 +250,18 @@ impl<'context> Elaborator<'context> { } else if name == WILDCARD_TYPE { return self.interner.next_type_variable(); } + } else if let Some(typ) = self.lookup_associated_type_on_self(&path) { + if !args.is_empty() { + self.push_err(ResolverError::GenericsOnAssociatedType { span: path.span() }); + } + return typ; } let span = path.span(); if let Some(type_alias) = self.lookup_type_alias(path.clone()) { - let type_alias = type_alias.borrow(); - let actual_generic_count = args.len(); - let expected_generic_count = type_alias.generics.len(); - let type_alias_string = type_alias.to_string(); - let id = type_alias.id; - - let mut args = vecmap(type_alias.generics.iter().zip(args), |(generic, arg)| { - self.resolve_type_inner(arg, &generic.kind) - }); - - self.verify_generics_count( - expected_generic_count, - actual_generic_count, - &mut args, - span, - || type_alias_string, - ); + let id = type_alias.borrow().id; + let (args, _) = self.resolve_type_args(args, id, path.span()); if let Some(item) = self.current_item { self.interner.add_type_alias_dependency(item, id); @@ -260,8 +276,7 @@ impl<'context> Elaborator<'context> { // equal to another type alias. Fixing this fully requires an analysis to create a DFG // of definition ordering, but for now we have an explicit check here so that we at // least issue an error that the type was not found instead of silently passing. - let alias = self.interner.get_type_alias(id); - return Type::Alias(alias, args); + return Type::Alias(type_alias, args); } match self.lookup_struct_or_error(path) { @@ -274,9 +289,6 @@ impl<'context> Elaborator<'context> { return Type::Error; } - let expected_generic_count = struct_type.borrow().generics.len(); - let actual_generic_count = args.len(); - if !self.in_contract() && self .interner @@ -289,18 +301,7 @@ impl<'context> Elaborator<'context> { }); } - let mut args = - vecmap(struct_type.borrow().generics.iter().zip(args), |(generic, arg)| { - self.resolve_type_inner(arg, &generic.kind) - }); - - self.verify_generics_count( - expected_generic_count, - actual_generic_count, - &mut args, - span, - || struct_type.borrow().to_string(), - ); + let (args, _) = self.resolve_type_args(args, struct_type.borrow(), span); if let Some(current_item) = self.current_item { let dependency_id = struct_type.borrow().id; @@ -313,44 +314,99 @@ impl<'context> Elaborator<'context> { } } - fn resolve_trait_as_type(&mut self, path: Path, args: Vec) -> Type { + fn resolve_trait_as_type(&mut self, path: Path, args: GenericTypeArgs) -> Type { // Fetch information needed from the trait as the closure for resolving all the `args` // requires exclusive access to `self` - let trait_as_type_info = self - .lookup_trait_or_error(path) - .map(|t| (t.id, Rc::new(t.name.to_string()), t.generics.clone())); - - if let Some((id, name, resolved_generics)) = trait_as_type_info { - assert_eq!(resolved_generics.len(), args.len()); - let generics_with_types = resolved_generics.iter().zip(args); - let args = vecmap(generics_with_types, |(generic, typ)| { - self.resolve_type_inner(typ, &generic.kind) - }); - Type::TraitAsType(id, Rc::new(name.to_string()), args) + let span = path.span; + let trait_as_type_info = self.lookup_trait_or_error(path).map(|t| t.id); + + if let Some(id) = trait_as_type_info { + let (ordered, named) = self.resolve_type_args(args, id, span); + let name = self.interner.get_trait(id).name.to_string(); + let generics = TraitGenerics { ordered, named }; + Type::TraitAsType(id, Rc::new(name), generics) } else { Type::Error } } - fn verify_generics_count( + pub(super) fn resolve_type_args( &mut self, - expected_count: usize, - actual_count: usize, - args: &mut Vec, + mut args: GenericTypeArgs, + item: impl Generic, span: Span, - type_name: impl FnOnce() -> String, - ) { - if actual_count != expected_count { - self.push_err(ResolverError::IncorrectGenericCount { + ) -> (Vec, Vec) { + let expected_kinds = item.generics(self.interner); + + if args.ordered_args.len() != expected_kinds.len() { + self.push_err(TypeCheckError::GenericCountMismatch { + item: item.item_name(self.interner), + expected: expected_kinds.len(), + found: args.ordered_args.len(), span, - item_name: type_name(), - actual: actual_count, - expected: expected_count, }); + let error_type = UnresolvedTypeData::Error.with_span(span); + args.ordered_args.resize(expected_kinds.len(), error_type); + } - // Fix the generic count so we can continue typechecking - args.resize_with(expected_count, || Type::Error); + let ordered_args = expected_kinds.iter().zip(args.ordered_args); + let ordered = + vecmap(ordered_args, |(generic, typ)| self.resolve_type_inner(typ, &generic.kind)); + + let mut associated = Vec::new(); + + if item.accepts_named_type_args() { + associated = self.resolve_associated_type_args(args.named_args, item, span); + } else if !args.named_args.is_empty() { + let item_kind = item.item_kind(); + self.push_err(ResolverError::NamedTypeArgs { span, item_kind }); + } + + (ordered, associated) + } + + fn resolve_associated_type_args( + &mut self, + args: Vec<(Ident, UnresolvedType)>, + item: impl Generic, + span: Span, + ) -> Vec { + let mut seen_args = HashMap::default(); + let mut required_args = item.named_generics(self.interner); + let mut resolved = Vec::with_capacity(required_args.len()); + + // Go through each argument to check if it is in our required_args list. + // If it is remove it from the list, otherwise issue an error. + for (name, typ) in args { + let index = + required_args.iter().position(|item| item.name.as_ref() == &name.0.contents); + + let Some(index) = index else { + if let Some(prev_span) = seen_args.get(&name.0.contents).copied() { + self.push_err(TypeCheckError::DuplicateNamedTypeArg { name, prev_span }); + } else { + let item = item.item_name(self.interner); + self.push_err(TypeCheckError::NoSuchNamedTypeArg { name, item }); + } + continue; + }; + + // Remove the argument from the required list so we remember that we already have it + let expected = required_args.remove(index); + seen_args.insert(name.0.contents.clone(), name.span()); + + let typ = self.resolve_type_inner(typ, &expected.kind); + resolved.push(NamedType { name, typ }); + } + + // Anything that hasn't been removed yet is missing + for generic in required_args { + let item = item.item_name(self.interner); + let name = generic.name.clone(); + self.push_err(TypeCheckError::MissingNamedTypeArg { item, span, name }); } + + resolved } pub fn lookup_generic_or_global_type(&mut self, path: &Path) -> Option { @@ -360,6 +416,8 @@ impl<'context> Elaborator<'context> { let generic = generic.clone(); return Some(Type::NamedGeneric(generic.type_var, generic.name, generic.kind)); } + } else if let Some(typ) = self.lookup_associated_type_on_self(path) { + return Some(typ); } // If we cannot find a local generic of the same name, try to look up a global @@ -407,6 +465,49 @@ impl<'context> Elaborator<'context> { } } } + UnresolvedTypeExpression::AsTraitPath(path) => self.resolve_as_trait_path(*path), + } + } + + fn resolve_as_trait_path(&mut self, path: AsTraitPath) -> Type { + let span = path.trait_path.span; + let Some(trait_id) = self.resolve_trait_by_path(path.trait_path.clone()) else { + // Error should already be pushed in the None case + return Type::Error; + }; + + let (ordered, named) = self.resolve_type_args(path.trait_generics.clone(), trait_id, span); + let object_type = self.resolve_type(path.typ.clone()); + + match self.interner.lookup_trait_implementation(&object_type, trait_id, &ordered, &named) { + Ok(impl_kind) => self.get_associated_type_from_trait_impl(path, impl_kind), + Err(constraints) => { + self.push_trait_constraint_error(&object_type, constraints, span); + Type::Error + } + } + } + + fn get_associated_type_from_trait_impl( + &mut self, + path: AsTraitPath, + impl_kind: TraitImplKind, + ) -> Type { + let associated_types = match impl_kind { + TraitImplKind::Assumed { trait_generics, .. } => Cow::Owned(trait_generics.named), + TraitImplKind::Normal(impl_id) => { + Cow::Borrowed(self.interner.get_associated_types_for_impl(impl_id)) + } + }; + + match associated_types.iter().find(|named| named.name == path.impl_item) { + Some(generic) => generic.typ.clone(), + None => { + let name = path.impl_item.clone(); + let item = format!("<{} as {}>", path.typ, path.trait_path); + self.push_err(TypeCheckError::NoSuchNamedTypeArg { name, item }); + Type::Error + } } } @@ -428,17 +529,8 @@ impl<'context> Elaborator<'context> { if name == SELF_TYPE_NAME { let the_trait = self.interner.get_trait(trait_id); let method = the_trait.find_method(method.0.contents.as_str())?; - - let constraint = TraitConstraint { - typ: self.self_type.clone()?, - trait_generics: Type::from_generics(&vecmap(&the_trait.generics, |generic| { - generic.type_var.clone() - })), - trait_id, - span: path.span(), - }; - - return Some((method, constraint, false)); + let constraint = the_trait.as_constraint(path.span); + return Some((method, constraint, true)); } } None @@ -454,17 +546,9 @@ impl<'context> Elaborator<'context> { ) -> Option<(TraitMethodId, TraitConstraint, bool)> { let func_id: FuncId = self.lookup(path.clone()).ok()?; let meta = self.interner.function_meta(&func_id); - let trait_id = meta.trait_id?; - let the_trait = self.interner.get_trait(trait_id); + let the_trait = self.interner.get_trait(meta.trait_id?); let method = the_trait.find_method(path.last_name())?; - let constraint = TraitConstraint { - typ: Type::TypeVariable(the_trait.self_type_typevar.clone(), TypeVariableKind::Normal), - trait_generics: Type::from_generics(&vecmap(&the_trait.generics, |generic| { - generic.type_var.clone() - })), - trait_id, - span: path.span(), - }; + let constraint = the_trait.as_constraint(path.span); Some((method, constraint, false)) } @@ -674,7 +758,7 @@ impl<'context> Elaborator<'context> { /// Insert as many dereference operations as necessary to automatically dereference a method /// call object to its base value type T. pub(super) fn insert_auto_dereferences(&mut self, object: ExprId, typ: Type) -> (ExprId, Type) { - if let Type::MutableReference(element) = typ { + if let Type::MutableReference(element) = typ.follow_bindings() { let location = self.interner.id_location(object); let object = self.interner.push_expr(HirExpression::Prefix(HirPrefixExpression { @@ -957,9 +1041,6 @@ impl<'context> Elaborator<'context> { // Matches on TypeVariable must be first so that we follow any type // bindings. (TypeVariable(int, _), other) | (other, TypeVariable(int, _)) => { - if let TypeBinding::Bound(binding) = &*int.borrow() { - return self.infix_operand_type_rules(binding, op, other, span); - } if op.kind == BinaryOpKind::ShiftLeft || op.kind == BinaryOpKind::ShiftRight { self.unify( rhs_type, @@ -974,6 +1055,9 @@ impl<'context> Elaborator<'context> { }; return Ok((lhs_type.clone(), use_impl)); } + if let TypeBinding::Bound(binding) = &*int.borrow() { + return self.infix_operand_type_rules(binding, op, other, span); + } let use_impl = self.bind_type_variables_for_infix(lhs_type, op, rhs_type, span); Ok((other.clone(), use_impl)) } @@ -1151,7 +1235,7 @@ impl<'context> Elaborator<'context> { let the_trait = self.interner.get_trait(trait_method_id.trait_id); let object_type = object_type.substitute(&bindings); bindings.insert( - the_trait.self_type_typevar_id, + the_trait.self_type_typevar.id(), (the_trait.self_type_typevar.clone(), object_type.clone()), ); self.interner.select_impl_for_expression( @@ -1248,7 +1332,7 @@ impl<'context> Elaborator<'context> { // The type variable must be unbound at this point since follow_bindings was called Type::TypeVariable(_, TypeVariableKind::Normal) => { - self.push_err(TypeCheckError::TypeAnnotationsNeeded { span }); + self.push_err(TypeCheckError::TypeAnnotationsNeededForMethodCall { span }); None } @@ -1283,10 +1367,9 @@ impl<'context> Elaborator<'context> { if method.name.0.contents == method_name { let trait_method = TraitMethodId { trait_id: constraint.trait_id, method_index }; - return Some(HirMethodReference::TraitMethodId( - trait_method, - constraint.trait_generics.clone(), - )); + + let generics = constraint.trait_generics.clone(); + return Some(HirMethodReference::TraitMethodId(trait_method, generics)); } } } @@ -1440,7 +1523,16 @@ impl<'context> Elaborator<'context> { let func_span = self.interner.expr_span(&body_id); // XXX: We could be more specific and return the span of the last stmt, however stmts do not have spans yet if let Type::TraitAsType(trait_id, _, generics) = declared_return_type { - if self.interner.lookup_trait_implementation(&body_type, *trait_id, generics).is_err() { + if self + .interner + .lookup_trait_implementation( + &body_type, + *trait_id, + &generics.ordered, + &generics.named, + ) + .is_err() + { self.push_err(TypeCheckError::TypeMismatchWithSource { expected: declared_return_type.clone(), actual: body_type, @@ -1491,22 +1583,47 @@ impl<'context> Elaborator<'context> { object_type: &Type, trait_id: TraitId, trait_generics: &[Type], + associated_types: &[NamedType], function_ident_id: ExprId, span: Span, ) { - match self.interner.lookup_trait_implementation(object_type, trait_id, trait_generics) { + match self.interner.lookup_trait_implementation( + object_type, + trait_id, + trait_generics, + associated_types, + ) { Ok(impl_kind) => { self.interner.select_impl_for_expression(function_ident_id, impl_kind); } - Err(erroring_constraints) => { - if erroring_constraints.is_empty() { - self.push_err(TypeCheckError::TypeAnnotationsNeeded { span }); - } else if let Some(error) = - NoMatchingImplFoundError::new(self.interner, erroring_constraints, span) + Err(error) => self.push_trait_constraint_error(object_type, error, span), + } + } + + fn push_trait_constraint_error( + &mut self, + object_type: &Type, + error: ImplSearchErrorKind, + span: Span, + ) { + match error { + ImplSearchErrorKind::TypeAnnotationsNeededOnObjectType => { + self.push_err(TypeCheckError::TypeAnnotationsNeededForMethodCall { span }); + } + ImplSearchErrorKind::Nested(constraints) => { + if let Some(error) = NoMatchingImplFoundError::new(self.interner, constraints, span) { self.push_err(TypeCheckError::NoMatchingImplFound(error)); } } + ImplSearchErrorKind::MultipleMatching(candidates) => { + let object_type = object_type.clone(); + self.push_err(TypeCheckError::MultipleMatchingImpls { + object_type, + span, + candidates, + }); + } } } @@ -1567,9 +1684,12 @@ impl<'context> Elaborator<'context> { | Type::Forall(_, _) => (), Type::TraitAsType(_, _, args) => { - for arg in args { + for arg in &args.ordered { Self::find_numeric_generics_in_type(arg, found); } + for arg in &args.named { + Self::find_numeric_generics_in_type(&arg.typ, found); + } } Type::Array(length, element_type) => { @@ -1665,6 +1785,50 @@ impl<'context> Elaborator<'context> { } } } + + pub fn bind_generics_from_trait_constraint( + &mut self, + constraint: &TraitConstraint, + assumed: bool, + bindings: &mut TypeBindings, + ) { + let the_trait = self.interner.get_trait(constraint.trait_id); + assert_eq!(the_trait.generics.len(), constraint.trait_generics.ordered.len()); + + for (param, arg) in the_trait.generics.iter().zip(&constraint.trait_generics.ordered) { + // Avoid binding t = t + if !arg.occurs(param.type_var.id()) { + bindings.insert(param.type_var.id(), (param.type_var.clone(), arg.clone())); + } + } + + let mut associated_types = the_trait.associated_types.clone(); + assert_eq!(associated_types.len(), constraint.trait_generics.named.len()); + + for arg in &constraint.trait_generics.named { + let i = associated_types + .iter() + .position(|typ| *typ.name == arg.name.0.contents) + .unwrap_or_else(|| { + unreachable!("Expected to find associated type named {}", arg.name) + }); + + let param = associated_types.swap_remove(i); + + // Avoid binding t = t + if !arg.typ.occurs(param.type_var.id()) { + bindings.insert(param.type_var.id(), (param.type_var.clone(), arg.typ.clone())); + } + } + + // If the trait impl is already assumed to exist we should add any type bindings for `Self`. + // Otherwise `self` will be replaced with a fresh type variable, which will require the user + // to specify a redundant type annotation. + if assumed { + let self_type = the_trait.self_type_typevar.clone(); + bindings.insert(self_type.id(), (self_type, constraint.typ.clone())); + } + } } /// Gives an error if a user tries to create a mutable reference diff --git a/compiler/noirc_frontend/src/hir/comptime/errors.rs b/compiler/noirc_frontend/src/hir/comptime/errors.rs index b7b4909023..fd916485ea 100644 --- a/compiler/noirc_frontend/src/hir/comptime/errors.rs +++ b/compiler/noirc_frontend/src/hir/comptime/errors.rs @@ -188,11 +188,19 @@ pub enum InterpreterError { FunctionAlreadyResolved { location: Location, }, + MultipleMatchingImpls { + object_type: Type, + candidates: Vec, + location: Location, + }, Unimplemented { item: String, location: Location, }, + TypeAnnotationsNeededForMethodCall { + location: Location, + }, // These cases are not errors, they are just used to prevent us from running more code // until the loop can be resumed properly. These cases will never be displayed to users. @@ -257,8 +265,10 @@ impl InterpreterError { | InterpreterError::ContinueNotInLoop { location, .. } | InterpreterError::TraitDefinitionMustBeAPath { location } | InterpreterError::FailedToResolveTraitDefinition { location } - | InterpreterError::FailedToResolveTraitBound { location, .. } => *location, - InterpreterError::FunctionAlreadyResolved { location, .. } => *location, + | InterpreterError::FailedToResolveTraitBound { location, .. } + | InterpreterError::FunctionAlreadyResolved { location, .. } + | InterpreterError::MultipleMatchingImpls { location, .. } + | InterpreterError::TypeAnnotationsNeededForMethodCall { location } => *location, InterpreterError::FailedToParseMacro { error, file, .. } => { Location::new(error.span(), *file) @@ -527,6 +537,26 @@ impl<'a> From<&'a InterpreterError> for CustomDiagnostic { .to_string(); CustomDiagnostic::simple_error(msg, secondary, location.span) } + InterpreterError::MultipleMatchingImpls { object_type, candidates, location } => { + let message = format!("Multiple trait impls match the object type `{object_type}`"); + let secondary = "Ambiguous impl".to_string(); + let mut error = CustomDiagnostic::simple_error(message, secondary, location.span); + for (i, candidate) in candidates.iter().enumerate() { + error.add_note(format!("Candidate {}: `{candidate}`", i + 1)); + } + error + } + InterpreterError::TypeAnnotationsNeededForMethodCall { location } => { + let mut error = CustomDiagnostic::simple_error( + "Object type is unknown in method call".to_string(), + "Type must be known by this point to know which method to call".to_string(), + location.span, + ); + let message = + "Try adding a type annotation for the object type before this method call"; + error.add_note(message.to_string()); + error + } } } } diff --git a/compiler/noirc_frontend/src/hir/comptime/hir_to_display_ast.rs b/compiler/noirc_frontend/src/hir/comptime/hir_to_display_ast.rs index 07c5c1a0c7..1c03184a8f 100644 --- a/compiler/noirc_frontend/src/hir/comptime/hir_to_display_ast.rs +++ b/compiler/noirc_frontend/src/hir/comptime/hir_to_display_ast.rs @@ -3,8 +3,8 @@ use noirc_errors::{Span, Spanned}; use crate::ast::{ ArrayLiteral, AssignStatement, BlockExpression, CallExpression, CastExpression, ConstrainKind, - ConstructorExpression, ExpressionKind, ForLoopStatement, ForRange, Ident, IfExpression, - IndexExpression, InfixExpression, LValue, Lambda, LetStatement, Literal, + ConstructorExpression, ExpressionKind, ForLoopStatement, ForRange, GenericTypeArgs, Ident, + IfExpression, IndexExpression, InfixExpression, LValue, Lambda, LetStatement, Literal, MemberAccessExpression, MethodCallExpression, Path, PathSegment, Pattern, PrefixExpression, UnresolvedType, UnresolvedTypeData, UnresolvedTypeExpression, }; @@ -300,7 +300,8 @@ impl Type { } Type::Struct(def, generics) => { let struct_def = def.borrow(); - let generics = vecmap(generics, |generic| generic.to_display_ast()); + let ordered_args = vecmap(generics, |generic| generic.to_display_ast()); + let generics = GenericTypeArgs { ordered_args, named_args: Vec::new() }; let name = Path::from_ident(struct_def.name.clone()); UnresolvedTypeData::Named(name, generics, false) } @@ -308,7 +309,8 @@ impl Type { // Keep the alias name instead of expanding this in case the // alias' definition was changed let type_def = type_def.borrow(); - let generics = vecmap(generics, |generic| generic.to_display_ast()); + let ordered_args = vecmap(generics, |generic| generic.to_display_ast()); + let generics = GenericTypeArgs { ordered_args, named_args: Vec::new() }; let name = Path::from_ident(type_def.name.clone()); UnresolvedTypeData::Named(name, generics, false) } @@ -335,13 +337,17 @@ impl Type { } } Type::TraitAsType(_, name, generics) => { - let generics = vecmap(generics, |generic| generic.to_display_ast()); + let ordered_args = vecmap(&generics.ordered, |generic| generic.to_display_ast()); + let named_args = vecmap(&generics.named, |named_type| { + (named_type.name.clone(), named_type.typ.to_display_ast()) + }); + let generics = GenericTypeArgs { ordered_args, named_args }; let name = Path::from_single(name.as_ref().clone(), Span::default()); UnresolvedTypeData::TraitAsType(name, generics) } Type::NamedGeneric(_var, name, _kind) => { let name = Path::from_single(name.as_ref().clone(), Span::default()); - UnresolvedTypeData::TraitAsType(name, Vec::new()) + UnresolvedTypeData::Named(name, GenericTypeArgs::default(), true) } Type::Function(args, ret, env, unconstrained) => { let args = vecmap(args, |arg| arg.to_display_ast()); diff --git a/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs b/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs index 14c217b2aa..e577904fb5 100644 --- a/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs +++ b/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin.rs @@ -5,20 +5,22 @@ use std::{ use acvm::{AcirField, FieldElement}; use builtin_helpers::{ - check_argument_count, check_function_not_yet_resolved, check_one_argument, - check_three_arguments, check_two_arguments, get_expr, get_function_def, get_module, get_quoted, - get_slice, get_struct, get_trait_constraint, get_trait_def, get_trait_impl, get_tuple, - get_type, get_u32, hir_pattern_to_tokens, mutate_func_meta_type, parse, parse_tokens, - replace_func_meta_parameters, replace_func_meta_return_type, + block_expression_to_value, check_argument_count, check_function_not_yet_resolved, + check_one_argument, check_three_arguments, check_two_arguments, get_expr, get_function_def, + get_module, get_quoted, get_slice, get_struct, get_trait_constraint, get_trait_def, + get_trait_impl, get_tuple, get_type, get_u32, get_unresolved_type, hir_pattern_to_tokens, + mutate_func_meta_type, parse, parse_tokens, replace_func_meta_parameters, + replace_func_meta_return_type, }; +use im::Vector; use iter_extended::{try_vecmap, vecmap}; use noirc_errors::Location; use rustc_hash::FxHashMap as HashMap; use crate::{ ast::{ - ArrayLiteral, ExpressionKind, FunctionKind, FunctionReturnType, IntegerBitSize, Literal, - UnaryOp, UnresolvedType, UnresolvedTypeData, Visibility, + ArrayLiteral, Expression, ExpressionKind, FunctionKind, FunctionReturnType, IntegerBitSize, + Literal, UnaryOp, UnresolvedType, UnresolvedTypeData, Visibility, }, hir_def::function::FunctionBody, macros_api::{ModuleDefId, NodeInterner, Signedness}, @@ -47,13 +49,17 @@ impl<'local, 'context> Interpreter<'local, 'context> { "array_len" => array_len(interner, arguments, location), "as_slice" => as_slice(interner, arguments, location), "expr_as_array" => expr_as_array(arguments, return_type, location), + "expr_as_assign" => expr_as_assign(arguments, return_type, location), "expr_as_binary_op" => expr_as_binary_op(arguments, return_type, location), "expr_as_bool" => expr_as_bool(arguments, return_type, location), + "expr_as_cast" => expr_as_cast(arguments, return_type, location), + "expr_as_comptime" => expr_as_comptime(arguments, return_type, location), "expr_as_function_call" => expr_as_function_call(arguments, return_type, location), "expr_as_if" => expr_as_if(arguments, return_type, location), "expr_as_index" => expr_as_index(arguments, return_type, location), "expr_as_integer" => expr_as_integer(arguments, return_type, location), "expr_as_member_access" => expr_as_member_access(arguments, return_type, location), + "expr_as_method_call" => expr_as_method_call(arguments, return_type, location), "expr_as_repeated_element_array" => { expr_as_repeated_element_array(arguments, return_type, location) } @@ -63,6 +69,9 @@ impl<'local, 'context> Interpreter<'local, 'context> { "expr_as_slice" => expr_as_slice(arguments, return_type, location), "expr_as_tuple" => expr_as_tuple(arguments, return_type, location), "expr_as_unary_op" => expr_as_unary_op(arguments, return_type, location), + "expr_as_unsafe" => expr_as_unsafe(arguments, return_type, location), + "expr_is_break" => expr_is_break(arguments, location), + "expr_is_continue" => expr_is_continue(arguments, location), "is_unconstrained" => Ok(Value::Bool(true)), "function_def_name" => function_def_name(interner, arguments, location), "function_def_parameters" => function_def_parameters(interner, arguments, location), @@ -119,6 +128,7 @@ impl<'local, 'context> Interpreter<'local, 'context> { "type_is_bool" => type_is_bool(arguments, location), "type_is_field" => type_is_field(arguments, location), "type_of" => type_of(arguments, location), + "unresolved_type_is_field" => unresolved_type_is_field(arguments, location), "zeroed" => zeroed(return_type), _ => { let item = format!("Comptime evaluation for builtin function {name}"); @@ -383,6 +393,7 @@ fn quoted_as_trait_constraint( elaborator.resolve_trait_bound(&trait_bound, Type::Unit) }) .ok_or(InterpreterError::FailedToResolveTraitBound { trait_bound, location })?; + Ok(Value::TraitConstraint(bound.trait_id, bound.trait_generics)) } @@ -541,7 +552,12 @@ fn type_get_trait_impl( let typ = get_type(typ)?; let (trait_id, generics) = get_trait_constraint(constraint)?; - let option_value = match interner.try_lookup_trait_implementation(&typ, trait_id, &generics) { + let option_value = match interner.try_lookup_trait_implementation( + &typ, + trait_id, + &generics.ordered, + &generics.named, + ) { Ok((TraitImplKind::Normal(trait_impl_id), _)) => Some(Value::TraitImpl(trait_impl_id)), _ => None, }; @@ -560,7 +576,9 @@ fn type_implements( let typ = get_type(typ)?; let (trait_id, generics) = get_trait_constraint(constraint)?; - let implements = interner.try_lookup_trait_implementation(&typ, trait_id, &generics).is_ok(); + let implements = interner + .try_lookup_trait_implementation(&typ, trait_id, &generics.ordered, &generics.named) + .is_ok(); Ok(Value::Bool(implements)) } @@ -684,6 +702,16 @@ fn trait_impl_trait_generic_args( Ok(Value::Slice(trait_generics, slice_type)) } +// fn is_field(self) -> bool +fn unresolved_type_is_field( + arguments: Vec<(Value, Location)>, + location: Location, +) -> IResult { + let self_argument = check_one_argument(arguments, location)?; + let typ = get_unresolved_type(self_argument)?; + Ok(Value::Bool(matches!(typ, UnresolvedTypeData::FieldElement))) +} + // fn zeroed() -> T fn zeroed(return_type: Type) -> IResult { match return_type { @@ -760,7 +788,7 @@ fn zeroed(return_type: Type) -> IResult { | Type::InfixExpr(..) | Type::Quoted(_) | Type::Error - | Type::TraitAsType(_, _, _) + | Type::TraitAsType(..) | Type::NamedGeneric(_, _, _) => Ok(Value::Zeroed(return_type)), } } @@ -782,6 +810,23 @@ fn expr_as_array( }) } +// fn as_assign(self) -> Option<(Expr, Expr)> +fn expr_as_assign( + arguments: Vec<(Value, Location)>, + return_type: Type, + location: Location, +) -> IResult { + expr_as(arguments, return_type, location, |expr| { + if let ExprValue::Statement(StatementKind::Assign(assign)) = expr { + let lhs = Value::lvalue(assign.lvalue); + let rhs = Value::expression(assign.expression.kind); + Some(Value::Tuple(vec![lhs, rhs])) + } else { + None + } + }) +} + // fn as_binary_op(self) -> Option<(Expr, BinaryOp, Expr)> fn expr_as_binary_op( arguments: Vec<(Value, Location)>, @@ -815,6 +860,21 @@ fn expr_as_binary_op( }) } +// fn as_block(self) -> Option<[Expr]> +fn expr_as_block( + arguments: Vec<(Value, Location)>, + return_type: Type, + location: Location, +) -> IResult { + expr_as(arguments, return_type, location, |expr| { + if let ExprValue::Expression(ExpressionKind::Block(block_expr)) = expr { + Some(block_expression_to_value(block_expr)) + } else { + None + } + }) +} + // fn as_bool(self) -> Option fn expr_as_bool( arguments: Vec<(Value, Location)>, @@ -830,6 +890,55 @@ fn expr_as_bool( }) } +// fn as_cast(self) -> Option<(Expr, UnresolvedType)> +fn expr_as_cast( + arguments: Vec<(Value, Location)>, + return_type: Type, + location: Location, +) -> IResult { + expr_as(arguments, return_type, location, |expr| { + if let ExprValue::Expression(ExpressionKind::Cast(cast)) = expr { + let lhs = Value::expression(cast.lhs.kind); + let typ = Value::UnresolvedType(cast.r#type.typ); + Some(Value::Tuple(vec![lhs, typ])) + } else { + None + } + }) +} + +// fn as_comptime(self) -> Option<[Expr]> +fn expr_as_comptime( + arguments: Vec<(Value, Location)>, + return_type: Type, + location: Location, +) -> IResult { + use ExpressionKind::Block; + + expr_as(arguments, return_type, location, |expr| { + if let ExprValue::Expression(ExpressionKind::Comptime(block_expr, _)) = expr { + Some(block_expression_to_value(block_expr)) + } else if let ExprValue::Statement(StatementKind::Comptime(statement)) = expr { + let typ = Type::Slice(Box::new(Type::Quoted(QuotedType::Expr))); + + // comptime { ... } as a statement wraps a block expression, + // and in that case we return the block expression statements + // (comptime as a statement can also be comptime for, but in that case we'll + // return the for statement as a single expression) + if let StatementKind::Expression(Expression { kind: Block(block), .. }) = statement.kind + { + Some(block_expression_to_value(block)) + } else { + let mut elements = Vector::new(); + elements.push_back(Value::statement(statement.kind)); + Some(Value::Slice(elements, typ)) + } + } else { + None + } + }) +} + // fn as_function_call(self) -> Option<(Expr, [Expr])> fn expr_as_function_call( arguments: Vec<(Value, Location)>, @@ -919,10 +1028,46 @@ fn expr_as_member_access( return_type: Type, location: Location, ) -> IResult { - expr_as(arguments, return_type, location, |expr| { - if let ExpressionKind::MemberAccess(member_access) = expr { + expr_as(arguments, return_type, location, |expr| match expr { + ExpressionKind::MemberAccess(member_access) => { let tokens = Rc::new(vec![Token::Ident(member_access.rhs.0.contents.clone())]); Some(Value::Tuple(vec![Value::Expr(member_access.lhs.kind), Value::Quoted(tokens)])) + } + ExprValue::LValue(crate::ast::LValue::MemberAccess { object, field_name, span: _ }) => { + let tokens = Rc::new(vec![Token::Ident(field_name.0.contents.clone())]); + Some(Value::Tuple(vec![Value::lvalue(*object), Value::Quoted(tokens)])) + } + _ => None, + }) +} + +// fn as_method_call(self) -> Option<(Expr, Quoted, [UnresolvedType], [Expr])> +fn expr_as_method_call( + arguments: Vec<(Value, Location)>, + return_type: Type, + location: Location, +) -> IResult { + expr_as(arguments, return_type, location, |expr| { + if let ExprValue::Expression(ExpressionKind::MethodCall(method_call)) = expr { + let object = Value::expression(method_call.object.kind); + + let name_tokens = + Rc::new(vec![Token::Ident(method_call.method_name.0.contents.clone())]); + let name = Value::Quoted(name_tokens); + + let generics = method_call.generics.unwrap_or_default().into_iter(); + let generics = generics.map(|generic| Value::UnresolvedType(generic.typ)).collect(); + let generics = Value::Slice( + generics, + Type::Slice(Box::new(Type::Quoted(QuotedType::UnresolvedType))), + ); + + let arguments = method_call.arguments.into_iter(); + let arguments = arguments.map(|argument| Value::expression(argument.kind)).collect(); + let arguments = + Value::Slice(arguments, Type::Slice(Box::new(Type::Quoted(QuotedType::Expr)))); + + Some(Value::Tuple(vec![object, name, generics, arguments])) } else { None } @@ -1038,6 +1183,35 @@ fn expr_as_unary_op( }) } +// fn as_unsafe(self) -> Option<[Expr]> +fn expr_as_unsafe( + arguments: Vec<(Value, Location)>, + return_type: Type, + location: Location, +) -> IResult { + expr_as(arguments, return_type, location, |expr| { + if let ExprValue::Expression(ExpressionKind::Unsafe(block_expr, _)) = expr { + Some(block_expression_to_value(block_expr)) + } else { + None + } + }) +} + +// fn is_break(self) -> bool +fn expr_is_break(arguments: Vec<(Value, Location)>, location: Location) -> IResult { + let self_argument = check_one_argument(arguments, location)?; + let expr_value = get_expr(self_argument)?; + Ok(Value::Bool(matches!(expr_value, ExprValue::Statement(StatementKind::Break)))) +} + +// fn is_continue(self) -> bool +fn expr_is_continue(arguments: Vec<(Value, Location)>, location: Location) -> IResult { + let self_argument = check_one_argument(arguments, location)?; + let expr_value = get_expr(self_argument)?; + Ok(Value::Bool(matches!(expr_value, ExprValue::Statement(StatementKind::Continue)))) +} + // Helper function for implementing the `expr_as_...` functions. fn expr_as( arguments: Vec<(Value, Location)>, @@ -1356,12 +1530,9 @@ fn trait_def_as_trait_constraint( let argument = check_one_argument(arguments, location)?; let trait_id = get_trait_def(argument)?; - let the_trait = interner.get_trait(trait_id); - let trait_generics = vecmap(&the_trait.generics, |generic| { - Type::NamedGeneric(generic.type_var.clone(), generic.name.clone(), generic.kind.clone()) - }); + let constraint = interner.get_trait(trait_id).as_constraint(location.span); - Ok(Value::TraitConstraint(trait_id, trait_generics)) + Ok(Value::TraitConstraint(trait_id, constraint.trait_generics)) } /// Creates a value that holds an `Option`. diff --git a/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin/builtin_helpers.rs b/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin/builtin_helpers.rs index 81abc4e76f..a409731a5e 100644 --- a/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin/builtin_helpers.rs +++ b/compiler/noirc_frontend/src/hir/comptime/interpreter/builtin/builtin_helpers.rs @@ -4,7 +4,7 @@ use acvm::FieldElement; use noirc_errors::Location; use crate::{ - ast::{IntegerBitSize, Signedness}, + ast::{BlockExpression, IntegerBitSize, Signedness, UnresolvedTypeData}, hir::{ comptime::{ errors::IResult, @@ -12,6 +12,7 @@ use crate::{ Interpreter, InterpreterError, Value, }, def_map::ModuleId, + type_check::generics::TraitGenerics, }, hir_def::{ function::{FuncMeta, FunctionBody}, @@ -171,7 +172,7 @@ pub(crate) fn get_struct((value, location): (Value, Location)) -> IResult IResult<(TraitId, Vec)> { +) -> IResult<(TraitId, TraitGenerics)> { match value { Value::TraitConstraint(trait_id, generics) => Ok((trait_id, generics)), value => type_mismatch(value, Type::Quoted(QuotedType::TraitConstraint), location), @@ -206,6 +207,15 @@ pub(crate) fn get_quoted((value, location): (Value, Location)) -> IResult IResult { + match value { + Value::UnresolvedType(typ) => Ok(typ), + value => type_mismatch(value, Type::Quoted(QuotedType::UnresolvedType), location), + } +} + fn type_mismatch(value: Value, expected: Type, location: Location) -> IResult { let actual = value.get_type().into_owned(); Err(InterpreterError::TypeMismatch { expected, actual, location }) @@ -350,3 +360,11 @@ pub(super) fn replace_func_meta_return_type(typ: &mut Type, return_type: Type) { _ => {} } } + +pub(super) fn block_expression_to_value(block_expr: BlockExpression) -> Value { + let typ = Type::Slice(Box::new(Type::Quoted(QuotedType::Expr))); + let statements = block_expr.statements.into_iter(); + let statements = statements.map(|statement| Value::statement(statement.kind)).collect(); + + Value::Slice(statements, typ) +} diff --git a/compiler/noirc_frontend/src/hir/comptime/value.rs b/compiler/noirc_frontend/src/hir/comptime/value.rs index 21957f3eb2..18f482585e 100644 --- a/compiler/noirc_frontend/src/hir/comptime/value.rs +++ b/compiler/noirc_frontend/src/hir/comptime/value.rs @@ -9,10 +9,10 @@ use strum_macros::Display; use crate::{ ast::{ - ArrayLiteral, BlockExpression, ConstructorExpression, Ident, IntegerBitSize, Signedness, - Statement, StatementKind, + ArrayLiteral, BlockExpression, ConstructorExpression, Ident, IntegerBitSize, LValue, + Signedness, Statement, StatementKind, UnresolvedTypeData, }, - hir::def_map::ModuleId, + hir::{def_map::ModuleId, type_check::generics::TraitGenerics}, hir_def::{ expr::{HirArrayLiteral, HirConstructorExpression, HirIdent, HirLambda, ImplKind}, traits::TraitConstraint, @@ -58,7 +58,7 @@ pub enum Value { /// be inserted into separate files entirely. Quoted(Rc>), StructDefinition(StructId), - TraitConstraint(TraitId, /* trait generics */ Vec), + TraitConstraint(TraitId, TraitGenerics), TraitDefinition(TraitId), TraitImpl(TraitImplId), FunctionDefinition(FuncId), @@ -66,12 +66,14 @@ pub enum Value { Type(Type), Zeroed(Type), Expr(ExprValue), + UnresolvedType(UnresolvedTypeData), } #[derive(Debug, Clone, PartialEq, Eq, Display)] pub enum ExprValue { Expression(ExpressionKind), Statement(StatementKind), + LValue(LValue), } impl Value { @@ -83,6 +85,10 @@ impl Value { Value::Expr(ExprValue::Statement(statement)) } + pub(crate) fn lvalue(lvaue: LValue) -> Self { + Value::Expr(ExprValue::LValue(lvaue)) + } + pub(crate) fn get_type(&self) -> Cow { Cow::Owned(match self { Value::Unit => Type::Unit, @@ -128,6 +134,7 @@ impl Value { Value::Type(_) => Type::Quoted(QuotedType::Type), Value::Zeroed(typ) => return Cow::Borrowed(typ), Value::Expr(_) => Type::Quoted(QuotedType::Expr), + Value::UnresolvedType(_) => Type::Quoted(QuotedType::UnresolvedType), }) } @@ -254,7 +261,8 @@ impl Value { statements: vec![Statement { kind: statement, span: location.span }], }) } - Value::Pointer(..) + Value::Expr(ExprValue::LValue(_)) + | Value::Pointer(..) | Value::StructDefinition(_) | Value::TraitConstraint(..) | Value::TraitDefinition(_) @@ -262,6 +270,7 @@ impl Value { | Value::FunctionDefinition(_) | Value::Zeroed(_) | Value::Type(_) + | Value::UnresolvedType(_) | Value::ModuleDefinition(_) => { let typ = self.get_type().into_owned(); let value = self.display(interner).to_string(); @@ -386,6 +395,7 @@ impl Value { | Value::FunctionDefinition(_) | Value::Zeroed(_) | Value::Type(_) + | Value::UnresolvedType(_) | Value::ModuleDefinition(_) => { let typ = self.get_type().into_owned(); let value = self.display(interner).to_string(); @@ -546,7 +556,8 @@ impl<'value, 'interner> Display for ValuePrinter<'value, 'interner> { write!(f, "{}", def.name) } Value::TraitConstraint(trait_id, generics) => { - write!(f, "{}", display_trait_id_and_generics(self.interner, trait_id, generics)) + let trait_ = self.interner.get_trait(*trait_id); + write!(f, "{}{generics}", trait_.name) } Value::TraitDefinition(trait_id) => { let trait_ = self.interner.get_trait(*trait_id); @@ -588,29 +599,13 @@ impl<'value, 'interner> Display for ValuePrinter<'value, 'interner> { Value::Type(typ) => write!(f, "{}", typ), Value::Expr(ExprValue::Expression(expr)) => write!(f, "{}", expr), Value::Expr(ExprValue::Statement(statement)) => write!(f, "{}", statement), + Value::Expr(ExprValue::LValue(lvalue)) => write!(f, "{}", lvalue), + Value::UnresolvedType(typ) => write!(f, "{}", typ), } } } -fn display_trait_id_and_generics( - interner: &NodeInterner, - trait_id: &TraitId, - generics: &Vec, -) -> String { - let trait_ = interner.get_trait(*trait_id); - let generic_string = vecmap(generics, ToString::to_string).join(", "); - if generics.is_empty() { - format!("{}", trait_.name) - } else { - format!("{}<{generic_string}>", trait_.name) - } -} - fn display_trait_constraint(interner: &NodeInterner, trait_constraint: &TraitConstraint) -> String { - let trait_constraint_string = display_trait_id_and_generics( - interner, - &trait_constraint.trait_id, - &trait_constraint.trait_generics, - ); - format!("{}: {}", trait_constraint.typ, trait_constraint_string) + let trait_ = interner.get_trait(trait_constraint.trait_id); + format!("{}: {}{}", trait_constraint.typ, trait_.name, trait_constraint.trait_generics) } diff --git a/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs b/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs index 8c62eb431b..a961de628a 100644 --- a/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs +++ b/compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs @@ -12,16 +12,16 @@ use crate::{Generics, Type}; use crate::hir::resolution::import::{resolve_import, ImportDirective, PathResolution}; use crate::hir::Context; -use crate::macros_api::{MacroError, MacroProcessor}; +use crate::macros_api::{Expression, MacroError, MacroProcessor}; use crate::node_interner::{ FuncId, GlobalId, ModuleAttributes, NodeInterner, ReferenceId, StructId, TraitId, TraitImplId, TypeAliasId, }; use crate::ast::{ - ExpressionKind, Ident, LetStatement, Literal, NoirFunction, NoirStruct, NoirTrait, - NoirTypeAlias, Path, PathKind, PathSegment, UnresolvedGenerics, UnresolvedTraitConstraint, - UnresolvedType, + ExpressionKind, GenericTypeArgs, Ident, LetStatement, Literal, NoirFunction, NoirStruct, + NoirTrait, NoirTypeAlias, Path, PathKind, PathSegment, UnresolvedGenerics, + UnresolvedTraitConstraint, UnresolvedType, }; use crate::parser::{ParserError, SortedModule}; @@ -75,13 +75,16 @@ pub struct UnresolvedTrait { pub struct UnresolvedTraitImpl { pub file_id: FileId, pub module_id: LocalModuleId, - pub trait_generics: Vec, + pub trait_generics: GenericTypeArgs, pub trait_path: Path, pub object_type: UnresolvedType, pub methods: UnresolvedFunctions, pub generics: UnresolvedGenerics, pub where_clause: Vec, + pub associated_types: Vec<(Ident, UnresolvedType)>, + pub associated_constants: Vec<(Ident, UnresolvedType, Expression)>, + // Every field after this line is filled in later in the elaborator pub trait_id: Option, pub impl_id: Option, diff --git a/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs b/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs index 03ab9fa3a7..459c486937 100644 --- a/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs +++ b/compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs @@ -1,4 +1,5 @@ use std::path::Path; +use std::rc::Rc; use std::vec; use acvm::{AcirField, FieldElement}; @@ -13,7 +14,8 @@ use crate::ast::{ NoirStruct, NoirTrait, NoirTraitImpl, NoirTypeAlias, Pattern, TraitImplItem, TraitItem, TypeImpl, }; -use crate::macros_api::NodeInterner; +use crate::hir::resolution::errors::ResolverError; +use crate::macros_api::{Expression, NodeInterner, UnresolvedType, UnresolvedTypeData}; use crate::node_interner::ModuleAttributes; use crate::{ graph::CrateId, @@ -22,6 +24,7 @@ use crate::{ node_interner::{FunctionModifiers, TraitId, TypeAliasId}, parser::{SortedModule, SortedSubModule}, }; +use crate::{Generics, Kind, ResolvedGeneric, Type, TypeVariable}; use super::{ dc_crate::{ @@ -162,13 +165,14 @@ impl<'a> ModCollector<'a> { for mut trait_impl in impls { let trait_name = trait_impl.trait_name.clone(); - let mut unresolved_functions = collect_trait_impl_functions( - &mut context.def_interner, - &mut trait_impl, - krate, - self.file_id, - self.module_id, - ); + let (mut unresolved_functions, associated_types, associated_constants) = + collect_trait_impl_items( + &mut context.def_interner, + &mut trait_impl, + krate, + self.file_id, + self.module_id, + ); let module = ModuleId { krate, local_id: self.module_id }; @@ -186,6 +190,8 @@ impl<'a> ModCollector<'a> { generics: trait_impl.impl_generics, where_clause: trait_impl.where_clause, trait_generics: trait_impl.trait_generics, + associated_constants, + associated_types, // These last fields are filled later on trait_id: None, @@ -461,6 +467,8 @@ impl<'a> ModCollector<'a> { }; let mut method_ids = HashMap::default(); + let mut associated_types = Generics::new(); + for trait_item in &trait_definition.items { match trait_item { TraitItem::Function { @@ -521,7 +529,7 @@ impl<'a> ModCollector<'a> { } } } - TraitItem::Constant { name, .. } => { + TraitItem::Constant { name, typ, default_value: _ } => { let global_id = context.def_interner.push_empty_global( name.clone(), trait_id.0.local_id, @@ -542,10 +550,19 @@ impl<'a> ModCollector<'a> { second_def, }; errors.push((error.into(), self.file_id)); + } else { + let type_variable_id = context.def_interner.next_type_variable_id(); + let typ = self.resolve_associated_constant_type(typ, &mut errors); + + associated_types.push(ResolvedGeneric { + name: Rc::new(name.to_string()), + type_var: TypeVariable::unbound(type_variable_id), + kind: Kind::Numeric(Box::new(typ)), + span: name.span(), + }); } } TraitItem::Type { name } => { - // TODO(nickysn or alexvitkov): implement context.def_interner.push_empty_type_alias and get an id, instead of using TypeAliasId::dummy_id() if let Err((first_def, second_def)) = self.def_collector.def_map.modules [trait_id.0.local_id.0] .declare_type_alias(name.clone(), TypeAliasId::dummy_id()) @@ -556,6 +573,14 @@ impl<'a> ModCollector<'a> { second_def, }; errors.push((error.into(), self.file_id)); + } else { + let type_variable_id = context.def_interner.next_type_variable_id(); + associated_types.push(ResolvedGeneric { + name: Rc::new(name.to_string()), + type_var: TypeVariable::unbound(type_variable_id), + kind: Kind::Normal, + span: name.span(), + }); } } } @@ -564,7 +589,6 @@ impl<'a> ModCollector<'a> { let resolved_generics = context.resolve_generics(&trait_definition.generics, &mut errors, self.file_id); - // And store the TraitId -> TraitType mapping somewhere it is reachable let unresolved = UnresolvedTrait { file_id: self.file_id, module_id: self.module_id, @@ -573,7 +597,12 @@ impl<'a> ModCollector<'a> { method_ids, fns_with_default_impl: unresolved_functions, }; - context.def_interner.push_empty_trait(trait_id, &unresolved, resolved_generics); + context.def_interner.push_empty_trait( + trait_id, + &unresolved, + resolved_generics, + associated_types, + ); if context.def_interner.is_in_lsp_mode() { let parent_module_id = ModuleId { krate, local_id: self.module_id }; @@ -782,6 +811,23 @@ impl<'a> ModCollector<'a> { Ok(mod_id) } + + fn resolve_associated_constant_type( + &self, + typ: &UnresolvedType, + errors: &mut Vec<(CompilationError, FileId)>, + ) -> Type { + match &typ.typ { + UnresolvedTypeData::FieldElement => Type::FieldElement, + UnresolvedTypeData::Integer(sign, bits) => Type::Integer(*sign, *bits), + _ => { + let span = typ.span; + let error = ResolverError::AssociatedConstantsMustBeNumeric { span }; + errors.push((error.into(), self.file_id)); + Type::Error + } + } + } } fn find_module( @@ -870,28 +916,43 @@ fn is_native_field(str: &str) -> bool { } } -pub(crate) fn collect_trait_impl_functions( +type AssociatedTypes = Vec<(Ident, UnresolvedType)>; +type AssociatedConstants = Vec<(Ident, UnresolvedType, Expression)>; + +/// Returns a tuple of (methods, associated types, associated constants) +pub(crate) fn collect_trait_impl_items( interner: &mut NodeInterner, trait_impl: &mut NoirTraitImpl, krate: CrateId, file_id: FileId, local_id: LocalModuleId, -) -> UnresolvedFunctions { +) -> (UnresolvedFunctions, AssociatedTypes, AssociatedConstants) { let mut unresolved_functions = UnresolvedFunctions { file_id, functions: Vec::new(), trait_id: None, self_type: None }; + let mut associated_types = Vec::new(); + let mut associated_constants = Vec::new(); + let module = ModuleId { krate, local_id }; for item in std::mem::take(&mut trait_impl.items) { - if let TraitImplItem::Function(impl_method) = item { - let func_id = interner.push_empty_fn(); - let location = Location::new(impl_method.span(), file_id); - interner.push_function(func_id, &impl_method.def, module, location); - unresolved_functions.push_fn(local_id, func_id, impl_method); + match item { + TraitImplItem::Function(impl_method) => { + let func_id = interner.push_empty_fn(); + let location = Location::new(impl_method.span(), file_id); + interner.push_function(func_id, &impl_method.def, module, location); + unresolved_functions.push_fn(local_id, func_id, impl_method); + } + TraitImplItem::Constant(name, typ, expr) => { + associated_constants.push((name, typ, expr)); + } + TraitImplItem::Type { name, alias } => { + associated_types.push((name, alias)); + } } } - unresolved_functions + (unresolved_functions, associated_types, associated_constants) } pub(crate) fn collect_global( diff --git a/compiler/noirc_frontend/src/hir/def_collector/errors.rs b/compiler/noirc_frontend/src/hir/def_collector/errors.rs index 9e9471c0cb..e705d7b6fa 100644 --- a/compiler/noirc_frontend/src/hir/def_collector/errors.rs +++ b/compiler/noirc_frontend/src/hir/def_collector/errors.rs @@ -1,5 +1,6 @@ use crate::ast::{Ident, Path, UnresolvedTypeData}; use crate::hir::resolution::import::PathResolutionError; +use crate::hir::type_check::generics::TraitGenerics; use noirc_errors::CustomDiagnostic as Diagnostic; use noirc_errors::FileDiagnostic; @@ -76,7 +77,7 @@ pub enum DefCollectorErrorKind { ImplIsStricterThanTrait { constraint_typ: crate::Type, constraint_name: String, - constraint_generics: Vec, + constraint_generics: TraitGenerics, constraint_span: Span, trait_method_name: String, trait_method_span: Span, @@ -280,18 +281,11 @@ impl<'a> From<&'a DefCollectorErrorKind> for Diagnostic { ) } DefCollectorErrorKind::ImplIsStricterThanTrait { constraint_typ, constraint_name, constraint_generics, constraint_span, trait_method_name, trait_method_span } => { - let mut constraint_name_with_generics = constraint_name.to_owned(); - if !constraint_generics.is_empty() { - constraint_name_with_generics.push('<'); - for generic in constraint_generics.iter() { - constraint_name_with_generics.push_str(generic.to_string().as_str()); - } - constraint_name_with_generics.push('>'); - } + let constraint = format!("{}{}", constraint_name, constraint_generics); let mut diag = Diagnostic::simple_error( "impl has stricter requirements than trait".to_string(), - format!("impl has extra requirement `{constraint_typ}: {constraint_name_with_generics}`"), + format!("impl has extra requirement `{constraint_typ}: {constraint}`"), *constraint_span, ); diag.add_secondary(format!("definition of `{trait_method_name}` from trait"), *trait_method_span); diff --git a/compiler/noirc_frontend/src/hir/resolution/errors.rs b/compiler/noirc_frontend/src/hir/resolution/errors.rs index cfaa2063c4..0aad50d13b 100644 --- a/compiler/noirc_frontend/src/hir/resolution/errors.rs +++ b/compiler/noirc_frontend/src/hir/resolution/errors.rs @@ -58,8 +58,8 @@ pub enum ResolverError { NonStructWithGenerics { span: Span }, #[error("Cannot apply generics on Self type")] GenericsOnSelfType { span: Span }, - #[error("Incorrect amount of arguments to {item_name}")] - IncorrectGenericCount { span: Span, item_name: String, actual: usize, expected: usize }, + #[error("Cannot apply generics on an associated type")] + GenericsOnAssociatedType { span: Span }, #[error("{0}")] ParserError(Box), #[error("Cannot create a mutable reference to {variable}, it was declared to be immutable")] @@ -116,6 +116,10 @@ pub enum ResolverError { NonFunctionInAnnotation { span: Span }, #[error("Type `{typ}` was inserted into the generics list from a macro, but is not a generic")] MacroResultInGenericsListNotAGeneric { span: Span, typ: Type }, + #[error("Named type arguments aren't allowed in a {item_kind}")] + NamedTypeArgs { span: Span, item_kind: &'static str }, + #[error("Associated constants may only be a field or integer type")] + AssociatedConstantsMustBeNumeric { span: Span }, } impl ResolverError { @@ -281,16 +285,11 @@ impl<'a> From<&'a ResolverError> for Diagnostic { "Use an explicit type name or apply the generics at the start of the impl instead".into(), *span, ), - ResolverError::IncorrectGenericCount { span, item_name, actual, expected } => { - let expected_plural = if *expected == 1 { "" } else { "s" }; - let actual_plural = if *actual == 1 { "is" } else { "are" }; - - Diagnostic::simple_error( - format!("`{item_name}` has {expected} generic argument{expected_plural} but {actual} {actual_plural} given here"), - "Incorrect number of generic arguments".into(), - *span, - ) - } + ResolverError::GenericsOnAssociatedType { span } => Diagnostic::simple_error( + "Generic Associated Types (GATs) are currently unsupported in Noir".into(), + "Cannot apply generics to an associated type".into(), + *span, + ), ResolverError::ParserError(error) => error.as_ref().into(), ResolverError::MutableReferenceToImmutableVariable { variable, span } => { Diagnostic::simple_error(format!("Cannot mutably reference the immutable variable {variable}"), format!("{variable} is immutable"), *span) @@ -467,6 +466,20 @@ impl<'a> From<&'a ResolverError> for Diagnostic { *span, ) } + ResolverError::NamedTypeArgs { span, item_kind } => { + Diagnostic::simple_error( + format!("Named type arguments aren't allowed on a {item_kind}"), + "Named type arguments are only allowed for associated types on traits".to_string(), + *span, + ) + } + ResolverError::AssociatedConstantsMustBeNumeric { span } => { + Diagnostic::simple_error( + "Associated constants may only be a field or integer type".to_string(), + "Only numeric constants are allowed".to_string(), + *span, + ) + } } } } diff --git a/compiler/noirc_frontend/src/hir/resolution/import.rs b/compiler/noirc_frontend/src/hir/resolution/import.rs index 761da6c361..b820e4664e 100644 --- a/compiler/noirc_frontend/src/hir/resolution/import.rs +++ b/compiler/noirc_frontend/src/hir/resolution/import.rs @@ -374,9 +374,9 @@ fn resolve_external_dep( resolve_path_to_ns(&dep_directive, dep_module.krate, importing_crate, def_maps, path_references) } -// Issue an error if the given private function is being called from a non-child module, or -// if the given pub(crate) function is being called from another crate -fn can_reference_module_id( +// Returns false if the given private function is being called from a non-child module, or +// if the given pub(crate) function is being called from another crate. Otherwise returns true. +pub fn can_reference_module_id( def_maps: &BTreeMap, importing_crate: CrateId, current_module: LocalModuleId, diff --git a/compiler/noirc_frontend/src/hir/type_check/errors.rs b/compiler/noirc_frontend/src/hir/type_check/errors.rs index de5d146713..1764284375 100644 --- a/compiler/noirc_frontend/src/hir/type_check/errors.rs +++ b/compiler/noirc_frontend/src/hir/type_check/errors.rs @@ -1,5 +1,6 @@ +use std::rc::Rc; + use acvm::FieldElement; -use iter_extended::vecmap; use noirc_errors::CustomDiagnostic as Diagnostic; use noirc_errors::Span; use thiserror::Error; @@ -9,6 +10,7 @@ use crate::hir::resolution::errors::ResolverError; use crate::hir_def::expr::HirBinaryOp; use crate::hir_def::traits::TraitConstraint; use crate::hir_def::types::Type; +use crate::macros_api::Ident; use crate::macros_api::NodeInterner; #[derive(Error, Debug, Clone, PartialEq, Eq)] @@ -102,8 +104,12 @@ pub enum TypeCheckError { second_type: String, second_index: usize, }, - #[error("Cannot infer type of expression, type annotations needed before this point")] - TypeAnnotationsNeeded { span: Span }, + #[error("Object type is unknown in method call")] + TypeAnnotationsNeededForMethodCall { span: Span }, + #[error("Object type is unknown in field access")] + TypeAnnotationsNeededForFieldAccess { span: Span }, + #[error("Multiple trait impls may apply to this object type")] + MultipleMatchingImpls { object_type: Type, candidates: Vec, span: Span }, #[error("use of deprecated function {name}")] CallDeprecated { name: String, note: Option, span: Span }, #[error("{0}")] @@ -158,6 +164,16 @@ pub enum TypeCheckError { MacroReturningNonExpr { typ: Type, span: Span }, #[error("turbofish (`::<_>`) usage at this position isn't supported yet")] UnsupportedTurbofishUsage { span: Span }, + #[error("`{name}` has already been specified")] + DuplicateNamedTypeArg { name: Ident, prev_span: Span }, + #[error("`{item}` has no associated type named `{name}`")] + NoSuchNamedTypeArg { name: Ident, item: String }, + #[error("`{item}` is missing the associated type `{name}`")] + MissingNamedTypeArg { name: Rc, item: String, span: Span }, + #[error("Internal compiler error: type unspecified for value")] + UnspecifiedType { span: Span }, + #[error("Binding `{typ}` here to the `_` inside would create a cyclic type")] + CyclicType { typ: Type, span: Span }, } #[derive(Debug, Clone, PartialEq, Eq)] @@ -278,11 +294,33 @@ impl<'a> From<&'a TypeCheckError> for Diagnostic { format!("return type is {typ}"), *span, ), - TypeCheckError::TypeAnnotationsNeeded { span } => Diagnostic::simple_error( - "Expression type is ambiguous".to_string(), - "Type must be known at this point".to_string(), - *span, - ), + TypeCheckError::TypeAnnotationsNeededForMethodCall { span } => { + let mut error = Diagnostic::simple_error( + "Object type is unknown in method call".to_string(), + "Type must be known by this point to know which method to call".to_string(), + *span, + ); + error.add_note("Try adding a type annotation for the object type before this method call".to_string()); + error + }, + TypeCheckError::TypeAnnotationsNeededForFieldAccess { span } => { + let mut error = Diagnostic::simple_error( + "Object type is unknown in field access".to_string(), + "Type must be known by this point".to_string(), + *span, + ); + error.add_note("Try adding a type annotation for the object type before this expression".to_string()); + error + }, + TypeCheckError::MultipleMatchingImpls { object_type, candidates, span } => { + let message = format!("Multiple trait impls match the object type `{object_type}`"); + let secondary = "Ambiguous impl".to_string(); + let mut error = Diagnostic::simple_error(message, secondary, *span); + for (i, candidate) in candidates.iter().enumerate() { + error.add_note(format!("Candidate {}: `{candidate}`", i + 1)); + } + error + }, TypeCheckError::ResolverError(error) => error.into(), TypeCheckError::TypeMismatchWithSource { expected, actual, span, source } => { let message = match source { @@ -360,12 +398,32 @@ impl<'a> From<&'a TypeCheckError> for Diagnostic { let msg = "turbofish (`::<_>`) usage at this position isn't supported yet"; Diagnostic::simple_error(msg.to_string(), "".to_string(), *span) }, + TypeCheckError::DuplicateNamedTypeArg { name, prev_span } => { + let msg = format!("`{name}` has already been specified"); + let mut error = Diagnostic::simple_error(msg.to_string(), "".to_string(), name.span()); + error.add_secondary(format!("`{name}` previously specified here"), *prev_span); + error + }, + TypeCheckError::NoSuchNamedTypeArg { name, item } => { + let msg = format!("`{item}` has no associated type named `{name}`"); + Diagnostic::simple_error(msg.to_string(), "".to_string(), name.span()) + }, + TypeCheckError::MissingNamedTypeArg { name, item, span } => { + let msg = format!("`{item}` is missing the associated type `{name}`"); + Diagnostic::simple_error(msg.to_string(), "".to_string(), *span) + }, TypeCheckError::Unsafe { span } => { Diagnostic::simple_warning(error.to_string(), String::new(), *span) } TypeCheckError::UnsafeFn { span } => { Diagnostic::simple_warning(error.to_string(), String::new(), *span) } + TypeCheckError::UnspecifiedType { span } => { + Diagnostic::simple_error(error.to_string(), String::new(), *span) + } + TypeCheckError::CyclicType { typ: _, span } => { + Diagnostic::simple_error(error.to_string(), "Cyclic types have unlimited size and are prohibited in Noir".into(), *span) + } } } } @@ -404,11 +462,7 @@ impl NoMatchingImplFoundError { .into_iter() .map(|constraint| { let r#trait = interner.try_get_trait(constraint.trait_id)?; - let mut name = r#trait.name.to_string(); - if !constraint.trait_generics.is_empty() { - let generics = vecmap(&constraint.trait_generics, ToString::to_string); - name += &format!("<{}>", generics.join(", ")); - } + let name = format!("{}{}", r#trait.name, constraint.trait_generics); Some((constraint.typ, name)) }) .collect::>>()?; diff --git a/compiler/noirc_frontend/src/hir/type_check/generics.rs b/compiler/noirc_frontend/src/hir/type_check/generics.rs new file mode 100644 index 0000000000..379c53944e --- /dev/null +++ b/compiler/noirc_frontend/src/hir/type_check/generics.rs @@ -0,0 +1,165 @@ +use std::cell::Ref; + +use iter_extended::vecmap; + +use crate::{ + hir_def::traits::NamedType, + macros_api::NodeInterner, + node_interner::{TraitId, TypeAliasId}, + ResolvedGeneric, StructType, Type, +}; + +/// Represents something that can be generic over type variables +/// such as a trait, struct type, or type alias. +/// +/// Used primarily by `Elaborator::resolve_type_args` so that we can +/// have one function to do this for struct types, type aliases, traits, etc. +pub trait Generic { + /// The name of this kind of item, for error messages. E.g. "trait", "struct type". + fn item_kind(&self) -> &'static str; + + /// The name of this item, usually named by a user. E.g. "Foo" for "struct Foo {}" + fn item_name(&self, interner: &NodeInterner) -> String; + + /// Each ordered generic on this type, excluding any named generics. + fn generics(&self, interner: &NodeInterner) -> Vec; + + /// True if this item kind can ever accept named type arguments. + /// Currently, this is only true for traits. Structs & aliases can never have named args. + fn accepts_named_type_args(&self) -> bool; + + fn named_generics(&self, interner: &NodeInterner) -> Vec; +} + +impl Generic for TraitId { + fn item_kind(&self) -> &'static str { + "trait" + } + + fn item_name(&self, interner: &NodeInterner) -> String { + interner.get_trait(*self).name.to_string() + } + + fn generics(&self, interner: &NodeInterner) -> Vec { + interner.get_trait(*self).generics.clone() + } + + fn accepts_named_type_args(&self) -> bool { + true + } + + fn named_generics(&self, interner: &NodeInterner) -> Vec { + interner.get_trait(*self).associated_types.clone() + } +} + +impl Generic for TypeAliasId { + fn item_kind(&self) -> &'static str { + "type alias" + } + + fn item_name(&self, interner: &NodeInterner) -> String { + interner.get_type_alias(*self).borrow().name.to_string() + } + + fn generics(&self, interner: &NodeInterner) -> Vec { + interner.get_type_alias(*self).borrow().generics.clone() + } + + fn accepts_named_type_args(&self) -> bool { + false + } + + fn named_generics(&self, _interner: &NodeInterner) -> Vec { + Vec::new() + } +} + +impl Generic for Ref<'_, StructType> { + fn item_kind(&self) -> &'static str { + "struct" + } + + fn item_name(&self, _interner: &NodeInterner) -> String { + self.name.to_string() + } + + fn generics(&self, _interner: &NodeInterner) -> Vec { + self.generics.clone() + } + + fn accepts_named_type_args(&self) -> bool { + false + } + + fn named_generics(&self, _interner: &NodeInterner) -> Vec { + Vec::new() + } +} + +/// TraitGenerics are different from regular generics in that they can +/// also contain associated type arguments. +#[derive(Default, PartialEq, Eq, Clone, Hash, Ord, PartialOrd)] +pub struct TraitGenerics { + pub ordered: Vec, + pub named: Vec, +} + +impl TraitGenerics { + pub fn map(&self, mut f: impl FnMut(&Type) -> Type) -> TraitGenerics { + let ordered = vecmap(&self.ordered, &mut f); + let named = + vecmap(&self.named, |named| NamedType { name: named.name.clone(), typ: f(&named.typ) }); + TraitGenerics { ordered, named } + } +} + +impl std::fmt::Display for TraitGenerics { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fmt_trait_generics(self, f, false) + } +} + +impl std::fmt::Debug for TraitGenerics { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fmt_trait_generics(self, f, true) + } +} + +fn fmt_trait_generics( + generics: &TraitGenerics, + f: &mut std::fmt::Formatter<'_>, + debug: bool, +) -> std::fmt::Result { + if !generics.ordered.is_empty() || !generics.named.is_empty() { + write!(f, "<")?; + for (i, typ) in generics.ordered.iter().enumerate() { + if i != 0 { + write!(f, ", ")?; + } + + if debug { + write!(f, "{typ:?}")?; + } else { + write!(f, "{typ}")?; + } + } + + if !generics.ordered.is_empty() && !generics.named.is_empty() { + write!(f, ", ")?; + } + + for (i, named) in generics.named.iter().enumerate() { + if i != 0 { + write!(f, ", ")?; + } + + if debug { + write!(f, "{} = {:?}", named.name, named.typ)?; + } else { + write!(f, "{} = {}", named.name, named.typ)?; + } + } + } + Ok(()) +} diff --git a/compiler/noirc_frontend/src/hir/type_check/mod.rs b/compiler/noirc_frontend/src/hir/type_check/mod.rs index b6efa17a52..f45b68dd81 100644 --- a/compiler/noirc_frontend/src/hir/type_check/mod.rs +++ b/compiler/noirc_frontend/src/hir/type_check/mod.rs @@ -1,12 +1,5 @@ -//! This file contains type_check_func, the entry point to the type checking pass (for each function). -//! -//! The pass structure of type checking is relatively straightforward. It is a single pass through -//! the HIR of each function and outputs the inferred type of each HIR node into the NodeInterner, -//! keyed by the ID of the node. -//! -//! Although this algorithm features inference via TypeVariables, there is no generalization step -//! as all functions are required to give their full signatures. Closures are inferred but are -//! never generalized and thus cannot be used polymorphically. mod errors; +pub mod generics; + pub use self::errors::Source; pub use errors::{NoMatchingImplFoundError, TypeCheckError}; diff --git a/compiler/noirc_frontend/src/hir_def/expr.rs b/compiler/noirc_frontend/src/hir_def/expr.rs index 8137e74da3..40c16d0035 100644 --- a/compiler/noirc_frontend/src/hir_def/expr.rs +++ b/compiler/noirc_frontend/src/hir_def/expr.rs @@ -3,6 +3,7 @@ use fm::FileId; use noirc_errors::Location; use crate::ast::{BinaryOp, BinaryOpKind, Ident, UnaryOp}; +use crate::hir::type_check::generics::TraitGenerics; use crate::node_interner::{DefinitionId, ExprId, FuncId, NodeInterner, StmtId, TraitMethodId}; use crate::token::Tokens; use crate::Shared; @@ -199,7 +200,7 @@ pub enum HirMethodReference { /// Or a method can come from a Trait impl block, in which case /// the actual function called will depend on the instantiated type, /// which can be only known during monomorphization. - TraitMethodId(TraitMethodId, /*trait generics:*/ Vec), + TraitMethodId(TraitMethodId, TraitGenerics), } impl HirMethodCallExpression { @@ -208,7 +209,7 @@ impl HirMethodCallExpression { /// Returns ((func_var_id, func_var), call_expr) pub fn into_function_call( mut self, - method: &HirMethodReference, + method: HirMethodReference, object_type: Type, is_macro_call: bool, location: Location, @@ -219,17 +220,17 @@ impl HirMethodCallExpression { let (id, impl_kind) = match method { HirMethodReference::FuncId(func_id) => { - (interner.function_definition_id(*func_id), ImplKind::NotATraitMethod) + (interner.function_definition_id(func_id), ImplKind::NotATraitMethod) } - HirMethodReference::TraitMethodId(method_id, generics) => { - let id = interner.trait_method_id(*method_id); + HirMethodReference::TraitMethodId(method_id, trait_generics) => { + let id = interner.trait_method_id(method_id); let constraint = TraitConstraint { typ: object_type, trait_id: method_id.trait_id, - trait_generics: generics.clone(), + trait_generics, span: location.span, }; - (id, ImplKind::TraitMethod(*method_id, constraint, false)) + (id, ImplKind::TraitMethod(method_id, constraint, false)) } }; let func_var = HirIdent { location, id, impl_kind }; diff --git a/compiler/noirc_frontend/src/hir_def/traits.rs b/compiler/noirc_frontend/src/hir_def/traits.rs index 9d820b9443..0572ba403a 100644 --- a/compiler/noirc_frontend/src/hir_def/traits.rs +++ b/compiler/noirc_frontend/src/hir_def/traits.rs @@ -1,12 +1,14 @@ +use iter_extended::vecmap; use rustc_hash::FxHashMap as HashMap; use crate::ast::{Ident, NoirFunction}; -use crate::TypeVariableId; +use crate::hir::type_check::generics::TraitGenerics; use crate::{ graph::CrateId, node_interner::{FuncId, TraitId, TraitMethodId}, Generics, Type, TypeBindings, TypeVariable, }; +use crate::{ResolvedGeneric, TypeVariableKind}; use fm::FileId; use noirc_errors::{Location, Span}; @@ -24,15 +26,20 @@ pub struct TraitFunction { #[derive(Clone, Debug, PartialEq, Eq)] pub struct TraitConstant { pub name: Ident, - pub ty: Type, + pub typ: Type, pub span: Span, } -#[derive(Clone, Debug, PartialEq, Eq)] -pub struct TraitType { +#[derive(Clone, Debug, PartialEq, Eq, Hash, Ord, PartialOrd)] +pub struct NamedType { pub name: Ident, - pub ty: Type, - pub span: Span, + pub typ: Type, +} + +impl std::fmt::Display for NamedType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{} = {}", self.name, self.typ) + } } /// Represents a trait in the type system. Each instance of this struct @@ -54,8 +61,7 @@ pub struct Trait { /// the information needed to create the full TraitFunction. pub method_ids: HashMap, - pub constants: Vec, - pub types: Vec, + pub associated_types: Generics, pub name: Ident, pub generics: Generics, @@ -65,7 +71,6 @@ pub struct Trait { /// to this TypeVariable. Then when we check if the types of trait impl elements /// match the definition in the trait, we bind this TypeVariable to whatever /// the correct Self type is for that particular impl block. - pub self_type_typevar_id: TypeVariableId, pub self_type_typevar: TypeVariable, } @@ -74,7 +79,15 @@ pub struct TraitImpl { pub ident: Ident, pub typ: Type, pub trait_id: TraitId, + + /// Any ordered type arguments on the trait this impl is for. + /// E.g. `A, B` in `impl Foo for Bar` + /// + /// Note that named arguments (associated types) are stored separately + /// in the NodeInterner. This is because they're required to resolve types + /// before the impl as a whole is finished resolving. pub trait_generics: Vec, + pub file: FileId, pub methods: Vec, // methods[i] is the implementation of trait.methods[i] for Type typ @@ -89,21 +102,21 @@ pub struct TraitImpl { pub struct TraitConstraint { pub typ: Type, pub trait_id: TraitId, - pub trait_generics: Vec, + pub trait_generics: TraitGenerics, pub span: Span, } impl TraitConstraint { - pub fn new(typ: Type, trait_id: TraitId, trait_generics: Vec, span: Span) -> Self { - Self { typ, trait_id, trait_generics, span } - } - pub fn apply_bindings(&mut self, type_bindings: &TypeBindings) { self.typ = self.typ.substitute(type_bindings); - for typ in &mut self.trait_generics { + for typ in &mut self.trait_generics.ordered { *typ = typ.substitute(type_bindings); } + + for named in &mut self.trait_generics.named { + named.typ = named.typ.substitute(type_bindings); + } } } @@ -132,6 +145,35 @@ impl Trait { } None } + + pub fn get_associated_type(&self, last_name: &str) -> Option<&ResolvedGeneric> { + self.associated_types.iter().find(|typ| typ.name.as_ref() == last_name) + } + + /// Returns both the ordered generics of this type, and its named, associated types. + /// These types are all as-is and are not instantiated. + pub fn get_generics(&self) -> (Vec, Vec) { + let ordered = vecmap(&self.generics, |generic| generic.clone().as_named_generic()); + let named = vecmap(&self.associated_types, |generic| generic.clone().as_named_generic()); + (ordered, named) + } + + /// Returns a TraitConstraint for this trait using Self as the object + /// type and the uninstantiated generics for any trait generics. + pub fn as_constraint(&self, span: Span) -> TraitConstraint { + let ordered = vecmap(&self.generics, |generic| generic.clone().as_named_generic()); + let named = vecmap(&self.associated_types, |generic| { + let name = Ident::new(generic.name.to_string(), span); + NamedType { name, typ: generic.clone().as_named_generic() } + }); + + TraitConstraint { + typ: Type::TypeVariable(self.self_type_typevar.clone(), TypeVariableKind::Normal), + trait_generics: TraitGenerics { ordered, named }, + trait_id: self.id, + span, + } + } } impl std::fmt::Display for Trait { diff --git a/compiler/noirc_frontend/src/hir_def/types.rs b/compiler/noirc_frontend/src/hir_def/types.rs index d6d114c707..807666f9af 100644 --- a/compiler/noirc_frontend/src/hir_def/types.rs +++ b/compiler/noirc_frontend/src/hir_def/types.rs @@ -7,7 +7,7 @@ use std::{ use crate::{ ast::IntegerBitSize, - hir::type_check::TypeCheckError, + hir::type_check::{generics::TraitGenerics, TypeCheckError}, node_interner::{ExprId, NodeInterner, TraitId, TypeAliasId}, }; use iter_extended::vecmap; @@ -19,7 +19,10 @@ use crate::{ node_interner::StructId, }; -use super::expr::{HirCallExpression, HirExpression, HirIdent}; +use super::{ + expr::{HirCallExpression, HirExpression, HirIdent}, + traits::NamedType, +}; #[derive(PartialEq, Eq, Clone, Hash, Ord, PartialOrd)] pub enum Type { @@ -78,7 +81,7 @@ pub enum Type { /// `impl Trait` when used in a type position. /// These are only matched based on the TraitId. The trait name parameter is only /// used for displaying error messages using the name of the trait. - TraitAsType(TraitId, /*name:*/ Rc, /*generics:*/ Vec), + TraitAsType(TraitId, Rc, TraitGenerics), /// NamedGenerics are the 'T' or 'U' in a user-defined generic function /// like `fn foo(...) {}`. Unlike TypeVariables, they cannot be bound over. @@ -152,6 +155,7 @@ pub enum QuotedType { TraitConstraint, TraitDefinition, TraitImpl, + UnresolvedType, FunctionDefinition, Module, } @@ -540,7 +544,7 @@ impl TypeVariable { }; if binding.occurs(id) { - Err(TypeCheckError::TypeAnnotationsNeeded { span }) + Err(TypeCheckError::CyclicType { span, typ: binding }) } else { *self.1.borrow_mut() = TypeBinding::Bound(binding); Ok(()) @@ -646,12 +650,7 @@ impl std::fmt::Display for Type { } } Type::TraitAsType(_id, name, generics) => { - write!(f, "impl {}", name)?; - if !generics.is_empty() { - let generics = vecmap(generics, ToString::to_string).join(", "); - write!(f, "<{generics}>")?; - } - Ok(()) + write!(f, "impl {}{}", name, generics) } Type::Tuple(elements) => { let elements = vecmap(elements, ToString::to_string); @@ -744,6 +743,7 @@ impl std::fmt::Display for QuotedType { QuotedType::TraitDefinition => write!(f, "TraitDefinition"), QuotedType::TraitConstraint => write!(f, "TraitConstraint"), QuotedType::TraitImpl => write!(f, "TraitImpl"), + QuotedType::UnresolvedType => write!(f, "UnresolvedType"), QuotedType::FunctionDefinition => write!(f, "FunctionDefinition"), QuotedType::Module => write!(f, "Module"), } @@ -860,8 +860,9 @@ impl Type { | Type::Forall(_, _) | Type::Quoted(_) => false, - Type::TraitAsType(_, _, args) => { - args.iter().any(|generic| generic.contains_numeric_typevar(target_id)) + Type::TraitAsType(_, _, generics) => { + generics.ordered.iter().any(|generic| generic.contains_numeric_typevar(target_id)) + || generics.named.iter().any(|typ| typ.typ.contains_numeric_typevar(target_id)) } Type::Array(length, elem) => { elem.contains_numeric_typevar(target_id) || named_generic_id_matches_target(length) @@ -936,9 +937,12 @@ impl Type { } Type::TraitAsType(_, _, args) => { - for arg in args.iter() { + for arg in args.ordered.iter() { arg.find_numeric_type_vars(found_names); } + for arg in args.named.iter() { + arg.typ.find_numeric_type_vars(found_names); + } } Type::Array(length, elem) => { elem.find_numeric_type_vars(found_names); @@ -1629,6 +1633,15 @@ impl Type { } else { Err(UnificationError) } + } else if let InfixExpr(lhs, op, rhs) = other { + if let Some(inverse) = op.inverse() { + // Handle cases like `4 = a + b` by trying to solve to `a = 4 - b` + let new_type = InfixExpr(Box::new(Constant(*value)), inverse, rhs.clone()); + new_type.try_unify(lhs, bindings)?; + Ok(()) + } else { + Err(UnificationError) + } } else { Err(UnificationError) } @@ -1656,12 +1669,17 @@ impl Type { pub fn canonicalize(&self) -> Type { match self.follow_bindings() { Type::InfixExpr(lhs, op, rhs) => { - if let Some(value) = self.evaluate_to_u32() { - return Type::Constant(value); + // evaluate_to_u32 also calls canonicalize so if we just called + // `self.evaluate_to_u32()` we'd get infinite recursion. + if let (Some(lhs), Some(rhs)) = (lhs.evaluate_to_u32(), rhs.evaluate_to_u32()) { + return Type::Constant(op.function(lhs, rhs)); } let lhs = lhs.canonicalize(); let rhs = rhs.canonicalize(); + if let Some(result) = Self::try_simplify_addition(&lhs, op, &rhs) { + return result; + } if let Some(result) = Self::try_simplify_subtraction(&lhs, op, &rhs) { return result; @@ -1719,6 +1737,26 @@ impl Type { } } + /// Try to simplify an addition expression of `lhs + rhs`. + /// + /// - Simplifies `(a - b) + b` to `a`. + fn try_simplify_addition(lhs: &Type, op: BinaryTypeOperator, rhs: &Type) -> Option { + use BinaryTypeOperator::*; + match lhs { + Type::InfixExpr(l_lhs, l_op, l_rhs) => { + if op == Addition && *l_op == Subtraction { + // TODO: Propagate type bindings. Can do in another PR, this one is large enough. + let unifies = l_rhs.try_unify(rhs, &mut TypeBindings::new()); + if unifies.is_ok() { + return Some(l_lhs.as_ref().clone()); + } + } + None + } + _ => None, + } + } + /// Try to simplify a subtraction expression of `lhs - rhs`. /// /// - Simplifies `(a + C1) - C2` to `a + (C1 - C2)` if C1 and C2 are constants. @@ -1881,10 +1919,10 @@ impl Type { } } - match self { - Type::TypeVariable(_, TypeVariableKind::Constant(size)) => Some(*size), + match self.canonicalize() { + Type::TypeVariable(_, TypeVariableKind::Constant(size)) => Some(size), Type::Array(len, _elem) => len.evaluate_to_u32(), - Type::Constant(x) => Some(*x), + Type::Constant(x) => Some(x), Type::InfixExpr(lhs, op, rhs) => { let lhs = lhs.evaluate_to_u32()?; let rhs = rhs.evaluate_to_u32()?; @@ -1919,11 +1957,13 @@ impl Type { /// Retrieves the type of the given field name /// Panics if the type is not a struct or tuple. pub fn get_field_type(&self, field_name: &str) -> Option { - match self { - Type::Struct(def, args) => def.borrow().get_field(field_name, args).map(|(typ, _)| typ), + match self.follow_bindings() { + Type::Struct(def, args) => { + def.borrow().get_field(field_name, &args).map(|(typ, _)| typ) + } Type::Tuple(fields) => { - let mut fields = fields.iter().enumerate(); - fields.find(|(i, _)| i.to_string() == *field_name).map(|(_, typ)| typ).cloned() + let mut fields = fields.into_iter().enumerate(); + fields.find(|(i, _)| i.to_string() == *field_name).map(|(_, typ)| typ) } _ => None, } @@ -2145,11 +2185,15 @@ impl Type { element.substitute_helper(type_bindings, substitute_bound_typevars), )), - Type::TraitAsType(s, name, args) => { - let args = vecmap(args, |arg| { + Type::TraitAsType(s, name, generics) => { + let ordered = vecmap(&generics.ordered, |arg| { arg.substitute_helper(type_bindings, substitute_bound_typevars) }); - Type::TraitAsType(*s, name.clone(), args) + let named = vecmap(&generics.named, |arg| { + let typ = arg.typ.substitute_helper(type_bindings, substitute_bound_typevars); + NamedType { name: arg.name.clone(), typ } + }); + Type::TraitAsType(*s, name.clone(), TraitGenerics { ordered, named }) } Type::InfixExpr(lhs, op, rhs) => { let lhs = lhs.substitute_helper(type_bindings, substitute_bound_typevars); @@ -2178,11 +2222,13 @@ impl Type { let field_occurs = fields.occurs(target_id); len_occurs || field_occurs } - Type::Struct(_, generic_args) - | Type::Alias(_, generic_args) - | Type::TraitAsType(_, _, generic_args) => { + Type::Struct(_, generic_args) | Type::Alias(_, generic_args) => { generic_args.iter().any(|arg| arg.occurs(target_id)) } + Type::TraitAsType(_, _, args) => { + args.ordered.iter().any(|arg| arg.occurs(target_id)) + || args.named.iter().any(|arg| arg.typ.occurs(target_id)) + } Type::Tuple(fields) => fields.iter().any(|field| field.occurs(target_id)), Type::NamedGeneric(type_var, _, _) | Type::TypeVariable(type_var, _) => { match &*type_var.borrow() { @@ -2259,8 +2305,12 @@ impl Type { MutableReference(element) => MutableReference(Box::new(element.follow_bindings())), TraitAsType(s, name, args) => { - let args = vecmap(args, |arg| arg.follow_bindings()); - TraitAsType(*s, name.clone(), args) + let ordered = vecmap(&args.ordered, |arg| arg.follow_bindings()); + let named = vecmap(&args.named, |arg| NamedType { + name: arg.name.clone(), + typ: arg.typ.follow_bindings(), + }); + TraitAsType(*s, name.clone(), TraitGenerics { ordered, named }) } InfixExpr(lhs, op, rhs) => { let lhs = lhs.follow_bindings(); @@ -2329,9 +2379,12 @@ impl Type { } } Type::TraitAsType(_, _, generics) => { - for generic in generics { + for generic in &mut generics.ordered { generic.replace_named_generics_with_type_variables(); } + for generic in &mut generics.named { + generic.typ.replace_named_generics_with_type_variables(); + } } Type::NamedGeneric(var, _, _) => { let type_binding = var.borrow(); @@ -2417,6 +2470,17 @@ impl BinaryTypeOperator { fn is_commutative(self) -> bool { matches!(self, BinaryTypeOperator::Addition | BinaryTypeOperator::Multiplication) } + + /// Return the operator that will "undo" this operation if applied to the rhs + fn inverse(self) -> Option { + match self { + BinaryTypeOperator::Addition => Some(BinaryTypeOperator::Subtraction), + BinaryTypeOperator::Subtraction => Some(BinaryTypeOperator::Addition), + BinaryTypeOperator::Multiplication => Some(BinaryTypeOperator::Division), + BinaryTypeOperator::Division => Some(BinaryTypeOperator::Multiplication), + BinaryTypeOperator::Modulo => None, + } + } } impl TypeVariableKind { @@ -2485,7 +2549,7 @@ impl From<&Type> for PrintableType { PrintableType::Struct { fields, name: struct_type.name.to_string() } } Type::Alias(alias, args) => alias.borrow().get_type(args).into(), - Type::TraitAsType(_, _, _) => unreachable!(), + Type::TraitAsType(..) => unreachable!(), Type::Tuple(types) => PrintableType::Tuple { types: vecmap(types, |typ| typ.into()) }, Type::TypeVariable(_, _) => unreachable!(), Type::NamedGeneric(..) => unreachable!(), @@ -2547,14 +2611,7 @@ impl std::fmt::Debug for Type { write!(f, "{}<{}>", alias.borrow(), args.join(", ")) } } - Type::TraitAsType(_id, name, generics) => { - write!(f, "impl {}", name)?; - if !generics.is_empty() { - let generics = vecmap(generics, |arg| format!("{:?}", arg)).join(", "); - write!(f, "<{generics}>")?; - } - Ok(()) - } + Type::TraitAsType(_id, name, generics) => write!(f, "impl {}{:?}", name, generics), Type::Tuple(elements) => { let elements = vecmap(elements, |arg| format!("{:?}", arg)); write!(f, "({})", elements.join(", ")) diff --git a/compiler/noirc_frontend/src/lexer/token.rs b/compiler/noirc_frontend/src/lexer/token.rs index b9b4bdef9a..8ee0fca295 100644 --- a/compiler/noirc_frontend/src/lexer/token.rs +++ b/compiler/noirc_frontend/src/lexer/token.rs @@ -930,6 +930,7 @@ pub enum Keyword { TypeType, Unchecked, Unconstrained, + UnresolvedType, Unsafe, Use, Where, @@ -984,6 +985,7 @@ impl fmt::Display for Keyword { Keyword::TypeType => write!(f, "Type"), Keyword::Unchecked => write!(f, "unchecked"), Keyword::Unconstrained => write!(f, "unconstrained"), + Keyword::UnresolvedType => write!(f, "UnresolvedType"), Keyword::Unsafe => write!(f, "unsafe"), Keyword::Use => write!(f, "use"), Keyword::Where => write!(f, "where"), @@ -1041,6 +1043,7 @@ impl Keyword { "StructDefinition" => Keyword::StructDefinition, "unchecked" => Keyword::Unchecked, "unconstrained" => Keyword::Unconstrained, + "UnresolvedType" => Keyword::UnresolvedType, "unsafe" => Keyword::Unsafe, "use" => Keyword::Use, "where" => Keyword::Where, diff --git a/compiler/noirc_frontend/src/monomorphization/ast.rs b/compiler/noirc_frontend/src/monomorphization/ast.rs index f7bcfd58b7..eb6b4bf7bd 100644 --- a/compiler/noirc_frontend/src/monomorphization/ast.rs +++ b/compiler/noirc_frontend/src/monomorphization/ast.rs @@ -1,3 +1,5 @@ +use std::fmt::Display; + use acvm::FieldElement; use iter_extended::vecmap; use noirc_errors::{ @@ -67,6 +69,12 @@ pub struct LocalId(pub u32); #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] pub struct FuncId(pub u32); +impl Display for FuncId { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + #[derive(Debug, Clone, Hash)] pub struct Ident { pub location: Option, @@ -357,16 +365,12 @@ impl Program { FuncId(0) } - pub fn take_main_body(&mut self) -> Expression { - self.take_function_body(FuncId(0)) - } - /// Takes a function body by replacing it with `false` and /// returning the previous value pub fn take_function_body(&mut self, function: FuncId) -> Expression { - let main = &mut self.functions[function.0 as usize]; - let replacement = Expression::Literal(Literal::Bool(false)); - std::mem::replace(&mut main.body, replacement) + let function_definition = &mut self[function]; + let replacement = Expression::Block(vec![]); + std::mem::replace(&mut function_definition.body, replacement) } } diff --git a/compiler/noirc_frontend/src/monomorphization/errors.rs b/compiler/noirc_frontend/src/monomorphization/errors.rs index df61c138c0..665bf26f7b 100644 --- a/compiler/noirc_frontend/src/monomorphization/errors.rs +++ b/compiler/noirc_frontend/src/monomorphization/errors.rs @@ -1,11 +1,11 @@ use noirc_errors::{CustomDiagnostic, FileDiagnostic, Location}; -use crate::hir::comptime::InterpreterError; +use crate::{hir::comptime::InterpreterError, Type}; #[derive(Debug)] pub enum MonomorphizationError { - UnknownArrayLength { location: Location }, - TypeAnnotationsNeeded { location: Location }, + UnknownArrayLength { length: Type, location: Location }, + NoDefaultType { location: Location }, InternalError { message: &'static str, location: Location }, InterpreterError(InterpreterError), } @@ -13,9 +13,9 @@ pub enum MonomorphizationError { impl MonomorphizationError { fn location(&self) -> Location { match self { - MonomorphizationError::UnknownArrayLength { location } + MonomorphizationError::UnknownArrayLength { location, .. } | MonomorphizationError::InternalError { location, .. } - | MonomorphizationError::TypeAnnotationsNeeded { location } => *location, + | MonomorphizationError::NoDefaultType { location, .. } => *location, MonomorphizationError::InterpreterError(error) => error.get_location(), } } @@ -32,16 +32,20 @@ impl From for FileDiagnostic { impl MonomorphizationError { fn into_diagnostic(self) -> CustomDiagnostic { - let message = match self { - MonomorphizationError::UnknownArrayLength { .. } => { - "Length of generic array could not be determined." + let message = match &self { + MonomorphizationError::UnknownArrayLength { length, .. } => { + format!("ICE: Could not determine array length `{length}`") } - MonomorphizationError::TypeAnnotationsNeeded { .. } => "Type annotations needed", - MonomorphizationError::InterpreterError(error) => return (&error).into(), - MonomorphizationError::InternalError { message, .. } => message, + MonomorphizationError::NoDefaultType { location } => { + let message = "Type annotation needed".into(); + let secondary = "Could not determine type of generic argument".into(); + return CustomDiagnostic::simple_error(message, secondary, location.span); + } + MonomorphizationError::InterpreterError(error) => return error.into(), + MonomorphizationError::InternalError { message, .. } => message.to_string(), }; let location = self.location(); - CustomDiagnostic::simple_error(message.into(), String::new(), location.span) + CustomDiagnostic::simple_error(message, String::new(), location.span) } } diff --git a/compiler/noirc_frontend/src/monomorphization/mod.rs b/compiler/noirc_frontend/src/monomorphization/mod.rs index 510b81d9ac..edb831b215 100644 --- a/compiler/noirc_frontend/src/monomorphization/mod.rs +++ b/compiler/noirc_frontend/src/monomorphization/mod.rs @@ -11,7 +11,7 @@ use crate::ast::{FunctionKind, IntegerBitSize, Signedness, UnaryOp, Visibility}; use crate::hir::comptime::InterpreterError; use crate::hir::type_check::NoMatchingImplFoundError; -use crate::node_interner::ExprId; +use crate::node_interner::{ExprId, ImplSearchErrorKind}; use crate::{ debug::DebugInstrumenter, hir_def::{ @@ -569,7 +569,7 @@ impl<'interner> Monomorphizer<'interner> { let length = length.evaluate_to_u32().ok_or_else(|| { let location = self.interner.expr_location(&array); - MonomorphizationError::UnknownArrayLength { location } + MonomorphizationError::UnknownArrayLength { location, length } })?; let contents = try_vecmap(0..length, |_| self.expr(repeated_element))?; @@ -936,7 +936,10 @@ impl<'interner> Monomorphizer<'interner> { let element = Box::new(Self::convert_type(element.as_ref(), location)?); let length = match length.evaluate_to_u32() { Some(length) => length, - None => return Err(MonomorphizationError::TypeAnnotationsNeeded { location }), + None => { + let length = length.as_ref().clone(); + return Err(MonomorphizationError::UnknownArrayLength { location, length }); + } }; ast::Type::Array(length, element) } @@ -969,7 +972,7 @@ impl<'interner> Monomorphizer<'interner> { // and within a larger generic type. let default = match kind.default_type() { Some(typ) => typ, - None => return Err(MonomorphizationError::TypeAnnotationsNeeded { location }), + None => return Err(MonomorphizationError::NoDefaultType { location }), }; let monomorphized_default = Self::convert_type(&default, location)?; @@ -1029,11 +1032,12 @@ impl<'interner> Monomorphizer<'interner> { ast::Type::MutableReference(Box::new(element)) } - HirType::Forall(_, _) - | HirType::Constant(_) - | HirType::InfixExpr(..) - | HirType::Error => { - unreachable!("Unexpected type {} found", typ) + HirType::Forall(_, _) | HirType::Constant(_) | HirType::InfixExpr(..) => { + unreachable!("Unexpected type {typ} found") + } + HirType::Error => { + let message = "Unexpected Type::Error found during monomorphization"; + return Err(MonomorphizationError::InternalError { message, location }); } HirType::Quoted(_) => unreachable!("Tried to translate Code type into runtime code"), }) @@ -1073,7 +1077,7 @@ impl<'interner> Monomorphizer<'interner> { // and within a larger generic type. let default = match kind.default_type() { Some(typ) => typ, - None => return Err(MonomorphizationError::TypeAnnotationsNeeded { location }), + None => return Err(MonomorphizationError::NoDefaultType { location }), }; Self::check_type(&default, location) @@ -1945,28 +1949,37 @@ pub fn resolve_trait_method( let impl_id = match trait_impl { TraitImplKind::Normal(impl_id) => impl_id, TraitImplKind::Assumed { object_type, trait_generics } => { + let location = interner.expr_location(&expr_id); match interner.lookup_trait_implementation( &object_type, method.trait_id, - &trait_generics, + &trait_generics.ordered, + &trait_generics.named, ) { Ok(TraitImplKind::Normal(impl_id)) => impl_id, Ok(TraitImplKind::Assumed { .. }) => { - let location = interner.expr_location(&expr_id); return Err(InterpreterError::NoImpl { location }); } - Err(constraints) => { - let location = interner.expr_location(&expr_id); + Err(ImplSearchErrorKind::TypeAnnotationsNeededOnObjectType) => { + return Err(InterpreterError::TypeAnnotationsNeededForMethodCall { location }); + } + Err(ImplSearchErrorKind::Nested(constraints)) => { if let Some(error) = NoMatchingImplFoundError::new(interner, constraints, location.span) { let file = location.file; return Err(InterpreterError::NoMatchingImplFound { error, file }); } else { - let location = interner.expr_location(&expr_id); return Err(InterpreterError::NoImpl { location }); } } + Err(ImplSearchErrorKind::MultipleMatching(candidates)) => { + return Err(InterpreterError::MultipleMatchingImpls { + object_type, + location, + candidates, + }); + } } } }; diff --git a/compiler/noirc_frontend/src/monomorphization/printer.rs b/compiler/noirc_frontend/src/monomorphization/printer.rs index 9b1eeecdc1..b6421b26a0 100644 --- a/compiler/noirc_frontend/src/monomorphization/printer.rs +++ b/compiler/noirc_frontend/src/monomorphization/printer.rs @@ -19,7 +19,7 @@ impl AstPrinter { write!( f, "fn {}$f{}({}) -> {} {{", - function.name, function.id.0, params, function.return_type + function.name, function.id, params, function.return_type )?; self.indent_level += 1; self.print_expr_expect_block(&function.body, f)?; @@ -291,7 +291,7 @@ impl Display for Definition { fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { match self { Definition::Local(id) => write!(f, "l{}", id.0), - Definition::Function(id) => write!(f, "f{}", id.0), + Definition::Function(id) => write!(f, "f{}", id), Definition::Builtin(name) => write!(f, "{name}"), Definition::LowLevel(name) => write!(f, "{name}"), Definition::Oracle(name) => write!(f, "{name}"), diff --git a/compiler/noirc_frontend/src/node_interner.rs b/compiler/noirc_frontend/src/node_interner.rs index f9c6921ac8..2c0426f693 100644 --- a/compiler/noirc_frontend/src/node_interner.rs +++ b/compiler/noirc_frontend/src/node_interner.rs @@ -19,6 +19,8 @@ use crate::hir::comptime; use crate::hir::def_collector::dc_crate::CompilationError; use crate::hir::def_collector::dc_crate::{UnresolvedStruct, UnresolvedTrait, UnresolvedTypeAlias}; use crate::hir::def_map::{LocalModuleId, ModuleId}; +use crate::hir::type_check::generics::TraitGenerics; +use crate::hir_def::traits::NamedType; use crate::macros_api::ModuleDefId; use crate::macros_api::UnaryOp; use crate::QuotedType; @@ -133,6 +135,11 @@ pub struct NodeInterner { next_trait_implementation_id: usize, + /// The associated types for each trait impl. + /// This is stored outside of the TraitImpl object since it is required before that object is + /// created, when resolving the type signature of each method in the impl. + trait_impl_associated_types: HashMap>, + /// Trait implementations on each type. This is expected to always have the same length as /// `self.trait_implementations`. /// @@ -307,10 +314,17 @@ pub enum TraitImplKind { /// /// The reference `Into::into(x)` would have inferred generics, but /// `x.into()` with a `X: Into` in scope would not. - trait_generics: Vec, + trait_generics: TraitGenerics, }, } +/// When searching for a trait impl, these are the types of errors we can expect +pub enum ImplSearchErrorKind { + TypeAnnotationsNeededOnObjectType, + Nested(Vec), + MultipleMatching(Vec), +} + /// Represents the methods on a given type that each share the same name. /// /// Methods are split into inherent methods and trait methods. If there is @@ -610,6 +624,7 @@ impl Default for NodeInterner { reference_modules: HashMap::default(), auto_import_names: HashMap::default(), comptime_scopes: vec![HashMap::default()], + trait_impl_associated_types: HashMap::default(), } } } @@ -651,21 +666,18 @@ impl NodeInterner { type_id: TraitId, unresolved_trait: &UnresolvedTrait, generics: Generics, + associated_types: Generics, ) { - let self_type_typevar_id = self.next_type_variable_id(); - let new_trait = Trait { id: type_id, name: unresolved_trait.trait_def.name.clone(), crate_id: unresolved_trait.crate_id, location: Location::new(unresolved_trait.trait_def.span, unresolved_trait.file_id), generics, - self_type_typevar_id, - self_type_typevar: TypeVariable::unbound(self_type_typevar_id), + self_type_typevar: TypeVariable::unbound(self.next_type_variable_id()), methods: Vec::new(), method_ids: unresolved_trait.method_ids.clone(), - constants: Vec::new(), - types: Vec::new(), + associated_types, }; self.traits.insert(type_id, new_trait); @@ -1377,9 +1389,14 @@ impl NodeInterner { object_type: &Type, trait_id: TraitId, trait_generics: &[Type], - ) -> Result> { - let (impl_kind, bindings) = - self.try_lookup_trait_implementation(object_type, trait_id, trait_generics)?; + trait_associated_types: &[NamedType], + ) -> Result { + let (impl_kind, bindings) = self.try_lookup_trait_implementation( + object_type, + trait_id, + trait_generics, + trait_associated_types, + )?; Type::apply_type_bindings(bindings); Ok(impl_kind) @@ -1394,23 +1411,14 @@ impl NodeInterner { ) -> Vec<&TraitImplKind> { let trait_impl = self.trait_implementation_map.get(&trait_id); - trait_impl - .map(|trait_impl| { - trait_impl - .iter() - .filter_map(|(typ, impl_kind)| match &typ { - Type::Forall(_, typ) => { - if typ.deref() == object_type { - Some(impl_kind) - } else { - None - } - } - _ => None, - }) - .collect() - }) - .unwrap_or_default() + let trait_impl = trait_impl.map(|trait_impl| { + let impls = trait_impl.iter().filter_map(|(typ, impl_kind)| match &typ { + Type::Forall(_, typ) => (typ.deref() == object_type).then_some(impl_kind), + _ => None, + }); + impls.collect() + }); + trait_impl.unwrap_or_default() } /// Similar to `lookup_trait_implementation` but does not apply any type bindings on success. @@ -1423,12 +1431,14 @@ impl NodeInterner { object_type: &Type, trait_id: TraitId, trait_generics: &[Type], - ) -> Result<(TraitImplKind, TypeBindings), Vec> { + trait_associated_types: &[NamedType], + ) -> Result<(TraitImplKind, TypeBindings), ImplSearchErrorKind> { let mut bindings = TypeBindings::new(); let impl_kind = self.lookup_trait_implementation_helper( object_type, trait_id, trait_generics, + trait_associated_types, &mut bindings, IMPL_SEARCH_RECURSION_LIMIT, )?; @@ -1445,63 +1455,73 @@ impl NodeInterner { object_type: &Type, trait_id: TraitId, trait_generics: &[Type], + trait_associated_types: &[NamedType], type_bindings: &mut TypeBindings, recursion_limit: u32, - ) -> Result> { + ) -> Result { let make_constraint = || { - TraitConstraint::new( - object_type.clone(), + let ordered = trait_generics.to_vec(); + let named = trait_associated_types.to_vec(); + TraitConstraint { + typ: object_type.clone(), trait_id, - trait_generics.to_vec(), - Span::default(), - ) + trait_generics: TraitGenerics { ordered, named }, + span: Span::default(), + } }; + let nested_error = || ImplSearchErrorKind::Nested(vec![make_constraint()]); + // Prevent infinite recursion when looking for impls if recursion_limit == 0 { - return Err(vec![make_constraint()]); + return Err(nested_error()); } let object_type = object_type.substitute(type_bindings); // If the object type isn't known, just return an error saying type annotations are needed. if object_type.is_bindable() { - return Err(Vec::new()); + return Err(ImplSearchErrorKind::TypeAnnotationsNeededOnObjectType); } - let impls = - self.trait_implementation_map.get(&trait_id).ok_or_else(|| vec![make_constraint()])?; + let impls = self.trait_implementation_map.get(&trait_id).ok_or_else(nested_error)?; let mut matching_impls = Vec::new(); + let mut where_clause_error = None; - let mut where_clause_errors = Vec::new(); - - for (existing_object_type2, impl_kind) in impls { + for (existing_object_type, impl_kind) in impls { // Bug: We're instantiating only the object type's generics here, not all of the trait's generics like we need to let (existing_object_type, instantiation_bindings) = - existing_object_type2.instantiate(self); + existing_object_type.instantiate(self); let mut fresh_bindings = type_bindings.clone(); - let mut check_trait_generics = |impl_generics: &[Type]| { - trait_generics.iter().zip(impl_generics).all(|(trait_generic, impl_generic2)| { - let impl_generic = impl_generic2.substitute(&instantiation_bindings); - trait_generic.try_unify(&impl_generic, &mut fresh_bindings).is_ok() - }) - }; + let mut check_trait_generics = + |impl_generics: &[Type], impl_associated_types: &[NamedType]| { + trait_generics.iter().zip(impl_generics).all(|(trait_generic, impl_generic)| { + let impl_generic = impl_generic.force_substitute(&instantiation_bindings); + trait_generic.try_unify(&impl_generic, &mut fresh_bindings).is_ok() + }) && trait_associated_types.iter().zip(impl_associated_types).all( + |(trait_generic, impl_generic)| { + let impl_generic2 = + impl_generic.typ.force_substitute(&instantiation_bindings); + trait_generic.typ.try_unify(&impl_generic2, &mut fresh_bindings).is_ok() + }, + ) + }; - let generics_match = match impl_kind { + let trait_generics = match impl_kind { TraitImplKind::Normal(id) => { let shared_impl = self.get_trait_implementation(*id); let shared_impl = shared_impl.borrow(); - check_trait_generics(&shared_impl.trait_generics) - } - TraitImplKind::Assumed { trait_generics, .. } => { - check_trait_generics(trait_generics) + let named = self.get_associated_types_for_impl(*id).to_vec(); + let ordered = shared_impl.trait_generics.clone(); + TraitGenerics { named, ordered } } + TraitImplKind::Assumed { trait_generics, .. } => trait_generics.clone(), }; - if !generics_match { + if !check_trait_generics(&trait_generics.ordered, &trait_generics.named) { continue; } @@ -1510,34 +1530,48 @@ impl NodeInterner { let trait_impl = self.get_trait_implementation(*impl_id); let trait_impl = trait_impl.borrow(); - if let Err(errors) = self.validate_where_clause( + if let Err(error) = self.validate_where_clause( &trait_impl.where_clause, &mut fresh_bindings, &instantiation_bindings, recursion_limit, ) { // Only keep the first errors we get from a failing where clause - if where_clause_errors.is_empty() { - where_clause_errors.extend(errors); + if where_clause_error.is_none() { + where_clause_error = Some(error); } continue; } } - matching_impls.push((impl_kind.clone(), fresh_bindings)); + let constraint = TraitConstraint { + typ: existing_object_type, + trait_id, + trait_generics, + span: Span::default(), + }; + matching_impls.push((impl_kind.clone(), fresh_bindings, constraint)); } } if matching_impls.len() == 1 { - let (impl_, fresh_bindings) = matching_impls.pop().unwrap(); + let (impl_, fresh_bindings, _) = matching_impls.pop().unwrap(); *type_bindings = fresh_bindings; Ok(impl_) } else if matching_impls.is_empty() { - where_clause_errors.push(make_constraint()); - Err(where_clause_errors) + let mut errors = match where_clause_error { + Some((_, ImplSearchErrorKind::Nested(errors))) => errors, + Some((constraint, _other)) => vec![constraint], + None => vec![], + }; + errors.push(make_constraint()); + Err(ImplSearchErrorKind::Nested(errors)) } else { - // multiple matching impls, type annotations needed - Err(vec![]) + let impls = vecmap(matching_impls, |(_, _, constraint)| { + let name = &self.get_trait(constraint.trait_id).name; + format!("{}: {name}{}", constraint.typ, constraint.trait_generics) + }); + Err(ImplSearchErrorKind::MultipleMatching(impls)) } } @@ -1549,27 +1583,36 @@ impl NodeInterner { type_bindings: &mut TypeBindings, instantiation_bindings: &TypeBindings, recursion_limit: u32, - ) -> Result<(), Vec> { + ) -> Result<(), (TraitConstraint, ImplSearchErrorKind)> { for constraint in where_clause { // Instantiation bindings are generally safe to force substitute into the same type. // This is needed here to undo any bindings done to trait methods by monomorphization. - // Otherwise, an impl for (A, B) could get narrowed to only an impl for e.g. (u8, u16). + // Otherwise, an impl for any (A, B) could get narrowed to only an impl for e.g. (u8, u16). let constraint_type = constraint.typ.force_substitute(instantiation_bindings).substitute(type_bindings); - let trait_generics = vecmap(&constraint.trait_generics, |generic| { + let trait_generics = vecmap(&constraint.trait_generics.ordered, |generic| { generic.force_substitute(instantiation_bindings).substitute(type_bindings) }); + let trait_associated_types = vecmap(&constraint.trait_generics.named, |generic| { + let typ = generic.typ.force_substitute(instantiation_bindings); + NamedType { name: generic.name.clone(), typ: typ.substitute(type_bindings) } + }); + + // We can ignore any associated types on the constraint since those should not affect + // which impl we choose. self.lookup_trait_implementation_helper( &constraint_type, constraint.trait_id, &trait_generics, + &trait_associated_types, // Use a fresh set of type bindings here since the constraint_type originates from // our impl list, which we don't want to bind to. type_bindings, recursion_limit - 1, - )?; + ) + .map_err(|error| (constraint.clone(), error))?; } Ok(()) @@ -1587,10 +1630,16 @@ impl NodeInterner { &mut self, object_type: Type, trait_id: TraitId, - trait_generics: Vec, + trait_generics: TraitGenerics, ) -> bool { // Make sure there are no overlapping impls - if self.try_lookup_trait_implementation(&object_type, trait_id, &trait_generics).is_ok() { + let existing = self.try_lookup_trait_implementation( + &object_type, + trait_id, + &trait_generics.ordered, + &trait_generics.named, + ); + if existing.is_ok() { return false; } @@ -1604,7 +1653,6 @@ impl NodeInterner { &mut self, object_type: Type, trait_id: TraitId, - trait_generics: Vec, impl_id: TraitImplId, impl_generics: GenericTypeVars, trait_impl: Shared, @@ -1626,6 +1674,9 @@ impl NodeInterner { let instantiated_object_type = object_type.substitute(&substitutions); + let trait_generics = &trait_impl.borrow().trait_generics; + let associated_types = self.get_associated_types_for_impl(impl_id); + // Ignoring overlapping `TraitImplKind::Assumed` impls here is perfectly fine. // It should never happen since impls are defined at global scope, but even // if they were, we should never prevent defining a new impl because a 'where' @@ -1633,7 +1684,8 @@ impl NodeInterner { if let Ok((TraitImplKind::Normal(existing), _)) = self.try_lookup_trait_implementation( &instantiated_object_type, trait_id, - &trait_generics, + trait_generics, + associated_types, ) { let existing_impl = self.get_trait_implementation(existing); let existing_impl = existing_impl.borrow(); @@ -2021,6 +2073,59 @@ impl NodeInterner { pub fn is_in_lsp_mode(&self) -> bool { self.lsp_mode } + + pub fn set_associated_types_for_impl( + &mut self, + impl_id: TraitImplId, + associated_types: Vec, + ) { + self.trait_impl_associated_types.insert(impl_id, associated_types); + } + + pub fn get_associated_types_for_impl(&self, impl_id: TraitImplId) -> &[NamedType] { + &self.trait_impl_associated_types[&impl_id] + } + + pub fn find_associated_type_for_impl( + &self, + impl_id: TraitImplId, + type_name: &str, + ) -> Option<&Type> { + let types = self.trait_impl_associated_types.get(&impl_id)?; + types.iter().find(|typ| typ.name.0.contents == type_name).map(|typ| &typ.typ) + } + + /// Return a set of TypeBindings to bind types from the parent trait to those from the trait impl. + pub fn trait_to_impl_bindings( + &self, + trait_id: TraitId, + impl_id: TraitImplId, + trait_impl_generics: &[Type], + impl_self_type: Type, + ) -> TypeBindings { + let mut bindings = TypeBindings::new(); + let the_trait = self.get_trait(trait_id); + let trait_generics = the_trait.generics.clone(); + + let self_type_var = the_trait.self_type_typevar.clone(); + bindings.insert(self_type_var.id(), (self_type_var, impl_self_type)); + + for (trait_generic, trait_impl_generic) in trait_generics.iter().zip(trait_impl_generics) { + let type_var = trait_generic.type_var.clone(); + bindings.insert(type_var.id(), (type_var, trait_impl_generic.clone())); + } + + // Now that the normal bindings are added, we still need to bind the associated types + let impl_associated_types = self.get_associated_types_for_impl(impl_id); + let trait_associated_types = &the_trait.associated_types; + + for (trait_type, impl_type) in trait_associated_types.iter().zip(impl_associated_types) { + let type_variable = trait_type.type_var.clone(); + bindings.insert(type_variable.id(), (type_variable, impl_type.typ.clone())); + } + + bindings + } } impl Methods { diff --git a/compiler/noirc_frontend/src/parser/errors.rs b/compiler/noirc_frontend/src/parser/errors.rs index ebb58ddc22..2e38d7ae83 100644 --- a/compiler/noirc_frontend/src/parser/errors.rs +++ b/compiler/noirc_frontend/src/parser/errors.rs @@ -64,6 +64,10 @@ pub enum ParserErrorReason { ForbiddenNumericGenericType, #[error("Invalid call data identifier, must be a number. E.g `call_data(0)`")] InvalidCallDataIdentifier, + #[error("Associated types are not allowed in paths")] + AssociatedTypesNotAllowedInPaths, + #[error("Associated types are not allowed on a method call")] + AssociatedTypesNotAllowedInMethodCalls, } /// Represents a parsing error, or a parsing error in the making. diff --git a/compiler/noirc_frontend/src/parser/parser.rs b/compiler/noirc_frontend/src/parser/parser.rs index b86c2c46c9..56c80ee1ce 100644 --- a/compiler/noirc_frontend/src/parser/parser.rs +++ b/compiler/noirc_frontend/src/parser/parser.rs @@ -36,9 +36,9 @@ use super::{ }; use super::{spanned, Item, ItemKind}; use crate::ast::{ - BinaryOp, BinaryOpKind, BlockExpression, ForLoopStatement, ForRange, Ident, IfExpression, - InfixExpression, LValue, Literal, ModuleDeclaration, NoirTypeAlias, Param, Path, Pattern, - Recoverable, Statement, TypeImpl, UnaryRhsMemberAccess, UnaryRhsMethodCall, UseTree, + BinaryOp, BinaryOpKind, BlockExpression, ForLoopStatement, ForRange, GenericTypeArgs, Ident, + IfExpression, InfixExpression, LValue, Literal, ModuleDeclaration, NoirTypeAlias, Param, Path, + Pattern, Recoverable, Statement, TypeImpl, UnaryRhsMemberAccess, UnaryRhsMethodCall, UseTree, UseTreeKind, Visibility, }; use crate::ast::{ @@ -333,7 +333,9 @@ fn self_parameter() -> impl NoirParser { .map(|(pattern_keyword, ident_span)| { let ident = Ident::new("self".to_string(), ident_span); let path = Path::from_single("Self".to_owned(), ident_span); - let mut self_type = UnresolvedTypeData::Named(path, vec![], true).with_span(ident_span); + let no_args = GenericTypeArgs::default(); + let mut self_type = + UnresolvedTypeData::Named(path, no_args, true).with_span(ident_span); let mut pattern = Pattern::Identifier(ident); match pattern_keyword { @@ -902,10 +904,15 @@ where let method_call_rhs = turbofish .then(just(Token::Bang).or_not()) .then(parenthesized(expression_list(expr_parser.clone()))) - .map(|((turbofish, macro_call), args)| UnaryRhsMethodCall { - turbofish, - macro_call: macro_call.is_some(), - args, + .validate(|((turbofish, macro_call), args), span, emit| { + if turbofish.as_ref().map_or(false, |generics| !generics.named_args.is_empty()) { + let reason = ParserErrorReason::AssociatedTypesNotAllowedInMethodCalls; + emit(ParserError::with_reason(reason, span)); + } + + let macro_call = macro_call.is_some(); + let turbofish = turbofish.map(|generics| generics.ordered_args); + UnaryRhsMethodCall { turbofish, macro_call, args } }); // `.foo` or `.foo(args)` in `atom.foo` or `atom.foo(args)` diff --git a/compiler/noirc_frontend/src/parser/parser/path.rs b/compiler/noirc_frontend/src/parser/parser/path.rs index ae3a1bc0b9..ea121c6f6d 100644 --- a/compiler/noirc_frontend/src/parser/parser/path.rs +++ b/compiler/noirc_frontend/src/parser/parser/path.rs @@ -7,6 +7,7 @@ use chumsky::prelude::*; use super::keyword; use super::primitives::{ident, path_segment, path_segment_no_turbofish}; +use super::types::generic_type_args; pub(super) fn path<'a>( type_parser: impl NoirParser + 'a, @@ -54,14 +55,16 @@ pub(super) fn as_trait_path<'a>( just(Token::Less) .ignore_then(type_parser.clone()) .then_ignore(keyword(Keyword::As)) - .then(path(type_parser)) + .then(path(type_parser.clone())) + .then(generic_type_args(type_parser)) .then_ignore(just(Token::Greater)) .then_ignore(just(Token::DoubleColon)) .then(ident()) - .validate(|((typ, trait_path), impl_item), span, emit| { - let reason = ParserErrorReason::ExperimentalFeature("Fully qualified trait impl paths"); - emit(ParserError::with_reason(reason, span)); - AsTraitPath { typ, trait_path, impl_item } + .map(|(((typ, trait_path), trait_generics), impl_item)| AsTraitPath { + typ, + trait_path, + trait_generics, + impl_item, }) } diff --git a/compiler/noirc_frontend/src/parser/parser/primitives.rs b/compiler/noirc_frontend/src/parser/parser/primitives.rs index 25f693bf50..9145fb945c 100644 --- a/compiler/noirc_frontend/src/parser/parser/primitives.rs +++ b/compiler/noirc_frontend/src/parser/parser/primitives.rs @@ -1,7 +1,8 @@ use chumsky::prelude::*; -use crate::ast::{ExpressionKind, Ident, PathSegment, UnaryOp}; +use crate::ast::{ExpressionKind, GenericTypeArgs, Ident, PathSegment, UnaryOp}; use crate::macros_api::UnresolvedType; +use crate::parser::ParserErrorReason; use crate::{ parser::{labels::ParsingRuleLabel, ExprParser, NoirParser, ParserError}, token::{Keyword, Token, TokenKind}, @@ -36,10 +37,14 @@ pub(super) fn token_kind(token_kind: TokenKind) -> impl NoirParser { pub(super) fn path_segment<'a>( type_parser: impl NoirParser + 'a, ) -> impl NoirParser + 'a { - ident().then(turbofish(type_parser)).map_with_span(|(ident, generics), span| PathSegment { - ident, - generics, - span, + ident().then(turbofish(type_parser)).validate(|(ident, generics), span, emit| { + if generics.as_ref().map_or(false, |generics| !generics.named_args.is_empty()) { + let reason = ParserErrorReason::AssociatedTypesNotAllowedInPaths; + emit(ParserError::with_reason(reason, span)); + } + + let generics = generics.map(|generics| generics.ordered_args); + PathSegment { ident, generics, span } }) } @@ -95,7 +100,7 @@ where pub(super) fn turbofish<'a>( type_parser: impl NoirParser + 'a, -) -> impl NoirParser>> + 'a { +) -> impl NoirParser> + 'a { just(Token::DoubleColon).ignore_then(required_generic_type_args(type_parser)).or_not() } diff --git a/compiler/noirc_frontend/src/parser/parser/traits.rs b/compiler/noirc_frontend/src/parser/parser/traits.rs index 0cf5e63f5f..bf5a4b4d0b 100644 --- a/compiler/noirc_frontend/src/parser/parser/traits.rs +++ b/compiler/noirc_frontend/src/parser/parser/traits.rs @@ -111,15 +111,10 @@ fn trait_function_declaration() -> impl NoirParser { /// trait_type_declaration: 'type' ident generics fn trait_type_declaration() -> impl NoirParser { - keyword(Keyword::Type).ignore_then(ident()).then_ignore(just(Token::Semicolon)).validate( - |name, span, emit| { - emit(ParserError::with_reason( - ParserErrorReason::ExperimentalFeature("Associated types"), - span, - )); - TraitItem::Type { name } - }, - ) + keyword(Keyword::Type) + .ignore_then(ident()) + .then_ignore(just(Token::Semicolon)) + .map(|name| TraitItem::Type { name }) } /// Parses a trait implementation, implementing a particular trait for a type. diff --git a/compiler/noirc_frontend/src/parser/parser/types.rs b/compiler/noirc_frontend/src/parser/parser/types.rs index cb7271a416..c655ab8c5a 100644 --- a/compiler/noirc_frontend/src/parser/parser/types.rs +++ b/compiler/noirc_frontend/src/parser/parser/types.rs @@ -1,11 +1,12 @@ use super::path::{as_trait_path, path_no_turbofish}; -use super::primitives::token_kind; +use super::primitives::{ident, token_kind}; use super::{ expression_with_precedence, keyword, nothing, parenthesized, NoirParser, ParserError, ParserErrorReason, Precedence, }; use crate::ast::{ - Expression, Recoverable, UnresolvedType, UnresolvedTypeData, UnresolvedTypeExpression, + Expression, GenericTypeArg, GenericTypeArgs, Recoverable, UnresolvedType, UnresolvedTypeData, + UnresolvedTypeExpression, }; use crate::QuotedType; @@ -27,16 +28,7 @@ pub(super) fn parse_type_inner<'a>( int_type(), bool_type(), string_type(), - expr_type(), - struct_definition_type(), - trait_constraint_type(), - trait_definition_type(), - trait_impl_type(), - function_definition_type(), - module_type(), - top_level_item_type(), - type_of_quoted_types(), - quoted_type(), + comptime_type(), resolved_type(), format_string_type(recursive_type_parser.clone()), named_type(recursive_type_parser.clone()), @@ -82,6 +74,22 @@ pub(super) fn bool_type() -> impl NoirParser { keyword(Keyword::Bool).map_with_span(|_, span| UnresolvedTypeData::Bool.with_span(span)) } +pub(super) fn comptime_type() -> impl NoirParser { + choice(( + expr_type(), + struct_definition_type(), + trait_constraint_type(), + trait_definition_type(), + trait_impl_type(), + unresolved_type_type(), + function_definition_type(), + module_type(), + type_of_quoted_types(), + top_level_item_type(), + quoted_type(), + )) +} + /// This is the type `Expr` - the type of a quoted, untyped expression object used for macros pub(super) fn expr_type() -> impl NoirParser { keyword(Keyword::Expr) @@ -113,6 +121,12 @@ pub(super) fn trait_impl_type() -> impl NoirParser { .map_with_span(|_, span| UnresolvedTypeData::Quoted(QuotedType::TraitImpl).with_span(span)) } +pub(super) fn unresolved_type_type() -> impl NoirParser { + keyword(Keyword::UnresolvedType).map_with_span(|_, span| { + UnresolvedTypeData::Quoted(QuotedType::UnresolvedType).with_span(span) + }) +} + pub(super) fn function_definition_type() -> impl NoirParser { keyword(Keyword::FunctionDefinition).map_with_span(|_, span| { UnresolvedTypeData::Quoted(QuotedType::FunctionDefinition).with_span(span) @@ -213,25 +227,37 @@ pub(super) fn named_trait<'a>( pub(super) fn generic_type_args<'a>( type_parser: impl NoirParser + 'a, -) -> impl NoirParser> + 'a { +) -> impl NoirParser + 'a { required_generic_type_args(type_parser).or_not().map(Option::unwrap_or_default) } pub(super) fn required_generic_type_args<'a>( type_parser: impl NoirParser + 'a, -) -> impl NoirParser> + 'a { - type_parser +) -> impl NoirParser + 'a { + let generic_type_arg = type_parser .clone() + .then_ignore(one_of([Token::Comma, Token::Greater]).rewind()) + .or(type_expression_validated()); + + let named_arg = ident() + .then_ignore(just(Token::Assign)) + .then(generic_type_arg.clone()) + .map(|(name, typ)| GenericTypeArg::Named(name, typ)); + + // We need to parse named arguments first since otherwise when we see + // `Foo = Bar`, just `Foo` is a valid type, and we'd parse an ordered + // generic before erroring that an `=` is invalid after an ordered generic. + choice((named_arg, generic_type_arg.map(GenericTypeArg::Ordered))) + .boxed() // Without checking for a terminating ',' or '>' here we may incorrectly // parse a generic `N * 2` as just the type `N` then fail when there is no // separator afterward. Failing early here ensures we try the `type_expression` // parser afterward. - .then_ignore(one_of([Token::Comma, Token::Greater]).rewind()) - .or(type_expression_validated()) .separated_by(just(Token::Comma)) .allow_trailing() .at_least(1) .delimited_by(just(Token::Less), just(Token::Greater)) + .map(GenericTypeArgs::from) } pub(super) fn array_type<'a>( diff --git a/compiler/noirc_frontend/src/tests.rs b/compiler/noirc_frontend/src/tests.rs index bba596ed19..cc4aae7f44 100644 --- a/compiler/noirc_frontend/src/tests.rs +++ b/compiler/noirc_frontend/src/tests.rs @@ -2227,7 +2227,7 @@ fn impl_stricter_than_trait_different_trait_generics() { { assert!(matches!(constraint_typ.to_string().as_str(), "A")); assert!(matches!(constraint_name.as_str(), "T2")); - assert!(matches!(constraint_generics[0].to_string().as_str(), "B")); + assert!(matches!(constraint_generics.ordered[0].to_string().as_str(), "B")); } else { panic!("Expected DefCollectorErrorKind::ImplIsStricterThanTrait but got {:?}", errors[0].0); } @@ -2889,16 +2889,14 @@ fn incorrect_generic_count_on_struct_impl() { let errors = get_program_errors(src); assert_eq!(errors.len(), 1); - let CompilationError::ResolverError(ResolverError::IncorrectGenericCount { - actual, - expected, - .. + let CompilationError::TypeError(TypeCheckError::GenericCountMismatch { + found, expected, .. }) = errors[0].0 else { panic!("Expected an incorrect generic count mismatch error, got {:?}", errors[0].0); }; - assert_eq!(actual, 1); + assert_eq!(found, 1); assert_eq!(expected, 0); } @@ -2913,16 +2911,14 @@ fn incorrect_generic_count_on_type_alias() { let errors = get_program_errors(src); assert_eq!(errors.len(), 1); - let CompilationError::ResolverError(ResolverError::IncorrectGenericCount { - actual, - expected, - .. + let CompilationError::TypeError(TypeCheckError::GenericCountMismatch { + found, expected, .. }) = errors[0].0 else { panic!("Expected an incorrect generic count mismatch error, got {:?}", errors[0].0); }; - assert_eq!(actual, 1); + assert_eq!(found, 1); assert_eq!(expected, 0); } @@ -3114,3 +3110,80 @@ fn trait_impl_for_a_type_that_implements_another_trait_with_another_impl_used() "#; assert_no_errors(src); } + +#[test] +fn impl_missing_associated_type() { + let src = r#" + trait Foo { + type Assoc; + } + + impl Foo for () {} + "#; + + let errors = get_program_errors(src); + assert_eq!(errors.len(), 1); + + assert!(matches!( + &errors[0].0, + CompilationError::TypeError(TypeCheckError::MissingNamedTypeArg { .. }) + )); +} + +#[test] +fn as_trait_path_syntax_resolves_outside_impl() { + let src = r#" + trait Foo { + type Assoc; + } + + struct Bar {} + + impl Foo for Bar { + type Assoc = i32; + } + + fn main() { + // AsTraitPath syntax is a bit silly when associated types + // are explicitly specified + let _: i64 = 1 as >::Assoc; + } + "#; + + let errors = get_program_errors(src); + assert_eq!(errors.len(), 1); + + use CompilationError::TypeError; + use TypeCheckError::TypeMismatch; + let TypeError(TypeMismatch { expected_typ, expr_typ, .. }) = errors[0].0.clone() else { + panic!("Expected TypeMismatch error, found {:?}", errors[0].0); + }; + + assert_eq!(expected_typ, "i64".to_string()); + assert_eq!(expr_typ, "i32".to_string()); +} + +#[test] +fn as_trait_path_syntax_no_impl() { + let src = r#" + trait Foo { + type Assoc; + } + + struct Bar {} + + impl Foo for Bar { + type Assoc = i32; + } + + fn main() { + let _: i64 = 1 as >::Assoc; + } + "#; + + let errors = get_program_errors(src); + assert_eq!(errors.len(), 1); + + use CompilationError::TypeError; + assert!(matches!(&errors[0].0, TypeError(TypeCheckError::NoMatchingImplFound { .. }))); +} diff --git a/cspell.json b/cspell.json index e98b11f2e8..293523d7c1 100644 --- a/cspell.json +++ b/cspell.json @@ -20,6 +20,7 @@ "barretenberg", "barustenberg", "bbup", + "bignum", "bincode", "bindgen", "bitand", @@ -128,6 +129,7 @@ "memfs", "memset", "merkle", + "metaprogramming", "metas", "microcontroller", "minreq", @@ -168,6 +170,7 @@ "pseudocode", "pubkey", "quantile", + "quasiquote", "rangemap", "repr", "reqwest", @@ -206,6 +209,8 @@ "typevar", "typevars", "udiv", + "umap", + "underconstrained", "uninstantiated", "unnormalized", "unoptimized", @@ -218,6 +223,7 @@ "wasi", "wasmer", "Weierstraß", + "zkhash", "zshell", "Linea" ], diff --git a/docs/docs/noir/concepts/comptime.md b/docs/docs/noir/concepts/comptime.md new file mode 100644 index 0000000000..2b5c29538b --- /dev/null +++ b/docs/docs/noir/concepts/comptime.md @@ -0,0 +1,262 @@ +--- +title: Compile-time Code & Metaprogramming +description: Learn how to use metaprogramming in Noir to create macros or derive your own traits +keywords: [Noir, comptime, compile-time, metaprogramming, macros, quote, unquote] +sidebar_position: 15 +--- + +# Overview + +Metaprogramming in Noir is comprised of three parts: +1. `comptime` code +2. Quoting and unquoting +3. The metaprogramming API in `std::meta` + +Each of these are explained in more detail in the next sections but the wide picture is that +`comptime` allows us to write code which runs at compile-time. In this `comptime` code we +can quote and unquote snippets of the program, manipulate them, and insert them in other +parts of the program. Comptime functions which do this are said to be macros. Additionally, +there's a compile-time API of built-in types and functions provided by the compiler which allows +for greater analysis and modification of programs. + +--- + +# Comptime + +`comptime` is a new keyword in Noir which marks an item as executing or existing at compile-time. It can be used in several ways: + +- `comptime fn` to define functions which execute exclusively during compile-time. +- `comptime global` to define a global variable which is evaluated at compile-time. + - Unlike runtime globals, `comptime global`s can be mutable. +- `comptime { ... }` to execute a block of statements during compile-time. +- `comptime let` to define a variable whose value is evaluated at compile-time. +- `comptime for` to run a for loop at compile-time. Syntax sugar for `comptime { for .. }`. + +## Scoping + +Note that while in a `comptime` context, any runtime variables _local to the current function_ are never visible. + +## Evaluating + +Evaluation rules of `comptime` follows the normal unconstrained evaluation rules for other Noir code. There are a few things to note though: + +- Certain built-in functions may not be available, although more may be added over time. +- Evaluation order of global items is currently unspecified. For example, given the following two functions we can't guarantee +which `println` will execute first. The ordering of the two printouts will be arbitrary, but should be stable across multiple compilations with the same `nargo` version as long as the program is also unchanged. + +```rust +fn one() { + comptime { println("one"); } +} + +fn two() { + comptime { println("two"); } +} +``` + +- Since evaluation order is unspecified, care should be taken when using mutable globals so that they do not rely on a particular ordering. +For example, using globals to generate unique ids should be fine but relying on certain ids always being produced (especially after edits to the program) should be avoided. +- Although most ordering of globals is unspecified, two are: + - Dependencies of a crate will always be evaluated before the dependent crate. + - Any annotations on a function will be run before the function itself is resolved. This is to allow the annotation to modify the function if necessary. Note that if the + function itself was called at compile-time previously, it will already be resolved and cannot be modified. To prevent accidentally calling functions you wish to modify + at compile-time, it may be helpful to sort your `comptime` annotation functions into a different crate along with any dependencies they require. + +## Lowering + +When a `comptime` value is used in runtime code it must be lowered into a runtime value. This means replacing the expression with the literal that it evaluated to. For example, the code: + +```rust +struct Foo { array: [Field; 2], len: u32 } + +fn main() { + println(comptime { + let mut foo = std::mem::zeroed::(); + foo.array[0] = 4; + foo.len = 1; + foo + }); +} +``` + +will be converted to the following after `comptime` expressions are evaluated: + +```rust +struct Foo { array: [Field; 2], len: u32 } + +fn main() { + println(Foo { array: [4, 0], len: 1 }); +} +``` + +Not all types of values can be lowered. For example, `Type`s and `TypeDefinition`s (among other types) cannot be lowered at all. + +```rust +fn main() { + // There's nothing we could inline here to create a Type value at runtime + // let _ = get_type!(); +} + +comptime fn get_type() -> Type { ... } +``` + +--- + +# (Quasi) Quote + +Macros in Noir are `comptime` functions which return code as a value which is inserted into the call site when it is lowered there. +A code value in this case is of type `Quoted` and can be created by a `quote { ... }` expression. +More specifically, the code value `quote` creates is a token stream - a representation of source code as a series of words, numbers, string literals, or operators. +For example, the expression `quote { Hi "there reader"! }` would quote three tokens: the word "hi", the string "there reader", and an exclamation mark. +You'll note that snippets that would otherwise be invalid syntax can still be quoted. + +When a `Quoted` value is used in runtime code, it is lowered into a `quote { ... }` expression. Since this expression is only valid +in compile-time code however, we'd get an error if we tried this. Instead, we can use macro insertion to insert each token into the +program at that point, and parse it as an expression. To do this, we have to add a `!` after the function name returning the `Quoted` value. +If the value was created locally and there is no function returning it, `std::meta::unquote!(_)` can be used instead. +Calling such a function at compile-time without `!` will just return the `Quoted` value to be further manipulated. For example: + +#include_code quote-example noir_stdlib/src/meta/mod.nr rust + +For those familiar with quoting from other languages (primarily lisps), Noir's `quote` is actually a _quasiquote_. +This means we can escape the quoting by using the unquote operator to splice values in the middle of quoted code. + +# Unquote + +The unquote operator `$` is usable within a `quote` expression. +It takes a variable as an argument, evaluates the variable, and splices the resulting value into the quoted token stream at that point. For example, + +```rust +comptime { + let x = 1 + 2; + let y = quote { $x + 4 }; +} +``` + +The value of `y` above will be the token stream containing `3`, `+`, and `4`. We can also use this to combine `Quoted` values into larger token streams: + +```rust +comptime { + let x = quote { 1 + 2 }; + let y = quote { $x + 4 }; +} +``` + +The value of `y` above is now the token stream containing five tokens: `1 + 2 + 4`. + +Note that to unquote something, a variable name _must_ follow the `$` operator in a token stream. +If it is an expression (even a parenthesized one), it will do nothing. Most likely a parse error will be given when the macro is later unquoted. + +Unquoting can also be avoided by escaping the `$` with a backslash: + +``` +comptime { + let x = quote { 1 + 2 }; + + // y contains the four tokens: `$x + 4` + let y = quote { \$x + 4 }; +} +``` + +--- + +# Annotations + +Annotations provide a way to run a `comptime` function on an item in the program. +When you use an annotation, the function with the same name will be called with that item as an argument: + +```rust +#[my_struct_annotation] +struct Foo {} + +comptime fn my_struct_annotation(s: StructDefinition) { + println("Called my_struct_annotation!"); +} + +#[my_function_annotation] +fn foo() {} + +comptime fn my_function_annotation(f: FunctionDefinition) { + println("Called my_function_annotation!"); +} +``` + +Anything returned from one of these functions will be inserted at top-level along with the original item. +Note that expressions are not valid at top-level so you'll get an error trying to return `3` or similar just as if you tried to write a program containing `3; struct Foo {}`. +You can insert other top-level items such as traits, structs, or functions this way though. +For example, this is the mechanism used to insert additional trait implementations into the program when deriving a trait impl from a struct: + +#include_code derive-field-count-example noir_stdlib/src/meta/mod.nr rust + +## Calling annotations with additional arguments + +Arguments may optionally be given to annotations. +When this is done, these additional arguments are passed to the annotation function after the item argument. + +#include_code annotation-arguments-example noir_stdlib/src/meta/mod.nr rust + +We can also take any number of arguments by adding the `varargs` annotation: + +#include_code annotation-varargs-example noir_stdlib/src/meta/mod.nr rust + +--- + +# Comptime API + +Although `comptime`, `quote`, and unquoting provide a flexible base for writing macros, +Noir's true metaprogramming ability comes from being able to interact with the compiler through a compile-time API. +This API can be accessed through built-in functions in `std::meta` as well as on methods of several `comptime` types. + +The following is an incomplete list of some `comptime` types along with some useful methods on them. + +- `Quoted`: A token stream +- `Type`: The type of a Noir type + - `fn implements(self, constraint: TraitConstraint) -> bool` + - Returns true if `self` implements the given trait constraint +- `Expr`: A syntactically valid expression. Can be used to recur on a program's parse tree to inspect how it is structured. + - Methods: + - `fn as_function_call(self) -> Option<(Expr, [Expr])>` + - If this is a function call expression, return `(function, arguments)` + - `fn as_block(self) -> Option<[Expr]>` + - If this is a block, return each statement in the block +- `FunctionDefinition`: A function definition + - Methods: + - `fn parameters(self) -> [(Quoted, Type)]` + - Returns a slice of `(name, type)` pairs for each parameter +- `StructDefinition`: A struct definition + - Methods: + - `fn as_type(self) -> Type` + - Returns this `StructDefinition` as a `Type`. Any generics are kept as-is + - `fn generics(self) -> [Quoted]` + - Return the name of each generic on this struct + - `fn fields(self) -> [(Quoted, Type)]` + - Return the name and type of each field +- `TraitConstraint`: A trait constraint such as `From` +- `UnresolvedType`: A syntactic notation that refers to a Noir type that hasn't been resolved yet + +There are many more functions available by exploring the `std::meta` module and its submodules. +Using these methods is the key to writing powerful metaprogramming libraries. + +--- + +# Example: Derive + +Using all of the above, we can write a `derive` macro that behaves similarly to Rust's but is not built into the language. +From the user's perspective it will look like this: + +```rust +// Example usage +#[derive(Default, Eq, Ord)] +struct MyStruct { my_field: u32 } +``` + +To implement `derive` we'll have to create a `comptime` function that accepts +a variable amount of traits. + +#include_code derive_example noir_stdlib/src/meta/mod.nr rust + +Registering a derive function could be done as follows: + +#include_code derive_via noir_stdlib/src/meta/mod.nr rust + +#include_code big-derive-usage-example noir_stdlib/src/meta/mod.nr rust diff --git a/docs/docs/noir/concepts/traits.md b/docs/docs/noir/concepts/traits.md index 51305b38c1..597c62c737 100644 --- a/docs/docs/noir/concepts/traits.md +++ b/docs/docs/noir/concepts/traits.md @@ -225,6 +225,66 @@ fn main() { } ``` +### Associated Types and Constants + +Traits also support associated types and constraints which can be thought of as additional generics that are referred to by name. + +Here's an example of a trait with an associated type `Foo` and a constant `Bar`: + +```rust +trait MyTrait { + type Foo; + + let Bar: u32; +} +``` + +Now when we're implementing `MyTrait` we also have to provide values for `Foo` and `Bar`: + +```rust +impl MyTrait for Field { + type Foo = i32; + + let Bar: u32 = 11; +} +``` + +Since associated constants can also be used in a type position, its values are limited to only other +expression kinds allowed in numeric generics. + +Note that currently all associated types and constants must be explicitly specified in a trait constraint. +If we leave out any, we'll get an error that we're missing one: + +```rust +// Error! Constraint is missing associated constant for `Bar` +fn foo(x: T) where T: MyTrait { + ... +} +``` + +Because all associated types and constants must be explicitly specified, they are essentially named generics, +although this is set to change in the future. Future versions of Noir will allow users to elide associated types +in trait constraints similar to Rust. When this is done, you may still refer to their value with the `::AssociatedType` +syntax: + +```rust +// Only valid in future versions of Noir: +fn foo(x: T) where T: MyTrait { + let _: ::Foo = ...; +} +``` + +The type as trait syntax is possible in Noir today but is less useful when each type must be explicitly specified anyway: + +```rust +fn foo(x: T) where T: MyTrait { + // Works, but could just use F directly + let _: >::Foo = ...; + + let _: F = ...; +} +``` + ## Trait Methods With No `self` A trait can contain any number of methods, each of which have access to the `Self` type which represents each type diff --git a/docs/docs/noir/standard_library/bigint.md b/docs/docs/noir/standard_library/bigint.md index 54d791b82d..cc7d6e1c8d 100644 --- a/docs/docs/noir/standard_library/bigint.md +++ b/docs/docs/noir/standard_library/bigint.md @@ -15,6 +15,11 @@ The BigInt module in the standard library exposes some class of integers which d The module can currently be considered as `Field`s with fixed modulo sizes used by a set of elliptic curves, in addition to just the native curve. [More work](https://github.com/noir-lang/noir/issues/510) is needed to achieve arbitrarily sized big integers. +:::note + +`nargo` can be built with `--profile release-pedantic` to enable extra overflow checks which may affect `BigInt` results in some cases. +Consider the [`noir-bignum`](https://github.com/noir-lang/noir-bignum) library for an optimized alternative approach. + ::: Currently 6 classes of integers (i.e 'big' prime numbers) are available in the module, namely: diff --git a/docs/docs/noir/standard_library/meta/expr.md b/docs/docs/noir/standard_library/meta/expr.md new file mode 100644 index 0000000000..0a32b2b04f --- /dev/null +++ b/docs/docs/noir/standard_library/meta/expr.md @@ -0,0 +1,163 @@ +--- +title: Expr +--- + +`std::meta::expr` contains methods on the built-in `Expr` type for quoted, syntactically valid expressions. + +## Methods + +### as_array + +#include_code as_array noir_stdlib/src/meta/expr.nr rust + +If this expression is an array, this returns a slice of each element in the array. + +### as_assign + +#include_code as_assign noir_stdlib/src/meta/expr.nr rust + +If this expression is an assignment, this returns a tuple with the left hand side +and right hand side in order. + +### as_binary_op + +#include_code as_binary_op noir_stdlib/src/meta/expr.nr rust + +If this expression is a binary operator operation ` `, +return the left-hand side, operator, and the right-hand side of the operation. + +### as_block + +#include_code as_block noir_stdlib/src/meta/expr.nr rust + +If this expression is a block `{ stmt1; stmt2; ...; stmtN }`, return +a slice containing each statement. + +### as_bool + +#include_code as_bool noir_stdlib/src/meta/expr.nr rust + +If this expression is a boolean literal, return that literal. + +### as_comptime + +#include_code as_comptime noir_stdlib/src/meta/expr.nr rust + +If this expression is a `comptime { stmt1; stmt2; ...; stmtN }` block, +return each statement in the block. + +### as_function_call + +#include_code as_function_call noir_stdlib/src/meta/expr.nr rust + +If this expression is a function call `foo(arg1, ..., argN)`, return +the function and a slice of each argument. + +### as_if + +#include_code as_if noir_stdlib/src/meta/expr.nr rust + +If this expression is an `if condition { then_branch } else { else_branch }`, +return the condition, then branch, and else branch. If there is no else branch, +`None` is returned for that branch instead. + +### as_index + +#include_code as_index noir_stdlib/src/meta/expr.nr rust + +If this expression is an index into an array `array[index]`, return the +array and the index. + +### as_integer + +#include_code as_integer noir_stdlib/src/meta/expr.nr rust + +If this element is an integer literal, return the integer as a field +as well as whether the integer is negative (true) or not (false). + +### as_member_access + +#include_code as_member_access noir_stdlib/src/meta/expr.nr rust + +If this expression is a member access `foo.bar`, return the struct/tuple +expression and the field. The field will be represented as a quoted value. + +### as_method_call + +#include_code as_method_call noir_stdlib/src/meta/expr.nr rust + +If this expression is a method call `foo.bar::(arg1, ..., argN)`, return +the receiver, method name, a slice of each generic argument, and a slice of each argument. + +### as_repeated_element_array + +#include_code as_repeated_element_array noir_stdlib/src/meta/expr.nr rust + +If this expression is a repeated element array `[elem; length]`, return +the repeated element and the length expressions. + +### as_repeated_element_slice + +#include_code as_repeated_element_slice noir_stdlib/src/meta/expr.nr rust + +If this expression is a repeated element slice `[elem; length]`, return +the repeated element and the length expressions. + +### as_slice + +#include_code as_slice noir_stdlib/src/meta/expr.nr rust + +If this expression is a slice literal `&[elem1, ..., elemN]`, +return each element of the slice. + +### as_tuple + +#include_code as_tuple noir_stdlib/src/meta/expr.nr rust + +If this expression is a tuple `(field1, ..., fieldN)`, +return each element of the tuple. + +### as_unary_op + +#include_code as_unary_op noir_stdlib/src/meta/expr.nr rust + +If this expression is a unary operation ` `, +return the unary operator as well as the right-hand side expression. + +### as_unsafe + +#include_code as_unsafe noir_stdlib/src/meta/expr.nr rust + +If this expression is an `unsafe { stmt1; ...; stmtN }` block, +return each statement inside in a slice. + +### has_semicolon + +#include_code has_semicolon noir_stdlib/src/meta/expr.nr rust + +`true` if this expression is trailed by a semicolon. E.g. + +``` +comptime { + let expr1 = quote { 1 + 2 }.as_expr().unwrap(); + let expr2 = quote { 1 + 2; }.as_expr().unwrap(); + + assert(expr1.as_binary_op().is_some()); + assert(expr2.as_binary_op().is_some()); + + assert(!expr1.has_semicolon()); + assert(expr2.has_semicolon()); +} +``` + +### is_break + +#include_code is_break noir_stdlib/src/meta/expr.nr rust + +`true` if this expression is `break`. + +### is_continue + +#include_code is_continue noir_stdlib/src/meta/expr.nr rust + +`true` if this expression is `continue`. diff --git a/docs/docs/noir/standard_library/meta/function_def.md b/docs/docs/noir/standard_library/meta/function_def.md new file mode 100644 index 0000000000..4b359a9d34 --- /dev/null +++ b/docs/docs/noir/standard_library/meta/function_def.md @@ -0,0 +1,55 @@ +--- +title: FunctionDefinition +--- + +`std::meta::function_def` contains methods on the built-in `FunctionDefinition` type representing +a function definition in the source program. + +## Methods + +### name + +#include_code name noir_stdlib/src/meta/function_def.nr rust + +Returns the name of the function. + +### parameters + +#include_code parameters noir_stdlib/src/meta/function_def.nr rust + +Returns each parameter of the function as a tuple of (parameter pattern, parameter type). + +### return_type + +#include_code return_type noir_stdlib/src/meta/function_def.nr rust + +The return type of the function. + +### set_body + +#include_code set_body noir_stdlib/src/meta/function_def.nr rust + +Mutate the function body to a new expression. This is only valid +on functions in the current crate which have not yet been resolved. +This means any functions called at compile-time are invalid targets for this method. + +Requires the new body to be a valid expression. + +### set_parameters + +#include_code set_parameters noir_stdlib/src/meta/function_def.nr rust + +Mutates the function's parameters to a new set of parameters. This is only valid +on functions in the current crate which have not yet been resolved. +This means any functions called at compile-time are invalid targets for this method. + +Expects a slice of (parameter pattern, parameter type) for each parameter. Requires +each parameter pattern to be a syntactically valid parameter. + +### set_return_type + +#include_code set_return_type noir_stdlib/src/meta/function_def.nr rust + +Mutates the function's return type to a new type. This is only valid +on functions in the current crate which have not yet been resolved. +This means any functions called at compile-time are invalid targets for this method. diff --git a/docs/docs/noir/standard_library/meta/index.md b/docs/docs/noir/standard_library/meta/index.md new file mode 100644 index 0000000000..db0e5d0e41 --- /dev/null +++ b/docs/docs/noir/standard_library/meta/index.md @@ -0,0 +1,141 @@ +--- +title: Metaprogramming +description: Noir's Metaprogramming API +keywords: [metaprogramming, comptime, macros, macro, quote, unquote] +--- + +`std::meta` is the entry point for Noir's metaprogramming API. This consists of `comptime` functions +and types used for inspecting and modifying Noir programs. + +## Functions + +### type_of + +#include_code type_of noir_stdlib/src/meta/mod.nr rust + +Returns the type of a variable at compile-time. + +Example: +```rust +comptime { + let x: i32 = 1; + let x_type: Type = std::meta::type_of(x); + + assert_eq(x_type, quote { i32 }.as_type()); +} +``` + +### unquote + +#include_code unquote noir_stdlib/src/meta/mod.nr rust + +Unquotes the passed-in token stream where this function was called. + +Example: +```rust +comptime { + let code = quote { 1 + 2 }; + + // let x = 1 + 2; + let x = unquote!(code); +} +``` + +### derive + +#include_code derive noir_stdlib/src/meta/mod.nr rust + +Attribute placed on struct definitions. + +Creates a trait impl for each trait passed in as an argument. +To do this, the trait must have a derive handler registered +with `derive_via` beforehand. The traits in the stdlib that +can be derived this way are `Eq`, `Ord`, `Default`, and `Hash`. + +Example: +```rust +#[derive(Eq, Default)] +struct Foo { + x: i32, + y: T, +} + +fn main() { + let foo1 = Foo::default(); + let foo2 = Foo { x: 0, y: &[0] }; + assert_eq(foo1, foo2); +} +``` + +### derive_via + +#include_code derive_via_signature noir_stdlib/src/meta/mod.nr rust + +Attribute placed on trait definitions. + +Registers a function to create impls for the given trait +when the trait is used in a `derive` call. Users may use +this to register their own functions to enable their traits +to be derived by `derive`. + +Because this function requires a function as an argument which +should produce a trait impl for any given struct, users may find +it helpful to use a function like `std::meta::make_trait_impl` to +help creating these impls. + +Example: +```rust +#[derive_via(derive_do_nothing)] +trait DoNothing { + fn do_nothing(self); +} + +comptime fn derive_do_nothing(s: StructDefinition) -> Quoted { + let typ = s.as_type(); + quote { + impl DoNothing for $typ { + fn do_nothing(self) { + println("Nothing"); + } + } + } +} +``` + +As another example, `derive_eq` in the stdlib is used to derive the `Eq` +trait for any struct. It makes use of `make_trait_impl` to do this: + +#include_code derive_eq noir_stdlib/src/cmp.nr rust + +### make_trait_impl + +#include_code make_trait_impl noir_stdlib/src/meta/mod.nr rust + +A helper function to more easily create trait impls while deriving traits. + +Note that this function only works for traits which: +1. Have only one method +2. Have no generics on the trait itself. + - E.g. Using this on a trait such as `trait Foo { ... }` will result in the + generated impl incorrectly missing the `T` generic. + +If your trait fits these criteria then `make_trait_impl` is likely the easiest +way to write your derive handler. The arguments are as follows: + +- `s`: The struct to make the impl for +- `trait_name`: The name of the trait to derive. E.g. `quote { Eq }`. +- `function_signature`: The signature of the trait method to derive. E.g. `fn eq(self, other: Self) -> bool`. +- `for_each_field`: An operation to be performed on each field. E.g. `|name| quote { (self.$name == other.$name) }`. +- `join_fields_with`: A separator to join each result of `for_each_field` with. + E.g. `quote { & }`. You can also use an empty `quote {}` for no separator. +- `body`: The result of the field operations are passed into this function for any final processing. + This is the place to insert any setup/teardown code the trait requires. If the trait doesn't require + any such code, you can return the body as-is: `|body| body`. + +Example deriving `Hash`: + +#include_code derive_hash noir_stdlib/src/hash/mod.nr rust + +Example deriving `Ord`: + +#include_code derive_ord noir_stdlib/src/cmp.nr rust diff --git a/docs/docs/noir/standard_library/meta/module.md b/docs/docs/noir/standard_library/meta/module.md new file mode 100644 index 0000000000..d283f2da8b --- /dev/null +++ b/docs/docs/noir/standard_library/meta/module.md @@ -0,0 +1,27 @@ +--- +title: Module +--- + +`std::meta::module` contains methods on the built-in `Module` type which represents a module in the source program. +Note that this type represents a module generally, it isn't limited to only `mod my_submodule { ... }` +declarations in the source program. + +## Methods + +### name + +#include_code name noir_stdlib/src/meta/module.nr rust + +Returns the name of the module. + +### functions + +#include_code functions noir_stdlib/src/meta/module.nr rust + +Returns each function in the module. + +### is_contract + +#include_code is_contract noir_stdlib/src/meta/module.nr rust + +`true` if this module is a contract module (was declared via `contract foo { ... }`). diff --git a/docs/docs/noir/standard_library/meta/op.md b/docs/docs/noir/standard_library/meta/op.md new file mode 100644 index 0000000000..37d4cb746a --- /dev/null +++ b/docs/docs/noir/standard_library/meta/op.md @@ -0,0 +1,134 @@ +--- +title: UnaryOp and BinaryOp +--- + +`std::meta::op` contains the `UnaryOp` and `BinaryOp` types as well as methods on them. +These types are used to represent a unary or binary operator respectively in Noir source code. + +## Types + +### UnaryOp + +Represents a unary operator. One of `-`, `!`, `&mut`, or `*`. + +### Methods + +#### is_minus + +#include_code is_minus noir_stdlib/src/meta/op.nr rust + +Returns `true` if this operator is `-`. + +#### is_not + +#include_code is_not noir_stdlib/src/meta/op.nr rust + +`true` if this operator is `!` + +#### is_mutable_reference + +#include_code is_mutable_reference noir_stdlib/src/meta/op.nr rust + +`true` if this operator is `&mut` + +#### is_dereference + +#include_code is_dereference noir_stdlib/src/meta/op.nr rust + +`true` if this operator is `*` + +### BinaryOp + +Represents a binary operator. One of `+`, `-`, `*`, `/`, `%`, `==`, `!=`, `<`, `<=`, `>`, `>=`, `&`, `|`, `^`, `>>`, or `<<`. + +### Methods + +#### is_add + +#include_code is_add noir_stdlib/src/meta/op.nr rust + +`true` if this operator is `+` + +#### is_subtract + +#include_code is_subtract noir_stdlib/src/meta/op.nr rust + +`true` if this operator is `-` + +#### is_multiply + +#include_code is_multiply noir_stdlib/src/meta/op.nr rust + +`true` if this operator is `*` + +#### is_divide + +#include_code is_divide noir_stdlib/src/meta/op.nr rust + +`true` if this operator is `/` + +#### is_modulo + +#include_code is_modulo noir_stdlib/src/meta/op.nr rust + +`true` if this operator is `%` + +#### is_equal + +#include_code is_equal noir_stdlib/src/meta/op.nr rust + +`true` if this operator is `==` + +#### is_not_equal + +#include_code is_not_equal noir_stdlib/src/meta/op.nr rust + +`true` if this operator is `!=` + +#### is_less_than + +#include_code is_less_than noir_stdlib/src/meta/op.nr rust + +`true` if this operator is `<` + +#### is_less_than_or_equal + +#include_code is_less_than_or_equal noir_stdlib/src/meta/op.nr rust + +`true` if this operator is `<=` + +#### is_greater_than + +#include_code is_greater_than noir_stdlib/src/meta/op.nr rust + +`true` if this operator is `>` + +#### is_greater_than_or_equal + +#include_code is_greater_than_or_equal noir_stdlib/src/meta/op.nr rust + +`true` if this operator is `>=` + +#### is_and + +#include_code is_and noir_stdlib/src/meta/op.nr rust + +`true` if this operator is `&` + +#### is_or + +#include_code is_or noir_stdlib/src/meta/op.nr rust + +`true` if this operator is `|` + +#### is_shift_right + +#include_code is_shift_right noir_stdlib/src/meta/op.nr rust + +`true` if this operator is `>>` + +#### is_shift_left + +#include_code is_shift_right noir_stdlib/src/meta/op.nr rust + +`true` if this operator is `<<` diff --git a/docs/docs/noir/standard_library/meta/quoted.md b/docs/docs/noir/standard_library/meta/quoted.md new file mode 100644 index 0000000000..bf79f2e5d9 --- /dev/null +++ b/docs/docs/noir/standard_library/meta/quoted.md @@ -0,0 +1,57 @@ +--- +title: Quoted +--- + +`std::meta::quoted` contains methods on the built-in `Quoted` type which represents +quoted token streams and is the result of the `quote { ... }` expression. + +## Methods + +### as_expr + +#include_code as_expr noir_stdlib/src/meta/quoted.nr rust + +Parses the quoted token stream as an expression. Returns `Option::none()` if +the expression failed to parse. + +Example: + +#include_code as_expr_example test_programs/noir_test_success/comptime_expr/src/main.nr rust + +### as_module + +#include_code as_module noir_stdlib/src/meta/quoted.nr rust + +Interprets this token stream as a module path leading to the name of a module. +Returns `Option::none()` if the module isn't found or this token stream cannot be parsed as a path. + +Example: + +#include_code as_module_example test_programs/compile_success_empty/comptime_module/src/main.nr rust + +### as_trait_constraint + +#include_code as_trait_constraint noir_stdlib/src/meta/quoted.nr rust + +Interprets this token stream as a trait constraint (without an object type). +Note that this function panics instead of returning `Option::none()` if the token +stream does not parse and resolve to a valid trait constraint. + +Example: + +#include_code implements_example test_programs/compile_success_empty/comptime_type/src/main.nr rust + +### as_type + +#include_code as_type noir_stdlib/src/meta/quoted.nr rust + +Interprets this token stream as a resolved type. Panics if the token +stream doesn't parse to a type or if the type isn't a valid type in scope. + +#include_code implements_example test_programs/compile_success_empty/comptime_type/src/main.nr rust + +## Trait Implementations + +```rust +impl Eq for Quoted +``` diff --git a/docs/docs/noir/standard_library/meta/struct_def.md b/docs/docs/noir/standard_library/meta/struct_def.md new file mode 100644 index 0000000000..ab3ea4e069 --- /dev/null +++ b/docs/docs/noir/standard_library/meta/struct_def.md @@ -0,0 +1,45 @@ +--- +title: StructDefinition +--- + +`std::meta::struct_def` contains methods on the built-in `StructDefinition` type. +This type corresponds to `struct Name { field1: Type1, ... }` items in the source program. + +## Methods + +### as_type + +#include_code as_type noir_stdlib/src/meta/struct_def.nr rust + +Returns this struct as a type in the source program. If this struct has +any generics, the generics are also included as-is. + +### generics + +#include_code generics noir_stdlib/src/meta/struct_def.nr rust + +Returns each generic on this struct. + +Example: + +``` +#[example] +struct Foo { + bar: [T; 2], + baz: Baz, +} + +comptime fn example(foo: StructDefinition) { + assert_eq(foo.generics().len(), 2); + + // Fails because `T` isn't in scope + // let t = quote { T }.as_type(); + // assert_eq(foo.generics()[0], t); +} +``` + +### fields + +#include_code fields noir_stdlib/src/meta/struct_def.nr rust + +Returns each field of this struct as a pair of (field name, field type). diff --git a/docs/docs/noir/standard_library/meta/trait_constraint.md b/docs/docs/noir/standard_library/meta/trait_constraint.md new file mode 100644 index 0000000000..3106f732b5 --- /dev/null +++ b/docs/docs/noir/standard_library/meta/trait_constraint.md @@ -0,0 +1,17 @@ +--- +title: TraitConstraint +--- + +`std::meta::trait_constraint` contains methods on the built-in `TraitConstraint` type which represents +a trait constraint that can be used to search for a trait implementation. This is similar +syntactically to just the trait itself, but can also contain generic arguments. E.g. `Eq`, `Default`, +`BuildHasher`. + +This type currently has no public methods but it can be used alongside `Type` in `implements` or `get_trait_impl`. + +## Trait Implementations + +```rust +impl Eq for TraitConstraint +impl Hash for TraitConstraint +``` diff --git a/docs/docs/noir/standard_library/meta/trait_def.md b/docs/docs/noir/standard_library/meta/trait_def.md new file mode 100644 index 0000000000..b6e8bf4ff7 --- /dev/null +++ b/docs/docs/noir/standard_library/meta/trait_def.md @@ -0,0 +1,22 @@ +--- +title: TraitDefinition +--- + +`std::meta::trait_def` contains methods on the built-in `TraitDefinition` type. This type +represents trait definitions such as `trait Foo { .. }` at the top-level of a program. + +## Methods + +### as_trait_constraint + +#include_code as_trait_constraint noir_stdlib/src/meta/trait_def.nr rust + +Converts this trait into a trait constraint. If there are any generics on this +trait, they will be kept as-is without instantiating or replacing them. + +## Trait Implementations + +```rust +impl Eq for TraitDefinition +impl Hash for TraitDefinition +``` diff --git a/docs/docs/noir/standard_library/meta/trait_impl.md b/docs/docs/noir/standard_library/meta/trait_impl.md new file mode 100644 index 0000000000..659c6aad71 --- /dev/null +++ b/docs/docs/noir/standard_library/meta/trait_impl.md @@ -0,0 +1,52 @@ +--- +title: TraitImpl +--- + +`std::meta::trait_impl` contains methods on the built-in `TraitImpl` type which represents a trait +implementation such as `impl Foo for Bar { ... }`. + +## Methods + +### trait_generic_args + +#include_code trait_generic_args noir_stdlib/src/meta/trait_impl.nr rust + +Returns any generic arguments on the trait of this trait implementation, if any. + +```rs +impl Foo for Bar { ... } + +comptime { + let bar_type = quote { Bar }.as_type(); + let foo = quote { Foo }.as_trait_constraint(); + + let my_impl: TraitImpl = bar_type.get_trait_impl(foo).unwrap(); + + let generics = my_impl.trait_generic_args(); + assert_eq(generics.len(), 2); + + assert_eq(generics[0], quote { i32 }.as_type()); + assert_eq(generics[1], quote { Field }.as_type()); +} +``` + +### methods + +#include_code methods noir_stdlib/src/meta/trait_impl.nr rust + +Returns each method in this trait impl. + +Example: + +```rs +comptime { + let i32_type = quote { i32 }.as_type(); + let eq = quote { Eq }.as_trait_constraint(); + + let impl_eq_for_i32: TraitImpl = i32_type.get_trait_impl(eq).unwrap(); + let methods = impl_eq_for_i32.methods(); + + assert_eq(methods.len(), 1); + assert_eq(methods[0].name(), quote { eq }); +} +``` diff --git a/docs/docs/noir/standard_library/meta/typ.md b/docs/docs/noir/standard_library/meta/typ.md new file mode 100644 index 0000000000..bad6435e94 --- /dev/null +++ b/docs/docs/noir/standard_library/meta/typ.md @@ -0,0 +1,126 @@ +--- +title: Type +--- + +`std::meta::typ` contains methods on the built-in `Type` type used for representing +a type in the source program. + +## Methods + +### as_array + +#include_code as_array noir_stdlib/src/meta/typ.nr rust + +If this type is an array, return a pair of (element type, size type). + +Example: + +```rust +comptime { + let array_type = quote { [Field; 3] }.as_type(); + let (field_type, three_type) = array_type.as_array().unwrap(); + + assert(field_type.is_field()); + assert_eq(three_type.as_constant().unwrap(), 3); +} +``` + +### as_constant + +#include_code as_constant noir_stdlib/src/meta/typ.nr rust + +If this type is a constant integer (such as the `3` in the array type `[Field; 3]`), +return the numeric constant. + +### as_integer + +#include_code as_integer noir_stdlib/src/meta/typ.nr rust + +If this is an integer type, return a boolean which is `true` +if the type is signed, as well as the number of bits of this integer type. + +### as_slice + +#include_code as_slice noir_stdlib/src/meta/typ.nr rust + +If this is a slice type, return the element type of the slice. + +### as_struct + +#include_code as_struct noir_stdlib/src/meta/typ.nr rust + +If this is a struct type, returns the struct in addition to +any generic arguments on this type. + +### as_tuple + +#include_code as_tuple noir_stdlib/src/meta/typ.nr rust + +If this is a tuple type, returns each element type of the tuple. + +### get_trait_impl + +#include_code get_trait_impl noir_stdlib/src/meta/typ.nr rust + +Retrieves the trait implementation that implements the given +trait constraint for this type. If the trait constraint is not +found, `None` is returned. Note that since the concrete trait implementation +for a trait constraint specified from a `where` clause is unknown, +this function will return `None` in these cases. If you only want to know +whether a type implements a trait, use `implements` instead. + +Example: + +```rust +comptime { + let field_type = quote { Field }.as_type(); + let default = quote { Default }.as_trait_constraint(); + + let the_impl: TraitImpl = field_type.get_trait_impl(default).unwrap(); + assert(the_impl.methods().len(), 1); +} +``` + +### implements + +#include_code implements noir_stdlib/src/meta/typ.nr rust + +`true` if this type implements the given trait. Note that unlike +`get_trait_impl` this will also return true for any `where` constraints +in scope. + +Example: + +```rust +fn foo() where T: Default { + comptime { + let field_type = quote { Field }.as_type(); + let default = quote { Default }.as_trait_constraint(); + assert(field_type.implements(default)); + + let t = quote { T }.as_type(); + assert(t.implements(default)); + } +} +``` + +### is_bool + +#include_code is_bool noir_stdlib/src/meta/typ.nr rust + +`true` if this type is `bool`. + +### is_field + +#include_code is_field noir_stdlib/src/meta/typ.nr rust + +`true` if this type is `Field`. + +## Trait Implementations + +```rust +impl Eq for Type +``` +Note that this is syntactic equality, this is not the same as whether two types will type check +to be the same type. Unless type inference or generics are being used however, users should not +typically have to worry about this distinction. diff --git a/docs/docs/noir/standard_library/meta/unresolved_type.md b/docs/docs/noir/standard_library/meta/unresolved_type.md new file mode 100644 index 0000000000..9c61f91dee --- /dev/null +++ b/docs/docs/noir/standard_library/meta/unresolved_type.md @@ -0,0 +1,13 @@ +--- +title: UnresolvedType +--- + +`std::meta::unresolved_type` contains methods on the built-in `UnresolvedType` type for the syntax of types. + +## Methods + +### is_field + +#include_code is_field noir_stdlib/src/meta/unresolved_type.nr rust + +Returns true if this type refers to the Field type. diff --git a/noir_stdlib/src/cmp.nr b/noir_stdlib/src/cmp.nr index ec979d6075..b7f473429a 100644 --- a/noir_stdlib/src/cmp.nr +++ b/noir_stdlib/src/cmp.nr @@ -7,12 +7,14 @@ trait Eq { } // docs:end:eq-trait +// docs:start:derive_eq comptime fn derive_eq(s: StructDefinition) -> Quoted { let signature = quote { fn eq(_self: Self, _other: Self) -> bool }; let for_each_field = |name| quote { (_self.$name == _other.$name) }; let body = |fields| fields; crate::meta::make_trait_impl(s, quote { Eq }, signature, for_each_field, quote { & }, body) } +// docs:end:derive_eq impl Eq for Field { fn eq(self, other: Field) -> bool { self == other } } @@ -118,6 +120,7 @@ trait Ord { } // docs:end:ord-trait +// docs:start:derive_ord comptime fn derive_ord(s: StructDefinition) -> Quoted { let signature = quote { fn cmp(_self: Self, _other: Self) -> std::cmp::Ordering }; let for_each_field = |name| quote { @@ -132,6 +135,7 @@ comptime fn derive_ord(s: StructDefinition) -> Quoted { }; crate::meta::make_trait_impl(s, quote { Ord }, signature, for_each_field, quote {}, body) } +// docs:end:derive_ord // Note: Field deliberately does not implement Ord diff --git a/noir_stdlib/src/hash/mod.nr b/noir_stdlib/src/hash/mod.nr index d77b655398..657e1cd830 100644 --- a/noir_stdlib/src/hash/mod.nr +++ b/noir_stdlib/src/hash/mod.nr @@ -144,12 +144,14 @@ trait Hash { fn hash(self, state: &mut H) where H: Hasher; } +// docs:start:derive_hash comptime fn derive_hash(s: StructDefinition) -> Quoted { let name = quote { Hash }; let signature = quote { fn hash(_self: Self, _state: &mut H) where H: std::hash::Hasher }; let for_each_field = |name| quote { _self.$name.hash(_state); }; crate::meta::make_trait_impl(s, name, signature, for_each_field, quote {}, |fields| fields) } +// docs:end:derive_hash // Hasher trait shall be implemented by algorithms to provide hash-agnostic means. // TODO: consider making the types generic here ([u8], [Field], etc.) diff --git a/noir_stdlib/src/meta/expr.nr b/noir_stdlib/src/meta/expr.nr index 94889b7c3d..b99383c605 100644 --- a/noir_stdlib/src/meta/expr.nr +++ b/noir_stdlib/src/meta/expr.nr @@ -4,43 +4,112 @@ use crate::meta::op::BinaryOp; impl Expr { #[builtin(expr_as_array)] + // docs:start:as_array fn as_array(self) -> Option<[Expr]> {} + // docs:end:as_array + + #[builtin(expr_as_assign)] + // docs:start:as_assign + fn as_assign(self) -> Option<(Expr, Expr)> {} + // docs:end:as_assign #[builtin(expr_as_integer)] + // docs:start:as_integer fn as_integer(self) -> Option<(Field, bool)> {} + // docs:end:as_integer #[builtin(expr_as_binary_op)] + // docs:start:as_binary_op fn as_binary_op(self) -> Option<(Expr, BinaryOp, Expr)> {} + // docs:end:as_binary_op + + #[builtin(expr_as_block)] + // docs:start:as_block + fn as_block(self) -> Option<[Expr]> {} + // docs:end:as_block #[builtin(expr_as_bool)] + // docs:start:as_bool fn as_bool(self) -> Option {} + // docs:end:as_bool + + #[builtin(expr_as_cast)] + fn as_cast(self) -> Option<(Expr, UnresolvedType)> {} + + #[builtin(expr_as_comptime)] + // docs:start:as_comptime + fn as_comptime(self) -> Option<[Expr]> {} + // docs:end:as_comptime #[builtin(expr_as_function_call)] + // docs:start:as_function_call fn as_function_call(self) -> Option<(Expr, [Expr])> {} + // docs:end:as_function_call #[builtin(expr_as_if)] + // docs:start:as_if fn as_if(self) -> Option<(Expr, Expr, Option)> {} + // docs:end:as_if #[builtin(expr_as_index)] + // docs:start:as_index fn as_index(self) -> Option<(Expr, Expr)> {} + // docs:end:as_index #[builtin(expr_as_member_access)] + // docs:start:as_member_access fn as_member_access(self) -> Option<(Expr, Quoted)> {} + // docs:end:as_member_access + + #[builtin(expr_as_method_call)] + // docs:start:as_method_call + fn as_method_call(self) -> Option<(Expr, Quoted, [UnresolvedType], [Expr])> {} + // docs:end:as_method_call #[builtin(expr_as_repeated_element_array)] + // docs:start:as_repeated_element_array fn as_repeated_element_array(self) -> Option<(Expr, Expr)> {} + // docs:end:as_repeated_element_array #[builtin(expr_as_repeated_element_slice)] + // docs:start:as_repeated_element_slice fn as_repeated_element_slice(self) -> Option<(Expr, Expr)> {} + // docs:end:as_repeated_element_slice #[builtin(expr_as_slice)] + // docs:start:as_slice fn as_slice(self) -> Option<[Expr]> {} + // docs:end:as_slice #[builtin(expr_as_tuple)] + // docs:start:as_tuple fn as_tuple(self) -> Option<[Expr]> {} + // docs:end:as_tuple #[builtin(expr_as_unary_op)] + // docs:start:as_unary_op fn as_unary_op(self) -> Option<(UnaryOp, Expr)> {} + // docs:end:as_unary_op + + #[builtin(expr_as_unsafe)] + // docs:start:as_unsafe + fn as_unsafe(self) -> Option<[Expr]> {} + // docs:end:as_unsafe + + #[builtin(expr_has_semicolon)] + // docs:start:has_semicolon + fn has_semicolon(self) -> bool {} + // docs:end:has_semicolon + + #[builtin(expr_is_break)] + // docs:start:is_break + fn is_break(self) -> bool {} + // docs:end:is_break + + #[builtin(expr_is_continue)] + // docs:start:is_continue + fn is_continue(self) -> bool {} + // docs:end:is_continue } mod tests { diff --git a/noir_stdlib/src/meta/function_def.nr b/noir_stdlib/src/meta/function_def.nr index 2b5ddd008e..7ac8803e7e 100644 --- a/noir_stdlib/src/meta/function_def.nr +++ b/noir_stdlib/src/meta/function_def.nr @@ -1,19 +1,31 @@ impl FunctionDefinition { #[builtin(function_def_name)] + // docs:start:name fn name(self) -> Quoted {} + // docs:end:name #[builtin(function_def_parameters)] + // docs:start:parameters fn parameters(self) -> [(Quoted, Type)] {} + // docs:end:parameters #[builtin(function_def_return_type)] + // docs:start:return_type fn return_type(self) -> Type {} + // docs:end:return_type #[builtin(function_def_set_body)] + // docs:start:set_body fn set_body(self, body: Quoted) {} + // docs:end:set_body #[builtin(function_def_set_parameters)] + // docs:start:set_parameters fn set_parameters(self, parameters: [(Quoted, Type)]) {} + // docs:end:set_parameters #[builtin(function_def_set_return_type)] + // docs:start:set_return_type fn set_return_type(self, return_type: Type) {} + // docs:end:set_return_type } diff --git a/noir_stdlib/src/meta/mod.nr b/noir_stdlib/src/meta/mod.nr index d16a8648bc..be1b12540c 100644 --- a/noir_stdlib/src/meta/mod.nr +++ b/noir_stdlib/src/meta/mod.nr @@ -1,7 +1,3 @@ -use crate::collections::umap::UHashMap; -use crate::hash::BuildHasherDefault; -use crate::hash::poseidon2::Poseidon2Hasher; - mod expr; mod function_def; mod module; @@ -12,24 +8,43 @@ mod trait_def; mod trait_impl; mod typ; mod quoted; +mod unresolved_type; /// Calling unquote as a macro (via `unquote!(arg)`) will unquote /// its argument. Since this is the effect `!` already does, `unquote` /// itself does not need to do anything besides return its argument. +// docs:start:unquote pub comptime fn unquote(code: Quoted) -> Quoted { + // docs:end:unquote code } /// Returns the type of any value #[builtin(type_of)] +// docs:start:type_of pub comptime fn type_of(x: T) -> Type {} +// docs:end:type_of + +// docs:start:derive_example +// These are needed for the unconstrained hashmap we're using to store derive functions +use crate::collections::umap::UHashMap; +use crate::hash::BuildHasherDefault; +use crate::hash::poseidon2::Poseidon2Hasher; +// A derive function is one that given a struct definition can +// create us a quoted trait impl from it. type DeriveFunction = fn(StructDefinition) -> Quoted; +// We'll keep a global HANDLERS map to keep track of the derive handler for each trait comptime mut global HANDLERS: UHashMap> = UHashMap::default(); +// Given a struct and a slice of traits to derive, create trait impls for each. +// This function is as simple as iterating over the slice, checking if we have a trait +// handler registered for the given trait, calling it, and appending the result. +// docs:start:derive #[varargs] pub comptime fn derive(s: StructDefinition, traits: [TraitDefinition]) -> Quoted { + // docs:end:derive let mut result = quote {}; for trait_to_derive in traits { @@ -44,10 +59,16 @@ pub comptime fn derive(s: StructDefinition, traits: [TraitDefinition]) -> Quoted result } +// docs:end:derive_example +// docs:start:derive_via +// To register a handler for a trait, just add it to our handlers map +// docs:start:derive_via_signature pub comptime fn derive_via(t: TraitDefinition, f: DeriveFunction) { + // docs:end:derive_via_signature HANDLERS.insert(t, f); } +// docs:end:derive_via /// `make_impl` is a helper function to make a simple impl, usually while deriving a trait. /// This impl has a couple assumptions: @@ -61,6 +82,7 @@ pub comptime fn derive_via(t: TraitDefinition, f: DeriveFunction) { /// any final processing - e.g. wrapping each field in a `StructConstructor { .. }` expression. /// /// See `derive_eq` and `derive_default` for example usage. +// docs:start:make_trait_impl pub comptime fn make_trait_impl( s: StructDefinition, trait_name: Quoted, @@ -69,6 +91,7 @@ pub comptime fn make_trait_impl( join_fields_with: Quoted, body: fn[Env2](Quoted) -> Quoted ) -> Quoted { + // docs:end:make_trait_impl let typ = s.as_type(); let impl_generics = s.generics().map(|g| quote { $g }).join(quote {,}); let where_clause = s.generics().map(|name| quote { $name: $trait_name }).join(quote {,}); @@ -90,3 +113,124 @@ pub comptime fn make_trait_impl( } } } + +mod tests { + // docs:start:quote-example + comptime fn quote_one() -> Quoted { + quote { 1 } + } + + fn returning_versus_macro_insertion() { + comptime + { + // let _a: Quoted = quote { 1 }; + let _a: Quoted = quote_one(); + + // let _b: i32 = 1; + let _b: i32 = quote_one!(); + } + } + // docs:end:quote-example + + // docs:start:derive-field-count-example + trait FieldCount { + fn field_count() -> u32; + } + + #[derive_field_count] + struct Bar { x: Field, y: [Field; 2] } + + comptime fn derive_field_count(s: StructDefinition) -> Quoted { + let typ = s.as_type(); + let field_count = s.fields().len(); + quote { + impl FieldCount for $typ { + fn field_count() -> u32 { + $field_count + } + } + } + } + // docs:end:derive-field-count-example + + // docs:start:annotation-arguments-example + #[assert_field_is_type(quote { i32 }.as_type())] + struct MyStruct { my_field: i32 } + + comptime fn assert_field_is_type(s: StructDefinition, typ: Type) { + // Assert the first field in `s` has type `typ` + let fields = s.fields(); + assert_eq(fields[0].1, typ); + } + // docs:end:annotation-arguments-example + + // docs:start:annotation-varargs-example + #[assert_three_args(1, 2, 3)] + struct MyOtherStruct { my_other_field: u32 } + + #[varargs] + comptime fn assert_three_args(_s: StructDefinition, args: [Field]) { + assert_eq(args.len(), 3); + } + // docs:end:annotation-varargs-example + + // docs:start:big-derive-usage-example + // Finally, to register a handler we call the above function as an annotation + // with our handler function. + #[derive_via(derive_do_nothing)] + trait DoNothing { + fn do_nothing(self); + } + + comptime fn derive_do_nothing(s: StructDefinition) -> Quoted { + // This is simplified since we don't handle generics or where clauses! + // In a real example we'd likely also need to introduce each of + // `s.generics()` as well as a trait constraint for each generic + // to ensure they also implement the trait. + let typ = s.as_type(); + quote { + impl DoNothing for $typ { + fn do_nothing(self) { + // Traits can't tell us what to do + println("something"); + } + } + } + } + + // Since `DoNothing` is a simple trait which: + // 1. Only has one method + // 2. Does not have any generics on the trait itself + // We can use `std::meta::make_trait_impl` to help us out. + // This helper function will generate our impl for us along with any + // necessary where clauses and still provides a flexible interface + // for us to work on each field on the struct. + comptime fn derive_do_nothing_alt(s: StructDefinition) -> Quoted { + let trait_name = quote { DoNothing }; + let method_signature = quote { fn do_nothing(self) }; + + // Call `do_nothing` recursively on each field in the struct + let for_each_field = |field_name| quote { self.$field_name.do_nothing(); }; + + // Some traits like Eq want to join each field expression with something like `&`. + // We don't need that here + let join_fields_with = quote {}; + + // The body function is a spot to insert any extra setup/teardown needed. + // We'll insert our println here. Since we recur on each field, we should see + // one println for the struct itself, followed by a println for every field (recursively). + let body = |body| quote { + println("something"); + $body + }; + crate::meta::make_trait_impl( + s, + trait_name, + method_signature, + for_each_field, + join_fields_with, + body + ) + } + // docs:end:big-derive-usage-example +} diff --git a/noir_stdlib/src/meta/module.nr b/noir_stdlib/src/meta/module.nr index ee00f36080..6ea3ca55fb 100644 --- a/noir_stdlib/src/meta/module.nr +++ b/noir_stdlib/src/meta/module.nr @@ -1,10 +1,16 @@ impl Module { #[builtin(module_is_contract)] +// docs:start:is_contract fn is_contract(self) -> bool {} + // docs:end:is_contract #[builtin(module_functions)] +// docs:start:functions fn functions(self) -> [FunctionDefinition] {} + // docs:end:functions #[builtin(module_name)] +// docs:start:name fn name(self) -> Quoted {} + // docs:end:name } diff --git a/noir_stdlib/src/meta/op.nr b/noir_stdlib/src/meta/op.nr index ebd89677c5..9c892c4d80 100644 --- a/noir_stdlib/src/meta/op.nr +++ b/noir_stdlib/src/meta/op.nr @@ -3,19 +3,27 @@ struct UnaryOp { } impl UnaryOp { - fn is_minus(self) -> bool { + // docs:start:is_minus + pub fn is_minus(self) -> bool { + // docs:end:is_minus self.op == 0 } - fn is_not(self) -> bool { + // docs:start:is_not + pub fn is_not(self) -> bool { + // docs:end:is_not self.op == 1 } - fn is_mutable_reference(self) -> bool { + // docs:start:is_mutable_reference + pub fn is_mutable_reference(self) -> bool { + // docs:end:is_mutable_reference self.op == 2 } - fn is_dereference(self) -> bool { + // docs:start:is_dereference + pub fn is_dereference(self) -> bool { + // docs:end:is_dereference self.op == 3 } } @@ -25,67 +33,99 @@ struct BinaryOp { } impl BinaryOp { - fn is_add(self) -> bool { + // docs:start:is_add + pub fn is_add(self) -> bool { + // docs:end:is_add self.op == 0 } - fn is_subtract(self) -> bool { + // docs:start:is_subtract + pub fn is_subtract(self) -> bool { + // docs:end:is_subtract self.op == 1 } - fn is_multiply(self) -> bool { + // docs:start:is_multiply + pub fn is_multiply(self) -> bool { + // docs:end:is_multiply self.op == 2 } - fn is_divide(self) -> bool { + // docs:start:is_divide + pub fn is_divide(self) -> bool { + // docs:end:is_divide self.op == 3 } - fn is_equal(self) -> bool { + // docs:start:is_equal + pub fn is_equal(self) -> bool { + // docs:end:is_equal self.op == 4 } - fn is_not_equal(self) -> bool { + // docs:start:is_not_equal + pub fn is_not_equal(self) -> bool { + // docs:end:is_not_equal self.op == 5 } - fn is_less(self) -> bool { + // docs:start:is_less_than + pub fn is_less_than(self) -> bool { + // docs:end:is_less_than self.op == 6 } - fn is_less_equal(self) -> bool { + // docs:start:is_less_than_or_equal + pub fn is_less_than_or_equal(self) -> bool { + // docs:end:is_less_than_or_equal self.op == 7 } - fn is_greater(self) -> bool { + // docs:start:is_greater_than + pub fn is_greater_than(self) -> bool { + // docs:end:is_greater_than self.op == 8 } - fn is_greater_or_equal(self) -> bool { + // docs:start:is_greater_than_or_equal + pub fn is_greater_than_or_equal(self) -> bool { + // docs:end:is_greater_than_or_equal self.op == 9 } - fn is_and(self) -> bool { + // docs:start:is_and + pub fn is_and(self) -> bool { + // docs:end:is_and self.op == 10 } - fn is_or(self) -> bool { + // docs:start:is_or + pub fn is_or(self) -> bool { + // docs:end:is_or self.op == 11 } - fn is_xor(self) -> bool { + // docs:start:is_xor + pub fn is_xor(self) -> bool { + // docs:end:is_xor self.op == 12 } - fn is_shift_right(self) -> bool { + // docs:start:is_shift_right + pub fn is_shift_right(self) -> bool { + // docs:end:is_shift_right self.op == 13 } - fn is_shift_left(self) -> bool { + // docs:start:is_shift_left + pub fn is_shift_left(self) -> bool { + // docs:end:is_shift_left self.op == 14 } - fn is_modulo(self) -> bool { + // docs:start:is_modulo + pub fn is_modulo(self) -> bool { + // docs:end:is_modulo self.op == 15 } } diff --git a/noir_stdlib/src/meta/quoted.nr b/noir_stdlib/src/meta/quoted.nr index cccc3fe0f1..9fd1e9026b 100644 --- a/noir_stdlib/src/meta/quoted.nr +++ b/noir_stdlib/src/meta/quoted.nr @@ -3,16 +3,24 @@ use crate::option::Option; impl Quoted { #[builtin(quoted_as_expr)] +// docs:start:as_expr fn as_expr(self) -> Option {} + // docs:end:as_expr #[builtin(quoted_as_module)] +// docs:start:as_module fn as_module(self) -> Option {} + // docs:end:as_module #[builtin(quoted_as_trait_constraint)] +// docs:start:as_trait_constraint fn as_trait_constraint(self) -> TraitConstraint {} + // docs:end:as_trait_constraint #[builtin(quoted_as_type)] +// docs:start:as_type fn as_type(self) -> Type {} + // docs:end:as_type } impl Eq for Quoted { diff --git a/noir_stdlib/src/meta/struct_def.nr b/noir_stdlib/src/meta/struct_def.nr index 8d3f9ceb8a..60fdeba21a 100644 --- a/noir_stdlib/src/meta/struct_def.nr +++ b/noir_stdlib/src/meta/struct_def.nr @@ -2,14 +2,20 @@ impl StructDefinition { /// Return a syntactic version of this struct definition as a type. /// For example, `as_type(quote { type Foo { ... } })` would return `Foo` #[builtin(struct_def_as_type)] +// docs:start:as_type fn as_type(self) -> Type {} + // docs:end:as_type /// Return each generic on this struct. #[builtin(struct_def_generics)] +// docs:start:generics fn generics(self) -> [Type] {} + // docs:end:generics /// Returns (name, type) pairs of each field in this struct. Each type is as-is /// with any generic arguments unchanged. #[builtin(struct_def_fields)] +// docs:start:fields fn fields(self) -> [(Quoted, Type)] {} + // docs:end:fields } diff --git a/noir_stdlib/src/meta/trait_def.nr b/noir_stdlib/src/meta/trait_def.nr index ca381cb8e1..c26b571240 100644 --- a/noir_stdlib/src/meta/trait_def.nr +++ b/noir_stdlib/src/meta/trait_def.nr @@ -3,7 +3,9 @@ use crate::cmp::Eq; impl TraitDefinition { #[builtin(trait_def_as_trait_constraint)] +// docs:start:as_trait_constraint fn as_trait_constraint(_self: Self) -> TraitConstraint {} + // docs:end:as_trait_constraint } impl Eq for TraitDefinition { diff --git a/noir_stdlib/src/meta/trait_impl.nr b/noir_stdlib/src/meta/trait_impl.nr index 2f82ee5f43..15b02eac6b 100644 --- a/noir_stdlib/src/meta/trait_impl.nr +++ b/noir_stdlib/src/meta/trait_impl.nr @@ -1,7 +1,11 @@ impl TraitImpl { #[builtin(trait_impl_trait_generic_args)] +// docs:start:trait_generic_args fn trait_generic_args(self) -> [Type] {} + // docs:end:trait_generic_args #[builtin(trait_impl_methods)] +// docs:start:methods fn methods(self) -> [FunctionDefinition] {} + // docs:end:methods } diff --git a/noir_stdlib/src/meta/typ.nr b/noir_stdlib/src/meta/typ.nr index 67ad2a9673..a3f35b28e4 100644 --- a/noir_stdlib/src/meta/typ.nr +++ b/noir_stdlib/src/meta/typ.nr @@ -3,34 +3,54 @@ use crate::option::Option; impl Type { #[builtin(type_as_array)] +// docs:start:as_array fn as_array(self) -> Option<(Type, Type)> {} + // docs:end:as_array #[builtin(type_as_constant)] +// docs:start:as_constant fn as_constant(self) -> Option {} + // docs:end:as_constant #[builtin(type_as_integer)] +// docs:start:as_integer fn as_integer(self) -> Option<(bool, u8)> {} + // docs:end:as_integer #[builtin(type_as_slice)] +// docs:start:as_slice fn as_slice(self) -> Option {} + // docs:end:as_slice #[builtin(type_as_struct)] +// docs:start:as_struct fn as_struct(self) -> Option<(StructDefinition, [Type])> {} + // docs:end:as_struct #[builtin(type_as_tuple)] +// docs:start:as_tuple fn as_tuple(self) -> Option<[Type]> {} + // docs:end:as_tuple #[builtin(type_get_trait_impl)] +// docs:start:get_trait_impl fn get_trait_impl(self, constraint: TraitConstraint) -> Option {} + // docs:end:get_trait_impl #[builtin(type_implements)] +// docs:start:implements fn implements(self, constraint: TraitConstraint) -> bool {} + // docs:end:implements #[builtin(type_is_bool)] +// docs:start:is_bool fn is_bool(self) -> bool {} + // docs:end:is_bool #[builtin(type_is_field)] +// docs:start:is_field fn is_field(self) -> bool {} + // docs:end:is_field } impl Eq for Type { diff --git a/noir_stdlib/src/meta/unresolved_type.nr b/noir_stdlib/src/meta/unresolved_type.nr new file mode 100644 index 0000000000..2589174ed6 --- /dev/null +++ b/noir_stdlib/src/meta/unresolved_type.nr @@ -0,0 +1,6 @@ +impl UnresolvedType { + #[builtin(unresolved_type_is_field)] + // docs:start:is_field + fn is_field(self) -> bool {} + // docs:end:is_field +} diff --git a/test_programs/compile_success_empty/associated_types/Nargo.toml b/test_programs/compile_success_empty/associated_types/Nargo.toml new file mode 100644 index 0000000000..99b8e1b2d4 --- /dev/null +++ b/test_programs/compile_success_empty/associated_types/Nargo.toml @@ -0,0 +1,7 @@ +[package] +name = "associated_types" +type = "bin" +authors = [""] +compiler_version = ">=0.33.0" + +[dependencies] \ No newline at end of file diff --git a/test_programs/compile_success_empty/associated_types/src/main.nr b/test_programs/compile_success_empty/associated_types/src/main.nr new file mode 100644 index 0000000000..dbc6a393ec --- /dev/null +++ b/test_programs/compile_success_empty/associated_types/src/main.nr @@ -0,0 +1,28 @@ +trait Collection { + type Elem; + + fn cget(self, index: u32) -> Option; + + fn ctake(self, index: u32) -> Self::Elem { + self.cget(index).unwrap() + } +} + +impl Collection for [T; N] { + type Elem = T; + + fn cget(self, index: u32) -> Option { + if index < self.len() { + Option::some(self[index]) + } else { + Option::none() + } + } +} + +fn main() { + // Use zeroed here so that we don't help by adding another type constraint. + // We should know Elem = Field from the associated type alone. + let array = [1, 2, 3, 0, 5]; + assert_eq(array.ctake(3), std::mem::zeroed()); +} diff --git a/test_programs/compile_success_empty/comptime_module/src/main.nr b/test_programs/compile_success_empty/comptime_module/src/main.nr index e9c9817cfd..8d834381fe 100644 --- a/test_programs/compile_success_empty/comptime_module/src/main.nr +++ b/test_programs/compile_success_empty/comptime_module/src/main.nr @@ -24,3 +24,18 @@ fn main() { assert_eq(bar.name(), quote { bar }); } } + +// docs:start:as_module_example +mod baz { + mod qux {} +} + +#[test] +fn as_module_test() { + comptime + { + let my_mod = quote { baz::qux }.as_module().unwrap(); + assert_eq(my_mod.name(), quote { qux }); + } +} +// docs:end:as_module_example diff --git a/test_programs/compile_success_empty/comptime_type/src/main.nr b/test_programs/compile_success_empty/comptime_type/src/main.nr index f0b53a392e..6d98d1d173 100644 --- a/test_programs/compile_success_empty/comptime_type/src/main.nr +++ b/test_programs/compile_success_empty/comptime_type/src/main.nr @@ -113,6 +113,7 @@ fn main() { } } +// docs:start:implements_example fn function_with_where(_x: T) where T: SomeTrait { comptime { @@ -123,3 +124,4 @@ fn function_with_where(_x: T) where T: SomeTrait { assert(t.get_trait_impl(some_trait_i32).is_none()); } } +// docs:end:implements_example diff --git a/test_programs/compile_success_empty/regression_5823/Nargo.toml b/test_programs/compile_success_empty/regression_5823/Nargo.toml new file mode 100644 index 0000000000..a2de5c954b --- /dev/null +++ b/test_programs/compile_success_empty/regression_5823/Nargo.toml @@ -0,0 +1,7 @@ +[package] +name = "regression_5823" +type = "bin" +authors = [""] +compiler_version = ">=0.33.0" + +[dependencies] \ No newline at end of file diff --git a/test_programs/compile_success_empty/regression_5823/src/main.nr b/test_programs/compile_success_empty/regression_5823/src/main.nr new file mode 100644 index 0000000000..f615564fae --- /dev/null +++ b/test_programs/compile_success_empty/regression_5823/src/main.nr @@ -0,0 +1,5 @@ +fn main() { + let x = 1 as u64; + let y = 2 as u8; + assert_eq(x << y, 4); +} diff --git a/test_programs/compile_success_empty/serialize/Nargo.toml b/test_programs/compile_success_empty/serialize/Nargo.toml new file mode 100644 index 0000000000..2cf87765b8 --- /dev/null +++ b/test_programs/compile_success_empty/serialize/Nargo.toml @@ -0,0 +1,7 @@ +[package] +name = "serialize" +type = "bin" +authors = [""] +compiler_version = ">=0.32.0" + +[dependencies] diff --git a/test_programs/compile_success_empty/serialize/src/main.nr b/test_programs/compile_success_empty/serialize/src/main.nr new file mode 100644 index 0000000000..79114c5b56 --- /dev/null +++ b/test_programs/compile_success_empty/serialize/src/main.nr @@ -0,0 +1,59 @@ +trait Serialize { + let Size: u32; + + // Note that Rust disallows referencing constants here! + fn serialize(self) -> [Field; Self::Size]; +} + +impl Serialize for (A, B) where A: Serialize, B: Serialize { + // let Size = ::Size + ::Size; + let Size = AS + BS; + + fn serialize(self: Self) -> [Field; Self::Size] { + let mut array: [Field; Self::Size] = std::mem::zeroed(); + let a = self.0.serialize(); + let b = self.1.serialize(); + + for i in 0 .. a.len() { + array[i] = a[i]; + } + for i in 0 .. b.len() { + array[i + a.len()] = b[i]; + } + array + } +} + +impl Serialize for [T; N] where T: Serialize { + // let Size = ::Size * N; + let Size = TS * N; + + fn serialize(self: Self) -> [Field; Self::Size] { + let mut array: [Field; Self::Size] = std::mem::zeroed(); + let mut array_i = 0; + + for elem in self { + let elem_fields = elem.serialize(); + + for i in 0 .. elem_fields.len() { + array[array_i] = elem_fields[i]; + array_i += 1; + } + } + + array + } +} + +impl Serialize for Field { + let Size = 1; + + fn serialize(self) -> [Field; Self::Size] { + [self] + } +} + +fn main() { + let x = ((1, [2, 3, 4]), [5, 6, 7, 8]); + assert_eq(x.serialize().len(), 8); +} diff --git a/test_programs/execution_failure/bigint_from_too_many_le_bytes/Nargo.toml b/test_programs/execution_failure/bigint_from_too_many_le_bytes/Nargo.toml new file mode 100644 index 0000000000..cbdfc2d83d --- /dev/null +++ b/test_programs/execution_failure/bigint_from_too_many_le_bytes/Nargo.toml @@ -0,0 +1,7 @@ +[package] +name = "bigint_from_too_many_le_bytes" +type = "bin" +authors = [""] +compiler_version = ">=0.31.0" + +[dependencies] \ No newline at end of file diff --git a/test_programs/execution_failure/bigint_from_too_many_le_bytes/src/main.nr b/test_programs/execution_failure/bigint_from_too_many_le_bytes/src/main.nr new file mode 100644 index 0000000000..2d4587ee3d --- /dev/null +++ b/test_programs/execution_failure/bigint_from_too_many_le_bytes/src/main.nr @@ -0,0 +1,22 @@ +use std::bigint::{bn254_fq, BigInt}; + +// TODO(https://github.com/noir-lang/noir/issues/5580): decide whether this is desired behavior +// +// Fails at execution time: +// +// error: Assertion failed: 'Index out of bounds' +// ┌─ std/cmp.nr:35:34 +// │ +// 35 │ result &= self[i].eq(other[i]); +// │ -------- +// │ +// = Call stack: +// 1. /Users/michaelklein/Coding/rust/noir/test_programs/compile_failure/bigint_from_too_many_le_bytes/src/main.nr:7:12 +// 2. std/cmp.nr:35:34 +// Failed assertion +fn main() { + let bytes: [u8] = bn254_fq.push_front(0x00); + let bigint = BigInt::from_le_bytes(bytes, bn254_fq); + let result_bytes = bigint.to_le_bytes(); + assert(bytes == result_bytes.as_slice()); +} diff --git a/test_programs/execution_success/nested_dyn_array_regression_5782/Nargo.toml b/test_programs/execution_success/nested_dyn_array_regression_5782/Nargo.toml new file mode 100644 index 0000000000..b5cdd19e18 --- /dev/null +++ b/test_programs/execution_success/nested_dyn_array_regression_5782/Nargo.toml @@ -0,0 +1,7 @@ +[package] +name = "nested_dyn_array_regression_5782" +type = "bin" +authors = [""] +compiler_version = ">=0.33.0" + +[dependencies] \ No newline at end of file diff --git a/test_programs/execution_success/nested_dyn_array_regression_5782/Prover.toml b/test_programs/execution_success/nested_dyn_array_regression_5782/Prover.toml new file mode 100644 index 0000000000..de2960def0 --- /dev/null +++ b/test_programs/execution_success/nested_dyn_array_regression_5782/Prover.toml @@ -0,0 +1,2 @@ +array = [5, 10] +i = 1 diff --git a/test_programs/execution_success/nested_dyn_array_regression_5782/src/main.nr b/test_programs/execution_success/nested_dyn_array_regression_5782/src/main.nr new file mode 100644 index 0000000000..b6a1238a9d --- /dev/null +++ b/test_programs/execution_success/nested_dyn_array_regression_5782/src/main.nr @@ -0,0 +1,13 @@ +fn main(mut array: [Field; 2], i: u32) { + assert_eq(array[i - 1], 5); + assert_eq(array[i], 10); + + array[i] = 2; + + let array2 = [array, array]; + + assert_eq(array2[0][0], 5); + assert_eq(array2[0][i], 2); + assert_eq(array2[i][0], 5); + assert_eq(array2[i][i], 2); +} diff --git a/test_programs/execution_success/sha256/Prover.toml b/test_programs/execution_success/sha256/Prover.toml index c4df1b749b..b4bf416237 100644 --- a/test_programs/execution_success/sha256/Prover.toml +++ b/test_programs/execution_success/sha256/Prover.toml @@ -34,3 +34,5 @@ result = [ 0x73, 0x2b, ] +input = [0, 0] +toggle = false \ No newline at end of file diff --git a/test_programs/execution_success/sha256/src/main.nr b/test_programs/execution_success/sha256/src/main.nr index 4f999d349f..29bc9ac371 100644 --- a/test_programs/execution_success/sha256/src/main.nr +++ b/test_programs/execution_success/sha256/src/main.nr @@ -10,11 +10,16 @@ // Not yet here: For R1CS, it is more about manipulating arithmetic gates to get performance // This can be done in ACIR! -fn main(x: Field, result: [u8; 32]) { +fn main(x: Field, result: [u8; 32], input: [u8; 2], toggle: bool) { // We use the `as` keyword here to denote the fact that we want to take just the first byte from the x Field // The padding is taken care of by the program // docs:start:sha256_var let digest = std::hash::sha256_var([x as u8], 1); // docs:end:sha256_var assert(digest == result); + + // variable size + let size: Field = 1 + toggle as Field; + let var_sha = std::hash::sha256_var(input, size as u64); + assert(var_sha == std::hash::sha256_var(input, 1)); } diff --git a/test_programs/noir_test_success/comptime_expr/Nargo.toml b/test_programs/noir_test_success/comptime_expr/Nargo.toml new file mode 100644 index 0000000000..a40da9c5f2 --- /dev/null +++ b/test_programs/noir_test_success/comptime_expr/Nargo.toml @@ -0,0 +1,7 @@ +[package] +name = "comptime_expr" +type = "bin" +authors = [""] +compiler_version = ">=0.31.0" + +[dependencies] \ No newline at end of file diff --git a/test_programs/noir_test_success/comptime_expr/src/main.nr b/test_programs/noir_test_success/comptime_expr/src/main.nr new file mode 100644 index 0000000000..329e97dc9d --- /dev/null +++ b/test_programs/noir_test_success/comptime_expr/src/main.nr @@ -0,0 +1,313 @@ +mod tests { + use std::meta::op::UnaryOp; + use std::meta::op::BinaryOp; + + #[test] + fn test_expr_as_array() { + comptime + { + let expr = quote { [1, 2, 4] }.as_expr().unwrap(); + let elems = expr.as_array().unwrap(); + assert_eq(elems.len(), 3); + assert_eq(elems[0].as_integer().unwrap(), (1, false)); + assert_eq(elems[1].as_integer().unwrap(), (2, false)); + assert_eq(elems[2].as_integer().unwrap(), (4, false)); + } + } + + #[test] + fn test_expr_as_assign() { + comptime + { + let expr = quote { { a = 1; } }.as_expr().unwrap(); + let exprs = expr.as_block().unwrap(); + let (_lhs, rhs) = exprs[0].as_assign().unwrap(); + assert_eq(rhs.as_integer().unwrap(), (1, false)); + } + } + + #[test] + fn test_expr_as_block() { + comptime + { + let expr = quote { { 1; 4; 23 } }.as_expr().unwrap(); + let exprs = expr.as_block().unwrap(); + assert_eq(exprs.len(), 3); + assert_eq(exprs[0].as_integer().unwrap(), (1, false)); + assert_eq(exprs[1].as_integer().unwrap(), (4, false)); + assert_eq(exprs[2].as_integer().unwrap(), (23, false)); + + assert(exprs[0].has_semicolon()); + assert(exprs[1].has_semicolon()); + assert(!exprs[2].has_semicolon()); + } + } + + #[test] + fn test_expr_as_method_call() { + comptime + { + let expr = quote { foo.bar::(3, 4) }.as_expr().unwrap(); + let (_object, name, generics, arguments) = expr.as_method_call().unwrap(); + + assert_eq(name, quote { bar }); + + assert_eq(generics.len(), 1); + assert(generics[0].is_field()); + + assert_eq(arguments.len(), 2); + assert_eq(arguments[0].as_integer().unwrap(), (3, false)); + assert_eq(arguments[1].as_integer().unwrap(), (4, false)); + } + } + + #[test] + fn test_expr_as_integer() { + comptime + { + let expr = quote { 1 }.as_expr().unwrap(); + assert_eq((1, false), expr.as_integer().unwrap()); + + let expr = quote { -2 }.as_expr().unwrap(); + assert_eq((2, true), expr.as_integer().unwrap()); + } + } + + #[test] + fn test_expr_as_binary_op() { + comptime + { + assert(get_binary_op(quote { x + y }).is_add()); + assert(get_binary_op(quote { x - y }).is_subtract()); + assert(get_binary_op(quote { x * y }).is_multiply()); + assert(get_binary_op(quote { x / y }).is_divide()); + assert(get_binary_op(quote { x == y }).is_equal()); + assert(get_binary_op(quote { x != y }).is_not_equal()); + assert(get_binary_op(quote { x < y }).is_less_than()); + assert(get_binary_op(quote { x <= y }).is_less_than_or_equal()); + assert(get_binary_op(quote { x > y }).is_greater_than()); + assert(get_binary_op(quote { x >= y }).is_greater_than_or_equal()); + assert(get_binary_op(quote { x & y }).is_and()); + assert(get_binary_op(quote { x | y }).is_or()); + assert(get_binary_op(quote { x ^ y }).is_xor()); + assert(get_binary_op(quote { x >> y }).is_shift_right()); + assert(get_binary_op(quote { x << y }).is_shift_left()); + assert(get_binary_op(quote { x % y }).is_modulo()); + } + } + + #[test] + fn test_expr_as_bool() { + comptime + { + let expr = quote { false }.as_expr().unwrap(); + assert(expr.as_bool().unwrap() == false); + + let expr = quote { true }.as_expr().unwrap(); + assert_eq(expr.as_bool().unwrap(), true); + } + } + + #[test] + fn test_expr_as_cast() { + comptime + { + let expr = quote { 1 as Field }.as_expr().unwrap(); + let (expr, typ) = expr.as_cast().unwrap(); + assert_eq(expr.as_integer().unwrap(), (1, false)); + assert(typ.is_field()); + } + } + + #[test] + fn test_expr_as_comptime() { + comptime + { + let expr = quote { comptime { 1; 4; 23 } }.as_expr().unwrap(); + let exprs = expr.as_comptime().unwrap(); + assert_eq(exprs.len(), 3); + } + } + + #[test] + fn test_expr_as_comptime_as_statement() { + comptime + { + let expr = quote { { comptime { 1; 4; 23 } } }.as_expr().unwrap(); + let exprs = expr.as_block().unwrap(); + assert_eq(exprs.len(), 1); + + let exprs = exprs[0].as_comptime().unwrap(); + assert_eq(exprs.len(), 3); + } + } + + // This test can't only be around the comptime block since that will cause + // `nargo fmt` to remove the comptime keyword. + // docs:start:as_expr_example + #[test] + fn test_expr_as_function_call() { + comptime + { + let expr = quote { foo(42) }.as_expr().unwrap(); + let (_function, args) = expr.as_function_call().unwrap(); + assert_eq(args.len(), 1); + assert_eq(args[0].as_integer().unwrap(), (42, false)); + } + } + // docs:end:as_expr_example + + #[test] + fn test_expr_as_if() { + comptime + { + let expr = quote { if 1 { 2 } }.as_expr().unwrap(); + let (_condition, _consequence, alternative) = expr.as_if().unwrap(); + assert(alternative.is_none()); + + let expr = quote { if 1 { 2 } else { 3 } }.as_expr().unwrap(); + let (_condition, _consequence, alternative) = expr.as_if().unwrap(); + assert(alternative.is_some()); + } + } + + #[test] + fn test_expr_as_index() { + comptime + { + let expr = quote { foo[bar] }.as_expr().unwrap(); + assert(expr.as_index().is_some()); + } + } + + #[test] + fn test_expr_as_member_access() { + comptime + { + let expr = quote { foo.bar }.as_expr().unwrap(); + let (_, name) = expr.as_member_access().unwrap(); + assert_eq(name, quote { bar }); + } + } + + #[test] + fn test_expr_as_member_access_with_an_lvalue() { + comptime + { + let expr = quote { { foo.bar = 1; } }.as_expr().unwrap(); + let exprs = expr.as_block().unwrap(); + let (lhs, _rhs) = exprs[0].as_assign().unwrap(); + let (_, name) = lhs.as_member_access().unwrap(); + assert_eq(name, quote { bar }); + } + } + + #[test] + fn test_expr_as_repeated_element_array() { + comptime + { + let expr = quote { [1; 3] }.as_expr().unwrap(); + let (expr, length) = expr.as_repeated_element_array().unwrap(); + assert_eq(expr.as_integer().unwrap(), (1, false)); + assert_eq(length.as_integer().unwrap(), (3, false)); + } + } + + #[test] + fn test_expr_as_repeated_element_slice() { + comptime + { + let expr = quote { &[1; 3] }.as_expr().unwrap(); + let (expr, length) = expr.as_repeated_element_slice().unwrap(); + assert_eq(expr.as_integer().unwrap(), (1, false)); + assert_eq(length.as_integer().unwrap(), (3, false)); + } + } + + #[test] + fn test_expr_as_slice() { + comptime + { + let expr = quote { &[1, 3, 5] }.as_expr().unwrap(); + let elems = expr.as_slice().unwrap(); + assert_eq(elems.len(), 3); + assert_eq(elems[0].as_integer().unwrap(), (1, false)); + assert_eq(elems[1].as_integer().unwrap(), (3, false)); + assert_eq(elems[2].as_integer().unwrap(), (5, false)); + } + } + + #[test] + fn test_expr_as_tuple() { + comptime + { + let expr = quote { (1, 2) }.as_expr().unwrap(); + let tuple_exprs = expr.as_tuple().unwrap(); + assert_eq(tuple_exprs.len(), 2); + } + } + + #[test] + fn test_expr_as_unary_op() { + comptime + { + assert(get_unary_op(quote { -x }).is_minus()); + assert(get_unary_op(quote { !x }).is_not()); + assert(get_unary_op(quote { &mut x }).is_mutable_reference()); + assert(get_unary_op(quote { *x }).is_dereference()); + } + } + + #[test] + fn test_expr_as_unsafe() { + comptime + { + let expr = quote { unsafe { 1; 4; 23 } }.as_expr().unwrap(); + let exprs = expr.as_unsafe().unwrap(); + assert_eq(exprs.len(), 3); + } + } + + #[test] + fn test_expr_is_break() { + comptime + { + let expr = quote { { break; } }.as_expr().unwrap(); + let exprs = expr.as_block().unwrap(); + assert(exprs[0].is_break()); + } + } + + #[test] + fn test_expr_is_continue() { + comptime + { + let expr = quote { { continue; } }.as_expr().unwrap(); + let exprs = expr.as_block().unwrap(); + assert(exprs[0].is_continue()); + } + } + + #[test] + fn test_automatically_unwraps_parenthesized_expression() { + comptime + { + let expr = quote { ((if 1 { 2 })) }.as_expr().unwrap(); + assert(expr.as_if().is_some()); + } + } + + comptime fn get_unary_op(quoted: Quoted) -> UnaryOp { + let expr = quoted.as_expr().unwrap(); + let (op, _) = expr.as_unary_op().unwrap(); + op + } + + comptime fn get_binary_op(quoted: Quoted) -> BinaryOp { + let expr = quoted.as_expr().unwrap(); + let (_, op, _) = expr.as_binary_op().unwrap(); + op + } +} + +fn main() {} diff --git a/test_programs/rebuild.sh b/test_programs/rebuild.sh index 452c91bfc0..1f2e199e81 100755 --- a/test_programs/rebuild.sh +++ b/test_programs/rebuild.sh @@ -80,7 +80,7 @@ fi rm -f "$current_dir/rebuild.log" # Process directories in parallel -parallel -j0 process_dir {} "$current_dir" ::: "${dirs_to_process[@]}" +parallel -j7 process_dir {} "$current_dir" ::: ${dirs_to_process[@]} # Check rebuild.log for failures if [ -f "$current_dir/rebuild.log" ]; then diff --git a/tooling/debugger/src/context.rs b/tooling/debugger/src/context.rs index 890732b579..0d348cf172 100644 --- a/tooling/debugger/src/context.rs +++ b/tooling/debugger/src/context.rs @@ -442,16 +442,15 @@ impl<'a, B: BlackBoxFunctionSolver> DebugContext<'a, B> { self.debug_artifact.debug_symbols[debug_location.circuit_id as usize] .opcode_location(&debug_location.opcode_location) .unwrap_or_else(|| { - if let Some(brillig_function_id) = debug_location.brillig_function_id { + if let (Some(brillig_function_id), Some(brillig_location)) = ( + debug_location.brillig_function_id, + debug_location.opcode_location.to_brillig_location(), + ) { let brillig_locations = self.debug_artifact.debug_symbols [debug_location.circuit_id as usize] .brillig_locations .get(&brillig_function_id); - brillig_locations - .unwrap() - .get(&debug_location.opcode_location) - .cloned() - .unwrap_or_default() + brillig_locations.unwrap().get(&brillig_location).cloned().unwrap_or_default() } else { vec![] } @@ -660,8 +659,9 @@ impl<'a, B: BlackBoxFunctionSolver> DebugContext<'a, B> { fn get_current_acir_index(&self) -> Option { self.get_current_debug_location().map(|debug_location| { match debug_location.opcode_location { - OpcodeLocation::Acir(acir_index) => acir_index, - OpcodeLocation::Brillig { acir_index, .. } => acir_index, + OpcodeLocation::Acir(acir_index) | OpcodeLocation::Brillig { acir_index, .. } => { + acir_index + } } }) } @@ -893,8 +893,19 @@ fn build_source_to_opcode_debug_mappings( ); for (brillig_function_id, brillig_locations_map) in &debug_symbols.brillig_locations { + let brillig_locations_map = brillig_locations_map + .iter() + .map(|(key, val)| { + ( + // TODO: this is a temporary placeholder until the debugger is updated to handle the new brillig debug locations. + OpcodeLocation::Brillig { acir_index: 0, brillig_index: key.0 }, + val.clone(), + ) + }) + .collect(); + add_opcode_locations_map( - brillig_locations_map, + &brillig_locations_map, &mut result, &simple_files, circuit_id, diff --git a/tooling/lsp/src/requests/completion.rs b/tooling/lsp/src/requests/completion.rs index 28388230e9..c61f92795a 100644 --- a/tooling/lsp/src/requests/completion.rs +++ b/tooling/lsp/src/requests/completion.rs @@ -15,16 +15,20 @@ use lsp_types::{CompletionItem, CompletionItemKind, CompletionParams, Completion use noirc_errors::{Location, Span}; use noirc_frontend::{ ast::{ - AsTraitPath, BlockExpression, ConstructorExpression, Expression, ExpressionKind, - ForLoopStatement, Ident, IfExpression, LValue, Lambda, LetStatement, - MemberAccessExpression, NoirFunction, NoirStruct, NoirTraitImpl, Path, PathKind, - PathSegment, Pattern, Statement, StatementKind, TraitItem, TypeImpl, UnresolvedGeneric, - UnresolvedGenerics, UnresolvedType, UnresolvedTypeData, UseTree, UseTreeKind, + AsTraitPath, BlockExpression, CallExpression, ConstructorExpression, Expression, + ExpressionKind, ForLoopStatement, Ident, IfExpression, ItemVisibility, LValue, Lambda, + LetStatement, MemberAccessExpression, MethodCallExpression, NoirFunction, NoirStruct, + NoirTraitImpl, Path, PathKind, PathSegment, Pattern, Statement, StatementKind, TraitItem, + TypeImpl, UnresolvedGeneric, UnresolvedGenerics, UnresolvedType, UnresolvedTypeData, + UseTree, UseTreeKind, }, graph::{CrateId, Dependency}, hir::{ def_map::{CrateDefMap, LocalModuleId, ModuleId}, - resolution::path_resolver::{PathResolver, StandardPathResolver}, + resolution::{ + import::can_reference_module_id, + path_resolver::{PathResolver, StandardPathResolver}, + }, }, hir_def::traits::Trait, macros_api::{ModuleDefId, NodeInterner}, @@ -338,6 +342,71 @@ impl<'a> NodeFinder<'a> { } } + pub(super) fn find_in_call_expression(&mut self, call_expression: &CallExpression) { + // Check if it's this case: + // + // foo::b>|<(...) + // + // In this case we want to suggest items in foo but if they are functions + // we don't want to insert arguments, because they are already there (even if + // they could be wrong) just because inserting them would lead to broken code. + if let ExpressionKind::Variable(path) = &call_expression.func.kind { + if self.includes_span(path.span) { + self.find_in_path_impl(path, RequestedItems::AnyItems, true); + return; + } + } + + // Check if it's this case: + // + // foo.>|<(...) + // + // "foo." is actually broken, but it's parsed as "foo", so this is seen + // as "foo(...)" but if we are at a dot right after "foo" it means it's + // the above case and we want to suggest methods of foo's type. + let after_dot = self.byte == Some(b'.'); + if after_dot && call_expression.func.span.end() as usize == self.byte_index - 1 { + let location = Location::new(call_expression.func.span, self.file); + if let Some(typ) = self.interner.type_at_location(location) { + let typ = typ.follow_bindings(); + let prefix = ""; + self.complete_type_fields_and_methods(&typ, prefix, FunctionCompletionKind::Name); + return; + } + } + + self.find_in_expression(&call_expression.func); + self.find_in_expressions(&call_expression.arguments); + } + + pub(super) fn find_in_method_call_expression( + &mut self, + method_call_expression: &MethodCallExpression, + ) { + // Check if it's this case: + // + // foo.b>|<(...) + // + // In this case we want to suggest items in foo but if they are functions + // we don't want to insert arguments, because they are already there (even if + // they could be wrong) just because inserting them would lead to broken code. + if self.includes_span(method_call_expression.method_name.span()) { + let location = Location::new(method_call_expression.object.span, self.file); + if let Some(typ) = self.interner.type_at_location(location) { + let typ = typ.follow_bindings(); + let prefix = method_call_expression.method_name.to_string(); + let offset = + self.byte_index - method_call_expression.method_name.span().start() as usize; + let prefix = prefix[0..offset].to_string(); + self.complete_type_fields_and_methods(&typ, &prefix, FunctionCompletionKind::Name); + return; + } + } + + self.find_in_expression(&method_call_expression.object); + self.find_in_expressions(&method_call_expression.arguments); + } + fn find_in_block_expression(&mut self, block_expression: &BlockExpression) { let old_local_variables = self.local_variables.clone(); for statement in &block_expression.statements { @@ -418,7 +487,11 @@ impl<'a> NodeFinder<'a> { { let typ = self.interner.definition_type(definition_id); let prefix = ""; - self.complete_type_fields_and_methods(&typ, prefix); + self.complete_type_fields_and_methods( + &typ, + prefix, + FunctionCompletionKind::NameAndParameters, + ); } } } @@ -508,7 +581,11 @@ impl<'a> NodeFinder<'a> { if let Some(typ) = self.interner.type_at_location(location) { let typ = typ.follow_bindings(); let prefix = ""; - self.complete_type_fields_and_methods(&typ, prefix); + self.complete_type_fields_and_methods( + &typ, + prefix, + FunctionCompletionKind::NameAndParameters, + ); } } } @@ -570,7 +647,11 @@ impl<'a> NodeFinder<'a> { if let Some(typ) = self.interner.type_at_location(location) { let typ = typ.follow_bindings(); let prefix = ident.to_string().to_case(Case::Snake); - self.complete_type_fields_and_methods(&typ, &prefix); + self.complete_type_fields_and_methods( + &typ, + &prefix, + FunctionCompletionKind::NameAndParameters, + ); return; } } @@ -628,11 +709,11 @@ impl<'a> NodeFinder<'a> { } UnresolvedTypeData::Named(path, unresolved_types, _) => { self.find_in_path(path, RequestedItems::OnlyTypes); - self.find_in_unresolved_types(unresolved_types); + self.find_in_type_args(unresolved_types); } UnresolvedTypeData::TraitAsType(path, unresolved_types) => { self.find_in_path(path, RequestedItems::OnlyTypes); - self.find_in_unresolved_types(unresolved_types); + self.find_in_type_args(unresolved_types); } UnresolvedTypeData::MutableReference(unresolved_type) => { self.find_in_unresolved_type(unresolved_type); @@ -663,15 +744,61 @@ impl<'a> NodeFinder<'a> { } fn find_in_path(&mut self, path: &Path, requested_items: RequestedItems) { - // Only offer completions if we are right at the end of the path - if self.byte_index != path.span.end() as usize { + self.find_in_path_impl(path, requested_items, false); + } + + fn find_in_path_impl( + &mut self, + path: &Path, + requested_items: RequestedItems, + mut in_the_middle: bool, + ) { + if !self.includes_span(path.span) { return; } let after_colons = self.byte == Some(b':'); - let mut idents: Vec = - path.segments.iter().map(|segment| segment.ident.clone()).collect(); + let mut idents: Vec = Vec::new(); + + // Find in which ident we are in, and in which part of it + // (it could be that we are completting in the middle of an ident) + for segment in &path.segments { + let ident = &segment.ident; + + // Check if we are at the end of the ident + if self.byte_index == ident.span().end() as usize { + idents.push(ident.clone()); + break; + } + + // Check if we are in the middle of an ident + if self.includes_span(ident.span()) { + // If so, take the substring and push that as the list of idents + // we'll do autocompletion for + let offset = self.byte_index - ident.span().start() as usize; + let substring = ident.0.contents[0..offset].to_string(); + let ident = Ident::new( + substring, + Span::from(ident.span().start()..ident.span().start() + offset as u32), + ); + idents.push(ident); + in_the_middle = true; + break; + } + + idents.push(ident.clone()); + + // Stop if the cursor is right after this ident and '::' + if after_colons && self.byte_index == ident.span().end() as usize + 2 { + break; + } + } + + if idents.len() < path.segments.len() { + in_the_middle = true; + } + let prefix; let at_root; @@ -688,6 +815,21 @@ impl<'a> NodeFinder<'a> { let is_single_segment = !after_colons && idents.is_empty() && path.kind == PathKind::Plain; let module_id; + let module_completion_kind = if after_colons || !idents.is_empty() { + ModuleCompletionKind::DirectChildren + } else { + ModuleCompletionKind::AllVisibleItems + }; + + // When completing in the middle of an ident, we don't want to complete + // with function parameters because there might already be function parameters, + // and in the middle of a path it leads to code that won't compile + let function_completion_kind = if in_the_middle { + FunctionCompletionKind::Name + } else { + FunctionCompletionKind::NameAndParameters + }; + if idents.is_empty() { module_id = self.module_id; } else { @@ -703,6 +845,7 @@ impl<'a> NodeFinder<'a> { &Type::Struct(struct_type, vec![]), &prefix, FunctionKind::Any, + function_completion_kind, ); return; } @@ -713,25 +856,28 @@ impl<'a> NodeFinder<'a> { ModuleDefId::TypeAliasId(type_alias_id) => { let type_alias = self.interner.get_type_alias(type_alias_id); let type_alias = type_alias.borrow(); - self.complete_type_methods(&type_alias.typ, &prefix, FunctionKind::Any); + self.complete_type_methods( + &type_alias.typ, + &prefix, + FunctionKind::Any, + function_completion_kind, + ); return; } ModuleDefId::TraitId(trait_id) => { let trait_ = self.interner.get_trait(trait_id); - self.complete_trait_methods(trait_, &prefix, FunctionKind::Any); + self.complete_trait_methods( + trait_, + &prefix, + FunctionKind::Any, + function_completion_kind, + ); return; } ModuleDefId::GlobalId(_) => return, } } - let module_completion_kind = if after_colons { - ModuleCompletionKind::DirectChildren - } else { - ModuleCompletionKind::AllVisibleItems - }; - let function_completion_kind = FunctionCompletionKind::NameAndParameters; - self.complete_in_module( module_id, &prefix, @@ -746,7 +892,7 @@ impl<'a> NodeFinder<'a> { match requested_items { RequestedItems::AnyItems => { self.local_variables_completion(&prefix); - self.builtin_functions_completion(&prefix); + self.builtin_functions_completion(&prefix, function_completion_kind); self.builtin_values_completion(&prefix); } RequestedItems::OnlyTypes => { @@ -754,7 +900,7 @@ impl<'a> NodeFinder<'a> { self.type_parameters_completion(&prefix); } } - self.complete_auto_imports(&prefix, requested_items); + self.complete_auto_imports(&prefix, requested_items, function_completion_kind); } } @@ -925,17 +1071,30 @@ impl<'a> NodeFinder<'a> { }; } - fn complete_type_fields_and_methods(&mut self, typ: &Type, prefix: &str) { + fn complete_type_fields_and_methods( + &mut self, + typ: &Type, + prefix: &str, + function_completion_kind: FunctionCompletionKind, + ) { match typ { Type::Struct(struct_type, generics) => { self.complete_struct_fields(&struct_type.borrow(), generics, prefix); } Type::MutableReference(typ) => { - return self.complete_type_fields_and_methods(typ, prefix); + return self.complete_type_fields_and_methods( + typ, + prefix, + function_completion_kind, + ); } Type::Alias(type_alias, _) => { let type_alias = type_alias.borrow(); - return self.complete_type_fields_and_methods(&type_alias.typ, prefix); + return self.complete_type_fields_and_methods( + &type_alias.typ, + prefix, + function_completion_kind, + ); } Type::Tuple(types) => { self.complete_tuple_fields(types); @@ -959,10 +1118,21 @@ impl<'a> NodeFinder<'a> { | Type::Error => (), } - self.complete_type_methods(typ, prefix, FunctionKind::SelfType(typ)); + self.complete_type_methods( + typ, + prefix, + FunctionKind::SelfType(typ), + function_completion_kind, + ); } - fn complete_type_methods(&mut self, typ: &Type, prefix: &str, function_kind: FunctionKind) { + fn complete_type_methods( + &mut self, + typ: &Type, + prefix: &str, + function_kind: FunctionKind, + function_completion_kind: FunctionCompletionKind, + ) { let Some(methods_by_name) = self.interner.get_type_methods(typ) else { return; }; @@ -972,7 +1142,7 @@ impl<'a> NodeFinder<'a> { if name_matches(name, prefix) { if let Some(completion_item) = self.function_completion_item( func_id, - FunctionCompletionKind::NameAndParameters, + function_completion_kind, function_kind, ) { self.completion_items.push(completion_item); @@ -988,14 +1158,13 @@ impl<'a> NodeFinder<'a> { trait_: &Trait, prefix: &str, function_kind: FunctionKind, + function_completion_kind: FunctionCompletionKind, ) { for (name, func_id) in &trait_.method_ids { if name_matches(name, prefix) { - if let Some(completion_item) = self.function_completion_item( - *func_id, - FunctionCompletionKind::NameAndParameters, - function_kind, - ) { + if let Some(completion_item) = + self.function_completion_item(*func_id, function_completion_kind, function_kind) + { self.completion_items.push(completion_item); self.suggested_module_def_ids.insert(ModuleDefId::FunctionId(*func_id)); } @@ -1072,29 +1241,33 @@ impl<'a> NodeFinder<'a> { if name_matches(name, prefix) { let per_ns = module_data.find_name(ident); - if let Some((module_def_id, _, _)) = per_ns.types { - if let Some(completion_item) = self.module_def_id_completion_item( - module_def_id, - name.clone(), - function_completion_kind, - function_kind, - requested_items, - ) { - self.completion_items.push(completion_item); - self.suggested_module_def_ids.insert(module_def_id); + if let Some((module_def_id, visibility, _)) = per_ns.types { + if is_visible(module_id, self.module_id, visibility, self.def_maps) { + if let Some(completion_item) = self.module_def_id_completion_item( + module_def_id, + name.clone(), + function_completion_kind, + function_kind, + requested_items, + ) { + self.completion_items.push(completion_item); + self.suggested_module_def_ids.insert(module_def_id); + } } } - if let Some((module_def_id, _, _)) = per_ns.values { - if let Some(completion_item) = self.module_def_id_completion_item( - module_def_id, - name.clone(), - function_completion_kind, - function_kind, - requested_items, - ) { - self.completion_items.push(completion_item); - self.suggested_module_def_ids.insert(module_def_id); + if let Some((module_def_id, visibility, _)) = per_ns.values { + if is_visible(module_id, self.module_id, visibility, self.def_maps) { + if let Some(completion_item) = self.module_def_id_completion_item( + module_def_id, + name.clone(), + function_completion_kind, + function_kind, + requested_items, + ) { + self.completion_items.push(completion_item); + self.suggested_module_def_ids.insert(module_def_id); + } } } } @@ -1218,6 +1391,21 @@ fn module_def_id_from_reference_id(reference_id: ReferenceId) -> Option, +) -> bool { + can_reference_module_id( + def_maps, + current_module_id.krate, + current_module_id.local_id, + target_module_id, + visibility, + ) +} + #[cfg(test)] mod completion_name_matches_tests { use crate::requests::completion::name_matches; diff --git a/tooling/lsp/src/requests/completion/auto_import.rs b/tooling/lsp/src/requests/completion/auto_import.rs index 8d7824502c..7c56a0758c 100644 --- a/tooling/lsp/src/requests/completion/auto_import.rs +++ b/tooling/lsp/src/requests/completion/auto_import.rs @@ -17,7 +17,12 @@ use super::{ }; impl<'a> NodeFinder<'a> { - pub(super) fn complete_auto_imports(&mut self, prefix: &str, requested_items: RequestedItems) { + pub(super) fn complete_auto_imports( + &mut self, + prefix: &str, + requested_items: RequestedItems, + function_completion_kind: FunctionCompletionKind, + ) { let current_module_parent_id = get_parent_module_id(self.def_maps, self.module_id); for (name, entries) in self.interner.get_auto_import_names() { @@ -33,7 +38,7 @@ impl<'a> NodeFinder<'a> { let Some(mut completion_item) = self.module_def_id_completion_item( *module_def_id, name.clone(), - FunctionCompletionKind::NameAndParameters, + function_completion_kind, FunctionKind::Any, requested_items, ) else { diff --git a/tooling/lsp/src/requests/completion/builtins.rs b/tooling/lsp/src/requests/completion/builtins.rs index 75eba7fb3c..b9c4ce2358 100644 --- a/tooling/lsp/src/requests/completion/builtins.rs +++ b/tooling/lsp/src/requests/completion/builtins.rs @@ -4,19 +4,38 @@ use strum::IntoEnumIterator; use super::{ completion_items::{simple_completion_item, snippet_completion_item}, + kinds::FunctionCompletionKind, name_matches, NodeFinder, }; impl<'a> NodeFinder<'a> { - pub(super) fn builtin_functions_completion(&mut self, prefix: &str) { + pub(super) fn builtin_functions_completion( + &mut self, + prefix: &str, + function_completion_kind: FunctionCompletionKind, + ) { for keyword in Keyword::iter() { if let Some(func) = keyword_builtin_function(&keyword) { if name_matches(func.name, prefix) { + let description = Some(func.description.to_string()); + let label; + let insert_text; + match function_completion_kind { + FunctionCompletionKind::Name => { + label = func.name.to_string(); + insert_text = func.name.to_string(); + } + FunctionCompletionKind::NameAndParameters => { + label = format!("{}(…)", func.name); + insert_text = format!("{}({})", func.name, func.parameters); + } + } + self.completion_items.push(snippet_completion_item( - format!("{}(…)", func.name), + label, CompletionItemKind::FUNCTION, - format!("{}({})", func.name, func.parameters), - Some(func.description.to_string()), + insert_text, + description, )); } } @@ -76,6 +95,7 @@ pub(super) fn keyword_builtin_type(keyword: &Keyword) -> Option<&'static str> { Keyword::TraitDefinition => Some("TraitDefinition"), Keyword::TraitImpl => Some("TraitImpl"), Keyword::TypeType => Some("Type"), + Keyword::UnresolvedType => Some("UnresolvedType"), Keyword::As | Keyword::Assert @@ -183,6 +203,7 @@ pub(super) fn keyword_builtin_function(keyword: &Keyword) -> Option u64 { 0 } + pub fn bar(x: i32) -> u64 { 0 } + fn bar_is_private(x: i32) -> u64 { 0 } } use foo::>|< "#; @@ -1638,4 +1639,156 @@ mod completion_tests { }) ); } + + #[test] + async fn test_auto_import_from_std() { + let src = r#" + fn main() { + compute_merkle_roo>|< + } + "#; + let items = get_completions(src).await; + assert_eq!(items.len(), 1); + + let item = &items[0]; + assert_eq!(item.label, "compute_merkle_root(…)"); + assert_eq!( + item.label_details.as_ref().unwrap().detail, + Some("(use std::merkle::compute_merkle_root)".to_string()), + ); + } + + #[test] + async fn test_completes_after_first_letter_of_path() { + let src = r#" + fn main() { + h>||||||<() + } + "#; + assert_completion_excluding_auto_import( + src, + vec![simple_completion_item( + "bar", + CompletionItemKind::FUNCTION, + Some("fn(self)".to_string()), + )], + ) + .await; + } } diff --git a/tooling/lsp/src/requests/completion/traversal.rs b/tooling/lsp/src/requests/completion/traversal.rs index b487c8baf3..e3bd8ffadf 100644 --- a/tooling/lsp/src/requests/completion/traversal.rs +++ b/tooling/lsp/src/requests/completion/traversal.rs @@ -2,9 +2,9 @@ /// traversing the AST without any additional logic. use noirc_frontend::{ ast::{ - ArrayLiteral, AssignStatement, CallExpression, CastExpression, ConstrainStatement, - Expression, ForRange, FunctionReturnType, IndexExpression, InfixExpression, Literal, - MethodCallExpression, NoirTrait, NoirTypeAlias, TraitImplItem, UnresolvedType, + ArrayLiteral, AssignStatement, CastExpression, ConstrainStatement, Expression, ForRange, + FunctionReturnType, GenericTypeArgs, IndexExpression, InfixExpression, Literal, NoirTrait, + NoirTypeAlias, TraitImplItem, UnresolvedType, }, ParsedModule, }; @@ -89,19 +89,6 @@ impl<'a> NodeFinder<'a> { self.find_in_expression(&index_expression.index); } - pub(super) fn find_in_call_expression(&mut self, call_expression: &CallExpression) { - self.find_in_expression(&call_expression.func); - self.find_in_expressions(&call_expression.arguments); - } - - pub(super) fn find_in_method_call_expression( - &mut self, - method_call_expression: &MethodCallExpression, - ) { - self.find_in_expression(&method_call_expression.object); - self.find_in_expressions(&method_call_expression.arguments); - } - pub(super) fn find_in_cast_expression(&mut self, cast_expression: &CastExpression) { self.find_in_expression(&cast_expression.lhs); } @@ -117,6 +104,13 @@ impl<'a> NodeFinder<'a> { } } + pub(super) fn find_in_type_args(&mut self, generics: &GenericTypeArgs) { + self.find_in_unresolved_types(&generics.ordered_args); + for (_name, typ) in &generics.named_args { + self.find_in_unresolved_type(typ); + } + } + pub(super) fn find_in_function_return_type(&mut self, return_type: &FunctionReturnType) { match return_type { noirc_frontend::ast::FunctionReturnType::Default(_) => (), diff --git a/tooling/lsp/src/requests/document_symbol.rs b/tooling/lsp/src/requests/document_symbol.rs index 5d2635b354..bda246f7c9 100644 --- a/tooling/lsp/src/requests/document_symbol.rs +++ b/tooling/lsp/src/requests/document_symbol.rs @@ -359,12 +359,22 @@ impl<'a> DocumentSymbolCollector<'a> { trait_name.push_str(&noir_trait_impl.trait_name.to_string()); if !noir_trait_impl.trait_generics.is_empty() { trait_name.push('<'); - for (index, generic) in noir_trait_impl.trait_generics.iter().enumerate() { + for (index, generic) in noir_trait_impl.trait_generics.ordered_args.iter().enumerate() { if index > 0 { trait_name.push_str(", "); } trait_name.push_str(&generic.to_string()); } + for (index, (name, generic)) in + noir_trait_impl.trait_generics.named_args.iter().enumerate() + { + if index > 0 { + trait_name.push_str(", "); + } + trait_name.push_str(&name.0.contents); + trait_name.push_str(" = "); + trait_name.push_str(&generic.to_string()); + } trait_name.push('>'); } diff --git a/tooling/lsp/src/requests/hover.rs b/tooling/lsp/src/requests/hover.rs index 11a296ed4f..bde1674efb 100644 --- a/tooling/lsp/src/requests/hover.rs +++ b/tooling/lsp/src/requests/hover.rs @@ -425,9 +425,12 @@ impl<'a> TypeLinksGatherer<'a> { Type::TraitAsType(trait_id, _, generics) => { let some_trait = self.interner.get_trait(*trait_id); self.gather_trait_links(some_trait); - for generic in generics { + for generic in &generics.ordered { self.gather_type_links(generic); } + for named_type in &generics.named { + self.gather_type_links(&named_type.typ); + } } Type::NamedGeneric(var, _, _) => { self.gather_type_variable_links(var); diff --git a/tooling/lsp/src/requests/signature_help.rs b/tooling/lsp/src/requests/signature_help.rs index c2c6918554..8aa74fe990 100644 --- a/tooling/lsp/src/requests/signature_help.rs +++ b/tooling/lsp/src/requests/signature_help.rs @@ -146,6 +146,7 @@ impl<'a> SignatureFinder<'a> { // Otherwise, the call must be a reference to an fn type if let Some(mut typ) = self.interner.type_at_location(location) { + typ = typ.follow_bindings(); if let Type::Forall(_, forall_typ) = typ { typ = *forall_typ; } diff --git a/tooling/nargo/src/errors.rs b/tooling/nargo/src/errors.rs index f9668653d0..b5571ff775 100644 --- a/tooling/nargo/src/errors.rs +++ b/tooling/nargo/src/errors.rs @@ -158,13 +158,16 @@ fn extract_locations_from_error( debug[resolved_location.acir_function_index] .opcode_location(&resolved_location.opcode_location) .unwrap_or_else(|| { - if let Some(brillig_function_id) = brillig_function_id { + if let (Some(brillig_function_id), Some(brillig_location)) = ( + brillig_function_id, + &resolved_location.opcode_location.to_brillig_location(), + ) { let brillig_locations = debug[resolved_location.acir_function_index] .brillig_locations .get(&brillig_function_id); brillig_locations .unwrap() - .get(&resolved_location.opcode_location) + .get(brillig_location) .cloned() .unwrap_or_default() } else { diff --git a/tooling/profiler/src/cli/gates_flamegraph_cmd.rs b/tooling/profiler/src/cli/gates_flamegraph_cmd.rs index 0fa12239d0..d5fefc4ecd 100644 --- a/tooling/profiler/src/cli/gates_flamegraph_cmd.rs +++ b/tooling/profiler/src/cli/gates_flamegraph_cmd.rs @@ -22,7 +22,7 @@ pub(crate) struct GatesFlamegraphCommand { backend_path: String, /// Command to get a gates report from the backend. Defaults to "gates" - #[clap(long, short, default_value = "gates")] + #[clap(long, short = 'g', default_value = "gates")] backend_gates_command: String, #[arg(trailing_var_arg = true, allow_hyphen_values = true)] @@ -87,6 +87,7 @@ fn run_with_provider( opcode: AcirOrBrilligOpcode::Acir(opcode), call_stack: vec![OpcodeLocation::Acir(index)], count: gates, + brillig_function_id: None, }) .collect(); diff --git a/tooling/profiler/src/cli/opcodes_flamegraph_cmd.rs b/tooling/profiler/src/cli/opcodes_flamegraph_cmd.rs index d7f3cbb9b8..863d45b96d 100644 --- a/tooling/profiler/src/cli/opcodes_flamegraph_cmd.rs +++ b/tooling/profiler/src/cli/opcodes_flamegraph_cmd.rs @@ -1,5 +1,6 @@ use std::path::{Path, PathBuf}; +use acir::circuit::brillig::BrilligFunctionId; use acir::circuit::{Circuit, Opcode, OpcodeLocation}; use clap::Args; use color_eyre::eyre::{self, Context}; @@ -20,7 +21,7 @@ pub(crate) struct OpcodesFlamegraphCommand { #[clap(long, short)] output: String, - /// Wether to skip brillig functions + /// Whether to skip brillig functions #[clap(long, short, action)] skip_brillig: bool, } @@ -62,6 +63,7 @@ fn run_with_generator( opcode: AcirOrBrilligOpcode::Acir(opcode.clone()), call_stack: vec![OpcodeLocation::Acir(index)], count: 1, + brillig_function_id: None, }) .collect(); @@ -101,6 +103,7 @@ fn run_with_generator( brillig_index, }], count: 1, + brillig_function_id: Some(BrilligFunctionId(brillig_fn_index as u32)), }) .collect(); diff --git a/tooling/profiler/src/flamegraph.rs b/tooling/profiler/src/flamegraph.rs index da76f9b993..488079de50 100644 --- a/tooling/profiler/src/flamegraph.rs +++ b/tooling/profiler/src/flamegraph.rs @@ -1,6 +1,7 @@ use std::path::Path; use std::{collections::BTreeMap, io::BufWriter}; +use acir::circuit::brillig::BrilligFunctionId; use acir::circuit::OpcodeLocation; use acir::AcirField; use color_eyre::eyre::{self}; @@ -19,6 +20,7 @@ pub(crate) struct Sample { pub(crate) opcode: AcirOrBrilligOpcode, pub(crate) call_stack: Vec, pub(crate) count: usize, + pub(crate) brillig_function_id: Option, } #[derive(Debug, Default)] @@ -90,9 +92,24 @@ fn generate_folded_sorted_lines<'files, F: AcirField>( let mut location_names: Vec = sample .call_stack .into_iter() - .flat_map(|opcode_location| debug_symbols.locations.get(&opcode_location)) - .flatten() - .map(|location| location_to_callsite_label(*location, files)) + .flat_map(|opcode_location| { + debug_symbols.opcode_location(&opcode_location).unwrap_or_else(|| { + if let (Some(brillig_function_id), Some(brillig_location)) = + (sample.brillig_function_id, opcode_location.to_brillig_location()) + { + let brillig_locations = + debug_symbols.brillig_locations.get(&brillig_function_id); + if let Some(brillig_locations) = brillig_locations { + brillig_locations.get(&brillig_location).cloned().unwrap_or_default() + } else { + vec![] + } + } else { + vec![] + } + }) + }) + .map(|location| location_to_callsite_label(location, files)) .collect(); if location_names.is_empty() { @@ -286,11 +303,13 @@ mod tests { opcode: AcirOrBrilligOpcode::Acir(AcirOpcode::AssertZero(Expression::default())), call_stack: vec![OpcodeLocation::Acir(0)], count: 10, + brillig_function_id: None, }, Sample { opcode: AcirOrBrilligOpcode::Acir(AcirOpcode::AssertZero(Expression::default())), call_stack: vec![OpcodeLocation::Acir(1)], count: 20, + brillig_function_id: None, }, Sample { opcode: AcirOrBrilligOpcode::Acir(AcirOpcode::MemoryInit { @@ -300,6 +319,7 @@ mod tests { }), call_stack: vec![OpcodeLocation::Acir(2)], count: 30, + brillig_function_id: None, }, ];