[pypy-commit] pypy numpy-multidim: (antocuni, mwp)merge heads, wanted to checkin on default, did it on branch by mistake

mwp noreply at buildbot.pypy.org
Thu Nov 10 10:48:05 CET 2011


Author: Mark Pearse <mark.pearse at skynet.be>
Branch: numpy-multidim
Changeset: r49115:f95cf09f56dd
Date: 2011-11-10 10:46 +0100
http://bitbucket.org/pypy/pypy/changeset/f95cf09f56dd/

Log:	(antocuni, mwp)merge heads, wanted to checkin on default, did it on
	branch by mistake

diff --git a/pypy/module/micronumpy/compile.py b/pypy/module/micronumpy/compile.py
--- a/pypy/module/micronumpy/compile.py
+++ b/pypy/module/micronumpy/compile.py
@@ -9,7 +9,7 @@
      descr_new_array, scalar_w, NDimArray)
 from pypy.module.micronumpy import interp_ufuncs
 from pypy.rlib.objectmodel import specialize
-
+import re
 
 class BogusBytecode(Exception):
     pass
@@ -23,6 +23,12 @@
 class WrongFunctionName(Exception):
     pass
 
+class TokenizerError(Exception):
+    pass
+
+class BadToken(Exception):
+    pass
+
 SINGLE_ARG_FUNCTIONS = ["sum", "prod", "max", "min", "all", "any", "unegative"]
 
 class FakeSpace(object):
@@ -192,7 +198,7 @@
         interp.variables[self.name] = self.expr.execute(interp)
 
     def __repr__(self):
-        return "%% = %r" % (self.name, self.expr)
+        return "%r = %r" % (self.name, self.expr)
 
 class ArrayAssignment(Node):
     def __init__(self, name, index, expr):
@@ -214,7 +220,7 @@
 
 class Variable(Node):
     def __init__(self, name):
-        self.name = name
+        self.name = name.strip(" ")
 
     def execute(self, interp):
         return interp.variables[self.name]
@@ -332,7 +338,7 @@
 
 class FunctionCall(Node):
     def __init__(self, name, args):
-        self.name = name
+        self.name = name.strip(" ")
         self.args = args
 
     def __repr__(self):
@@ -375,118 +381,172 @@
         else:
             raise WrongFunctionName
 
+_REGEXES = [
+    ('-?[\d\.]+', 'number'),
+    ('\[', 'array_left'),
+    (':', 'colon'),
+    ('\w+', 'identifier'),
+    ('\]', 'array_right'),
+    ('(->)|[\+\-\*\/]', 'operator'),
+    ('=', 'assign'),
+    (',', 'coma'),
+    ('\|', 'pipe'),
+    ('\(', 'paren_left'),
+    ('\)', 'paren_right'),
+]
+REGEXES = []
+
+for r, name in _REGEXES:
+    REGEXES.append((re.compile(r' *(' + r + ')'), name))
+del _REGEXES
+
+class Token(object):
+    def __init__(self, name, v):
+        self.name = name
+        self.v = v
+
+    def __repr__(self):
+        return '(%s, %s)' % (self.name, self.v)
+
+empty = Token('', '')
+
+class TokenStack(object):
+    def __init__(self, tokens):
+        self.tokens = tokens
+        self.c = 0
+
+    def pop(self):
+        token = self.tokens[self.c]
+        self.c += 1
+        return token
+
+    def get(self, i):
+        if self.c + i >= len(self.tokens):
+            return empty
+        return self.tokens[self.c + i]
+
+    def remaining(self):
+        return len(self.tokens) - self.c
+
+    def push(self):
+        self.c -= 1
+
+    def __repr__(self):
+        return repr(self.tokens[self.c:])
+
 class Parser(object):
-    def parse_identifier(self, id):
-        id = id.strip(" ")
-        #assert id.isalpha()
-        return Variable(id)
+    def tokenize(self, line):
+        tokens = []
+        while True:
+            for r, name in REGEXES:
+                m = r.match(line)
+                if m is not None:
+                    g = m.group(0)
+                    tokens.append(Token(name, g))
+                    line = line[len(g):]
+                    if not line:
+                        return TokenStack(tokens)
+                    break
+            else:
+                raise TokenizerError(line)
 
