[Python-ideas] overloading chained comparison

Nathaniel Smith njs at pobox.com
Tue Mar 18 22:48:27 CET 2014


Hi python-ideas,

Since Nick is being a killjoy and shutting down bikeshedding in the @
thread [1] until the numpy folks finish bikeshedding over it
themselves [2], I thought I'd throw out another idea for...
brainstorming [3].

Guido has suggested that while PEP 335 (overloading 'and', 'or',
'not') is rejected, he might be amenable to making chained comparisons
like a < b < c overloadable [4]. The idea is that right now,
  a < b < c
always expands to
  (a < b) and (b < c)
and the 'and' forces boolean coercion. When working with arrays, '<'
is often used elementwise, so we want the 'and' to apply elementwise
as well; in numpy, this is done using '&' (since we can't overload
'and'). So if a, b, c are numpy arrays, we'd like this to instead
expand to
  (a < b) & (b < c)
Similar considerations apply to other systems that overload the
comparison operators, e.g. DSLs for generating SQL queries.

This seems like a good and straightforward enough idea to me in
principle, but it's not at all clear to me what the best way to
accomplish it in practice is. I thought of three options, but none is
obviously Right.

To have a way to talk about our options precisely, let's pretend that
there's something called operator.chain_comparison, and that the way
it works is that
  a < b <= c
produces a call like
  operator.chain_comparison([a, b, c], [operator.lt, operator.le])

Right now, the semantics are:

# Current
def chain_comparison(args, ops):
    for i, op in enumerate(ops):
        result = op(args[i], args[i + 1]):
        # short-circuit
        if not result:
            return result
    return result

(Of course in reality in CPython the compiler unrolls the loop and
inlines this directly into the bytecode, but whatever, that's an
implementation detail.)

IDEA 1: simple, neat and (sort of) wrong

Let's define a new special method __chain_and__; whenever we do a
chain comparison, we check for the presence of this method, and if
found, we call it instead of using 'and'. The intuition is that if we
have a < b < c then this expands to either
  (a < b) and (b < c)
or
  (a < b).__chain_and__(b < c)
depending on whether hasattr((a < b), "__chain_and__"). Notice that
the first case is short-circuiting, and the second is not. Which seems
totally fine (contra PEP 335), because short-circuiting by definition
requires that you make a boolean decision (quit early/don't quit
early), and the whole point of these overloads is to avoid boolean
coercion. I think in general the semantics here look like:

# __chain_and__
def chain_comparison(args, ops):
    so_far = True
    for i, op in enumerate(ops):
        result = op(args[i], args[i + 1])
        if hasattr(so_far, "__chain_and__"):
            so_far = so_far.__chain_and__(result)
        else:
            so_far = so_far and result
        # short-circuit, but only if the next reduction would use 'and':
        if not hasattr(so_far, "__chain_and__") and not so_far:
            return so_far
    return so_far

So if 'arr' is a numpy array, then code like
  0 < arr < 1
will now work great, because (0 < arr) will return an array of
booleans, and this array will have the __chain_and__ method, so we'll
end up doing
  (0 < arr).__chain_and__(arr < 1)
and successfully return a single array of booleans indicating which
locations in 'arr' are between 0 and 1.

But -- suppose that we have, say, a GUI displaying a barchart, and we
have a horizontal "threshold" line across the barchart that the user
control. The idea is that they can move it up and down, and all the
bars that it overlaps will change color, so we can easily see which
bars are above the threshold and which are below it. So given the
current threshold 't' (an ordinary float or int), and a 1d array
holding the various bar heights, we can request the set of overlapping
bars as:
  0 < x < arr
Okay, first problem is that this fails because here we don't want to do
  (0 < x).__chain_and__(x < arr)
because (0 < x) is just False. So okay, let's say we enhance the above
definition to allow for __rchain_and__, and we get
  (x < arr).__rchain_and__(0 < x)
Great! So we implement it and we test it in our program and it works
-- under ordinary conditions, this spits out a nice 1d array of
booleans, our code continues on its merry way, using the values in
this array to decide how each bar in our plot should be colored.

Until the user slides the threshold line down below zero, at which
point (0 < x) starts returning False, and we short-circuit out before
even evaluating (x < arr). And then our code blows up because instead
of getting an array like we expected, we just get False instead. Oops.
 __rchain_and__ is useless.

