Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
WIP
  • Loading branch information
dvdplm committed Feb 26, 2021
commit c2d565af11c75b70cf18b6669532fd10e224c308
15 changes: 8 additions & 7 deletions derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,31 +80,32 @@ fn generate_type(input: TokenStream2) -> Result<TokenStream2> {
let parity_scale_codec = crate_name_ident("parity-scale-codec")?;

let ident = &ast.ident;

println!("[generate_type] START {:?}", ident);
ast.generics
.lifetimes_mut()
.for_each(|l| *l = parse_quote!('static));

let (_, ty_generics, _) = ast.generics.split_for_impl();
let (where_clause, types) = trait_bounds::make_where_clause(
let (where_clause, compact_types) = trait_bounds::make_where_clause(
ident,
&ast.generics,
&ast.data,
&scale_info,
&parity_scale_codec,
)?;

println!("[generate_type] compact_types={:?}", compact_types);
let generic_type_ids = ast.generics.type_params().map(|ty| {
// If this is used in a Compact field, then the call must be: ::scale_info::meta_type::<<#ty_ident as HasCompact>::Type>()
let ty_ident = &ty.ident;
let is_compact = types.as_ref().unwrap().iter().filter(|infos| if let Some(i) = &infos.2 { i == ty_ident } else {false}).any(|infos| infos.1 );
if is_compact {
println!("[DDDD] Adding call to meta_type with as HasCompact");
// let is_compact = types.as_ref().unwrap().iter().filter(|infos| if let Some(i) = &infos.2 { i == ty_ident } else {false}).any(|infos| infos.1 );
// if is_compact {
if compact_types.contains(ty_ident) {
println!("[generate_type] Adding call to meta_type with as HasCompact");
quote! {
:: #scale_info ::meta_type::<<#ty_ident as :: #parity_scale_codec :: HasCompact>::Type>()
}
} else {
println!("[DDDD] Adding normal call to meta_type");
println!("[generate_type] Adding normal call to meta_type");
quote! {
:: #scale_info ::meta_type::<#ty_ident>()
}
Expand Down
281 changes: 232 additions & 49 deletions derive/src/trait_bounds.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::collections::HashSet;

use alloc::vec::Vec;
use proc_macro2::Ident;
use syn::{
Expand All @@ -29,12 +31,13 @@ use syn::{
/// relevant generic types including associated types (e.g. `T::A: TypeInfo`), correctly dealing
/// with self-referential types.
pub fn make_where_clause<'a>(
input_ident: &'a Ident,
generics: &'a Generics,
data: &'a syn::Data,
scale_info: &Ident,
parity_scale_codec: &Ident,
) -> Result<(WhereClause, Option<Vec<(Type, bool, Option<Ident>)>>)> {
input_ident: &'a Ident, // The type we're deriving TypeInfo for
generics: &'a Generics, // The type params to our type
data: &'a syn::Data, // The body of the type
scale_info: &Ident, // Crate name
parity_scale_codec: &Ident, /* Crate name
* ) -> Result<(WhereClause, Option<Vec<(Type, bool, Option<Ident>)>>)> { */
) -> Result<(WhereClause, HashSet<Ident>)> {
let mut where_clause = generics.where_clause.clone().unwrap_or_else(|| {
WhereClause {
where_token: <syn::Token![where]>::default(),
Expand All @@ -48,65 +51,162 @@ pub fn make_where_clause<'a>(
.collect::<Vec<Ident>>();

if ty_params_ids.is_empty() {
return Ok((where_clause, None))
return Ok((where_clause, HashSet::new()))
}

let types = collect_types_to_bind(input_ident, data, &ty_params_ids)?;
let types2 = types.clone();
println!("[DDD] nr types appearing in source={:?}", types.len());
// println!("[DDD] nr types appearing in source={:?}", types.len());
let mut need_compact_bounds: HashSet<Ident> = HashSet::new();
let mut need_normal_bounds: HashSet<Ident> = HashSet::new();
use syn::WherePredicate;
let mut where_predicates: HashSet<WherePredicate> = HashSet::new();
types.into_iter().for_each(|(ty, is_compact, _)| {
// Compact types need extra bounds, T: HasCompact and <T as
// HasCompact>::Type: TypeInfo + 'static
// Compact types need two bounds: `T: HasCompact` and `<T as
// HasCompact>::Type: TypeInfo + 'static`

// TODO: should pass in a `&mut HashSet` from outside the loop to ensure we avoid repetitions
let type_params_in_type = collect_type_params(&ty);
if is_compact {
println!("[DDD] ty={:?} is Compact, adding HasCompact bounds", ty);
where_clause
.predicates
.push(parse_quote!(#ty : :: #parity_scale_codec ::HasCompact));
where_clause
.predicates
.push(parse_quote!(<#ty as :: #parity_scale_codec ::HasCompact>::Type : :: #scale_info ::TypeInfo + 'static));
println!("[make_where_clause] is Compact, ty={:?} adding HasCompact bounds", ty);
println!("[make_where_clause] is Compact, type_params_in_type={:?}", type_params_in_type);
where_predicates.insert(parse_quote!(#ty : :: #parity_scale_codec ::HasCompact));
where_predicates.insert(parse_quote!(<#ty as :: #parity_scale_codec ::HasCompact>::Type : :: #scale_info ::TypeInfo + 'static));

// where_clause
// .predicates
// .push(parse_quote!(#ty : :: #parity_scale_codec ::HasCompact));
// where_clause
// .predicates
// .push(parse_quote!(<#ty as :: #parity_scale_codec ::HasCompact>::Type : :: #scale_info ::TypeInfo + 'static));
// for tp in type_params_in_type {
// where_clause.predicates.push(parse_quote!( #tp : 'static ))
// }
need_compact_bounds.extend(type_params_in_type);
} else {
println!("[DDD] ty={:?} is NOT Compact, adding TypeInfo bound", ty);
where_clause
.predicates
.push(parse_quote!(#ty : :: #scale_info ::TypeInfo + 'static));
// println!("[DDD] ty={:?} is NOT Compact, adding TypeInfo bound", ty);
where_predicates.insert(parse_quote!(#ty : :: #scale_info ::TypeInfo + 'static));

// where_clause
// .predicates
// .push(parse_quote!(#ty : :: #scale_info ::TypeInfo + 'static));
// for tp in type_params_in_type {
// where_clause.predicates.push(parse_quote!( #tp : :: #scale_info ::TypeInfo + 'static ))
// }
need_normal_bounds.extend(type_params_in_type);
}
});
println!(
"[make_where_clause] need_compact_bounds={:?}",
need_compact_bounds
);
println!(
"[make_where_clause] need_normal_bounds={:?}",
need_normal_bounds
);
// TODO: make issue:
// Loop over the type params given to the type and add `TypeInfo` and
// `'static` bounds. Type params that are used in fields/variants that are
// `Compact` need only the `'static` bound.
// Note that if a type parameter appears "naked" in a field/variant, e.g.
// `struct A<T> { one: T }`, we will end up adding the bounds twice, once
// above as it appears in a field and once again below as it's one of the
// type params of `A`. Type params that appear as parameters to other types,
// e.g. `struct A<T> { one: PhantomData<T> }` do not get double bounds.

// The reason we do this "double looping" over first generics in used in
// fields/variants and then generics is that a type can have a type
// parameter bound to a trait, e.g. `T: SomeTrait`, that is not used as-is
// in the fields/variants but whose associated type **is**. Something like
// this:
// `struct A<T: SomeTrait> { one: T::SomeAssoc, }`
// When deriving `TypeInfo` for `A`, the first loop above adds
// `T::SomeAssoc: TypeInfo + 'static`, but we also need
// `T: TypeInfo + 'static`.
// Hence the second loop.
generics.type_params().into_iter().for_each(|type_param| {
let ident = type_param.ident.clone();
// Find `ident` in `types`, check if it is Compact. If yes, skip, else add bounds
let is_compact = types2.iter().filter(|ty| if let Some(i) = &ty.2 { i == &ident } else {false}).any(|ty| ty.1 );
if is_compact {
println!("[DDD] ident {:?} is used for a field that is Compact; not adding TypeInfo bound", ident);
let mut bounds = type_param.bounds.clone();
bounds.push(parse_quote!('static));
where_clause
.predicates
.push(parse_quote!( #ident : #bounds));
// Find the type parameter in the list of types that appear in the
// fields/variants and check if it is used in a `Compact` field/variant.
// If yes, only add the `'static` bound, else add both `TypeInfo` and
// `'static` bounds.

let mut bounds = type_param.bounds.clone();
if need_compact_bounds.contains(&ident) {
bounds.push(parse_quote!('static));
} else {
// I wonder if we need further checks. As is this leads to double bounds, as the bound is added as part of the generics too. Investigate.
println!("[DDD] ident {:?} is used for a field that is NOT Compact. Adding bounds", ident);
let mut bounds = type_param.bounds.clone();
bounds.push(parse_quote!(:: #scale_info ::TypeInfo));
bounds.push(parse_quote!('static));
where_clause
.predicates
.push(parse_quote!( #ident : #bounds));
}
where_predicates.insert(parse_quote!( #ident : #bounds));

// let mut bounds = type_param.bounds.clone();
// if need_compact_bounds.contains(&ident) {
// bounds.push(parse_quote!('static));
// where_clause
// .predicates
// .push(parse_quote!( #ident : #bounds));
// } else {
// // let mut bounds = type_param.bounds.clone();
// bounds.push(parse_quote!(:: #scale_info ::TypeInfo));
// bounds.push(parse_quote!('static));
// where_clause
// .predicates
// .push(parse_quote!( #ident : #bounds));
// }

// let is_compact = types2
// .iter()
// .filter(|ty| {
// if let Some(i) = &ty.2 {
// i == &ident
// } else {
// false
// }
// })
// .any(|ty| ty.1 );
// if is_compact {
// // println!("[DDD] ident {:?} is used for a field that is Compact; not adding TypeInfo bound", ident);
// let mut bounds = type_param.bounds.clone();
// bounds.push(parse_quote!('static));
// where_clause
// .predicates
// .push(parse_quote!( #ident : #bounds));

// } else {
// // println!("[DDD] ident {:?} is used for a field that is NOT Compact. Adding bounds", ident);
// let mut bounds = type_param.bounds.clone();
// bounds.push(parse_quote!(:: #scale_info ::TypeInfo));
// bounds.push(parse_quote!('static));
// where_clause
// .predicates
// .push(parse_quote!( #ident : #bounds));
// }
});

Ok((where_clause, Some(types2)))
where_predicates.extend(
need_compact_bounds
.iter()
.map(|tp_ident| parse_quote!( #tp_ident : 'static )),
);
where_predicates.extend(
need_normal_bounds.iter().map(
|tp_ident| parse_quote!( #tp_ident : :: #scale_info ::TypeInfo + 'static ),
),
);
where_clause.predicates.extend(where_predicates);

Ok((where_clause, need_compact_bounds))
}

/// Visits the ast and checks if the given type contains one of the given
/// idents.
// TODO: what if the type contains more than one of the idents?
// TODO: return just `Option<Ident>`?
fn type_contains_idents(ty: &Type, idents: &[Ident]) -> (bool, Option<Ident>) {
struct ContainIdents<'a> {
result: (bool, Option<Ident>),
idents: &'a [Ident],
}

impl<'a, 'ast> Visit<'ast> for ContainIdents<'a> {
fn visit_ident(&mut self, i: &'ast Ident) {
if self.idents.iter().any(|id| id == i) {
Expand All @@ -123,27 +223,110 @@ fn type_contains_idents(ty: &Type, idents: &[Ident]) -> (bool, Option<Ident>) {
visitor.result
}

// Should call this on a syn:Field instead so we have access to the Attributes, that way we could collect both `Compact` and generics usage?
// fn collect_type_params(ty: &Type) -> Vec<Ident> {
fn collect_type_params(ty: &Type) -> HashSet<Ident> {
use syn::{
GenericArgument,
Variant,
};
println!("[collect_type_params] ty={:?}", ty);
struct Bla {
result: HashSet<Ident>,
}
impl<'ast> Visit<'ast> for Bla {
fn visit_generic_argument(&mut self, g: &'ast GenericArgument) {
println!("visit_generic_argument={:?}", g);
if let GenericArgument::Type(syn::Type::Path(syn::TypePath {
path: t, ..
})) = g
{
println!("visit_generic_argument, type={:?}", t.get_ident());

if let Some(ident) = t.get_ident() {
self.result.insert(ident.clone());
}
}
}
}
let mut visitor = Bla {
result: HashSet::new(),
};
visitor.visit_type(ty);
println!(
"[collect_type_params] found type params={:?}",
visitor.result
);

visitor.result
}

fn visit_fields(ident: &Ident, data: &syn::Data) {
use syn::{
Attribute,
Field,
Fields,
};
struct Bla {
result: HashSet<Ident>,
}

impl<'ast> Visit<'ast> for Bla {
fn visit_fields(&mut self, fs: &'ast Fields) {
println!("visit_fields={:#?}", fs);
for f in fs {
self.visit_field(f)
}
}
// Nope
fn visit_attribute(&mut self, node: &'ast Attribute) {
println!("visit_attribute={:#?}", node);
}
// Nope
fn visit_field(&mut self, i: &'ast Field) {
println!("visit_field (sing)={:#?}", i);
for attr in &i.attrs {
self.visit_attribute(&attr);
}
}
}

let mut visitor = Bla {
result: HashSet::new(),
};
visitor.visit_data(data);
}

/// Returns all types that must be added to the where clause with a boolean
/// indicating if the field is [`scale::Compact`] or not.
fn collect_types_to_bind(
input_ident: &Ident,
data: &syn::Data,
ty_params: &[Ident],
) -> Result<Vec<(Type, bool, Option<Ident>)>> {
let types_from_fields = |fields: &Punctuated<syn::Field, _>| -> Vec<(Type, bool, Option<Ident>)> {
fields
.iter()
.filter(|field| {
// Only add a bound if the type uses a generic.
type_contains_idents(&field.ty, &ty_params).0
// Experiment:
// visit_fields(input_ident, data);
let types_from_fields =
|fields: &Punctuated<syn::Field, _>| -> Vec<(Type, bool, Option<Ident>)> {
fields
.iter()
.filter(|field| {
// Only add a bound if the type uses a generic.
type_contains_idents(&field.ty, &ty_params).0
&&
// Remove all remaining types that start/contain the input ident
// to not have them in the where clause.
!type_contains_idents(&field.ty, &[input_ident.clone()]).0
})
.map(|f| (f.ty.clone(), super::is_compact(f), type_contains_idents(&f.ty, &ty_params).1))
.collect()
};
})
.map(|f| {
(
f.ty.clone(),
super::is_compact(f),
type_contains_idents(&f.ty, &ty_params).1,
)
})
.collect()
};

let types = match *data {
syn::Data::Struct(ref data) => {
Expand Down
Loading