-    def parse_expression(self, expr):
-        tokens = [i for i in expr.split(" ") if i]
-        if len(tokens) == 1:
-            return self.parse_constant_or_identifier(tokens[0])
+    def parse_number_or_slice(self, tokens):
+        start_tok = tokens.pop()
+        if start_tok.name == 'colon':
+            start = 0
+        else:
+            if tokens.get(0).name != 'colon':
+                return FloatConstant(start_tok.v)
+            start = int(start_tok.v)
+            tokens.pop()
+        if not tokens.get(0).name in ['colon', 'number']:
+            stop = -1
+            step = 1
+        else:
+            next = tokens.pop()
+            if next.name == 'colon':
+                stop = -1
+                step = int(tokens.pop().v)
+            else:
+                stop = int(next.v)
+                if tokens.get(0).name == 'colon':
+                    tokens.pop()
+                    step = int(tokens.pop().v)
+                else:
+                    step = 1
+        return SliceConstant(start, stop, step)
+            
+        
+    def parse_expression(self, tokens):
         stack = []
-        tokens.reverse()
-        while tokens:
+        while tokens.remaining():
             token = tokens.pop()
-            if token == ')':
-                raise NotImplementedError
-            elif self.is_identifier_or_const(token):
-                if stack:
-                    name = stack.pop().name
-                    lhs = stack.pop()
-                    rhs = self.parse_constant_or_identifier(token)
-                    stack.append(Operator(lhs, name, rhs))
+            if token.name == 'identifier':
+                if tokens.remaining() and tokens.get(0).name == 'paren_left':
+                    stack.append(self.parse_function_call(token.v, tokens))
                 else:
-                    stack.append(self.parse_constant_or_identifier(token))
+                    stack.append(Variable(token.v))
+            elif token.name == 'array_left':
+                stack.append(ArrayConstant(self.parse_array_const(tokens)))
+            elif token.name == 'operator':
+                stack.append(Variable(token.v))
+            elif token.name == 'number' or token.name == 'colon':
+                tokens.push()
+                stack.append(self.parse_number_or_slice(tokens))
+            elif token.name == 'pipe':
+                stack.append(RangeConstant(tokens.pop().v))
+                end = tokens.pop()
+                assert end.name == 'pipe'
             else:
-                stack.append(Variable(token))
-        assert len(stack) == 1
-        return stack[-1]
+                tokens.push()
+                break
+        stack.reverse()
+        lhs = stack.pop()
+        while stack:
+            op = stack.pop()
+            assert isinstance(op, Variable)
+            rhs = stack.pop()
+            lhs = Operator(lhs, op.name, rhs)
+        return lhs
 
-    def parse_constant(self, v):
-        lgt = len(v)-1
-        assert lgt >= 0
-        if ':' in v:
-            # a slice
-            if v == ':':
-                return SliceConstant(0, 0, 0)
-            else:
-                l = v.split(':')
-                if len(l) == 2:
-                    one = l[0]
-                    two = l[1]
-                    if not one:
-                        one = 0
-                    else:
-                        one = int(one)
-                    return SliceConstant(int(l[0]), int(l[1]), 1)
-                else:
-                    three = int(l[2])
-                    # all can be empty
-                    if l[0]:
-                        one = int(l[0])
-                    else:
-                        one = 0
-                    if l[1]:
-                        two = int(l[1])
-                    else:
-                        two = -1
-                    return SliceConstant(one, two, three)
-                
-        if v[0] == '[':
-            return ArrayConstant([self.parse_constant(elem)
-                                  for elem in v[1:lgt].split(",")])
-        if v[0] == '|':
-            return RangeConstant(v[1:lgt])
-        return FloatConstant(v)
-
-    def is_identifier_or_const(self, v):
-        c = v[0]
-        if ((c >= 'a' and c <= 'z') or (c >= 'A' and c <= 'Z') or
-            (c >= '0' and c <= '9') or c in '-.[|:'):
-            if v == '-' or v == "->":
-                return False
-            return True
-        return False
-
-    def parse_function_call(self, v):
-        l = v.split('(')
-        assert len(l) == 2
-        name = l[0]
-        cut = len(l[1]) - 1
-        assert cut >= 0
-        args = [self.parse_constant_or_identifier(id)
-                for id in l[1][:cut].split(",")]
+    def parse_function_call(self, name, tokens):
+        args = []
+        tokens.pop() # lparen
+        while tokens.get(0).name != 'paren_right':
+            args.append(self.parse_expression(tokens))
         return FunctionCall(name, args)
 
