Skip to content

Commit

Permalink
feat: add #[pyo3(allow_threads)] to release the GIL in (async) functions
Browse files Browse the repository at this point in the history
  • Loading branch information
wyfo committed Apr 6, 2024
1 parent a4aea23 commit eab7a6b
Show file tree
Hide file tree
Showing 16 changed files with 387 additions and 133 deletions.
1 change: 1 addition & 0 deletions newsfragments/3610.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add `#[pyo3(allow_threads)]` to release the GIL in (async) functions
1 change: 1 addition & 0 deletions pyo3-macros-backend/src/attributes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use syn::{
};

pub mod kw {
syn::custom_keyword!(allow_threads);
syn::custom_keyword!(annotation);
syn::custom_keyword!(attribute);
syn::custom_keyword!(cancel_handle);
Expand Down
75 changes: 49 additions & 26 deletions pyo3-macros-backend/src/method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use syn::{ext::IdentExt, spanned::Spanned, Ident, Result};

use crate::utils::Ctx;
use crate::{
attributes,
attributes::{TextSignatureAttribute, TextSignatureAttributeValue},
deprecations::{Deprecation, Deprecations},
params::{impl_arg_params, Holders},
Expand Down Expand Up @@ -278,6 +279,7 @@ pub struct FnSpec<'a> {
pub asyncness: Option<syn::Token![async]>,
pub unsafety: Option<syn::Token![unsafe]>,
pub deprecations: Deprecations<'a>,
pub allow_threads: Option<attributes::kw::allow_threads>,
}

pub fn parse_method_receiver(arg: &syn::FnArg) -> Result<SelfType> {
Expand Down Expand Up @@ -315,6 +317,7 @@ impl<'a> FnSpec<'a> {
text_signature,
name,
signature,
allow_threads,
..
} = options;

Expand Down Expand Up @@ -360,6 +363,7 @@ impl<'a> FnSpec<'a> {
asyncness: sig.asyncness,
unsafety: sig.unsafety,
deprecations,
allow_threads,
})
}

Expand Down Expand Up @@ -500,6 +504,16 @@ impl<'a> FnSpec<'a> {
bail_spanned!(arg2.name.span() => "`cancel_handle` may only be specified once");
}
}
if let Some(py_arg) = self.signature.arguments.iter().find(|arg| arg.py) {
ensure_spanned!(
self.asyncness.is_none(),
py_arg.ty.span() => "GIL token cannot be passed to async function"
);
ensure_spanned!(
self.allow_threads.is_none(),
py_arg.ty.span() => "GIL cannot be held in function annotated with `allow_threads`"
);
}

