[Jython-checkins] jython: Fix race condition calling defaultdict#__missing__ for derived classes.
jim.baker
jython-checkins at python.org
Thu May 22 22:02:42 CEST 2014
http://hg.python.org/jython/rev/1a7fdc3681b2
changeset: 7266:1a7fdc3681b2
user: Indra Talip <indra.talip at gmail.com>
date: Wed May 21 11:05:26 2014 +1000
summary:
Fix race condition calling defaultdict#__missing__ for derived classes.
Moves the test for deferring to the subclasses __missing__ to the
CacheLoader#load in order to have atomic behaviour with regards to
calling __missing__.
Adds a test to ThreadSafetyTestCase for subclasses of defaultdict.
files:
Lib/test/test_defaultdict_jy.py | 114 ++++++---
src/org/python/modules/_collections/PyDefaultDict.java | 17 +-
2 files changed, 79 insertions(+), 52 deletions(-)
diff --git a/Lib/test/test_defaultdict_jy.py b/Lib/test/test_defaultdict_jy.py
--- a/Lib/test/test_defaultdict_jy.py
+++ b/Lib/test/test_defaultdict_jy.py
@@ -38,30 +38,29 @@
for t in threads:
self.assertFalse(t.isAlive())
+ class Counter(object):
+ def __init__(self, initial=0):
+ self.atomic = AtomicInteger(initial)
+ # waiting is important here to ensure that
+ # defaultdict factories can step on each other
+ time.sleep(0.001)
+
+ def decrementAndGet(self):
+ return self.atomic.decrementAndGet()
+
+ def incrementAndGet(self):
+ return self.atomic.incrementAndGet()
+
+ def get(self):
+ return self.atomic.get()
+
+ def __repr__(self):
+ return "Counter<%s>" % (self.atomic.get())
+
def test_inc_dec(self):
+ counters = defaultdict(ThreadSafetyTestCase.Counter)
+ size = 17
- class Counter(object):
- def __init__(self):
- self.atomic = AtomicInteger()
- # waiting is important here to ensure that
- # defaultdict factories can step on each other
- time.sleep(0.001)
-
- def decrementAndGet(self):
- return self.atomic.decrementAndGet()
-
- def incrementAndGet(self):
- return self.atomic.incrementAndGet()
-
- def get(self):
- return self.atomic.get()
-
- def __repr__(self):
- return "Counter<%s>" % (self.atomic.get())
-
- counters = defaultdict(Counter)
- size = 17
-
def tester():
for i in xrange(1000):
j = (i + randint(0, size)) % size
@@ -70,10 +69,36 @@
counters[j].incrementAndGet()
self.run_threads(tester, 20)
-
+
for i in xrange(size):
self.assertEqual(counters[i].get(), 0, counters)
+ def test_derived_inc_dec(self):
+ class DerivedDefaultDict(defaultdict):
+ def __missing__(self, key):
+ if self.default_factory is None:
+ raise KeyError("Invalid key '{0}' and no default factory was set")
+
+ val = self.default_factory(key)
+
+ self[key] = val
+ return val
+
+ counters = DerivedDefaultDict(lambda key: ThreadSafetyTestCase.Counter(key))
+ size = 17
+
+ def tester():
+ for i in xrange(1000):
+ j = (i + randint(0, size)) % size
+ counters[j].decrementAndGet()
+ time.sleep(0.0001)
+ counters[j].incrementAndGet()
+
+ self.run_threads(tester, 20)
+
+ for i in xrange(size):
+ self.assertEqual(counters[i].get(), i, counters)
+
class GetVariantsTestCase(unittest.TestCase):
#http://bugs.jython.org/issue2133
@@ -94,34 +119,35 @@
self.assertEquals(d.items(), [("vivify", [])])
-class KeyDefaultDict(defaultdict):
- """defaultdict to pass the requested key to factory function."""
- def __missing__(self, key):
- if self.default_factory is None:
- raise KeyError("Invalid key '{0}' and no default factory was set")
- else:
- val = self.default_factory(key)
-
- self[key] = val
- return val
-
- @classmethod
- def double(cls, k):
- return k + k
class OverrideMissingTestCase(unittest.TestCase):
+ class KeyDefaultDict(defaultdict):
+ """defaultdict to pass the requested key to factory function."""
+ def __missing__(self, key):
+ if self.default_factory is None:
+ raise KeyError("Invalid key '{0}' and no default factory was set")
+ else:
+ val = self.default_factory(key)
+
+ self[key] = val
+ return val
+
+ @classmethod
+ def double(cls, k):
+ return k + k
+
+ def setUp(self):
+ self.kdd = OverrideMissingTestCase.KeyDefaultDict(OverrideMissingTestCase.KeyDefaultDict.double)
+
def test_dont_call_derived_missing(self):
- kdd = KeyDefaultDict(KeyDefaultDict.double)
- kdd[3] = 5
- self.assertEquals(kdd[3], 5)
+ self.kdd[3] = 5
+ self.assertEquals(self.kdd[3], 5)
#http://bugs.jython.org/issue2088
def test_override_missing(self):
-
- kdd = KeyDefaultDict(KeyDefaultDict.double)
# line below causes KeyError in Jython, ignoring overridden __missing__ method
- self.assertEquals(kdd[3], 6)
- self.assertEquals(kdd['ab'], 'abab')
+ self.assertEquals(self.kdd[3], 6)
+ self.assertEquals(self.kdd['ab'], 'abab')
def test_main():
diff --git a/src/org/python/modules/_collections/PyDefaultDict.java b/src/org/python/modules/_collections/PyDefaultDict.java
--- a/src/org/python/modules/_collections/PyDefaultDict.java
+++ b/src/org/python/modules/_collections/PyDefaultDict.java
@@ -58,6 +58,15 @@
backingMap = CacheBuilder.newBuilder().build(
new CacheLoader<PyObject, PyObject>() {
public PyObject load(PyObject key) {
+ PyType self_type = getType();
+ if (self_type != TYPE) {
+ // Is a subclass. If it exists call the subclasses __missing__.
+ // Otherwise PyDefaultDic.defaultdict___missing__() will
+ // be invoked.
+ return PyDefaultDict.this.invoke("__missing__", key);
+ }
+
+ // in-lined __missing__
if (defaultFactory == Py.None) {
throw Py.KeyError(key);
}
@@ -167,14 +176,6 @@
@ExposedMethod(doc = BuiltinDocs.dict___getitem___doc)
protected final PyObject defaultdict___getitem__(PyObject key) {
try {
- PyType type = getType();
- if (!getMap().containsKey(key) && type != TYPE) {
- // is a subclass. if it exists call the subclasses __missing__
- PyObject missing = type.lookup("__missing__");
- if (missing != null) {
- return missing.__get__(this, type).__call__(key);
- }
- }
return backingMap.get(key);
} catch (Exception ex) {
throw Py.KeyError(key);
--
Repository URL: http://hg.python.org/jython
More information about the Jython-checkins
mailing list