-    def parse_constant_or_identifier(self, v):
-        c = v[0]
-        if (c >= 'a' and c <= 'z') or (c >= 'A' and c <= 'Z'):
-            if '(' in v:
-                return self.parse_function_call(v)
-            return self.parse_identifier(v)
-        return self.parse_constant(v)
-
-    def parse_array_subscript(self, v):
-        v = v.strip(" ")
-        l = v.split("[")
-        lgt = len(l[1]) - 1
-        assert lgt >= 0
-        rhs = self.parse_constant_or_identifier(l[1][:lgt])
-        return l[0], rhs
+    def parse_array_const(self, tokens):
+        elems = []
+        while True:
+            token = tokens.pop()
+            if token.name == 'number':
+                elems.append(FloatConstant(token.v))
+            elif token.name == 'array_left':
+                elems.append(ArrayConstant(self.parse_array_const(tokens)))
+            else:
+                raise BadToken()
+            token = tokens.pop()
+            if token.name == 'array_right':
+                return elems
+            assert token.name == 'coma'
         
-    def parse_statement(self, line):
-        if '=' in line:
-            lhs, rhs = line.split("=")
-            lhs = lhs.strip(" ")
-            if '[' in lhs:
-                name, index = self.parse_array_subscript(lhs)
-                return ArrayAssignment(name, index, self.parse_expression(rhs))
-            else: 
-                return Assignment(lhs, self.parse_expression(rhs))
-        else:
-            return Execute(self.parse_expression(line))
+    def parse_statement(self, tokens):
+        if (tokens.get(0).name == 'identifier' and
+            tokens.get(1).name == 'assign'):
+            lhs = tokens.pop().v
+            tokens.pop()
+            rhs = self.parse_expression(tokens)
+            return Assignment(lhs, rhs)
+        elif (tokens.get(0).name == 'identifier' and
+              tokens.get(1).name == 'array_left'):
+            name = tokens.pop().v
+            tokens.pop()
+            index = self.parse_expression(tokens)
+            tokens.pop()
+            tokens.pop()
+            return ArrayAssignment(name, index, self.parse_expression(tokens))
+        return Execute(self.parse_expression(tokens))
 
     def parse(self, code):
         statements = []
@@ -495,7 +555,8 @@
                 line = line.split('#', 1)[0]
             line = line.strip(" ")
             if line:
-                statements.append(self.parse_statement(line))
+                tokens = self.tokenize(line)
+                statements.append(self.parse_statement(tokens))
         return Code(statements)
 
 def numpy_compile(code):
diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -6,7 +6,7 @@
 from pypy.rlib import jit
 from pypy.rpython.lltypesystem import lltype
 from pypy.tool.sourcetools import func_with_new_name
-
+from pypy.rlib.rstring import StringBuilder
 
 numpy_driver = jit.JitDriver(greens = ['signature'],
                              reds = ['result_size', 'i', 'self', 'result'])
@@ -68,6 +68,14 @@
         dtype.setitem_w(space, arr.storage, i, w_elem)
     return arr
 
+class ArrayIndex(object):
+    """ An index into an array or view. Offset is a data offset, indexes
+    are respective indexes in dimensions
+    """
+    def __init__(self, indexes, offset):
+        self.indexes = indexes
+        self.offset = offset
+
 class BaseArray(Wrappable):
     _attrs_ = ["invalidates", "signature", "shape"]
 
@@ -209,25 +217,6 @@
             assert isinstance(w_res, BaseArray)
             return w_res.descr_sum(space)
 
-    def _getnums(self, comma):
-        dtype = self.find_dtype()
-        if self.find_size() > 1000:
-            nums = [
-                dtype.str_format(self.eval(index))
-                for index in range(3)
-            ]
-            nums.append("..." + "," * comma)
-            nums.extend([
-                dtype.str_format(self.eval(index))
-                for index in range(self.find_size() - 3, self.find_size())
-            ])
-        else:
-            nums = [
-                dtype.str_format(self.eval(index))
-                for index in range(self.find_size())
-            ]
-        return nums
-
     def get_concrete(self):
         raise NotImplementedError
 
@@ -246,26 +235,35 @@
     def descr_repr(self, space):
         # Simple implementation so that we can see the array.
         # Since what we want is to print a plethora of 2d views, 
-        # use recursive calls to  tostr() to do the work.
+        # use recursive calls to  to_str() to do the work.
         concrete = self.get_concrete()
