[pypy-commit] pypy default: merge heads

arigo noreply at buildbot.pypy.org
Fri Apr 20 19:02:41 CEST 2012


Author: Armin Rigo <arigo at tunes.org>
Branch: 
Changeset: r54590:ae2bca32fdd5
Date: 2012-04-20 19:01 +0200
http://bitbucket.org/pypy/pypy/changeset/ae2bca32fdd5/

Log:	merge heads

diff --git a/pypy/module/micronumpy/__init__.py b/pypy/module/micronumpy/__init__.py
--- a/pypy/module/micronumpy/__init__.py
+++ b/pypy/module/micronumpy/__init__.py
@@ -5,6 +5,7 @@
     interpleveldefs = {
         'debug_repr': 'interp_extras.debug_repr',
         'remove_invalidates': 'interp_extras.remove_invalidates',
+        'set_invalidation': 'interp_extras.set_invalidation',
     }
     appleveldefs = {}
 
diff --git a/pypy/module/micronumpy/interp_arrayops.py b/pypy/module/micronumpy/interp_arrayops.py
--- a/pypy/module/micronumpy/interp_arrayops.py
+++ b/pypy/module/micronumpy/interp_arrayops.py
@@ -4,7 +4,7 @@
 from pypy.module.micronumpy import signature
 
 class WhereArray(VirtualArray):
-    def __init__(self, arr, x, y):
+    def __init__(self, space, arr, x, y):
         self.arr = arr
         self.x = x
         self.y = y
@@ -87,4 +87,4 @@
     arr = convert_to_array(space, w_arr)
     x = convert_to_array(space, w_x)
     y = convert_to_array(space, w_y)
-    return WhereArray(arr, x, y)
+    return WhereArray(space, arr, x, y)
diff --git a/pypy/module/micronumpy/interp_extras.py b/pypy/module/micronumpy/interp_extras.py
--- a/pypy/module/micronumpy/interp_extras.py
+++ b/pypy/module/micronumpy/interp_extras.py
@@ -1,5 +1,5 @@
 from pypy.interpreter.gateway import unwrap_spec
-from pypy.module.micronumpy.interp_numarray import BaseArray
+from pypy.module.micronumpy.interp_numarray import BaseArray, get_numarray_cache
 
 
 @unwrap_spec(array=BaseArray)
@@ -13,3 +13,7 @@
     """
     del array.invalidates[:]
     return space.w_None
+
+ at unwrap_spec(arg=bool)
+def set_invalidation(space, arg):
+    get_numarray_cache(space).enable_invalidation = arg
diff --git a/pypy/module/micronumpy/interp_numarray.py b/pypy/module/micronumpy/interp_numarray.py
--- a/pypy/module/micronumpy/interp_numarray.py
+++ b/pypy/module/micronumpy/interp_numarray.py
@@ -72,9 +72,10 @@
             arr.force_if_needed()
         del self.invalidates[:]
 
-    def add_invalidates(self, other):
-        self.invalidates.append(other)
-
+    def add_invalidates(self, space, other):
+        if get_numarray_cache(space).enable_invalidation:
+            self.invalidates.append(other)
+        
     def descr__new__(space, w_subtype, w_size, w_dtype=None):
         dtype = space.interp_w(interp_dtype.W_Dtype,
             space.call_function(space.gettypefor(interp_dtype.W_Dtype), w_dtype)
@@ -1583,3 +1584,10 @@
         arr.fill(space, space.wrap(False))
         return arr
     return space.wrap(False)
+
+class NumArrayCache(object):
+    def __init__(self, space):
+        self.enable_invalidation = True
+
+def get_numarray_cache(space):
+    return space.fromcache(NumArrayCache)
diff --git a/pypy/module/micronumpy/interp_ufuncs.py b/pypy/module/micronumpy/interp_ufuncs.py
--- a/pypy/module/micronumpy/interp_ufuncs.py
+++ b/pypy/module/micronumpy/interp_ufuncs.py
@@ -278,7 +278,7 @@
         else:
             w_res = Call1(self.func, self.name, w_obj.shape, calc_dtype,
                                          res_dtype, w_obj)
-        w_obj.add_invalidates(w_res)
+        w_obj.add_invalidates(space, w_res)
         return w_res
 
 
@@ -347,8 +347,8 @@
         w_res = Call2(self.func, self.name,
                       new_shape, calc_dtype,
                       res_dtype, w_lhs, w_rhs, out)
-        w_lhs.add_invalidates(w_res)
-        w_rhs.add_invalidates(w_res)
+        w_lhs.add_invalidates(space, w_res)
+        w_rhs.add_invalidates(space, w_res)
         if out:
             w_res.get_concrete()
         return w_res
diff --git a/pypy/module/micronumpy/test/test_arrayops.py b/pypy/module/micronumpy/test/test_arrayops.py
--- a/pypy/module/micronumpy/test/test_arrayops.py
+++ b/pypy/module/micronumpy/test/test_arrayops.py
@@ -7,3 +7,10 @@
         a = [1, 2, 3, 0, -3]
         a = where(array(a) > 0, ones(5), zeros(5))
         assert (a == [1, 1, 1, 0, 0]).all()
+
+    def test_where_invalidates(self):
+        from _numpypy import where, ones, zeros, array
+        a = array([1, 2, 3, 0, -3])
+        b = where(a > 0, ones(5), zeros(5))
+        a[0] = 0
+        assert (b == [1, 1, 1, 0, 0]).all()


More information about the pypy-commit mailing list