[Python-checkins] r87671 - in python/branches/py3k: Lib/test/test_collections.py Modules/_collectionsmodule.c

raymond.hettinger python-checkins at python.org
Mon Jan 3 03:12:02 CET 2011


Author: raymond.hettinger
Date: Mon Jan  3 03:12:02 2011
New Revision: 87671

Log:
Make C helper function more closely match the pure python version, and add tests.

Modified:
   python/branches/py3k/Lib/test/test_collections.py
   python/branches/py3k/Modules/_collectionsmodule.c

Modified: python/branches/py3k/Lib/test/test_collections.py
==============================================================================
--- python/branches/py3k/Lib/test/test_collections.py	(original)
+++ python/branches/py3k/Lib/test/test_collections.py	Mon Jan  3 03:12:02 2011
@@ -3,7 +3,7 @@
 import unittest, doctest, operator
 import inspect
 from test import support
-from collections import namedtuple, Counter, OrderedDict
+from collections import namedtuple, Counter, OrderedDict, _count_elements
 from test import mapping_tests
 import pickle, copy
 from random import randrange, shuffle
@@ -775,6 +775,19 @@
         c.subtract('aaaabbcce')
         self.assertEqual(c, Counter(a=-1, b=0, c=-1, d=1, e=-1))
 
+    def test_helper_function(self):
+        # two paths, one for real dicts and one for other mappings
+        elems = list('abracadabra')
+
+        d = dict()
+        _count_elements(d, elems)
+        self.assertEqual(d, {'a': 5, 'r': 2, 'b': 2, 'c': 1, 'd': 1})
+
+        m = OrderedDict()
+        _count_elements(m, elems)
+        self.assertEqual(m,
+             OrderedDict([('a', 5), ('b', 2), ('r', 2), ('c', 1), ('d', 1)]))
+
 class TestOrderedDict(unittest.TestCase):
 
     def test_init(self):

Modified: python/branches/py3k/Modules/_collectionsmodule.c
==============================================================================
--- python/branches/py3k/Modules/_collectionsmodule.c	(original)
+++ python/branches/py3k/Modules/_collectionsmodule.c	Mon Jan  3 03:12:02 2011
@@ -1536,41 +1536,68 @@
     if (!PyArg_UnpackTuple(args, "_count_elements", 2, 2, &mapping, &iterable))
         return NULL;
 
-    if (!PyDict_Check(mapping)) {
-        PyErr_SetString(PyExc_TypeError,
-            "Expected mapping argument to be a dictionary");
-        return NULL;
-    }
-
     it = PyObject_GetIter(iterable);
     if (it == NULL)
         return NULL;
+
     one = PyLong_FromLong(1);
     if (one == NULL) {
         Py_DECREF(it);
         return NULL;
     }
-    while (1) {
-        key = PyIter_Next(it);
-        if (key == NULL) {
-            if (PyErr_Occurred() && PyErr_ExceptionMatches(PyExc_StopIteration))
-                PyErr_Clear();
-            break;
+
+    if (PyDict_CheckExact(mapping)) {
+        while (1) {
+            key = PyIter_Next(it);
+            if (key == NULL) {
+                if (PyErr_Occurred() && PyErr_ExceptionMatches(PyExc_StopIteration))
+                    PyErr_Clear();
+                else
+                    break;
+            }
+            oldval = PyDict_GetItem(mapping, key);
+            if (oldval == NULL) {
+                if (PyDict_SetItem(mapping, key, one) == -1)
+                    break;
+            } else {
+                newval = PyNumber_Add(oldval, one);
+                if (newval == NULL)
+                    break;
+                if (PyDict_SetItem(mapping, key, newval) == -1)
+                    break;
+                Py_CLEAR(newval);
+            }
+            Py_DECREF(key);
         }
-        oldval = PyDict_GetItem(mapping, key);
-        if (oldval == NULL) {
-            if (PyDict_SetItem(mapping, key, one) == -1)
-                break;
-        } else {
-            newval = PyNumber_Add(oldval, one);
-            if (newval == NULL)
-                break;
-            if (PyDict_SetItem(mapping, key, newval) == -1)
+    } else {
+        while (1) {
+            key = PyIter_Next(it);
+            if (key == NULL) {
+                if (PyErr_Occurred() && PyErr_ExceptionMatches(PyExc_StopIteration))
+                    PyErr_Clear();
+                else
+                    break;
+            }
+            oldval = PyObject_GetItem(mapping, key);
+            if (oldval == NULL) {
+                if (!PyErr_Occurred() || !PyErr_ExceptionMatches(PyExc_KeyError))
+                    break;
+                PyErr_Clear();
+                Py_INCREF(one);
+                newval = one;
+            } else {
+                newval = PyNumber_Add(oldval, one);
+                Py_DECREF(oldval);
+                if (newval == NULL)
+                    break;
+            }
+            if (PyObject_SetItem(mapping, key, newval) == -1)
                 break;
             Py_CLEAR(newval);
+            Py_DECREF(key);
         }
-        Py_DECREF(key);
     }
+
     Py_DECREF(it);
     Py_XDECREF(key);
     Py_XDECREF(newval);


More information about the Python-checkins mailing list