[Python-checkins] cpython (merge 3.3 -> default): Issue #17812: Fixed quadratic complexity of base64.b32encode().

serhiy.storchaka python-checkins at python.org
Sun May 19 10:50:03 CEST 2013


http://hg.python.org/cpython/rev/1b5ef05d6ced
changeset:   83837:1b5ef05d6ced
parent:      83835:5abe85aefe29
parent:      83836:4b5467d997f1
user:        Serhiy Storchaka <storchaka at gmail.com>
date:        Sun May 19 11:49:32 2013 +0300
summary:
  Issue #17812: Fixed quadratic complexity of base64.b32encode().
Optimize base64.b32encode() and base64.b32decode() (speed up to 3x).

files:
  Lib/base64.py |  125 ++++++++++++++-----------------------
  Misc/NEWS     |    3 +
  2 files changed, 51 insertions(+), 77 deletions(-)


diff --git a/Lib/base64.py b/Lib/base64.py
--- a/Lib/base64.py
+++ b/Lib/base64.py
@@ -138,21 +138,10 @@
 
 
 # Base32 encoding/decoding must be done in Python
-_b32alphabet = {
-    0: b'A',  9: b'J', 18: b'S', 27: b'3',
-    1: b'B', 10: b'K', 19: b'T', 28: b'4',
-    2: b'C', 11: b'L', 20: b'U', 29: b'5',
-    3: b'D', 12: b'M', 21: b'V', 30: b'6',
-    4: b'E', 13: b'N', 22: b'W', 31: b'7',
-    5: b'F', 14: b'O', 23: b'X',
-    6: b'G', 15: b'P', 24: b'Y',
-    7: b'H', 16: b'Q', 25: b'Z',
-    8: b'I', 17: b'R', 26: b'2',
-    }
-
-_b32tab = [v[0] for k, v in sorted(_b32alphabet.items())]
-_b32rev = dict([(v[0], k) for k, v in _b32alphabet.items()])
-
+_b32alphabet = b'ABCDEFGHIJKLMNOPQRSTUVWXYZ234567'
+_b32tab = [bytes([i]) for i in _b32alphabet]
+_b32tab2 = [a + b for a in _b32tab for b in _b32tab]
+_b32rev = {v: k for k, v in enumerate(_b32alphabet)}
 
 def b32encode(s):
     """Encode a byte string using Base32.
@@ -161,41 +150,30 @@
     """
     if not isinstance(s, bytes_types):
         raise TypeError("expected bytes, not %s" % s.__class__.__name__)
-    quanta, leftover = divmod(len(s), 5)
+    leftover = len(s) % 5
     # Pad the last quantum with zero bits if necessary
     if leftover:
         s = s + bytes(5 - leftover)  # Don't use += !
-        quanta += 1
-    encoded = bytes()
-    for i in range(quanta):
-        # c1 and c2 are 16 bits wide, c3 is 8 bits wide.  The intent of this
-        # code is to process the 40 bits in units of 5 bits.  So we take the 1
-        # leftover bit of c1 and tack it onto c2.  Then we take the 2 leftover
-        # bits of c2 and tack them onto c3.  The shifts and masks are intended
-        # to give us values of exactly 5 bits in width.
-        c1, c2, c3 = struct.unpack('!HHB', s[i*5:(i+1)*5])
-        c2 += (c1 & 1) << 16 # 17 bits wide
-        c3 += (c2 & 3) << 8  # 10 bits wide
-        encoded += bytes([_b32tab[c1 >> 11],         # bits 1 - 5
-                          _b32tab[(c1 >> 6) & 0x1f], # bits 6 - 10
-                          _b32tab[(c1 >> 1) & 0x1f], # bits 11 - 15
-                          _b32tab[c2 >> 12],         # bits 16 - 20 (1 - 5)
-                          _b32tab[(c2 >> 7) & 0x1f], # bits 21 - 25 (6 - 10)
-                          _b32tab[(c2 >> 2) & 0x1f], # bits 26 - 30 (11 - 15)
-                          _b32tab[c3 >> 5],          # bits 31 - 35 (1 - 5)
-                          _b32tab[c3 & 0x1f],        # bits 36 - 40 (1 - 5)
-                          ])
+    encoded = bytearray()
+    from_bytes = int.from_bytes
+    b32tab2 = _b32tab2
+    for i in range(0, len(s), 5):
+        c = from_bytes(s[i: i + 5], 'big')
+        encoded += (b32tab2[c >> 30] +           # bits 1 - 10
+                    b32tab2[(c >> 20) & 0x3ff] + # bits 11 - 20
+                    b32tab2[(c >> 10) & 0x3ff] + # bits 21 - 30
+                    b32tab2[c & 0x3ff]           # bits 31 - 40
+                   )
     # Adjust for any leftover partial quanta
     if leftover == 1:
