[Scipy-svn] r3922 - in trunk/scipy/stats/models: . tests

scipy-svn at scipy.org scipy-svn at scipy.org
Tue Feb 12 14:47:46 EST 2008


Author: jonathan.taylor
Date: 2008-02-12 13:47:39 -0600 (Tue, 12 Feb 2008)
New Revision: 3922

Modified:
   trunk/scipy/stats/models/formula.py
   trunk/scipy/stats/models/tests/test_formula.py
Log:
added __getitem__ method to Factor 

fixed bug in ordinal factors, added two tests


Modified: trunk/scipy/stats/models/formula.py
===================================================================
--- trunk/scipy/stats/models/formula.py	2008-02-12 17:35:40 UTC (rev 3921)
+++ trunk/scipy/stats/models/formula.py	2008-02-12 19:47:39 UTC (rev 3922)
@@ -136,18 +136,23 @@
 
     def __init__(self, termname, keys, ordinal=False):
         """
-        factor is initialized with keys, representing all valid
+        Factor is initialized with keys, representing all valid
         levels of the factor.
+
+        If ordinal is True, the order is taken from the keys.
         """
 
-        self.keys = list(set(keys))
-        self.keys.sort()
+        if not ordinal:
+            self.keys = list(set(keys))
+            self.keys.sort()
+        else:
+            self.keys = keys
         self._name = termname
         self.termname = termname
         self.ordinal = ordinal
 
         if self.ordinal:
-            name = self.name
+            name = self.termname
         else:
             name = ['(%s==%s)' % (self.termname, str(key)) for key in self.keys]
 
@@ -166,12 +171,13 @@
                 v = v(*args, **kw)
             else: break
 
+        n = len(v)
+
         if self.ordinal:
-            col = [float(self.keys.index(v[i])) for i in range(len(self.keys))]
+            col = [float(self.keys.index(v[i])) for i in range(n)]
             return N.array(col)
 
         else:
-            n = len(v)
             value = []
             for key in self.keys:
                 col = [float((v[i] == key)) for i in range(n)]
@@ -241,6 +247,24 @@
         value.namespace = self.namespace
         return value
 
+    def __getitem__(self, key):
+        """
+        Retrieve the column corresponding to key in a Formula.
+        
+        :Parameters:
+            key : one of the Factor's keys
+        
+        :Returns: ndarray corresponding to key, when evaluated in
+                  current namespace
+        """
+        if not self.ordinal:
+            i = self.names().index('(%s==%s)' % (self.termname, str(key)))
+            return self()[i]
+        else:
+            v = self.namespace[self._name]
+            return N.array([(vv == key) for vv in v]).astype(N.float)
+
+
 class Quantitative(Term):
     """
     A subclass of term that can be used to apply point transformations

Modified: trunk/scipy/stats/models/tests/test_formula.py
===================================================================
--- trunk/scipy/stats/models/tests/test_formula.py	2008-02-12 17:35:40 UTC (rev 3921)
+++ trunk/scipy/stats/models/tests/test_formula.py	2008-02-12 19:47:39 UTC (rev 3922)
@@ -215,6 +215,39 @@
         _m = N.array([r[0]-r[2],r[1]-r[2]])
         assert_almost_equal(_m, m())
 
+    def test_factor5(self):
+        f = ['a','b','c']*3
+        fac = formula.Factor('ff', f)
+        fac.namespace = {'ff':f}
+
+        assert_equal(fac(), [[1,0,0]*3,
+                             [0,1,0]*3,
+                             [0,0,1]*3])
+        assert_equal(fac['a'], [1,0,0]*3)
+        assert_equal(fac['b'], [0,1,0]*3)
+        assert_equal(fac['c'], [0,0,1]*3)
+
+
+    def test_ordinal_factor(self):
+        f = ['a','b','c']*3
+        fac = formula.Factor('ff', f, ordinal=True)
+        fac.namespace = {'ff':f}
+
+        assert_equal(fac(), [0,1,2]*3)
+        assert_equal(fac['a'], [1,0,0]*3)
+        assert_equal(fac['b'], [0,1,0]*3)
+        assert_equal(fac['c'], [0,0,1]*3)
+
+    def test_ordinal_factor2(self):
+        f = ['b','c', 'a']*3
+        fac = formula.Factor('ff', ['a','b','c'], ordinal=True)
+        fac.namespace = {'ff':f}
+
+        assert_equal(fac(), [1,2,0]*3)
+        assert_equal(fac['a'], [0,0,1]*3)
+        assert_equal(fac['b'], [1,0,0]*3)
+        assert_equal(fac['c'], [0,1,0]*3)
+
     def test_contrast4(self):
 
         f = self.formula + self.terms[5] + self.terms[5]




More information about the Scipy-svn mailing list