Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
support match_args for tuple enum
  • Loading branch information
newcomertv committed May 16, 2024
commit 91c457f02c3c52076c937559aa0586f46cf601d9
95 changes: 63 additions & 32 deletions pyo3-macros-backend/src/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -980,21 +980,21 @@ fn impl_complex_enum(
}

Ok(quote! {
#pytypeinfo
#pytypeinfo

#pyclass_impls
#pyclass_impls

#[doc(hidden)]
#[allow(non_snake_case)]
impl #cls {}
#[doc(hidden)]
#[allow(non_snake_case)]
impl #cls {}

#(#variant_cls_zsts)*
#(#variant_cls_zsts)*

#(#variant_cls_pytypeinfos)*
#(#variant_cls_pytypeinfos)*

#(#variant_cls_pyclass_impls)*
#(#variant_cls_pyclass_impls)*

#(#variant_cls_impls)*
#(#variant_cls_impls)*
})
}

Expand Down Expand Up @@ -1073,8 +1073,8 @@ fn impl_complex_enum_tuple_variant_field_getters(
variant_cls_type: &syn::Type,
variant_ident: &&Ident,
field_names: &mut Vec<Ident>,
fields_with_types: &mut Vec<TokenStream>,
) -> Result<(Vec<MethodAndMethodDef>, Vec<TokenStream>)> {
fields_types: &mut Vec<syn::Type>,
) -> Result<(Vec<MethodAndMethodDef>, Vec<syn::ImplItemFn>)> {
let Ctx { pyo3_path } = ctx;

let mut field_getters = vec![];
Expand All @@ -1083,7 +1083,6 @@ fn impl_complex_enum_tuple_variant_field_getters(
for (index, field) in variant.fields.iter().enumerate() {
let field_name = format_ident!("_{}", index);
let field_type = field.ty;
let field_with_type = quote! { #field_name : #field_type };

let field_getter =
complex_enum_variant_field_getter(&variant_cls_type, &field_name, field.span, ctx)?;
Expand All @@ -1099,7 +1098,7 @@ fn impl_complex_enum_tuple_variant_field_getters(
})
.collect();

let field_getter_impl = quote! {
let field_getter_impl: syn::ImplItemFn = parse_quote! {
fn #field_name(slf: #pyo3_path::PyRef<Self>) -> #pyo3_path::PyResult<#field_type> {
match &*slf.into_super() {
#enum_name::#variant_ident ( #(#field_access_tokens), *) => Ok(val.clone()),
Expand All @@ -1109,7 +1108,7 @@ fn impl_complex_enum_tuple_variant_field_getters(
};

field_names.push(field_name);
fields_with_types.push(field_with_type);
fields_types.push(field_type.clone());
field_getters.push(field_getter);
field_getter_impls.push(field_getter_impl);
}
Expand All @@ -1125,14 +1124,12 @@ fn impl_complex_enum_tuple_variant_len(
) -> Result<(MethodAndSlotDef, syn::ImplItemFn)> {
let Ctx { pyo3_path } = ctx;

let len_method_impl = quote! {
let mut len_method_impl: syn::ImplItemFn = parse_quote! {
fn __len__(slf: #pyo3_path::PyRef<Self>) -> #pyo3_path::PyResult<usize> {
Ok(#num_fields)
}
};

let mut len_method_impl: syn::ImplItemFn = syn::parse2(len_method_impl).unwrap();

let variant_len =
crate::pymethod::impl_py_slot_def(&variant_cls_type, ctx, &mut len_method_impl.sig)?;

Expand Down Expand Up @@ -1161,27 +1158,54 @@ fn impl_complex_enum_tuple_variant_getitem(
})
.collect();

let matcher = quote! {
let py = slf.py();
match idx {
#( #match_arms, )*
_ => Err(pyo3::exceptions::PyIndexError::new_err("tuple index out of range")),
}
};

let get_item_method_impl = quote! {
let mut get_item_method_impl: syn::ImplItemFn = parse_quote! {
fn __getitem__(slf: #pyo3_path::PyRef<Self>, idx: usize) -> #pyo3_path::PyResult< #pyo3_path::PyObject> {
#matcher
let py = slf.py();
match idx {
#( #match_arms, )*
_ => Err(pyo3::exceptions::PyIndexError::new_err("tuple index out of range")),
}
}
};

let mut get_item_method_impl: syn::ImplItemFn = syn::parse2(get_item_method_impl).unwrap();
let variant_getitem =
crate::pymethod::impl_py_slot_def(&variant_cls_type, ctx, &mut get_item_method_impl.sig)?;

Ok((variant_getitem, get_item_method_impl))
}

fn impl_complex_enum_tuple_variant_match_args(
ctx: &Ctx,
variant_cls_type: &syn::Type,
field_names: &mut Vec<Ident>,
) -> Result<(MethodAndMethodDef, syn::ImplItemConst)> {
let args_tp = field_names.iter().map(|_| {
quote! { &'static str }
});

let match_args_const_impl: syn::ImplItemConst = parse_quote! {
const __match_args__: ( #(#args_tp),* ) = (
#(stringify!(#field_names),)*
);
};

let spec = ConstSpec {
rust_ident: format_ident!("__match_args__"),
attributes: ConstAttributes {
is_class_attr: true,
name: Some(NameAttribute {
kw: syn::parse_quote! { name },
value: NameLitStr(format_ident!("__match_args__")),
}),
deprecations: Deprecations::new(ctx),
},
};

let variant_match_args = gen_py_const(variant_cls_type, &spec, ctx);

Ok((variant_match_args, match_args_const_impl))
}

fn impl_complex_enum_tuple_variant_cls(
enum_name: &syn::Ident,
variant: &PyClassEnumTupleVariant<'_>,
Expand All @@ -1196,16 +1220,16 @@ fn impl_complex_enum_tuple_variant_cls(

// represents the index of the field
let mut field_names: Vec<Ident> = vec![];
let mut fields_with_types: Vec<TokenStream> = vec![];
let mut field_types: Vec<syn::Type> = vec![];

let (field_getters, field_getter_impls) = impl_complex_enum_tuple_variant_field_getters(
let (mut field_getters, field_getter_impls) = impl_complex_enum_tuple_variant_field_getters(
ctx,
variant,
enum_name,
&variant_cls_type,
&variant_ident,
&mut field_names,
&mut fields_with_types,
&mut field_types,
)?;

let num_fields = variant.fields.len();
Expand All @@ -1220,11 +1244,16 @@ fn impl_complex_enum_tuple_variant_cls(

slots.push(variant_getitem);

let (variant_match_args, match_args_method_impl) =
impl_complex_enum_tuple_variant_match_args(ctx, &variant_cls_type, &mut field_names)?;

field_getters.push(variant_match_args);

let cls_impl = quote! {
#[doc(hidden)]
#[allow(non_snake_case)]
impl #variant_cls {
fn __pymethod_constructor__(py: #pyo3_path::Python<'_>, #(#fields_with_types,)*) -> #pyo3_path::PyClassInitializer<#variant_cls> {
fn __pymethod_constructor__(py: #pyo3_path::Python<'_>, #(#field_names : #field_types,)*) -> #pyo3_path::PyClassInitializer<#variant_cls> {
let base_value = #enum_name::#variant_ident ( #(#field_names,)* );
#pyo3_path::PyClassInitializer::from(base_value).add_subclass(#variant_cls)
}
Expand All @@ -1233,6 +1262,8 @@ fn impl_complex_enum_tuple_variant_cls(

#getitem_method_impl

#match_args_method_impl

#(#field_getter_impls)*
}
};
Expand Down
Loading