[Scipy-svn] r2295 - in trunk/Lib/sandbox/pyem: . tests
scipy-svn at scipy.org
scipy-svn at scipy.org
Mon Oct 23 06:52:32 EDT 2006
Author: cdavid
Date: 2006-10-23 05:52:27 -0500 (Mon, 23 Oct 2006)
New Revision: 2295
Modified:
trunk/Lib/sandbox/pyem/online_em.py
trunk/Lib/sandbox/pyem/tests/test_online_em.py
Log:
* pyem: trivial change of API for OGMM
Modified: trunk/Lib/sandbox/pyem/online_em.py
===================================================================
--- trunk/Lib/sandbox/pyem/online_em.py 2006-10-23 09:52:04 UTC (rev 2294)
+++ trunk/Lib/sandbox/pyem/online_em.py 2006-10-23 10:52:27 UTC (rev 2295)
@@ -1,5 +1,5 @@
# /usr/bin/python
-# Last Change: Fri Oct 20 12:00 PM 2006 J
+# Last Change: Mon Oct 23 07:00 PM 2006 J
#---------------------------------------------
# This is not meant to be used yet !!!! I am
@@ -135,7 +135,7 @@
self.init = init_methods[init]
- def sufficient_statistics(self, frame, nu):
+ def compute_sufficient_statistics(self, frame, nu):
""" sufficient_statistics(frame, nu)
frame has to be rank 2 !"""
@@ -147,13 +147,12 @@
self.cw *= (1 - nu)
self.cw += nu * gamma
- return gamma
-
- def update_em(self, frame, gamma, nu):
for k in range(self.gm.k):
self.cx[k] = (1 - nu) * self.cx[k] + nu * frame * gamma[k]
self.cxx[k] = (1 - nu) * self.cxx[k] + nu * frame ** 2 * gamma[k]
+ def update_em(self):
+ for k in range(self.gm.k):
self.cmu[k] = self.cx[k] / self.cw[k]
self.cva[k] = self.cxx[k] / self.cw[k] - self.cmu[k] ** 2
@@ -209,8 +208,8 @@
# object version of online EM
for t in range(nframes):
- gamma = ogmm.sufficient_statistics(data[t:t+1, :], nu[t])
- ogmm.update_em(data[t, :], gamma, nu[t])
+ ogmm.compute_sufficient_statistics(data[t:t+1, :], nu[t])
+ ogmm.update_em()
ogmm.gm.set_param(ogmm.cw, ogmm.cmu, ogmm.cva)
Modified: trunk/Lib/sandbox/pyem/tests/test_online_em.py
===================================================================
--- trunk/Lib/sandbox/pyem/tests/test_online_em.py 2006-10-23 09:52:04 UTC (rev 2294)
+++ trunk/Lib/sandbox/pyem/tests/test_online_em.py 2006-10-23 10:52:27 UTC (rev 2295)
@@ -1,5 +1,5 @@
#! /usr/bin/env python
-# Last Change: Fri Oct 20 12:00 PM 2006 J
+# Last Change: Mon Oct 23 07:00 PM 2006 J
import copy
@@ -116,8 +116,8 @@
ogmm.pva = ogmm.cva.copy()
for e in range(emiter):
for t in range(nframes):
- gamma = ogmm.sufficient_statistics(self.data[t:t+1, :], nu[t])
- ogmm.update_em(self.data[t, :], gamma, nu[t])
+ ogmm.compute_sufficient_statistics(self.data[t:t+1, :], nu[t])
+ ogmm.update_em()
# Change pw args only a each epoch
ogmm.pw = ogmm.cw.copy()
@@ -179,8 +179,8 @@
assert ogmm.pw is ogmm.cw
assert ogmm.pmu is ogmm.cmu
assert ogmm.pva is ogmm.cva
- gamma = ogmm.sufficient_statistics(self.data[t:t+1, :], nu[t])
- ogmm.update_em(self.data[t, :], gamma, nu[t])
+ ogmm.compute_sufficient_statistics(self.data[t:t+1, :], nu[t])
+ ogmm.update_em()
ogmm.gm.set_param(ogmm.cw, ogmm.cmu, ogmm.cva)
More information about the Scipy-svn
mailing list