[Numpy-discussion] A little about XND
skrah at bytereef.org
Mon Jun 18 10:02:18 EDT 2018
On Sun, Jun 17, 2018 at 08:47:02PM -0400, Marten van Kerkwijk wrote:
> More of a detailed question, but as we are currently thinking about
> extending the signature of gufuncs (i.e., things like `(m,n),(n,p)->(m,p)`
> for matrix multiplication), and as you must have thought about this for
> libgufunc, could you point me to how one would document the signature in
> your new system? (I briefly tried but there's no docs yet and I couldn't
> immediately find it in the code).
The docs are a bit scattered across the three libraries, here is something
about types and pattern matching:
A couple of example signatures:
The function signature for float64-specialized matrix multiplication is:
"... * N * M * float64, ... * M * P * float64 -> ... * N * P * float64"
The function signature for generic matrix multiplication is:
"... * N * M * T, ... * M * P * T -> ... * N * P * T"
A function that only accepts scalars:
"... * N * M * Scalar, ... * M * P * Scalar -> ... * N * P * Scalar"
A couple of observations: Functions are multimethods, so function dispatch
on concrete arguments works by trying to locate a matching kernel.
For example, if only the above "float64" kernel is present, all other
dtypes will fail.
It is still under debate how we handle casting. The current examples
libgumath/kernels simply generate *all* signatures that allow exact
casting of the input for a specific function.
This is feasible for unary and binary kernels, but could lead to case
explosion for functions with many arguments.
The kernel writer however is always free to use the above type variable
or Scalar signatures and handle casting inside the kernel.
Gufuncs are explicit and require leading ellipses. A signature of
"N * M * float64" is not a gufunc and does not allow outer dimensions.
"D... * N * M * float64, D... * M * P * float64 -> D... * N * P * float64"
Dimension variables match a sequence of dimensions, so in the above example
all outer dimensions must be exactly the same.
"... * 2 * 3 * int8" only accepts "2 * 3 * int8" as the inner dimensions.
Sorry for the long mail, I hope this clears up a bit what function signatures
generally look like.
More information about the NumPy-Discussion