[Numpy-svn] r5044 - in trunk/numpy/core: include/numpy src tests
numpy-svn at scipy.org
numpy-svn at scipy.org
Fri Apr 18 08:44:16 EDT 2008
Author: stefan
Date: 2008-04-18 07:43:33 -0500 (Fri, 18 Apr 2008)
New Revision: 5044
Modified:
trunk/numpy/core/include/numpy/ndarrayobject.h
trunk/numpy/core/src/arraytypes.inc.src
trunk/numpy/core/src/multiarraymodule.c
trunk/numpy/core/tests/test_multiarray.py
Log:
Fast implementation of take [patch by Eric Firing].
Modified: trunk/numpy/core/include/numpy/ndarrayobject.h
===================================================================
--- trunk/numpy/core/include/numpy/ndarrayobject.h 2008-04-18 06:42:49 UTC (rev 5043)
+++ trunk/numpy/core/include/numpy/ndarrayobject.h 2008-04-18 12:43:33 UTC (rev 5044)
@@ -1052,6 +1052,10 @@
void *max, void *out);
typedef void (PyArray_FastPutmaskFunc)(void *in, void *mask, npy_intp n_in,
void *values, npy_intp nv);
+typedef int (PyArray_FastTakeFunc)(void *dest, void *src, npy_intp *indarray,
+ npy_intp nindarray, npy_intp n_outer,
+ npy_intp m_middle, npy_intp nelem,
+ NPY_CLIPMODE clipmode);
typedef struct {
npy_intp *ptr;
@@ -1130,6 +1134,7 @@
PyArray_FastClipFunc *fastclip;
PyArray_FastPutmaskFunc *fastputmask;
+ PyArray_FastTakeFunc *fasttake;
} PyArray_ArrFuncs;
#define NPY_ITEM_REFCOUNT 0x01 /* The item must be reference counted
@@ -1752,7 +1757,7 @@
/* FIXME: This should check for a flag on the data-type
that states whether or not it is variable length.
Because the ISFLEXIBLE check is hard-coded to the
- built-in data-types.
+ built-in data-types.
*/
#define PyArray_ISVARIABLE(obj) PyTypeNum_ISFLEXIBLE(PyArray_TYPE(obj))
Modified: trunk/numpy/core/src/arraytypes.inc.src
===================================================================
--- trunk/numpy/core/src/arraytypes.inc.src 2008-04-18 06:42:49 UTC (rev 5043)
+++ trunk/numpy/core/src/arraytypes.inc.src 2008-04-18 12:43:33 UTC (rev 5044)
@@ -2179,6 +2179,89 @@
#define OBJECT_fastputmask NULL
+
+/************************
+ * Fast take functions
+ *************************/
+
+/**begin repeat
+#name=BOOL,BYTE, UBYTE, SHORT, USHORT, INT, UINT, LONG, ULONG, LONGLONG, ULONGLONG, FLOAT, DOUBLE, LONGDOUBLE,CFLOAT, CDOUBLE, CLONGDOUBLE#
+#type= Bool, byte, ubyte, short, ushort, int, uint, long, ulong, longlong, ulonglong, float, double, longdouble,cfloat, cdouble, clongdouble#
+*/
+static int
+ at name@_fasttake(@type@ *dest, @type@ *src, intp *indarray,
+ intp nindarray, intp n_outer,
+ intp m_middle, intp nelem,
+ NPY_CLIPMODE clipmode)
+{
+ intp i, j, k, tmp;
+
+ switch(clipmode) {
+ case NPY_RAISE:
+ for(i=0; i<n_outer; i++) {
+ for(j=0; j<m_middle; j++) {
+ tmp = indarray[j];
+ if (tmp < 0) tmp = tmp+nindarray;
+ if ((tmp < 0) || (tmp >= nindarray)) {
+ PyErr_SetString(PyExc_IndexError,
+ "index out of range "\
+ "for array");
+ return 1;
+ }
+ if (nelem == 1) *dest++ = *(src+tmp);
+ else {
+ for(k=0; k<nelem; k++) {
+ *dest++ = *(src+tmp*nelem+k);
+ }
+ }
+ }
+ src += nelem*nindarray;
+ }
+ break;
+ case NPY_WRAP:
+ for(i=0; i<n_outer; i++) {
+ for(j=0; j<m_middle; j++) {
+ tmp = indarray[j];
+ if (tmp < 0) while (tmp < 0) tmp += nindarray;
+ else if (tmp >= nindarray)
+ while (tmp >= nindarray)
+ tmp -= nindarray;
+ if (nelem == 1) *dest++ = *(src+tmp);
+ else {
+ for(k=0; k<nelem; k++) {
+ *dest++ = *(src+tmp*nelem+k);
+ }
+ }
+ }
+ src += nelem*nindarray;
+ }
+ break;
+ case NPY_CLIP:
+ for(i=0; i<n_outer; i++) {
+ for(j=0; j<m_middle; j++) {
+ tmp = indarray[j];
+ if (tmp < 0)
+ tmp = 0;
+ else if (tmp >= nindarray)
+ tmp = nindarray-1;
+ if (nelem == 1) *dest++ = *(src+tmp);
+ else {
+ for(k=0; k<nelem; k++) {
+ *dest++ = *(src+tmp*nelem+k);
+ }
+ }
+ }
+ src += nelem*nindarray;
+ }
+ break;
+ }
+ return 0;
+}
+/**end repeat**/
+
+#define OBJECT_fasttake NULL
+
+
#define _ALIGN(type) offsetof(struct {char c; type v;},v)
/* Disable harmless compiler warning "4116: unnamed type definition in
@@ -2244,7 +2327,8 @@
NULL,
NULL,
(PyArray_FastClipFunc *)NULL,
- (PyArray_FastPutmaskFunc *)NULL
+ (PyArray_FastPutmaskFunc *)NULL,
+ (PyArray_FastTakeFunc *)NULL
};
static PyArray_Descr @from at _Descr = {
@@ -2322,7 +2406,8 @@
NULL,
NULL,
(PyArray_FastClipFunc*)@from at _fastclip,
- (PyArray_FastPutmaskFunc*)@from at _fastputmask
+ (PyArray_FastPutmaskFunc*)@from at _fastputmask,
+ (PyArray_FastTakeFunc*)@from at _fasttake
};
static PyArray_Descr @from at _Descr = {
Modified: trunk/numpy/core/src/multiarraymodule.c
===================================================================
--- trunk/numpy/core/src/multiarraymodule.c 2008-04-18 06:42:49 UTC (rev 5043)
+++ trunk/numpy/core/src/multiarraymodule.c 2008-04-18 12:43:33 UTC (rev 5044)
@@ -1136,12 +1136,12 @@
}
static PyObject *
-_GenericBinaryOutFunction(PyArrayObject *m1, PyObject *m2, PyArrayObject *out,
+_GenericBinaryOutFunction(PyArrayObject *m1, PyObject *m2, PyArrayObject *out,
PyObject *op)
{
if (out == NULL)
return PyObject_CallFunction(op, "OO", m1, m2);
- else
+ else
return PyObject_CallFunction(op, "OOO", m1, m2, out);
}
@@ -1160,7 +1160,7 @@
}
if (min != NULL) {
- res2 = _GenericBinaryOutFunction((PyArrayObject *)res1,
+ res2 = _GenericBinaryOutFunction((PyArrayObject *)res1,
min, out, n_ops.maximum);
if (res2 == NULL) {Py_XDECREF(res1); return NULL;}
}
@@ -1168,7 +1168,7 @@
res2 = res1;
Py_INCREF(res2);
}
- Py_DECREF(res1);
+ Py_DECREF(res1);
return res2;
}
@@ -1214,8 +1214,8 @@
else {
newdescr = indescr; /* Steal the reference */
}
-
+
/* Use the scalar descriptor only if it is of a bigger
KIND than the input array (and then find the
type that matches both).
@@ -1284,8 +1284,8 @@
Py_DECREF(min);
if (mina == NULL) goto fail;
}
-
+
/* Check to see if input is single-segment, aligned,
and in native byteorder */
if (PyArray_ISONESEGMENT(self) && PyArray_CHKFLAGS(self, ALIGNED) &&
@@ -1380,7 +1380,7 @@
min_data = mina->data;
if (maxa != NULL)
max_data = maxa->data;
-
+
func(newin->data, PyArray_SIZE(newin), min_data, max_data,
newout->data);
@@ -3107,10 +3107,10 @@
NPY_BEGIN_THREADS_DEF
dtype = PyArray_DescrFromObject((PyObject *)op2, op1->descr);
-
+
/* need ap1 as contiguous array and of right type */
Py_INCREF(dtype);
- ap1 = (PyArrayObject *)PyArray_FromAny((PyObject *)op1, dtype,
+ ap1 = (PyArrayObject *)PyArray_FromAny((PyObject *)op1, dtype,
1, 1, NPY_DEFAULT, NULL);
if (ap1 == NULL) {
@@ -3746,10 +3746,10 @@
op = ap;
}
- /* Will get native-byte order contiguous copy.
+ /* Will get native-byte order contiguous copy.
*/
ap = (PyArrayObject *)\
- PyArray_ContiguousFromAny((PyObject *)op,
+ PyArray_ContiguousFromAny((PyObject *)op,
op->descr->type_num, 1, 0);
Py_DECREF(op);
@@ -3824,11 +3824,13 @@
PyArray_TakeFrom(PyArrayObject *self0, PyObject *indices0, int axis,
PyArrayObject *ret, NPY_CLIPMODE clipmode)
{
+ PyArray_FastTakeFunc *func;
PyArrayObject *self, *indices;
- intp nd, i, j, n, m, max_item, tmp, chunk;
+ intp nd, i, j, n, m, max_item, tmp, chunk, nelem;
intp shape[MAX_DIMS];
char *src, *dest;
int copyret=0;
+ int err;
indices = NULL;
self = (PyAO *)_check_axis(self0, &axis, CARRAY);
@@ -3892,10 +3894,14 @@
}
max_item = self->dimensions[axis];
+ nelem = chunk;
chunk = chunk * ret->descr->elsize;
src = self->data;
dest = ret->data;
+ func = self->descr->f->fasttake;
+ if (func == NULL) {
+
switch(clipmode) {
case NPY_RAISE:
for(i=0; i<n; i++) {
@@ -3943,6 +3949,12 @@
}
break;
}
+ }
+ else {
+ err = func(dest, src, (intp *)(indices->data),
+ max_item, n, m, nelem, clipmode);
+ if (err) goto fail;
+ }
PyArray_INCREF(ret);
@@ -5666,8 +5678,8 @@
Py_XDECREF(type);
return NULL;
}
-
+
/* fast exit if simple call */
if ((subok && PyArray_Check(op)) ||
(!subok && PyArray_CheckExact(op))) {
@@ -5787,7 +5799,7 @@
return NULL;
}
-/* This function is needed for supporting Pickles of
+/* This function is needed for supporting Pickles of
numpy scalar objects.
*/
static PyObject *
@@ -6363,7 +6375,7 @@
*/
if ((tmp == NULL) || (nread == 0)) {
Py_DECREF(ret);
- return PyErr_NoMemory();
+ return PyErr_NoMemory();
}
ret->data = tmp;
PyArray_DIM(ret,0) = nread;
@@ -6433,7 +6445,7 @@
if ((elsize=dtype->elsize) == 0) {
PyErr_SetString(PyExc_ValueError, "Must specify length "\
"when using variable-size data-type.");
- goto done;
+ goto done;
}
/* We would need to alter the memory RENEW code to decrement any
@@ -7126,7 +7138,7 @@
retobj = (ret ? Py_True : Py_False);
Py_INCREF(retobj);
- finish:
+ finish:
Py_XDECREF(d1);
Py_XDECREF(d2);
return retobj;
Modified: trunk/numpy/core/tests/test_multiarray.py
===================================================================
--- trunk/numpy/core/tests/test_multiarray.py 2008-04-18 06:42:49 UTC (rev 5043)
+++ trunk/numpy/core/tests/test_multiarray.py 2008-04-18 12:43:33 UTC (rev 5044)
@@ -699,6 +699,56 @@
## np.putmask(z,[True,True,True],3)
pass
+class TestTake(ParametricTestCase):
+ def tst_basic(self,x):
+ ind = range(x.shape[0])
+ assert_array_equal(x.take(ind, axis=0), x)
+
+ def testip_types(self):
+ unchecked_types = [str, unicode, np.void, object]
+
+ x = np.random.random(24)*100
+ x.shape = 2,3,4
+ tests = []
+ for types in np.sctypes.itervalues():
+ tests.extend([(self.tst_basic,x.copy().astype(T))
+ for T in types if T not in unchecked_types])
+ return tests
+
+ def test_raise(self):
+ x = np.random.random(24)*100
+ x.shape = 2,3,4
+ self.failUnlessRaises(IndexError, x.take, [0,1,2], axis=0)
+ self.failUnlessRaises(IndexError, x.take, [-3], axis=0)
+ assert_array_equal(x.take([-1], axis=0)[0], x[1])
+
+ def test_clip(self):
+ x = np.random.random(24)*100
+ x.shape = 2,3,4
+ assert_array_equal(x.take([-1], axis=0, mode='clip')[0], x[0])
+ assert_array_equal(x.take([2], axis=0, mode='clip')[0], x[1])
+
+ def test_wrap(self):
+ x = np.random.random(24)*100
+ x.shape = 2,3,4
+ assert_array_equal(x.take([-1], axis=0, mode='wrap')[0], x[1])
+ assert_array_equal(x.take([2], axis=0, mode='wrap')[0], x[0])
+ assert_array_equal(x.take([3], axis=0, mode='wrap')[0], x[1])
+
+ def tst_byteorder(self,dtype):
+ x = np.array([1,2,3],dtype)
+ assert_array_equal(x.take([0,2,1]),[1,3,2])
+
+ def testip_byteorder(self):
+ return [(self.tst_byteorder,dtype) for dtype in ('>i4','<i4')]
+
+ def test_record_array(self):
+ # Note mixed byteorder.
+ rec = np.array([(-5, 2.0, 3.0), (5.0, 4.0, 3.0)],
+ dtype=[('x', '<f8'), ('y', '>f8'), ('z', '<f8')])
+ rec1 = rec.take([1])
+ assert rec1['x'] == 5.0 and rec1['y'] == 4.0
+
class TestLexsort(NumpyTestCase):
def test_basic(self):
a = [1,2,1,3,1,5]
More information about the Numpy-svn
mailing list