-        res = "array("
-        res0 = NDimSlice(concrete, self.signature, [], self.shape).tostr(True, indent='       ')
-        if res0=="[]" and isinstance(self,NDimSlice):
-            res0 += ", shape=%s"%(tuple(self.shape),)
-        res += res0
+        res = StringBuilder()
+        res.append("array(")
+        myview = NDimSlice(concrete, self.signature, [], self.shape)
+        res0 = myview.to_str(True, indent='       ')
+        #This is for numpy compliance: an empty slice reports its shape
+        if res0 == "[]" and isinstance(self, NDimSlice):
+            res.append("[], shape=(")
+            self_shape = str(self.shape)
+            res.append_slice(str(self_shape), 1, len(self_shape)-1)
+            res.append(')')
+        else:
+            res.append(res0)
         dtype = concrete.find_dtype()
         if (dtype is not space.fromcache(interp_dtype.W_Float64Dtype) and
-            dtype is not space.fromcache(interp_dtype.W_Int64Dtype)) or not self.find_size():
-            res += ", dtype=" + dtype.name
-        res += ")"
-        return space.wrap(res)
+            dtype is not space.fromcache(interp_dtype.W_Int64Dtype)) or \
+            not self.find_size():
+            res.append(", dtype=" + dtype.name)
+        res.append(")")
+        return space.wrap(res.build())
 
     def descr_str(self, space):
         # Simple implementation so that we can see the array. 
         # Since what we want is to print a plethora of 2d views, let
         # a slice do the work for us.
         concrete = self.get_concrete()
-        return space.wrap(NDimSlice(concrete, self.signature, [], self.shape).tostr(False))
+        r = NDimSlice(concrete, self.signature, [], self.shape).to_str(False)
+        return space.wrap(r)
 
     def _index_of_single_item(self, space, w_idx):
         # we assume C ordering for now
@@ -297,9 +295,6 @@
             item += v
         return item
 
-    def len_of_shape(self):
-        return len(self.shape)
-
     def get_root_shape(self):
         return self.shape
 
@@ -307,7 +302,7 @@
         """ The result of getitem/setitem is a single item if w_idx
         is a list of scalars that match the size of shape
         """
