Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#![cfg_attr(not(feature = "std"), no_std)]

extern crate alloc;
extern crate proc_macro;

Expand Down Expand Up @@ -80,24 +78,30 @@ fn generate_type(input: TokenStream2) -> Result<TokenStream2> {
let parity_scale_codec = crate_name_ident("parity-scale-codec")?;

let ident = &ast.ident;

ast.generics
.lifetimes_mut()
.for_each(|l| *l = parse_quote!('static));

let (_, ty_generics, _) = ast.generics.split_for_impl();
let where_clause = trait_bounds::make_where_clause(
let (where_clause, compact_types) = trait_bounds::make_where_clause(
ident,
&ast.generics,
&ast.data,
&scale_info,
&parity_scale_codec,
)?;

let generic_type_ids = ast.generics.type_params().map(|ty| {
let ty_ident = &ty.ident;
quote! {
:: #scale_info ::meta_type::<#ty_ident>()
// If this type param is used as a parameter in a Compact field, then the call must be:
// ::scale_info::meta_type::<<#ty_ident as HasCompact>::Type>()
if compact_types.contains(ty_ident) {
quote! {
:: #scale_info ::meta_type::<<#ty_ident as :: #parity_scale_codec :: HasCompact>::Type>()
}
} else {
quote! {
:: #scale_info ::meta_type::<#ty_ident>()
}
}
});

Expand Down Expand Up @@ -274,7 +278,7 @@ fn generate_variant_type(data_enum: &DataEnum, scale_info: &Ident) -> TokenStrea

let variants = variants.into_iter().map(|v| {
let ident = &v.ident;
let v_name = quote! {stringify!(#ident) };
let v_name = quote! { stringify!(#ident) };
match v.fields {
Fields::Named(ref fs) => {
let fields = generate_fields(&fs.named);
Expand Down
146 changes: 105 additions & 41 deletions derive/src/trait_bounds.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,33 @@

use alloc::vec::Vec;
use proc_macro2::Ident;
use std::collections::HashSet;
use syn::{
parse_quote,
punctuated::Punctuated,
spanned::Spanned,
visit::Visit,
GenericArgument,
Generics,
Result,
Type,
WhereClause,
WherePredicate,
};

/// Generates a where clause for a `TypeInfo` impl, adding `TypeInfo + 'static` bounds to all
/// relevant generic types including associated types (e.g. `T::A: TypeInfo`), correctly dealing
/// with self-referential types.
/// Returns a tuple with a where clause and the set of type params that appear
/// in `Compact` fields/variants in the form `Something<T>` (type params that
/// appear as just `T` are not included)
pub fn make_where_clause<'a>(
input_ident: &'a Ident,
generics: &'a Generics,
data: &'a syn::Data,
scale_info: &Ident,
parity_scale_codec: &Ident,
) -> Result<WhereClause> {
) -> Result<(WhereClause, HashSet<Ident>)> {
let mut where_clause = generics.where_clause.clone().unwrap_or_else(|| {
WhereClause {
where_token: <syn::Token![where]>::default(),
Expand All @@ -48,86 +54,144 @@ pub fn make_where_clause<'a>(
.collect::<Vec<Ident>>();

if ty_params_ids.is_empty() {
return Ok(where_clause)
return Ok((where_clause, HashSet::new()))
}

let types = collect_types_to_bind(input_ident, data, &ty_params_ids)?;

types.into_iter().for_each(|(ty, is_compact)| {
// Compact types need extra bounds, T: HasCompact and <T as
// HasCompact>::Type: TypeInfo + 'static
let mut need_compact_bounds: HashSet<Ident> = HashSet::new();
let mut need_normal_bounds: HashSet<Ident> = HashSet::new();
let mut where_predicates: HashSet<WherePredicate> = HashSet::new();
types.into_iter().for_each(|(ty, is_compact, _)| {
// Compact types need two bounds:
// `T: HasCompact` and
// `<T as HasCompact>::Type: TypeInfo + 'static`
let generic_arguments = collect_generic_arguments(&ty);
if is_compact {
where_clause
.predicates
.push(parse_quote!(#ty : :: #parity_scale_codec ::HasCompact));
where_clause
.predicates
.push(parse_quote!(<#ty as :: #parity_scale_codec ::HasCompact>::Type : :: #scale_info ::TypeInfo + 'static));
where_predicates.insert(parse_quote!(#ty : :: #parity_scale_codec ::HasCompact));
where_predicates.insert(parse_quote!(<#ty as :: #parity_scale_codec ::HasCompact>::Type : :: #scale_info ::TypeInfo + 'static));
need_compact_bounds.extend(generic_arguments);
} else {
where_clause
.predicates
.push(parse_quote!(#ty : :: #scale_info ::TypeInfo + 'static));
where_predicates.insert(parse_quote!(#ty : :: #scale_info ::TypeInfo + 'static));
need_normal_bounds.extend(generic_arguments);
}
});

// Loop over the type params given to the type and add `TypeInfo` and
// `'static` bounds. Type params that are used in fields/variants that are
// `Compact` need only the `'static` bound.
// The reason we do this "double looping", first over generics used in
// fields/variants and then over the generics given in the type definition
// itself, is that a type can have a type parameter bound to a trait, e.g.
// `T: SomeTrait`, that is not used as-is in the fields/variants but whose
// associated type **is**. Something like this:
// `struct A<T: SomeTrait> { one: T::SomeAssoc, }`
//
// When deriving `TypeInfo` for `A`, the first loop above adds
// `T::SomeAssoc: TypeInfo + 'static`, but we also need
// `T: TypeInfo + 'static`.
// Hence the second loop.
generics.type_params().into_iter().for_each(|type_param| {
let ident = type_param.ident.clone();
// Find the type parameter in the list of types that appear in any of the
// fields/variants and check if it is used in a `Compact` field/variant.
// If yes, only add the `'static` bound, else add both `TypeInfo` and
// `'static` bounds.

let mut bounds = type_param.bounds.clone();
bounds.push(parse_quote!(:: #scale_info ::TypeInfo));
bounds.push(parse_quote!('static));
where_clause
.predicates
.push(parse_quote!( #ident : #bounds));
if need_compact_bounds.contains(&ident) {
bounds.push(parse_quote!('static));
} else {
bounds.push(parse_quote!(:: #scale_info ::TypeInfo));
bounds.push(parse_quote!('static));
}
where_predicates.insert(parse_quote!( #ident : #bounds));
});

Ok(where_clause)
where_predicates.extend(
need_compact_bounds
.iter()
.map(|tp_ident| parse_quote!( #tp_ident : 'static )),
);
where_predicates.extend(
need_normal_bounds.iter().map(
|tp_ident| parse_quote!( #tp_ident : :: #scale_info ::TypeInfo + 'static ),
),
);
where_clause.predicates.extend(where_predicates);

Ok((where_clause, need_compact_bounds))
}

/// Visits the ast and checks if the given type contains one of the given
/// Visits the ast for a [`syn::Type`] and checks if it contains one of the given
/// idents.
fn type_contains_idents(ty: &Type, idents: &[Ident]) -> bool {
fn type_contains_idents(ty: &Type, idents: &[Ident]) -> Option<Ident> {
struct ContainIdents<'a> {
result: bool,
idents: &'a [Ident],
result: Option<Ident>,
}

impl<'a, 'ast> Visit<'ast> for ContainIdents<'a> {
fn visit_ident(&mut self, i: &'ast Ident) {
if self.idents.iter().any(|id| id == i) {
self.result = true;
self.result = Some(i.clone());
}
}
}

let mut visitor = ContainIdents {
result: false,
idents,
result: None,
};
visitor.visit_type(ty);
visitor.result
}

/// Returns all types that must be added to the where clause with a boolean
/// Visit a `Type` and collect generic type params used.
/// Given `struct A<T> { thing: SomeType<T> }`, will include `T` in the set.
/// Given `struct A<T> { thing: T }`, will **not** include `T` in the set.
fn collect_generic_arguments(ty: &Type) -> HashSet<Ident> {
struct GenericsVisitor {
found: HashSet<Ident>,
}
impl<'ast> Visit<'ast> for GenericsVisitor {
fn visit_generic_argument(&mut self, g: &'ast GenericArgument) {
if let GenericArgument::Type(syn::Type::Path(syn::TypePath {
path: t, ..
})) = g
{
if let Some(ident) = t.get_ident() {
self.found.insert(ident.clone());
}
}
}
}
let mut visitor = GenericsVisitor {
found: HashSet::new(),
};
visitor.visit_type(ty);
visitor.found
}

/// Returns all types that must be added to the where clause, with a boolean
/// indicating if the field is [`scale::Compact`] or not.
fn collect_types_to_bind(
input_ident: &Ident,
data: &syn::Data,
ty_params: &[Ident],
) -> Result<Vec<(Type, bool)>> {
let types_from_fields = |fields: &Punctuated<syn::Field, _>| -> Vec<(Type, bool)> {
fields
.iter()
.filter(|field| {
) -> Result<Vec<(Type, bool, Option<Ident>)>> {
let types_from_fields =
|fields: &Punctuated<syn::Field, _>| -> Vec<(Type, bool, Option<Ident>)> {
fields.iter().fold(Vec::new(), |mut acc, field| {
// Only add a bound if the type uses a generic.
type_contains_idents(&field.ty, &ty_params)
&&
// Remove all remaining types that start/contain the input ident
let uses_generic = type_contains_idents(&field.ty, &ty_params);
// Find remaining types that start with/contain the input ident
// to not have them in the where clause.
!type_contains_idents(&field.ty, &[input_ident.clone()])
let uses_input_ident =
type_contains_idents(&field.ty, &[input_ident.clone()]);
if uses_generic.is_some() && uses_input_ident.is_none() {
acc.push((field.ty.clone(), super::is_compact(field), uses_generic));
}
acc
})
.map(|f| (f.ty.clone(), super::is_compact(f)))
.collect()
};
};

let types = match *data {
syn::Data::Struct(ref data) => {
Expand Down
34 changes: 34 additions & 0 deletions test_suite/tests/derive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,39 @@ fn scale_compact_types_work_in_enums() {
assert_type!(MutilatedMultiAddress<u8, u16>, ty);
}

#[test]
fn scale_compact_types_complex() {
trait Boo {
type B: TypeInfo;
}
impl Boo for u8 {
type B = bool;
}

#[allow(unused)]
#[derive(Encode, TypeInfo)]
struct A<T: Boo, U> {
one: PhantomData<T>,
two: PhantomData<U>,
#[codec(compact)]
three: T,
four: T::B,
}

let ty = Type::builder()
.path(Path::new("A", "derive"))
.type_params(tuple_meta_type![u8, u16])
.composite(
Fields::named()
.field_of::<PhantomData<u8>>("one", "PhantomData<T>")
.field_of::<PhantomData<u16>>("two", "PhantomData<U>")
.compact_of::<u8>("three", "T")
.field_of::<bool>("four", "T::B"),
);

assert_type!(A<u8, u16>, ty);
}

#[test]
fn whitespace_scrubbing_works() {
#[allow(unused)]
Expand All @@ -303,6 +336,7 @@ fn ui_tests() {
t.compile_fail("tests/ui/fail_unions.rs");
t.compile_fail("tests/ui/fail_use_codec_attrs_without_deriving_encode.rs");
t.compile_fail("tests/ui/fail_with_invalid_codec_attrs.rs");
t.compile_fail("tests/ui/fail_infinite_recursion.rs");
t.pass("tests/ui/pass_with_valid_codec_attrs.rs");
t.pass("tests/ui/pass_non_static_lifetime.rs");
t.pass("tests/ui/pass_self_referential.rs");
Expand Down
2 changes: 1 addition & 1 deletion test_suite/tests/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ fn test_struct_with_phantom() {
use scale_info::prelude::marker::PhantomData;
#[derive(TypeInfo)]
struct Struct<T> {
a: i32,
a: T,
b: PhantomData<T>,
}

Expand Down
22 changes: 22 additions & 0 deletions test_suite/tests/ui/fail_infinite_recursion.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
use scale_info::TypeInfo;
use scale::Encode;

#[derive(TypeInfo)]
struct Color<Hue>{hue: Hue}
#[derive(TypeInfo)]
struct Texture<Bump, Hump>{bump: Bump, hump: Hump}

#[allow(unused)]
#[derive(Encode, TypeInfo)]
struct Apple<T, U> {
#[codec(compact)]
one: Color<U>, // <– works with a "naked" generic, `U`, but not like this
two: Texture<T, U>,
}

fn assert_type_info<T: TypeInfo + 'static>() {}

fn main() {
// When this test fails it could mean that https://github.com/rust-lang/rust/issues/81785 is fixed
assert_type_info::<Apple<u8, u16>>();
}
6 changes: 6 additions & 0 deletions test_suite/tests/ui/fail_infinite_recursion.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
error[E0275]: overflow evaluating the requirement `_parity_scale_codec::Compact<_>: Decode`
|
= help: consider adding a `#![recursion_limit="256"]` attribute to your crate (`$CRATE`)
= note: required because of the requirements on the impl of `Decode` for `_parity_scale_codec::Compact<_>`
= note: 126 redundant requirements hidden
= note: required because of the requirements on the impl of `Decode` for `_parity_scale_codec::Compact<_>`