Boilerplate in rich comparison methods

Dan Bishop danb_83 at yahoo.com
Sat Jan 13 13:43:24 EST 2007


On Jan 13, 12:52 am, Steven D'Aprano
<s... at REMOVE.THIS.cybersource.com.au> wrote:
> I'm writing a class that implements rich comparisons, and I find myself
> writing a lot of very similar code. If the calculation is short and
> simple, I do something like this:
>
> class Parrot:
>     def __eq__(self, other):
>         return self.plumage() == other.plumage()
>     def __ne__(self, other):
>         return self.plumage() != other.plumage()
>     def __lt__(self, other):
>         return self.plumage() < other.plumage()
>     def __gt__(self, other):
>         return self.plumage() > other.plumage()
>     def __le__(self, other):
>         return self.plumage() <= other.plumage()
>     def __ge__(self, other):
>         return self.plumage() >= other.plumage()
>
> If the comparison requires a lot of work, I'll do something like this:
>
> class Aardvark:
>     def __le__(self, other):
>         return lots_of_work(self, other)
>     def __gt__(self, other):
>         return not self <= other
>     # etc.
>
> But I can't help feeling that there is a better way. What do others do?

Typically, I write only two kinds of classes that use comparion
operators: (1) ones that can get by with __cmp__ and (2) ones that
define __eq__ and __ne__ without any of the other four.

But for your case, I'd say you're doing it the right way.  If you
define a lot of classes like Parrot, you might want to try moving the
six operators to a common base class:

class Comparable:
    """
    Abstract base class for classes using rich comparisons.
    Objects are compared using their cmp_key() method.
    """
    def __eq__(self, other):
        return (self is other) or (self.cmp_key() == other.cmp_key())
    def __ne__(self, other):
        return (self is not other) and (self.cmp_key() !=
other.cmp_key())
    def __lt__(self, other):
        return self.cmp_key() < other.cmp_key()
    def __le__(self, other):
        return self.cmp_key() <= other.cmp_key()
    def __gt__(self, other):
        return self.cmp_key() > other.cmp_key()
    def __ge__(self, other):
        return self.cmp_key() >= other.cmp_key()
    def cmp_key(self):
        """Overriden by derived classes to define a comparison key."""
        raise NotImplementedError()

class Parrot(Comparable):
    def cmp_key(self):
        return self.plumage()
    # ...




More information about the Python-list mailing list