[pypy-svn] r79505 - pypy/trunk/pypy/module/itertools

arigo at codespeak.net arigo at codespeak.net
Thu Nov 25 13:09:13 CET 2010


Author: arigo
Date: Thu Nov 25 13:09:12 2010
New Revision: 79505

Modified:
   pypy/trunk/pypy/module/itertools/interp_itertools.py
Log:
JIT-friendly rewrites: where it is possible, move the loops to
subfunctions to prevent the whole function from being seen by the jit.


Modified: pypy/trunk/pypy/module/itertools/interp_itertools.py
==============================================================================
--- pypy/trunk/pypy/module/itertools/interp_itertools.py	(original)
+++ pypy/trunk/pypy/module/itertools/interp_itertools.py	Thu Nov 25 13:09:12 2010
@@ -391,7 +391,8 @@
         self.iterators_w = iterators_w
         self.current_iterator = 0
         self.num_iterators = len(iterators_w)
-        self.started = False
+        if self.num_iterators > 0:
+            self.w_it = iterators_w[0]
 
     def iter_w(self):
         return self.space.wrap(self)
@@ -399,26 +400,23 @@
     def next_w(self):
         if self.current_iterator >= self.num_iterators:
             raise OperationError(self.space.w_StopIteration, self.space.w_None)
-        if not self.started:
-            self.current_iterator = 0
-            self.w_it = self.iterators_w[self.current_iterator]
-            self.started = True
+        try:
+            return self.space.next(self.w_it)
+        except OperationError, e:
+            return self._handle_error(e)
 
+    def _handle_error(self, e):
         while True:
+            if not e.match(self.space, self.space.w_StopIteration):
+                raise e
+            self.current_iterator += 1
+            if self.current_iterator >= self.num_iterators:
+                raise e
+            self.w_it = self.iterators_w[self.current_iterator]
             try:
-                w_obj = self.space.next(self.w_it)
+                return self.space.next(self.w_it)
             except OperationError, e:
-                if e.match(self.space, self.space.w_StopIteration):
-                    self.current_iterator += 1
-                    if self.current_iterator >= self.num_iterators:
-                        raise OperationError(self.space.w_StopIteration, self.space.w_None)
-                    else:
-                        self.w_it = self.iterators_w[self.current_iterator]
-                else:
-                    raise
-            else:
-                break
-        return w_obj
+                pass   # loop back to the start of _handle_error(e)
 
 def W_Chain___new__(space, w_subtype, args_w):
     return space.wrap(W_Chain(space, args_w))
@@ -446,8 +444,10 @@
 
     def __init__(self, space, w_fun, args_w):
         self.space = space
-        self.identity_fun = (self.space.is_w(w_fun, space.w_None))
-        self.w_fun = w_fun
+        if self.space.is_w(w_fun, space.w_None):
+            self.w_fun = None
+        else:
+            self.w_fun = w_fun
 
         iterators_w = []
         i = 0
@@ -470,12 +470,26 @@
         return self.space.wrap(self)
 
     def next_w(self):
-        w_objects = self.space.newtuple([self.space.next(w_it) for w_it in self.iterators_w])
-        if self.identity_fun:
+        # common case: 1 or 2 arguments
+        iterators_w = self.iterators_w
+        length = len(iterators_w)
+        if length == 1:
+            objects = [self.space.next(iterators_w[0])]
+        elif length == 2:
+            objects = [self.space.next(iterators_w[0]),
+                       self.space.next(iterators_w[1])]
+        else:
+            objects = self._get_objects()
+        w_objects = self.space.newtuple(objects)
+        if self.w_fun is None:
             return w_objects
         else:
             return self.space.call(self.w_fun, w_objects)
 
+    def _get_objects(self):
+        # the loop is out of the way of the JIT
+        return [self.space.next(w_elem) for w_elem in self.iterators_w]
+
 
 def W_IMap___new__(space, w_subtype, w_fun, args_w):
     if len(args_w) == 0:
@@ -769,15 +783,7 @@
             raise OperationError(self.space.w_StopIteration, self.space.w_None)
 
         if not self.new_group:
-            # Consume unwanted input until we reach the next group
-            try:
-                while True:
-                    self.group_next(self.index)
-
-            except StopIteration:
-                pass
-            if self.exhausted:
-                raise OperationError(self.space.w_StopIteration, self.space.w_None)
+            self._consume_unwanted_input()
 
         if not self.started:
             self.started = True
@@ -799,6 +805,16 @@
         w_iterator = self.space.wrap(W_GroupByIterator(self.space, self.index, self))
         return self.space.newtuple([self.w_key, w_iterator])
 
+    def _consume_unwanted_input(self):
+        # Consume unwanted input until we reach the next group
+        try:
+            while True:
+                self.group_next(self.index)
+        except StopIteration:
+            pass
+        if self.exhausted:
+            raise OperationError(self.space.w_StopIteration, self.space.w_None)
+
     def group_next(self, group_index):
         if group_index < self.index:
             raise StopIteration



More information about the Pypy-commit mailing list