[Python-checkins] bpo-35431: Refactor math.comb() implementation. (GH-13725)

Serhiy Storchaka webhook-mailer at python.org
Sat Jun 1 15:09:06 EDT 2019


https://github.com/python/cpython/commit/2b843ac0ae745026ce39514573c5d075137bef65
commit: 2b843ac0ae745026ce39514573c5d075137bef65
branch: master
author: Serhiy Storchaka <storchaka at gmail.com>
committer: GitHub <noreply at github.com>
date: 2019-06-01T22:09:02+03:00
summary:

bpo-35431: Refactor math.comb() implementation. (GH-13725)

* Fixed some bugs.
* Added support for index-likes objects.
* Improved error messages.
* Cleaned up and optimized the code.
* Added more tests.

files:
M Doc/library/math.rst
M Lib/test/test_math.py
M Modules/clinic/mathmodule.c.h
M Modules/mathmodule.c

diff --git a/Doc/library/math.rst b/Doc/library/math.rst
index 5243970df806..206b06edd2a2 100644
--- a/Doc/library/math.rst
+++ b/Doc/library/math.rst
@@ -238,11 +238,11 @@ Number-theoretic and representation functions
    and without order.
 
    Also called the binomial coefficient. It is mathematically equal to the expression
-   ``n! / (k! (n - k)!)``. It is equivalent to the coefficient of k-th term in
+   ``n! / (k! (n - k)!)``. It is equivalent to the coefficient of the *k*-th term in the
    polynomial expansion of the expression ``(1 + x) ** n``.
 
    Raises :exc:`TypeError` if the arguments not integers.
-   Raises :exc:`ValueError` if the arguments are negative or if k > n.
+   Raises :exc:`ValueError` if the arguments are negative or if *k* > *n*.
 
    .. versionadded:: 3.8
 
diff --git a/Lib/test/test_math.py b/Lib/test/test_math.py
index 9da7f7c4e6e2..e27092eefd6e 100644
--- a/Lib/test/test_math.py
+++ b/Lib/test/test_math.py
@@ -1893,9 +1893,11 @@ def testComb(self):
         # Raises TypeError if any argument is non-integer or argument count is
         # not 2
         self.assertRaises(TypeError, comb, 10, 1.0)
+        self.assertRaises(TypeError, comb, 10, decimal.Decimal(1.0))
         self.assertRaises(TypeError, comb, 10, "1")
-        self.assertRaises(TypeError, comb, "10", 1)
         self.assertRaises(TypeError, comb, 10.0, 1)
+        self.assertRaises(TypeError, comb, decimal.Decimal(10.0), 1)
+        self.assertRaises(TypeError, comb, "10", 1)
 
         self.assertRaises(TypeError, comb, 10)
         self.assertRaises(TypeError, comb, 10, 1, 3)
@@ -1903,15 +1905,28 @@ def testComb(self):
 
         # Raises Value error if not k or n are negative numbers
         self.assertRaises(ValueError, comb, -1, 1)
-        self.assertRaises(ValueError, comb, -10*10, 1)
+        self.assertRaises(ValueError, comb, -2**1000, 1)
         self.assertRaises(ValueError, comb, 1, -1)
-        self.assertRaises(ValueError, comb, 1, -10*10)
+        self.assertRaises(ValueError, comb, 1, -2**1000)
 
         # Raises value error if k is greater than n
