[Scipy-svn] r2295 - in trunk/Lib/sandbox/pyem: . tests

scipy-svn at scipy.org scipy-svn at scipy.org
Mon Oct 23 06:52:32 EDT 2006


Author: cdavid
Date: 2006-10-23 05:52:27 -0500 (Mon, 23 Oct 2006)
New Revision: 2295

Modified:
   trunk/Lib/sandbox/pyem/online_em.py
   trunk/Lib/sandbox/pyem/tests/test_online_em.py
Log:
 * pyem: trivial change of API for OGMM 



Modified: trunk/Lib/sandbox/pyem/online_em.py
===================================================================
--- trunk/Lib/sandbox/pyem/online_em.py	2006-10-23 09:52:04 UTC (rev 2294)
+++ trunk/Lib/sandbox/pyem/online_em.py	2006-10-23 10:52:27 UTC (rev 2295)
@@ -1,5 +1,5 @@
 # /usr/bin/python
-# Last Change: Fri Oct 20 12:00 PM 2006 J
+# Last Change: Mon Oct 23 07:00 PM 2006 J
 
 #---------------------------------------------
 # This is not meant to be used yet !!!! I am 
@@ -135,7 +135,7 @@
 
         self.init   = init_methods[init]
 
-    def sufficient_statistics(self, frame, nu):
+    def compute_sufficient_statistics(self, frame, nu):
         """ sufficient_statistics(frame, nu)
         
         frame has to be rank 2 !"""
@@ -147,13 +147,12 @@
         self.cw	*= (1 - nu)
         self.cw += nu * gamma
 
-        return gamma
-
-    def update_em(self, frame, gamma, nu):
         for k in range(self.gm.k):
             self.cx[k]   = (1 - nu) * self.cx[k] + nu * frame * gamma[k]
             self.cxx[k]  = (1 - nu) * self.cxx[k] + nu * frame ** 2 * gamma[k]
 
+    def update_em(self):
+        for k in range(self.gm.k):
             self.cmu[k]  = self.cx[k] / self.cw[k]
             self.cva[k]  = self.cxx[k] / self.cw[k] - self.cmu[k] ** 2
     
@@ -209,8 +208,8 @@
 
     # object version of online EM
     for t in range(nframes):
-        gamma   = ogmm.sufficient_statistics(data[t:t+1, :], nu[t])
-        ogmm.update_em(data[t, :], gamma, nu[t])
+        ogmm.compute_sufficient_statistics(data[t:t+1, :], nu[t])
+        ogmm.update_em()
 
     ogmm.gm.set_param(ogmm.cw, ogmm.cmu, ogmm.cva)
 

Modified: trunk/Lib/sandbox/pyem/tests/test_online_em.py
===================================================================
--- trunk/Lib/sandbox/pyem/tests/test_online_em.py	2006-10-23 09:52:04 UTC (rev 2294)
+++ trunk/Lib/sandbox/pyem/tests/test_online_em.py	2006-10-23 10:52:27 UTC (rev 2295)
@@ -1,5 +1,5 @@
 #! /usr/bin/env python
-# Last Change: Fri Oct 20 12:00 PM 2006 J
+# Last Change: Mon Oct 23 07:00 PM 2006 J
 
 import copy
 
@@ -116,8 +116,8 @@
         ogmm.pva   = ogmm.cva.copy()
         for e in range(emiter):
             for t in range(nframes):
-                gamma   = ogmm.sufficient_statistics(self.data[t:t+1, :], nu[t])
-                ogmm.update_em(self.data[t, :], gamma, nu[t])
+                ogmm.compute_sufficient_statistics(self.data[t:t+1, :], nu[t])
+                ogmm.update_em()
 
             # Change pw args only a each epoch 
             ogmm.pw  = ogmm.cw.copy()
@@ -179,8 +179,8 @@
             assert ogmm.pw is ogmm.cw
             assert ogmm.pmu is ogmm.cmu
             assert ogmm.pva is ogmm.cva
-            gamma   = ogmm.sufficient_statistics(self.data[t:t+1, :], nu[t])
-            ogmm.update_em(self.data[t, :], gamma, nu[t])
+            ogmm.compute_sufficient_statistics(self.data[t:t+1, :], nu[t])
+            ogmm.update_em()
 
         ogmm.gm.set_param(ogmm.cw, ogmm.cmu, ogmm.cva)
 




More information about the Scipy-svn mailing list