-        shape_len = self.len_of_shape()
+        shape_len = len(self.shape)
         if shape_len == 0:
             if not space.isinstance_w(w_idx, space.w_int):
                 raise OperationError(space.w_IndexError, space.wrap(
@@ -409,6 +404,7 @@
         return scalar_w(space, dtype, w_obj)
 
 def scalar_w(space, dtype, w_obj):
+    assert isinstance(dtype, interp_dtype.W_Dtype)
     return Scalar(dtype, dtype.unwrap(space, w_obj))
 
 class Scalar(BaseArray):
@@ -586,16 +582,12 @@
 
 class NDimSlice(ViewArray):
     signature = signature.BaseSignature()
-    
+
     _immutable_fields_ = ['shape[*]', 'chunks[*]']
 
     def __init__(self, parent, signature, chunks, shape):
         ViewArray.__init__(self, parent, signature, shape)
         self.chunks = chunks
-        self.shape_reduction = 0
-        for chunk in chunks:
-            if chunk[-2] == 0:
-                self.shape_reduction += 1
 
     def get_root_storage(self):
         return self.parent.get_concrete().get_root_storage()
@@ -624,9 +616,6 @@
     def setitem(self, item, value):
         self.parent.setitem(self.calc_index(item), value)
 
-    def len_of_shape(self):
-        return self.parent.len_of_shape() - self.shape_reduction
-
     def get_root_shape(self):
         return self.parent.get_root_shape()
 
@@ -636,7 +625,6 @@
     @jit.unroll_safe
     def calc_index(self, item):
         index = []
-        __item = item
         _item = item
         for i in range(len(self.shape) -1, 0, -1):
             s = self.shape[i]
@@ -666,46 +654,57 @@
             item += index[i]
             i += 1
         return item
-    def tostr(self, comma,indent=' '):
-        ret = ''
+
+    def to_str(self, comma, indent=' '):
+        ret = StringBuilder()
         dtype = self.find_dtype()
-        ndims = len(self.shape)#-self.shape_reduction
-        if any([s==0 for s in self.shape]):
-            ret += '[]'
-            return ret
-        if ndims>2:
-            ret += '['
+        ndims = len(self.shape)
+        for s in self.shape:
+            if s == 0:
+                ret.append('[]')
+                return ret.build()
+        if ndims > 2:
+            ret.append('[')
             for i in range(self.shape[0]):
-                ret += NDimSlice(self.parent, self.signature, [(i,0,0,1)], self.shape[1:]).tostr(comma,indent=indent+' ')
-                if i+1<self.shape[0]:
-                    ret += ',\n\n'+ indent
-            ret += ']'
-        elif ndims==2:
-            ret += '['
+                smallerview = NDimSlice(self.parent, self.signature,
+                                        [(i, 0, 0, 1)], self.shape[1:])
+                ret.append(smallerview.to_str(comma, indent=indent + ' '))
+                if i + 1 < self.shape[0]:
+                    ret.append(',\n\n' + indent)
+            ret.append(']')
+        elif ndims == 2:
+            ret.append('[')
             for i in range(self.shape[0]):
-                ret += '['
-                ret += (','*comma + ' ' ).join([dtype.str_format(self.eval(i*self.shape[1]+j)) \
-                                                    for j in range(self.shape[1])])
-                ret += ']'
-                if i+1< self.shape[0]:
-                    ret += ',\n' + indent
-            ret += ']'
-        elif ndims==1:
-            ret += '['
-            if self.shape[0]>1000:
-                ret += (','*comma + ' ').join([dtype.str_format(self.eval(j)) \
-                                                    for j in range(3)])
-                ret += ','*comma + ' ..., '
-                ret += (','*comma + ' ').join([dtype.str_format(self.eval(j)) \
-                                                    for j in range(self.shape[0]-3,self.shape[0])])
+                ret.append('[')
+                spacer = ',' * comma + ' '
+                ret.append(spacer.join(\
+                    [dtype.str_format(self.eval(i * self.shape[1] + j)) \
+                    for j in range(self.shape[1])]))
+                ret.append(']')
+                if i + 1 < self.shape[0]:
+                    ret.append(',\n' + indent)
+            ret.append(']')
+        elif ndims == 1:
+            ret.append('[')
+            spacer = ',' * comma + ' '
+            if self.shape[0] > 1000:
+                ret.append(spacer.join([dtype.str_format(self.eval(j)) \
+                           for j in range(3)]))
+                ret.append(',' * comma + ' ..., ')
+                ret.append(spacer.join([dtype.str_format(self.eval(j)) \
+                           for j in range(self.shape[0] - 3, self.shape[0])]))
             else:
-                ret += (','*comma + ' ').join([dtype.str_format(self.eval(j)) \
-                                                    for j in range(self.shape[0])])
-            ret += ']'
+                ret.append(spacer.join([dtype.str_format(self.eval(j)) \
+                           for j in range(self.shape[0])]))
+            ret.append(']')
         else:
-            ret += dtype.str_format(self.eval(0))
-        return ret
+            ret.append(dtype.str_format(self.eval(0)))
+        return ret.build()
+
 class NDimArray(BaseArray):
+    """ A class representing contiguous array. We know that each iteration
+    by say ufunc will increase the data index by one
+    """
     def __init__(self, size, shape, dtype):
         BaseArray.__init__(self, shape)
         self.size = size
diff --git a/pypy/module/micronumpy/test/test_compile.py b/pypy/module/micronumpy/test/test_compile.py
--- a/pypy/module/micronumpy/test/test_compile.py
+++ b/pypy/module/micronumpy/test/test_compile.py
@@ -102,10 +102,11 @@
         code = """
         a = [1,2,3,4]
         b = [4,5,6,5]
-        a + b
+        c = a + b
+        c -> 3
         """
         interp = self.run(code)
-        assert interp.results[0]._getnums(False) == ["5.0", "7.0", "9.0", "9.0"]
+        assert interp.results[-1].value.val == 9
 
     def test_array_getitem(self):
         code = """
@@ -176,3 +177,17 @@
         """)
         assert interp.results[0].value.val == 6
         
+    def test_multidim_getitem(self):
+        interp = self.run("""
+        a = [[1,2]]
+        a -> 0 -> 1
+        """)
+        assert interp.results[0].value.val == 2
+
+    def test_multidim_getitem_2(self):
+        interp = self.run("""
+        a = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]
+        b = a + a
+        b -> 1 -> 1
+        """)
+        assert interp.results[0].value.val == 8
diff --git a/pypy/module/micronumpy/test/test_numarray.py b/pypy/module/micronumpy/test/test_numarray.py
--- a/pypy/module/micronumpy/test/test_numarray.py
+++ b/pypy/module/micronumpy/test/test_numarray.py
@@ -737,6 +737,19 @@
         a = array([[1, 2], [3, 4], [5, 6]])
         assert ((a + a) == array([[1+1, 2+2], [3+3, 4+4], [5+5, 6+6]])).all()
 
+    def test_getitem_add(self):
+        from numpy import array
+        a = array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]])
+        assert (a + a)[1, 1] == 8
+
+    def test_broadcast(self):
+        skip("not working")
+        import numpy
+        a = numpy.zeros((100, 100))
+        b = numpy.ones(100)
+        a[:,:] = b
+        assert a[13,15] == 1
+
 class AppTestSupport(object):
     def setup_class(cls):
         import struct
