Skip to content
30 changes: 19 additions & 11 deletions derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#![cfg_attr(not(feature = "std"), no_std)]

extern crate alloc;
extern crate proc_macro;

Expand Down Expand Up @@ -53,7 +51,7 @@ use syn::{
Variant,
};

#[proc_macro_derive(TypeInfo)]
#[proc_macro_derive(TypeInfo, attributes(scale_info))]
pub fn type_info(input: TokenStream) -> TokenStream {
match generate(input.into()) {
Ok(output) => output.into(),
Expand All @@ -68,21 +66,31 @@ fn generate(input: TokenStream2) -> Result<TokenStream2> {
}

fn generate_type(input: TokenStream2) -> Result<TokenStream2> {
let ast: DeriveInput = syn::parse2(input.clone())?;
let mut ast: DeriveInput = syn::parse2(input.clone())?;

utils::check_attributes(&ast)?;

let scale_info = crate_name_ident("scale-info")?;
let parity_scale_codec = crate_name_ident("parity-scale-codec")?;

let ident = &ast.ident;

let where_clause = if let Some(custom_bounds) = utils::custom_trait_bounds(&ast.attrs)
{
let where_clause = ast.generics.make_where_clause();
where_clause.predicates.extend(custom_bounds);
where_clause.clone()
} else {
trait_bounds::make_where_clause(
ident,
&ast.generics,
&ast.data,
&scale_info,
&parity_scale_codec,
)?
};

let (impl_generics, ty_generics, _) = ast.generics.split_for_impl();
let where_clause = trait_bounds::make_where_clause(
ident,
&ast.generics,
&ast.data,
&scale_info,
&parity_scale_codec,
)?;

let generic_type_ids = ast.generics.type_params().map(|ty| {
let ty_ident = &ty.ident;
Expand Down
103 changes: 87 additions & 16 deletions derive/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,56 @@
use proc_macro2::TokenStream;
use quote::quote;
use syn::{
parse::Parse,
punctuated::Punctuated,
spanned::Spanned,
token,
AttrStyle,
Attribute,
DeriveInput,
Lit,
Meta,
NestedMeta,
Variant,
};

/// Trait bounds.
pub type TraitBounds = Punctuated<syn::WherePredicate, token::Comma>;

/// Parse `name(T: Bound, N: Bound)` as a custom trait bound.
struct CustomTraitBound<N> {
_name: N,
_paren_token: token::Paren,
bounds: TraitBounds,
}

impl<N: Parse> Parse for CustomTraitBound<N> {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let content;
Ok(Self {
_name: input.parse()?,
_paren_token: syn::parenthesized!(content in input),
bounds: content.parse_terminated(syn::WherePredicate::parse)?,
})
}
}

syn::custom_keyword!(bounds);

/// Look for a `#[scale_info(bounds(…))]`in the given attributes.
///
/// If found, use the given trait bounds when deriving the `TypeInfo` trait.
pub fn custom_trait_bounds(attrs: &[Attribute]) -> Option<TraitBounds> {
scale_info_meta_item(attrs.iter(), |meta: CustomTraitBound<bounds>| {
Some(meta.bounds)
})
}

/// Look for a `#[codec(index = $int)]` attribute on a variant. If no attribute
/// is found, fall back to the discriminant or just the variant index.
pub fn variant_index(v: &Variant, i: usize) -> TokenStream {
// first look for an attribute
let index = find_meta_item(v.attrs.iter(), |meta| {
let index = codec_meta_item(v.attrs.iter(), |meta| {
if let NestedMeta::Meta(Meta::NameValue(ref nv)) = meta {
if nv.path.is_ident("index") {
if let Lit::Int(ref v) = nv.lit {
Expand Down Expand Up @@ -62,7 +98,7 @@ pub fn is_compact(field: &syn::Field) -> bool {
.attrs
.iter()
.filter(|attr| attr.style == AttrStyle::Outer);
find_meta_item(outer_attrs, |meta| {
codec_meta_item(outer_attrs, |meta| {
if let NestedMeta::Meta(Meta::Path(ref path)) = meta {
if path.is_ident("compact") {
return Some(())
Expand All @@ -76,7 +112,7 @@ pub fn is_compact(field: &syn::Field) -> bool {

/// Look for a `#[codec(skip)]` in the given attributes.
pub fn should_skip(attrs: &[Attribute]) -> bool {
find_meta_item(attrs.iter(), |meta| {
codec_meta_item(attrs.iter(), |meta| {
if let NestedMeta::Meta(Meta::Path(ref path)) = meta {
if path.is_ident("skip") {
return Some(path.span())
Expand All @@ -88,22 +124,57 @@ pub fn should_skip(attrs: &[Attribute]) -> bool {
.is_some()
}

fn find_meta_item<'a, F, R, I>(itr: I, pred: F) -> Option<R>
fn codec_meta_item<'a, F, R, I, M>(itr: I, pred: F) -> Option<R>
where
F: Fn(&NestedMeta) -> Option<R> + Clone,
F: FnMut(M) -> Option<R> + Clone,
I: Iterator<Item = &'a Attribute>,
M: Parse,
{
itr.filter_map(|attr| {
if attr.path.is_ident("codec") {
if let Meta::List(ref meta_list) = attr
.parse_meta()
.expect("scale-info: Bad index in `#[codec(index = …)]`, see `parity-scale-codec` error")
{
return meta_list.nested.iter().filter_map(pred.clone()).next()
}
}
find_meta_item("codec", itr, pred)
}

None
fn scale_info_meta_item<'a, F, R, I, M>(itr: I, pred: F) -> Option<R>
where
F: FnMut(M) -> Option<R> + Clone,
I: Iterator<Item = &'a Attribute>,
M: Parse,
{
find_meta_item("scale_info", itr, pred)
}

fn find_meta_item<'a, F, R, I, M>(kind: &str, mut itr: I, mut pred: F) -> Option<R>
where
F: FnMut(M) -> Option<R> + Clone,
I: Iterator<Item = &'a Attribute>,
M: Parse,
{
itr.find_map(|attr| {
attr.path
.is_ident(kind)
.then(|| pred(attr.parse_args().ok()?))
.flatten()
})
.next()
}

/// Ensure attributes are correctly applied. This *must* be called before using
/// any of the attribute finder methods or the macro may panic if it encounters
/// misapplied attributes.
/// `#[scale_info(bounds())]` is the only accepted attribute.
pub fn check_attributes(input: &DeriveInput) -> syn::Result<()> {
for attr in &input.attrs {
check_top_attribute(attr)?;
}
Ok(())
}

// Only `#[scale_info(bounds())]` is a valid top attribute.
fn check_top_attribute(attr: &Attribute) -> syn::Result<()> {
if attr.path.is_ident("scale_info") {
match attr.parse_args::<CustomTraitBound<bounds>>() {
Ok(_) => Ok(()),
Err(e) => Err(syn::Error::new(attr.span(), format!("Invalid attribute: {:?}. Only `#[scale_info(bounds(…))]` is a valid top attribute", e)))
}
} else {
Ok(())
}
}
35 changes: 35 additions & 0 deletions test_suite/tests/derive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,41 @@ fn whitespace_scrubbing_works() {
assert_type!(A, ty);
}

#[test]
fn custom_bounds() {
// TODO: this test is dumb. It's a copy of Basti's equivalent in `parity-scale-codec` but I
// don't think it can work for us. I need a proper example of when custom bounds are needed.
// As-is, this test is simply setting the same bounds as the derive would have, which is pretty
// pointless.
#[allow(unused)]
#[derive(TypeInfo)]
#[scale_info(bounds(T: Default + TypeInfo + 'static, N: TypeInfo + 'static))]
struct Hey<T, N> {
ciao: Greet<T>,
ho: N,
}

#[derive(TypeInfo)]
#[scale_info(bounds(T: TypeInfo + 'static))]
struct Greet<T> {
marker: PhantomData<T>,
}

#[derive(TypeInfo, Default)]
struct SomeType;

let ty = Type::builder()
.path(Path::new("Hey", "derive"))
.type_params(tuple_meta_type!(SomeType, u16))
.composite(
Fields::named()
.field_of::<Greet<SomeType>>("ciao", "Greet<T>")
.field_of::<u16>("ho", "N"),
);

assert_type!(Hey<SomeType, u16>, ty);
}

#[rustversion::nightly]
#[test]
fn ui_tests() {
Expand Down