Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add #[pyo3(allow_threads)] to release the GIL in (async) functions #3610

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions newsfragments/3610.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add `#[pyo3(allow_threads)]` to release the GIL in (async) functions
1 change: 1 addition & 0 deletions pyo3-macros-backend/src/attributes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use syn::{
};

pub mod kw {
syn::custom_keyword!(allow_threads);
syn::custom_keyword!(annotation);
syn::custom_keyword!(attribute);
syn::custom_keyword!(cancel_handle);
Expand Down
78 changes: 49 additions & 29 deletions pyo3-macros-backend/src/method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use syn::{ext::IdentExt, spanned::Spanned, Ident, Result};

use crate::utils::Ctx;
use crate::{
attributes,
attributes::{FromPyWithAttribute, TextSignatureAttribute, TextSignatureAttributeValue},
deprecations::{Deprecation, Deprecations},
params::{impl_arg_params, Holders},
Expand Down Expand Up @@ -379,6 +380,7 @@ pub struct FnSpec<'a> {
pub asyncness: Option<syn::Token![async]>,
pub unsafety: Option<syn::Token![unsafe]>,
pub deprecations: Deprecations<'a>,
pub allow_threads: Option<attributes::kw::allow_threads>,
}

pub fn parse_method_receiver(arg: &syn::FnArg) -> Result<SelfType> {
Expand Down Expand Up @@ -416,6 +418,7 @@ impl<'a> FnSpec<'a> {
text_signature,
name,
signature,
allow_threads,
..
} = options;

Expand Down Expand Up @@ -461,6 +464,7 @@ impl<'a> FnSpec<'a> {
asyncness: sig.asyncness,
unsafety: sig.unsafety,
deprecations,
allow_threads,
})
}

Expand Down Expand Up @@ -603,6 +607,21 @@ impl<'a> FnSpec<'a> {
bail_spanned!(name.span() => "`cancel_handle` may only be specified once");
}
}
if let Some(FnArg::Py(py_arg)) = self
.signature
.arguments
.iter()
.find(|arg| matches!(arg, FnArg::Py(_)))
{
ensure_spanned!(
self.asyncness.is_none(),
py_arg.ty.span() => "GIL token cannot be passed to async function"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 it looks like this already doesn't work so this makes the error message nicer. One suggestion given the GIL is potentially an outdated idea:

Suggested change
py_arg.ty.span() => "GIL token cannot be passed to async function"
py_arg.ty.span() => "`Python<'_>` token cannot be passed to async functions"

);
ensure_spanned!(
self.allow_threads.is_none(),
py_arg.ty.span() => "GIL cannot be held in function annotated with `allow_threads`"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly here I think we want to just name the type:

Suggested change
py_arg.ty.span() => "GIL cannot be held in function annotated with `allow_threads`"
py_arg.ty.span() => "`Python<'_>` cannot be passed to a function annotated with `allow_threads`"

);
}

