[Python-checkins] bpo-36957: Speed up math.isqrt (#13405)

Mark Dickinson webhook-mailer at python.org
Sun May 19 12:52:00 EDT 2019


https://github.com/python/cpython/commit/5c08ce9bf712acbb3f05a3a57baf51fcb534cdf0
commit: 5c08ce9bf712acbb3f05a3a57baf51fcb534cdf0
branch: master
author: Mark Dickinson <mdickinson at enthought.com>
committer: GitHub <noreply at github.com>
date: 2019-05-19T17:51:56+01:00
summary:

bpo-36957: Speed up math.isqrt (#13405)

* Add math.isqrt function computing the integer square root.

* Code cleanup: remove redundant comments, rename some variables.

* Tighten up code a bit more; use Py_XDECREF to simplify error handling.

* Update Modules/mathmodule.c

Co-Authored-By: Serhiy Storchaka <storchaka at gmail.com>

* Update Modules/mathmodule.c

Use real argument clinic type instead of an alias

Co-Authored-By: Serhiy Storchaka <storchaka at gmail.com>

* Add proof sketch

* Updates from review.

* Correct and expand documentation.

* Fix bad reference handling on error; make some variables block-local; other tidying.

* Style and consistency fixes.

* Add missing error check; don't try to DECREF a NULL a

* Simplify some error returns.

* Another two test cases:

- clarify that floats are rejected even if they happen to be
  squares of small integers
- TypeError beats ValueError for a negative float

* Add fast path for small inputs. Needs tests.

* Speed up isqrt for n >= 2**64 as well; add extra tests.

* Reduce number of test-cases to avoid dominating the run-time of test_math.

* Don't perform unnecessary extra iterations when computing c_bit_length.

* Abstract common uint64_t code out into a separate function.

* Cleanup.

* Add a missing Py_DECREF in an error branch. More cleanup.

* Update Modules/mathmodule.c

Add missing `static` declaration to helper function.

Co-Authored-By: Serhiy Storchaka <storchaka at gmail.com>

* Add missing backtick.

files:
M Lib/test/test_math.py
M Modules/mathmodule.c

diff --git a/Lib/test/test_math.py b/Lib/test/test_math.py
index a11a34478564..853a0e62f823 100644
--- a/Lib/test/test_math.py
+++ b/Lib/test/test_math.py
@@ -917,6 +917,7 @@ def testIsqrt(self):
         test_values = (
             list(range(1000))
             + list(range(10**6 - 1000, 10**6 + 1000))
+            + [2**e + i for e in range(60, 200) for i in range(-40, 40)]
             + [3**9999, 10**5001]
         )
 
diff --git a/Modules/mathmodule.c b/Modules/mathmodule.c
index 7a0044a9fcf0..a153e984ca59 100644
--- a/Modules/mathmodule.c
+++ b/Modules/mathmodule.c
@@ -1620,6 +1620,22 @@ completes the proof sketch.
 
 */
 
+
+/* Approximate square root of a large 64-bit integer.
+
+   Given `n` satisfying `2**62 <= n < 2**64`, return `a`
+   satisfying `(a - 1)**2 < n < (a + 1)**2`. */
+
+static uint64_t
+_approximate_isqrt(uint64_t n)
+{
+    uint32_t u = 1U + (n >> 62);
+    u = (u << 1) + (n >> 59) / u;
+    u = (u << 3) + (n >> 53) / u;
+    u = (u << 7) + (n >> 41) / u;
+    return (u << 15) + (n >> 17) / u;
+}
+
 /*[clinic input]
 math.isqrt
 
@@ -1633,8 +1649,9 @@ static PyObject *
 math_isqrt(PyObject *module, PyObject *n)
 /*[clinic end generated code: output=35a6f7f980beab26 input=5b6e7ae4fa6c43d6]*/
 {
-    int a_too_large, s;
+    int a_too_large, c_bit_length;
     size_t c, d;
+    uint64_t m, u;
     PyObject *a = NULL, *b;
 
     n = PyNumber_Index(n);
@@ -1653,24 +1670,55 @@ math_isqrt(PyObject *module, PyObject *n)
         return PyLong_FromLong(0);
     }
 
+    /* c = (n.bit_length() - 1) // 2 */
     c = _PyLong_NumBits(n);
     if (c == (size_t)(-1)) {
         goto error;
     }
     c = (c - 1U) / 2U;
 
-    /* s = c.bit_length() */
-    s = 0;
-    while ((c >> s) > 0) {
-        ++s;
+    /* Fast path: if c <= 31 then n < 2**64 and we can compute directly with a
+       fast, almost branch-free algorithm. In the final correction, we use `u*u
+       - 1 >= m` instead of the simpler `u*u > m` in order to get the correct
+       result in the corner case where `u=2**32`. */
+    if (c <= 31U) {
+        m = (uint64_t)PyLong_AsUnsignedLongLong(n);
+        Py_DECREF(n);
+        if (m == (uint64_t)(-1) && PyErr_Occurred()) {
+            return NULL;
+        }
+        u = _approximate_isqrt(m << (62U - 2U*c)) >> (31U - c);
+        u -= u * u - 1U >= m;
+        return PyLong_FromUnsignedLongLong((unsigned long long)u);
     }
 
-    a = PyLong_FromLong(1);
+    /* Slow path: n >= 2**64. We perform the first five iterations in C integer
+       arithmetic, then switch to using Python long integers. */
+
+    /* From n >= 2**64 it follows that c.bit_length() >= 6. */
+    c_bit_length = 6;
+    while ((c >> c_bit_length) > 0U) {
+        ++c_bit_length;
+    }
+
+    /* Initialise d and a. */
+    d = c >> (c_bit_length - 5);
+    b = _PyLong_Rshift(n, 2U*c - 62U);
+    if (b == NULL) {
+        goto error;
+    }
+    m = (uint64_t)PyLong_AsUnsignedLongLong(b);
+    Py_DECREF(b);
+    if (m == (uint64_t)(-1) && PyErr_Occurred()) {
+        goto error;
+    }
+    u = _approximate_isqrt(m) >> (31U - d);
+    a = PyLong_FromUnsignedLongLong((unsigned long long)u);
     if (a == NULL) {
         goto error;
     }
-    d = 0;
-    while (--s >= 0) {
+
+    for (int s = c_bit_length - 6; s >= 0; --s) {
         PyObject *q;
         size_t e = d;
 



More information about the Python-checkins mailing list