On Wed, 2021-09-29 at 15:09 -0400, Aaron Watters wrote:
Hi folks!
The np.choose function raises a ValueError if called with more than 31 choices. This PR adds an alternate implementation np.extended_choose (which uses the base implementation) that supports any number of choices.
https://github.com/numpy/numpy/pull/20001
FYI, I needed this functionality for a mouse embryo microscopy tool I'm building. I'm attempting to contribute it because I thought it might be generally useful.
Thanks for the effort of upstreaming your code! My inclination is against adding it, though. The limitation of `choose` to 32 argument is unfortunate, but adding a new function as a workaround does not seem great, either. Maybe it would be possible to fix `choose` instead? Unfortunately, it seems likely that the current `choose` code is great for few choices but bad for many, so that might require switching between different strategies. Importantly: If your data (choices) is an array and not a sequence of arrays, you should use `np.take_along_axis` instead, which is far superior! For small to mid-sized arrays, it may even be fastest to use it with `np.asarray(choices)`, because it avoids many overheads. Not happy with the idea of extending the way choose works to many choices, I cooked up the approach below. My expectation is that it should be much faster for many choices, at least for larger arrays. The approach below moves the work of the element to pick from which choice into a (fairly involved) pre-processing step to make the final assignment more streamlined. Cheers, Sebastian ``` from itertools import chain def choose(a, choices): # Make sure we work with the correct result shape. # (this is not great if `a` ends up being broadcast) a_bc, *choices = np.broadcast_arrays(a, *choices) a = a_bc.ravel() sorter = np.argsort(a, axis=None) which = a[sorter] indices = np.meshgrid(*[np.arange(s) for s in a_bc.shape]) indices = [i.ravel()[sorter] for i in indices] out_dtype = np.result_type(*choices) result = np.empty(choices[0].shape, dtype=out_dtype) mask = np.empty(which.shape, dtype=bool) ends = np.flatnonzero(which[1:] != which[:-1]) start = 0 for end in chain(ends, [len(which)]): end += 1 choice = choices[which[start]] ind = tuple(i[start:end] for i in indices) result[ind] = choice[ind] start = end return result ```
All comments, complaints, or suggestions or code reviews appreciated.
thanks! -- Aaron Watters _______________________________________________ NumPy-Discussion mailing list -- numpy-discussion@python.org To unsubscribe send an email to numpy-discussion-leave@python.org https://mail.python.org/mailman3/lists/numpy-discussion.python.org/ Member address: sebastian@sipsolutions.net