Simple Python implementation of bag/multiset

MRAB google at mrabarnett.plus.com
Thu Jun 19 15:58:42 CEST 2008


While another thread is talking about an ordered dict, I thought I'd
try a simple implementation of a bag/multiset in Python. Comments/
suggestions welcome, etc.

class bag(object):
    def __add__(self, other):
        result = self.copy()
        for item, count in other.iteritems():
            result._items[item] = result._items.get(item, 0) + count
        return result
    def __and__(self, other):
        result = bag()
        for item, count in other.iteritems():
            new_count = min(self._items.get(item, 0), count)
            if new_count > 0:
                result._items[item] = new_count
        return result
    def __contains__(self, item):
        return item in self._items
    def __eq__(self, other):
        return self._items == other._items
    def __getitem__(self, item):
        return self._items[item]
    def __iadd__(self, other):
        self._items = self.__add__(other)._items
        return self
    def __iand__(self, other):
        self._items = self.__and__(other)._items
        return self
    def __init__(self, iterable=None):
        self._items = {}
        if iterable is not None:
            for item in iterable:
                self._items[item] = self._items.get(item, 0) + 1
    def __ior__(self, other):
        self._items = self.__or__(other)._items
        return self
    def __isub__(self, other):
        self._items = self.__sub__(other)._items
        return self
    def __iter__(self):
        for item, count in self.iteritems():
            for counter in xrange(count):
                yield item
    def __ixor__(self, other):
        self._items = self.__xor__(other)._items
        return self
    def __len__(self):
        return sum(self._items.itervalues())
    def __ne__(self, other):
        return self._items != other._items
    def __or__(self, other):
        result = self.copy()
        for item, count in other.iteritems():
            result._items[item] = max(result._items.get(item, 0),
count)
        return result
    def __repr__(self):
        result = []
        for item, count in self.iteritems():
            result += [repr(item)] * count
        return 'bag([%s])' % ', '.join(result)
    def __setitem__(self, item, count):
        if not isinstance(count, int) or count < 0:
            raise ValueError
        if count > 0:
            self._items[item] = count
        elif item in self._items:
            del self._items[item]
    def __sub__(self, other):
        result = bag()
        for item, count in self.iteritems():
            new_count = count - other._items.get(item, 0)
            if new_count > 0:
                result._items[item] = new_count
        return result
    def __xor__(self, other):
        result = self.copy()
        for item, count in other.iteritems():
            new_count = abs(result._items.get(item, 0) - count)
            if new_count > 0:
                result._items[item] = new_count
            elif item in result._item:
                del result._items[item]
        return result
    def add(self, item):
        self._items[item] = self._items.get(item, 0) + 1
    def discard(self, item):
        new_count = self._items.get(item, 0) - 1
        if new_count > 0:
            self._items[item] = new_count
        elif new_count == 0:
            del self._items[item]
    def clear(self):
        self._items = {}
    def copy(self):
        result = bag()
        result._items = self._items.copy()
        return result
    def difference(self, other):
        return self.__sub__(other)
    def difference_update(self, other):
        self._items = self.__sub__(other)._items
    def get(self, item, default=0):
        return self._items.get(item, default)
    def intersection(self, other):
        return self.__and__(other)
    def intersection_update(self, other):
        self.__iand__(other)
    def items(self):
        return self._items.items()
    def iteritems(self):
        return self._items.iteritems()
    def iterkeys(self):
        return self._items.iterkeys()
    def itervalues(self):
        return self._items.itervalues()
    def keys(self):
        return self._items.keys()
    def pop(self):
        item = self._items.keys()[0]
        self._items[item] -= 1
        if self._items[item] == 0:
            del self._items[item]
        return item
    def remove(self, item):
        new_count = self._items[item] - 1
        if new_count > 0:
            self._items[item] = new_count
        else:
            del self._items[item]
    def symmetric_difference(self, other):
        return self.__xor__(other)
    def symmetric_difference_update(self, other):
        self.__ixor__(other)
    def union(self, other):
        return self.__or__(other)
    def update(self, other):
        self.__ior__(other)
    def values(self):
        return self._items.values()



More information about the Python-list mailing list