[Python-checkins] cpython (merge 3.3 -> default): merge

raymond.hettinger python-checkins at python.org
Sat Oct 5 01:53:38 CEST 2013


http://hg.python.org/cpython/rev/50e0ed353c7f
changeset:   85961:50e0ed353c7f
parent:      85959:f0416b2b5654
parent:      85960:e4cec1116e5c
user:        Raymond Hettinger <python at rcn.com>
date:        Fri Oct 04 16:52:39 2013 -0700
summary:
  merge

files:
  Lib/test/test_collections.py |  24 +++++++++++++++++++
  Modules/_collectionsmodule.c |  29 ++++++++++++-----------
  2 files changed, 39 insertions(+), 14 deletions(-)


diff --git a/Lib/test/test_collections.py b/Lib/test/test_collections.py
--- a/Lib/test/test_collections.py
+++ b/Lib/test/test_collections.py
@@ -852,6 +852,24 @@
 ### Counter
 ################################################################################
 
+class CounterSubclassWithSetItem(Counter):
+    # Test a counter subclass that overrides __setitem__
+    def __init__(self, *args, **kwds):
+        self.called = False
+        Counter.__init__(self, *args, **kwds)
+    def __setitem__(self, key, value):
+        self.called = True
+        Counter.__setitem__(self, key, value)
+
+class CounterSubclassWithGet(Counter):
+    # Test a counter subclass that overrides get()
+    def __init__(self, *args, **kwds):
+        self.called = False
+        Counter.__init__(self, *args, **kwds)
+    def get(self, key, default):
+        self.called = True
+        return Counter.get(self, key, default)
+
 class TestCounter(unittest.TestCase):
 
     def test_basics(self):
@@ -1059,6 +1077,12 @@
         self.assertEqual(m,
              OrderedDict([('a', 5), ('b', 2), ('r', 2), ('c', 1), ('d', 1)]))
 
+        # test fidelity to the pure python version
+        c = CounterSubclassWithSetItem('abracadabra')
+        self.assertTrue(c.called)
+        c = CounterSubclassWithGet('abracadabra')
+        self.assertTrue(c.called)
+
 
 ################################################################################
 ### OrderedDict
diff --git a/Modules/_collectionsmodule.c b/Modules/_collectionsmodule.c
--- a/Modules/_collectionsmodule.c
+++ b/Modules/_collectionsmodule.c
@@ -1763,17 +1763,17 @@
 static PyObject *
 _count_elements(PyObject *self, PyObject *args)
 {
-    _Py_IDENTIFIER(__getitem__);
+    _Py_IDENTIFIER(get);
     _Py_IDENTIFIER(__setitem__);
     PyObject *it, *iterable, *mapping, *oldval;
     PyObject *newval = NULL;
     PyObject *key = NULL;
     PyObject *zero = NULL;
     PyObject *one = NULL;
-    PyObject *mapping_get = NULL;
-    PyObject *mapping_getitem;
+    PyObject *bound_get = NULL;
+    PyObject *mapping_get;
+    PyObject *dict_get;
     PyObject *mapping_setitem;
-    PyObject *dict_getitem;
     PyObject *dict_setitem;
 
     if (!PyArg_UnpackTuple(args, "_count_elements", 2, 2, &mapping, &iterable))
@@ -1787,15 +1787,16 @@
     if (one == NULL)
         goto done;
 
-    mapping_getitem = _PyType_LookupId(Py_TYPE(mapping), &PyId___getitem__);
-    dict_getitem = _PyType_LookupId(&PyDict_Type, &PyId___getitem__);
+    /* Only take the fast path when get() and __setitem__()
+     * have not been overridden.
+     */
+    mapping_get = _PyType_LookupId(Py_TYPE(mapping), &PyId_get);
+    dict_get = _PyType_LookupId(&PyDict_Type, &PyId_get);
     mapping_setitem = _PyType_LookupId(Py_TYPE(mapping), &PyId___setitem__);
     dict_setitem = _PyType_LookupId(&PyDict_Type, &PyId___setitem__);
 
-    if (mapping_getitem != NULL &&
-        mapping_getitem == dict_getitem &&
-        mapping_setitem != NULL &&
-        mapping_setitem == dict_setitem) {
+    if (mapping_get != NULL && mapping_get == dict_get &&
+        mapping_setitem != NULL && mapping_setitem == dict_setitem) {
         while (1) {
             key = PyIter_Next(it);
             if (key == NULL)
@@ -1815,8 +1816,8 @@
             Py_DECREF(key);
         }
     } else {
-        mapping_get = PyObject_GetAttrString(mapping, "get");
-        if (mapping_get == NULL)
+        bound_get = PyObject_GetAttrString(mapping, "get");
+        if (bound_get == NULL)
             goto done;
 
         zero = PyLong_FromLong(0);
@@ -1827,7 +1828,7 @@
             key = PyIter_Next(it);
             if (key == NULL)
                 break;
-            oldval = PyObject_CallFunctionObjArgs(mapping_get, key, zero, NULL);
+            oldval = PyObject_CallFunctionObjArgs(bound_get, key, zero, NULL);
             if (oldval == NULL)
                 break;
             newval = PyNumber_Add(oldval, one);
@@ -1845,7 +1846,7 @@
     Py_DECREF(it);
     Py_XDECREF(key);
     Py_XDECREF(newval);
-    Py_XDECREF(mapping_get);
+    Py_XDECREF(bound_get);
     Py_XDECREF(zero);
     Py_XDECREF(one);
     if (PyErr_Occurred())

-- 
Repository URL: http://hg.python.org/cpython


More information about the Python-checkins mailing list