So I guess the solution is to write
  arr > x > 0
which will always work? I'm not sure how common this case is in
practice, but I find it somewhat disturbing that with the
__chain_and__ approach, a < b < c and c > b > a can return completely
different values even for completely well-behaved objects, and there's
nothing the person writing overload methods can do about it. Certainly
the bug in our original code is not obvious to the casual reader. (And
AFAICT the proposal in PEP 335 also has this problem -- in fact PEP
335 basically is the same as this proposal -- so no help there.)

IDEA 2: the FULLY GENERAL solution (uh oh)

So we started with a call like
  operator.chain_comparison([a, b, c], [operator.lt, operator.le])
Maybe what we should do is treat *this* as the basic operator, and try
calling a special __chain_comparison__ method on a, b, and/or c.

Of course this immediately runs into a problem, because all python's
existing operators have only 1 or 2 arguments [5], not an indefinite
and varying number of arguments. So we can't use the standard
__X__/__rX__ dispatch strategy. We need something like multimethod
dispatch. (Cue thunder, ominous music.) I know this has been discussed
to death, and I haven't read the discussion, so I guess people can
have fun educating me if they feel like it. But what I'd suggest in
this case is to do what we did in numpy to solve a similar problem
[6]. Instead of using a "real" multimethod system, just directly
generalize the __X__/__rX__ trick: when looking for a candidate object
to call __chain_comparison__ on, take the leftmost one that (a) hasn't
been tried, and (b) doesn't have another object which is a proper
subclass of its type that also hasn't been tried. If you get
NotImplemented, keep trying until you run out of candidates; then fall
back to the traditional 'and'-based semantics.

This does solve our problematic case above: only 'arr' implements
__chain_comparison__, so we have
  0 < x < arr
becoming
  arr.__chain_comparison__([0, x, arr], [operator.lt, operator.lt])
and then that can just call the underlying rich comparison operators
in a non-short-circuiting loop, and combine the results using '&'
instead of 'and'.

But it does require a somewhat odd looking piece of machinery for the dispatch.

OPTION 3: the FULLY SPECIFIC solution

A nice thing about the __chain_comparison__ method is that it will
actually be identical for all array-like objects, no matter which
library they're defined in. So, if there are multiple array-like
objects from different libraries in a single chain, it doesn't matter
who gets picked to handle the overall evaluation -- any one of them is
good enough, and then the actual interoperability problems are
delegated to the rich comparison methods, which already have to have a
strategy for dealing with them.

In fact, this same __chain_comparison__ can also probably be used for
just about anyone who wants to define it -- e.g., I think most DB
query DSLs are going to overload & to mean "and", right?

So, maybe all we need is a flag that says: if you see an object with
this flag anywhere in a chain, then switch to these semantics:

# &-flag
def chain_comparison(args, ops):
    for i, op in enumerate(ops):
        result = op(args[i], args[i + 1]):
        if i == 0:
            combined = result
        else:
            combined &= result
    return combined

...and otherwise, use the standard semantics.

So that's what I got. None of these approaches seems obviously Right,
but they're all at least sort of viable. Anyone else got any ideas?

-n

[1] https://mail.python.org/pipermail/python-ideas/2014-March/027174.html
[2] http://mail.scipy.org/pipermail/numpy-discussion/2014-March/069444.html
[3] "We're more of the love, bikeshedding, and rhetoric school. Well,
we can do you bikeshedding and love without the rhetoric, and we can
do you bikeshedding and rhetoric without the love, and we can do you
all three concurrent or consecutive. But we can't give you love and
rhetoric without the bikeshedding. Bikeshedding is compulsory."
[4] https://mail.python.org/pipermail/python-dev/2012-March/117510.html
[5] Except pow(), but that doesn't really count because it never
dispatches on the third argument.
[6] The next release of numpy allows non-numpy array-like classes to
handle all numpy math functions in a generic way by dispatch to a
special __numpy_ufunc__ method; I can provide more details on what's
going on here if anyone's curious:
http://docs.scipy.org/doc/numpy-dev/reference/arrays.classes.html#numpy.class.__numpy_ufunc__

-- 
Nathaniel J. Smith
Postdoctoral researcher - Informatics - University of Edinburgh
http://vorpus.org


More information about the Python-ideas mailing list