Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
feat: add #[pyo3(allow_threads)] to release the GIL in (async) functions
  • Loading branch information
wyfo committed Apr 25, 2024
commit 08948df6c10682f111fab3caa47c1530bee0d525
1 change: 1 addition & 0 deletions newsfragments/3610.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add `#[pyo3(allow_threads)]` to release the GIL in (async) functions
1 change: 1 addition & 0 deletions pyo3-macros-backend/src/attributes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use syn::{
};

pub mod kw {
syn::custom_keyword!(allow_threads);
syn::custom_keyword!(annotation);
syn::custom_keyword!(attribute);
syn::custom_keyword!(cancel_handle);
Expand Down
78 changes: 49 additions & 29 deletions pyo3-macros-backend/src/method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use syn::{ext::IdentExt, spanned::Spanned, Ident, Result};

use crate::utils::Ctx;
use crate::{
attributes,
attributes::{FromPyWithAttribute, TextSignatureAttribute, TextSignatureAttributeValue},
deprecations::{Deprecation, Deprecations},
params::{impl_arg_params, Holders},
Expand Down Expand Up @@ -379,6 +380,7 @@ pub struct FnSpec<'a> {
pub asyncness: Option<syn::Token![async]>,
pub unsafety: Option<syn::Token![unsafe]>,
pub deprecations: Deprecations<'a>,
pub allow_threads: Option<attributes::kw::allow_threads>,
}

pub fn parse_method_receiver(arg: &syn::FnArg) -> Result<SelfType> {
Expand Down Expand Up @@ -416,6 +418,7 @@ impl<'a> FnSpec<'a> {
text_signature,
name,
signature,
allow_threads,
..
} = options;

Expand Down Expand Up @@ -461,6 +464,7 @@ impl<'a> FnSpec<'a> {
asyncness: sig.asyncness,
unsafety: sig.unsafety,
deprecations,
allow_threads,
})
}

Expand Down Expand Up @@ -603,6 +607,21 @@ impl<'a> FnSpec<'a> {
bail_spanned!(name.span() => "`cancel_handle` may only be specified once");
}
}
if let Some(FnArg::Py(py_arg)) = self
.signature
.arguments
.iter()
.find(|arg| matches!(arg, FnArg::Py(_)))
{
ensure_spanned!(
self.asyncness.is_none(),
py_arg.ty.span() => "GIL token cannot be passed to async function"
);
ensure_spanned!(
self.allow_threads.is_none(),
py_arg.ty.span() => "GIL cannot be held in function annotated with `allow_threads`"
);
}

