
On Tue, 2021-01-26 at 06:11 +0100, Ralf Gommers wrote:
On Tue, Jan 26, 2021 at 2:01 AM Sebastian Berg < sebastian@sipsolutions.net> wrote:
Hi all,
does anyone have a thought about how user DTypes (i.e. DTypes not currently part of NumPy) should interact with the "value based promotion" logic we currently have? For now I can just do anything, and we will find out later. And I will have to do something for now, basically with the hope that it all turns out all-right.
But there are multiple options for both what to offer to user DTypes and where we want to move (I am using `bfloat16` as a potential DType here).
1. The "weak" dtype option (this is what JAX does), where:
np.array([1], dtype=bfloat16) + 4.
returns a bfloat16, because 4. is "lower" than all floating point types. In this scheme the user defined `bfloat16` knows that the input is a Python float, but it does not know its value (if an overflow occurs during conversion, it could warn or error but not upcast). For example `np.array([1], dtype=uint4) + 2**5` will try `uint4(2**5)` assuming it works. NumPy is different `2.**300` would ensure the result is a `float64`.
If a DType does not make use of this, it would get the behaviour of option 2.
2. The "default" DType option: np.array([1], dtype=bfloat16) + 4. is always the same as `bfloat16 + float64 -> float64`.
3. Use whatever NumPy considers the "smallest appropriate dtype". This will not always work correctly for unsigned integers, and for floats this would be float16, which doesn't help with bfloat16.
4. Try to expose the actual value. (I do not want to do this, but it is probably a plausible extension with most other options, since the other options can be the "default".)
Within these options, there is one more difficulty. NumPy currently applies the same logic for:
np.array([1], dtype=bfloat16) + np.array(4., dtype=np.float64)
which in my opinion is wrong (the second array is typed). We do have the same issue with deciding what to do in the future for NumPy itself. Right now I feel that new (user) DTypes should live in the future (whatever that future is).
I agree. And I have a preference for option 1. Option 2 is too greedy in upcasting, the value-based casting is problematic in multiple ways (e.g., hard for Numba because output dtype cannot be predicted from input dtypes), and option 4 is hard to understand a rationale for (maybe so the user dtype itself can implement option 3?).
Yes, well, the "rational" for option 4 is that you expose everything that NumPy currently needs (assuming we make no changes). That would be the only way that allows a `bfloat16` to work exactly comparable to a `float16` as currently defined in NumPy. To be clear: It horrifies me, but defining a "better" way is much easier than trying to keep everything as (at least for now) while also thinking about how it should look like in the future (and making sure that user DTypes are ready for that future). My guess is, we can agree on aiming for Option 1 and trying to limit it to Python operators. Unfortunately, only time will tell how feasible that will actually be.
I have said previously, that we could distinguish this for universal functions. But calls like `np.asarray(4.)` are common, and they would lose the information that `4.` was originally a Python float.
Hopefully the future will have way fewer asarray calls in it. Rejecting scalar input to functions would be nice. This is what most other array/tensor libraries do.
Well, right now NumPy has scalars (both ours and Python), and I would expect that changing that may well be more disruptive than changing the value based promotion (assuming we can add good FutureWarnings). I would probabaly need a bit convincing that forbidding `np.add(array, 2)` is worth the trouble, but luckily that is probably an orthogonal question. (The fact that we even accept 0-D arrays as "value based" is probably the biggest difficulty.)
So, recently, I was considering that a better option may be to limit this to math Python operators: +, -, /, **, ...
+1
This discussion may be relevant: https://github.com/data-apis/array-api/issues/14.
I have browsed through it, I guess you also were thinking of limiting scalars to operators (although possibly even more broadly rather than just for promotion purposes). I am not sure I understand this: Non-array ("scalar") operands are not permitted to participate in type promotion. Since they do participate also in JAX and in what I wrote here. They just participate in an abstract way. I.e. as `Floating` or `Integer`, but not like a specific float or integer.
Those are the places where it may make a difference to write:
arr + 4. vs. arr + bfloat16(4.) int8_arr + 1 vs. int8_arr + np.int8(1) arr += 4. (in-place may be the most significant use-case)
while:
np.add(int8_arr, 1) vs. np.add(int8_arr, np.int8(1))
is maybe less significant. On the other hand, it would add a subtle difference between operators vs. direct ufunc calls...
In general, it may not matter: We can choose option 1 (which the bfloat16 does not have to use), and modify it if we ever change the logic in NumPy itself. Basically, I will probably pick option 1 for now and press on, and we can reconsider later. And hope that it does not make things even more complicated than it is now.
Or maybe better just limit it completely to always use the default for user DTypes?
I'm not sure I understand why you like option 1 but want to give user-defined dtypes the choice of opting out of it. Upcasting will rarely make sense for user-defined dtypes anyway.
I never meant this as an opt-out, the question is what you do if the user DType does not opt-in/define the operation. Basically, the we would promote with `Floating` here (or `PyFloating`, but there should be no difference; for now I will do PyFloating, but it should probably be changed later). I was hinting at provide a default fallback, so that if: UserDtype + Floating -> Undefined/Error we automatically try the "default", e.g.: UserDType + Float64 -> Something That would mean users don't have to worry about `Floating` itself. But I am not opinionated here, a user DType author should be able to quickly deal with either issue (that Float64 is undesired or that the Error is undesired if no "default" exists). Maybe the error is more conservative/constructive though.
But I would be interested if the "limit to Python operators" is something we should aim for here. This does make a small difference, because user DTypes could "live" in the future if we have an idea of how that future may look like.
A future with: - no array scalars - 0-D arrays have the same casting rules as >=1-D arrays - no value-based casting would be quite nice. For "same kind" casting like
I don't think array-scalars really matter here, since they are typed and behave identical to 0-D arrays anyway. We can have long opinion pieces on whether they should exist :).
https://data-apis.github.io/array-api/latest/API_specification/type_promotio... . Mixed-kind casting isn't specified there, because it's too different between libraries. The JAX design ( https://jax.readthedocs.io/en/latest/type_promotion.html) seems sensible there.
The JAX design is the "weak DType" design (when it comes to Python numbers). Although, the fact that a "weak" `complex` is sorted above all floats, means that `bfloat16_arr + 1j` will go to the default complex dtype as well. But yes, I like the "weak" approach, just think also JAX has some wrinkles to smoothen. There is a good deal more to this if you get user DTypes and I add one more important constraint that: from my_extension_module import uint24 must not change any existing code that does not explicitly use `uint24`. Then my current approach guarantees: np.result_type(uint24, int48, int64) -> Error If `uint24` and `int48` do not know each other (`int64` is obviously right here, but it is tricky to be quite certain). The other tricky example I have was: The following becomes problematic (order does not matter): uint24 + int16 + uint32 -> int64 <== (uint24 + int16) + (uint24 + uint32) -> int64 <== int32 + uint32 -> int64 With the addition that `uint24 + int32 -> int48` is defined the first could be expected to return `int48`, but actually getting there is tricky (and my current code will not). If promotion result of a user DType with a builtin one, can be a builtin one, then "ammending" the promotion with things like `uint24 + int32 -> int48` can lead to slightly surprising promotion results. This happens if the result of a promotion with another "category" (builtin) can be both a larger category or a lower one. - Sebastian
Cheers, Ralf _______________________________________________ NumPy-Discussion mailing list NumPy-Discussion@python.org https://mail.python.org/mailman/listinfo/numpy-discussion