Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

defmt-test: Modify attributes in place and handle #[cfg] #383

Merged
merged 2 commits into from
Feb 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 70 additions & 56 deletions firmware/defmt-test/macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ extern crate proc_macro;
use proc_macro::TokenStream;
use proc_macro2::Span;
use quote::{format_ident, quote};
use syn::{parse, spanned::Spanned, Block, FnArg, Ident, Item, ItemMod, Path, ReturnType, Type};
use syn::{parse, spanned::Spanned, Attribute, Item, ItemFn, ItemMod, ReturnType, Type};

#[proc_macro_attribute]
pub fn tests(args: TokenStream, input: TokenStream) -> TokenStream {
Expand Down Expand Up @@ -37,24 +37,24 @@ fn tests_impl(args: TokenStream, input: TokenStream) -> parse::Result<TokenStrea
let mut imports = vec![];
for item in items {
match item {
Item::Fn(f) => {
Item::Fn(mut f) => {
let mut test_kind = None;
let mut should_error = false;

for attr in &f.attrs {
if path_is_ident(&attr.path, "init") {
f.attrs.retain(|attr| {
if attr.path.is_ident("init") {
test_kind = Some(Attr::Init);
} else if path_is_ident(&attr.path, "test") {
false
} else if attr.path.is_ident("test") {
test_kind = Some(Attr::Test);
} else if path_is_ident(&attr.path, "should_error") {
false
} else if attr.path.is_ident("should_error") {
should_error = true;
false
} else {
return Err(parse::Error::new(
attr.span(),
"only attributes `#[test]`, `#[init]` and `#[should_error]` are accepted",
));
true
}
}
});

let attr = match test_kind {
Some(it) => it,
Expand Down Expand Up @@ -89,16 +89,12 @@ fn tests_impl(args: TokenStream, input: TokenStream) -> parse::Result<TokenStrea
));
}

let state = match f.sig.output {
let state = match &f.sig.output {
ReturnType::Default => None,
ReturnType::Type(.., ty) => Some(ty),
ReturnType::Type(.., ty) => Some(ty.clone()),
};

init = Some(Init {
block: f.block,
ident: f.sig.ident,
state,
});
init = Some(Init { func: f, state });
}

Attr::Test => {
Expand All @@ -115,10 +111,7 @@ fn tests_impl(args: TokenStream, input: TokenStream) -> parse::Result<TokenStrea
// NOTE we cannot check the argument type matches `init.state` at this
// point
if let Some(ty) = get_mutable_reference_type(arg).cloned() {
Some(Input {
arg: arg.clone(),
ty,
})
Some(Input { ty })
} else {
// was not `&mut T`
return Err(parse::Error::new(
Expand All @@ -130,16 +123,10 @@ fn tests_impl(args: TokenStream, input: TokenStream) -> parse::Result<TokenStrea
None
};

let ret_ty = match f.sig.output {
ReturnType::Default => syn::parse_str("()").unwrap(),
ReturnType::Type(_, ty) => (*ty).clone(),
};

tests.push(Test {
block: f.block,
ident: f.sig.ident,
cfgs: extract_cfgs(&f.attrs),
func: f,
input,
ret_ty,
should_error,
})
}
Expand All @@ -163,12 +150,12 @@ fn tests_impl(args: TokenStream, input: TokenStream) -> parse::Result<TokenStrea
let ident = module.ident;
let mut state_ty = None;
let (init_fn, init_expr) = if let Some(init) = init {
let init_ident = init.ident;
let init_block = init.block;
let init_func = &init.func;
let init_ident = &init.func.sig.ident;
state_ty = init.state;

(
Some(quote!(fn #init_ident() -> #state_ty #init_block)),
Some(quote!(#init_func)),
Some(quote!(#[allow(dead_code)] let mut state = #init_ident();)),
)
} else {
Expand All @@ -178,8 +165,8 @@ fn tests_impl(args: TokenStream, input: TokenStream) -> parse::Result<TokenStrea
let mut unit_test_calls = vec![];
for test in &tests {
let should_error = test.should_error;
let ident = &test.ident;
let span = ident.span();
let ident = &test.func.sig.ident;
let span = test.func.sig.ident.span();
let call = if let Some(input) = test.input.as_ref() {
if let Some(state) = &state_ty {
if input.ty != **state {
Expand All @@ -203,26 +190,48 @@ fn tests_impl(args: TokenStream, input: TokenStream) -> parse::Result<TokenStrea
#krate::export::check_outcome(#call, #should_error);
));
}
let unit_test_names = tests.iter().map(|test| &test.ident);
let unit_test_inputs = tests
.iter()
.map(|test| test.input.as_ref().map(|input| &input.arg));
let unit_test_outputs = tests.iter().map(|test| &test.ret_ty);
let unit_test_blocks = tests.iter().map(|test| &test.block);

let test_functions = tests.iter().map(|test| &test.func);
let test_cfgs = tests.iter().map(|test| &test.cfgs);
let declare_test_count = {
let test_cfgs = test_cfgs.clone();
quote!(
// We can't evaluate `#[cfg]`s in the macro, but this works too.
const __DEFMT_TEST_COUNT: usize = {
let mut counter = 0;
#(
#(#test_cfgs)*
{ counter += 1; }
)*
counter
};
)
};
let unit_test_running = tests
.iter()
.enumerate()
.map(|(i, test)| format!("({}/{}) running `{}`...", i + 1, tests.len(), test.ident))
.map(|test| {
format!(
"({{=usize}}/{{=usize}}) running `{}`...",
test.func.sig.ident
)
})
.collect::<Vec<_>>();
Ok(quote!(mod #ident {
#(#imports)*
// TODO use `cortex-m-rt::entry` here to get the `static mut` transform
#[export_name = "main"]
unsafe extern "C" fn __defmt_test_entry() -> ! {
#declare_test_count
#init_expr

let mut __defmt_test_number: usize = 1;
#(
defmt::info!(#unit_test_running);
#unit_test_calls
#(#test_cfgs)*
{
defmt::info!(#unit_test_running, __defmt_test_number, __DEFMT_TEST_COUNT);
#unit_test_calls
__defmt_test_number += 1;
}
)*

defmt::info!("all tests passed!");
Expand All @@ -232,7 +241,7 @@ fn tests_impl(args: TokenStream, input: TokenStream) -> parse::Result<TokenStrea
#init_fn

#(
fn #unit_test_names(#unit_test_inputs) -> #unit_test_outputs #unit_test_blocks
#test_functions
)*
})
.into())
Expand All @@ -245,28 +254,21 @@ enum Attr {
}

struct Init {
block: Box<Block>,
ident: Ident,
func: ItemFn,
state: Option<Box<Type>>,
}

struct Test {
block: Box<Block>,
ident: Ident,
func: ItemFn,
cfgs: Vec<Attribute>,
input: Option<Input>,
ret_ty: Type,
should_error: bool,
}

struct Input {
arg: FnArg,
ty: Type,
}

fn path_is_ident(path: &Path, s: &str) -> bool {
path.get_ident().map(|ident| ident == s).unwrap_or(false)
}

// NOTE doesn't check the parameters or the return type
fn check_fn_sig(sig: &syn::Signature) -> Result<(), ()> {
if sig.constness.is_none()
Expand Down Expand Up @@ -298,3 +300,15 @@ fn get_mutable_reference_type(arg: &syn::FnArg) -> Option<&Type> {
None
}
}

fn extract_cfgs(attrs: &[Attribute]) -> Vec<Attribute> {
let mut cfgs = vec![];

for attr in attrs {
if attr.path.is_ident("cfg") {
cfgs.push(attr.clone());
}
}

cfgs
}
10 changes: 10 additions & 0 deletions firmware/qemu/src/bin/defmt-test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ mod tests {
use core::u8::MAX;
use defmt::{assert, assert_eq};

#[init]
fn init() {}

#[test]
fn assert_true() -> () {
assert!(true);
Expand All @@ -18,11 +21,18 @@ mod tests {
assert_eq!(255, MAX);
}

#[cfg(not(never))]
#[test]
fn result() -> Result<(), ()> {
Ok(())
}

#[cfg(never)]
#[test]
fn doesnt_compile() {
jonas-schievink marked this conversation as resolved.
Show resolved Hide resolved
because::this::doesnt::exist();
}

#[test]
#[should_error]
fn should_error() -> Result<(), ()> {
Expand Down