Skip to content

Commit fda65f2

Browse files
ascjonesdvdplm
andauthored
Derive TypeInfo for fields with associated types without bounds (#20)
* Add bounds for generic type param * Fmt * Remove redundant clone * Make clippy happy * Fmt * Remove readding of type to bounds * Unused imports * Doc tweaks (#37) * Docs and nits * Update derive/src/trait_bounds.rs Co-authored-by: Andrew Jones <[email protected]> * Adapt and simplify code to scale-info's needs Add a few trybuild tests (WIP) * Resolve todo * Fmt * Add ui test for Unions * Only run trybuild-tests on nightly * Unify and simply collect_types_to_bind() * Move a few trivial tests to trybuild tests instead * Add more trybuild tests * remove trivial test add trybuild test files * Make type_contains_idents more self-contained * Obey the fmt Co-authored-by: David <[email protected]>
1 parent bb031f6 commit fda65f2

13 files changed

+353
-6
lines changed

derive/src/lib.rs

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ extern crate alloc;
1818
extern crate proc_macro;
1919

2020
mod impl_wrapper;
21+
mod trait_bounds;
2122

2223
use alloc::{
2324
string::{
@@ -34,7 +35,6 @@ use syn::{
3435
Error,
3536
Result,
3637
},
37-
parse_quote,
3838
punctuated::Punctuated,
3939
token::Comma,
4040
Data,
@@ -66,12 +66,9 @@ fn generate(input: TokenStream2) -> Result<TokenStream2> {
6666
fn generate_type(input: TokenStream2) -> Result<TokenStream2> {
6767
let mut ast: DeriveInput = syn::parse2(input.clone())?;
6868

69-
ast.generics.type_params_mut().for_each(|p| {
70-
p.bounds.push(parse_quote!(::scale_info::TypeInfo));
71-
p.bounds.push(parse_quote!('static));
72-
});
73-
7469
let ident = &ast.ident;
70+
trait_bounds::add(ident, &mut ast.generics, &ast.data)?;
71+
7572
let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
7673
let generic_type_ids = ast.generics.type_params().map(|ty| {
7774
let ty_ident = &ty.ident;

derive/src/trait_bounds.rs

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
// Copyright 2019-2020 Parity Technologies (UK) Ltd.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
use alloc::vec::Vec;
16+
use proc_macro2::Ident;
17+
use syn::{
18+
parse_quote,
19+
punctuated::Punctuated,
20+
spanned::Spanned,
21+
visit::Visit,
22+
Generics,
23+
Result,
24+
Type,
25+
};
26+
27+
/// Adds a `TypeInfo + 'static` bound to all relevant generic types including
28+
/// associated types (e.g. `T::A: TypeInfo`), correctly dealing with
29+
/// self-referential types.
30+
pub fn add(input_ident: &Ident, generics: &mut Generics, data: &syn::Data) -> Result<()> {
31+
let ty_params = generics.type_params_mut().fold(Vec::new(), |mut acc, p| {
32+
p.bounds.push(parse_quote!(::scale_info::TypeInfo));
33+
p.bounds.push(parse_quote!('static));
34+
acc.push(p.ident.clone());
35+
acc
36+
});
37+
38+
if ty_params.is_empty() {
39+
return Ok(())
40+
}
41+
42+
let types = collect_types_to_bind(input_ident, data, &ty_params)?;
43+
44+
if !types.is_empty() {
45+
let where_clause = generics.make_where_clause();
46+
47+
types.into_iter().for_each(|ty| {
48+
where_clause
49+
.predicates
50+
.push(parse_quote!(#ty : ::scale_info::TypeInfo + 'static))
51+
});
52+
}
53+
54+
Ok(())
55+
}
56+
57+
/// Visits the ast and checks if the given type contains one of the given
58+
/// idents.
59+
fn type_contains_idents(ty: &Type, idents: &[Ident]) -> bool {
60+
struct ContainIdents<'a> {
61+
result: bool,
62+
idents: &'a [Ident],
63+
}
64+
65+
impl<'a, 'ast> Visit<'ast> for ContainIdents<'a> {
66+
fn visit_ident(&mut self, i: &'ast Ident) {
67+
if self.idents.iter().any(|id| id == i) {
68+
self.result = true;
69+
}
70+
}
71+
}
72+
73+
let mut visitor = ContainIdents {
74+
result: false,
75+
idents,
76+
};
77+
visitor.visit_type(ty);
78+
visitor.result
79+
}
80+
81+
/// Returns all types that must be added to the where clause with the respective
82+
/// trait bound.
83+
fn collect_types_to_bind(
84+
input_ident: &Ident,
85+
data: &syn::Data,
86+
ty_params: &[Ident],
87+
) -> Result<Vec<Type>> {
88+
let types_from_fields = |fields: &Punctuated<syn::Field, _>| -> Vec<syn::Type> {
89+
fields
90+
.iter()
91+
.filter(|field| {
92+
// Only add a bound if the type uses a generic.
93+
type_contains_idents(&field.ty, &ty_params)
94+
&&
95+
// Remove all remaining types that start/contain the input ident
96+
// to not have them in the where clause.
97+
!type_contains_idents(&field.ty, &[input_ident.clone()])
98+
})
99+
.map(|f| f.ty.clone())
100+
.collect()
101+
};
102+
103+
let types = match *data {
104+
syn::Data::Struct(ref data) => {
105+
match &data.fields {
106+
syn::Fields::Named(syn::FieldsNamed { named: fields, .. })
107+
| syn::Fields::Unnamed(syn::FieldsUnnamed {
108+
unnamed: fields, ..
109+
}) => types_from_fields(fields),
110+
syn::Fields::Unit => Vec::new(),
111+
}
112+
}
113+
114+
syn::Data::Enum(ref data) => {
115+
data.variants
116+
.iter()
117+
.flat_map(|variant| {
118+
match &variant.fields {
119+
syn::Fields::Named(syn::FieldsNamed {
120+
named: fields, ..
121+
})
122+
| syn::Fields::Unnamed(syn::FieldsUnnamed {
123+
unnamed: fields,
124+
..
125+
}) => types_from_fields(fields),
126+
syn::Fields::Unit => Vec::new(),
127+
}
128+
})
129+
.collect()
130+
}
131+
132+
syn::Data::Union(ref data) => {
133+
return Err(syn::Error::new(
134+
data.union_token.span(),
135+
"Union types are not supported.",
136+
))
137+
}
138+
};
139+
140+
Ok(types)
141+
}

test_suite/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,5 @@ scale = { package = "parity-scale-codec", version = "1.3", default-features = fa
1616
serde = "1.0"
1717
serde_json = "1.0"
1818
pretty_assertions = "0.6.1"
19+
trybuild = "1"
20+
rustversion = "1"

test_suite/tests/derive.rs

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,3 +179,40 @@ fn fields_with_type_alias() {
179179

180180
assert_type!(S, ty);
181181
}
182+
183+
#[test]
184+
fn associated_types_derive_without_bounds() {
185+
trait Types {
186+
type A;
187+
}
188+
#[allow(unused)]
189+
#[derive(TypeInfo)]
190+
struct Assoc<T: Types> {
191+
a: T::A,
192+
}
193+
194+
#[derive(TypeInfo)]
195+
enum ConcreteTypes {}
196+
impl Types for ConcreteTypes {
197+
type A = bool;
198+
}
199+
200+
let struct_type = Type::builder()
201+
.path(Path::new("Assoc", "derive"))
202+
.type_params(tuple_meta_type!(ConcreteTypes))
203+
.composite(Fields::named().field_of::<bool>("a", "T::A"));
204+
205+
assert_type!(Assoc<ConcreteTypes>, struct_type);
206+
}
207+
208+
#[rustversion::nightly]
209+
#[test]
210+
fn ui_tests() {
211+
let t = trybuild::TestCases::new();
212+
t.compile_fail("tests/ui/fail_missing_derive.rs");
213+
t.compile_fail("tests/ui/fail_non_static_lifetime.rs");
214+
t.compile_fail("tests/ui/fail_unions.rs");
215+
t.pass("tests/ui/pass_self_referential.rs");
216+
t.pass("tests/ui/pass_basic_generic_type.rs");
217+
t.pass("tests/ui/pass_complex_generic_self_referential_type.rs");
218+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
use scale_info::TypeInfo;
2+
3+
enum PawType<Paw> {
4+
Big(Paw),
5+
Small(Paw),
6+
}
7+
#[derive(TypeInfo)]
8+
struct Cat<Tail, Ear, Paw> {
9+
tail: Tail,
10+
ears: [Ear; 3],
11+
paws: PawType<Paw>,
12+
}
13+
14+
fn assert_type_info<T: TypeInfo + 'static>() {}
15+
16+
fn main() {
17+
assert_type_info::<Cat<bool, u8, u16>>();
18+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
error[E0277]: the trait bound `PawType<u16>: TypeInfo` is not satisfied
2+
--> $DIR/fail_missing_derive.rs:17:5
3+
|
4+
14 | fn assert_type_info<T: TypeInfo + 'static>() {}
5+
| -------- required by this bound in `assert_type_info`
6+
...
7+
17 | assert_type_info::<Cat<bool, u8, u16>>();
8+
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ the trait `TypeInfo` is not implemented for `PawType<u16>`
9+
|
10+
= note: required because of the requirements on the impl of `TypeInfo` for `Cat<bool, u8, u16>`
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
use scale_info::TypeInfo;
2+
3+
#[derive(TypeInfo)]
4+
struct Me<'a> {
5+
me: &'a Me<'a>,
6+
}
7+
8+
fn assert_type_info<T: TypeInfo + 'static>() {}
9+
10+
fn main() {
11+
assert_type_info::<Me>();
12+
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
error[E0477]: the type `Me<'a>` does not fulfill the required lifetime
2+
--> $DIR/fail_non_static_lifetime.rs:3:10
3+
|
4+
3 | #[derive(TypeInfo)]
5+
| ^^^^^^^^ in this macro invocation
6+
|
7+
::: $WORKSPACE/derive/src/lib.rs
8+
|
9+
| pub fn type_info(input: TokenStream) -> TokenStream {
10+
| --------------------------------------------------- in this expansion of `#[derive(TypeInfo)]`
11+
|
12+
= note: type must satisfy the static lifetime
13+
14+
error[E0477]: the type `&'a Me<'a>` does not fulfill the required lifetime
15+
--> $DIR/fail_non_static_lifetime.rs:3:10
16+
|
17+
3 | #[derive(TypeInfo)]
18+
| ^^^^^^^^ in this macro invocation
19+
|
20+
::: $WORKSPACE/derive/src/lib.rs
21+
|
22+
| pub fn type_info(input: TokenStream) -> TokenStream {
23+
| --------------------------------------------------- in this expansion of `#[derive(TypeInfo)]`
24+
|
25+
= note: type must satisfy the static lifetime

test_suite/tests/ui/fail_unions.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
use scale_info::TypeInfo;
2+
3+
#[derive(TypeInfo)]
4+
#[repr(C)]
5+
union Commonwealth {
6+
a: u8,
7+
b: f32,
8+
}
9+
10+
fn assert_type_info<T: TypeInfo + 'static>() {}
11+
12+
fn main() {
13+
assert_type_info::<Commonwealth>();
14+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
error: Unions not supported
2+
--> $DIR/fail_unions.rs:4:1
3+
|
4+
4 | / #[repr(C)]
5+
5 | | union Commonwealth {
6+
6 | | a: u8,
7+
7 | | b: f32,
8+
8 | | }
9+
| |_^
10+
11+
error[E0277]: the trait bound `Commonwealth: TypeInfo` is not satisfied
12+
--> $DIR/fail_unions.rs:13:24
13+
|
14+
10 | fn assert_type_info<T: TypeInfo + 'static>() {}
15+
| -------- required by this bound in `assert_type_info`
16+
...
17+
13 | assert_type_info::<Commonwealth>();
18+
| ^^^^^^^^^^^^ the trait `TypeInfo` is not implemented for `Commonwealth`

0 commit comments

Comments
 (0)