diff --git a/Lib/test/test_set.py b/Lib/test/test_set.py index 554716aed1e98b..9bfd4bc7d63669 100644 --- a/Lib/test/test_set.py +++ b/Lib/test/test_set.py @@ -188,7 +188,10 @@ def test_symmetric_difference(self): self.assertEqual(type(i), self.basetype) self.assertRaises(PassThru, self.s.symmetric_difference, check_pass_thru()) self.assertRaises(TypeError, self.s.symmetric_difference, [[]]) - for C in set, frozenset, dict.fromkeys, str, list, tuple: + constructors = (set, frozenset, + dict.fromkeys, frozendict.fromkeys, + str, list, tuple) + for C in constructors: self.assertEqual(self.thetype('abcba').symmetric_difference(C('cdc')), set('abd')) self.assertEqual(self.thetype('abcba').symmetric_difference(C('efgfe')), set('abcefg')) self.assertEqual(self.thetype('abcba').symmetric_difference(C('ccb')), set('a')) @@ -1591,6 +1594,14 @@ def setUp(self): #------------------------------------------------------------------------------ +class TestOnlySetsFrozenDict(TestOnlySetsInBinaryOps, unittest.TestCase): + def setUp(self): + self.set = set((1, 2, 3)) + self.other = frozendict({1:2, 3:4}) + self.otherIsIterable = True + +#------------------------------------------------------------------------------ + class TestOnlySetsOperator(TestOnlySetsInBinaryOps, unittest.TestCase): def setUp(self): self.set = set((1, 2, 3)) diff --git a/Objects/setobject.c b/Objects/setobject.c index 5d4d1812282eed..8c64d12cbc7a6e 100644 --- a/Objects/setobject.c +++ b/Objects/setobject.c @@ -1183,10 +1183,14 @@ set_iter(PyObject *so) static int set_update_dict_lock_held(PySetObject *so, PyObject *other) { - assert(PyDict_CheckExact(other)); + assert(PyAnyDict_CheckExact(other)); _Py_CRITICAL_SECTION_ASSERT_OBJECT_LOCKED(so); - _Py_CRITICAL_SECTION_ASSERT_OBJECT_LOCKED(other); +#ifdef Py_DEBUG + if (!PyFrozenDict_CheckExact(other)) { + _Py_CRITICAL_SECTION_ASSERT_OBJECT_LOCKED(other); + } +#endif /* Do one big resize at the start, rather than * incrementally resizing as we insert new keys. Expect @@ -1242,7 +1246,7 @@ set_update_lock_held(PySetObject *so, PyObject *other) if (PyAnySet_Check(other)) { return set_merge_lock_held(so, other); } - else if (PyDict_CheckExact(other)) { + else if (PyAnyDict_CheckExact(other)) { return set_update_dict_lock_held(so, other); } return set_update_iterable_lock_held(so, other); @@ -1267,6 +1271,9 @@ set_update_local(PySetObject *so, PyObject *other) Py_END_CRITICAL_SECTION(); return rv; } + else if (PyFrozenDict_CheckExact(other)) { + return set_update_dict_lock_held(so, other); + } return set_update_iterable_lock_held(so, other); } @@ -1290,6 +1297,13 @@ set_update_internal(PySetObject *so, PyObject *other) Py_END_CRITICAL_SECTION2(); return rv; } + else if (PyFrozenDict_CheckExact(other)) { + int rv; + Py_BEGIN_CRITICAL_SECTION(so); + rv = set_update_dict_lock_held(so, other); + Py_END_CRITICAL_SECTION(); + return rv; + } else { int rv; Py_BEGIN_CRITICAL_SECTION(so); @@ -2030,7 +2044,7 @@ set_difference(PySetObject *so, PyObject *other) if (PyAnySet_Check(other)) { other_size = PySet_GET_SIZE(other); } - else if (PyDict_CheckExact(other)) { + else if (PyAnyDict_CheckExact(other)) { other_size = PyDict_GET_SIZE(other); } else { @@ -2047,7 +2061,7 @@ set_difference(PySetObject *so, PyObject *other) if (result == NULL) return NULL; - if (PyDict_CheckExact(other)) { + if (PyAnyDict_CheckExact(other)) { while (set_next(so, &pos, &entry)) { key = entry->key; hash = entry->hash; @@ -2169,7 +2183,11 @@ static int set_symmetric_difference_update_dict(PySetObject *so, PyObject *other) { _Py_CRITICAL_SECTION_ASSERT_OBJECT_LOCKED(so); - _Py_CRITICAL_SECTION_ASSERT_OBJECT_LOCKED(other); +#ifdef Py_DEBUG + if (!PyFrozenDict_CheckExact(other)) { + _Py_CRITICAL_SECTION_ASSERT_OBJECT_LOCKED(other); + } +#endif Py_ssize_t pos = 0; PyObject *key, *value; @@ -2243,6 +2261,11 @@ set_symmetric_difference_update_impl(PySetObject *so, PyObject *other) rv = set_symmetric_difference_update_dict(so, other); Py_END_CRITICAL_SECTION2(); } + else if (PyFrozenDict_CheckExact(other)) { + Py_BEGIN_CRITICAL_SECTION(so); + rv = set_symmetric_difference_update_dict(so, other); + Py_END_CRITICAL_SECTION(); + } else if (PyAnySet_Check(other)) { Py_BEGIN_CRITICAL_SECTION2(so, other); rv = set_symmetric_difference_update_set(so, (PySetObject *)other);