if self.asyncness.is_some() {
ensure_spanned!(
Expand All @@ -509,8 +523,18 @@ impl<'a> FnSpec<'a> {
}

let rust_call = |args: Vec<TokenStream>, holders: &mut Holders| {
let mut self_arg = || self.tp.self_arg(cls, ExtractErrorMode::Raise, holders, ctx);

let allow_threads = self.allow_threads.is_some();
let mut self_arg = || {
let self_arg = self.tp.self_arg(cls, ExtractErrorMode::Raise, holders, ctx);
if self_arg.is_empty() {
self_arg
} else {
let self_checker = holders.push_gil_refs_checker(self_arg.span());
quote! {
#pyo3_path::impl_::deprecations::inspect_type(#self_arg &#self_checker),
}
}
};
let call = if self.asyncness.is_some() {
let throw_callback = if cancel_handle.is_some() {
quote! { Some(__throw_callback) }
Expand Down Expand Up @@ -542,18 +566,7 @@ impl<'a> FnSpec<'a> {
}
_ => {
let self_arg = self_arg();
if self_arg.is_empty() {
quote! { function(#(#args),*) }
} else {
let self_checker = holders.push_gil_refs_checker(self_arg.span());
quote! {
function(
// NB #self_arg includes a comma, so none inserted here
#pyo3_path::impl_::deprecations::inspect_type(#self_arg &#self_checker),
#(#args),*
)
}
}
quote!(function(#self_arg #(#args),*))
}
};
let mut call = quote! {{
Expand All @@ -562,6 +575,7 @@ impl<'a> FnSpec<'a> {
#pyo3_path::intern!(py, stringify!(#python_name)),
#qualname_prefix,
#throw_callback,
#allow_threads,
async move { #pyo3_path::impl_::wrap::OkWrap::wrap(future.await) },
)
}};
Expand All @@ -573,20 +587,29 @@ impl<'a> FnSpec<'a> {
}};
}
call
} else {
} else if allow_threads {
let self_arg = self_arg();
if self_arg.is_empty() {
quote! { function(#(#args),*) }
let (self_arg_name, self_arg_decl) = if self_arg.is_empty() {
(quote!(), quote!())
} else {
let self_checker = holders.push_gil_refs_checker(self_arg.span());
quote! {
function(
// NB #self_arg includes a comma, so none inserted here
#pyo3_path::impl_::deprecations::inspect_type(#self_arg &#self_checker),
#(#args),*
)
}
}
(quote!(__self,), quote! { let __self = (#self_arg).0; })
};
let arg_names: Vec<Ident> = (0..args.len())
.map(|i| format_ident!("__arg{}", i))
.collect();
let arg_decls: Vec<TokenStream> = args
.into_iter()
.zip(&arg_names)
.map(|(arg, name)| quote! { let #name = #arg; })
.collect();
quote! {{
#self_arg_decl
#(#arg_decls)*
py.allow_threads(|| function(#self_arg_name #(#arg_names),*))
}}
} else {
let self_arg = self_arg();
quote!(function(#self_arg #(#args),*))
};
quotes::map_result_into_ptr(quotes::ok_wrap(call, ctx), ctx)
};
Expand Down
2 changes: 2 additions & 0 deletions pyo3-macros-backend/src/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1188,6 +1188,7 @@ fn complex_enum_struct_variant_new<'a>(
asyncness: None,
unsafety: None,
deprecations: Deprecations::new(ctx),
allow_threads: None,
};

crate::pymethod::impl_py_method_def_new(&variant_cls_type, &spec, ctx)
Expand All @@ -1213,6 +1214,7 @@ fn complex_enum_variant_field_getter<'a>(
asyncness: None,
unsafety: None,
deprecations: Deprecations::new(ctx),
allow_threads: None,
};

let property_type = crate::pymethod::PropertyType::Function {
Expand Down
12 changes: 10 additions & 2 deletions pyo3-macros-backend/src/pyfunction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ pub struct PyFunctionOptions {
pub signature: Option<SignatureAttribute>,
pub text_signature: Option<TextSignatureAttribute>,
pub krate: Option<CrateAttribute>,
pub allow_threads: Option<attributes::kw::allow_threads>,
}

impl Parse for PyFunctionOptions {
Expand All @@ -99,7 +100,8 @@ impl Parse for PyFunctionOptions {

while !input.is_empty() {
let lookahead = input.lookahead1();
if lookahead.peek(attributes::kw::name)
if lookahead.peek(attributes::kw::allow_threads)
|| lookahead.peek(attributes::kw::name)
|| lookahead.peek(attributes::kw::pass_module)
|| lookahead.peek(attributes::kw::signature)
|| lookahead.peek(attributes::kw::text_signature)
Expand All @@ -121,6 +123,7 @@ impl Parse for PyFunctionOptions {
}

pub enum PyFunctionOption {
AllowThreads(attributes::kw::allow_threads),
Name(NameAttribute),
PassModule(attributes::kw::pass_module),
Signature(SignatureAttribute),
Expand All @@ -131,7 +134,9 @@ pub enum PyFunctionOption {
impl Parse for PyFunctionOption {
fn parse(input: ParseStream<'_>) -> Result<Self> {
let lookahead = input.lookahead1();
if lookahead.peek(attributes::kw::name) {
if lookahead.peek(attributes::kw::allow_threads) {
input.parse().map(PyFunctionOption::AllowThreads)
} else if lookahead.peek(attributes::kw::name) {
input.parse().map(PyFunctionOption::Name)
} else if lookahead.peek(attributes::kw::pass_module) {
input.parse().map(PyFunctionOption::PassModule)
Expand Down Expand Up @@ -171,6 +176,7 @@ impl PyFunctionOptions {
}
for attr in attrs {
match attr {
PyFunctionOption::AllowThreads(allow_threads) => set_option!(allow_threads),
PyFunctionOption::Name(name) => set_option!(name),
PyFunctionOption::PassModule(pass_module) => set_option!(pass_module),
PyFunctionOption::Signature(signature) => set_option!(signature),
Expand Down Expand Up @@ -198,6 +204,7 @@ pub fn impl_wrap_pyfunction(
) -> syn::Result<TokenStream> {
check_generic(&func.sig)?;
let PyFunctionOptions {
allow_threads,
pass_module,
name,
signature,
Expand Down Expand Up @@ -247,6 +254,7 @@ pub fn impl_wrap_pyfunction(
python_name,
signature,
text_signature,
allow_threads,
asyncness: func.sig.asyncness,
unsafety: func.sig.unsafety,
deprecations: Deprecations::new(ctx),
Expand Down
1 change: 1 addition & 0 deletions pyo3-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ pub fn pymethods(attr: TokenStream, input: TokenStream) -> TokenStream {
/// | `#[pyo3(name = "...")]` | Defines the name of the function in Python. |
/// | `#[pyo3(text_signature = "...")]` | Defines the `__text_signature__` attribute of the function in Python. |
/// | `#[pyo3(pass_module)]` | Passes the module containing the function as a `&PyModule` first argument to the function. |
/// | `#[pyo3(allow_threads)]` | Release the GIL in the function body, or each time the returned future is polled for `async fn` |
///
/// For more on exposing functions see the [function section of the guide][1].
///
Expand Down
63 changes: 50 additions & 13 deletions src/coroutine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,54 @@ use crate::{
pub(crate) mod cancel;
mod waker;

use crate::marker::Ungil;
pub use cancel::CancelHandle;

const COROUTINE_REUSED_ERROR: &str = "cannot reuse already awaited coroutine";

trait CoroutineFuture: Send {
fn poll(self: Pin<&mut Self>, py: Python<'_>, waker: &Waker) -> Poll<PyResult<PyObject>>;
}

impl<F, T, E> CoroutineFuture for F
where
F: Future<Output = Result<T, E>> + Send,
T: IntoPy<PyObject> + Send,
E: Into<PyErr> + Send,
{
fn poll(self: Pin<&mut Self>, py: Python<'_>, waker: &Waker) -> Poll<PyResult<PyObject>> {
self.poll(&mut Context::from_waker(waker))
.map_ok(|obj| obj.into_py(py))
.map_err(Into::into)
}
}

struct AllowThreads<F> {
future: F,
}

impl<F, T, E> CoroutineFuture for AllowThreads<F>
where
F: Future<Output = Result<T, E>> + Send + Ungil,
T: IntoPy<PyObject> + Send + Ungil,
E: Into<PyErr> + Send + Ungil,
{
fn poll(self: Pin<&mut Self>, py: Python<'_>, waker: &Waker) -> Poll<PyResult<PyObject>> {
// SAFETY: future field is pinned when self is
let future = unsafe { self.map_unchecked_mut(|a| &mut a.future) };
py.allow_threads(|| future.poll(&mut Context::from_waker(waker)))
.map_ok(|obj| obj.into_py(py))
.map_err(Into::into)
}
}

/// Python coroutine wrapping a [`Future`].
#[pyclass(crate = "crate")]
pub struct Coroutine {
name: Option<Py<PyString>>,
qualname_prefix: Option<&'static str>,
throw_callback: Option<ThrowCallback>,
future: Option<Pin<Box<dyn Future<Output = PyResult<PyObject>> + Send>>>,
future: Option<Pin<Box<dyn CoroutineFuture>>>,
waker: Option<Arc<AsyncioWaker>>,
}

Expand All @@ -46,23 +83,23 @@ impl Coroutine {
name: Option<Py<PyString>>,
qualname_prefix: Option<&'static str>,
throw_callback: Option<ThrowCallback>,
allow_threads: bool,
future: F,
) -> Self
where
F: Future<Output = Result<T, E>> + Send + 'static,
T: IntoPy<PyObject>,
E: Into<PyErr>,
F: Future<Output = Result<T, E>> + Send + Ungil + 'static,
T: IntoPy<PyObject> + Send + Ungil,
E: Into<PyErr> + Send + Ungil,
{
let wrap = async move {
let obj = future.await.map_err(Into::into)?;
// SAFETY: GIL is acquired when future is polled (see `Coroutine::poll`)
Ok(obj.into_py(unsafe { Python::assume_gil_acquired() }))
};
Self {
name,
qualname_prefix,
throw_callback,
future: Some(Box::pin(wrap)),
future: Some(if allow_threads {
Box::pin(AllowThreads { future })
} else {
Box::pin(future)
}),
waker: None,
}
}
Expand All @@ -88,10 +125,10 @@ impl Coroutine {
} else {
self.waker = Some(Arc::new(AsyncioWaker::new()));
}
let waker = Waker::from(self.waker.clone().unwrap());
// poll the Rust future and forward its results if ready
// poll the future and forward its results if ready
// polling is UnwindSafe because the future is dropped in case of panic
let poll = || future_rs.as_mut().poll(&mut Context::from_waker(&waker));
let waker = Waker::from(self.waker.clone().unwrap());
let poll = || future_rs.as_mut().poll(py, &waker);
match panic::catch_unwind(panic::AssertUnwindSafe(poll)) {
Ok(Poll::Ready(res)) => {
self.close();
Expand Down
Loading

0 comments on commit eab7a6b

Please sign in to comment.