[Numpy-discussion] numpy where and dtype in 1.9

Nathan Jensen ndjensen at gmail.com
Wed Jul 29 17:16:04 EDT 2015


Hi,

The numpy.where() function was rewritten in numpy 1.9 to speed it up.  I
traced it to this changeset.
https://github.com/numpy/numpy/commit/593e3c30c24f0c61a271dc883c614724d7a57e1e

The weird thing is the 1.9 behavior changed the resulting dtype in some
situations when using scalar values as the second or third argument.  To
try and illustrate, I wrote a simple test script and ran it against both
numpy 1.7 and 1.9.  Here are the results:

2.7.9 (default, Jul 25 2015, 03:06:43)
[GCC 4.4.7 20120313 (Red Hat 4.4.7-3)]
***** numpy version 1.7.2 *****

=== testing numpy.where with NaNs ===
numpy.where([True], numpy.float32(1.0), numpy.NaN).dtype
float64
numpy.where([True], [numpy.float32(1.0)], numpy.NaN).dtype
float32
numpy.where([True], numpy.float32(1.0), [numpy.NaN]).dtype
float64
numpy.where([True], [numpy.float32(1.0)], [numpy.NaN]).dtype
float64


=== testing numpy.where with integers ===
numpy.where([True], [numpy.float32(1.0)], 65535).dtype
float32
numpy.where([True], [numpy.float32(1.0)], 65536).dtype
float32
numpy.where([True], [numpy.float32(1.0)], -32768).dtype
float32
numpy.where([True], [numpy.float32(1.0)], -32769).dtype
float32



2.7.9 (default, Mar 10 2015, 09:26:44)
[GCC 4.4.7 20120313 (Red Hat 4.4.7-3)]
***** numpy version 1.9.2 *****

=== testing numpy.where with NaNs ===
numpy.where([True], numpy.float32(1.0), numpy.NaN).dtype
float64
numpy.where([True], [numpy.float32(1.0)], numpy.NaN).dtype
float32
numpy.where([True], numpy.float32(1.0), [numpy.NaN]).dtype
float64
numpy.where([True], [numpy.float32(1.0)], [numpy.NaN]).dtype
float64


=== testing numpy.where with integers ===
numpy.where([True], [numpy.float32(1.0)], 65535).dtype
float32
numpy.where([True], [numpy.float32(1.0)], 65536).dtype
float64
numpy.where([True], [numpy.float32(1.0)], -32768).dtype
float32
numpy.where([True], [numpy.float32(1.0)], -32769).dtype
float64



Regarding the NaNs with where, the behavior does not differ between 1.7 and
1.9.  But it's a little odd that the one scenario returns a dtype of
float32 where the other three scenarios return dtype of float64.  I'm not
sure if that was intentional or a bug?

Regarding using ints with where, in 1.7 the resulting dtype is consistent
but then in 1.9 the resulting dtype is influenced by the value of the int.
It appears it is somehow related to whether the value falls within the
range of a short.  I'm not sure if this was a side effect of the
performance improvement or was intentional?

At the very least I think this change in where() should probably be noted
in the release notes for 1.9.  Our project saw an increase in memory usage
with 1.9 due to where(cond, array, scalar) returning arrays of dtype
float64 when using scalars not within that limited range.

I've attached my simple script if you're interested in running it.
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <http://mail.python.org/pipermail/numpy-discussion/attachments/20150729/8b17ef98/attachment.html>
-------------- next part --------------
A non-text attachment was scrubbed...
Name: testNumpyWhere.py
Type: text/x-python
Size: 1247 bytes
Desc: not available
URL: <http://mail.python.org/pipermail/numpy-discussion/attachments/20150729/8b17ef98/attachment.py>


More information about the NumPy-Discussion mailing list