-        return encoded[:-6] + b'======'
+        encoded[-6:] = b'======'
     elif leftover == 2:
-        return encoded[:-4] + b'===='
+        encoded[-4:] = b'===='
     elif leftover == 3:
-        return encoded[:-3] + b'==='
+        encoded[-3:] = b'==='
     elif leftover == 4:
-        return encoded[:-1] + b'='
-    return encoded
-
+        encoded[-1:] = b'='
+    return bytes(encoded)
 
 def b32decode(s, casefold=False, map01=None):
     """Decode a Base32 encoded byte string.
@@ -217,8 +195,7 @@
     characters present in the input.
     """
     s = _bytes_from_decode_data(s)
-    quanta, leftover = divmod(len(s), 8)
-    if leftover:
+    if len(s) % 8:
         raise binascii.Error('Incorrect padding')
     # Handle section 2.4 zero and one mapping.  The flag map01 will be either
     # False, or the character to map the digit 1 (one) to.  It should be
@@ -232,42 +209,36 @@
     # Strip off pad characters from the right.  We need to count the pad
     # characters because this will tell us how many null bytes to remove from
     # the end of the decoded string.
-    padchars = 0
-    mo = re.search(b'(?P<pad>[=]*)$', s)
-    if mo:
-        padchars = len(mo.group('pad'))
-        if padchars > 0:
-            s = s[:-padchars]
+    l = len(s)
+    s = s.rstrip(b'=')
+    padchars = l - len(s)
     # Now decode the full quanta
-    parts = []
-    acc = 0
-    shift = 35
-    for c in s:
-        val = _b32rev.get(c)
-        if val is None:
+    decoded = bytearray()
+    b32rev = _b32rev
+    for i in range(0, len(s), 8):
+        quanta = s[i: i + 8]
+        acc = 0
+        try:
+            for c in quanta:
+                acc = (acc << 5) + b32rev[c]
+        except KeyError:
             raise TypeError('Non-base32 digit found')
-        acc += _b32rev[c] << shift
-        shift -= 5
-        if shift < 0:
-            parts.append(binascii.unhexlify(bytes('%010x' % acc, "ascii")))
-            acc = 0
-            shift = 35
+        decoded += acc.to_bytes(5, 'big')
     # Process the last, partial quanta
-    last = binascii.unhexlify(bytes('%010x' % acc, "ascii"))
-    if padchars == 0:
-        last = b''                      # No characters
-    elif padchars == 1:
-        last = last[:-1]
-    elif padchars == 3:
-        last = last[:-2]
-    elif padchars == 4:
-        last = last[:-3]
-    elif padchars == 6:
-        last = last[:-4]
-    else:
-        raise binascii.Error('Incorrect padding')
-    parts.append(last)
-    return b''.join(parts)
+    if padchars:
+        acc <<= 5 * padchars
+        last = acc.to_bytes(5, 'big')
+        if padchars == 1:
+            decoded[-5:] = last[:-1]
+        elif padchars == 3:
+            decoded[-5:] = last[:-2]
+        elif padchars == 4:
+            decoded[-5:] = last[:-3]
+        elif padchars == 6:
+            decoded[-5:] = last[:-4]
+        else:
+            raise binascii.Error('Incorrect padding')
+    return bytes(decoded)
 
 
 
diff --git a/Misc/NEWS b/Misc/NEWS
--- a/Misc/NEWS
+++ b/Misc/NEWS
@@ -10,6 +10,9 @@
 Core and Builtins
 -----------------
 
+- Issue #17812: Fixed quadratic complexity of base64.b32encode().
+  Optimize base64.b32encode() and base64.b32decode() (speed up to 3x).
+
 - Issue #17937: Try harder to collect cyclic garbage at shutdown.
 
 - Issue #12370: Prevent class bodies from interfering with the __class__

-- 
Repository URL: http://hg.python.org/cpython


More information about the Python-checkins mailing list