Skip to content
Closed
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions derive/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ quote = "1.0"
syn = { version = "1.0", features = ["derive", "visit", "visit-mut", "extra-traits"] }
proc-macro2 = "1.0"
proc-macro-crate = "0.1.5"
hashbrown = "0.9.1"
18 changes: 12 additions & 6 deletions derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,24 +80,30 @@ fn generate_type(input: TokenStream2) -> Result<TokenStream2> {
let parity_scale_codec = crate_name_ident("parity-scale-codec")?;

let ident = &ast.ident;

ast.generics
.lifetimes_mut()
.for_each(|l| *l = parse_quote!('static));

let (_, ty_generics, _) = ast.generics.split_for_impl();
let where_clause = trait_bounds::make_where_clause(
let (where_clause, compact_types) = trait_bounds::make_where_clause(
ident,
&ast.generics,
&ast.data,
&scale_info,
&parity_scale_codec,
)?;

let generic_type_ids = ast.generics.type_params().map(|ty| {
let ty_ident = &ty.ident;
quote! {
:: #scale_info ::meta_type::<#ty_ident>()
// If this type param is used as a parameter in a Compact field, then the call must be:
// ::scale_info::meta_type::<<#ty_ident as HasCompact>::Type>()
if compact_types.contains(ty_ident) {
quote! {
:: #scale_info ::meta_type::<<#ty_ident as :: #parity_scale_codec :: HasCompact>::Type>()
}
} else {
quote! {
:: #scale_info ::meta_type::<#ty_ident>()
}
}
});

Expand Down Expand Up @@ -274,7 +280,7 @@ fn generate_variant_type(data_enum: &DataEnum, scale_info: &Ident) -> TokenStrea

let variants = variants.into_iter().map(|v| {
let ident = &v.ident;
let v_name = quote! {stringify!(#ident) };
let v_name = quote! { stringify!(#ident) };
match v.fields {
Fields::Named(ref fs) => {
let fields = generate_fields(&fs.named);
Expand Down
146 changes: 105 additions & 41 deletions derive/src/trait_bounds.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,34 @@
// limitations under the License.

use alloc::vec::Vec;
use hashbrown::HashSet;
use proc_macro2::Ident;
use syn::{
parse_quote,
punctuated::Punctuated,
spanned::Spanned,
visit::Visit,
GenericArgument,
Generics,
Result,
Type,
WhereClause,
WherePredicate,
};

/// Generates a where clause for a `TypeInfo` impl, adding `TypeInfo + 'static` bounds to all
/// relevant generic types including associated types (e.g. `T::A: TypeInfo`), correctly dealing
/// with self-referential types.
/// Returns a tuple with a where clause and the set of type params that appear
/// in `Compact` fields/variants in the form `Something<T>` (type params that
/// appear as just `T` are not included)
pub fn make_where_clause<'a>(
input_ident: &'a Ident,
generics: &'a Generics,
data: &'a syn::Data,
scale_info: &Ident,
parity_scale_codec: &Ident,
) -> Result<WhereClause> {
) -> Result<(WhereClause, HashSet<Ident>)> {
let mut where_clause = generics.where_clause.clone().unwrap_or_else(|| {
WhereClause {
where_token: <syn::Token![where]>::default(),
Expand All @@ -48,86 +54,144 @@ pub fn make_where_clause<'a>(
.collect::<Vec<Ident>>();

if ty_params_ids.is_empty() {
return Ok(where_clause)
return Ok((where_clause, HashSet::new()))
}

let types = collect_types_to_bind(input_ident, data, &ty_params_ids)?;

types.into_iter().for_each(|(ty, is_compact)| {
// Compact types need extra bounds, T: HasCompact and <T as
// HasCompact>::Type: TypeInfo + 'static
let mut need_compact_bounds: HashSet<Ident> = HashSet::new();
let mut need_normal_bounds: HashSet<Ident> = HashSet::new();
let mut where_predicates: HashSet<WherePredicate> = HashSet::new();
types.into_iter().for_each(|(ty, is_compact, _)| {
// Compact types need two bounds:
// `T: HasCompact` and
// `<T as HasCompact>::Type: TypeInfo + 'static`
let generic_arguments = collect_generic_arguments(&ty);
if is_compact {
where_clause
.predicates
.push(parse_quote!(#ty : :: #parity_scale_codec ::HasCompact));
where_clause
.predicates
.push(parse_quote!(<#ty as :: #parity_scale_codec ::HasCompact>::Type : :: #scale_info ::TypeInfo + 'static));
where_predicates.insert(parse_quote!(#ty : :: #parity_scale_codec ::HasCompact));
where_predicates.insert(parse_quote!(<#ty as :: #parity_scale_codec ::HasCompact>::Type : :: #scale_info ::TypeInfo + 'static));
need_compact_bounds.extend(generic_arguments);
} else {
where_clause
.predicates
.push(parse_quote!(#ty : :: #scale_info ::TypeInfo + 'static));
where_predicates.insert(parse_quote!(#ty : :: #scale_info ::TypeInfo + 'static));
need_normal_bounds.extend(generic_arguments);
}
});

// Loop over the type params given to the type and add `TypeInfo` and
// `'static` bounds. Type params that are used in fields/variants that are
// `Compact` need only the `'static` bound.
// The reason we do this "double looping", first over generics used in
// fields/variants and then over the generics given in the type definition
// itself, is that a type can have a type parameter bound to a trait, e.g.
// `T: SomeTrait`, that is not used as-is in the fields/variants but whose
// associated type **is**. Something like this:
// `struct A<T: SomeTrait> { one: T::SomeAssoc, }`
//
// When deriving `TypeInfo` for `A`, the first loop above adds
// `T::SomeAssoc: TypeInfo + 'static`, but we also need
// `T: TypeInfo + 'static`.
// Hence the second loop.
generics.type_params().into_iter().for_each(|type_param| {
let ident = type_param.ident.clone();
// Find the type parameter in the list of types that appear in any of the
// fields/variants and check if it is used in a `Compact` field/variant.
// If yes, only add the `'static` bound, else add both `TypeInfo` and
// `'static` bounds.

let mut bounds = type_param.bounds.clone();
bounds.push(parse_quote!(:: #scale_info ::TypeInfo));
bounds.push(parse_quote!('static));
where_clause
.predicates
.push(parse_quote!( #ident : #bounds));
if need_compact_bounds.contains(&ident) {
bounds.push(parse_quote!('static));
} else {
bounds.push(parse_quote!(:: #scale_info ::TypeInfo));
bounds.push(parse_quote!('static));
}
where_predicates.insert(parse_quote!( #ident : #bounds));
});

Ok(where_clause)
where_predicates.extend(
need_compact_bounds
.iter()
.map(|tp_ident| parse_quote!( #tp_ident : 'static )),
);
where_predicates.extend(
need_normal_bounds.iter().map(
|tp_ident| parse_quote!( #tp_ident : :: #scale_info ::TypeInfo + 'static ),
),
);
where_clause.predicates.extend(where_predicates);

Ok((where_clause, need_compact_bounds))
}

/// Visits the ast and checks if the given type contains one of the given
/// Visits the ast for a [`syn::Type`] and checks if it contains one of the given
/// idents.
fn type_contains_idents(ty: &Type, idents: &[Ident]) -> bool {
fn type_contains_idents(ty: &Type, idents: &[Ident]) -> Option<Ident> {
struct ContainIdents<'a> {
result: bool,
idents: &'a [Ident],
result: Option<Ident>,
}

impl<'a, 'ast> Visit<'ast> for ContainIdents<'a> {
fn visit_ident(&mut self, i: &'ast Ident) {
if self.idents.iter().any(|id| id == i) {
self.result = true;
self.result = Some(i.clone());
}
}
}

let mut visitor = ContainIdents {
result: false,
idents,
result: None,
};
visitor.visit_type(ty);
visitor.result
}

/// Returns all types that must be added to the where clause with a boolean
/// Visit a `Type` and collect generic type params used.
/// Given `struct A<T> { thing: SomeType<T> }`, will include `T` in the set.
/// Given `struct A<T> { thing: T }`, will **not** include `T` in the set.
fn collect_generic_arguments(ty: &Type) -> HashSet<Ident> {
struct GenericsVisitor {
found: HashSet<Ident>,
}
impl<'ast> Visit<'ast> for GenericsVisitor {
fn visit_generic_argument(&mut self, g: &'ast GenericArgument) {
if let GenericArgument::Type(syn::Type::Path(syn::TypePath {
path: t, ..
})) = g
{
if let Some(ident) = t.get_ident() {
self.found.insert(ident.clone());
}
}
}
}
let mut visitor = GenericsVisitor {
found: HashSet::new(),
};
visitor.visit_type(ty);
visitor.found
}

/// Returns all types that must be added to the where clause, with a boolean
/// indicating if the field is [`scale::Compact`] or not.
fn collect_types_to_bind(
input_ident: &Ident,
data: &syn::Data,
ty_params: &[Ident],
) -> Result<Vec<(Type, bool)>> {
let types_from_fields = |fields: &Punctuated<syn::Field, _>| -> Vec<(Type, bool)> {
fields
.iter()
.filter(|field| {
) -> Result<Vec<(Type, bool, Option<Ident>)>> {
let types_from_fields =
|fields: &Punctuated<syn::Field, _>| -> Vec<(Type, bool, Option<Ident>)> {
fields.iter().fold(Vec::new(), |mut acc, field| {
// Only add a bound if the type uses a generic.
type_contains_idents(&field.ty, &ty_params)
&&
// Remove all remaining types that start/contain the input ident
let uses_generic = type_contains_idents(&field.ty, &ty_params);
// Find remaining types that start with/contain the input ident
// to not have them in the where clause.
!type_contains_idents(&field.ty, &[input_ident.clone()])
let uses_input_ident =
type_contains_idents(&field.ty, &[input_ident.clone()]);
if uses_generic.is_some() && uses_input_ident.is_none() {
acc.push((field.ty.clone(), super::is_compact(field), uses_generic));
}
acc
})
.map(|f| (f.ty.clone(), super::is_compact(f)))
.collect()
};
};

let types = match *data {
syn::Data::Struct(ref data) => {
Expand Down
34 changes: 34 additions & 0 deletions test_suite/tests/derive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,39 @@ fn scale_compact_types_work_in_enums() {
assert_type!(MutilatedMultiAddress<u8, u16>, ty);
}

#[test]
fn scale_compact_types_complex() {
trait Boo {
type B: TypeInfo;
}
impl Boo for u8 {
type B = bool;
}

#[allow(unused)]
#[derive(Encode, TypeInfo)]
struct A<T: Boo, U> {
one: PhantomData<T>,
two: PhantomData<U>,
#[codec(compact)]
three: T,
four: T::B,
}

let ty = Type::builder()
.path(Path::new("A", "derive"))
.type_params(tuple_meta_type![u8, u16])
.composite(
Fields::named()
.field_of::<PhantomData<u8>>("one", "PhantomData<T>")
.field_of::<PhantomData<u16>>("two", "PhantomData<U>")
.compact_of::<u8>("three", "T")
.field_of::<bool>("four", "T::B"),
);

assert_type!(A<u8, u16>, ty);
}

#[test]
fn whitespace_scrubbing_works() {
#[allow(unused)]
Expand All @@ -303,6 +336,7 @@ fn ui_tests() {
t.compile_fail("tests/ui/fail_unions.rs");
t.compile_fail("tests/ui/fail_use_codec_attrs_without_deriving_encode.rs");
t.compile_fail("tests/ui/fail_with_invalid_codec_attrs.rs");
t.compile_fail("tests/ui/fail_infinite_recursion.rs");
t.pass("tests/ui/pass_with_valid_codec_attrs.rs");
t.pass("tests/ui/pass_non_static_lifetime.rs");
t.pass("tests/ui/pass_self_referential.rs");
Expand Down
2 changes: 1 addition & 1 deletion test_suite/tests/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ fn test_struct_with_phantom() {
use scale_info::prelude::marker::PhantomData;
#[derive(TypeInfo)]
struct Struct<T> {
a: i32,
a: T,
b: PhantomData<T>,
}

Expand Down
22 changes: 22 additions & 0 deletions test_suite/tests/ui/fail_infinite_recursion.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
use scale_info::TypeInfo;
use scale::Encode;

#[derive(TypeInfo)]
struct Color<Hue>{hue: Hue}
#[derive(TypeInfo)]
struct Texture<Bump, Hump>{bump: Bump, hump: Hump}

#[allow(unused)]
#[derive(Encode, TypeInfo)]
struct Apple<T, U> {
#[codec(compact)]
one: Color<U>, // <– works with a "naked" generic, `U`, but not like this
two: Texture<T, U>,
}

fn assert_type_info<T: TypeInfo + 'static>() {}

fn main() {
// When this test fails it could mean that https://github.com/rust-lang/rust/issues/81785 is fixed
assert_type_info::<Apple<u8, u16>>();
}
6 changes: 6 additions & 0 deletions test_suite/tests/ui/fail_infinite_recursion.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
error[E0275]: overflow evaluating the requirement `_parity_scale_codec::Compact<_>: Decode`
|
= help: consider adding a `#![recursion_limit="256"]` attribute to your crate (`$CRATE`)
= note: required because of the requirements on the impl of `Decode` for `_parity_scale_codec::Compact<_>`
= note: 126 redundant requirements hidden
= note: required because of the requirements on the impl of `Decode` for `_parity_scale_codec::Compact<_>`