diff --git a/pypy/module/micronumpy/test/test_zjit.py b/pypy/module/micronumpy/test/test_zjit.py
--- a/pypy/module/micronumpy/test/test_zjit.py
+++ b/pypy/module/micronumpy/test/test_zjit.py
@@ -8,7 +8,7 @@
 from pypy.jit.metainterp.test.support import LLJitMixin
 from pypy.module.micronumpy import interp_ufuncs, signature
 from pypy.module.micronumpy.compile import (numpy_compile, FakeSpace,
-    FloatObject, IntObject, BoolObject)
+    FloatObject, IntObject, BoolObject, Parser, InterpreterState)
 from pypy.module.micronumpy.interp_numarray import NDimArray, NDimSlice
 from pypy.rlib.nonconst import NonConstant
 from pypy.rpython.annlowlevel import llstr, hlstr
@@ -18,12 +18,33 @@
 class TestNumpyJIt(LLJitMixin):
     graph = None
     interp = None
+
+    def setup_class(cls):
+        default = """
+        a = [1,2,3,4]
+        c = a + b
+        sum(c) -> 1::1
+        a -> 3:1:2
+        """
+
+        d = {}
+        p = Parser()
+        allcodes = [p.parse(default)]
+        for name, meth in cls.__dict__.iteritems():
+            if name.startswith("define_"):
+                code = meth()
+                d[name[len("define_"):]] = len(allcodes)
+                allcodes.append(p.parse(code))
+        cls.code_mapping = d
+        cls.codes = allcodes
         
-    def run(self, code):
+    def run(self, name):
         space = FakeSpace()
+        i = self.code_mapping[name]
+        codes = self.codes
         
-        def f(code):
-            interp = numpy_compile(hlstr(code))
+        def f(i):
+            interp = InterpreterState(codes[i])
             interp.run(space)
             res = interp.results[-1]
             w_res = res.eval(0).wrap(interp.space)
@@ -37,55 +58,66 @@
                 return -42.
 
         if self.graph is None:
