NEP draft for the future behaviour of scalar promotion

Hi all, NumPy has awkward behaviour when it comes to promotion with 0-D arrays, and Python scalars. This is both a technical challenge (numpy needs to inspect the values where it shouldn't), as well as surprising for users. Roughly speaking, I have made a proposal under the 3 points: * NumPy scalars and NumPy arrays always behave the same. * A NumPy array always respects the dtype * A Python scalar is "weak" so that uint8_arr + 3 returns a uint8_arr The NEP is here: https://25105-908607-gh.circle-artifacts.com/0/doc/neps/_build/html/nep-0050... But please refer to the PR, since above may go away or get outdated: https://github.com/numpy/numpy/pull/21103 Note that I have not 100% made up my mind on these, because some alternatives exist which may give a somewhat easier transition. Because of this, this is a very early draft (expect large changes/rewrite), but some feedback/input may go a long way to make sure we keep moving on this project. For those aware of the issues, it probably makes sense to skip ahead to the "Alternatives" section. I do expect that a large refactor/rewrite will be necessary, but need some feedback to keep moving. I had send the poll recently: https://discuss.scientific-python.org/t/poll-future-numpy-behavior-when-mixi... just to say, I have not completely ignored it, although (as expected) the results do not give a very simple answer. Many agree with the choices I made, but some also seem to prefer "strong" Python types, or more special handling of NumPy scalars. Please do not hesitate to give opinions! I am not sure we can find a clear "obviously right" solution. Especially since there are tough backwards compatibility choices (even if most users are likely not to notice). So any input is appreciated. Cheers, Sebastian

I added a few comments on the PR. The main comments of substance I had boil down to: - consistency with other programming languages/major frameworks (perhaps a few more "examples of consistency" for the new approach with others may help strengthen the arguments?)--I know JAX was mentioned, and their dtype promotion docs are quite nice - one thing I struggled with in deciding if some of the "new behaviors" were nicer was the tension between protecting from accidental overflow vs. a more "purist" view that types should be preserved more strictly; the latter would seem consistent with the "principle of least surprise" when moving from a typed language to NumPy work perhaps, though arguably slightly less user-friendly if naively doing some operations with a less formal view of typing (new Python user messing around with NumPy?) On Mon, 21 Feb 2022 at 16:35, Sebastian Berg <sebastian@sipsolutions.net> wrote:

fwiw, my rationale here is that many (most?) beginners will eventually become intermediate-to-advanced, at which point purity becomes increasingly important. It is often easier to explain a "pure" principle to a beginner than it is to navigate around magic behaviour as an expert. At scikit-image tutorials we often begin by having the users overflow a uint8 image, then we explain why that's the case and how to work around it. We have also increasingly encountered users surprised/annoyed that scikit-image blew up their uint8 to a float64, using 8x the RAM.
JAX was mentioned, and their dtype promotion docs are quite nice
My God! They are awesome. Hadn't seen them before. For reference: https://jax.readthedocs.io/en/latest/design_notes/type_promotion.html I certainly wouldn't mind if NumPy adopted these wholesale. Juan. On Mon, 21 Feb 2022, at 9:39 PM, Tyler Reddy wrote:

On Mon, Feb 21, 2022, at 20:56, Juan Nunez-Iglesias wrote:
Just to play a bit of devil's advocate here, I'd have to say that most people will not expect x[0] + 200 To often yield a number less than 200! I think uint8's are especially problematic because they overflow so quickly (you won't easily run into the same behavior with uint16 and higher). Of course, there is no way to pretend that NumPy integers are Python integers, but by changing the casting table for uint8 a bit we may be able to avoid many common errors. Besides, coming from value based casting, users already have this expectation: In [1]: np.uint8(255) + 1 Out[1]: 256 Currently, NumPy scalars and arrays are treated differently. Arrays have stronger types than scalars, in that users expect: In [3]: np.array([253, 254, 255], dtype=np.uint8) + 3 Out[3]: array([0, 1, 2], dtype=uint8) So perhaps the real question is: how important is it to us that arrays and scalars behave the same in the new casting scheme? (JAX, from the docs you linked, also makes the scalar vs array distinction.)
We have also increasingly encountered users surprised/annoyed that scikit-image blew up their uint8 to a float64, using 8x the RAM.
I know this used to be true, but my sense is that it is less and less so, especially now that almost all skimage functions use floatx internally. Stéfan

