[Numpy-svn] r5044 - in trunk/numpy/core: include/numpy src tests

numpy-svn at scipy.org numpy-svn at scipy.org
Fri Apr 18 08:44:16 EDT 2008


Author: stefan
Date: 2008-04-18 07:43:33 -0500 (Fri, 18 Apr 2008)
New Revision: 5044

Modified:
   trunk/numpy/core/include/numpy/ndarrayobject.h
   trunk/numpy/core/src/arraytypes.inc.src
   trunk/numpy/core/src/multiarraymodule.c
   trunk/numpy/core/tests/test_multiarray.py
Log:
Fast implementation of take [patch by Eric Firing].


Modified: trunk/numpy/core/include/numpy/ndarrayobject.h
===================================================================
--- trunk/numpy/core/include/numpy/ndarrayobject.h	2008-04-18 06:42:49 UTC (rev 5043)
+++ trunk/numpy/core/include/numpy/ndarrayobject.h	2008-04-18 12:43:33 UTC (rev 5044)
@@ -1052,6 +1052,10 @@
                                     void *max, void *out);
 typedef void (PyArray_FastPutmaskFunc)(void *in, void *mask, npy_intp n_in,
                                        void *values, npy_intp nv);
+typedef int  (PyArray_FastTakeFunc)(void *dest, void *src, npy_intp *indarray,
+                                       npy_intp nindarray, npy_intp n_outer,
+                                       npy_intp m_middle, npy_intp nelem,
+                                       NPY_CLIPMODE clipmode);
 
 typedef struct {
         npy_intp *ptr;
@@ -1130,6 +1134,7 @@
 
         PyArray_FastClipFunc *fastclip;
         PyArray_FastPutmaskFunc *fastputmask;
+        PyArray_FastTakeFunc *fasttake;
 } PyArray_ArrFuncs;
 
 #define NPY_ITEM_REFCOUNT   0x01  /* The item must be reference counted
@@ -1752,7 +1757,7 @@
     /* FIXME: This should check for a flag on the data-type
        that states whether or not it is variable length.
        Because the ISFLEXIBLE check is hard-coded to the
-       built-in data-types.  
+       built-in data-types.
      */
 #define PyArray_ISVARIABLE(obj) PyTypeNum_ISFLEXIBLE(PyArray_TYPE(obj))
 

Modified: trunk/numpy/core/src/arraytypes.inc.src
===================================================================
--- trunk/numpy/core/src/arraytypes.inc.src	2008-04-18 06:42:49 UTC (rev 5043)
+++ trunk/numpy/core/src/arraytypes.inc.src	2008-04-18 12:43:33 UTC (rev 5044)
@@ -2179,6 +2179,89 @@
 #define OBJECT_fastputmask NULL
 
 
