Hi all,
does anyone have a thought about how user DTypes (i.e. DTypes not
currently part of NumPy) should interact with the "value based
promotion" logic we currently have?
For now I can just do anything, and we will find out later. And I will
have to do something for now, basically with the hope that it all turns
out all-right.
But there are multiple options for both what to offer to user DTypes
and where we want to move (I am using `bfloat16` as a potential DType
here).
1. The "weak" dtype option (this is what JAX does), where:
np.array([1], dtype=bfloat16) + 4.
returns a bfloat16, because 4. is "lower" than all floating
point types.
In this scheme the user defined `bfloat16` knows that the input
is a Python float, but it does not know its value (if an
overflow occurs during conversion, it could warn or error but
not upcast). For example `np.array([1], dtype=uint4) + 2**5`
will try `uint4(2**5)` assuming it works.
NumPy is different `2.**300` would ensure the result is a `float64`.
If a DType does not make use of this, it would get the behaviour
of option 2.
2. The "default" DType option: np.array([1], dtype=bfloat16) + 4. is
always the same as `bfloat16 + float64 -> float64`.
3. Use whatever NumPy considers the "smallest appropriate dtype".
This will not always work correctly for unsigned integers, and for
floats this would be float16, which doesn't help with bfloat16.
4. Try to expose the actual value. (I do not want to do this, but it
is probably a plausible extension with most other options, since
the other options can be the "default".)
Within these options, there is one more difficulty. NumPy currently
applies the same logic for:
np.array([1], dtype=bfloat16) + np.array(4., dtype=np.float64)
which in my opinion is wrong (the second array is typed). We do have
the same issue with deciding what to do in the future for NumPy itself.
Right now I feel that new (user) DTypes should live in the future
(whatever that future is).
I have said previously, that we could distinguish this for universal
functions. But calls like `np.asarray(4.)` are common, and they would
lose the information that `4.` was originally a Python float.
So, recently, I was considering that a better option may be to limit
this to math Python operators: +, -, /, **, ...
Those are the places where it may make a difference to write:
arr + 4. vs. arr + bfloat16(4.)
int8_arr + 1 vs. int8_arr + np.int8(1)
arr += 4. (in-place may be the most significant use-case)
while:
np.add(int8_arr, 1) vs. np.add(int8_arr, np.int8(1))
is maybe less significant. On the other hand, it would add a subtle
difference between operators vs. direct ufunc calls...
In general, it may not matter: We can choose option 1 (which the
bfloat16 does not have to use), and modify it if we ever change the
logic in NumPy itself. Basically, I will probably pick option 1 for
now and press on, and we can reconsider later. And hope that it does
not make things even more complicated than it is now.
Or maybe better just limit it completely to always use the default for
user DTypes?
But I would be interested if the "limit to Python operators" is
something we should aim for here. This does make a small difference,
because user DTypes could "live" in the future if we have an idea of
how that future may look like.
Cheers,
Sebastian