On Mon, 21 Feb 2022, at 11:50 PM, Stefan van der Walt wrote:
It's tricky though, because I would expect np.uint8(255) + 1 to be equal to 0. (As does JAX, see below.) ie, someone, somewhere, is going to be surprised. I don't think we can help that at all. So my argument is that we should prefer the surprising behaviour that is at least consistent in some overarching framework, and the framework itself should be as parsimonious as possible. I'd prefer not to have to write "except for scalars" in a bunch of places in the docs.
I think uint8's are especially problematic because they overflow so quickly (you won't easily run into the same behavior with uint16 and higher). Of course, there is no way to pretend that NumPy integers are Python integers, but by changing the casting table for uint8 a bit we may be able to avoid many common errors.
See, I kinda hate the idea of special-casing one dtype. Common errors might be a good thing — people can very quickly learn to be careful with uint8s. If we try really hard to hide this reality, people will be surprised *later*, or indeed errors may go unnoticed.
I think the users that expect *both* of those behaviours are a small set.
So perhaps the real question is: how important is it to us that arrays and scalars behave the same in the new casting scheme? (JAX, from the docs you linked, also makes the scalar vs array distinction.)
No, as far as I can tell, they distinguish between *Python* scalars and arrays, not between JAX scalars and arrays. They do have a concept of weakly typed arrays, but I don't think that's what you get when you do jnp.uint8(x). Indeed I just checked that jnp.uint8(255) + 1 returns a uint8 scalar with value 0. (or 0-dimensional array? Not sure how JAX handles scalars, the exact repr returned is DeviceArray(0, dtype=uint8))
We have also increasingly encountered users surprised/annoyed that scikit-image blew up their uint8 to a float64, using 8x the RAM.
I know this used to be true, but my sense is that it is less and less so, especially now that almost all skimage functions use floatx internally.
Greg spent a long time last year making sure that we didn't promote float32 to float64 for this reason. This has reduced some of the burden but not all, and my point is broader: users will not be happy to have uint8 + Python int return an int64 array implicitly. And to quote from the JAX document, which to me seems to be the nail in the coffin for alternatives:
Juan.

I added a few comments on the PR. The main comments of substance I had boil down to: - consistency with other programming languages/major frameworks (perhaps a few more "examples of consistency" for the new approach with others may help strengthen the arguments?)--I know JAX was mentioned, and their dtype promotion docs are quite nice - one thing I struggled with in deciding if some of the "new behaviors" were nicer was the tension between protecting from accidental overflow vs. a more "purist" view that types should be preserved more strictly; the latter would seem consistent with the "principle of least surprise" when moving from a typed language to NumPy work perhaps, though arguably slightly less user-friendly if naively doing some operations with a less formal view of typing (new Python user messing around with NumPy?) On Mon, 21 Feb 2022 at 16:35, Sebastian Berg <sebastian@sipsolutions.net> wrote:

fwiw, my rationale here is that many (most?) beginners will eventually become intermediate-to-advanced, at which point purity becomes increasingly important. It is often easier to explain a "pure" principle to a beginner than it is to navigate around magic behaviour as an expert. At scikit-image tutorials we often begin by having the users overflow a uint8 image, then we explain why that's the case and how to work around it. We have also increasingly encountered users surprised/annoyed that scikit-image blew up their uint8 to a float64, using 8x the RAM.
JAX was mentioned, and their dtype promotion docs are quite nice
My God! They are awesome. Hadn't seen them before. For reference: https://jax.readthedocs.io/en/latest/design_notes/type_promotion.html I certainly wouldn't mind if NumPy adopted these wholesale. Juan. On Mon, 21 Feb 2022, at 9:39 PM, Tyler Reddy wrote:

On Mon, Feb 21, 2022, at 20:56, Juan Nunez-Iglesias wrote:
Just to play a bit of devil's advocate here, I'd have to say that most people will not expect x[0] + 200 To often yield a number less than 200! I think uint8's are especially problematic because they overflow so quickly (you won't easily run into the same behavior with uint16 and higher). Of course, there is no way to pretend that NumPy integers are Python integers, but by changing the casting table for uint8 a bit we may be able to avoid many common errors. Besides, coming from value based casting, users already have this expectation: In [1]: np.uint8(255) + 1 Out[1]: 256 Currently, NumPy scalars and arrays are treated differently. Arrays have stronger types than scalars, in that users expect: In [3]: np.array([253, 254, 255], dtype=np.uint8) + 3 Out[3]: array([0, 1, 2], dtype=uint8) So perhaps the real question is: how important is it to us that arrays and scalars behave the same in the new casting scheme? (JAX, from the docs you linked, also makes the scalar vs array distinction.)
We have also increasingly encountered users surprised/annoyed that scikit-image blew up their uint8 to a float64, using 8x the RAM.
I know this used to be true, but my sense is that it is less and less so, especially now that almost all skimage functions use floatx internally. Stéfan

On Mon, 21 Feb 2022, at 11:50 PM, Stefan van der Walt wrote:
It's tricky though, because I would expect np.uint8(255) + 1 to be equal to 0. (As does JAX, see below.) ie, someone, somewhere, is going to be surprised. I don't think we can help that at all. So my argument is that we should prefer the surprising behaviour that is at least consistent in some overarching framework, and the framework itself should be as parsimonious as possible. I'd prefer not to have to write "except for scalars" in a bunch of places in the docs.
I think uint8's are especially problematic because they overflow so quickly (you won't easily run into the same behavior with uint16 and higher). Of course, there is no way to pretend that NumPy integers are Python integers, but by changing the casting table for uint8 a bit we may be able to avoid many common errors.
See, I kinda hate the idea of special-casing one dtype. Common errors might be a good thing — people can very quickly learn to be careful with uint8s. If we try really hard to hide this reality, people will be surprised *later*, or indeed errors may go unnoticed.
I think the users that expect *both* of those behaviours are a small set.
So perhaps the real question is: how important is it to us that arrays and scalars behave the same in the new casting scheme? (JAX, from the docs you linked, also makes the scalar vs array distinction.)
No, as far as I can tell, they distinguish between *Python* scalars and arrays, not between JAX scalars and arrays. They do have a concept of weakly typed arrays, but I don't think that's what you get when you do jnp.uint8(x). Indeed I just checked that jnp.uint8(255) + 1 returns a uint8 scalar with value 0. (or 0-dimensional array? Not sure how JAX handles scalars, the exact repr returned is DeviceArray(0, dtype=uint8))
We have also increasingly encountered users surprised/annoyed that scikit-image blew up their uint8 to a float64, using 8x the RAM.
I know this used to be true, but my sense is that it is less and less so, especially now that almost all skimage functions use floatx internally.
Greg spent a long time last year making sure that we didn't promote float32 to float64 for this reason. This has reduced some of the burden but not all, and my point is broader: users will not be happy to have uint8 + Python int return an int64 array implicitly. And to quote from the JAX document, which to me seems to be the nail in the coffin for alternatives:
Juan.
participants (5)
-
David Menéndez Hurtado
-
Juan Nunez-Iglesias
-
Sebastian Berg
-
Stefan van der Walt
-
Tyler Reddy