+
+/************************
+ * Fast take functions
+ *************************/
+
+/**begin repeat
+#name=BOOL,BYTE, UBYTE, SHORT, USHORT, INT, UINT, LONG, ULONG, LONGLONG, ULONGLONG, FLOAT, DOUBLE, LONGDOUBLE,CFLOAT, CDOUBLE, CLONGDOUBLE#
+#type= Bool, byte, ubyte, short, ushort, int, uint, long, ulong, longlong, ulonglong, float, double, longdouble,cfloat, cdouble, clongdouble#
+*/
+static int
+ at name@_fasttake(@type@ *dest, @type@ *src, intp *indarray,
+                    intp nindarray, intp n_outer,
+                    intp m_middle, intp nelem,
+                    NPY_CLIPMODE clipmode)
+{
+    intp i, j, k, tmp;
+
+    switch(clipmode) {
+    case NPY_RAISE:
+        for(i=0; i<n_outer; i++) {
+            for(j=0; j<m_middle; j++) {
+                tmp = indarray[j];
+                if (tmp < 0) tmp = tmp+nindarray;
+                if ((tmp < 0) || (tmp >= nindarray)) {
+                    PyErr_SetString(PyExc_IndexError,
+                                    "index out of range "\
+                                    "for array");
+                    return 1;
+                }
+                if (nelem == 1) *dest++ = *(src+tmp);
+                else {
+                    for(k=0; k<nelem; k++) {
+                        *dest++ = *(src+tmp*nelem+k);
+                    }
+                }
+            }
+            src += nelem*nindarray;
+        }
+        break;
+    case NPY_WRAP:
+        for(i=0; i<n_outer; i++) {
+            for(j=0; j<m_middle; j++) {
+                tmp = indarray[j];
+                if (tmp < 0) while (tmp < 0) tmp += nindarray;
+                else if (tmp >= nindarray)
+                    while (tmp >= nindarray)
+                        tmp -= nindarray;
+                if (nelem == 1) *dest++ = *(src+tmp);
+                else {
+                    for(k=0; k<nelem; k++) {
+                        *dest++ = *(src+tmp*nelem+k);
+                    }
+                }
+            }
+            src += nelem*nindarray;
+        }
+        break;
+    case NPY_CLIP:
+        for(i=0; i<n_outer; i++) {
+            for(j=0; j<m_middle; j++) {
+                tmp = indarray[j];
+                if (tmp < 0)
+                    tmp = 0;
+                else if (tmp >= nindarray)
+                    tmp = nindarray-1;
+                if (nelem == 1) *dest++ = *(src+tmp);
+                else {
+                    for(k=0; k<nelem; k++) {
+                        *dest++ = *(src+tmp*nelem+k);
+                    }
+                }
+            }
+            src += nelem*nindarray;
+        }
+        break;
+    }
+    return 0;
+}
+/**end repeat**/
+
+#define OBJECT_fasttake NULL
+
+
 #define _ALIGN(type) offsetof(struct {char c; type v;},v)
 
 /* Disable harmless compiler warning "4116: unnamed type definition in
@@ -2244,7 +2327,8 @@
     NULL,
     NULL,
     (PyArray_FastClipFunc *)NULL,
-    (PyArray_FastPutmaskFunc *)NULL
+    (PyArray_FastPutmaskFunc *)NULL,
+    (PyArray_FastTakeFunc *)NULL
 };
 
 static PyArray_Descr @from at _Descr = {
@@ -2322,7 +2406,8 @@
     NULL,
     NULL,
     (PyArray_FastClipFunc*)@from at _fastclip,
-    (PyArray_FastPutmaskFunc*)@from at _fastputmask
+    (PyArray_FastPutmaskFunc*)@from at _fastputmask,
+    (PyArray_FastTakeFunc*)@from at _fasttake
 };
 
 static PyArray_Descr @from at _Descr = {

Modified: trunk/numpy/core/src/multiarraymodule.c
===================================================================
--- trunk/numpy/core/src/multiarraymodule.c	2008-04-18 06:42:49 UTC (rev 5043)
+++ trunk/numpy/core/src/multiarraymodule.c	2008-04-18 12:43:33 UTC (rev 5044)
@@ -1136,12 +1136,12 @@
 }
 
 static PyObject *
-_GenericBinaryOutFunction(PyArrayObject *m1, PyObject *m2, PyArrayObject *out, 
+_GenericBinaryOutFunction(PyArrayObject *m1, PyObject *m2, PyArrayObject *out,
 			  PyObject *op)
 {
     if (out == NULL)
 	return PyObject_CallFunction(op, "OO", m1, m2);
-    else 
+    else
 	return PyObject_CallFunction(op, "OOO", m1, m2, out);
 }
 
@@ -1160,7 +1160,7 @@
     }
 
     if (min != NULL) {
-	res2 = _GenericBinaryOutFunction((PyArrayObject *)res1, 
+	res2 = _GenericBinaryOutFunction((PyArrayObject *)res1,
 					 min, out, n_ops.maximum);
 	if (res2 == NULL) {Py_XDECREF(res1); return NULL;}
     }
@@ -1168,7 +1168,7 @@
 	res2 = res1;
 	Py_INCREF(res2);
     }
-    Py_DECREF(res1);    
+    Py_DECREF(res1);
     return res2;
 }
 
@@ -1214,8 +1214,8 @@
     else {
 	newdescr = indescr; /* Steal the reference */
     }