if self.asyncness.is_some() {
ensure_spanned!(
Expand All @@ -612,8 +631,21 @@ impl<'a> FnSpec<'a> {
}

let rust_call = |args: Vec<TokenStream>, holders: &mut Holders| {
let mut self_arg = || self.tp.self_arg(cls, ExtractErrorMode::Raise, holders, ctx);

let allow_threads = self.allow_threads.is_some();
let mut self_arg = || {
let self_arg = self.tp.self_arg(cls, ExtractErrorMode::Raise, holders, ctx);
if self_arg.is_empty() {
self_arg
} else {
let self_checker = holders.push_gil_refs_checker(self_arg.span());
quote! {
#pyo3_path::impl_::deprecations::inspect_type(#self_arg &#self_checker),
}
}
};
let arg_names = (0..args.len())
.map(|i| format_ident!("arg_{}", i))
.collect::<Vec<_>>();
let call = if self.asyncness.is_some() {
let throw_callback = if cancel_handle.is_some() {
quote! { Some(__throw_callback) }
Expand All @@ -625,9 +657,6 @@ impl<'a> FnSpec<'a> {
Some(cls) => quote!(Some(<#cls as #pyo3_path::PyTypeInfo>::NAME)),
None => quote!(None),
};
let arg_names = (0..args.len())
.map(|i| format_ident!("arg_{}", i))
.collect::<Vec<_>>();
let future = match self.tp {
FnType::Fn(SelfType::Receiver { mutable: false, .. }) => {
quote! {{
Expand All @@ -645,18 +674,7 @@ impl<'a> FnSpec<'a> {
}
_ => {
let self_arg = self_arg();
if self_arg.is_empty() {
quote! { function(#(#args),*) }
} else {
let self_checker = holders.push_gil_refs_checker(self_arg.span());
quote! {
function(
// NB #self_arg includes a comma, so none inserted here
#pyo3_path::impl_::deprecations::inspect_type(#self_arg &#self_checker),
#(#args),*
)
}
}
quote!(function(#self_arg #(#args),*))
}
};
let mut call = quote! {{
Expand All @@ -665,6 +683,7 @@ impl<'a> FnSpec<'a> {
#pyo3_path::intern!(py, stringify!(#python_name)),
#qualname_prefix,
#throw_callback,
#allow_threads,
async move { #pyo3_path::impl_::wrap::OkWrap::wrap(future.await) },
)
}};
Expand All @@ -676,20 +695,21 @@ impl<'a> FnSpec<'a> {
}};
}
call
} else {
} else if allow_threads {
let self_arg = self_arg();
if self_arg.is_empty() {
quote! { function(#(#args),*) }
let (self_arg_name, self_arg_decl) = if self_arg.is_empty() {
(quote!(), quote!())
} else {
let self_checker = holders.push_gil_refs_checker(self_arg.span());
quote! {
function(
// NB #self_arg includes a comma, so none inserted here
#pyo3_path::impl_::deprecations::inspect_type(#self_arg &#self_checker),
#(#args),*
)
}
}
(quote!(__self,), quote! { let (__self,) = (#self_arg); })
};
quote! {{
#self_arg_decl
#(let #arg_names = #args;)*
py.allow_threads(|| function(#self_arg_name #(#arg_names),*))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a question in my head here about after #3646 which of the allow_threads variants we will use here (cc @adamreichold). I think it would be possible to create a scoped TLS issue by setting up a TLS scope and then calling into Python. Similar thinking applies to the async case below. I guess we have to stick to the principle that the soundness is important and use the "safe but slow" form for this attribute. If users want to go unsafe they can release the GIL themselves. This feels the right way to do it, I think.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels the right way to do it, I think.

Yes, I think so too. Especially since we have already accumulated a significant amount of high risk changes for 0.22 and we can always optimize things in following releases.

}}
} else {
let self_arg = self_arg();
quote!(function(#self_arg #(#args),*))
};
quotes::map_result_into_ptr(quotes::ok_wrap(call, ctx), ctx)
};
Expand Down
2 changes: 2 additions & 0 deletions pyo3-macros-backend/src/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1174,6 +1174,7 @@ fn complex_enum_struct_variant_new<'a>(
asyncness: None,
unsafety: None,
deprecations: Deprecations::new(ctx),
allow_threads: None,
};

crate::pymethod::impl_py_method_def_new(&variant_cls_type, &spec, ctx)
Expand All @@ -1199,6 +1200,7 @@ fn complex_enum_variant_field_getter<'a>(
asyncness: None,
unsafety: None,
deprecations: Deprecations::new(ctx),
allow_threads: None,
};

let property_type = crate::pymethod::PropertyType::Function {
Expand Down
12 changes: 10 additions & 2 deletions pyo3-macros-backend/src/pyfunction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ pub struct PyFunctionOptions {
pub signature: Option<SignatureAttribute>,
pub text_signature: Option<TextSignatureAttribute>,
pub krate: Option<CrateAttribute>,
pub allow_threads: Option<attributes::kw::allow_threads>,
}

impl Parse for PyFunctionOptions {
Expand All @@ -99,7 +100,8 @@ impl Parse for PyFunctionOptions {

while !input.is_empty() {
let lookahead = input.lookahead1();
if lookahead.peek(attributes::kw::name)
if lookahead.peek(attributes::kw::allow_threads)
|| lookahead.peek(attributes::kw::name)
|| lookahead.peek(attributes::kw::pass_module)
|| lookahead.peek(attributes::kw::signature)
|| lookahead.peek(attributes::kw::text_signature)
Expand All @@ -121,6 +123,7 @@ impl Parse for PyFunctionOptions {
}

pub enum PyFunctionOption {
AllowThreads(attributes::kw::allow_threads),
Name(NameAttribute),
PassModule(attributes::kw::pass_module),
Signature(SignatureAttribute),
Expand All @@ -131,7 +134,9 @@ pub enum PyFunctionOption {
impl Parse for PyFunctionOption {
fn parse(input: ParseStream<'_>) -> Result<Self> {
let lookahead = input.lookahead1();
if lookahead.peek(attributes::kw::name) {
if lookahead.peek(attributes::kw::allow_threads) {
input.parse().map(PyFunctionOption::AllowThreads)
} else if lookahead.peek(attributes::kw::name) {
input.parse().map(PyFunctionOption::Name)
} else if lookahead.peek(attributes::kw::pass_module) {
input.parse().map(PyFunctionOption::PassModule)
Expand Down Expand Up @@ -171,6 +176,7 @@ impl PyFunctionOptions {
}
for attr in attrs {
match attr {
PyFunctionOption::AllowThreads(allow_threads) => set_option!(allow_threads),
PyFunctionOption::Name(name) => set_option!(name),
PyFunctionOption::PassModule(pass_module) => set_option!(pass_module),
PyFunctionOption::Signature(signature) => set_option!(signature),
Expand Down Expand Up @@ -198,6 +204,7 @@ pub fn impl_wrap_pyfunction(
) -> syn::Result<TokenStream> {
check_generic(&func.sig)?;
let PyFunctionOptions {
allow_threads,
pass_module,
name,
signature,
Expand Down Expand Up @@ -247,6 +254,7 @@ pub fn impl_wrap_pyfunction(
python_name,
signature,
text_signature,
allow_threads,
asyncness: func.sig.asyncness,
unsafety: func.sig.unsafety,
deprecations: Deprecations::new(ctx),
Expand Down
1 change: 1 addition & 0 deletions pyo3-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ pub fn pymethods(attr: TokenStream, input: TokenStream) -> TokenStream {
/// | `#[pyo3(name = "...")]` | Defines the name of the function in Python. |
/// | `#[pyo3(text_signature = "...")]` | Defines the `__text_signature__` attribute of the function in Python. |
/// | `#[pyo3(pass_module)]` | Passes the module containing the function as a `&PyModule` first argument to the function. |
/// | `#[pyo3(allow_threads)]` | Release the GIL in the function body, or each time the returned future is polled for `async fn` |
///
/// For more on exposing functions see the [function section of the guide][1].
///
Expand Down
63 changes: 50 additions & 13 deletions src/coroutine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,54 @@ use crate::{
pub(crate) mod cancel;
mod waker;

use crate::marker::Ungil;
pub use cancel::CancelHandle;

const COROUTINE_REUSED_ERROR: &str = "cannot reuse already awaited coroutine";

trait CoroutineFuture: Send {
fn poll(self: Pin<&mut Self>, py: Python<'_>, waker: &Waker) -> Poll<PyResult<PyObject>>;
}

impl<F, T, E> CoroutineFuture for F
where
F: Future<Output = Result<T, E>> + Send,
T: IntoPy<PyObject> + Send,
E: Into<PyErr> + Send,
{
fn poll(self: Pin<&mut Self>, py: Python<'_>, waker: &Waker) -> Poll<PyResult<PyObject>> {
self.poll(&mut Context::from_waker(waker))
.map_ok(|obj| obj.into_py(py))
.map_err(Into::into)
}
}

struct AllowThreads<F> {
future: F,
}

impl<F, T, E> CoroutineFuture for AllowThreads<F>
where
F: Future<Output = Result<T, E>> + Send + Ungil,
T: IntoPy<PyObject> + Send + Ungil,
E: Into<PyErr> + Send + Ungil,
{
fn poll(self: Pin<&mut Self>, py: Python<'_>, waker: &Waker) -> Poll<PyResult<PyObject>> {
// SAFETY: future field is pinned when self is
let future = unsafe { self.map_unchecked_mut(|a| &mut a.future) };
py.allow_threads(|| future.poll(&mut Context::from_waker(waker)))
.map_ok(|obj| obj.into_py(py))
.map_err(Into::into)
}
}

/// Python coroutine wrapping a [`Future`].
#[pyclass(crate = "crate")]
pub struct Coroutine {
name: Option<Py<PyString>>,
qualname_prefix: Option<&'static str>,
throw_callback: Option<ThrowCallback>,
future: Option<Pin<Box<dyn Future<Output = PyResult<PyObject>> + Send>>>,
future: Option<Pin<Box<dyn CoroutineFuture>>>,
waker: Option<Arc<AsyncioWaker>>,
}

Expand All @@ -46,23 +83,23 @@ impl Coroutine {
name: Option<Py<PyString>>,
qualname_prefix: Option<&'static str>,
throw_callback: Option<ThrowCallback>,
allow_threads: bool,
future: F,
) -> Self
where
F: Future<Output = Result<T, E>> + Send + 'static,
T: IntoPy<PyObject>,
E: Into<PyErr>,
F: Future<Output = Result<T, E>> + Send + Ungil + 'static,
T: IntoPy<PyObject> + Send + Ungil,
E: Into<PyErr> + Send + Ungil,
{
let wrap = async move {
let obj = future.await.map_err(Into::into)?;
// SAFETY: GIL is acquired when future is polled (see `Coroutine::poll`)
Ok(obj.into_py(unsafe { Python::assume_gil_acquired() }))
};
Self {
name,
qualname_prefix,
throw_callback,
future: Some(Box::pin(wrap)),
future: Some(if allow_threads {
Box::pin(AllowThreads { future })
} else {
Box::pin(future)
}),
waker: None,
}
}
Expand All @@ -88,10 +125,10 @@ impl Coroutine {
} else {
self.waker = Some(Arc::new(AsyncioWaker::new()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can split this here and do something like

let waker = self
    .waker
    .get_or_insert_with(|| Arc::new(AsyncioWaker::new()));

after the (possible) reset, to avoid the unwrap()s below?

}
let waker = Waker::from(self.waker.clone().unwrap());
// poll the Rust future and forward its results if ready
// poll the future and forward its results if ready
// polling is UnwindSafe because the future is dropped in case of panic
let poll = || future_rs.as_mut().poll(&mut Context::from_waker(&waker));
let waker = Waker::from(self.waker.clone().unwrap());
let poll = || future_rs.as_mut().poll(py, &waker);
match panic::catch_unwind(panic::AssertUnwindSafe(poll)) {
Ok(Poll::Ready(res)) => {
self.close();
Expand Down
Loading
Loading