From 03dc96bff17bde94c98860996c6746acf3a64470 Mon Sep 17 00:00:00 2001 From: Georg Brandl Date: Fri, 25 Feb 2022 20:39:45 +0100 Subject: [PATCH] Py/PyAny: remove PartialEq impl and add is() (#2183) --- CHANGELOG.md | 6 ++++-- guide/src/migration.md | 13 +++++++++++++ src/conversion.rs | 2 +- src/err/mod.rs | 35 +++++++++++++++++++++------------- src/impl_/extract_argument.rs | 2 +- src/instance.rs | 16 +++++++++------- src/types/any.rs | 9 +++++++++ src/types/boolobject.rs | 4 ++-- src/types/dict.rs | 8 ++++---- src/types/iterator.rs | 2 +- src/types/mod.rs | 9 --------- src/types/sequence.rs | 4 ++-- src/types/string.rs | 6 +++--- tests/test_sequence.rs | 10 ++++++---- tests/test_sequence_pyproto.rs | 10 ++++++---- tests/test_serde.rs | 4 ++-- 16 files changed, 85 insertions(+), 55 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2ae1684a413..b2ae4ce53e2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -35,7 +35,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Add `PyMapping::contains` method (`in` operator for `PyMapping`). [#2133](https://github.com/PyO3/pyo3/pull/2133) - Add garbage collection magic methods `__traverse__` and `__clear__` to `#[pymethods]`. [#2159](https://github.com/PyO3/pyo3/pull/2159) - Add support for `from_py_with` on struct tuples and enums to override the default from-Python conversion. [#2181](https://github.com/PyO3/pyo3/pull/2181) -- Add `eq`, `ne`, `lt`, `le`, `gt`, `ge` methods to `PyAny` that wrap `rich_compare`. +- Add `eq`, `ne`, `lt`, `le`, `gt`, `ge` methods to `PyAny` that wrap `rich_compare`. [#2175](https://github.com/PyO3/pyo3/pull/2175) +- Add `Py::is` and `PyAny::is` methods to check for object identity. [#2183](https://github.com/PyO3/pyo3/pull/2183) ### Changed @@ -81,7 +82,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Removed - Remove all functionality deprecated in PyO3 0.14. [#2007](https://github.com/PyO3/pyo3/pull/2007) -- Remove `Default` impl for `PyMethodDef` [2166](https://github.com/PyO3/pyo3/pull/2166) +- Remove `Default` impl for `PyMethodDef`. [#2166](https://github.com/PyO3/pyo3/pull/2166) +- Remove `PartialEq` impl for `Py` and `PyAny` (use the new `is()` instead). [#2183](https://github.com/PyO3/pyo3/pull/2183) ### Fixed diff --git a/guide/src/migration.md b/guide/src/migration.md index 19fb91bb3db..f7cd6bf39a8 100644 --- a/guide/src/migration.md +++ b/guide/src/migration.md @@ -62,6 +62,19 @@ impl MyClass { } ``` +### Removed `PartialEq` for object wrappers + +The Python object wrappers `Py` and `PyAny` had implementations of `PartialEq` +so that `object_a == object_b` would compare the Python objects for pointer +equality, which corresponds to the `is` operator, not the `==` operator in +Python. This has been removed in favor of a new method: use +`object_a.is(object_b)`. This also has the advantage of not requiring the same +wrapper type for `object_a` and `object_b`; you can now directly compare a +`Py` with a `&PyAny` without having to convert. + +To check for Python object equality (the Python `==` operator), use the new +method `eq()`. + ### Container magic methods now match Python behavior In PyO3 0.15, `__getitem__`, `__setitem__` and `__delitem__` in `#[pymethods]` would generate only the _mapping_ implementation for a `#[pyclass]`. To match the Python behavior, these methods now generate both the _mapping_ **and** _sequence_ implementations. diff --git a/src/conversion.rs b/src/conversion.rs index 6d9e74b0120..276f52d10b5 100644 --- a/src/conversion.rs +++ b/src/conversion.rs @@ -561,7 +561,7 @@ mod tests { Python::with_gil(|py| { let list = PyList::new(py, &[1, 2, 3]); let val = unsafe { ::try_from_unchecked(list.as_ref()) }; - assert_eq!(list, val); + assert!(list.is(val)); }); } diff --git a/src/err/mod.rs b/src/err/mod.rs index f45f4952e75..286c02e03d1 100644 --- a/src/err/mod.rs +++ b/src/err/mod.rs @@ -187,7 +187,7 @@ impl PyErr { /// /// Python::with_gil(|py| { /// let err: PyErr = PyTypeError::new_err(("some type error",)); - /// assert_eq!(err.get_type(py), PyType::new::(py)); + /// assert!(err.get_type(py).is(PyType::new::(py))); /// }); /// ``` pub fn get_type<'py>(&'py self, py: Python<'py>) -> &'py PyType { @@ -231,7 +231,7 @@ impl PyErr { /// /// Python::with_gil(|py| { /// let err = PyTypeError::new_err(("some type error",)); - /// assert_eq!(err.traceback(py), None); + /// assert!(err.traceback(py).is_none()); /// }); /// ``` pub fn traceback<'py>(&'py self, py: Python<'py>) -> Option<&'py PyTraceback> { @@ -469,9 +469,12 @@ impl PyErr { /// Python::with_gil(|py| { /// let err: PyErr = PyTypeError::new_err(("some type error",)); /// let err_clone = err.clone_ref(py); - /// assert_eq!(err.get_type(py), err_clone.get_type(py)); - /// assert_eq!(err.value(py), err_clone.value(py)); - /// assert_eq!(err.traceback(py), err_clone.traceback(py)); + /// assert!(err.get_type(py).is(err_clone.get_type(py))); + /// assert!(err.value(py).is(err_clone.value(py))); + /// match err.traceback(py) { + /// None => assert!(err_clone.traceback(py).is_none()), + /// Some(tb) => assert!(err_clone.traceback(py).unwrap().is(tb)), + /// } /// }); /// ``` #[inline] @@ -706,7 +709,7 @@ fn exceptions_must_derive_from_base_exception(py: Python) -> PyErr { mod tests { use super::PyErrState; use crate::exceptions; - use crate::{PyErr, Python}; + use crate::{AsPyPointer, PyErr, Python}; #[test] fn no_error() { @@ -857,16 +860,22 @@ mod tests { fn deprecations() { let err = exceptions::PyValueError::new_err("an error"); Python::with_gil(|py| { - assert_eq!(err.ptype(py), err.get_type(py)); - assert_eq!(err.pvalue(py), err.value(py)); - assert_eq!(err.instance(py), err.value(py)); - assert_eq!(err.ptraceback(py), err.traceback(py)); + assert_eq!(err.ptype(py).as_ptr(), err.get_type(py).as_ptr()); + assert_eq!(err.pvalue(py).as_ptr(), err.value(py).as_ptr()); + assert_eq!(err.instance(py).as_ptr(), err.value(py).as_ptr()); + assert_eq!( + err.ptraceback(py).map(|t| t.as_ptr()), + err.traceback(py).map(|t| t.as_ptr()) + ); assert_eq!( - err.clone_ref(py).into_instance(py).as_ref(py), - err.value(py) + err.clone_ref(py).into_instance(py).as_ref(py).as_ptr(), + err.value(py).as_ptr() + ); + assert_eq!( + PyErr::from_instance(err.value(py)).value(py).as_ptr(), + err.value(py).as_ptr() ); - assert_eq!(PyErr::from_instance(err.value(py)).value(py), err.value(py)); }); } } diff --git a/src/impl_/extract_argument.rs b/src/impl_/extract_argument.rs index 9d5333a7c64..481003a60b0 100644 --- a/src/impl_/extract_argument.rs +++ b/src/impl_/extract_argument.rs @@ -102,7 +102,7 @@ pub fn from_py_with_with_default<'py, T>( #[doc(hidden)] #[cold] pub fn argument_extraction_error(py: Python, arg_name: &str, error: PyErr) -> PyErr { - if error.get_type(py) == PyTypeError::type_object(py) { + if error.get_type(py).is(PyTypeError::type_object(py)) { let remapped_error = PyTypeError::new_err(format!("argument '{}': {}", arg_name, error.value(py))); remapped_error.set_cause(py, error.cause(py)); diff --git a/src/instance.rs b/src/instance.rs index 937b9891714..773d62510f5 100644 --- a/src/instance.rs +++ b/src/instance.rs @@ -463,6 +463,15 @@ where } impl Py { + /// Returns whether `self` and `other` point to the same object. To compare + /// the equality of two objects (the `==` operator), use [`eq`](PyAny::eq). + /// + /// This is equivalent to the Python expression `self is other`. + #[inline] + pub fn is(&self, o: &U) -> bool { + self.as_ptr() == o.as_ptr() + } + /// Gets the reference count of the `ffi::PyObject` pointer. #[inline] pub fn get_refcnt(&self, _py: Python) -> isize { @@ -829,13 +838,6 @@ where } } -impl PartialEq for Py { - #[inline] - fn eq(&self, o: &Py) -> bool { - self.0 == o.0 - } -} - /// If the GIL is held this increments `self`'s reference count. /// Otherwise this registers the [`Py`]`` instance to have its reference count /// incremented the next time PyO3 acquires the GIL. diff --git a/src/types/any.rs b/src/types/any.rs index c1968525f22..4c217c2552d 100644 --- a/src/types/any.rs +++ b/src/types/any.rs @@ -87,6 +87,15 @@ impl PyAny { ::try_from(self) } + /// Returns whether `self` and `other` point to the same object. To compare + /// the equality of two objects (the `==` operator), use [`eq`](PyAny::eq). + /// + /// This is equivalent to the Python expression `self is other`. + #[inline] + pub fn is(&self, other: &T) -> bool { + self.as_ptr() == other.as_ptr() + } + /// Determines whether this object has the given attribute. /// /// This is equivalent to the Python expression `hasattr(self, attr_name)`. diff --git a/src/types/boolobject.rs b/src/types/boolobject.rs index 98fc233c0ca..4e17e2a5d24 100644 --- a/src/types/boolobject.rs +++ b/src/types/boolobject.rs @@ -69,7 +69,7 @@ mod tests { assert!(PyBool::new(py, true).is_true()); let t: &PyAny = PyBool::new(py, true).into(); assert!(t.extract::().unwrap()); - assert_eq!(true.to_object(py), PyBool::new(py, true).into()); + assert!(true.to_object(py).is(PyBool::new(py, true))); }); } @@ -79,7 +79,7 @@ mod tests { assert!(!PyBool::new(py, false).is_true()); let t: &PyAny = PyBool::new(py, false).into(); assert!(!t.extract::().unwrap()); - assert_eq!(false.to_object(py), PyBool::new(py, false).into()); + assert!(false.to_object(py).is(PyBool::new(py, false))); }); } } diff --git a/src/types/dict.rs b/src/types/dict.rs index 955204ebccf..4e2877def99 100644 --- a/src/types/dict.rs +++ b/src/types/dict.rs @@ -387,7 +387,7 @@ mod tests { Python::with_gil(|py| { let dict = [(7, 32)].into_py_dict(py); assert_eq!(32, dict.get_item(7i32).unwrap().extract::().unwrap()); - assert_eq!(None, dict.get_item(8i32)); + assert!(dict.get_item(8i32).is_none()); let map: HashMap = [(7, 32)].iter().cloned().collect(); assert_eq!(map, dict.extract().unwrap()); let map: BTreeMap = [(7, 32)].iter().cloned().collect(); @@ -426,7 +426,7 @@ mod tests { let ndict = dict.copy().unwrap(); assert_eq!(32, ndict.get_item(7i32).unwrap().extract::().unwrap()); - assert_eq!(None, ndict.get_item(8i32)); + assert!(ndict.get_item(8i32).is_none()); }); } @@ -464,7 +464,7 @@ mod tests { let ob = v.to_object(py); let dict = ::try_from(ob.as_ref(py)).unwrap(); assert_eq!(32, dict.get_item(7i32).unwrap().extract::().unwrap()); - assert_eq!(None, dict.get_item(8i32)); + assert!(dict.get_item(8i32).is_none()); }); } @@ -527,7 +527,7 @@ mod tests { let dict = ::try_from(ob.as_ref(py)).unwrap(); assert!(dict.del_item(7i32).is_ok()); assert_eq!(0, dict.len()); - assert_eq!(None, dict.get_item(7i32)); + assert!(dict.get_item(7i32).is_none()); }); } diff --git a/src/types/iterator.rs b/src/types/iterator.rs index 02d437c8bd4..97faf61a11b 100644 --- a/src/types/iterator.rs +++ b/src/types/iterator.rs @@ -213,7 +213,7 @@ def fibonacci(target): Python::with_gil(|py| { let obj: Py = vec![10, 20].to_object(py).as_ref(py).iter().unwrap().into(); let iter: &PyIterator = PyIterator::try_from(obj.as_ref(py)).unwrap(); - assert_eq!(obj, iter.into()); + assert!(obj.is(iter)); }); } diff --git a/src/types/mod.rs b/src/types/mod.rs index b85c31af9a5..fa3ca231f24 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -64,15 +64,6 @@ macro_rules! pyobject_native_type_base( unsafe { $crate::PyObject::from_borrowed_ptr(py, self.as_ptr()) } } } - - impl<$($generics,)*> ::std::cmp::PartialEq for $name { - #[inline] - fn eq(&self, o: &$name) -> bool { - use $crate::AsPyPointer; - - self.as_ptr() == o.as_ptr() - } - } }; ); diff --git a/src/types/sequence.rs b/src/types/sequence.rs index 844c80b5695..8f96eeb59f6 100644 --- a/src/types/sequence.rs +++ b/src/types/sequence.rs @@ -725,11 +725,11 @@ mod tests { let seq = ob.cast_as::(py).unwrap(); let rep_seq = seq.in_place_repeat(3).unwrap(); assert_eq!(6, seq.len().unwrap()); - assert_eq!(seq, rep_seq); + assert!(seq.is(rep_seq)); let conc_seq = seq.in_place_concat(seq).unwrap(); assert_eq!(12, seq.len().unwrap()); - assert_eq!(seq, conc_seq); + assert!(seq.is(conc_seq)); }); } diff --git a/src/types/string.rs b/src/types/string.rs index 22496e260f7..1b0e2f15d02 100644 --- a/src/types/string.rs +++ b/src/types/string.rs @@ -504,7 +504,7 @@ mod tests { let data = unsafe { s.data().unwrap() }; assert_eq!(data, PyStringData::Ucs1(b"f\xfe")); let err = data.to_string(py).unwrap_err(); - assert_eq!(err.get_type(py), PyUnicodeDecodeError::type_object(py)); + assert!(err.get_type(py).is(PyUnicodeDecodeError::type_object(py))); assert!(err .to_string() .contains("'utf-8' codec can't decode byte 0xfe in position 1")); @@ -546,7 +546,7 @@ mod tests { let data = unsafe { s.data().unwrap() }; assert_eq!(data, PyStringData::Ucs2(&[0xff22, 0xd800])); let err = data.to_string(py).unwrap_err(); - assert_eq!(err.get_type(py), PyUnicodeDecodeError::type_object(py)); + assert!(err.get_type(py).is(PyUnicodeDecodeError::type_object(py))); assert!(err .to_string() .contains("'utf-16' codec can't decode bytes in position 0-3")); @@ -585,7 +585,7 @@ mod tests { let data = unsafe { s.data().unwrap() }; assert_eq!(data, PyStringData::Ucs4(&[0x20000, 0xd800])); let err = data.to_string(py).unwrap_err(); - assert_eq!(err.get_type(py), PyUnicodeDecodeError::type_object(py)); + assert!(err.get_type(py).is(PyUnicodeDecodeError::type_object(py))); assert!(err .to_string() .contains("'utf-32' codec can't decode bytes in position 0-7")); diff --git a/tests/test_sequence.rs b/tests/test_sequence.rs index 2c9b2ef912b..1a1132a23d3 100644 --- a/tests/test_sequence.rs +++ b/tests/test_sequence.rs @@ -279,10 +279,12 @@ fn test_generic_list_set() { let list = PyCell::new(py, GenericList { items: vec![] }).unwrap(); py_run!(py, list, "list.items = [1, 2, 3]"); - assert_eq!( - list.borrow().items, - vec![1.to_object(py), 2.to_object(py), 3.to_object(py)] - ); + assert!(list + .borrow() + .items + .iter() + .zip(&[1u32, 2, 3]) + .all(|(a, b)| a.as_ref(py).eq(&b.into_py(py)).unwrap())); } #[pyclass] diff --git a/tests/test_sequence_pyproto.rs b/tests/test_sequence_pyproto.rs index bb89fe49770..3395d852b3e 100644 --- a/tests/test_sequence_pyproto.rs +++ b/tests/test_sequence_pyproto.rs @@ -263,10 +263,12 @@ fn test_generic_list_set() { let list = PyCell::new(py, GenericList { items: vec![] }).unwrap(); py_run!(py, list, "list.items = [1, 2, 3]"); - assert_eq!( - list.borrow().items, - vec![1.to_object(py), 2.to_object(py), 3.to_object(py)] - ); + assert!(list + .borrow() + .items + .iter() + .zip(&[1u32, 2, 3]) + .all(|(a, b)| a.as_ref(py).eq(&b.into_py(py)).unwrap())); } #[pyclass] diff --git a/tests/test_serde.rs b/tests/test_serde.rs index cb58528ac5f..f9c965982d5 100644 --- a/tests/test_serde.rs +++ b/tests/test_serde.rs @@ -59,12 +59,12 @@ mod test_serde { #[test] fn test_deserialize() { - let serialized = r#"{"username": "danya", "friends": + let serialized = r#"{"username": "danya", "friends": [{"username": "friend", "group": {"name": "danya's friends"}, "friends": []}]}"#; let user: User = serde_json::from_str(serialized).expect("failed to deserialize"); assert_eq!(user.username, "danya"); - assert_eq!(user.group, None); + assert!(user.group.is_none()); assert_eq!(user.friends.len(), 1usize); let friend = user.friends.get(0).unwrap();