-	
 
+
     /* Use the scalar descriptor only if it is of a bigger
        KIND than the input array (and then find the
        type that matches both).
@@ -1284,8 +1284,8 @@
 	Py_DECREF(min);
 	if (mina == NULL) goto fail;
     }
-	
 
+
     /* Check to see if input is single-segment, aligned,
        and in native byteorder */
     if (PyArray_ISONESEGMENT(self) && PyArray_CHKFLAGS(self, ALIGNED) &&
@@ -1380,7 +1380,7 @@
 	min_data = mina->data;
     if (maxa != NULL)
 	max_data = maxa->data;
-	
+
     func(newin->data, PyArray_SIZE(newin), min_data, max_data,
          newout->data);
 
@@ -3107,10 +3107,10 @@
     NPY_BEGIN_THREADS_DEF
 
     dtype = PyArray_DescrFromObject((PyObject *)op2, op1->descr);
-    
+
     /* need ap1 as contiguous array and of right type */
     Py_INCREF(dtype);
-    ap1 = (PyArrayObject *)PyArray_FromAny((PyObject *)op1, dtype, 
+    ap1 = (PyArrayObject *)PyArray_FromAny((PyObject *)op1, dtype,
 					   1, 1, NPY_DEFAULT, NULL);
 
     if (ap1 == NULL) {
@@ -3746,10 +3746,10 @@
         op = ap;
     }
 
-    /* Will get native-byte order contiguous copy. 
+    /* Will get native-byte order contiguous copy.
      */
     ap = (PyArrayObject *)\
-        PyArray_ContiguousFromAny((PyObject *)op, 
+        PyArray_ContiguousFromAny((PyObject *)op,
                                   op->descr->type_num, 1, 0);
 
     Py_DECREF(op);
@@ -3824,11 +3824,13 @@
 PyArray_TakeFrom(PyArrayObject *self0, PyObject *indices0, int axis,
                  PyArrayObject *ret, NPY_CLIPMODE clipmode)
 {
+    PyArray_FastTakeFunc *func;
     PyArrayObject *self, *indices;
-    intp nd, i, j, n, m, max_item, tmp, chunk;
+    intp nd, i, j, n, m, max_item, tmp, chunk, nelem;
     intp shape[MAX_DIMS];
     char *src, *dest;
     int copyret=0;
+    int err;
 
     indices = NULL;
     self = (PyAO *)_check_axis(self0, &axis, CARRAY);
@@ -3892,10 +3894,14 @@
     }
 
     max_item = self->dimensions[axis];
+    nelem = chunk;
     chunk = chunk * ret->descr->elsize;
     src = self->data;
     dest = ret->data;
 
+    func = self->descr->f->fasttake;
+    if (func == NULL) {
+
     switch(clipmode) {
     case NPY_RAISE:
         for(i=0; i<n; i++) {
@@ -3943,6 +3949,12 @@
         }
         break;
     }
+    }
+    else {
+        err = func(dest, src, (intp *)(indices->data),
+                    max_item, n, m, nelem, clipmode);
+        if (err) goto fail;
+    }
 
     PyArray_INCREF(ret);
 
@@ -5666,8 +5678,8 @@
             Py_XDECREF(type);
             return NULL;
     }
-            
 
+
     /* fast exit if simple call */
     if ((subok && PyArray_Check(op)) ||
         (!subok && PyArray_CheckExact(op))) {
@@ -5787,7 +5799,7 @@
     return NULL;
 }
 
-/* This function is needed for supporting Pickles of 
+/* This function is needed for supporting Pickles of
    numpy scalar objects.
 */
 static PyObject *
@@ -6363,7 +6375,7 @@
 	 */
         if ((tmp == NULL) || (nread == 0)) {
 	    Py_DECREF(ret);
-	    return PyErr_NoMemory();	    
+	    return PyErr_NoMemory();
 	}
 	ret->data = tmp;
         PyArray_DIM(ret,0) = nread;
@@ -6433,7 +6445,7 @@
     if ((elsize=dtype->elsize) == 0) {
         PyErr_SetString(PyExc_ValueError, "Must specify length "\
                         "when using variable-size data-type.");
-        goto done;                        
+        goto done;
     }
 
     /* We would need to alter the memory RENEW code to decrement any
@@ -7126,7 +7138,7 @@
     retobj = (ret ? Py_True : Py_False);
     Py_INCREF(retobj);
 
- finish:    
+ finish:
     Py_XDECREF(d1);
     Py_XDECREF(d2);
     return retobj;

Modified: trunk/numpy/core/tests/test_multiarray.py
===================================================================
--- trunk/numpy/core/tests/test_multiarray.py	2008-04-18 06:42:49 UTC (rev 5043)
+++ trunk/numpy/core/tests/test_multiarray.py	2008-04-18 12:43:33 UTC (rev 5044)
@@ -699,6 +699,56 @@
         ## np.putmask(z,[True,True,True],3)
         pass
 
+class TestTake(ParametricTestCase):
+    def tst_basic(self,x):
+        ind = range(x.shape[0])
+        assert_array_equal(x.take(ind, axis=0), x)
+
+    def testip_types(self):
+        unchecked_types = [str, unicode, np.void, object]
+
+        x = np.random.random(24)*100
+        x.shape = 2,3,4
+        tests = []
+        for types in np.sctypes.itervalues():
+            tests.extend([(self.tst_basic,x.copy().astype(T))
+                          for T in types if T not in unchecked_types])
+        return tests
+
+    def test_raise(self):
+        x = np.random.random(24)*100
+        x.shape = 2,3,4
+        self.failUnlessRaises(IndexError, x.take, [0,1,2], axis=0)
+        self.failUnlessRaises(IndexError, x.take, [-3], axis=0)
+        assert_array_equal(x.take([-1], axis=0)[0], x[1])
+
+    def test_clip(self):
+        x = np.random.random(24)*100
+        x.shape = 2,3,4
+        assert_array_equal(x.take([-1], axis=0, mode='clip')[0], x[0])
+        assert_array_equal(x.take([2], axis=0, mode='clip')[0], x[1])
+
+    def test_wrap(self):
+        x = np.random.random(24)*100
+        x.shape = 2,3,4
+        assert_array_equal(x.take([-1], axis=0, mode='wrap')[0], x[1])
+        assert_array_equal(x.take([2], axis=0, mode='wrap')[0], x[0])
+        assert_array_equal(x.take([3], axis=0, mode='wrap')[0], x[1])
+
+    def tst_byteorder(self,dtype):
+        x = np.array([1,2,3],dtype)
+        assert_array_equal(x.take([0,2,1]),[1,3,2])
+
+    def testip_byteorder(self):
+        return [(self.tst_byteorder,dtype) for dtype in ('>i4','<i4')]
+
+    def test_record_array(self):
+        # Note mixed byteorder.
+        rec = np.array([(-5, 2.0, 3.0), (5.0, 4.0, 3.0)],
+                      dtype=[('x', '<f8'), ('y', '>f8'), ('z', '<f8')])
+        rec1 = rec.take([1])
+        assert rec1['x'] == 5.0 and rec1['y'] == 4.0
+
 class TestLexsort(NumpyTestCase):
     def test_basic(self):
         a = [1,2,1,3,1,5]




More information about the Numpy-svn mailing list