[Scipy-svn] r2358 - trunk/Lib/sandbox/pyem/profile_data
scipy-svn at scipy.org
scipy-svn at scipy.org
Wed Dec 6 07:48:15 EST 2006
Author: cdavid
Date: 2006-12-06 06:46:23 -0600 (Wed, 06 Dec 2006)
New Revision: 2358
Added:
trunk/Lib/sandbox/pyem/profile_data/profile_online_em.py
Log:
Add a script to profile online em
Added: trunk/Lib/sandbox/pyem/profile_data/profile_online_em.py
===================================================================
--- trunk/Lib/sandbox/pyem/profile_data/profile_online_em.py 2006-12-06 12:27:51 UTC (rev 2357)
+++ trunk/Lib/sandbox/pyem/profile_data/profile_online_em.py 2006-12-06 12:46:23 UTC (rev 2358)
@@ -0,0 +1,241 @@
+# /usr/bin/python
+# Last Change: Wed Dec 06 08:00 PM 2006 J
+import copy
+
+import numpy as N
+
+from gauss_mix import GM
+from gmm_em import GMM
+
+def _generate_data(nframes, d, k, mode = 'diag'):
+ N.random.seed(0)
+ w, mu, va = GM.gen_param(d, k, mode, spread = 1.5)
+ gm = GM.fromvalues(w, mu, va)
+ # Sample nframes frames from the model
+ data = gm.sample(nframes)
+
+ #++++++++++++++++++++++++++++++++++++++++++
+ # Approximate the models with classical EM
+ #++++++++++++++++++++++++++++++++++++++++++
+ emiter = 5
+ # Init the model
+ lgm = GM(d, k, mode)
+ gmm = GMM(lgm, 'kmean')
+ gmm.init(data)
+
+ gm0 = copy.copy(gmm.gm)
+ # The actual EM, with likelihood computation
+ like = N.zeros(emiter)
+ for i in range(emiter):
+ g, tgd = gmm.sufficient_statistics(data)
+ like[i] = N.sum(N.log(N.sum(tgd, 1)), axis = 0)
+ gmm.update_em(data, g)
+
+ return data, gm
+
+nframes = int(5e3)
+d = 1
+k = 2
+niter = 1
+
+def test_v1():
+ # Generate test data
+ data, gm = _generate_data(nframes, d, k)
+ for i in range(niter):
+ iter_1(data, gm)
+
+def test_v2():
+ # Generate test data
+ data, gm = _generate_data(nframes, d, k)
+ for i in range(niter):
+ iter_2(data, gm)
+
+def test_v3():
+ # Generate test data
+ data, gm = _generate_data(nframes, d, k)
+ for i in range(niter):
+ iter_3(data, gm)
+
+def test_v4():
+ # Generate test data
+ data, gm = _generate_data(nframes, d, k)
+ for i in range(niter):
+ iter_4(data, gm)
+
+def iter_1(data, gm):
+ """Online EM with original densities + original API"""
+ from online_em import OnGMM
+
+ nframes = data.shape[0]
+ ogm = copy.copy(gm)
+ ogmm = OnGMM(ogm, 'kmean')
+ init_data = data[0:nframes / 20, :]
+ ogmm.init(init_data)
+
+ # Forgetting param
+ ku = 0.005
+ t0 = 200
+ lamb = 1 - 1/(N.arange(-1, nframes-1) * ku + t0)
+ nu0 = 0.2
+ nu = N.zeros((len(lamb), 1))
+ nu[0] = nu0
+ for i in range(1, len(lamb)):
+ nu[i] = 1./(1 + lamb[i] / nu[i-1])
+
+ # object version of online EM
+ for t in range(nframes):
+ ogmm.compute_sufficient_statistics_frame(data[t], nu[t])
+ ogmm.update_em_frame()
+
+ ogmm.gm.set_param(ogmm.cw, ogmm.cmu, ogmm.cva)
+ print ogmm.cw
+ print ogmm.cmu
+ print ogmm.cva
+
+def iter_2(data, gm):
+ """Online EM with densities2 + original API"""
+ from online_em2 import OnGMM
+
+ nframes = data.shape[0]
+ ogm = copy.copy(gm)
+ ogmm = OnGMM(ogm, 'kmean')
+ init_data = data[0:nframes / 20, :]
+ ogmm.init(init_data)
+
+ # Forgetting param
+ ku = 0.005
+ t0 = 200
+ lamb = 1 - 1/(N.arange(-1, nframes-1) * ku + t0)
+ nu0 = 0.2
+ nu = N.zeros((len(lamb), 1))
+ nu[0] = nu0
+ for i in range(1, len(lamb)):
+ nu[i] = 1./(1 + lamb[i] / nu[i-1])
+
+ # object version of online EM
+ for t in range(nframes):
+ ogmm.compute_sufficient_statistics_frame(data[t], nu[t])
+ ogmm.update_em_frame()
+
+ ogmm.gm.set_param(ogmm.cw, ogmm.cmu, ogmm.cva)
+ print ogmm.cw
+ print ogmm.cmu
+ print ogmm.cva
+
+def iter_3(data, gm):
+ """Online EM with densities + 1d API"""
+ from online_em import OnGMM1d
+
+ #def blop(self, frame, nu):
+ # self.compute_sufficient_statistics_frame(frame, nu)
+ #OnGMM.blop = blop
+
+ nframes = data.shape[0]
+ ogm = copy.copy(gm)
+ ogmm = OnGMM1d(ogm, 'kmean')
+ init_data = data[0:nframes / 20, :]
+ ogmm.init(init_data[:, 0])
+
+ # Forgetting param
+ ku = 0.005
+ t0 = 200
+ lamb = 1 - 1/(N.arange(-1, nframes-1) * ku + t0)
+ nu0 = 0.2
+ nu = N.zeros((len(lamb), 1))
+ nu[0] = nu0
+ for i in range(1, len(lamb)):
+ nu[i] = 1./(1 + lamb[i] / nu[i-1])
+
+ # object version of online EM
+ for t in range(nframes):
+ #assert ogmm.cw is ogmm.pw
+ #assert ogmm.cva is ogmm.pva
+ #assert ogmm.cmu is ogmm.pmu
+ a, b, c = ogmm.compute_sufficient_statistics_frame(data[t, 0], nu[t])
+ ##ogmm.blop(data[t,0], nu[t])
+ ogmm.update_em_frame(a, b, c)
+
+ #ogmm.gm.set_param(ogmm.cw, ogmm.cmu, ogmm.cva)
+ print ogmm.cw
+ print ogmm.cmu
+ print ogmm.cva
+
+def iter_4(data, gm):
+ """Online EM with densities2 + 1d API"""
+ from online_em2 import OnGMM1d
+
+ #def blop(self, frame, nu):
+ # self.compute_sufficient_statistics_frame(frame, nu)
+ #OnGMM.blop = blop
+
+ nframes = data.shape[0]
+ ogm = copy.copy(gm)
+ ogmm = OnGMM1d(ogm, 'kmean')
+ init_data = data[0:nframes / 20, :]
+ ogmm.init(init_data[:, 0])
+
+ # Forgetting param
+ ku = 0.005
+ t0 = 200
+ lamb = 1 - 1/(N.arange(-1, nframes-1) * ku + t0)
+ nu0 = 0.2
+ nu = N.zeros((len(lamb), 1))
+ nu[0] = nu0
+ for i in range(1, len(lamb)):
+ nu[i] = 1./(1 + lamb[i] / nu[i-1])
+
+ # object version of online EM
+ def blop():
+ #for t in range(nframes):
+ # #assert ogmm.cw is ogmm.pw
+ # #assert ogmm.cva is ogmm.pva
+ # #assert ogmm.cmu is ogmm.pmu
+ # #a, b, c = ogmm.compute_sufficient_statistics_frame(data[t, 0], nu[t])
+ # ###ogmm.blop(data[t,0], nu[t])
+ # #ogmm.update_em_frame(a, b, c)
+ # ogmm.compute_em_frame(data[t, 0], nu[t])
+ [ogmm.compute_em_frame(data[t, 0], nu[t]) for t in range(nframes)]
+ blop()
+
+ #ogmm.gm.set_param(ogmm.cw, ogmm.cmu, ogmm.cva)
+ print ogmm.cw
+ print ogmm.cmu
+ print ogmm.cva
+
+
+
+if __name__ == '__main__':
+ #import hotshot, hotshot.stats
+ #profile_file = 'onem1.prof'
+ #prof = hotshot.Profile(profile_file, lineevents=1)
+ #prof.runcall(test_v1)
+ #p = hotshot.stats.load(profile_file)
+ #print p.sort_stats('cumulative').print_stats(20)
+ #prof.close()
+
+ #import hotshot, hotshot.stats
+ #profile_file = 'onem2.prof'
+ #prof = hotshot.Profile(profile_file, lineevents=1)
+ #prof.runcall(test_v2)
+ #p = hotshot.stats.load(profile_file)
+ #print p.sort_stats('cumulative').print_stats(20)
+ #prof.close()
+
+ import hotshot, hotshot.stats
+ profile_file = 'onem3.prof'
+ prof = hotshot.Profile(profile_file, lineevents=1)
+ prof.runcall(test_v3)
+ p = hotshot.stats.load(profile_file)
+ print p.sort_stats('cumulative').print_stats(20)
+ prof.close()
+
+ import hotshot, hotshot.stats
+ profile_file = 'onem4.prof'
+ prof = hotshot.Profile(profile_file, lineevents=1)
+ prof.runcall(test_v4)
+ p = hotshot.stats.load(profile_file)
+ print p.sort_stats('cumulative').print_stats(20)
+ prof.close()
+ #test_v1()
+ #test_v2()
+ #test_v3()
More information about the Scipy-svn
mailing list