Skip to content

Commit

Permalink
support native async in trait
Browse files Browse the repository at this point in the history
  • Loading branch information
andylokandy committed Jul 21, 2023
1 parent 1f5a2c0 commit cc5fef9
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 284 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "logcall"
version = "0.1.0"
version = "0.1.1"
edition = "2021"
authors = ["andylokandy <andylokandy@hotmail.com>"]
description = "An attribute macro that logs the return value from function call."
Expand Down
305 changes: 22 additions & 283 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,8 @@ extern crate proc_macro;
extern crate proc_macro_error;

use proc_macro2::Span;
use proc_macro2::TokenStream;
use proc_macro2::TokenTree;
use quote::format_ident;
use quote::quote_spanned;
use syn::punctuated::Punctuated;
use syn::spanned::Spanned;
use syn::visit_mut::VisitMut;
use syn::Ident;
use syn::*;

Expand Down Expand Up @@ -69,8 +64,13 @@ pub fn logcall(
AsyncTraitKind::Async(async_expr) => {
// fallback if we couldn't find the '__async_trait' binding, might be
// useful for crates exhibiting the same behaviors as async-trait
let instrumented_block =
gen_block(&async_expr.block, true, &input.sig.ident.to_string(), args);
let instrumented_block = gen_block(
&async_expr.block,
true,
false,
&input.sig.ident.to_string(),
args,
);
let async_attrs = &async_expr.attrs;
quote! {
Box::pin(#(#async_attrs) * #instrumented_block )
Expand All @@ -81,30 +81,24 @@ pub fn logcall(
gen_block(
&input.block,
input.sig.asyncness.is_some(),
input.sig.asyncness.is_some(),
&input.sig.ident.to_string(),
args,
)
};

let ItemFn {
attrs,
vis,
mut sig,
..
attrs, vis, sig, ..
} = input;

if sig.asyncness.is_some() {
let has_self = has_self_in_sig(&mut sig);
transform_sig(&mut sig, has_self, true);
}

let Signature {
output: return_type,
inputs: params,
unsafety,
constness,
abi,
ident,
asyncness,
generics:
Generics {
params: gen_params,
Expand All @@ -116,7 +110,7 @@ pub fn logcall(

quote::quote!(
#(#attrs) *
#vis #constness #unsafety #abi fn #ident<#gen_params>(#params) #return_type
#vis #constness #unsafety #asyncness #abi fn #ident<#gen_params>(#params) #return_type
#where_clause
{
#func_body
Expand All @@ -129,20 +123,29 @@ pub fn logcall(
fn gen_block(
block: &Block,
async_context: bool,
async_keyword: bool,
fn_name: &str,
args: Args,
) -> proc_macro2::TokenStream {
// Generate the instrumented function body.
// If the function is an `async fn`, this will wrap it in an async block.
if async_context {
let log = gen_log(&args.level, fn_name, "__ret_value");
quote_spanned!(block.span()=>
let block = quote_spanned!(block.span()=>
async move {
let __ret_value = #block;
#log;
__ret_value
}
)
);

if async_keyword {
quote_spanned!(block.span()=>
#block.await
)
} else {
block
}
} else {
let log = gen_log(&args.level, fn_name, "__ret_value");
quote_spanned!(block.span()=>
Expand All @@ -165,270 +168,6 @@ fn gen_log(level: &str, fn_name: &str, return_value: &str) -> proc_macro2::Token
)
}

fn transform_sig(sig: &mut Signature, has_self: bool, is_local: bool) {
sig.fn_token.span = sig.asyncness.take().unwrap().span;

let ret = match &sig.output {
ReturnType::Default => quote!(()),
ReturnType::Type(_, ret) => quote!(#ret),
};

let default_span = sig
.ident
.span()
.join(sig.paren_token.span)
.unwrap_or_else(|| sig.ident.span());

let mut lifetimes = CollectLifetimes::new("'life", default_span);
for arg in sig.inputs.iter_mut() {
match arg {
FnArg::Receiver(arg) => lifetimes.visit_receiver_mut(arg),
FnArg::Typed(arg) => lifetimes.visit_type_mut(&mut arg.ty),
}
}

for param in sig.generics.params.iter() {
match param {
GenericParam::Type(param) => {
let param = &param.ident;
let span = param.span();
where_clause_or_default(&mut sig.generics.where_clause)
.predicates
.push(parse_quote_spanned!(span=> #param: 'logcall));
}
GenericParam::Lifetime(param) => {
let param = &param.lifetime;
let span = param.span();
where_clause_or_default(&mut sig.generics.where_clause)
.predicates
.push(parse_quote_spanned!(span=> #param: 'logcall));
}
GenericParam::Const(_) => {}
}
}

if sig.generics.lt_token.is_none() {
sig.generics.lt_token = Some(Token![<](sig.ident.span()));
}
if sig.generics.gt_token.is_none() {
sig.generics.gt_token = Some(Token![>](sig.paren_token.span));
}

for (idx, elided) in lifetimes.elided.iter().enumerate() {
sig.generics.params.insert(idx, parse_quote!(#elided));
where_clause_or_default(&mut sig.generics.where_clause)
.predicates
.push(parse_quote_spanned!(elided.span()=> #elided: 'logcall));
}

sig.generics
.params
.insert(0, parse_quote_spanned!(default_span=> 'logcall));

if has_self {
let bound_span = sig.ident.span();
let bound = match sig.inputs.iter().next() {
Some(FnArg::Receiver(Receiver {
reference: Some(_),
mutability: None,
..
})) => Ident::new("Sync", bound_span),
Some(FnArg::Typed(arg))
if match (arg.pat.as_ref(), arg.ty.as_ref()) {
(Pat::Ident(pat), Type::Reference(ty)) => {
pat.ident == "self" && ty.mutability.is_none()
}
_ => false,
} =>
{
Ident::new("Sync", bound_span)
}
_ => Ident::new("Send", bound_span),
};

let where_clause = where_clause_or_default(&mut sig.generics.where_clause);
where_clause.predicates.push(if is_local {
parse_quote_spanned!(bound_span=> Self: 'logcall)
} else {
parse_quote_spanned!(bound_span=> Self: ::core::marker::#bound + 'logcall)
});
}

for (i, arg) in sig.inputs.iter_mut().enumerate() {
match arg {
FnArg::Receiver(Receiver {
reference: Some(_), ..
}) => {}
FnArg::Receiver(arg) => arg.mutability = None,
FnArg::Typed(arg) => {
if let Pat::Ident(ident) = &mut *arg.pat {
ident.by_ref = None;
ident.mutability = None;
} else {
let positional = positional_arg(i, &arg.pat);
let m = mut_pat(&mut arg.pat);
arg.pat = parse_quote!(#m #positional);
}
}
}
}

let ret_span = sig.ident.span();
let bounds = if is_local {
quote_spanned!(ret_span=> 'logcall)
} else {
quote_spanned!(ret_span=> ::core::marker::Send + 'logcall)
};
sig.output = parse_quote_spanned! {ret_span=>
-> impl ::core::future::Future<Output = #ret> + #bounds
};
}

struct CollectLifetimes {
pub elided: Vec<Lifetime>,
pub explicit: Vec<Lifetime>,
pub name: &'static str,
pub default_span: Span,
}

impl CollectLifetimes {
pub fn new(name: &'static str, default_span: Span) -> Self {
CollectLifetimes {
elided: Vec::new(),
explicit: Vec::new(),
name,
default_span,
}
}

fn visit_opt_lifetime(&mut self, lifetime: &mut Option<Lifetime>) {
match lifetime {
None => *lifetime = Some(self.next_lifetime(None)),
Some(lifetime) => self.visit_lifetime(lifetime),
}
}

fn visit_lifetime(&mut self, lifetime: &mut Lifetime) {
if lifetime.ident == "_" {
*lifetime = self.next_lifetime(lifetime.span());
} else {
self.explicit.push(lifetime.clone());
}
}

fn next_lifetime<S: Into<Option<Span>>>(&mut self, span: S) -> Lifetime {
let name = format!("{}{}", self.name, self.elided.len());
let span = span.into().unwrap_or(self.default_span);
let life = Lifetime::new(&name, span);
self.elided.push(life.clone());
life
}
}

impl VisitMut for CollectLifetimes {
fn visit_receiver_mut(&mut self, arg: &mut Receiver) {
if let Some((_, lifetime)) = &mut arg.reference {
self.visit_opt_lifetime(lifetime);
}
}

fn visit_type_reference_mut(&mut self, ty: &mut TypeReference) {
self.visit_opt_lifetime(&mut ty.lifetime);
visit_mut::visit_type_reference_mut(self, ty);
}

fn visit_generic_argument_mut(&mut self, gen: &mut GenericArgument) {
if let GenericArgument::Lifetime(lifetime) = gen {
self.visit_lifetime(lifetime);
}
visit_mut::visit_generic_argument_mut(self, gen);
}
}

fn positional_arg(i: usize, pat: &Pat) -> Ident {
format_ident!("__arg{}", i, span = pat.span())
}

fn mut_pat(pat: &mut Pat) -> Option<Token![mut]> {
let mut visitor = HasMutPat(None);
visitor.visit_pat_mut(pat);
visitor.0
}

fn has_self_in_sig(sig: &mut Signature) -> bool {
let mut visitor = HasSelf(false);
visitor.visit_signature_mut(sig);
visitor.0
}

fn has_self_in_token_stream(tokens: TokenStream) -> bool {
tokens.into_iter().any(|tt| match tt {
TokenTree::Ident(ident) => ident == "Self",
TokenTree::Group(group) => has_self_in_token_stream(group.stream()),
_ => false,
})
}

struct HasMutPat(Option<Token![mut]>);

impl VisitMut for HasMutPat {
fn visit_pat_ident_mut(&mut self, i: &mut PatIdent) {
if let Some(m) = i.mutability {
self.0 = Some(m);
} else {
visit_mut::visit_pat_ident_mut(self, i);
}
}
}

struct HasSelf(bool);

impl VisitMut for HasSelf {
fn visit_expr_path_mut(&mut self, expr: &mut ExprPath) {
self.0 |= expr.path.segments[0].ident == "Self";
visit_mut::visit_expr_path_mut(self, expr);
}

fn visit_pat_path_mut(&mut self, pat: &mut PatPath) {
self.0 |= pat.path.segments[0].ident == "Self";
visit_mut::visit_pat_path_mut(self, pat);
}

fn visit_type_path_mut(&mut self, ty: &mut TypePath) {
self.0 |= ty.path.segments[0].ident == "Self";
visit_mut::visit_type_path_mut(self, ty);
}

fn visit_receiver_mut(&mut self, _arg: &mut Receiver) {
self.0 = true;
}

fn visit_item_mut(&mut self, _: &mut Item) {
// Do not recurse into nested items.
}

fn visit_macro_mut(&mut self, mac: &mut Macro) {
if !contains_fn(mac.tokens.clone()) {
self.0 |= has_self_in_token_stream(mac.tokens.clone());
}
}
}

fn contains_fn(tokens: TokenStream) -> bool {
tokens.into_iter().any(|tt| match tt {
TokenTree::Ident(ident) => ident == "fn",
TokenTree::Group(group) => contains_fn(group.stream()),
_ => false,
})
}

fn where_clause_or_default(clause: &mut Option<WhereClause>) -> &mut WhereClause {
clause.get_or_insert_with(|| WhereClause {
where_token: Default::default(),
predicates: Punctuated::new(),
})
}

enum AsyncTraitKind<'a> {
// old construction. Contains the function
Function(&'a ItemFn),
Expand Down
Loading

0 comments on commit cc5fef9

Please sign in to comment.