if self.asyncness.is_some() {
ensure_spanned!(
Expand All @@ -612,8 +631,21 @@ impl<'a> FnSpec<'a> {
}

let rust_call = |args: Vec<TokenStream>, holders: &mut Holders| {
let mut self_arg = || self.tp.self_arg(cls, ExtractErrorMode::Raise, holders, ctx);

let allow_threads = self.allow_threads.is_some();
let mut self_arg = || {
let self_arg = self.tp.self_arg(cls, ExtractErrorMode::Raise, holders, ctx);
if self_arg.is_empty() {
self_arg
} else {
let self_checker = holders.push_gil_refs_checker(self_arg.span());
quote! {
#pyo3_path::impl_::deprecations::inspect_type(#self_arg &#self_checker),
}
}
};
let arg_names = (0..args.len())
.map(|i| format_ident!("arg_{}", i))
.collect::<Vec<_>>();
let call = if self.asyncness.is_some() {
let throw_callback = if cancel_handle.is_some() {
quote! { Some(__throw_callback) }
Expand All @@ -625,9 +657,6 @@ impl<'a> FnSpec<'a> {
Some(cls) => quote!(Some(<#cls as #pyo3_path::PyTypeInfo>::NAME)),
None => quote!(None),
};
let arg_names = (0..args.len())
.map(|i| format_ident!("arg_{}", i))
.collect::<Vec<_>>();
let future = match self.tp {
FnType::Fn(SelfType::Receiver { mutable: false, .. }) => {
quote! {{
Expand All @@ -645,18 +674,7 @@ impl<'a> FnSpec<'a> {
}
_ => {
let self_arg = self_arg();
if self_arg.is_empty() {
quote! { function(#(#args),*) }
} else {
let self_checker = holders.push_gil_refs_checker(self_arg.span());
quote! {
function(
// NB #self_arg includes a comma, so none inserted here
#pyo3_path::impl_::deprecations::inspect_type(#self_arg &#self_checker),
#(#args),*
)
}
}
quote!(function(#self_arg #(#args),*))
}
};
let mut call = quote! {{
Expand All @@ -665,6 +683,7 @@ impl<'a> FnSpec<'a> {
#pyo3_path::intern!(py, stringify!(#python_name)),
#qualname_prefix,
#throw_callback,
#allow_threads,
async move { #pyo3_path::impl_::wrap::OkWrap::wrap(future.await) },
)
}};
Expand All @@ -676,20 +695,21 @@ impl<'a> FnSpec<'a> {
}};
}
call
} else {
} else if allow_threads {
let self_arg = self_arg();
if self_arg.is_empty() {
quote! { function(#(#args),*) }
let (self_arg_name, self_arg_decl) = if self_arg.is_empty() {
(quote!(), quote!())
} else {
let self_checker = holders.push_gil_refs_checker(self_arg.span());
quote! {
function(
// NB #self_arg includes a comma, so none inserted here
#pyo3_path::impl_::deprecations::inspect_type(#self_arg &#self_checker),
#(#args),*
)
}
}
(quote!(__self,), quote! { let (__self,) = (#self_arg); })
};
quote! {{
#self_arg_decl
#(let #arg_names = #args;)*
py.allow_threads(|| function(#self_arg_name #(#arg_names),*))
}}
} else {
let self_arg = self_arg();
quote!(function(#self_arg #(#args),*))
};
quotes::map_result_into_ptr(quotes::ok_wrap(call, ctx), ctx)
};
Expand Down
2 changes: 2 additions & 0 deletions pyo3-macros-backend/src/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1174,6 +1174,7 @@ fn complex_enum_struct_variant_new<'a>(
asyncness: None,
unsafety: None,
deprecations: Deprecations::new(ctx),
allow_threads: None,
};

crate::pymethod::impl_py_method_def_new(&variant_cls_type, &spec, ctx)
Expand All @@ -1199,6 +1200,7 @@ fn complex_enum_variant_field_getter<'a>(
asyncness: None,
unsafety: None,
deprecations: Deprecations::new(ctx),
allow_threads: None,
};

let property_type = crate::pymethod::PropertyType::Function {
Expand Down
12 changes: 10 additions & 2 deletions pyo3-macros-backend/src/pyfunction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ pub struct PyFunctionOptions {
pub signature: Option<SignatureAttribute>,
pub text_signature: Option<TextSignatureAttribute>,
pub krate: Option<CrateAttribute>,
pub allow_threads: Option<attributes::kw::allow_threads>,
}

impl Parse for PyFunctionOptions {
Expand All @@ -99,7 +100,8 @@ impl Parse for PyFunctionOptions {

while !input.is_empty() {
let lookahead = input.lookahead1();
if lookahead.peek(attributes::kw::name)
if lookahead.peek(attributes::kw::allow_threads)
|| lookahead.peek(attributes::kw::name)
|| lookahead.peek(attributes::kw::pass_module)
|| lookahead.peek(attributes::kw::signature)
|| lookahead.peek(attributes::kw::text_signature)
Expand All @@ -121,6 +123,7 @@ impl Parse for PyFunctionOptions {
}

pub enum PyFunctionOption {
AllowThreads(attributes::kw::allow_threads),
Name(NameAttribute),
PassModule(attributes::kw::pass_module),
Signature(SignatureAttribute),
Expand All @@ -131,7 +134,9 @@ pub enum PyFunctionOption {
impl Parse for PyFunctionOption {
fn parse(input: ParseStream<'_>) -> Result<Self> {
let lookahead = input.lookahead1();
if lookahead.peek(attributes::kw::name) {
if lookahead.peek(attributes::kw::allow_threads) {
input.parse().map(PyFunctionOption::AllowThreads)
} else if lookahead.peek(attributes::kw::name) {
input.parse().map(PyFunctionOption::Name)
} else if lookahead.peek(attributes::kw::pass_module) {
input.parse().map(PyFunctionOption::PassModule)
Expand Down Expand Up @@ -171,6 +176,7 @@ impl PyFunctionOptions {
}
for attr in attrs {
match attr {
PyFunctionOption::AllowThreads(allow_threads) => set_option!(allow_threads),
PyFunctionOption::Name(name) => set_option!(name),
PyFunctionOption::PassModule(pass_module) => set_option!(pass_module),
PyFunctionOption::Signature(signature) => set_option!(signature),
Expand Down Expand Up @@ -198,6 +204,7 @@ pub fn impl_wrap_pyfunction(
) -> syn::Result<TokenStream> {
check_generic(&func.sig)?;
let PyFunctionOptions {
allow_threads,
pass_module,
name,
signature,
Expand Down Expand Up @@ -247,6 +254,7 @@ pub fn impl_wrap_pyfunction(
python_name,
signature,
text_signature,
allow_threads,
asyncness: func.sig.asyncness,
unsafety: func.sig.unsafety,
deprecations: Deprecations::new(ctx),
Expand Down
1 change: 1 addition & 0 deletions pyo3-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ pub fn pymethods(attr: TokenStream, input: TokenStream) -> TokenStream {
/// | `#[pyo3(name = "...")]` | Defines the name of the function in Python. |
/// | `#[pyo3(text_signature = "...")]` | Defines the `__text_signature__` attribute of the function in Python. |
/// | `#[pyo3(pass_module)]` | Passes the module containing the function as a `&PyModule` first argument to the function. |
/// | `#[pyo3(allow_threads)]` | Release the GIL in the function body, or each time the returned future is polled for `async fn` |
///
/// For more on exposing functions see the [function section of the guide][1].
///
Expand Down
63 changes: 50 additions & 13 deletions src/coroutine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,54 @@ use crate::{
pub(crate) mod cancel;
mod waker;

use crate::marker::Ungil;
pub use cancel::CancelHandle;

const COROUTINE_REUSED_ERROR: &str = "cannot reuse already awaited coroutine";

trait CoroutineFuture: Send {
fn poll(self: Pin<&mut Self>, py: Python<'_>, waker: &Waker) -> Poll<PyResult<PyObject>>;
}

impl<F, T, E> CoroutineFuture for F
where
F: Future<Output = Result<T, E>> + Send,
T: IntoPy<PyObject> + Send,
E: Into<PyErr> + Send,
{
fn poll(self: Pin<&mut Self>, py: Python<'_>, waker: &Waker) -> Poll<PyResult<PyObject>> {
self.poll(&mut Context::from_waker(waker))
.map_ok(|obj| obj.into_py(py))
.map_err(Into::into)
}
}

struct AllowThreads<F> {
future: F,
}

impl<F, T, E> CoroutineFuture for AllowThreads<F>
where
F: Future<Output = Result<T, E>> + Send + Ungil,
T: IntoPy<PyObject> + Send + Ungil,
E: Into<PyErr> + Send + Ungil,
{
fn poll(self: Pin<&mut Self>, py: Python<'_>, waker: &Waker) -> Poll<PyResult<PyObject>> {
// SAFETY: future field is pinned when self is
let future = unsafe { self.map_unchecked_mut(|a| &mut a.future) };
py.allow_threads(|| future.poll(&mut Context::from_waker(waker)))
.map_ok(|obj| obj.into_py(py))
.map_err(Into::into)
}
}

/// Python coroutine wrapping a [`Future`].
#[pyclass(crate = "crate")]
pub struct Coroutine {
name: Option<Py<PyString>>,
qualname_prefix: Option<&'static str>,
throw_callback: Option<ThrowCallback>,
future: Option<Pin<Box<dyn Future<Output = PyResult<PyObject>> + Send>>>,
future: Option<Pin<Box<dyn CoroutineFuture>>>,
waker: Option<Arc<AsyncioWaker>>,
}

Expand All @@ -46,23 +83,23 @@ impl Coroutine {
name: Option<Py<PyString>>,
qualname_prefix: Option<&'static str>,
throw_callback: Option<ThrowCallback>,
allow_threads: bool,
future: F,
) -> Self
where
F: Future<Output = Result<T, E>> + Send + 'static,
T: IntoPy<PyObject>,
E: Into<PyErr>,
F: Future<Output = Result<T, E>> + Send + Ungil + 'static,
T: IntoPy<PyObject> + Send + Ungil,
E: Into<PyErr> + Send + Ungil,
{
let wrap = async move {
let obj = future.await.map_err(Into::into)?;
// SAFETY: GIL is acquired when future is polled (see `Coroutine::poll`)
Ok(obj.into_py(unsafe { Python::assume_gil_acquired() }))
};
Self {
name,
qualname_prefix,
throw_callback,
future: Some(Box::pin(wrap)),
future: Some(if allow_threads {
Box::pin(AllowThreads { future })
} else {
Box::pin(future)
}),
waker: None,
}
}
Expand All @@ -88,10 +125,10 @@ impl Coroutine {
} else {
self.waker = Some(Arc::new(AsyncioWaker::new()));
}
let waker = Waker::from(self.waker.clone().unwrap());
// poll the Rust future and forward its results if ready
// poll the future and forward its results if ready
// polling is UnwindSafe because the future is dropped in case of panic
let poll = || future_rs.as_mut().poll(&mut Context::from_waker(&waker));
let waker = Waker::from(self.waker.clone().unwrap());
let poll = || future_rs.as_mut().poll(py, &waker);
match panic::catch_unwind(panic::AssertUnwindSafe(poll)) {
Ok(Poll::Ready(res)) => {
self.close();
Expand Down
Loading