[pypy-commit] pypy py3.3: Introduce gateway.Unwrapper, a convenient way to write custom unwrap_spec functions,

amauryfa noreply at buildbot.pypy.org
Sat Apr 12 21:24:24 CEST 2014


Author: Amaury Forgeot d'Arc <amauryfa at gmail.com>
Branch: py3.3
Changeset: r70605:fb069da00160
Date: 2014-04-12 17:57 +0200
http://bitbucket.org/pypy/pypy/changeset/fb069da00160/

Log:	Introduce gateway.Unwrapper, a convenient way to write custom
	unwrap_spec functions, Similar to the "O&" spec in PyArg_ParseTuple.

	Use it in binascii: "a2b" functions now accept ASCII-only unicode
	strings. (CPython Issue #13637)

diff --git a/pypy/interpreter/gateway.py b/pypy/interpreter/gateway.py
--- a/pypy/interpreter/gateway.py
+++ b/pypy/interpreter/gateway.py
@@ -53,10 +53,24 @@
 
 #________________________________________________________________
 
+
+class Unwrapper(object):
+    """A base class for custom unwrap_spec items.
+
+    Subclasses must override unwrap().
+    """
+    def _freeze_(self):
+        return True
+
+    def unwrap(self, space, w_value):
+        """NOT_RPYTHON"""
+        raise NotImplementedError
+
+
 class UnwrapSpecRecipe(object):
     "NOT_RPYTHON"
 
-    bases_order = [W_Root, ObjSpace, Arguments, object]
+    bases_order = [W_Root, ObjSpace, Arguments, Unwrapper, object]
 
     def dispatch(self, el, *args):
         if isinstance(el, str):
@@ -159,6 +173,9 @@
     def visit_truncatedint_w(self, el, app_sig):
         self.checked_space_method(el, app_sig)
 
+    def visit__Unwrapper(self, el, app_sig):
+        self.checked_space_method(el, app_sig)
+
     def visit__ObjSpace(self, el, app_sig):
         self.orig_arg()
 
@@ -218,6 +235,10 @@
         self.run_args.append("space.descr_self_interp_w(%s, %s)" %
                              (self.use(typ), self.scopenext()))
 
+    def visit__Unwrapper(self, typ):
+        self.run_args.append("%s().unwrap(space, %s)" %
+                             (self.use(typ), self.scopenext()))
+
     def visit__ObjSpace(self, el):
         self.run_args.append('space')
 
@@ -364,6 +385,10 @@
         self.unwrap.append("space.descr_self_interp_w(%s, %s)" %
                            (self.use(typ), self.nextarg()))
 
+    def visit__Unwrapper(self, typ):
+        self.unwrap.append("%s().unwrap(space, %s)" %
+                           (self.use(typ), self.nextarg()))
+
     def visit__ObjSpace(self, el):
         if self.finger > 1:
             raise FastFuncNotSupported
diff --git a/pypy/interpreter/test/test_gateway.py b/pypy/interpreter/test/test_gateway.py
--- a/pypy/interpreter/test/test_gateway.py
+++ b/pypy/interpreter/test/test_gateway.py
@@ -531,6 +531,23 @@
         raises(gateway.OperationError, space.call_function, w_app_g3_u,
                w(42))
 
+    def test_interp2app_unwrap_spec_unwrapper(self):
+        space = self.space
+        class Unwrapper(gateway.Unwrapper):
+            def unwrap(self, space, w_value):
+                return space.int_w(w_value)
+
+        w = space.wrap
+        def g3_u(space, value):
+            return space.wrap(value + 1)
+        app_g3_u = gateway.interp2app_temp(g3_u,
+                                         unwrap_spec=[gateway.ObjSpace,
+                                                      Unwrapper])
+        assert self.space.eq_w(
+            space.call_function(w(app_g3_u), w(42)), w(43))
+        raises(gateway.OperationError, space.call_function,
+               w(app_g3_u), w(None))
+
     def test_interp2app_classmethod(self):
         space = self.space
         w = space.wrap
diff --git a/pypy/module/binascii/interp_base64.py b/pypy/module/binascii/interp_base64.py
--- a/pypy/module/binascii/interp_base64.py
+++ b/pypy/module/binascii/interp_base64.py
@@ -2,6 +2,7 @@
 from pypy.interpreter.gateway import unwrap_spec
 from rpython.rlib.rstring import StringBuilder
 from pypy.module.binascii.interp_binascii import raise_Error
+from pypy.module.binascii.interp_binascii import AsciiBufferUnwrapper
 from rpython.rlib.rarithmetic import ovfcheck
 
 # ____________________________________________________________
@@ -34,8 +35,7 @@
 table_a2b_base64 = ''.join(map(_transform, table_a2b_base64))
 assert len(table_a2b_base64) == 256
 
-
- at unwrap_spec(ascii='bufferstr')
+ at unwrap_spec(ascii=AsciiBufferUnwrapper)
 def a2b_base64(space, ascii):
     "Decode a line of base64 data."
 
diff --git a/pypy/module/binascii/interp_binascii.py b/pypy/module/binascii/interp_binascii.py
--- a/pypy/module/binascii/interp_binascii.py
+++ b/pypy/module/binascii/interp_binascii.py
@@ -1,4 +1,5 @@
 from pypy.interpreter.error import OperationError
+from pypy.interpreter.gateway import Unwrapper
 
 class Cache:
     def __init__(self, space):
@@ -13,3 +14,11 @@
 def raise_Incomplete(space, msg):
     w_error = space.fromcache(Cache).w_incomplete
     raise OperationError(w_error, space.wrap(msg))
+
+# a2b functions accept bytes and buffers, but also ASCII strings.
+class AsciiBufferUnwrapper(Unwrapper):
+    def unwrap(self, space, w_value):
+        if space.isinstance_w(w_value, space.w_unicode):
+            w_value = space.call_method(w_value, "encode", space.wrap("ascii"))
+        return space.bufferstr_w(w_value)
+
diff --git a/pypy/module/binascii/interp_hexlify.py b/pypy/module/binascii/interp_hexlify.py
--- a/pypy/module/binascii/interp_hexlify.py
+++ b/pypy/module/binascii/interp_hexlify.py
@@ -3,6 +3,7 @@
 from rpython.rlib.rstring import StringBuilder
 from rpython.rlib.rarithmetic import ovfcheck
 from pypy.module.binascii.interp_binascii import raise_Error
+from pypy.module.binascii.interp_binascii import AsciiBufferUnwrapper
 
 # ____________________________________________________________
 
@@ -42,7 +43,7 @@
     raise_Error(space, 'Non-hexadecimal digit found')
 _char2value._always_inline_ = True
 
- at unwrap_spec(hexstr='bufferstr')
+ at unwrap_spec(hexstr=AsciiBufferUnwrapper)
 def unhexlify(space, hexstr):
     '''Binary data of hexadecimal representation.
 hexstr must contain an even number of hex digits (upper or lower case).
diff --git a/pypy/module/binascii/interp_hqx.py b/pypy/module/binascii/interp_hqx.py
--- a/pypy/module/binascii/interp_hqx.py
+++ b/pypy/module/binascii/interp_hqx.py
@@ -2,6 +2,7 @@
 from pypy.interpreter.gateway import unwrap_spec
 from rpython.rlib.rstring import StringBuilder
 from pypy.module.binascii.interp_binascii import raise_Error, raise_Incomplete
+from pypy.module.binascii.interp_binascii import AsciiBufferUnwrapper
 from rpython.rlib.rarithmetic import ovfcheck
 
 # ____________________________________________________________
@@ -62,7 +63,7 @@
 ]
 table_a2b_hqx = ''.join(map(chr, table_a2b_hqx))
 
- at unwrap_spec(ascii='bufferstr')
+ at unwrap_spec(ascii=AsciiBufferUnwrapper)
 def a2b_hqx(space, ascii):
     """Decode .hqx coding.  Returns (bin, done)."""
 
diff --git a/pypy/module/binascii/interp_qp.py b/pypy/module/binascii/interp_qp.py
--- a/pypy/module/binascii/interp_qp.py
+++ b/pypy/module/binascii/interp_qp.py
@@ -1,5 +1,6 @@
 from pypy.interpreter.gateway import unwrap_spec
 from rpython.rlib.rstring import StringBuilder
+from pypy.module.binascii.interp_binascii import AsciiBufferUnwrapper
 
 MAXLINESIZE = 76
 
@@ -14,7 +15,7 @@
         return ord(c) - (ord('a') - 10)
 hexval._always_inline_ = True
 
- at unwrap_spec(data='bufferstr', header=int)
+ at unwrap_spec(data=AsciiBufferUnwrapper, header=int)
 def a2b_qp(space, data, header=0):
     "Decode a string of qp-encoded data."
 
diff --git a/pypy/module/binascii/interp_uu.py b/pypy/module/binascii/interp_uu.py
--- a/pypy/module/binascii/interp_uu.py
+++ b/pypy/module/binascii/interp_uu.py
@@ -1,6 +1,7 @@
 from pypy.interpreter.gateway import unwrap_spec
 from rpython.rlib.rstring import StringBuilder
 from pypy.module.binascii.interp_binascii import raise_Error
+from pypy.module.binascii.interp_binascii import AsciiBufferUnwrapper
 
 # ____________________________________________________________
 
@@ -29,7 +30,7 @@
 _a2b_write._always_inline_ = True
 
 
- at unwrap_spec(ascii='bufferstr')
+ at unwrap_spec(ascii=AsciiBufferUnwrapper)
 def a2b_uu(space, ascii):
     "Decode a line of uuencoded data."
 
diff --git a/pypy/module/binascii/test/test_binascii.py b/pypy/module/binascii/test/test_binascii.py
--- a/pypy/module/binascii/test/test_binascii.py
+++ b/pypy/module/binascii/test/test_binascii.py
@@ -58,6 +58,9 @@
             raises(self.binascii.Error, self.binascii.a2b_uu, bogus + b'\n')
             raises(self.binascii.Error, self.binascii.a2b_uu, bogus + b'\r\n')
             raises(self.binascii.Error, self.binascii.a2b_uu, bogus + b'  \r\n')
+        #
+        assert self.binascii.a2b_uu(u"!6") == b"X"
+        raises(UnicodeEncodeError, self.binascii.a2b_uu, u"caf\xe9")
 
     def test_b2a_uu(self):
         for input, expected in [
@@ -111,6 +114,9 @@
             b"abcdefg",
             ]:
             raises(self.binascii.Error, self.binascii.a2b_base64, bogus)
+        #
+        assert self.binascii.a2b_base64(u"Yg==\n") == b"b"
+        raises(UnicodeEncodeError, self.binascii.a2b_base64, u"caf\xe9")
 
     def test_b2a_base64(self):
         for input, expected in [
@@ -149,6 +155,9 @@
             (b"a_b", b"a b"),
             ]:
             assert self.binascii.a2b_qp(input, header=True) == expected
+        #
+        assert self.binascii.a2b_qp(u"a_b", header=True) == b"a b"
+        raises(UnicodeEncodeError, self.binascii.a2b_qp, u"caf\xe9")
 
     def test_b2a_qp(self):
         for input, flags, expected in [
@@ -230,6 +239,9 @@
             b"AAA AAAAAA:",
             ]:
             raises(self.binascii.Error, self.binascii.a2b_hqx, bogus)
+        #
+        assert self.binascii.a2b_hqx("AAA:") == (b"]u", 1)
+        raises(UnicodeEncodeError, self.binascii.a2b_hqx, u"caf\xe9")
 
     def test_b2a_hqx(self):
         for input, expected in [
@@ -410,6 +422,9 @@
             ]:
             assert self.binascii.unhexlify(input) == expected
             assert self.binascii.a2b_hex(input) == expected
+            assert self.binascii.unhexlify(input.decode('ascii')) == expected
+            assert self.binascii.a2b_hex(input.decode('ascii')) == expected
+        raises(UnicodeEncodeError, self.binascii.a2b_hex, u"caf\xe9")
 
     def test_errors(self):
         binascii = self.binascii


More information about the pypy-commit mailing list