diff --git a/Cargo.lock b/Cargo.lock index c57953ba6547..e0b118563f49 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,21 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +[[package]] +name = "aho-corasick" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" +dependencies = [ + "memchr", +] + +[[package]] +name = "allocator-api2" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" + [[package]] name = "always-assert" version = "0.2.0" @@ -71,7 +86,7 @@ dependencies = [ "la-arena 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)", "lz4_flex", "rustc-hash 2.0.0", - "salsa", + "salsa 0.0.0", "semver", "span", "stdx", @@ -108,6 +123,15 @@ dependencies = [ "cfg_aliases 0.2.1", ] +[[package]] +name = "boxcar" +version = "0.2.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "225450ee9328e1e828319b48a89726cffc1b0ad26fd9211ad435de9fa376acae" +dependencies = [ + "loom", +] + [[package]] name = "byteorder" version = "1.5.0" @@ -289,6 +313,15 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "crossbeam-queue" +version = "0.3.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f58bbc28f91df819d0aa2a2c00cd19754769c2fad90579b3592b1c9ba7a3115" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "crossbeam-utils" version = "0.8.20" @@ -312,7 +345,21 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856" dependencies = [ "cfg-if", - "hashbrown", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + +[[package]] +name = "dashmap" +version = "6.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" +dependencies = [ + "cfg-if", + "crossbeam-utils", + "hashbrown 0.14.5", "lock_api", "once_cell", "parking_lot_core", @@ -449,6 +496,12 @@ dependencies = [ "miniz_oxide", ] +[[package]] +name = "foldhash" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0d2fde1f7b3d48b8395d5f2de76c18a528bd6a9cdde438df747bfcba3e05d6f" + [[package]] name = "form_urlencoded" version = "1.2.1" @@ -473,6 +526,19 @@ version = "0.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7ab85b9b05e3978cc9a9cf8fea7f01b494e1a09ed3037e16ba39edc7a29eb61a" +[[package]] +name = "generator" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc6bd114ceda131d3b1d665eba35788690ad37f5916457286b32ab6fd3c438dd" +dependencies = [ + "cfg-if", + "libc", + "log", + "rustversion", + "windows 0.58.0", +] + [[package]] name = "getrandom" version = "0.2.15" @@ -496,11 +562,31 @@ version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +[[package]] +name = "hashbrown" +version = "0.15.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" +dependencies = [ + "allocator-api2", + "equivalent", + "foldhash", +] + +[[package]] +name = "hashlink" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7382cf6263419f2d8df38c55d7da83da5c18aef87fc7a7fc1fb1e344edfe14c1" +dependencies = [ + "hashbrown 0.15.2", +] + [[package]] name = "heck" -version = "0.4.1" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" [[package]] name = "hermit-abi" @@ -545,12 +631,12 @@ dependencies = [ "bitflags 2.7.0", "cfg", "cov-mark", - "dashmap", + "dashmap 5.5.3", "drop_bomb", "either", "expect-test", "fst", - "hashbrown", + "hashbrown 0.14.5", "hir-expand", "indexmap", "intern", @@ -584,7 +670,7 @@ dependencies = [ "cov-mark", "either", "expect-test", - "hashbrown", + "hashbrown 0.14.5", "intern", "itertools", "la-arena 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)", @@ -811,7 +897,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "de3fc2e30ba82dd1b3911c8de1ffc143c74a914a14e99514d7637e3099df5ea0" dependencies = [ "equivalent", - "hashbrown", + "hashbrown 0.14.5", ] [[package]] @@ -838,8 +924,8 @@ dependencies = [ name = "intern" version = "0.0.0" dependencies = [ - "dashmap", - "hashbrown", + "dashmap 5.5.3", + "hashbrown 0.14.5", "rustc-hash 2.0.0", "triomphe", ] @@ -999,6 +1085,19 @@ version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" +[[package]] +name = "loom" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "419e0dc8046cb947daa77eb95ae174acfbddb7673b4151f56d1eed8e93fbfaca" +dependencies = [ + "cfg-if", + "generator", + "scoped-tls", + "tracing", + "tracing-subscriber", +] + [[package]] name = "lsp-server" version = "0.7.7" @@ -1043,6 +1142,15 @@ version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "75761162ae2b0e580d7e7c390558127e5f01b4194debd6221fd8c207fc80e3f5" +[[package]] +name = "matchers" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558" +dependencies = [ + "regex-automata 0.1.10", +] + [[package]] name = "mbe" version = "0.0.0" @@ -1170,6 +1278,16 @@ version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5e0826a989adedc2a244799e823aece04662b66609d96af8dff7ac6df9a8925d" +[[package]] +name = "nu-ansi-term" +version = "0.46.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" +dependencies = [ + "overload", + "winapi", +] + [[package]] name = "nu-ansi-term" version = "0.50.1" @@ -1240,6 +1358,12 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" +[[package]] +name = "overload" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" + [[package]] name = "parking_lot" version = "0.12.3" @@ -1414,7 +1538,7 @@ dependencies = [ "indexmap", "nix", "tracing", - "windows", + "windows 0.56.0", ] [[package]] @@ -1493,6 +1617,18 @@ dependencies = [ "pulldown-cmark", ] +[[package]] +name = "query-group-macro" +version = "0.0.0" +dependencies = [ + "expect-test", + "heck", + "proc-macro2", + "quote", + "salsa 0.18.0", + "syn", +] + [[package]] name = "quote" version = "1.0.36" @@ -1648,6 +1784,50 @@ dependencies = [ "thiserror", ] +[[package]] +name = "regex" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata 0.4.9", + "regex-syntax 0.8.5", +] + +[[package]] +name = "regex-automata" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" +dependencies = [ + "regex-syntax 0.6.29", +] + +[[package]] +name = "regex-automata" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax 0.8.5", +] + +[[package]] +name = "regex-syntax" +version = "0.6.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" + +[[package]] +name = "regex-syntax" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" + [[package]] name = "rowan" version = "0.15.15" @@ -1655,7 +1835,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32a58fa8a7ccff2aec4f39cc45bf5f985cec7125ab271cf681c279fd00192b49" dependencies = [ "countme", - "hashbrown", + "hashbrown 0.14.5", "memoffset", "rustc-hash 1.1.0", "text-size", @@ -1760,6 +1940,12 @@ dependencies = [ "smallvec", ] +[[package]] +name = "rustversion" +version = "1.0.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eded382c5f5f786b989652c49544c4877d9f015cc22e145a5ea8ea66c2921cd2" + [[package]] name = "ryu" version = "1.0.18" @@ -1780,12 +1966,37 @@ dependencies = [ "parking_lot", "rand", "rustc-hash 2.0.0", - "salsa-macros", + "salsa-macros 0.0.0", "smallvec", "tracing", "triomphe", ] +[[package]] +name = "salsa" +version = "0.18.0" +source = "git+https://github.com/salsa-rs/salsa.git#9d2a9786c45000f5fa396ad2872391e302a2836a" +dependencies = [ + "boxcar", + "crossbeam-queue", + "dashmap 6.1.0", + "hashbrown 0.15.2", + "hashlink", + "indexmap", + "parking_lot", + "rayon", + "rustc-hash 2.0.0", + "salsa-macro-rules", + "salsa-macros 0.18.0", + "smallvec", + "tracing", +] + +[[package]] +name = "salsa-macro-rules" +version = "0.18.0" +source = "git+https://github.com/salsa-rs/salsa.git#9d2a9786c45000f5fa396ad2872391e302a2836a" + [[package]] name = "salsa-macros" version = "0.0.0" @@ -1796,6 +2007,18 @@ dependencies = [ "syn", ] +[[package]] +name = "salsa-macros" +version = "0.18.0" +source = "git+https://github.com/salsa-rs/salsa.git#9d2a9786c45000f5fa396ad2872391e302a2836a" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn", + "synstructure", +] + [[package]] name = "same-file" version = "1.0.6" @@ -1923,10 +2146,10 @@ dependencies = [ name = "span" version = "0.0.0" dependencies = [ - "hashbrown", + "hashbrown 0.14.5", "la-arena 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)", "rustc-hash 2.0.0", - "salsa", + "salsa 0.0.0", "stdx", "syntax", "text-size", @@ -2246,9 +2469,15 @@ version = "0.3.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b" dependencies = [ + "matchers", + "nu-ansi-term 0.46.0", + "once_cell", + "regex", "sharded-slab", + "smallvec", "thread_local", "time", + "tracing", "tracing-core", "tracing-log", ] @@ -2259,7 +2488,7 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b56c62d2c80033cb36fae448730a2f2ef99410fe3ecbffc916681a32f6807dbe" dependencies = [ - "nu-ansi-term", + "nu-ansi-term 0.50.1", "tracing-core", "tracing-log", "tracing-subscriber", @@ -2405,6 +2634,22 @@ version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + [[package]] name = "winapi-util" version = "0.1.9" @@ -2414,13 +2659,29 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + [[package]] name = "windows" version = "0.56.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1de69df01bdf1ead2f4ac895dc77c9351aefff65b2f3db429a343f9cbf05e132" dependencies = [ - "windows-core", + "windows-core 0.56.0", + "windows-targets 0.52.6", +] + +[[package]] +name = "windows" +version = "0.58.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd04d41d93c4992d421894c18c8b43496aa748dd4c081bac0dc93eb0489272b6" +dependencies = [ + "windows-core 0.58.0", "windows-targets 0.52.6", ] @@ -2430,9 +2691,22 @@ version = "0.56.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4698e52ed2d08f8658ab0c39512a7c00ee5fe2688c65f8c0a4f06750d729f2a6" dependencies = [ - "windows-implement", - "windows-interface", - "windows-result", + "windows-implement 0.56.0", + "windows-interface 0.56.0", + "windows-result 0.1.2", + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-core" +version = "0.58.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ba6d44ec8c2591c134257ce647b7ea6b20335bf6379a27dac5f1641fcf59f99" +dependencies = [ + "windows-implement 0.58.0", + "windows-interface 0.58.0", + "windows-result 0.2.0", + "windows-strings", "windows-targets 0.52.6", ] @@ -2447,6 +2721,17 @@ dependencies = [ "syn", ] +[[package]] +name = "windows-implement" +version = "0.58.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2bbd5b46c938e506ecbce286b6628a02171d56153ba733b6c741fc627ec9579b" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "windows-interface" version = "0.56.0" @@ -2458,6 +2743,17 @@ dependencies = [ "syn", ] +[[package]] +name = "windows-interface" +version = "0.58.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "053c4c462dc91d3b1504c6fe5a726dd15e216ba718e84a0e46a88fbe5ded3515" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "windows-result" version = "0.1.2" @@ -2467,6 +2763,25 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "windows-result" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d1043d8214f791817bab27572aaa8af63732e11bf84aa21a45a78d6c317ae0e" +dependencies = [ + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-strings" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4cd9b125c486025df0eabcb585e62173c6c9eddcec5d117d3b6e8c30e2ee4d10" +dependencies = [ + "windows-result 0.2.0", + "windows-targets 0.52.6", +] + [[package]] name = "windows-sys" version = "0.48.0" diff --git a/crates/query-group-macro/Cargo.toml b/crates/query-group-macro/Cargo.toml new file mode 100644 index 000000000000..e81266f3e0aa --- /dev/null +++ b/crates/query-group-macro/Cargo.toml @@ -0,0 +1,23 @@ +[package] +name = "query-group-macro" +version = "0.0.0" +repository.workspace = true +description = "A macro mimicking the `#[salsa::query_group]` macro for migrating to new Salsa" + +authors.workspace = true +edition.workspace = true +license.workspace = true +rust-version.workspace = true + +[lib] +proc-macro = true + +[dependencies] +heck = "0.5.0" +proc-macro2 = "1.0" +quote = "1.0" +syn = { version = "2.0", features = ["full", "extra-traits"] } +salsa = { git = "https://github.com/salsa-rs/salsa.git" } + +[dev-dependencies] +expect-test = "1.5.0" diff --git a/crates/query-group-macro/src/lib.rs b/crates/query-group-macro/src/lib.rs new file mode 100644 index 000000000000..6469dacff3b1 --- /dev/null +++ b/crates/query-group-macro/src/lib.rs @@ -0,0 +1,492 @@ +use core::fmt; +use std::vec; + +use proc_macro::TokenStream; +use proc_macro2::Span; +use queries::{ + GeneratedInputStruct, InputQuery, InputSetter, InputSetterWithDurability, Intern, Lookup, + Queries, SetterKind, TrackedQuery, Transparent, +}; +use quote::{format_ident, quote, ToTokens}; +use syn::spanned::Spanned; +use syn::visit_mut::VisitMut; +use syn::{parse_quote, Attribute, FnArg, ItemTrait, Path, TraitItem, TraitItemFn}; + +mod queries; + +#[proc_macro_attribute] +pub fn query_group(args: TokenStream, input: TokenStream) -> TokenStream { + match query_group_impl(args, input.clone()) { + Ok(tokens) => tokens.into(), + Err(e) => token_stream_with_error(input, e), + } +} + +#[derive(Debug)] +struct InputStructField { + name: proc_macro2::TokenStream, + ty: proc_macro2::TokenStream, +} + +impl fmt::Display for InputStructField { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.name) + } +} + +struct SalsaAttr { + name: String, + tts: TokenStream, + span: Span, +} + +impl std::fmt::Debug for SalsaAttr { + fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(fmt, "{:?}", self.name) + } +} + +impl TryFrom for SalsaAttr { + type Error = syn::Attribute; + + fn try_from(attr: syn::Attribute) -> Result { + if is_not_salsa_attr_path(attr.path()) { + return Err(attr); + } + + let span = attr.span(); + + let name = attr.path().segments[1].ident.to_string(); + let tts = match attr.meta { + syn::Meta::Path(path) => path.into_token_stream(), + syn::Meta::List(ref list) => { + let tts = list + .into_token_stream() + .into_iter() + .skip(attr.path().to_token_stream().into_iter().count()); + proc_macro2::TokenStream::from_iter(tts) + } + syn::Meta::NameValue(nv) => nv.into_token_stream(), + } + .into(); + + Ok(SalsaAttr { name, tts, span }) + } +} + +fn is_not_salsa_attr_path(path: &syn::Path) -> bool { + path.segments + .first() + .map(|s| s.ident != "salsa") + .unwrap_or(true) + || path.segments.len() != 2 +} + +fn filter_attrs(attrs: Vec) -> (Vec, Vec) { + let mut other = vec![]; + let mut salsa = vec![]; + // Leave non-salsa attributes untouched. These are + // attributes that don't start with `salsa::` or don't have + // exactly two segments in their path. + for attr in attrs { + match SalsaAttr::try_from(attr) { + Ok(it) => salsa.push(it), + Err(it) => other.push(it), + } + } + (other, salsa) +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum QueryKind { + Input, + Tracked, + TrackedWithSalsaStruct, + Transparent, + Interned, +} + +pub(crate) fn query_group_impl( + _args: proc_macro::TokenStream, + input: proc_macro::TokenStream, +) -> Result { + let mut item_trait = match syn::parse::(input) { + Ok(path) => path, + Err(e) => return Err(e), + }; + + let supertraits = &item_trait.supertraits; + + let db_attr: Attribute = parse_quote! { + #[salsa::db] + }; + item_trait.attrs.push(db_attr); + + let trait_name_ident = &item_trait.ident.clone(); + let input_struct_name = format_ident!("{}Data", trait_name_ident); + let create_data_ident = format_ident!("create_data_{}", trait_name_ident); + + let mut input_struct_fields: Vec = vec![]; + let mut trait_methods = vec![]; + let mut setter_trait_methods = vec![]; + let mut lookup_signatures = vec![]; + let mut lookup_methods = vec![]; + + for item in item_trait.clone().items { + match item { + syn::TraitItem::Fn(method) => { + let method_name = &method.sig.ident; + let signature = &method.sig.clone(); + + let (_attrs, salsa_attrs) = filter_attrs(method.attrs); + + let mut query_kind = QueryKind::Tracked; + let mut invoke = None; + let mut cycle = None; + let mut interned_struct_path = None; + let mut lru = None; + + let params: Vec = signature.inputs.clone().into_iter().collect(); + let pat_and_tys = params + .into_iter() + .filter(|fn_arg| matches!(fn_arg, FnArg::Typed(_))) + .map(|fn_arg| match fn_arg { + FnArg::Typed(pat_type) => pat_type.clone(), + FnArg::Receiver(_) => unreachable!("this should have been filtered out"), + }) + .collect::>(); + + for SalsaAttr { name, tts, span } in salsa_attrs { + match name.as_str() { + "cycle" => { + let path = match syn::parse::>(tts) { + Ok(path) => path, + Err(e) => return Err(e), + }; + cycle = Some(path.0.clone()) + } + "input" => { + if !pat_and_tys.is_empty() { + return Err(syn::Error::new( + span, + "input methods cannot have a parameter", + )); + } + query_kind = QueryKind::Input; + } + "interned" => { + let syn::ReturnType::Type(_, ty) = &signature.output else { + return Err(syn::Error::new( + span, + "interned queries must have return type", + )); + }; + let syn::Type::Path(path) = &**ty else { + return Err(syn::Error::new( + span, + "interned queries must have return type", + )); + }; + interned_struct_path = Some(path.path.clone()); + query_kind = QueryKind::Interned; + } + "invoke" => { + let path = match syn::parse::>(tts) { + Ok(path) => path, + Err(e) => return Err(e), + }; + invoke = Some(path.0.clone()); + } + "invoke_actual" => { + let path = match syn::parse::>(tts) { + Ok(path) => path, + Err(e) => return Err(e), + }; + invoke = Some(path.0.clone()); + query_kind = QueryKind::TrackedWithSalsaStruct; + } + "lru" => { + let lru_count = match syn::parse::>(tts) { + Ok(path) => path, + Err(e) => return Err(e), + }; + let value = lru_count.0.base10_parse::()?; + + lru = Some(value); + } + "transparent" => { + query_kind = QueryKind::Transparent; + } + _ => { + return Err(syn::Error::new( + span.clone(), + format!("unknown attribute `{name}`"), + )) + } + } + } + + let syn::ReturnType::Type(_, return_ty) = signature.output.clone() else { + return Err(syn::Error::new( + signature.span(), + "Queries must have a return type", + )); + }; + + if let syn::Type::Path(ref ty_path) = *return_ty { + if matches!(query_kind, QueryKind::Input) { + let field = InputStructField { + name: method_name.to_token_stream(), + ty: ty_path.path.to_token_stream(), + }; + + input_struct_fields.push(field); + } + } + + match (query_kind, invoke) { + // input + (QueryKind::Input, None) => { + let query = InputQuery { + signature: method.sig.clone(), + create_data_ident: create_data_ident.clone(), + }; + let value = Queries::InputQuery(query); + trait_methods.push(value); + + let setter = InputSetter { + signature: method.sig.clone(), + return_type: *return_ty.clone(), + create_data_ident: create_data_ident.clone(), + }; + setter_trait_methods.push(SetterKind::Plain(setter)); + + let setter = InputSetterWithDurability { + signature: method.sig.clone(), + return_type: *return_ty.clone(), + create_data_ident: create_data_ident.clone(), + }; + setter_trait_methods.push(SetterKind::WithDurability(setter)); + } + (QueryKind::Interned, None) => { + let interned_struct_path = interned_struct_path.unwrap(); + let method = Intern { + signature: signature.clone(), + pat_and_tys: pat_and_tys.clone(), + interned_struct_path: interned_struct_path.clone(), + }; + + trait_methods.push(Queries::Intern(method)); + + let mut method = Lookup { + signature: signature.clone(), + pat_and_tys: pat_and_tys.clone(), + return_ty: *return_ty, + interned_struct_path, + }; + method.prepare_signature(); + + lookup_signatures + .push(TraitItem::Fn(make_trait_method(method.signature.clone()))); + lookup_methods.push(method); + } + // tracked function without *any* invoke. + (QueryKind::Tracked, None) => { + let method = TrackedQuery { + trait_name: trait_name_ident.clone(), + generated_struct: Some(GeneratedInputStruct { + input_struct_name: input_struct_name.clone(), + create_data_ident: create_data_ident.clone(), + }), + signature: signature.clone(), + pat_and_tys: pat_and_tys.clone(), + invoke: None, + cycle, + lru, + }; + + trait_methods.push(Queries::TrackedQuery(method)); + } + // tracked function with an invoke + (QueryKind::Tracked, Some(invoke)) => { + let method = TrackedQuery { + trait_name: trait_name_ident.clone(), + generated_struct: Some(GeneratedInputStruct { + input_struct_name: input_struct_name.clone(), + create_data_ident: create_data_ident.clone(), + }), + signature: signature.clone(), + pat_and_tys: pat_and_tys.clone(), + invoke: Some(invoke), + cycle, + lru, + }; + + trait_methods.push(Queries::TrackedQuery(method)) + } + (QueryKind::TrackedWithSalsaStruct, Some(invoke)) => { + let method = TrackedQuery { + trait_name: trait_name_ident.clone(), + generated_struct: None, + signature: signature.clone(), + pat_and_tys: pat_and_tys.clone(), + invoke: Some(invoke), + cycle, + lru, + }; + + trait_methods.push(Queries::TrackedQuery(method)) + } + (QueryKind::TrackedWithSalsaStruct, None) => unreachable!(), + (QueryKind::Transparent, None) => { + let method = Transparent { + signature: method.sig.clone(), + pat_and_tys: pat_and_tys.clone(), + invoke: None, + }; + trait_methods.push(Queries::Transparent(method)); + } + (QueryKind::Transparent, Some(invoke)) => { + let method = Transparent { + signature: method.sig.clone(), + pat_and_tys: pat_and_tys.clone(), + invoke: Some(invoke), + }; + trait_methods.push(Queries::Transparent(method)); + } + // error/invalid constructions + (QueryKind::Interned, Some(path)) => { + return Err(syn::Error::new( + path.span(), + format!("Interned queries cannot be used with an `#[invoke]`"), + )) + } + (QueryKind::Input, Some(path)) => { + return Err(syn::Error::new( + path.span(), + format!("Inputs cannot be used with an `#[invoke]`"), + )) + } + } + } + + _ => (), + } + } + + let fields = input_struct_fields + .into_iter() + .map(|input| { + let name = input.name; + let ret = input.ty; + quote! { #name: Option<#ret> } + }) + .collect::>(); + + let input_struct = quote! { + #[salsa::input] + pub(crate) struct #input_struct_name { + #(#fields),* + } + }; + + let field_params = std::iter::repeat_n(quote! { None }, fields.len()) + .collect::>(); + + let create_data_method = quote! { + #[allow(non_snake_case)] + #[salsa::tracked] + fn #create_data_ident(db: &dyn #trait_name_ident) -> #input_struct_name { + #input_struct_name::new(db, #(#field_params),*) + } + }; + + let mut setter_signatures = vec![]; + let mut setter_methods = vec![]; + for trait_item in setter_trait_methods + .iter() + .map(|method| method.to_token_stream()) + .map(|tokens| syn::parse2::(tokens).unwrap()) + { + let mut methods_sans_body = trait_item.clone(); + methods_sans_body.default = None; + methods_sans_body.semi_token = Some(syn::Token![;](trait_item.span())); + + setter_signatures.push(TraitItem::Fn(methods_sans_body)); + setter_methods.push(TraitItem::Fn(trait_item)); + } + + item_trait.items.append(&mut setter_signatures); + item_trait.items.append(&mut lookup_signatures); + + let trait_impl = quote! { + #[salsa::db] + impl #trait_name_ident for DB + where + DB: #supertraits, + { + #(#trait_methods)* + + #(#setter_methods)* + + #(#lookup_methods)* + } + }; + RemoveAttrsFromTraitMethods.visit_item_trait_mut(&mut item_trait); + + let out = quote! { + #item_trait + + #trait_impl + + #input_struct + + #create_data_method + } + .into(); + + Ok(out) +} + +/// Parenthesis helper +pub(crate) struct Parenthesized(pub(crate) T); + +impl syn::parse::Parse for Parenthesized +where + T: syn::parse::Parse, +{ + fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result { + let content; + syn::parenthesized!(content in input); + content.parse::().map(Parenthesized) + } +} + +fn make_trait_method(sig: syn::Signature) -> TraitItemFn { + TraitItemFn { + attrs: vec![], + sig: sig.clone(), + semi_token: Some(syn::Token![;](sig.span())), + default: None, + } +} + +struct RemoveAttrsFromTraitMethods; + +impl VisitMut for RemoveAttrsFromTraitMethods { + fn visit_item_trait_mut(&mut self, i: &mut syn::ItemTrait) { + for item in &mut i.items { + match item { + TraitItem::Fn(trait_item_fn) => { + trait_item_fn.attrs = vec![]; + } + _ => (), + } + } + } +} + +pub(crate) fn token_stream_with_error(mut tokens: TokenStream, error: syn::Error) -> TokenStream { + tokens.extend(TokenStream::from(error.into_compile_error())); + tokens +} diff --git a/crates/query-group-macro/src/queries.rs b/crates/query-group-macro/src/queries.rs new file mode 100644 index 000000000000..7a526a42ffbe --- /dev/null +++ b/crates/query-group-macro/src/queries.rs @@ -0,0 +1,335 @@ +use quote::{ToTokens, format_ident, quote}; +use syn::{FnArg, Ident, PatType, Path, Receiver, ReturnType, Type, parse_quote}; + +pub(crate) struct TrackedQuery { + pub(crate) trait_name: Ident, + pub(crate) signature: syn::Signature, + pub(crate) pat_and_tys: Vec, + pub(crate) invoke: Option, + pub(crate) cycle: Option, + pub(crate) lru: Option, + pub(crate) generated_struct: Option, +} + +pub(crate) struct GeneratedInputStruct { + pub(crate) input_struct_name: Ident, + pub(crate) create_data_ident: Ident, +} + +impl ToTokens for TrackedQuery { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + let sig = &self.signature; + let trait_name = &self.trait_name; + + let ret = &sig.output; + + let invoke = match &self.invoke { + Some(path) => path.to_token_stream(), + None => sig.ident.to_token_stream(), + }; + + let fn_ident = &sig.ident; + let shim: Ident = format_ident!("{}_shim", fn_ident); + + let annotation = match (self.cycle.clone(), self.lru) { + (Some(cycle), Some(lru)) => quote!(#[salsa::tracked(lru = #lru, recovery_fn = #cycle)]), + (Some(cycle), None) => quote!(#[salsa::tracked(recovery_fn = #cycle)]), + (None, Some(lru)) => quote!(#[salsa::tracked(lru = #lru)]), + (None, None) => quote!(#[salsa::tracked]), + }; + + let pat_and_tys = &self.pat_and_tys; + let params = self + .pat_and_tys + .iter() + .map(|pat_type| pat_type.pat.clone()) + .collect::>>(); + + let method = match &self.generated_struct { + Some(generated_struct) => { + let input_struct_name = &generated_struct.input_struct_name; + let create_data_ident = &generated_struct.create_data_ident; + + quote! { + #sig { + #annotation + fn #shim( + db: &dyn #trait_name, + _input: #input_struct_name, + #(#pat_and_tys),* + ) #ret { + #invoke(db, #(#params),*) + } + #shim(self, #create_data_ident(self), #(#params),*) + } + } + } + None => { + quote! { + #sig { + #annotation + fn #shim( + db: &dyn #trait_name, + #(#pat_and_tys),* + ) #ret { + #invoke(db, #(#params),*) + } + #shim(self, #(#params),*) + } + } + } + }; + + method.to_tokens(tokens); + } +} + +pub(crate) struct InputQuery { + pub(crate) signature: syn::Signature, + pub(crate) create_data_ident: Ident, +} + +impl ToTokens for InputQuery { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + let sig = &self.signature; + let fn_ident = &sig.ident; + let create_data_ident = &self.create_data_ident; + + let method = quote! { + #sig { + let data = #create_data_ident(self); + data.#fn_ident(self).unwrap() + } + }; + method.to_tokens(tokens); + } +} + +pub(crate) struct InputSetter { + pub(crate) signature: syn::Signature, + pub(crate) return_type: syn::Type, + pub(crate) create_data_ident: Ident, +} + +impl ToTokens for InputSetter { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + let sig = &mut self.signature.clone(); + + let ty = &self.return_type; + let fn_ident = &sig.ident; + let create_data_ident = &self.create_data_ident; + + let setter_ident = format_ident!("set_{}", fn_ident); + sig.ident = setter_ident.clone(); + + let value_argument: PatType = parse_quote!(__value: #ty); + sig.inputs.push(FnArg::Typed(value_argument.clone())); + + // make `&self` `&mut self` instead. + let mut_recevier: Receiver = parse_quote!(&mut self); + sig.inputs + .first_mut() + .map(|og| *og = FnArg::Receiver(mut_recevier)); + + // remove the return value. + sig.output = ReturnType::Default; + + let value = &value_argument.pat; + let method = quote! { + #sig { + use salsa::Setter; + let data = #create_data_ident(self); + data.#setter_ident(self).to(Some(#value)); + } + }; + method.to_tokens(tokens); + } +} + +pub(crate) struct InputSetterWithDurability { + pub(crate) signature: syn::Signature, + pub(crate) return_type: syn::Type, + pub(crate) create_data_ident: Ident, +} + +impl ToTokens for InputSetterWithDurability { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + let sig = &mut self.signature.clone(); + + let ty = &self.return_type; + let fn_ident = &sig.ident; + let setter_ident = format_ident!("set_{}", fn_ident); + + let create_data_ident = &self.create_data_ident; + + sig.ident = format_ident!("set_{}_with_durability", fn_ident); + + let value_argument: PatType = parse_quote!(__value: #ty); + sig.inputs.push(FnArg::Typed(value_argument.clone())); + + let durability_argument: PatType = parse_quote!(durability: salsa::Durability); + sig.inputs.push(FnArg::Typed(durability_argument.clone())); + + // make `&self` `&mut self` instead. + let mut_recevier: Receiver = parse_quote!(&mut self); + sig.inputs + .first_mut() + .map(|og| *og = FnArg::Receiver(mut_recevier)); + + // remove the return value. + sig.output = ReturnType::Default; + + let value = &value_argument.pat; + let durability = &durability_argument.pat; + let method = quote! { + #sig { + use salsa::Setter; + let data = #create_data_ident(self); + data.#setter_ident(self) + .with_durability(#durability) + .to(Some(#value)); + } + }; + method.to_tokens(tokens); + } +} + +pub(crate) enum SetterKind { + Plain(InputSetter), + WithDurability(InputSetterWithDurability), +} + +impl ToTokens for SetterKind { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + match self { + SetterKind::Plain(input_setter) => input_setter.to_tokens(tokens), + SetterKind::WithDurability(input_setter_with_durability) => { + input_setter_with_durability.to_tokens(tokens) + } + } + } +} + +pub(crate) struct Transparent { + pub(crate) signature: syn::Signature, + pub(crate) pat_and_tys: Vec, + pub(crate) invoke: Option, +} + +impl ToTokens for Transparent { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + let sig = &self.signature; + + let ty = self + .pat_and_tys + .iter() + .map(|pat_type| pat_type.pat.clone()) + .collect::>>(); + + let invoke = match &self.invoke { + Some(path) => path.to_token_stream(), + None => sig.ident.to_token_stream(), + }; + + let method = quote! { + #sig { + #invoke(self, #(#ty),*) + } + }; + + method.to_tokens(tokens); + } +} +pub(crate) struct Intern { + pub(crate) signature: syn::Signature, + pub(crate) pat_and_tys: Vec, + pub(crate) interned_struct_path: Path, +} + +impl ToTokens for Intern { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + let sig = &self.signature; + + let ty = self + .pat_and_tys + .iter() + .map(|pat_type| pat_type.clone()) + .collect::>(); + + let interned_pat = ty.iter().next().unwrap(); + let interned_pat = &interned_pat.pat; + + let wrapper_struct = self.interned_struct_path.to_token_stream(); + + let method = quote! { + #sig { + #wrapper_struct::new(self, #interned_pat) + } + }; + + method.to_tokens(tokens); + } +} + +pub(crate) struct Lookup { + pub(crate) signature: syn::Signature, + pub(crate) pat_and_tys: Vec, + pub(crate) return_ty: Type, + pub(crate) interned_struct_path: Path, +} + +impl Lookup { + pub(crate) fn prepare_signature(&mut self) { + let sig = &self.signature; + + let ident = format_ident!("lookup_{}", sig.ident); + + let ty = self + .pat_and_tys + .iter() + .map(|pat_type| pat_type.clone()) + .collect::>(); + + let interned_key = &self.return_ty; + + let interned_pat = ty.iter().next().unwrap(); + let interned_return_ty = &interned_pat.ty; + + self.signature = parse_quote!( + fn #ident(&self, id: #interned_key) -> #interned_return_ty + ); + } +} + +impl ToTokens for Lookup { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + let sig = &self.signature; + + let wrapper_struct = self.interned_struct_path.to_token_stream(); + let method = quote! { + #sig { + #wrapper_struct::ingredient(self).data(self.as_dyn_database(), id.as_id()).0.clone() + } + }; + + method.to_tokens(tokens); + } +} + +pub(crate) enum Queries { + TrackedQuery(TrackedQuery), + InputQuery(InputQuery), + Intern(Intern), + Transparent(Transparent), +} + +impl ToTokens for Queries { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + match self { + Queries::TrackedQuery(tracked_query) => tracked_query.to_tokens(tokens), + Queries::InputQuery(input_query) => input_query.to_tokens(tokens), + Queries::Transparent(transparent) => transparent.to_tokens(tokens), + Queries::Intern(intern) => intern.to_tokens(tokens), + } + } +} diff --git a/crates/query-group-macro/tests/arity.rs b/crates/query-group-macro/tests/arity.rs new file mode 100644 index 000000000000..f6b2968caaae --- /dev/null +++ b/crates/query-group-macro/tests/arity.rs @@ -0,0 +1,28 @@ +use query_group_macro::query_group; + +#[query_group] +pub trait ArityDb: salsa::Database { + fn one(&self, a: ()) -> String; + + fn two(&self, a: (), b: ()) -> String; + + fn three(&self, a: (), b: (), c: ()) -> String; + + fn none(&self) -> String; +} + +fn one(_db: &dyn ArityDb, _a: ()) -> String { + todo!() +} + +fn two(_db: &dyn ArityDb, _a: (), _b: ()) -> String { + todo!() +} + +fn three(_db: &dyn ArityDb, _a: (), _b: (), _c: ()) -> String { + todo!() +} + +fn none(_db: &dyn ArityDb) -> String { + todo!() +} diff --git a/crates/query-group-macro/tests/cycle.rs b/crates/query-group-macro/tests/cycle.rs new file mode 100644 index 000000000000..3b5d8c8f8ec8 --- /dev/null +++ b/crates/query-group-macro/tests/cycle.rs @@ -0,0 +1,283 @@ +use std::panic::UnwindSafe; + +use query_group_macro::query_group; +use expect_test::expect; +use salsa::Setter; + +/// The queries A, B, and C in `Database` can be configured +/// to invoke one another in arbitrary ways using this +/// enum. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +enum CycleQuery { + None, + A, + B, + C, + AthenC, +} + +#[salsa::input] +struct ABC { + a: CycleQuery, + b: CycleQuery, + c: CycleQuery, +} + +impl CycleQuery { + fn invoke(self, db: &dyn CycleDatabase, abc: ABC) -> Result<(), Error> { + match self { + CycleQuery::A => db.cycle_a(abc), + CycleQuery::B => db.cycle_b(abc), + CycleQuery::C => db.cycle_c(abc), + CycleQuery::AthenC => { + let _ = db.cycle_a(abc); + db.cycle_c(abc) + } + CycleQuery::None => Ok(()), + } + } +} + +#[salsa::input] +struct MyInput {} + +#[salsa::tracked] +fn memoized_a(db: &dyn CycleDatabase, input: MyInput) { + memoized_b(db, input) +} + +#[salsa::tracked] +fn memoized_b(db: &dyn CycleDatabase, input: MyInput) { + memoized_a(db, input) +} + +#[salsa::tracked] +fn volatile_a(db: &dyn CycleDatabase, input: MyInput) { + db.report_untracked_read(); + volatile_b(db, input) +} + +#[salsa::tracked] +fn volatile_b(db: &dyn CycleDatabase, input: MyInput) { + db.report_untracked_read(); + volatile_a(db, input) +} + +#[track_caller] +fn extract_cycle(f: impl FnOnce() + UnwindSafe) -> salsa::Cycle { + let v = std::panic::catch_unwind(f); + if let Err(d) = &v { + if let Some(cycle) = d.downcast_ref::() { + return cycle.clone(); + } + } + panic!("unexpected value: {:?}", v) +} + +#[derive(PartialEq, Eq, Hash, Clone, Debug)] +struct Error { + cycle: Vec, +} + +#[query_group] +trait CycleDatabase: salsa::Database { + #[salsa::cycle(recover_a)] + fn cycle_a(&self, abc: ABC) -> Result<(), Error>; + + #[salsa::cycle(recover_b)] + fn cycle_b(&self, abc: ABC) -> Result<(), Error>; + + fn cycle_c(&self, abc: ABC) -> Result<(), Error>; +} + +fn cycle_a(db: &dyn CycleDatabase, abc: ABC) -> Result<(), Error> { + abc.a(db).invoke(db, abc) +} + +fn recover_a( + _db: &dyn CycleDatabase, + cycle: &salsa::Cycle, + _: CycleDatabaseData, + _abc: ABC, +) -> Result<(), Error> { + Err(Error { + cycle: cycle.participant_keys().map(|k| format!("{k:?}")).collect(), + }) +} + +fn cycle_b(db: &dyn CycleDatabase, abc: ABC) -> Result<(), Error> { + abc.b(db).invoke(db, abc) +} + +fn recover_b( + _db: &dyn CycleDatabase, + cycle: &salsa::Cycle, + _: CycleDatabaseData, + _abc: ABC, +) -> Result<(), Error> { + Err(Error { + cycle: cycle.participant_keys().map(|k| format!("{k:?}")).collect(), + }) +} + +fn cycle_c(db: &dyn CycleDatabase, abc: ABC) -> Result<(), Error> { + abc.c(db).invoke(db, abc) +} + +#[test] +fn cycle_memoized() { + let db = salsa::DatabaseImpl::new(); + + let input = MyInput::new(&db); + let cycle = extract_cycle(|| memoized_a(&db, input)); + let expected = expect![[r#" + [ + DatabaseKeyIndex( + IngredientIndex( + 1, + ), + Id(0), + ), + DatabaseKeyIndex( + IngredientIndex( + 2, + ), + Id(0), + ), + ] + "#]]; + expected.assert_debug_eq(&cycle.all_participants(&db)); +} + +#[test] +fn inner_cycle() { + // A --> B <-- C + // ^ | + // +-----+ + let db = salsa::DatabaseImpl::new(); + + let abc = ABC::new(&db, CycleQuery::B, CycleQuery::A, CycleQuery::B); + let err = db.cycle_c(abc); + assert!(err.is_err()); + let expected = expect![[r#" + [ + "cycle_a_shim(Id(1400))", + "cycle_b_shim(Id(1000))", + ] + "#]]; + expected.assert_debug_eq(&err.unwrap_err().cycle); +} + +#[test] +fn cycle_revalidate() { + // A --> B + // ^ | + // +-----+ + let mut db = salsa::DatabaseImpl::new(); + let abc = ABC::new(&db, CycleQuery::B, CycleQuery::A, CycleQuery::None); + assert!(db.cycle_a(abc).is_err()); + abc.set_b(&mut db).to(CycleQuery::A); // same value as default + assert!(db.cycle_a(abc).is_err()); +} + +#[test] +fn cycle_recovery_unchanged_twice() { + // A --> B + // ^ | + // +-----+ + let mut db = salsa::DatabaseImpl::new(); + let abc = ABC::new(&db, CycleQuery::B, CycleQuery::A, CycleQuery::None); + assert!(db.cycle_a(abc).is_err()); + + abc.set_c(&mut db).to(CycleQuery::A); // force new revision + assert!(db.cycle_a(abc).is_err()); +} + +#[test] +fn cycle_appears() { + let mut db = salsa::DatabaseImpl::new(); + // A --> B + let abc = ABC::new(&db, CycleQuery::B, CycleQuery::None, CycleQuery::None); + assert!(db.cycle_a(abc).is_ok()); + + // A --> B + // ^ | + // +-----+ + abc.set_b(&mut db).to(CycleQuery::A); + assert!(db.cycle_a(abc).is_err()); +} + +#[test] +fn cycle_disappears() { + let mut db = salsa::DatabaseImpl::new(); + + // A --> B + // ^ | + // +-----+ + let abc = ABC::new(&db, CycleQuery::B, CycleQuery::A, CycleQuery::None); + assert!(db.cycle_a(abc).is_err()); + + // A --> B + abc.set_b(&mut db).to(CycleQuery::None); + assert!(db.cycle_a(abc).is_ok()); +} + +#[test] +fn cycle_multiple() { + // No matter whether we start from A or B, we get the same set of participants: + let db = salsa::DatabaseImpl::new(); + + // Configuration: + // + // A --> B <-- C + // ^ | ^ + // +-----+ | + // | | + // +-----+ + // + // Here, conceptually, B encounters a cycle with A and then + // recovers. + let abc = ABC::new(&db, CycleQuery::B, CycleQuery::AthenC, CycleQuery::A); + + let c = db.cycle_c(abc); + let b = db.cycle_b(abc); + let a = db.cycle_a(abc); + let expected = expect![[r#" + ( + [ + "cycle_a_shim(Id(1000))", + "cycle_b_shim(Id(1400))", + ], + [ + "cycle_a_shim(Id(1000))", + "cycle_b_shim(Id(1400))", + ], + [ + "cycle_a_shim(Id(1000))", + "cycle_b_shim(Id(1400))", + ], + ) + "#]]; + expected.assert_debug_eq(&( + c.unwrap_err().cycle, + b.unwrap_err().cycle, + a.unwrap_err().cycle, + )); +} + +#[test] +fn cycle_mixed_1() { + let db = salsa::DatabaseImpl::new(); + // A --> B <-- C + // | ^ + // +-----+ + let abc = ABC::new(&db, CycleQuery::B, CycleQuery::C, CycleQuery::B); + + let expected = expect![[r#" + [ + "cycle_b_shim(Id(1000))", + "cycle_c_shim(Id(c00))", + ] + "#]]; + expected.assert_debug_eq(&db.cycle_c(abc).unwrap_err().cycle); +} diff --git a/crates/query-group-macro/tests/hello_world.rs b/crates/query-group-macro/tests/hello_world.rs new file mode 100644 index 000000000000..71151eabf6fe --- /dev/null +++ b/crates/query-group-macro/tests/hello_world.rs @@ -0,0 +1,129 @@ +use query_group_macro::query_group; +use expect_test::expect; + +mod logger_db; +use logger_db::LoggerDb; + +#[query_group] +pub trait HelloWorldDatabase: salsa::Database { + // input + // // input with no params + #[salsa::input] + fn input_string(&self) -> String; + + // unadorned query + fn length_query(&self, key: ()) -> usize; + + // unadorned query + fn length_query_with_no_params(&self) -> usize; + + // renamed/invoke query + #[salsa::invoke(invoke_length_query_actual)] + fn invoke_length_query(&self, key: ()) -> usize; + + // not a query. should not invoked + #[salsa::transparent] + fn transparent_length(&self, key: ()) -> usize; + + #[salsa::transparent] + #[salsa::invoke(transparent_and_invoke_length_actual)] + fn transparent_and_invoke_length(&self, key: ()) -> usize; +} + +fn length_query(db: &dyn HelloWorldDatabase, key: ()) -> usize { + let _ = key; + db.input_string().len() +} + +fn length_query_with_no_params(db: &dyn HelloWorldDatabase) -> usize { + db.input_string().len() +} + +fn invoke_length_query_actual(db: &dyn HelloWorldDatabase, key: ()) -> usize { + let _ = key; + db.input_string().len() +} + +fn transparent_length(db: &dyn HelloWorldDatabase, key: ()) -> usize { + let _ = key; + db.input_string().len() +} + +fn transparent_and_invoke_length_actual(db: &dyn HelloWorldDatabase, key: ()) -> usize { + let _ = key; + db.input_string().len() +} + +#[test] +fn unadorned_query() { + let mut db = LoggerDb::default(); + + db.set_input_string(String::from("Hello, world!")); + let len = db.length_query(()); + + assert_eq!(len, 13); + db.assert_logs(expect![[r#" + [ + "salsa_event(WillCheckCancellation)", + "salsa_event(WillExecute { database_key: create_data_HelloWorldDatabase(Id(0)) })", + "salsa_event(WillCheckCancellation)", + "salsa_event(DidValidateMemoizedValue { database_key: create_data_HelloWorldDatabase(Id(0)) })", + "salsa_event(WillCheckCancellation)", + "salsa_event(WillExecute { database_key: length_query_shim(Id(800)) })", + "salsa_event(WillCheckCancellation)", + ]"#]]); +} + +#[test] +fn invoke_query() { + let mut db = LoggerDb::default(); + + db.set_input_string(String::from("Hello, world!")); + let len = db.invoke_length_query(()); + + assert_eq!(len, 13); + db.assert_logs(expect![[r#" + [ + "salsa_event(WillCheckCancellation)", + "salsa_event(WillExecute { database_key: create_data_HelloWorldDatabase(Id(0)) })", + "salsa_event(WillCheckCancellation)", + "salsa_event(DidValidateMemoizedValue { database_key: create_data_HelloWorldDatabase(Id(0)) })", + "salsa_event(WillCheckCancellation)", + "salsa_event(WillExecute { database_key: invoke_length_query_shim(Id(800)) })", + "salsa_event(WillCheckCancellation)", + ]"#]]); +} + +#[test] +fn transparent() { + let mut db = LoggerDb::default(); + + db.set_input_string(String::from("Hello, world!")); + let len = db.transparent_length(()); + + assert_eq!(len, 13); + db.assert_logs(expect![[r#" + [ + "salsa_event(WillCheckCancellation)", + "salsa_event(WillExecute { database_key: create_data_HelloWorldDatabase(Id(0)) })", + "salsa_event(WillCheckCancellation)", + "salsa_event(DidValidateMemoizedValue { database_key: create_data_HelloWorldDatabase(Id(0)) })", + ]"#]]); +} + +#[test] +fn transparent_invoke() { + let mut db = LoggerDb::default(); + + db.set_input_string(String::from("Hello, world!")); + let len = db.transparent_and_invoke_length(()); + + assert_eq!(len, 13); + db.assert_logs(expect![[r#" + [ + "salsa_event(WillCheckCancellation)", + "salsa_event(WillExecute { database_key: create_data_HelloWorldDatabase(Id(0)) })", + "salsa_event(WillCheckCancellation)", + "salsa_event(DidValidateMemoizedValue { database_key: create_data_HelloWorldDatabase(Id(0)) })", + ]"#]]); +} diff --git a/crates/query-group-macro/tests/interned.rs b/crates/query-group-macro/tests/interned.rs new file mode 100644 index 000000000000..fe4ff90955a9 --- /dev/null +++ b/crates/query-group-macro/tests/interned.rs @@ -0,0 +1,53 @@ +use query_group_macro::query_group; + +use expect_test::expect; +use salsa::plumbing::AsId; + +mod logger_db; +use logger_db::LoggerDb; + +#[salsa::interned(no_lifetime)] +pub struct InternedString { + data: String, +} + +#[query_group] +pub trait InternedDB: salsa::Database { + #[salsa::interned] + fn intern_string(&self, data: String) -> InternedString; + + fn interned_len(&self, id: InternedString) -> usize; +} + +fn interned_len(db: &dyn InternedDB, id: InternedString +) -> usize { + db.lookup_intern_string(id).len() +} + +#[test] +fn intern_round_trip() { + let db = LoggerDb::default(); + + let id = db.intern_string(String::from("Hello, world!")); + let s = db.lookup_intern_string(id); + + assert_eq!(s.len(), 13); + db.assert_logs(expect![[r#"[]"#]]); +} + +#[test] +fn intern_with_query() { + let db = LoggerDb::default(); + + let id = db.intern_string(String::from("Hello, world!")); + let len = db.interned_len(id); + + assert_eq!(len, 13); + db.assert_logs(expect![[r#" + [ + "salsa_event(WillCheckCancellation)", + "salsa_event(WillExecute { database_key: create_data_InternedDB(Id(400)) })", + "salsa_event(WillCheckCancellation)", + "salsa_event(WillExecute { database_key: interned_len_shim(Id(c00)) })", + ]"#]]); +} diff --git a/crates/query-group-macro/tests/logger_db.rs b/crates/query-group-macro/tests/logger_db.rs new file mode 100644 index 000000000000..1d3194645c3a --- /dev/null +++ b/crates/query-group-macro/tests/logger_db.rs @@ -0,0 +1,45 @@ +use std::sync::{Arc, Mutex}; + +#[salsa::db] +#[derive(Default, Clone)] +pub struct LoggerDb { + storage: salsa::Storage, + logger: Logger, +} + +#[derive(Default, Clone)] +struct Logger { + logs: Arc>>, +} + +#[salsa::db] +impl salsa::Database for LoggerDb { + fn salsa_event(&self, event: &dyn Fn() -> salsa::Event) { + let event = event(); + match event.kind { + salsa::EventKind::WillExecute { .. } + | salsa::EventKind::WillCheckCancellation { .. } + | salsa::EventKind::DidValidateMemoizedValue { .. } + | salsa::EventKind::WillDiscardStaleOutput { .. } + | salsa::EventKind::DidDiscard { .. } => { + self.push_log(format!("salsa_event({:?})", event.kind)); + } + _ => {} + } + } +} + +impl LoggerDb { + /// Log an event from inside a tracked function. + pub fn push_log(&self, string: String) { + self.logger.logs.lock().unwrap().push(string); + } + + /// Asserts what the (formatted) logs should look like, + /// clearing the logged events. This takes `&mut self` because + /// it is meant to be run from outside any tracked functions. + pub fn assert_logs(&self, expected: expect_test::Expect) { + let logs = std::mem::take(&mut *self.logger.logs.lock().unwrap()); + expected.assert_eq(&format!("{:#?}", logs)); + } +} diff --git a/crates/query-group-macro/tests/lru.rs b/crates/query-group-macro/tests/lru.rs new file mode 100644 index 000000000000..ef3343e59831 --- /dev/null +++ b/crates/query-group-macro/tests/lru.rs @@ -0,0 +1,69 @@ +use expect_test::expect; + +mod logger_db; +use logger_db::LoggerDb; +use query_group_macro::query_group; + +#[query_group] +pub trait LruDB: salsa::Database { + // // input with no params + #[salsa::input] + fn input_string(&self) -> String; + + #[salsa::lru(16)] + fn length_query(&self, key: ()) -> usize; + + #[salsa::lru(16)] + #[salsa::invoke(invoked_query)] + fn length_query_invoke(&self, key: ()) -> usize; +} + +fn length_query(db: &dyn LruDB, key: ()) -> usize { + let _ = key; + db.input_string().len() +} + +fn invoked_query(db: &dyn LruDB, key: ()) -> usize { + let _ = key; + db.input_string().len() +} + +#[test] +fn plain_lru() { + let mut db = LoggerDb::default(); + + db.set_input_string(String::from("Hello, world!")); + let len = db.length_query(()); + + assert_eq!(len, 13); + db.assert_logs(expect![[r#" + [ + "salsa_event(WillCheckCancellation)", + "salsa_event(WillExecute { database_key: create_data_LruDB(Id(0)) })", + "salsa_event(WillCheckCancellation)", + "salsa_event(DidValidateMemoizedValue { database_key: create_data_LruDB(Id(0)) })", + "salsa_event(WillCheckCancellation)", + "salsa_event(WillExecute { database_key: length_query_shim(Id(800)) })", + "salsa_event(WillCheckCancellation)", + ]"#]]); +} + +#[test] +fn invoke_lru() { + let mut db = LoggerDb::default(); + + db.set_input_string(String::from("Hello, world!")); + let len = db.length_query_invoke(()); + + assert_eq!(len, 13); + db.assert_logs(expect![[r#" + [ + "salsa_event(WillCheckCancellation)", + "salsa_event(WillExecute { database_key: create_data_LruDB(Id(0)) })", + "salsa_event(WillCheckCancellation)", + "salsa_event(DidValidateMemoizedValue { database_key: create_data_LruDB(Id(0)) })", + "salsa_event(WillCheckCancellation)", + "salsa_event(WillExecute { database_key: length_query_invoke_shim(Id(800)) })", + "salsa_event(WillCheckCancellation)", + ]"#]]); +} diff --git a/crates/query-group-macro/tests/multiple_dbs.rs b/crates/query-group-macro/tests/multiple_dbs.rs new file mode 100644 index 000000000000..b8a376e6d522 --- /dev/null +++ b/crates/query-group-macro/tests/multiple_dbs.rs @@ -0,0 +1,25 @@ +use query_group_macro::query_group; + +#[query_group] +pub trait DatabaseOne: salsa::Database { + #[salsa::input] + fn input_string(&self) -> String; + + // unadorned query + fn length(&self, key: ()) -> usize; +} + +#[query_group] +pub trait DatabaseTwo: DatabaseOne { + fn second_length(&self, key: ()) -> usize; +} + +fn length(db: &dyn DatabaseOne, key: ()) -> usize { + let _ = key; + db.input_string().len() +} + +fn second_length(db: &dyn DatabaseTwo, key: ()) -> usize { + let _ = key; + db.input_string().len() +} diff --git a/crates/query-group-macro/tests/old_and_new.rs b/crates/query-group-macro/tests/old_and_new.rs new file mode 100644 index 000000000000..733b3793ae00 --- /dev/null +++ b/crates/query-group-macro/tests/old_and_new.rs @@ -0,0 +1,115 @@ +use expect_test::expect; + +mod logger_db; +use logger_db::LoggerDb; +use query_group_macro::query_group; + +#[salsa::input] +struct Input { + str: String, +} + +#[query_group] +trait PartialMigrationDatabase: salsa::Database { + fn length_query(&self, input: Input) -> usize; + + // renamed/invoke query + #[salsa::invoke(invoke_length_query_actual)] + fn invoke_length_query(&self, input: Input) -> usize; + + // invoke tracked function + #[salsa::invoke(invoke_length_tracked_actual)] + fn invoke_length_tracked(&self, input: Input) -> usize; +} + +fn length_query(db: &dyn PartialMigrationDatabase, input: Input) -> usize { + input.str(db).len() +} + +fn invoke_length_query_actual(db: &dyn PartialMigrationDatabase, input: Input) -> usize { + input.str(db).len() +} + +#[salsa::tracked] +fn invoke_length_tracked_actual(db: &dyn PartialMigrationDatabase, input: Input) -> usize { + input.str(db).len() +} + +#[test] +fn unadorned_query() { + let db = LoggerDb::default(); + + let input = Input::new(&db, String::from("Hello, world!")); + let len = db.length_query(input); + + assert_eq!(len, 13); + db.assert_logs(expect![[r#" + [ + "salsa_event(WillCheckCancellation)", + "salsa_event(WillExecute { database_key: create_data_PartialMigrationDatabase(Id(400)) })", + "salsa_event(WillCheckCancellation)", + "salsa_event(WillExecute { database_key: length_query_shim(Id(c00)) })", + ]"#]]); +} + +#[test] +fn invoke_query() { + let db = LoggerDb::default(); + + let input = Input::new(&db, String::from("Hello, world!")); + let len = db.invoke_length_query(input); + + assert_eq!(len, 13); + db.assert_logs(expect![[r#" + [ + "salsa_event(WillCheckCancellation)", + "salsa_event(WillExecute { database_key: create_data_PartialMigrationDatabase(Id(400)) })", + "salsa_event(WillCheckCancellation)", + "salsa_event(WillExecute { database_key: invoke_length_query_shim(Id(c00)) })", + ]"#]]); +} + +// todo: does this even make sense? +#[test] +fn invoke_tracked_query() { + let db = LoggerDb::default(); + + let input = Input::new(&db, String::from("Hello, world!")); + let len = db.invoke_length_tracked(input); + + assert_eq!(len, 13); + db.assert_logs(expect![[r#" + [ + "salsa_event(WillCheckCancellation)", + "salsa_event(WillExecute { database_key: create_data_PartialMigrationDatabase(Id(400)) })", + "salsa_event(WillCheckCancellation)", + "salsa_event(WillExecute { database_key: invoke_length_tracked_shim(Id(c00)) })", + "salsa_event(WillCheckCancellation)", + "salsa_event(WillExecute { database_key: invoke_length_tracked_actual(Id(0)) })", + ]"#]]); +} + +#[test] +fn new_salsa_baseline() { + let db = LoggerDb::default(); + + #[salsa::input] + struct Input { + str: String, + } + + #[salsa::tracked] + fn new_salsa_length_query(db: &dyn PartialMigrationDatabase, input: Input) -> usize { + input.str(db).len() + } + + let input = Input::new(&db, String::from("Hello, world!")); + let len = new_salsa_length_query(&db, input); + + assert_eq!(len, 13); + db.assert_logs(expect![[r#" + [ + "salsa_event(WillCheckCancellation)", + "salsa_event(WillExecute { database_key: new_salsa_length_query(Id(0)) })", + ]"#]]); +} diff --git a/crates/query-group-macro/tests/result.rs b/crates/query-group-macro/tests/result.rs new file mode 100644 index 000000000000..0a77bfb9a990 --- /dev/null +++ b/crates/query-group-macro/tests/result.rs @@ -0,0 +1,53 @@ +mod logger_db; +use expect_test::expect; +use logger_db::LoggerDb; + +use query_group_macro::query_group; + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Error; + +#[query_group] +pub trait ResultDatabase: salsa::Database { + #[salsa::input] + fn input_string(&self) -> String; + + fn length(&self, key: ()) -> Result; + + fn length2(&self, key: ()) -> Result; +} + +fn length(db: &dyn ResultDatabase, key: ()) -> Result { + let _ = key; + Ok(db.input_string().len()) +} + +fn length2(db: &dyn ResultDatabase, key: ()) -> Result { + let _ = key; + Ok(db.input_string().len()) +} + + +#[test] +fn test_queries_with_results() { + let mut db = LoggerDb::default(); + let input = "hello"; + db.set_input_string(input.to_owned()); + assert_eq!(db.length(()), Ok(input.len())); + assert_eq!(db.length2(()), Ok(input.len())); + + db.assert_logs(expect![[r#" + [ + "salsa_event(WillCheckCancellation)", + "salsa_event(WillExecute { database_key: create_data_ResultDatabase(Id(0)) })", + "salsa_event(WillCheckCancellation)", + "salsa_event(DidValidateMemoizedValue { database_key: create_data_ResultDatabase(Id(0)) })", + "salsa_event(WillCheckCancellation)", + "salsa_event(WillExecute { database_key: length_shim(Id(800)) })", + "salsa_event(WillCheckCancellation)", + "salsa_event(WillCheckCancellation)", + "salsa_event(WillCheckCancellation)", + "salsa_event(WillExecute { database_key: length2_shim(Id(c00)) })", + "salsa_event(WillCheckCancellation)", + ]"#]]); +} diff --git a/crates/query-group-macro/tests/supertrait.rs b/crates/query-group-macro/tests/supertrait.rs new file mode 100644 index 000000000000..f7361eaa90fc --- /dev/null +++ b/crates/query-group-macro/tests/supertrait.rs @@ -0,0 +1,19 @@ +use query_group_macro::query_group; + +#[salsa::db] +pub trait SourceDb: salsa::Database { + /// Text of the file. + fn file_text(&self, id: usize) -> String; +} + +#[query_group] +pub trait RootDb: SourceDb { + fn parse(&self, id: usize) -> String; +} + +fn parse(db: &dyn RootDb, id: usize) -> String { + // this is the test: does the following compile? + db.file_text(id); + + String::new() +} diff --git a/crates/query-group-macro/tests/tuples.rs b/crates/query-group-macro/tests/tuples.rs new file mode 100644 index 000000000000..c0bcb9ea93da --- /dev/null +++ b/crates/query-group-macro/tests/tuples.rs @@ -0,0 +1,39 @@ +use query_group_macro::query_group; + +mod logger_db; +use expect_test::expect; +use logger_db::LoggerDb; + +#[query_group] +pub trait HelloWorldDatabase: salsa::Database { + #[salsa::input] + fn input_string(&self) -> String; + + fn length_query(&self, key: ()) -> (usize, usize); +} + +fn length_query(db: &dyn HelloWorldDatabase, key: ()) -> (usize, usize) { + let _ = key; + let len = db.input_string().len(); + (len, len) +} + +#[test] +fn query() { + let mut db = LoggerDb::default(); + + db.set_input_string(String::from("Hello, world!")); + let len = db.length_query(()); + + assert_eq!(len, (13, 13)); + db.assert_logs(expect![[r#" + [ + "salsa_event(WillCheckCancellation)", + "salsa_event(WillExecute { database_key: create_data_HelloWorldDatabase(Id(0)) })", + "salsa_event(WillCheckCancellation)", + "salsa_event(DidValidateMemoizedValue { database_key: create_data_HelloWorldDatabase(Id(0)) })", + "salsa_event(WillCheckCancellation)", + "salsa_event(WillExecute { database_key: length_query_shim(Id(800)) })", + "salsa_event(WillCheckCancellation)", + ]"#]]); +} diff --git a/crates/ra-salsa/ra-salsa-macros/Cargo.toml b/crates/ra-salsa/ra-salsa-macros/Cargo.toml index 5613d75c7522..3c2daacaf950 100644 --- a/crates/ra-salsa/ra-salsa-macros/Cargo.toml +++ b/crates/ra-salsa/ra-salsa-macros/Cargo.toml @@ -14,7 +14,7 @@ proc-macro = true name = "ra_salsa_macros" [dependencies] -heck = "0.4" +heck = "0.5.0" proc-macro2 = "1.0" quote = "1.0" syn = { version = "2.0", features = ["full", "extra-traits"] }