numpy where and dtype in 1.9
![](https://secure.gravatar.com/avatar/5ae61b60d63f961894912c0cd64ef63e.jpg?s=120&d=mm&r=g)
Hi, The numpy.where() function was rewritten in numpy 1.9 to speed it up. I traced it to this changeset. https://github.com/numpy/numpy/commit/593e3c30c24f0c61a271dc883c614724d7a57e... The weird thing is the 1.9 behavior changed the resulting dtype in some situations when using scalar values as the second or third argument. To try and illustrate, I wrote a simple test script and ran it against both numpy 1.7 and 1.9. Here are the results: 2.7.9 (default, Jul 25 2015, 03:06:43) [GCC 4.4.7 20120313 (Red Hat 4.4.7-3)] ***** numpy version 1.7.2 ***** === testing numpy.where with NaNs === numpy.where([True], numpy.float32(1.0), numpy.NaN).dtype float64 numpy.where([True], [numpy.float32(1.0)], numpy.NaN).dtype float32 numpy.where([True], numpy.float32(1.0), [numpy.NaN]).dtype float64 numpy.where([True], [numpy.float32(1.0)], [numpy.NaN]).dtype float64 === testing numpy.where with integers === numpy.where([True], [numpy.float32(1.0)], 65535).dtype float32 numpy.where([True], [numpy.float32(1.0)], 65536).dtype float32 numpy.where([True], [numpy.float32(1.0)], -32768).dtype float32 numpy.where([True], [numpy.float32(1.0)], -32769).dtype float32 2.7.9 (default, Mar 10 2015, 09:26:44) [GCC 4.4.7 20120313 (Red Hat 4.4.7-3)] ***** numpy version 1.9.2 ***** === testing numpy.where with NaNs === numpy.where([True], numpy.float32(1.0), numpy.NaN).dtype float64 numpy.where([True], [numpy.float32(1.0)], numpy.NaN).dtype float32 numpy.where([True], numpy.float32(1.0), [numpy.NaN]).dtype float64 numpy.where([True], [numpy.float32(1.0)], [numpy.NaN]).dtype float64 === testing numpy.where with integers === numpy.where([True], [numpy.float32(1.0)], 65535).dtype float32 numpy.where([True], [numpy.float32(1.0)], 65536).dtype float64 numpy.where([True], [numpy.float32(1.0)], -32768).dtype float32 numpy.where([True], [numpy.float32(1.0)], -32769).dtype float64 Regarding the NaNs with where, the behavior does not differ between 1.7 and 1.9. But it's a little odd that the one scenario returns a dtype of float32 where the other three scenarios return dtype of float64. I'm not sure if that was intentional or a bug? Regarding using ints with where, in 1.7 the resulting dtype is consistent but then in 1.9 the resulting dtype is influenced by the value of the int. It appears it is somehow related to whether the value falls within the range of a short. I'm not sure if this was a side effect of the performance improvement or was intentional? At the very least I think this change in where() should probably be noted in the release notes for 1.9. Our project saw an increase in memory usage with 1.9 due to where(cond, array, scalar) returning arrays of dtype float64 when using scalars not within that limited range. I've attached my simple script if you're interested in running it.
![](https://secure.gravatar.com/avatar/09939f25b639512a537ce2c90f77f958.jpg?s=120&d=mm&r=g)
What a coincidence! A very related bug just got re-opened today at my behest: https://github.com/numpy/numpy/issues/5095 Not the same, but I wouldn't be surprised if it stems from the same sources. The short of it... np.where(x, 0, x) where x is a masked array, will return a masked array in 1.8.2 and earlier, but will return a regular numpy array in 1.9 and above (drops the mask). That bug took a long time for me to track down! Ben Root On Wed, Jul 29, 2015 at 5:16 PM, Nathan Jensen <ndjensen@gmail.com> wrote:
![](https://secure.gravatar.com/avatar/5ae61b60d63f961894912c0cd64ef63e.jpg?s=120&d=mm&r=g)
Thanks for the link. I'm glad I'm not the only one tripping over the where() changes. Should I open a new ticket for what I've encountered, or just add a comment to 5095 that the behavior of the output's dtype is also different? It doesn't sound like it's going to be fixed in 1.9, so I'm not sure what path forward my software team will take. We'll either have to move to 1.8 or try to explicitly cast it to float32 throughout the software. Given the size of the ndarrays the software works with, we really need those extra 32 bits. :-) On Wed, Jul 29, 2015 at 4:33 PM, Benjamin Root <ben.root@ou.edu> wrote:
![](https://secure.gravatar.com/avatar/09939f25b639512a537ce2c90f77f958.jpg?s=120&d=mm&r=g)
What a coincidence! A very related bug just got re-opened today at my behest: https://github.com/numpy/numpy/issues/5095 Not the same, but I wouldn't be surprised if it stems from the same sources. The short of it... np.where(x, 0, x) where x is a masked array, will return a masked array in 1.8.2 and earlier, but will return a regular numpy array in 1.9 and above (drops the mask). That bug took a long time for me to track down! Ben Root On Wed, Jul 29, 2015 at 5:16 PM, Nathan Jensen <ndjensen@gmail.com> wrote:
![](https://secure.gravatar.com/avatar/5ae61b60d63f961894912c0cd64ef63e.jpg?s=120&d=mm&r=g)
Thanks for the link. I'm glad I'm not the only one tripping over the where() changes. Should I open a new ticket for what I've encountered, or just add a comment to 5095 that the behavior of the output's dtype is also different? It doesn't sound like it's going to be fixed in 1.9, so I'm not sure what path forward my software team will take. We'll either have to move to 1.8 or try to explicitly cast it to float32 throughout the software. Given the size of the ndarrays the software works with, we really need those extra 32 bits. :-) On Wed, Jul 29, 2015 at 4:33 PM, Benjamin Root <ben.root@ou.edu> wrote:
participants (2)
-
Benjamin Root
-
Nathan Jensen