-            interp, graph = self.meta_interp(f, [llstr(code)],
+            interp, graph = self.meta_interp(f, [i],
                                              listops=True,
                                              backendopt=True,
                                              graph_and_interp_only=True)
             self.__class__.interp = interp
             self.__class__.graph = graph
-
         reset_stats()
         pyjitpl._warmrunnerdesc.memory_manager.alive_loops.clear()
-        return self.interp.eval_graph(self.graph, [llstr(code)])
+        return self.interp.eval_graph(self.graph, [i])
 
-    def test_add(self):
-        result = self.run("""
+    def define_add():
+        return """
         a = |30|
         b = a + a
         b -> 3
-        """)
+        """
+
+    def test_add(self):
+        result = self.run("add")
         self.check_loops({'getarrayitem_raw': 2, 'float_add': 1,
                           'setarrayitem_raw': 1, 'int_add': 1,
                           'int_lt': 1, 'guard_true': 1, 'jump': 1})
         assert result == 3 + 3
 
-    def test_floatadd(self):
-        result = self.run("""
+    def define_float_add():
+        return """
         a = |30| + 3
         a -> 3
-        """)
+        """
+
+    def test_floatadd(self):
+        result = self.run("float_add")
         assert result == 3 + 3
         self.check_loops({"getarrayitem_raw": 1, "float_add": 1,
                           "setarrayitem_raw": 1, "int_add": 1,
                           "int_lt": 1, "guard_true": 1, "jump": 1})
 
-    def test_sum(self):
-        result = self.run("""
+    def define_sum():
+        return """
         a = |30|
         b = a + a
         sum(b)
-        """)
+        """
+
+    def test_sum(self):
+        result = self.run("sum")
         assert result == 2 * sum(range(30))
         self.check_loops({"getarrayitem_raw": 2, "float_add": 2,
                           "int_add": 1,
                           "int_lt": 1, "guard_true": 1, "jump": 1})
 
-    def test_prod(self):
-        result = self.run("""
+    def define_prod():
+        return """
         a = |30|
         b = a + a
         prod(b)
-        """)
+        """
+
+    def test_prod(self):
+        result = self.run("prod")
         expected = 1
         for i in range(30):
             expected *= i * 2
@@ -120,27 +152,33 @@
                           "float_mul": 1, "int_add": 1,
                           "int_lt": 1, "guard_true": 1, "jump": 1})
 
-    def test_any(self):
-        result = self.run("""
+    def define_any():
+        return """
         a = [0,0,0,0,0,0,0,0,0,0,0]
         a[8] = -12
         b = a + a
         any(b)
-        """)
+        """
+
+    def test_any(self):
+        result = self.run("any")
         assert result == 1
         self.check_loops({"getarrayitem_raw": 2, "float_add": 1,
                           "float_ne": 1, "int_add": 1,
                           "int_lt": 1, "guard_true": 1, "jump": 1,
                           "guard_false": 1})
 
-    def test_already_forced(self):
-        result = self.run("""
+    def define_already_forced():
+        return """
         a = |30|
         b = a + 4.5
         b -> 5 # forces
         c = b * 8
         c -> 5
-        """)
+        """
+
+    def test_already_forced(self):
+        result = self.run("already_forced")
         assert result == (5 + 4.5) * 8
         # This is the sum of the ops for both loops, however if you remove the
         # optimization then you end up with 2 float_adds, so we can still be
@@ -149,21 +187,24 @@
                            "setarrayitem_raw": 2, "int_add": 2,
                            "int_lt": 2, "guard_true": 2, "jump": 2})
 
-    def test_ufunc(self):
-        result = self.run("""
+    def define_ufunc():
+        return """
         a = |30|
         b = a + a
         c = unegative(b)
         c -> 3
-        """)
+        """
+
+    def test_ufunc(self):
+        result = self.run("ufunc")
         assert result == -6
         self.check_loops({"getarrayitem_raw": 2, "float_add": 1, "float_neg": 1,
                           "setarrayitem_raw": 1, "int_add": 1,
                           "int_lt": 1, "guard_true": 1, "jump": 1,
         })
 
-    def test_specialization(self):
-        self.run("""
+    def define_specialization():
+        return """
         a = |30|
         b = a + a
         c = unegative(b)
@@ -180,22 +221,57 @@
         d = a * a
         unegative(d)
         d -> 3
-        """)
+        """
+
+    def test_specialization(self):
+        self.run("specialization")
         # This is 3, not 2 because there is a bridge for the exit.
         self.check_loop_count(3)
 
-    def test_slice(self):
-        result = self.run("""
+    def define_slice():
+        return """
         a = |30|
         b = a -> ::3
         c = b + b
         c -> 3
-        """)
+        """
+
+    def test_slice(self):
+        result = self.run("slice")
         assert result == 18
         self.check_loops({'int_mul': 2, 'getarrayitem_raw': 2, 'float_add': 1,
                           'setarrayitem_raw': 1, 'int_add': 3,
                           'int_lt': 1, 'guard_true': 1, 'jump': 1})
 
+    def define_multidim():
+        return """
+        a = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]
+        b = a + a
+        b -> 1 -> 1
+        """
+
+    def test_multidim(self):
+        result = self.run('multidim')
+        assert result == 8
+        self.check_loops({'float_add': 1, 'getarrayitem_raw': 2,
+                          'guard_true': 1, 'int_add': 1, 'int_lt': 1,
+                          'jump': 1, 'setarrayitem_raw': 1})
+
+    def define_multidim_slice():
+        return """
+        a = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14]]
+        b = a -> ::2
+        c = b + b
+        c -> 1 -> 1
+        """
+
+    def test_multidim_slice(self):
+        result = self.run('multidim_slice')
+        assert result == 12
+        py.test.skip("improve")
+        self.check_loops({})
+    
+
 class TestNumpyOld(LLJitMixin):
     def setup_class(cls):
         py.test.skip("old")
diff --git a/pypy/rlib/rsre/rpy.py b/pypy/rlib/rsre/rpy.py
new file mode 100644
--- /dev/null
+++ b/pypy/rlib/rsre/rpy.py
@@ -0,0 +1,49 @@
+
+from pypy.rlib.rsre import rsre_char
+from pypy.rlib.rsre.rsre_core import match
+
+def get_hacked_sre_compile(my_compile):
+    """Return a copy of the sre_compile module for which the _sre
+    module is a custom module that has _sre.compile == my_compile
+    and CODESIZE == rsre_char.CODESIZE.
+    """
+    import sre_compile, __builtin__, new
+    sre_hacked = new.module("_sre_hacked")
+    sre_hacked.compile = my_compile
+    sre_hacked.MAGIC = sre_compile.MAGIC
+    sre_hacked.CODESIZE = rsre_char.CODESIZE
+    sre_hacked.getlower = rsre_char.getlower
+    def my_import(name, *args):
+        if name == '_sre':
+            return sre_hacked
+        else:
+            return default_import(name, *args)
+    src = sre_compile.__file__
+    if src.lower().endswith('.pyc') or src.lower().endswith('.pyo'):
+        src = src[:-1]
+    mod = new.module("sre_compile_hacked")
+    default_import = __import__
+    try:
+        __builtin__.__import__ = my_import
+        execfile(src, mod.__dict__)
+    finally:
+        __builtin__.__import__ = default_import
+    return mod
+
+class GotIt(Exception):
+    pass
+def my_compile(pattern, flags, code, *args):
+    raise GotIt(code, flags, args)
+sre_compile_hacked = get_hacked_sre_compile(my_compile)
+
+def get_code(regexp, flags=0, allargs=False):
+    try:
+        sre_compile_hacked.compile(regexp, flags)
+    except GotIt, e:
+        pass
+    else:
+        raise ValueError("did not reach _sre.compile()!")
+    if allargs:
+        return e.args
+    else:
+        return e.args[0]
diff --git a/pypy/rlib/rsre/rsre_core.py b/pypy/rlib/rsre/rsre_core.py
--- a/pypy/rlib/rsre/rsre_core.py
+++ b/pypy/rlib/rsre/rsre_core.py
@@ -154,7 +154,6 @@
         return (fmarks[groupnum], fmarks[groupnum+1])
 
     def group(self, groupnum=0):
-        "NOT_RPYTHON"   # compatibility
         frm, to = self.span(groupnum)
         if 0 <= frm <= to:
             return self._string[frm:to]
diff --git a/pypy/rlib/rsre/test/test_match.py b/pypy/rlib/rsre/test/test_match.py
--- a/pypy/rlib/rsre/test/test_match.py
+++ b/pypy/rlib/rsre/test/test_match.py
@@ -1,54 +1,8 @@
 import re
-from pypy.rlib.rsre import rsre_core, rsre_char
+from pypy.rlib.rsre import rsre_core
+from pypy.rlib.rsre.rpy import get_code
 
 
-def get_hacked_sre_compile(my_compile):
-    """Return a copy of the sre_compile module for which the _sre
-    module is a custom module that has _sre.compile == my_compile
-    and CODESIZE == rsre_char.CODESIZE.
-    """
-    import sre_compile, __builtin__, new
-    sre_hacked = new.module("_sre_hacked")
-    sre_hacked.compile = my_compile
-    sre_hacked.MAGIC = sre_compile.MAGIC
-    sre_hacked.CODESIZE = rsre_char.CODESIZE
-    sre_hacked.getlower = rsre_char.getlower
-    def my_import(name, *args):
-        if name == '_sre':
-            return sre_hacked
-        else:
-            return default_import(name, *args)
-    src = sre_compile.__file__
-    if src.lower().endswith('.pyc') or src.lower().endswith('.pyo'):
-        src = src[:-1]
-    mod = new.module("sre_compile_hacked")
-    default_import = __import__
-    try:
-        __builtin__.__import__ = my_import
-        execfile(src, mod.__dict__)
-    finally:
-        __builtin__.__import__ = default_import
-    return mod
-
-class GotIt(Exception):
-    pass
-def my_compile(pattern, flags, code, *args):
-    print code
-    raise GotIt(code, flags, args)
-sre_compile_hacked = get_hacked_sre_compile(my_compile)
-
-def get_code(regexp, flags=0, allargs=False):
-    try:
-        sre_compile_hacked.compile(regexp, flags)
-    except GotIt, e:
-        pass
-    else:
-        raise ValueError("did not reach _sre.compile()!")
-    if allargs:
-        return e.args
-    else:
-        return e.args[0]
-
 def get_code_and_re(regexp):
     return get_code(regexp), re.compile(regexp)
 


More information about the pypy-commit mailing list