diff --git a/newsfragments/3287.added.md b/newsfragments/3287.added.md new file mode 100644 index 00000000000..bde61a4b506 --- /dev/null +++ b/newsfragments/3287.added.md @@ -0,0 +1 @@ +`#[new]` methods may now return `Py` in order to return existing instances diff --git a/src/pyclass_init.rs b/src/pyclass_init.rs index 57f86665843..534e5fedd5c 100644 --- a/src/pyclass_init.rs +++ b/src/pyclass_init.rs @@ -1,7 +1,7 @@ //! Contains initialization utilities for `#[pyclass]`. use crate::callback::IntoPyCallbackOutput; use crate::impl_::pyclass::{PyClassBaseType, PyClassDict, PyClassThreadChecker, PyClassWeakRef}; -use crate::{ffi, PyCell, PyClass, PyErr, PyResult, Python}; +use crate::{ffi, IntoPyPointer, Py, PyCell, PyClass, PyErr, PyResult, Python}; use crate::{ ffi::PyTypeObject, pycell::{ @@ -134,9 +134,14 @@ impl PyObjectInit for PyNativeTypeInitializer { /// ); /// }); /// ``` -pub struct PyClassInitializer { - init: T, - super_init: ::Initializer, +pub struct PyClassInitializer(PyClassInitializerImpl); + +enum PyClassInitializerImpl { + Existing(Py), + New { + init: T, + super_init: ::Initializer, + }, } impl PyClassInitializer { @@ -144,7 +149,7 @@ impl PyClassInitializer { /// /// It is recommended to use `add_subclass` instead of this method for most usage. pub fn new(init: T, super_init: ::Initializer) -> Self { - Self { init, super_init } + Self(PyClassInitializerImpl::New { init, super_init }) } /// Constructs a new initializer from an initializer for the base class. @@ -242,13 +247,18 @@ impl PyObjectInit for PyClassInitializer { contents: MaybeUninit>, } - let obj = self.super_init.into_new_object(py, subtype)?; + let (init, super_init) = match self.0 { + PyClassInitializerImpl::Existing(value) => return Ok(value.into_ptr()), + PyClassInitializerImpl::New { init, super_init } => (init, super_init), + }; + + let obj = super_init.into_new_object(py, subtype)?; let cell: *mut PartiallyInitializedPyCell = obj as _; std::ptr::write( (*cell).contents.as_mut_ptr(), PyCellContents { - value: ManuallyDrop::new(UnsafeCell::new(self.init)), + value: ManuallyDrop::new(UnsafeCell::new(init)), borrow_checker: ::Storage::new(), thread_checker: T::ThreadChecker::new(), dict: T::Dict::INIT, @@ -284,6 +294,13 @@ where } } +impl From> for PyClassInitializer { + #[inline] + fn from(value: Py) -> PyClassInitializer { + PyClassInitializer(PyClassInitializerImpl::Existing(value)) + } +} + // Implementation used by proc macros to allow anything convertible to PyClassInitializer to be // the return value of pyclass #[new] method (optionally wrapped in `Result`). impl IntoPyCallbackOutput> for U diff --git a/tests/test_class_new.rs b/tests/test_class_new.rs index b9b0d152086..ff159c610f8 100644 --- a/tests/test_class_new.rs +++ b/tests/test_class_new.rs @@ -2,6 +2,7 @@ use pyo3::exceptions::PyValueError; use pyo3::prelude::*; +use pyo3::sync::GILOnceCell; use pyo3::types::IntoPyDict; #[pyclass] @@ -204,3 +205,62 @@ fn new_with_custom_error() { assert_eq!(err.to_string(), "ValueError: custom error"); }); } + +#[pyclass] +struct NewExisting { + #[pyo3(get)] + num: usize, +} + +#[pymethods] +impl NewExisting { + #[new] + fn new(py: pyo3::Python<'_>, val: usize) -> pyo3::Py { + static PRE_BUILT: GILOnceCell<[pyo3::Py; 2]> = GILOnceCell::new(); + let existing = PRE_BUILT.get_or_init(py, || { + [ + pyo3::PyCell::new(py, NewExisting { num: 0 }) + .unwrap() + .into(), + pyo3::PyCell::new(py, NewExisting { num: 1 }) + .unwrap() + .into(), + ] + }); + + if val < existing.len() { + return existing[val].clone_ref(py); + } + + pyo3::PyCell::new(py, NewExisting { num: val }) + .unwrap() + .into() + } +} + +#[test] +fn test_new_existing() { + Python::with_gil(|py| { + let typeobj = py.get_type::(); + + let obj1 = typeobj.call1((0,)).unwrap(); + let obj2 = typeobj.call1((0,)).unwrap(); + let obj3 = typeobj.call1((1,)).unwrap(); + let obj4 = typeobj.call1((1,)).unwrap(); + let obj5 = typeobj.call1((2,)).unwrap(); + let obj6 = typeobj.call1((2,)).unwrap(); + + assert!(obj1.getattr("num").unwrap().extract::().unwrap() == 0); + assert!(obj2.getattr("num").unwrap().extract::().unwrap() == 0); + assert!(obj3.getattr("num").unwrap().extract::().unwrap() == 1); + assert!(obj4.getattr("num").unwrap().extract::().unwrap() == 1); + assert!(obj5.getattr("num").unwrap().extract::().unwrap() == 2); + assert!(obj6.getattr("num").unwrap().extract::().unwrap() == 2); + + assert!(obj1.is(obj2)); + assert!(obj3.is(obj4)); + assert!(!obj1.is(obj3)); + assert!(!obj1.is(obj5)); + assert!(!obj5.is(obj6)); + }); +}