Hypothetically, JAX could be written on top of a "restricted NumPy" instead, which in turn could have an implementation written in LAX. This would facilitate reusing JAX's higher level functions for automatic differentiation and vectorization on top of different array backends.
I would also be happy to see guidance for NumPy API re-implementers, both for those scratching from scratch (e.g., in a new language) or who plan to copy NumPy's Python API (e.g., with __array_function__).
I would focus on:
1. Describing the tradeoffs of challenging design decisions that NumPy may have gotten wrong, e.g., scalars and indexing.
2. Describing common "gotchas" where it's easy to deviate from NumPy's semantics unintentionally, e.g., with scalar arithmetic dtypes or indexing edge cases.
I would *not* try to identify a "core" list of methods/functionality to implement. Everyone uses their own slice of NumPy's API, so the rational approach for anyone trying to reimplement exactly (i.e., with __array_function__) is to start with a minimal subset and add functionality on demand to meet user's needs. Also, many of the choices involved in making an array library don't really have objectively right or wrong answers, and authors are going to make intentional deviations from NumPy's semantics when it makes sense for them.
Cheers,
Stephan