-        self.assertRaises(ValueError, comb, 1, 10**10)
-        self.assertRaises(ValueError, comb, 0, 1)
-
-
+        self.assertRaises(ValueError, comb, 1, 2)
+        self.assertRaises(ValueError, comb, 1, 2**1000)
+
+        n = 2**1000
+        self.assertEqual(comb(n, 0), 1)
+        self.assertEqual(comb(n, 1), n)
+        self.assertEqual(comb(n, 2), n * (n-1) // 2)
+        self.assertEqual(comb(n, n), 1)
+        self.assertEqual(comb(n, n-1), n)
+        self.assertEqual(comb(n, n-2), n * (n-1) // 2)
+        self.assertRaises((OverflowError, MemoryError), comb, n, n//2)
+
+        for n, k in (True, True), (True, False), (False, False):
+            self.assertEqual(comb(n, k), 1)
+            self.assertIs(type(comb(n, k)), int)
+        self.assertEqual(comb(MyIndexable(5), MyIndexable(2)), 10)
+        self.assertIs(type(comb(MyIndexable(5), MyIndexable(2))), int)
 
 
 def test_main():
diff --git a/Modules/clinic/mathmodule.c.h b/Modules/clinic/mathmodule.c.h
index cba791e2098f..92ec4bec9bf1 100644
--- a/Modules/clinic/mathmodule.c.h
+++ b/Modules/clinic/mathmodule.c.h
@@ -639,10 +639,10 @@ math_prod(PyObject *module, PyObject *const *args, Py_ssize_t nargs, PyObject *k
 }
 
 PyDoc_STRVAR(math_comb__doc__,
-"comb($module, /, n, k)\n"
+"comb($module, n, k, /)\n"
 "--\n"
 "\n"
-"Number of ways to choose *k* items from *n* items without repetition and without order.\n"
+"Number of ways to choose k items from n items without repetition and without order.\n"
 "\n"
 "Also called the binomial coefficient. It is mathematically equal to the expression\n"
 "n! / (k! * (n - k)!). It is equivalent to the coefficient of k-th term in\n"
@@ -652,38 +652,26 @@ PyDoc_STRVAR(math_comb__doc__,
 "Raises ValueError if the arguments are negative or if k > n.");
 
 #define MATH_COMB_METHODDEF    \
-    {"comb", (PyCFunction)(void(*)(void))math_comb, METH_FASTCALL|METH_KEYWORDS, math_comb__doc__},
+    {"comb", (PyCFunction)(void(*)(void))math_comb, METH_FASTCALL, math_comb__doc__},
 
 static PyObject *
 math_comb_impl(PyObject *module, PyObject *n, PyObject *k);
 
 static PyObject *
-math_comb(PyObject *module, PyObject *const *args, Py_ssize_t nargs, PyObject *kwnames)
+math_comb(PyObject *module, PyObject *const *args, Py_ssize_t nargs)
 {
     PyObject *return_value = NULL;
-    static const char * const _keywords[] = {"n", "k", NULL};
-    static _PyArg_Parser _parser = {NULL, _keywords, "comb", 0};
-    PyObject *argsbuf[2];
     PyObject *n;
     PyObject *k;
 
-    args = _PyArg_UnpackKeywords(args, nargs, NULL, kwnames, &_parser, 2, 2, 0, argsbuf);
-    if (!args) {
-        goto exit;
-    }
-    if (!PyLong_Check(args[0])) {
-        _PyArg_BadArgument("comb", 1, "int", args[0]);
+    if (!_PyArg_CheckPositional("comb", nargs, 2, 2)) {
         goto exit;
     }
     n = args[0];
-    if (!PyLong_Check(args[1])) {
-        _PyArg_BadArgument("comb", 2, "int", args[1]);
-        goto exit;
-    }
     k = args[1];
     return_value = math_comb_impl(module, n, k);
 
 exit:
     return return_value;
 }
-/*[clinic end generated code: output=00aa76356759617a input=a9049054013a1b77]*/
+/*[clinic end generated code: output=6709521e5e1d90ec input=a9049054013a1b77]*/
diff --git a/Modules/mathmodule.c b/Modules/mathmodule.c
index 007a8801429c..bea4607b9be1 100644
--- a/Modules/mathmodule.c
+++ b/Modules/mathmodule.c
@@ -3001,10 +3001,11 @@ math_prod_impl(PyObject *module, PyObject *iterable, PyObject *start)
 /*[clinic input]
 math.comb
 
-    n: object(subclass_of='&PyLong_Type')
-    k: object(subclass_of='&PyLong_Type')
+    n: object
+    k: object
+    /
 
-Number of ways to choose *k* items from *n* items without repetition and without order.
+Number of ways to choose k items from n items without repetition and without order.
 
 Also called the binomial coefficient. It is mathematically equal to the expression
 n! / (k! * (n - k)!). It is equivalent to the coefficient of k-th term in
@@ -3017,103 +3018,109 @@ Raises ValueError if the arguments are negative or if k > n.
 
 static PyObject *
 math_comb_impl(PyObject *module, PyObject *n, PyObject *k)
-/*[clinic end generated code: output=bd2cec8d854f3493 input=565f340f98efb5b5]*/
+/*[clinic end generated code: output=bd2cec8d854f3493 input=2f336ac9ec8242f9]*/
 {
-    PyObject *val = NULL,
-        *temp_obj1 = NULL,
-        *temp_obj2 = NULL,
-        *dump_var = NULL;
+    PyObject *result = NULL, *factor = NULL, *temp;
     int overflow, cmp;
-    long long i, terms;
+    long long i, factors;
 
-    cmp = PyObject_RichCompareBool(n, k, Py_LT);
-    if (cmp < 0) {
-        goto fail_comb;
+    n = PyNumber_Index(n);
+    if (n == NULL) {
+        return NULL;
     }
-    else if (cmp > 0) {
-        PyErr_Format(PyExc_ValueError,
-                     "n must be an integer greater than or equal to k");
-        goto fail_comb;
+    k = PyNumber_Index(k);
+    if (k == NULL) {
+        Py_DECREF(n);
+        return NULL;
     }
 
-    /* b = min(b, a - b) */
-    dump_var = PyNumber_Subtract(n, k);
-    if (dump_var == NULL) {
-        goto fail_comb;
+    if (Py_SIZE(n) < 0) {
+        PyErr_SetString(PyExc_ValueError,
+                        "n must be a non-negative integer");
+        goto error;
     }
-    cmp = PyObject_RichCompareBool(k, dump_var, Py_GT);
-    if (cmp < 0) {
-        goto fail_comb;
+    /* k = min(k, n - k) */
+    temp = PyNumber_Subtract(n, k);
+    if (temp == NULL) {
+        goto error;
     }
-    else if (cmp > 0) {
-        k = dump_var;
-        dump_var = NULL;
+    if (Py_SIZE(temp) < 0) {
+        Py_DECREF(temp);
+        PyErr_SetString(PyExc_ValueError,
+                        "k must be an integer less than or equal to n");
+        goto error;
+    }
+    cmp = PyObject_RichCompareBool(k, temp, Py_GT);
+    if (cmp > 0) {
+        Py_SETREF(k, temp);
     }
     else {
-        Py_DECREF(dump_var);
-        dump_var = NULL;
+        Py_DECREF(temp);
+        if (cmp < 0) {
+            goto error;
+        }
     }
 
-    terms = PyLong_AsLongLongAndOverflow(k, &overflow);
-    if (terms < 0 && PyErr_Occurred()) {
-        goto fail_comb;
-    }
-    else if (overflow > 0) {
+    factors = PyLong_AsLongLongAndOverflow(k, &overflow);
+    if (overflow > 0) {
         PyErr_Format(PyExc_OverflowError,
-                     "minimum(n - k, k) must not exceed %lld",
+                     "min(n - k, k) must not exceed %lld",
                      LLONG_MAX);
-        goto fail_comb;
+        goto error;
     }
-    else if (overflow < 0 || terms < 0) {
-        PyErr_Format(PyExc_ValueError,
-                     "k must be a positive integer");
-        goto fail_comb;
+    else if (overflow < 0 || factors < 0) {
+        if (!PyErr_Occurred()) {
+            PyErr_SetString(PyExc_ValueError,
+                            "k must be a non-negative integer");
+        }
+        goto error;
     }
 
-    if (terms == 0) {
-        return PyNumber_Long(_PyLong_One);
+    if (factors == 0) {
+        result = PyLong_FromLong(1);
+        goto done;
     }
 
-    val = PyNumber_Long(n);
-    for (i = 1; i < terms; ++i) {
-        temp_obj1 = PyLong_FromSsize_t(i);
-        if (temp_obj1 == NULL) {
-            goto fail_comb;
-        }
-        temp_obj2 = PyNumber_Subtract(n, temp_obj1);
-        if (temp_obj2 == NULL) {
-            goto fail_comb;
+    result = n;
+    Py_INCREF(result);
+    if (factors == 1) {
+        goto done;
+    }
+
+    factor = n;
+    Py_INCREF(factor);
+    for (i = 1; i < factors; ++i) {
+        Py_SETREF(factor, PyNumber_Subtract(factor, _PyLong_One));
+        if (factor == NULL) {
+            goto error;
         }
-        dump_var = val;
-        val = PyNumber_Multiply(val, temp_obj2);
-        if (val == NULL) {
-            goto fail_comb;
+        Py_SETREF(result, PyNumber_Multiply(result, factor));
+        if (result == NULL) {
+            goto error;
         }
-        Py_DECREF(dump_var);
-        dump_var = NULL;
-        Py_DECREF(temp_obj2);
-        temp_obj2 = PyLong_FromUnsignedLongLong((unsigned long long)(i + 1));
-        if (temp_obj2 == NULL) {
-            goto fail_comb;
+
+        temp = PyLong_FromUnsignedLongLong((unsigned long long)i + 1);
+        if (temp == NULL) {
+            goto error;
         }
-        dump_var = val;
-        val = PyNumber_FloorDivide(val, temp_obj2);
-        if (val == NULL) {
-            goto fail_comb;
+        Py_SETREF(result, PyNumber_FloorDivide(result, temp));
+        Py_DECREF(temp);
+        if (result == NULL) {
+            goto error;
         }
-        Py_DECREF(dump_var);
-        Py_DECREF(temp_obj1);
-        Py_DECREF(temp_obj2);
     }
+    Py_DECREF(factor);
 
-    return val;
-
-fail_comb:
-    Py_XDECREF(val);
-    Py_XDECREF(dump_var);
-    Py_XDECREF(temp_obj1);
-    Py_XDECREF(temp_obj2);
+done:
+    Py_DECREF(n);
+    Py_DECREF(k);
+    return result;
 
+error:
+    Py_XDECREF(factor);
+    Py_XDECREF(result);
+    Py_DECREF(n);
+    Py_DECREF(k);
     return NULL;
 }
 



More information about the Python-checkins mailing list