diff --git a/derive/src/lib.rs b/derive/src/lib.rs index 7ba60a7e..f85f26ba 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#![cfg_attr(not(feature = "std"), no_std)] - extern crate alloc; extern crate proc_macro; @@ -80,24 +78,30 @@ fn generate_type(input: TokenStream2) -> Result { let parity_scale_codec = crate_name_ident("parity-scale-codec")?; let ident = &ast.ident; - ast.generics .lifetimes_mut() .for_each(|l| *l = parse_quote!('static)); let (_, ty_generics, _) = ast.generics.split_for_impl(); - let where_clause = trait_bounds::make_where_clause( + let (where_clause, compact_types) = trait_bounds::make_where_clause( ident, &ast.generics, &ast.data, &scale_info, &parity_scale_codec, )?; - let generic_type_ids = ast.generics.type_params().map(|ty| { let ty_ident = &ty.ident; - quote! { - :: #scale_info ::meta_type::<#ty_ident>() + // If this type param is used as a parameter in a Compact field, then the call must be: + // ::scale_info::meta_type::<<#ty_ident as HasCompact>::Type>() + if compact_types.contains(ty_ident) { + quote! { + :: #scale_info ::meta_type::<<#ty_ident as :: #parity_scale_codec :: HasCompact>::Type>() + } + } else { + quote! { + :: #scale_info ::meta_type::<#ty_ident>() + } } }); @@ -274,7 +278,7 @@ fn generate_variant_type(data_enum: &DataEnum, scale_info: &Ident) -> TokenStrea let variants = variants.into_iter().map(|v| { let ident = &v.ident; - let v_name = quote! {stringify!(#ident) }; + let v_name = quote! { stringify!(#ident) }; match v.fields { Fields::Named(ref fs) => { let fields = generate_fields(&fs.named); diff --git a/derive/src/trait_bounds.rs b/derive/src/trait_bounds.rs index 8a64dd85..57383bee 100644 --- a/derive/src/trait_bounds.rs +++ b/derive/src/trait_bounds.rs @@ -14,27 +14,33 @@ use alloc::vec::Vec; use proc_macro2::Ident; +use std::collections::HashSet; use syn::{ parse_quote, punctuated::Punctuated, spanned::Spanned, visit::Visit, + GenericArgument, Generics, Result, Type, WhereClause, + WherePredicate, }; /// Generates a where clause for a `TypeInfo` impl, adding `TypeInfo + 'static` bounds to all /// relevant generic types including associated types (e.g. `T::A: TypeInfo`), correctly dealing /// with self-referential types. +/// Returns a tuple with a where clause and the set of type params that appear +/// in `Compact` fields/variants in the form `Something` (type params that +/// appear as just `T` are not included) pub fn make_where_clause<'a>( input_ident: &'a Ident, generics: &'a Generics, data: &'a syn::Data, scale_info: &Ident, parity_scale_codec: &Ident, -) -> Result { +) -> Result<(WhereClause, HashSet)> { let mut where_clause = generics.where_clause.clone().unwrap_or_else(|| { WhereClause { where_token: ::default(), @@ -48,86 +54,144 @@ pub fn make_where_clause<'a>( .collect::>(); if ty_params_ids.is_empty() { - return Ok(where_clause) + return Ok((where_clause, HashSet::new())) } let types = collect_types_to_bind(input_ident, data, &ty_params_ids)?; - - types.into_iter().for_each(|(ty, is_compact)| { - // Compact types need extra bounds, T: HasCompact and ::Type: TypeInfo + 'static + let mut need_compact_bounds: HashSet = HashSet::new(); + let mut need_normal_bounds: HashSet = HashSet::new(); + let mut where_predicates: HashSet = HashSet::new(); + types.into_iter().for_each(|(ty, is_compact, _)| { + // Compact types need two bounds: + // `T: HasCompact` and + // `::Type: TypeInfo + 'static` + let generic_arguments = collect_generic_arguments(&ty); if is_compact { - where_clause - .predicates - .push(parse_quote!(#ty : :: #parity_scale_codec ::HasCompact)); - where_clause - .predicates - .push(parse_quote!(<#ty as :: #parity_scale_codec ::HasCompact>::Type : :: #scale_info ::TypeInfo + 'static)); + where_predicates.insert(parse_quote!(#ty : :: #parity_scale_codec ::HasCompact)); + where_predicates.insert(parse_quote!(<#ty as :: #parity_scale_codec ::HasCompact>::Type : :: #scale_info ::TypeInfo + 'static)); + need_compact_bounds.extend(generic_arguments); } else { - where_clause - .predicates - .push(parse_quote!(#ty : :: #scale_info ::TypeInfo + 'static)); + where_predicates.insert(parse_quote!(#ty : :: #scale_info ::TypeInfo + 'static)); + need_normal_bounds.extend(generic_arguments); } }); - + // Loop over the type params given to the type and add `TypeInfo` and + // `'static` bounds. Type params that are used in fields/variants that are + // `Compact` need only the `'static` bound. + // The reason we do this "double looping", first over generics used in + // fields/variants and then over the generics given in the type definition + // itself, is that a type can have a type parameter bound to a trait, e.g. + // `T: SomeTrait`, that is not used as-is in the fields/variants but whose + // associated type **is**. Something like this: + // `struct A { one: T::SomeAssoc, }` + // + // When deriving `TypeInfo` for `A`, the first loop above adds + // `T::SomeAssoc: TypeInfo + 'static`, but we also need + // `T: TypeInfo + 'static`. + // Hence the second loop. generics.type_params().into_iter().for_each(|type_param| { let ident = type_param.ident.clone(); + // Find the type parameter in the list of types that appear in any of the + // fields/variants and check if it is used in a `Compact` field/variant. + // If yes, only add the `'static` bound, else add both `TypeInfo` and + // `'static` bounds. + let mut bounds = type_param.bounds.clone(); - bounds.push(parse_quote!(:: #scale_info ::TypeInfo)); - bounds.push(parse_quote!('static)); - where_clause - .predicates - .push(parse_quote!( #ident : #bounds)); + if need_compact_bounds.contains(&ident) { + bounds.push(parse_quote!('static)); + } else { + bounds.push(parse_quote!(:: #scale_info ::TypeInfo)); + bounds.push(parse_quote!('static)); + } + where_predicates.insert(parse_quote!( #ident : #bounds)); }); - Ok(where_clause) + where_predicates.extend( + need_compact_bounds + .iter() + .map(|tp_ident| parse_quote!( #tp_ident : 'static )), + ); + where_predicates.extend( + need_normal_bounds.iter().map( + |tp_ident| parse_quote!( #tp_ident : :: #scale_info ::TypeInfo + 'static ), + ), + ); + where_clause.predicates.extend(where_predicates); + + Ok((where_clause, need_compact_bounds)) } -/// Visits the ast and checks if the given type contains one of the given +/// Visits the ast for a [`syn::Type`] and checks if it contains one of the given /// idents. -fn type_contains_idents(ty: &Type, idents: &[Ident]) -> bool { +fn type_contains_idents(ty: &Type, idents: &[Ident]) -> Option { struct ContainIdents<'a> { - result: bool, idents: &'a [Ident], + result: Option, } - impl<'a, 'ast> Visit<'ast> for ContainIdents<'a> { fn visit_ident(&mut self, i: &'ast Ident) { if self.idents.iter().any(|id| id == i) { - self.result = true; + self.result = Some(i.clone()); } } } let mut visitor = ContainIdents { - result: false, idents, + result: None, }; visitor.visit_type(ty); visitor.result } -/// Returns all types that must be added to the where clause with a boolean +/// Visit a `Type` and collect generic type params used. +/// Given `struct A { thing: SomeType }`, will include `T` in the set. +/// Given `struct A { thing: T }`, will **not** include `T` in the set. +fn collect_generic_arguments(ty: &Type) -> HashSet { + struct GenericsVisitor { + found: HashSet, + } + impl<'ast> Visit<'ast> for GenericsVisitor { + fn visit_generic_argument(&mut self, g: &'ast GenericArgument) { + if let GenericArgument::Type(syn::Type::Path(syn::TypePath { + path: t, .. + })) = g + { + if let Some(ident) = t.get_ident() { + self.found.insert(ident.clone()); + } + } + } + } + let mut visitor = GenericsVisitor { + found: HashSet::new(), + }; + visitor.visit_type(ty); + visitor.found +} + +/// Returns all types that must be added to the where clause, with a boolean /// indicating if the field is [`scale::Compact`] or not. fn collect_types_to_bind( input_ident: &Ident, data: &syn::Data, ty_params: &[Ident], -) -> Result> { - let types_from_fields = |fields: &Punctuated| -> Vec<(Type, bool)> { - fields - .iter() - .filter(|field| { +) -> Result)>> { + let types_from_fields = + |fields: &Punctuated| -> Vec<(Type, bool, Option)> { + fields.iter().fold(Vec::new(), |mut acc, field| { // Only add a bound if the type uses a generic. - type_contains_idents(&field.ty, &ty_params) - && - // Remove all remaining types that start/contain the input ident + let uses_generic = type_contains_idents(&field.ty, &ty_params); + // Find remaining types that start with/contain the input ident // to not have them in the where clause. - !type_contains_idents(&field.ty, &[input_ident.clone()]) + let uses_input_ident = + type_contains_idents(&field.ty, &[input_ident.clone()]); + if uses_generic.is_some() && uses_input_ident.is_none() { + acc.push((field.ty.clone(), super::is_compact(field), uses_generic)); + } + acc }) - .map(|f| (f.ty.clone(), super::is_compact(f))) - .collect() - }; + }; let types = match *data { syn::Data::Struct(ref data) => { diff --git a/test_suite/tests/derive.rs b/test_suite/tests/derive.rs index 55676c12..d0d8ae6f 100644 --- a/test_suite/tests/derive.rs +++ b/test_suite/tests/derive.rs @@ -280,6 +280,39 @@ fn scale_compact_types_work_in_enums() { assert_type!(MutilatedMultiAddress, ty); } +#[test] +fn scale_compact_types_complex() { + trait Boo { + type B: TypeInfo; + } + impl Boo for u8 { + type B = bool; + } + + #[allow(unused)] + #[derive(Encode, TypeInfo)] + struct A { + one: PhantomData, + two: PhantomData, + #[codec(compact)] + three: T, + four: T::B, + } + + let ty = Type::builder() + .path(Path::new("A", "derive")) + .type_params(tuple_meta_type![u8, u16]) + .composite( + Fields::named() + .field_of::>("one", "PhantomData") + .field_of::>("two", "PhantomData") + .compact_of::("three", "T") + .field_of::("four", "T::B"), + ); + + assert_type!(A, ty); +} + #[test] fn whitespace_scrubbing_works() { #[allow(unused)] @@ -303,6 +336,7 @@ fn ui_tests() { t.compile_fail("tests/ui/fail_unions.rs"); t.compile_fail("tests/ui/fail_use_codec_attrs_without_deriving_encode.rs"); t.compile_fail("tests/ui/fail_with_invalid_codec_attrs.rs"); + t.compile_fail("tests/ui/fail_infinite_recursion.rs"); t.pass("tests/ui/pass_with_valid_codec_attrs.rs"); t.pass("tests/ui/pass_non_static_lifetime.rs"); t.pass("tests/ui/pass_self_referential.rs"); diff --git a/test_suite/tests/json.rs b/test_suite/tests/json.rs index 393a670e..34c90667 100644 --- a/test_suite/tests/json.rs +++ b/test_suite/tests/json.rs @@ -247,7 +247,7 @@ fn test_struct_with_phantom() { use scale_info::prelude::marker::PhantomData; #[derive(TypeInfo)] struct Struct { - a: i32, + a: T, b: PhantomData, } diff --git a/test_suite/tests/ui/fail_infinite_recursion.rs b/test_suite/tests/ui/fail_infinite_recursion.rs new file mode 100644 index 00000000..b1c67e96 --- /dev/null +++ b/test_suite/tests/ui/fail_infinite_recursion.rs @@ -0,0 +1,22 @@ +use scale_info::TypeInfo; +use scale::Encode; + +#[derive(TypeInfo)] +struct Color{hue: Hue} +#[derive(TypeInfo)] +struct Texture{bump: Bump, hump: Hump} + +#[allow(unused)] +#[derive(Encode, TypeInfo)] +struct Apple { + #[codec(compact)] + one: Color, // <– works with a "naked" generic, `U`, but not like this + two: Texture, +} + +fn assert_type_info() {} + +fn main() { + // When this test fails it could mean that https://github.com/rust-lang/rust/issues/81785 is fixed + assert_type_info::>(); +} diff --git a/test_suite/tests/ui/fail_infinite_recursion.stderr b/test_suite/tests/ui/fail_infinite_recursion.stderr new file mode 100644 index 00000000..1ea4afc4 --- /dev/null +++ b/test_suite/tests/ui/fail_infinite_recursion.stderr @@ -0,0 +1,6 @@ +error[E0275]: overflow evaluating the requirement `_parity_scale_codec::Compact<_>: Decode` + | + = help: consider adding a `#![recursion_limit="256"]` attribute to your crate (`$CRATE`) + = note: required because of the requirements on the impl of `Decode` for `_parity_scale_codec::Compact<_>` + = note: 126 redundant requirements hidden + = note: required because of the requirements on the impl of `Decode` for `_parity_scale_codec::Compact<_>`